├── .python-version ├── src └── caldera │ ├── __init__.py │ ├── utils │ ├── enums.py │ └── quantization.py │ └── decomposition │ ├── dataclasses.py │ ├── weight_compression.py │ ├── layer_quantization.py │ ├── quantized_layer.py │ └── alg.py ├── assets └── caldera_decomposition.png ├── .gitmodules ├── .gitignore ├── setup_quip_sharp.sh ├── shell_scripts ├── accelerate_config.yaml ├── deepspeed_config.json ├── run_eval_ppl.sh ├── run_save_hessians.sh ├── run_eval_zeroshot.sh ├── run_finetune_RHT.sh ├── run_finetune_glue.sh ├── run_finetune_winogrande.sh ├── run_finetune_wikitext.sh └── run_quantize_save_caldera.sh ├── scripts ├── get_sv_info.py ├── save_llama_hessians.py ├── eval_zero_shot.py ├── eval_ppl.py ├── finetune_RHT.py ├── quantize_save_llama.py ├── finetune_wikitext.py ├── finetune_winogrande.py └── finetune_glue.py ├── requirements.txt ├── quip-sharp-pyproject.toml ├── notebooks ├── test_caldera.ipynb └── eval_throughput.ipynb ├── pyproject.toml └── README.md /.python-version: -------------------------------------------------------------------------------- 1 | 3.9.0 2 | -------------------------------------------------------------------------------- /src/caldera/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1.dev0" 2 | -------------------------------------------------------------------------------- /assets/caldera_decomposition.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pilancilab/caldera/HEAD/assets/caldera_decomposition.png -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "quip-sharp"] 2 | path = quip-sharp 3 | url = git@github.com:Cornell-RelaxML/quip-sharp.git 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .python_version 3 | .conda 4 | data/ 5 | *.egg-info 6 | naomi/results 7 | .gitconfig 8 | .vscode/settings.json 9 | artifacts/ 10 | .vscode 11 | shell_scripts/custom 12 | wandb 13 | .python-version -------------------------------------------------------------------------------- /setup_quip_sharp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | git submodule init 3 | git submodule update 4 | cp quip-sharp-pyproject.toml quip-sharp/pyproject.toml 5 | cd quip-sharp && pip install --editable . && cd .. 6 | cd quip-sharp/quiptools && python setup.py install && cd ../.. 7 | 8 | -------------------------------------------------------------------------------- /src/caldera/utils/enums.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum 2 | 3 | 4 | class DevSet(IntEnum): 5 | RP1T = 0 6 | FALCON = 1 7 | 8 | class TransformerSubLayers(IntEnum): 9 | QUERY = 0 10 | KEY = 1 11 | VALUE = 2 12 | O = 3 13 | UP = 4 14 | GATE = 5 15 | DOWN = 6 -------------------------------------------------------------------------------- /shell_scripts/accelerate_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_config_file: ./shell_scripts/deepspeed_config.json 5 | zero3_init_flag: false 6 | distributed_type: DEEPSPEED 7 | downcast_bf16: 'no' 8 | machine_rank: 0 9 | main_training_function: main 10 | num_machines: 1 11 | num_processes: 4 12 | rdzv_backend: static 13 | same_network: true 14 | tpu_env: [] 15 | tpu_use_cluster: false 16 | tpu_use_sudo: false 17 | use_cpu: false 18 | main_process_port: 29500 19 | -------------------------------------------------------------------------------- /shell_scripts/deepspeed_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 2, 4 | "overlap_comm": true, 5 | "contiguous_gradients": true, 6 | "stage3_max_live_parameters": 5e8, 7 | "offload_optimizer": { 8 | "device": "cpu" 9 | }, 10 | "offload_param": { 11 | "device": "cpu" 12 | } 13 | }, 14 | "bf16": { 15 | "enabled": true, 16 | "auto_cast": true 17 | }, 18 | "train_micro_batch_size_per_gpu": "auto", 19 | "gradient_accumulation_steps": "auto" 20 | } 21 | -------------------------------------------------------------------------------- /shell_scripts/run_eval_ppl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SCRIPT_FILLED_IN=0 4 | 5 | BASE_MODEL="meta-llama/Llama-2-7b-hf" 6 | CALDERA_MODEL_SAVE_PATH="YOUR_MODEL_PATH" 7 | DEVICE="cuda:0" 8 | OUTPUT_FILENAME="YOUR FILENAME HERE" 9 | 10 | if [ $SCRIPT_FILLED_IN -eq 0 ]; then 11 | echo -e "This script is meant as a template for running scripts/eval_ppl.py. \ 12 | Please go into shell_scripts/run_eval_ppl.sh and replace BASE_MODEL, \ 13 | CALDERA_MODEL_SAVE_PATH, etc., and then set SCRIPT_FILLED_IN=1 at the top of the file." 14 | exit -0 15 | fi 16 | 17 | python scripts/eval_ppl.py \ 18 | --model_save_path $CALDERA_MODEL_SAVE_PATH \ 19 | --base_model $BASE_MODEL \ 20 | --seed 0 \ 21 | --seqlen 4096 \ 22 | --datasets wikitext2 c4 \ 23 | --device $DEVICE \ 24 | --output_path $OUTPUT_FILENAME 25 | -------------------------------------------------------------------------------- /shell_scripts/run_save_hessians.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SCRIPT_FILLED_IN=0 4 | 5 | BASE_MODEL="meta-llama/Llama-2-7b-hf" 6 | HESSIAN_SAVE_DIR="YOUR DIRECTORY HERE" 7 | DEVICES="cuda:0 cuda:1 cuda:2 cuda:3" 8 | 9 | if [ $SCRIPT_FILLED_IN -eq 0 ]; then 10 | echo -e "This script is meant as a template for running scripts/save_llama_hessians.py. \ 11 | Please go into shell_scripts/run_save_hessians.sh and replace BASE_MODEL, HESSIAN_SAVE_DIR, etc., \ 12 | and then set SCRIPT_FILLED_IN=1 at the top of the file." 13 | exit -0 14 | fi 15 | 16 | python scripts/save_llama_hessians.py \ 17 | --base_model $BASE_MODEL \ 18 | --hessian_save_path $HESSIAN_SAVE_DIR \ 19 | --devset rp1t \ 20 | --context_length 4096 \ 21 | --devset_size 256 \ 22 | --chunk_size 64 \ 23 | --batch_size 32 \ 24 | --devices $DEVICES -------------------------------------------------------------------------------- /shell_scripts/run_eval_zeroshot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SCRIPT_FILLED_IN=0 4 | 5 | BASE_MODEL="meta-llama/Llama-2-7b-hf" 6 | CALDERA_MODEL_SAVE_PATH="PATH OF .pt FILE WITH MODEL" 7 | DEVICE="cuda:0" 8 | OUTPUT_FILENAME="YOUR FILE HERE" 9 | 10 | if [ $SCRIPT_FILLED_IN -eq 0 ]; then 11 | echo -e "This script is meant as a template for running scripts/eval_zeroshot.py. \ 12 | Please go into shell_scripts/run_eval_zeroshot.sh and replace BASE_MODEL, \ 13 | CALDERA_MODEL_SAVE_PATH, etc., and then set SCRIPT_FILLED_IN=1 at the top of the file." 14 | exit -0 15 | fi 16 | 17 | python scripts/eval_zero_shot.py \ 18 | --model_save_path $CALDERA_MODEL_SAVE_PATH \ 19 | --device $DEVICE \ 20 | --tasks winogrande rte piqa arc_easy arc_challenge \ 21 | --base_model $BASE_MODEL \ 22 | --batch_size 8 \ 23 | --output_path $OUTPUT_FILENAME -------------------------------------------------------------------------------- /shell_scripts/run_finetune_RHT.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SCRIPT_FILLED_IN=0 4 | 5 | BASE_MODEL="meta-llama/Llama-2-7b-hf" 6 | CALDERA_MODEL_SAVE_PATH_NO_RHT_FT="PATH OF .pt FILE WITH MODEL TO FINETUNED" 7 | CALDERA_MODEL_SAVE_PATH_WITH_RHT_FT="PATH OF .pt FILE TO SAVE FINETUNED MODEL" 8 | 9 | if [ $SCRIPT_FILLED_IN -eq 0 ]; then 10 | echo -e "This script is meant as a template for running scripts/finetune_RHT.py. \ 11 | Please go into shell_scripts/run_finetune_glue.sh and replace BASE_MODEL, \ 12 | CALDERA_MODEL_SAVE_PATH, etc., and then set SCRIPT_FILLED_IN=1 at the top of the file." 13 | exit -0 14 | fi 15 | 16 | python scripts/finetune_RHT.py \ 17 | --base_model $BASE_MODEL \ 18 | --model_path $CALDERA_MODEL_SAVE_PATH_NO_RHT_FT \ 19 | --finetuned_save_path $CALDERA_MODEL_SAVE_PATH_WITH_RHT_FT \ 20 | --devset_size 256 \ 21 | --ctx_size 512 \ 22 | --device cuda:0 \ 23 | --ft_bs 2 \ 24 | --ft_valid_size 64 \ 25 | --RHT_learning_rate 1e-3 \ 26 | --epochs 1 -------------------------------------------------------------------------------- /shell_scripts/run_finetune_glue.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SCRIPT_FILLED_IN=0 4 | 5 | BASE_MODEL="meta-llama/Llama-2-7b-hf" 6 | CALDERA_MODEL_SAVE_PATH="PATH OF .pt FILE WITH MODEL" 7 | OUTPUT_DIR="FINETUNING OUTPUT DIRECTORY" 8 | GLUE_TASK="rte" 9 | 10 | if [ $SCRIPT_FILLED_IN -eq 0 ]; then 11 | echo -e "This script is meant as a template for running scripts/finetune_glue.py. \ 12 | Please go into shell_scripts/run_finetune_glue.sh and replace BASE_MODEL, \ 13 | CALDERA_MODEL_SAVE_PATH, etc., and then set SCRIPT_FILLED_IN=1 at the top of the file." 14 | exit -0 15 | fi 16 | 17 | accelerate launch --config_file shell_scripts/accelerate_config.yaml \ 18 | scripts/finetune_glue.py \ 19 | --task_name $GLUE_TASK \ 20 | --model_name_or_path $CALDERA_MODEL_SAVE_PATH \ 21 | --base_model $BASE_MODEL \ 22 | --output_dir $OUTPUT_DIR \ 23 | --learning_rate 3e-5 \ 24 | --num_train_epochs 15 \ 25 | --per_device_train_batch_size 1 \ 26 | --per_device_eval_batch_size 1 \ 27 | --gradient_accumulation_steps 2 \ 28 | --weight_decay 0.01 \ 29 | --num_warmup_steps 100 \ 30 | --lr_scheduler_type linear \ 31 | --report_to tensorboard \ 32 | --with_tracking \ 33 | --checkpointing_steps epoch \ 34 | --pad_to_max_length \ 35 | --seed 314 -------------------------------------------------------------------------------- /shell_scripts/run_finetune_winogrande.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SCRIPT_FILLED_IN=0 4 | 5 | BASE_MODEL="meta-llama/Llama-2-7b-hf" 6 | CALDERA_MODEL_SAVE_PATH="PATH OF .pt FILE WITH MODEL" 7 | OUTPUT_DIR="FINETUNING OUTPUT DIRECTORY" 8 | SEQ_LEN=256 9 | 10 | if [ $SCRIPT_FILLED_IN -eq 0 ]; then 11 | echo -e "This script is meant as a template for running scripts/finetune_winogrande.py. \ 12 | Please go into shell_scripts/run_finetune_winogrande.sh and replace BASE_MODEL, \ 13 | CALDERA_MODEL_SAVE_PATH, etc., and then set SCRIPT_FILLED_IN=1 at the top of the file." 14 | exit -0 15 | fi 16 | 17 | 18 | CUDA_VISIBLE_DEVICES=$DEVICES accelerate launch --config_file ./shell_scripts/accelerate_config.yaml \ 19 | scripts/finetune_winogrande.py \ 20 | --model_name_or_path $CALDERA_MODEL_SAVE_PATH \ 21 | --base_model $BASE_MODEL \ 22 | --output_dir $OUTPUT_DIR \ 23 | --learning_rate 1e-5 \ 24 | --weight_decay 0.01 \ 25 | --lr_scheduler_type linear \ 26 | --warmup_ratio 0.033 \ 27 | --num_warmup_steps 100 \ 28 | --seed 202 \ 29 | --max_seq_length $SEQ_LEN \ 30 | --num_train_epochs 1 \ 31 | --per_device_train_batch_size 10 \ 32 | --per_device_eval_batch_size 10 \ 33 | --gradient_accumulation_steps 1 \ 34 | --logging_steps 10 \ 35 | --save_steps 200 \ 36 | --bf16 \ 37 | --with_tracking true \ 38 | --report_to tensorboard -------------------------------------------------------------------------------- /shell_scripts/run_finetune_wikitext.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SCRIPT_FILLED_IN=0 4 | 5 | BASE_MODEL="meta-llama/Llama-2-7b-hf" 6 | CALDERA_MODEL_SAVE_PATH="PATH OF .pt FILE WITH MODEL" 7 | OUTPUT_DIR="FINETUNING OUTPUT DIRECTORY" 8 | BLOCK_SIZE=256 9 | 10 | if [ $SCRIPT_FILLED_IN -eq 0 ]; then 11 | echo -e "This script is meant as a template for running scripts/finetune_wikitext.py. \ 12 | Please go into shell_scripts/run_finetune_wikitext.sh and replace BASE_MODEL, \ 13 | CALDERA_MODEL_SAVE_PATH, etc., and then set SCRIPT_FILLED_IN=1 at the top of the file." 14 | exit -0 15 | fi 16 | 17 | accelerate launch --config_file shell_scripts/accelerate_config.yaml \ 18 | scripts/finetune_wikitext.py \ 19 | --model_save_path $CALDERA_MODEL_SAVE_PATH \ 20 | --base_model $BASE_MODEL \ 21 | --output_dir $OUTPUT_DIR \ 22 | --dataset_name wikitext \ 23 | --dataset_config_name wikitext-2-raw-v1 \ 24 | --block_size $BLOCK_SIZE \ 25 | --learning_rate 3e-6 \ 26 | --num_train_epochs 3 \ 27 | --per_device_train_batch_size 1 \ 28 | --per_device_eval_batch_size 1 \ 29 | --gradient_accumulation_steps 1 \ 30 | --weight_decay 0.001 \ 31 | --warmup_ratio 0.02 \ 32 | --lr_scheduler_type linear \ 33 | --bf16 \ 34 | --logging_steps 10 \ 35 | --save_steps 200 \ 36 | --eval_steps 100 \ 37 | --evaluation_strategy steps \ 38 | --prediction_loss_only 39 | -------------------------------------------------------------------------------- /shell_scripts/run_quantize_save_caldera.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SCRIPT_FILLED_IN=0 4 | 5 | # CALDERA PARAMETERS 6 | RANK=256 7 | LR_BITS=4 8 | CALDERA_ITERS=20 9 | LPLR_ITERS=50 10 | # END CALDERA PARAMETERS 11 | 12 | BASE_MODEL="meta-llama/Llama-2-7b-hf" 13 | DEVICES="cuda:0 cuda:1 cuda:2 cuda:3" 14 | 15 | CALDERA_MODEL_SAVE_PATH="YOUR_MODEL_OUTPUT_FILENAME" 16 | HESSIAN_SAVE_DIR="~/.cache/huggingface/hub/models--relaxml--Hessians-Llama-2-7b-6144/snapshots/SNAPSHOT_ID_HERE" 17 | 18 | if [ $SCRIPT_FILLED_IN -eq 0 ]; then 19 | echo -e "This script is meant as a template for running scripts/quantize_save_llama.py. \ 20 | Please go into shell_scripts/run_quantize_save_caldera.sh and replace the CALDERA parameters, etc., \ 21 | and then set SCRIPT_FILLED_IN=1 at the top of the file." 22 | exit -0 23 | fi 24 | 25 | QUANT_PARAMS="--Q_bits 2 \ 26 | --compute_low_rank_factors true \ 27 | --compute_quantized_component true \ 28 | --L_bits $LR_BITS \ 29 | --R_bits $LR_BITS \ 30 | --lattice_quant_LR true \ 31 | --rank $RANK \ 32 | --activation_aware_LR true \ 33 | --activation_aware_Q true \ 34 | --hadamard_transform true \ 35 | --iters $CALDERA_ITERS \ 36 | --lplr_iters $LPLR_ITERS \ 37 | --rand_svd false \ 38 | --update_order LR Q \ 39 | --Q_hessian_downdate true \ 40 | --ft_rank 0 \ 41 | --random_seed 42" 42 | 43 | python scripts/quantize_save_llama.py \ 44 | --hessian_save_path $HESSIAN_SAVE_DIR \ 45 | --model_save_path $CALDERA_MODEL_SAVE_PATH \ 46 | --devices $DEVICES \ 47 | --base_model $BASE_MODEL \ 48 | $QUANT_PARAMS -------------------------------------------------------------------------------- /scripts/get_sv_info.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM, HfArgumentParser 3 | import gc 4 | from tqdm import tqdm 5 | from dataclasses import dataclass, field 6 | import os 7 | 8 | 9 | @dataclass 10 | class Arguments: 11 | base_model: str = field(default="meta-llama/Llama-2-7b-hf", metadata={ 12 | "help": ("Path of the original model, as either a local or a " 13 | "Huggingface path") 14 | }) 15 | device: str = field(default="cuda:0", metadata={ 16 | "help": "Device on which to run evaluation." 17 | }) 18 | save_dir: str = field(default="./data") 19 | 20 | SUBLAYERS = { 21 | "q_proj": lambda layer: layer.self_attn.q_proj, 22 | "k_proj": lambda layer: layer.self_attn.k_proj, 23 | "v_proj": lambda layer: layer.self_attn.v_proj, 24 | "o_proj": lambda layer: layer.self_attn.o_proj, 25 | "up_proj": lambda layer: layer.mlp.up_proj, 26 | "gate_proj": lambda layer: layer.mlp.gate_proj, 27 | "down_proj": lambda layer: layer.mlp.down_proj 28 | } 29 | 30 | def main(base_model, device, save_dir): 31 | model = AutoModelForCausalLM.from_pretrained( 32 | base_model, torch_dtype='auto', low_cpu_mem_usage=True 33 | ).cpu() 34 | 35 | for label in SUBLAYERS: 36 | A = SUBLAYERS[label](model.model.layers[0]).weight 37 | n = min(A.shape[0], A.shape[1]) 38 | SVs = torch.zeros(n, 32, requires_grad=False) 39 | for i, layer in tqdm(enumerate(model.model.layers)): 40 | A = SUBLAYERS[label](layer).weight.to(device).float().detach() 41 | _, S, _ = torch.linalg.svd(A) 42 | S = S.cpu() 43 | del A 44 | gc.collect() 45 | torch.cuda.empty_cache() 46 | 47 | SVs[:, i] = S 48 | 49 | torch.save({ 50 | "SV_data": SVs.detach(), 51 | "means": torch.mean(SVs.detach(), dim=1), 52 | "stdevs": torch.std(SVs.detach(), dim=1) 53 | }, f"{save_dir}/{label}_sv_info.pt") 54 | 55 | if __name__ == "__main__": 56 | parser = HfArgumentParser([Arguments]) 57 | args = parser.parse_args_into_dataclasses()[0] 58 | os.makedirs(os.path.dirname(args.save_dir), exist_ok=True) 59 | main(args.base_model, args.device, args.save_dir) 60 | 61 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | packaging==24.0 2 | torch==2.2.1 3 | torchaudio==2.2.1 4 | torchvision==0.17.1 5 | absl-py==2.1.0 6 | accelerate==0.28.0 7 | aiohttp==3.9.3 8 | aiosignal==1.3.1 9 | annotated-types==0.6.0 10 | anyio==4.3.0 11 | argparse-dataclass==2.0.0 12 | asttokens==2.4.1 13 | attrs==23.2.0 14 | chardet==5.2.0 15 | click==8.1.7 16 | colorama==0.4.6 17 | comm==0.2.2 18 | contourpy==1.2.0 19 | cycler==0.12.1 20 | DataProperty==1.0.1 21 | datasets==2.18.0 22 | debugpy==1.8.1 23 | decorator==5.1.1 24 | deepspeed==0.14.0 25 | dill==0.3.8 26 | distro==1.9.0 27 | evaluate==0.4.1 28 | executing==2.0.1 29 | fonttools==4.50.0 30 | frozenlist==1.4.1 31 | fsspec==2024.2.0 32 | glog==0.3.1 33 | h11==0.14.0 34 | hjson==3.1.0 35 | httpcore==1.0.4 36 | httpx==0.27.0 37 | huggingface-hub[cli]==0.21.4 38 | ipykernel==6.29.3 39 | ipython==8.22.2 40 | jedi==0.19.1 41 | joblib==1.3.2 42 | jsonlines==4.0.0 43 | jupyter_client==8.6.1 44 | jupyter_core==5.7.2 45 | kiwisolver==1.4.5 46 | lm-eval==0.3.0 47 | lxml==5.1.0 48 | matplotlib==3.8.3 49 | matplotlib-inline==0.1.6 50 | mbstrdecoder==1.1.3 51 | more-itertools==10.2.0 52 | multidict==6.0.5 53 | multiprocess==0.70.16 54 | nest-asyncio==1.6.0 55 | ninja==1.11.1.1 56 | nltk==3.8.1 57 | numexpr==2.9.0 58 | openai==1.14.2 59 | pandas==2.2.1 60 | parso==0.8.3 61 | pathvalidate==3.2.0 62 | peft==0.10.0 63 | pexpect==4.9.0 64 | platformdirs==4.2.0 65 | portalocker==2.8.2 66 | primefac==2.0.12 67 | prompt-toolkit==3.0.43 68 | psutil==5.9.8 69 | ptyprocess==0.7.0 70 | pure-eval==0.2.2 71 | py-cpuinfo==9.0.0 72 | pyarrow==15.0.2 73 | pyarrow-hotfix==0.6 74 | pybind11==2.11.1 75 | pycountry==23.12.11 76 | pydantic==2.6.4 77 | pydantic_core==2.16.3 78 | Pygments==2.17.2 79 | pynvml==11.5.0 80 | pyparsing==3.1.2 81 | pytablewriter==1.2.0 82 | python-dateutil==2.9.0.post0 83 | python-gflags==3.1.2 84 | pytz==2024.1 85 | pyzmq==25.1.2 86 | regex==2023.12.25 87 | responses==0.18.0 88 | rouge_score==0.1.2 89 | sacrebleu==1.5.0 90 | safetensors==0.4.2 91 | scikit-learn==1.4.1.post1 92 | scipy==1.12.0 93 | six==1.16.0 94 | sniffio==1.3.1 95 | sqlitedict==2.1.0 96 | stack-data==0.6.3 97 | tabledata==1.3.3 98 | tabulate==0.9.0 99 | tcolorpy==0.1.4 100 | threadpoolctl==3.4.0 101 | tokenizers==0.15.2 102 | tornado==6.4 103 | tqdm==4.66.2 104 | tqdm-multiprocess==0.0.11 105 | traitlets==5.14.2 106 | transformers==4.39.1 107 | triton==2.2.0 108 | typepy==1.3.2 109 | tzdata==2024.1 110 | wcwidth==0.2.13 111 | word2number==1.1 112 | xxhash==3.4.1 113 | yarl==1.9.4 114 | zstandard==0.22.0 115 | sentencepiece==0.2.0 116 | protobuf==5.26.0 117 | tensorboard==2.16.2 118 | -------------------------------------------------------------------------------- /scripts/save_llama_hessians.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | from caldera.decomposition.weight_compression import * 3 | 4 | 5 | @dataclass 6 | class Arguments: 7 | base_model: str = field(default="meta-llama/Llama-2-7b-hf", metadata={ 8 | "help": ("Path of the model that will eventually be quantized, as " 9 | "either a local or a Huggingface path.") 10 | }) 11 | token: str = field(default=None, metadata={ 12 | "help": "Huggingface access token for private models." 13 | }) 14 | hessian_save_path: str = field( 15 | default="./data/hessians/llama-2-7b", metadata={ 16 | "help": ("Directory in which to save Hessians.") 17 | }) 18 | n_sample_proc: int = field(default=4, metadata={ 19 | "help": "Number of processes used to sample calibration data." 20 | }) 21 | 22 | 23 | @dataclass 24 | class DataParametersCommandLine: 25 | """ 26 | Parameters for loading the calibration dataset and computing the 27 | inputs to each layer. 28 | """ 29 | devset: str = field( 30 | default="rp1t", metadata={"help": ( 31 | "Calibration dataset; either rp1t or falcon" 32 | ), "choices": ["rp1t", "falcon"]} 33 | ) 34 | devset_size: int = field( 35 | default=256, metadata={"help": ( 36 | "Number of calibration samples to use." 37 | )} 38 | ) 39 | context_length: int = field( 40 | default=4096, metadata={"help": ( 41 | "Length of context window." 42 | )} 43 | ) 44 | batch_size: int = field( 45 | default=2, metadata={"help": ( 46 | "Number of datapoints to pass into the model at once." 47 | )} 48 | ) 49 | chunk_size: int = field( 50 | default=256, metadata={"help": ( 51 | "Number of batches sent to each GPU at a time." 52 | )} 53 | ) 54 | devices: list[str] = field( 55 | default=None, metadata={"help": ( 56 | "Specific CUDA devices to use for Hessian computation. Defaults " 57 | "to None, which means that all available devices are used." 58 | )} 59 | ) 60 | 61 | 62 | if __name__ == "__main__": 63 | parser = transformers.HfArgumentParser([Arguments, DataParametersCommandLine]) 64 | args, data_params_command_line = parser.parse_args_into_dataclasses() 65 | 66 | devset = DevSet.RP1T if data_params_command_line.devset == "rp1t" \ 67 | else DevSet.FALCON 68 | 69 | data_params = DataParameters( 70 | devset=devset, 71 | devset_size=data_params_command_line.devset_size, 72 | context_length=data_params_command_line.context_length, 73 | batch_size=data_params_command_line.batch_size, 74 | chunk_size=data_params_command_line.chunk_size, 75 | devices=data_params_command_line.devices 76 | ) 77 | 78 | ActivationAwareWeightCompressor( 79 | model_params=ModelParameters( 80 | base_model=args.base_model, 81 | token=args.token 82 | ), 83 | data_params=data_params, 84 | hessian_save_path=args.hessian_save_path, 85 | quant_device="cuda", 86 | n_sample_proc=args.n_sample_proc, 87 | compute_hessians=True 88 | ) 89 | -------------------------------------------------------------------------------- /quip-sharp-pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >= 61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "quip-sharp" 7 | version = "0.0.1" 8 | requires-python = ">=3.8.0" 9 | 10 | dependencies = [ 11 | "packaging==24.0", 12 | "torch==2.2.1", 13 | "torchaudio==2.2.1", 14 | "torchvision==0.17.1", 15 | "absl-py==2.1.0", 16 | "accelerate==0.28.0", 17 | "aiohttp==3.9.3", 18 | "aiosignal==1.3.1", 19 | "annotated-types==0.6.0", 20 | "anyio==4.3.0", 21 | "argparse-dataclass==2.0.0", 22 | "asttokens==2.4.1", 23 | "attrs==23.2.0", 24 | "chardet==5.2.0", 25 | "click==8.1.7", 26 | "colorama==0.4.6", 27 | "comm==0.2.2", 28 | "contourpy==1.2.0", 29 | "cycler==0.12.1", 30 | "DataProperty==1.0.1", 31 | "datasets==2.18.0", 32 | "debugpy==1.8.1", 33 | "decorator==5.1.1", 34 | "deepspeed==0.14.0", 35 | "dill==0.3.8", 36 | "distro==1.9.0", 37 | "evaluate==0.4.1", 38 | "executing==2.0.1", 39 | "fonttools==4.50.0", 40 | "frozenlist==1.4.1", 41 | "fsspec==2024.2.0", 42 | "glog==0.3.1", 43 | "h11==0.14.0", 44 | "hjson==3.1.0", 45 | "httpcore==1.0.4", 46 | "httpx==0.27.0", 47 | "huggingface-hub==0.21.4", 48 | "ipykernel==6.29.3", 49 | "ipython==8.22.2", 50 | "jedi==0.19.1", 51 | "joblib==1.3.2", 52 | "jsonlines==4.0.0", 53 | "jupyter_client==8.6.1", 54 | "jupyter_core==5.7.2", 55 | "kiwisolver==1.4.5", 56 | "lm-eval==0.3.0", 57 | "lxml==5.1.0", 58 | "matplotlib==3.8.3", 59 | "matplotlib-inline==0.1.6", 60 | "mbstrdecoder==1.1.3", 61 | "more-itertools==10.2.0", 62 | "multidict==6.0.5", 63 | "multiprocess==0.70.16", 64 | "nest-asyncio==1.6.0", 65 | "ninja==1.11.1.1", 66 | "nltk==3.8.1", 67 | "numexpr==2.9.0", 68 | "openai==1.14.2", 69 | "pandas==2.2.1", 70 | "parso==0.8.3", 71 | "pathvalidate==3.2.0", 72 | "peft==0.10.0", 73 | "pexpect==4.9.0", 74 | "platformdirs==4.2.0", 75 | "portalocker==2.8.2", 76 | "primefac==2.0.12", 77 | "prompt-toolkit==3.0.43", 78 | "psutil==5.9.8", 79 | "ptyprocess==0.7.0", 80 | "pure-eval==0.2.2", 81 | "py-cpuinfo==9.0.0", 82 | "pyarrow==15.0.2", 83 | "pyarrow-hotfix==0.6", 84 | "pybind11==2.11.1", 85 | "pycountry==23.12.11", 86 | "pydantic==2.6.4", 87 | "pydantic_core==2.16.3", 88 | "Pygments==2.17.2", 89 | "pynvml==11.5.0", 90 | "pyparsing==3.1.2", 91 | "pytablewriter==1.2.0", 92 | "python-dateutil==2.9.0.post0", 93 | "python-gflags==3.1.2", 94 | "pytz==2024.1", 95 | "pyzmq==25.1.2", 96 | "regex==2023.12.25", 97 | "responses==0.18.0", 98 | "rouge_score==0.1.2", 99 | "sacrebleu==1.5.0", 100 | "safetensors==0.4.2", 101 | "scikit-learn==1.4.1.post1", 102 | "scipy==1.12.0", 103 | "six==1.16.0", 104 | "sniffio==1.3.1", 105 | "sqlitedict==2.1.0", 106 | "stack-data==0.6.3", 107 | "tabledata==1.3.3", 108 | "tabulate==0.9.0", 109 | "tcolorpy==0.1.4", 110 | "threadpoolctl==3.4.0", 111 | "tokenizers==0.15.2", 112 | "tornado==6.4", 113 | "tqdm==4.66.2", 114 | "tqdm-multiprocess==0.0.11", 115 | "traitlets==5.14.2", 116 | "transformers==4.39.1", 117 | "triton==2.2.0", 118 | "typepy==1.3.2", 119 | "tzdata==2024.1", 120 | "wcwidth==0.2.13", 121 | "word2number==1.1", 122 | "xxhash==3.4.1", 123 | "yarl==1.9.4", 124 | "zstandard==0.22.0", 125 | "sentencepiece==0.2.0", 126 | "protobuf==5.26.0", 127 | "tensorboard==2.16.2", 128 | ] 129 | 130 | [tool.setuptools.packages.find] 131 | where = ["."] -------------------------------------------------------------------------------- /scripts/eval_zero_shot.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | from lib.utils import LMEvalAdaptor 3 | from lm_eval import evaluator 4 | import os 5 | import json 6 | from quantize_save_llama import load_quantized_model 7 | from dataclasses import field, dataclass 8 | import transformers 9 | import glob 10 | 11 | # Adapted from https://github.com/Cornell-RelaxML/quip-sharp 12 | 13 | class LMEvalAdaptorWithDevice(LMEvalAdaptor): 14 | def __init__(self, 15 | model_name, 16 | model, 17 | tokenizer, 18 | batch_size=1, 19 | max_length=-1, 20 | device="cuda"): 21 | super().__init__( 22 | model_name, model, tokenizer, batch_size, max_length 23 | ) 24 | self._device = device 25 | 26 | @property 27 | def device(self): 28 | return self._device 29 | 30 | 31 | @dataclass 32 | class Arguments: 33 | model_save_path: str = field(metadata={ 34 | "help": ("Path of the .pt file in which the model can be found.") 35 | }) 36 | finetune_save_dir: str = field(default=None, metadata={ 37 | "help": ("If using a finetuned model, the directory in which the " 38 | "model.safetensors file is stored") 39 | }) 40 | output_path: str = field(default=None, metadata={ 41 | "help": ("Path in which to save a JSON file with zero-shot results.") 42 | }) 43 | base_model: str = field(default="meta-llama/Llama-2-7b-hf", metadata={ 44 | "help": ("Path of the original model, as either a local or a " 45 | "Huggingface path") 46 | }) 47 | tasks: list[str] = field(default_factory=list, metadata={ 48 | "help": ("Task on which to measure zero-shot accuracy, e.g." 49 | "wingorande, piqa, arc_easy, arc_challenge, rte, cola...") 50 | }) 51 | batch_size: int = field(default=1, metadata={ 52 | "help": "Number of datapoints processed at once" 53 | }) 54 | device: str = field(default="cuda:0", metadata={ 55 | "help": "Device on which to run evaluation." 56 | }) 57 | cuda_graph: bool = field(default=False, metadata={ 58 | "help": "Whether to use CUDA graphs and flash attention to speed up evaluation." 59 | }) 60 | 61 | 62 | def eval_zero_shot(args: Arguments): 63 | print(args.base_model) 64 | model = load_quantized_model(args.model_save_path, args.base_model, args.device, cuda_graph=args.cuda_graph) 65 | model = model.to(args.device) 66 | if args.finetune_save_dir is not None: 67 | from safetensors.torch import load_model 68 | for safetensor_file in glob.glob(args.finetune_save_dir + "/model*.safetensors"): 69 | print("Loading ", safetensor_file) 70 | load_model(model, safetensor_file, strict=False) 71 | tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=False) 72 | print("Loaded model!") 73 | 74 | tokenizer.pad_token = tokenizer.eos_token 75 | 76 | lm_eval_model = LMEvalAdaptorWithDevice( 77 | args.base_model, model, tokenizer, args.batch_size, device=args.device) 78 | lm_eval_model.device 79 | results = evaluator.simple_evaluate( 80 | model=lm_eval_model, 81 | tasks=args.tasks, 82 | batch_size=args.batch_size, 83 | no_cache=True, 84 | num_fewshot=0, 85 | device=args.device 86 | ) 87 | 88 | print(evaluator.make_table(results)) 89 | 90 | if args.output_path is not None: 91 | os.makedirs(os.path.dirname(args.output_path), exist_ok=True) 92 | # otherwise cannot save 93 | results["config"]["model"] = args.base_model 94 | with open(args.output_path + ".json", "w") as f: 95 | json.dump(results, f, indent=2) 96 | 97 | 98 | if __name__ == "__main__": 99 | parser = transformers.HfArgumentParser([Arguments]) 100 | args = parser.parse_args_into_dataclasses()[0] 101 | eval_zero_shot(args) -------------------------------------------------------------------------------- /notebooks/test_caldera.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Notebook to try out CALDERA decomposition on a random matrix" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import torch\n", 17 | "\n", 18 | "import sys\n", 19 | "import os" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "src_dir = os.path.abspath(os.path.join(os.getcwd(), \"..\"))\n", 29 | "\n", 30 | "if src_dir not in sys.path:\n", 31 | " sys.path.append(src_dir)" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "from src.caldera.decomposition.dataclasses import CalderaParams\n", 41 | "from src.caldera.utils.quantization import QuantizerFactory\n", 42 | "from src.caldera.decomposition.alg import caldera" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "quant_factory_Q = QuantizerFactory(method=\"uniform\", block_size=64)\n", 52 | "quant_factor_LR = QuantizerFactory(method=\"uniform\", block_size=64)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "quant_params = CalderaParams(\n", 62 | " compute_quantized_component=True,\n", 63 | " compute_low_rank_factors=True,\n", 64 | " Q_bits=4,\n", 65 | " L_bits=4,\n", 66 | " R_bits=4,\n", 67 | " rank=16,\n", 68 | " iters=20,\n", 69 | " lplr_iters=5,\n", 70 | " activation_aware_Q=False,\n", 71 | " activation_aware_LR=True,\n", 72 | " lattice_quant_Q=False,\n", 73 | " lattice_quant_LR=False,\n", 74 | " update_order=[\"Q\", \"LR\"],\n", 75 | " quant_factory_Q=quant_factory_Q,\n", 76 | " quant_factory_LR=quant_factor_LR,\n", 77 | " rand_svd=False\n", 78 | ")" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "torch.manual_seed(42)\n", 88 | "\n", 89 | "W = torch.rand(1024, 1024)\n", 90 | "X = torch.randn(1024, 2048)\n", 91 | "H = torch.matmul(X, X.T)" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "metadata": {}, 98 | "outputs": [], 99 | "source": [ 100 | "caldera_decomposition = caldera(\n", 101 | " quant_params=quant_params,\n", 102 | " W=W,\n", 103 | " H=H,\n", 104 | " device=\"cpu\",\n", 105 | " use_tqdm=True,\n", 106 | " scale_W=True\n", 107 | ")" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "print(f\"caldera_decomposition.Q.shape: {caldera_decomposition.Q.shape}\")\n", 117 | "print(f\"caldera_decomposition.L.shape: {caldera_decomposition.L.shape}\")\n", 118 | "print(f\"caldera_decomposition.R.shape: {caldera_decomposition.R.shape}\")" 119 | ] 120 | } 121 | ], 122 | "metadata": { 123 | "kernelspec": { 124 | "display_name": "Python 3.11 (venv)", 125 | "language": "python", 126 | "name": "venv" 127 | }, 128 | "language_info": { 129 | "codemirror_mode": { 130 | "name": "ipython", 131 | "version": 3 132 | }, 133 | "file_extension": ".py", 134 | "mimetype": "text/x-python", 135 | "name": "python", 136 | "nbconvert_exporter": "python", 137 | "pygments_lexer": "ipython3", 138 | "version": "3.10.15" 139 | } 140 | }, 141 | "nbformat": 4, 142 | "nbformat_minor": 2 143 | } 144 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >= 61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "caldera" 7 | version = "0.0.1" 8 | description = "Decomposition of neural network (e.g., LLM) weight matrices into Q + LR, where L and R are low-rank factors and all matrices are quantized to low-bit precision." 9 | readme = "README.md" 10 | requires-python = ">=3.8.0" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: MIT License", 14 | "Operating System :: OS Independent", 15 | ] 16 | 17 | dependencies = [ 18 | "packaging==24.0", 19 | "torch==2.2.1", 20 | "torchaudio==2.2.1", 21 | "torchvision==0.17.1", 22 | "absl-py==2.1.0", 23 | "accelerate==0.28.0", 24 | "aiohttp==3.9.3", 25 | "aiosignal==1.3.1", 26 | "annotated-types==0.6.0", 27 | "anyio==4.3.0", 28 | "argparse-dataclass==2.0.0", 29 | "asttokens==2.4.1", 30 | "attrs==23.2.0", 31 | "chardet==5.2.0", 32 | "click==8.1.7", 33 | "colorama==0.4.6", 34 | "comm==0.2.2", 35 | "contourpy==1.2.0", 36 | "cycler==0.12.1", 37 | "DataProperty==1.0.1", 38 | "datasets==2.18.0", 39 | "debugpy==1.8.1", 40 | "decorator==5.1.1", 41 | "deepspeed==0.14.0", 42 | "dill==0.3.8", 43 | "distro==1.9.0", 44 | "evaluate==0.4.1", 45 | "executing==2.0.1", 46 | "fonttools==4.50.0", 47 | "frozenlist==1.4.1", 48 | "fsspec==2024.2.0", 49 | "glog==0.3.1", 50 | "h11==0.14.0", 51 | "hjson==3.1.0", 52 | "httpcore==1.0.4", 53 | "httpx==0.27.0", 54 | "huggingface-hub==0.21.4", 55 | "ipykernel==6.29.3", 56 | "ipython==8.22.2", 57 | "jedi==0.19.1", 58 | "joblib==1.3.2", 59 | "jsonlines==4.0.0", 60 | "jupyter_client==8.6.1", 61 | "jupyter_core==5.7.2", 62 | "kiwisolver==1.4.5", 63 | "lm-eval==0.3.0", 64 | "lxml==5.1.0", 65 | "matplotlib==3.8.3", 66 | "matplotlib-inline==0.1.6", 67 | "mbstrdecoder==1.1.3", 68 | "more-itertools==10.2.0", 69 | "multidict==6.0.5", 70 | "multiprocess==0.70.16", 71 | "nest-asyncio==1.6.0", 72 | "ninja==1.11.1.1", 73 | "nltk==3.8.1", 74 | "numexpr==2.9.0", 75 | "openai==1.14.2", 76 | "pandas==2.2.1", 77 | "parso==0.8.3", 78 | "pathvalidate==3.2.0", 79 | "peft==0.10.0", 80 | "pexpect==4.9.0", 81 | "platformdirs==4.2.0", 82 | "portalocker==2.8.2", 83 | "primefac==2.0.12", 84 | "prompt-toolkit==3.0.43", 85 | "psutil==5.9.8", 86 | "ptyprocess==0.7.0", 87 | "pure-eval==0.2.2", 88 | "py-cpuinfo==9.0.0", 89 | "pyarrow==15.0.2", 90 | "pyarrow-hotfix==0.6", 91 | "pybind11==2.11.1", 92 | "pycountry==23.12.11", 93 | "pydantic==2.6.4", 94 | "pydantic_core==2.16.3", 95 | "Pygments==2.17.2", 96 | "pynvml==11.5.0", 97 | "pyparsing==3.1.2", 98 | "pytablewriter==1.2.0", 99 | "python-dateutil==2.9.0.post0", 100 | "python-gflags==3.1.2", 101 | "pytz==2024.1", 102 | "pyzmq==25.1.2", 103 | "regex==2023.12.25", 104 | "responses==0.18.0", 105 | "rouge_score==0.1.2", 106 | "sacrebleu==1.5.0", 107 | "safetensors==0.4.2", 108 | "scikit-learn==1.4.1.post1", 109 | "scipy==1.12.0", 110 | "six==1.16.0", 111 | "sniffio==1.3.1", 112 | "sqlitedict==2.1.0", 113 | "stack-data==0.6.3", 114 | "tabledata==1.3.3", 115 | "tabulate==0.9.0", 116 | "tcolorpy==0.1.4", 117 | "threadpoolctl==3.4.0", 118 | "tokenizers==0.15.2", 119 | "tornado==6.4", 120 | "tqdm==4.66.2", 121 | "tqdm-multiprocess==0.0.11", 122 | "traitlets==5.14.2", 123 | "transformers==4.39.1", 124 | "triton==2.2.0", 125 | "typepy==1.3.2", 126 | "tzdata==2024.1", 127 | "wcwidth==0.2.13", 128 | "word2number==1.1", 129 | "xxhash==3.4.1", 130 | "yarl==1.9.4", 131 | "zstandard==0.22.0", 132 | "sentencepiece==0.2.0", 133 | "protobuf==5.26.0", 134 | "tensorboard==2.16.2", 135 | ] 136 | 137 | [tool.setuptools.packages.find] 138 | where = ["src"] -------------------------------------------------------------------------------- /scripts/eval_ppl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from lib.utils import gptq_data_utils 3 | from tqdm import tqdm 4 | from quantize_save_llama import load_quantized_model 5 | from dataclasses import dataclass, field 6 | import transformers 7 | import os 8 | import json 9 | import glob 10 | 11 | # Adapted from https://github.com/Cornell-RelaxML/quip-sharp 12 | 13 | @dataclass 14 | class Arguments: 15 | model_save_path: str = field(metadata={ 16 | "help": ("Path of the .pt file in which the model can be found.") 17 | }) 18 | finetune_save_dir: str = field(default=None, metadata={ 19 | "help": ("If using a finetuned model, the directory in which the " 20 | "model.safetensors file is stored") 21 | }) 22 | output_path: str = field(default=None, metadata={ 23 | "help": ("Path in which to save a JSON file with zero-shot results.") 24 | }) 25 | base_model: str = field(default="meta-llama/Llama-2-7b-hf", metadata={ 26 | "help": ("Path of the original model, as either a local or a " 27 | "Huggingface path") 28 | }) 29 | seed: int = field(default=0, metadata={ 30 | "help": "Random seed for selecting test points from the dataset" 31 | }) 32 | seqlen: int = field(default=4096, metadata={ 33 | "help": "Sequence length of model inputs" 34 | }) 35 | device: str = field(default="cuda:0", metadata={ 36 | "help": "Device on which to run evaluation." 37 | }) 38 | datasets: list[str] = field(default_factory=list, metadata={ 39 | "help": ("Which datasets, out of \"wikitext2\" and \"c4\" to compute " 40 | "perplexity. Defaults to both datasets")}) 41 | cuda_graph: bool = field(default=False, metadata={ 42 | "help": "Whether to use CUDA graphs and flash attention to speed up evaluation." 43 | }) 44 | 45 | 46 | def eval_ppl(args: Arguments): 47 | 48 | with torch.no_grad(): 49 | model = load_quantized_model(args.model_save_path, args.base_model, args.device, cuda_graph=args.cuda_graph) 50 | 51 | if args.finetune_save_dir is not None: 52 | from safetensors.torch import load_model 53 | for safetensor_file in glob.glob(args.finetune_save_dir + "/model*.safetensors"): 54 | print("Loading ", safetensor_file) 55 | load_model(model, safetensor_file, strict=False) 56 | 57 | 58 | if not args.datasets: 59 | args.datasets = ["wikitext2", "c4"] 60 | 61 | ppls = {} 62 | 63 | for dataset in args.datasets: 64 | input_tok = gptq_data_utils.get_test_tokens( 65 | dataset, seed=args.seed, seqlen=args.seqlen, model=args.base_model) 66 | nsamples = input_tok.numel() // args.seqlen 67 | input_tok = input_tok[0, :(args.seqlen * nsamples)].view( 68 | nsamples, args.seqlen) 69 | 70 | loss_fct = torch.nn.CrossEntropyLoss().cuda() 71 | acc_loss = 0.0 72 | progress = tqdm(range(nsamples)) 73 | for ii in progress: 74 | input = input_tok[ii, :].to(args.device).view(1, -1) 75 | output = model(input, 76 | use_cache=False, 77 | output_hidden_states=False, 78 | output_attentions=False)[0] 79 | 80 | shift_logits = output[:, :-1, :].contiguous() 81 | shift_labels = input[:, 1:] 82 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), 83 | shift_labels.view(-1)) 84 | acc_loss += loss.item() 85 | progress.set_description(f"avg_loss = {acc_loss/(ii+1)}") 86 | 87 | avg_loss = acc_loss / nsamples 88 | 89 | ppl = torch.exp(torch.tensor(avg_loss)).item() 90 | print(f'{dataset} perplexity: {ppl}') 91 | 92 | ppls[dataset] = ppl 93 | 94 | if args.output_path is not None: 95 | os.makedirs(os.path.dirname(args.output_path), exist_ok=True) 96 | with open(args.output_path + ".json", "w") as f: 97 | json.dump(ppls, f, indent=2) 98 | 99 | 100 | if __name__ == "__main__": 101 | parser = transformers.HfArgumentParser([Arguments]) 102 | args = parser.parse_args_into_dataclasses()[0] 103 | eval_ppl(args) -------------------------------------------------------------------------------- /notebooks/eval_throughput.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import gc\n", 20 | "import torch\n", 21 | "from lib.utils import graph_wrapper\n", 22 | "from transformers import AutoTokenizer, LlamaForCausalLM\n", 23 | "import time" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "def load_quantized_model(\n", 33 | " model_save_path,\n", 34 | " base_model,\n", 35 | " device,\n", 36 | "):\n", 37 | " model = torch.load(model_save_path, map_location=device).to(device) # Llama with Caldera\n", 38 | " graph_model = graph_wrapper.get_graph_wrapper(LlamaForCausalLM, device=\"cpu\").from_pretrained(\n", 39 | " base_model, torch_dtype='auto', device_map=\"cpu\", low_cpu_mem_usage=True,\n", 40 | " use_flash_attention_2=True\n", 41 | " ).to(\"cpu\") # base Llama\n", 42 | "\n", 43 | " for i in range(len(graph_model.model.layers)):\n", 44 | " graph_model.model.layers[i].self_attn.q_proj = model.model.layers[i].self_attn.q_proj\n", 45 | " graph_model.model.layers[i].self_attn.k_proj = model.model.layers[i].self_attn.k_proj\n", 46 | " graph_model.model.layers[i].self_attn.v_proj = model.model.layers[i].self_attn.v_proj\n", 47 | " graph_model.model.layers[i].self_attn.o_proj = model.model.layers[i].self_attn.o_proj\n", 48 | " graph_model.model.layers[i].mlp = model.model.layers[i].mlp\n", 49 | " graph_model.model.layers[i].post_attention_layernorm = graph_model.model.layers[i].post_attention_layernorm.to(device)\n", 50 | " graph_model.model.layers[i].input_layernorm = graph_model.model.layers[i].input_layernorm.to(device)\n", 51 | " graph_model.model.norm = graph_model.model.norm.to(device)\n", 52 | " graph_model.model.embed_tokens = graph_model.model.embed_tokens.to(device)\n", 53 | " graph_model.lm_head = graph_model.lm_head.to(device)\n", 54 | " graph_model.graph_device = device\n", 55 | " return graph_model.to(device)\n", 56 | " " 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "## Test Throughput of CALDERA Model" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "MODEL_PATH = \"/media/hdd1/caldera-full-models/llama-2-7b/caldera-rank-256-4B-factors-downdate-no-RHT-ft.pt\"\n", 73 | "BASE_MODEL = \"meta-llama/Llama-2-7b-hf\"\n", 74 | "DEVICE = \"cuda:2\"\n", 75 | "SAMPLES = 500" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "model = load_quantized_model(MODEL_PATH, BASE_MODEL, DEVICE)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "def eval_throughput(model, samples, base_model, device, batch_size=1, seq_len=1):\n", 94 | " tokenizer = AutoTokenizer.from_pretrained(base_model)\n", 95 | "\n", 96 | " prompt = 'It is a truth universally acknowledged that'\n", 97 | " inputs = tokenizer(prompt, return_tensors='pt')\n", 98 | " token = inputs['input_ids'][0:1, 0:1].to(device).repeat(batch_size, seq_len)\n", 99 | " model(token)\n", 100 | "\n", 101 | " torch.cuda.synchronize()\n", 102 | " start = time.time()\n", 103 | " for _ in range(samples):\n", 104 | " model(token)\n", 105 | " torch.cuda.synchronize()\n", 106 | " end = time.time()\n", 107 | " print('TIME:', (end - start) / samples, 's/tok')\n", 108 | " print (f'THROUGHPUT: {samples / (end - start)} tok/s')" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "eval_throughput(model, SAMPLES, BASE_MODEL, DEVICE)" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "## Compare with Unquantized" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "del model\n", 134 | "gc.collect()\n", 135 | "torch.cuda.empty_cache()" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "model = graph_wrapper.get_graph_wrapper(LlamaForCausalLM, device=DEVICE).from_pretrained(\n", 145 | " BASE_MODEL, torch_dtype='auto', device_map=DEVICE, low_cpu_mem_usage=True,\n", 146 | " use_flash_attention_2=True\n", 147 | " )" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": null, 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "eval_throughput(model, SAMPLES, BASE_MODEL, DEVICE)" 157 | ] 158 | } 159 | ], 160 | "metadata": { 161 | "kernelspec": { 162 | "display_name": "Python 3", 163 | "language": "python", 164 | "name": "python3" 165 | }, 166 | "language_info": { 167 | "codemirror_mode": { 168 | "name": "ipython", 169 | "version": 3 170 | }, 171 | "file_extension": ".py", 172 | "mimetype": "text/x-python", 173 | "name": "python", 174 | "nbconvert_exporter": "python", 175 | "pygments_lexer": "ipython3", 176 | "version": "3.11.9" 177 | } 178 | }, 179 | "nbformat": 4, 180 | "nbformat_minor": 2 181 | } 182 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CALDERA (Calibration Aware Low-Precision DEcomposition with Low-Rank Adaptation) 2 | 3 | CALDERA is a post-training compression method that represents the weights of LLM matrices via a low-rank, low-precision decomposition $\mathbf{W} \approx \mathbf{Q} + \mathbf{L} \mathbf{R}$, where $\mathbf{L}$ and $\mathbf{R}$ are low-rank factors and $\mathbf{Q}, \mathbf{L}$ and $\mathbf{R}$ are all quantized to low-precision formats. 4 | By formulating this decomposition as an optimization problem and solving it via alternating minimization, CALDERA outperforms existing compression techniques in the regime of less than 2.5 bits per parameter. 5 | To enhance performance on specific tasks, CALDERA also supports Low Rank Adaptation (LoRA) fine tuning ([Hu et al, 2021](https://arxiv.org/pdf/2106.09685)) of a portion of the low-rank factors. 6 | 7 | 🔗 Paper link: [Compressing Large Language Models using Low Rank 8 | and Low Precision Decomposition](https://openreview.net/pdf?id=lkx3OpcqSZ) 9 | 10 |

11 | Alt Text 12 |

13 | 14 |

CALDERA decomposes a full-precision weight matrix into a low-rank component (LR), which captures the contribution of the top singular values using BL, BR bits, and Q for the trailing singular values with BQ bits, enabling flexible precision settings for each component.

15 | 16 | 17 | ## Setup Instructions 18 | 19 | ### Note on Python, CUDA, and PyTorch Versions 20 | These setup instructions have been tested on Python 3.10 and 3.11, CUDA 12.1 and 12.2, and PyTorch 2.2. 21 | 22 | In particular, the package `fast-hadamard-transform` lacks wheels for newer versions of these dependencies; the available wheels can be found [here](https://github.com/Dao-AILab/fast-hadamard-transform) (the wheel filenames are of the form `fast_hadamard_transform-1.0.4.post1+cutorchxx11abiTRUE-cp-cp-linux_x86_64.whl`). 23 | 24 | ### 🛠 Instructions 25 | 1. Install `caldera` as a submodule (named `caldera`). 26 | From the home directory of this repository, run 27 | ``` 28 | pip install . 29 | ``` 30 | This will automatically install all dependencies, except `fast-hadamard-transform`, which has dependency issues. 31 | 32 | 2. While CALDERA can be used with any quantizer, we demonstrate the results using QuIP#'s quantizer. Setup the QuIP# ([Tseng et al, 2024](https://arxiv.org/pdf/2402.04396)) submodule: 33 | ``` 34 | ./setup_quip_sharp.sh 35 | ``` 36 | 37 | This script first sets up the QuIP# Python library, and then builds the `quiptools` CUDA library, which provides dequantization kernels for inference. 38 | 39 | QuIP# is used for the quantization of the $\mathbf{Q}$ matrix (backbone), and also provides useful subroutines for Hessian computation. 40 | 41 | 3. Install `fast-hadamard-transform`: `pip install fast-hadamard-transform`. 42 | 43 | **Note**: If you get the error `package 'wheel' is not installed`, you can install it using `pip install wheel`. 44 | 45 | 46 | ## Repo structure 47 | 48 | ### `src/caldera` 49 | This folder contains the bulk of the code for CALDERA. Via step 1 above, everything in this folder is contained in the editable python package `caldera`. 50 | 51 | **`src/caldera/utils`**: utils for CALDERA. Some relevant utils files are listed below: 52 | - `enums.py`: `Enum` objects, e.g., for specifying transformer sublayers (query, key, etc.) and the name of the calibration dataset. 53 | - `quantization.py`: Uniform and Normal Float ([Dettmers et al, 2023](https://arxiv.org/pdf/2305.14314)) quantizers. 54 | Generally, these are not recommended; E8 Lattice quantizers from QuIP# typically perform better. 55 | 56 | **`src/caldera/decomposition`**: code for the CALDERA decomposition algorithm, as well as its application to transformer layers. 57 | 58 | - `dataclasses.py`: classes for storing parameters of the CALDERA algorithm, as well as information about quantized layers. 59 | - `weight_compression.py`: code for the `ActivationAwareWeightCompressor` class. Unless Hessians have already been computed, this performs Hessian computation upon instantiation. The method `get_layer_quantizer`, called on a layer index, instantiates an `ActivationAwareLayerQuant` object. 60 | - `layer_quantization.py`: code for the `ActivationAwareLayerQuant` class. The `compress_sublayer` compresses the specified sublayer, calling the `caldera` method from `alg.py`. 61 | There are also methods for plotting the data-aware error, saving errors and quantization parameters to a JSON file, and instantiating a quantized linear layer. 62 | - `alg.py`: the CALDERA algorithm. 63 | - `quantized_layer.py`: code for the `CalderaQuantizedLinear` class, which is a neural network module that computes $X^\top (Q + LR)^\top$ on layer input $X$, performing dequantization on the fly. 64 | 65 | 66 | ### `scripts` 67 | This folder contains python scripts for running zero-shot, perplexity, and finetuning experiments. 68 | 69 | Parameters for all of these scripts are specified via command-line arguments. 70 | 71 | ### `shell_scripts` 72 | These are Bash scripts for running experiments in the `scripts` folder with some reasonable parameters. 73 | 74 | Each shell script has variables at the top specifying, e.g., directories in which to save script outputs. 75 | Make sure you set those variables as appropriate. 76 | 77 | **Note**: all shell scripts are meant to be run from the root directory of this repo, i.e., `./shell_scripts/run_eval_ppl.py` instead of `cd shell_scripts && ./run_eval_ppl.py`. 78 | 79 | ### `quip_sharp` 80 | This is the quip-sharp submodule, which is initialized in step 2 of the setup instructions. 81 | 82 | ### `notebooks` 83 | 84 | - `test_caldera.ipynb` to obtain the quickly try out CALDERA decomposition on a random matrix. 85 | - `eval_throughput.ipynb` obtains the autoregressive generation throughput of the model. 86 | 87 | ## 🚀 Trying out CALDERA: Example experiment workflow 88 | 89 | **Note**: Edit each script before running it to make sure desired parameters are used. 90 | 91 | 1. **Compute the Hessians** using `./shell_scripts/run_save_hessians.sh`, which will store Hessian matrices for each layer to files. 92 | 93 | 2. **Quantize the full model** using `shell_scripts/run_quantize_save_caldera.sh`. This stores each quantized transformer layer. 94 | The quantized model can later be loaded in using the `load_quantized_model` function in `scripts/quantize_save_llama.py`. 95 | 96 | 3. **Run zero-shot/perplexity experiments** using `shell_scripts/run_eval_zeroshot.sh` or `shell_scripts/run_eval_ppl.sh`. 97 | 98 | 4. **Finetune** using, e.g., `shell_scripts/run_finetune_wikitext.sh` 99 | 100 | ## Citation 101 | If you find our work useful, consider citing it as: 102 | ```bibtex 103 | @inproceedings{ 104 | saha2024compressing, 105 | title={Compressing Large Language Models using Low Rank and Low Precision Decomposition}, 106 | author={Rajarshi Saha and Naomi Sagan and Varun Srivastava and Andrea Goldsmith and Mert Pilanci}, 107 | booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems}, 108 | year={2024}, 109 | url={https://openreview.net/forum?id=lkx3OpcqSZ} 110 | } 111 | ``` 112 | -------------------------------------------------------------------------------- /scripts/finetune_RHT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser 4 | from lib.utils import sample_rp1t 5 | from lib.utils.data_utils import split_data 6 | import gc 7 | from quantize_save_llama import load_quantized_model 8 | from dataclasses import dataclass, field 9 | from tqdm import tqdm 10 | from safetensors.torch import save_model, load_model 11 | import os 12 | 13 | # Adapted from https://github.com/Cornell-RelaxML/quip-sharp 14 | 15 | 16 | @dataclass 17 | class Arguments: 18 | base_model: str = field(metadata={ 19 | "help": ("Path of the original model, as either a local or a " 20 | "Huggingface path") 21 | }) 22 | model_path: str = field(metadata={ 23 | "help": ("Path of the .pt file in which the model can be found.") 24 | }) 25 | finetuned_save_path: str = field(metadata={ 26 | "help": ("Path in which to save the final finetuned model") 27 | }) 28 | devset_size: int = field(default=256, metadata={ 29 | "help": ("Number of datapoints to sample from the calibration set " 30 | "for finetuning.") 31 | }) 32 | ctx_size: int = field(default=512, metadata={ 33 | "help": ("Length of each input data sequence.") 34 | }) 35 | device: str = field(default="cuda", metadata={ 36 | "help": "Device to use for finetuning." 37 | }) 38 | ft_bs: int = field(default=2, metadata={ 39 | "help": "Batch size for finetuning." 40 | }) 41 | ft_valid_size: int = field(default=64, metadata={ 42 | "help": ("Number of datapoints to set aside for validation. " 43 | "The number of training datapoints is devset_size, minus " 44 | "ft_valid_size.") 45 | }) 46 | finetune_factors: bool = field(default=False, metadata={ 47 | "help": ("Whether to finetune L and R in addition to the randomized " 48 | "Hadamard transform diagonal matrices.") 49 | }) 50 | RHT_learning_rate: float = field(default=1e-3, metadata={ 51 | "help": "Learning rate for the randomized Hadamard transform parameters." 52 | }) 53 | factors_learning_rate: float = field(default=1e-4, metadata={ 54 | "help": "Learning rate for L andsR, if finetune_factors is set True." 55 | }) 56 | epochs: int = field(default=5, metadata={ 57 | "help": "Number of epochs of finetuning." 58 | }) 59 | 60 | 61 | def main(args: Arguments): 62 | torch.set_grad_enabled(False) 63 | tokenizer = AutoTokenizer.from_pretrained(args.base_model) 64 | tokenizer.pad_token = tokenizer.eos_token 65 | devset = sample_rp1t(tokenizer, args.devset_size, args.ctx_size, 1) 66 | 67 | # Get the logits for the calibration set from the original model. The loss 68 | # function for finetuning will be the cross-entropy loss between the quantized 69 | # model logits and these logits. 70 | orig_model = AutoModelForCausalLM.from_pretrained( 71 | args.base_model, torch_dtype='auto', device_map=args.device, low_cpu_mem_usage=True 72 | ).to(args.device) 73 | 74 | orig_logits = torch.zeros(args.devset_size, args.ctx_size, orig_model.config.vocab_size) 75 | for i in range(args.devset_size): 76 | input = devset[i:i+1].to(args.device) 77 | output = orig_model(input).logits.cpu() 78 | orig_logits[i:i+1, :, :] = output 79 | 80 | orig_logits = orig_logits[:, :-1].contiguous().softmax(dim=-1).float() 81 | del orig_model 82 | gc.collect() 83 | torch.cuda.empty_cache() 84 | 85 | torch.set_grad_enabled(True) 86 | quant_model = load_quantized_model( 87 | args.model_path, args.base_model, args.device 88 | ).to(args.device).float() 89 | 90 | factor_params = [] 91 | RHT_params = [] 92 | 93 | for name, param in quant_model.named_parameters(): 94 | if 'L_ft' in name or 'R_ft' in name: 95 | if not args.finetune_factors: 96 | param.requires_grad = False 97 | else: 98 | factor_params.append(param) 99 | elif 'SU' in name or 'SV' in name: 100 | RHT_params.append(param) 101 | train_dataloader, valid_dataloader = split_data(devset, orig_logits, args) 102 | 103 | adam_params = [{ 104 | 'params': RHT_params, 105 | 'lr': args.RHT_learning_rate 106 | }] 107 | if args.finetune_factors: 108 | adam_params.append({ 109 | 'params': factor_params, 110 | 'lr': args.factors_learning_rate 111 | }) 112 | optim = torch.optim.AdamW(adam_params) 113 | scaler = torch.cuda.amp.GradScaler(enabled=True) 114 | 115 | quant_model.eval() 116 | print("Running eval.") 117 | with torch.no_grad(): 118 | val_loss = 0 119 | for _, (input, targets) in enumerate(valid_dataloader): 120 | input, targets = input.to(args.device), targets.to(args.device) 121 | output = quant_model(input).logits[:, :-1].contiguous() 122 | 123 | val_loss += nn.CrossEntropyLoss()( 124 | output.view(-1, output.shape[-1]), 125 | targets.view(-1, targets.shape[-1])).item() 126 | val_loss /= len(valid_dataloader) 127 | print("Validation loss: ", val_loss) 128 | best_val_loss = val_loss 129 | save_model(quant_model, args.finetuned_save_path) 130 | 131 | progress_bar = tqdm(range(len(train_dataloader)*args.epochs)) 132 | for _ in range(args.epochs): 133 | for _, (input, targets) in enumerate(train_dataloader): 134 | input, targets = input.to(args.device), targets.to(args.device) 135 | output = quant_model(input).logits[:, :-1].contiguous() 136 | 137 | loss = nn.CrossEntropyLoss()(output.view(-1, output.shape[-1]), 138 | targets.view(-1, targets.shape[-1])) 139 | scaler.scale(loss).backward() 140 | scaler.step(optim) 141 | scaler.update() 142 | optim.zero_grad() 143 | progress_bar.update(1) 144 | 145 | # validation 146 | quant_model.eval() 147 | print("Running eval.") 148 | with torch.no_grad(): 149 | val_loss = 0 150 | for _, (input, targets) in enumerate(valid_dataloader): 151 | input, targets = input.to(args.device), targets.to(args.device) 152 | output = quant_model(input).logits[:, :-1].contiguous() 153 | 154 | val_loss += nn.CrossEntropyLoss()( 155 | output.view(-1, output.shape[-1]), 156 | targets.view(-1, targets.shape[-1])).item() 157 | val_loss /= len(valid_dataloader) 158 | print("Validation loss: ", val_loss) 159 | if val_loss < best_val_loss: 160 | save_model(quant_model, args.finetuned_save_path) 161 | best_val_loss = val_loss 162 | quant_model.train() 163 | 164 | quant_model = load_quantized_model( 165 | args.model_path, args.base_model, args.device 166 | ).to(args.device) 167 | load_model(quant_model, args.finetuned_save_path) 168 | torch.save(quant_model, args.finetuned_save_path) 169 | 170 | if __name__ == "__main__": 171 | parser = HfArgumentParser([Arguments]) 172 | args = parser.parse_args() 173 | main(args) -------------------------------------------------------------------------------- /src/caldera/decomposition/dataclasses.py: -------------------------------------------------------------------------------- 1 | from caldera.utils.enums import DevSet 2 | from caldera.utils.quantization import QuantizerFactory, \ 3 | AbstractQuantizer, LowMemoryQuantizer 4 | 5 | from dataclasses import field, dataclass 6 | import torch 7 | 8 | 9 | @dataclass 10 | class DataParameters: 11 | """ 12 | Parameters for loading the calibration dataset and computing the 13 | inputs to each layer. 14 | """ 15 | devset: int = field( 16 | default=DevSet.RP1T, metadata={"help": ( 17 | "Calibration dataset, as a member of the DevSet enum." 18 | )} 19 | ) 20 | devset_size: int = field( 21 | default=256, metadata={"help": ( 22 | "Number of calibration samples to use." 23 | )} 24 | ) 25 | context_length: int = field( 26 | default=4096, metadata={"help": ( 27 | "Length of context window." 28 | )} 29 | ) 30 | batch_size: int = field( 31 | default=2, metadata={"help": ( 32 | "Number of datapoints to pass into the model at once." 33 | )} 34 | ) 35 | chunk_size: int = field( 36 | default=256, metadata={"help": ( 37 | "Number of datapoints sent to each GPU at a time. " 38 | "Must be a multiple of batch_size" 39 | )} 40 | ) 41 | devices: list[str] = field( 42 | default=None, metadata={"help": ( 43 | "Specific CUDA devices to use for Hessian computation. Defaults " 44 | "to None, which means that all available devices are used." 45 | )} 46 | ) 47 | 48 | 49 | @dataclass 50 | class ModelParameters: 51 | """ 52 | Parameters for loading in a transformer model and simulating forward 53 | passes. 54 | """ 55 | base_model: str = field( 56 | default="meta-llama/Llama-2-7b-hf", metadata={"help": ( 57 | "Model to quantize." 58 | )} 59 | ) 60 | token: str = field(default=None, metadata={ 61 | "help": "Huggingface access token for private models." 62 | }) 63 | 64 | 65 | @dataclass 66 | class AccumulatorArgs: 67 | """ 68 | These arguments will be passed into the `accumulator` function in 69 | quip-sharp/quantize_llama/hessian_offline_llama.py. This class will be 70 | automatically instantiated by ActivationAwareWeightCompressor. 71 | """ 72 | scratch_path = None 73 | save_path: str = field(default="./hessians/") 74 | 75 | 76 | @dataclass 77 | class QuIPArgs: 78 | """ 79 | Parameters for QuIP. See the documentation of the quip-sharp repository 80 | for descriptions of these (or some of these) parameters. 81 | """ 82 | lora_rank: int = field(default=0) 83 | full_svd: bool = field(default=False) 84 | use_fp64: bool = False 85 | lowmem_ldlq: bool = field(default=False) 86 | scale_override: float = field(default=0.9) 87 | resid_scale_override: float = field(default=0.9) 88 | no_use_buffered: bool = field(default=False) 89 | sigma_reg: float = field(default=1e-2) 90 | sigma_reg2: float = field(default=1e-2) 91 | incoh_mode: str = field(default="had", metadata={ 92 | "help": ("Which form of incoherence processing to use. Either \"had\"" 93 | "for a randomized Hadamard transform, or \"kron\" for a " 94 | "randomized Kronecker product of 2x2 matrices.") 95 | }) 96 | rescale_WH: bool = field(default=False) 97 | quip_tune_iters: int = field(default=10, metadata={ 98 | "help": ("Number of iterations in the LDLQ step.") 99 | }) 100 | 101 | 102 | @dataclass 103 | class CalderaParams: 104 | """ 105 | Parameters for the CALDERA decomposition. 106 | """ 107 | quip_args: QuIPArgs = field(default_factory=QuIPArgs) 108 | compute_quantized_component: bool = field( 109 | default=True, metadata={"help": ( 110 | "Whether the decomposition should include a quantized full-size" 111 | "component (denoted Q)." 112 | )} 113 | ) 114 | compute_low_rank_factors: bool = field( 115 | default=True, metadata={"help": ( 116 | "Whether the decomposition should include low-rank factors (L, R)." 117 | )} 118 | ) 119 | Q_bits: int = field(default=2, metadata={ 120 | "help": "Either 2, 3, or 4 bit lattice quantization" 121 | }) 122 | L_bits: int = field(default=2, metadata={ 123 | "help": "Either 2, 3, or 4 bit lattice quantization" 124 | }) 125 | R_bits: int = field(default=2, metadata={ 126 | "help": "Either 2, 3, or 4 bit lattice quantization" 127 | }) 128 | rank: int = field(default=64, metadata={ 129 | "help": "Rank of L and R factors" 130 | }) 131 | iters: int = field(default=20) 132 | lplr_iters: int = field(default=5) 133 | activation_aware_Q: bool = field(default=True, metadata={ 134 | "help": ("Use QuIP# activation-aware quantization for Q, as opposed " 135 | "to naive quantization.") 136 | }) 137 | activation_aware_LR: bool = field(default=True, metadata={ 138 | "help": "Use activation-aware LPLR for computing the factors." 139 | }) 140 | lattice_quant_Q: bool = field(default=True, metadata={ 141 | "help": ("If Q is not data-aware, this determines whether to use " 142 | "lattice quantization, as opposed to unif/normal float quantization " 143 | "implementations.") 144 | }) 145 | lattice_quant_LR: bool = field(default=True, metadata={ 146 | "help": ("Use lattice quantization from the QuIP# codebase, as opposed" 147 | " to uniform or normal float, for L and R") 148 | }) 149 | hadamard_transform: bool = field(default=False, metadata={ 150 | "help": ("Whether to perform a randomized Hadamard transform on W " 151 | "before computing the decomposition W = Q + LR.") 152 | }) 153 | full_quip_sharp: bool = field(default=False, metadata={ 154 | "help": ("If Q is activation-aware and this parameter is True, then " 155 | "Q is computed using the full quip-sharp algorithm. " 156 | "Otherwise, we only use LDLQ.") 157 | }) 158 | update_order: list[str] = field(default_factory=list, metadata={ 159 | "help": ("List specifying whether to update the \"LR\" factors before " 160 | "\"q\" or vice versa. The default is [\"LR\", \"Q\"]; pass " 161 | "in [\"Q\", \"LR\"] to swap the update order.") 162 | }) 163 | quant_factory_Q: QuantizerFactory = field( 164 | default_factory=QuantizerFactory, metadata={"help": ( 165 | "(Non-data-aware only) QuantizerFactory (from caldera.utils.quantizers)" 166 | " object used to instantiate quantizer for Q. Only used if " 167 | "activation_aware_Q is False." 168 | )} 169 | ) 170 | quant_factory_LR: QuantizerFactory = field( 171 | default_factory=QuantizerFactory, metadata={"help": ( 172 | "(Non-lattice quant only) QuantizerFactory (from " 173 | "caldera.utils.quantizers) object used to instantiate quantizer for L " 174 | "and R. Only used if lattice_quant_LR is False." 175 | )} 176 | ) 177 | rand_svd: bool = field(default=True, metadata={ 178 | "help": "Whether to use randomized SVD for LPLR initialization" 179 | }) 180 | Q_hessian_downdate: bool = field(default=False, metadata={ 181 | "help": ("Whether to do quip-sharp's heuristic Hessian correction" 182 | "via Cholesky downdating before updating Q.") 183 | }) 184 | lattice_quant_block_size: int = field(default=32000, metadata={ 185 | "help": ("For lattice quantization, quantize parameters in groups of " 186 | "(codesize * lattice_quant_block_size) to reduce memory " 187 | "usage") 188 | }) 189 | 190 | @dataclass 191 | class CalderaDecomposition: 192 | Q: torch.Tensor = field(default=None) 193 | L: torch.Tensor = field(default=None) 194 | R: torch.Tensor = field(default=None) 195 | W: torch.Tensor = field(default=None) 196 | Q_idxs: torch.Tensor = field(default=None) 197 | L_idxs: torch.Tensor = field(default=None) 198 | R_idxs: torch.Tensor = field(default=None) 199 | Q_scale: float = field(default=1) 200 | L_scale: float = field(default=1) 201 | R_scale: float = field(default=1) 202 | global_scale: float = field(default=1) 203 | SU: torch.Tensor = field(default=None) 204 | SV: torch.Tensor = field(default=None) 205 | scaleWH: torch.Tensor = field(default=None) 206 | errors: dict[str,list[float]] = field(default_factory=dict) 207 | 208 | 209 | @dataclass 210 | class SubLayerInfo: 211 | """ 212 | Class for storing information about a transformer sub-layer (i.e., one of 213 | {query, key, value, out, gate, up, down}), including the computed 214 | decomposition Q + LR, and the activation-aware error at each iteration of 215 | the CALDERA algorithm. 216 | """ 217 | sublayer: torch.nn.Module = field(default=None) 218 | key: str = field(default="") 219 | out_key: str = field(default="") 220 | started_quant: bool = field(default=False) 221 | caldera: CalderaDecomposition = field(default_factory=CalderaDecomposition) 222 | 223 | 224 | @dataclass 225 | class QuantInfo: 226 | """ 227 | Stores information necessary for quantizing a specific matrix: 228 | 1. Whether to use lattice quantization (QuIP#) or Unif./NormalFloat 229 | quantization. 230 | 2. If lattice quantization is used, the codebook. 231 | 3. If our quantization methods are used, the quantizer object. 232 | """ 233 | lattice_quant: bool = field(default=True) 234 | lattice_cb: torch.nn.Module = field(default=None) 235 | quant: AbstractQuantizer = field(default_factory=LowMemoryQuantizer) 236 | 237 | -------------------------------------------------------------------------------- /src/caldera/utils/quantization.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from scipy import stats 4 | from abc import ABC, abstractmethod 5 | import numpy as np 6 | 7 | 8 | class AbstractQuantizer(ABC): 9 | @abstractmethod 10 | def quantize_block(self, weight): 11 | ... 12 | 13 | @abstractmethod 14 | def dequantize_block(self, weight_quant, weight_max, weight_shape): 15 | ... 16 | 17 | class LowMemoryQuantizer(AbstractQuantizer): 18 | def __init__(self, num_bits=2, method="normal", block_size=64): 19 | self.num_bits = num_bits 20 | assert num_bits in [2, 4, 8, 16] 21 | self.method = method 22 | self.block_size = block_size 23 | 24 | if self.method != "normal" and self.method != "uniform" and self.method != "uniform_clipped": 25 | raise NotImplementedError("Other quantization methods not supported yet.") 26 | 27 | def _quantize_uniform(self, weight_divabs): 28 | # weight_divabs is between -1 and 1, inclusive 29 | weight_scaled = weight_divabs * (2**(self.num_bits - 1) - 1) 30 | weight_scaled = weight_scaled.round() 31 | if self.num_bits <= 8: 32 | return weight_scaled.to(torch.int8) 33 | else: 34 | return weight_scaled.to(torch.int16) 35 | 36 | def _quantize_nf(self, weight_divabs): 37 | Normal = torch.distributions.Normal(0, 1) 38 | # We quantize the range [-1, 0) and the range [0, 1] separately, with 39 | # each having 2^{b-1} levels. 40 | # 41 | # The quantization levels are found as follows: take 2^{b-1} evenly-spaced 42 | # points from [delta, 1/2] and 2{b-1} + 1 from [1/2, 1-delta], where delta 43 | # is as defined below. The quantization levels are the corresponding 44 | # quantiles of a standard normal distribution, scaled such that they lie 45 | # in the range [-1, 1]. 46 | M = 2**(self.num_bits-1) 47 | delta = 1/2 * (1/30 + 1/32) # as described above 48 | res_neg = (1/2 - delta) / (M - 1) # resolution for [delta, 1/2] 49 | res_pos = (1/2 - delta) / M # resolution for [1/2, 1-delta] 50 | # levels to be in [-1, 1] 51 | 52 | # We index into q_neg and q_pos with these indices to get the quantized 53 | # values for the negative and positive parts of A, respectively. 54 | q_neg = Normal.icdf(res_neg * torch.arange(M).to(weight_divabs.device) + delta) / stats.norm.ppf(1-delta) 55 | q_pos = Normal.icdf(res_pos * torch.arange(M + 1).to(weight_divabs.device) + 1/2) / stats.norm.ppf(1-delta) 56 | 57 | neg_quantiles = (weight_divabs < 0) * \ 58 | ((Normal.cdf(weight_divabs * stats.norm.ppf(1-delta)) - delta) / res_neg) 59 | neg_quantiles_round_down = neg_quantiles.floor().long() 60 | neg_quantiles_round_up = torch.minimum(neg_quantiles.ceil().long(), torch.tensor(M-1)) 61 | mask = (torch.abs(weight_divabs - q_neg[neg_quantiles_round_down]) <= torch.abs(weight_divabs - q_neg[neg_quantiles_round_up])) 62 | neg_quant_idxs = neg_quantiles_round_down * mask + neg_quantiles_round_up * (~mask) 63 | 64 | pos_quantiles = (weight_divabs >= 0) * \ 65 | ((Normal.cdf(weight_divabs * stats.norm.ppf(1-delta)) - 1/2) / res_pos) 66 | pos_quantiles_round_down = pos_quantiles.floor().long() 67 | pos_quantiles_round_up = torch.minimum(pos_quantiles.ceil().long(), torch.tensor(M)) 68 | mask = (torch.abs(weight_divabs - q_pos[pos_quantiles_round_down]) <= torch.abs(weight_divabs - q_pos[pos_quantiles_round_up])) 69 | pos_quant_idxs = pos_quantiles_round_down * mask + pos_quantiles_round_up * (~mask) 70 | 71 | idxs = neg_quant_idxs + (weight_divabs >= 0) * (pos_quant_idxs + M - 1) 72 | 73 | if self.num_bits <= 8: 74 | return idxs.to(torch.uint8) 75 | else: 76 | return idxs 77 | 78 | def _dequantize_uniform(self, weight_quant): 79 | return weight_quant.float() / (2**(self.num_bits - 1) - 1) 80 | 81 | def _dequantize_nf(self, weight_quant): 82 | Normal = torch.distributions.Normal(0, 1) 83 | M = 2**(self.num_bits-1) 84 | delta = 1/2 * (1/30 + 1/32) # as described above 85 | res_neg = (1/2 - delta) / (M - 1) # resolution for [delta, 1/2] 86 | res_pos = (1/2 - delta) / M # resolution for [1/2, 1-delta] 87 | # levels to be in [-1, 1] 88 | # quantization levels for the negative and positive halves, respectively 89 | q_neg = Normal.icdf(res_neg * torch.arange(M - 1).to(weight_quant.device) + delta) / stats.norm.ppf(1-delta) 90 | q_pos = Normal.icdf(res_pos * torch.arange(M + 1).to(weight_quant.device) + 1/2) / stats.norm.ppf(1-delta) 91 | q_levels = torch.cat((q_neg, q_pos)) 92 | return q_levels[weight_quant.long()] 93 | 94 | def quantize_block(self, weight, epsilon=1e-8): 95 | if len(weight.shape) != 2: 96 | raise ValueError(f"Only support 2D matrix, but your input has {len(weight.shape)} dimensions.") 97 | if weight.shape[0] * weight.shape[1] % self.block_size != 0: 98 | raise ValueError( 99 | f"Weight with shape ({weight.shape[0]} x {weight.shape[1]}) " 100 | f"is not dividable by block size {self.block_size}." 101 | ) 102 | 103 | weight_reshape = weight.flatten().reshape(-1, self.block_size) # (L, M*N/B) 104 | weight_max = weight_reshape.abs().max(dim=-1)[0].unsqueeze(-1) 105 | if self.method == "uniform_clipped": 106 | weight_max = weight_reshape.mean(dim=1) + 2.5 * weight_reshape.std(dim=1) 107 | weight_max = weight_max.unsqueeze(-1) 108 | weight_reshape = torch.minimum(weight_reshape, weight_max) 109 | weight_reshape = torch.maximum(weight_reshape, -weight_max) 110 | weight_max = torch.maximum(weight_max, torch.Tensor([epsilon]).to(weight.device)) 111 | weight_divabs = weight_reshape / weight_max 112 | if self.method == "normal": 113 | weight_quant = self._quantize_nf(weight_divabs) 114 | else: 115 | weight_quant = self._quantize_uniform(weight_divabs) 116 | return weight_quant, weight_max, weight.shape 117 | 118 | def dequantize_block(self, weight_quant, weight_max, weight_shape): 119 | if self.method == "normal": 120 | weight = self._dequantize_nf(weight_quant) 121 | else: 122 | weight = self._dequantize_uniform(weight_quant) 123 | 124 | return (weight * weight_max).reshape(weight_shape) 125 | 126 | def simulated_quant( 127 | quant: AbstractQuantizer, 128 | X: torch.Tensor 129 | ): 130 | if quant.num_bits >= 16: 131 | if X.dtype == torch.float32: 132 | return X.to(dtype=torch.bfloat16).float() 133 | return X 134 | return quant.dequantize_block(*quant.quantize_block(X)) 135 | 136 | class QuantizerFactory: 137 | def __init__(self, method="uniform", block_size=64): 138 | self.method = method 139 | self.block_size = block_size 140 | 141 | def get_quantizer(self, num_bits, device="cpu"): 142 | return LowMemoryQuantizer(num_bits=num_bits, method=self.method, block_size=self.block_size) 143 | 144 | def __str__(self): 145 | return f"QuantizerFactory(method={self.method}, block_size={self.block_size})" 146 | 147 | def mixed_precision_quantize( 148 | X: torch.Tensor = None, 149 | strip_widths: list[int] = 0, 150 | quantizers: list[AbstractQuantizer] = None, 151 | transposed: bool = False 152 | ) -> torch.Tensor: 153 | """ 154 | Given a matrix X, quantizes strips of columns with different bit levels. 155 | The width of each quantized strip is given by the argument strip_widths, 156 | and the corresponding quantizer object for each strip is given by the 157 | argument quantizers. 158 | 159 | There is a one-to-one correspondence between the elements of strip_widths 160 | and the elements of quantizers, so the lists must be the same length. 161 | If sum(strip_widths) is less than the total number of columns in X, the 162 | remaining columns are dropped. 163 | 164 | If transposed is set to True, it quantizes strips of rows instead. 165 | """ 166 | assert len(strip_widths) == len(quantizers) 167 | if transposed: 168 | X = X.T 169 | 170 | assert sum(strip_widths) <= X.shape[1], "sum(widths) should be less than X.shape[1]" 171 | 172 | idxs = np.hstack(([0], np.cumsum(strip_widths))) 173 | # Perform simulated quantization by quantizing and the dequantizing 174 | quantized_components = [ 175 | simulated_quant(quantizers[i], X[:, idxs[i]:idxs[i+1]]) \ 176 | for i in range(len(strip_widths)) \ 177 | if idxs[i+1] > idxs[i] 178 | ] 179 | X_quant = torch.cat(quantized_components, dim=1) 180 | return X_quant.T if transposed else X_quant 181 | 182 | def quantize_small_sv_components( 183 | X: torch.Tensor = None, 184 | r: int = 0, 185 | quantizer:AbstractQuantizer = None, 186 | transposed: bool = False 187 | ) -> torch.Tensor: 188 | """ 189 | Keep the first r columns in original dtype and quantize the last 190 | (X.shape[1] - r) columns. 191 | If "transposed" is True, then quantize rows instead of columns. 192 | 193 | The parameter `quantization_fn` allows you to specify uniform quantization 194 | (via the `quantize` function) or normal float quantization (via the 195 | `quantize_nf` function). 196 | """ 197 | 198 | if transposed: 199 | X = X.T 200 | assert r <= X.shape[1], "r should be less than X.shape[1]" 201 | 202 | if r == X.shape[1]: 203 | return X 204 | 205 | # Perform simulated quantization by quantizing and the dequantizing 206 | quantized_component = simulated_quant(quantizer, X[:, r:]) 207 | X_quant = torch.cat((X[:, :r], quantized_component), dim=1) 208 | return X_quant.T if transposed else X_quant 209 | 210 | def absmax_quantize_int8(X: torch.Tensor) -> tuple[torch.Tensor, torch.float16]: 211 | """Quantize each float16/32 data type to int8 and return the maximum value in float16""" 212 | scale = X.abs().max().item() / 127.0 213 | int8_tensor = (X / scale).round().to(torch.int8) 214 | return scale, int8_tensor 215 | 216 | def absmax_dequantize_int8(Xq: torch.Tensor, scale: torch.float16) -> torch.Tensor: 217 | """Dequantize int8 data type to float16/32""" 218 | return Xq.to(torch.float16) * scale -------------------------------------------------------------------------------- /src/caldera/decomposition/weight_compression.py: -------------------------------------------------------------------------------- 1 | from quantize_llama.hessian_offline_llama import forward_layer, accumulate 2 | from lib.utils.data_utils import sample_rp1t, sample_falcon_refinedweb 3 | 4 | from transformers import AutoModelForCausalLM, AutoTokenizer 5 | from transformers.modeling_attn_mask_utils import \ 6 | _prepare_4d_causal_attention_mask 7 | 8 | import torch 9 | import torch.multiprocessing as mp 10 | import gc 11 | 12 | from caldera.decomposition.dataclasses import * 13 | from caldera.utils.enums import DevSet 14 | from caldera.decomposition.layer_quantization import \ 15 | ActivationAwareLayerQuant 16 | 17 | from caldera.utils.enums import DevSet 18 | import os 19 | 20 | 21 | # Partially adapted from https://github.com/Cornell-RelaxML/quip-sharp 22 | 23 | 24 | class ActivationAwareWeightCompressor: 25 | """ 26 | Sets up the framework for activation aware weight compression: loads in 27 | the model and calibration dataset, and then computes the inputs to each 28 | layer and the corresponding Hessian matrix (the second moment of the 29 | inputs). The inputs and Hessians are stored in files at 30 | `hessian_save_path`. 31 | 32 | Yor can instantiate an `ActivationAwareLayerQuant` for a given layer by 33 | calling the `get_layer_quantizer` method. See `ActivationAwareLayerQuant` 34 | for more usage details. 35 | 36 | Note: if you have already done Hessian computation and the data is stored 37 | in the appropriate files w.r.t. `Hessian_save_path`, you can pass in 38 | `compute_Hessians=False` to skip the data sampling and Hessian computation. 39 | """ 40 | def __init__( 41 | self, 42 | model_params: ModelParameters = ModelParameters(), 43 | data_params: DataParameters = DataParameters(), 44 | quant_params: CalderaParams = CalderaParams(), 45 | quant_device: str = "cuda", 46 | hessian_save_path: str = "", 47 | start_at_layer: int = 0, 48 | stop_at_layer: int = None, 49 | n_sample_proc: int = 4, 50 | compute_hessians: bool = True 51 | ): 52 | try: 53 | mp.set_start_method('spawn') 54 | except RuntimeError: 55 | print(("Warning: an exception occured when setting the " 56 | "multiprocessing context. If you have previously " 57 | "instantiated an ActivationAwareWeightCompressor, you can " 58 | "ignore this.")) 59 | 60 | self.compute_hessians = compute_hessians 61 | 62 | torch.set_grad_enabled(False) 63 | os.makedirs(hessian_save_path, exist_ok=True) 64 | 65 | # passed into `ActivationAwareLayerQuant` when `get_layer_quantizer` 66 | # is called: 67 | self.hessian_save_path = hessian_save_path 68 | self.quant_params = quant_params 69 | self.quant_device = quant_device 70 | self.data_params = data_params 71 | 72 | self._setup_model_and_data( 73 | model_params, 74 | data_params, 75 | n_sample_proc 76 | ) 77 | 78 | if stop_at_layer is None: 79 | stop_at_layer = float('inf') 80 | 81 | if self.compute_hessians: 82 | # Loop through transformer layers and comput + save the Hessians 83 | for transformer_layer_index, transformer_layer \ 84 | in enumerate(self.model.model.layers): 85 | if transformer_layer_index < start_at_layer: 86 | continue 87 | if transformer_layer_index >= stop_at_layer: 88 | break 89 | self._process_layer( 90 | transformer_layer_index=transformer_layer_index, 91 | transformer_layer=transformer_layer, 92 | data_params=data_params, 93 | hessian_save_path=hessian_save_path 94 | ) 95 | 96 | def get_layer_quantizer( 97 | self, 98 | layer_idx: int, 99 | device: str = None, 100 | label: str = None 101 | ): 102 | """ 103 | Instantiates an `ActivationAwareLayerQuant` object for a given 104 | transformer layer. 105 | """ 106 | assert layer_idx >= 0 and layer_idx < len(self.model.model.layers) 107 | 108 | if device is None: 109 | device = self.quant_device 110 | 111 | return ActivationAwareLayerQuant( 112 | layer=self.model.model.layers[layer_idx], 113 | layer_idx=layer_idx, 114 | hessian_save_path=self.hessian_save_path, 115 | quant_params=self.quant_params, 116 | device=device, 117 | label=label 118 | ) 119 | 120 | def _setup_model_and_data( 121 | self, 122 | model_params: ModelParameters, 123 | data_params: DataParameters, 124 | n_sample_proc: int # Number of processes used to sample the 125 | # calibration dataset. Unrelated to Hessian 126 | # computation. 127 | ): 128 | """ 129 | Loads in the model and calibration dataset. 130 | """ 131 | # Model 132 | self.model = AutoModelForCausalLM.from_pretrained( 133 | model_params.base_model, torch_dtype="auto", low_cpu_mem_usage=True, 134 | token=model_params.token 135 | ) 136 | 137 | if self.compute_hessians: 138 | # Tokenizer 139 | tokenizer = AutoTokenizer.from_pretrained( 140 | model_params.base_model, use_fast=True, 141 | token=model_params.token 142 | ) 143 | tokenizer.pad_token = tokenizer.eos_token 144 | 145 | # Calibration dataset 146 | if data_params.devset == DevSet.RP1T: 147 | self.devset = sample_rp1t(tokenizer, 148 | data_params.devset_size, 149 | data_params.context_length, 150 | nproc=n_sample_proc) 151 | self.dev_emb = self.model.model.embed_tokens(self.devset) 152 | elif data_params.devset == DevSet.FALCON: 153 | self.devset = sample_falcon_refinedweb( 154 | tokenizer, data_params.devset_size, 155 | data_params.context_length, 156 | nproc=n_sample_proc 157 | ) 158 | self.dev_emb = self.model.model.embed_tokens(self.devset) 159 | else: 160 | raise NotImplementedError("Dataset not implemented yet") 161 | self.dev_emb.share_memory_() 162 | 163 | # Attention mask and position IDs 164 | self.position_ids = torch.arange( 165 | data_params.context_length, dtype=torch.int64 166 | )[None, :] + torch.zeros( 167 | data_params.batch_size, 168 | data_params.context_length, 169 | dtype=torch.int64 170 | ) 171 | 172 | if hasattr(self.model.config, 'sliding_window'): 173 | self.attention_mask = _prepare_4d_causal_attention_mask( 174 | None, (data_params.batch_size, data_params.context_length), 175 | self.dev_emb[0:data_params.batch_size], 0, 176 | sliding_window=self.model.config.sliding_window 177 | ) 178 | else: 179 | self.attention_mask = _prepare_4d_causal_attention_mask( 180 | None, (data_params.batch_size, data_params.context_length), 181 | self.dev_emb[0:data_params.batch_size], 0 182 | ) 183 | 184 | def _process_layer( 185 | self, 186 | transformer_layer_index: int, 187 | transformer_layer: torch.nn.Module, 188 | data_params: DataParameters, 189 | hessian_save_path: str 190 | ): 191 | """ 192 | Compute the layer inputs and Hessians via the same process as 193 | quip-sharp/quantize_llama/hessian_offline_llama.py. 194 | """ 195 | # Check that there are four layers (QKV + 4 MLP), as expected 196 | assert (len([ 197 | m for m in transformer_layer.modules() 198 | if isinstance(m, torch.nn.Linear) 199 | ]) == 7) 200 | 201 | chunk_size = min(data_params.chunk_size, len(self.dev_emb)) 202 | 203 | devices_available = data_params.devices if \ 204 | data_params.devices is not None else \ 205 | range(torch.cuda.device_count()) 206 | ngpus = min(len(devices_available), len(self.dev_emb) // chunk_size) 207 | devices = devices_available[:ngpus] 208 | print(f"Computing hessians on {devices}") 209 | 210 | manager = mp.get_context('spawn').Manager() 211 | in_q = manager.Queue() 212 | out_q = manager.Queue() 213 | 214 | accumulate_proc = mp.Process( 215 | target=accumulate, 216 | args=( 217 | out_q, None, ngpus, 218 | AccumulatorArgs(save_path=hessian_save_path), 219 | transformer_layer_index 220 | ) 221 | ) 222 | accumulate_proc.start() 223 | 224 | forward_procs = [] 225 | for device in devices: 226 | p = mp.Process( 227 | target=forward_layer, 228 | args=( 229 | transformer_layer, 230 | self.position_ids, 231 | self.attention_mask, 232 | data_params.batch_size, 233 | device, in_q, out_q 234 | ) 235 | ) 236 | p.start() 237 | forward_procs.append(p) 238 | 239 | assert len(self.dev_emb) % data_params.batch_size == 0 and \ 240 | chunk_size % data_params.batch_size == 0 241 | i = 0 242 | while i < len(self.dev_emb): 243 | next = min(i + chunk_size, len(self.dev_emb)) 244 | in_q.put(self.dev_emb[i:next]) 245 | i = next 246 | 247 | for device in devices: 248 | in_q.put(None) 249 | 250 | for p in forward_procs: 251 | p.join() 252 | 253 | accumulate_proc.join() 254 | 255 | transformer_layer.cpu() 256 | # self.model.model.layers[transformer_layer_index] = None 257 | gc.collect() 258 | torch.cuda.empty_cache() 259 | 260 | print(f"done processing layer {transformer_layer_index}") 261 | -------------------------------------------------------------------------------- /scripts/quantize_save_llama.py: -------------------------------------------------------------------------------- 1 | import random 2 | import sys 3 | 4 | import numpy as np 5 | from caldera.decomposition.dataclasses import * 6 | from caldera.decomposition.weight_compression import ActivationAwareWeightCompressor 7 | from caldera.utils.enums import TransformerSubLayers 8 | import gc 9 | from transformers import ( 10 | AutoModelForCausalLM, 11 | HfArgumentParser, 12 | AutoModelForSequenceClassification, 13 | LlamaForCausalLM 14 | ) 15 | import torch 16 | import os 17 | import torch.multiprocessing as mp 18 | import glog 19 | from lib.utils import graph_wrapper 20 | import shutil 21 | 22 | SUBLAYER_TO_STRING = { 23 | TransformerSubLayers.KEY: "Key Projection (attn)", 24 | TransformerSubLayers.QUERY: "Query Projection (attn)", 25 | TransformerSubLayers.VALUE: "Value Projection (attn)", 26 | TransformerSubLayers.O: "O Projection (attn)", 27 | TransformerSubLayers.GATE: "Gate Projection (mlp)", 28 | TransformerSubLayers.UP: "Up Projection (mlp)", 29 | TransformerSubLayers.DOWN: "Down Projection (mlp)", 30 | } 31 | 32 | 33 | @dataclass 34 | class Arguments: 35 | hessian_save_path: str = field( 36 | metadata={"help": "Path in which the Hessians were stored"} 37 | ) 38 | model_save_path: str = field( 39 | metadata={"help": ("Path in which to save the quantized model, e.g., artifacts/model.pt")} 40 | ) 41 | base_model: str = field( 42 | metadata={ 43 | "help": ( 44 | "Path of the model that is being quantized, as " 45 | "either a local or a Huggingface path" 46 | ) 47 | } 48 | ) 49 | devices: list[str] = field( 50 | metadata={ 51 | "help": ( 52 | "List of devices to use for quantization, e.g. " 53 | '"cuda:0 cuda:1 cuda:2 cuda:3"' 54 | ) 55 | } 56 | ) 57 | ft_rank: int = field( 58 | default=64, 59 | metadata={ 60 | "help": ( 61 | "Number of columns of L and rows of R, in the decomposition" 62 | "W approx. Q + LR to finetune. The remaining columns will " 63 | "remain fixed." 64 | ) 65 | }, 66 | ) 67 | token: str = field( 68 | default="", metadata={"help": "Huggingface token for private models."} 69 | ) 70 | start_layer: int = field( 71 | default=0, 72 | metadata={ 73 | "help": "Layer index to start quantizing from (to resume quantization from an interrupt)" 74 | }, 75 | ) 76 | stop_layer: int = field( 77 | default=int(sys.maxsize), 78 | metadata={ 79 | "help": "Layer index to stop quantizing at (to resume quantization from an interrupt)" 80 | }, 81 | ) 82 | 83 | random_seed: int = field( 84 | default=42, 85 | metadata={ 86 | "help": "Random seed for reproducibility." 87 | }, 88 | ) 89 | 90 | 91 | def quant_layer( 92 | in_q, 93 | model_save_path, 94 | base_model, 95 | config, 96 | ft_rank, 97 | grad_ckpt, 98 | device, 99 | data_params, 100 | quant_params, 101 | hessian_save_path, 102 | random_seed 103 | ): 104 | np.random.seed(random_seed) 105 | torch.manual_seed(random_seed) 106 | torch.cuda.manual_seed_all(random_seed) 107 | random.seed(random_seed) 108 | 109 | model = AutoModelForCausalLM.from_pretrained( 110 | base_model, torch_dtype="auto", low_cpu_mem_usage=True 111 | ).cpu() 112 | 113 | while True: 114 | layer_idx = in_q.get() 115 | 116 | if layer_idx is None: 117 | return 118 | 119 | weight_compressor = ActivationAwareWeightCompressor( 120 | model_params=ModelParameters(base_model), 121 | data_params=data_params, 122 | hessian_save_path=hessian_save_path, 123 | quant_params=quant_params, 124 | compute_hessians=False, 125 | ) 126 | layer_quant = weight_compressor.get_layer_quantizer(layer_idx, device) 127 | 128 | with torch.no_grad(): 129 | layer = model.model.layers[layer_idx] 130 | 131 | for sublayer in layer_quant.sublayer_info.keys(): 132 | print(f"Quantizing layer {layer_idx}, {SUBLAYER_TO_STRING[sublayer]}") 133 | layer_quant.compress_sublayer(sublayer) 134 | 135 | attr_names = layer_quant.sublayer_info[sublayer].out_key.split(".") 136 | setattr( 137 | getattr(layer, attr_names[0]), 138 | attr_names[1], 139 | layer_quant.get_quantized_linear_layer( 140 | sublayer, ft_rank, grad_ckpt 141 | ), 142 | ) 143 | layer_quant.clean_up_sublayer(sublayer) 144 | layer = layer.cpu() 145 | torch.save(layer, f"{model_save_path}/layers/quant_layer_{layer_idx}.pt") 146 | del layer_quant 147 | gc.collect() 148 | torch.cuda.empty_cache() 149 | 150 | 151 | def quantize_save_llama( 152 | base_model: str, 153 | hessian_save_path: str, 154 | model_save_path: str, 155 | token: str = "", 156 | ft_rank: int = 64, 157 | grad_ckpt: bool = True, 158 | data_params: DataParameters = DataParameters(), 159 | quant_params: CalderaParams = CalderaParams(), 160 | quant_devices=["cuda"], 161 | start_layer=0, 162 | stop_layer=int(sys.maxsize), 163 | random_seed: int = 42, 164 | ): 165 | os.makedirs(f"{model_save_path}/layers", exist_ok=True) 166 | mp.set_start_method("spawn") 167 | 168 | if token: 169 | model = AutoModelForCausalLM.from_pretrained( 170 | base_model, torch_dtype="auto", low_cpu_mem_usage=True, token=token 171 | ).cpu() 172 | else: 173 | model = AutoModelForCausalLM.from_pretrained( 174 | base_model, torch_dtype="auto", low_cpu_mem_usage=True 175 | ).cpu() 176 | 177 | model_config = model.config 178 | n_layers = len(model.model.layers) 179 | del model 180 | gc.collect() 181 | torch.cuda.empty_cache() 182 | 183 | manager = mp.get_context("spawn").Manager() 184 | in_q = manager.Queue() 185 | quant_procs = [] 186 | 187 | for device in quant_devices: 188 | p = mp.Process( 189 | target=quant_layer, 190 | args=( 191 | in_q, 192 | model_save_path, 193 | base_model, 194 | model_config, 195 | ft_rank, 196 | grad_ckpt, 197 | device, 198 | data_params, 199 | quant_params, 200 | hessian_save_path, 201 | random_seed, 202 | ), 203 | ) 204 | p.start() 205 | quant_procs.append(p) 206 | 207 | stop_layer: int = min(stop_layer, n_layers) 208 | for layer_idx in range(start_layer, stop_layer): 209 | in_q.put(layer_idx) 210 | 211 | for _ in quant_devices: 212 | in_q.put(None) 213 | 214 | for p in quant_procs: 215 | p.join() 216 | 217 | # now save the full model 218 | model = load_layers_cpu(model_save_path, base_model) 219 | shutil.rmtree(f'{model_save_path}/') 220 | torch.save(model, f"{model_save_path}") 221 | 222 | def load_layers_cpu( 223 | model_save_path, 224 | base_model, 225 | ): 226 | model = AutoModelForCausalLM.from_pretrained( 227 | base_model, torch_dtype='auto', device_map="cpu", low_cpu_mem_usage=True 228 | ).cpu() 229 | 230 | for layer_idx in range(len(model.model.layers)): 231 | layer = torch.load( 232 | f"{model_save_path}/layers/quant_layer_{layer_idx}.pt", 233 | map_location="cpu" 234 | ) 235 | layer.post_attention_layernorm.weight.requires_grad = False 236 | layer.input_layernorm.weight.requires_grad = False 237 | 238 | for sublayer in [ 239 | layer.self_attn.q_proj, layer.self_attn.k_proj, layer.self_attn.v_proj, 240 | layer.self_attn.o_proj, layer.mlp.gate_proj, layer.mlp.up_proj, 241 | layer.mlp.down_proj 242 | ]: 243 | if sublayer.ft_rank > 0: 244 | sublayer.L_ft = torch.nn.Parameter(sublayer.L_ft.contiguous(), requires_grad=True) 245 | sublayer.R_ft = torch.nn.Parameter(sublayer.R_ft.contiguous(), requires_grad=True) 246 | 247 | model.model.layers[layer_idx] = layer 248 | 249 | return model 250 | 251 | def load_quantized_model( 252 | model_save_path, 253 | base_model, 254 | device, 255 | sequence_classification=False, 256 | seq_class_num_labels=2, 257 | cuda_graph=False, 258 | ): 259 | model = torch.load(model_save_path, map_location=device).to(device) 260 | if cuda_graph: 261 | graph_model = graph_wrapper.get_graph_wrapper(AutoModelForCausalLM, device="cpu").from_pretrained( 262 | base_model, torch_dtype='auto', device_map="cpu", low_cpu_mem_usage=True, 263 | use_flash_attention_2=True 264 | ).to("cpu") 265 | for i in range(len(graph_model.model.layers)): 266 | graph_model.model.layers[i].self_attn.q_proj = model.model.layers[i].self_attn.q_proj 267 | graph_model.model.layers[i].self_attn.k_proj = model.model.layers[i].self_attn.k_proj 268 | graph_model.model.layers[i].self_attn.v_proj = model.model.layers[i].self_attn.v_proj 269 | graph_model.model.layers[i].self_attn.o_proj = model.model.layers[i].self_attn.o_proj 270 | graph_model.model.layers[i].mlp = model.model.layers[i].mlp 271 | graph_model.model.layers[i].post_attention_layernorm = graph_model.model.layers[i].post_attention_layernorm.to(device) 272 | graph_model.model.layers[i].input_layernorm = graph_model.model.layers[i].input_layernorm.to(device) 273 | graph_model.model.norm = graph_model.model.norm.to(device) 274 | graph_model.model.embed_tokens = graph_model.model.embed_tokens.to(device) 275 | graph_model.lm_head = graph_model.lm_head.to(device) 276 | graph_model.graph_device = device 277 | model = graph_model.to(device ) 278 | 279 | elif sequence_classification: 280 | seq_model = AutoModelForSequenceClassification.from_pretrained( 281 | base_model, torch_dtype='auto', device_map="cpu", low_cpu_mem_usage=True, num_labels=seq_class_num_labels 282 | ).cpu() 283 | seq_model.score = seq_model.score.to(device) 284 | seq_model.score.weight.requires_grad = True 285 | model.model.embed_tokens = model.model.embed_tokens.to(device) 286 | seq_model.model.layers = model.model.layers 287 | model = seq_model.to(device) 288 | 289 | if not sequence_classification: 290 | model.lm_head.weight.requires_grad = False 291 | model.model.embed_tokens.weight.requires_grad = False 292 | model.model.norm.weight.requires_grad = False 293 | for layer in model.model.layers: 294 | layer.post_attention_layernorm.weight.requires_grad = False 295 | layer.input_layernorm.weight.requires_grad = False 296 | 297 | return model 298 | 299 | 300 | if __name__ == "__main__": 301 | glog.setLevel("WARN") 302 | 303 | parser = HfArgumentParser([Arguments, CalderaParams, QuIPArgs]) 304 | 305 | args, quant_params, quip_args = parser.parse_args_into_dataclasses() 306 | quant_params.quip_args = quip_args 307 | quantize_save_llama( 308 | base_model=args.base_model, 309 | hessian_save_path=args.hessian_save_path, 310 | model_save_path=args.model_save_path, 311 | token=args.token, 312 | ft_rank=args.ft_rank, 313 | grad_ckpt=False, 314 | data_params=DataParameters(), 315 | quant_params=quant_params, 316 | quant_devices=args.devices, 317 | start_layer=args.start_layer, 318 | stop_layer=args.stop_layer, 319 | random_seed=args.random_seed, 320 | ) 321 | -------------------------------------------------------------------------------- /scripts/finetune_wikitext.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import transformers 3 | from transformers import ( 4 | AutoConfig, 5 | AutoTokenizer, 6 | Trainer, 7 | default_data_collator, 8 | TrainingArguments 9 | ) 10 | from transformers.testing_utils import CaptureLogger 11 | import math 12 | import logging 13 | from itertools import chain 14 | import evaluate 15 | from dataclasses import dataclass, field 16 | from typing import Optional 17 | from transformers.utils.versions import require_version 18 | from quantize_save_llama import load_quantized_model 19 | from accelerate import Accelerator 20 | import os 21 | 22 | 23 | @dataclass 24 | class Arguments: 25 | model_save_path: str = field(metadata={ 26 | "help": ("Path in which the quantized model was saved via " 27 | "quantize_save_llama.py") 28 | }) 29 | base_model: str = field(metadata={ 30 | "help": ("Path of the original model, as either a local or a " 31 | "Huggingface path") 32 | }) 33 | 34 | @dataclass 35 | class DataTrainingArguments: 36 | """ 37 | Arguments pertaining to what data we are going to input our model for training and eval. 38 | """ 39 | 40 | dataset_name: Optional[str] = field( 41 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 42 | ) 43 | dataset_config_name: Optional[str] = field( 44 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 45 | ) 46 | train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) 47 | validation_file: Optional[str] = field( 48 | default=None, 49 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, 50 | ) 51 | max_train_samples: Optional[int] = field( 52 | default=None, 53 | metadata={ 54 | "help": ( 55 | "For debugging purposes or quicker training, truncate the number of training examples to this " 56 | "value if set." 57 | ) 58 | }, 59 | ) 60 | max_eval_samples: Optional[int] = field( 61 | default=None, 62 | metadata={ 63 | "help": ( 64 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 65 | "value if set." 66 | ) 67 | }, 68 | ) 69 | streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"}) 70 | block_size: Optional[int] = field( 71 | default=None, 72 | metadata={ 73 | "help": ( 74 | "Optional input sequence length after tokenization. " 75 | "The training dataset will be truncated in block of this size for training. " 76 | "Default to the model max input length for single sentence inputs (take into account special tokens)." 77 | ) 78 | }, 79 | ) 80 | overwrite_cache: bool = field( 81 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 82 | ) 83 | validation_split_percentage: Optional[int] = field( 84 | default=5, 85 | metadata={ 86 | "help": "The percentage of the train set used as validation set in case there's no validation split" 87 | }, 88 | ) 89 | preprocessing_num_workers: Optional[int] = field( 90 | default=None, 91 | metadata={"help": "The number of processes to use for the preprocessing."}, 92 | ) 93 | keep_linebreaks: bool = field( 94 | default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."} 95 | ) 96 | 97 | def __post_init__(self): 98 | if self.streaming: 99 | require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`") 100 | 101 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 102 | raise ValueError("Need either a dataset name or a training/validation file.") 103 | else: 104 | if self.train_file is not None: 105 | extension = self.train_file.split(".")[-1] 106 | assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." 107 | if self.validation_file is not None: 108 | extension = self.validation_file.split(".")[-1] 109 | assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." 110 | 111 | 112 | def main( 113 | model_save_path: str, 114 | base_model: str, 115 | output_dir: str, 116 | data_args: DataTrainingArguments, 117 | training_args: TrainingArguments 118 | ): 119 | os.makedirs(output_dir, exist_ok=True) 120 | accelerator = Accelerator() 121 | model = load_quantized_model(model_save_path, base_model, accelerator.device) 122 | for name, param in model.named_parameters(): 123 | if 'SU' in name or 'SV' in name: 124 | param.requires_grad = False 125 | 126 | logger = logging.getLogger(__name__) 127 | 128 | config = AutoConfig.from_pretrained(base_model) 129 | 130 | raw_datasets = load_dataset( 131 | data_args.dataset_name, 132 | data_args.dataset_config_name, 133 | streaming=data_args.streaming, 134 | ) 135 | 136 | tokenizer = AutoTokenizer.from_pretrained( 137 | base_model, use_fast=True) 138 | embedding_size = model.get_input_embeddings().weight.shape[0] 139 | 140 | if len(tokenizer) > embedding_size: 141 | model.resize_token_embeddings(len(tokenizer)) 142 | 143 | column_names = list(raw_datasets["validation"].features) 144 | text_column_name = "text" if "text" in column_names else column_names[0] 145 | 146 | tok_logger = transformers.utils.logging.get_logger( 147 | "transformers.tokenization_utils_base" 148 | ) 149 | 150 | def tokenize_function(examples): 151 | with CaptureLogger(tok_logger) as cl: 152 | output = tokenizer(examples[text_column_name]) 153 | # clm input could be much much longer than block_size 154 | if "Token indices sequence length is longer than the" in cl.out: 155 | tok_logger.warning( 156 | "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long " 157 | "input will be chunked into smaller bits before being passed " 158 | "to the model." 159 | ) 160 | return output 161 | 162 | with training_args.main_process_first(desc="dataset map tokenization"): 163 | if not data_args.streaming: 164 | tokenized_datasets = raw_datasets.map( 165 | tokenize_function, 166 | batched=True, 167 | num_proc=data_args.preprocessing_num_workers, 168 | remove_columns=column_names, 169 | load_from_cache_file=not data_args.overwrite_cache, 170 | desc="Running tokenizer on dataset", 171 | ) 172 | else: 173 | tokenized_datasets = raw_datasets.map( 174 | tokenize_function, 175 | batched=True, 176 | remove_columns=column_names, 177 | ) 178 | 179 | if hasattr(config, "max_position_embeddings"): 180 | max_pos_embeddings = config.max_position_embeddings 181 | else: 182 | # Define a default value if the attribute is missing in the config. 183 | max_pos_embeddings = 1024 184 | 185 | if data_args.block_size is None: 186 | block_size = tokenizer.model_max_length 187 | if block_size > max_pos_embeddings: 188 | logger.warning( 189 | "The tokenizer picked seems to have a very large " 190 | f"`model_max_length` ({tokenizer.model_max_length}). " 191 | f"Using block_size={min(1024, max_pos_embeddings)} instead. " 192 | "You can change that default value by passing --block_size xxx." 193 | ) 194 | if max_pos_embeddings > 0: 195 | block_size = min(1024, max_pos_embeddings) 196 | else: 197 | block_size = 1024 198 | else: 199 | if data_args.block_size > tokenizer.model_max_length: 200 | logger.warning( 201 | f"The block_size passed ({data_args.block_size}) is larger " 202 | "than the maximum length for the model " 203 | f"({tokenizer.model_max_length}). Using " 204 | f"block_size={tokenizer.model_max_length}." 205 | ) 206 | block_size = min(data_args.block_size, tokenizer.model_max_length) 207 | 208 | # Main data processing function that will concatenate all texts from 209 | # our dataset and generate chunks of block_size. 210 | def group_texts(examples): 211 | # Concatenate all texts. 212 | concatenated_examples = {k: list(chain(*examples[k])) 213 | for k in examples.keys()} 214 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 215 | # We drop the small remainder, and if the total_length < block_size 216 | # we exclude this batch and return an empty dict. 217 | # We could add padding if the model supported it instead of this drop, 218 | # you can customize this part to your needs. 219 | total_length = (total_length // block_size) * block_size 220 | # Split by chunks of max_len. 221 | result = { 222 | k: [t[i: i + block_size] 223 | for i in range(0, total_length, block_size)] 224 | for k, t in concatenated_examples.items() 225 | } 226 | result["labels"] = result["input_ids"].copy() 227 | return result 228 | 229 | # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder 230 | # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower 231 | # to preprocess. 232 | # 233 | # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: 234 | # https://huggingface.co/docs/datasets/process#map 235 | 236 | with training_args.main_process_first(desc="grouping texts together"): 237 | if not data_args.streaming: 238 | lm_datasets = tokenized_datasets.map( 239 | group_texts, 240 | batched=True, 241 | num_proc=data_args.preprocessing_num_workers, 242 | load_from_cache_file=not data_args.overwrite_cache, 243 | desc=f"Grouping texts in chunks of {block_size}", 244 | ) 245 | else: 246 | lm_datasets = tokenized_datasets.map( 247 | group_texts, 248 | batched=True, 249 | ) 250 | eval_dataset = lm_datasets["validation"] 251 | if data_args.max_eval_samples is not None: 252 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 253 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 254 | train_dataset = lm_datasets["train"] 255 | if data_args.max_train_samples is not None: 256 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 257 | train_dataset = train_dataset.select(range(max_train_samples)) 258 | 259 | trainer = Trainer( 260 | model=model, 261 | args=training_args, 262 | train_dataset=train_dataset, 263 | eval_dataset=eval_dataset, 264 | tokenizer=tokenizer, 265 | # Data collator will default to DataCollatorWithPadding 266 | data_collator=default_data_collator 267 | ) 268 | 269 | # train? 270 | trainer.train() 271 | trainer.save_model() 272 | 273 | metrics = trainer.evaluate() 274 | 275 | max_eval_samples = len(eval_dataset) 276 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 277 | try: 278 | perplexity = math.exp(metrics["eval_loss"]) 279 | except OverflowError: 280 | perplexity = float("inf") 281 | print(f"perplexity: {perplexity}") 282 | 283 | 284 | if __name__ == "__main__": 285 | parser = transformers.HfArgumentParser([ 286 | Arguments, DataTrainingArguments, TrainingArguments]) 287 | args, data_args, training_args = parser.parse_args_into_dataclasses() 288 | main( 289 | args.model_save_path, 290 | args.base_model, 291 | training_args.output_dir, 292 | data_args=data_args, 293 | training_args=training_args 294 | ) 295 | -------------------------------------------------------------------------------- /src/caldera/decomposition/layer_quantization.py: -------------------------------------------------------------------------------- 1 | from lib.utils.data_utils import flat_to_sym 2 | from lib.utils.math_utils import regularize_H 3 | 4 | import torch 5 | import matplotlib.pyplot as plt 6 | import gc 7 | 8 | from caldera.decomposition.dataclasses import * 9 | from caldera.utils.enums import TransformerSubLayers 10 | from caldera.decomposition.alg import caldera 11 | 12 | from dataclasses import asdict 13 | import json 14 | import datetime 15 | 16 | 17 | # Maps number of bits to name of the QuIP# lattice quantizer 18 | BITS_TO_CODEBOOK = { 19 | 2: 'E8P12', 20 | 3: 'E8P12RVQ3B', 21 | 4: 'E8P12RVQ4B' 22 | } 23 | 24 | 25 | class ActivationAwareLayerQuant: 26 | """ 27 | For a given transformer layer, decomposes each sublayer (i.e., one of 28 | {query, key, value, out, gate, up, down}) into Q + LR, where Q is a 29 | matrix quantized according to QuIP# and L, R are low-rank and quantized 30 | factors via an iterative, activation-aware procedure. 31 | 32 | This class is instantiated by `ActivationAwareWeightCompressor` upon 33 | calling the method `get_layer_quantizer`. 34 | 35 | Usage example: 36 | ``` 37 | # Instantiate ActivationAwareWeightCompressor. This will automatically 38 | # compute all the Hessians upon initialization, unless you pass in the 39 | # `stop_at_layer` keyword argument 40 | weight_compressor = ActivationAwareWeightCompressor( 41 | model_params=ModelParameters( 42 | base_model="meta-llama/Llama-2-7b-hf" 43 | ), 44 | data_params=DataParameters( 45 | devset=DevSet.FALCON, 46 | devset_size=128, 47 | context_length=4096, 48 | batch_size=2, 49 | devices=["cuda", "cuda:2"] 50 | ), 51 | hessian_save_path="./data/hessians/llama-7b", 52 | quant_params=ActivationAwareQuantParams( 53 | Q_bits=2, 54 | L_bits=3, 55 | R_bits=4, 56 | rank=256, 57 | iters=25, 58 | lplr_iters=3 59 | ), 60 | quant_device="cuda:2" 61 | ) 62 | 63 | # Instantiate ActivationAwareLayerQuant for a layer 64 | layer_quant = weight_compressor.get_layer_quantizer(layer_idx=12) 65 | 66 | # Quantize one sub-layer using CALDERA 67 | layer_quant.compress_sublayer(TransformerSubLayers.VALUE) 68 | 69 | # Plot the quantization error 70 | best_error = layer_quant.min_error(TransformerSubLayers.VALUE) 71 | layer_quant.plot_errors(TransformerSubLayers.VALUE) 72 | 73 | # Delete Q, L, and R matrices to free GPU memory before quantizing another 74 | # sublayer (optional) 75 | layer_quant.clean_up_sublayer(TransformerSubLayers.VALUE) 76 | ``` 77 | """ 78 | def __init__( 79 | self, 80 | layer: torch.nn.Module, 81 | layer_idx: int, 82 | hessian_save_path: str = "", 83 | quant_params: CalderaParams = CalderaParams(), 84 | label: str = "CALDERA", 85 | device: str = "cuda", 86 | ): 87 | self.hessian_save_path = hessian_save_path 88 | self.layer = layer.to(device) 89 | self.label = label 90 | if label is None: 91 | self.label = "CALDERA" 92 | 93 | self.layer_idx = layer_idx 94 | self.quant_params = quant_params 95 | if not self.quant_params.update_order: 96 | if not self.quant_params.compute_low_rank_factors: 97 | self.quant_params.update_order = ["Q"] 98 | elif not self.quant_params.compute_quantized_component: 99 | self.quant_params.update_order = ["LR"] 100 | else: 101 | self.quant_params.update_order = ["LR", "Q"] 102 | if not self.quant_params.compute_low_rank_factors: 103 | self.quant_params.rank = 0 104 | self.device = device 105 | 106 | self._set_sublayer_weights_and_info() 107 | 108 | def compress_sublayer(self, sublayer): 109 | """ 110 | Decomposes a sublayer (e.g., query, key, value, etc.) by alternating 111 | between LDLQ and LPLR. 112 | 113 | The sublayer argument must be a member of the TransformerSubLayers 114 | enum. 115 | """ 116 | assert sublayer in self.sublayer_info.keys(), \ 117 | ("Invalid sublayer! Please use a member of the " 118 | "TransformerSubLayers enum.") 119 | sublayer_info = self.sublayer_info[sublayer] 120 | sublayer_info.started_quant = True 121 | 122 | sublayer_info.caldera.W = sublayer_info.sublayer.weight.to(self.device).float() 123 | 124 | W = sublayer_info.caldera.W 125 | H = self._get_H(sublayer) 126 | 127 | sublayer_info.caldera = caldera(self.quant_params, W, H, self.device) 128 | 129 | def get_quantized_linear_layer(self, sublayer, ft_rank, grad_ckpt=True): 130 | from caldera.decomposition.quantized_layer import \ 131 | CalderaQuantizedLinear 132 | 133 | sublayer_info = self._get_sublayer_info_and_check_sublayer(sublayer) 134 | 135 | have_L_codebook = self.quant_params.lattice_quant_LR and self.quant_params.compute_low_rank_factors and \ 136 | self.quant_params.L_bits < 16 137 | have_R_codebook = self.quant_params.lattice_quant_LR and self.quant_params.compute_low_rank_factors and \ 138 | self.quant_params.R_bits < 16 139 | 140 | L_codebook_version = BITS_TO_CODEBOOK[self.quant_params.L_bits] if have_L_codebook else None 141 | 142 | R_codebook_version = BITS_TO_CODEBOOK[self.quant_params.R_bits] if have_R_codebook else None 143 | 144 | if ft_rank < self.quant_params.rank and self.quant_params.compute_low_rank_factors and \ 145 | ((L_codebook_version is None and self.quant_params.L_bits < 16) 146 | or (R_codebook_version is None and 147 | self.quant_params.R_bits < 16)): 148 | raise NotImplementedError( 149 | "Only lattice quantization for L and R implemented so far" 150 | ) 151 | 152 | return CalderaQuantizedLinear( 153 | # Dimensions 154 | in_features=sublayer_info.caldera.W.shape[1], 155 | out_features=sublayer_info.caldera.W.shape[0], 156 | # Codebooks 157 | Q_codebook_version=BITS_TO_CODEBOOK[self.quant_params.Q_bits], 158 | L_codebook_version=L_codebook_version, 159 | R_codebook_version=R_codebook_version, 160 | # L and R 161 | L=sublayer_info.caldera.L, 162 | R=sublayer_info.caldera.R, 163 | # Quantized idxs 164 | L_idxs=sublayer_info.caldera.L_idxs, 165 | R_idxs=sublayer_info.caldera.R_idxs, 166 | Q_idxs=sublayer_info.caldera.Q_idxs, 167 | # Scaling 168 | L_scale=sublayer_info.caldera.L_scale, 169 | R_scale=sublayer_info.caldera.R_scale, 170 | Q_scale=sublayer_info.caldera.Q_scale, 171 | global_scale=sublayer_info.caldera.global_scale, 172 | scaleWH=sublayer_info.caldera.scaleWH, 173 | # Hadamard 174 | hadamard=(self.quant_params.hadamard_transform 175 | or self.quant_params.full_quip_sharp), 176 | # SU and SV 177 | SU=sublayer_info.caldera.SU, 178 | SV=sublayer_info.caldera.SV, 179 | # Rank and fine-tuning 180 | rank=max(self.quant_params.rank, 181 | self.quant_params.quip_args.lora_rank), 182 | ft_rank=ft_rank, 183 | grad_ckpt=grad_ckpt 184 | ) 185 | 186 | def plot_errors(self, sublayer, plot_first_iter=True, savefile=None): 187 | """ 188 | Plot the per-iteration approximation errors for a given sublayer 189 | (i.e., a member of the TransformerSubLayers enum). 190 | """ 191 | sublayer_info = self._get_sublayer_info_and_check_sublayer(sublayer) 192 | self._plot(sublayer_info.caldera.errors, plot_first_iter, savefile=savefile) 193 | 194 | def _plot(self, errors, plot_first_iter, savefile=None): 195 | COLORS = ['b', 'r', 'm'] 196 | plt.figure(figsize=(12, 4)) 197 | title = f"Activation-Aware Error per iteration: {self.label}" 198 | plt.title(title) 199 | for i, key in enumerate(errors.keys()): 200 | if plot_first_iter or len(errors[key] )== 1: 201 | plt.plot(range(len(errors[key])), errors[key], 202 | marker='o', linestyle='-', color=COLORS[i], label=key) 203 | else: 204 | plt.plot(range(1, len(errors[key])), errors[key][1:], 205 | marker='o', linestyle='-', color=COLORS[i], label=key) 206 | plt.yscale('log') 207 | plt.legend() 208 | plt.grid(True) 209 | if savefile is not None: 210 | plt.savefig(savefile) 211 | plt.close() 212 | else: 213 | plt.show() 214 | 215 | def plot_errors_json(self, file, plot_first_iter=True, savefile=None): 216 | with open(file, "r") as infile: 217 | json_str = infile.read() 218 | data = json.loads(json_str) 219 | self._plot(data["per_iter_errors"], plot_first_iter, savefile=savefile) 220 | 221 | def min_error(self, sublayer): 222 | """ 223 | Returns the minimum value of the activation-aware loss for a given 224 | sublayer (i.e., a member of the TransformerSubLayers enum). 225 | """ 226 | sublayer_info = self._get_sublayer_info_and_check_sublayer(sublayer) 227 | min_error = float('inf') 228 | for key in sublayer_info.caldera.errors: 229 | min_error = min(min_error, min(sublayer_info.caldera.errors[key])) 230 | return min_error 231 | 232 | def export_errors_json(self, sublayer, savefile): 233 | sublayer_info = self._get_sublayer_info_and_check_sublayer(sublayer) 234 | param_dict = asdict(self.quant_params) 235 | del param_dict["quant_factory_Q"] 236 | del param_dict["quant_factory_LR"] 237 | 238 | now = datetime.datetime.now() 239 | 240 | data = { 241 | "per_iter_errors": sublayer_info.caldera.errors, 242 | "layer_idx": self.layer_idx, 243 | "sublayer": sublayer_info.key, 244 | "datetime": str(now), 245 | "timestamp": datetime.datetime.timestamp(now), 246 | "params": param_dict 247 | } 248 | json_object = json.dumps(data) 249 | with open(savefile + ".json", "w") as out: 250 | out.write(json_object) 251 | 252 | def clean_up_sublayer(self, sublayer): 253 | """ 254 | Delete Q, L, and R matrices for a sublayer to free GPU memory. 255 | """ 256 | sublayer_info = self._get_sublayer_info_and_check_sublayer(sublayer) 257 | 258 | self.sublayer_info[sublayer] = SubLayerInfo( 259 | sublayer=sublayer_info.sublayer, key=sublayer_info.key 260 | ) 261 | gc.collect() 262 | torch.cuda.empty_cache() 263 | 264 | def _set_sublayer_weights_and_info(self): 265 | """ 266 | Initializes a SubLayerInfo object for each of the seven transformer 267 | sublayers. Called upon instantiation. 268 | """ 269 | self.sublayer_info = { 270 | TransformerSubLayers.KEY: SubLayerInfo( 271 | sublayer=self.layer.self_attn.k_proj, key="qkv", 272 | out_key="self_attn.k_proj"), 273 | TransformerSubLayers.QUERY: SubLayerInfo( 274 | sublayer=self.layer.self_attn.q_proj, key="qkv", 275 | out_key="self_attn.q_proj"), 276 | TransformerSubLayers.VALUE: SubLayerInfo( 277 | sublayer=self.layer.self_attn.v_proj, key="qkv", 278 | out_key="self_attn.v_proj"), 279 | TransformerSubLayers.O: SubLayerInfo( 280 | sublayer=self.layer.self_attn.o_proj, key="o", 281 | out_key="self_attn.o_proj"), 282 | TransformerSubLayers.UP: SubLayerInfo( 283 | sublayer=self.layer.mlp.up_proj, key="up", 284 | out_key="mlp.up_proj"), 285 | TransformerSubLayers.GATE: SubLayerInfo( 286 | sublayer=self.layer.mlp.gate_proj, key="up", 287 | out_key="mlp.gate_proj"), 288 | TransformerSubLayers.DOWN: SubLayerInfo( 289 | sublayer=self.layer.mlp.down_proj, key="down", 290 | out_key="mlp.down_proj") 291 | } 292 | 293 | def _get_H(self, sublayer): 294 | """ 295 | Reads the Hessian (sum X_i X_i^T) for a specific sublayer from the 296 | corresponding file (in which it was saved by 297 | ActivationAwareWeightCompressor). 298 | """ 299 | sublayer_key = self.sublayer_info[sublayer].key 300 | H_data = torch.load( 301 | f'{self.hessian_save_path}/{self.layer_idx}_{sublayer_key}.pt', 302 | map_location=torch.device(self.device), 303 | ) 304 | H = flat_to_sym(H_data['flatH'], H_data['n']) 305 | 306 | # Add back in the mean 307 | mu = H_data['mu'] 308 | H.add_(mu[None, :] * mu[:, None]) 309 | H = regularize_H(H, H_data['n'], self.quant_params.quip_args.sigma_reg) 310 | 311 | return H 312 | 313 | def _get_sublayer_info_and_check_sublayer(self, sublayer): 314 | assert sublayer in self.sublayer_info.keys(), \ 315 | ("Invalid sublayer! Please use a member of the " 316 | "TransformerSubLayers enum.") 317 | 318 | sublayer_info = self.sublayer_info[sublayer] 319 | assert sublayer_info.started_quant, \ 320 | "Sublayer has't been quantized yet!" 321 | return sublayer_info 322 | -------------------------------------------------------------------------------- /src/caldera/decomposition/quantized_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from lib import codebook 5 | from lib.utils import get_hadK 6 | 7 | import quiptools_cuda 8 | from lib.utils.matmul_had import matmul_hadU_cuda, matmul_hadUt_cuda 9 | 10 | 11 | # Adapted from https://github.com/Cornell-RelaxML/quip-sharp 12 | 13 | 14 | class LatticeQuantizedParameter(nn.Module): 15 | def __init__( 16 | self, 17 | in_features, 18 | out_features, 19 | idxs, 20 | scale, 21 | codebook_version, 22 | transposed=False 23 | ): 24 | super(LatticeQuantizedParameter, self).__init__() 25 | 26 | self.in_features = in_features 27 | self.out_features = out_features 28 | self.transposed = transposed 29 | self.scale = scale 30 | 31 | self.codebook_version = codebook_version 32 | self.codebook = codebook.codebook_id[codebook_version][1](inference=True).to(torch.float16).to(idxs.device) 33 | 34 | idxs_dev = idxs.device 35 | self.idxs = idxs.cpu() 36 | codebook_class = codebook.get_quantized_class( 37 | codebook.get_id(codebook_version) 38 | )(self.idxs.device) 39 | 40 | split_idxs = codebook_class.maybe_unpack_idxs( 41 | self.idxs 42 | ) 43 | self.idxs_list = [] 44 | for i in range(len(split_idxs)): 45 | self.register_buffer(f'idxs_{i}', split_idxs[i].to(idxs_dev)) 46 | exec(f'self.idxs_list.append(self.idxs_{i})') 47 | 48 | self.idxs = None 49 | 50 | def get_W_decompressed(self): 51 | n = self.in_features 52 | m = self.out_features 53 | if self.codebook_version == 'E8P12': 54 | return quiptools_cuda.decompress_packed_e8p( 55 | self.idxs_list[0].view(m // 16, n // 64, 8, 4), 56 | self.codebook.grid_packed_abs) * self.scale 57 | elif self.codebook_version == 'E8P12RVQ4B': 58 | resid_scale = self.codebook.opt_resid_scale 59 | return (quiptools_cuda.decompress_packed_e8p( 60 | self.idxs_list[0].view(m // 16, n // 64, 8, 4), 61 | self.codebook.grid_packed_abs) + \ 62 | quiptools_cuda.decompress_packed_e8p( 63 | self.idxs_list[1].view(m // 16, n // 64, 8, 4), 64 | self.codebook.grid_packed_abs 65 | ) / resid_scale) * self.scale 66 | 67 | else: 68 | W_decompressed = quiptools_cuda.decompress_packed_e8p( 69 | self.idxs_list[0].view(m // 16, n // 64, 8, 4), 70 | self.codebook.grid_packed_abs) 71 | 72 | W_resid_decompressed = torch.zeros( 73 | self.idxs_list[1].shape[0], 74 | 64 * self.idxs_list[1].shape[-1], 75 | device=self.idxs_list[1].device, dtype=torch.float16 76 | ) 77 | return (W_decompressed + W_resid_decompressed / resid_scale) * self.scale 78 | 79 | def forward(self, x, float_precision=False): 80 | dtype = x.dtype 81 | n = self.in_features 82 | m = self.out_features 83 | 84 | if self.idxs_list[0].device != x.device: 85 | for i in range(len(self.idxs_list)): 86 | self.idxs_list[i] = self.idxs_list[0].to(x.device) 87 | self.codebook = self.codebook.to(x.device) 88 | x = x / 32 89 | if not float_precision: 90 | x = x.half() 91 | else: 92 | x = x.float() 93 | 94 | if self.codebook_version == 'E8P12': 95 | if x.size(0) == 1 and not self.transposed and not float_precision: 96 | x = quiptools_cuda.decode_matvec_e8p( 97 | x[0].to(torch.float16), 98 | self.idxs_list[0].view(m // 16, n // 64, 8, 4), 99 | self.codebook.grid_packed_abs).unsqueeze(0) 100 | else: 101 | W_decompressed = quiptools_cuda.decompress_packed_e8p( 102 | self.idxs_list[0].view(m // 16, n // 64, 8, 4), 103 | self.codebook.grid_packed_abs) 104 | if float_precision: 105 | W_decompressed = W_decompressed.float() 106 | if self.transposed: 107 | x = (x @ W_decompressed) 108 | else: 109 | x = (x @ W_decompressed.T) 110 | 111 | elif self.codebook_version == 'E8P12RVQ3B': 112 | resid_scale = self.codebook.opt_resid_scale 113 | x16 = x.to(torch.float16) 114 | if x.shape[0] == 1 and not self.transposed and not float_precision: 115 | x_padded = torch.zeros( 116 | 8, x16.shape[1], dtype=torch.float16, device=x16.device) 117 | x_padded[0] = x16[0] 118 | z = torch.zeros( 119 | 8, m, dtype=torch.float16, device=x_padded.device) 120 | quiptools_cuda.lookupmatmul_e81b_k8( 121 | x_padded / resid_scale, self.idxs_list[1], 122 | self.codebook.e81b_grid, z 123 | ) 124 | 125 | x = quiptools_cuda.decode_matvec_e8p( 126 | x16[0], self.idxs_list[0].view(m // 16, n // 64, 8, 4), 127 | self.codebook.grid_packed_abs) + z[0] 128 | x = x.unsqueeze(0) 129 | 130 | else: 131 | W_decompressed = quiptools_cuda.decompress_packed_e8p( 132 | self.idxs_list[0].view(m // 16, n // 64, 8, 4), 133 | self.codebook.grid_packed_abs) 134 | 135 | W_resid_decompressed = torch.zeros( 136 | self.idxs_list[1].shape[0], 137 | 64 * self.idxs_list[1].shape[-1], 138 | device=self.idxs_list[1].device, dtype=torch.float16 139 | ) 140 | 141 | quiptools_cuda.decompress_e81b_packed( 142 | self.idxs_list[1], self.codebook.e81b_grid, 143 | W_resid_decompressed 144 | ) 145 | 146 | if float_precision: 147 | W_decompressed = W_decompressed.to(torch.bfloat16) 148 | W_resid_decompressed = W_resid_decompressed.to(torch.bfloat16) 149 | 150 | if self.transposed: 151 | x = (x @ (W_decompressed + 152 | W_resid_decompressed / resid_scale)) 153 | else: 154 | x = (x @ (W_decompressed + 155 | W_resid_decompressed / resid_scale).T) 156 | else: 157 | resid_scale = self.codebook.opt_resid_scale 158 | if x.size(0) == 1 and not self.transposed and not float_precision: 159 | x16 = x[0].to(torch.float16) 160 | x = (quiptools_cuda.decode_matvec_e8p( 161 | x16, self.idxs_list[0].view(m // 16, n // 64, 8, 4), 162 | self.codebook.grid_packed_abs) + 163 | quiptools_cuda.decode_matvec_e8p( 164 | x16 / resid_scale, self.idxs_list[1].view( 165 | m // 16, n // 64, 8, 4), 166 | self.codebook.grid_packed_abs)).unsqueeze(0) 167 | else: 168 | W_decompressed = quiptools_cuda.decompress_packed_e8p( 169 | self.idxs_list[0].view(m // 16, n // 64, 8, 4), 170 | self.codebook.grid_packed_abs) + \ 171 | quiptools_cuda.decompress_packed_e8p( 172 | self.idxs_list[1].view(m // 16, n // 64, 8, 4), 173 | self.codebook.grid_packed_abs 174 | ) / resid_scale 175 | if float_precision: 176 | W_decompressed = W_decompressed.float() 177 | if self.transposed: 178 | x = (x @ W_decompressed) 179 | else: 180 | x = (x @ W_decompressed.T) 181 | x = x.to(dtype) 182 | x *= self.scale * 32 183 | return x 184 | 185 | 186 | class CalderaQuantizedLinear(nn.Module): 187 | 188 | def __init__( 189 | self, 190 | in_features, 191 | out_features, 192 | Q_codebook_version, 193 | L_codebook_version, 194 | R_codebook_version, 195 | L, R, 196 | L_idxs, R_idxs, Q_idxs, 197 | L_scale, R_scale, Q_scale, 198 | global_scale, 199 | scaleWH, 200 | hadamard, 201 | SU, SV, 202 | rank=64, 203 | ft_rank=64, 204 | grad_ckpt=True 205 | ): 206 | super(CalderaQuantizedLinear, self).__init__() 207 | 208 | self.rank = rank 209 | self.ft_rank = ft_rank 210 | 211 | self.in_features = in_features 212 | self.out_features = out_features 213 | 214 | self.hadamard = hadamard 215 | self.global_scale = global_scale 216 | self.scaleWH = scaleWH 217 | if self.scaleWH is not None: 218 | self.scaleWH = nn.Parameter(self.scaleWH, requires_grad=False) 219 | 220 | if Q_idxs is not None: 221 | self.Q = LatticeQuantizedParameter( 222 | in_features=in_features, 223 | out_features=out_features, 224 | idxs=Q_idxs, 225 | scale=Q_scale, 226 | codebook_version=Q_codebook_version 227 | ) 228 | else: 229 | self.Q = None 230 | 231 | self.L_idxs = L_idxs 232 | self.R_idxs = R_idxs 233 | self.L_codebook_version = L_codebook_version 234 | self.R_codebook_version = R_codebook_version 235 | self.L_scale = L_scale 236 | self.R_scale = R_scale 237 | 238 | self.split_L_and_R_for_LoRA(ft_rank, L, R) 239 | 240 | self.SU = nn.Parameter(SU, requires_grad=True) 241 | self.SV = nn.Parameter(SV, requires_grad=True) 242 | 243 | had_left, K_left = get_hadK(in_features) 244 | had_right, K_right = get_hadK(out_features) 245 | 246 | self.had_left = nn.Parameter(had_left, requires_grad=False) 247 | self.had_right = nn.Parameter(had_right, requires_grad=False) 248 | 249 | self.K_left = K_left 250 | self.K_right = K_right 251 | 252 | self.grad_ckpt = grad_ckpt 253 | 254 | def split_L_and_R_for_LoRA(self, ft_rank, L, R): 255 | if ft_rank > 0: 256 | self.L_ft = nn.Parameter( 257 | L[:, :ft_rank], requires_grad=True) 258 | self.R_ft = nn.Parameter( 259 | R[:ft_rank, :], requires_grad=True) 260 | assert self.L_ft != [] and self.R_ft != [] 261 | 262 | if self.rank > ft_rank: 263 | if self.L_codebook_version is not None: 264 | self.L = LatticeQuantizedParameter( 265 | in_features=self.out_features, 266 | out_features=self.rank - ft_rank, 267 | idxs=self.L_idxs[ft_rank:, :], 268 | scale=self.L_scale, 269 | codebook_version=self.L_codebook_version, 270 | transposed=True 271 | ) 272 | self.L_idxs = None 273 | self.quant_L = True 274 | else: 275 | self.L = nn.Parameter(L[:, ft_rank:], requires_grad=False) 276 | self.quant_L = False 277 | 278 | if self.R_codebook_version is not None: 279 | self.R = LatticeQuantizedParameter( 280 | in_features=self.in_features, 281 | out_features=self.rank - ft_rank, 282 | idxs=self.R_idxs[ft_rank:, :], 283 | scale=self.R_scale, 284 | codebook_version=self.R_codebook_version 285 | ) 286 | self.R_idxs = None 287 | self.quant_R = True 288 | else: 289 | self.R = nn.Parameter(R[ft_rank:, :], requires_grad=False) 290 | self.quant_R = False 291 | else: 292 | self.L = None 293 | self.R = None 294 | 295 | 296 | def forward(self, x): 297 | old_dtype = x.dtype 298 | x = x.float() 299 | shape = x.shape 300 | n, m = len(self.SU), len(self.SV) 301 | x = x.view(-1, n) 302 | # Preprocessing 303 | if self.scaleWH is not None: 304 | x /= self.scaleWH 305 | x = x * self.SU 306 | x = matmul_hadUt_cuda(x, self.had_left, self.K_left) 307 | 308 | # Apply Q 309 | output_no_ft = self.Q.forward(x) 310 | 311 | # Apply quantized L and R 312 | if self.L is not None: 313 | if self.quant_R: 314 | xR = self.R.forward(x, float_precision=True) 315 | else: 316 | xR = (x.float() @ self.R.T.float()) 317 | 318 | if self.quant_L: 319 | output_no_ft += self.L.forward(xR, float_precision=True) 320 | else: 321 | output_no_ft += xR.float() @ self.L.T.float() 322 | 323 | # Apply LoRA factors 324 | if self.ft_rank > 0: 325 | output = output_no_ft + x @ self.R_ft.T.float() @ self.L_ft.T.float() 326 | else: 327 | output = output_no_ft 328 | 329 | output = matmul_hadU_cuda(output, self.had_right, self.K_right) 330 | 331 | output = output * self.SV * self.global_scale 332 | if self.scaleWH is not None: 333 | output *= self.scaleWH 334 | return output.view(*shape[:-1], m).to(old_dtype) 335 | 336 | def compare_outputs(self, input, W_hat): 337 | output = self.no_ckpt_forward(input) 338 | comparison = input @ W_hat.T 339 | return (torch.linalg.matrix_norm(output - comparison, ord='fro') / 340 | torch.linalg.matrix_norm(comparison, ord='fro')).mean().item() -------------------------------------------------------------------------------- /src/caldera/decomposition/alg.py: -------------------------------------------------------------------------------- 1 | from lib.utils.math_utils import block_LDL 2 | import lib.algo.quip as quip 3 | from lib import codebook 4 | 5 | import torch 6 | from tqdm import tqdm 7 | 8 | from caldera.decomposition.dataclasses import * 9 | from caldera.utils.quantization import QuantizerFactory 10 | 11 | from collections import namedtuple 12 | from copy import deepcopy 13 | 14 | 15 | # Maps number of bits to name of the QuIP# lattice quantizer 16 | BITS_TO_CODEBOOK = { 17 | 2: 'E8P12', 18 | 3: 'E8P12RVQ3B', 19 | 4: 'E8P12RVQ4B' 20 | } 21 | 22 | def caldera( 23 | quant_params: CalderaParams, 24 | W: torch.Tensor, 25 | H: torch.Tensor = None, 26 | device: str = "cuda", 27 | use_tqdm: bool = True, 28 | scale_W: bool = True, 29 | ): 30 | """ 31 | Runs the CALDERA algorithm, to decompose a weight matrix into Q + LR, where 32 | Q is full-rank, L and R are low-rank factors, and all matrices are in a low- 33 | precision format. 34 | """ 35 | # scaling 36 | if scale_W: 37 | global_scale = W.square().mean().sqrt().item() 38 | else: 39 | global_scale = 1 40 | W = W / global_scale 41 | 42 | if H is None: 43 | H = torch.eye(W.shape[1]).to(device) 44 | 45 | # Compute the symmetric square root of H, because the data-aware 46 | # objective can be formulated as 47 | # min_{L, R} ||(W - LR - Q)H^{1/2}||_F^2. 48 | EigTuple = namedtuple("EigTuple", ["eigenvalues", "eigenvectors"]) 49 | if not quant_params.activation_aware_LR and not quant_params.activation_aware_Q \ 50 | and not quant_params.full_quip_sharp: 51 | H_sqrt = H 52 | eigH = EigTuple(torch.ones(W.shape[1]).to(device), H) 53 | else: 54 | eigH = torch.linalg.eigh(H) 55 | 56 | eigvals = eigH.eigenvalues 57 | # if eigvals.min() < quant_params.quip_args.sigma_reg: 58 | # H = H + (quant_params.quip_args.sigma_reg - eigvals.min()) * \ 59 | # torch.eye(H.shape[0], device=H.device, dtype=H.dtype) 60 | # eigvals += quant_params.quip_args.sigma_reg - eigvals.min() 61 | # eigH = EigTuple(eigvals, eigH.eigenvectors) 62 | 63 | H_sqrt = (eigH.eigenvectors @ 64 | torch.diag(torch.sqrt(eigvals)) @ 65 | eigH.eigenvectors.T) 66 | 67 | # Initialization and Hadamard transform 68 | best_decomp = CalderaDecomposition( 69 | Q=torch.zeros_like(W).float(), 70 | L=torch.zeros(W.shape[0], quant_params.rank).to(device), 71 | R=torch.zeros(quant_params.rank, W.shape[1]).to(device)) 72 | 73 | if quant_params.hadamard_transform: 74 | _, H, W, SU, SV, scaleWH = quip.incoherence_preprocess( 75 | H, W, quant_params.quip_args 76 | ) 77 | best_decomp.SU = SU.to(W.device) 78 | best_decomp.SV = SV.to(W.device) 79 | best_decomp.scaleWH = scaleWH 80 | 81 | eigH = torch.linalg.eigh(H) 82 | H_sqrt = (eigH.eigenvectors @ 83 | torch.diag(torch.sqrt(eigH.eigenvalues)) @ 84 | eigH.eigenvectors.T) 85 | else: 86 | best_decomp.scaleWH = None 87 | best_decomp.SU = torch.ones(W.shape[1]).to(W.dtype).to(W.device) 88 | best_decomp.SV = torch.ones(W.shape[0]).to(W.dtype).to(W.device) 89 | 90 | best_decomp.W = W.cpu() 91 | errors = {} 92 | for mtx in quant_params.update_order: 93 | errors[mtx] = [] 94 | 95 | min_error = float('inf') 96 | curr_decomp = deepcopy(best_decomp) 97 | 98 | updated = {mtx: False for mtx in quant_params.update_order} 99 | 100 | to_iter = range(quant_params.iters) 101 | if use_tqdm: 102 | to_iter = tqdm(to_iter) 103 | for _ in to_iter: 104 | for mtx in quant_params.update_order: 105 | if mtx == "LR": 106 | maybe_update_LR(curr_decomp, quant_params, W, H_sqrt, eigH, device) 107 | elif mtx == "Q": 108 | maybe_update_Q(curr_decomp, quant_params, W, H, device) 109 | updated[mtx] = True 110 | 111 | errors[mtx].append( 112 | activation_aware_error(W, H, curr_decomp, device) 113 | ) 114 | if errors[mtx][-1] < min_error and all(updated.values()): 115 | min_error = errors[mtx][-1] 116 | best_decomp = deepcopy(curr_decomp) 117 | best_decomp.errors = errors 118 | 119 | # Update scales 120 | best_decomp.global_scale = global_scale 121 | return best_decomp 122 | 123 | 124 | def activation_aware_error( 125 | W: torch.Tensor, 126 | H: torch.Tensor, 127 | caldera_info: CalderaDecomposition, 128 | device: str 129 | ): 130 | """ 131 | Computes the activation-aware loss for a sublayer as 132 | tr((W - W_hat) H (W - W_hat).T) / tr(W H^1/2), 133 | where H^1/2 is the symmetric square root. 134 | """ 135 | 136 | W = W.to(device).float() 137 | W_hat = caldera_info.Q + caldera_info.L @ caldera_info.R 138 | W_hat *= caldera_info.global_scale 139 | 140 | error = (torch.trace((W_hat - W) @ H @ (W_hat - W).T) / 141 | torch.trace(W @ H @ W.T)).sqrt().item() 142 | return error 143 | 144 | 145 | def get_quant_info( 146 | use_lattice_quant: bool, 147 | quant_factory: QuantizerFactory, 148 | bits: int, 149 | device: str 150 | ): 151 | cb = None 152 | quantizer = None 153 | if use_lattice_quant: 154 | cb = codebook.get_codebook(BITS_TO_CODEBOOK[bits]).to(device) 155 | else: 156 | quantizer = quant_factory.get_quantizer(bits, device) 157 | return QuantInfo( 158 | lattice_quant=use_lattice_quant, 159 | lattice_cb=cb, 160 | quant=quantizer 161 | ) 162 | 163 | 164 | def quantize_matrix( 165 | A, quant_params, 166 | quant_info: QuantInfo = None 167 | ): 168 | QuantReturn = namedtuple( 169 | 'QuantReturn', ['A_hat', 'A_idxs', 'scale'] 170 | ) 171 | if not quant_info.lattice_quant: 172 | quant_info.quant.block_size = A.shape[0] * A.shape[1] 173 | A_idxs, scales, shape = quant_info.quant.quantize_block(A) 174 | A_hat = quant_info.quant.dequantize_block(A_idxs, scales, shape) 175 | return QuantReturn(A_hat, A_idxs, scales) 176 | 177 | # Scale before quantization, as in QuIP# 178 | scale = A.square().mean().sqrt().item() 179 | 180 | m, n = A.shape 181 | 182 | A = A.reshape(-1, quant_info.lattice_cb.codesz).clone() / scale 183 | A_idxs = torch.zeros( 184 | m * n // quant_info.lattice_cb.codesz, 185 | dtype=quant_info.lattice_cb.idx_dtype, 186 | device=A.device 187 | ) 188 | K = quant_params.lattice_quant_block_size 189 | for i in range(0, A.shape[0], K): 190 | A[i:i+K], A_idxs[i:i+K] = \ 191 | quant_info.lattice_cb.quantize(A.float()[i:i+K]) 192 | A = A.reshape(m, n) 193 | A_idxs = A_idxs.reshape(m, n // quant_info.lattice_cb.codesz) 194 | 195 | A = A * scale 196 | 197 | A_idxs = quant_info.lattice_cb.maybe_pack_idxs(A_idxs) 198 | return QuantReturn(A, A_idxs, scale) 199 | 200 | 201 | def maybe_update_Q( 202 | caldera_info: CalderaDecomposition, 203 | quant_params: CalderaParams, 204 | W: torch.Tensor, 205 | H: torch.Tensor, 206 | device: str 207 | ): 208 | 209 | if quant_params.compute_quantized_component: 210 | residual = W - caldera_info.L @ caldera_info.R 211 | if not quant_params.compute_low_rank_factors: 212 | residual = W 213 | if quant_params.activation_aware_Q: 214 | update_Q_data_aware(caldera_info, quant_params, H, residual, device) 215 | else: 216 | update_Q_non_data_aware(caldera_info, quant_params, residual, device) 217 | 218 | 219 | def update_Q_non_data_aware( 220 | caldera_info: CalderaDecomposition, 221 | quant_params: CalderaParams, 222 | residual: torch.Tensor, 223 | device: str 224 | ): 225 | quant_info = get_quant_info( 226 | use_lattice_quant=quant_params.lattice_quant_Q, 227 | quant_factory=quant_params.quant_factory_Q, 228 | bits=quant_params.Q_bits, 229 | device=device 230 | ) 231 | 232 | quant_return = quantize_matrix(residual, quant_params, quant_info) 233 | caldera_info.Q = quant_return.A_hat 234 | caldera_info.Q_idxs = quant_return.A_idxs 235 | caldera_info.Q_scale = quant_return.scale 236 | 237 | 238 | def update_Q_data_aware( 239 | caldera_info: CalderaDecomposition, 240 | quant_params: CalderaParams, 241 | H: torch.Tensor, 242 | residual: torch.Tensor, 243 | device: str 244 | ): 245 | """ 246 | Performs an LDLQ update on the residual (W - LR) 247 | """ 248 | 249 | # Scale the residual, as done in the quantize_linear function of QuIP# 250 | scale = residual.square().mean().sqrt().item() 251 | residual /= scale 252 | 253 | codebook_str = BITS_TO_CODEBOOK[quant_params.Q_bits] 254 | cb = codebook.get_codebook(codebook_str).to(residual.device) 255 | 256 | if quant_params.compute_low_rank_factors and \ 257 | quant_params.Q_hessian_downdate: 258 | 259 | M = torch.linalg.cholesky(H) 260 | if quant_params.rand_svd: 261 | _, _, V = torch.svd_lowrank( 262 | caldera_info.L @ caldera_info.R @ M, 263 | quant_params.rank * 3, niter=10) 264 | V = V[:, :quant_params.rank] 265 | else: 266 | _, _, Vh = torch.linalg.svd( 267 | caldera_info.L @ caldera_info.R @ M, full_matrices=False) 268 | V = Vh.T[:, :quant_params.rank] 269 | 270 | H = H - (M @ V @ V.T @ M.T).to(H.dtype) 271 | min_eigval = torch.linalg.eigh(H).eigenvalues.min() 272 | H = H + (quant_params.quip_args.sigma_reg2 + max(-min_eigval, 0)) * torch.eye(H.shape[0], device=H.device, dtype=H.dtype) 273 | alpha = torch.diag(H).mean().abs() * quant_params.quip_args.sigma_reg2 274 | H = H + alpha * torch.eye(H.shape[0], device=H.device, dtype=H.dtype) 275 | 276 | if quant_params.full_quip_sharp: 277 | assert not quant_params.hadamard_transform, \ 278 | ("Full QuIP# incompatible with performing Hadamard transform " 279 | "on our end") 280 | assert quant_params.rank == 0 or \ 281 | not quant_params.compute_low_rank_factors, \ 282 | ("Full QuIP# incompatible with separately computing low-rank " 283 | "factors.") 284 | 285 | caldera_info.Q, attr = quip.quantize( 286 | H_orig=H, 287 | W_orig=residual, 288 | rank=0, 289 | codebook_orig=cb, 290 | args=quant_params.quip_args, 291 | device=device 292 | ) 293 | caldera_info.Q_idxs = attr['Qidxs'].to(device) 294 | 295 | caldera_info.scaleWH = attr['scaleWH'] 296 | caldera_info.SU = attr['SU'] 297 | caldera_info.SV = attr['SV'] 298 | if quant_params.quip_args.lora_rank != 0: 299 | caldera_info.L = attr['A'].to(device) / caldera_info.SV[0].abs().sqrt() 300 | caldera_info.R = attr['B'].to(device) / caldera_info.SV[0].abs().sqrt() 301 | caldera_info.L_scale = scale 302 | caldera_info.R_scale = scale 303 | caldera_info.Q -= caldera_info.L @ caldera_info.R 304 | 305 | else: 306 | # Just do LDLQ 307 | block_LDL_out = block_LDL(H, cb.codesz) 308 | assert block_LDL_out is not None 309 | 310 | L, D = block_LDL_out 311 | del block_LDL_out 312 | 313 | scale /= cb.opt_scale 314 | residual *= cb.opt_scale 315 | 316 | if quant_params.quip_args.no_use_buffered: 317 | Q, Qidxs = quip.LDLQ( 318 | residual, H, L, D, cb, quant_params.quip_args) 319 | elif quant_params.quip_args.lowmem_ldlq or \ 320 | quant_params.quip_args.use_fp64: 321 | Q, Qidxs = quip.LDLQ_buffered_lowmem( 322 | residual, H, L, D, cb, quant_params.quip_args, 323 | buf_cols=128) 324 | else: 325 | Q, Qidxs = quip.LDLQ_buffered( 326 | residual, H, L, D, cb, quant_params.quip_args, 327 | buf_cols=128) 328 | caldera_info.Q_idxs = Qidxs 329 | caldera_info.Q = Q 330 | caldera_info.Q_idxs = cb.maybe_pack_idxs(caldera_info.Q_idxs) 331 | 332 | caldera_info.Q_scale = scale 333 | caldera_info.Q *= scale 334 | 335 | 336 | def LR_init( 337 | caldera_info: CalderaDecomposition, 338 | quant_params: CalderaParams, 339 | H_sqrt: torch.Tensor, 340 | eigH: torch.Tensor, 341 | residual: torch.Tensor 342 | ): 343 | """ 344 | Runs rank-constrained regression to minimize 345 | ||(residual - LR) eigH||_F^2 346 | over L, R in closed-form. 347 | """ 348 | if quant_params.activation_aware_LR: 349 | Y = residual @ H_sqrt @ eigH.eigenvectors 350 | if quant_params.rand_svd: 351 | q = min(quant_params.rank*2, min(*caldera_info.W.shape)) 352 | U, Sigma, V = torch.svd_lowrank(Y, q) 353 | Vh = V.T 354 | else: 355 | U, Sigma, Vh = torch.linalg.svd(Y, full_matrices=False) 356 | 357 | L = U[:, :quant_params.rank] 358 | R = torch.diag(Sigma[:quant_params.rank]) @ \ 359 | Vh[:quant_params.rank, :] @ \ 360 | torch.diag(1 / eigH.eigenvalues.sqrt()) @ eigH.eigenvectors.T 361 | else: 362 | if quant_params.rand_svd: 363 | q = min(quant_params.rank*2, 364 | min(*caldera_info.W.shape)) 365 | U, Sigma, V = torch.svd_lowrank(residual, q) 366 | Vh = V.T 367 | else: 368 | U, Sigma, Vh = torch.linalg.svd(residual, full_matrices=False) 369 | L = U[:, :quant_params.rank] @ \ 370 | torch.diag(Sigma[:quant_params.rank].sqrt()) 371 | R = torch.diag(Sigma[:quant_params.rank].sqrt()) @ \ 372 | Vh[:quant_params.rank, :] 373 | return L, R 374 | 375 | def maybe_update_LR( 376 | caldera_info: CalderaDecomposition, 377 | quant_params: CalderaParams, 378 | W: torch.Tensor, 379 | H_sqrt: torch.Tensor, 380 | eigH, 381 | device 382 | ): 383 | if quant_params.compute_low_rank_factors: 384 | residual = W - caldera_info.Q 385 | update_LR(caldera_info, quant_params, residual, H_sqrt, eigH, device) 386 | 387 | 388 | def update_LR( 389 | caldera_info: CalderaDecomposition, 390 | quant_params: CalderaParams, 391 | residual: torch.Tensor, 392 | H_sqrt: torch.Tensor, 393 | eigH, 394 | device 395 | ): 396 | """ 397 | Run LPLR on the residual (W - Q) 398 | """ 399 | data_aware = quant_params.activation_aware_LR 400 | 401 | # Initialization of L, R 402 | L, R = LR_init(caldera_info, quant_params, H_sqrt, eigH, residual) 403 | 404 | if quant_params.L_bits < 16 or quant_params.R_bits < 16: 405 | quant_info_L = get_quant_info( 406 | use_lattice_quant=quant_params.lattice_quant_LR, 407 | quant_factory=quant_params.quant_factory_LR, 408 | bits=quant_params.L_bits, 409 | device=device 410 | ) 411 | quant_info_R = get_quant_info( 412 | use_lattice_quant=quant_params.lattice_quant_LR, 413 | quant_factory=quant_params.quant_factory_LR, 414 | bits=quant_params.R_bits, 415 | device=device 416 | ) 417 | 418 | best_L, best_R = L, R 419 | best_L_quant_out, best_R_quant_out = None, None 420 | best_error = float('inf') 421 | 422 | for _ in range(quant_params.lplr_iters): 423 | # L 424 | if data_aware: 425 | L = torch.linalg.lstsq((R @ H_sqrt).T, (residual @ H_sqrt).T)[0].T 426 | if torch.isnan(L).any(): 427 | L = (residual @ H_sqrt) @ torch.linalg.pinv(R @ H_sqrt) 428 | else: 429 | L = torch.linalg.lstsq(R.T, residual.T)[0].T 430 | if torch.isnan(R).any(): 431 | L = residual @ torch.linalg.pinv(R) 432 | 433 | quant_out_L = quantize_matrix(L.T, quant_params, quant_info_L) 434 | L = quant_out_L.A_hat.T 435 | 436 | # R 437 | R = torch.linalg.lstsq(L, residual)[0] 438 | if torch.isnan(R).any(): 439 | R = torch.linalg.pinv(L) @ residual 440 | 441 | quant_out_R = quantize_matrix(R, quant_params, quant_info_R) 442 | R = quant_out_R.A_hat 443 | 444 | error = torch.linalg.matrix_norm((residual - L @ R) @ H_sqrt) #/ \ 445 | # torch.linalg.matrix_norm((residual + caldera_info.Q) @ H_sqrt) 446 | if error < best_error: 447 | best_L, best_R = L, R 448 | best_L_quant_out = quant_out_L 449 | best_R_quant_out = quant_out_R 450 | best_error = error 451 | 452 | caldera_info.L_idxs = best_L_quant_out.A_idxs 453 | caldera_info.R_idxs = best_R_quant_out.A_idxs 454 | caldera_info.L_scale = best_L_quant_out.scale 455 | caldera_info.R_scale = best_R_quant_out.scale 456 | 457 | L, R = best_L, best_R 458 | 459 | caldera_info.L = L 460 | caldera_info.R = R -------------------------------------------------------------------------------- /scripts/finetune_winogrande.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/tatsu-lab/stanford_alpaca/blob/main/train.py 2 | 3 | import copy 4 | import logging 5 | import os 6 | from dataclasses import dataclass, field 7 | from typing import Dict, Optional, Sequence 8 | 9 | import torch 10 | import transformers 11 | from transformers import default_data_collator, DataCollatorWithPadding, get_scheduler 12 | 13 | from datasets import load_dataset, Dataset 14 | from quantize_save_llama import load_quantized_model 15 | from torch.utils.data import DataLoader 16 | 17 | from accelerate import Accelerator 18 | from accelerate.logging import get_logger 19 | from accelerate.utils import set_seed 20 | import datasets 21 | 22 | import math 23 | import evaluate 24 | from tqdm import tqdm 25 | 26 | 27 | logger = get_logger(__name__) 28 | 29 | 30 | @dataclass 31 | class ModelArguments: 32 | base_model: str = field() 33 | model_name_or_path: str = field() 34 | token: Optional[str] = field( 35 | default=None, 36 | metadata={"help": "HF token to access to private models, e.g., meta-llama"}, 37 | ) 38 | 39 | 40 | @dataclass 41 | class TrainingArguments(transformers.TrainingArguments): 42 | max_seq_length: int = field(default=128, metadata={ 43 | "help": ("The maximum total input sequence length after tokenization. Sequences longer " 44 | "than this will be truncated, sequences shorter will be padded.") 45 | }) 46 | with_tracking: bool = field(default=False, metadata={ 47 | "help": "Whether to report eval accuracies to, e.g., tensorboard." 48 | }) 49 | num_warmup_steps: int = field(default=0) 50 | 51 | def run_eval(eval_dataloader, model, accelerator, metric): 52 | samples_seen = 0 53 | 54 | for step, batch in enumerate(eval_dataloader): 55 | with torch.autocast(device_type='cuda', dtype=torch.bfloat16): 56 | with torch.no_grad(): 57 | # batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()} 58 | predictions = model(**batch) 59 | 60 | references = batch["labels"] 61 | predictions, references = accelerator.gather_for_metrics((predictions, references)) 62 | 63 | # print(predictions.logits) 64 | if predictions.logits.isnan().any(): 65 | print("WARNING NaN OUTPUT LOGITS") 66 | predictions = predictions.logits.argmax(dim=-1) 67 | # If we are in a multiprocess environment, the last batch has duplicates 68 | if step == len(eval_dataloader) - 1: 69 | predictions = predictions[: len(eval_dataloader.dataset) - samples_seen] 70 | references = references[: len(eval_dataloader.dataset) - samples_seen] 71 | else: 72 | samples_seen += references.shape[0] 73 | metric.add_batch( 74 | predictions=predictions, 75 | references=references, 76 | ) 77 | return metric.compute() 78 | 79 | 80 | def train(): 81 | parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments)) 82 | model_args, training_args = parser.parse_args_into_dataclasses() 83 | 84 | transformers.utils.send_example_telemetry("run_clm_no_trainer", training_args) 85 | 86 | training_args.output_dir = os.path.join( 87 | training_args.output_dir, 88 | f"ep_{int(training_args.num_train_epochs)}_lr_{training_args.learning_rate}_seed_{training_args.seed}" 89 | ) 90 | 91 | accelerator_log_kwargs = {} 92 | 93 | if training_args.with_tracking: 94 | accelerator_log_kwargs["log_with"] = training_args.report_to 95 | accelerator_log_kwargs["project_dir"] = training_args.output_dir 96 | 97 | accelerator = Accelerator( 98 | gradient_accumulation_steps=training_args.gradient_accumulation_steps, 99 | **accelerator_log_kwargs) 100 | 101 | logging.basicConfig( 102 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 103 | datefmt="%m/%d/%Y %H:%M:%S", 104 | level=logging.INFO, 105 | ) 106 | 107 | logger.info(accelerator.state, main_process_only=False) 108 | if accelerator.is_local_main_process: 109 | datasets.utils.logging.set_verbosity_warning() 110 | transformers.utils.logging.set_verbosity_info() 111 | else: 112 | datasets.utils.logging.set_verbosity_error() 113 | transformers.utils.logging.set_verbosity_error() 114 | 115 | # If passed along, set the training seed now. 116 | if accelerator.is_main_process: 117 | if training_args.seed is not None: 118 | set_seed(training_args.seed) 119 | 120 | if training_args.output_dir is not None: 121 | os.makedirs(training_args.output_dir, exist_ok=True) 122 | # writer = SummaryWriter(args.output_dir) 123 | accelerator.wait_for_everyone() 124 | 125 | model = load_quantized_model( 126 | model_args.model_name_or_path, model_args.base_model, 127 | accelerator.device, sequence_classification=True 128 | ).to(torch.bfloat16) 129 | 130 | for name, param in model.named_parameters(): 131 | if 'SU' in name or 'SV' in name: 132 | param.requires_grad = False 133 | 134 | tokenizer = transformers.AutoTokenizer.from_pretrained( 135 | model_args.base_model, 136 | token=model_args.token, 137 | use_fast=True, 138 | ) 139 | 140 | model.config.pad_token_id = 0 141 | tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token 142 | tokenizer.padding_side = "left" # Allow batched inference 143 | tokenizer.truncation_side = "left" 144 | 145 | embedding_size = model.get_input_embeddings().weight.shape[0] 146 | if len(tokenizer) > embedding_size: 147 | model.resize_token_embeddings(len(tokenizer)) 148 | 149 | train_data = load_dataset("winogrande", "winogrande_xl", split="train").to_pandas() 150 | eval_data = load_dataset("winogrande", "winogrande_xl", split="validation").to_pandas() 151 | 152 | def preprocess(data): 153 | data["text1"] = data.apply(lambda x: x["sentence"].replace("_", x["option1"]), axis=1) 154 | data["text2"] = data.apply(lambda x: x["sentence"].replace("_", x["option2"]), axis=1) 155 | data["label"] = data.apply(lambda x: int(x["answer"]) - 1, axis=1) 156 | return Dataset.from_pandas(data) 157 | 158 | def tokenize(sample): 159 | model_inps = tokenizer(sample["text1"], sample["text2"], padding="max_length", 160 | truncation=True, max_length=training_args.max_seq_length) 161 | model_inps["labels"] = sample["label"] 162 | return model_inps 163 | 164 | with accelerator.main_process_first(): 165 | train_data = preprocess(train_data) 166 | eval_data = preprocess(eval_data) 167 | tokenized_train_data = train_data.map(tokenize, batched=True, desc="Tokenizing training data", 168 | remove_columns=train_data.column_names) 169 | tokenized_eval_data = eval_data.map(tokenize, batched=True, desc="Tokenizing eval data", 170 | remove_columns=train_data.column_names) 171 | 172 | print(tokenized_train_data) 173 | # print(tokenized_train_data["labels"]) 174 | 175 | train_dataloader = DataLoader( 176 | tokenized_train_data, collate_fn=default_data_collator, 177 | batch_size=training_args.per_device_train_batch_size, shuffle=True 178 | ) 179 | eval_dataloader = DataLoader( 180 | tokenized_eval_data, collate_fn=default_data_collator, 181 | batch_size=training_args.per_device_eval_batch_size 182 | ) 183 | 184 | # Optimizer 185 | # Split weights in two groups, one with weight decay and the other not. 186 | no_decay = ["bias", "LayerNorm.weight"] 187 | optimizer_grouped_parameters = [ 188 | { 189 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 190 | "weight_decay": training_args.weight_decay, 191 | }, 192 | { 193 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 194 | "weight_decay": 0.0, 195 | }, 196 | ] 197 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=training_args.learning_rate) 198 | 199 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) 200 | max_train_steps = training_args.num_train_epochs * num_update_steps_per_epoch 201 | 202 | lr_scheduler = get_scheduler( 203 | name=training_args.lr_scheduler_type, 204 | optimizer=optimizer, 205 | num_warmup_steps=training_args.num_warmup_steps, 206 | num_training_steps=max_train_steps, 207 | ) 208 | 209 | # Prepare everything with our `accelerator`. 210 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( 211 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler 212 | ) 213 | 214 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) 215 | max_train_steps = training_args.num_train_epochs * num_update_steps_per_epoch 216 | # Afterwards we recalculate our number of training epochs 217 | training_args.num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) 218 | 219 | # Figure out how many steps we should save the Accelerator states 220 | save_steps = training_args.save_steps 221 | if save_steps is not None and isinstance(save_steps, str) and save_steps.isdigit(): 222 | save_steps = int(save_steps) 223 | 224 | # We need to initialize the trackers we use, and also store our configuration. 225 | # The trackers initializes automatically on the main process. 226 | if training_args.with_tracking: 227 | experiment_config = vars(training_args) 228 | # TensorBoard cannot log Enums, need the raw value 229 | # experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value 230 | accelerator.init_trackers("winogrande_no_trainer", {}) 231 | 232 | metric = evaluate.load("accuracy") 233 | max_train_steps = int(max_train_steps) 234 | 235 | # Train! 236 | total_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps 237 | 238 | print("***** Running training *****") 239 | print(f" Num examples = {len(tokenized_train_data)}") 240 | print(f" Num Epochs = {training_args.num_train_epochs}") 241 | print(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}") 242 | print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 243 | print(f" Gradient Accumulation steps = {training_args.gradient_accumulation_steps}") 244 | print(f" Total optimization steps = {max_train_steps}") 245 | # Only show the progress bar once on each machine. 246 | progress_bar = tqdm(range(max_train_steps)) 247 | completed_steps = 0 248 | starting_epoch = 0 249 | # Potentially load in the weights and states from a previous save 250 | if training_args.resume_from_checkpoint: 251 | if training_args.resume_from_checkpoint is not None or training_args.resume_from_checkpoint != "": 252 | accelerator.print(f"Resumed from checkpoint: {training_args.resume_from_checkpoint}") 253 | accelerator.load_state(training_args.resume_from_checkpoint) 254 | path = os.path.basename(training_args.resume_from_checkpoint) 255 | else: 256 | # Get the most recent checkpoint 257 | dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] 258 | dirs.sort(key=os.path.getctime) 259 | path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last 260 | # Extract `epoch_{i}` or `step_{i}` 261 | training_difference = os.path.splitext(path)[0] 262 | 263 | if "epoch" in training_difference: 264 | starting_epoch = int(training_difference.replace("epoch_", "")) + 1 265 | resume_step = None 266 | else: 267 | resume_step = int(training_difference.replace("step_", "")) 268 | starting_epoch = resume_step // len(train_dataloader) 269 | resume_step -= starting_epoch * len(train_dataloader) 270 | completed_steps = starting_epoch * len(train_dataloader) + resume_step 271 | progress_bar.update(completed_steps) 272 | 273 | performace_dict = {} 274 | for epoch in range(starting_epoch, training_args.num_train_epochs): 275 | model.train() 276 | if training_args.with_tracking: 277 | total_loss = 0 278 | # if training_args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None: 279 | # # We skip the first `n` batches in the dataloader when resuming from a checkpoint 280 | # train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) 281 | 282 | for step, batch in enumerate(train_dataloader): 283 | # print(batch["attention_mask"]) 284 | # # We need to skip steps until we reach the resumed step 285 | if training_args.resume_from_checkpoint and epoch == starting_epoch: 286 | if resume_step is not None and step < resume_step: 287 | # completed_steps += 1 288 | continue 289 | 290 | # print(batch["labels"], batch["labels"].shape) 291 | with accelerator.accumulate(model): 292 | # print(batch) 293 | # return 294 | outputs = model(**batch) 295 | loss = outputs.loss 296 | # We keep track of the loss at each epoch 297 | if training_args.with_tracking: 298 | total_loss += loss.detach().float() 299 | accelerator.backward(loss) 300 | if completed_steps % 50: 301 | accelerator.print(f"Epoch: {epoch} | Step: {completed_steps} | Loss: {loss}") 302 | optimizer.step() 303 | lr_scheduler.step() 304 | optimizer.zero_grad() 305 | 306 | # Checks if the accelerator has performed an optimization step behind the scenes 307 | if accelerator.sync_gradients: 308 | progress_bar.update(1) 309 | completed_steps += 1 310 | 311 | if isinstance(save_steps, int): 312 | if completed_steps % save_steps == 0: 313 | output_dir = f"step_{completed_steps}" 314 | if output_dir is not None: 315 | output_dir = os.path.join(training_args.output_dir, output_dir) 316 | accelerator.save_state(output_dir) 317 | 318 | if completed_steps % 200 == 0: 319 | model.eval() 320 | eval_metric = run_eval(eval_dataloader, model, accelerator, metric) 321 | 322 | # for k, v in eval_metric.items(): 323 | # writer.add_scalar(f"eval/{args.output_dir}/{k}", v, global_step=completed_steps) 324 | logger.info( 325 | f"seed {training_args.seed} learning rate {training_args.learning_rate} " 326 | + f"epoch {epoch}: {eval_metric}") 327 | performace_dict[completed_steps]=eval_metric["accuracy"] 328 | 329 | if training_args.with_tracking and total_loss != 0: 330 | accelerator.log( 331 | { 332 | "accuracy": eval_metric, 333 | "train_loss": total_loss.item() / len(train_dataloader), 334 | "epoch": epoch, 335 | "step": completed_steps, 336 | }, 337 | step=completed_steps, 338 | ) 339 | 340 | if completed_steps >= max_train_steps: 341 | break 342 | 343 | if completed_steps % 500 == 0 and step % training_args.gradient_accumulation_steps == 0 : 344 | logger.info(f"The current loss is {loss}") 345 | 346 | 347 | if training_args.save_steps == "epoch": 348 | output_dir = f"epoch_{epoch}" 349 | if output_dir is not None: 350 | output_dir = os.path.join(training_args.output_dir, output_dir) 351 | 352 | accelerator.save_state(output_dir) 353 | 354 | accelerator.wait_for_everyone() 355 | if accelerator.is_main_process: 356 | from safetensors.torch import save_file 357 | save_file(accelerator.get_state_dict(model), output_dir + "/model.safetensors") 358 | print("Saved checkpoint in ", output_dir + "/model.safetensors") 359 | accelerator.wait_for_everyone() 360 | 361 | model.eval() 362 | eval_metric = run_eval(eval_dataloader, model, accelerator, metric) 363 | 364 | # for k, v in eval_metric.items(): 365 | # writer.add_scalar(f"eval/{args.output_dir}/{k}", v, global_step=epoch) 366 | logger.info(f"{training_args.output_dir} | epoch {epoch}: {eval_metric}") 367 | if training_args.with_tracking and total_loss != 0: 368 | accelerator.log( 369 | { 370 | "accuracy": eval_metric, 371 | "train_loss": total_loss.item() / len(train_dataloader), 372 | "epoch": epoch, 373 | "step": completed_steps, 374 | }, 375 | step=completed_steps, 376 | ) 377 | performace_dict[epoch] = eval_metric["accuracy"] 378 | 379 | # torch.save(model.state_dict(), args.output_dir + f"/{completed_steps}.bin") 380 | 381 | model.eval() 382 | eval_metric = run_eval(eval_dataloader, model, accelerator, metric) 383 | 384 | # for k, v in eval_metric.items(): 385 | # writer.add_scalar(f"eval/{args.output_dir}/{k}", v, global_step=completed_steps) 386 | print(f"{training_args.output_dir} | step {completed_steps}: {eval_metric}") 387 | if not eval: 388 | best_performance = max(performace_dict.values()) 389 | max_keys = [k for k, v in performace_dict.items() if 390 | v == best_performance] # getting all keys containing the `maximum` 391 | print(f"seed {training_args.seed} learning rate {args.learning_rate} " 392 | + f"The best performance is at {max_keys[0]} with {best_performance}") 393 | 394 | 395 | if __name__ == "__main__": 396 | train() 397 | -------------------------------------------------------------------------------- /scripts/finetune_glue.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License.f 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Finetuning a 🤗 Transformers model for sequence classification on GLUE.""" 16 | import argparse 17 | import json 18 | import logging 19 | import math 20 | import os 21 | import random 22 | import numpy as np 23 | 24 | 25 | import datasets 26 | import evaluate 27 | import torch 28 | from accelerate import Accelerator 29 | from accelerate.logging import get_logger 30 | from accelerate.utils import set_seed 31 | 32 | from datasets import load_dataset 33 | from torch.utils.data import DataLoader 34 | from tqdm.auto import tqdm 35 | 36 | import transformers 37 | from transformers import ( 38 | AutoTokenizer, 39 | AutoConfig, 40 | DataCollatorWithPadding, 41 | PretrainedConfig, 42 | SchedulerType, 43 | default_data_collator, 44 | get_scheduler, 45 | ) 46 | from quantize_save_llama import load_quantized_model 47 | 48 | logger = get_logger(__name__) 49 | 50 | task_to_keys = { 51 | "cola": ("sentence", None), 52 | "mnli": ("premise", "hypothesis"), 53 | "mrpc": ("sentence1", "sentence2"), 54 | "qnli": ("question", "sentence"), 55 | "qqp": ("question1", "question2"), 56 | "rte": ("sentence1", "sentence2"), 57 | "sst2": ("sentence", None), 58 | "stsb": ("sentence1", "sentence2"), 59 | "wnli": ("sentence1", "sentence2"), 60 | } 61 | 62 | task_to_metrics = { 63 | "cola": "matthews_correlation", 64 | "mnli": "accuracy", 65 | "mrpc": "f1", 66 | "qnli": "accuracy", 67 | "qqp": "f1", 68 | "rte": "accuracy", 69 | "sst2": "accuracy", 70 | "stsb": "pearson", 71 | } 72 | 73 | device = torch.device(0) 74 | DEBUG = False 75 | 76 | def set_seed(seed: int): 77 | random.seed(seed) 78 | np.random.seed(seed) 79 | torch.manual_seed(seed) 80 | torch.cuda.manual_seed_all(seed) 81 | 82 | 83 | def parse_args(): 84 | parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task") 85 | parser.add_argument( 86 | "--task_name", 87 | type=str, 88 | default=None, 89 | help="The name of the glue task to train on.", 90 | choices=list(task_to_keys.keys()), 91 | ) 92 | parser.add_argument( 93 | "--train_file", type=str, default=None, help="A csv or a json file containing the training data." 94 | ) 95 | parser.add_argument( 96 | "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data." 97 | ) 98 | parser.add_argument( 99 | "--max_length", 100 | type=int, 101 | default=128, 102 | help=( 103 | "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated," 104 | " sequences shorter will be padded if `--pad_to_max_length` is passed." 105 | ), 106 | ) 107 | parser.add_argument( 108 | "--pad_to_max_length", 109 | action="store_true", 110 | help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.", 111 | ) 112 | parser.add_argument( 113 | "--model_name_or_path", 114 | type=str, 115 | help="Path to pretrained model or model identifier from huggingface.co/models.", 116 | required=True, 117 | ) 118 | parser.add_argument( 119 | "--base_model", 120 | type=str, 121 | help="Huggingface identifier of original model", 122 | required=True 123 | ) 124 | parser.add_argument( 125 | "--tokenizer_name", 126 | type=str, 127 | default=None, 128 | help="Pretrained tokenizer name or path if not the same as model_name", 129 | ) 130 | parser.add_argument( 131 | "--use_slow_tokenizer", 132 | action="store_true", 133 | help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).", 134 | ) 135 | parser.add_argument( 136 | "--per_device_train_batch_size", 137 | type=int, 138 | default=8, 139 | help="Batch size (per device) for the training dataloader.", 140 | ) 141 | parser.add_argument( 142 | "--per_device_eval_batch_size", 143 | type=int, 144 | default=8, 145 | help="Batch size (per device) for the evaluation dataloader.", 146 | ) 147 | parser.add_argument( 148 | "--learning_rate", 149 | type=float, 150 | default=5e-5, 151 | help="Initial learning rate (after the potential warmup period) to use.", 152 | ) 153 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") 154 | parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.") 155 | parser.add_argument( 156 | "--max_train_steps", 157 | type=int, 158 | default=None, 159 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 160 | ) 161 | parser.add_argument( 162 | "--gradient_accumulation_steps", 163 | type=int, 164 | default=1, 165 | help="Number of updates steps to accumulate before performing a backward/update pass.", 166 | ) 167 | parser.add_argument( 168 | "--lr_scheduler_type", 169 | type=SchedulerType, 170 | default="linear", 171 | help="The scheduler type to use.", 172 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], 173 | ) 174 | parser.add_argument( 175 | "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." 176 | ) 177 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.") 178 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 179 | parser.add_argument( 180 | "--trust_remote_code", 181 | type=bool, 182 | default=False, 183 | help=( 184 | "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option" 185 | "should only be set to `True` for repositories you trust and in which you have read the code, as it will" 186 | "execute code present on the Hub on your local machine." 187 | ), 188 | ) 189 | parser.add_argument( 190 | "--checkpointing_steps", 191 | type=str, 192 | default=None, 193 | help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.", 194 | ) 195 | parser.add_argument( 196 | "--resume_from_checkpoint", 197 | type=str, 198 | default=None, 199 | help="If the training should continue from a checkpoint folder.", 200 | ) 201 | parser.add_argument( 202 | "--with_tracking", 203 | action="store_true", 204 | help="Whether to enable experiment trackers for logging.", 205 | ) 206 | parser.add_argument( 207 | "--report_to", 208 | type=str, 209 | default="all", 210 | help=( 211 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' 212 | ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations.' 213 | "Only applicable when `--with_tracking` is passed." 214 | ), 215 | ) 216 | parser.add_argument( 217 | "--config_name", 218 | type=str, 219 | default=None, 220 | help="Pretrained config name or path if not the same as model_name", 221 | ) 222 | 223 | args = parser.parse_args() 224 | 225 | # Sanity checks 226 | if args.task_name is None and args.train_file is None and args.validation_file is None: 227 | raise ValueError("Need either a task name or a training/validation file.") 228 | else: 229 | if args.train_file is not None: 230 | extension = args.train_file.split(".")[-1] 231 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 232 | if args.validation_file is not None: 233 | extension = args.validation_file.split(".")[-1] 234 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 235 | 236 | return args 237 | 238 | 239 | def run_eval(eval_dataloader, model, accelerator, is_regression, metric): 240 | samples_seen = 0 241 | 242 | for step, batch in enumerate(eval_dataloader): 243 | with torch.autocast(device_type='cuda', dtype=torch.bfloat16): 244 | with torch.no_grad(): 245 | predictions = model(**batch) 246 | 247 | references = batch["labels"] 248 | predictions, references = accelerator.gather_for_metrics((predictions, references)) 249 | 250 | if predictions.logits.isnan().any(): 251 | print("Warning: some of the output logits for evaluation were NaN!") 252 | 253 | predictions = predictions.logits.argmax(dim=-1) if not is_regression else predictions.logits.squeeze() 254 | # If we are in a multiprocess environment, the last batch has duplicates 255 | if step == len(eval_dataloader) - 1: 256 | predictions = predictions[: len(eval_dataloader.dataset) - samples_seen] 257 | references = references[: len(eval_dataloader.dataset) - samples_seen] 258 | else: 259 | samples_seen += references.shape[0] 260 | metric.add_batch( 261 | predictions=predictions, 262 | references=references, 263 | ) 264 | return metric.compute() 265 | 266 | 267 | def main(): 268 | args = parse_args() 269 | # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The 270 | # information sent is the one passed as arguments along with your Python/PyTorch versions. 271 | # send_example_telemetry("run_glue_no_trainer", args) 272 | 273 | transformers.utils.send_example_telemetry("run_clm_no_trainer", args) 274 | 275 | # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. 276 | # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers 277 | # in the environment 278 | accelerator_log_kwargs = {} 279 | 280 | if args.with_tracking: 281 | accelerator_log_kwargs["log_with"] = args.report_to 282 | accelerator_log_kwargs["project_dir"] = args.output_dir 283 | 284 | accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs) 285 | 286 | # Make one log on every process with the configuration for debugging. 287 | logging.basicConfig( 288 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 289 | datefmt="%m/%d/%Y %H:%M:%S", 290 | level=logging.INFO, 291 | ) 292 | 293 | logger.info(accelerator.state, main_process_only=False) 294 | if accelerator.is_local_main_process: 295 | datasets.utils.logging.set_verbosity_warning() 296 | transformers.utils.logging.set_verbosity_info() 297 | else: 298 | datasets.utils.logging.set_verbosity_error() 299 | transformers.utils.logging.set_verbosity_error() 300 | 301 | # If passed along, set the training seed now. 302 | if accelerator.is_main_process: 303 | if args.seed is not None: 304 | set_seed(args.seed) 305 | 306 | if args.output_dir is not None: 307 | os.makedirs(args.output_dir, exist_ok=True) 308 | # writer = SummaryWriter(args.output_dir) 309 | accelerator.wait_for_everyone() 310 | 311 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 312 | # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). 313 | 314 | # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the 315 | # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named 316 | # label if at least two columns are provided. 317 | 318 | # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this 319 | # single column. You can easily tweak this behavior (see below) 320 | 321 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 322 | # download the dataset. 323 | if args.task_name is not None: 324 | # Downloading and loading a dataset from the hub. 325 | raw_datasets = load_dataset("glue", args.task_name) 326 | else: 327 | # Loading the dataset from local csv or json file. 328 | data_files = {} 329 | if args.train_file is not None: 330 | data_files["train"] = args.train_file 331 | if args.validation_file is not None: 332 | data_files["validation"] = args.validation_file 333 | extension = (args.train_file if args.train_file is not None else args.validation_file).split(".")[-1] 334 | raw_datasets = load_dataset(extension, data_files=data_files) 335 | # See more about loading any type of standard or custom dataset at 336 | # https://huggingface.co/docs/datasets/loading_datasets.html. 337 | 338 | # Labels 339 | if args.task_name is not None: 340 | is_regression = args.task_name == "stsb" 341 | if not is_regression: 342 | label_list = raw_datasets["train"].features["label"].names 343 | num_labels = len(label_list) 344 | else: 345 | num_labels = 1 346 | else: 347 | # Trying to have good defaults here, don't hesitate to tweak to your needs. 348 | is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"] 349 | if is_regression: 350 | num_labels = 1 351 | else: 352 | # A useful fast method: 353 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique 354 | label_list = raw_datasets["train"].unique("label") 355 | label_list.sort() # Let's sort it for determinism 356 | num_labels = len(label_list) 357 | 358 | # Load pretrained model and tokenizer 359 | # 360 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently 361 | # download model & vocab. 362 | if args.config_name: 363 | config = AutoConfig.from_pretrained( 364 | args.config_name, 365 | trust_remote_code=args.trust_remote_code, 366 | ) 367 | elif args.model_name_or_path: 368 | config = AutoConfig.from_pretrained( 369 | args.base_model, 370 | trust_remote_code=args.trust_remote_code, 371 | ) 372 | 373 | # Model and tokenizer 374 | model = load_quantized_model( 375 | args.model_name_or_path, args.base_model, 376 | accelerator.device, sequence_classification=True, 377 | ).to(torch.bfloat16) 378 | for name, param in model.named_parameters(): 379 | if 'SU' in name or 'SV' in name: 380 | param.requires_grad = False 381 | 382 | if args.tokenizer_name: 383 | tokenizer = AutoTokenizer.from_pretrained( 384 | args.tokenizer_name, use_fast=not args.use_slow_tokenizer, trust_remote_code=args.trust_remote_code 385 | ) 386 | elif args.base_model: 387 | tokenizer = AutoTokenizer.from_pretrained( 388 | args.base_model, 389 | use_fast=not args.use_slow_tokenizer, 390 | ) 391 | else: 392 | raise ValueError( 393 | "You are instantiating a new tokenizer from scratch. This is not supported by this script." 394 | "You can do it from another script, save it, and load it from here, using --tokenizer_name." 395 | ) 396 | model.config.pad_token_id = 0 397 | tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token 398 | tokenizer.padding_side = "left" # Allow batched inference 399 | tokenizer.truncation_side = "left" 400 | 401 | embedding_size = model.get_input_embeddings().weight.shape[0] 402 | if len(tokenizer) > embedding_size: 403 | model.resize_token_embeddings(len(tokenizer)) 404 | 405 | model = model.to(device) 406 | 407 | # Preprocessing the datasets 408 | if args.task_name is not None: 409 | sentence1_key, sentence2_key = task_to_keys[args.task_name] 410 | else: 411 | # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. 412 | non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"] 413 | if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: 414 | sentence1_key, sentence2_key = "sentence1", "sentence2" 415 | else: 416 | if len(non_label_column_names) >= 2: 417 | sentence1_key, sentence2_key = non_label_column_names[:2] 418 | else: 419 | sentence1_key, sentence2_key = non_label_column_names[0], None 420 | 421 | # Some models have set the order of the labels to use, so let's make sure we do use it. 422 | label_to_id = None 423 | if ( 424 | model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id 425 | and args.task_name is not None 426 | and not is_regression 427 | ): 428 | # Some have all caps in their config, some don't. 429 | label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} 430 | if sorted(label_name_to_id.keys()) == sorted(label_list): 431 | print( 432 | f"The configuration of the model provided the following label correspondence: {label_name_to_id}. " 433 | "Using it!" 434 | ) 435 | label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)} 436 | else: 437 | print( 438 | "Your model seems to have been trained with labels, but they don't match the dataset: ", 439 | f"model labels: {sorted(label_name_to_id.keys())}, dataset labels: {sorted(label_list)}." 440 | "\nIgnoring the model labels as a result.", 441 | ) 442 | elif args.task_name is None and not is_regression: 443 | label_to_id = {v: i for i, v in enumerate(label_list)} 444 | 445 | if label_to_id is not None: 446 | model.config.label2id = label_to_id 447 | model.config.id2label = {id: label for label, id in config.label2id.items()} 448 | elif args.task_name is not None and not is_regression: 449 | model.config.label2id = {l: i for i, l in enumerate(label_list)} 450 | model.config.id2label = {id: label for label, id in config.label2id.items()} 451 | 452 | padding = "max_length" if args.pad_to_max_length else False 453 | 454 | def preprocess_function(examples): 455 | # Tokenize the texts 456 | texts = ( 457 | (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) 458 | ) 459 | result = tokenizer(*texts, padding=padding, max_length=args.max_length, truncation=True) 460 | 461 | if "label" in examples: 462 | if label_to_id is not None: 463 | # Map labels to IDs (not necessary for GLUE tasks) 464 | result["labels"] = [label_to_id[l] for l in examples["label"]] 465 | else: 466 | # In all cases, rename the column to labels because the model will expect that. 467 | result["labels"] = examples["label"] 468 | 469 | return result 470 | 471 | with accelerator.main_process_first(): 472 | processed_datasets = raw_datasets.map( 473 | preprocess_function, 474 | batched=True, 475 | remove_columns=raw_datasets["train"].column_names, 476 | desc="Running tokenizer on dataset", 477 | ) 478 | 479 | train_dataset = processed_datasets["train"] 480 | eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"] 481 | 482 | # DataLoaders creation: 483 | if args.pad_to_max_length: 484 | # If padding was already done ot max length, we use the default data collator that will just convert everything 485 | # to tensors. 486 | data_collator = default_data_collator 487 | else: 488 | # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of 489 | # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple 490 | # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta). 491 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=None) 492 | 493 | train_dataloader = DataLoader( 494 | train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size 495 | ) 496 | eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size) 497 | 498 | # Optimizer 499 | # Split weights in two groups, one with weight decay and the other not. 500 | no_decay = ["bias", "LayerNorm.weight"] 501 | optimizer_grouped_parameters = [ 502 | { 503 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 504 | "weight_decay": args.weight_decay, 505 | }, 506 | { 507 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 508 | "weight_decay": 0.0, 509 | }, 510 | ] 511 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) 512 | 513 | # optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=args.learning_rate) 514 | # Scheduler and math around the number of training steps. 515 | overrode_max_train_steps = False 516 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 517 | if args.max_train_steps is None: 518 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 519 | overrode_max_train_steps = True 520 | 521 | lr_scheduler = get_scheduler( 522 | name=args.lr_scheduler_type, 523 | optimizer=optimizer, 524 | num_warmup_steps=args.num_warmup_steps, 525 | num_training_steps=args.max_train_steps, 526 | ) 527 | 528 | # Prepare everything with our `accelerator`. 529 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( 530 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler 531 | ) 532 | 533 | # We need to recalculate our total training steps as the size of the training dataloader may have changed 534 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 535 | if overrode_max_train_steps: 536 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 537 | # Afterwards we recalculate our number of training epochs 538 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 539 | 540 | # Figure out how many steps we should save the Accelerator states 541 | checkpointing_steps = args.checkpointing_steps 542 | if checkpointing_steps is not None and checkpointing_steps.isdigit(): 543 | checkpointing_steps = int(checkpointing_steps) 544 | 545 | # We need to initialize the trackers we use, and also store our configuration. 546 | # The trackers initializes automatically on the main process. 547 | if args.with_tracking: 548 | experiment_config = vars(args) 549 | # TensorBoard cannot log Enums, need the raw value 550 | experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value 551 | accelerator.init_trackers("glue_no_trainer", experiment_config) 552 | 553 | # Get the metric function 554 | if args.task_name is not None: 555 | metric = evaluate.load("glue", args.task_name) 556 | else: 557 | metric = evaluate.load("accuracy") 558 | 559 | # Train! 560 | total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps 561 | 562 | print("***** Running training *****") 563 | print(f" Num examples = {len(train_dataset)}") 564 | print(f" Num Epochs = {args.num_train_epochs}") 565 | print(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") 566 | print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 567 | print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 568 | print(f" Total optimization steps = {args.max_train_steps}") 569 | # Only show the progress bar once on each machine. 570 | progress_bar = tqdm(range(args.max_train_steps)) 571 | completed_steps = 0 572 | starting_epoch = 0 573 | # Potentially load in the weights and states from a previous save 574 | if args.resume_from_checkpoint: 575 | if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "": 576 | accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") 577 | accelerator.load_state(args.resume_from_checkpoint) 578 | path = os.path.basename(args.resume_from_checkpoint) 579 | else: 580 | # Get the most recent checkpoint 581 | dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()] 582 | dirs.sort(key=os.path.getctime) 583 | path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last 584 | # Extract `epoch_{i}` or `step_{i}` 585 | training_difference = os.path.splitext(path)[0] 586 | 587 | if "epoch" in training_difference: 588 | starting_epoch = int(training_difference.replace("epoch_", "")) + 1 589 | resume_step = None 590 | else: 591 | resume_step = int(training_difference.replace("step_", "")) 592 | starting_epoch = resume_step // len(train_dataloader) 593 | resume_step -= starting_epoch * len(train_dataloader) 594 | 595 | performace_dict = {} 596 | for epoch in range(starting_epoch, args.num_train_epochs): 597 | model.train() 598 | if args.with_tracking: 599 | total_loss = 0 600 | if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None: 601 | # We skip the first `n` batches in the dataloader when resuming from a checkpoint 602 | train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step) 603 | 604 | for step, batch in enumerate(train_dataloader): 605 | # print(batch["attention_mask"]) 606 | # # We need to skip steps until we reach the resumed step 607 | if args.resume_from_checkpoint and epoch == starting_epoch: 608 | if resume_step is not None and step < resume_step: 609 | completed_steps += 1 610 | continue 611 | 612 | with accelerator.accumulate(model): 613 | # print(batch) 614 | # return 615 | outputs = model(**batch) 616 | loss = outputs.loss 617 | # We keep track of the loss at each epoch 618 | if args.with_tracking: 619 | total_loss += loss.detach().float() 620 | accelerator.backward(loss) 621 | if completed_steps % 50: 622 | accelerator.print(f"Epoch: {epoch} | Step: {completed_steps} | Loss: {loss}") 623 | optimizer.step() 624 | lr_scheduler.step() 625 | optimizer.zero_grad() 626 | 627 | # Checks if the accelerator has performed an optimization step behind the scenes 628 | if accelerator.sync_gradients: 629 | progress_bar.update(1) 630 | completed_steps += 1 631 | 632 | if isinstance(checkpointing_steps, int): 633 | if completed_steps % checkpointing_steps == 0: 634 | output_dir = f"step_{completed_steps}" 635 | if args.output_dir is not None: 636 | output_dir = os.path.join(args.output_dir, output_dir) 637 | accelerator.save_state(output_dir) 638 | 639 | if completed_steps >= args.max_train_steps: 640 | break 641 | 642 | if completed_steps % 500 == 0 and step % args.gradient_accumulation_steps == 0 : 643 | logger.info(f"The current loss is {loss}") 644 | 645 | if completed_steps % (5000 * 32 / total_batch_size) == 0 and step % args.gradient_accumulation_steps == 0: 646 | model.eval() 647 | eval_metric = run_eval(eval_dataloader, model, accelerator, is_regression, metric) 648 | 649 | # for k, v in eval_metric.items(): 650 | # writer.add_scalar(f"eval/{args.output_dir}/{k}", v, global_step=completed_steps) 651 | logger.info( 652 | f"seed {args.seed} learning rate {args.learning_rate} " 653 | + f"epoch {epoch}: {eval_metric}") 654 | performace_dict[completed_steps]=eval_metric[task_to_metrics[args.task_name]] 655 | 656 | torch.save(model.state_dict(), args.output_dir +f"/{completed_steps}.bin") 657 | 658 | if args.with_tracking: 659 | accelerator.log( 660 | { 661 | "accuracy" if args.task_name is not None else "glue": eval_metric, 662 | "train_loss": total_loss.item() / len(train_dataloader), 663 | "epoch": epoch, 664 | "step": completed_steps, 665 | }, 666 | step=completed_steps, 667 | ) 668 | if args.checkpointing_steps == "epoch": 669 | output_dir = f"epoch_{epoch}" 670 | if args.output_dir is not None: 671 | output_dir = os.path.join(args.output_dir, output_dir) 672 | 673 | accelerator.save_state(output_dir) 674 | 675 | accelerator.wait_for_everyone() 676 | if accelerator.is_main_process: 677 | from safetensors.torch import save_file 678 | save_file(accelerator.get_state_dict(model), output_dir + "/model.safetensors") 679 | print("Saved checkpoint in ", output_dir + "/model.safetensors") 680 | accelerator.wait_for_everyone() 681 | 682 | model.eval() 683 | eval_metric = run_eval(eval_dataloader, model, accelerator, is_regression, metric) 684 | logger.info(f"{args.output_dir} | epoch {epoch}: {eval_metric}") 685 | if args.with_tracking: 686 | accelerator.log( 687 | { 688 | "accuracy" if args.task_name is not None else "glue": eval_metric, 689 | "train_loss": total_loss.item() / len(train_dataloader), 690 | "epoch": epoch, 691 | "step": completed_steps, 692 | }, 693 | step=completed_steps, 694 | ) 695 | performace_dict[epoch] = eval_metric[task_to_metrics[args.task_name]] 696 | 697 | torch.save(model.state_dict(), args.output_dir + f"/{completed_steps}.bin") 698 | 699 | model.eval() 700 | eval_metric = run_eval(eval_dataloader, model, accelerator, is_regression, metric) 701 | 702 | print(f"{args.output_dir} | step {completed_steps}: {eval_metric}") 703 | if not eval: 704 | best_performance = max(performace_dict.values()) 705 | max_keys = [k for k, v in performace_dict.items() if 706 | v == best_performance] # getting all keys containing the `maximum` 707 | print(f"seed {args.seed} learning rate {args.learning_rate} " 708 | + f"The best performance is at {max_keys[0]} with {best_performance}") 709 | 710 | if args.task_name == "mnli": 711 | # Final evaluation on mismatched validation set 712 | eval_dataset = processed_datasets["validation_mismatched"] 713 | eval_dataloader = DataLoader( 714 | eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size 715 | ) 716 | 717 | eval_metric = run_eval(eval_dataloader, model, accelerator, is_regression, metric) 718 | print(f"{args.output_dir}|mnli-mm: {eval_metric}") 719 | 720 | if args.output_dir is not None: 721 | all_results = {f"eval_{k}": v for k, v in eval_metric.items()} 722 | with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: 723 | json.dump(all_results, f) 724 | 725 | if __name__ == "__main__": 726 | main() 727 | --------------------------------------------------------------------------------