├── prompts ├── short_prompt_short_output.txt ├── short_prompt_long_output.txt ├── reverse_list.txt ├── README.md ├── noise_qa.txt ├── long_prompt_long_output.txt └── long_prompt_short_output.txt ├── cache_configs ├── full.yaml ├── keep_it_odd.yaml ├── random.yaml ├── recent_global.yaml ├── l2.yaml ├── debug_heavy_hitter.yaml ├── heavy_hitter.yaml ├── local_global.yaml ├── heavy_hitter_funnel.yaml ├── heavy_hitter_pyramid.yaml ├── task_stats.csv ├── fastgen.yaml └── hybrid.yaml ├── charts ├── attention_loss.png ├── llama3_performance_graphs │ ├── qmsum_BertScore_f1.jpg │ ├── qmsum_Rouge_rougeL.jpg │ ├── truthfulqa_Accuracy.jpg │ ├── musique_BertScore_f1.jpg │ ├── musique_Rouge_rougeL.jpg │ ├── reprobench_ExactMatch.jpg │ ├── squality_BertScore_f1.jpg │ ├── squality_Rouge_rougeL.jpg │ ├── dolomites_BertScore_f1.jpg │ ├── dolomites_Rouge_rougeL.jpg │ ├── reprobench_Levenshtein.jpg │ ├── scrollsquality_Accuracy.jpg │ ├── qmsum_LLM-Rouge_llm_rouge.jpg │ ├── rulercwe_StringMatch_score.jpg │ ├── rulerqa_StringMatch_score.jpg │ ├── rulervt_StringMatch_score.jpg │ ├── dolomites_LLM-Rouge_llm_rouge.jpg │ ├── musique_LLM-Rouge_llm_rouge.jpg │ ├── rulerniah_StringMatch_score.jpg │ └── squality_LLM-Rouge_llm_rouge.jpg ├── qwen2_performance_graphs │ ├── musique_BertScore_f1.jpg │ ├── musique_Rouge_rougeL.jpg │ ├── qmsum_BertScore_f1.jpg │ ├── qmsum_Rouge_rougeL.jpg │ ├── truthfulqa_Accuracy.jpg │ ├── dolomites_BertScore_f1.jpg │ ├── dolomites_Rouge_rougeL.jpg │ ├── reprobench_ExactMatch.jpg │ ├── reprobench_Levenshtein.jpg │ ├── squality_BertScore_f1.jpg │ ├── squality_Rouge_rougeL.jpg │ ├── qmsum_LLM-Rouge_llm_rouge.jpg │ ├── rulerqa_StringMatch_score.jpg │ ├── rulervt_StringMatch_score.jpg │ ├── scrollsquality_Accuracy.jpg │ ├── musique_LLM-Rouge_llm_rouge.jpg │ ├── rulercwe_StringMatch_score.jpg │ ├── rulerniah_StringMatch_score.jpg │ ├── dolomites_LLM-Rouge_llm_rouge.jpg │ └── squality_LLM-Rouge_llm_rouge.jpg ├── attention_loss.py └── blogpost_perf.py ├── images ├── kv_cache_flow.png ├── attention_loss_pg19.png ├── cold_compress_logo.jpg ├── kv_cache_compression.png ├── attention_loss_concept.png └── local_global_from_character_ai.png ├── scripts ├── prepare.sh ├── prepare_qwen2.sh ├── prepare_llama2.sh ├── prepare_llama3.sh ├── prepare_llama31.sh ├── download.py └── convert_hf_checkpoint.py ├── .gitignore ├── requirements.txt ├── setup.py ├── experiments ├── eval_all.sh ├── attention_loss.sh ├── variable_compression.sh ├── multi_strategy.sh └── multi_strategy.txt ├── LICENSE ├── attention_utils.py ├── BENCHMARK.md ├── eval_multi.py ├── DISCLAIMER.md ├── quantization_utils.py ├── tp.py ├── generate.py ├── parallelize_evals.py ├── prompt_compression.py ├── metric.py ├── tokenizer.py ├── GPTQ.py ├── model.py ├── eval.py └── generation_utils.py /prompts/short_prompt_short_output.txt: -------------------------------------------------------------------------------- 1 | Which architect designed the Guggenheim? -------------------------------------------------------------------------------- /cache_configs/full.yaml: -------------------------------------------------------------------------------- 1 | cache_strategy: ["full"] 2 | prompt_compression_strategy: ["full"] -------------------------------------------------------------------------------- /charts/attention_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/attention_loss.png -------------------------------------------------------------------------------- /images/kv_cache_flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/images/kv_cache_flow.png -------------------------------------------------------------------------------- /images/attention_loss_pg19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/images/attention_loss_pg19.png -------------------------------------------------------------------------------- /images/cold_compress_logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/images/cold_compress_logo.jpg -------------------------------------------------------------------------------- /images/kv_cache_compression.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/images/kv_cache_compression.png -------------------------------------------------------------------------------- /images/attention_loss_concept.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/images/attention_loss_concept.png -------------------------------------------------------------------------------- /images/local_global_from_character_ai.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/images/local_global_from_character_ai.png -------------------------------------------------------------------------------- /cache_configs/keep_it_odd.yaml: -------------------------------------------------------------------------------- 1 | cache_strategy: ["keep_it_odd"] 2 | prompt_compression_strategy: ["keep_it_odd"] 3 | global_tokens: 4 4 | recent_window: 10 -------------------------------------------------------------------------------- /charts/llama3_performance_graphs/qmsum_BertScore_f1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/llama3_performance_graphs/qmsum_BertScore_f1.jpg -------------------------------------------------------------------------------- /charts/llama3_performance_graphs/qmsum_Rouge_rougeL.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/llama3_performance_graphs/qmsum_Rouge_rougeL.jpg -------------------------------------------------------------------------------- /charts/llama3_performance_graphs/truthfulqa_Accuracy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/llama3_performance_graphs/truthfulqa_Accuracy.jpg -------------------------------------------------------------------------------- /charts/qwen2_performance_graphs/musique_BertScore_f1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/qwen2_performance_graphs/musique_BertScore_f1.jpg -------------------------------------------------------------------------------- /charts/qwen2_performance_graphs/musique_Rouge_rougeL.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/qwen2_performance_graphs/musique_Rouge_rougeL.jpg -------------------------------------------------------------------------------- /charts/qwen2_performance_graphs/qmsum_BertScore_f1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/qwen2_performance_graphs/qmsum_BertScore_f1.jpg -------------------------------------------------------------------------------- /charts/qwen2_performance_graphs/qmsum_Rouge_rougeL.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/qwen2_performance_graphs/qmsum_Rouge_rougeL.jpg -------------------------------------------------------------------------------- /charts/qwen2_performance_graphs/truthfulqa_Accuracy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/qwen2_performance_graphs/truthfulqa_Accuracy.jpg -------------------------------------------------------------------------------- /charts/llama3_performance_graphs/musique_BertScore_f1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/llama3_performance_graphs/musique_BertScore_f1.jpg -------------------------------------------------------------------------------- /charts/llama3_performance_graphs/musique_Rouge_rougeL.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/llama3_performance_graphs/musique_Rouge_rougeL.jpg -------------------------------------------------------------------------------- /charts/llama3_performance_graphs/reprobench_ExactMatch.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/llama3_performance_graphs/reprobench_ExactMatch.jpg -------------------------------------------------------------------------------- /charts/llama3_performance_graphs/squality_BertScore_f1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/llama3_performance_graphs/squality_BertScore_f1.jpg -------------------------------------------------------------------------------- /charts/llama3_performance_graphs/squality_Rouge_rougeL.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/llama3_performance_graphs/squality_Rouge_rougeL.jpg -------------------------------------------------------------------------------- /charts/qwen2_performance_graphs/dolomites_BertScore_f1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/qwen2_performance_graphs/dolomites_BertScore_f1.jpg -------------------------------------------------------------------------------- /charts/qwen2_performance_graphs/dolomites_Rouge_rougeL.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/qwen2_performance_graphs/dolomites_Rouge_rougeL.jpg -------------------------------------------------------------------------------- /charts/qwen2_performance_graphs/reprobench_ExactMatch.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/qwen2_performance_graphs/reprobench_ExactMatch.jpg -------------------------------------------------------------------------------- /charts/qwen2_performance_graphs/reprobench_Levenshtein.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/qwen2_performance_graphs/reprobench_Levenshtein.jpg -------------------------------------------------------------------------------- /charts/qwen2_performance_graphs/squality_BertScore_f1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/qwen2_performance_graphs/squality_BertScore_f1.jpg -------------------------------------------------------------------------------- /charts/qwen2_performance_graphs/squality_Rouge_rougeL.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/qwen2_performance_graphs/squality_Rouge_rougeL.jpg -------------------------------------------------------------------------------- /scripts/prepare.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | python scripts/download.py --repo_id $1 6 | python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$1 7 | -------------------------------------------------------------------------------- /charts/llama3_performance_graphs/dolomites_BertScore_f1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/llama3_performance_graphs/dolomites_BertScore_f1.jpg -------------------------------------------------------------------------------- /charts/llama3_performance_graphs/dolomites_Rouge_rougeL.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/llama3_performance_graphs/dolomites_Rouge_rougeL.jpg -------------------------------------------------------------------------------- /charts/llama3_performance_graphs/reprobench_Levenshtein.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/llama3_performance_graphs/reprobench_Levenshtein.jpg -------------------------------------------------------------------------------- /charts/llama3_performance_graphs/scrollsquality_Accuracy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/llama3_performance_graphs/scrollsquality_Accuracy.jpg -------------------------------------------------------------------------------- /charts/qwen2_performance_graphs/qmsum_LLM-Rouge_llm_rouge.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/qwen2_performance_graphs/qmsum_LLM-Rouge_llm_rouge.jpg -------------------------------------------------------------------------------- /charts/qwen2_performance_graphs/rulerqa_StringMatch_score.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/qwen2_performance_graphs/rulerqa_StringMatch_score.jpg -------------------------------------------------------------------------------- /charts/qwen2_performance_graphs/rulervt_StringMatch_score.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/qwen2_performance_graphs/rulervt_StringMatch_score.jpg -------------------------------------------------------------------------------- /charts/qwen2_performance_graphs/scrollsquality_Accuracy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/qwen2_performance_graphs/scrollsquality_Accuracy.jpg -------------------------------------------------------------------------------- /prompts/short_prompt_long_output.txt: -------------------------------------------------------------------------------- 1 | Write a detailed textbook on how to build a house from scratch. 2 | Write a separate chapter for each stage from initial planning to furnishing. -------------------------------------------------------------------------------- /cache_configs/random.yaml: -------------------------------------------------------------------------------- 1 | # Keeps Recent & Global tokens + random tokens in the middle 2 | cache_strategy: ["random"] 3 | prompt_compression_strategy: ["random"] 4 | global_tokens: 4 5 | -------------------------------------------------------------------------------- /charts/llama3_performance_graphs/qmsum_LLM-Rouge_llm_rouge.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/llama3_performance_graphs/qmsum_LLM-Rouge_llm_rouge.jpg -------------------------------------------------------------------------------- /charts/llama3_performance_graphs/rulercwe_StringMatch_score.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/llama3_performance_graphs/rulercwe_StringMatch_score.jpg -------------------------------------------------------------------------------- /charts/llama3_performance_graphs/rulerqa_StringMatch_score.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/llama3_performance_graphs/rulerqa_StringMatch_score.jpg -------------------------------------------------------------------------------- /charts/llama3_performance_graphs/rulervt_StringMatch_score.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/llama3_performance_graphs/rulervt_StringMatch_score.jpg -------------------------------------------------------------------------------- /charts/qwen2_performance_graphs/musique_LLM-Rouge_llm_rouge.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/qwen2_performance_graphs/musique_LLM-Rouge_llm_rouge.jpg -------------------------------------------------------------------------------- /charts/qwen2_performance_graphs/rulercwe_StringMatch_score.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/qwen2_performance_graphs/rulercwe_StringMatch_score.jpg -------------------------------------------------------------------------------- /charts/qwen2_performance_graphs/rulerniah_StringMatch_score.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/qwen2_performance_graphs/rulerniah_StringMatch_score.jpg -------------------------------------------------------------------------------- /charts/llama3_performance_graphs/dolomites_LLM-Rouge_llm_rouge.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/llama3_performance_graphs/dolomites_LLM-Rouge_llm_rouge.jpg -------------------------------------------------------------------------------- /charts/llama3_performance_graphs/musique_LLM-Rouge_llm_rouge.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/llama3_performance_graphs/musique_LLM-Rouge_llm_rouge.jpg -------------------------------------------------------------------------------- /charts/llama3_performance_graphs/rulerniah_StringMatch_score.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/llama3_performance_graphs/rulerniah_StringMatch_score.jpg -------------------------------------------------------------------------------- /charts/llama3_performance_graphs/squality_LLM-Rouge_llm_rouge.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/llama3_performance_graphs/squality_LLM-Rouge_llm_rouge.jpg -------------------------------------------------------------------------------- /charts/qwen2_performance_graphs/dolomites_LLM-Rouge_llm_rouge.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/qwen2_performance_graphs/dolomites_LLM-Rouge_llm_rouge.jpg -------------------------------------------------------------------------------- /charts/qwen2_performance_graphs/squality_LLM-Rouge_llm_rouge.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AnswerDotAI/cold-compress/HEAD/charts/qwen2_performance_graphs/squality_LLM-Rouge_llm_rouge.jpg -------------------------------------------------------------------------------- /cache_configs/recent_global.yaml: -------------------------------------------------------------------------------- 1 | # Sliding Window with support for global (lead) tokens 2 | cache_strategy: ["recent_global"] 3 | prompt_compression_strategy: ["recent_global"] 4 | global_tokens: 4 5 | -------------------------------------------------------------------------------- /prompts/reverse_list.txt: -------------------------------------------------------------------------------- 1 | Write this list in reverse order: a dog named Remy, a homemade hat, a wicked witch, a head of broccoli, a really smelly fried egg, a famous baseball player named Aaron Judge, and a mysterious aunt. -------------------------------------------------------------------------------- /cache_configs/l2.yaml: -------------------------------------------------------------------------------- 1 | # Evicts tokens with high L2-Norm key vectors based on 2 | # https://arxiv.org/abs/2406.11430 3 | cache_strategy: ["l2"] 4 | prompt_compression_strategy: ["l2"] 5 | global_tokens: 4 6 | recent_window: 10 -------------------------------------------------------------------------------- /cache_configs/debug_heavy_hitter.yaml: -------------------------------------------------------------------------------- 1 | cache_strategy: ["debug_heavy_hitter"] 2 | prompt_compression_strategy: ["heavy_hitter"] 3 | global_tokens: 4 4 | recent_window: 10 5 | history_window_size: 400 6 | attn_thresholding: False 7 | -------------------------------------------------------------------------------- /prompts/README.md: -------------------------------------------------------------------------------- 1 | # Prompts 2 | 3 | This directory contains a collection of toy prompts which can be run with: 4 | 5 | ``` 6 | python generate.py --prompt {noise_qa,reverse_list,...}.txt 7 | ``` 8 | 9 | For benchmarking, please use `eval.py` which pulls in long-context benchmarks from `tasks.py`. -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .idea 3 | .DS_Store 4 | *.egg-info 5 | build 6 | 7 | # data 8 | data 9 | checkpoints 10 | out 11 | !data/shakespeare/prepare.py 12 | wandb 13 | 14 | # downloaded by our tests 15 | original_model.py 16 | original_adapter.py 17 | 18 | .vscode 19 | 20 | torch_compile_debug 21 | results -------------------------------------------------------------------------------- /cache_configs/heavy_hitter.yaml: -------------------------------------------------------------------------------- 1 | cache_strategy: ["heavy_hitter"] 2 | prompt_compression_strategy: ["heavy_hitter"] 3 | global_tokens: 4 4 | recent_window: 10 5 | history_window_size: 1 # If history_window_size is 1, then the history window is disabled (we use all historical attentions which is faster). 6 | attn_thresholding: False 7 | -------------------------------------------------------------------------------- /cache_configs/local_global.yaml: -------------------------------------------------------------------------------- 1 | # Alternates between local and sliding-window attention layers 2 | cache_strategy: ["full", "window"] 3 | cache_strategy_pattern: "repeat" # or tile 4 | max_cache_length: [1.0, 0.25] 5 | cache_length_pattern: "repeat" # or tile 6 | prompt_compression_strategy: [None, "recent_global"] 7 | global_tokens: 4 -------------------------------------------------------------------------------- /scripts/prepare_qwen2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | # Use env vars if they exist, otherwise set defaults 6 | : "${HF:=Qwen/Qwen2-1.5B-Instruct}" 7 | 8 | # Export the variables 9 | export HF 10 | 11 | python scripts/download.py --repo_id $HF 12 | python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$HF 13 | -------------------------------------------------------------------------------- /scripts/prepare_llama2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | # Use env vars if they exist, otherwise set defaults 6 | : "${HF:=meta-llama/Llama-2-7b-chat-hf}" 7 | 8 | # Export the variables 9 | export HF 10 | 11 | python scripts/download.py --repo_id $HF 12 | python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$HF 13 | -------------------------------------------------------------------------------- /scripts/prepare_llama3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | # Use env vars if they exist, otherwise set defaults 6 | : "${HF:=meta-llama/Meta-Llama-3-8B-Instruct}" 7 | 8 | # Export the variables 9 | export HF 10 | 11 | python scripts/download.py --repo_id $HF 12 | python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$HF 13 | -------------------------------------------------------------------------------- /scripts/prepare_llama31.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | # Use env vars if they exist, otherwise set defaults 6 | : "${HF:=meta-llama/Meta-Llama-3.1-8B-Instruct}" 7 | 8 | # Export the variables 9 | export HF 10 | 11 | python scripts/download.py --repo_id $HF 12 | python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$HF 13 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.5.0.dev20240723+cu121 2 | pytorch-triton==3.0.0+dedb7bdf33 3 | absl-py 4 | accelerate 5 | bert-score 6 | blobfile 7 | claudette 8 | datasets 9 | evaluate 10 | fuzzywuzzy 11 | huggingface_hub 12 | nltk 13 | python-Levenshtein 14 | regex 15 | rouge-score 16 | ruff 17 | safetensors 18 | scikit-learn 19 | sentencepiece 20 | tiktoken 21 | git+https://github.com/google-research/bleurt.git 22 | -------------------------------------------------------------------------------- /cache_configs/heavy_hitter_funnel.yaml: -------------------------------------------------------------------------------- 1 | cache_strategy: ["heavy_hitter"] 2 | prompt_compression_strategy: ["heavy_hitter"] 3 | cache_length_pattern: "funnel" # More compression at lower layers 4 | global_tokens: 4 5 | recent_window: 10 6 | history_window_size: 1 # If history_window_size is 1, then the history window is disabled (we use all historical attentions which is faster). 7 | attn_thresholding: False 8 | max_cache_length: [1024] -------------------------------------------------------------------------------- /cache_configs/heavy_hitter_pyramid.yaml: -------------------------------------------------------------------------------- 1 | cache_strategy: ["heavy_hitter"] 2 | prompt_compression_strategy: ["heavy_hitter"] 3 | cache_length_pattern: "pyramid" # More compression at higher layers 4 | global_tokens: 4 5 | recent_window: 10 6 | history_window_size: 1 # If history_window_size is 1, then the history window is disabled (we use all historical attentions which is faster). 7 | attn_thresholding: False 8 | max_cache_length: [1024] -------------------------------------------------------------------------------- /cache_configs/task_stats.csv: -------------------------------------------------------------------------------- 1 | task,n,is_mcqa,prompt_tokens,label_tokens,n_choices 2 | dolomites,664,False,780.5105421686746,468.89006024096386, 3 | musique,2417,False,2469.275134464212,14.035579328959543, 4 | qmsum,281,False,14065.02846975089,84.60854092526691, 5 | rulercwe,500,False,3791.214,11.924400000000007, 6 | rulerniah,500,False,3819.522,13.0, 7 | rulerqa,500,False,3333.914,13.738, 8 | rulervt,500,False,3847.114,13.107199999999976, 9 | scrollsquality,2086,True,5986.950623202301,11.0, 10 | squality,260,False,6879.084615384615,283.7625, 11 | triviaqa,17210,False,10643.657989540965,13.0, 12 | truthfulqa,817,True,152.84944920440637,11.0, 13 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from setuptools import setup, find_packages 7 | 8 | setup( 9 | name="gpt-fast", 10 | version="0.1", 11 | packages=find_packages(), 12 | install_requires=[ 13 | "torch", 14 | ], 15 | description="A simple, fast, pure PyTorch Llama inference engine", 16 | long_description=open("README.md").read(), 17 | long_description_content_type="text/markdown", 18 | url="https://github.com/pytorch-labs/gpt-fast", 19 | ) 20 | -------------------------------------------------------------------------------- /experiments/eval_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | MODELS=( 6 | "checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/model.pth" 7 | "checkpoints/Qwen/Qwen2-7B-Instruct/model.pth" 8 | ) 9 | NUM_SAMPLES=500 10 | CACHE_SIZES="0.75 0.5 0.25 0.1 0.05" 11 | TASKS="truthfulqa rulerqa rulerniah rulervt rulercwe scrollsquality musique squality dolomites qmsum repobench" 12 | CACHE_CONFIGS="random l2 heavy_hitter recent_global" 13 | 14 | for MODEL in ${MODELS[@]}; do 15 | echo "Starting evals for ${MODEL}" 16 | python parallelize_evals.py \ 17 | --checkpoint_path $MODEL \ 18 | --config_names $CACHE_CONFIGS \ 19 | --tasks $TASKS \ 20 | --cache_sizes $CACHE_SIZES \ 21 | --num_samples $NUM_SAMPLES \ 22 | --add_full 23 | done 24 | -------------------------------------------------------------------------------- /experiments/attention_loss.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | DIR=$(dirname $(dirname "$0")) 6 | export CKPT=$DIR/checkpoints/Qwen/Qwen2-1.5B-Instruct/model.pth 7 | RATIOS=(0.25 0.5 0.75) 8 | GLOBAL_TOKENS=4 9 | 10 | # For loop over ratios 11 | for RATIO in $RATIOS 12 | do 13 | echo "Running Attention Loss Experiments for ${CKPT} at Ratio ${RATIO}" 14 | python eval.py --compile --tasks pg19 --global_tokens $GLOBAL_TOKENS --checkpoint_path $CKPT --cache_strategy heavy_hitter --prompt_compression_strategy heavy_hitter --max_cache_length $RATIO 15 | python eval.py --compile --tasks pg19 --global_tokens $GLOBAL_TOKENS --checkpoint_path $CKPT --cache_strategy debug_heavy_hitter --prompt_compression_strategy heavy_hitter --max_cache_length $RATIO 16 | done 17 | -------------------------------------------------------------------------------- /cache_configs/fastgen.yaml: -------------------------------------------------------------------------------- 1 | # FastGen hybrid strategies 2 | # https://arxiv.org/abs/2310.01801 3 | cache_strategy: ["hybrid"] 4 | max_cache_length: [1.0] # [Do not Change] Control compression ratio with min_recovery_frac 5 | global_tokens: 4 6 | # min_recovery_frac: 0.85 # Higher is less compression (0.85 means we choose the policy which compresses the most tokens AND recovers 85% of the full attention matrix) 7 | hybrid_strategies: 8 | - strategy: "special" 9 | - strategy: "special_punc" 10 | - strategy: "special_punc_heavy_hitter" 11 | heavy_hitter_frac: 0.3 # Fraction of important tokens to keep 12 | - strategy: "special_punc_heavy_hitter_window" 13 | recent_window: 0.3 # Fraction of recent tokens to keep 14 | heavy_hitter_frac: 0.3 # Fraction of important tokens to keep 15 | - strategy: "full" -------------------------------------------------------------------------------- /cache_configs/hybrid.yaml: -------------------------------------------------------------------------------- 1 | # Based on FastGen but with custom hybrid strategies 2 | # https://arxiv.org/abs/2310.01801 3 | cache_strategy: ["hybrid"] 4 | prompt_compression_strategy: ["full"] 5 | max_cache_length: [1.0] # [Do not Change] Control compression ratio with min_recovery_frac 6 | global_tokens: 4 7 | # min_recovery_frac: 0.85 # Higher is less compression (0.85 means we choose the policy which compresses the most tokens AND recovers 85% of the full attention matrix) 8 | hybrid_strategies: # 9 | - strategy: "window" 10 | recent_window: 0.1 # Fraction of recent tokens to keep 11 | - strategy: "window_heavy_hitter" 12 | heavy_hitter_frac: 0.25 # Fraction of important tokens to keep 13 | recent_window: 0.1 # Fraction of recent tokens to keep 14 | - strategy: "window_heavy_hitter" 15 | heavy_hitter_frac: 0.5 # Fraction of important tokens to keep 16 | recent_window: 0.1 # Fraction of recent tokens to keep 17 | - strategy: "full" -------------------------------------------------------------------------------- /experiments/variable_compression.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | DIR=$(dirname $(dirname "$0")) 6 | export CKPT=$DIR/checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/model.pth 7 | NUM_SAMPLES=500 8 | GLOBAL_TOKENS=4 9 | CACHE_STRATEGY="heavy_hitter" 10 | PROMPT_STRATEGY="heavy_hitter" 11 | TASKS="rulerniah musique dolomites" 12 | 13 | SHARED_ARGS="--compile --tasks ${TASKS} --global_tokens ${GLOBAL_TOKENS} --checkpoint_path ${CKPT} --num_samples ${NUM_SAMPLES}" 14 | 15 | MAX_CACHE_LENGTHS=(0.1 0.25 0.5) 16 | 17 | for MAX_CACHE_LENGTH in $MAX_CACHE_LENGTHS 18 | do 19 | echo "Starting experiments with Max Cache Length=${MAX_CACHE_LENGTH}." 20 | python eval.py $SHARED_ARGS --max_cache_length $MAX_CACHE_LENGTH --cache_length_pattern pyramid 21 | python eval.py $SHARED_ARGS --max_cache_length $MAX_CACHE_LENGTH --cache_length_pattern repeat 22 | python eval.py $SHARED_ARGS --max_cache_length $MAX_CACHE_LENGTH --cache_length_pattern tile 23 | done 24 | -------------------------------------------------------------------------------- /prompts/noise_qa.txt: -------------------------------------------------------------------------------- 1 | Answer the Question below. Ignore the "*". 2 | ********** 3 | ********** 4 | ********** 5 | ********** 6 | ********** 7 | ********** 8 | ********** 9 | ********** 10 | ********** 11 | ********** 12 | ********** 13 | ********** 14 | ********** 15 | ********** 16 | ********** 17 | ********** 18 | ********** 19 | ********** 20 | ********** 21 | ********** 22 | ********** 23 | ********** 24 | ********** 25 | ********** 26 | ********** 27 | ********** 28 | ********** 29 | ********** 30 | ********** 31 | ********** 32 | ********** 33 | ********** 34 | ********** 35 | ********** 36 | ********** 37 | ********** 38 | ********** 39 | ********** 40 | ********** 41 | ********** 42 | What is (10 * 10) - 5? Explain how you arrived at the answer as if you were helping a 5 year old just starting to learn math. 43 | ********** 44 | ********** 45 | ********** 46 | ********** 47 | ********** 48 | ********** 49 | ********** 50 | ********** 51 | ********** 52 | ********** 53 | ********** 54 | ********** 55 | ********** 56 | ********** 57 | ********** 58 | ********** 59 | ********** 60 | ********** 61 | ********** 62 | ********** 63 | ********** 64 | ********** 65 | ********** 66 | ********** 67 | ********** 68 | ********** 69 | ********** 70 | ********** 71 | ********** 72 | ********** 73 | ********** 74 | ********** 75 | ********** 76 | ********** 77 | ********** 78 | ********** 79 | ********** 80 | ********** 81 | ********** 82 | ********** -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2023 Meta 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 10 | 11 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /experiments/multi_strategy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | DIR=$(dirname $(dirname "$0")) 6 | export CKPT=$DIR/checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/model.pth 7 | NUM_SAMPLES=500 8 | GLOBAL_TOKENS=4 9 | TASKS="rulerniah musique dolomites" 10 | 11 | SHARED_ARGS="--compile --tasks ${TASKS} --global_tokens ${GLOBAL_TOKENS} --checkpoint_path ${CKPT} --num_samples ${NUM_SAMPLES}" 12 | 13 | MAX_CACHE_LENGTHS=(0.25 0.5 0.75) # fraction of full cache to set for "local" layers 14 | 15 | for MAX_CACHE_LENGTH in $MAX_CACHE_LENGTHS 16 | do 17 | # Select "local", e.g., compressed, strategy 18 | COMPRESS_STRAT="window" 19 | COMPRESS_PROMPT_STRAT="recent_global" 20 | 21 | LOCAL2GLOBAL_ARGS="--cache_strategy ${COMPRESS_STRAT} full \ 22 | --prompt_compression_strategy ${COMPRESS_PROMPT_STRAT} recent_global \ 23 | --max_cache_length ${MAX_CACHE_LENGTH} 1.0" 24 | 25 | GLOBAL2LOCAL_ARGS="--cache_strategy full ${COMPRESS_STRAT} \ 26 | --prompt_compression_strategy recent_global ${COMPRESS_PROMPT_STRAT} \ 27 | --max_cache_length 1.0 ${MAX_CACHE_LENGTH}" 28 | 29 | ALTERNATING_ARGS="--cache_length_pattern repeat --cache_strategy_pattern repeat" 30 | REPEATING_ARGS="--cache_length_pattern tile --cache_strategy_pattern tile" 31 | 32 | A="${SHARED_ARGS} ${LOCAL2GLOBAL_ARGS} ${ALTERNATING_ARGS}" 33 | B="${SHARED_ARGS} ${LOCAL2GLOBAL_ARGS} ${REPEATING_ARGS}" 34 | C="${SHARED_ARGS} ${GLOBAL2LOCAL_ARGS} ${ALTERNATING_ARGS}" 35 | D="${SHARED_ARGS} ${GLOBAL2LOCAL_ARGS} ${REPEATING_ARGS}" 36 | 37 | echo $A 38 | python eval.py $A 39 | 40 | echo $B 41 | python eval.py $B 42 | 43 | echo $C 44 | python eval.py $C 45 | 46 | echo $D 47 | python eval.py $D 48 | done 49 | -------------------------------------------------------------------------------- /scripts/download.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | from typing import Optional 8 | 9 | from requests.exceptions import HTTPError 10 | 11 | 12 | def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -> None: 13 | from huggingface_hub import snapshot_download 14 | 15 | os.makedirs(f"checkpoints/{repo_id}", exist_ok=True) 16 | 17 | # if directory is not empty, don't download 18 | if os.listdir(f"checkpoints/{repo_id}"): 19 | print( 20 | f'Directory checkpoints/{repo_id} is not empty, skipping download. First, "rm -rf checkpoints/{repo_id}" if you want to re-download.' 21 | ) 22 | return 23 | 24 | try: 25 | snapshot_download( 26 | repo_id, 27 | local_dir=f"checkpoints/{repo_id}", 28 | local_dir_use_symlinks=False, 29 | token=hf_token, 30 | ) 31 | except HTTPError as e: 32 | if e.response.status_code == 401: 33 | print( 34 | "You need to pass a valid `--hf_token=...` to download private checkpoints." 35 | ) 36 | else: 37 | raise e 38 | 39 | 40 | if __name__ == "__main__": 41 | import argparse 42 | 43 | parser = argparse.ArgumentParser(description="Download data from HuggingFace Hub.") 44 | parser.add_argument( 45 | "--repo_id", 46 | type=str, 47 | default="karpathy/tinyllamas", 48 | help="Repository ID to download from.", 49 | ) 50 | parser.add_argument( 51 | "--hf_token", type=str, default=None, help="HuggingFace API token." 52 | ) 53 | 54 | args = parser.parse_args() 55 | hf_download(args.repo_id, args.hf_token) 56 | -------------------------------------------------------------------------------- /attention_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Tuple 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | 7 | 8 | def scaled_dot_product_attention( 9 | query, 10 | key, 11 | value, 12 | attn_mask=None, 13 | dropout_p=0.0, 14 | scale=None, 15 | return_attn=False, 16 | attn_top_k=1.0, 17 | ) -> Tuple[torch.Tensor, torch.Tensor | None]: 18 | """ 19 | Uses naive PyTorch sdpa implementation if we need to return_attn. Otherwise use the optimized version. 20 | 21 | The naive implementation will be optimized later. 22 | """ 23 | B, H, L, S = query.size(0), query.size(1), query.size(-2), key.size(-2) 24 | top_k = ( 25 | S if L > 1 else int(attn_top_k * S) 26 | ) # We use full attention during prefill (L > 1) 27 | if not return_attn and top_k == S: 28 | return F.scaled_dot_product_attention( 29 | query, 30 | key, 31 | value, 32 | attn_mask=attn_mask, 33 | dropout_p=dropout_p, 34 | scale=scale, 35 | ), None 36 | scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale 37 | attn_weight = query @ key.transpose(-2, -1) * scale_factor 38 | 39 | if attn_mask is not None: 40 | assert top_k == S, "Top-k attention not supported with masks." 41 | attn_bias = torch.zeros(B, H, L, S, dtype=query.dtype, device=query.device) 42 | attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) 43 | attn_weight += attn_bias 44 | 45 | if top_k < S: 46 | _, top_k_idxs = attn_weight.topk(top_k, dim=-1) 47 | value = value.gather( 48 | -2, top_k_idxs.view(B, H, top_k, 1).expand(-1, -1, -1, value.shape[-1]) 49 | ) 50 | attn_weight = attn_weight.gather(-1, top_k_idxs) 51 | 52 | attn_prob = torch.softmax(attn_weight, dim=-1) 53 | attn_prob = torch.dropout(attn_prob, dropout_p, train=True) 54 | return attn_prob @ value, attn_prob 55 | -------------------------------------------------------------------------------- /BENCHMARK.md: -------------------------------------------------------------------------------- 1 | # Description of Long-Context Tasks in the Eval Harness 2 | 3 | These tasks can be found in `./tasks.py` and are invoked from the `eval.py` harness with the `--tasks` parameter. 4 | 5 | ## Synthetic 6 | 7 | ### [RULER](https://arxiv.org/abs/2404.06654) 8 | 9 | RULER defines a set of synthetic tasks designed to test a model’s long-context understanding. 10 | 11 | Tasks include needle in a haystack (NIAH), variable tracking (VT), question answering (QA), and common word extraction (CWE). 12 | 13 | ## Domain-Specific 14 | 15 | ### [Dolomites](https://arxiv.org/abs/2405.05938) 16 | 17 | Evaluates the model’s ability to perform domain-specific methodical writing tasks such as writing a differential diagnosis for a patient, or writing a lesson plan for students. 18 | 19 | ## Coding 20 | 21 | ### [RepoBench](https://arxiv.org/abs/2306.03091) 22 | 23 | This task tests the model’s ability to understand coding repositories and make correct predictions for code completion. 24 | 25 | ## QA 26 | 27 | ### [MuSiQue](https://arxiv.org/abs/2108.00573) 28 | 29 | MuSiQue is a question-answering dataset that tests the model’s ability to perform multihop reasoning over a long input context. 30 | 31 | ## [TruthfulQA](https://arxiv.org/abs/2109.07958) 32 | 33 | TruthfulQA tests the models ability to answer questions truthfully across a broad set of categories such as health, law, finance, and politics. 34 | 35 | ## Language Modeling 36 | 37 | ### [PG19](https://github.com/google-deepmind/pg19) 38 | 39 | This task tests the model’s ability to generate longform text (~8K tokens) by providing a title and first initial words of a book. 40 | 41 | ## Summarization 42 | 43 | ## [QMSum](https://arxiv.org/abs/2104.05938) 44 | 45 | A meeting summarization dataset that evaluates the model’s ability to select and summarize content that is relevant to the given query. 46 | 47 | ## [SQuALITY](https://arxiv.org/abs/2205.11465) 48 | 49 | SQuALITY is a question-focused summarization dataset, which tests the models ability to understand long narratives and select and summarize content relevant to the provided question. 50 | 51 | ## [QuALITY](https://arxiv.org/abs/2112.08608v2) 52 | 53 | QuALITY tests the model’s ability to understand and answer questions about long narratives. 54 | -------------------------------------------------------------------------------- /eval_multi.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import sys 7 | import argparse 8 | from pathlib import Path 9 | 10 | import torch 11 | import torch._dynamo.config 12 | import torch._inductor.config 13 | 14 | 15 | from cache import add_cache_arguments 16 | from generation_utils import add_generation_arguments 17 | 18 | torch._inductor.config.coordinate_descent_tuning = True 19 | torch._inductor.config.triton.unique_kernel_names = True 20 | torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future 21 | 22 | default_device = "cuda" if torch.cuda.is_available() else "cpu" 23 | 24 | # support running without installing as a package 25 | wd = Path(__file__).parent.parent.resolve() 26 | sys.path.append(str(wd)) 27 | 28 | from eval import add_eval_args, setup, merge_cache_config, main as eval_main 29 | 30 | 31 | HPARAMS = { 32 | "max_cache_length": [[8192], [4096], [2048], [1024], [512], [256], [128]], 33 | "min_recovery_frac": [0.5, 0.6, 0.7, 0.8, 0.9, 0.95], 34 | } 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser( 39 | description="Sweep a hyper-parameter for a KV-Cache Compression Algorithms." 40 | ) 41 | 42 | parser.add_argument( 43 | "--hparam", 44 | default="max_cache_length", 45 | help="The hyper-parameter to sweep.", 46 | ) 47 | 48 | add_eval_args(parser) 49 | add_generation_arguments(parser) 50 | add_cache_arguments(parser) 51 | 52 | args = merge_cache_config(parser.parse_args()) 53 | 54 | assert args.hparam in HPARAMS, f"Set {args.hparam} in HPARAMS dictionary first." 55 | 56 | for v in HPARAMS[args.hparam]: 57 | # Copy the args object to avoid modifying the original 58 | exp_args = argparse.Namespace(**vars(args)) 59 | print(f"Setting {args.hparam} to {v}") 60 | setattr(exp_args, args.hparam, v) 61 | 62 | out_dir = setup(exp_args) 63 | 64 | eval_main( 65 | args, 66 | args.tasks, 67 | args.debug, 68 | args.checkpoint_path, 69 | args.profile, 70 | args.compile, 71 | args.feed_long_prompts, 72 | args.device, 73 | cache_kwargs=vars(exp_args), 74 | out_dir=out_dir, 75 | ) 76 | -------------------------------------------------------------------------------- /charts/attention_loss.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import pandas as pd 4 | from pathlib import Path 5 | 6 | 7 | if __name__ == "__main__": 8 | # Define the data 9 | df = pd.read_csv("/workspace/attention_loss.csv") 10 | 11 | decoding_steps = np.arange(500, 8500, 500) 12 | n = len(decoding_steps) 13 | models = ['Low Compression', 'Medium Compression', 'High Compression'] 14 | 15 | # Sample data - replace with your actual data 16 | attention_loss = { 17 | 'Low Compression': df["25_attention_loss"][:n], 18 | 'Medium Compression': df["50_attention_loss"][:n], 19 | 'High Compression': df["75_attention_loss"][:n], 20 | } 21 | 22 | ppl_delta = { 23 | 'Low Compression': df["25_ppl_delta"][:n], 24 | 'Medium Compression': df["50_ppl_delta"][:n], 25 | 'High Compression': df["75_ppl_delta"][:n], 26 | } 27 | 28 | # Create the plot 29 | plt.rcParams.update({'font.size': 20}) 30 | 31 | fig, ax1 = plt.subplots(figsize=(20, 10)) 32 | 33 | # Colors for each model 34 | colors = ["#006AA7", '#16a085', '#8e44ad', '#d35400'] 35 | 36 | # Plot Attention Loss 37 | for model, color in zip(models, colors): 38 | ax1.plot(decoding_steps, attention_loss[model], color=color, label=f'{model} (Attention Loss)', linewidth=6) 39 | ax1.scatter(decoding_steps, attention_loss[model], color=color, s=400) 40 | 41 | ax1.set_xlabel('Decoding Steps', fontsize=32) 42 | ax1.set_ylabel('Attention Loss', fontsize=32) 43 | 44 | ax1.tick_params(axis='y', labelsize=32) 45 | ax1.tick_params(axis='x', labelsize=32) 46 | 47 | # Create a second y-axis for PPL 48 | ax2 = ax1.twinx() 49 | 50 | # Plot Perplexity (PPL) 51 | for model, color in zip(models, colors): 52 | ax2.plot(decoding_steps, ppl_delta[model], color=color, linestyle='--', label=f'{model} (PPL Δ)', linewidth=6) 53 | ax2.scatter(decoding_steps, ppl_delta[model], color=color, marker='s', s=400) 54 | 55 | ax2.set_ylabel("Perplexity Delta (PPL Δ)", fontsize=32) 56 | ax2.tick_params(axis="y", labelsize=32) 57 | 58 | # Combine legends 59 | lines1, labels1 = ax1.get_legend_handles_labels() 60 | lines2, labels2 = ax2.get_legend_handles_labels() 61 | ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper left', bbox_to_anchor=(0.05, 0.95), borderaxespad=0.25, fontsize=24) 62 | 63 | plt.title("Attention Loss & Perplexity vs Decoding Steps", fontsize=32) 64 | plt.grid(True) 65 | plt.tight_layout() 66 | # Save a plot to ../images directory 67 | # Get the current directory 68 | current_dir = Path(__file__).resolve().parent 69 | # Save the plot to the desired path 70 | plt.savefig(current_dir.parent / "images" / "attention_loss_pg19.png") 71 | -------------------------------------------------------------------------------- /charts/blogpost_perf.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | 4 | if __name__ == "__main__": 5 | sequence_lengths = [8192, 16384, 32768, 65536] 6 | baseline_tokens = [20.82, 19.13, 11.77, 6.60] 7 | compile_tokens = [69.61, 50.18, 30.03, 17.97] 8 | compress_tokens = [73.26, 71.94, 71.93, 71.81] 9 | 10 | baseline_memory = [1.04, 2.12, 4.24, 8.79] 11 | compile_memory = [1.04, 2.12, 4.24, 8.79] 12 | compress_memory = [0.52, 0.52, 0.52, 0.52] 13 | 14 | baseline_perplexity = [10.69, 9.53, 10.45, 10.52] 15 | compile_perplexity = [10.69, 10.63, 10.45, 10.52] 16 | compress_perplexity = [10.70, 9.69, 10.59, 10.70] 17 | 18 | # Set up the plot 19 | plt.figure(figsize=(20, 8)) 20 | # Custom style with larger fonts, especially for axes 21 | plt.rcParams.update({ 22 | 'font.size': 18, 23 | 'axes.labelsize': 28, 24 | 'axes.titlesize': 30, 25 | 'xtick.labelsize': 24, 26 | 'ytick.labelsize': 24, 27 | 'legend.fontsize': 20, 28 | 'figure.titlesize': 32, 29 | 'axes.grid': True, 30 | 'grid.alpha': 0.3, 31 | 'axes.axisbelow': True, 32 | 'axes.edgecolor': '#888888', 33 | 'axes.linewidth': 1.5, 34 | }) 35 | # plt.style.use('fivethirtyeight') 36 | 37 | # Colors 38 | colors = ['#1f77b4', '#ff7f0e', '#2ca02c'] 39 | 40 | # First subplot: Tokens / Second 41 | plt.subplot(121) 42 | plt.plot(sequence_lengths, baseline_tokens, label='Baseline', color=colors[0], linewidth=5) 43 | plt.plot(sequence_lengths, compile_tokens, label='+ Compile', color=colors[1], linewidth=5) 44 | plt.plot(sequence_lengths, compress_tokens, label='+ Compile + Compression', color=colors[2], linewidth=5) 45 | plt.xlabel('Sequence Length', fontsize=22) 46 | plt.ylabel('Tokens / Second', fontsize=22) 47 | plt.title('Tokens / Second vs Sequence Length', fontsize=26) 48 | plt.legend(fontsize=20) 49 | 50 | # Second subplot: KV Cache Memory GB 51 | plt.subplot(122) 52 | plt.plot(sequence_lengths, baseline_memory, label='Baseline', color=colors[0], linewidth=5) 53 | plt.plot(sequence_lengths, compile_memory, label='+ Compile', color=colors[1], linewidth=5) 54 | plt.plot(sequence_lengths, compress_memory, label='+ Compile + Compression', color=colors[2], linewidth=5) 55 | plt.xlabel('Sequence Length', fontsize=22) 56 | plt.ylabel('KV Cache Memory (GB)', fontsize=22) 57 | plt.title('KV Cache Memory vs Sequence Length', fontsize=26) 58 | plt.legend(fontsize=20) 59 | 60 | # Uncomment to add the third plot and change plt.subplot numbers above to 131 and 132 61 | # Third subplot: Perplexity 62 | # plt.subplot(133) 63 | # plt.plot(sequence_lengths, baseline_perplexity, label='Baseline', color=colors[0], linewidth=3) 64 | # plt.plot(sequence_lengths, compile_perplexity, label='+ Compile', color=colors[1], linewidth=3) 65 | # plt.plot(sequence_lengths, compress_perplexity, label='+ Compile + Compression', color=colors[2], linewidth=3) 66 | # plt.xlabel('Sequence Length', fontsize=22) 67 | # plt.ylabel('Perplexity', fontsize=22) 68 | # plt.title('Perplexity vs Sequence Length', fontsize=26) 69 | # plt.legend(fontsize=20) 70 | 71 | # Adjust layout and save 72 | plt.tight_layout() 73 | plt.savefig('performance_graphs.png', dpi=300, bbox_inches='tight') 74 | -------------------------------------------------------------------------------- /prompts/long_prompt_long_output.txt: -------------------------------------------------------------------------------- 1 | You are an architect tasked with drawing up plans for a modern residential house. 2 | 3 | Architectural Plan Creation Instructions 4 | 5 | Objective: 6 | Create a comprehensive set of architectural plans for a modern residential house. The plans should include detailed layouts, elevations, sections, and necessary annotations to guide the construction process. The design should focus on functionality, aesthetics, sustainability, and compliance with local building codes. 7 | 8 | Requirements: 9 | 10 | General Layout: 11 | 12 | Total area: Approximately 2,500 square feet. 13 | Number of floors: Two. 14 | Number of bedrooms: Four (including a master suite). 15 | Number of bathrooms: Three full bathrooms and one half bathroom. 16 | Common areas: Open-plan kitchen, dining area, living room, and a study/office. 17 | Additional spaces: Laundry room, garage (for two cars), storage rooms, and a small basement. 18 | Site Plan: 19 | 20 | Include property boundaries, adjacent streets, and any existing structures. 21 | Show the placement of the house, driveway, pathways, garden, and outdoor living spaces (e.g., patio, deck). 22 | Include landscaping elements like trees, shrubs, and lawn areas. 23 | Floor Plans: 24 | 25 | Ground Floor: Include entryway, living spaces, kitchen, one bedroom (guest room), one full bathroom, and access to the garage. 26 | Second Floor: Include master suite with attached bathroom and walk-in closet, two additional bedrooms, one full bathroom, and a study/office. 27 | Indicate all door and window placements, furniture layouts, and circulation paths. 28 | Elevations: 29 | 30 | Provide front, rear, and side elevations. 31 | Show the external appearance, including the roof design, facade materials, window and door placements, and any architectural features (e.g., balconies, porches). 32 | Sections: 33 | 34 | Include at least two sections (one longitudinal and one cross-sectional) showing internal details. 35 | Highlight the relationship between different floors and ceiling heights. 36 | Show structural elements like beams, columns, and floor slabs. 37 | Roof Plan: 38 | 39 | Indicate the roof slope, materials, drainage system, and any roof features (e.g., skylights, chimneys). 40 | Electrical and Plumbing Plans: 41 | 42 | Show the layout of electrical outlets, switches, lighting fixtures, and major appliances. 43 | Include the plumbing layout for water supply and drainage, showing the location of pipes, fixtures, and connections. 44 | Materials and Finishes: 45 | 46 | Specify the materials for walls, floors, ceilings, and roofs. 47 | Include details on interior and exterior finishes (e.g., paint, tiles, cladding). 48 | Sustainability Features: 49 | 50 | Incorporate energy-efficient systems (e.g., HVAC, solar panels). 51 | Use sustainable building materials. 52 | Plan for natural lighting and ventilation. 53 | Include rainwater harvesting and greywater recycling systems if possible. 54 | Compliance: 55 | 56 | Ensure the design complies with local building codes and regulations. 57 | Include necessary annotations and notes for construction guidelines. 58 | 59 | You must return the following: 60 | - Include a detailed list of materials and specifications. 61 | - Add a cover sheet with project title, address, date, and designer's name. 62 | - Add a sheet for each component with detailed plans. 63 | - Ensure all documents are clearly labeled and organized. -------------------------------------------------------------------------------- /DISCLAIMER.md: -------------------------------------------------------------------------------- 1 | # Cold Compress Disclaimer & Attribution 2 | 3 | This toolkit, "Cold Compress," is provided for research purposes only and is not intended for commercial use. It builds upon the open-source project GPT-Fast, which is copyright © 2023 Meta. The original source code from GPT-Fast is redistributed under the terms of its original BSD-style license, which permits use and redistribution with or without modification under the conditions listed in the license. 4 | 5 | The Cold Compress toolkit includes additional code that is the original work of the contributors to this toolkit. All new contributions made specifically for the Cold Compress toolkit are subject to the same BSD-style license as the GPT-Fast project to maintain consistency and compliance with the original work's licensing terms. 6 | 7 | Users of the Cold Compress toolkit should adhere to the following conditions: 8 | - Redistributions of the original source code must retain the copyright notice, this list of conditions, and the following disclaimer as provided by the GPT-Fast project. 9 | - Redistributions in binary form must reproduce the original copyright notice, this list of conditions, and the following disclaimer in the documentation and/or other materials provided with the distribution. 10 | - Neither the name of the Cold Compress toolkit nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | THE COLD COMPRESS TOOLKIT IS PROVIDED "AS IS," WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. IN NO EVENT SHALL THE CONTRIBUTORS BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY ARISING FROM, OUT OF, OR IN CONNECTION WITH THE TOOLKIT OR THE USE OR OTHER DEALINGS IN THE TOOLKIT. 13 | 14 | By using the Cold Compress toolkit, you acknowledge and agree to the terms of this disclaimer. 15 | 16 | ----------------------------- 17 | Make sure to also include the license below in in a file named “LICENSE” in the root directory of your project repo. 18 | 19 | ## BSD 3-Clause License 20 | 21 | ``` 22 | Copyright 2023 Meta 23 | 24 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 25 | 26 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 27 | 28 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 29 | 30 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 31 | 32 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 33 | ``` -------------------------------------------------------------------------------- /quantization_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def quantize_tensor(x, n_bit=8, axis=0): 5 | assert n_bit in [2, 4, 8], "Only 2-bit, 4-bit, and 8-bit quantization are supported" 6 | # Move the quantization axis to the first dimension 7 | x = x.transpose(0, axis) 8 | 9 | min_val, max_val = torch.aminmax(x.reshape(x.shape[0], -1), dim=1) 10 | max_int = 2**n_bit - 1 11 | min_int = 0 12 | scales = (max_val - min_val).clamp(min=1e-6) / max_int 13 | zeros = min_val + scales * (2 ** (n_bit - 1)) 14 | 15 | x_int8 = ( 16 | x.sub(min_val.reshape(-1, *([1] * (x.dim() - 1)))) 17 | .div(scales.reshape(-1, *([1] * (x.dim() - 1)))) 18 | .round() 19 | .clamp_(min_int, max_int) 20 | .to(torch.int8) 21 | .reshape_as(x) 22 | ).transpose(0, axis) 23 | 24 | # Pack low-bit tensors into 8-bit dtype 25 | if n_bit < 8: 26 | x_int8 = pack_low_bit_tensor(x_int8, n_bit) 27 | 28 | return x_int8, scales, zeros 29 | 30 | 31 | def dequantize_tensor(x, scales, zeros, orig_shape, n_bit=8, axis=0): 32 | assert n_bit in [2, 4, 8], "Only 2-bit, 4-bit, and 8-bit quantization are supported" 33 | # Unpack low-bit tensor from 8-bit dtype 34 | if n_bit < 8: 35 | x = unpack_low_bit_tensor(x, n_bit, orig_shape) 36 | 37 | # Move the quantization axis to the first dimension 38 | x = x.transpose(0, axis) 39 | 40 | return ( 41 | x.sub(2 ** (n_bit - 1)) 42 | .mul(scales.reshape(-1, *([1] * (x.dim() - 1)))) 43 | .add(zeros.reshape(-1, *([1] * (x.dim() - 1)))) 44 | .reshape_as(x) 45 | .transpose(0, axis) 46 | ) 47 | 48 | 49 | def pack_low_bit_tensor(tensor, n_bit): 50 | assert n_bit in [2, 4], "Only 2-bit and 4-bit packing are supported" 51 | 52 | if n_bit == 4: 53 | assert torch.all(tensor < 16) and torch.all( 54 | tensor >= 0 55 | ), "All values must be in [0, 15] range for 4-bit packing" 56 | else: 57 | # 2-bit packing 58 | assert torch.all(tensor < 4) and torch.all( 59 | tensor >= 0 60 | ), "All values must be in [0, 3] range for 2-bit packing" 61 | 62 | values_per_byte = 8 // n_bit 63 | 64 | # Flatten the tensor 65 | flat_tensor = tensor.flatten() 66 | 67 | # Pad the tensor if necessary 68 | if flat_tensor.numel() % values_per_byte != 0: 69 | padding_size = values_per_byte - (flat_tensor.numel() % values_per_byte) 70 | flat_tensor = torch.cat([flat_tensor, flat_tensor.new_zeros(padding_size)]) 71 | 72 | # Reshape to 2D tensor 73 | reshaped = flat_tensor.reshape(-1, values_per_byte) 74 | 75 | shifts = torch.arange(0, 8, n_bit, device=tensor.device) 76 | packed = (reshaped << shifts).sum(dim=1).byte() 77 | 78 | return packed 79 | 80 | 81 | def unpack_low_bit_tensor(packed_tensor, n_bit, original_shape): 82 | assert n_bit in [2, 4], "Only 2-bit and 4-bit unpacking are supported" 83 | 84 | mask = (1 << n_bit) - 1 85 | 86 | # Calculate the total number of elements in the original tensor 87 | original_numel = torch.prod(torch.tensor(original_shape)) 88 | 89 | shifts = torch.arange(0, 8, n_bit, device=packed_tensor.device) 90 | unpacked = ((packed_tensor.unsqueeze(1) >> shifts) & mask).flatten() 91 | 92 | # Flatten and truncate to the original number of elements 93 | original = unpacked.reshape(-1)[:original_numel] 94 | 95 | # Reshape back to original shape 96 | original = original.reshape(original_shape) 97 | 98 | return original 99 | -------------------------------------------------------------------------------- /prompts/long_prompt_short_output.txt: -------------------------------------------------------------------------------- 1 | Carefully read the beginning of the Wikipedia page on the Guggenheim meseum. You will be asked to answer a question at the end. 2 | 3 | # Introduction 4 | 5 | The Solomon R. Guggenheim Museum, often referred to as The Guggenheim, is an art museum at 1071 Fifth Avenue between 88th and 89th Streets on the Upper East Side of Manhattan in New York City. It hosts a permanent collection of Impressionist, Post-Impressionist, early Modern, and contemporary art and also features special exhibitions throughout the year. It was established by the Solomon R. Guggenheim Foundation in 1939 as the Museum of Non-Objective Painting, under the guidance of its first director, Hilla von Rebay. The museum adopted its current name in 1952, three years after the death of its founder Solomon R. Guggenheim. It continues to be operated and owned by the Solomon R. Guggenheim Foundation. 6 | The museum's building, a landmark work of 20th-century architecture designed by Frank Lloyd Wright, drew controversy for the unusual shape of its display spaces and took 15 years to design and build; it was completed in 1959. It consists of a six-story, bowl-shaped main gallery to the south, a four-story "monitor" to the north, and a ten-story annex to the northeast. A six-story helical ramp extends along the main gallery's perimeter, under a central ceiling skylight. The Thannhauser Collection is housed within the top three stories of the monitor, and there are additional galleries in the annex and a learning center in the basement. The museum building's design was controversial when it was completed but was widely praised afterward. The building underwent extensive renovations from 1990 to 1992, when the annex was built, and it was renovated again from 2005 to 2008. 7 | The museum's collection has grown over the decades and is founded upon several important private collections, including those of Guggenheim, Karl Nierendorf, Katherine Sophie Dreier, Justin Thannhauser, Rebay, Giuseppe Panza, Robert Mapplethorpe, and the Bohen Foundation. The collection, which includes around 8,000 works as of 2022, is shared with sister museums in Bilbao, Spain, and Venice, Italy. In 2023, nearly 861,000 people visited the museum. 8 | 9 | # History 10 | 11 | ## Early years and Hilla Rebay 12 | Solomon R. Guggenheim, a member of a wealthy mining family, began collecting works of the old masters in the 1890s. In 1926, he met artist Hilla von Rebay, who introduced him to European avant-garde art, in particular abstract art that she felt had a spiritual and utopian aspect (non-objective art). Guggenheim completely changed his collecting strategy, turning to the work of Wassily Kandinsky, among others. He began to display his collection to the public at his apartment in the Plaza Hotel in New York City. Guggenheim and Rebay initially considered building a museum at Rockefeller Center in Manhattan. As the collection grew, Guggenheim established the Solomon R. Guggenheim Foundation, in 1937, to foster the appreciation of modern art. 13 | The foundation's first venue, the Museum of Non-Objective Painting, opened in 1939, under Rebay's direction, at 24 East 54th Street in midtown Manhattan. Under her guidance, Guggenheim sought to include in the collection the most important examples of non-objective art by early modernists. He wanted to display the collection at the 1939 New York World's Fair in Queens, but Rebay advocated for a more permanent location in Manhattan. By the early 1940s, the foundation had accumulated such a large collection of avant-garde paintings that the need for a permanent museum was apparent, and Rebay wanted to establish it before Guggenheim died. 14 | 15 | ## Design process 16 | In 1943, Rebay and Guggenheim wrote a letter to Frank Lloyd Wright asking him to design a structure to house and display the collection. Rebay thought the 76-year-old Wright was dead, but Guggenheim's wife Irene Rothschild Guggenheim knew better and suggested that Rebay contact him. Wright accepted the opportunity to experiment with his "organic" style in an urban setting, saying that he had never seen a museum that was "properly designed". He was hired to design the building in June 1943. He was to receive a 10 percent commission on the project, which was expected to cost at least $1 million. It took him 15 years, more than 700 sketches, and six sets of working drawings to create and complete the museum, after a series of difficulties and delays; the cost eventually doubled from the initial estimate. 17 | Rebay envisioned a space that would facilitate a new way of seeing modern art. She wrote Wright that "each of these great masterpieces should be organized into space, and only you ... would test the possibilities to do so. ... I want a temple of spirit, a monument!" Critic Paul Goldberger later wrote that Wright's modernist building was a catalyst for change, making it "socially and culturally acceptable for an architect to design a highly expressive, intensely personal museum. In this sense almost every museum of our time is a child of the Guggenheim." The Guggenheim is the only museum Wright designed; its urban location required him to design it in a vertical rather than horizontal form, far different from his earlier, rural works. Since he was not licensed as an architect in New York, he relied on Arthur Cort Holden, of the architectural firm Holden, McLaughlin & Associates, to deal with New York City's Board of Standards and Appeals. 18 | From 1943 to early 1944, Wright produced four differing designs. One had a hexagonal shape and level floors for the galleries, though all the others had circular schemes and used a ramp continuing around the building. In his notes, he indicated that he wanted a "well proportioned floor space from bottom to top—a wheel chair going around and up and down". His original concept was called an inverted "ziggurat", because it resembled the steep steps on the ziggurats built in ancient Mesopotamia. Several architecture professors have speculated that the helical ramp and glass dome of Giuseppe Momo's 1932 staircase at the Vatican Museums was an inspiration for Wright's ramp and atrium. 19 | 20 | Question: Which is the largest number? 21 | A) Frank Lloyd Wright's age in 1943. 22 | B) The size of the collection at the Guggenheim. 23 | C) The building number of the museum's first venue. 24 | D) The number of sketches it took Frank Lloyd Wright to create the museum. -------------------------------------------------------------------------------- /experiments/multi_strategy.txt: -------------------------------------------------------------------------------- 1 | python eval.py --compile --tasks squality --global_tokens 4 --checkpoint_path ./checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/model.pth --num_samples 500 --cache_strategy recent_global recent_global recent_global recent_global recent_global recent_global recent_global full --prompt_compression_strategy recent_global recent_global recent_global recent_global recent_global recent_global recent_global recent_global --max_cache_length 0.14 0.14 0.14 0.14 0.14 0.14 0.14 1.0 --cache_length_pattern tile --cache_strategy_pattern tile 2 | python eval.py --compile --tasks squality --global_tokens 4 --checkpoint_path ./checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/model.pth --num_samples 500 --cache_strategy recent_global recent_global recent_global recent_global recent_global recent_global recent_global full --prompt_compression_strategy recent_global recent_global recent_global recent_global recent_global recent_global recent_global recent_global --max_cache_length 0.14 0.14 0.14 0.14 0.14 0.14 0.14 1.0 --cache_length_pattern repeat --cache_strategy_pattern repeat 3 | python eval.py --compile --tasks squality --global_tokens 4 --checkpoint_path ./checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/model.pth --num_samples 500 --cache_strategy full recent_global recent_global recent_global recent_global recent_global recent_global recent_global --prompt_compression_strategy recent_global recent_global recent_global recent_global recent_global recent_global recent_global recent_global --max_cache_length 1.0 0.14 0.14 0.14 0.14 0.14 0.14 0.14 --cache_length_pattern tile --cache_strategy_pattern tile 4 | python eval.py --compile --tasks squality --global_tokens 4 --checkpoint_path ./checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/model.pth --num_samples 500 --cache_strategy full recent_global recent_global recent_global recent_global recent_global recent_global recent_global --prompt_compression_strategy recent_global recent_global recent_global recent_global recent_global recent_global recent_global recent_global --max_cache_length 1.0 0.14 0.14 0.14 0.14 0.14 0.14 0.14 --cache_length_pattern repeat --cache_strategy_pattern repeat 5 | python eval.py --compile --tasks musique --global_tokens 4 --checkpoint_path ./checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/model.pth --num_samples 500 --cache_strategy recent_global recent_global recent_global recent_global recent_global recent_global recent_global full --prompt_compression_strategy recent_global recent_global recent_global recent_global recent_global recent_global recent_global recent_global --max_cache_length 0.14 0.14 0.14 0.14 0.14 0.14 0.14 1.0 --cache_length_pattern tile --cache_strategy_pattern tile 6 | python eval.py --compile --tasks musique --global_tokens 4 --checkpoint_path ./checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/model.pth --num_samples 500 --cache_strategy recent_global recent_global recent_global recent_global recent_global recent_global recent_global full --prompt_compression_strategy recent_global recent_global recent_global recent_global recent_global recent_global recent_global recent_global --max_cache_length 0.14 0.14 0.14 0.14 0.14 0.14 0.14 1.0 --cache_length_pattern repeat --cache_strategy_pattern repeat 7 | python eval.py --compile --tasks musique --global_tokens 4 --checkpoint_path ./checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/model.pth --num_samples 500 --cache_strategy full recent_global recent_global recent_global recent_global recent_global recent_global recent_global --prompt_compression_strategy recent_global recent_global recent_global recent_global recent_global recent_global recent_global recent_global --max_cache_length 1.0 0.14 0.14 0.14 0.14 0.14 0.14 0.14 --cache_length_pattern tile --cache_strategy_pattern tile 8 | python eval.py --compile --tasks musique --global_tokens 4 --checkpoint_path ./checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/model.pth --num_samples 500 --cache_strategy full recent_global recent_global recent_global recent_global recent_global recent_global recent_global --prompt_compression_strategy recent_global recent_global recent_global recent_global recent_global recent_global recent_global recent_global --max_cache_length 1.0 0.14 0.14 0.14 0.14 0.14 0.14 0.14 --cache_length_pattern repeat --cache_strategy_pattern repeat 9 | python eval.py --compile --tasks rulerniah --global_tokens 4 --checkpoint_path ./checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/model.pth --num_samples 500 --cache_strategy recent_global recent_global recent_global recent_global recent_global recent_global recent_global full --prompt_compression_strategy recent_global recent_global recent_global recent_global recent_global recent_global recent_global recent_global --max_cache_length 0.14 0.14 0.14 0.14 0.14 0.14 0.14 1.0 --cache_length_pattern tile --cache_strategy_pattern tile 10 | python eval.py --compile --tasks rulerniah --global_tokens 4 --checkpoint_path ./checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/model.pth --num_samples 500 --cache_strategy recent_global recent_global recent_global recent_global recent_global recent_global recent_global full --prompt_compression_strategy recent_global recent_global recent_global recent_global recent_global recent_global recent_global recent_global --max_cache_length 0.14 0.14 0.14 0.14 0.14 0.14 0.14 1.0 --cache_length_pattern repeat --cache_strategy_pattern repeat 11 | python eval.py --compile --tasks rulerniah --global_tokens 4 --checkpoint_path ./checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/model.pth --num_samples 500 --cache_strategy full recent_global recent_global recent_global recent_global recent_global recent_global recent_global --prompt_compression_strategy recent_global recent_global recent_global recent_global recent_global recent_global recent_global recent_global --max_cache_length 1.0 0.14 0.14 0.14 0.14 0.14 0.14 0.14 --cache_length_pattern tile --cache_strategy_pattern tile 12 | python eval.py --compile --tasks rulerniah --global_tokens 4 --checkpoint_path ./checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/model.pth --num_samples 500 --cache_strategy full recent_global recent_global recent_global recent_global recent_global recent_global recent_global --prompt_compression_strategy recent_global recent_global recent_global recent_global recent_global recent_global recent_global recent_global --max_cache_length 1.0 0.14 0.14 0.14 0.14 0.14 0.14 0.14 --cache_length_pattern repeat --cache_strategy_pattern repeat 13 | -------------------------------------------------------------------------------- /tp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | from typing import List, Optional 8 | 9 | import torch 10 | import torch.distributed as dist 11 | from torch import nn 12 | 13 | if os.uname().sysname != "Darwin": 14 | from torch.distributed import _functional_collectives as funcol 15 | else: 16 | # Distributed is not supported on MacOS 17 | funcol = None 18 | 19 | from model import Attention, FeedForward, Transformer 20 | from quantize import WeightOnlyInt4Linear 21 | 22 | 23 | def _get_rank() -> int: 24 | return int(os.environ.get("LOCAL_RANK", "0")) 25 | 26 | 27 | def is_local(): 28 | return _get_rank() == 0 29 | 30 | 31 | def local_break(): 32 | if is_local(): 33 | breakpoint() 34 | dist.barrier() 35 | 36 | 37 | def _get_world_size() -> int: 38 | return int(os.environ.get("LOCAL_WORLD_SIZE", "1")) 39 | 40 | 41 | def maybe_init_dist() -> Optional[int]: 42 | try: 43 | # provided by torchrun 44 | rank = _get_rank() 45 | world_size = _get_world_size() 46 | 47 | if world_size < 2: 48 | # too few gpus to parallelize, tp is no-op 49 | return None 50 | except KeyError: 51 | # not run via torchrun, no-op 52 | return None 53 | 54 | torch.cuda.set_device(rank) 55 | dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) 56 | return rank 57 | 58 | 59 | def _apply_tp_linear( 60 | linear: nn.Linear, style: str, weight_splits: List[int] = [] 61 | ) -> None: 62 | rank = _get_rank() 63 | world_size = _get_world_size() 64 | 65 | # Linear's weight matrix is transposed, and is of shape 66 | # (linear.out_features, linear.in_features) 67 | dim_lookup = {"colwise": (0, "out_features"), "rowwise": (1, "in_features")} 68 | assert style in dim_lookup 69 | shard_dim, size_attr = dim_lookup[style] 70 | 71 | # ensure we can shard evenly 72 | assert getattr(linear, size_attr) % world_size == 0 73 | 74 | def shard(x, dim): 75 | assert x.size(dim=dim) % world_size == 0 76 | return torch.tensor_split(x, world_size, dim=dim)[rank] 77 | 78 | def shard_qkv(qkv, dim, weight_splits): 79 | q, k, v = qkv.split(weight_splits, dim=dim) 80 | q = shard(q, dim) 81 | k = shard(k, dim) 82 | v = shard(v, dim) 83 | return torch.cat((q, k, v), dim=dim) 84 | 85 | # shard 86 | if weight_splits: 87 | # attention 88 | assert len(weight_splits) == 3 89 | 90 | if isinstance(linear, WeightOnlyInt4Linear): 91 | sharded_weight = shard_qkv( 92 | linear.weight, shard_dim, [i // 8 for i in weight_splits] 93 | ) 94 | linear.scales_and_zeros = shard_qkv( 95 | linear.scales_and_zeros, 1 - shard_dim, weight_splits 96 | ) 97 | else: 98 | sharded_weight = shard_qkv(linear.weight, shard_dim, weight_splits) 99 | if hasattr(linear, "scales") and style == "colwise": 100 | linear.scales = shard_qkv(linear.scales, 0, weight_splits) 101 | else: 102 | sharded_weight = shard(linear.weight, shard_dim) 103 | if isinstance(linear, WeightOnlyInt4Linear): 104 | linear.scales_and_zeros = shard(linear.scales_and_zeros, 1 - shard_dim) 105 | if style == "rowwise": 106 | assert ( 107 | linear.scales_and_zeros.shape[0] * 32 108 | == sharded_weight.shape[1] 109 | * sharded_weight.shape[2] 110 | * sharded_weight.shape[3] 111 | ) 112 | assert linear.scales_and_zeros.shape[1] == sharded_weight.shape[0] * 8 113 | if hasattr(linear, "scales") and style == "colwise": 114 | linear.scales = shard(linear.scales, 0) 115 | 116 | # local_break() 117 | linear.weight = nn.Parameter(sharded_weight, requires_grad=False) 118 | setattr(linear, size_attr, getattr(linear, size_attr) // world_size) 119 | 120 | # shape info should still be synced 121 | # assert linear.weight.shape == (linear.out_features, linear.in_features) 122 | 123 | 124 | def _apply_tp_ffn(mlp: FeedForward) -> None: 125 | assert hasattr(mlp, "w1") 126 | assert hasattr(mlp, "w3") 127 | assert hasattr(mlp, "w2") 128 | 129 | _apply_tp_linear(mlp.w1, "colwise") 130 | _apply_tp_linear(mlp.w3, "colwise") 131 | _apply_tp_linear(mlp.w2, "rowwise") 132 | 133 | world_size = _get_world_size() 134 | mlp.register_forward_hook( 135 | lambda _module, _input, output: funcol.all_reduce( 136 | output, "sum", list(range(world_size)) 137 | ) 138 | ) 139 | 140 | 141 | def _apply_tp_attn(attn: Attention) -> None: 142 | assert hasattr(attn, "wqkv") 143 | assert hasattr(attn, "wo") 144 | 145 | kv_size = attn.n_local_heads * attn.head_dim 146 | _apply_tp_linear(attn.wqkv, "colwise", [attn.dim, kv_size, kv_size]) 147 | _apply_tp_linear(attn.wo, "rowwise") 148 | 149 | # overwrite 150 | world_size = _get_world_size() 151 | attn.n_head = attn.n_head // world_size 152 | attn.dim = attn.dim // world_size 153 | attn.head_dim = attn.dim // attn.n_head 154 | attn.n_local_heads = attn.n_local_heads // world_size 155 | 156 | attn.register_forward_hook( 157 | lambda _module, _input, output: funcol.all_reduce( 158 | output[0], "sum", list(range(world_size)) 159 | ) 160 | ) 161 | 162 | 163 | def _apply_tp_Transformer(Transformer: Transformer) -> None: 164 | # overwrite config before Transformer.setup_cache is called 165 | world_size = _get_world_size() 166 | Transformer.config.n_head = Transformer.config.n_head // world_size 167 | Transformer.config.dim = Transformer.config.dim // world_size 168 | Transformer.config.n_local_heads = Transformer.config.n_local_heads // world_size 169 | 170 | 171 | def apply_tp(model: Transformer) -> None: 172 | _apply_tp_Transformer(model) 173 | for block in model.layers: 174 | # Apply to MLP 175 | _apply_tp_ffn(block.feed_forward) 176 | _apply_tp_attn(block.attention) 177 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import sys 7 | import time 8 | import contextlib 9 | from pathlib import Path 10 | from typing import Optional 11 | 12 | import torch 13 | import torch._dynamo.config 14 | import torch._inductor.config 15 | import logging 16 | 17 | from cache import add_cache_arguments 18 | from generation_utils import ( 19 | add_generation_arguments, 20 | compile_funcs, 21 | compute_max_seq_length, 22 | device_sync, 23 | print_stats, 24 | ) 25 | 26 | torch._inductor.config.coordinate_descent_tuning = True 27 | torch._inductor.config.triton.unique_kernel_names = True 28 | torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future 29 | DEBUG_COMPILE = False 30 | if DEBUG_COMPILE: 31 | import logging 32 | 33 | level = logging.DEBUG 34 | torch._logging.set_logs(dynamo=level, inductor=level) 35 | torch._dynamo.config.verbose = True 36 | 37 | default_device = "cuda" if torch.cuda.is_available() else "cpu" 38 | 39 | # support running without installing as a package 40 | wd = Path(__file__).parent.parent.resolve() 41 | sys.path.append(str(wd)) 42 | 43 | from tokenizer import get_tokenizer, encode 44 | from generation_utils import ( 45 | generate, 46 | get_model_size, 47 | load_model, 48 | merge_cache_config, 49 | setup_caches, 50 | ) 51 | from cache import add_cache_arguments, cache_compatibility 52 | 53 | 54 | def main( 55 | prompt: str = "Hello, my name is", 56 | max_new_tokens: int = 100, 57 | checkpoint_path: Path = Path( 58 | "checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth" 59 | ), 60 | compile: bool = True, 61 | feed_long_prompts: bool = False, 62 | attn_top_k: float = 1.0, 63 | profile: Optional[Path] = None, 64 | device=default_device, 65 | cache_kwargs: dict = {}, 66 | ) -> None: 67 | """Generates text samples based on a pre-trained Transformer model and tokenizer.""" 68 | assert checkpoint_path.is_file(), checkpoint_path 69 | 70 | # pytorch_logs_to_file() 71 | 72 | tokenizer_path = checkpoint_path.parent / "tokenizer.model" 73 | if not tokenizer_path.is_file(): 74 | # If there's no tokenizer.model, try to load the tokenizer from the parent directory 75 | # NOTE: We assume the tokenizer in the parent directory is compatible with huggingface transformers 76 | tokenizer_path = checkpoint_path.parent 77 | 78 | global print 79 | from tp import maybe_init_dist 80 | 81 | rank = maybe_init_dist() 82 | use_tp = rank is not None 83 | if use_tp: 84 | if rank != 0: 85 | # only print on rank 0 86 | print = lambda *args, **kwargs: None 87 | 88 | print(f"Using device={device}") 89 | precision = torch.bfloat16 90 | is_chat = ( 91 | "chat" in str(checkpoint_path).lower() 92 | or "instruct" in str(checkpoint_path).lower() 93 | ) 94 | 95 | print("Loading model ...") 96 | t0 = time.time() 97 | model = load_model(checkpoint_path, device, precision, use_tp) 98 | 99 | device_sync(device=device) # MKG 100 | print(f"Time to load model: {time.time() - t0:.02f} seconds") 101 | 102 | tokenizer = get_tokenizer(tokenizer_path, checkpoint_path, is_chat=is_chat) 103 | 104 | inputs = [encode(tokenizer, prompt, device=device, is_chat=is_chat)] 105 | 106 | terminator_ids = tokenizer.get_terminator_ids() 107 | 108 | torch.manual_seed(1234) 109 | model_size = get_model_size(model) 110 | print(f"{model_size / 1e9:.02f} billion parameters in model.") 111 | 112 | prefill, decode_one_token = compile_funcs(compile) 113 | 114 | device_sync(device=device) # MKG 115 | 116 | max_prompt_length, max_seq_length = compute_max_seq_length( 117 | model, inputs, None, max_new_tokens 118 | ) 119 | max_new_tokens = min(max_new_tokens, max_seq_length - max_prompt_length) 120 | setup_caches(model, tokenizer, inputs[0].device, max_seq_length, cache_kwargs) 121 | 122 | y, _, perf_stats = generate( 123 | model, 124 | inputs[0], 125 | prefill, 126 | decode_one_token, 127 | max_new_tokens=max_new_tokens, 128 | terminator_ids=terminator_ids, 129 | attn_top_k=attn_top_k, 130 | feed_long_prompts=feed_long_prompts, 131 | ) 132 | 133 | device_sync(device=device) # MKG 134 | print("\n==========\n") 135 | print("GENERATION:") 136 | print(tokenizer.decode(y.tolist())) 137 | print("\n==========\n") 138 | print("PERFORMANCE:") 139 | tokens_per_second = perf_stats["total_toks_per_sec"] 140 | decode_tokens = perf_stats["decode_tokens"] 141 | total_seconds = perf_stats["total_seconds"] 142 | memory_used_gb = perf_stats["memory_used_gb"] 143 | 144 | print( 145 | f"Time: {total_seconds:.02f} sec total, {tokens_per_second:.02f} tokens/sec, {decode_tokens} tokens" 146 | ) 147 | print(f"Bandwidth: {model_size * tokens_per_second / 1e9:.02f} GB/s") 148 | print(f"Memory used: {memory_used_gb} GB") 149 | print("\n==========\n") 150 | print("DETAILED PERFORMANCE:") 151 | print_stats(perf_stats) 152 | 153 | print("\n==========\n") 154 | print("KV CACHE STATISTICS:") 155 | cache_stats = model.get_cache_stats(max_prompt_length, decode_tokens) 156 | print_stats(cache_stats) 157 | 158 | 159 | if __name__ == "__main__": 160 | import argparse 161 | 162 | parser = argparse.ArgumentParser( 163 | description="Run Simple Single Prompt Generation (for development and debugging purposes)." 164 | ) 165 | parser.add_argument( 166 | "--prompt", 167 | type=str, 168 | default="long_prompt_short_output.txt", 169 | help="Input prompt. If it ends in .txt, we will load the prompt from the ./prompts dir.", 170 | ) 171 | parser.add_argument( 172 | "--max_new_tokens", type=int, default=512, help="Maximum number of new tokens." 173 | ) 174 | 175 | parser.add_argument( 176 | "--cache_config", 177 | type=str, 178 | default=None, 179 | help="Name of YAML file in ./cache_configs.", 180 | ) 181 | 182 | add_generation_arguments(parser) 183 | add_cache_arguments(parser) 184 | 185 | args = merge_cache_config(parser.parse_args()) 186 | 187 | if args.prompt.endswith(".txt"): 188 | prompt_fn = Path(__file__).resolve().parent / "prompts" / args.prompt 189 | with open(prompt_fn) as fd: 190 | args.prompt = fd.read().strip() 191 | 192 | cache_compatibility(args) 193 | 194 | main( 195 | args.prompt, 196 | args.max_new_tokens, 197 | args.checkpoint_path, 198 | args.compile, 199 | args.feed_long_prompts, 200 | args.attn_top_k, 201 | args.profile, 202 | args.device, 203 | cache_kwargs=vars(args), 204 | ) 205 | -------------------------------------------------------------------------------- /scripts/convert_hf_checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import json 7 | import os 8 | import re 9 | import shutil 10 | import sys 11 | from pathlib import Path 12 | from typing import Optional 13 | 14 | import torch 15 | from safetensors.torch import load_file 16 | 17 | # support running without installing as a package 18 | wd = Path(__file__).parent.parent.resolve() 19 | sys.path.append(str(wd)) 20 | 21 | from model import ModelArgs 22 | 23 | 24 | @torch.inference_mode() 25 | def convert_hf_checkpoint( 26 | *, 27 | checkpoint_dir: Path = Path( 28 | "checkpoints/meta-Transformer/Transformer-2-7b-chat-hf" 29 | ), 30 | model_name: Optional[str] = None, 31 | ) -> None: 32 | out_model_path = checkpoint_dir / "model.pth" 33 | if os.path.exists(out_model_path): 34 | print(f"Model already exists at {out_model_path}") 35 | return 36 | 37 | if model_name is None: 38 | model_name = checkpoint_dir.name 39 | 40 | # Llama 3 8B doesn't need conversion; instead, the original/consolidated.NN.pth files 41 | # need to be copied into model.pth. 42 | # Llama 3 70B can't be easily merged into one model.pth file, though, since names of the 43 | # weights is state dict are the same in each consolidated.NN.pth file. Thus, it is not 44 | # currently supported. 45 | # Along this, we need to copy the original/tokenizer.model file to tokenizer.model.tiktoken 46 | is_llama3 = "Llama-3" in model_name 47 | if is_llama3: 48 | # Check if we have multiple original/consolidated.NN.pth files and report error 49 | # if we do for Llama 3. 50 | original_dir = checkpoint_dir / "original" 51 | pattern = re.compile(r"^consolidated\.\d{2}\.pth$") 52 | bin_files = [bin for bin in original_dir.iterdir() if pattern.match(bin.name)] 53 | if len(bin_files) > 1: 54 | raise ValueError( 55 | f"Multiple consolidated.NN.pth files found in {original_dir}. " 56 | "Merging them into one model.pth file is not supported for Llama 3." 57 | ) 58 | 59 | config = ModelArgs.from_name(model_name) 60 | print(f"Model config {config.__dict__}") 61 | 62 | # Load the json file containing weight mapping 63 | if not is_llama3: 64 | # Check for index file 65 | index_files = list(checkpoint_dir.glob("*.index.json")) 66 | assert len(index_files) <= 1, "There should be at most one index file." 67 | 68 | if len(index_files) == 1: 69 | # For larger models, the weights are stored in separate files, so we need to load the index. 70 | with open(index_files[0]) as json_map: 71 | bin_index = json.load(json_map) 72 | bin_files = { 73 | checkpoint_dir / bin for bin in bin_index["weight_map"].values() 74 | } 75 | else: 76 | # For smaller models, the weights are stored in a single file. 77 | # Note it could be a bin file or a safetensors file. 78 | if (checkpoint_dir / "pytorch_model.bin").exists(): 79 | bin_files = {checkpoint_dir / "pytorch_model.bin"} 80 | else: 81 | bin_files = {checkpoint_dir / "model.safetensors"} 82 | weight_map = { 83 | "model.embed_tokens.weight": "tok_embeddings.weight", 84 | "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", 85 | "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", 86 | "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", 87 | "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", 88 | "model.layers.{}.self_attn.q_proj.bias": "layers.{}.attention.wq.bias", 89 | "model.layers.{}.self_attn.k_proj.bias": "layers.{}.attention.wk.bias", 90 | "model.layers.{}.self_attn.v_proj.bias": "layers.{}.attention.wv.bias", 91 | "model.layers.{}.self_attn.rotary_emb.inv_freq": None, 92 | "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", 93 | "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", 94 | "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", 95 | "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", 96 | "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", 97 | "model.norm.weight": "norm.weight", 98 | "lm_head.weight": "output.weight", 99 | } 100 | else: 101 | # There is no separate pytorch_model.bin.index.json file for llama3. 102 | # Instead, we will just use all original/consolidated.NN.pth files. 103 | # so, we use model.safetensors.index.json 104 | weight_map = None 105 | original_dir = checkpoint_dir / "original" 106 | pattern = re.compile(r"^consolidated\.\d{2}\.pth$") 107 | bin_files = {bin for bin in original_dir.iterdir() if pattern.match(bin.name)} 108 | 109 | def permute(w, n_head, dim=config.dim): 110 | return ( 111 | w.view(n_head, 2, config.head_dim // 2, dim) 112 | .transpose(1, 2) 113 | .reshape(config.head_dim * n_head, dim) 114 | ) 115 | 116 | merged_result = {} 117 | for file in sorted(bin_files): 118 | if str(file).endswith(".safetensors"): 119 | state_dict = load_file(str(file)) 120 | else: 121 | state_dict = torch.load( 122 | str(file), map_location="cpu", mmap=True, weights_only=True 123 | ) 124 | merged_result.update(state_dict) 125 | final_result = {} 126 | if weight_map is not None: 127 | for key, value in merged_result.items(): 128 | if "layers" in key: 129 | abstract_key = re.sub(r"(\d+)", "{}", key) 130 | layer_num = re.search(r"\d+", key).group(0) 131 | new_key = weight_map[abstract_key] 132 | if new_key is None: 133 | continue 134 | new_key = new_key.format(layer_num) 135 | else: 136 | new_key = weight_map[key] 137 | 138 | final_result[new_key] = value 139 | 140 | for key in tuple(final_result.keys()): 141 | if "wq" in key: 142 | q = final_result[key] 143 | k = final_result[key.replace("wq", "wk")] 144 | v = final_result[key.replace("wq", "wv")] 145 | if key.endswith("weight"): 146 | q = permute(q, config.n_head) 147 | k = permute(k, config.n_local_heads) 148 | else: 149 | # Permute bias to be compatible with the weight permutation 150 | q = permute(q, config.n_head, dim=1).view(-1) 151 | k = permute(k, config.n_local_heads, dim=1).view(-1) 152 | final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v]) 153 | del final_result[key] 154 | del final_result[key.replace("wq", "wk")] 155 | del final_result[key.replace("wq", "wv")] 156 | if "output.weight" not in final_result: 157 | # lm_head.weight may not be explicitly stored in the HF checkpoint if input and output embeddings are shared 158 | final_result["output.weight"] = final_result[ 159 | "tok_embeddings.weight" 160 | ].clone() 161 | else: 162 | final_result = merged_result 163 | if is_llama3: 164 | original_dir = checkpoint_dir / "original" 165 | tokenizer_model = original_dir / "tokenizer.model" 166 | tokenizer_model_tiktoken = checkpoint_dir / "tokenizer.model" 167 | print(f"Copying {tokenizer_model} to {tokenizer_model_tiktoken}") 168 | shutil.copy(tokenizer_model, tokenizer_model_tiktoken) 169 | print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") 170 | torch.save(final_result, out_model_path) 171 | 172 | 173 | if __name__ == "__main__": 174 | import argparse 175 | 176 | parser = argparse.ArgumentParser(description="Convert HuggingFace checkpoint.") 177 | parser.add_argument( 178 | "--checkpoint_dir", 179 | type=Path, 180 | default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf"), 181 | ) 182 | parser.add_argument("--model_name", type=str, default=None) 183 | 184 | args = parser.parse_args() 185 | convert_hf_checkpoint( 186 | checkpoint_dir=args.checkpoint_dir, 187 | model_name=args.model_name, 188 | ) 189 | 190 | # Remove unused files 191 | # shutil.rmtree(args.checkpoint_dir / "original", ignore_errors=True) 192 | 193 | # remove any files in args.checkpoint_dir not named model.pth or tokenizer.model 194 | # for file in args.checkpoint_dir.iterdir(): 195 | # if file.is_file() and file.name not in ["model.pth", "tokenizer.model"]: 196 | # os.remove(file) 197 | # else: 198 | # shutil.rmtree(file, ignore_errors=True) 199 | -------------------------------------------------------------------------------- /parallelize_evals.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import queue 3 | import threading 4 | import time 5 | import os 6 | import sys 7 | import json 8 | import argparse 9 | import itertools 10 | from datetime import datetime 11 | from task import TASK_MAPPING 12 | from pathlib import Path 13 | 14 | 15 | class GPUJobQueue: 16 | def __init__(self, num_gpus=8, log_dir="job_logs"): 17 | self.num_gpus = num_gpus 18 | self.job_queue = queue.Queue() 19 | self.gpu_locks = [threading.Lock() for _ in range(num_gpus)] 20 | self.running_processes = [None] * num_gpus 21 | self.log_dir = log_dir 22 | os.makedirs(self.log_dir, exist_ok=True) 23 | self.queue_file = os.path.join(self.log_dir, "queued_commands.json") 24 | self.completed_file = os.path.join(self.log_dir, "completed_commands.json") 25 | self.log_files = [ 26 | os.path.join(self.log_dir, f"gpu{i}.log") for i in range(num_gpus) 27 | ] 28 | self.queue_lock = threading.Lock() 29 | 30 | # Intialize completed jobs with empty list 31 | with open(self.completed_file, "w") as f: 32 | json.dump([], f, indent=4) 33 | 34 | def _save_queue(self): 35 | with self.queue_lock: 36 | try: 37 | with open(self.queue_file, "w") as f: 38 | json.dump(list(self.job_queue.queue), f, indent=4) 39 | except Exception as e: 40 | print(f"Error saving queue to {self.queue_file}: {str(e)}") 41 | 42 | def _save_completed(self, command): 43 | with self.queue_lock: 44 | try: 45 | with open(self.completed_file, "r+") as f: 46 | completed = json.load(f) 47 | completed.append(command) 48 | f.seek(0) 49 | json.dump(completed, f, indent=4) 50 | f.truncate() 51 | except Exception as e: 52 | print(f"Error updating {self.completed_file}: {str(e)}") 53 | 54 | def add_job(self, bash_command): 55 | self.job_queue.put(bash_command) 56 | self._save_queue() 57 | 58 | def run_job(self, gpu_id, bash_command): 59 | env = os.environ.copy() 60 | env["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 61 | 62 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 63 | log_file = self.log_files[gpu_id] 64 | 65 | try: 66 | with open(log_file, "a") as log: 67 | log.write(f"Running command: {bash_command}\n") 68 | log.write(f"GPU: {gpu_id}\n") 69 | log.write( 70 | f"Start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n" 71 | ) 72 | log.write("-" * 50 + "\n") 73 | log.flush() 74 | 75 | process = subprocess.Popen( 76 | bash_command, 77 | shell=True, 78 | env=env, 79 | stdout=log, 80 | stderr=subprocess.STDOUT, 81 | universal_newlines=True, 82 | ) 83 | self.running_processes[gpu_id] = process 84 | 85 | process.wait() 86 | 87 | log.write("\n" + "-" * 50 + "\n") 88 | log.write(f"End time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") 89 | log.write(f"Exit code: {process.returncode}\n") 90 | 91 | self._save_completed(bash_command) 92 | except Exception as e: 93 | print(f"Error running job on GPU {gpu_id}: {str(e)}") 94 | finally: 95 | self.gpu_locks[gpu_id].release() 96 | self.running_processes[gpu_id] = None 97 | 98 | def process_queue(self): 99 | while True: 100 | if self.job_queue.empty() and all( 101 | proc is None for proc in self.running_processes 102 | ): 103 | break 104 | 105 | for gpu_id in range(self.num_gpus): 106 | if self.running_processes[gpu_id] is None and self.gpu_locks[ 107 | gpu_id 108 | ].acquire(blocking=False): 109 | if not self.job_queue.empty(): 110 | bash_command = self.job_queue.get() 111 | threading.Thread( 112 | target=self.run_job, args=(gpu_id, bash_command) 113 | ).start() 114 | self._save_queue() 115 | else: 116 | self.gpu_locks[gpu_id].release() 117 | 118 | time.sleep(1) # Small delay to prevent busy-waiting 119 | 120 | def terminate_all_jobs(self): 121 | print("Terminating all running jobs...") 122 | for gpu_id, process in enumerate(self.running_processes): 123 | if process is not None: 124 | process.terminate() 125 | with open(self.log_files[gpu_id], "a") as log: 126 | log.write( 127 | f"\nJob terminated at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n" 128 | ) 129 | log.write("=" * 50 + "\n\n") 130 | print("All jobs terminated.") 131 | 132 | 133 | if __name__ == "__main__": 134 | parser = argparse.ArgumentParser(description="Run eval jobs for given yaml configs") 135 | parser.add_argument( 136 | "--command_file", 137 | type=str, 138 | help="text file consisting of commands (1 per line) to be run", 139 | ) 140 | parser.add_argument( 141 | "--config_names", 142 | nargs="+", 143 | help="YAML configuration files that need to be evaluated", 144 | required="--command_file" not in sys.argv, 145 | ) 146 | parser.add_argument( 147 | "--tasks", 148 | type=str, 149 | nargs="+", 150 | required="--command_file" not in sys.argv, 151 | choices=list(TASK_MAPPING.keys()), 152 | help="List of tasks to be evaluated.", 153 | ) 154 | parser.add_argument( 155 | "--cache_sizes", 156 | type=float, 157 | nargs="+", 158 | default=[8192, 4096, 2048, 1024, 512, 256, 128], 159 | help="Cache sizes to be evaluated.", 160 | ) 161 | parser.add_argument( 162 | "--num_samples", 163 | type=int, 164 | default=-1, 165 | help="Number of examples to sample for evaluation. Defaults to None, which uses the full dataset.", 166 | ) 167 | parser.add_argument( 168 | "--add_full", 169 | default=False, 170 | action="store_true", 171 | help="Run the full attention model in addition to the compressed models.", 172 | ) 173 | parser.add_argument( 174 | "--checkpoint_path", 175 | type=Path, 176 | default=Path(__file__).resolve().parent 177 | / "checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/model.pth", 178 | help="Model checkpoint path.", 179 | ) 180 | parser.add_argument( 181 | "--num_gpus", type=int, default=8, help="Number of GPUs available" 182 | ) 183 | parser.add_argument( 184 | "--log_dir", default="eval_job_logs", help="Directory for job logs" 185 | ) 186 | args = parser.parse_args() 187 | 188 | gpu_queue = GPUJobQueue(num_gpus=args.num_gpus, log_dir=args.log_dir) 189 | 190 | if args.command_file: 191 | with open(args.command_file) as fin: 192 | lines = [line.strip() for line in fin] 193 | for line in lines: 194 | if line: 195 | gpu_queue.add_job(line) 196 | 197 | else: 198 | configs = [] 199 | for config in args.config_names: 200 | if not config.endswith(".yaml"): 201 | config = config + ".yaml" 202 | assert os.path.join( 203 | os.path.abspath(__file__), "cache_configs", config 204 | ), f"{config} not found in cache_configs" 205 | configs.append(config) 206 | 207 | base_command = "python eval.py --task {task} --checkpoint {chkpt} --cache_config {config} --num_samples {ns} --compile --max_cache_length {cs}" 208 | 209 | # Create tasks and add them to the task queue. 210 | tasks = list(itertools.product(args.tasks, args.cache_sizes, configs)) 211 | for task, cs, config in itertools.product( 212 | args.tasks, args.cache_sizes, configs 213 | ): 214 | gpu_queue.add_job( 215 | base_command.format( 216 | task=task, 217 | chkpt=args.checkpoint_path, 218 | config=config, 219 | ns=args.num_samples, 220 | cs=cs, 221 | ) 222 | ) 223 | 224 | if args.add_full: 225 | for task in args.tasks: 226 | gpu_queue.add_job( 227 | base_command.format( 228 | task=task, 229 | chkpt=args.checkpoint_path, 230 | config="full.yaml", 231 | ns=args.num_samples, 232 | cs=1.0, 233 | ) 234 | ) 235 | 236 | print(f"Adding {gpu_queue.job_queue.qsize()} tasks into the job queue") 237 | 238 | try: 239 | gpu_queue.process_queue() 240 | except KeyboardInterrupt: 241 | print("\nKeyboardInterrupt received, terminating all jobs...") 242 | gpu_queue.terminate_all_jobs() 243 | print("Exiting.") 244 | 245 | print("All jobs completed or terminated") 246 | -------------------------------------------------------------------------------- /prompt_compression.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from abc import ABC, abstractmethod 3 | 4 | 5 | class PromptCompressor(ABC): 6 | def __init__(self, head_specific, **kwargs) -> None: 7 | # Assign each kwarg as an attribute of the class 8 | for key, value in kwargs.items(): 9 | setattr(self, key, value) 10 | 11 | self.head_specific = head_specific 12 | assert self.is_compatible(), f"Prompt compressor ({self.__class__.__name__}) is not compatible with the chosen cache strategy." 13 | 14 | def _recent_global_mask(self, input_pos): 15 | seq_len = input_pos.shape[-1] 16 | return torch.logical_or( 17 | input_pos < self.global_tokens, 18 | input_pos >= seq_len - self.recent_window, 19 | ) 20 | 21 | def _keep_idxs(self, priority): 22 | return ( 23 | priority.topk(self.max_cache_length, dim=-1) 24 | .indices.sort(dim=-1) 25 | .values.squeeze(0) 26 | ) 27 | 28 | def __call__(self, input_pos, k_val, v_val, **kwargs): 29 | # Assign a score to each token in the prompt to determine filtering priority 30 | priority = self._token_importances(input_pos, k_val, v_val, **kwargs) 31 | 32 | # Get the self.max_cache_length indices with the highest priority 33 | keep_idxs = self._keep_idxs(priority) 34 | 35 | # Compress the prompt based on these indices 36 | k_val, v_val = self._filter_kv(keep_idxs, k_val, v_val) 37 | 38 | return ( 39 | keep_idxs, 40 | k_val, 41 | v_val, 42 | self._update_state(keep_idxs, input_pos, **kwargs), 43 | ) 44 | 45 | def _update_state(self, keep_idxs, input_pos, **kwargs): 46 | # [Optional] Over-write to return attention scores corresponding to keep_idxs 47 | return None 48 | 49 | @abstractmethod 50 | def _filter_kv(self, keep_idxs, k_val, v_val): 51 | raise NotImplementedError 52 | 53 | @abstractmethod 54 | def _token_importances(self, input_pos, k_val, v_val, **kwargs): 55 | raise NotImplementedError 56 | 57 | @abstractmethod 58 | def is_compatible(self) -> bool: 59 | raise NotImplementedError 60 | 61 | 62 | class PromptCompressorHeadConstant(PromptCompressor): 63 | def __init__(self, head_specific, **kwargs) -> None: 64 | super().__init__(head_specific, **kwargs) 65 | 66 | def is_compatible(self) -> bool: 67 | return True 68 | 69 | def _filter_kv(self, keep_idxs, k_val, v_val): 70 | k_val = k_val[:, :, keep_idxs] 71 | v_val = v_val[:, :, keep_idxs] 72 | return k_val, v_val 73 | 74 | 75 | class PromptCompressorHeadSpecific(PromptCompressor): 76 | def __init__(self, head_specific, **kwargs) -> None: 77 | super().__init__(head_specific, **kwargs) 78 | 79 | def is_compatible(self) -> bool: 80 | return self.head_specific 81 | 82 | def _filter_kv(self, keep_idxs, k_val, v_val): 83 | keep_idxs_rep = keep_idxs.view(1, -1, self.max_cache_length, 1).expand( 84 | -1, -1, -1, k_val.shape[-1] 85 | ) 86 | k_val = k_val.gather(2, keep_idxs_rep) 87 | v_val = v_val.gather(2, keep_idxs_rep) 88 | return k_val, v_val 89 | 90 | 91 | class PromptCompressorFull(PromptCompressorHeadConstant): 92 | """ 93 | This is a dummy (pass through) method which returns its inputs 94 | """ 95 | 96 | def __init__(self, head_specific, **kwargs) -> None: 97 | super().__init__(head_specific, **kwargs) 98 | 99 | def is_compatible(self) -> bool: 100 | return True 101 | 102 | def __call__(self, input_pos, k_val, v_val, **kwargs): 103 | return input_pos, k_val, v_val, None # noop 104 | 105 | def _token_importances(self, input_pos, k_val, v_val, **kwargs): 106 | raise Exception("This method should not be called!") 107 | 108 | 109 | class PromptCompressorRandom(PromptCompressorHeadConstant): 110 | def __init__(self, head_specific, **kwargs) -> None: 111 | super().__init__(head_specific, **kwargs) 112 | 113 | def is_compatible(self) -> bool: 114 | # Can be used with any cache 115 | return True 116 | 117 | def _token_importances(self, input_pos, k_val, v_val, **kwargs): 118 | seq_len = input_pos.shape[-1] 119 | save_mask = self._recent_global_mask(input_pos) 120 | priority = input_pos.masked_fill(save_mask, seq_len) 121 | # Assign positions in the middle uniform low priority 122 | priority = priority.masked_fill(~save_mask, -seq_len) 123 | # Add random noise to randomize the middle priorities 124 | priority += torch.randperm(seq_len, device=priority.device) 125 | return priority 126 | 127 | 128 | class PromptCompressorRecentGlobal(PromptCompressorHeadConstant): 129 | def __init__(self, head_specific, **kwargs) -> None: 130 | super().__init__(head_specific, **kwargs) 131 | 132 | window_size = self.max_cache_length - self.global_tokens 133 | assert ( 134 | window_size > 0 135 | ), f"Number of global tokens ({self.global_tokens}) cannot exceed the max cache length ({self.max_cache_length})" 136 | 137 | def is_compatible(self) -> bool: 138 | # Can be used with any cache 139 | return True 140 | 141 | def _token_importances(self, input_pos, k_val, v_val, **kwargs): 142 | # Assign Global tokens to max seq length so they are always saved 143 | return input_pos.masked_fill( 144 | input_pos < self.global_tokens, input_pos.shape[-1] 145 | ) 146 | 147 | 148 | class PromptCompressorHeavyHitter(PromptCompressorHeadSpecific): 149 | """ 150 | Use SnapKV to compress the prompt 151 | Based on the pseudo code on Page 7 of https://arxiv.org/abs/2404.14469 152 | """ 153 | 154 | def __init__(self, head_specific, **kwargs) -> None: 155 | super().__init__(head_specific, **kwargs) 156 | 157 | self.kernel_size = 5 158 | self.observation_len = 16 159 | 160 | # Pooling layer to smooth out the attention distribution 161 | # Feel free to remove this or optimize the kernel size 162 | self.pool = torch.nn.AvgPool1d( 163 | self.kernel_size, 164 | stride=1, 165 | padding=self.kernel_size // 2, 166 | ceil_mode=False, 167 | count_include_pad=False, 168 | ) 169 | 170 | def _token_importances(self, input_pos, k_val, v_val, **kwargs): 171 | attn = kwargs["attn"] 172 | seq_len = input_pos.shape[-1] 173 | obs_len = min(self.observation_len, seq_len) 174 | 175 | priority = attn[:, :, -obs_len:, :].mean(dim=2) 176 | prev_shape = priority.shape 177 | 178 | # We'll be returning the attention history so we need to keep a copy before it's modified 179 | priority = self.pool(priority) 180 | assert ( 181 | priority.shape == prev_shape 182 | ), f"Pooling operation should not change the dimension: {prev_shape} -> {priority.shape}" 183 | priority[:, :, -obs_len:] = 1.0 # Ensure the observation window is selected 184 | priority[:, :, : self.global_tokens] = ( 185 | 1.0 # Ensure the global tokens are selected 186 | ) 187 | return priority 188 | 189 | def _update_state(self, keep_idxs, input_pos, **kwargs): 190 | seq_len = input_pos.shape[-1] 191 | # Return average attention across prompt to insert into KV Cache's attention history tracker 192 | cum_attn = kwargs["attn"].sum(dim=2) / (seq_len - input_pos) 193 | cum_attn = cum_attn.gather(2, keep_idxs.view(1, -1, self.max_cache_length)) 194 | return cum_attn 195 | 196 | 197 | class PromptCompressorL2(PromptCompressorHeadSpecific): 198 | def __init__(self, head_specific, **kwargs) -> None: 199 | super().__init__(head_specific, **kwargs) 200 | 201 | def _token_importances(self, input_pos, k_val, v_val, **kwargs): 202 | # We want to prioritize the lowest L2 norm tokens so we negate the L2 norm 203 | priority = -torch.linalg.vector_norm(k_val, ord=2, dim=-1) 204 | 205 | # Give low score to global and recent tokens 206 | save_mask = self._recent_global_mask(input_pos).view(1, 1, -1) 207 | priority = priority.masked_fill(save_mask, float("inf")) 208 | 209 | return priority 210 | 211 | 212 | class PromptCompressorKeepItOdd(PromptCompressorHeadConstant): 213 | """ 214 | A toy example of a prompt compressor that keeps the odd positions indices of the prompt. 215 | """ 216 | 217 | def __init__(self, head_specific, **kwargs) -> None: 218 | super().__init__(head_specific, **kwargs) 219 | 220 | def _token_importances(self, input_pos, k_val, v_val, **kwargs): 221 | seq_len = input_pos.shape[-1] 222 | # Compute odd indices from keep_idxs to input_pos.shape[-1] - window 223 | priority = input_pos.masked_fill( 224 | self._recent_global_mask(input_pos), seq_len * 2 225 | ) 226 | 227 | # Lower the priority of even tokens 228 | priority[input_pos % 2 == 0] -= seq_len 229 | 230 | return priority 231 | 232 | 233 | def get_prompt_compressor_constructor(strategy): 234 | if strategy == "full": 235 | return PromptCompressorFull 236 | if strategy == "recent_global": 237 | return PromptCompressorRecentGlobal 238 | elif strategy == "heavy_hitter": 239 | return PromptCompressorHeavyHitter 240 | elif strategy == "l2": 241 | return PromptCompressorL2 242 | elif strategy == "random": 243 | return PromptCompressorRandom 244 | elif strategy == "keep_it_odd": 245 | return PromptCompressorKeepItOdd 246 | else: 247 | raise ValueError(f"Unknown prompt compression strategy: {strategy}") 248 | -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import numpy as np 5 | import regex as re 6 | from claudette import Chat, models 7 | from evaluate import load 8 | from anthropic import RateLimitError 9 | import regex as re 10 | 11 | 12 | class Metric: 13 | def __init__(self, **kwargs): 14 | self._load_metric(**kwargs) 15 | 16 | def _load_metric(self, **kwargs): 17 | raise NotImplementedError("This method should be overridden by subclasses.") 18 | 19 | def compute(self, prompts, predictions, references): 20 | raise NotImplementedError("This method should be overridden by subclasses.") 21 | 22 | 23 | class Rouge(Metric): 24 | def __init__(self, **kwargs): 25 | super().__init__(**kwargs) 26 | 27 | def _load_metric(self, **kwargs): 28 | self.metric = load("rouge", keep_in_memory=True) 29 | 30 | def compute(self, prompts, predictions, references): 31 | return self.metric.compute(predictions=predictions, references=references) 32 | 33 | 34 | class Bleurt(Metric): 35 | def __init__(self, **kwargs): 36 | super().__init__(**kwargs) 37 | 38 | def _load_metric(self, **kwargs): 39 | self.metric = load("bleurt", keep_in_memory=True) 40 | 41 | def compute(self, prompts, predictions, references): 42 | return np.mean( 43 | self.metric.compute(predictions=predictions, references=references)[ 44 | "scores" 45 | ] 46 | ) 47 | 48 | 49 | class BertScore(Metric): 50 | def __init__(self, **kwargs): 51 | super().__init__(**kwargs) 52 | 53 | def _load_metric(self, **kwargs): 54 | self.metric = load("bertscore", keep_in_memory=True) 55 | 56 | def compute(self, prompts, predictions, references): 57 | result = self.metric.compute( 58 | predictions=predictions, references=references, lang="en" 59 | ) 60 | return { 61 | "precision": np.mean(result["precision"]), 62 | "recall": np.mean(result["recall"]), 63 | "f1": np.mean(result["f1"]), 64 | } 65 | 66 | 67 | class Accuracy(Metric): 68 | def __init__(self, **kwargs): 69 | super().__init__(**kwargs) 70 | 71 | def _load_metric(self, **kwargs): 72 | from sklearn.metrics import accuracy_score 73 | 74 | self.metric = accuracy_score 75 | 76 | def compute(self, prompts, predictions, references): 77 | return self.metric(references, predictions) 78 | 79 | 80 | class ExactMatchScore(Metric): 81 | def __init__(self, **kwargs): 82 | super().__init__(**kwargs) 83 | 84 | def _load_metric(self, **kwargs): 85 | pass 86 | 87 | def compute(self, prompts, predictions, references): 88 | return np.mean( 89 | [ 90 | 1 if p.split() == r.split() else 0 91 | for p, r in zip(predictions, references) 92 | ] 93 | ) 94 | 95 | 96 | class LevenshteinDistance(Metric): 97 | def __init__(self, **kwargs): 98 | super().__init__(**kwargs) 99 | 100 | def _load_metric(self, **kwargs): 101 | from fuzzywuzzy import fuzz 102 | 103 | self.metric = fuzz.ratio 104 | 105 | def compute(self, prompts, predictions, references): 106 | return np.mean([self.metric(p, r) for p, r in zip(predictions, references)]) 107 | 108 | 109 | class RulerStringMatch(Metric): 110 | """ 111 | Metric used in RULER. 112 | Reference: https://github.com/hsiehjackson/RULER/blob/main/scripts/eval/synthetic/constants.py 113 | """ 114 | 115 | def __init__(self, **kwargs): 116 | super().__init__(**kwargs) 117 | 118 | @staticmethod 119 | def postprocess_pred(predict_str: str): 120 | predict_str = predict_str.strip() 121 | 122 | # Remove all non-printable characters 123 | np_pattern = re.compile(r"[\x00-\x1f]") 124 | predict_str = np_pattern.sub("\n", predict_str).strip() 125 | 126 | return predict_str 127 | 128 | @staticmethod 129 | def string_match_part(refs, preds): 130 | scores = [ 131 | max([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) 132 | for pred, ref in zip(preds, refs) 133 | ] 134 | score = sum(scores) / len(preds) * 100 135 | return {"score": round(score, 4)} 136 | 137 | @staticmethod 138 | def string_match_all(refs, preds): 139 | scores = [ 140 | sum([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) / len(ref) 141 | for pred, ref in zip(preds, refs) 142 | ] 143 | score = sum(scores) / len(preds) * 100 144 | return {"score": round(score, 4)} 145 | 146 | def _load_metric(self, **kwargs): 147 | if kwargs.get("match_part", False): 148 | self.metric = self.string_match_part 149 | else: 150 | self.metric = self.string_match_all 151 | 152 | def compute(self, prompts, predictions, references): 153 | predictions = [self.postprocess_pred(pred) for pred in predictions] 154 | return self.metric(references, predictions) 155 | 156 | 157 | REFERENCE_TEMPLATE = """You are shown ground-truth answer(s) and asked to judge the quality of an LLM-generated answer. 158 | Assign it a score from 1-5 where 1 is the worst and 5 is the best based on how similar it is to the ground-truth(s). 159 | Do NOT explain your choice. Simply return a number from 1-5. 160 | 161 | ====GROUND TRUTHS==== 162 | {labels} 163 | 164 | ====ANSWER==== 165 | {prediction}""" 166 | 167 | PREFILL = "The score (1-5) is:" 168 | 169 | 170 | class LLMRouge(Metric): 171 | def __init__(self, num_retries=5, **kwargs) -> None: 172 | assert ( 173 | "ANTHROPIC_API_KEY" in os.environ 174 | ), "Please set the ANTHROPIC_API_KEY environment variable." 175 | super().__init__(**kwargs) 176 | self.num_retries = num_retries 177 | 178 | def _load_metric(self, **kwargs): 179 | name = kwargs.get("name", "haiku") 180 | matching_names = [m for m in models if name in m] 181 | assert len(matching_names) > 0, f"Model name {name} not found in {models}" 182 | assert ( 183 | len(matching_names) == 1 184 | ), f"Model name {name} found x{len(matching_names)} in {models}" 185 | self.chat = Chat( 186 | matching_names[0], sp="""You are a helpful and concise assistant.""" 187 | ) 188 | 189 | def parse_int(self, text): 190 | return int(re.search(r"\d+", text).group()) 191 | 192 | def compute(self, prompts, predictions, labels): 193 | scores = [] 194 | for p, ls in zip(predictions, labels): 195 | prompt = REFERENCE_TEMPLATE.format(labels="\n---\n".join(ls), prediction=p) 196 | # Clear conversation history 197 | self.chat.h = [] 198 | try: 199 | score = ( 200 | self.chat(prompt, prefill=PREFILL) 201 | .content[0] 202 | .text[len(PREFILL) :] 203 | .strip() 204 | ) 205 | except RateLimitError: 206 | retries = 0 207 | while retries < self.num_retries: 208 | time.sleep(10) 209 | try: 210 | score = ( 211 | self.chat(prompt, prefill=PREFILL) 212 | .content[0] 213 | .text[len(PREFILL) :] 214 | .strip() 215 | ) 216 | break 217 | except RateLimitError: 218 | retries += 1 219 | if retries == self.num_retries: 220 | raise RateLimitError("Exceeded maximum number of retries.") 221 | 222 | score = self.parse_int(score) 223 | scores.append(score) 224 | return {"llm_rouge": sum(scores) / len(scores)} 225 | 226 | 227 | LLM_JUDGE_TEMPLATE = """You are shown a prompt and asked to assess the quality of an LLM-generated answer on the following dimensions: 228 | 229 | ===CRITERIA=== 230 | {criteria} 231 | 232 | Respond with "criteria: score" for each criteria with a newline for each criteria. 233 | Assign a score from 1-5 where 1 is the worst and 5 is the best based on how well the answer meets the criteria. 234 | 235 | ====PROMPT==== 236 | {prompt} 237 | 238 | ====ANSWER==== 239 | {prediction}""" 240 | 241 | 242 | CRITERIA = { 243 | "helpful": "The answer executes the action requested by the prompt without extraneous detail.", 244 | "coherent": "The answer is logically structured and coherent (ignore the prompt).", 245 | "faithful": "The answer is faithful to the prompt and does not contain false information.", 246 | } 247 | 248 | 249 | class LLMJudge(LLMRouge): 250 | def __init__(self, **kwargs) -> None: 251 | assert ( 252 | "ANTHROPIC_API_KEY" in os.environ 253 | ), "Please set the ANTHROPIC_API_KEY environment variable." 254 | super().__init__(**kwargs) 255 | 256 | self.criteria = list(sorted([k for k in CRITERIA])) 257 | self.criteria_def = "\n".join([f"{k}: {CRITERIA[k]}" for k in self.criteria]) 258 | self.prefill = ( 259 | f"\n\n====SCORES for {', '.join(self.criteria)}====\n\n{self.criteria[0]}:" 260 | ) 261 | 262 | def parse_scorecard(self, scorecard): 263 | try: 264 | return { 265 | k: int(v) 266 | for k, v in dict( 267 | re.findall(rf"({'|'.join(self.criteria)})\W+(\d+)", scorecard) 268 | ).items() 269 | } 270 | except Exception as e: 271 | print(e) 272 | raise Exception( 273 | f"Could not parse LLM-generated scorecard for {self.__class__}:\n{scorecard}" 274 | ) 275 | 276 | def claudette_scorecard(self, prompt, prediction): 277 | prompt = LLM_JUDGE_TEMPLATE.format( 278 | criteria=self.criteria_def, prompt=prompt, prediction=prediction 279 | ) 280 | # Clear conversation history 281 | self.chat.h = [] 282 | scorecard = ( 283 | self.chat(prompt, prefill=self.prefill) 284 | .content[0] 285 | .text[len(self.prefill) - len(self.criteria[0]) - 1 :] 286 | .strip() 287 | ) 288 | return scorecard 289 | 290 | def compute(self, prompts, predictions, labels): 291 | scores = [] 292 | 293 | for prompt, pred in zip(prompts, predictions): 294 | scorecard = self.claudette_scorecard(prompt, pred) 295 | score_dict = self.parse_scorecard(scorecard) 296 | scores.append(score_dict) 297 | 298 | return {k: np.mean([s[k] for s in scores]) for k in self.criteria} 299 | 300 | 301 | METRIC_MAPPING = { 302 | "accuracy": Accuracy, 303 | "bertscore": BertScore, 304 | "bleurt": Bleurt, 305 | "exact_match": ExactMatchScore, 306 | "levenshtein": LevenshteinDistance, 307 | "llm-rouge": LLMRouge, 308 | "llm-as-a-judge": LLMJudge, 309 | "rouge": Rouge, 310 | "ruler-string-match": RulerStringMatch, 311 | } 312 | 313 | 314 | class AutoMetric: 315 | def __init__(self): 316 | raise EnvironmentError( 317 | "This class is designed to be instantiated only through the from_name method" 318 | ) 319 | 320 | def from_name(metric_name, **kwargs): 321 | if metric_name not in METRIC_MAPPING: 322 | raise ValueError(f"Invalid metric name: {metric_name}") 323 | return METRIC_MAPPING[metric_name](**kwargs) 324 | 325 | 326 | if __name__ == "__main__": 327 | metric = AutoMetric.from_name("llm-as-a-judge") 328 | predictions = [ 329 | "The answer to 2x2 is 4.", 330 | "The answer to 2x2 is 5.", 331 | ] 332 | labels = [["4"], ["4"]] 333 | prompts = [ 334 | "What is 2x2?", 335 | "What is 2x2?", 336 | ] 337 | print(metric.compute(prompts=prompts, predictions=predictions, labels=None)) 338 | -------------------------------------------------------------------------------- /tokenizer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import itertools 3 | import os 4 | import regex as re 5 | import string 6 | import sentencepiece as spm 7 | import tiktoken 8 | import torch 9 | from tiktoken.load import load_tiktoken_bpe 10 | from transformers import AutoTokenizer 11 | from pathlib import Path 12 | from typing import ( 13 | Dict, 14 | List, 15 | Literal, 16 | TypedDict, 17 | ) 18 | 19 | 20 | default_device = "cuda" if torch.cuda.is_available() else "cpu" 21 | 22 | 23 | def is_punc_id(text): 24 | # Define a regex pattern that matches any character that is not whitespace or punctuation 25 | pattern = rf"^[\s{re.escape(string.punctuation)}]*$" 26 | return bool(re.match(pattern, text)) 27 | 28 | 29 | class TokenizerInterface(ABC): 30 | def __init__(self, model_path): 31 | self.model_path = model_path 32 | self.vocab = None 33 | 34 | @abstractmethod 35 | def encode(self, text): 36 | pass 37 | 38 | @abstractmethod 39 | def decode(self, tokens): 40 | pass 41 | 42 | @abstractmethod 43 | def bos_id(self): 44 | pass 45 | 46 | @abstractmethod 47 | def eos_id(self): 48 | pass 49 | 50 | @abstractmethod 51 | def get_terminator_ids(self): 52 | pass 53 | 54 | @abstractmethod 55 | def special_ids(self) -> List[List[int]]: 56 | pass 57 | 58 | @abstractmethod 59 | def __len__(self): 60 | pass 61 | 62 | def punctuation_ids(self): 63 | return [i for i, wp in enumerate(self.vocab) if is_punc_id(wp)] 64 | 65 | def get_vocab(self): 66 | assert ( 67 | self.vocab is not None 68 | ), "Subclasses should set the vocab attribute during initialization." 69 | return self.vocab 70 | 71 | 72 | class SentencePieceWrapper(TokenizerInterface): 73 | def __init__(self, model_path): 74 | super().__init__(model_path) 75 | self.model_path = model_path 76 | self.processor = spm.SentencePieceProcessor(str(model_path)) 77 | self.terminator_ids = [self.processor.eos_id()] 78 | self.vocab = [ 79 | self.processor.id_to_piece(id) 80 | for id in range(self.processor.get_piece_size()) 81 | ] 82 | 83 | def addl_special_ids(self): 84 | # If llama-2 in model path, return special tokens for llama-2 85 | if "llama-2" in str(self.model_path).lower(): 86 | special_tokens = ["[INST]", "[/INST]"] 87 | else: 88 | raise ValueError(f"Unknown model path: {self.model_path}") 89 | 90 | def _encode_special(token): 91 | ids = self.processor.EncodeAsIds(token) 92 | if len(ids) > 1: 93 | print(f"Special token {token} was tokenized into {len(ids)} tokens") 94 | return ids 95 | 96 | return list(map(_encode_special, special_tokens)) 97 | 98 | def special_ids(self) -> List[List[int]]: 99 | # Some of the chat templates aren't given a singular special token so we return a list of lists 100 | return [ 101 | [self.processor.bos_id()], 102 | [self.processor.eos_id()], 103 | *self.addl_special_ids(), 104 | ] 105 | 106 | def encode(self, text): 107 | return self.processor.EncodeAsIds(text) 108 | 109 | def decode(self, tokens): 110 | return self.processor.DecodeIds(tokens) 111 | 112 | def bos_id(self): 113 | return self.processor.bos_id() 114 | 115 | def eos_id(self): 116 | return self.processor.eos_id() 117 | 118 | def get_terminator_ids(self): 119 | return self.terminator_ids 120 | 121 | def __len__(self): 122 | return self.processor.get_piece_size() 123 | 124 | 125 | class TiktokenWrapper(TokenizerInterface): 126 | """ 127 | Tokenizing and encoding/decoding text using the Tiktoken tokenizer. 128 | """ 129 | 130 | special_tokens: Dict[str, int] 131 | 132 | num_reserved_special_tokens = 256 133 | 134 | pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 135 | 136 | def __init__(self, model_path): 137 | super().__init__(model_path) 138 | assert os.path.isfile(model_path), str(model_path) 139 | mergeable_ranks = load_tiktoken_bpe(str(model_path)) 140 | num_base_tokens = len(mergeable_ranks) 141 | special_tokens = [ 142 | "<|begin_of_text|>", 143 | "<|end_of_text|>", 144 | "<|reserved_special_token_0|>", 145 | "<|reserved_special_token_1|>", 146 | "<|reserved_special_token_2|>", 147 | "<|reserved_special_token_3|>", 148 | "<|start_header_id|>", 149 | "<|end_header_id|>", 150 | "<|reserved_special_token_4|>", 151 | "<|eot_id|>", # end of turn 152 | ] + [ 153 | f"<|reserved_special_token_{i}|>" 154 | for i in range(5, self.num_reserved_special_tokens - 5) 155 | ] 156 | self.special_tokens = { 157 | token: num_base_tokens + i for i, token in enumerate(special_tokens) 158 | } 159 | self.model = tiktoken.Encoding( 160 | name=Path(model_path).name, 161 | pat_str=self.pat_str, 162 | mergeable_ranks=mergeable_ranks, 163 | special_tokens=self.special_tokens, 164 | ) 165 | # BOS / EOS token IDs 166 | self._bos_id: int = self.special_tokens["<|begin_of_text|>"] 167 | self._eos_id: int = self.special_tokens["<|end_of_text|>"] 168 | self.terminator_ids = [self._eos_id, self.special_tokens["<|eot_id|>"]] 169 | self.vocab = [self.model.decode([i]) for i in range(self.model.n_vocab)] 170 | 171 | def encode(self, text): 172 | return self.model.encode(text) 173 | 174 | def special_ids(self) -> List[List[int]]: 175 | # Some of the chat templates aren't given a singular special token so we return a list of lists 176 | return [[x] for x in list(sorted(self.special_tokens.values()))] 177 | 178 | def decode(self, tokens): 179 | return self.model.decode(tokens) 180 | 181 | def bos_id(self): 182 | return self._bos_id 183 | 184 | def eos_id(self): 185 | return self._eos_id 186 | 187 | def get_terminator_ids(self): 188 | return self.terminator_ids 189 | 190 | def __len__(self): 191 | return self.model.n_vocab 192 | 193 | 194 | class TokenizersWrapper(TokenizerInterface): 195 | def __init__(self, model_path): 196 | super().__init__(model_path) 197 | self.tokenizer = AutoTokenizer.from_pretrained(model_path) 198 | self.terminator_ids = [self.tokenizer.eos_token_id] 199 | self.vocab = [ 200 | self.tokenizer.decode(i) for i in range(self.tokenizer.vocab_size) 201 | ] 202 | 203 | def special_ids(self) -> List[List[int]]: 204 | if hasattr(self.tokenizer, "special_token_ids"): 205 | return [[x] for x in self.tokenizer.special_token_ids] 206 | 207 | # Its likely a tokenizer that has a special_tokens_map attribute 208 | special_tokens_ = list(self.tokenizer.special_tokens_map.values()) 209 | special_tokens = [] 210 | for t in special_tokens_: 211 | if type(t) == list: 212 | special_tokens.extend(t) 213 | else: 214 | special_tokens.append(t) 215 | special_tokens = list(set(special_tokens)) 216 | return [[self.tokenizer.convert_tokens_to_ids(t)] for t in special_tokens] 217 | 218 | def encode(self, text): 219 | return self.tokenizer.encode(text, add_special_tokens=False) 220 | 221 | def decode(self, tokens): 222 | return self.tokenizer.decode(tokens) 223 | 224 | def bos_id(self): 225 | return self.tokenizer.bos_token_id 226 | 227 | def eos_id(self): 228 | return self.tokenizer.eos_token_id 229 | 230 | def get_terminator_ids(self): 231 | return self.terminator_ids 232 | 233 | def __len__(self): 234 | return len(self.tokenizer) 235 | 236 | 237 | def get_tokenizer(tokenizer_model_path, model_name, is_chat=False): 238 | """ 239 | Factory function to get the appropriate tokenizer based on the model name. 240 | 241 | Args: 242 | - tokenizer_model_path (str): The file path to the tokenizer model. 243 | - model_name (str): The name of the model, used to determine the tokenizer type. 244 | 245 | Returns: 246 | - TokenizerInterface: An instance of a tokenizer. 247 | """ 248 | if "llama-3" in str(model_name).lower(): 249 | return ( 250 | Llama3ChatFormat(tokenizer_model_path) 251 | if is_chat 252 | else TiktokenWrapper(tokenizer_model_path) 253 | ) 254 | elif "llama-2" in str(model_name).lower(): 255 | return ( 256 | Llama2ChatFormat(tokenizer_model_path) 257 | if is_chat 258 | else SentencePieceWrapper(tokenizer_model_path) 259 | ) 260 | else: 261 | return ( 262 | TokenizersChatFormat(tokenizer_model_path) 263 | if is_chat 264 | else TokenizersWrapper(tokenizer_model_path) 265 | ) 266 | 267 | 268 | Role = Literal["system", "user", "assistant"] 269 | 270 | 271 | class Message(TypedDict): 272 | role: Role 273 | content: str 274 | 275 | 276 | class Llama3ChatFormat(TiktokenWrapper): 277 | def __init__(self, model_path): 278 | super().__init__(model_path) 279 | 280 | def encode_header(self, message: Message) -> List[int]: 281 | return [ 282 | self.special_tokens["<|start_header_id|>"], 283 | *self.encode(message["role"]), 284 | self.special_tokens["<|end_header_id|>"], 285 | *self.encode("\n\n"), 286 | ] 287 | 288 | def encode_prompt(self, prompt: str): 289 | return self.encode_dialog_prompt([{"role": "user", "content": prompt}]) 290 | 291 | def encode_message(self, message: Message) -> List[int]: 292 | tokens = self.encode_header(message) 293 | tokens.extend(self.encode(message["content"].strip())) 294 | tokens.append(self.special_tokens["<|eot_id|>"]) 295 | return tokens 296 | 297 | def encode_dialog_prompt(self, dialog: List[Message]) -> List[int]: 298 | return [ 299 | self.special_tokens["<|begin_of_text|>"], 300 | *list(itertools.chain(*map(self.encode_message, dialog))), 301 | # Add the start of an assistant message for the model to complete. 302 | *self.encode_header({"role": "assistant", "content": ""}), 303 | ] 304 | 305 | 306 | class Llama2ChatFormat(SentencePieceWrapper): 307 | B_INST = "[INST]" 308 | E_INST = "[/INST]" 309 | 310 | def __init__(self, model_path): 311 | super().__init__(model_path) 312 | 313 | def encode_prompt(self, prompt: str): 314 | ids = [self.bos_id()] 315 | ids += self.encode(Llama2ChatFormat.B_INST + "\n\n") 316 | ids += self.encode(prompt + " " + Llama2ChatFormat.E_INST) 317 | return ids 318 | 319 | 320 | class TokenizersChatFormat(TokenizersWrapper): 321 | def __init__(self, model_path): 322 | super().__init__(model_path) 323 | 324 | def encode_prompt(self, prompt: str): 325 | messages = [{"role": "user", "content": prompt}] 326 | return self.encode_dialog_prompt(messages) 327 | 328 | def encode_dialog_prompt(self, dialog: List[Message]) -> List[int]: 329 | text = self.tokenizer.apply_chat_template( 330 | dialog, tokenize=False, add_generation_prompt=True 331 | ) 332 | return self.encode(text) 333 | 334 | 335 | def encode_tokens(tokenizer, string, bos=True, device=default_device): 336 | tokens = tokenizer.encode(string) 337 | if bos: 338 | tokens = [tokenizer.bos_id()] + tokens 339 | return torch.tensor(tokens, dtype=torch.int, device=device) 340 | 341 | 342 | def encode(tokenizer, prompt, device=default_device, bos=True, is_chat=True): 343 | if is_chat: 344 | tokens = tokenizer.encode_prompt(prompt) 345 | encoded = torch.tensor(tokens, dtype=torch.int, device=device) 346 | else: 347 | encoded = encode_tokens(tokenizer, prompt, device=device, bos=bos) 348 | 349 | return encoded 350 | -------------------------------------------------------------------------------- /GPTQ.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | import torch.fx as fx 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.utils._pytree import tree_flatten, tree_unflatten 13 | 14 | aten = torch.ops.aten 15 | 16 | from eval import ( 17 | setup_cache_padded_seq_input_pos_max_seq_length_for_prefill, 18 | GPTFastEvalWrapper, 19 | ) 20 | 21 | 22 | class InputRecorder(GPTFastEvalWrapper): 23 | """ 24 | This is a fake evaluation wrapper that just records the inputs 25 | so that they can be used in calibration. 26 | 27 | If pad_calibration_inputs is enabled, the input recorder will take 28 | each input and pad/truncate it down to the calibration_seq_length. 29 | It will also edit the model embeddings to be zero for the 0 token used 30 | in padding and avoid any inputs with the 0 token. 31 | 32 | If not, it will only truncate inputs to the desired length. 33 | """ 34 | 35 | def __init__( 36 | self, 37 | model, 38 | tokenizer, 39 | calibration_seq_length, 40 | pad_calibration_inputs=False, 41 | ): 42 | super().__init__(model, tokenizer, calibration_seq_length) 43 | self._model = model 44 | self._tokenizer = tokenizer 45 | self._device = torch.device("cpu") 46 | self.vocab_size = model.config.vocab_size 47 | self.calibration_seq_length = calibration_seq_length 48 | self.pad_calibration_inputs = pad_calibration_inputs 49 | self.inputs = None 50 | 51 | if self.pad_calibration_inputs: 52 | # This is needed for the pad_calibration_inputs option 53 | # to work properly, the 0 token's embeddings are set to 0 so that 54 | # the padded inputs will not affect the model numerics. This token isn't used 55 | # commonly in the eval tasks for the meta-llama tokenizer and we skip any inputs 56 | # where it appears 57 | try: 58 | if isinstance(self._model.transformer.wte, nn.Embedding): 59 | self.mod.transformer.wte.weight.data[0, :] *= 0 60 | except: 61 | print( 62 | "Did not find embeddings in model.transformer.wte, disabling padding" 63 | ) 64 | self.pad_calibration_inputs = False 65 | 66 | def add_input(self, args): 67 | if self.inputs is None: 68 | self.inputs = [MultiInput([arg]) for arg in args] 69 | else: 70 | self.inputs = [ 71 | multi.add_input(arg) for (multi, arg) in zip(self.inputs, args) 72 | ] 73 | 74 | def get_recorded_inputs(self): 75 | return self.inputs 76 | 77 | def _model_call(self, inps): 78 | inps = inps.squeeze(0) 79 | T = len(inps) 80 | if ( 81 | # can't use inputs that are too short when padding disabled 82 | (T < self.calibration_seq_length and not self.pad_calibration_inputs) 83 | or 84 | # can't use inputs that actually use token we use for padding 85 | (self.pad_calibration_inputs and 0 in inps) 86 | ): 87 | # give random output 88 | return torch.randn( 89 | (1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device 90 | ) 91 | 92 | # pad or truncate to the right size 93 | if T >= self.calibration_seq_length: 94 | inps = inps[: self.calibration_seq_length] 95 | else: 96 | inps = F.pad(inps, (0, self.calibration_seq_length - T)) 97 | 98 | max_new_tokens = 1 99 | ( 100 | seq, 101 | input_pos, 102 | max_seq_length, 103 | ) = setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( 104 | self._model, inps, max_new_tokens, self.max_length 105 | ) 106 | x = seq.index_select(0, input_pos).view(1, -1) 107 | self.add_input((x, input_pos)) 108 | 109 | # output `something` with correct shape to keep eval going 110 | return torch.randn( 111 | (1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device 112 | ) 113 | 114 | 115 | class MultiInput: 116 | def __init__(self, inputs): 117 | self.values = list(inputs) 118 | 119 | def add_input(self, input): 120 | self.values.append(input) 121 | return self 122 | 123 | def __getitem__(self, slice): 124 | return MultiInput(self.values[slice]) 125 | 126 | def cuda(self): 127 | self.values = [ 128 | val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values 129 | ] 130 | 131 | 132 | class GenericGPTQRunner(fx.Interpreter): 133 | """ 134 | This is a generic GPTQ runner that takes an existing model and applies GPTQ. 135 | It uses torch._dynamo.export to obtain a graph of the model and then hooks 136 | into function calls and when it detects a linear, it applies GPTQ to the weight 137 | given the calibration of inputs passed in at initialization. It puts the results 138 | into the state_dict so that the quantized model weights/qparams can be loaded 139 | directly into the model. 140 | 141 | This class is expected to work in concert with a GPTQSimpleQuantizer 142 | class to define the specific type of quantization being done. 143 | """ 144 | 145 | def __init__( 146 | self, model, inputs: MultiInput, blocksize=128, percdamp=0.01, groupsize=128 147 | ): 148 | self.id_to_name = { 149 | id(value): name for name, value in dict(model.named_parameters()).items() 150 | } 151 | 152 | # trace model for one input 153 | one_input = [multi.values[0].cpu() for multi in inputs] 154 | exported_model = torch._dynamo.export( 155 | model.cpu(), aten_graph=True, pre_dispatch=True, tracing_mode="fake" 156 | )(*one_input) 157 | super().__init__(exported_model.graph_module) 158 | self.new_state_dict = model.state_dict() 159 | self.blocksize = blocksize 160 | self.percdamp = percdamp 161 | self.groupsize = groupsize 162 | self.inputs = inputs 163 | self.gptq_done = False 164 | self.debug = False 165 | 166 | def configure_quantization_mode( 167 | self, 168 | get_qparams_func, 169 | quantize_func, 170 | dequantize_func, 171 | combine_qparams_list_func, 172 | make_names_and_values_dict_func, 173 | skip_layer_func, 174 | ): 175 | # these functions need to already be curried with all inputs other than weight, qparams 176 | self.get_qparams_func = ( 177 | get_qparams_func # accepts [2d weight tensor], outputs qparams. 178 | ) 179 | 180 | self.quantize_func = quantize_func # accepts [2d weight tensor], [qparams], outputs a 2d quantized tensor of desired dtype 181 | 182 | self.dequantize_func = dequantize_func 183 | # accepts [quantized] tensor and [qparams], outputs a 2d dequantized tensor of type float, 184 | # assumes this output .to(w_orig_dtype) is ~eventual desired dequant behavior 185 | 186 | self.combine_qparams_list_func = combine_qparams_list_func 187 | # accepts [`list` of qparams] from quantizing one group at a time, 188 | # outputs a qparams object that could be passed into quant/dequantize_func 189 | 190 | self.skip_layer_func = skip_layer_func # accepts [weight tensor], outputs a bool on whether or not to apply gptq to this layer 191 | 192 | self.make_names_and_values_dict_func = make_names_and_values_dict_func # accepts [2d quantized tensor], [qparams], returns a dict of names, values to put in state_dict 193 | # note any final packing for storage should happen here 194 | return self 195 | 196 | def run(self): 197 | assert ( 198 | self.get_qparams_func is not None 199 | ), "need to configure quantization mode before running" 200 | self.gptq_done = True 201 | super().run(*self.inputs) 202 | 203 | def get_quantized_state_dict(self): 204 | assert ( 205 | self.gptq_done 206 | ), "need to run GPTQRunner before you can get_quantized_state_dict" 207 | quantized_state_dict = self.new_state_dict 208 | # Don't want to store/load the kv_cache so remove it from the state_dict 209 | del_list = [] 210 | for param_fqn in quantized_state_dict: 211 | if "kv_cache" in param_fqn: 212 | del_list.append(param_fqn) 213 | for param_fqn in del_list: 214 | quantized_state_dict.pop(param_fqn) 215 | return quantized_state_dict 216 | 217 | def call_function(self, target, args, kwargs, skip_quant=False): 218 | def tensors_to_cuda(args): 219 | new_args = [] 220 | for x in args: 221 | new_args.append(x.cuda() if isinstance(x, torch.Tensor) else x) 222 | return new_args 223 | 224 | # flatten args and kwargs together 225 | flat_args, spec = tree_flatten((args, kwargs)) 226 | # move all single tensors to cuda, will move MultiInputs to cuda one at a time 227 | flat_args = tensors_to_cuda(flat_args) 228 | 229 | has_multi_input = MultiInput in [type(x) for x in flat_args] 230 | if has_multi_input: 231 | # Just some trickery to convert 232 | # [MultiInput[a, a, a], MultiInput(b, b, b)] => [a, b], [a, b], [a, b] 233 | multi_input_count = max( 234 | [len(x.values) if isinstance(x, MultiInput) else 1 for x in flat_args] 235 | ) 236 | transposed_args = list( 237 | zip( 238 | *[ 239 | x.values 240 | if isinstance(x, MultiInput) 241 | else [x] * multi_input_count 242 | for x in flat_args 243 | ] 244 | ) 245 | ) 246 | else: 247 | transposed_args = [flat_args] 248 | outputs = [] 249 | 250 | # check whether we apply GPTQ to this module 251 | quantize_linear = ( 252 | (target == aten.linear.default) # if its a linear 253 | and id(args[1]) in self.id_to_name # and if we know the layer name 254 | and not skip_quant # and if we weren't told to skip quantization 255 | # and if the skip_layer_func doesn't say we should skip 256 | and not (self.skip_layer_func is not None and self.skip_layer_func(args[1])) 257 | ) # then we will quantize this linear layer/weight 258 | 259 | if quantize_linear: # instantiate variables for GPTQ 260 | H = 0 261 | total_batches = 0 262 | 263 | for inp in transposed_args: 264 | inp = tensors_to_cuda(inp) 265 | cur_args, cur_kwargs = tree_unflatten(inp, spec) 266 | 267 | if quantize_linear: # calculate H instead of output (will run the linear eventually with updated weight) 268 | x = cur_args[0].float() 269 | shape = x.shape 270 | n = 1 if len(shape) == 2 else shape[0] 271 | H *= total_batches / (total_batches + n) 272 | total_batches += n 273 | x = ((2 / total_batches) ** (1 / 2)) * x.reshape( 274 | -1, shape[-1] 275 | ).t().float() 276 | H += x.matmul(x.t()) 277 | else: 278 | # get output if its not a linear 279 | out = super().call_function(target, cur_args, cur_kwargs) 280 | 281 | if isinstance(out, torch.Tensor): 282 | outputs.append(out.cpu()) 283 | else: 284 | outputs.append(out) 285 | 286 | if quantize_linear: 287 | mod_fqn = ".".join(self.id_to_name[id(args[1])].split(".")[:-1]) 288 | W = args[1].to(H.device) 289 | Q, DQ, qparams = self.faster_quant(H, W.detach()) 290 | print(mod_fqn) 291 | names_and_values_dict = self.make_names_and_values_dict_func(Q, qparams) 292 | 293 | # delete old weight 294 | if mod_fqn + ".weight" in self.new_state_dict: 295 | self.new_state_dict.pop(mod_fqn + ".weight") 296 | if len(args) > 2: 297 | self.new_state_dict[mod_fqn + ".bias"] = args[2] 298 | for name, value in names_and_values_dict.items(): 299 | self.new_state_dict[mod_fqn + "." + name] = value 300 | 301 | # run linear with new weight to get corrected output 302 | new_out = self.call_function( 303 | target, (args[0], DQ, *args[2:]), kwargs, skip_quant=True 304 | ) 305 | 306 | if self.debug: 307 | old_out = self.call_function( 308 | target, (args[0][:2], args[1], *args[2:]), kwargs, skip_quant=True 309 | ) 310 | 311 | def SQNR(x, y): 312 | return 20 * torch.log10(torch.norm(x) / torch.norm(x - y)) 313 | 314 | DQ_after = self.dequantize_func(Q, qparams).to(W.dtype) 315 | print( 316 | "SQNR for QDQ (this should be inf)", SQNR(DQ, DQ_after) 317 | ) # matches 318 | 319 | print( 320 | "SQNR for weight (can be low)", SQNR(W, DQ.cuda()) 321 | ) # fine to not match 322 | print( 323 | "SQNR for output with GPTQ (hopefully 35+)", 324 | torch.cat( 325 | [ 326 | SQNR(old.cpu(), new.cpu()).unsqueeze(0) 327 | for (old, new) in zip(old_out.values, new_out.values[:2]) 328 | ] 329 | ).mean(), 330 | ) 331 | 332 | qparams2 = self.get_qparams_func(W) 333 | Q2 = self.quantize_func(W, qparams2) 334 | DQ2 = self.dequantize_func(Q2, qparams2).to(W.dtype) 335 | old_q_out = self.call_function( 336 | target, (args[0][:2], DQ2, *args[2:]), kwargs, skip_quant=True 337 | ) 338 | 339 | print( 340 | "SQNR for output without GPTQ (should be less than above)", 341 | torch.cat( 342 | [ 343 | SQNR(old.cpu(), old_q.cpu()).unsqueeze(0) 344 | for (old, old_q) in zip(old_out.values, old_q_out.values) 345 | ] 346 | ).mean(), 347 | ) 348 | return new_out 349 | 350 | return MultiInput(outputs) if has_multi_input else outputs[0] 351 | 352 | def faster_quant(self, H, W): 353 | percdamp = self.percdamp 354 | blocksize = self.blocksize 355 | groupsize = self.groupsize 356 | orig_dtype = W.dtype 357 | W = W.detach().float() 358 | rows, columns = W.shape[0], W.shape[1] 359 | device = W.device 360 | 361 | if groupsize == -1: 362 | cur_qparams = self.get_qparams_func(W) 363 | dead = torch.diag(H) == 0 364 | H[dead, dead] = 1 365 | W[:, dead] = 0 366 | 367 | Losses = torch.zeros_like(W) 368 | DQ = torch.zeros_like(W) 369 | 370 | damp = percdamp * torch.mean(torch.diag(H)) 371 | diag = torch.arange(columns, device=device) 372 | H[diag, diag] += damp 373 | H = torch.linalg.cholesky(H) 374 | H = torch.cholesky_inverse(H) 375 | H = torch.linalg.cholesky(H, upper=True) 376 | Hinv = H 377 | 378 | all_qparams = [] 379 | for i1 in range(0, columns, blocksize): 380 | i2 = min(i1 + blocksize, columns) 381 | count = i2 - i1 382 | W1 = W[:, i1:i2].clone() 383 | DQ1 = torch.zeros_like(W1) 384 | Err1 = torch.zeros_like(W1) 385 | Losses1 = torch.zeros_like(W1) 386 | Hinv1 = Hinv[i1:i2, i1:i2] 387 | for i in range(count): 388 | w = W1[:, i] 389 | d = Hinv1[i, i] 390 | 391 | if groupsize != -1 and (i1 + i) % groupsize == 0: # start of new group 392 | cur_qparams = self.get_qparams_func( 393 | W[:, (i1 + i) : (i1 + i + groupsize)] 394 | ) 395 | all_qparams.append(cur_qparams) 396 | 397 | q = self.quantize_func(w.unsqueeze(1), cur_qparams).flatten() 398 | dq = self.dequantize_func(q.unsqueeze(1), cur_qparams).flatten() 399 | 400 | DQ1[:, i] = dq 401 | Losses1[:, i] = (w - dq) ** 2 / d**2 402 | 403 | err1 = (w - dq) / d 404 | W1[:, i:] -= ( 405 | err1.to(Hinv1.dtype).unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) 406 | ) 407 | Err1[:, i] = err1 408 | 409 | DQ[:, i1:i2] = DQ1 410 | Losses[:, i1:i2] = Losses1 / 2 411 | 412 | W[:, i2:] -= Err1.to(Hinv.dtype).matmul(Hinv[i1:i2, i2:]) 413 | 414 | torch.cuda.synchronize() 415 | 416 | if all_qparams == []: 417 | all_qparams.append(cur_qparams) 418 | 419 | # convert a list of qparams objects into a single one. enerally by 420 | # concatenating a bunch of n,1 scale/zeros tensors into a n,num_groups tensor 421 | all_qparams = self.combine_qparams_list_func(all_qparams) 422 | Q = self.quantize_func(DQ, all_qparams) 423 | return Q, DQ.to(orig_dtype), all_qparams 424 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from dataclasses import dataclass 7 | from collections import defaultdict 8 | from typing import Optional, Dict, Any 9 | 10 | import math 11 | import torch 12 | import torch.nn as nn 13 | from torch import Tensor 14 | from torch.nn import functional as F 15 | 16 | from attention_utils import scaled_dot_product_attention 17 | from cache import get_cache_constructor 18 | from prompt_compression import get_prompt_compressor_constructor 19 | 20 | 21 | def find_multiple(n: int, k: int) -> int: 22 | if n % k == 0: 23 | return n 24 | return n + k - (n % k) 25 | 26 | 27 | @dataclass 28 | class ModelArgs: 29 | block_size: int = 2048 30 | vocab_size: int = 32000 31 | n_layer: int = 32 32 | n_head: int = 32 33 | dim: int = 4096 34 | intermediate_size: int = None 35 | n_local_heads: int = -1 36 | head_dim: int = 64 37 | rope_base: float = 10000 38 | norm_eps: float = 1e-5 39 | attention_bias: bool = False 40 | max_length: int = 4096 41 | rope_scaling: Optional[Dict[str, Any]] = None 42 | 43 | def __post_init__(self): 44 | if self.n_local_heads == -1: 45 | self.n_local_heads = self.n_head 46 | if self.intermediate_size is None: 47 | hidden_dim = 4 * self.dim 48 | n_hidden = int(2 * hidden_dim / 3) 49 | self.intermediate_size = find_multiple(n_hidden, 256) 50 | self.head_dim = self.dim // self.n_head 51 | 52 | @classmethod 53 | def from_name(cls, name: str): 54 | if name in transformer_configs: 55 | return cls(**transformer_configs[name]) 56 | # fuzzy search 57 | config = [ 58 | config 59 | for config in transformer_configs 60 | if config in str(name).upper() or config in str(name) 61 | ] 62 | 63 | # We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match, 64 | # take longer name (as it have more symbols matched) 65 | if len(config) > 1: 66 | config.sort(key=len, reverse=True) 67 | assert len(config[0]) != len( 68 | config[1] 69 | ), name # make sure only one 'best' match 70 | 71 | return cls(**transformer_configs[config[0]]) 72 | 73 | 74 | transformer_configs = { 75 | "CodeLlama-7b-Python-hf": dict( 76 | block_size=16384, vocab_size=32000, n_layer=32, dim=4096, rope_base=1000000 77 | ), 78 | "7B": dict(n_layer=32, n_head=32, dim=4096), 79 | "13B": dict(n_layer=40, n_head=40, dim=5120), 80 | "30B": dict(n_layer=60, n_head=52, dim=6656), 81 | "34B": dict( 82 | n_layer=48, 83 | n_head=64, 84 | dim=8192, 85 | vocab_size=32000, 86 | n_local_heads=8, 87 | intermediate_size=22016, 88 | rope_base=1000000, 89 | ), # CodeLlama-34B-Python-hf 90 | "70B": dict( 91 | n_layer=80, n_head=64, dim=8192, n_local_heads=8, intermediate_size=28672 92 | ), 93 | "Mistral-7B": dict( 94 | n_layer=32, 95 | n_head=32, 96 | n_local_heads=8, 97 | dim=4096, 98 | intermediate_size=14336, 99 | vocab_size=32000, 100 | ), 101 | "stories15M": dict(n_layer=6, n_head=6, dim=288), 102 | "stories110M": dict(n_layer=12, n_head=12, dim=768), 103 | "Meta-Llama-3-8B-Instruct": dict( 104 | block_size=8192, 105 | n_layer=32, 106 | n_head=32, 107 | n_local_heads=8, 108 | dim=4096, 109 | intermediate_size=14336, 110 | vocab_size=128256, 111 | rope_base=500000, 112 | max_length=8192, 113 | ), 114 | "Meta-Llama-3.1-8B-Instruct": dict( 115 | block_size=131072, 116 | n_layer=32, 117 | n_head=32, 118 | n_local_heads=8, 119 | dim=4096, 120 | intermediate_size=14336, 121 | vocab_size=128256, 122 | rope_base=500000, 123 | max_length=131072, 124 | rope_scaling={ 125 | "factor": 8.0, 126 | "low_freq_factor": 1.0, 127 | "high_freq_factor": 4.0, 128 | "original_max_position_embeddings": 8192, 129 | "rope_type": "llama3", 130 | }, 131 | ), 132 | "Qwen2-1.5B-Instruct": dict( 133 | block_size=32768, 134 | n_layer=28, 135 | n_head=12, 136 | n_local_heads=2, 137 | dim=1536, 138 | intermediate_size=8960, 139 | vocab_size=151936, 140 | rope_base=1000000, 141 | attention_bias=True, 142 | norm_eps=1e-6, 143 | max_length=32768, 144 | ), 145 | "Qwen2-0.5B-Instruct": dict( 146 | block_size=32768, 147 | n_layer=24, 148 | n_head=14, 149 | n_local_heads=2, 150 | dim=896, 151 | intermediate_size=4864, 152 | vocab_size=151936, 153 | rope_base=1000000, 154 | attention_bias=True, 155 | norm_eps=1e-6, 156 | max_length=32768, 157 | ), 158 | "Qwen2-7B-Instruct": dict( 159 | block_size=32768, 160 | n_layer=28, 161 | n_head=28, 162 | n_local_heads=4, 163 | dim=3584, 164 | intermediate_size=18944, 165 | vocab_size=152064, 166 | rope_base=1000000, 167 | attention_bias=True, 168 | norm_eps=1e-6, 169 | max_length=32768, 170 | ), 171 | } 172 | 173 | 174 | class Transformer(nn.Module): 175 | def __init__(self, config: ModelArgs) -> None: 176 | super().__init__() 177 | self.config = config 178 | 179 | self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) 180 | self.layers = nn.ModuleList( 181 | TransformerBlock(config) for _ in range(config.n_layer) 182 | ) 183 | self.norm = RMSNorm(config.dim, eps=config.norm_eps) 184 | self.output = nn.Linear(config.dim, config.vocab_size, bias=False) 185 | 186 | self.freqs_cis: Optional[Tensor] = None 187 | 188 | # Fixed for now 189 | self.max_batch_size = 1 190 | 191 | def setup_caches(self, **kwargs): 192 | cache_strategy = kwargs.pop("cache_strategy") 193 | 194 | head_dim = self.config.dim // self.config.n_head 195 | 196 | dtype = self.output.weight.dtype 197 | # For quantized layers, dtype is encoded in scales 198 | if hasattr(self.output, "scales"): 199 | dtype = self.output.scales.dtype 200 | elif hasattr(self.output, "scales_and_zeros"): 201 | dtype = self.output.scales_and_zeros.dtype 202 | for layer_idx, b in enumerate(self.layers): 203 | cache_constructor, relevant_kwargs = get_cache_constructor( 204 | cache_strategy=cache_strategy[layer_idx] 205 | ) 206 | # Only pass in the kwargs we need for the cache we chose (useful especially for debugging) 207 | layerwise_keys = { 208 | "max_cache_length", 209 | "recent_window", 210 | "prompt_compression_strategy", 211 | } 212 | layer_kwargs = { 213 | k: kwargs[k][layer_idx] if k in layerwise_keys else kwargs[k] 214 | for k in relevant_kwargs 215 | } 216 | b.attention.kv_cache = cache_constructor( 217 | self.max_batch_size, 218 | self.config.n_local_heads, 219 | head_dim, 220 | dtype, 221 | **layer_kwargs, 222 | ) 223 | b.attention.prompt_compressor = get_prompt_compressor_constructor( 224 | kwargs["prompt_compression_strategy"][layer_idx] 225 | )(head_specific=b.attention.kv_cache.head_specific, **layer_kwargs) 226 | 227 | self.freqs_cis = precompute_freqs_cis( 228 | self.config.block_size, 229 | self.config.dim // self.config.n_head, 230 | self.config.rope_base, 231 | dtype, 232 | self.config.rope_scaling, 233 | ) 234 | 235 | def reset_caches(self): 236 | for layer in self.layers: 237 | layer.attention.kv_cache.reset() 238 | 239 | def prompt_cache_overflow(self, prompt_length: int): 240 | return [ 241 | prompt_length > layer.attention.kv_cache.max_cache_length 242 | for layer in self.layers 243 | ] 244 | 245 | def get_cache_stats(self, prompt_len, gen_len): 246 | stats = {} 247 | final_seq_len = prompt_len + gen_len 248 | avgs = defaultdict(list) 249 | mem_total = 0 250 | for layer_idx, layer in enumerate(self.layers): 251 | stat = layer.attention.kv_cache.compute_statistics( 252 | seq_len=torch.tensor(final_seq_len) 253 | ) 254 | mem_total += stat.pop("cache_memory_gb") 255 | for k, v in stat.items(): 256 | stats[f"{k}_{layer_idx}"] = v 257 | avgs[k].append(v) 258 | 259 | for k, v in avgs.items(): 260 | stats[f"{k}_avg"] = sum(v) / len(v) 261 | 262 | stats["cache_memory_gb"] = mem_total 263 | return stats 264 | 265 | def min_cache_length(self): 266 | return min([layer.attention.kv_cache.max_cache_length for layer in self.layers]) 267 | 268 | def forward( 269 | self, 270 | idx: Tensor, 271 | input_pos: Tensor, 272 | is_prefill: Tensor, 273 | mask: Optional[Tensor] = None, 274 | attn_top_k: Optional[float] = 1.0, 275 | ) -> Tensor: 276 | assert self.freqs_cis is not None, "Caches must be initialized first" 277 | freqs_cis = self.freqs_cis[input_pos] 278 | x = self.tok_embeddings(idx) 279 | 280 | for i, layer in enumerate(self.layers): 281 | x = layer( 282 | x, 283 | idx, 284 | input_pos, 285 | is_prefill, 286 | freqs_cis, 287 | mask, 288 | attn_top_k=attn_top_k, 289 | ) 290 | x = self.norm(x) 291 | logits = self.output(x) 292 | return logits 293 | 294 | @classmethod 295 | def from_name(cls, name: str): 296 | return cls(ModelArgs.from_name(name)) 297 | 298 | 299 | class TransformerBlock(nn.Module): 300 | def __init__(self, config: ModelArgs) -> None: 301 | super().__init__() 302 | self.attention = Attention(config) 303 | self.feed_forward = FeedForward(config) 304 | self.ffn_norm = RMSNorm(config.dim, config.norm_eps) 305 | self.attention_norm = RMSNorm(config.dim, config.norm_eps) 306 | 307 | def forward( 308 | self, 309 | x: Tensor, 310 | input_ids: Tensor, 311 | input_pos: Tensor, 312 | is_prefill: Tensor, 313 | freqs_cis: Tensor, 314 | mask: Tensor, 315 | attn_top_k: Optional[float] = 1.0, 316 | ) -> Tensor: 317 | h = x + self.attention( 318 | self.attention_norm(x), 319 | input_ids, 320 | freqs_cis, 321 | mask, 322 | is_prefill, 323 | input_pos, 324 | attn_top_k=attn_top_k, 325 | ) 326 | out = h + self.feed_forward(self.ffn_norm(h)) 327 | return out 328 | 329 | 330 | class Attention(nn.Module): 331 | def __init__(self, config: ModelArgs): 332 | super().__init__() 333 | assert config.dim % config.n_head == 0 334 | 335 | total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim 336 | # key, query, value projections for all heads, but in a batch 337 | self.wqkv = nn.Linear(config.dim, total_head_dim, bias=config.attention_bias) 338 | self.wo = nn.Linear(config.dim, config.dim, bias=False) 339 | self.kv_cache = None 340 | self.prompt_compressor = None 341 | 342 | self.n_head = config.n_head 343 | self.head_dim = config.head_dim 344 | self.n_local_heads = config.n_local_heads 345 | self.dim = config.dim 346 | self._register_load_state_dict_pre_hook(self.load_hook) 347 | 348 | def load_hook(self, state_dict, prefix, *args): 349 | if prefix + "wq.weight" in state_dict: 350 | wq = state_dict.pop(prefix + "wq.weight") 351 | wk = state_dict.pop(prefix + "wk.weight") 352 | wv = state_dict.pop(prefix + "wv.weight") 353 | state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) 354 | 355 | def compress_prompt(self, input_pos, k_val, v_val, attn): 356 | seq_len = input_pos.shape[0] 357 | if self.kv_cache.max_cache_length < seq_len: 358 | kwargs = {"attn": attn} 359 | return self.prompt_compressor(input_pos, k_val, v_val, **kwargs) 360 | 361 | return input_pos, k_val, v_val, attn 362 | 363 | def forward( 364 | self, 365 | x: Tensor, 366 | input_ids: Tensor, 367 | freqs_cis: Tensor, 368 | mask: Tensor, 369 | is_prefill: bool, 370 | input_pos: Optional[Tensor] = None, 371 | attn_top_k: Optional[float] = 1.0, 372 | ) -> Tensor: 373 | bsz, seqlen, _ = x.shape 374 | 375 | kv_size = self.n_local_heads * self.head_dim 376 | q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) 377 | 378 | q = q.view(bsz, seqlen, self.n_head, self.head_dim) 379 | k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) 380 | v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) 381 | 382 | q = apply_rotary_emb(q, freqs_cis) 383 | k = apply_rotary_emb(k, freqs_cis) 384 | 385 | q = q.transpose(1, 2) 386 | k = k.transpose(1, 2) 387 | v = v.transpose(1, 2) 388 | 389 | kv_mask = None 390 | cache_kwargs = {"input_ids": input_ids} 391 | if not is_prefill: 392 | k, v, kv_mask = self.kv_cache.update_kv( 393 | input_pos, k, v, is_prefill, **cache_kwargs 394 | ) 395 | kv_mask = kv_mask.repeat_interleave( 396 | self.n_head // self.n_local_heads, dim=1 397 | ) 398 | 399 | k_rep = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) 400 | v_rep = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) 401 | 402 | y, attn = scaled_dot_product_attention( 403 | q, 404 | k_rep, 405 | v_rep, 406 | attn_mask=kv_mask if mask is None else mask, 407 | dropout_p=0.0, 408 | attn_top_k=attn_top_k, 409 | # Ask the cache if needs attention scores returned (we cannot use FlexAttention if so) 410 | return_attn=self.kv_cache.return_attn(), 411 | ) 412 | 413 | if ( 414 | attn is not None 415 | ): # Mean pool over the grouped queries (average over self.n_head // self.n_local_heads) 416 | attn = attn.view( 417 | bsz, self.n_local_heads, self.n_head // self.n_local_heads, seqlen, -1 418 | ).mean(dim=2) 419 | 420 | # Prefill updates happen after since we don't use the KV cache for prefill attention 421 | if is_prefill: 422 | input_pos, k, v, attn = self.compress_prompt(input_pos, k, v, attn) 423 | self.kv_cache.update_kv(input_pos, k, v, is_prefill, **cache_kwargs) 424 | 425 | # [Optional] Update the KV Cache internal state now that we have attention probabilities 426 | # This is a no-op for most cache classes 427 | self.kv_cache.update_state(input_pos, k, v, is_prefill, attn, **cache_kwargs) 428 | 429 | y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) 430 | 431 | y = self.wo(y) 432 | return y 433 | 434 | 435 | class FeedForward(nn.Module): 436 | def __init__(self, config: ModelArgs) -> None: 437 | super().__init__() 438 | self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=False) 439 | self.w3 = nn.Linear(config.dim, config.intermediate_size, bias=False) 440 | self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) 441 | 442 | def forward(self, x: Tensor) -> Tensor: 443 | return self.w2(F.silu(self.w1(x)) * self.w3(x)) 444 | 445 | 446 | class RMSNorm(nn.Module): 447 | def __init__(self, dim: int, eps: float = 1e-5): 448 | super().__init__() 449 | self.eps = eps 450 | self.weight = nn.Parameter(torch.ones(dim)) 451 | 452 | def _norm(self, x): 453 | return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) 454 | 455 | def forward(self, x: Tensor) -> Tensor: 456 | output = self._norm(x.float()).type_as(x) 457 | return output * self.weight 458 | 459 | 460 | def precompute_freqs_cis( 461 | seq_len: int, 462 | n_elem: int, 463 | base: int = 10000, 464 | dtype: torch.dtype = torch.bfloat16, 465 | rope_scaling: Optional[Dict[str, Any]] = None, 466 | ) -> Tensor: 467 | freqs = 1.0 / ( 468 | base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) 469 | ) 470 | t = torch.arange(seq_len, device=freqs.device) 471 | if rope_scaling is not None: 472 | assert ( 473 | rope_scaling["rope_type"] == "llama3" 474 | ), "Only Llama 3.1 scaling is supported" 475 | # Apply Llama 3.1 scaling 476 | low_freq_wavelen = ( 477 | rope_scaling["original_max_position_embeddings"] 478 | / rope_scaling["low_freq_factor"] 479 | ) 480 | high_freq_wavelen = ( 481 | rope_scaling["original_max_position_embeddings"] 482 | / rope_scaling["high_freq_factor"] 483 | ) 484 | new_freqs = [] 485 | for freq in freqs: 486 | wavelen = 2 * math.pi / freq 487 | if wavelen < high_freq_wavelen: 488 | new_freqs.append(freq) 489 | elif wavelen > low_freq_wavelen: 490 | new_freqs.append(freq / rope_scaling["factor"]) 491 | else: 492 | smooth = ( 493 | rope_scaling["original_max_position_embeddings"] / wavelen 494 | - rope_scaling["low_freq_factor"] 495 | ) / (rope_scaling["high_freq_factor"] - rope_scaling["low_freq_factor"]) 496 | new_freqs.append( 497 | (1 - smooth) * freq / rope_scaling["factor"] + smooth * freq 498 | ) 499 | freqs = torch.tensor(new_freqs, device=t.device) 500 | 501 | freqs = torch.outer(t, freqs) 502 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) 503 | cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) 504 | return cache.to(dtype=dtype) 505 | 506 | 507 | def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: 508 | xshaped = x.float().reshape(*x.shape[:-1], -1, 2) 509 | freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) 510 | x_out2 = torch.stack( 511 | [ 512 | xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], 513 | xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], 514 | ], 515 | -1, 516 | ) 517 | 518 | x_out2 = x_out2.flatten(3) 519 | return x_out2.type_as(x) 520 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import sys 7 | import time 8 | import argparse 9 | import math 10 | import json 11 | import regex as re 12 | import contextlib 13 | import shutil 14 | import itertools 15 | import pandas as pd 16 | import numpy as np 17 | from pathlib import Path 18 | from typing import Optional, List 19 | from collections import defaultdict, Counter 20 | from tqdm.auto import tqdm 21 | 22 | import torch 23 | import torch._dynamo.config 24 | import torch._inductor.config 25 | 26 | from cache import add_cache_arguments, cache_compatibility, get_cache_constructor 27 | from model import Transformer 28 | from generation_utils import ( 29 | add_generation_arguments, 30 | compile_funcs, 31 | compute_max_seq_length, 32 | device_sync, 33 | get_cache_stats, 34 | merge_cache_config, 35 | reset_caches, 36 | setup_caches, 37 | ) 38 | from tokenizer import encode, TokenizerInterface 39 | 40 | torch._inductor.config.coordinate_descent_tuning = True 41 | torch._inductor.config.triton.unique_kernel_names = True 42 | torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future 43 | DEBUG_COMPILE = False 44 | if DEBUG_COMPILE: 45 | import logging 46 | 47 | level = logging.DEBUG 48 | torch._logging.set_logs(dynamo=level, inductor=level) 49 | torch._dynamo.config.verbose = True 50 | 51 | default_device = "cuda" if torch.cuda.is_available() else "cpu" 52 | 53 | # support running without installing as a package 54 | wd = Path(__file__).parent.parent.resolve() 55 | sys.path.append(str(wd)) 56 | 57 | from tokenizer import get_tokenizer 58 | from generation_utils import load_model, generate 59 | from task import TASK_MAPPING, AutoTask 60 | 61 | 62 | def flatten_dict(in_dict: dict) -> dict: 63 | out_dict = {} 64 | for k, v in in_dict.items(): 65 | if type(v) == dict: 66 | for kk, vv in v.items(): 67 | out_dict[f"{k}_{kk}"] = vv 68 | else: 69 | out_dict[k] = v 70 | return out_dict 71 | 72 | 73 | def compress_list(l): 74 | if len(l) < 3: 75 | return l 76 | else: 77 | counter = Counter(l) 78 | return [f"{k}x{v}" for k, v in counter.items()] 79 | 80 | 81 | def args_to_str(args): 82 | if "debug" in args.cache_strategy[0]: 83 | debug_suffix = "__debug" 84 | cache_strategy = [ 85 | re.sub(r"debug_+", "", cs).strip() for cs in args.cache_strategy 86 | ] 87 | else: 88 | cache_strategy = args.cache_strategy 89 | debug_suffix = "" 90 | RELEVANT_CACHE_KWARGS = list( 91 | sorted( 92 | set( 93 | itertools.chain( 94 | *[get_cache_constructor(cs)[1] for cs in cache_strategy] 95 | ) 96 | ) 97 | ) 98 | ) 99 | 100 | def process_num(n): 101 | # Return integer floats as "1" not 1.0 102 | # Otherwise, no op 103 | if type(n) == float and int(n) == n: 104 | return int(n) 105 | return n 106 | 107 | RELEVANT_CACHE_KWARGS.append("cache_length_pattern") 108 | RELEVANT_CACHE_KWARGS.append("cache_strategy_pattern") 109 | if hasattr(args, "attn_top_k") and args.attn_top_k != 1.0: 110 | RELEVANT_CACHE_KWARGS.append("attn_top_k") 111 | 112 | args_dict = vars(args).copy() 113 | 114 | # Hybrid Strategies will be too long to save in a file name so just need to pick the strategy 115 | if "hybrid_strategies" in args_dict: 116 | args_dict["hybrid_strategies"] = [ 117 | x["strategy"] for x in args_dict["hybrid_strategies"] 118 | ] 119 | 120 | return ( 121 | "__".join( 122 | sorted( 123 | [ 124 | f"{k}=" + ",".join(compress_list([str(process_num(m)) for m in v])) 125 | if type(v) == list 126 | else f"{k}={process_num(v)}" 127 | for k, v in args_dict.items() 128 | if k in RELEVANT_CACHE_KWARGS 129 | ] 130 | ) 131 | ) 132 | + debug_suffix 133 | ) 134 | 135 | 136 | def run_task( 137 | args: argparse.Namespace, 138 | task: AutoTask, 139 | model: Transformer, 140 | prefill: callable, 141 | decode_one_token: callable, 142 | tokenizer: TokenizerInterface, 143 | is_chat: bool = False, 144 | profile: Optional[Path] = None, 145 | feed_long_prompts=False, 146 | decode_first_token=False, 147 | device=default_device, 148 | cache_kwargs: dict = {}, 149 | use_tp: bool = False, 150 | rank: int = None, 151 | terminator_ids: List[int] = None, 152 | ): 153 | aggregate_metrics = defaultdict(list) 154 | predictions = [] 155 | all_probs = [] 156 | task_metrics = {} 157 | 158 | test = task.get_test() 159 | 160 | if len(test) == 0: 161 | print( 162 | f"No test data found for {task.__class__.__name__}. Skipping. Possibly all filtered out by tokenizer for being too long." 163 | ) 164 | return None, None, None 165 | 166 | prompts = test["prompt"] 167 | 168 | inputs = [ 169 | encode(tokenizer, prompt, device="cpu", is_chat=is_chat) 170 | for prompt in tqdm(prompts, desc="Encoding Prompts") 171 | ] 172 | 173 | if task.requires_perplexity: 174 | assert ( 175 | len(test["labels"][0]) == 1 176 | ), "Only one label supported for perplexity tasks" 177 | label_ids = [ 178 | encode(tokenizer, label[0], device="cpu", is_chat=False, bos=False) 179 | for label in tqdm(test["labels"], desc="Encoding Labels") 180 | ] 181 | _, max_seq_length = compute_max_seq_length(model, inputs, label_ids, 0) 182 | else: 183 | label_ids = None 184 | _, max_seq_length = compute_max_seq_length(model, inputs, None, task.max_tokens) 185 | 186 | # Estimate median sequence length 187 | median_seq_length = int(np.median([len(i) for i in inputs]) + task.max_tokens / 2) 188 | 189 | target_length = ( 190 | max_seq_length 191 | if any([x in {"full", "hybrid"} or "debug" in x for x in args.cache_strategy]) 192 | else median_seq_length 193 | ) 194 | 195 | task_cache_kwargs = setup_caches( 196 | model, tokenizer, device, target_length, cache_kwargs.copy() 197 | ) 198 | 199 | for i in tqdm(range(len(inputs))): 200 | input = inputs[i].to(device) 201 | next_tokens = None if label_ids is None else label_ids[i].to(device) 202 | prompt_length = input.size(0) 203 | max_new_tokens = min(task.max_tokens, max_seq_length - prompt_length) 204 | assert max_new_tokens > 0, f"Prompt too long for model: {prompt_length}" 205 | 206 | device_sync(device=device) # MKG 207 | 208 | if not profile or (use_tp and rank != 0): 209 | prof = contextlib.nullcontext() 210 | else: 211 | torch.profiler._utils._init_for_cuda_graphs() 212 | prof = torch.profiler.profile() 213 | with prof: 214 | y, probs, perf_stats = generate( 215 | model, 216 | input, 217 | prefill, 218 | decode_one_token, 219 | max_new_tokens=max_new_tokens, 220 | next_tokens=next_tokens, 221 | terminator_ids=terminator_ids if next_tokens is None else None, 222 | attn_top_k=args.attn_top_k, 223 | feed_long_prompts=feed_long_prompts, 224 | decode_first_token=decode_first_token, 225 | ) 226 | 227 | for k, v in perf_stats.items(): 228 | aggregate_metrics[k].append(v) 229 | 230 | if next_tokens is not None: 231 | nll = -torch.tensor( 232 | [ 233 | torch.log(probs[j][next_tokens[j]]) 234 | for j in range(next_tokens.size(0)) 235 | ] 236 | ) 237 | for k in range(500, len(nll), 500): 238 | aggregate_metrics[f"ppl@{k}"].append( 239 | float(torch.exp(torch.mean(nll[:k])).item()) 240 | ) 241 | aggregate_metrics["ppl"].append(float(torch.exp(torch.mean(nll)).item())) 242 | 243 | if hasattr(prof, "export_chrome_trace"): 244 | if use_tp: 245 | prof.export_chrome_trace(f"{profile}_rank_{rank}.json") 246 | else: 247 | prof.export_chrome_trace(f"{profile}.json") 248 | device_sync(device=device) # MKG 249 | 250 | cache_stats = get_cache_stats(model, prompt_length, perf_stats["decode_tokens"]) 251 | for k, v in cache_stats.items(): 252 | aggregate_metrics[k].append(v) 253 | 254 | if ( 255 | not task.requires_perplexity 256 | ): # Perplexity tasks don't decode from model so don't save predictions 257 | # Decode: remove EoT and prompt 258 | end = y.size(0) 259 | if y[-1] in terminator_ids: 260 | end = -1 261 | pred = tokenizer.decode(y[prompt_length:end].tolist()) 262 | 263 | if args.debug: 264 | print(f"Prediction: {pred}") 265 | 266 | predictions.append(pred) 267 | if task.requires_logits: 268 | all_probs.append( 269 | {k: v for k, v in zip(tokenizer.get_vocab(), probs[-1].tolist())} 270 | ) 271 | 272 | # Reset KV Cache state 273 | reset_caches(model) 274 | 275 | print( 276 | f"Average tokens/sec: {torch.mean(torch.tensor(aggregate_metrics['total_toks_per_sec'])).item():.2f}" 277 | ) 278 | max_mem_gb = torch.cuda.max_memory_reserved() / 1e9 279 | print(f"Memory used: {max_mem_gb} GB") 280 | task_metrics["max_memory_gb"] = max_mem_gb 281 | 282 | for k, v in aggregate_metrics.items(): 283 | task_metrics[k] = sum(v) / len(v) 284 | 285 | # For toks_per_sec, we also want to report the average of the highest 10% toks/second 286 | # This is useful to get a sense of toks / second without the one-time impact of compilation 287 | if "toks_per_sec" in k: 288 | # Useful to save toks_per_sec for each example for better understanding of how it changes over time with compile 289 | task_metrics[k] = v 290 | # Also save the top 10% average (likely unaffected by compile) 291 | v.sort() 292 | cutoff = math.ceil(len(v) / 10) 293 | task_metrics[f"{k}_top_10p"] = sum(v[-cutoff:]) / cutoff 294 | 295 | if k == "total_seconds": 296 | task_metrics[f"{k}_min"] = min(aggregate_metrics[k]) 297 | task_metrics[f"{k}_max"] = max(aggregate_metrics[k]) 298 | task_metrics[f"{k}_median"] = float(np.median(aggregate_metrics[k])) 299 | 300 | if task.requires_perplexity: 301 | pred_df = None 302 | else: 303 | pred_units = all_probs if task.requires_logits else predictions 304 | task_metrics.update(flatten_dict(task.test_metrics(pred_units))) 305 | pred_df = pd.DataFrame({"prompt": prompts, "prediction": predictions}) 306 | 307 | return task_metrics, pred_df, task_cache_kwargs 308 | 309 | 310 | def main( 311 | args: argparse.Namespace, 312 | tasks: List[str], 313 | debug: bool = False, 314 | checkpoint_path: Path = Path( 315 | "checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth" 316 | ), 317 | profile: Optional[Path] = None, 318 | compile=True, 319 | feed_long_prompts=False, 320 | decode_first_token=False, 321 | device=default_device, 322 | cache_kwargs: dict = {}, 323 | out_dir: Path = None, 324 | ) -> None: 325 | """Generates text samples based on a pre-trained Transformer model and tokenizer.""" 326 | assert checkpoint_path.is_file(), checkpoint_path 327 | 328 | tokenizer_path = checkpoint_path.parent / "tokenizer.model" 329 | if not tokenizer_path.is_file(): 330 | # If there's no tokenizer.model, try to load the tokenizer from the parent directory 331 | # NOTE: We assume the tokenizer in the parent directory is compatible with huggingface transformers 332 | tokenizer_path = checkpoint_path.parent 333 | 334 | global print 335 | from tp import maybe_init_dist 336 | 337 | rank = maybe_init_dist() 338 | use_tp = rank is not None 339 | if use_tp: 340 | if rank != 0: 341 | # only print on rank 0 342 | print = lambda *args, **kwargs: None 343 | 344 | print(f"Using device={device}") 345 | precision = torch.bfloat16 346 | is_chat = ( 347 | "chat" in str(checkpoint_path).lower() 348 | or "instruct" in str(checkpoint_path).lower() 349 | ) 350 | 351 | print("Loading model ...") 352 | t0 = time.time() 353 | model = load_model(checkpoint_path, device, precision, use_tp) 354 | 355 | device_sync(device=device) # MKG 356 | print(f"Time to load model: {time.time() - t0:.02f} seconds") 357 | 358 | tokenizer = get_tokenizer(tokenizer_path, checkpoint_path, is_chat=is_chat) 359 | 360 | if cache_kwargs["cache_strategy"] == "hybrid": 361 | # We need to pass the special and punctuation token ids to the cache via cache_kwargs 362 | cache_kwargs["token_ids"] = { 363 | "special": tokenizer.special_ids(), 364 | "punctuation": tokenizer.punctuation_ids(), 365 | } 366 | 367 | terminator_ids = tokenizer.get_terminator_ids() 368 | 369 | torch.manual_seed(1234) 370 | 371 | task_kwargs = { 372 | "model_max_length": model.config.max_length, 373 | "num_samples": args.num_samples, 374 | "tokenizer": tokenizer.encode_prompt if is_chat else tokenizer.encode, 375 | "seq_length": args.seq_length, 376 | } 377 | if tasks == ["all"]: 378 | # Evaluate all tasks 379 | tasks = list(TASK_MAPPING.keys()) 380 | eval_tasks = {task: AutoTask.from_name(task, **task_kwargs) for task in tasks} 381 | 382 | task_metrics = defaultdict(dict) 383 | args_fn = out_dir / "args.json" 384 | all_out_fn = out_dir / "all_metrics.json" 385 | for task_name, task in eval_tasks.items(): 386 | print(f"Running task {task_name} ...") 387 | task_out_fn = out_dir / f"{task_name}_metrics.json" 388 | task_args_out_fn = out_dir / f"{task_name}_args.json" 389 | pred_out_fn = out_dir / f"{task_name}_predictions.csv" 390 | if task_out_fn.exists() and not cache_kwargs["overwrite"]: 391 | print(f"Task {task_name} already evaluated. Skipping.") 392 | with open(task_out_fn, "r") as fd: 393 | task_metrics[task_name] = json.load(fd) 394 | else: 395 | prefill, decode_one_token = compile_funcs(compile) 396 | task_metrics[task_name], predictions, task_args = run_task( 397 | args, 398 | task, 399 | model, 400 | prefill, 401 | decode_one_token, 402 | tokenizer, 403 | is_chat, 404 | profile, 405 | feed_long_prompts, 406 | decode_first_token, 407 | device, 408 | cache_kwargs, 409 | use_tp, 410 | rank, 411 | terminator_ids, 412 | ) 413 | 414 | if task_metrics[task_name] is None: 415 | continue 416 | 417 | if predictions is not None: 418 | predictions.to_csv(pred_out_fn, index=False) 419 | 420 | if debug: 421 | print(f"Results for {task_name}:") 422 | print(task_metrics[task_name]) 423 | 424 | with open(task_out_fn, "w") as fd: 425 | print(f"Saving results for {task_name} to {task_out_fn}") 426 | json.dump(task_metrics[task_name], fd, indent=4) 427 | 428 | with open(task_args_out_fn, "w") as fd: 429 | print(f"Saving dynamic args for {task_name} to {task_args_out_fn}") 430 | # Convert Path objects to strings 431 | task_args_json = { 432 | k: str(v) if isinstance(v, Path) else v 433 | for k, v in task_args.items() 434 | } 435 | json.dump(task_args_json, fd, indent=4) 436 | 437 | if not args_fn.exists(): 438 | # Only save args once and only save if we've gotten through a full eval and are ready to dump metrics 439 | with open(args_fn, "w") as fd: 440 | # Convert Path objects to strings 441 | cache_kwargs_json = { 442 | k: str(v) if isinstance(v, Path) else v 443 | for k, v in cache_kwargs.items() 444 | } 445 | json.dump(cache_kwargs_json, fd, indent=4) 446 | 447 | with open(all_out_fn, "w") as fd: 448 | json.dump(task_metrics, fd, indent=4) 449 | 450 | 451 | def setup(args) -> Path: 452 | sub_dir = args_to_str(args) if args.out_dir is None else args.out_dir 453 | out_dir = ( 454 | Path(__file__).parent 455 | / "results" 456 | / args.checkpoint_path.parent.name 457 | / "__".join(compress_list(args.cache_strategy)) 458 | / sub_dir 459 | ) 460 | 461 | print(f"Saving to {out_dir}") 462 | # Make out_dir and don't err out if it already exists 463 | if out_dir.exists(): 464 | print(f"Output directory {out_dir} already exists.") 465 | if args.overwrite: 466 | print(f"Removing {out_dir}.") 467 | shutil.rmtree(out_dir) 468 | out_dir.mkdir(parents=True, exist_ok=True) 469 | 470 | cache_compatibility(args) 471 | 472 | for k, v in vars(args).items(): 473 | print(f"{k} -> {v}") 474 | 475 | return out_dir 476 | 477 | 478 | def add_eval_args(parser): 479 | parser.add_argument( 480 | "--tasks", 481 | type=str, 482 | nargs="+", 483 | default=["truthfulqa"], 484 | choices=list(TASK_MAPPING.keys()) + ["all"], 485 | help="List of tasks to be evaluated.", 486 | ) 487 | 488 | parser.add_argument( 489 | "--out_dir", 490 | type=Path, 491 | default=None, 492 | help="Output directory for results. If not specified, will be a concatenation of the program args.", 493 | ) 494 | 495 | parser.add_argument( 496 | "--debug", 497 | default=False, 498 | action="store_true", 499 | help="Debug mode uses first 10 examples in dataset.", 500 | ) 501 | 502 | parser.add_argument( 503 | "--num_samples", 504 | type=int, 505 | default=-1, 506 | help="Number of examples to sample for evaluation. Defaults to None, which uses the full dataset.", 507 | ) 508 | 509 | parser.add_argument( 510 | "--overwrite", 511 | default=False, 512 | action="store_true", 513 | help="Whether to over-write existing results if they exist.", 514 | ) 515 | 516 | # Only for --tasks PG19 517 | parser.add_argument( 518 | "--seq_length", 519 | type=int, 520 | default=None, 521 | help="Specify the number of tokens for the dataset.", 522 | ) 523 | 524 | parser.add_argument( 525 | "--cache_config", 526 | type=str, 527 | default=None, 528 | help="Name of YAML file in ./cache_configs.", 529 | ) 530 | 531 | parser.add_argument( 532 | "--decode_first_token", 533 | default=False, 534 | action="store_true", 535 | help="If True will truncate cache after prefill and then decode the first token.", 536 | ) 537 | 538 | 539 | if __name__ == "__main__": 540 | parser = argparse.ArgumentParser( 541 | description="Evaluation script for different KV-Cache Compression Algorithms." 542 | ) 543 | 544 | add_eval_args(parser) 545 | add_generation_arguments(parser) 546 | add_cache_arguments(parser) 547 | 548 | args = merge_cache_config(parser.parse_args()) 549 | 550 | if args.tasks[0] == "all": 551 | args.tasks = list(TASK_MAPPING.keys()) 552 | print(f"Running all tasks: {', '.join(args.tasks)}") 553 | 554 | out_dir = setup(args) 555 | 556 | main( 557 | args, 558 | args.tasks, 559 | args.debug, 560 | args.checkpoint_path, 561 | args.profile, 562 | args.compile, 563 | args.feed_long_prompts, 564 | args.decode_first_token, 565 | args.device, 566 | cache_kwargs=vars(args), 567 | out_dir=out_dir, 568 | ) 569 | -------------------------------------------------------------------------------- /generation_utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import time 3 | from typing import Optional, Tuple 4 | from pathlib import Path 5 | 6 | import torch 7 | import torch._dynamo.config 8 | import torch._inductor.config 9 | from torch.nn.attention import SDPBackend, sdpa_kernel 10 | 11 | import argparse 12 | import yaml 13 | from model import Transformer, find_multiple 14 | from tokenizer import TokenizerInterface 15 | 16 | default_device = "cuda" if torch.cuda.is_available() else "cpu" 17 | 18 | 19 | def snake_to_capitalized(s): 20 | return " ".join(word.capitalize() for word in s.split("_")) 21 | 22 | 23 | def print_stats(stats_dict): 24 | # Separate the stats into layered and non-layered 25 | layered_stats = {} 26 | non_layered_stats = {} 27 | 28 | for key, value in stats_dict.items(): 29 | parts = key.rsplit("_", 1) 30 | if len(parts) == 2 and parts[1].isdigit(): 31 | stat = snake_to_capitalized(parts[0]) 32 | layer = int(parts[1]) 33 | if stat not in layered_stats: 34 | layered_stats[stat] = [] 35 | layered_stats[stat].append((layer, value)) 36 | else: 37 | non_layered_stats[snake_to_capitalized(key)] = value 38 | 39 | # Print non-layered stats 40 | for key, value in non_layered_stats.items(): 41 | print(f"{key}: {value:.02f}") 42 | 43 | # Print layered stats 44 | for stat in sorted(layered_stats.keys()): 45 | layers_list = sorted(layered_stats[stat]) 46 | layers_str = ", ".join(f"{layer}={value:.02f}" for layer, value in layers_list) 47 | print(f"{stat} By Layer: {layers_str}") 48 | 49 | 50 | def add_generation_arguments(parser: argparse.ArgumentParser): 51 | group = parser.add_argument_group("generation_args") 52 | # Generation hparams 53 | group.add_argument( 54 | "--checkpoint_path", 55 | type=Path, 56 | default=Path(__file__).resolve().parent 57 | / "checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/model.pth", 58 | help="Model checkpoint path.", 59 | ) 60 | 61 | group.add_argument("--profile", type=Path, default=None, help="Profile path.") 62 | 63 | group.add_argument( 64 | "--compile", action="store_true", help="Whether to compile the model." 65 | ) 66 | 67 | group.add_argument( 68 | "--device", type=str, default=default_device, help="Device to use" 69 | ) 70 | 71 | group.add_argument( 72 | "--attn_top_k", 73 | type=float, 74 | default=1.0, 75 | help="Fraction of top-K attentions over which to compute values. 1.0 means all V are used regardless of attention weight (QK).", 76 | ) 77 | 78 | 79 | def merge_cache_config(args): 80 | if not args.cache_config: 81 | return args 82 | # Get parent directory of current file 83 | if not args.cache_config.endswith(".yaml"): 84 | args.cache_config = args.cache_config + ".yaml" 85 | yaml_fn = Path(__file__).parent / "cache_configs" / args.cache_config 86 | assert yaml_fn.exists(), f"Cache config file {yaml_fn} does not exist." 87 | with open(yaml_fn, "r") as f: 88 | cache_kwargs = yaml.safe_load(f) 89 | # Over-write args with cache_kwargs 90 | args = argparse.Namespace(**{**vars(args), **cache_kwargs}) 91 | return args 92 | 93 | 94 | def compute_max_seq_length( 95 | model, prompt_lens: list[int], target_lens: list[int], max_new_tokens: int 96 | ) -> int: 97 | max_prompt_length = max(len(prompt_lens[i]) for i in range(len(prompt_lens))) 98 | # Should either pass target_lens or max_new_tokens 99 | max_target_lens = ( 100 | 0 101 | if target_lens is None 102 | else max(len(target_lens[i]) for i in range(len(target_lens))) 103 | ) 104 | max_new_tokens = max(max_new_tokens, max_target_lens) 105 | max_seq_length = max_prompt_length + max_new_tokens 106 | if max_seq_length > model.config.block_size: 107 | print( 108 | f"Warning: The longest prompt puts the desired max_seq_length at {max_seq_length}, which is greater than models max of {model.config.block_size}." 109 | ) 110 | print(f"Setting to model's max_seq_length of {model.config.block_size}.") 111 | max_seq_length = model.config.block_size 112 | print(f"Maximum context length of {max_seq_length} tokens.") 113 | return max_prompt_length, max_seq_length 114 | 115 | 116 | def device_sync(device): 117 | if "cuda" in device: 118 | torch.cuda.synchronize(device) 119 | elif ("cpu" in device) or ("mps" in device): 120 | pass 121 | else: 122 | print(f"device={device} is not yet suppported") 123 | 124 | 125 | def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): 126 | logits = logits / max(temperature, 1e-5) 127 | 128 | if top_k is not None: 129 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 130 | pivot = v.select(-1, -1).unsqueeze(-1) 131 | logits = torch.where(logits < pivot, -float("Inf"), logits) 132 | probs = torch.nn.functional.softmax(logits, dim=-1) 133 | return probs 134 | 135 | 136 | def greedy(logits, next_token): 137 | probs = torch.nn.functional.softmax(logits[0, -1], dim=-1) 138 | if next_token is None: 139 | idx_next = torch.argmax(probs, keepdim=True).to(dtype=torch.int) 140 | else: 141 | idx_next = next_token 142 | return idx_next, probs 143 | 144 | 145 | def prefill( 146 | model: Transformer, 147 | x: torch.Tensor, 148 | input_pos: torch.Tensor, 149 | next_token: torch.Tensor = None, 150 | **sampling_kwargs, 151 | ) -> torch.Tensor: 152 | # input_pos: [B, S] 153 | causal_mask = ( 154 | torch.tril(torch.ones(len(input_pos), len(input_pos), dtype=torch.bool)) 155 | .unsqueeze(0) 156 | .unsqueeze(0) 157 | .to(x.device) 158 | ) 159 | logits = model(x, input_pos, mask=causal_mask, is_prefill=True) 160 | return greedy(logits, next_token) 161 | 162 | 163 | def decode_one_token( 164 | model: Transformer, 165 | x: torch.Tensor, 166 | input_pos: torch.Tensor, 167 | next_token: torch.Tensor = None, 168 | attn_top_k: float = 1, 169 | **sampling_kwargs, 170 | ) -> Tuple[torch.Tensor, torch.Tensor]: 171 | # input_pos: [B, 1] 172 | logits = model( 173 | x, 174 | input_pos, 175 | is_prefill=False, 176 | attn_top_k=attn_top_k, 177 | ) 178 | return greedy(logits, next_token=next_token) 179 | 180 | 181 | def decode_n_tokens( 182 | model: Transformer, 183 | cur_token: torch.Tensor, 184 | input_pos: torch.Tensor, 185 | decode_one_token: callable, 186 | num_new_tokens: int, 187 | terminator_ids: Optional[list] = None, 188 | attn_top_k: float = 1, 189 | prefix: Optional[torch.Tensor] = None, 190 | **sampling_kwargs, 191 | ): 192 | new_tokens, new_probs = [], [] 193 | for i in range(num_new_tokens): 194 | with sdpa_kernel( 195 | [SDPBackend.MATH] 196 | ): # Actually better for Inductor to codegen attention here 197 | teacher_force = prefix is not None and i < len(prefix) 198 | next_token = prefix[i].view(1) if teacher_force else None 199 | next_token, next_prob = decode_one_token( 200 | model, 201 | cur_token, 202 | input_pos, 203 | next_token=next_token, 204 | attn_top_k=attn_top_k, 205 | **sampling_kwargs, 206 | ) 207 | 208 | new_tokens.append(next_token.clone()) 209 | new_probs.append(next_prob.clone()) 210 | 211 | if terminator_ids and next_token in terminator_ids and not teacher_force: 212 | break 213 | 214 | input_pos += 1 215 | cur_token = next_token.view(1, -1) 216 | 217 | return new_tokens, new_probs 218 | 219 | 220 | def model_forward(model, x, input_pos): 221 | return model(x, input_pos) 222 | 223 | 224 | def apply_pattern( 225 | pattern: list[str | int], 226 | out_size: int, 227 | extension_strategy: str = "tile", 228 | max_seq_length: int = None, 229 | ): 230 | """ 231 | Extend a given pattern across n_layers of the model. 232 | """ 233 | assert extension_strategy in { 234 | "tile", 235 | "repeat", 236 | "pyramid", 237 | "funnel", 238 | }, "extension_strategy must be one of 'tile', 'repeat', 'pyramid', or 'funnel'." 239 | assert ( 240 | out_size % len(pattern) == 0 241 | ), f"{len(pattern)} must be a divisible factor of the number of layers ({out_size})." 242 | factor = out_size // len(pattern) 243 | 244 | if extension_strategy in {"funnel", "pyramid"}: 245 | assert ( 246 | len(pattern) == 1 247 | ), "Funnel and pyramid patterns must have a single element." 248 | return apply_pyramid_pattern( 249 | pattern[0], 250 | max_seq_length, 251 | out_size, 252 | decreasing=extension_strategy == "pyramid", 253 | ) 254 | elif extension_strategy == "tile": 255 | return [item for item in pattern for _ in range(factor)] 256 | else: # Repeat 257 | return pattern * factor 258 | 259 | 260 | def normalize_cache_length( 261 | max_cache_length: float, max_seq_length: int, multiple_of: int = 8 262 | ) -> int: 263 | """ 264 | Computes the absolute cache length given the max_cache_length and max_seq_length. 265 | """ 266 | if 0 < max_cache_length <= 1: 267 | max_cache_length = round(max_seq_length * max_cache_length) 268 | else: 269 | assert int(max_cache_length) == max_cache_length 270 | max_cache_length = int(max_cache_length) 271 | if max_cache_length > max_seq_length: 272 | print( 273 | f"FYI: max_cache_length ({max_cache_length}) is greater than max_seq_length ({max_seq_length}). Setting to {max_seq_length}" 274 | ) 275 | max_cache_length = max_seq_length 276 | return min(find_multiple(max_cache_length, multiple_of), max_seq_length) 277 | 278 | 279 | def apply_pyramid_pattern( 280 | max_cache_length: int, 281 | max_seq_length: int, 282 | model_n_layer: int, 283 | decreasing: bool = True, 284 | min_cache_length: int = 256, 285 | ): 286 | # Implements https://arxiv.org/abs/2406.02069 287 | # Paper finds best beta of 14 288 | beta = 14 289 | min_allowable = min(min_cache_length, max_cache_length) 290 | total_len = max_cache_length * model_n_layer 291 | min_cache_length = total_len / (model_n_layer * beta) 292 | max_cache_length = 2 * total_len / model_n_layer 293 | diff = (max_cache_length - min_cache_length) / model_n_layer 294 | cache_lens = [min_cache_length] 295 | for l in range(1, model_n_layer - 1): 296 | cache_lens.append(min_cache_length + diff * l) 297 | cache_lens.append(max_cache_length) 298 | cache_lens = [normalize_cache_length(int(l), max_seq_length) for l in cache_lens] 299 | 300 | overflow = 0 301 | num_overflow = 0 302 | for i in range(len(cache_lens)): 303 | if cache_lens[i] < min_allowable: 304 | overflow += min_allowable - cache_lens[i] 305 | cache_lens[i] = min_allowable 306 | num_overflow += 1 307 | 308 | if num_overflow < len(cache_lens): 309 | decr_amount = overflow // (len(cache_lens) - num_overflow) 310 | for i in range(len(cache_lens)): 311 | if cache_lens[i] > min_allowable: 312 | # This will change the overall cache length slightly if min_allowable threshold is hit but should be very minor 313 | cache_lens[i] = max(min_allowable, cache_lens[i] - decr_amount) 314 | 315 | if decreasing: 316 | cache_lens = cache_lens[::-1] 317 | assert cache_lens[-1] < cache_lens[0], "Cache lengths should be decreasing." 318 | else: 319 | assert cache_lens[0] < cache_lens[-1], "Cache lengths should be increasing." 320 | 321 | return cache_lens 322 | 323 | 324 | def setup_caches( 325 | model: Transformer, 326 | tokenizer: TokenizerInterface, 327 | device: torch.device, 328 | max_seq_length: int, 329 | cache_kwargs: dict = None, 330 | ) -> dict: 331 | # Normalize max_cache_length to absolute cache length if provided as a fraction of the max seq sequence length 332 | cache_kwargs["max_seq_length"] = max_seq_length 333 | cache_kwargs["max_cache_length"] = list( 334 | map( 335 | lambda l: normalize_cache_length(l, max_seq_length), 336 | cache_kwargs["max_cache_length"], 337 | ) 338 | ) 339 | 340 | cache_kwargs["max_cache_length"] = apply_pattern( 341 | pattern=cache_kwargs["max_cache_length"], 342 | out_size=model.config.n_layer, 343 | extension_strategy=cache_kwargs["cache_length_pattern"], 344 | max_seq_length=max_seq_length, 345 | ) 346 | 347 | assert len(cache_kwargs["cache_strategy"]) == len( 348 | cache_kwargs["prompt_compression_strategy"] 349 | ), "You must specify a prompt_compression_strategy for each cache_strategy." 350 | 351 | cache_kwargs["cache_strategy"] = apply_pattern( 352 | pattern=cache_kwargs["cache_strategy"], 353 | out_size=model.config.n_layer, 354 | extension_strategy=cache_kwargs["cache_strategy_pattern"], 355 | ) 356 | cache_kwargs["prompt_compression_strategy"] = apply_pattern( 357 | pattern=cache_kwargs["prompt_compression_strategy"], 358 | out_size=model.config.n_layer, 359 | extension_strategy=cache_kwargs["cache_strategy_pattern"], 360 | ) 361 | 362 | if type(cache_kwargs["recent_window"]) != list: 363 | if cache_kwargs["recent_window"] <= 1: 364 | cache_kwargs["recent_window"] = [ 365 | max(1, int(cache_kwargs["recent_window"] * l)) 366 | for l in cache_kwargs["max_cache_length"] 367 | ] 368 | else: 369 | cache_kwargs["recent_window"] = [ 370 | max(1, min(cache_kwargs["recent_window"], l)) 371 | for l in cache_kwargs["max_cache_length"] 372 | ] 373 | 374 | assert cache_kwargs["global_tokens"] <= min( 375 | cache_kwargs["max_cache_length"] 376 | ), "Global tokens must be less than max_cache_length." 377 | 378 | if cache_kwargs["cache_strategy"][0] == "hybrid": 379 | # We need to pass the special and punctuation token ids to the cache via cache_kwargs 380 | cache_kwargs["token_ids"] = { 381 | "special": tokenizer.special_ids(), 382 | "punctuation": tokenizer.punctuation_ids(), 383 | } 384 | 385 | with torch.device(device): 386 | model.setup_caches(max_batch_size=1, **cache_kwargs) 387 | 388 | return cache_kwargs 389 | 390 | 391 | def reset_caches(model: Transformer): 392 | model.reset_caches() 393 | 394 | 395 | def get_cache_stats(model: Transformer, prompt_len: int, gen_len: int): 396 | return model.get_cache_stats(prompt_len, gen_len) 397 | 398 | 399 | @torch.no_grad() 400 | def generate( 401 | model: Transformer, 402 | prompt: torch.Tensor, 403 | prefill: callable, 404 | decode_one_token: callable, 405 | max_new_tokens: int, 406 | next_tokens: Optional[torch.Tensor] = None, 407 | terminator_ids: Optional[list] = None, 408 | feed_long_prompts: bool = False, 409 | decode_first_token: bool = False, 410 | attn_top_k: float = 1, 411 | **sampling_kwargs, 412 | ) -> torch.Tensor: 413 | """ 414 | Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. 415 | """ 416 | 417 | # create an empty tensor of the expected final shape and fill in the current tokens 418 | prompt_length = prompt.size(0) 419 | 420 | device, dtype = prompt.device, prompt.dtype 421 | 422 | min_cache_length = model.min_cache_length() 423 | # Subtract 1 in case we need one generation step over which to compute attention, etc. 424 | max_prompt_len = min_cache_length - 1 425 | prefix = None 426 | # If we asked to have prompt truncated and fed, we need to do split prompt into prompt and prefix 427 | # We also define a rare yet important edge case: if |prompt| is exactly cache length 428 | # We might have to start evictions before having had a change to record any state (attentions). 429 | # In this scenario let's decrement prompt by 1 and start "generating" on the prefix 430 | if ( 431 | feed_long_prompts and prompt_length > max_prompt_len 432 | ) or prompt_length == min_cache_length: 433 | prompt, prefix = prompt[:max_prompt_len], prompt[max_prompt_len:] 434 | max_new_tokens += len(prefix) 435 | prompt_length = max_prompt_len 436 | 437 | if decode_first_token: 438 | prompt, prefix = prompt[:-1], prompt[-1:] 439 | max_new_tokens += 1 440 | prompt_length -= 1 441 | 442 | # create an empty tensor (all -1) of the expected final shape and fill in the current tokens 443 | # GPT-Fast had this as empty but the values of empty are non-deterministic 444 | seq = torch.full((prompt_length + max_new_tokens,), -1, dtype=dtype, device=device) 445 | seq[:prompt_length] = prompt 446 | input_pos = torch.arange(0, prompt_length, device=device) 447 | 448 | if next_tokens is not None: # We are in teacher forcing mode for Perplexity task 449 | max_new_tokens = len(next_tokens) 450 | next_token = next_tokens[0].view(1) 451 | prefix = next_tokens[1:] 452 | elif prefix is not None: # We are in partial teacher forcing due to a long prompt 453 | next_token = prefix[0].view(1) 454 | prefix = prefix[1:] 455 | else: 456 | next_token = prefix = None # We are in normal generation mode 457 | 458 | # create an empty tensor (all -1) of the expected final shape and fill in the current tokens 459 | # GPT-Fast had this as empty but the values of empty are non-deterministic 460 | seq = torch.full((prompt_length + max_new_tokens,), -1, dtype=dtype, device=device) 461 | seq[:prompt_length] = prompt 462 | input_pos = torch.arange(0, prompt_length, device=device) 463 | 464 | t0 = time.perf_counter() 465 | 466 | ret = prefill( 467 | model, 468 | prompt.view(1, -1), 469 | input_pos, 470 | next_token=next_token, 471 | **sampling_kwargs, 472 | ) 473 | 474 | t1 = time.perf_counter() 475 | 476 | prefill_seconds = t1 - t0 477 | 478 | next_token = ret[0].clone() 479 | next_tok_probs = ret[1].clone() 480 | seq[prompt_length] = next_token 481 | 482 | input_pos = torch.tensor([prompt_length], device=device, dtype=torch.int) 483 | generated_tokens, generated_tok_probs = decode_n_tokens( 484 | model, 485 | next_token.view(1, -1), 486 | input_pos, 487 | decode_one_token, 488 | max_new_tokens - 1, 489 | terminator_ids=terminator_ids, 490 | prefix=prefix, 491 | attn_top_k=attn_top_k, 492 | **sampling_kwargs, 493 | ) 494 | 495 | t2 = time.perf_counter() 496 | decode_seconds = t2 - t1 497 | 498 | total_seconds = t2 - t0 499 | 500 | prefill_tokens = prompt_length 501 | decode_tokens = ( 502 | len(generated_tokens) + 1 503 | ) # +1 because we generate 1 token from prefill 504 | 505 | decode_toks_per_sec = decode_tokens / decode_seconds 506 | prefill_toks_per_sec = prefill_tokens / prefill_seconds 507 | total_toks_per_sec = decode_tokens / total_seconds 508 | 509 | perf_stats = { 510 | "prefill_tokens": prefill_tokens, 511 | "decode_tokens": decode_tokens, 512 | "prefill_toks_per_sec": prefill_toks_per_sec, 513 | "decode_toks_per_sec": decode_toks_per_sec, 514 | "total_toks_per_sec": total_toks_per_sec, 515 | "total_seconds": total_seconds, 516 | "prefill_seconds": prefill_seconds, 517 | "decode_seconds": decode_seconds, 518 | "decode_seconds_frac_of_total": decode_seconds / total_seconds, 519 | "memory_used_gb": torch.cuda.max_memory_reserved() / 1e9, 520 | } 521 | 522 | if len(generated_tokens) > 0: 523 | seq[prompt_length + 1 : prompt_length + 1 + len(generated_tokens)] = torch.cat( 524 | generated_tokens 525 | ) 526 | 527 | # Truncate seq to first instance of -1 if -1 is present 528 | if -1 in seq: 529 | seq = seq[: torch.where(seq == -1)[0][0]] 530 | 531 | return seq, [next_tok_probs] + generated_tok_probs, perf_stats 532 | 533 | 534 | def load_model(checkpoint_path, device, precision, use_tp): 535 | use_cuda = "cuda" in device 536 | with torch.device("meta"): 537 | model = Transformer.from_name(checkpoint_path.parent.name) 538 | 539 | if "int8" in str(checkpoint_path): 540 | print("Using int8 weight-only quantization!") 541 | from quantize import WeightOnlyInt8QuantHandler 542 | 543 | simple_quantizer = WeightOnlyInt8QuantHandler(model) 544 | model = simple_quantizer.convert_for_runtime() 545 | 546 | if "int4" in str(checkpoint_path): 547 | print("Using int4 weight-only quantization!") 548 | path_comps = checkpoint_path.name.split(".") 549 | groupsize = int(path_comps[-2][1:]) 550 | from quantize import WeightOnlyInt4QuantHandler 551 | 552 | simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize) 553 | model = simple_quantizer.convert_for_runtime() 554 | 555 | checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) 556 | if "model" in checkpoint and "stories" in str(checkpoint_path): 557 | checkpoint = checkpoint["model"] 558 | model.load_state_dict(checkpoint, assign=True) 559 | if use_tp: 560 | from tp import apply_tp 561 | 562 | print("Applying tensor parallel to model ...") 563 | apply_tp(model) 564 | 565 | model = model.to(device=device, dtype=precision) 566 | return model.eval() 567 | 568 | 569 | def get_model_size(model): 570 | model_size = 0 571 | for name, child in model.named_children(): 572 | if not isinstance(child, torch.nn.Embedding): 573 | for p in itertools.chain(child.parameters(), child.buffers()): 574 | model_size += p.numel() * p.dtype.itemsize 575 | return model_size 576 | 577 | 578 | def compile_funcs(compile=True): 579 | if compile: 580 | global decode_one_token, prefill 581 | decode_one_token = torch.compile( 582 | decode_one_token, 583 | fullgraph=True, 584 | # dynamic=True, 585 | mode="reduce-overhead", 586 | # options={"trace.graph_diagram": True, "trace.enabled": True} 587 | ) 588 | prefill = torch.compile( 589 | prefill, 590 | fullgraph=True, 591 | dynamic=True, 592 | # options={"trace.graph_diagram": True, "trace.enabled": True} 593 | ) 594 | return prefill, decode_one_token 595 | --------------------------------------------------------------------------------