├── .python-version
├── src
└── caldera
│ ├── __init__.py
│ ├── utils
│ ├── enums.py
│ └── quantization.py
│ └── decomposition
│ ├── dataclasses.py
│ ├── weight_compression.py
│ ├── layer_quantization.py
│ ├── quantized_layer.py
│ └── alg.py
├── assets
└── caldera_decomposition.png
├── .gitmodules
├── .gitignore
├── setup_quip_sharp.sh
├── shell_scripts
├── accelerate_config.yaml
├── deepspeed_config.json
├── run_eval_ppl.sh
├── run_save_hessians.sh
├── run_eval_zeroshot.sh
├── run_finetune_RHT.sh
├── run_finetune_glue.sh
├── run_finetune_winogrande.sh
├── run_finetune_wikitext.sh
└── run_quantize_save_caldera.sh
├── scripts
├── get_sv_info.py
├── save_llama_hessians.py
├── eval_zero_shot.py
├── eval_ppl.py
├── finetune_RHT.py
├── quantize_save_llama.py
├── finetune_wikitext.py
├── finetune_winogrande.py
└── finetune_glue.py
├── requirements.txt
├── quip-sharp-pyproject.toml
├── notebooks
├── test_caldera.ipynb
└── eval_throughput.ipynb
├── pyproject.toml
└── README.md
/.python-version:
--------------------------------------------------------------------------------
1 | 3.9.0
2 |
--------------------------------------------------------------------------------
/src/caldera/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.0.1.dev0"
2 |
--------------------------------------------------------------------------------
/assets/caldera_decomposition.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pilancilab/caldera/HEAD/assets/caldera_decomposition.png
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "quip-sharp"]
2 | path = quip-sharp
3 | url = git@github.com:Cornell-RelaxML/quip-sharp.git
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .python_version
3 | .conda
4 | data/
5 | *.egg-info
6 | naomi/results
7 | .gitconfig
8 | .vscode/settings.json
9 | artifacts/
10 | .vscode
11 | shell_scripts/custom
12 | wandb
13 | .python-version
--------------------------------------------------------------------------------
/setup_quip_sharp.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | git submodule init
3 | git submodule update
4 | cp quip-sharp-pyproject.toml quip-sharp/pyproject.toml
5 | cd quip-sharp && pip install --editable . && cd ..
6 | cd quip-sharp/quiptools && python setup.py install && cd ../..
7 |
8 |
--------------------------------------------------------------------------------
/src/caldera/utils/enums.py:
--------------------------------------------------------------------------------
1 | from enum import IntEnum
2 |
3 |
4 | class DevSet(IntEnum):
5 | RP1T = 0
6 | FALCON = 1
7 |
8 | class TransformerSubLayers(IntEnum):
9 | QUERY = 0
10 | KEY = 1
11 | VALUE = 2
12 | O = 3
13 | UP = 4
14 | GATE = 5
15 | DOWN = 6
--------------------------------------------------------------------------------
/shell_scripts/accelerate_config.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_config_file: ./shell_scripts/deepspeed_config.json
5 | zero3_init_flag: false
6 | distributed_type: DEEPSPEED
7 | downcast_bf16: 'no'
8 | machine_rank: 0
9 | main_training_function: main
10 | num_machines: 1
11 | num_processes: 4
12 | rdzv_backend: static
13 | same_network: true
14 | tpu_env: []
15 | tpu_use_cluster: false
16 | tpu_use_sudo: false
17 | use_cpu: false
18 | main_process_port: 29500
19 |
--------------------------------------------------------------------------------
/shell_scripts/deepspeed_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_optimization": {
3 | "stage": 2,
4 | "overlap_comm": true,
5 | "contiguous_gradients": true,
6 | "stage3_max_live_parameters": 5e8,
7 | "offload_optimizer": {
8 | "device": "cpu"
9 | },
10 | "offload_param": {
11 | "device": "cpu"
12 | }
13 | },
14 | "bf16": {
15 | "enabled": true,
16 | "auto_cast": true
17 | },
18 | "train_micro_batch_size_per_gpu": "auto",
19 | "gradient_accumulation_steps": "auto"
20 | }
21 |
--------------------------------------------------------------------------------
/shell_scripts/run_eval_ppl.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SCRIPT_FILLED_IN=0
4 |
5 | BASE_MODEL="meta-llama/Llama-2-7b-hf"
6 | CALDERA_MODEL_SAVE_PATH="YOUR_MODEL_PATH"
7 | DEVICE="cuda:0"
8 | OUTPUT_FILENAME="YOUR FILENAME HERE"
9 |
10 | if [ $SCRIPT_FILLED_IN -eq 0 ]; then
11 | echo -e "This script is meant as a template for running scripts/eval_ppl.py. \
12 | Please go into shell_scripts/run_eval_ppl.sh and replace BASE_MODEL, \
13 | CALDERA_MODEL_SAVE_PATH, etc., and then set SCRIPT_FILLED_IN=1 at the top of the file."
14 | exit -0
15 | fi
16 |
17 | python scripts/eval_ppl.py \
18 | --model_save_path $CALDERA_MODEL_SAVE_PATH \
19 | --base_model $BASE_MODEL \
20 | --seed 0 \
21 | --seqlen 4096 \
22 | --datasets wikitext2 c4 \
23 | --device $DEVICE \
24 | --output_path $OUTPUT_FILENAME
25 |
--------------------------------------------------------------------------------
/shell_scripts/run_save_hessians.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SCRIPT_FILLED_IN=0
4 |
5 | BASE_MODEL="meta-llama/Llama-2-7b-hf"
6 | HESSIAN_SAVE_DIR="YOUR DIRECTORY HERE"
7 | DEVICES="cuda:0 cuda:1 cuda:2 cuda:3"
8 |
9 | if [ $SCRIPT_FILLED_IN -eq 0 ]; then
10 | echo -e "This script is meant as a template for running scripts/save_llama_hessians.py. \
11 | Please go into shell_scripts/run_save_hessians.sh and replace BASE_MODEL, HESSIAN_SAVE_DIR, etc., \
12 | and then set SCRIPT_FILLED_IN=1 at the top of the file."
13 | exit -0
14 | fi
15 |
16 | python scripts/save_llama_hessians.py \
17 | --base_model $BASE_MODEL \
18 | --hessian_save_path $HESSIAN_SAVE_DIR \
19 | --devset rp1t \
20 | --context_length 4096 \
21 | --devset_size 256 \
22 | --chunk_size 64 \
23 | --batch_size 32 \
24 | --devices $DEVICES
--------------------------------------------------------------------------------
/shell_scripts/run_eval_zeroshot.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SCRIPT_FILLED_IN=0
4 |
5 | BASE_MODEL="meta-llama/Llama-2-7b-hf"
6 | CALDERA_MODEL_SAVE_PATH="PATH OF .pt FILE WITH MODEL"
7 | DEVICE="cuda:0"
8 | OUTPUT_FILENAME="YOUR FILE HERE"
9 |
10 | if [ $SCRIPT_FILLED_IN -eq 0 ]; then
11 | echo -e "This script is meant as a template for running scripts/eval_zeroshot.py. \
12 | Please go into shell_scripts/run_eval_zeroshot.sh and replace BASE_MODEL, \
13 | CALDERA_MODEL_SAVE_PATH, etc., and then set SCRIPT_FILLED_IN=1 at the top of the file."
14 | exit -0
15 | fi
16 |
17 | python scripts/eval_zero_shot.py \
18 | --model_save_path $CALDERA_MODEL_SAVE_PATH \
19 | --device $DEVICE \
20 | --tasks winogrande rte piqa arc_easy arc_challenge \
21 | --base_model $BASE_MODEL \
22 | --batch_size 8 \
23 | --output_path $OUTPUT_FILENAME
--------------------------------------------------------------------------------
/shell_scripts/run_finetune_RHT.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SCRIPT_FILLED_IN=0
4 |
5 | BASE_MODEL="meta-llama/Llama-2-7b-hf"
6 | CALDERA_MODEL_SAVE_PATH_NO_RHT_FT="PATH OF .pt FILE WITH MODEL TO FINETUNED"
7 | CALDERA_MODEL_SAVE_PATH_WITH_RHT_FT="PATH OF .pt FILE TO SAVE FINETUNED MODEL"
8 |
9 | if [ $SCRIPT_FILLED_IN -eq 0 ]; then
10 | echo -e "This script is meant as a template for running scripts/finetune_RHT.py. \
11 | Please go into shell_scripts/run_finetune_glue.sh and replace BASE_MODEL, \
12 | CALDERA_MODEL_SAVE_PATH, etc., and then set SCRIPT_FILLED_IN=1 at the top of the file."
13 | exit -0
14 | fi
15 |
16 | python scripts/finetune_RHT.py \
17 | --base_model $BASE_MODEL \
18 | --model_path $CALDERA_MODEL_SAVE_PATH_NO_RHT_FT \
19 | --finetuned_save_path $CALDERA_MODEL_SAVE_PATH_WITH_RHT_FT \
20 | --devset_size 256 \
21 | --ctx_size 512 \
22 | --device cuda:0 \
23 | --ft_bs 2 \
24 | --ft_valid_size 64 \
25 | --RHT_learning_rate 1e-3 \
26 | --epochs 1
--------------------------------------------------------------------------------
/shell_scripts/run_finetune_glue.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SCRIPT_FILLED_IN=0
4 |
5 | BASE_MODEL="meta-llama/Llama-2-7b-hf"
6 | CALDERA_MODEL_SAVE_PATH="PATH OF .pt FILE WITH MODEL"
7 | OUTPUT_DIR="FINETUNING OUTPUT DIRECTORY"
8 | GLUE_TASK="rte"
9 |
10 | if [ $SCRIPT_FILLED_IN -eq 0 ]; then
11 | echo -e "This script is meant as a template for running scripts/finetune_glue.py. \
12 | Please go into shell_scripts/run_finetune_glue.sh and replace BASE_MODEL, \
13 | CALDERA_MODEL_SAVE_PATH, etc., and then set SCRIPT_FILLED_IN=1 at the top of the file."
14 | exit -0
15 | fi
16 |
17 | accelerate launch --config_file shell_scripts/accelerate_config.yaml \
18 | scripts/finetune_glue.py \
19 | --task_name $GLUE_TASK \
20 | --model_name_or_path $CALDERA_MODEL_SAVE_PATH \
21 | --base_model $BASE_MODEL \
22 | --output_dir $OUTPUT_DIR \
23 | --learning_rate 3e-5 \
24 | --num_train_epochs 15 \
25 | --per_device_train_batch_size 1 \
26 | --per_device_eval_batch_size 1 \
27 | --gradient_accumulation_steps 2 \
28 | --weight_decay 0.01 \
29 | --num_warmup_steps 100 \
30 | --lr_scheduler_type linear \
31 | --report_to tensorboard \
32 | --with_tracking \
33 | --checkpointing_steps epoch \
34 | --pad_to_max_length \
35 | --seed 314
--------------------------------------------------------------------------------
/shell_scripts/run_finetune_winogrande.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SCRIPT_FILLED_IN=0
4 |
5 | BASE_MODEL="meta-llama/Llama-2-7b-hf"
6 | CALDERA_MODEL_SAVE_PATH="PATH OF .pt FILE WITH MODEL"
7 | OUTPUT_DIR="FINETUNING OUTPUT DIRECTORY"
8 | SEQ_LEN=256
9 |
10 | if [ $SCRIPT_FILLED_IN -eq 0 ]; then
11 | echo -e "This script is meant as a template for running scripts/finetune_winogrande.py. \
12 | Please go into shell_scripts/run_finetune_winogrande.sh and replace BASE_MODEL, \
13 | CALDERA_MODEL_SAVE_PATH, etc., and then set SCRIPT_FILLED_IN=1 at the top of the file."
14 | exit -0
15 | fi
16 |
17 |
18 | CUDA_VISIBLE_DEVICES=$DEVICES accelerate launch --config_file ./shell_scripts/accelerate_config.yaml \
19 | scripts/finetune_winogrande.py \
20 | --model_name_or_path $CALDERA_MODEL_SAVE_PATH \
21 | --base_model $BASE_MODEL \
22 | --output_dir $OUTPUT_DIR \
23 | --learning_rate 1e-5 \
24 | --weight_decay 0.01 \
25 | --lr_scheduler_type linear \
26 | --warmup_ratio 0.033 \
27 | --num_warmup_steps 100 \
28 | --seed 202 \
29 | --max_seq_length $SEQ_LEN \
30 | --num_train_epochs 1 \
31 | --per_device_train_batch_size 10 \
32 | --per_device_eval_batch_size 10 \
33 | --gradient_accumulation_steps 1 \
34 | --logging_steps 10 \
35 | --save_steps 200 \
36 | --bf16 \
37 | --with_tracking true \
38 | --report_to tensorboard
--------------------------------------------------------------------------------
/shell_scripts/run_finetune_wikitext.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SCRIPT_FILLED_IN=0
4 |
5 | BASE_MODEL="meta-llama/Llama-2-7b-hf"
6 | CALDERA_MODEL_SAVE_PATH="PATH OF .pt FILE WITH MODEL"
7 | OUTPUT_DIR="FINETUNING OUTPUT DIRECTORY"
8 | BLOCK_SIZE=256
9 |
10 | if [ $SCRIPT_FILLED_IN -eq 0 ]; then
11 | echo -e "This script is meant as a template for running scripts/finetune_wikitext.py. \
12 | Please go into shell_scripts/run_finetune_wikitext.sh and replace BASE_MODEL, \
13 | CALDERA_MODEL_SAVE_PATH, etc., and then set SCRIPT_FILLED_IN=1 at the top of the file."
14 | exit -0
15 | fi
16 |
17 | accelerate launch --config_file shell_scripts/accelerate_config.yaml \
18 | scripts/finetune_wikitext.py \
19 | --model_save_path $CALDERA_MODEL_SAVE_PATH \
20 | --base_model $BASE_MODEL \
21 | --output_dir $OUTPUT_DIR \
22 | --dataset_name wikitext \
23 | --dataset_config_name wikitext-2-raw-v1 \
24 | --block_size $BLOCK_SIZE \
25 | --learning_rate 3e-6 \
26 | --num_train_epochs 3 \
27 | --per_device_train_batch_size 1 \
28 | --per_device_eval_batch_size 1 \
29 | --gradient_accumulation_steps 1 \
30 | --weight_decay 0.001 \
31 | --warmup_ratio 0.02 \
32 | --lr_scheduler_type linear \
33 | --bf16 \
34 | --logging_steps 10 \
35 | --save_steps 200 \
36 | --eval_steps 100 \
37 | --evaluation_strategy steps \
38 | --prediction_loss_only
39 |
--------------------------------------------------------------------------------
/shell_scripts/run_quantize_save_caldera.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SCRIPT_FILLED_IN=0
4 |
5 | # CALDERA PARAMETERS
6 | RANK=256
7 | LR_BITS=4
8 | CALDERA_ITERS=20
9 | LPLR_ITERS=50
10 | # END CALDERA PARAMETERS
11 |
12 | BASE_MODEL="meta-llama/Llama-2-7b-hf"
13 | DEVICES="cuda:0 cuda:1 cuda:2 cuda:3"
14 |
15 | CALDERA_MODEL_SAVE_PATH="YOUR_MODEL_OUTPUT_FILENAME"
16 | HESSIAN_SAVE_DIR="~/.cache/huggingface/hub/models--relaxml--Hessians-Llama-2-7b-6144/snapshots/SNAPSHOT_ID_HERE"
17 |
18 | if [ $SCRIPT_FILLED_IN -eq 0 ]; then
19 | echo -e "This script is meant as a template for running scripts/quantize_save_llama.py. \
20 | Please go into shell_scripts/run_quantize_save_caldera.sh and replace the CALDERA parameters, etc., \
21 | and then set SCRIPT_FILLED_IN=1 at the top of the file."
22 | exit -0
23 | fi
24 |
25 | QUANT_PARAMS="--Q_bits 2 \
26 | --compute_low_rank_factors true \
27 | --compute_quantized_component true \
28 | --L_bits $LR_BITS \
29 | --R_bits $LR_BITS \
30 | --lattice_quant_LR true \
31 | --rank $RANK \
32 | --activation_aware_LR true \
33 | --activation_aware_Q true \
34 | --hadamard_transform true \
35 | --iters $CALDERA_ITERS \
36 | --lplr_iters $LPLR_ITERS \
37 | --rand_svd false \
38 | --update_order LR Q \
39 | --Q_hessian_downdate true \
40 | --ft_rank 0 \
41 | --random_seed 42"
42 |
43 | python scripts/quantize_save_llama.py \
44 | --hessian_save_path $HESSIAN_SAVE_DIR \
45 | --model_save_path $CALDERA_MODEL_SAVE_PATH \
46 | --devices $DEVICES \
47 | --base_model $BASE_MODEL \
48 | $QUANT_PARAMS
--------------------------------------------------------------------------------
/scripts/get_sv_info.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import AutoModelForCausalLM, HfArgumentParser
3 | import gc
4 | from tqdm import tqdm
5 | from dataclasses import dataclass, field
6 | import os
7 |
8 |
9 | @dataclass
10 | class Arguments:
11 | base_model: str = field(default="meta-llama/Llama-2-7b-hf", metadata={
12 | "help": ("Path of the original model, as either a local or a "
13 | "Huggingface path")
14 | })
15 | device: str = field(default="cuda:0", metadata={
16 | "help": "Device on which to run evaluation."
17 | })
18 | save_dir: str = field(default="./data")
19 |
20 | SUBLAYERS = {
21 | "q_proj": lambda layer: layer.self_attn.q_proj,
22 | "k_proj": lambda layer: layer.self_attn.k_proj,
23 | "v_proj": lambda layer: layer.self_attn.v_proj,
24 | "o_proj": lambda layer: layer.self_attn.o_proj,
25 | "up_proj": lambda layer: layer.mlp.up_proj,
26 | "gate_proj": lambda layer: layer.mlp.gate_proj,
27 | "down_proj": lambda layer: layer.mlp.down_proj
28 | }
29 |
30 | def main(base_model, device, save_dir):
31 | model = AutoModelForCausalLM.from_pretrained(
32 | base_model, torch_dtype='auto', low_cpu_mem_usage=True
33 | ).cpu()
34 |
35 | for label in SUBLAYERS:
36 | A = SUBLAYERS[label](model.model.layers[0]).weight
37 | n = min(A.shape[0], A.shape[1])
38 | SVs = torch.zeros(n, 32, requires_grad=False)
39 | for i, layer in tqdm(enumerate(model.model.layers)):
40 | A = SUBLAYERS[label](layer).weight.to(device).float().detach()
41 | _, S, _ = torch.linalg.svd(A)
42 | S = S.cpu()
43 | del A
44 | gc.collect()
45 | torch.cuda.empty_cache()
46 |
47 | SVs[:, i] = S
48 |
49 | torch.save({
50 | "SV_data": SVs.detach(),
51 | "means": torch.mean(SVs.detach(), dim=1),
52 | "stdevs": torch.std(SVs.detach(), dim=1)
53 | }, f"{save_dir}/{label}_sv_info.pt")
54 |
55 | if __name__ == "__main__":
56 | parser = HfArgumentParser([Arguments])
57 | args = parser.parse_args_into_dataclasses()[0]
58 | os.makedirs(os.path.dirname(args.save_dir), exist_ok=True)
59 | main(args.base_model, args.device, args.save_dir)
60 |
61 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | packaging==24.0
2 | torch==2.2.1
3 | torchaudio==2.2.1
4 | torchvision==0.17.1
5 | absl-py==2.1.0
6 | accelerate==0.28.0
7 | aiohttp==3.9.3
8 | aiosignal==1.3.1
9 | annotated-types==0.6.0
10 | anyio==4.3.0
11 | argparse-dataclass==2.0.0
12 | asttokens==2.4.1
13 | attrs==23.2.0
14 | chardet==5.2.0
15 | click==8.1.7
16 | colorama==0.4.6
17 | comm==0.2.2
18 | contourpy==1.2.0
19 | cycler==0.12.1
20 | DataProperty==1.0.1
21 | datasets==2.18.0
22 | debugpy==1.8.1
23 | decorator==5.1.1
24 | deepspeed==0.14.0
25 | dill==0.3.8
26 | distro==1.9.0
27 | evaluate==0.4.1
28 | executing==2.0.1
29 | fonttools==4.50.0
30 | frozenlist==1.4.1
31 | fsspec==2024.2.0
32 | glog==0.3.1
33 | h11==0.14.0
34 | hjson==3.1.0
35 | httpcore==1.0.4
36 | httpx==0.27.0
37 | huggingface-hub[cli]==0.21.4
38 | ipykernel==6.29.3
39 | ipython==8.22.2
40 | jedi==0.19.1
41 | joblib==1.3.2
42 | jsonlines==4.0.0
43 | jupyter_client==8.6.1
44 | jupyter_core==5.7.2
45 | kiwisolver==1.4.5
46 | lm-eval==0.3.0
47 | lxml==5.1.0
48 | matplotlib==3.8.3
49 | matplotlib-inline==0.1.6
50 | mbstrdecoder==1.1.3
51 | more-itertools==10.2.0
52 | multidict==6.0.5
53 | multiprocess==0.70.16
54 | nest-asyncio==1.6.0
55 | ninja==1.11.1.1
56 | nltk==3.8.1
57 | numexpr==2.9.0
58 | openai==1.14.2
59 | pandas==2.2.1
60 | parso==0.8.3
61 | pathvalidate==3.2.0
62 | peft==0.10.0
63 | pexpect==4.9.0
64 | platformdirs==4.2.0
65 | portalocker==2.8.2
66 | primefac==2.0.12
67 | prompt-toolkit==3.0.43
68 | psutil==5.9.8
69 | ptyprocess==0.7.0
70 | pure-eval==0.2.2
71 | py-cpuinfo==9.0.0
72 | pyarrow==15.0.2
73 | pyarrow-hotfix==0.6
74 | pybind11==2.11.1
75 | pycountry==23.12.11
76 | pydantic==2.6.4
77 | pydantic_core==2.16.3
78 | Pygments==2.17.2
79 | pynvml==11.5.0
80 | pyparsing==3.1.2
81 | pytablewriter==1.2.0
82 | python-dateutil==2.9.0.post0
83 | python-gflags==3.1.2
84 | pytz==2024.1
85 | pyzmq==25.1.2
86 | regex==2023.12.25
87 | responses==0.18.0
88 | rouge_score==0.1.2
89 | sacrebleu==1.5.0
90 | safetensors==0.4.2
91 | scikit-learn==1.4.1.post1
92 | scipy==1.12.0
93 | six==1.16.0
94 | sniffio==1.3.1
95 | sqlitedict==2.1.0
96 | stack-data==0.6.3
97 | tabledata==1.3.3
98 | tabulate==0.9.0
99 | tcolorpy==0.1.4
100 | threadpoolctl==3.4.0
101 | tokenizers==0.15.2
102 | tornado==6.4
103 | tqdm==4.66.2
104 | tqdm-multiprocess==0.0.11
105 | traitlets==5.14.2
106 | transformers==4.39.1
107 | triton==2.2.0
108 | typepy==1.3.2
109 | tzdata==2024.1
110 | wcwidth==0.2.13
111 | word2number==1.1
112 | xxhash==3.4.1
113 | yarl==1.9.4
114 | zstandard==0.22.0
115 | sentencepiece==0.2.0
116 | protobuf==5.26.0
117 | tensorboard==2.16.2
118 |
--------------------------------------------------------------------------------
/scripts/save_llama_hessians.py:
--------------------------------------------------------------------------------
1 | import transformers
2 | from caldera.decomposition.weight_compression import *
3 |
4 |
5 | @dataclass
6 | class Arguments:
7 | base_model: str = field(default="meta-llama/Llama-2-7b-hf", metadata={
8 | "help": ("Path of the model that will eventually be quantized, as "
9 | "either a local or a Huggingface path.")
10 | })
11 | token: str = field(default=None, metadata={
12 | "help": "Huggingface access token for private models."
13 | })
14 | hessian_save_path: str = field(
15 | default="./data/hessians/llama-2-7b", metadata={
16 | "help": ("Directory in which to save Hessians.")
17 | })
18 | n_sample_proc: int = field(default=4, metadata={
19 | "help": "Number of processes used to sample calibration data."
20 | })
21 |
22 |
23 | @dataclass
24 | class DataParametersCommandLine:
25 | """
26 | Parameters for loading the calibration dataset and computing the
27 | inputs to each layer.
28 | """
29 | devset: str = field(
30 | default="rp1t", metadata={"help": (
31 | "Calibration dataset; either rp1t or falcon"
32 | ), "choices": ["rp1t", "falcon"]}
33 | )
34 | devset_size: int = field(
35 | default=256, metadata={"help": (
36 | "Number of calibration samples to use."
37 | )}
38 | )
39 | context_length: int = field(
40 | default=4096, metadata={"help": (
41 | "Length of context window."
42 | )}
43 | )
44 | batch_size: int = field(
45 | default=2, metadata={"help": (
46 | "Number of datapoints to pass into the model at once."
47 | )}
48 | )
49 | chunk_size: int = field(
50 | default=256, metadata={"help": (
51 | "Number of batches sent to each GPU at a time."
52 | )}
53 | )
54 | devices: list[str] = field(
55 | default=None, metadata={"help": (
56 | "Specific CUDA devices to use for Hessian computation. Defaults "
57 | "to None, which means that all available devices are used."
58 | )}
59 | )
60 |
61 |
62 | if __name__ == "__main__":
63 | parser = transformers.HfArgumentParser([Arguments, DataParametersCommandLine])
64 | args, data_params_command_line = parser.parse_args_into_dataclasses()
65 |
66 | devset = DevSet.RP1T if data_params_command_line.devset == "rp1t" \
67 | else DevSet.FALCON
68 |
69 | data_params = DataParameters(
70 | devset=devset,
71 | devset_size=data_params_command_line.devset_size,
72 | context_length=data_params_command_line.context_length,
73 | batch_size=data_params_command_line.batch_size,
74 | chunk_size=data_params_command_line.chunk_size,
75 | devices=data_params_command_line.devices
76 | )
77 |
78 | ActivationAwareWeightCompressor(
79 | model_params=ModelParameters(
80 | base_model=args.base_model,
81 | token=args.token
82 | ),
83 | data_params=data_params,
84 | hessian_save_path=args.hessian_save_path,
85 | quant_device="cuda",
86 | n_sample_proc=args.n_sample_proc,
87 | compute_hessians=True
88 | )
89 |
--------------------------------------------------------------------------------
/quip-sharp-pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools >= 61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "quip-sharp"
7 | version = "0.0.1"
8 | requires-python = ">=3.8.0"
9 |
10 | dependencies = [
11 | "packaging==24.0",
12 | "torch==2.2.1",
13 | "torchaudio==2.2.1",
14 | "torchvision==0.17.1",
15 | "absl-py==2.1.0",
16 | "accelerate==0.28.0",
17 | "aiohttp==3.9.3",
18 | "aiosignal==1.3.1",
19 | "annotated-types==0.6.0",
20 | "anyio==4.3.0",
21 | "argparse-dataclass==2.0.0",
22 | "asttokens==2.4.1",
23 | "attrs==23.2.0",
24 | "chardet==5.2.0",
25 | "click==8.1.7",
26 | "colorama==0.4.6",
27 | "comm==0.2.2",
28 | "contourpy==1.2.0",
29 | "cycler==0.12.1",
30 | "DataProperty==1.0.1",
31 | "datasets==2.18.0",
32 | "debugpy==1.8.1",
33 | "decorator==5.1.1",
34 | "deepspeed==0.14.0",
35 | "dill==0.3.8",
36 | "distro==1.9.0",
37 | "evaluate==0.4.1",
38 | "executing==2.0.1",
39 | "fonttools==4.50.0",
40 | "frozenlist==1.4.1",
41 | "fsspec==2024.2.0",
42 | "glog==0.3.1",
43 | "h11==0.14.0",
44 | "hjson==3.1.0",
45 | "httpcore==1.0.4",
46 | "httpx==0.27.0",
47 | "huggingface-hub==0.21.4",
48 | "ipykernel==6.29.3",
49 | "ipython==8.22.2",
50 | "jedi==0.19.1",
51 | "joblib==1.3.2",
52 | "jsonlines==4.0.0",
53 | "jupyter_client==8.6.1",
54 | "jupyter_core==5.7.2",
55 | "kiwisolver==1.4.5",
56 | "lm-eval==0.3.0",
57 | "lxml==5.1.0",
58 | "matplotlib==3.8.3",
59 | "matplotlib-inline==0.1.6",
60 | "mbstrdecoder==1.1.3",
61 | "more-itertools==10.2.0",
62 | "multidict==6.0.5",
63 | "multiprocess==0.70.16",
64 | "nest-asyncio==1.6.0",
65 | "ninja==1.11.1.1",
66 | "nltk==3.8.1",
67 | "numexpr==2.9.0",
68 | "openai==1.14.2",
69 | "pandas==2.2.1",
70 | "parso==0.8.3",
71 | "pathvalidate==3.2.0",
72 | "peft==0.10.0",
73 | "pexpect==4.9.0",
74 | "platformdirs==4.2.0",
75 | "portalocker==2.8.2",
76 | "primefac==2.0.12",
77 | "prompt-toolkit==3.0.43",
78 | "psutil==5.9.8",
79 | "ptyprocess==0.7.0",
80 | "pure-eval==0.2.2",
81 | "py-cpuinfo==9.0.0",
82 | "pyarrow==15.0.2",
83 | "pyarrow-hotfix==0.6",
84 | "pybind11==2.11.1",
85 | "pycountry==23.12.11",
86 | "pydantic==2.6.4",
87 | "pydantic_core==2.16.3",
88 | "Pygments==2.17.2",
89 | "pynvml==11.5.0",
90 | "pyparsing==3.1.2",
91 | "pytablewriter==1.2.0",
92 | "python-dateutil==2.9.0.post0",
93 | "python-gflags==3.1.2",
94 | "pytz==2024.1",
95 | "pyzmq==25.1.2",
96 | "regex==2023.12.25",
97 | "responses==0.18.0",
98 | "rouge_score==0.1.2",
99 | "sacrebleu==1.5.0",
100 | "safetensors==0.4.2",
101 | "scikit-learn==1.4.1.post1",
102 | "scipy==1.12.0",
103 | "six==1.16.0",
104 | "sniffio==1.3.1",
105 | "sqlitedict==2.1.0",
106 | "stack-data==0.6.3",
107 | "tabledata==1.3.3",
108 | "tabulate==0.9.0",
109 | "tcolorpy==0.1.4",
110 | "threadpoolctl==3.4.0",
111 | "tokenizers==0.15.2",
112 | "tornado==6.4",
113 | "tqdm==4.66.2",
114 | "tqdm-multiprocess==0.0.11",
115 | "traitlets==5.14.2",
116 | "transformers==4.39.1",
117 | "triton==2.2.0",
118 | "typepy==1.3.2",
119 | "tzdata==2024.1",
120 | "wcwidth==0.2.13",
121 | "word2number==1.1",
122 | "xxhash==3.4.1",
123 | "yarl==1.9.4",
124 | "zstandard==0.22.0",
125 | "sentencepiece==0.2.0",
126 | "protobuf==5.26.0",
127 | "tensorboard==2.16.2",
128 | ]
129 |
130 | [tool.setuptools.packages.find]
131 | where = ["."]
--------------------------------------------------------------------------------
/scripts/eval_zero_shot.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer
2 | from lib.utils import LMEvalAdaptor
3 | from lm_eval import evaluator
4 | import os
5 | import json
6 | from quantize_save_llama import load_quantized_model
7 | from dataclasses import field, dataclass
8 | import transformers
9 | import glob
10 |
11 | # Adapted from https://github.com/Cornell-RelaxML/quip-sharp
12 |
13 | class LMEvalAdaptorWithDevice(LMEvalAdaptor):
14 | def __init__(self,
15 | model_name,
16 | model,
17 | tokenizer,
18 | batch_size=1,
19 | max_length=-1,
20 | device="cuda"):
21 | super().__init__(
22 | model_name, model, tokenizer, batch_size, max_length
23 | )
24 | self._device = device
25 |
26 | @property
27 | def device(self):
28 | return self._device
29 |
30 |
31 | @dataclass
32 | class Arguments:
33 | model_save_path: str = field(metadata={
34 | "help": ("Path of the .pt file in which the model can be found.")
35 | })
36 | finetune_save_dir: str = field(default=None, metadata={
37 | "help": ("If using a finetuned model, the directory in which the "
38 | "model.safetensors file is stored")
39 | })
40 | output_path: str = field(default=None, metadata={
41 | "help": ("Path in which to save a JSON file with zero-shot results.")
42 | })
43 | base_model: str = field(default="meta-llama/Llama-2-7b-hf", metadata={
44 | "help": ("Path of the original model, as either a local or a "
45 | "Huggingface path")
46 | })
47 | tasks: list[str] = field(default_factory=list, metadata={
48 | "help": ("Task on which to measure zero-shot accuracy, e.g."
49 | "wingorande, piqa, arc_easy, arc_challenge, rte, cola...")
50 | })
51 | batch_size: int = field(default=1, metadata={
52 | "help": "Number of datapoints processed at once"
53 | })
54 | device: str = field(default="cuda:0", metadata={
55 | "help": "Device on which to run evaluation."
56 | })
57 | cuda_graph: bool = field(default=False, metadata={
58 | "help": "Whether to use CUDA graphs and flash attention to speed up evaluation."
59 | })
60 |
61 |
62 | def eval_zero_shot(args: Arguments):
63 | print(args.base_model)
64 | model = load_quantized_model(args.model_save_path, args.base_model, args.device, cuda_graph=args.cuda_graph)
65 | model = model.to(args.device)
66 | if args.finetune_save_dir is not None:
67 | from safetensors.torch import load_model
68 | for safetensor_file in glob.glob(args.finetune_save_dir + "/model*.safetensors"):
69 | print("Loading ", safetensor_file)
70 | load_model(model, safetensor_file, strict=False)
71 | tokenizer = AutoTokenizer.from_pretrained(args.base_model, use_fast=False)
72 | print("Loaded model!")
73 |
74 | tokenizer.pad_token = tokenizer.eos_token
75 |
76 | lm_eval_model = LMEvalAdaptorWithDevice(
77 | args.base_model, model, tokenizer, args.batch_size, device=args.device)
78 | lm_eval_model.device
79 | results = evaluator.simple_evaluate(
80 | model=lm_eval_model,
81 | tasks=args.tasks,
82 | batch_size=args.batch_size,
83 | no_cache=True,
84 | num_fewshot=0,
85 | device=args.device
86 | )
87 |
88 | print(evaluator.make_table(results))
89 |
90 | if args.output_path is not None:
91 | os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
92 | # otherwise cannot save
93 | results["config"]["model"] = args.base_model
94 | with open(args.output_path + ".json", "w") as f:
95 | json.dump(results, f, indent=2)
96 |
97 |
98 | if __name__ == "__main__":
99 | parser = transformers.HfArgumentParser([Arguments])
100 | args = parser.parse_args_into_dataclasses()[0]
101 | eval_zero_shot(args)
--------------------------------------------------------------------------------
/notebooks/test_caldera.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## Notebook to try out CALDERA decomposition on a random matrix"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import torch\n",
17 | "\n",
18 | "import sys\n",
19 | "import os"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": null,
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "src_dir = os.path.abspath(os.path.join(os.getcwd(), \"..\"))\n",
29 | "\n",
30 | "if src_dir not in sys.path:\n",
31 | " sys.path.append(src_dir)"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": null,
37 | "metadata": {},
38 | "outputs": [],
39 | "source": [
40 | "from src.caldera.decomposition.dataclasses import CalderaParams\n",
41 | "from src.caldera.utils.quantization import QuantizerFactory\n",
42 | "from src.caldera.decomposition.alg import caldera"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": null,
48 | "metadata": {},
49 | "outputs": [],
50 | "source": [
51 | "quant_factory_Q = QuantizerFactory(method=\"uniform\", block_size=64)\n",
52 | "quant_factor_LR = QuantizerFactory(method=\"uniform\", block_size=64)"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": null,
58 | "metadata": {},
59 | "outputs": [],
60 | "source": [
61 | "quant_params = CalderaParams(\n",
62 | " compute_quantized_component=True,\n",
63 | " compute_low_rank_factors=True,\n",
64 | " Q_bits=4,\n",
65 | " L_bits=4,\n",
66 | " R_bits=4,\n",
67 | " rank=16,\n",
68 | " iters=20,\n",
69 | " lplr_iters=5,\n",
70 | " activation_aware_Q=False,\n",
71 | " activation_aware_LR=True,\n",
72 | " lattice_quant_Q=False,\n",
73 | " lattice_quant_LR=False,\n",
74 | " update_order=[\"Q\", \"LR\"],\n",
75 | " quant_factory_Q=quant_factory_Q,\n",
76 | " quant_factory_LR=quant_factor_LR,\n",
77 | " rand_svd=False\n",
78 | ")"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": null,
84 | "metadata": {},
85 | "outputs": [],
86 | "source": [
87 | "torch.manual_seed(42)\n",
88 | "\n",
89 | "W = torch.rand(1024, 1024)\n",
90 | "X = torch.randn(1024, 2048)\n",
91 | "H = torch.matmul(X, X.T)"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": null,
97 | "metadata": {},
98 | "outputs": [],
99 | "source": [
100 | "caldera_decomposition = caldera(\n",
101 | " quant_params=quant_params,\n",
102 | " W=W,\n",
103 | " H=H,\n",
104 | " device=\"cpu\",\n",
105 | " use_tqdm=True,\n",
106 | " scale_W=True\n",
107 | ")"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": null,
113 | "metadata": {},
114 | "outputs": [],
115 | "source": [
116 | "print(f\"caldera_decomposition.Q.shape: {caldera_decomposition.Q.shape}\")\n",
117 | "print(f\"caldera_decomposition.L.shape: {caldera_decomposition.L.shape}\")\n",
118 | "print(f\"caldera_decomposition.R.shape: {caldera_decomposition.R.shape}\")"
119 | ]
120 | }
121 | ],
122 | "metadata": {
123 | "kernelspec": {
124 | "display_name": "Python 3.11 (venv)",
125 | "language": "python",
126 | "name": "venv"
127 | },
128 | "language_info": {
129 | "codemirror_mode": {
130 | "name": "ipython",
131 | "version": 3
132 | },
133 | "file_extension": ".py",
134 | "mimetype": "text/x-python",
135 | "name": "python",
136 | "nbconvert_exporter": "python",
137 | "pygments_lexer": "ipython3",
138 | "version": "3.10.15"
139 | }
140 | },
141 | "nbformat": 4,
142 | "nbformat_minor": 2
143 | }
144 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools >= 61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "caldera"
7 | version = "0.0.1"
8 | description = "Decomposition of neural network (e.g., LLM) weight matrices into Q + LR, where L and R are low-rank factors and all matrices are quantized to low-bit precision."
9 | readme = "README.md"
10 | requires-python = ">=3.8.0"
11 | classifiers = [
12 | "Programming Language :: Python :: 3",
13 | "License :: OSI Approved :: MIT License",
14 | "Operating System :: OS Independent",
15 | ]
16 |
17 | dependencies = [
18 | "packaging==24.0",
19 | "torch==2.2.1",
20 | "torchaudio==2.2.1",
21 | "torchvision==0.17.1",
22 | "absl-py==2.1.0",
23 | "accelerate==0.28.0",
24 | "aiohttp==3.9.3",
25 | "aiosignal==1.3.1",
26 | "annotated-types==0.6.0",
27 | "anyio==4.3.0",
28 | "argparse-dataclass==2.0.0",
29 | "asttokens==2.4.1",
30 | "attrs==23.2.0",
31 | "chardet==5.2.0",
32 | "click==8.1.7",
33 | "colorama==0.4.6",
34 | "comm==0.2.2",
35 | "contourpy==1.2.0",
36 | "cycler==0.12.1",
37 | "DataProperty==1.0.1",
38 | "datasets==2.18.0",
39 | "debugpy==1.8.1",
40 | "decorator==5.1.1",
41 | "deepspeed==0.14.0",
42 | "dill==0.3.8",
43 | "distro==1.9.0",
44 | "evaluate==0.4.1",
45 | "executing==2.0.1",
46 | "fonttools==4.50.0",
47 | "frozenlist==1.4.1",
48 | "fsspec==2024.2.0",
49 | "glog==0.3.1",
50 | "h11==0.14.0",
51 | "hjson==3.1.0",
52 | "httpcore==1.0.4",
53 | "httpx==0.27.0",
54 | "huggingface-hub==0.21.4",
55 | "ipykernel==6.29.3",
56 | "ipython==8.22.2",
57 | "jedi==0.19.1",
58 | "joblib==1.3.2",
59 | "jsonlines==4.0.0",
60 | "jupyter_client==8.6.1",
61 | "jupyter_core==5.7.2",
62 | "kiwisolver==1.4.5",
63 | "lm-eval==0.3.0",
64 | "lxml==5.1.0",
65 | "matplotlib==3.8.3",
66 | "matplotlib-inline==0.1.6",
67 | "mbstrdecoder==1.1.3",
68 | "more-itertools==10.2.0",
69 | "multidict==6.0.5",
70 | "multiprocess==0.70.16",
71 | "nest-asyncio==1.6.0",
72 | "ninja==1.11.1.1",
73 | "nltk==3.8.1",
74 | "numexpr==2.9.0",
75 | "openai==1.14.2",
76 | "pandas==2.2.1",
77 | "parso==0.8.3",
78 | "pathvalidate==3.2.0",
79 | "peft==0.10.0",
80 | "pexpect==4.9.0",
81 | "platformdirs==4.2.0",
82 | "portalocker==2.8.2",
83 | "primefac==2.0.12",
84 | "prompt-toolkit==3.0.43",
85 | "psutil==5.9.8",
86 | "ptyprocess==0.7.0",
87 | "pure-eval==0.2.2",
88 | "py-cpuinfo==9.0.0",
89 | "pyarrow==15.0.2",
90 | "pyarrow-hotfix==0.6",
91 | "pybind11==2.11.1",
92 | "pycountry==23.12.11",
93 | "pydantic==2.6.4",
94 | "pydantic_core==2.16.3",
95 | "Pygments==2.17.2",
96 | "pynvml==11.5.0",
97 | "pyparsing==3.1.2",
98 | "pytablewriter==1.2.0",
99 | "python-dateutil==2.9.0.post0",
100 | "python-gflags==3.1.2",
101 | "pytz==2024.1",
102 | "pyzmq==25.1.2",
103 | "regex==2023.12.25",
104 | "responses==0.18.0",
105 | "rouge_score==0.1.2",
106 | "sacrebleu==1.5.0",
107 | "safetensors==0.4.2",
108 | "scikit-learn==1.4.1.post1",
109 | "scipy==1.12.0",
110 | "six==1.16.0",
111 | "sniffio==1.3.1",
112 | "sqlitedict==2.1.0",
113 | "stack-data==0.6.3",
114 | "tabledata==1.3.3",
115 | "tabulate==0.9.0",
116 | "tcolorpy==0.1.4",
117 | "threadpoolctl==3.4.0",
118 | "tokenizers==0.15.2",
119 | "tornado==6.4",
120 | "tqdm==4.66.2",
121 | "tqdm-multiprocess==0.0.11",
122 | "traitlets==5.14.2",
123 | "transformers==4.39.1",
124 | "triton==2.2.0",
125 | "typepy==1.3.2",
126 | "tzdata==2024.1",
127 | "wcwidth==0.2.13",
128 | "word2number==1.1",
129 | "xxhash==3.4.1",
130 | "yarl==1.9.4",
131 | "zstandard==0.22.0",
132 | "sentencepiece==0.2.0",
133 | "protobuf==5.26.0",
134 | "tensorboard==2.16.2",
135 | ]
136 |
137 | [tool.setuptools.packages.find]
138 | where = ["src"]
--------------------------------------------------------------------------------
/scripts/eval_ppl.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from lib.utils import gptq_data_utils
3 | from tqdm import tqdm
4 | from quantize_save_llama import load_quantized_model
5 | from dataclasses import dataclass, field
6 | import transformers
7 | import os
8 | import json
9 | import glob
10 |
11 | # Adapted from https://github.com/Cornell-RelaxML/quip-sharp
12 |
13 | @dataclass
14 | class Arguments:
15 | model_save_path: str = field(metadata={
16 | "help": ("Path of the .pt file in which the model can be found.")
17 | })
18 | finetune_save_dir: str = field(default=None, metadata={
19 | "help": ("If using a finetuned model, the directory in which the "
20 | "model.safetensors file is stored")
21 | })
22 | output_path: str = field(default=None, metadata={
23 | "help": ("Path in which to save a JSON file with zero-shot results.")
24 | })
25 | base_model: str = field(default="meta-llama/Llama-2-7b-hf", metadata={
26 | "help": ("Path of the original model, as either a local or a "
27 | "Huggingface path")
28 | })
29 | seed: int = field(default=0, metadata={
30 | "help": "Random seed for selecting test points from the dataset"
31 | })
32 | seqlen: int = field(default=4096, metadata={
33 | "help": "Sequence length of model inputs"
34 | })
35 | device: str = field(default="cuda:0", metadata={
36 | "help": "Device on which to run evaluation."
37 | })
38 | datasets: list[str] = field(default_factory=list, metadata={
39 | "help": ("Which datasets, out of \"wikitext2\" and \"c4\" to compute "
40 | "perplexity. Defaults to both datasets")})
41 | cuda_graph: bool = field(default=False, metadata={
42 | "help": "Whether to use CUDA graphs and flash attention to speed up evaluation."
43 | })
44 |
45 |
46 | def eval_ppl(args: Arguments):
47 |
48 | with torch.no_grad():
49 | model = load_quantized_model(args.model_save_path, args.base_model, args.device, cuda_graph=args.cuda_graph)
50 |
51 | if args.finetune_save_dir is not None:
52 | from safetensors.torch import load_model
53 | for safetensor_file in glob.glob(args.finetune_save_dir + "/model*.safetensors"):
54 | print("Loading ", safetensor_file)
55 | load_model(model, safetensor_file, strict=False)
56 |
57 |
58 | if not args.datasets:
59 | args.datasets = ["wikitext2", "c4"]
60 |
61 | ppls = {}
62 |
63 | for dataset in args.datasets:
64 | input_tok = gptq_data_utils.get_test_tokens(
65 | dataset, seed=args.seed, seqlen=args.seqlen, model=args.base_model)
66 | nsamples = input_tok.numel() // args.seqlen
67 | input_tok = input_tok[0, :(args.seqlen * nsamples)].view(
68 | nsamples, args.seqlen)
69 |
70 | loss_fct = torch.nn.CrossEntropyLoss().cuda()
71 | acc_loss = 0.0
72 | progress = tqdm(range(nsamples))
73 | for ii in progress:
74 | input = input_tok[ii, :].to(args.device).view(1, -1)
75 | output = model(input,
76 | use_cache=False,
77 | output_hidden_states=False,
78 | output_attentions=False)[0]
79 |
80 | shift_logits = output[:, :-1, :].contiguous()
81 | shift_labels = input[:, 1:]
82 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
83 | shift_labels.view(-1))
84 | acc_loss += loss.item()
85 | progress.set_description(f"avg_loss = {acc_loss/(ii+1)}")
86 |
87 | avg_loss = acc_loss / nsamples
88 |
89 | ppl = torch.exp(torch.tensor(avg_loss)).item()
90 | print(f'{dataset} perplexity: {ppl}')
91 |
92 | ppls[dataset] = ppl
93 |
94 | if args.output_path is not None:
95 | os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
96 | with open(args.output_path + ".json", "w") as f:
97 | json.dump(ppls, f, indent=2)
98 |
99 |
100 | if __name__ == "__main__":
101 | parser = transformers.HfArgumentParser([Arguments])
102 | args = parser.parse_args_into_dataclasses()[0]
103 | eval_ppl(args)
--------------------------------------------------------------------------------
/notebooks/eval_throughput.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "import gc\n",
20 | "import torch\n",
21 | "from lib.utils import graph_wrapper\n",
22 | "from transformers import AutoTokenizer, LlamaForCausalLM\n",
23 | "import time"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": null,
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "def load_quantized_model(\n",
33 | " model_save_path,\n",
34 | " base_model,\n",
35 | " device,\n",
36 | "):\n",
37 | " model = torch.load(model_save_path, map_location=device).to(device) # Llama with Caldera\n",
38 | " graph_model = graph_wrapper.get_graph_wrapper(LlamaForCausalLM, device=\"cpu\").from_pretrained(\n",
39 | " base_model, torch_dtype='auto', device_map=\"cpu\", low_cpu_mem_usage=True,\n",
40 | " use_flash_attention_2=True\n",
41 | " ).to(\"cpu\") # base Llama\n",
42 | "\n",
43 | " for i in range(len(graph_model.model.layers)):\n",
44 | " graph_model.model.layers[i].self_attn.q_proj = model.model.layers[i].self_attn.q_proj\n",
45 | " graph_model.model.layers[i].self_attn.k_proj = model.model.layers[i].self_attn.k_proj\n",
46 | " graph_model.model.layers[i].self_attn.v_proj = model.model.layers[i].self_attn.v_proj\n",
47 | " graph_model.model.layers[i].self_attn.o_proj = model.model.layers[i].self_attn.o_proj\n",
48 | " graph_model.model.layers[i].mlp = model.model.layers[i].mlp\n",
49 | " graph_model.model.layers[i].post_attention_layernorm = graph_model.model.layers[i].post_attention_layernorm.to(device)\n",
50 | " graph_model.model.layers[i].input_layernorm = graph_model.model.layers[i].input_layernorm.to(device)\n",
51 | " graph_model.model.norm = graph_model.model.norm.to(device)\n",
52 | " graph_model.model.embed_tokens = graph_model.model.embed_tokens.to(device)\n",
53 | " graph_model.lm_head = graph_model.lm_head.to(device)\n",
54 | " graph_model.graph_device = device\n",
55 | " return graph_model.to(device)\n",
56 | " "
57 | ]
58 | },
59 | {
60 | "cell_type": "markdown",
61 | "metadata": {},
62 | "source": [
63 | "## Test Throughput of CALDERA Model"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": null,
69 | "metadata": {},
70 | "outputs": [],
71 | "source": [
72 | "MODEL_PATH = \"/media/hdd1/caldera-full-models/llama-2-7b/caldera-rank-256-4B-factors-downdate-no-RHT-ft.pt\"\n",
73 | "BASE_MODEL = \"meta-llama/Llama-2-7b-hf\"\n",
74 | "DEVICE = \"cuda:2\"\n",
75 | "SAMPLES = 500"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": null,
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "model = load_quantized_model(MODEL_PATH, BASE_MODEL, DEVICE)"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": null,
90 | "metadata": {},
91 | "outputs": [],
92 | "source": [
93 | "def eval_throughput(model, samples, base_model, device, batch_size=1, seq_len=1):\n",
94 | " tokenizer = AutoTokenizer.from_pretrained(base_model)\n",
95 | "\n",
96 | " prompt = 'It is a truth universally acknowledged that'\n",
97 | " inputs = tokenizer(prompt, return_tensors='pt')\n",
98 | " token = inputs['input_ids'][0:1, 0:1].to(device).repeat(batch_size, seq_len)\n",
99 | " model(token)\n",
100 | "\n",
101 | " torch.cuda.synchronize()\n",
102 | " start = time.time()\n",
103 | " for _ in range(samples):\n",
104 | " model(token)\n",
105 | " torch.cuda.synchronize()\n",
106 | " end = time.time()\n",
107 | " print('TIME:', (end - start) / samples, 's/tok')\n",
108 | " print (f'THROUGHPUT: {samples / (end - start)} tok/s')"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": null,
114 | "metadata": {},
115 | "outputs": [],
116 | "source": [
117 | "eval_throughput(model, SAMPLES, BASE_MODEL, DEVICE)"
118 | ]
119 | },
120 | {
121 | "cell_type": "markdown",
122 | "metadata": {},
123 | "source": [
124 | "## Compare with Unquantized"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": null,
130 | "metadata": {},
131 | "outputs": [],
132 | "source": [
133 | "del model\n",
134 | "gc.collect()\n",
135 | "torch.cuda.empty_cache()"
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": null,
141 | "metadata": {},
142 | "outputs": [],
143 | "source": [
144 | "model = graph_wrapper.get_graph_wrapper(LlamaForCausalLM, device=DEVICE).from_pretrained(\n",
145 | " BASE_MODEL, torch_dtype='auto', device_map=DEVICE, low_cpu_mem_usage=True,\n",
146 | " use_flash_attention_2=True\n",
147 | " )"
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": null,
153 | "metadata": {},
154 | "outputs": [],
155 | "source": [
156 | "eval_throughput(model, SAMPLES, BASE_MODEL, DEVICE)"
157 | ]
158 | }
159 | ],
160 | "metadata": {
161 | "kernelspec": {
162 | "display_name": "Python 3",
163 | "language": "python",
164 | "name": "python3"
165 | },
166 | "language_info": {
167 | "codemirror_mode": {
168 | "name": "ipython",
169 | "version": 3
170 | },
171 | "file_extension": ".py",
172 | "mimetype": "text/x-python",
173 | "name": "python",
174 | "nbconvert_exporter": "python",
175 | "pygments_lexer": "ipython3",
176 | "version": "3.11.9"
177 | }
178 | },
179 | "nbformat": 4,
180 | "nbformat_minor": 2
181 | }
182 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CALDERA (Calibration Aware Low-Precision DEcomposition with Low-Rank Adaptation)
2 |
3 | CALDERA is a post-training compression method that represents the weights of LLM matrices via a low-rank, low-precision decomposition $\mathbf{W} \approx \mathbf{Q} + \mathbf{L} \mathbf{R}$, where $\mathbf{L}$ and $\mathbf{R}$ are low-rank factors and $\mathbf{Q}, \mathbf{L}$ and $\mathbf{R}$ are all quantized to low-precision formats.
4 | By formulating this decomposition as an optimization problem and solving it via alternating minimization, CALDERA outperforms existing compression techniques in the regime of less than 2.5 bits per parameter.
5 | To enhance performance on specific tasks, CALDERA also supports Low Rank Adaptation (LoRA) fine tuning ([Hu et al, 2021](https://arxiv.org/pdf/2106.09685)) of a portion of the low-rank factors.
6 |
7 | 🔗 Paper link: [Compressing Large Language Models using Low Rank
8 | and Low Precision Decomposition](https://openreview.net/pdf?id=lkx3OpcqSZ)
9 |
10 |
11 |
12 |
13 |
14 | CALDERA decomposes a full-precision weight matrix into a low-rank component (LR), which captures the contribution of the top singular values using BL, BR bits, and Q for the trailing singular values with BQ bits, enabling flexible precision settings for each component.
15 |
16 |
17 | ## Setup Instructions
18 |
19 | ### Note on Python, CUDA, and PyTorch Versions
20 | These setup instructions have been tested on Python 3.10 and 3.11, CUDA 12.1 and 12.2, and PyTorch 2.2.
21 |
22 | In particular, the package `fast-hadamard-transform` lacks wheels for newer versions of these dependencies; the available wheels can be found [here](https://github.com/Dao-AILab/fast-hadamard-transform) (the wheel filenames are of the form `fast_hadamard_transform-1.0.4.post1+cutorchxx11abiTRUE-cp-cp-linux_x86_64.whl`).
23 |
24 | ### 🛠 Instructions
25 | 1. Install `caldera` as a submodule (named `caldera`).
26 | From the home directory of this repository, run
27 | ```
28 | pip install .
29 | ```
30 | This will automatically install all dependencies, except `fast-hadamard-transform`, which has dependency issues.
31 |
32 | 2. While CALDERA can be used with any quantizer, we demonstrate the results using QuIP#'s quantizer. Setup the QuIP# ([Tseng et al, 2024](https://arxiv.org/pdf/2402.04396)) submodule:
33 | ```
34 | ./setup_quip_sharp.sh
35 | ```
36 |
37 | This script first sets up the QuIP# Python library, and then builds the `quiptools` CUDA library, which provides dequantization kernels for inference.
38 |
39 | QuIP# is used for the quantization of the $\mathbf{Q}$ matrix (backbone), and also provides useful subroutines for Hessian computation.
40 |
41 | 3. Install `fast-hadamard-transform`: `pip install fast-hadamard-transform`.
42 |
43 | **Note**: If you get the error `package 'wheel' is not installed`, you can install it using `pip install wheel`.
44 |
45 |
46 | ## Repo structure
47 |
48 | ### `src/caldera`
49 | This folder contains the bulk of the code for CALDERA. Via step 1 above, everything in this folder is contained in the editable python package `caldera`.
50 |
51 | **`src/caldera/utils`**: utils for CALDERA. Some relevant utils files are listed below:
52 | - `enums.py`: `Enum` objects, e.g., for specifying transformer sublayers (query, key, etc.) and the name of the calibration dataset.
53 | - `quantization.py`: Uniform and Normal Float ([Dettmers et al, 2023](https://arxiv.org/pdf/2305.14314)) quantizers.
54 | Generally, these are not recommended; E8 Lattice quantizers from QuIP# typically perform better.
55 |
56 | **`src/caldera/decomposition`**: code for the CALDERA decomposition algorithm, as well as its application to transformer layers.
57 |
58 | - `dataclasses.py`: classes for storing parameters of the CALDERA algorithm, as well as information about quantized layers.
59 | - `weight_compression.py`: code for the `ActivationAwareWeightCompressor` class. Unless Hessians have already been computed, this performs Hessian computation upon instantiation. The method `get_layer_quantizer`, called on a layer index, instantiates an `ActivationAwareLayerQuant` object.
60 | - `layer_quantization.py`: code for the `ActivationAwareLayerQuant` class. The `compress_sublayer` compresses the specified sublayer, calling the `caldera` method from `alg.py`.
61 | There are also methods for plotting the data-aware error, saving errors and quantization parameters to a JSON file, and instantiating a quantized linear layer.
62 | - `alg.py`: the CALDERA algorithm.
63 | - `quantized_layer.py`: code for the `CalderaQuantizedLinear` class, which is a neural network module that computes $X^\top (Q + LR)^\top$ on layer input $X$, performing dequantization on the fly.
64 |
65 |
66 | ### `scripts`
67 | This folder contains python scripts for running zero-shot, perplexity, and finetuning experiments.
68 |
69 | Parameters for all of these scripts are specified via command-line arguments.
70 |
71 | ### `shell_scripts`
72 | These are Bash scripts for running experiments in the `scripts` folder with some reasonable parameters.
73 |
74 | Each shell script has variables at the top specifying, e.g., directories in which to save script outputs.
75 | Make sure you set those variables as appropriate.
76 |
77 | **Note**: all shell scripts are meant to be run from the root directory of this repo, i.e., `./shell_scripts/run_eval_ppl.py` instead of `cd shell_scripts && ./run_eval_ppl.py`.
78 |
79 | ### `quip_sharp`
80 | This is the quip-sharp submodule, which is initialized in step 2 of the setup instructions.
81 |
82 | ### `notebooks`
83 |
84 | - `test_caldera.ipynb` to obtain the quickly try out CALDERA decomposition on a random matrix.
85 | - `eval_throughput.ipynb` obtains the autoregressive generation throughput of the model.
86 |
87 | ## 🚀 Trying out CALDERA: Example experiment workflow
88 |
89 | **Note**: Edit each script before running it to make sure desired parameters are used.
90 |
91 | 1. **Compute the Hessians** using `./shell_scripts/run_save_hessians.sh`, which will store Hessian matrices for each layer to files.
92 |
93 | 2. **Quantize the full model** using `shell_scripts/run_quantize_save_caldera.sh`. This stores each quantized transformer layer.
94 | The quantized model can later be loaded in using the `load_quantized_model` function in `scripts/quantize_save_llama.py`.
95 |
96 | 3. **Run zero-shot/perplexity experiments** using `shell_scripts/run_eval_zeroshot.sh` or `shell_scripts/run_eval_ppl.sh`.
97 |
98 | 4. **Finetune** using, e.g., `shell_scripts/run_finetune_wikitext.sh`
99 |
100 | ## Citation
101 | If you find our work useful, consider citing it as:
102 | ```bibtex
103 | @inproceedings{
104 | saha2024compressing,
105 | title={Compressing Large Language Models using Low Rank and Low Precision Decomposition},
106 | author={Rajarshi Saha and Naomi Sagan and Varun Srivastava and Andrea Goldsmith and Mert Pilanci},
107 | booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
108 | year={2024},
109 | url={https://openreview.net/forum?id=lkx3OpcqSZ}
110 | }
111 | ```
112 |
--------------------------------------------------------------------------------
/scripts/finetune_RHT.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
4 | from lib.utils import sample_rp1t
5 | from lib.utils.data_utils import split_data
6 | import gc
7 | from quantize_save_llama import load_quantized_model
8 | from dataclasses import dataclass, field
9 | from tqdm import tqdm
10 | from safetensors.torch import save_model, load_model
11 | import os
12 |
13 | # Adapted from https://github.com/Cornell-RelaxML/quip-sharp
14 |
15 |
16 | @dataclass
17 | class Arguments:
18 | base_model: str = field(metadata={
19 | "help": ("Path of the original model, as either a local or a "
20 | "Huggingface path")
21 | })
22 | model_path: str = field(metadata={
23 | "help": ("Path of the .pt file in which the model can be found.")
24 | })
25 | finetuned_save_path: str = field(metadata={
26 | "help": ("Path in which to save the final finetuned model")
27 | })
28 | devset_size: int = field(default=256, metadata={
29 | "help": ("Number of datapoints to sample from the calibration set "
30 | "for finetuning.")
31 | })
32 | ctx_size: int = field(default=512, metadata={
33 | "help": ("Length of each input data sequence.")
34 | })
35 | device: str = field(default="cuda", metadata={
36 | "help": "Device to use for finetuning."
37 | })
38 | ft_bs: int = field(default=2, metadata={
39 | "help": "Batch size for finetuning."
40 | })
41 | ft_valid_size: int = field(default=64, metadata={
42 | "help": ("Number of datapoints to set aside for validation. "
43 | "The number of training datapoints is devset_size, minus "
44 | "ft_valid_size.")
45 | })
46 | finetune_factors: bool = field(default=False, metadata={
47 | "help": ("Whether to finetune L and R in addition to the randomized "
48 | "Hadamard transform diagonal matrices.")
49 | })
50 | RHT_learning_rate: float = field(default=1e-3, metadata={
51 | "help": "Learning rate for the randomized Hadamard transform parameters."
52 | })
53 | factors_learning_rate: float = field(default=1e-4, metadata={
54 | "help": "Learning rate for L andsR, if finetune_factors is set True."
55 | })
56 | epochs: int = field(default=5, metadata={
57 | "help": "Number of epochs of finetuning."
58 | })
59 |
60 |
61 | def main(args: Arguments):
62 | torch.set_grad_enabled(False)
63 | tokenizer = AutoTokenizer.from_pretrained(args.base_model)
64 | tokenizer.pad_token = tokenizer.eos_token
65 | devset = sample_rp1t(tokenizer, args.devset_size, args.ctx_size, 1)
66 |
67 | # Get the logits for the calibration set from the original model. The loss
68 | # function for finetuning will be the cross-entropy loss between the quantized
69 | # model logits and these logits.
70 | orig_model = AutoModelForCausalLM.from_pretrained(
71 | args.base_model, torch_dtype='auto', device_map=args.device, low_cpu_mem_usage=True
72 | ).to(args.device)
73 |
74 | orig_logits = torch.zeros(args.devset_size, args.ctx_size, orig_model.config.vocab_size)
75 | for i in range(args.devset_size):
76 | input = devset[i:i+1].to(args.device)
77 | output = orig_model(input).logits.cpu()
78 | orig_logits[i:i+1, :, :] = output
79 |
80 | orig_logits = orig_logits[:, :-1].contiguous().softmax(dim=-1).float()
81 | del orig_model
82 | gc.collect()
83 | torch.cuda.empty_cache()
84 |
85 | torch.set_grad_enabled(True)
86 | quant_model = load_quantized_model(
87 | args.model_path, args.base_model, args.device
88 | ).to(args.device).float()
89 |
90 | factor_params = []
91 | RHT_params = []
92 |
93 | for name, param in quant_model.named_parameters():
94 | if 'L_ft' in name or 'R_ft' in name:
95 | if not args.finetune_factors:
96 | param.requires_grad = False
97 | else:
98 | factor_params.append(param)
99 | elif 'SU' in name or 'SV' in name:
100 | RHT_params.append(param)
101 | train_dataloader, valid_dataloader = split_data(devset, orig_logits, args)
102 |
103 | adam_params = [{
104 | 'params': RHT_params,
105 | 'lr': args.RHT_learning_rate
106 | }]
107 | if args.finetune_factors:
108 | adam_params.append({
109 | 'params': factor_params,
110 | 'lr': args.factors_learning_rate
111 | })
112 | optim = torch.optim.AdamW(adam_params)
113 | scaler = torch.cuda.amp.GradScaler(enabled=True)
114 |
115 | quant_model.eval()
116 | print("Running eval.")
117 | with torch.no_grad():
118 | val_loss = 0
119 | for _, (input, targets) in enumerate(valid_dataloader):
120 | input, targets = input.to(args.device), targets.to(args.device)
121 | output = quant_model(input).logits[:, :-1].contiguous()
122 |
123 | val_loss += nn.CrossEntropyLoss()(
124 | output.view(-1, output.shape[-1]),
125 | targets.view(-1, targets.shape[-1])).item()
126 | val_loss /= len(valid_dataloader)
127 | print("Validation loss: ", val_loss)
128 | best_val_loss = val_loss
129 | save_model(quant_model, args.finetuned_save_path)
130 |
131 | progress_bar = tqdm(range(len(train_dataloader)*args.epochs))
132 | for _ in range(args.epochs):
133 | for _, (input, targets) in enumerate(train_dataloader):
134 | input, targets = input.to(args.device), targets.to(args.device)
135 | output = quant_model(input).logits[:, :-1].contiguous()
136 |
137 | loss = nn.CrossEntropyLoss()(output.view(-1, output.shape[-1]),
138 | targets.view(-1, targets.shape[-1]))
139 | scaler.scale(loss).backward()
140 | scaler.step(optim)
141 | scaler.update()
142 | optim.zero_grad()
143 | progress_bar.update(1)
144 |
145 | # validation
146 | quant_model.eval()
147 | print("Running eval.")
148 | with torch.no_grad():
149 | val_loss = 0
150 | for _, (input, targets) in enumerate(valid_dataloader):
151 | input, targets = input.to(args.device), targets.to(args.device)
152 | output = quant_model(input).logits[:, :-1].contiguous()
153 |
154 | val_loss += nn.CrossEntropyLoss()(
155 | output.view(-1, output.shape[-1]),
156 | targets.view(-1, targets.shape[-1])).item()
157 | val_loss /= len(valid_dataloader)
158 | print("Validation loss: ", val_loss)
159 | if val_loss < best_val_loss:
160 | save_model(quant_model, args.finetuned_save_path)
161 | best_val_loss = val_loss
162 | quant_model.train()
163 |
164 | quant_model = load_quantized_model(
165 | args.model_path, args.base_model, args.device
166 | ).to(args.device)
167 | load_model(quant_model, args.finetuned_save_path)
168 | torch.save(quant_model, args.finetuned_save_path)
169 |
170 | if __name__ == "__main__":
171 | parser = HfArgumentParser([Arguments])
172 | args = parser.parse_args()
173 | main(args)
--------------------------------------------------------------------------------
/src/caldera/decomposition/dataclasses.py:
--------------------------------------------------------------------------------
1 | from caldera.utils.enums import DevSet
2 | from caldera.utils.quantization import QuantizerFactory, \
3 | AbstractQuantizer, LowMemoryQuantizer
4 |
5 | from dataclasses import field, dataclass
6 | import torch
7 |
8 |
9 | @dataclass
10 | class DataParameters:
11 | """
12 | Parameters for loading the calibration dataset and computing the
13 | inputs to each layer.
14 | """
15 | devset: int = field(
16 | default=DevSet.RP1T, metadata={"help": (
17 | "Calibration dataset, as a member of the DevSet enum."
18 | )}
19 | )
20 | devset_size: int = field(
21 | default=256, metadata={"help": (
22 | "Number of calibration samples to use."
23 | )}
24 | )
25 | context_length: int = field(
26 | default=4096, metadata={"help": (
27 | "Length of context window."
28 | )}
29 | )
30 | batch_size: int = field(
31 | default=2, metadata={"help": (
32 | "Number of datapoints to pass into the model at once."
33 | )}
34 | )
35 | chunk_size: int = field(
36 | default=256, metadata={"help": (
37 | "Number of datapoints sent to each GPU at a time. "
38 | "Must be a multiple of batch_size"
39 | )}
40 | )
41 | devices: list[str] = field(
42 | default=None, metadata={"help": (
43 | "Specific CUDA devices to use for Hessian computation. Defaults "
44 | "to None, which means that all available devices are used."
45 | )}
46 | )
47 |
48 |
49 | @dataclass
50 | class ModelParameters:
51 | """
52 | Parameters for loading in a transformer model and simulating forward
53 | passes.
54 | """
55 | base_model: str = field(
56 | default="meta-llama/Llama-2-7b-hf", metadata={"help": (
57 | "Model to quantize."
58 | )}
59 | )
60 | token: str = field(default=None, metadata={
61 | "help": "Huggingface access token for private models."
62 | })
63 |
64 |
65 | @dataclass
66 | class AccumulatorArgs:
67 | """
68 | These arguments will be passed into the `accumulator` function in
69 | quip-sharp/quantize_llama/hessian_offline_llama.py. This class will be
70 | automatically instantiated by ActivationAwareWeightCompressor.
71 | """
72 | scratch_path = None
73 | save_path: str = field(default="./hessians/")
74 |
75 |
76 | @dataclass
77 | class QuIPArgs:
78 | """
79 | Parameters for QuIP. See the documentation of the quip-sharp repository
80 | for descriptions of these (or some of these) parameters.
81 | """
82 | lora_rank: int = field(default=0)
83 | full_svd: bool = field(default=False)
84 | use_fp64: bool = False
85 | lowmem_ldlq: bool = field(default=False)
86 | scale_override: float = field(default=0.9)
87 | resid_scale_override: float = field(default=0.9)
88 | no_use_buffered: bool = field(default=False)
89 | sigma_reg: float = field(default=1e-2)
90 | sigma_reg2: float = field(default=1e-2)
91 | incoh_mode: str = field(default="had", metadata={
92 | "help": ("Which form of incoherence processing to use. Either \"had\""
93 | "for a randomized Hadamard transform, or \"kron\" for a "
94 | "randomized Kronecker product of 2x2 matrices.")
95 | })
96 | rescale_WH: bool = field(default=False)
97 | quip_tune_iters: int = field(default=10, metadata={
98 | "help": ("Number of iterations in the LDLQ step.")
99 | })
100 |
101 |
102 | @dataclass
103 | class CalderaParams:
104 | """
105 | Parameters for the CALDERA decomposition.
106 | """
107 | quip_args: QuIPArgs = field(default_factory=QuIPArgs)
108 | compute_quantized_component: bool = field(
109 | default=True, metadata={"help": (
110 | "Whether the decomposition should include a quantized full-size"
111 | "component (denoted Q)."
112 | )}
113 | )
114 | compute_low_rank_factors: bool = field(
115 | default=True, metadata={"help": (
116 | "Whether the decomposition should include low-rank factors (L, R)."
117 | )}
118 | )
119 | Q_bits: int = field(default=2, metadata={
120 | "help": "Either 2, 3, or 4 bit lattice quantization"
121 | })
122 | L_bits: int = field(default=2, metadata={
123 | "help": "Either 2, 3, or 4 bit lattice quantization"
124 | })
125 | R_bits: int = field(default=2, metadata={
126 | "help": "Either 2, 3, or 4 bit lattice quantization"
127 | })
128 | rank: int = field(default=64, metadata={
129 | "help": "Rank of L and R factors"
130 | })
131 | iters: int = field(default=20)
132 | lplr_iters: int = field(default=5)
133 | activation_aware_Q: bool = field(default=True, metadata={
134 | "help": ("Use QuIP# activation-aware quantization for Q, as opposed "
135 | "to naive quantization.")
136 | })
137 | activation_aware_LR: bool = field(default=True, metadata={
138 | "help": "Use activation-aware LPLR for computing the factors."
139 | })
140 | lattice_quant_Q: bool = field(default=True, metadata={
141 | "help": ("If Q is not data-aware, this determines whether to use "
142 | "lattice quantization, as opposed to unif/normal float quantization "
143 | "implementations.")
144 | })
145 | lattice_quant_LR: bool = field(default=True, metadata={
146 | "help": ("Use lattice quantization from the QuIP# codebase, as opposed"
147 | " to uniform or normal float, for L and R")
148 | })
149 | hadamard_transform: bool = field(default=False, metadata={
150 | "help": ("Whether to perform a randomized Hadamard transform on W "
151 | "before computing the decomposition W = Q + LR.")
152 | })
153 | full_quip_sharp: bool = field(default=False, metadata={
154 | "help": ("If Q is activation-aware and this parameter is True, then "
155 | "Q is computed using the full quip-sharp algorithm. "
156 | "Otherwise, we only use LDLQ.")
157 | })
158 | update_order: list[str] = field(default_factory=list, metadata={
159 | "help": ("List specifying whether to update the \"LR\" factors before "
160 | "\"q\" or vice versa. The default is [\"LR\", \"Q\"]; pass "
161 | "in [\"Q\", \"LR\"] to swap the update order.")
162 | })
163 | quant_factory_Q: QuantizerFactory = field(
164 | default_factory=QuantizerFactory, metadata={"help": (
165 | "(Non-data-aware only) QuantizerFactory (from caldera.utils.quantizers)"
166 | " object used to instantiate quantizer for Q. Only used if "
167 | "activation_aware_Q is False."
168 | )}
169 | )
170 | quant_factory_LR: QuantizerFactory = field(
171 | default_factory=QuantizerFactory, metadata={"help": (
172 | "(Non-lattice quant only) QuantizerFactory (from "
173 | "caldera.utils.quantizers) object used to instantiate quantizer for L "
174 | "and R. Only used if lattice_quant_LR is False."
175 | )}
176 | )
177 | rand_svd: bool = field(default=True, metadata={
178 | "help": "Whether to use randomized SVD for LPLR initialization"
179 | })
180 | Q_hessian_downdate: bool = field(default=False, metadata={
181 | "help": ("Whether to do quip-sharp's heuristic Hessian correction"
182 | "via Cholesky downdating before updating Q.")
183 | })
184 | lattice_quant_block_size: int = field(default=32000, metadata={
185 | "help": ("For lattice quantization, quantize parameters in groups of "
186 | "(codesize * lattice_quant_block_size) to reduce memory "
187 | "usage")
188 | })
189 |
190 | @dataclass
191 | class CalderaDecomposition:
192 | Q: torch.Tensor = field(default=None)
193 | L: torch.Tensor = field(default=None)
194 | R: torch.Tensor = field(default=None)
195 | W: torch.Tensor = field(default=None)
196 | Q_idxs: torch.Tensor = field(default=None)
197 | L_idxs: torch.Tensor = field(default=None)
198 | R_idxs: torch.Tensor = field(default=None)
199 | Q_scale: float = field(default=1)
200 | L_scale: float = field(default=1)
201 | R_scale: float = field(default=1)
202 | global_scale: float = field(default=1)
203 | SU: torch.Tensor = field(default=None)
204 | SV: torch.Tensor = field(default=None)
205 | scaleWH: torch.Tensor = field(default=None)
206 | errors: dict[str,list[float]] = field(default_factory=dict)
207 |
208 |
209 | @dataclass
210 | class SubLayerInfo:
211 | """
212 | Class for storing information about a transformer sub-layer (i.e., one of
213 | {query, key, value, out, gate, up, down}), including the computed
214 | decomposition Q + LR, and the activation-aware error at each iteration of
215 | the CALDERA algorithm.
216 | """
217 | sublayer: torch.nn.Module = field(default=None)
218 | key: str = field(default="")
219 | out_key: str = field(default="")
220 | started_quant: bool = field(default=False)
221 | caldera: CalderaDecomposition = field(default_factory=CalderaDecomposition)
222 |
223 |
224 | @dataclass
225 | class QuantInfo:
226 | """
227 | Stores information necessary for quantizing a specific matrix:
228 | 1. Whether to use lattice quantization (QuIP#) or Unif./NormalFloat
229 | quantization.
230 | 2. If lattice quantization is used, the codebook.
231 | 3. If our quantization methods are used, the quantizer object.
232 | """
233 | lattice_quant: bool = field(default=True)
234 | lattice_cb: torch.nn.Module = field(default=None)
235 | quant: AbstractQuantizer = field(default_factory=LowMemoryQuantizer)
236 |
237 |
--------------------------------------------------------------------------------
/src/caldera/utils/quantization.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from scipy import stats
4 | from abc import ABC, abstractmethod
5 | import numpy as np
6 |
7 |
8 | class AbstractQuantizer(ABC):
9 | @abstractmethod
10 | def quantize_block(self, weight):
11 | ...
12 |
13 | @abstractmethod
14 | def dequantize_block(self, weight_quant, weight_max, weight_shape):
15 | ...
16 |
17 | class LowMemoryQuantizer(AbstractQuantizer):
18 | def __init__(self, num_bits=2, method="normal", block_size=64):
19 | self.num_bits = num_bits
20 | assert num_bits in [2, 4, 8, 16]
21 | self.method = method
22 | self.block_size = block_size
23 |
24 | if self.method != "normal" and self.method != "uniform" and self.method != "uniform_clipped":
25 | raise NotImplementedError("Other quantization methods not supported yet.")
26 |
27 | def _quantize_uniform(self, weight_divabs):
28 | # weight_divabs is between -1 and 1, inclusive
29 | weight_scaled = weight_divabs * (2**(self.num_bits - 1) - 1)
30 | weight_scaled = weight_scaled.round()
31 | if self.num_bits <= 8:
32 | return weight_scaled.to(torch.int8)
33 | else:
34 | return weight_scaled.to(torch.int16)
35 |
36 | def _quantize_nf(self, weight_divabs):
37 | Normal = torch.distributions.Normal(0, 1)
38 | # We quantize the range [-1, 0) and the range [0, 1] separately, with
39 | # each having 2^{b-1} levels.
40 | #
41 | # The quantization levels are found as follows: take 2^{b-1} evenly-spaced
42 | # points from [delta, 1/2] and 2{b-1} + 1 from [1/2, 1-delta], where delta
43 | # is as defined below. The quantization levels are the corresponding
44 | # quantiles of a standard normal distribution, scaled such that they lie
45 | # in the range [-1, 1].
46 | M = 2**(self.num_bits-1)
47 | delta = 1/2 * (1/30 + 1/32) # as described above
48 | res_neg = (1/2 - delta) / (M - 1) # resolution for [delta, 1/2]
49 | res_pos = (1/2 - delta) / M # resolution for [1/2, 1-delta]
50 | # levels to be in [-1, 1]
51 |
52 | # We index into q_neg and q_pos with these indices to get the quantized
53 | # values for the negative and positive parts of A, respectively.
54 | q_neg = Normal.icdf(res_neg * torch.arange(M).to(weight_divabs.device) + delta) / stats.norm.ppf(1-delta)
55 | q_pos = Normal.icdf(res_pos * torch.arange(M + 1).to(weight_divabs.device) + 1/2) / stats.norm.ppf(1-delta)
56 |
57 | neg_quantiles = (weight_divabs < 0) * \
58 | ((Normal.cdf(weight_divabs * stats.norm.ppf(1-delta)) - delta) / res_neg)
59 | neg_quantiles_round_down = neg_quantiles.floor().long()
60 | neg_quantiles_round_up = torch.minimum(neg_quantiles.ceil().long(), torch.tensor(M-1))
61 | mask = (torch.abs(weight_divabs - q_neg[neg_quantiles_round_down]) <= torch.abs(weight_divabs - q_neg[neg_quantiles_round_up]))
62 | neg_quant_idxs = neg_quantiles_round_down * mask + neg_quantiles_round_up * (~mask)
63 |
64 | pos_quantiles = (weight_divabs >= 0) * \
65 | ((Normal.cdf(weight_divabs * stats.norm.ppf(1-delta)) - 1/2) / res_pos)
66 | pos_quantiles_round_down = pos_quantiles.floor().long()
67 | pos_quantiles_round_up = torch.minimum(pos_quantiles.ceil().long(), torch.tensor(M))
68 | mask = (torch.abs(weight_divabs - q_pos[pos_quantiles_round_down]) <= torch.abs(weight_divabs - q_pos[pos_quantiles_round_up]))
69 | pos_quant_idxs = pos_quantiles_round_down * mask + pos_quantiles_round_up * (~mask)
70 |
71 | idxs = neg_quant_idxs + (weight_divabs >= 0) * (pos_quant_idxs + M - 1)
72 |
73 | if self.num_bits <= 8:
74 | return idxs.to(torch.uint8)
75 | else:
76 | return idxs
77 |
78 | def _dequantize_uniform(self, weight_quant):
79 | return weight_quant.float() / (2**(self.num_bits - 1) - 1)
80 |
81 | def _dequantize_nf(self, weight_quant):
82 | Normal = torch.distributions.Normal(0, 1)
83 | M = 2**(self.num_bits-1)
84 | delta = 1/2 * (1/30 + 1/32) # as described above
85 | res_neg = (1/2 - delta) / (M - 1) # resolution for [delta, 1/2]
86 | res_pos = (1/2 - delta) / M # resolution for [1/2, 1-delta]
87 | # levels to be in [-1, 1]
88 | # quantization levels for the negative and positive halves, respectively
89 | q_neg = Normal.icdf(res_neg * torch.arange(M - 1).to(weight_quant.device) + delta) / stats.norm.ppf(1-delta)
90 | q_pos = Normal.icdf(res_pos * torch.arange(M + 1).to(weight_quant.device) + 1/2) / stats.norm.ppf(1-delta)
91 | q_levels = torch.cat((q_neg, q_pos))
92 | return q_levels[weight_quant.long()]
93 |
94 | def quantize_block(self, weight, epsilon=1e-8):
95 | if len(weight.shape) != 2:
96 | raise ValueError(f"Only support 2D matrix, but your input has {len(weight.shape)} dimensions.")
97 | if weight.shape[0] * weight.shape[1] % self.block_size != 0:
98 | raise ValueError(
99 | f"Weight with shape ({weight.shape[0]} x {weight.shape[1]}) "
100 | f"is not dividable by block size {self.block_size}."
101 | )
102 |
103 | weight_reshape = weight.flatten().reshape(-1, self.block_size) # (L, M*N/B)
104 | weight_max = weight_reshape.abs().max(dim=-1)[0].unsqueeze(-1)
105 | if self.method == "uniform_clipped":
106 | weight_max = weight_reshape.mean(dim=1) + 2.5 * weight_reshape.std(dim=1)
107 | weight_max = weight_max.unsqueeze(-1)
108 | weight_reshape = torch.minimum(weight_reshape, weight_max)
109 | weight_reshape = torch.maximum(weight_reshape, -weight_max)
110 | weight_max = torch.maximum(weight_max, torch.Tensor([epsilon]).to(weight.device))
111 | weight_divabs = weight_reshape / weight_max
112 | if self.method == "normal":
113 | weight_quant = self._quantize_nf(weight_divabs)
114 | else:
115 | weight_quant = self._quantize_uniform(weight_divabs)
116 | return weight_quant, weight_max, weight.shape
117 |
118 | def dequantize_block(self, weight_quant, weight_max, weight_shape):
119 | if self.method == "normal":
120 | weight = self._dequantize_nf(weight_quant)
121 | else:
122 | weight = self._dequantize_uniform(weight_quant)
123 |
124 | return (weight * weight_max).reshape(weight_shape)
125 |
126 | def simulated_quant(
127 | quant: AbstractQuantizer,
128 | X: torch.Tensor
129 | ):
130 | if quant.num_bits >= 16:
131 | if X.dtype == torch.float32:
132 | return X.to(dtype=torch.bfloat16).float()
133 | return X
134 | return quant.dequantize_block(*quant.quantize_block(X))
135 |
136 | class QuantizerFactory:
137 | def __init__(self, method="uniform", block_size=64):
138 | self.method = method
139 | self.block_size = block_size
140 |
141 | def get_quantizer(self, num_bits, device="cpu"):
142 | return LowMemoryQuantizer(num_bits=num_bits, method=self.method, block_size=self.block_size)
143 |
144 | def __str__(self):
145 | return f"QuantizerFactory(method={self.method}, block_size={self.block_size})"
146 |
147 | def mixed_precision_quantize(
148 | X: torch.Tensor = None,
149 | strip_widths: list[int] = 0,
150 | quantizers: list[AbstractQuantizer] = None,
151 | transposed: bool = False
152 | ) -> torch.Tensor:
153 | """
154 | Given a matrix X, quantizes strips of columns with different bit levels.
155 | The width of each quantized strip is given by the argument strip_widths,
156 | and the corresponding quantizer object for each strip is given by the
157 | argument quantizers.
158 |
159 | There is a one-to-one correspondence between the elements of strip_widths
160 | and the elements of quantizers, so the lists must be the same length.
161 | If sum(strip_widths) is less than the total number of columns in X, the
162 | remaining columns are dropped.
163 |
164 | If transposed is set to True, it quantizes strips of rows instead.
165 | """
166 | assert len(strip_widths) == len(quantizers)
167 | if transposed:
168 | X = X.T
169 |
170 | assert sum(strip_widths) <= X.shape[1], "sum(widths) should be less than X.shape[1]"
171 |
172 | idxs = np.hstack(([0], np.cumsum(strip_widths)))
173 | # Perform simulated quantization by quantizing and the dequantizing
174 | quantized_components = [
175 | simulated_quant(quantizers[i], X[:, idxs[i]:idxs[i+1]]) \
176 | for i in range(len(strip_widths)) \
177 | if idxs[i+1] > idxs[i]
178 | ]
179 | X_quant = torch.cat(quantized_components, dim=1)
180 | return X_quant.T if transposed else X_quant
181 |
182 | def quantize_small_sv_components(
183 | X: torch.Tensor = None,
184 | r: int = 0,
185 | quantizer:AbstractQuantizer = None,
186 | transposed: bool = False
187 | ) -> torch.Tensor:
188 | """
189 | Keep the first r columns in original dtype and quantize the last
190 | (X.shape[1] - r) columns.
191 | If "transposed" is True, then quantize rows instead of columns.
192 |
193 | The parameter `quantization_fn` allows you to specify uniform quantization
194 | (via the `quantize` function) or normal float quantization (via the
195 | `quantize_nf` function).
196 | """
197 |
198 | if transposed:
199 | X = X.T
200 | assert r <= X.shape[1], "r should be less than X.shape[1]"
201 |
202 | if r == X.shape[1]:
203 | return X
204 |
205 | # Perform simulated quantization by quantizing and the dequantizing
206 | quantized_component = simulated_quant(quantizer, X[:, r:])
207 | X_quant = torch.cat((X[:, :r], quantized_component), dim=1)
208 | return X_quant.T if transposed else X_quant
209 |
210 | def absmax_quantize_int8(X: torch.Tensor) -> tuple[torch.Tensor, torch.float16]:
211 | """Quantize each float16/32 data type to int8 and return the maximum value in float16"""
212 | scale = X.abs().max().item() / 127.0
213 | int8_tensor = (X / scale).round().to(torch.int8)
214 | return scale, int8_tensor
215 |
216 | def absmax_dequantize_int8(Xq: torch.Tensor, scale: torch.float16) -> torch.Tensor:
217 | """Dequantize int8 data type to float16/32"""
218 | return Xq.to(torch.float16) * scale
--------------------------------------------------------------------------------
/src/caldera/decomposition/weight_compression.py:
--------------------------------------------------------------------------------
1 | from quantize_llama.hessian_offline_llama import forward_layer, accumulate
2 | from lib.utils.data_utils import sample_rp1t, sample_falcon_refinedweb
3 |
4 | from transformers import AutoModelForCausalLM, AutoTokenizer
5 | from transformers.modeling_attn_mask_utils import \
6 | _prepare_4d_causal_attention_mask
7 |
8 | import torch
9 | import torch.multiprocessing as mp
10 | import gc
11 |
12 | from caldera.decomposition.dataclasses import *
13 | from caldera.utils.enums import DevSet
14 | from caldera.decomposition.layer_quantization import \
15 | ActivationAwareLayerQuant
16 |
17 | from caldera.utils.enums import DevSet
18 | import os
19 |
20 |
21 | # Partially adapted from https://github.com/Cornell-RelaxML/quip-sharp
22 |
23 |
24 | class ActivationAwareWeightCompressor:
25 | """
26 | Sets up the framework for activation aware weight compression: loads in
27 | the model and calibration dataset, and then computes the inputs to each
28 | layer and the corresponding Hessian matrix (the second moment of the
29 | inputs). The inputs and Hessians are stored in files at
30 | `hessian_save_path`.
31 |
32 | Yor can instantiate an `ActivationAwareLayerQuant` for a given layer by
33 | calling the `get_layer_quantizer` method. See `ActivationAwareLayerQuant`
34 | for more usage details.
35 |
36 | Note: if you have already done Hessian computation and the data is stored
37 | in the appropriate files w.r.t. `Hessian_save_path`, you can pass in
38 | `compute_Hessians=False` to skip the data sampling and Hessian computation.
39 | """
40 | def __init__(
41 | self,
42 | model_params: ModelParameters = ModelParameters(),
43 | data_params: DataParameters = DataParameters(),
44 | quant_params: CalderaParams = CalderaParams(),
45 | quant_device: str = "cuda",
46 | hessian_save_path: str = "",
47 | start_at_layer: int = 0,
48 | stop_at_layer: int = None,
49 | n_sample_proc: int = 4,
50 | compute_hessians: bool = True
51 | ):
52 | try:
53 | mp.set_start_method('spawn')
54 | except RuntimeError:
55 | print(("Warning: an exception occured when setting the "
56 | "multiprocessing context. If you have previously "
57 | "instantiated an ActivationAwareWeightCompressor, you can "
58 | "ignore this."))
59 |
60 | self.compute_hessians = compute_hessians
61 |
62 | torch.set_grad_enabled(False)
63 | os.makedirs(hessian_save_path, exist_ok=True)
64 |
65 | # passed into `ActivationAwareLayerQuant` when `get_layer_quantizer`
66 | # is called:
67 | self.hessian_save_path = hessian_save_path
68 | self.quant_params = quant_params
69 | self.quant_device = quant_device
70 | self.data_params = data_params
71 |
72 | self._setup_model_and_data(
73 | model_params,
74 | data_params,
75 | n_sample_proc
76 | )
77 |
78 | if stop_at_layer is None:
79 | stop_at_layer = float('inf')
80 |
81 | if self.compute_hessians:
82 | # Loop through transformer layers and comput + save the Hessians
83 | for transformer_layer_index, transformer_layer \
84 | in enumerate(self.model.model.layers):
85 | if transformer_layer_index < start_at_layer:
86 | continue
87 | if transformer_layer_index >= stop_at_layer:
88 | break
89 | self._process_layer(
90 | transformer_layer_index=transformer_layer_index,
91 | transformer_layer=transformer_layer,
92 | data_params=data_params,
93 | hessian_save_path=hessian_save_path
94 | )
95 |
96 | def get_layer_quantizer(
97 | self,
98 | layer_idx: int,
99 | device: str = None,
100 | label: str = None
101 | ):
102 | """
103 | Instantiates an `ActivationAwareLayerQuant` object for a given
104 | transformer layer.
105 | """
106 | assert layer_idx >= 0 and layer_idx < len(self.model.model.layers)
107 |
108 | if device is None:
109 | device = self.quant_device
110 |
111 | return ActivationAwareLayerQuant(
112 | layer=self.model.model.layers[layer_idx],
113 | layer_idx=layer_idx,
114 | hessian_save_path=self.hessian_save_path,
115 | quant_params=self.quant_params,
116 | device=device,
117 | label=label
118 | )
119 |
120 | def _setup_model_and_data(
121 | self,
122 | model_params: ModelParameters,
123 | data_params: DataParameters,
124 | n_sample_proc: int # Number of processes used to sample the
125 | # calibration dataset. Unrelated to Hessian
126 | # computation.
127 | ):
128 | """
129 | Loads in the model and calibration dataset.
130 | """
131 | # Model
132 | self.model = AutoModelForCausalLM.from_pretrained(
133 | model_params.base_model, torch_dtype="auto", low_cpu_mem_usage=True,
134 | token=model_params.token
135 | )
136 |
137 | if self.compute_hessians:
138 | # Tokenizer
139 | tokenizer = AutoTokenizer.from_pretrained(
140 | model_params.base_model, use_fast=True,
141 | token=model_params.token
142 | )
143 | tokenizer.pad_token = tokenizer.eos_token
144 |
145 | # Calibration dataset
146 | if data_params.devset == DevSet.RP1T:
147 | self.devset = sample_rp1t(tokenizer,
148 | data_params.devset_size,
149 | data_params.context_length,
150 | nproc=n_sample_proc)
151 | self.dev_emb = self.model.model.embed_tokens(self.devset)
152 | elif data_params.devset == DevSet.FALCON:
153 | self.devset = sample_falcon_refinedweb(
154 | tokenizer, data_params.devset_size,
155 | data_params.context_length,
156 | nproc=n_sample_proc
157 | )
158 | self.dev_emb = self.model.model.embed_tokens(self.devset)
159 | else:
160 | raise NotImplementedError("Dataset not implemented yet")
161 | self.dev_emb.share_memory_()
162 |
163 | # Attention mask and position IDs
164 | self.position_ids = torch.arange(
165 | data_params.context_length, dtype=torch.int64
166 | )[None, :] + torch.zeros(
167 | data_params.batch_size,
168 | data_params.context_length,
169 | dtype=torch.int64
170 | )
171 |
172 | if hasattr(self.model.config, 'sliding_window'):
173 | self.attention_mask = _prepare_4d_causal_attention_mask(
174 | None, (data_params.batch_size, data_params.context_length),
175 | self.dev_emb[0:data_params.batch_size], 0,
176 | sliding_window=self.model.config.sliding_window
177 | )
178 | else:
179 | self.attention_mask = _prepare_4d_causal_attention_mask(
180 | None, (data_params.batch_size, data_params.context_length),
181 | self.dev_emb[0:data_params.batch_size], 0
182 | )
183 |
184 | def _process_layer(
185 | self,
186 | transformer_layer_index: int,
187 | transformer_layer: torch.nn.Module,
188 | data_params: DataParameters,
189 | hessian_save_path: str
190 | ):
191 | """
192 | Compute the layer inputs and Hessians via the same process as
193 | quip-sharp/quantize_llama/hessian_offline_llama.py.
194 | """
195 | # Check that there are four layers (QKV + 4 MLP), as expected
196 | assert (len([
197 | m for m in transformer_layer.modules()
198 | if isinstance(m, torch.nn.Linear)
199 | ]) == 7)
200 |
201 | chunk_size = min(data_params.chunk_size, len(self.dev_emb))
202 |
203 | devices_available = data_params.devices if \
204 | data_params.devices is not None else \
205 | range(torch.cuda.device_count())
206 | ngpus = min(len(devices_available), len(self.dev_emb) // chunk_size)
207 | devices = devices_available[:ngpus]
208 | print(f"Computing hessians on {devices}")
209 |
210 | manager = mp.get_context('spawn').Manager()
211 | in_q = manager.Queue()
212 | out_q = manager.Queue()
213 |
214 | accumulate_proc = mp.Process(
215 | target=accumulate,
216 | args=(
217 | out_q, None, ngpus,
218 | AccumulatorArgs(save_path=hessian_save_path),
219 | transformer_layer_index
220 | )
221 | )
222 | accumulate_proc.start()
223 |
224 | forward_procs = []
225 | for device in devices:
226 | p = mp.Process(
227 | target=forward_layer,
228 | args=(
229 | transformer_layer,
230 | self.position_ids,
231 | self.attention_mask,
232 | data_params.batch_size,
233 | device, in_q, out_q
234 | )
235 | )
236 | p.start()
237 | forward_procs.append(p)
238 |
239 | assert len(self.dev_emb) % data_params.batch_size == 0 and \
240 | chunk_size % data_params.batch_size == 0
241 | i = 0
242 | while i < len(self.dev_emb):
243 | next = min(i + chunk_size, len(self.dev_emb))
244 | in_q.put(self.dev_emb[i:next])
245 | i = next
246 |
247 | for device in devices:
248 | in_q.put(None)
249 |
250 | for p in forward_procs:
251 | p.join()
252 |
253 | accumulate_proc.join()
254 |
255 | transformer_layer.cpu()
256 | # self.model.model.layers[transformer_layer_index] = None
257 | gc.collect()
258 | torch.cuda.empty_cache()
259 |
260 | print(f"done processing layer {transformer_layer_index}")
261 |
--------------------------------------------------------------------------------
/scripts/quantize_save_llama.py:
--------------------------------------------------------------------------------
1 | import random
2 | import sys
3 |
4 | import numpy as np
5 | from caldera.decomposition.dataclasses import *
6 | from caldera.decomposition.weight_compression import ActivationAwareWeightCompressor
7 | from caldera.utils.enums import TransformerSubLayers
8 | import gc
9 | from transformers import (
10 | AutoModelForCausalLM,
11 | HfArgumentParser,
12 | AutoModelForSequenceClassification,
13 | LlamaForCausalLM
14 | )
15 | import torch
16 | import os
17 | import torch.multiprocessing as mp
18 | import glog
19 | from lib.utils import graph_wrapper
20 | import shutil
21 |
22 | SUBLAYER_TO_STRING = {
23 | TransformerSubLayers.KEY: "Key Projection (attn)",
24 | TransformerSubLayers.QUERY: "Query Projection (attn)",
25 | TransformerSubLayers.VALUE: "Value Projection (attn)",
26 | TransformerSubLayers.O: "O Projection (attn)",
27 | TransformerSubLayers.GATE: "Gate Projection (mlp)",
28 | TransformerSubLayers.UP: "Up Projection (mlp)",
29 | TransformerSubLayers.DOWN: "Down Projection (mlp)",
30 | }
31 |
32 |
33 | @dataclass
34 | class Arguments:
35 | hessian_save_path: str = field(
36 | metadata={"help": "Path in which the Hessians were stored"}
37 | )
38 | model_save_path: str = field(
39 | metadata={"help": ("Path in which to save the quantized model, e.g., artifacts/model.pt")}
40 | )
41 | base_model: str = field(
42 | metadata={
43 | "help": (
44 | "Path of the model that is being quantized, as "
45 | "either a local or a Huggingface path"
46 | )
47 | }
48 | )
49 | devices: list[str] = field(
50 | metadata={
51 | "help": (
52 | "List of devices to use for quantization, e.g. "
53 | '"cuda:0 cuda:1 cuda:2 cuda:3"'
54 | )
55 | }
56 | )
57 | ft_rank: int = field(
58 | default=64,
59 | metadata={
60 | "help": (
61 | "Number of columns of L and rows of R, in the decomposition"
62 | "W approx. Q + LR to finetune. The remaining columns will "
63 | "remain fixed."
64 | )
65 | },
66 | )
67 | token: str = field(
68 | default="", metadata={"help": "Huggingface token for private models."}
69 | )
70 | start_layer: int = field(
71 | default=0,
72 | metadata={
73 | "help": "Layer index to start quantizing from (to resume quantization from an interrupt)"
74 | },
75 | )
76 | stop_layer: int = field(
77 | default=int(sys.maxsize),
78 | metadata={
79 | "help": "Layer index to stop quantizing at (to resume quantization from an interrupt)"
80 | },
81 | )
82 |
83 | random_seed: int = field(
84 | default=42,
85 | metadata={
86 | "help": "Random seed for reproducibility."
87 | },
88 | )
89 |
90 |
91 | def quant_layer(
92 | in_q,
93 | model_save_path,
94 | base_model,
95 | config,
96 | ft_rank,
97 | grad_ckpt,
98 | device,
99 | data_params,
100 | quant_params,
101 | hessian_save_path,
102 | random_seed
103 | ):
104 | np.random.seed(random_seed)
105 | torch.manual_seed(random_seed)
106 | torch.cuda.manual_seed_all(random_seed)
107 | random.seed(random_seed)
108 |
109 | model = AutoModelForCausalLM.from_pretrained(
110 | base_model, torch_dtype="auto", low_cpu_mem_usage=True
111 | ).cpu()
112 |
113 | while True:
114 | layer_idx = in_q.get()
115 |
116 | if layer_idx is None:
117 | return
118 |
119 | weight_compressor = ActivationAwareWeightCompressor(
120 | model_params=ModelParameters(base_model),
121 | data_params=data_params,
122 | hessian_save_path=hessian_save_path,
123 | quant_params=quant_params,
124 | compute_hessians=False,
125 | )
126 | layer_quant = weight_compressor.get_layer_quantizer(layer_idx, device)
127 |
128 | with torch.no_grad():
129 | layer = model.model.layers[layer_idx]
130 |
131 | for sublayer in layer_quant.sublayer_info.keys():
132 | print(f"Quantizing layer {layer_idx}, {SUBLAYER_TO_STRING[sublayer]}")
133 | layer_quant.compress_sublayer(sublayer)
134 |
135 | attr_names = layer_quant.sublayer_info[sublayer].out_key.split(".")
136 | setattr(
137 | getattr(layer, attr_names[0]),
138 | attr_names[1],
139 | layer_quant.get_quantized_linear_layer(
140 | sublayer, ft_rank, grad_ckpt
141 | ),
142 | )
143 | layer_quant.clean_up_sublayer(sublayer)
144 | layer = layer.cpu()
145 | torch.save(layer, f"{model_save_path}/layers/quant_layer_{layer_idx}.pt")
146 | del layer_quant
147 | gc.collect()
148 | torch.cuda.empty_cache()
149 |
150 |
151 | def quantize_save_llama(
152 | base_model: str,
153 | hessian_save_path: str,
154 | model_save_path: str,
155 | token: str = "",
156 | ft_rank: int = 64,
157 | grad_ckpt: bool = True,
158 | data_params: DataParameters = DataParameters(),
159 | quant_params: CalderaParams = CalderaParams(),
160 | quant_devices=["cuda"],
161 | start_layer=0,
162 | stop_layer=int(sys.maxsize),
163 | random_seed: int = 42,
164 | ):
165 | os.makedirs(f"{model_save_path}/layers", exist_ok=True)
166 | mp.set_start_method("spawn")
167 |
168 | if token:
169 | model = AutoModelForCausalLM.from_pretrained(
170 | base_model, torch_dtype="auto", low_cpu_mem_usage=True, token=token
171 | ).cpu()
172 | else:
173 | model = AutoModelForCausalLM.from_pretrained(
174 | base_model, torch_dtype="auto", low_cpu_mem_usage=True
175 | ).cpu()
176 |
177 | model_config = model.config
178 | n_layers = len(model.model.layers)
179 | del model
180 | gc.collect()
181 | torch.cuda.empty_cache()
182 |
183 | manager = mp.get_context("spawn").Manager()
184 | in_q = manager.Queue()
185 | quant_procs = []
186 |
187 | for device in quant_devices:
188 | p = mp.Process(
189 | target=quant_layer,
190 | args=(
191 | in_q,
192 | model_save_path,
193 | base_model,
194 | model_config,
195 | ft_rank,
196 | grad_ckpt,
197 | device,
198 | data_params,
199 | quant_params,
200 | hessian_save_path,
201 | random_seed,
202 | ),
203 | )
204 | p.start()
205 | quant_procs.append(p)
206 |
207 | stop_layer: int = min(stop_layer, n_layers)
208 | for layer_idx in range(start_layer, stop_layer):
209 | in_q.put(layer_idx)
210 |
211 | for _ in quant_devices:
212 | in_q.put(None)
213 |
214 | for p in quant_procs:
215 | p.join()
216 |
217 | # now save the full model
218 | model = load_layers_cpu(model_save_path, base_model)
219 | shutil.rmtree(f'{model_save_path}/')
220 | torch.save(model, f"{model_save_path}")
221 |
222 | def load_layers_cpu(
223 | model_save_path,
224 | base_model,
225 | ):
226 | model = AutoModelForCausalLM.from_pretrained(
227 | base_model, torch_dtype='auto', device_map="cpu", low_cpu_mem_usage=True
228 | ).cpu()
229 |
230 | for layer_idx in range(len(model.model.layers)):
231 | layer = torch.load(
232 | f"{model_save_path}/layers/quant_layer_{layer_idx}.pt",
233 | map_location="cpu"
234 | )
235 | layer.post_attention_layernorm.weight.requires_grad = False
236 | layer.input_layernorm.weight.requires_grad = False
237 |
238 | for sublayer in [
239 | layer.self_attn.q_proj, layer.self_attn.k_proj, layer.self_attn.v_proj,
240 | layer.self_attn.o_proj, layer.mlp.gate_proj, layer.mlp.up_proj,
241 | layer.mlp.down_proj
242 | ]:
243 | if sublayer.ft_rank > 0:
244 | sublayer.L_ft = torch.nn.Parameter(sublayer.L_ft.contiguous(), requires_grad=True)
245 | sublayer.R_ft = torch.nn.Parameter(sublayer.R_ft.contiguous(), requires_grad=True)
246 |
247 | model.model.layers[layer_idx] = layer
248 |
249 | return model
250 |
251 | def load_quantized_model(
252 | model_save_path,
253 | base_model,
254 | device,
255 | sequence_classification=False,
256 | seq_class_num_labels=2,
257 | cuda_graph=False,
258 | ):
259 | model = torch.load(model_save_path, map_location=device).to(device)
260 | if cuda_graph:
261 | graph_model = graph_wrapper.get_graph_wrapper(AutoModelForCausalLM, device="cpu").from_pretrained(
262 | base_model, torch_dtype='auto', device_map="cpu", low_cpu_mem_usage=True,
263 | use_flash_attention_2=True
264 | ).to("cpu")
265 | for i in range(len(graph_model.model.layers)):
266 | graph_model.model.layers[i].self_attn.q_proj = model.model.layers[i].self_attn.q_proj
267 | graph_model.model.layers[i].self_attn.k_proj = model.model.layers[i].self_attn.k_proj
268 | graph_model.model.layers[i].self_attn.v_proj = model.model.layers[i].self_attn.v_proj
269 | graph_model.model.layers[i].self_attn.o_proj = model.model.layers[i].self_attn.o_proj
270 | graph_model.model.layers[i].mlp = model.model.layers[i].mlp
271 | graph_model.model.layers[i].post_attention_layernorm = graph_model.model.layers[i].post_attention_layernorm.to(device)
272 | graph_model.model.layers[i].input_layernorm = graph_model.model.layers[i].input_layernorm.to(device)
273 | graph_model.model.norm = graph_model.model.norm.to(device)
274 | graph_model.model.embed_tokens = graph_model.model.embed_tokens.to(device)
275 | graph_model.lm_head = graph_model.lm_head.to(device)
276 | graph_model.graph_device = device
277 | model = graph_model.to(device )
278 |
279 | elif sequence_classification:
280 | seq_model = AutoModelForSequenceClassification.from_pretrained(
281 | base_model, torch_dtype='auto', device_map="cpu", low_cpu_mem_usage=True, num_labels=seq_class_num_labels
282 | ).cpu()
283 | seq_model.score = seq_model.score.to(device)
284 | seq_model.score.weight.requires_grad = True
285 | model.model.embed_tokens = model.model.embed_tokens.to(device)
286 | seq_model.model.layers = model.model.layers
287 | model = seq_model.to(device)
288 |
289 | if not sequence_classification:
290 | model.lm_head.weight.requires_grad = False
291 | model.model.embed_tokens.weight.requires_grad = False
292 | model.model.norm.weight.requires_grad = False
293 | for layer in model.model.layers:
294 | layer.post_attention_layernorm.weight.requires_grad = False
295 | layer.input_layernorm.weight.requires_grad = False
296 |
297 | return model
298 |
299 |
300 | if __name__ == "__main__":
301 | glog.setLevel("WARN")
302 |
303 | parser = HfArgumentParser([Arguments, CalderaParams, QuIPArgs])
304 |
305 | args, quant_params, quip_args = parser.parse_args_into_dataclasses()
306 | quant_params.quip_args = quip_args
307 | quantize_save_llama(
308 | base_model=args.base_model,
309 | hessian_save_path=args.hessian_save_path,
310 | model_save_path=args.model_save_path,
311 | token=args.token,
312 | ft_rank=args.ft_rank,
313 | grad_ckpt=False,
314 | data_params=DataParameters(),
315 | quant_params=quant_params,
316 | quant_devices=args.devices,
317 | start_layer=args.start_layer,
318 | stop_layer=args.stop_layer,
319 | random_seed=args.random_seed,
320 | )
321 |
--------------------------------------------------------------------------------
/scripts/finetune_wikitext.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | import transformers
3 | from transformers import (
4 | AutoConfig,
5 | AutoTokenizer,
6 | Trainer,
7 | default_data_collator,
8 | TrainingArguments
9 | )
10 | from transformers.testing_utils import CaptureLogger
11 | import math
12 | import logging
13 | from itertools import chain
14 | import evaluate
15 | from dataclasses import dataclass, field
16 | from typing import Optional
17 | from transformers.utils.versions import require_version
18 | from quantize_save_llama import load_quantized_model
19 | from accelerate import Accelerator
20 | import os
21 |
22 |
23 | @dataclass
24 | class Arguments:
25 | model_save_path: str = field(metadata={
26 | "help": ("Path in which the quantized model was saved via "
27 | "quantize_save_llama.py")
28 | })
29 | base_model: str = field(metadata={
30 | "help": ("Path of the original model, as either a local or a "
31 | "Huggingface path")
32 | })
33 |
34 | @dataclass
35 | class DataTrainingArguments:
36 | """
37 | Arguments pertaining to what data we are going to input our model for training and eval.
38 | """
39 |
40 | dataset_name: Optional[str] = field(
41 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
42 | )
43 | dataset_config_name: Optional[str] = field(
44 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
45 | )
46 | train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
47 | validation_file: Optional[str] = field(
48 | default=None,
49 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
50 | )
51 | max_train_samples: Optional[int] = field(
52 | default=None,
53 | metadata={
54 | "help": (
55 | "For debugging purposes or quicker training, truncate the number of training examples to this "
56 | "value if set."
57 | )
58 | },
59 | )
60 | max_eval_samples: Optional[int] = field(
61 | default=None,
62 | metadata={
63 | "help": (
64 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
65 | "value if set."
66 | )
67 | },
68 | )
69 | streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
70 | block_size: Optional[int] = field(
71 | default=None,
72 | metadata={
73 | "help": (
74 | "Optional input sequence length after tokenization. "
75 | "The training dataset will be truncated in block of this size for training. "
76 | "Default to the model max input length for single sentence inputs (take into account special tokens)."
77 | )
78 | },
79 | )
80 | overwrite_cache: bool = field(
81 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
82 | )
83 | validation_split_percentage: Optional[int] = field(
84 | default=5,
85 | metadata={
86 | "help": "The percentage of the train set used as validation set in case there's no validation split"
87 | },
88 | )
89 | preprocessing_num_workers: Optional[int] = field(
90 | default=None,
91 | metadata={"help": "The number of processes to use for the preprocessing."},
92 | )
93 | keep_linebreaks: bool = field(
94 | default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
95 | )
96 |
97 | def __post_init__(self):
98 | if self.streaming:
99 | require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`")
100 |
101 | if self.dataset_name is None and self.train_file is None and self.validation_file is None:
102 | raise ValueError("Need either a dataset name or a training/validation file.")
103 | else:
104 | if self.train_file is not None:
105 | extension = self.train_file.split(".")[-1]
106 | assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
107 | if self.validation_file is not None:
108 | extension = self.validation_file.split(".")[-1]
109 | assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
110 |
111 |
112 | def main(
113 | model_save_path: str,
114 | base_model: str,
115 | output_dir: str,
116 | data_args: DataTrainingArguments,
117 | training_args: TrainingArguments
118 | ):
119 | os.makedirs(output_dir, exist_ok=True)
120 | accelerator = Accelerator()
121 | model = load_quantized_model(model_save_path, base_model, accelerator.device)
122 | for name, param in model.named_parameters():
123 | if 'SU' in name or 'SV' in name:
124 | param.requires_grad = False
125 |
126 | logger = logging.getLogger(__name__)
127 |
128 | config = AutoConfig.from_pretrained(base_model)
129 |
130 | raw_datasets = load_dataset(
131 | data_args.dataset_name,
132 | data_args.dataset_config_name,
133 | streaming=data_args.streaming,
134 | )
135 |
136 | tokenizer = AutoTokenizer.from_pretrained(
137 | base_model, use_fast=True)
138 | embedding_size = model.get_input_embeddings().weight.shape[0]
139 |
140 | if len(tokenizer) > embedding_size:
141 | model.resize_token_embeddings(len(tokenizer))
142 |
143 | column_names = list(raw_datasets["validation"].features)
144 | text_column_name = "text" if "text" in column_names else column_names[0]
145 |
146 | tok_logger = transformers.utils.logging.get_logger(
147 | "transformers.tokenization_utils_base"
148 | )
149 |
150 | def tokenize_function(examples):
151 | with CaptureLogger(tok_logger) as cl:
152 | output = tokenizer(examples[text_column_name])
153 | # clm input could be much much longer than block_size
154 | if "Token indices sequence length is longer than the" in cl.out:
155 | tok_logger.warning(
156 | "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long "
157 | "input will be chunked into smaller bits before being passed "
158 | "to the model."
159 | )
160 | return output
161 |
162 | with training_args.main_process_first(desc="dataset map tokenization"):
163 | if not data_args.streaming:
164 | tokenized_datasets = raw_datasets.map(
165 | tokenize_function,
166 | batched=True,
167 | num_proc=data_args.preprocessing_num_workers,
168 | remove_columns=column_names,
169 | load_from_cache_file=not data_args.overwrite_cache,
170 | desc="Running tokenizer on dataset",
171 | )
172 | else:
173 | tokenized_datasets = raw_datasets.map(
174 | tokenize_function,
175 | batched=True,
176 | remove_columns=column_names,
177 | )
178 |
179 | if hasattr(config, "max_position_embeddings"):
180 | max_pos_embeddings = config.max_position_embeddings
181 | else:
182 | # Define a default value if the attribute is missing in the config.
183 | max_pos_embeddings = 1024
184 |
185 | if data_args.block_size is None:
186 | block_size = tokenizer.model_max_length
187 | if block_size > max_pos_embeddings:
188 | logger.warning(
189 | "The tokenizer picked seems to have a very large "
190 | f"`model_max_length` ({tokenizer.model_max_length}). "
191 | f"Using block_size={min(1024, max_pos_embeddings)} instead. "
192 | "You can change that default value by passing --block_size xxx."
193 | )
194 | if max_pos_embeddings > 0:
195 | block_size = min(1024, max_pos_embeddings)
196 | else:
197 | block_size = 1024
198 | else:
199 | if data_args.block_size > tokenizer.model_max_length:
200 | logger.warning(
201 | f"The block_size passed ({data_args.block_size}) is larger "
202 | "than the maximum length for the model "
203 | f"({tokenizer.model_max_length}). Using "
204 | f"block_size={tokenizer.model_max_length}."
205 | )
206 | block_size = min(data_args.block_size, tokenizer.model_max_length)
207 |
208 | # Main data processing function that will concatenate all texts from
209 | # our dataset and generate chunks of block_size.
210 | def group_texts(examples):
211 | # Concatenate all texts.
212 | concatenated_examples = {k: list(chain(*examples[k]))
213 | for k in examples.keys()}
214 | total_length = len(concatenated_examples[list(examples.keys())[0]])
215 | # We drop the small remainder, and if the total_length < block_size
216 | # we exclude this batch and return an empty dict.
217 | # We could add padding if the model supported it instead of this drop,
218 | # you can customize this part to your needs.
219 | total_length = (total_length // block_size) * block_size
220 | # Split by chunks of max_len.
221 | result = {
222 | k: [t[i: i + block_size]
223 | for i in range(0, total_length, block_size)]
224 | for k, t in concatenated_examples.items()
225 | }
226 | result["labels"] = result["input_ids"].copy()
227 | return result
228 |
229 | # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
230 | # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
231 | # to preprocess.
232 | #
233 | # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
234 | # https://huggingface.co/docs/datasets/process#map
235 |
236 | with training_args.main_process_first(desc="grouping texts together"):
237 | if not data_args.streaming:
238 | lm_datasets = tokenized_datasets.map(
239 | group_texts,
240 | batched=True,
241 | num_proc=data_args.preprocessing_num_workers,
242 | load_from_cache_file=not data_args.overwrite_cache,
243 | desc=f"Grouping texts in chunks of {block_size}",
244 | )
245 | else:
246 | lm_datasets = tokenized_datasets.map(
247 | group_texts,
248 | batched=True,
249 | )
250 | eval_dataset = lm_datasets["validation"]
251 | if data_args.max_eval_samples is not None:
252 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
253 | eval_dataset = eval_dataset.select(range(max_eval_samples))
254 | train_dataset = lm_datasets["train"]
255 | if data_args.max_train_samples is not None:
256 | max_train_samples = min(len(train_dataset), data_args.max_train_samples)
257 | train_dataset = train_dataset.select(range(max_train_samples))
258 |
259 | trainer = Trainer(
260 | model=model,
261 | args=training_args,
262 | train_dataset=train_dataset,
263 | eval_dataset=eval_dataset,
264 | tokenizer=tokenizer,
265 | # Data collator will default to DataCollatorWithPadding
266 | data_collator=default_data_collator
267 | )
268 |
269 | # train?
270 | trainer.train()
271 | trainer.save_model()
272 |
273 | metrics = trainer.evaluate()
274 |
275 | max_eval_samples = len(eval_dataset)
276 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
277 | try:
278 | perplexity = math.exp(metrics["eval_loss"])
279 | except OverflowError:
280 | perplexity = float("inf")
281 | print(f"perplexity: {perplexity}")
282 |
283 |
284 | if __name__ == "__main__":
285 | parser = transformers.HfArgumentParser([
286 | Arguments, DataTrainingArguments, TrainingArguments])
287 | args, data_args, training_args = parser.parse_args_into_dataclasses()
288 | main(
289 | args.model_save_path,
290 | args.base_model,
291 | training_args.output_dir,
292 | data_args=data_args,
293 | training_args=training_args
294 | )
295 |
--------------------------------------------------------------------------------
/src/caldera/decomposition/layer_quantization.py:
--------------------------------------------------------------------------------
1 | from lib.utils.data_utils import flat_to_sym
2 | from lib.utils.math_utils import regularize_H
3 |
4 | import torch
5 | import matplotlib.pyplot as plt
6 | import gc
7 |
8 | from caldera.decomposition.dataclasses import *
9 | from caldera.utils.enums import TransformerSubLayers
10 | from caldera.decomposition.alg import caldera
11 |
12 | from dataclasses import asdict
13 | import json
14 | import datetime
15 |
16 |
17 | # Maps number of bits to name of the QuIP# lattice quantizer
18 | BITS_TO_CODEBOOK = {
19 | 2: 'E8P12',
20 | 3: 'E8P12RVQ3B',
21 | 4: 'E8P12RVQ4B'
22 | }
23 |
24 |
25 | class ActivationAwareLayerQuant:
26 | """
27 | For a given transformer layer, decomposes each sublayer (i.e., one of
28 | {query, key, value, out, gate, up, down}) into Q + LR, where Q is a
29 | matrix quantized according to QuIP# and L, R are low-rank and quantized
30 | factors via an iterative, activation-aware procedure.
31 |
32 | This class is instantiated by `ActivationAwareWeightCompressor` upon
33 | calling the method `get_layer_quantizer`.
34 |
35 | Usage example:
36 | ```
37 | # Instantiate ActivationAwareWeightCompressor. This will automatically
38 | # compute all the Hessians upon initialization, unless you pass in the
39 | # `stop_at_layer` keyword argument
40 | weight_compressor = ActivationAwareWeightCompressor(
41 | model_params=ModelParameters(
42 | base_model="meta-llama/Llama-2-7b-hf"
43 | ),
44 | data_params=DataParameters(
45 | devset=DevSet.FALCON,
46 | devset_size=128,
47 | context_length=4096,
48 | batch_size=2,
49 | devices=["cuda", "cuda:2"]
50 | ),
51 | hessian_save_path="./data/hessians/llama-7b",
52 | quant_params=ActivationAwareQuantParams(
53 | Q_bits=2,
54 | L_bits=3,
55 | R_bits=4,
56 | rank=256,
57 | iters=25,
58 | lplr_iters=3
59 | ),
60 | quant_device="cuda:2"
61 | )
62 |
63 | # Instantiate ActivationAwareLayerQuant for a layer
64 | layer_quant = weight_compressor.get_layer_quantizer(layer_idx=12)
65 |
66 | # Quantize one sub-layer using CALDERA
67 | layer_quant.compress_sublayer(TransformerSubLayers.VALUE)
68 |
69 | # Plot the quantization error
70 | best_error = layer_quant.min_error(TransformerSubLayers.VALUE)
71 | layer_quant.plot_errors(TransformerSubLayers.VALUE)
72 |
73 | # Delete Q, L, and R matrices to free GPU memory before quantizing another
74 | # sublayer (optional)
75 | layer_quant.clean_up_sublayer(TransformerSubLayers.VALUE)
76 | ```
77 | """
78 | def __init__(
79 | self,
80 | layer: torch.nn.Module,
81 | layer_idx: int,
82 | hessian_save_path: str = "",
83 | quant_params: CalderaParams = CalderaParams(),
84 | label: str = "CALDERA",
85 | device: str = "cuda",
86 | ):
87 | self.hessian_save_path = hessian_save_path
88 | self.layer = layer.to(device)
89 | self.label = label
90 | if label is None:
91 | self.label = "CALDERA"
92 |
93 | self.layer_idx = layer_idx
94 | self.quant_params = quant_params
95 | if not self.quant_params.update_order:
96 | if not self.quant_params.compute_low_rank_factors:
97 | self.quant_params.update_order = ["Q"]
98 | elif not self.quant_params.compute_quantized_component:
99 | self.quant_params.update_order = ["LR"]
100 | else:
101 | self.quant_params.update_order = ["LR", "Q"]
102 | if not self.quant_params.compute_low_rank_factors:
103 | self.quant_params.rank = 0
104 | self.device = device
105 |
106 | self._set_sublayer_weights_and_info()
107 |
108 | def compress_sublayer(self, sublayer):
109 | """
110 | Decomposes a sublayer (e.g., query, key, value, etc.) by alternating
111 | between LDLQ and LPLR.
112 |
113 | The sublayer argument must be a member of the TransformerSubLayers
114 | enum.
115 | """
116 | assert sublayer in self.sublayer_info.keys(), \
117 | ("Invalid sublayer! Please use a member of the "
118 | "TransformerSubLayers enum.")
119 | sublayer_info = self.sublayer_info[sublayer]
120 | sublayer_info.started_quant = True
121 |
122 | sublayer_info.caldera.W = sublayer_info.sublayer.weight.to(self.device).float()
123 |
124 | W = sublayer_info.caldera.W
125 | H = self._get_H(sublayer)
126 |
127 | sublayer_info.caldera = caldera(self.quant_params, W, H, self.device)
128 |
129 | def get_quantized_linear_layer(self, sublayer, ft_rank, grad_ckpt=True):
130 | from caldera.decomposition.quantized_layer import \
131 | CalderaQuantizedLinear
132 |
133 | sublayer_info = self._get_sublayer_info_and_check_sublayer(sublayer)
134 |
135 | have_L_codebook = self.quant_params.lattice_quant_LR and self.quant_params.compute_low_rank_factors and \
136 | self.quant_params.L_bits < 16
137 | have_R_codebook = self.quant_params.lattice_quant_LR and self.quant_params.compute_low_rank_factors and \
138 | self.quant_params.R_bits < 16
139 |
140 | L_codebook_version = BITS_TO_CODEBOOK[self.quant_params.L_bits] if have_L_codebook else None
141 |
142 | R_codebook_version = BITS_TO_CODEBOOK[self.quant_params.R_bits] if have_R_codebook else None
143 |
144 | if ft_rank < self.quant_params.rank and self.quant_params.compute_low_rank_factors and \
145 | ((L_codebook_version is None and self.quant_params.L_bits < 16)
146 | or (R_codebook_version is None and
147 | self.quant_params.R_bits < 16)):
148 | raise NotImplementedError(
149 | "Only lattice quantization for L and R implemented so far"
150 | )
151 |
152 | return CalderaQuantizedLinear(
153 | # Dimensions
154 | in_features=sublayer_info.caldera.W.shape[1],
155 | out_features=sublayer_info.caldera.W.shape[0],
156 | # Codebooks
157 | Q_codebook_version=BITS_TO_CODEBOOK[self.quant_params.Q_bits],
158 | L_codebook_version=L_codebook_version,
159 | R_codebook_version=R_codebook_version,
160 | # L and R
161 | L=sublayer_info.caldera.L,
162 | R=sublayer_info.caldera.R,
163 | # Quantized idxs
164 | L_idxs=sublayer_info.caldera.L_idxs,
165 | R_idxs=sublayer_info.caldera.R_idxs,
166 | Q_idxs=sublayer_info.caldera.Q_idxs,
167 | # Scaling
168 | L_scale=sublayer_info.caldera.L_scale,
169 | R_scale=sublayer_info.caldera.R_scale,
170 | Q_scale=sublayer_info.caldera.Q_scale,
171 | global_scale=sublayer_info.caldera.global_scale,
172 | scaleWH=sublayer_info.caldera.scaleWH,
173 | # Hadamard
174 | hadamard=(self.quant_params.hadamard_transform
175 | or self.quant_params.full_quip_sharp),
176 | # SU and SV
177 | SU=sublayer_info.caldera.SU,
178 | SV=sublayer_info.caldera.SV,
179 | # Rank and fine-tuning
180 | rank=max(self.quant_params.rank,
181 | self.quant_params.quip_args.lora_rank),
182 | ft_rank=ft_rank,
183 | grad_ckpt=grad_ckpt
184 | )
185 |
186 | def plot_errors(self, sublayer, plot_first_iter=True, savefile=None):
187 | """
188 | Plot the per-iteration approximation errors for a given sublayer
189 | (i.e., a member of the TransformerSubLayers enum).
190 | """
191 | sublayer_info = self._get_sublayer_info_and_check_sublayer(sublayer)
192 | self._plot(sublayer_info.caldera.errors, plot_first_iter, savefile=savefile)
193 |
194 | def _plot(self, errors, plot_first_iter, savefile=None):
195 | COLORS = ['b', 'r', 'm']
196 | plt.figure(figsize=(12, 4))
197 | title = f"Activation-Aware Error per iteration: {self.label}"
198 | plt.title(title)
199 | for i, key in enumerate(errors.keys()):
200 | if plot_first_iter or len(errors[key] )== 1:
201 | plt.plot(range(len(errors[key])), errors[key],
202 | marker='o', linestyle='-', color=COLORS[i], label=key)
203 | else:
204 | plt.plot(range(1, len(errors[key])), errors[key][1:],
205 | marker='o', linestyle='-', color=COLORS[i], label=key)
206 | plt.yscale('log')
207 | plt.legend()
208 | plt.grid(True)
209 | if savefile is not None:
210 | plt.savefig(savefile)
211 | plt.close()
212 | else:
213 | plt.show()
214 |
215 | def plot_errors_json(self, file, plot_first_iter=True, savefile=None):
216 | with open(file, "r") as infile:
217 | json_str = infile.read()
218 | data = json.loads(json_str)
219 | self._plot(data["per_iter_errors"], plot_first_iter, savefile=savefile)
220 |
221 | def min_error(self, sublayer):
222 | """
223 | Returns the minimum value of the activation-aware loss for a given
224 | sublayer (i.e., a member of the TransformerSubLayers enum).
225 | """
226 | sublayer_info = self._get_sublayer_info_and_check_sublayer(sublayer)
227 | min_error = float('inf')
228 | for key in sublayer_info.caldera.errors:
229 | min_error = min(min_error, min(sublayer_info.caldera.errors[key]))
230 | return min_error
231 |
232 | def export_errors_json(self, sublayer, savefile):
233 | sublayer_info = self._get_sublayer_info_and_check_sublayer(sublayer)
234 | param_dict = asdict(self.quant_params)
235 | del param_dict["quant_factory_Q"]
236 | del param_dict["quant_factory_LR"]
237 |
238 | now = datetime.datetime.now()
239 |
240 | data = {
241 | "per_iter_errors": sublayer_info.caldera.errors,
242 | "layer_idx": self.layer_idx,
243 | "sublayer": sublayer_info.key,
244 | "datetime": str(now),
245 | "timestamp": datetime.datetime.timestamp(now),
246 | "params": param_dict
247 | }
248 | json_object = json.dumps(data)
249 | with open(savefile + ".json", "w") as out:
250 | out.write(json_object)
251 |
252 | def clean_up_sublayer(self, sublayer):
253 | """
254 | Delete Q, L, and R matrices for a sublayer to free GPU memory.
255 | """
256 | sublayer_info = self._get_sublayer_info_and_check_sublayer(sublayer)
257 |
258 | self.sublayer_info[sublayer] = SubLayerInfo(
259 | sublayer=sublayer_info.sublayer, key=sublayer_info.key
260 | )
261 | gc.collect()
262 | torch.cuda.empty_cache()
263 |
264 | def _set_sublayer_weights_and_info(self):
265 | """
266 | Initializes a SubLayerInfo object for each of the seven transformer
267 | sublayers. Called upon instantiation.
268 | """
269 | self.sublayer_info = {
270 | TransformerSubLayers.KEY: SubLayerInfo(
271 | sublayer=self.layer.self_attn.k_proj, key="qkv",
272 | out_key="self_attn.k_proj"),
273 | TransformerSubLayers.QUERY: SubLayerInfo(
274 | sublayer=self.layer.self_attn.q_proj, key="qkv",
275 | out_key="self_attn.q_proj"),
276 | TransformerSubLayers.VALUE: SubLayerInfo(
277 | sublayer=self.layer.self_attn.v_proj, key="qkv",
278 | out_key="self_attn.v_proj"),
279 | TransformerSubLayers.O: SubLayerInfo(
280 | sublayer=self.layer.self_attn.o_proj, key="o",
281 | out_key="self_attn.o_proj"),
282 | TransformerSubLayers.UP: SubLayerInfo(
283 | sublayer=self.layer.mlp.up_proj, key="up",
284 | out_key="mlp.up_proj"),
285 | TransformerSubLayers.GATE: SubLayerInfo(
286 | sublayer=self.layer.mlp.gate_proj, key="up",
287 | out_key="mlp.gate_proj"),
288 | TransformerSubLayers.DOWN: SubLayerInfo(
289 | sublayer=self.layer.mlp.down_proj, key="down",
290 | out_key="mlp.down_proj")
291 | }
292 |
293 | def _get_H(self, sublayer):
294 | """
295 | Reads the Hessian (sum X_i X_i^T) for a specific sublayer from the
296 | corresponding file (in which it was saved by
297 | ActivationAwareWeightCompressor).
298 | """
299 | sublayer_key = self.sublayer_info[sublayer].key
300 | H_data = torch.load(
301 | f'{self.hessian_save_path}/{self.layer_idx}_{sublayer_key}.pt',
302 | map_location=torch.device(self.device),
303 | )
304 | H = flat_to_sym(H_data['flatH'], H_data['n'])
305 |
306 | # Add back in the mean
307 | mu = H_data['mu']
308 | H.add_(mu[None, :] * mu[:, None])
309 | H = regularize_H(H, H_data['n'], self.quant_params.quip_args.sigma_reg)
310 |
311 | return H
312 |
313 | def _get_sublayer_info_and_check_sublayer(self, sublayer):
314 | assert sublayer in self.sublayer_info.keys(), \
315 | ("Invalid sublayer! Please use a member of the "
316 | "TransformerSubLayers enum.")
317 |
318 | sublayer_info = self.sublayer_info[sublayer]
319 | assert sublayer_info.started_quant, \
320 | "Sublayer has't been quantized yet!"
321 | return sublayer_info
322 |
--------------------------------------------------------------------------------
/src/caldera/decomposition/quantized_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from lib import codebook
5 | from lib.utils import get_hadK
6 |
7 | import quiptools_cuda
8 | from lib.utils.matmul_had import matmul_hadU_cuda, matmul_hadUt_cuda
9 |
10 |
11 | # Adapted from https://github.com/Cornell-RelaxML/quip-sharp
12 |
13 |
14 | class LatticeQuantizedParameter(nn.Module):
15 | def __init__(
16 | self,
17 | in_features,
18 | out_features,
19 | idxs,
20 | scale,
21 | codebook_version,
22 | transposed=False
23 | ):
24 | super(LatticeQuantizedParameter, self).__init__()
25 |
26 | self.in_features = in_features
27 | self.out_features = out_features
28 | self.transposed = transposed
29 | self.scale = scale
30 |
31 | self.codebook_version = codebook_version
32 | self.codebook = codebook.codebook_id[codebook_version][1](inference=True).to(torch.float16).to(idxs.device)
33 |
34 | idxs_dev = idxs.device
35 | self.idxs = idxs.cpu()
36 | codebook_class = codebook.get_quantized_class(
37 | codebook.get_id(codebook_version)
38 | )(self.idxs.device)
39 |
40 | split_idxs = codebook_class.maybe_unpack_idxs(
41 | self.idxs
42 | )
43 | self.idxs_list = []
44 | for i in range(len(split_idxs)):
45 | self.register_buffer(f'idxs_{i}', split_idxs[i].to(idxs_dev))
46 | exec(f'self.idxs_list.append(self.idxs_{i})')
47 |
48 | self.idxs = None
49 |
50 | def get_W_decompressed(self):
51 | n = self.in_features
52 | m = self.out_features
53 | if self.codebook_version == 'E8P12':
54 | return quiptools_cuda.decompress_packed_e8p(
55 | self.idxs_list[0].view(m // 16, n // 64, 8, 4),
56 | self.codebook.grid_packed_abs) * self.scale
57 | elif self.codebook_version == 'E8P12RVQ4B':
58 | resid_scale = self.codebook.opt_resid_scale
59 | return (quiptools_cuda.decompress_packed_e8p(
60 | self.idxs_list[0].view(m // 16, n // 64, 8, 4),
61 | self.codebook.grid_packed_abs) + \
62 | quiptools_cuda.decompress_packed_e8p(
63 | self.idxs_list[1].view(m // 16, n // 64, 8, 4),
64 | self.codebook.grid_packed_abs
65 | ) / resid_scale) * self.scale
66 |
67 | else:
68 | W_decompressed = quiptools_cuda.decompress_packed_e8p(
69 | self.idxs_list[0].view(m // 16, n // 64, 8, 4),
70 | self.codebook.grid_packed_abs)
71 |
72 | W_resid_decompressed = torch.zeros(
73 | self.idxs_list[1].shape[0],
74 | 64 * self.idxs_list[1].shape[-1],
75 | device=self.idxs_list[1].device, dtype=torch.float16
76 | )
77 | return (W_decompressed + W_resid_decompressed / resid_scale) * self.scale
78 |
79 | def forward(self, x, float_precision=False):
80 | dtype = x.dtype
81 | n = self.in_features
82 | m = self.out_features
83 |
84 | if self.idxs_list[0].device != x.device:
85 | for i in range(len(self.idxs_list)):
86 | self.idxs_list[i] = self.idxs_list[0].to(x.device)
87 | self.codebook = self.codebook.to(x.device)
88 | x = x / 32
89 | if not float_precision:
90 | x = x.half()
91 | else:
92 | x = x.float()
93 |
94 | if self.codebook_version == 'E8P12':
95 | if x.size(0) == 1 and not self.transposed and not float_precision:
96 | x = quiptools_cuda.decode_matvec_e8p(
97 | x[0].to(torch.float16),
98 | self.idxs_list[0].view(m // 16, n // 64, 8, 4),
99 | self.codebook.grid_packed_abs).unsqueeze(0)
100 | else:
101 | W_decompressed = quiptools_cuda.decompress_packed_e8p(
102 | self.idxs_list[0].view(m // 16, n // 64, 8, 4),
103 | self.codebook.grid_packed_abs)
104 | if float_precision:
105 | W_decompressed = W_decompressed.float()
106 | if self.transposed:
107 | x = (x @ W_decompressed)
108 | else:
109 | x = (x @ W_decompressed.T)
110 |
111 | elif self.codebook_version == 'E8P12RVQ3B':
112 | resid_scale = self.codebook.opt_resid_scale
113 | x16 = x.to(torch.float16)
114 | if x.shape[0] == 1 and not self.transposed and not float_precision:
115 | x_padded = torch.zeros(
116 | 8, x16.shape[1], dtype=torch.float16, device=x16.device)
117 | x_padded[0] = x16[0]
118 | z = torch.zeros(
119 | 8, m, dtype=torch.float16, device=x_padded.device)
120 | quiptools_cuda.lookupmatmul_e81b_k8(
121 | x_padded / resid_scale, self.idxs_list[1],
122 | self.codebook.e81b_grid, z
123 | )
124 |
125 | x = quiptools_cuda.decode_matvec_e8p(
126 | x16[0], self.idxs_list[0].view(m // 16, n // 64, 8, 4),
127 | self.codebook.grid_packed_abs) + z[0]
128 | x = x.unsqueeze(0)
129 |
130 | else:
131 | W_decompressed = quiptools_cuda.decompress_packed_e8p(
132 | self.idxs_list[0].view(m // 16, n // 64, 8, 4),
133 | self.codebook.grid_packed_abs)
134 |
135 | W_resid_decompressed = torch.zeros(
136 | self.idxs_list[1].shape[0],
137 | 64 * self.idxs_list[1].shape[-1],
138 | device=self.idxs_list[1].device, dtype=torch.float16
139 | )
140 |
141 | quiptools_cuda.decompress_e81b_packed(
142 | self.idxs_list[1], self.codebook.e81b_grid,
143 | W_resid_decompressed
144 | )
145 |
146 | if float_precision:
147 | W_decompressed = W_decompressed.to(torch.bfloat16)
148 | W_resid_decompressed = W_resid_decompressed.to(torch.bfloat16)
149 |
150 | if self.transposed:
151 | x = (x @ (W_decompressed +
152 | W_resid_decompressed / resid_scale))
153 | else:
154 | x = (x @ (W_decompressed +
155 | W_resid_decompressed / resid_scale).T)
156 | else:
157 | resid_scale = self.codebook.opt_resid_scale
158 | if x.size(0) == 1 and not self.transposed and not float_precision:
159 | x16 = x[0].to(torch.float16)
160 | x = (quiptools_cuda.decode_matvec_e8p(
161 | x16, self.idxs_list[0].view(m // 16, n // 64, 8, 4),
162 | self.codebook.grid_packed_abs) +
163 | quiptools_cuda.decode_matvec_e8p(
164 | x16 / resid_scale, self.idxs_list[1].view(
165 | m // 16, n // 64, 8, 4),
166 | self.codebook.grid_packed_abs)).unsqueeze(0)
167 | else:
168 | W_decompressed = quiptools_cuda.decompress_packed_e8p(
169 | self.idxs_list[0].view(m // 16, n // 64, 8, 4),
170 | self.codebook.grid_packed_abs) + \
171 | quiptools_cuda.decompress_packed_e8p(
172 | self.idxs_list[1].view(m // 16, n // 64, 8, 4),
173 | self.codebook.grid_packed_abs
174 | ) / resid_scale
175 | if float_precision:
176 | W_decompressed = W_decompressed.float()
177 | if self.transposed:
178 | x = (x @ W_decompressed)
179 | else:
180 | x = (x @ W_decompressed.T)
181 | x = x.to(dtype)
182 | x *= self.scale * 32
183 | return x
184 |
185 |
186 | class CalderaQuantizedLinear(nn.Module):
187 |
188 | def __init__(
189 | self,
190 | in_features,
191 | out_features,
192 | Q_codebook_version,
193 | L_codebook_version,
194 | R_codebook_version,
195 | L, R,
196 | L_idxs, R_idxs, Q_idxs,
197 | L_scale, R_scale, Q_scale,
198 | global_scale,
199 | scaleWH,
200 | hadamard,
201 | SU, SV,
202 | rank=64,
203 | ft_rank=64,
204 | grad_ckpt=True
205 | ):
206 | super(CalderaQuantizedLinear, self).__init__()
207 |
208 | self.rank = rank
209 | self.ft_rank = ft_rank
210 |
211 | self.in_features = in_features
212 | self.out_features = out_features
213 |
214 | self.hadamard = hadamard
215 | self.global_scale = global_scale
216 | self.scaleWH = scaleWH
217 | if self.scaleWH is not None:
218 | self.scaleWH = nn.Parameter(self.scaleWH, requires_grad=False)
219 |
220 | if Q_idxs is not None:
221 | self.Q = LatticeQuantizedParameter(
222 | in_features=in_features,
223 | out_features=out_features,
224 | idxs=Q_idxs,
225 | scale=Q_scale,
226 | codebook_version=Q_codebook_version
227 | )
228 | else:
229 | self.Q = None
230 |
231 | self.L_idxs = L_idxs
232 | self.R_idxs = R_idxs
233 | self.L_codebook_version = L_codebook_version
234 | self.R_codebook_version = R_codebook_version
235 | self.L_scale = L_scale
236 | self.R_scale = R_scale
237 |
238 | self.split_L_and_R_for_LoRA(ft_rank, L, R)
239 |
240 | self.SU = nn.Parameter(SU, requires_grad=True)
241 | self.SV = nn.Parameter(SV, requires_grad=True)
242 |
243 | had_left, K_left = get_hadK(in_features)
244 | had_right, K_right = get_hadK(out_features)
245 |
246 | self.had_left = nn.Parameter(had_left, requires_grad=False)
247 | self.had_right = nn.Parameter(had_right, requires_grad=False)
248 |
249 | self.K_left = K_left
250 | self.K_right = K_right
251 |
252 | self.grad_ckpt = grad_ckpt
253 |
254 | def split_L_and_R_for_LoRA(self, ft_rank, L, R):
255 | if ft_rank > 0:
256 | self.L_ft = nn.Parameter(
257 | L[:, :ft_rank], requires_grad=True)
258 | self.R_ft = nn.Parameter(
259 | R[:ft_rank, :], requires_grad=True)
260 | assert self.L_ft != [] and self.R_ft != []
261 |
262 | if self.rank > ft_rank:
263 | if self.L_codebook_version is not None:
264 | self.L = LatticeQuantizedParameter(
265 | in_features=self.out_features,
266 | out_features=self.rank - ft_rank,
267 | idxs=self.L_idxs[ft_rank:, :],
268 | scale=self.L_scale,
269 | codebook_version=self.L_codebook_version,
270 | transposed=True
271 | )
272 | self.L_idxs = None
273 | self.quant_L = True
274 | else:
275 | self.L = nn.Parameter(L[:, ft_rank:], requires_grad=False)
276 | self.quant_L = False
277 |
278 | if self.R_codebook_version is not None:
279 | self.R = LatticeQuantizedParameter(
280 | in_features=self.in_features,
281 | out_features=self.rank - ft_rank,
282 | idxs=self.R_idxs[ft_rank:, :],
283 | scale=self.R_scale,
284 | codebook_version=self.R_codebook_version
285 | )
286 | self.R_idxs = None
287 | self.quant_R = True
288 | else:
289 | self.R = nn.Parameter(R[ft_rank:, :], requires_grad=False)
290 | self.quant_R = False
291 | else:
292 | self.L = None
293 | self.R = None
294 |
295 |
296 | def forward(self, x):
297 | old_dtype = x.dtype
298 | x = x.float()
299 | shape = x.shape
300 | n, m = len(self.SU), len(self.SV)
301 | x = x.view(-1, n)
302 | # Preprocessing
303 | if self.scaleWH is not None:
304 | x /= self.scaleWH
305 | x = x * self.SU
306 | x = matmul_hadUt_cuda(x, self.had_left, self.K_left)
307 |
308 | # Apply Q
309 | output_no_ft = self.Q.forward(x)
310 |
311 | # Apply quantized L and R
312 | if self.L is not None:
313 | if self.quant_R:
314 | xR = self.R.forward(x, float_precision=True)
315 | else:
316 | xR = (x.float() @ self.R.T.float())
317 |
318 | if self.quant_L:
319 | output_no_ft += self.L.forward(xR, float_precision=True)
320 | else:
321 | output_no_ft += xR.float() @ self.L.T.float()
322 |
323 | # Apply LoRA factors
324 | if self.ft_rank > 0:
325 | output = output_no_ft + x @ self.R_ft.T.float() @ self.L_ft.T.float()
326 | else:
327 | output = output_no_ft
328 |
329 | output = matmul_hadU_cuda(output, self.had_right, self.K_right)
330 |
331 | output = output * self.SV * self.global_scale
332 | if self.scaleWH is not None:
333 | output *= self.scaleWH
334 | return output.view(*shape[:-1], m).to(old_dtype)
335 |
336 | def compare_outputs(self, input, W_hat):
337 | output = self.no_ckpt_forward(input)
338 | comparison = input @ W_hat.T
339 | return (torch.linalg.matrix_norm(output - comparison, ord='fro') /
340 | torch.linalg.matrix_norm(comparison, ord='fro')).mean().item()
--------------------------------------------------------------------------------
/src/caldera/decomposition/alg.py:
--------------------------------------------------------------------------------
1 | from lib.utils.math_utils import block_LDL
2 | import lib.algo.quip as quip
3 | from lib import codebook
4 |
5 | import torch
6 | from tqdm import tqdm
7 |
8 | from caldera.decomposition.dataclasses import *
9 | from caldera.utils.quantization import QuantizerFactory
10 |
11 | from collections import namedtuple
12 | from copy import deepcopy
13 |
14 |
15 | # Maps number of bits to name of the QuIP# lattice quantizer
16 | BITS_TO_CODEBOOK = {
17 | 2: 'E8P12',
18 | 3: 'E8P12RVQ3B',
19 | 4: 'E8P12RVQ4B'
20 | }
21 |
22 | def caldera(
23 | quant_params: CalderaParams,
24 | W: torch.Tensor,
25 | H: torch.Tensor = None,
26 | device: str = "cuda",
27 | use_tqdm: bool = True,
28 | scale_W: bool = True,
29 | ):
30 | """
31 | Runs the CALDERA algorithm, to decompose a weight matrix into Q + LR, where
32 | Q is full-rank, L and R are low-rank factors, and all matrices are in a low-
33 | precision format.
34 | """
35 | # scaling
36 | if scale_W:
37 | global_scale = W.square().mean().sqrt().item()
38 | else:
39 | global_scale = 1
40 | W = W / global_scale
41 |
42 | if H is None:
43 | H = torch.eye(W.shape[1]).to(device)
44 |
45 | # Compute the symmetric square root of H, because the data-aware
46 | # objective can be formulated as
47 | # min_{L, R} ||(W - LR - Q)H^{1/2}||_F^2.
48 | EigTuple = namedtuple("EigTuple", ["eigenvalues", "eigenvectors"])
49 | if not quant_params.activation_aware_LR and not quant_params.activation_aware_Q \
50 | and not quant_params.full_quip_sharp:
51 | H_sqrt = H
52 | eigH = EigTuple(torch.ones(W.shape[1]).to(device), H)
53 | else:
54 | eigH = torch.linalg.eigh(H)
55 |
56 | eigvals = eigH.eigenvalues
57 | # if eigvals.min() < quant_params.quip_args.sigma_reg:
58 | # H = H + (quant_params.quip_args.sigma_reg - eigvals.min()) * \
59 | # torch.eye(H.shape[0], device=H.device, dtype=H.dtype)
60 | # eigvals += quant_params.quip_args.sigma_reg - eigvals.min()
61 | # eigH = EigTuple(eigvals, eigH.eigenvectors)
62 |
63 | H_sqrt = (eigH.eigenvectors @
64 | torch.diag(torch.sqrt(eigvals)) @
65 | eigH.eigenvectors.T)
66 |
67 | # Initialization and Hadamard transform
68 | best_decomp = CalderaDecomposition(
69 | Q=torch.zeros_like(W).float(),
70 | L=torch.zeros(W.shape[0], quant_params.rank).to(device),
71 | R=torch.zeros(quant_params.rank, W.shape[1]).to(device))
72 |
73 | if quant_params.hadamard_transform:
74 | _, H, W, SU, SV, scaleWH = quip.incoherence_preprocess(
75 | H, W, quant_params.quip_args
76 | )
77 | best_decomp.SU = SU.to(W.device)
78 | best_decomp.SV = SV.to(W.device)
79 | best_decomp.scaleWH = scaleWH
80 |
81 | eigH = torch.linalg.eigh(H)
82 | H_sqrt = (eigH.eigenvectors @
83 | torch.diag(torch.sqrt(eigH.eigenvalues)) @
84 | eigH.eigenvectors.T)
85 | else:
86 | best_decomp.scaleWH = None
87 | best_decomp.SU = torch.ones(W.shape[1]).to(W.dtype).to(W.device)
88 | best_decomp.SV = torch.ones(W.shape[0]).to(W.dtype).to(W.device)
89 |
90 | best_decomp.W = W.cpu()
91 | errors = {}
92 | for mtx in quant_params.update_order:
93 | errors[mtx] = []
94 |
95 | min_error = float('inf')
96 | curr_decomp = deepcopy(best_decomp)
97 |
98 | updated = {mtx: False for mtx in quant_params.update_order}
99 |
100 | to_iter = range(quant_params.iters)
101 | if use_tqdm:
102 | to_iter = tqdm(to_iter)
103 | for _ in to_iter:
104 | for mtx in quant_params.update_order:
105 | if mtx == "LR":
106 | maybe_update_LR(curr_decomp, quant_params, W, H_sqrt, eigH, device)
107 | elif mtx == "Q":
108 | maybe_update_Q(curr_decomp, quant_params, W, H, device)
109 | updated[mtx] = True
110 |
111 | errors[mtx].append(
112 | activation_aware_error(W, H, curr_decomp, device)
113 | )
114 | if errors[mtx][-1] < min_error and all(updated.values()):
115 | min_error = errors[mtx][-1]
116 | best_decomp = deepcopy(curr_decomp)
117 | best_decomp.errors = errors
118 |
119 | # Update scales
120 | best_decomp.global_scale = global_scale
121 | return best_decomp
122 |
123 |
124 | def activation_aware_error(
125 | W: torch.Tensor,
126 | H: torch.Tensor,
127 | caldera_info: CalderaDecomposition,
128 | device: str
129 | ):
130 | """
131 | Computes the activation-aware loss for a sublayer as
132 | tr((W - W_hat) H (W - W_hat).T) / tr(W H^1/2),
133 | where H^1/2 is the symmetric square root.
134 | """
135 |
136 | W = W.to(device).float()
137 | W_hat = caldera_info.Q + caldera_info.L @ caldera_info.R
138 | W_hat *= caldera_info.global_scale
139 |
140 | error = (torch.trace((W_hat - W) @ H @ (W_hat - W).T) /
141 | torch.trace(W @ H @ W.T)).sqrt().item()
142 | return error
143 |
144 |
145 | def get_quant_info(
146 | use_lattice_quant: bool,
147 | quant_factory: QuantizerFactory,
148 | bits: int,
149 | device: str
150 | ):
151 | cb = None
152 | quantizer = None
153 | if use_lattice_quant:
154 | cb = codebook.get_codebook(BITS_TO_CODEBOOK[bits]).to(device)
155 | else:
156 | quantizer = quant_factory.get_quantizer(bits, device)
157 | return QuantInfo(
158 | lattice_quant=use_lattice_quant,
159 | lattice_cb=cb,
160 | quant=quantizer
161 | )
162 |
163 |
164 | def quantize_matrix(
165 | A, quant_params,
166 | quant_info: QuantInfo = None
167 | ):
168 | QuantReturn = namedtuple(
169 | 'QuantReturn', ['A_hat', 'A_idxs', 'scale']
170 | )
171 | if not quant_info.lattice_quant:
172 | quant_info.quant.block_size = A.shape[0] * A.shape[1]
173 | A_idxs, scales, shape = quant_info.quant.quantize_block(A)
174 | A_hat = quant_info.quant.dequantize_block(A_idxs, scales, shape)
175 | return QuantReturn(A_hat, A_idxs, scales)
176 |
177 | # Scale before quantization, as in QuIP#
178 | scale = A.square().mean().sqrt().item()
179 |
180 | m, n = A.shape
181 |
182 | A = A.reshape(-1, quant_info.lattice_cb.codesz).clone() / scale
183 | A_idxs = torch.zeros(
184 | m * n // quant_info.lattice_cb.codesz,
185 | dtype=quant_info.lattice_cb.idx_dtype,
186 | device=A.device
187 | )
188 | K = quant_params.lattice_quant_block_size
189 | for i in range(0, A.shape[0], K):
190 | A[i:i+K], A_idxs[i:i+K] = \
191 | quant_info.lattice_cb.quantize(A.float()[i:i+K])
192 | A = A.reshape(m, n)
193 | A_idxs = A_idxs.reshape(m, n // quant_info.lattice_cb.codesz)
194 |
195 | A = A * scale
196 |
197 | A_idxs = quant_info.lattice_cb.maybe_pack_idxs(A_idxs)
198 | return QuantReturn(A, A_idxs, scale)
199 |
200 |
201 | def maybe_update_Q(
202 | caldera_info: CalderaDecomposition,
203 | quant_params: CalderaParams,
204 | W: torch.Tensor,
205 | H: torch.Tensor,
206 | device: str
207 | ):
208 |
209 | if quant_params.compute_quantized_component:
210 | residual = W - caldera_info.L @ caldera_info.R
211 | if not quant_params.compute_low_rank_factors:
212 | residual = W
213 | if quant_params.activation_aware_Q:
214 | update_Q_data_aware(caldera_info, quant_params, H, residual, device)
215 | else:
216 | update_Q_non_data_aware(caldera_info, quant_params, residual, device)
217 |
218 |
219 | def update_Q_non_data_aware(
220 | caldera_info: CalderaDecomposition,
221 | quant_params: CalderaParams,
222 | residual: torch.Tensor,
223 | device: str
224 | ):
225 | quant_info = get_quant_info(
226 | use_lattice_quant=quant_params.lattice_quant_Q,
227 | quant_factory=quant_params.quant_factory_Q,
228 | bits=quant_params.Q_bits,
229 | device=device
230 | )
231 |
232 | quant_return = quantize_matrix(residual, quant_params, quant_info)
233 | caldera_info.Q = quant_return.A_hat
234 | caldera_info.Q_idxs = quant_return.A_idxs
235 | caldera_info.Q_scale = quant_return.scale
236 |
237 |
238 | def update_Q_data_aware(
239 | caldera_info: CalderaDecomposition,
240 | quant_params: CalderaParams,
241 | H: torch.Tensor,
242 | residual: torch.Tensor,
243 | device: str
244 | ):
245 | """
246 | Performs an LDLQ update on the residual (W - LR)
247 | """
248 |
249 | # Scale the residual, as done in the quantize_linear function of QuIP#
250 | scale = residual.square().mean().sqrt().item()
251 | residual /= scale
252 |
253 | codebook_str = BITS_TO_CODEBOOK[quant_params.Q_bits]
254 | cb = codebook.get_codebook(codebook_str).to(residual.device)
255 |
256 | if quant_params.compute_low_rank_factors and \
257 | quant_params.Q_hessian_downdate:
258 |
259 | M = torch.linalg.cholesky(H)
260 | if quant_params.rand_svd:
261 | _, _, V = torch.svd_lowrank(
262 | caldera_info.L @ caldera_info.R @ M,
263 | quant_params.rank * 3, niter=10)
264 | V = V[:, :quant_params.rank]
265 | else:
266 | _, _, Vh = torch.linalg.svd(
267 | caldera_info.L @ caldera_info.R @ M, full_matrices=False)
268 | V = Vh.T[:, :quant_params.rank]
269 |
270 | H = H - (M @ V @ V.T @ M.T).to(H.dtype)
271 | min_eigval = torch.linalg.eigh(H).eigenvalues.min()
272 | H = H + (quant_params.quip_args.sigma_reg2 + max(-min_eigval, 0)) * torch.eye(H.shape[0], device=H.device, dtype=H.dtype)
273 | alpha = torch.diag(H).mean().abs() * quant_params.quip_args.sigma_reg2
274 | H = H + alpha * torch.eye(H.shape[0], device=H.device, dtype=H.dtype)
275 |
276 | if quant_params.full_quip_sharp:
277 | assert not quant_params.hadamard_transform, \
278 | ("Full QuIP# incompatible with performing Hadamard transform "
279 | "on our end")
280 | assert quant_params.rank == 0 or \
281 | not quant_params.compute_low_rank_factors, \
282 | ("Full QuIP# incompatible with separately computing low-rank "
283 | "factors.")
284 |
285 | caldera_info.Q, attr = quip.quantize(
286 | H_orig=H,
287 | W_orig=residual,
288 | rank=0,
289 | codebook_orig=cb,
290 | args=quant_params.quip_args,
291 | device=device
292 | )
293 | caldera_info.Q_idxs = attr['Qidxs'].to(device)
294 |
295 | caldera_info.scaleWH = attr['scaleWH']
296 | caldera_info.SU = attr['SU']
297 | caldera_info.SV = attr['SV']
298 | if quant_params.quip_args.lora_rank != 0:
299 | caldera_info.L = attr['A'].to(device) / caldera_info.SV[0].abs().sqrt()
300 | caldera_info.R = attr['B'].to(device) / caldera_info.SV[0].abs().sqrt()
301 | caldera_info.L_scale = scale
302 | caldera_info.R_scale = scale
303 | caldera_info.Q -= caldera_info.L @ caldera_info.R
304 |
305 | else:
306 | # Just do LDLQ
307 | block_LDL_out = block_LDL(H, cb.codesz)
308 | assert block_LDL_out is not None
309 |
310 | L, D = block_LDL_out
311 | del block_LDL_out
312 |
313 | scale /= cb.opt_scale
314 | residual *= cb.opt_scale
315 |
316 | if quant_params.quip_args.no_use_buffered:
317 | Q, Qidxs = quip.LDLQ(
318 | residual, H, L, D, cb, quant_params.quip_args)
319 | elif quant_params.quip_args.lowmem_ldlq or \
320 | quant_params.quip_args.use_fp64:
321 | Q, Qidxs = quip.LDLQ_buffered_lowmem(
322 | residual, H, L, D, cb, quant_params.quip_args,
323 | buf_cols=128)
324 | else:
325 | Q, Qidxs = quip.LDLQ_buffered(
326 | residual, H, L, D, cb, quant_params.quip_args,
327 | buf_cols=128)
328 | caldera_info.Q_idxs = Qidxs
329 | caldera_info.Q = Q
330 | caldera_info.Q_idxs = cb.maybe_pack_idxs(caldera_info.Q_idxs)
331 |
332 | caldera_info.Q_scale = scale
333 | caldera_info.Q *= scale
334 |
335 |
336 | def LR_init(
337 | caldera_info: CalderaDecomposition,
338 | quant_params: CalderaParams,
339 | H_sqrt: torch.Tensor,
340 | eigH: torch.Tensor,
341 | residual: torch.Tensor
342 | ):
343 | """
344 | Runs rank-constrained regression to minimize
345 | ||(residual - LR) eigH||_F^2
346 | over L, R in closed-form.
347 | """
348 | if quant_params.activation_aware_LR:
349 | Y = residual @ H_sqrt @ eigH.eigenvectors
350 | if quant_params.rand_svd:
351 | q = min(quant_params.rank*2, min(*caldera_info.W.shape))
352 | U, Sigma, V = torch.svd_lowrank(Y, q)
353 | Vh = V.T
354 | else:
355 | U, Sigma, Vh = torch.linalg.svd(Y, full_matrices=False)
356 |
357 | L = U[:, :quant_params.rank]
358 | R = torch.diag(Sigma[:quant_params.rank]) @ \
359 | Vh[:quant_params.rank, :] @ \
360 | torch.diag(1 / eigH.eigenvalues.sqrt()) @ eigH.eigenvectors.T
361 | else:
362 | if quant_params.rand_svd:
363 | q = min(quant_params.rank*2,
364 | min(*caldera_info.W.shape))
365 | U, Sigma, V = torch.svd_lowrank(residual, q)
366 | Vh = V.T
367 | else:
368 | U, Sigma, Vh = torch.linalg.svd(residual, full_matrices=False)
369 | L = U[:, :quant_params.rank] @ \
370 | torch.diag(Sigma[:quant_params.rank].sqrt())
371 | R = torch.diag(Sigma[:quant_params.rank].sqrt()) @ \
372 | Vh[:quant_params.rank, :]
373 | return L, R
374 |
375 | def maybe_update_LR(
376 | caldera_info: CalderaDecomposition,
377 | quant_params: CalderaParams,
378 | W: torch.Tensor,
379 | H_sqrt: torch.Tensor,
380 | eigH,
381 | device
382 | ):
383 | if quant_params.compute_low_rank_factors:
384 | residual = W - caldera_info.Q
385 | update_LR(caldera_info, quant_params, residual, H_sqrt, eigH, device)
386 |
387 |
388 | def update_LR(
389 | caldera_info: CalderaDecomposition,
390 | quant_params: CalderaParams,
391 | residual: torch.Tensor,
392 | H_sqrt: torch.Tensor,
393 | eigH,
394 | device
395 | ):
396 | """
397 | Run LPLR on the residual (W - Q)
398 | """
399 | data_aware = quant_params.activation_aware_LR
400 |
401 | # Initialization of L, R
402 | L, R = LR_init(caldera_info, quant_params, H_sqrt, eigH, residual)
403 |
404 | if quant_params.L_bits < 16 or quant_params.R_bits < 16:
405 | quant_info_L = get_quant_info(
406 | use_lattice_quant=quant_params.lattice_quant_LR,
407 | quant_factory=quant_params.quant_factory_LR,
408 | bits=quant_params.L_bits,
409 | device=device
410 | )
411 | quant_info_R = get_quant_info(
412 | use_lattice_quant=quant_params.lattice_quant_LR,
413 | quant_factory=quant_params.quant_factory_LR,
414 | bits=quant_params.R_bits,
415 | device=device
416 | )
417 |
418 | best_L, best_R = L, R
419 | best_L_quant_out, best_R_quant_out = None, None
420 | best_error = float('inf')
421 |
422 | for _ in range(quant_params.lplr_iters):
423 | # L
424 | if data_aware:
425 | L = torch.linalg.lstsq((R @ H_sqrt).T, (residual @ H_sqrt).T)[0].T
426 | if torch.isnan(L).any():
427 | L = (residual @ H_sqrt) @ torch.linalg.pinv(R @ H_sqrt)
428 | else:
429 | L = torch.linalg.lstsq(R.T, residual.T)[0].T
430 | if torch.isnan(R).any():
431 | L = residual @ torch.linalg.pinv(R)
432 |
433 | quant_out_L = quantize_matrix(L.T, quant_params, quant_info_L)
434 | L = quant_out_L.A_hat.T
435 |
436 | # R
437 | R = torch.linalg.lstsq(L, residual)[0]
438 | if torch.isnan(R).any():
439 | R = torch.linalg.pinv(L) @ residual
440 |
441 | quant_out_R = quantize_matrix(R, quant_params, quant_info_R)
442 | R = quant_out_R.A_hat
443 |
444 | error = torch.linalg.matrix_norm((residual - L @ R) @ H_sqrt) #/ \
445 | # torch.linalg.matrix_norm((residual + caldera_info.Q) @ H_sqrt)
446 | if error < best_error:
447 | best_L, best_R = L, R
448 | best_L_quant_out = quant_out_L
449 | best_R_quant_out = quant_out_R
450 | best_error = error
451 |
452 | caldera_info.L_idxs = best_L_quant_out.A_idxs
453 | caldera_info.R_idxs = best_R_quant_out.A_idxs
454 | caldera_info.L_scale = best_L_quant_out.scale
455 | caldera_info.R_scale = best_R_quant_out.scale
456 |
457 | L, R = best_L, best_R
458 |
459 | caldera_info.L = L
460 | caldera_info.R = R
--------------------------------------------------------------------------------
/scripts/finetune_winogrande.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/tatsu-lab/stanford_alpaca/blob/main/train.py
2 |
3 | import copy
4 | import logging
5 | import os
6 | from dataclasses import dataclass, field
7 | from typing import Dict, Optional, Sequence
8 |
9 | import torch
10 | import transformers
11 | from transformers import default_data_collator, DataCollatorWithPadding, get_scheduler
12 |
13 | from datasets import load_dataset, Dataset
14 | from quantize_save_llama import load_quantized_model
15 | from torch.utils.data import DataLoader
16 |
17 | from accelerate import Accelerator
18 | from accelerate.logging import get_logger
19 | from accelerate.utils import set_seed
20 | import datasets
21 |
22 | import math
23 | import evaluate
24 | from tqdm import tqdm
25 |
26 |
27 | logger = get_logger(__name__)
28 |
29 |
30 | @dataclass
31 | class ModelArguments:
32 | base_model: str = field()
33 | model_name_or_path: str = field()
34 | token: Optional[str] = field(
35 | default=None,
36 | metadata={"help": "HF token to access to private models, e.g., meta-llama"},
37 | )
38 |
39 |
40 | @dataclass
41 | class TrainingArguments(transformers.TrainingArguments):
42 | max_seq_length: int = field(default=128, metadata={
43 | "help": ("The maximum total input sequence length after tokenization. Sequences longer "
44 | "than this will be truncated, sequences shorter will be padded.")
45 | })
46 | with_tracking: bool = field(default=False, metadata={
47 | "help": "Whether to report eval accuracies to, e.g., tensorboard."
48 | })
49 | num_warmup_steps: int = field(default=0)
50 |
51 | def run_eval(eval_dataloader, model, accelerator, metric):
52 | samples_seen = 0
53 |
54 | for step, batch in enumerate(eval_dataloader):
55 | with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
56 | with torch.no_grad():
57 | # batch = {k: v.to(device) if torch.is_tensor(v) else v for k, v in batch.items()}
58 | predictions = model(**batch)
59 |
60 | references = batch["labels"]
61 | predictions, references = accelerator.gather_for_metrics((predictions, references))
62 |
63 | # print(predictions.logits)
64 | if predictions.logits.isnan().any():
65 | print("WARNING NaN OUTPUT LOGITS")
66 | predictions = predictions.logits.argmax(dim=-1)
67 | # If we are in a multiprocess environment, the last batch has duplicates
68 | if step == len(eval_dataloader) - 1:
69 | predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
70 | references = references[: len(eval_dataloader.dataset) - samples_seen]
71 | else:
72 | samples_seen += references.shape[0]
73 | metric.add_batch(
74 | predictions=predictions,
75 | references=references,
76 | )
77 | return metric.compute()
78 |
79 |
80 | def train():
81 | parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
82 | model_args, training_args = parser.parse_args_into_dataclasses()
83 |
84 | transformers.utils.send_example_telemetry("run_clm_no_trainer", training_args)
85 |
86 | training_args.output_dir = os.path.join(
87 | training_args.output_dir,
88 | f"ep_{int(training_args.num_train_epochs)}_lr_{training_args.learning_rate}_seed_{training_args.seed}"
89 | )
90 |
91 | accelerator_log_kwargs = {}
92 |
93 | if training_args.with_tracking:
94 | accelerator_log_kwargs["log_with"] = training_args.report_to
95 | accelerator_log_kwargs["project_dir"] = training_args.output_dir
96 |
97 | accelerator = Accelerator(
98 | gradient_accumulation_steps=training_args.gradient_accumulation_steps,
99 | **accelerator_log_kwargs)
100 |
101 | logging.basicConfig(
102 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
103 | datefmt="%m/%d/%Y %H:%M:%S",
104 | level=logging.INFO,
105 | )
106 |
107 | logger.info(accelerator.state, main_process_only=False)
108 | if accelerator.is_local_main_process:
109 | datasets.utils.logging.set_verbosity_warning()
110 | transformers.utils.logging.set_verbosity_info()
111 | else:
112 | datasets.utils.logging.set_verbosity_error()
113 | transformers.utils.logging.set_verbosity_error()
114 |
115 | # If passed along, set the training seed now.
116 | if accelerator.is_main_process:
117 | if training_args.seed is not None:
118 | set_seed(training_args.seed)
119 |
120 | if training_args.output_dir is not None:
121 | os.makedirs(training_args.output_dir, exist_ok=True)
122 | # writer = SummaryWriter(args.output_dir)
123 | accelerator.wait_for_everyone()
124 |
125 | model = load_quantized_model(
126 | model_args.model_name_or_path, model_args.base_model,
127 | accelerator.device, sequence_classification=True
128 | ).to(torch.bfloat16)
129 |
130 | for name, param in model.named_parameters():
131 | if 'SU' in name or 'SV' in name:
132 | param.requires_grad = False
133 |
134 | tokenizer = transformers.AutoTokenizer.from_pretrained(
135 | model_args.base_model,
136 | token=model_args.token,
137 | use_fast=True,
138 | )
139 |
140 | model.config.pad_token_id = 0
141 | tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token
142 | tokenizer.padding_side = "left" # Allow batched inference
143 | tokenizer.truncation_side = "left"
144 |
145 | embedding_size = model.get_input_embeddings().weight.shape[0]
146 | if len(tokenizer) > embedding_size:
147 | model.resize_token_embeddings(len(tokenizer))
148 |
149 | train_data = load_dataset("winogrande", "winogrande_xl", split="train").to_pandas()
150 | eval_data = load_dataset("winogrande", "winogrande_xl", split="validation").to_pandas()
151 |
152 | def preprocess(data):
153 | data["text1"] = data.apply(lambda x: x["sentence"].replace("_", x["option1"]), axis=1)
154 | data["text2"] = data.apply(lambda x: x["sentence"].replace("_", x["option2"]), axis=1)
155 | data["label"] = data.apply(lambda x: int(x["answer"]) - 1, axis=1)
156 | return Dataset.from_pandas(data)
157 |
158 | def tokenize(sample):
159 | model_inps = tokenizer(sample["text1"], sample["text2"], padding="max_length",
160 | truncation=True, max_length=training_args.max_seq_length)
161 | model_inps["labels"] = sample["label"]
162 | return model_inps
163 |
164 | with accelerator.main_process_first():
165 | train_data = preprocess(train_data)
166 | eval_data = preprocess(eval_data)
167 | tokenized_train_data = train_data.map(tokenize, batched=True, desc="Tokenizing training data",
168 | remove_columns=train_data.column_names)
169 | tokenized_eval_data = eval_data.map(tokenize, batched=True, desc="Tokenizing eval data",
170 | remove_columns=train_data.column_names)
171 |
172 | print(tokenized_train_data)
173 | # print(tokenized_train_data["labels"])
174 |
175 | train_dataloader = DataLoader(
176 | tokenized_train_data, collate_fn=default_data_collator,
177 | batch_size=training_args.per_device_train_batch_size, shuffle=True
178 | )
179 | eval_dataloader = DataLoader(
180 | tokenized_eval_data, collate_fn=default_data_collator,
181 | batch_size=training_args.per_device_eval_batch_size
182 | )
183 |
184 | # Optimizer
185 | # Split weights in two groups, one with weight decay and the other not.
186 | no_decay = ["bias", "LayerNorm.weight"]
187 | optimizer_grouped_parameters = [
188 | {
189 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
190 | "weight_decay": training_args.weight_decay,
191 | },
192 | {
193 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
194 | "weight_decay": 0.0,
195 | },
196 | ]
197 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=training_args.learning_rate)
198 |
199 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps)
200 | max_train_steps = training_args.num_train_epochs * num_update_steps_per_epoch
201 |
202 | lr_scheduler = get_scheduler(
203 | name=training_args.lr_scheduler_type,
204 | optimizer=optimizer,
205 | num_warmup_steps=training_args.num_warmup_steps,
206 | num_training_steps=max_train_steps,
207 | )
208 |
209 | # Prepare everything with our `accelerator`.
210 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
211 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
212 | )
213 |
214 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps)
215 | max_train_steps = training_args.num_train_epochs * num_update_steps_per_epoch
216 | # Afterwards we recalculate our number of training epochs
217 | training_args.num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
218 |
219 | # Figure out how many steps we should save the Accelerator states
220 | save_steps = training_args.save_steps
221 | if save_steps is not None and isinstance(save_steps, str) and save_steps.isdigit():
222 | save_steps = int(save_steps)
223 |
224 | # We need to initialize the trackers we use, and also store our configuration.
225 | # The trackers initializes automatically on the main process.
226 | if training_args.with_tracking:
227 | experiment_config = vars(training_args)
228 | # TensorBoard cannot log Enums, need the raw value
229 | # experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
230 | accelerator.init_trackers("winogrande_no_trainer", {})
231 |
232 | metric = evaluate.load("accuracy")
233 | max_train_steps = int(max_train_steps)
234 |
235 | # Train!
236 | total_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
237 |
238 | print("***** Running training *****")
239 | print(f" Num examples = {len(tokenized_train_data)}")
240 | print(f" Num Epochs = {training_args.num_train_epochs}")
241 | print(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
242 | print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
243 | print(f" Gradient Accumulation steps = {training_args.gradient_accumulation_steps}")
244 | print(f" Total optimization steps = {max_train_steps}")
245 | # Only show the progress bar once on each machine.
246 | progress_bar = tqdm(range(max_train_steps))
247 | completed_steps = 0
248 | starting_epoch = 0
249 | # Potentially load in the weights and states from a previous save
250 | if training_args.resume_from_checkpoint:
251 | if training_args.resume_from_checkpoint is not None or training_args.resume_from_checkpoint != "":
252 | accelerator.print(f"Resumed from checkpoint: {training_args.resume_from_checkpoint}")
253 | accelerator.load_state(training_args.resume_from_checkpoint)
254 | path = os.path.basename(training_args.resume_from_checkpoint)
255 | else:
256 | # Get the most recent checkpoint
257 | dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
258 | dirs.sort(key=os.path.getctime)
259 | path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
260 | # Extract `epoch_{i}` or `step_{i}`
261 | training_difference = os.path.splitext(path)[0]
262 |
263 | if "epoch" in training_difference:
264 | starting_epoch = int(training_difference.replace("epoch_", "")) + 1
265 | resume_step = None
266 | else:
267 | resume_step = int(training_difference.replace("step_", ""))
268 | starting_epoch = resume_step // len(train_dataloader)
269 | resume_step -= starting_epoch * len(train_dataloader)
270 | completed_steps = starting_epoch * len(train_dataloader) + resume_step
271 | progress_bar.update(completed_steps)
272 |
273 | performace_dict = {}
274 | for epoch in range(starting_epoch, training_args.num_train_epochs):
275 | model.train()
276 | if training_args.with_tracking:
277 | total_loss = 0
278 | # if training_args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
279 | # # We skip the first `n` batches in the dataloader when resuming from a checkpoint
280 | # train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
281 |
282 | for step, batch in enumerate(train_dataloader):
283 | # print(batch["attention_mask"])
284 | # # We need to skip steps until we reach the resumed step
285 | if training_args.resume_from_checkpoint and epoch == starting_epoch:
286 | if resume_step is not None and step < resume_step:
287 | # completed_steps += 1
288 | continue
289 |
290 | # print(batch["labels"], batch["labels"].shape)
291 | with accelerator.accumulate(model):
292 | # print(batch)
293 | # return
294 | outputs = model(**batch)
295 | loss = outputs.loss
296 | # We keep track of the loss at each epoch
297 | if training_args.with_tracking:
298 | total_loss += loss.detach().float()
299 | accelerator.backward(loss)
300 | if completed_steps % 50:
301 | accelerator.print(f"Epoch: {epoch} | Step: {completed_steps} | Loss: {loss}")
302 | optimizer.step()
303 | lr_scheduler.step()
304 | optimizer.zero_grad()
305 |
306 | # Checks if the accelerator has performed an optimization step behind the scenes
307 | if accelerator.sync_gradients:
308 | progress_bar.update(1)
309 | completed_steps += 1
310 |
311 | if isinstance(save_steps, int):
312 | if completed_steps % save_steps == 0:
313 | output_dir = f"step_{completed_steps}"
314 | if output_dir is not None:
315 | output_dir = os.path.join(training_args.output_dir, output_dir)
316 | accelerator.save_state(output_dir)
317 |
318 | if completed_steps % 200 == 0:
319 | model.eval()
320 | eval_metric = run_eval(eval_dataloader, model, accelerator, metric)
321 |
322 | # for k, v in eval_metric.items():
323 | # writer.add_scalar(f"eval/{args.output_dir}/{k}", v, global_step=completed_steps)
324 | logger.info(
325 | f"seed {training_args.seed} learning rate {training_args.learning_rate} "
326 | + f"epoch {epoch}: {eval_metric}")
327 | performace_dict[completed_steps]=eval_metric["accuracy"]
328 |
329 | if training_args.with_tracking and total_loss != 0:
330 | accelerator.log(
331 | {
332 | "accuracy": eval_metric,
333 | "train_loss": total_loss.item() / len(train_dataloader),
334 | "epoch": epoch,
335 | "step": completed_steps,
336 | },
337 | step=completed_steps,
338 | )
339 |
340 | if completed_steps >= max_train_steps:
341 | break
342 |
343 | if completed_steps % 500 == 0 and step % training_args.gradient_accumulation_steps == 0 :
344 | logger.info(f"The current loss is {loss}")
345 |
346 |
347 | if training_args.save_steps == "epoch":
348 | output_dir = f"epoch_{epoch}"
349 | if output_dir is not None:
350 | output_dir = os.path.join(training_args.output_dir, output_dir)
351 |
352 | accelerator.save_state(output_dir)
353 |
354 | accelerator.wait_for_everyone()
355 | if accelerator.is_main_process:
356 | from safetensors.torch import save_file
357 | save_file(accelerator.get_state_dict(model), output_dir + "/model.safetensors")
358 | print("Saved checkpoint in ", output_dir + "/model.safetensors")
359 | accelerator.wait_for_everyone()
360 |
361 | model.eval()
362 | eval_metric = run_eval(eval_dataloader, model, accelerator, metric)
363 |
364 | # for k, v in eval_metric.items():
365 | # writer.add_scalar(f"eval/{args.output_dir}/{k}", v, global_step=epoch)
366 | logger.info(f"{training_args.output_dir} | epoch {epoch}: {eval_metric}")
367 | if training_args.with_tracking and total_loss != 0:
368 | accelerator.log(
369 | {
370 | "accuracy": eval_metric,
371 | "train_loss": total_loss.item() / len(train_dataloader),
372 | "epoch": epoch,
373 | "step": completed_steps,
374 | },
375 | step=completed_steps,
376 | )
377 | performace_dict[epoch] = eval_metric["accuracy"]
378 |
379 | # torch.save(model.state_dict(), args.output_dir + f"/{completed_steps}.bin")
380 |
381 | model.eval()
382 | eval_metric = run_eval(eval_dataloader, model, accelerator, metric)
383 |
384 | # for k, v in eval_metric.items():
385 | # writer.add_scalar(f"eval/{args.output_dir}/{k}", v, global_step=completed_steps)
386 | print(f"{training_args.output_dir} | step {completed_steps}: {eval_metric}")
387 | if not eval:
388 | best_performance = max(performace_dict.values())
389 | max_keys = [k for k, v in performace_dict.items() if
390 | v == best_performance] # getting all keys containing the `maximum`
391 | print(f"seed {training_args.seed} learning rate {args.learning_rate} "
392 | + f"The best performance is at {max_keys[0]} with {best_performance}")
393 |
394 |
395 | if __name__ == "__main__":
396 | train()
397 |
--------------------------------------------------------------------------------
/scripts/finetune_glue.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.f
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Finetuning a 🤗 Transformers model for sequence classification on GLUE."""
16 | import argparse
17 | import json
18 | import logging
19 | import math
20 | import os
21 | import random
22 | import numpy as np
23 |
24 |
25 | import datasets
26 | import evaluate
27 | import torch
28 | from accelerate import Accelerator
29 | from accelerate.logging import get_logger
30 | from accelerate.utils import set_seed
31 |
32 | from datasets import load_dataset
33 | from torch.utils.data import DataLoader
34 | from tqdm.auto import tqdm
35 |
36 | import transformers
37 | from transformers import (
38 | AutoTokenizer,
39 | AutoConfig,
40 | DataCollatorWithPadding,
41 | PretrainedConfig,
42 | SchedulerType,
43 | default_data_collator,
44 | get_scheduler,
45 | )
46 | from quantize_save_llama import load_quantized_model
47 |
48 | logger = get_logger(__name__)
49 |
50 | task_to_keys = {
51 | "cola": ("sentence", None),
52 | "mnli": ("premise", "hypothesis"),
53 | "mrpc": ("sentence1", "sentence2"),
54 | "qnli": ("question", "sentence"),
55 | "qqp": ("question1", "question2"),
56 | "rte": ("sentence1", "sentence2"),
57 | "sst2": ("sentence", None),
58 | "stsb": ("sentence1", "sentence2"),
59 | "wnli": ("sentence1", "sentence2"),
60 | }
61 |
62 | task_to_metrics = {
63 | "cola": "matthews_correlation",
64 | "mnli": "accuracy",
65 | "mrpc": "f1",
66 | "qnli": "accuracy",
67 | "qqp": "f1",
68 | "rte": "accuracy",
69 | "sst2": "accuracy",
70 | "stsb": "pearson",
71 | }
72 |
73 | device = torch.device(0)
74 | DEBUG = False
75 |
76 | def set_seed(seed: int):
77 | random.seed(seed)
78 | np.random.seed(seed)
79 | torch.manual_seed(seed)
80 | torch.cuda.manual_seed_all(seed)
81 |
82 |
83 | def parse_args():
84 | parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
85 | parser.add_argument(
86 | "--task_name",
87 | type=str,
88 | default=None,
89 | help="The name of the glue task to train on.",
90 | choices=list(task_to_keys.keys()),
91 | )
92 | parser.add_argument(
93 | "--train_file", type=str, default=None, help="A csv or a json file containing the training data."
94 | )
95 | parser.add_argument(
96 | "--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
97 | )
98 | parser.add_argument(
99 | "--max_length",
100 | type=int,
101 | default=128,
102 | help=(
103 | "The maximum total input sequence length after tokenization. Sequences longer than this will be truncated,"
104 | " sequences shorter will be padded if `--pad_to_max_length` is passed."
105 | ),
106 | )
107 | parser.add_argument(
108 | "--pad_to_max_length",
109 | action="store_true",
110 | help="If passed, pad all samples to `max_length`. Otherwise, dynamic padding is used.",
111 | )
112 | parser.add_argument(
113 | "--model_name_or_path",
114 | type=str,
115 | help="Path to pretrained model or model identifier from huggingface.co/models.",
116 | required=True,
117 | )
118 | parser.add_argument(
119 | "--base_model",
120 | type=str,
121 | help="Huggingface identifier of original model",
122 | required=True
123 | )
124 | parser.add_argument(
125 | "--tokenizer_name",
126 | type=str,
127 | default=None,
128 | help="Pretrained tokenizer name or path if not the same as model_name",
129 | )
130 | parser.add_argument(
131 | "--use_slow_tokenizer",
132 | action="store_true",
133 | help="If passed, will use a slow tokenizer (not backed by the 🤗 Tokenizers library).",
134 | )
135 | parser.add_argument(
136 | "--per_device_train_batch_size",
137 | type=int,
138 | default=8,
139 | help="Batch size (per device) for the training dataloader.",
140 | )
141 | parser.add_argument(
142 | "--per_device_eval_batch_size",
143 | type=int,
144 | default=8,
145 | help="Batch size (per device) for the evaluation dataloader.",
146 | )
147 | parser.add_argument(
148 | "--learning_rate",
149 | type=float,
150 | default=5e-5,
151 | help="Initial learning rate (after the potential warmup period) to use.",
152 | )
153 | parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
154 | parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
155 | parser.add_argument(
156 | "--max_train_steps",
157 | type=int,
158 | default=None,
159 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
160 | )
161 | parser.add_argument(
162 | "--gradient_accumulation_steps",
163 | type=int,
164 | default=1,
165 | help="Number of updates steps to accumulate before performing a backward/update pass.",
166 | )
167 | parser.add_argument(
168 | "--lr_scheduler_type",
169 | type=SchedulerType,
170 | default="linear",
171 | help="The scheduler type to use.",
172 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
173 | )
174 | parser.add_argument(
175 | "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
176 | )
177 | parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
178 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
179 | parser.add_argument(
180 | "--trust_remote_code",
181 | type=bool,
182 | default=False,
183 | help=(
184 | "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
185 | "should only be set to `True` for repositories you trust and in which you have read the code, as it will"
186 | "execute code present on the Hub on your local machine."
187 | ),
188 | )
189 | parser.add_argument(
190 | "--checkpointing_steps",
191 | type=str,
192 | default=None,
193 | help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
194 | )
195 | parser.add_argument(
196 | "--resume_from_checkpoint",
197 | type=str,
198 | default=None,
199 | help="If the training should continue from a checkpoint folder.",
200 | )
201 | parser.add_argument(
202 | "--with_tracking",
203 | action="store_true",
204 | help="Whether to enable experiment trackers for logging.",
205 | )
206 | parser.add_argument(
207 | "--report_to",
208 | type=str,
209 | default="all",
210 | help=(
211 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
212 | ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations.'
213 | "Only applicable when `--with_tracking` is passed."
214 | ),
215 | )
216 | parser.add_argument(
217 | "--config_name",
218 | type=str,
219 | default=None,
220 | help="Pretrained config name or path if not the same as model_name",
221 | )
222 |
223 | args = parser.parse_args()
224 |
225 | # Sanity checks
226 | if args.task_name is None and args.train_file is None and args.validation_file is None:
227 | raise ValueError("Need either a task name or a training/validation file.")
228 | else:
229 | if args.train_file is not None:
230 | extension = args.train_file.split(".")[-1]
231 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
232 | if args.validation_file is not None:
233 | extension = args.validation_file.split(".")[-1]
234 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
235 |
236 | return args
237 |
238 |
239 | def run_eval(eval_dataloader, model, accelerator, is_regression, metric):
240 | samples_seen = 0
241 |
242 | for step, batch in enumerate(eval_dataloader):
243 | with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
244 | with torch.no_grad():
245 | predictions = model(**batch)
246 |
247 | references = batch["labels"]
248 | predictions, references = accelerator.gather_for_metrics((predictions, references))
249 |
250 | if predictions.logits.isnan().any():
251 | print("Warning: some of the output logits for evaluation were NaN!")
252 |
253 | predictions = predictions.logits.argmax(dim=-1) if not is_regression else predictions.logits.squeeze()
254 | # If we are in a multiprocess environment, the last batch has duplicates
255 | if step == len(eval_dataloader) - 1:
256 | predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
257 | references = references[: len(eval_dataloader.dataset) - samples_seen]
258 | else:
259 | samples_seen += references.shape[0]
260 | metric.add_batch(
261 | predictions=predictions,
262 | references=references,
263 | )
264 | return metric.compute()
265 |
266 |
267 | def main():
268 | args = parse_args()
269 | # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
270 | # information sent is the one passed as arguments along with your Python/PyTorch versions.
271 | # send_example_telemetry("run_glue_no_trainer", args)
272 |
273 | transformers.utils.send_example_telemetry("run_clm_no_trainer", args)
274 |
275 | # Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
276 | # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers
277 | # in the environment
278 | accelerator_log_kwargs = {}
279 |
280 | if args.with_tracking:
281 | accelerator_log_kwargs["log_with"] = args.report_to
282 | accelerator_log_kwargs["project_dir"] = args.output_dir
283 |
284 | accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, **accelerator_log_kwargs)
285 |
286 | # Make one log on every process with the configuration for debugging.
287 | logging.basicConfig(
288 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
289 | datefmt="%m/%d/%Y %H:%M:%S",
290 | level=logging.INFO,
291 | )
292 |
293 | logger.info(accelerator.state, main_process_only=False)
294 | if accelerator.is_local_main_process:
295 | datasets.utils.logging.set_verbosity_warning()
296 | transformers.utils.logging.set_verbosity_info()
297 | else:
298 | datasets.utils.logging.set_verbosity_error()
299 | transformers.utils.logging.set_verbosity_error()
300 |
301 | # If passed along, set the training seed now.
302 | if accelerator.is_main_process:
303 | if args.seed is not None:
304 | set_seed(args.seed)
305 |
306 | if args.output_dir is not None:
307 | os.makedirs(args.output_dir, exist_ok=True)
308 | # writer = SummaryWriter(args.output_dir)
309 | accelerator.wait_for_everyone()
310 |
311 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
312 | # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
313 |
314 | # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the
315 | # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named
316 | # label if at least two columns are provided.
317 |
318 | # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this
319 | # single column. You can easily tweak this behavior (see below)
320 |
321 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently
322 | # download the dataset.
323 | if args.task_name is not None:
324 | # Downloading and loading a dataset from the hub.
325 | raw_datasets = load_dataset("glue", args.task_name)
326 | else:
327 | # Loading the dataset from local csv or json file.
328 | data_files = {}
329 | if args.train_file is not None:
330 | data_files["train"] = args.train_file
331 | if args.validation_file is not None:
332 | data_files["validation"] = args.validation_file
333 | extension = (args.train_file if args.train_file is not None else args.validation_file).split(".")[-1]
334 | raw_datasets = load_dataset(extension, data_files=data_files)
335 | # See more about loading any type of standard or custom dataset at
336 | # https://huggingface.co/docs/datasets/loading_datasets.html.
337 |
338 | # Labels
339 | if args.task_name is not None:
340 | is_regression = args.task_name == "stsb"
341 | if not is_regression:
342 | label_list = raw_datasets["train"].features["label"].names
343 | num_labels = len(label_list)
344 | else:
345 | num_labels = 1
346 | else:
347 | # Trying to have good defaults here, don't hesitate to tweak to your needs.
348 | is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"]
349 | if is_regression:
350 | num_labels = 1
351 | else:
352 | # A useful fast method:
353 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
354 | label_list = raw_datasets["train"].unique("label")
355 | label_list.sort() # Let's sort it for determinism
356 | num_labels = len(label_list)
357 |
358 | # Load pretrained model and tokenizer
359 | #
360 | # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
361 | # download model & vocab.
362 | if args.config_name:
363 | config = AutoConfig.from_pretrained(
364 | args.config_name,
365 | trust_remote_code=args.trust_remote_code,
366 | )
367 | elif args.model_name_or_path:
368 | config = AutoConfig.from_pretrained(
369 | args.base_model,
370 | trust_remote_code=args.trust_remote_code,
371 | )
372 |
373 | # Model and tokenizer
374 | model = load_quantized_model(
375 | args.model_name_or_path, args.base_model,
376 | accelerator.device, sequence_classification=True,
377 | ).to(torch.bfloat16)
378 | for name, param in model.named_parameters():
379 | if 'SU' in name or 'SV' in name:
380 | param.requires_grad = False
381 |
382 | if args.tokenizer_name:
383 | tokenizer = AutoTokenizer.from_pretrained(
384 | args.tokenizer_name, use_fast=not args.use_slow_tokenizer, trust_remote_code=args.trust_remote_code
385 | )
386 | elif args.base_model:
387 | tokenizer = AutoTokenizer.from_pretrained(
388 | args.base_model,
389 | use_fast=not args.use_slow_tokenizer,
390 | )
391 | else:
392 | raise ValueError(
393 | "You are instantiating a new tokenizer from scratch. This is not supported by this script."
394 | "You can do it from another script, save it, and load it from here, using --tokenizer_name."
395 | )
396 | model.config.pad_token_id = 0
397 | tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token
398 | tokenizer.padding_side = "left" # Allow batched inference
399 | tokenizer.truncation_side = "left"
400 |
401 | embedding_size = model.get_input_embeddings().weight.shape[0]
402 | if len(tokenizer) > embedding_size:
403 | model.resize_token_embeddings(len(tokenizer))
404 |
405 | model = model.to(device)
406 |
407 | # Preprocessing the datasets
408 | if args.task_name is not None:
409 | sentence1_key, sentence2_key = task_to_keys[args.task_name]
410 | else:
411 | # Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
412 | non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"]
413 | if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names:
414 | sentence1_key, sentence2_key = "sentence1", "sentence2"
415 | else:
416 | if len(non_label_column_names) >= 2:
417 | sentence1_key, sentence2_key = non_label_column_names[:2]
418 | else:
419 | sentence1_key, sentence2_key = non_label_column_names[0], None
420 |
421 | # Some models have set the order of the labels to use, so let's make sure we do use it.
422 | label_to_id = None
423 | if (
424 | model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
425 | and args.task_name is not None
426 | and not is_regression
427 | ):
428 | # Some have all caps in their config, some don't.
429 | label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
430 | if sorted(label_name_to_id.keys()) == sorted(label_list):
431 | print(
432 | f"The configuration of the model provided the following label correspondence: {label_name_to_id}. "
433 | "Using it!"
434 | )
435 | label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)}
436 | else:
437 | print(
438 | "Your model seems to have been trained with labels, but they don't match the dataset: ",
439 | f"model labels: {sorted(label_name_to_id.keys())}, dataset labels: {sorted(label_list)}."
440 | "\nIgnoring the model labels as a result.",
441 | )
442 | elif args.task_name is None and not is_regression:
443 | label_to_id = {v: i for i, v in enumerate(label_list)}
444 |
445 | if label_to_id is not None:
446 | model.config.label2id = label_to_id
447 | model.config.id2label = {id: label for label, id in config.label2id.items()}
448 | elif args.task_name is not None and not is_regression:
449 | model.config.label2id = {l: i for i, l in enumerate(label_list)}
450 | model.config.id2label = {id: label for label, id in config.label2id.items()}
451 |
452 | padding = "max_length" if args.pad_to_max_length else False
453 |
454 | def preprocess_function(examples):
455 | # Tokenize the texts
456 | texts = (
457 | (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
458 | )
459 | result = tokenizer(*texts, padding=padding, max_length=args.max_length, truncation=True)
460 |
461 | if "label" in examples:
462 | if label_to_id is not None:
463 | # Map labels to IDs (not necessary for GLUE tasks)
464 | result["labels"] = [label_to_id[l] for l in examples["label"]]
465 | else:
466 | # In all cases, rename the column to labels because the model will expect that.
467 | result["labels"] = examples["label"]
468 |
469 | return result
470 |
471 | with accelerator.main_process_first():
472 | processed_datasets = raw_datasets.map(
473 | preprocess_function,
474 | batched=True,
475 | remove_columns=raw_datasets["train"].column_names,
476 | desc="Running tokenizer on dataset",
477 | )
478 |
479 | train_dataset = processed_datasets["train"]
480 | eval_dataset = processed_datasets["validation_matched" if args.task_name == "mnli" else "validation"]
481 |
482 | # DataLoaders creation:
483 | if args.pad_to_max_length:
484 | # If padding was already done ot max length, we use the default data collator that will just convert everything
485 | # to tensors.
486 | data_collator = default_data_collator
487 | else:
488 | # Otherwise, `DataCollatorWithPadding` will apply dynamic padding for us (by padding to the maximum length of
489 | # the samples passed). When using mixed precision, we add `pad_to_multiple_of=8` to pad all tensors to multiple
490 | # of 8s, which will enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 7.5 (Volta).
491 | data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=None)
492 |
493 | train_dataloader = DataLoader(
494 | train_dataset, shuffle=True, collate_fn=data_collator, batch_size=args.per_device_train_batch_size
495 | )
496 | eval_dataloader = DataLoader(eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size)
497 |
498 | # Optimizer
499 | # Split weights in two groups, one with weight decay and the other not.
500 | no_decay = ["bias", "LayerNorm.weight"]
501 | optimizer_grouped_parameters = [
502 | {
503 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
504 | "weight_decay": args.weight_decay,
505 | },
506 | {
507 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
508 | "weight_decay": 0.0,
509 | },
510 | ]
511 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
512 |
513 | # optimizer = torch.optim.SGD(optimizer_grouped_parameters, lr=args.learning_rate)
514 | # Scheduler and math around the number of training steps.
515 | overrode_max_train_steps = False
516 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
517 | if args.max_train_steps is None:
518 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
519 | overrode_max_train_steps = True
520 |
521 | lr_scheduler = get_scheduler(
522 | name=args.lr_scheduler_type,
523 | optimizer=optimizer,
524 | num_warmup_steps=args.num_warmup_steps,
525 | num_training_steps=args.max_train_steps,
526 | )
527 |
528 | # Prepare everything with our `accelerator`.
529 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
530 | model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
531 | )
532 |
533 | # We need to recalculate our total training steps as the size of the training dataloader may have changed
534 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
535 | if overrode_max_train_steps:
536 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
537 | # Afterwards we recalculate our number of training epochs
538 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
539 |
540 | # Figure out how many steps we should save the Accelerator states
541 | checkpointing_steps = args.checkpointing_steps
542 | if checkpointing_steps is not None and checkpointing_steps.isdigit():
543 | checkpointing_steps = int(checkpointing_steps)
544 |
545 | # We need to initialize the trackers we use, and also store our configuration.
546 | # The trackers initializes automatically on the main process.
547 | if args.with_tracking:
548 | experiment_config = vars(args)
549 | # TensorBoard cannot log Enums, need the raw value
550 | experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
551 | accelerator.init_trackers("glue_no_trainer", experiment_config)
552 |
553 | # Get the metric function
554 | if args.task_name is not None:
555 | metric = evaluate.load("glue", args.task_name)
556 | else:
557 | metric = evaluate.load("accuracy")
558 |
559 | # Train!
560 | total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps
561 |
562 | print("***** Running training *****")
563 | print(f" Num examples = {len(train_dataset)}")
564 | print(f" Num Epochs = {args.num_train_epochs}")
565 | print(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
566 | print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
567 | print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
568 | print(f" Total optimization steps = {args.max_train_steps}")
569 | # Only show the progress bar once on each machine.
570 | progress_bar = tqdm(range(args.max_train_steps))
571 | completed_steps = 0
572 | starting_epoch = 0
573 | # Potentially load in the weights and states from a previous save
574 | if args.resume_from_checkpoint:
575 | if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
576 | accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
577 | accelerator.load_state(args.resume_from_checkpoint)
578 | path = os.path.basename(args.resume_from_checkpoint)
579 | else:
580 | # Get the most recent checkpoint
581 | dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
582 | dirs.sort(key=os.path.getctime)
583 | path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
584 | # Extract `epoch_{i}` or `step_{i}`
585 | training_difference = os.path.splitext(path)[0]
586 |
587 | if "epoch" in training_difference:
588 | starting_epoch = int(training_difference.replace("epoch_", "")) + 1
589 | resume_step = None
590 | else:
591 | resume_step = int(training_difference.replace("step_", ""))
592 | starting_epoch = resume_step // len(train_dataloader)
593 | resume_step -= starting_epoch * len(train_dataloader)
594 |
595 | performace_dict = {}
596 | for epoch in range(starting_epoch, args.num_train_epochs):
597 | model.train()
598 | if args.with_tracking:
599 | total_loss = 0
600 | if args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
601 | # We skip the first `n` batches in the dataloader when resuming from a checkpoint
602 | train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
603 |
604 | for step, batch in enumerate(train_dataloader):
605 | # print(batch["attention_mask"])
606 | # # We need to skip steps until we reach the resumed step
607 | if args.resume_from_checkpoint and epoch == starting_epoch:
608 | if resume_step is not None and step < resume_step:
609 | completed_steps += 1
610 | continue
611 |
612 | with accelerator.accumulate(model):
613 | # print(batch)
614 | # return
615 | outputs = model(**batch)
616 | loss = outputs.loss
617 | # We keep track of the loss at each epoch
618 | if args.with_tracking:
619 | total_loss += loss.detach().float()
620 | accelerator.backward(loss)
621 | if completed_steps % 50:
622 | accelerator.print(f"Epoch: {epoch} | Step: {completed_steps} | Loss: {loss}")
623 | optimizer.step()
624 | lr_scheduler.step()
625 | optimizer.zero_grad()
626 |
627 | # Checks if the accelerator has performed an optimization step behind the scenes
628 | if accelerator.sync_gradients:
629 | progress_bar.update(1)
630 | completed_steps += 1
631 |
632 | if isinstance(checkpointing_steps, int):
633 | if completed_steps % checkpointing_steps == 0:
634 | output_dir = f"step_{completed_steps}"
635 | if args.output_dir is not None:
636 | output_dir = os.path.join(args.output_dir, output_dir)
637 | accelerator.save_state(output_dir)
638 |
639 | if completed_steps >= args.max_train_steps:
640 | break
641 |
642 | if completed_steps % 500 == 0 and step % args.gradient_accumulation_steps == 0 :
643 | logger.info(f"The current loss is {loss}")
644 |
645 | if completed_steps % (5000 * 32 / total_batch_size) == 0 and step % args.gradient_accumulation_steps == 0:
646 | model.eval()
647 | eval_metric = run_eval(eval_dataloader, model, accelerator, is_regression, metric)
648 |
649 | # for k, v in eval_metric.items():
650 | # writer.add_scalar(f"eval/{args.output_dir}/{k}", v, global_step=completed_steps)
651 | logger.info(
652 | f"seed {args.seed} learning rate {args.learning_rate} "
653 | + f"epoch {epoch}: {eval_metric}")
654 | performace_dict[completed_steps]=eval_metric[task_to_metrics[args.task_name]]
655 |
656 | torch.save(model.state_dict(), args.output_dir +f"/{completed_steps}.bin")
657 |
658 | if args.with_tracking:
659 | accelerator.log(
660 | {
661 | "accuracy" if args.task_name is not None else "glue": eval_metric,
662 | "train_loss": total_loss.item() / len(train_dataloader),
663 | "epoch": epoch,
664 | "step": completed_steps,
665 | },
666 | step=completed_steps,
667 | )
668 | if args.checkpointing_steps == "epoch":
669 | output_dir = f"epoch_{epoch}"
670 | if args.output_dir is not None:
671 | output_dir = os.path.join(args.output_dir, output_dir)
672 |
673 | accelerator.save_state(output_dir)
674 |
675 | accelerator.wait_for_everyone()
676 | if accelerator.is_main_process:
677 | from safetensors.torch import save_file
678 | save_file(accelerator.get_state_dict(model), output_dir + "/model.safetensors")
679 | print("Saved checkpoint in ", output_dir + "/model.safetensors")
680 | accelerator.wait_for_everyone()
681 |
682 | model.eval()
683 | eval_metric = run_eval(eval_dataloader, model, accelerator, is_regression, metric)
684 | logger.info(f"{args.output_dir} | epoch {epoch}: {eval_metric}")
685 | if args.with_tracking:
686 | accelerator.log(
687 | {
688 | "accuracy" if args.task_name is not None else "glue": eval_metric,
689 | "train_loss": total_loss.item() / len(train_dataloader),
690 | "epoch": epoch,
691 | "step": completed_steps,
692 | },
693 | step=completed_steps,
694 | )
695 | performace_dict[epoch] = eval_metric[task_to_metrics[args.task_name]]
696 |
697 | torch.save(model.state_dict(), args.output_dir + f"/{completed_steps}.bin")
698 |
699 | model.eval()
700 | eval_metric = run_eval(eval_dataloader, model, accelerator, is_regression, metric)
701 |
702 | print(f"{args.output_dir} | step {completed_steps}: {eval_metric}")
703 | if not eval:
704 | best_performance = max(performace_dict.values())
705 | max_keys = [k for k, v in performace_dict.items() if
706 | v == best_performance] # getting all keys containing the `maximum`
707 | print(f"seed {args.seed} learning rate {args.learning_rate} "
708 | + f"The best performance is at {max_keys[0]} with {best_performance}")
709 |
710 | if args.task_name == "mnli":
711 | # Final evaluation on mismatched validation set
712 | eval_dataset = processed_datasets["validation_mismatched"]
713 | eval_dataloader = DataLoader(
714 | eval_dataset, collate_fn=data_collator, batch_size=args.per_device_eval_batch_size
715 | )
716 |
717 | eval_metric = run_eval(eval_dataloader, model, accelerator, is_regression, metric)
718 | print(f"{args.output_dir}|mnli-mm: {eval_metric}")
719 |
720 | if args.output_dir is not None:
721 | all_results = {f"eval_{k}": v for k, v in eval_metric.items()}
722 | with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
723 | json.dump(all_results, f)
724 |
725 | if __name__ == "__main__":
726 | main()
727 |
--------------------------------------------------------------------------------