├── utils
├── __init__.py
├── mask.py
├── structure_arb.py
└── autosearch_arb.py
├── figs
├── qa.png
├── teaser.png
├── overview.png
├── wikitext2_opt.png
├── wikitext2_llama.png
└── wikitext2_vicuna.png
├── requirements.txt
├── modelutils.py
├── datautils.py
├── README.md
├── eval_ppl_utils.py
├── bigptq_arb.py
├── LICENSE
├── run_arb.py
└── binary_arb.py
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/figs/qa.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZHITENGLI/ARB-LLM/HEAD/figs/qa.png
--------------------------------------------------------------------------------
/figs/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZHITENGLI/ARB-LLM/HEAD/figs/teaser.png
--------------------------------------------------------------------------------
/figs/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZHITENGLI/ARB-LLM/HEAD/figs/overview.png
--------------------------------------------------------------------------------
/figs/wikitext2_opt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZHITENGLI/ARB-LLM/HEAD/figs/wikitext2_opt.png
--------------------------------------------------------------------------------
/figs/wikitext2_llama.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZHITENGLI/ARB-LLM/HEAD/figs/wikitext2_llama.png
--------------------------------------------------------------------------------
/figs/wikitext2_vicuna.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZHITENGLI/ARB-LLM/HEAD/figs/wikitext2_vicuna.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers==4.35.0
2 | datasets==2.14.6
3 | numpy==1.24.3
4 | huggingface-hub==0.16.4
5 | exceptiongroup
6 | protobuf
7 | sentencepiece
8 | pyparsing
9 | charset-normalizer==2.0.4
10 | pyarrow==12.0.0
11 |
--------------------------------------------------------------------------------
/modelutils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | DEV = torch.device('cuda:0')
6 |
7 |
8 | def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
9 | if type(module) in layers:
10 | return {name: module}
11 | res = {}
12 | for name1, child in module.named_children():
13 | res.update(find_layers(
14 | child, layers=layers, name=name + '.' + name1 if name != '' else name1
15 | ))
16 | return res
17 |
--------------------------------------------------------------------------------
/utils/mask.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 |
4 | '''
5 | Generate the structural mask on the basis of the split border
6 | '''
7 | def generate_structural_mask(origin_matrix, mask3, braq1_border):
8 | mask1_2 = ~mask3
9 |
10 | binary_group = torch.abs(origin_matrix*mask1_2)
11 |
12 | mask2 = binary_group >= braq1_border
13 | mask1 = binary_group < braq1_border
14 |
15 | mask1 = mask1 * mask1_2
16 | mask2 = mask2 * mask1_2
17 |
18 | return mask1, mask2
19 |
20 | def generate_multi_structural_mask(origin_matrix, mask3, braq1_border, braq2_border, braq3_border):
21 | mask1_2 = ~mask3
22 |
23 | binary_group = torch.abs(origin_matrix*mask1_2)
24 |
25 | mask4 = binary_group >= braq3_border
26 | mask1 = binary_group < braq1_border
27 | mask2 = (binary_group >= braq1_border) & (binary_group < braq2_border)
28 | mask3 = (binary_group >= braq2_border) & (binary_group < braq3_border)
29 |
30 | mask1 = mask1 * mask1_2
31 | mask2 = mask2 * mask1_2
32 | mask3 = mask3 * mask1_2
33 | mask4 = mask3 * mask1_2
34 |
35 | return mask1, mask2, mask3, mask4
36 |
37 |
38 | def generate_mask(origin_matrix, braq2_border, braq1_border):
39 | mask3 = torch.abs(origin_matrix) >= braq2_border
40 | mask1 = torch.abs(origin_matrix) <= braq1_border
41 | mask2 = (torch.abs(origin_matrix) > braq1_border) & (torch.abs(origin_matrix) < braq2_border)
42 | return mask1, mask2, mask3
--------------------------------------------------------------------------------
/utils/structure_arb.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils.autosearch_arb import structural_searching_multip, structural_searching_multip_alternating_group, structural_searching_multip_alternating_group_x, structural_searching_multip_alternating_group_rc
3 |
4 | import logging
5 | logger = logging.getLogger()
6 |
7 | '''
8 | Used to generate masks for minor structural 2-bit salient data and split major 1-bit normal data according to different metric.
9 | '''
10 | def structural_guassian_distribution_multip_alternating_group_x(tmp, H=None, metric="magnitude", up_lim=30, num_p=1, inp=None, method='arb', order2_group=False):
11 | if metric == "hessian":
12 | target_weights = tmp ** 2 / (torch.diag(H).reshape((1, -1))) ** 2
13 | elif metric == "magnitude":
14 | target_weights = tmp
15 | else:
16 | raise NotImplementedError
17 |
18 | # print(f'debug', inp)
19 | if method == 'arb':
20 | optimal_split_list, mask_list = structural_searching_multip_alternating_group(target_weights, up_lim, num_p, inp, order2_group=order2_group)
21 | elif method == 'arb-x':
22 | optimal_split_list, mask_list = structural_searching_multip_alternating_group_x(target_weights, up_lim, num_p, inp, order2_group=order2_group)
23 | elif method == 'arb-rc':
24 | optimal_split_list, mask_list = structural_searching_multip_alternating_group_rc(target_weights, up_lim, num_p, inp, order2_group=order2_group)
25 | elif method == 'braq':
26 | optimal_split_list, mask_list = structural_searching_multip(target_weights, up_lim, num_p, order2_group=order2_group)
27 |
28 | # print(mask1.sum() / mask1.numel(), mask2.sum() / mask2.numel(), mask3.sum() / mask3.numel())
29 | mask_ratio = []
30 | for i in range(len(mask_list)):
31 | mask_ratio.append(mask_list[i].sum() / mask_list[i].numel())
32 |
33 | ratios_info = ", ".join([f"mask{idx+1} ratio: {ratio:.2f}" for idx, ratio in enumerate(mask_ratio)])
34 | logger.info(ratios_info)
35 |
36 | return mask_list
37 |
--------------------------------------------------------------------------------
/datautils.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | import torch
5 | from datasets import load_dataset, load_from_disk
6 | from transformers import AutoTokenizer, LlamaTokenizer
7 | import os
8 |
9 |
10 | def set_seed(seed):
11 | np.random.seed(seed)
12 | torch.random.manual_seed(seed)
13 |
14 | '''
15 | Generate tokenizer and return it to preload datasets by converting them to embedded vectors instead of natural words
16 | '''
17 | def get_tokenizer(model):
18 | if "llama" in model.lower():
19 | if '3' in model:
20 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
21 | else:
22 | tokenizer = LlamaTokenizer.from_pretrained(model, use_fast=False)
23 | # fix for transformer 4.28.0.dev0 compatibility
24 | if tokenizer.bos_token_id != 1 or tokenizer.eos_token_id != 2:
25 | try:
26 | tokenizer.bos_token_id = 1
27 | tokenizer.eos_token_id = 2
28 | except AttributeError:
29 | pass
30 | else:
31 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
32 | return tokenizer
33 |
34 | def get_wikitext2(nsamples, seed, seqlen, model, tokenizer):
35 |
36 | traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
37 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
38 |
39 | # traindata = load_from_disk('/data/dataset/llm/wikitext/traindata')
40 | # testdata = load_from_disk('/data/dataset/llm/wikitext/testdata')
41 |
42 | trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt')
43 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')
44 |
45 | random.seed(seed)
46 | trainloader = []
47 | for _ in range(nsamples):
48 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
49 | j = i + seqlen
50 | inp = trainenc.input_ids[:, i:j]
51 | tar = inp.clone()
52 | tar[:, :-1] = -100
53 | trainloader.append((inp, tar))
54 | return trainloader, testenc
55 |
56 | def get_ptb(nsamples, seed, seqlen, model, tokenizer):
57 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
58 | testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test')
59 |
60 | # traindata = load_from_disk('/data/dataset/llm/ptb/traindata')
61 | # testdata = load_from_disk('/data/dataset/llm/ptb/testdata')
62 |
63 | trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt')
64 | testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt')
65 |
66 | random.seed(seed)
67 | trainloader = []
68 | for _ in range(nsamples):
69 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
70 | j = i + seqlen
71 | inp = trainenc.input_ids[:, i:j]
72 | tar = inp.clone()
73 | tar[:, :-1] = -100
74 | trainloader.append((inp, tar))
75 | return trainloader, testenc
76 |
77 | class TokenizerWrapper:
78 | def __init__(self, input_ids):
79 | self.input_ids = input_ids
80 |
81 | def get_c4(nsamples, seed, seqlen, model, tokenizer):
82 | traindata = load_dataset(
83 | 'allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train'
84 | )
85 | valdata = load_dataset(
86 | 'allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation'
87 | )
88 |
89 | # traindata = load_from_disk('/data/dataset/llm/c4/traindata')
90 | # valdata = load_from_disk('/data/dataset/llm/c4/valdata')
91 |
92 | random.seed(seed)
93 | trainloader = []
94 | for _ in range(nsamples):
95 | while True:
96 | i = random.randint(0, len(traindata) - 1)
97 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
98 | if trainenc.input_ids.shape[1] > seqlen:
99 | break
100 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
101 | j = i + seqlen
102 | inp = trainenc.input_ids[:, i:j]
103 | tar = inp.clone()
104 | tar[:, :-1] = -100
105 | trainloader.append((inp, tar))
106 |
107 | valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
108 | valenc = valenc.input_ids[:, :(256 * seqlen)]
109 |
110 |
111 | valenc = TokenizerWrapper(valenc)
112 |
113 | return trainloader, valenc
114 |
115 | def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model=''):
116 | cache_file=f'cache/{name}_{nsamples}_{seed}_{seqlen}_{model}.pt'
117 | try:
118 | return torch.load(cache_file)
119 | except:
120 | pass
121 |
122 | tokenizer = get_tokenizer(model)
123 |
124 | if 'wikitext2' in name:
125 | loaders= get_wikitext2(nsamples, seed, seqlen, model, tokenizer)
126 | if 'ptb' in name:
127 | loaders= get_ptb(nsamples, seed, seqlen, model, tokenizer)
128 | if 'c4' in name:
129 | loaders= get_c4(nsamples, seed, seqlen, model, tokenizer)
130 | directory='/'.join(cache_file.split('/')[:-1])
131 | if not os.path.exists(directory):
132 | os.makedirs(directory)
133 |
134 | torch.save(loaders,cache_file)
135 | return loaders
136 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [ICLR'25] ARB-LLM: Alternating Refined Binarizations for Large Language Models
2 |
3 | [Zhiteng Li](https://zhitengli.github.io), Xianglong Yan, Tianao Zhang, [Haotong Qin](https://htqin.github.io/), Dong Xie, Jiang Tian, Zhongchao Shi, [Linghe Kong](https://www.cs.sjtu.edu.cn/~linghe.kong/), [Yulun Zhang](http://yulunzhang.com/), and [Xiaokang Yang](https://scholar.google.com/citations?user=yDEavdMAAAAJ), "ARB-LLM: Alternating Refined Binarizations for Large Language Models", ICLR, 2025
4 |
5 | [[arXiv](https://arxiv.org/pdf/2410.03129
6 | )] [[supplementary material](https://github.com/ZHITENGLI/ARB-LLM/releases/tag/v1)]
7 |
8 |
9 | #### 🔥🔥🔥 News
10 |
11 | - **2025-02-16:** Code is released. ⭐️⭐️⭐️
12 | - **2025-01-23:** ARB-LLM is accepted at ICLR 2025. 🎉🎉🎉
13 | - **2024-10-03:** This repo is released.
14 |
15 | ---
16 |
17 | > **Abstract:** Large Language Models (LLMs) have greatly pushed forward advancements in natural language processing, yet their high memory and computational demands hinder practical deployment. Binarization, as an effective compression technique, can shrink model weights to just 1 bit, significantly reducing the high demands on computation and memory. However, current binarization methods struggle to narrow the distribution gap between binarized and full-precision weights, while also overlooking the column deviation in LLM weight distribution. To tackle these issues, we propose ARB-LLM, a novel 1-bit post-training quantization (PTQ) technique tailored for LLMs. To narrow the distribution shift between binarized and full-precision weights, we first design an alternating refined binarization (ARB) algorithm to progressively update the binarization parameters, which significantly reduces the quantization error. Moreover, considering the pivot role of calibration data and the column deviation in LLM weights, we further extend ARB to ARB-X and ARB-RC. In addition, we refine the weight partition strategy with column-group bitmap (CGB), which further enhance performance. Equipping ARB-X and ARB-RC with CGB, we obtain ARB-LLMX
18 | and ARB-LLMRC
19 | respectively, which significantly outperform state-of-the-art (SOTA) binarization methods for LLMs.
20 | As a binary PTQ method, our ARB-LLMRC
21 | is the first to surpass FP16 models of the same size. The code and models will be available at https://github.com/ZHITENGLI/ARB-LLM.
22 |
23 | 
24 |
25 | ---
26 |
27 | Figure 1 in the main paper demonstrates that our proposed ARB-LLMRC outperforms the previous state-of-the-art binary PTQ method, BiLLM, across all scales of the OPT model family. Furthermore, our binarized model surpasses full-precision models of similar size. For example, the memory footprint of the binarized OPT-13B is comparable to that of the full-precision OPT-2.7B, yet the binarized model achieves better performance.
28 |
29 |
30 |
31 |
32 |
33 | ## Dependencies
34 |
35 | ```bash
36 | # Clone the github repo and go to the default directory 'ARB-LLM'.
37 | git clone https://github.com/ZHITENGLI/ARB-LLM.git
38 | conda create -n arbllm python=3.11
39 | conda activate arbllm
40 | pip install torch torchvision torchaudio
41 | pip install -r requirements.txt
42 | ```
43 |
44 | ## 🔗 Contents
45 |
46 | 1. [Post-training quantization and evaluation](#post-training-quantization)
47 | 2. [Results](#-results)
48 | 3. [Citation](#citation)
49 | 4. [Acknowledgements](#-acknowledgements)
50 |
51 | ## Post-training quantization with PPL evaluation
52 |
53 | ### Binarization for OPT families
54 |
55 | - ARB-X
56 | ```shell
57 | python3 run_arb.py facebook/opt-6.7b c4 arb-x --blocksize 128 --salient_metric hessian --device "cuda:0" --save --num_p 1 --order2_group
58 | ```
59 |
60 | - ARB-RC
61 | ```shell
62 | python3 run_arb.py facebook/opt-6.7b c4 arb-rc --blocksize 128 --salient_metric hessian --device "cuda:0" --save --num_p 1 --order2_group
63 | ```
64 |
65 | ### Binarization for LLaMA families
66 |
67 | - ARB-X
68 | ```shell
69 | python3 run_arb.py meta-llama/llama-2-7b-hf c4 arb-x --blocksize 128 --salient_metric hessian --device "cuda:0" --save --num_p 1 --order2_group
70 | ```
71 |
72 | - ARB-RC
73 | ```shell
74 | python3 run_arb.py meta-llama/llama-2-7b-hf c4 arb-rc --blocksize 128 --salient_metric hessian --device "cuda:0" --save --num_p 1 --order2_group
75 | ```
76 |
77 | ### Binarization for Vicuna families (Instruction Fine-tuning Models)
78 |
79 | - ARB-X
80 | ```shell
81 | python3 run_arb.py lmsys/vicuna-7b-v1.5 c4 arb-x --blocksize 128 --salient_metric hessian --device "cuda:0" --save --num_p 1 --order2_group
82 | ```
83 |
84 | - ARB-RC
85 | ```shell
86 | python3 run_arb.py lmsys/vicuna-7b-v1.5 c4 arb-rc --blocksize 128 --salient_metric hessian --device "cuda:0" --save --num_p 1 --order2_group
87 | ```
88 |
89 | ## Evaluation on zero-shot QA datasets
90 |
91 | We use [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) kit to evaluate performance on QA datasets. Please refer to their framework for evaluating quantized models.
92 |
93 | ## 🔎 Results
94 |
95 |
96 | ARB-LLM achieves superior perplexity performance on WikiText2 datasets. (click to expand)
97 |
98 | - OPT family
99 |
100 |
101 |
102 |
103 | - LLaMA, LLaMA-2 and LLaMA-3 families
104 |
105 |
106 |
107 |
108 | - Vicuna 7B and 13B
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 | ARB-LLM achieves superior average accuracy on 7 zero-shot QA datasets. (click to expand)
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 | ## Citation
125 |
126 | If you find the code helpful in your research or work, please cite the following paper.
127 |
128 | ```
129 | @article{li2024arbllmalternatingrefinedbinarizations,
130 | title={ARB-LLM: Alternating Refined Binarizations for Large Language Models},
131 | author={Zhiteng Li and Xianglong Yan and Tianao Zhang and Haotong Qin and Dong Xie and Jiang Tian and zhongchao shi and Linghe Kong and Yulun Zhang and Xiaokang Yang},
132 | year={2024},
133 | eprint={2410.03129},
134 | archivePrefix={arXiv},
135 | primaryClass={cs.CV},
136 | url={https://arxiv.org/abs/2410.03129},
137 | }
138 | ```
139 |
140 | ## 💡 Acknowledgements
141 |
142 | This work is released under the Apache 2.0 license.
143 | The codes are based on [BiLLM](https://github.com/Aaronhuang-778/BiLLM). Please also follow their licenses. Thanks for their awesome works.
144 |
--------------------------------------------------------------------------------
/eval_ppl_utils.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | import logging
7 | logger = logging.getLogger()
8 |
9 | @torch.no_grad()
10 | def llama_eval(model, testenc, dev, dataset: str, log_wandb: bool = False):
11 | print("Evaluating ...")
12 |
13 | testenc = testenc.input_ids
14 | nsamples = testenc.numel() // model.seqlen
15 |
16 | use_cache = model.config.use_cache
17 | model.config.use_cache = False
18 | layers = model.model.layers
19 |
20 | model.model.embed_tokens = model.model.embed_tokens.to(dev)
21 | layers[0] = layers[0].to(dev)
22 |
23 | dtype = next(iter(model.parameters())).dtype
24 | inps = torch.zeros(
25 | (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
26 | )
27 | cache = {"i": 0, "attention_mask": None}
28 |
29 | class Catcher(nn.Module):
30 | def __init__(self, module):
31 | super().__init__()
32 | self.module = module
33 |
34 | def forward(self, inp, **kwargs):
35 | inps[cache["i"]] = inp
36 | cache["i"] += 1
37 | cache["attention_mask"] = kwargs["attention_mask"]
38 | raise ValueError
39 |
40 | layers[0] = Catcher(layers[0])
41 | for i in range(nsamples):
42 | batch = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)].to(dev)
43 | try:
44 | model(batch)
45 | except ValueError:
46 | pass
47 | layers[0] = layers[0].module
48 |
49 | layers[0] = layers[0].cpu()
50 | model.model.embed_tokens = model.model.embed_tokens.cpu()
51 | torch.cuda.empty_cache()
52 |
53 | outs = torch.zeros_like(inps)
54 | attention_mask = cache["attention_mask"]
55 |
56 | for i in range(len(layers)):
57 | # print(i)
58 | layer = layers[i].to(dev)
59 |
60 | for j in range(nsamples):
61 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
62 | layers[i] = layer.cpu()
63 | del layer
64 | torch.cuda.empty_cache()
65 | inps, outs = outs, inps
66 |
67 | if model.model.norm is not None:
68 | model.model.norm = model.model.norm.to(dev)
69 | model.lm_head = model.lm_head.to(dev)
70 |
71 | testenc = testenc.to(dev)
72 | nlls = []
73 | for i in range(nsamples):
74 | hidden_states = inps[i].unsqueeze(0)
75 | if model.model.norm is not None:
76 | hidden_states = model.model.norm(hidden_states)
77 | lm_logits = model.lm_head(hidden_states)
78 | shift_logits = lm_logits[:, :-1, :].contiguous()
79 | shift_labels = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)][:, 1:]
80 | loss_fct = nn.CrossEntropyLoss()
81 | loss = loss_fct(
82 | shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
83 | )
84 | neg_log_likelihood = loss.float() * model.seqlen
85 | nlls.append(neg_log_likelihood)
86 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
87 | print(f"Perplexity: {ppl.item():3f}")
88 | logger.info(f"{dataset}/Perplexity: {ppl.item():3f}")
89 |
90 | model.config.use_cache = use_cache
91 |
92 | @torch.no_grad()
93 | def opt_eval(model, testenc, dev, dataset: str, log_wandb: bool = False):
94 | print('Evaluating ...')
95 |
96 | testenc = testenc.input_ids
97 | nsamples = testenc.numel() // model.seqlen
98 |
99 | use_cache = model.config.use_cache
100 | model.config.use_cache = False
101 | layers = model.model.decoder.layers
102 |
103 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev)
104 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev)
105 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
106 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
107 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
108 | model.model.decoder.project_in = model.model.decoder.project_in.to(dev)
109 | layers[0] = layers[0].to(dev)
110 |
111 | dtype = next(iter(model.parameters())).dtype
112 | inps = torch.zeros(
113 | (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
114 | )
115 | cache = {'i': 0, 'attention_mask': None}
116 |
117 | class Catcher(nn.Module):
118 | def __init__(self, module):
119 | super().__init__()
120 | self.module = module
121 | def forward(self, inp, **kwargs):
122 | inps[cache['i']] = inp
123 | cache['i'] += 1
124 | cache['attention_mask'] = kwargs['attention_mask']
125 | raise ValueError
126 | layers[0] = Catcher(layers[0])
127 | for i in range(nsamples):
128 | batch = testenc[:, (i * model.seqlen):((i + 1) * model.seqlen)].to(dev)
129 | try:
130 | model(batch)
131 | except ValueError:
132 | pass
133 | layers[0] = layers[0].module
134 |
135 | layers[0] = layers[0].cpu()
136 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu()
137 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu()
138 | if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
139 | model.model.decoder.project_out = model.model.decoder.project_out.cpu()
140 | if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
141 | model.model.decoder.project_in = model.model.decoder.project_in.cpu()
142 | torch.cuda.empty_cache()
143 |
144 | outs = torch.zeros_like(inps)
145 | attention_mask = cache['attention_mask']
146 |
147 | for i in range(len(layers)):
148 | # print(i)
149 | layer = layers[i].to(dev)
150 |
151 | for j in range(nsamples):
152 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
153 | layers[i] = layer.cpu()
154 | del layer
155 | torch.cuda.empty_cache()
156 | inps, outs = outs, inps
157 |
158 | if model.model.decoder.final_layer_norm is not None:
159 | model.model.decoder.final_layer_norm = model.model.decoder.final_layer_norm.to(dev)
160 | if model.model.decoder.project_out is not None:
161 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
162 | model.lm_head = model.lm_head.to(dev)
163 |
164 | testenc = testenc.to(dev)
165 | nlls = []
166 | for i in range(nsamples):
167 | hidden_states = inps[i].unsqueeze(0)
168 | if model.model.decoder.final_layer_norm is not None:
169 | hidden_states = model.model.decoder.final_layer_norm(hidden_states)
170 | if model.model.decoder.project_out is not None:
171 | hidden_states = model.model.decoder.project_out(hidden_states)
172 | lm_logits = model.lm_head(hidden_states)
173 | shift_logits = lm_logits[:, :-1, :].contiguous()
174 | shift_labels = testenc[
175 | :, (i * model.seqlen):((i + 1) * model.seqlen)
176 | ][:, 1:]
177 | loss_fct = nn.CrossEntropyLoss()
178 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
179 | neg_log_likelihood = loss.float() * model.seqlen
180 | nlls.append(neg_log_likelihood)
181 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen))
182 | print(f"Perplexity: {ppl.item():3f}")
183 | print({f'{dataset}/perplexity': ppl.item()})
184 | logger.info(f"Perplexity: {ppl.item():3f}")
185 | logger.info({f'{dataset}/perplexity': ppl.item()})
186 |
187 | model.config.use_cache = use_cache
--------------------------------------------------------------------------------
/bigptq_arb.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 | from exceptiongroup import catch
4 | import torch
5 | import torch.nn as nn
6 | import transformers
7 | from utils.structure_arb import structural_guassian_distribution_multip_alternating_group_x
8 |
9 | import logging
10 | logger = logging.getLogger()
11 |
12 | DEBUG = True
13 |
14 | torch.backends.cuda.matmul.allow_tf32 = False
15 | torch.backends.cudnn.allow_tf32 = False
16 |
17 | '''
18 | BRAGPTQ is the meaning of GPTQ used Binary Residual Approximation in paper to realize 1-bit quantization
19 | BRAGPTQ uses structural mask to distinguish outliers and other data, and takes advantage of part of GPTQ to lower error
20 | '''
21 | class BRAGPTQ:
22 | def __init__(
23 | self, layer, braq_quantizer,salient_metric, disable_gptq=False, method='arb', order2_group=False
24 | ):
25 | self.method = method
26 | self.order2_group = order2_group
27 | self.layer = layer
28 | self.dev = self.layer.weight.device
29 | W = layer.weight.data.clone()
30 | if isinstance(self.layer, nn.Conv2d):
31 | W = W.flatten(1)
32 | if isinstance(self.layer, transformers.Conv1D):
33 | W = W.t()
34 | self.rows = W.shape[0]
35 | self.columns = W.shape[1]
36 | self.H = torch.zeros((self.columns, self.columns), device=self.dev)
37 | self.nsamples = 0
38 | self.braq_quantizer = braq_quantizer
39 | self.salient_metric = salient_metric # "magnitude" or "hessian"
40 | self.disable_gptq = disable_gptq
41 |
42 | self.inp = []
43 | # self.inp2 = torch.zeros(self.columns, self.columns, device=self.dev, dtype=torch.float16)
44 |
45 | def add_batch(self, inp, out, blocksize=1024):
46 | if DEBUG:
47 | self.inp1 = inp
48 | self.out1 = out
49 |
50 | # save memory
51 | # print(inp.shape) # [1,2048,4096]
52 | # print(inp[0].T.shape)
53 | # self.inp2 = self.inp2 + (inp[0].T @ inp[0])
54 |
55 | if len(inp.shape) == 2:
56 | inp = inp.unsqueeze(0)
57 |
58 | if self.method == 'arb-x':
59 | self.inp.append(inp)
60 |
61 | tmp = inp.shape[0]
62 | if isinstance(self.layer, nn.Linear) or isinstance(
63 | self.layer, transformers.Conv1D
64 | ):
65 | if len(inp.shape) == 3:
66 | inp = inp.reshape((-1, inp.shape[-1]))
67 | inp = inp.t()
68 | self.H *= self.nsamples / (self.nsamples + tmp)
69 | self.nsamples += tmp
70 | inp = math.sqrt(2 / self.nsamples) * inp.float()
71 | self.H += inp.matmul(inp.t())
72 | # breakpoint()
73 |
74 | def fasterquant(self,
75 | blocksize=128,
76 | percdamp=0.01,
77 | orders=(1,1,2),
78 | num_p=1,
79 | ):
80 | W = self.layer.weight.data.clone()
81 | if isinstance(self.layer, nn.Conv2d):
82 | W = W.flatten(1)
83 | if isinstance(self.layer, transformers.Conv1D):
84 | W = W.t()
85 | W = W.float()
86 | tick = time.time()
87 |
88 | H = self.H
89 | del self.H
90 | dead = torch.diag(H) == 0
91 | H[dead, dead] = 1
92 | W[:, dead] = 0
93 |
94 | Losses = torch.zeros(self.rows, device=self.dev)
95 |
96 | damp = percdamp * torch.mean(torch.diag(H))
97 | diag = torch.arange(self.columns, device=self.dev)
98 | H[diag, diag] += damp
99 | H = torch.linalg.cholesky(H)
100 | H = torch.cholesky_inverse(H)
101 | H = torch.linalg.cholesky(H, upper=True)
102 | Hinv = H
103 |
104 | if self.method == 'arb-x':
105 | self.inp = torch.concat(self.inp)
106 | # print(self.inp.shape)
107 |
108 | for blocki, col_st in enumerate(range(0, self.columns, blocksize)):
109 | col_ed = min(col_st + blocksize, self.columns)
110 | n_cols = col_ed - col_st
111 |
112 | st = col_st
113 | ed = col_ed
114 |
115 | if self.method == 'arb-x':
116 | # S = torch.einsum('bki,bkj->ij', self.inp[:, :, st:ed], self.inp[:, :, st:ed])
117 | S = torch.matmul(self.inp[:, :, st:ed].to(torch.float32).transpose(1, 2), self.inp[:, :, st:ed].to(torch.float32)).mean(dim=0) # avoid overflow
118 | else:
119 | S = None
120 | # S = self.inp2[st:ed, st:ed]
121 | # print(S==S2)
122 |
123 | if self.order2_group:
124 | num_mask = 2 * (num_p+1)
125 | orders = [2 for _ in range(num_p+1)] + [1 for _ in range(num_p+1)]
126 | else:
127 | num_mask = 1 + num_p + 1
128 | orders = [2] + [1 for _ in range(num_p+1)]
129 | mask = torch.zeros_like(W[:, st:ed], dtype=torch.bool).unsqueeze(0).repeat_interleave(num_mask, dim=0)
130 | mask_list = structural_guassian_distribution_multip_alternating_group_x(W[:, st:ed], H[st:ed, st:ed], self.salient_metric, 50, num_p, S, self.method, self.order2_group)
131 | for i in range(num_mask):
132 | mask[i] = mask_list[i]
133 |
134 | assert self.braq_quantizer.groupsize % blocksize == 0
135 |
136 | if self.disable_gptq:
137 | # RTN
138 | # print("RTN")
139 | w = W[:, col_st:col_ed]
140 |
141 | # from low to high group
142 | q_part_groups = []
143 | for i in range(mask.shape[0]):
144 | q_part_groups.append(self.braq_quantizer.quantize(w, mask[i], order=orders[i]))
145 |
146 | q = torch.zeros_like(w)
147 | for j in range(mask.shape[0]):
148 | q += q_part_groups[j][:] * mask[j, :]
149 | W[:, col_st:col_ed] = q
150 | else:
151 | # shape of W1: [oc, n_cols]
152 | W1 = W[:, col_st:col_ed].clone()
153 | Q1 = torch.zeros_like(W1)
154 | Err1 = torch.zeros_like(W1)
155 | Losses1 = torch.zeros_like(W1)
156 | Hinv1 = Hinv[col_st:col_ed, col_st:col_ed]
157 |
158 | # old_q_part_groups = []
159 | q_part_groups = []
160 |
161 | for i in range(mask.shape[0]):
162 | q_part_groups.append(self.braq_quantizer.quantize(W1, mask[i], order=orders[i], S=S))
163 |
164 | for i in range(n_cols):
165 | # shape of w: [oc, 1]
166 | w = W1[:, i]
167 | d = Hinv1[i, i]
168 |
169 | q = torch.zeros_like(w)
170 | for j in range(mask.shape[0]):
171 | q += q_part_groups[j][:, i] * mask[j, :, i]
172 |
173 | Q1[:, i] = q
174 | Losses1[:, i] = (w - q) ** 2 / d**2
175 | # breakpoint()
176 |
177 | err1 = (w - q) / d
178 | Err1[:, i] = err1
179 |
180 | W[:, col_st:col_ed] = Q1
181 | Losses += torch.sum(Losses1, 1) / 2
182 |
183 | W[:, col_ed:] -= Err1.matmul(Hinv[col_st:col_ed, col_ed:])
184 |
185 | if DEBUG:
186 | self.layer.weight.data[:, :col_ed] = W[:, :col_ed]
187 | self.layer.weight.data[:, col_ed:] = W[:, col_ed:]
188 | x_error = torch.sum((self.layer(self.inp1) - self.out1) ** 2)
189 | # print(torch.sum(Losses))
190 |
191 | torch.cuda.synchronize()
192 | # print("time %.2f" % (time.time() - tick))
193 | # print("error", torch.sum(Losses).item())
194 | times = time.time() - tick
195 | logger.info(f'time {times:.2f}')
196 | logger.info(f'error {torch.sum(Losses).item()}')
197 | logger.info(f'x error {x_error.item()}')
198 |
199 | if isinstance(self.layer, transformers.Conv1D):
200 | W = W.t()
201 | self.layer.weight.data = W.reshape(self.layer.weight.shape).to(
202 | self.layer.weight.data.dtype
203 | )
204 | if DEBUG:
205 | print(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
206 |
207 | del mask
208 | # del mask1, mask2, mask3
209 | del mask_list
210 | if not self.disable_gptq:
211 | del W1, Q1, W, Err1, Losses1, Hinv1
212 | del H, Hinv, self.inp, S, q_part_groups
213 | # del H, Hinv, self.inp2, S, q_part_groups
214 | torch.cuda.empty_cache()
215 | return {"error": torch.sum(Losses).item()}
216 |
217 | def free(self):
218 | if DEBUG:
219 | self.inp1 = None
220 | self.out1 = None
221 | self.H = None
222 | torch.cuda.empty_cache()
223 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2025 ARB-LLM Authors
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/run_arb.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
3 |
4 | import time
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 | from bigptq_arb import BRAGPTQ
10 | from binary_arb import Binarization
11 | from modelutils import find_layers
12 | from datautils import get_tokenizer
13 |
14 | import logging
15 |
16 |
17 | def setup_logger(log_file):
18 | logger = logging.getLogger()
19 | logger.setLevel(logging.DEBUG)
20 |
21 | console_handler = logging.StreamHandler()
22 | console_handler.setLevel(logging.DEBUG)
23 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
24 | console_handler.setFormatter(formatter)
25 |
26 | file_handler = logging.FileHandler(log_file)
27 | file_handler.setLevel(logging.DEBUG)
28 | file_handler.setFormatter(formatter)
29 |
30 | logger.addHandler(console_handler)
31 | logger.addHandler(file_handler)
32 |
33 | return logger
34 |
35 |
36 | def get_model(model):
37 | import torch
38 |
39 | def skip(*args, **kwargs):
40 | pass
41 |
42 | torch.nn.init.kaiming_uniform_ = skip
43 | torch.nn.init.uniform_ = skip
44 | torch.nn.init.normal_ = skip
45 | if "opt" in model:
46 | from transformers import OPTForCausalLM
47 |
48 | model = OPTForCausalLM.from_pretrained(model, torch_dtype="auto")
49 | model.seqlen = model.config.max_position_embeddings
50 | elif "llama" in model:
51 | from transformers import LlamaForCausalLM
52 |
53 | model = LlamaForCausalLM.from_pretrained(model, torch_dtype="auto")
54 | model.seqlen = 2048
55 | return model
56 |
57 |
58 | '''
59 | The function is employed to calibrate and quantize models layer by layer.
60 | '''
61 | @torch.no_grad()
62 | def quant_sequential(model, dataloader, dev):
63 | print("Starting ...")
64 |
65 | for name, module in model.named_modules():
66 | module.global_name = args.model + name
67 |
68 | use_cache = model.config.use_cache
69 | model.config.use_cache = False
70 |
71 | if "opt" in args.model:
72 | layers = model.model.decoder.layers
73 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev)
74 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(
75 | dev
76 | )
77 | if (
78 | hasattr(model.model.decoder, "project_out")
79 | and model.model.decoder.project_out
80 | ):
81 | model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
82 | if (
83 | hasattr(model.model.decoder, "project_in")
84 | and model.model.decoder.project_in
85 | ):
86 | model.model.decoder.project_in = model.model.decoder.project_in.to(dev)
87 | elif "llama" in args.model:
88 | layers = model.model.layers
89 | model.model.embed_tokens = model.model.embed_tokens.to(dev)
90 | model.model.norm = model.model.norm.to(dev)
91 | layers[0] = layers[0].to(dev)
92 |
93 | dtype = next(iter(model.parameters())).dtype
94 | inps = torch.zeros(
95 | (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
96 | )
97 | cache = {"i": 0, "attention_mask": None}
98 |
99 | class Catcher(nn.Module):
100 | def __init__(self, module):
101 | super().__init__()
102 | self.module = module
103 |
104 | def forward(self, inp, **kwargs):
105 | inps[cache["i"]] = inp
106 | cache["i"] += 1
107 | cache["attention_mask"] = kwargs["attention_mask"]
108 | raise ValueError
109 |
110 | layers[0] = Catcher(layers[0])
111 | for batch in dataloader:
112 | try:
113 | model(batch[0].to(dev))
114 | except ValueError:
115 | pass
116 | layers[0] = layers[0].module
117 |
118 | layers[0] = layers[0].cpu()
119 | if "opt" in args.model:
120 | model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu()
121 | model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu()
122 | if (
123 | hasattr(model.model.decoder, "project_out")
124 | and model.model.decoder.project_out
125 | ):
126 | model.model.decoder.project_out = model.model.decoder.project_out.cpu()
127 | if (
128 | hasattr(model.model.decoder, "project_in")
129 | and model.model.decoder.project_in
130 | ):
131 | model.model.decoder.project_in = model.model.decoder.project_in.cpu()
132 | elif "llama" in args.model:
133 | model.model.embed_tokens = model.model.embed_tokens.cpu()
134 | model.model.norm = model.model.norm.cpu()
135 | torch.cuda.empty_cache()
136 |
137 | outs = torch.zeros_like(inps)
138 | attention_mask = cache["attention_mask"]
139 |
140 | print("Ready.")
141 |
142 | for i in range(len(layers)):
143 | layer = layers[i].to(dev)
144 |
145 | subset = find_layers(layer)
146 |
147 | gptq = {}
148 | for name in subset:
149 | if (
150 | not (args.minlayer <= i < args.maxlayer and args.quant_only in name)
151 | ) == (not args.invert):
152 | continue
153 | braq_quantizer = Binarization(
154 | subset[name].weight,
155 | method=args.low_quant_method,
156 | groupsize=groupsize,
157 | )
158 | gptq[name] = BRAGPTQ(
159 | subset[name],
160 | braq_quantizer,
161 | salient_metric=args.salient_metric,
162 | disable_gptq=args.disable_gptq,
163 | method=args.low_quant_method,
164 | order2_group=args.order2_group,
165 | )
166 |
167 | def add_batch(name):
168 | def tmp(_, inp, out):
169 | gptq[name].add_batch(inp[0].data, out.data)
170 |
171 | return tmp
172 |
173 | handles = []
174 | for name in gptq:
175 | handles.append(subset[name].register_forward_hook(add_batch(name)))
176 | for j in range(args.nsamples):
177 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
178 | for h in handles:
179 | h.remove()
180 |
181 | for name in gptq:
182 | # print(i, name)
183 | # print("Quantizing ...")
184 | logging.info(f'{i} {name}')
185 | logging.info("Quantizing ...")
186 | info = gptq[name].fasterquant(
187 | percdamp=args.percdamp,
188 | blocksize=args.blocksize,
189 | num_p=args.num_p,
190 | )
191 | gptq[name].free()
192 |
193 | for j in range(args.nsamples):
194 | outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
195 |
196 | # # debug
197 | # print('fp16', fp_outs.shape) # [128, 2048, 4096]
198 | # print('fp16', fp_outs[0][:5])
199 | # print('billm', outs.shape) # [128, 2048, 4096]
200 | # print('billm', outs[0][:5])
201 | # print('------------------')
202 |
203 | layers[i] = layer.cpu()
204 | del layer
205 | del gptq
206 | torch.cuda.empty_cache()
207 |
208 | inps, outs = outs, inps
209 |
210 | model.config.use_cache = use_cache
211 |
212 |
213 | if __name__ == "__main__":
214 | import argparse
215 | from datautils import *
216 |
217 | def list_of_ints(arg):
218 | return list(map(int, arg.split(',')))
219 |
220 | def list_of_floats(arg):
221 | return list(map(float, arg.split(',')))
222 |
223 | parser = argparse.ArgumentParser()
224 |
225 | parser.add_argument(
226 | "model", type=str, help="model to load; for example `huggyllama/llama-7b`."
227 | )
228 | parser.add_argument(
229 | "dataset",
230 | type=str,
231 | choices=["wikitext2", "ptb", "c4"],
232 | help="Where to extract calibration data from.",
233 | )
234 | parser.add_argument(
235 | "low_quant_method",
236 | type=str,
237 | choices=["arb", "arb-x", 'arb-rc', 'braq'],
238 | help="alternating refined binarization method",
239 | )
240 | parser.add_argument(
241 | "--order2_group",
242 | action='store_true',
243 | help="division for salient weights",
244 | )
245 | parser.set_defaults(order2_group=False)
246 | parser.add_argument("--load_quantized", action="store_true")
247 | parser.add_argument(
248 | "--seed", type=int, default=0, help="Seed for sampling the calibration data."
249 | )
250 | parser.add_argument(
251 | "--nsamples", type=int, default=128, help="Number of calibration data samples."
252 | )
253 | parser.add_argument(
254 | "--percdamp",
255 | type=float,
256 | default=0.01,
257 | help="Percent of the average Hessian diagonal to use for dampening.",
258 | )
259 | parser.add_argument(
260 | "--blocksize",
261 | type=int,
262 | default=128,
263 | help="Blocksize to use for adaptive mask selection.",
264 | )
265 | parser.add_argument(
266 | "--num_p",
267 | type=int,
268 | default=1,
269 | help="Number of division for non-salient weights",
270 | )
271 | parser.add_argument(
272 | "--salient_metric",
273 | type=str,
274 | default="magnitude",
275 | choices=["magnitude", "hessian"],
276 | )
277 | parser.add_argument(
278 | "--device",
279 | type=str,
280 | default="cuda:0",
281 | help="set the device to use for quantization.",
282 | )
283 | parser.add_argument(
284 | "--disable_gptq",
285 | action="store_true",
286 | help="disable GPTQ for quantization.",
287 | )
288 | parser.add_argument(
289 | "--minlayer", type=int, default=-1, help="Quant all layers with id >= this."
290 | )
291 | parser.add_argument(
292 | "--maxlayer", type=int, default=1000, help="Quant all layers with id < this."
293 | )
294 | parser.add_argument(
295 | "--quant_only",
296 | type=str,
297 | default="",
298 | help="Quant only layers that contain this text.",
299 | )
300 | parser.add_argument("--invert", action="store_true", help="Invert subset.")
301 | parser.add_argument(
302 | "--save",
303 | action="store_true",
304 | )
305 | parser.add_argument(
306 | "--log_wandb", action="store_true", help="Whether to log to wandb."
307 | )
308 | parser.add_argument(
309 | "--tasks",
310 | type=str,
311 | default="",
312 | )
313 | parser.add_argument(
314 | "--experiment",
315 | type=str,
316 | default="",
317 | )
318 | parser.add_argument("--num_fewshot", type=int, default=0)
319 | parser.add_argument("--limit", type=int, default=-1)
320 |
321 | args = parser.parse_args()
322 | groupsize = args.blocksize
323 |
324 | device = args.device
325 | save_title = f"{args.model.split('/')[-1]}_{args.dataset}_{args.low_quant_method}_{groupsize}_{args.salient_metric}_nump_{args.num_p}_order2group_{args.order2_group}"
326 | save_file = "./output/" + save_title.replace("/", "_") + ".pt"
327 | if args.load_quantized:
328 | model = get_model(save_file)
329 | model.eval()
330 |
331 | else: # braq
332 | # log
333 | log_file = "./log/" + save_title.replace("/", "_") + f"_{args.experiment}" + ".log"
334 | log_path = os.path.dirname(log_file)
335 | if not os.path.exists(log_path):
336 | os.makedirs(log_path)
337 | logger = setup_logger(log_file)
338 |
339 | model = get_model(args.model)
340 | model.eval()
341 | tick = time.time()
342 | dataloader, testloader = get_loaders(
343 | args.dataset,
344 | nsamples=args.nsamples,
345 | seed=args.seed,
346 | model=args.model,
347 | seqlen=model.seqlen,
348 | )
349 | # print(model)
350 |
351 | quant_sequential(model, dataloader, device)
352 | print("quantization time:", time.time() - tick, "s")
353 | print(f'Experiment: {args.experiment}')
354 | logger.info(f'Experiment: {args.experiment}')
355 |
356 | if args.save:
357 | save_path = os.path.dirname(save_file)
358 | if not os.path.exists(save_path):
359 | os.makedirs(save_path)
360 | model.save_pretrained(save_file)
361 |
362 | for dataset in ["wikitext2", "ptb", "c4"]:
363 | dataloader, testloader = get_loaders(
364 | dataset, seed=args.seed, seqlen=model.seqlen, model=args.model
365 | )
366 | print(dataset)
367 | if "opt" in args.model:
368 | from eval_ppl_utils import opt_eval
369 |
370 | opt_eval(model, testloader, device, dataset, args.log_wandb)
371 | elif "llama" in args.model:
372 | from eval_ppl_utils import llama_eval
373 |
374 | llama_eval(model, testloader, device, dataset, args.log_wandb)
375 |
376 |
--------------------------------------------------------------------------------
/utils/autosearch_arb.py:
--------------------------------------------------------------------------------
1 | from re import L
2 | import numpy as np
3 | from pyparsing import line
4 | import torch
5 | from binary_arb import high_order_residual, high_order_residual_alternating_order1, high_order_residual_alternating_mean, high_order_residual_alternating_order2_rc_nomean, high_order_residual_alternating_order1_rc_nomean
6 | from utils.mask import generate_structural_mask
7 |
8 | error_N = 2048*4096*128
9 |
10 | def error_computing(origin_matrix, quantized_matrix):
11 | mse = torch.mean((origin_matrix - quantized_matrix) ** 2)
12 | return mse
13 |
14 | def error_computing_x_all_accelerate(origin_matrix, quantized_matrix, S):
15 | # inps shape [128, 2048, 128]
16 | R = (origin_matrix - quantized_matrix).T
17 | P = torch.einsum('ik,jk->ij', R, R)
18 | return torch.sum(P * S) / error_N
19 |
20 | def calculate_percentage_and_variance_original(weights, abs_weights, bin_edges):
21 | percentages = []
22 | variances = []
23 | accum_percentages = [0]
24 | total_elements = abs_weights.numel()
25 | for i in range(len(bin_edges) - 1):
26 | bin_mask = (abs_weights >= bin_edges[i]) & (abs_weights < bin_edges[i + 1])
27 | bin_weights = weights[bin_mask]
28 | percentages.append(bin_weights.numel() / total_elements * 100)
29 | accum_percentages.append(accum_percentages[-1] + percentages[-1])
30 | variances.append(torch.var(bin_weights))
31 | return percentages, variances, accum_percentages
32 |
33 | '''
34 | Include main method to search the rate for 2-bit salient data columns and the optimal split for 1-bit data
35 | '''
36 | def structural_searching_multip(origin_matrix, up_lim=30, num_p=1, order2_group=False):
37 | minimal_value = float('inf')
38 | minimal_value_0 = float('inf')
39 |
40 | true_counts = origin_matrix.abs().sum(dim=0)
41 |
42 | error = []
43 | lines = []
44 | # search for the optimal split for the first group, high order=2,, structured search
45 | _, top_braq_2_columns = torch.topk(true_counts, up_lim)
46 | for i in range(1, up_lim):
47 | mask3 = torch.full((origin_matrix.shape[0], origin_matrix.shape[1]), False).to(origin_matrix.device)
48 | mask3[:, top_braq_2_columns[:i]] = True
49 | group3 = high_order_residual(origin_matrix, mask3, order=2)
50 | group4 = high_order_residual(origin_matrix, ~mask3, order=2)
51 |
52 |
53 | quantize_error_0 = error_computing(origin_matrix, group4+group3)
54 | error.append(quantize_error_0.item())
55 | lines.append(i)
56 |
57 | if quantize_error_0 < minimal_value_0:
58 | minimal_value_0 = quantize_error_0
59 | optimal_split_0 = i
60 |
61 | _, top_braq_2_columns = torch.topk(true_counts, optimal_split_0)
62 | mask3 = torch.full((origin_matrix.shape[0], origin_matrix.shape[1]), False).to(origin_matrix.device)
63 | mask3[:, top_braq_2_columns] = True
64 | group3 = high_order_residual(origin_matrix, mask3, order=2)
65 |
66 | mask_list = [mask3]
67 | optimal_split_list = []
68 | for i in range(num_p):
69 | search_matrix = origin_matrix * (~mask3)
70 |
71 | flat_abs_tensor = torch.abs(search_matrix).view(-1)
72 | percentiles = torch.linspace(0.10, 0.90, 81).to(origin_matrix.device)
73 | percentile_values = torch.tensor(
74 | np.quantile(flat_abs_tensor.detach().cpu().numpy(), q=percentiles.cpu().numpy(), axis=None, keepdims=False)
75 | ).to(origin_matrix.device)
76 |
77 | # search for the optimal split for the second group, high order=1,, non-structured search
78 | for split_value in percentile_values:
79 | mask1, mask2 = generate_structural_mask(origin_matrix, mask3, split_value)
80 | group1 = high_order_residual(origin_matrix, mask1, order=1)
81 | group2 = high_order_residual(origin_matrix, mask2, order=1)
82 |
83 | quantize_error = error_computing(origin_matrix, group1+group2+group3)
84 | if quantize_error < minimal_value:
85 | minimal_value = quantize_error
86 | optimal_split = split_value
87 | optimal_group2 = group2
88 | best_mask2 = mask2
89 | best_mask1 = mask1
90 |
91 | mask_list.append(best_mask2)
92 | optimal_split_list.append(optimal_split)
93 | group3 = group3 + optimal_group2
94 | mask3 = mask3 | best_mask2
95 |
96 | mask_list.append(best_mask1)
97 |
98 | return optimal_split_list, mask_list
99 |
100 | def structural_searching_multip_alternating_group(origin_matrix, up_lim=30, num_p=1, inp=None, iter=0, order2_group=False):
101 | minimal_value = float('inf')
102 | minimal_value_0 = float('inf')
103 |
104 | true_counts = origin_matrix.abs().sum(dim=0)
105 |
106 | # error = []
107 | # lines = []
108 | # search for the optimal split for the first group, high order=2,, structured search
109 | _, top_braq_2_columns = torch.topk(true_counts, up_lim)
110 | for i in range(1, up_lim):
111 | mask3 = torch.full((origin_matrix.shape[0], origin_matrix.shape[1]), False).to(origin_matrix.device)
112 | mask3[:, top_braq_2_columns[:i]] = True
113 | group3 = high_order_residual(origin_matrix, mask3, order=2) # for fair comparison and accelerate
114 | group4 = high_order_residual(origin_matrix, ~mask3, order=2) # for fair comparison and accelerate
115 |
116 | quantize_error_0 = error_computing(origin_matrix, group4+group3)
117 | # error.append(quantize_error_0.item())
118 | # lines.append(i)
119 | # print(quantize_error_0)
120 |
121 | if quantize_error_0 < minimal_value_0:
122 | minimal_value_0 = quantize_error_0
123 | optimal_split_0 = i
124 |
125 |
126 | _, top_braq_2_columns = torch.topk(true_counts, optimal_split_0)
127 | mask3 = torch.full((origin_matrix.shape[0], origin_matrix.shape[1]), False).to(origin_matrix.device)
128 | mask3[:, top_braq_2_columns] = True
129 |
130 | group3 = high_order_residual_alternating_mean(origin_matrix, mask3, order=2)
131 |
132 | mask_list = []
133 | optimal_split_list = []
134 |
135 | # 2nd order group
136 | if order2_group:
137 | mask0 = mask3.clone()
138 | minimal_value2 = float('inf')
139 | group0 = torch.zeros(origin_matrix.shape, device=origin_matrix.device)
140 | for i in range(num_p):
141 | search_matrix = origin_matrix * mask0
142 |
143 | flat_abs_tensor = torch.abs(search_matrix).view(-1)
144 | flat_abs_tensor_nonzero = flat_abs_tensor[flat_abs_tensor != 0]
145 | percentiles = torch.linspace(0.10, 0.90, 81).to(origin_matrix.device)
146 | percentile_values = torch.tensor(
147 | np.quantile(flat_abs_tensor_nonzero.detach().cpu().numpy(), q=percentiles.cpu().numpy(), axis=None, keepdims=False)
148 | ).to(origin_matrix.device)
149 |
150 | # search for the optimal split for the second group, high order=1,, non-structured search
151 | flag = False
152 | for split_value in percentile_values:
153 | mask4, mask5 = generate_structural_mask(origin_matrix, ~mask0, split_value)
154 | group1 = high_order_residual(origin_matrix, mask4, order=2)
155 | group2 = high_order_residual(origin_matrix, mask5, order=2)
156 |
157 | quantize_error = error_computing(origin_matrix, group1+group2+group0)
158 | if quantize_error < minimal_value2:
159 | minimal_value2 = quantize_error
160 | optimal_split = split_value
161 | optimal_group2 = group2
162 | best_mask4 = mask4
163 | best_mask5 = mask5
164 | flag = True
165 |
166 | if not flag:
167 | print(False, 2)
168 | optimal_split = percentile_values[0]
169 | best_mask4, best_mask5 = generate_structural_mask(origin_matrix, ~mask0, optimal_split)
170 |
171 | mask0 = mask0 & (~best_mask5)
172 | mask_list.append(best_mask5)
173 | group0 = group0 + optimal_group2
174 |
175 | mask_list.append(best_mask4)
176 |
177 | else:
178 | mask_list.append(mask3)
179 |
180 | # 1st order group
181 | for i in range(num_p):
182 | search_matrix = origin_matrix * (~mask3)
183 |
184 | flat_abs_tensor = torch.abs(search_matrix).view(-1)
185 | percentiles = torch.linspace(0.10, 0.90, 81).to(origin_matrix.device)
186 | percentile_values = torch.tensor(
187 | np.quantile(flat_abs_tensor.detach().cpu().numpy(), q=percentiles.cpu().numpy(), axis=None, keepdims=False)
188 | ).to(origin_matrix.device)
189 |
190 | # search for the optimal split for the second group, high order=1,, non-structured search
191 | flag = False
192 | for split_value in percentile_values:
193 | mask1, mask2 = generate_structural_mask(origin_matrix, mask3, split_value)
194 |
195 | group1 = high_order_residual(origin_matrix, mask1, order=1)
196 | group2 = high_order_residual(origin_matrix, mask2, order=1)
197 |
198 | quantize_error = error_computing(origin_matrix, group1+group2+group3)
199 | if quantize_error < minimal_value:
200 | minimal_value = quantize_error
201 | optimal_split = split_value
202 | best_mask2 = mask2
203 | best_mask1 = mask1
204 | flag = True
205 |
206 | if not flag:
207 | print(False)
208 | optimal_split = percentile_values[0]
209 | best_mask1, best_mask2 = generate_structural_mask(origin_matrix, mask3, optimal_split)
210 |
211 | optimal_group2 = high_order_residual_alternating_order1(origin_matrix, best_mask2, order=1)
212 | mask_list.append(best_mask2)
213 | optimal_split_list.append(optimal_split)
214 | group3 = group3 + optimal_group2
215 | mask3 = mask3 | best_mask2
216 |
217 | mask_list.append(best_mask1)
218 |
219 | return optimal_split_list, mask_list
220 |
221 | def structural_searching_multip_alternating_group_x(origin_matrix, up_lim=30, num_p=1, inp=None, iter=0, order2_group=False):
222 | minimal_value = float('inf')
223 | minimal_value_0 = float('inf')
224 |
225 | true_counts = origin_matrix.abs().sum(dim=0)
226 |
227 | # error = []
228 | # lines = []
229 | # search for the optimal split for the first group, high order=2,, structured search
230 | _, top_braq_2_columns = torch.topk(true_counts, up_lim)
231 | for i in range(1, up_lim):
232 | mask3 = torch.full((origin_matrix.shape[0], origin_matrix.shape[1]), False).to(origin_matrix.device)
233 | mask3[:, top_braq_2_columns[:i]] = True
234 | group3 = high_order_residual(origin_matrix, mask3, order=2) # for fair comparison and accelerate
235 | group4 = high_order_residual(origin_matrix, ~mask3, order=2) # for fair comparison and accelerate
236 |
237 | quantize_error_0 = error_computing(origin_matrix, group4+group3)
238 | # error.append(quantize_error_0.item())
239 | # lines.append(i)
240 | # print(quantize_error_0)
241 |
242 | if quantize_error_0 < minimal_value_0:
243 | minimal_value_0 = quantize_error_0
244 | optimal_split_0 = i
245 |
246 |
247 | _, top_braq_2_columns = torch.topk(true_counts, optimal_split_0)
248 | mask3 = torch.full((origin_matrix.shape[0], origin_matrix.shape[1]), False).to(origin_matrix.device)
249 | mask3[:, top_braq_2_columns] = True
250 |
251 | group3 = high_order_residual_alternating_mean(origin_matrix, mask3, order=2)
252 |
253 | mask_list = []
254 | optimal_split_list = []
255 |
256 | # 2nd order group
257 | if order2_group:
258 | mask0 = mask3.clone()
259 | minimal_value2 = float('inf')
260 | group0 = torch.zeros(origin_matrix.shape, device=origin_matrix.device)
261 | for i in range(num_p):
262 | search_matrix = origin_matrix * mask0
263 |
264 | flat_abs_tensor = torch.abs(search_matrix).view(-1)
265 | flat_abs_tensor_nonzero = flat_abs_tensor[flat_abs_tensor != 0]
266 | percentiles = torch.linspace(0.10, 0.90, 81).to(origin_matrix.device)
267 | percentile_values = torch.tensor(
268 | np.quantile(flat_abs_tensor_nonzero.detach().cpu().numpy(), q=percentiles.cpu().numpy(), axis=None, keepdims=False)
269 | ).to(origin_matrix.device)
270 |
271 | # search for the optimal split for the second group, high order=1,, non-structured search
272 | flag = False
273 | for split_value in percentile_values:
274 | mask4, mask5 = generate_structural_mask(origin_matrix, ~mask0, split_value)
275 | group1 = high_order_residual(origin_matrix, mask4, order=2)
276 | group2 = high_order_residual(origin_matrix, mask5, order=2)
277 |
278 | quantize_error = error_computing(origin_matrix, group1+group2+group0)
279 | if quantize_error < minimal_value2:
280 | minimal_value2 = quantize_error
281 | optimal_split = split_value
282 | optimal_group2 = group2
283 | best_mask4 = mask4
284 | best_mask5 = mask5
285 | flag = True
286 |
287 | if not flag:
288 | print(False, 2)
289 | optimal_split = percentile_values[0]
290 | best_mask4, best_mask5 = generate_structural_mask(origin_matrix, ~mask0, optimal_split)
291 |
292 | mask0 = mask0 & (~best_mask5)
293 | mask_list.append(best_mask5)
294 | group0 = group0 + optimal_group2
295 |
296 | mask_list.append(best_mask4)
297 |
298 | else:
299 | mask_list.append(mask3)
300 |
301 | # 1st order group
302 | for i in range(num_p):
303 | search_matrix = origin_matrix * (~mask3)
304 |
305 | flat_abs_tensor = torch.abs(search_matrix).view(-1)
306 | percentiles = torch.linspace(0.10, 0.90, 81).to(origin_matrix.device)
307 | percentile_values = torch.tensor(
308 | np.quantile(flat_abs_tensor.detach().cpu().numpy(), q=percentiles.cpu().numpy(), axis=None, keepdims=False)
309 | ).to(origin_matrix.device)
310 |
311 | # search for the optimal split for the second group, high order=1,, non-structured search
312 | flag = False
313 | for split_value in percentile_values:
314 | mask1, mask2 = generate_structural_mask(origin_matrix, mask3, split_value)
315 |
316 | group1 = high_order_residual(origin_matrix, mask1, order=1)
317 | group2 = high_order_residual(origin_matrix, mask2, order=1)
318 |
319 | quantize_error = error_computing_x_all_accelerate(origin_matrix, group1+group2+group3, inp)
320 | if quantize_error < minimal_value:
321 | minimal_value = quantize_error
322 | optimal_split = split_value
323 | best_mask2 = mask2
324 | best_mask1 = mask1
325 | flag = True
326 |
327 | if not flag:
328 | print(False)
329 | optimal_split = percentile_values[0]
330 | best_mask1, best_mask2 = generate_structural_mask(origin_matrix, mask3, optimal_split)
331 |
332 | optimal_group2 = high_order_residual_alternating_order1(origin_matrix, best_mask2, order=1) # accelerate
333 | mask_list.append(best_mask2)
334 | optimal_split_list.append(optimal_split)
335 | group3 = group3 + optimal_group2
336 | mask3 = mask3 | best_mask2
337 |
338 | mask_list.append(best_mask1)
339 |
340 | return optimal_split_list, mask_list
341 |
342 | def structural_searching_multip_alternating_group_rc(origin_matrix, up_lim=30, num_p=1, inp=None, iter=0, order2_group=False):
343 | minimal_value = float('inf')
344 | minimal_value_0 = float('inf')
345 |
346 | true_counts = origin_matrix.abs().sum(dim=0)
347 |
348 | # error = []
349 | # lines = []
350 | # search for the optimal split for the first group, high order=2,, structured search
351 | _, top_braq_2_columns = torch.topk(true_counts, up_lim)
352 | for i in range(1, up_lim):
353 | mask3 = torch.full((origin_matrix.shape[0], origin_matrix.shape[1]), False).to(origin_matrix.device)
354 | mask3[:, top_braq_2_columns[:i]] = True
355 | group3 = high_order_residual(origin_matrix, mask3, order=2) # for fair comparison and accelerate
356 | group4 = high_order_residual(origin_matrix, ~mask3, order=2) # for fair comparison and accelerate
357 |
358 | quantize_error_0 = error_computing(origin_matrix, group4+group3)
359 | # error.append(quantize_error_0.item())
360 | # lines.append(i)
361 | # print(quantize_error_0)
362 |
363 | if quantize_error_0 < minimal_value_0:
364 | minimal_value_0 = quantize_error_0
365 | optimal_split_0 = i
366 |
367 |
368 | _, top_braq_2_columns = torch.topk(true_counts, optimal_split_0)
369 | mask3 = torch.full((origin_matrix.shape[0], origin_matrix.shape[1]), False).to(origin_matrix.device)
370 | mask3[:, top_braq_2_columns] = True
371 |
372 | group3 = high_order_residual_alternating_order2_rc_nomean(origin_matrix, mask3, order=2)
373 |
374 | mask_list = []
375 | optimal_split_list = []
376 |
377 | # 2nd order group
378 | if order2_group:
379 | mask0 = mask3.clone()
380 | minimal_value2 = float('inf')
381 | group0 = torch.zeros(origin_matrix.shape, device=origin_matrix.device)
382 | for i in range(num_p):
383 | search_matrix = origin_matrix * mask0
384 |
385 | flat_abs_tensor = torch.abs(search_matrix).view(-1)
386 | flat_abs_tensor_nonzero = flat_abs_tensor[flat_abs_tensor != 0]
387 | percentiles = torch.linspace(0.10, 0.90, 81).to(origin_matrix.device)
388 | percentile_values = torch.tensor(
389 | np.quantile(flat_abs_tensor_nonzero.detach().cpu().numpy(), q=percentiles.cpu().numpy(), axis=None, keepdims=False)
390 | ).to(origin_matrix.device)
391 |
392 | # search for the optimal split for the second group, high order=1,, non-structured search
393 | flag = False
394 | for split_value in percentile_values:
395 | mask4, mask5 = generate_structural_mask(origin_matrix, ~mask0, split_value)
396 | group1 = high_order_residual(origin_matrix, mask4, order=2)
397 | group2 = high_order_residual(origin_matrix, mask5, order=2)
398 |
399 | quantize_error = error_computing(origin_matrix, group1+group2+group0)
400 | if quantize_error < minimal_value2:
401 | minimal_value2 = quantize_error
402 | optimal_split = split_value
403 | optimal_group2 = group2
404 | best_mask4 = mask4
405 | best_mask5 = mask5
406 | flag = True
407 |
408 | if not flag:
409 | print(False, 2)
410 | optimal_split = percentile_values[0]
411 | best_mask4, best_mask5 = generate_structural_mask(origin_matrix, ~mask0, optimal_split)
412 |
413 | mask0 = mask0 & (~best_mask5)
414 | mask_list.append(best_mask5)
415 | group0 = group0 + optimal_group2
416 |
417 | mask_list.append(best_mask4)
418 |
419 | else:
420 | mask_list.append(mask3)
421 |
422 | # 1st order group
423 | for i in range(num_p):
424 | search_matrix = origin_matrix * (~mask3)
425 |
426 | flat_abs_tensor = torch.abs(search_matrix).view(-1)
427 | percentiles = torch.linspace(0.10, 0.90, 81).to(origin_matrix.device)
428 | percentile_values = torch.tensor(
429 | np.quantile(flat_abs_tensor.detach().cpu().numpy(), q=percentiles.cpu().numpy(), axis=None, keepdims=False)
430 | ).to(origin_matrix.device)
431 |
432 | # search for the optimal split for the second group, high order=1,, non-structured search
433 | flag = False
434 | for split_value in percentile_values:
435 | mask1, mask2 = generate_structural_mask(origin_matrix, mask3, split_value)
436 |
437 | group1 = high_order_residual(origin_matrix, mask1, order=1)
438 | group2 = high_order_residual(origin_matrix, mask2, order=1)
439 |
440 | # quantize_error = error_computing_x_all_accelerate(origin_matrix, group1+group2+group3, inp)
441 | quantize_error = error_computing(origin_matrix, group1+group2+group3)
442 | if quantize_error < minimal_value:
443 | minimal_value = quantize_error
444 | optimal_split = split_value
445 | best_mask2 = mask2
446 | best_mask1 = mask1
447 | flag = True
448 |
449 | if not flag:
450 | print(False)
451 | optimal_split = percentile_values[0]
452 | best_mask1, best_mask2 = generate_structural_mask(origin_matrix, mask3, optimal_split)
453 |
454 | # optimal_group2 = high_order_residual_alternating_order1(origin_matrix, best_mask2, order=1)
455 | optimal_group2 = high_order_residual_alternating_order1_rc_nomean(origin_matrix, best_mask2, order=1, iter=0)
456 | mask_list.append(best_mask2)
457 | optimal_split_list.append(optimal_split)
458 | group3 = group3 + optimal_group2
459 | mask3 = mask3 | best_mask2
460 |
461 | mask_list.append(best_mask1)
462 |
463 | return optimal_split_list, mask_list
464 |
465 | def find_optimal_split(group_max, origin_matrix, border):
466 | optimal_split = None
467 | minimal_value = float('inf')
468 | searching_steps = torch.arange(0.1,0.8,0.01)
469 | searching_steps = searching_steps * group_max
470 |
471 | group3 = high_order_residual(origin_matrix, torch.abs(origin_matrix) > border, order=2)
472 | for split_value in searching_steps:
473 |
474 | group1 = high_order_residual(origin_matrix, (torch.abs(origin_matrix) > split_value) & (torch.abs(origin_matrix) <= border), order=1)
475 | group2 = high_order_residual(origin_matrix, torch.abs(origin_matrix) <= split_value, order=1)
476 |
477 | quantize_error = error_computing(origin_matrix, group1+group2+group3)
478 | if quantize_error < minimal_value:
479 | minimal_value = quantize_error
480 | optimal_split = split_value
481 |
482 | return optimal_split, minimal_value
483 |
--------------------------------------------------------------------------------
/binary_arb.py:
--------------------------------------------------------------------------------
1 | from numpy import mean
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import math
6 | index = 0
7 | @torch.no_grad()
8 | def part_mean(tensor, op='-'):
9 | non_zero = tensor*(tensor!=0)
10 |
11 | mean_val = non_zero.mean(-1).view(-1, 1)
12 |
13 | return mean_val
14 |
15 | @torch.no_grad()
16 | def high_order_residual(x, mask, order=2):
17 | sum_order = torch.zeros_like(x)
18 | new_matrix = x.clone()
19 | new_matrix = new_matrix * mask
20 | global index
21 | index += 1
22 | for od in range(order):
23 | residual = new_matrix - sum_order
24 | masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))
25 |
26 | mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
27 | mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
28 | masked_x_tensor -= mean_tensor_all[:, None]
29 | scale_tensor_all = torch.nanmean(torch.abs(masked_x_tensor), dim=1)
30 | scale_tensor_all = torch.where(torch.isnan(scale_tensor_all), torch.zeros_like(scale_tensor_all), scale_tensor_all)
31 |
32 | binary= torch.sign(masked_x_tensor)
33 | binary *= scale_tensor_all[:, None]
34 | binary += mean_tensor_all[:, None]
35 | sum_order = sum_order + binary*mask
36 |
37 | return sum_order
38 |
39 | @torch.no_grad()
40 | def high_order_residual_rc(x, mask, order=2):
41 | sum_order = torch.zeros_like(x)
42 | new_matrix = x.clone()
43 | new_matrix = new_matrix * mask
44 | global index
45 | index += 1
46 | for od in range(order):
47 | residual = new_matrix - sum_order
48 | masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))
49 |
50 | # mean row
51 | mean_tensor_all_r = torch.nanmean(masked_x_tensor, dim=1)
52 | mean_tensor_all_r = torch.where(torch.isnan(mean_tensor_all_r), torch.zeros_like(mean_tensor_all_r), mean_tensor_all_r)
53 | masked_x_tensor -= mean_tensor_all_r[:, None]
54 | # mean column
55 | mean_tensor_all_c = torch.nanmean(masked_x_tensor, dim=0)
56 | mean_tensor_all_c = torch.where(torch.isnan(mean_tensor_all_c), torch.zeros_like(mean_tensor_all_c), mean_tensor_all_c)
57 | masked_x_tensor -= mean_tensor_all_c[None, :]
58 |
59 | # alpha row
60 | scale_tensor_all_r = torch.nanmean(torch.abs(masked_x_tensor), dim=1)
61 | scale_tensor_all_r = torch.where(torch.isnan(scale_tensor_all_r), torch.zeros_like(scale_tensor_all_r), scale_tensor_all_r)
62 | # alpha column
63 | scale_tensor_all_c = torch.nanmean(torch.abs(masked_x_tensor / scale_tensor_all_r[:, None]), dim=0)
64 | scale_tensor_all_c = torch.where(torch.isnan(scale_tensor_all_c), torch.zeros_like(scale_tensor_all_c), scale_tensor_all_c)
65 |
66 | binary= torch.sign(masked_x_tensor)
67 | binary *= scale_tensor_all_r[:, None]
68 | binary *= scale_tensor_all_c[None, :]
69 | binary += mean_tensor_all_r[:, None] + mean_tensor_all_c[None, :]
70 | sum_order = sum_order + binary*mask
71 |
72 | return sum_order
73 |
74 | @torch.no_grad()
75 | def high_order_residual_alternating_order1(x, mask, order=2, iter=15):
76 | sum_order = torch.zeros_like(x)
77 | new_matrix = x.clone()
78 | new_matrix = new_matrix * mask
79 | global index
80 | index += 1
81 | for od in range(order):
82 | residual = new_matrix - sum_order
83 | masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))
84 |
85 | mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
86 | mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
87 | masked_x_tensor -= mean_tensor_all[:, None]
88 | scale_tensor_all = torch.nanmean(torch.abs(masked_x_tensor), dim=1)
89 | scale_tensor_all = torch.where(torch.isnan(scale_tensor_all), torch.zeros_like(scale_tensor_all), scale_tensor_all)
90 |
91 | binary= torch.sign(masked_x_tensor)
92 | new_binary = binary.clone()
93 | binary *= scale_tensor_all[:, None]
94 | binary += mean_tensor_all[:, None]
95 | sum_order = sum_order + binary*mask
96 |
97 | # Alternating update
98 | refine_mean = mean_tensor_all.clone()
99 | sum_order_alternating = sum_order.clone()
100 |
101 | for k in range(iter):
102 | # 1. Fix alpha and B, update mean
103 | residual = new_matrix - sum_order_alternating
104 | masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))
105 | mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
106 | mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
107 | refine_mean += mean_tensor_all.clone()
108 |
109 | # 2. Fix mean and B, update alpha
110 | new_alpha = 1. / (torch.sum(new_binary * mask * new_binary * mask, dim=1) + 1e-8) * torch.sum(new_binary * mask * (new_matrix - refine_mean[:, None] * mask), dim=1)
111 |
112 | # 3. Fix mean and alpha, update B
113 | new_binary = torch.sign(new_matrix - refine_mean[:, None] * mask)
114 |
115 | # Final refine results
116 | sum_order_alternating = torch.zeros_like(x) + (new_alpha[:, None] * new_binary + refine_mean[:, None]) * mask
117 |
118 |
119 | return sum_order_alternating
120 |
121 | @torch.no_grad()
122 | def high_order_residual_alternating_order1_x(x, mask, order=2, S=None, iter=15, iter2=15):
123 | sum_order = torch.zeros_like(x)
124 | new_matrix = x.clone()
125 | new_matrix = new_matrix * mask
126 | global index
127 | index += 1
128 | for od in range(order):
129 | residual = new_matrix - sum_order
130 | masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))
131 |
132 | mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
133 | mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
134 | masked_x_tensor -= mean_tensor_all[:, None]
135 | scale_tensor_all = torch.nanmean(torch.abs(masked_x_tensor), dim=1)
136 | scale_tensor_all = torch.where(torch.isnan(scale_tensor_all), torch.zeros_like(scale_tensor_all), scale_tensor_all)
137 |
138 | binary= torch.sign(masked_x_tensor)
139 | new_binary = binary.clone()
140 | binary *= scale_tensor_all[:, None]
141 | binary += mean_tensor_all[:, None]
142 | sum_order = sum_order + binary*mask
143 |
144 | # Alternating update
145 | refine_mean = mean_tensor_all.clone()
146 | sum_order_alternating = sum_order.clone()
147 | new_alpha = scale_tensor_all.clone()
148 |
149 | for k in range(iter):
150 | # 1. Fix alpha and B, update mean
151 | residual = new_matrix - sum_order_alternating
152 | masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))
153 | mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
154 | mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
155 | refine_mean += mean_tensor_all.clone()
156 |
157 | # 2. Fix mean and B, update alpha
158 | new_alpha = 1. / (torch.sum(new_binary * mask * new_binary * mask, dim=1) + 1e-8) * torch.sum(new_binary * mask * (new_matrix - refine_mean[:, None] * mask), dim=1)
159 |
160 | # 3. Fix mean and alpha, update B
161 | new_binary = torch.sign(new_matrix - refine_mean[:, None] * mask)
162 |
163 | # Final refine results
164 | sum_order_alternating = torch.zeros_like(x) + (new_alpha[:, None] * new_binary + refine_mean[:, None]) * mask
165 |
166 | MM = mask[:, :, None] * mask[:, None, :]
167 | refine_mean_den = torch.sum(S * MM, dim=(1,2), dtype=torch.float32) + 1e-10
168 | masked_B = new_binary * mask
169 | new_alpha_den = torch.sum(S * masked_B[:, :, None] * masked_B[:, None, :], dim=(1,2)) + 1e-10
170 | # diag_S = torch.diag(S)
171 | for kk in range(iter2):
172 | # X error update mean
173 | refine_mean = torch.sum(S * (new_matrix - new_alpha[:, None] * new_binary * mask)[:, :, None] * MM, dim=(1,2)) / refine_mean_den
174 |
175 | # X error update alpha
176 | new_alpha = torch.sum(S * masked_B[:, :, None] * (new_matrix - refine_mean[:, None] * mask)[:, None, :], dim=(1,2)) / new_alpha_den
177 |
178 | sum_order_alternating = torch.zeros_like(x) + (new_alpha[:, None] * new_binary + refine_mean[:, None]) * mask
179 |
180 | return sum_order_alternating
181 |
182 | @torch.no_grad()
183 | def high_order_residual_alternating_order2_rc_nomean(x, mask, order=2, iter=15):
184 | sum_order = torch.zeros_like(x)
185 | new_matrix = x.clone()
186 | new_matrix = new_matrix * mask
187 | global index
188 | index += 1
189 | binary_list = []
190 | alpha_list_r = []
191 | alpha_list_c = []
192 | for od in range(order):
193 | residual = new_matrix - sum_order
194 | masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))
195 |
196 | # alpha row
197 | scale_tensor_all_r = torch.nanmean(torch.abs(masked_x_tensor), dim=1)
198 | scale_tensor_all_r = torch.where(torch.isnan(scale_tensor_all_r), torch.zeros_like(scale_tensor_all_r), scale_tensor_all_r)
199 | alpha_list_r.append(scale_tensor_all_r.clone())
200 | # alpha column
201 | scale_tensor_all_c = torch.nanmean(torch.abs(masked_x_tensor / scale_tensor_all_r[:, None]), dim=0)
202 | scale_tensor_all_c = torch.where(torch.isnan(scale_tensor_all_c), torch.zeros_like(scale_tensor_all_c), scale_tensor_all_c)
203 | alpha_list_c.append(scale_tensor_all_c.clone())
204 |
205 | binary= torch.sign(masked_x_tensor)
206 | binary_list.append(binary.clone())
207 | binary *= scale_tensor_all_r[:, None]
208 | binary *= scale_tensor_all_c[None, :]
209 | sum_order = sum_order + binary*mask
210 |
211 | # Alternating update
212 | sum_order_alternating = sum_order.clone()
213 |
214 | for k in range(iter):
215 | # 2-1. Fix mean, alpha column, and B, update alpha row 0
216 | W_tilde = new_matrix - (alpha_list_c[1][None, :] * alpha_list_r[1][:, None] * binary_list[1]) * mask
217 | alpha_c_B = alpha_list_c[0][None, :] * binary_list[0] * mask
218 | alpha_list_r[0] = torch.sum(alpha_c_B * W_tilde, dim=1) / (torch.sum(alpha_c_B * alpha_c_B, dim=1) + 1e-8)
219 |
220 | # 2-2. Fix mean, alpha row, and B, update alpha column 0
221 | alpha_r_B = alpha_list_r[0][:, None] * binary_list[0] * mask
222 | alpha_list_c[0] = torch.sum(alpha_r_B * W_tilde, dim=0) / (torch.sum(alpha_r_B * alpha_r_B, dim=0) + 1e-8)
223 |
224 | # 2-3. Fix mean, alpha column, and B, update alpha row 1
225 | W_tilde = new_matrix - (alpha_list_c[0][None, :] * alpha_list_r[0][:, None] * binary_list[0]) * mask
226 | alpha_c_B = alpha_list_c[1][None, :] * binary_list[1] * mask
227 | alpha_list_r[1] = torch.sum(alpha_c_B * W_tilde, dim=1) / (torch.sum(alpha_c_B * alpha_c_B, dim=1) + 1e-8)
228 |
229 | # 2-4. Fix mean, alpha row, and B, update alpha column 1
230 | alpha_r_B = alpha_list_r[1][:, None] * binary_list[1] * mask
231 | alpha_list_c[1] = torch.sum(alpha_r_B * W_tilde, dim=0) / (torch.sum(alpha_r_B * alpha_r_B, dim=0) + 1e-8)
232 |
233 | # 3. Fix mean and alpha, update B
234 | new_matrix_expanded = new_matrix.unsqueeze(-1)
235 | comb0 = alpha_list_r[0].reshape(-1, 1) @ alpha_list_c[0].reshape(1, -1)
236 | comb1 = alpha_list_r[1].reshape(-1, 1) @ alpha_list_c[1].reshape(1, -1)
237 | v = torch.stack([-comb0 - comb1, -comb0 + comb1,
238 | comb0 - comb1, comb0 + comb1], dim=2)
239 |
240 | min_indices = torch.argmin(torch.abs(new_matrix_expanded - v), dim=-1)
241 |
242 | binary_list[0] = torch.ones_like(min_indices)
243 | binary_list[0][(min_indices == 0) | (min_indices == 1)] = -1
244 | binary_list[1] = torch.ones_like(min_indices)
245 | binary_list[1][(min_indices == 0) | (min_indices == 2)] = -1
246 |
247 | # Final refine results
248 | sum_order_alternating = torch.zeros_like(x) + (alpha_list_c[0][None, :] * alpha_list_r[0][:, None] * binary_list[0] + alpha_list_c[1][None, :] * alpha_list_r[1][:, None] * binary_list[1]) * mask
249 |
250 | return sum_order_alternating
251 |
252 | @torch.no_grad()
253 | def high_order_residual_alternating_order1_rc_nomean(x, mask, order=2, iter=15):
254 | sum_order = torch.zeros_like(x)
255 | new_matrix = x.clone()
256 | new_matrix = new_matrix * mask
257 | global index
258 | index += 1
259 | for od in range(order):
260 | residual = new_matrix - sum_order
261 | masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))
262 |
263 | # alpha row
264 | scale_tensor_all_r = torch.nanmean(torch.abs(masked_x_tensor), dim=1)
265 | scale_tensor_all_r = torch.where(torch.isnan(scale_tensor_all_r), torch.zeros_like(scale_tensor_all_r), scale_tensor_all_r)
266 | # alpha column
267 | scale_tensor_all_c = torch.nanmean(torch.abs(masked_x_tensor / scale_tensor_all_r[:, None]), dim=0)
268 | scale_tensor_all_c = torch.where(torch.isnan(scale_tensor_all_c), torch.zeros_like(scale_tensor_all_c), scale_tensor_all_c)
269 |
270 | binary= torch.sign(masked_x_tensor)
271 | new_binary = binary.clone()
272 | binary *= scale_tensor_all_r[:, None]
273 | binary *= scale_tensor_all_c[None, :]
274 | sum_order = sum_order + binary*mask
275 |
276 | # Alternating update
277 | sum_order_alternating = sum_order.clone()
278 | new_alpha_r = scale_tensor_all_r.clone()
279 | new_alpha_c = scale_tensor_all_c.clone()
280 | for k in range(iter):
281 | # 1-1. Fix mean, alpha column, and B, update alpha row
282 | alpha_c_B = new_alpha_c[None, :] * new_binary * mask
283 | new_alpha_r = torch.sum(alpha_c_B * new_matrix, dim=1) / (torch.sum(alpha_c_B * alpha_c_B, dim=1) + 1e-8)
284 |
285 | # 1-2. Fix mean, alpha row, and B, update alpha column
286 | alpha_r_B = new_alpha_r[:, None] * new_binary * mask
287 | new_alpha_c = torch.sum(alpha_r_B * new_matrix, dim=0) / (torch.sum(alpha_r_B * alpha_r_B, dim=0) + 1e-8)
288 |
289 | # Final refine results
290 | sum_order_alternating = torch.zeros_like(x) + new_alpha_c[None, :] * new_alpha_r[:, None] * new_binary * mask
291 |
292 | return sum_order_alternating
293 |
294 | @torch.no_grad()
295 | def high_order_residual_alternating_mean(x, mask, order=2, num_iters=15):
296 | sum_order = torch.zeros_like(x)
297 | new_matrix = x.clone()
298 | new_matrix = new_matrix * mask
299 | global index
300 | index += 1
301 | binary_list = []
302 | alpha_list = []
303 | refine_mean = torch.zeros(x.shape[0], device=x.device)
304 | for od in range(order):
305 | residual = new_matrix - sum_order
306 | masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))
307 |
308 | mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
309 | mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
310 | refine_mean += mean_tensor_all.clone()
311 | masked_x_tensor -= mean_tensor_all[:, None]
312 | scale_tensor_all = torch.nanmean(torch.abs(masked_x_tensor), dim=1)
313 | scale_tensor_all = torch.where(torch.isnan(scale_tensor_all), torch.zeros_like(scale_tensor_all), scale_tensor_all)
314 | alpha_list.append(scale_tensor_all.clone())
315 |
316 | binary = torch.sign(masked_x_tensor)
317 | binary_list.append(binary.clone())
318 | binary *= scale_tensor_all[:, None]
319 | binary += mean_tensor_all[:, None]
320 | sum_order = sum_order + binary*mask
321 |
322 | new_matrix = x.clone() * mask
323 | sum_order_alternating = sum_order.clone()
324 |
325 | for k in range(num_iters):
326 | # 1. Fix alpha1, alpha2, B1, and B2, update mean
327 | residual = new_matrix - sum_order_alternating
328 | masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))
329 | mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
330 | mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
331 | refine_mean += mean_tensor_all.clone()
332 |
333 | # 2. Fix mean, B1, and B2, update alpha1 and alpha2
334 | alpha_list[0] = 1. / (torch.sum(binary_list[0] * mask * binary_list[0] * mask, dim=1) + 1e-8) * torch.sum(binary_list[0] * mask * (new_matrix - refine_mean[:, None] * mask - alpha_list[1][:, None] * binary_list[1] * mask), dim=1)
335 | alpha_list[1] = 1. / (torch.sum(binary_list[1] * mask * binary_list[1] * mask, dim=1) + 1e-8) * torch.sum(binary_list[1] * mask * (new_matrix - refine_mean[:, None] * mask - alpha_list[0][:, None] * binary_list[0] * mask), dim=1)
336 |
337 | # 3. Fix mean, alpha1, and alpha2, update B1 and B2
338 | new_matrix_expanded = (new_matrix - refine_mean[:, None] * mask).unsqueeze(-1)
339 | v = torch.stack([-alpha_list[0] - alpha_list[1], -alpha_list[0] + alpha_list[1],
340 | alpha_list[0] - alpha_list[1], alpha_list[0] + alpha_list[1]], dim=1).unsqueeze(1)
341 |
342 | min_indices = torch.argmin(torch.abs(new_matrix_expanded - v), dim=-1)
343 |
344 | binary_list[0] = torch.ones_like(min_indices)
345 | binary_list[0][(min_indices == 0) | (min_indices == 1)] = -1
346 | binary_list[1] = torch.ones_like(min_indices)
347 | binary_list[1][(min_indices == 0) | (min_indices == 2)] = -1
348 |
349 | sum_order_alternating = torch.zeros_like(x) + (alpha_list[0][:, None] * binary_list[0] + alpha_list[1][:, None] * binary_list[1] + refine_mean[:, None]) * mask
350 |
351 | return sum_order_alternating
352 |
353 | @torch.no_grad()
354 | def high_order_residual_alternating_mean_x(x, mask, order=2, S=None, num_iters=15, iter2=15):
355 | sum_order = torch.zeros_like(x)
356 | new_matrix = x.clone()
357 | new_matrix = new_matrix * mask
358 | global index
359 | index += 1
360 | binary_list = []
361 | alpha_list = []
362 | refine_mean = torch.zeros(x.shape[0], device=x.device)
363 | for od in range(order):
364 | residual = new_matrix - sum_order
365 | masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))
366 |
367 | mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
368 | mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
369 | refine_mean += mean_tensor_all.clone()
370 | masked_x_tensor -= mean_tensor_all[:, None]
371 | scale_tensor_all = torch.nanmean(torch.abs(masked_x_tensor), dim=1)
372 | scale_tensor_all = torch.where(torch.isnan(scale_tensor_all), torch.zeros_like(scale_tensor_all), scale_tensor_all)
373 | alpha_list.append(scale_tensor_all.clone())
374 |
375 | binary = torch.sign(masked_x_tensor)
376 | binary_list.append(binary.clone())
377 | binary *= scale_tensor_all[:, None]
378 | binary += mean_tensor_all[:, None]
379 | sum_order = sum_order + binary*mask
380 |
381 | new_matrix = x.clone() * mask
382 | sum_order_alternating = sum_order.clone()
383 |
384 | for k in range(num_iters):
385 | # 1. Fix alpha1, alpha2, B1, and B2, update mean
386 | residual = new_matrix - sum_order_alternating
387 | masked_x_tensor = torch.where(mask, residual, torch.tensor(float('nan')))
388 | mean_tensor_all = torch.nanmean(masked_x_tensor, dim=1)
389 | mean_tensor_all = torch.where(torch.isnan(mean_tensor_all), torch.zeros_like(mean_tensor_all), mean_tensor_all)
390 | refine_mean += mean_tensor_all.clone()
391 |
392 | # 2. Fix mean, B1, and B2, update alpha1 and alpha2
393 | alpha_list[0] = 1. / (torch.sum(binary_list[0] * mask * binary_list[0] * mask, dim=1) + 1e-8) * torch.sum(binary_list[0] * mask * (new_matrix - refine_mean[:, None] * mask - alpha_list[1][:, None] * binary_list[1] * mask), dim=1)
394 | alpha_list[1] = 1. / (torch.sum(binary_list[1] * mask * binary_list[1] * mask, dim=1) + 1e-8) * torch.sum(binary_list[1] * mask * (new_matrix - refine_mean[:, None] * mask - alpha_list[0][:, None] * binary_list[0] * mask), dim=1)
395 |
396 | # 3. Fix mean, alpha1, and alpha2, update B1 and B2
397 | new_matrix_expanded = (new_matrix - refine_mean[:, None] * mask).unsqueeze(-1)
398 | v = torch.stack([-alpha_list[0] - alpha_list[1], -alpha_list[0] + alpha_list[1],
399 | alpha_list[0] - alpha_list[1], alpha_list[0] + alpha_list[1]], dim=1).unsqueeze(1)
400 |
401 | min_indices = torch.argmin(torch.abs(new_matrix_expanded - v), dim=-1)
402 |
403 | binary_list[0] = torch.ones_like(min_indices)
404 | binary_list[0][(min_indices == 0) | (min_indices == 1)] = -1
405 | binary_list[1] = torch.ones_like(min_indices)
406 | binary_list[1][(min_indices == 0) | (min_indices == 2)] = -1
407 |
408 | sum_order_alternating = torch.zeros_like(x) + (alpha_list[0][:, None] * binary_list[0] + alpha_list[1][:, None] * binary_list[1] + refine_mean[:, None]) * mask
409 |
410 | MM = mask[:, :, None] * mask[:, None, :]
411 | refine_mean_den = torch.sum(S * MM, dim=(1,2)) + 1e-10
412 | masked_B0 = binary_list[0] * mask
413 | new_alpha0_den = torch.sum(S * masked_B0[:, :, None] * masked_B0[:, None, :], dim=(1,2)) + 1e-10
414 | masked_B1 = binary_list[1] * mask
415 | new_alpha1_den = torch.sum(S * masked_B1[:, :, None] * masked_B1[:, None, :], dim=(1,2)) + 1e-10
416 | for kk in range(iter2):
417 | # X error update mean
418 | refine_mean = torch.sum(S * (new_matrix - (alpha_list[0][:, None] * binary_list[0] + alpha_list[1][:, None] * binary_list[1]) * mask)[:, :, None] * MM, dim=(1,2)) / refine_mean_den
419 |
420 | # X error update alpha
421 | masked_W_mu = new_matrix - refine_mean[:, None] * mask
422 | alpha_list[0] = torch.sum(S * masked_B0[:, :, None] * (masked_W_mu[:, None, :] - (alpha_list[1][:, None] * masked_B1)[:, None, :]), dim=(1,2)) / new_alpha0_den
423 | alpha_list[1] = torch.sum(S * masked_B1[:, :, None] * (masked_W_mu[:, None, :] - (alpha_list[0][:, None] * masked_B0)[:, None, :]), dim=(1,2)) / new_alpha1_den
424 |
425 | sum_order_alternating = torch.zeros_like(x) + (alpha_list[0][:, None] * binary_list[0] + alpha_list[1][:, None] * binary_list[1] + refine_mean[:, None]) * mask
426 |
427 | return sum_order_alternating
428 |
429 | @torch.no_grad()
430 | def normal_quantize(x, scale, zero, maxq):
431 | q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
432 | return scale * (q - zero)
433 |
434 |
435 | class Binarization(nn.Module):
436 | def __init__(self, weight, method="arb", groupsize=-1):
437 | super().__init__()
438 | oc,ic=weight.shape
439 | if groupsize==-1:
440 | groupsize=ic
441 | self.groupsize=groupsize
442 | self.n_groups=math.ceil(ic/groupsize)
443 | self.method=method
444 | self.mean = 0
445 |
446 | def quantize(self, w, mask, order=2, groupi=0, S=None):
447 | if self.method=="xnor":
448 | w_mean = self.mean[groupi]
449 | w = w - w_mean # oc, ic
450 | w = w.sign()
451 | w = w * self.scale[groupi]
452 | w+=w_mean
453 | elif self.method=="braq": # The method used in BiLLM
454 | w = high_order_residual(w, mask, order=order)
455 |
456 | # arb series
457 | elif self.method == "arb":
458 | if order == 2:
459 | w = high_order_residual_alternating_mean(w, mask, order=order)
460 | else:
461 | w = high_order_residual_alternating_order1(w, mask, order=order)
462 | elif self.method == 'arb-x':
463 | if order == 2:
464 | w = high_order_residual_alternating_mean_x(w, mask, order=order, S=S)
465 | else:
466 | w = high_order_residual_alternating_order1_x(w, mask, order=order, S=S)
467 | elif self.method == 'arb-rc':
468 | if order == 2:
469 | w = high_order_residual_alternating_order2_rc_nomean(w, mask, order=order)
470 | else:
471 | w = high_order_residual_alternating_order1_rc_nomean(w, mask, order=order)
472 |
473 | elif self.method=="sign":
474 | w=(w>0).float()
475 | w*=self.scale[groupi]
476 | elif self.method=="rtn":
477 | w=F.relu(w)
478 | w_int=(w/self.scale[groupi]).round().clamp(0,1)
479 | w=w_int*self.scale[groupi]
480 | elif self.method in ['2bit','4bit']:
481 |
482 | bits = int(self.method[0])
483 | perchannel = True
484 | weight = True
485 | dev = w.device
486 | maxq = torch.tensor(2 ** bits - 1)
487 | scale = torch.zeros(1)
488 | zero = torch.zeros(1)
489 |
490 | if dev != scale.device:
491 | scale=scale.to(dev)
492 | zero=zero.to(dev)
493 | maxq=maxq.to(dev)
494 |
495 | x = w.clone()
496 | shape = x.shape
497 |
498 | if perchannel:
499 | if weight:
500 | x = x.flatten(1)
501 | else:
502 | if len(shape) == 4:
503 | x = x.permute([1, 0, 2, 3])
504 | x = x.flatten(1)
505 | if len(shape) == 3:
506 | x = x.reshape((-1, shape[-1])).t()
507 | if len(shape) == 2:
508 | x = x.t()
509 | else:
510 | x = x.flatten().unsqueeze(0)
511 | tmp = torch.zeros(x.shape[0], device=dev)
512 | xmin = torch.minimum(x.min(1)[0], tmp)
513 | xmax = torch.maximum(x.max(1)[0], tmp)
514 |
515 | tmp = (xmin == 0) & (xmax == 0)
516 | xmin[tmp] = -1
517 | xmax[tmp] = +1
518 | scale = (xmax - xmin) / maxq
519 | zero = torch.round(-xmin / scale)
520 | if not perchannel:
521 | if weight:
522 | tmp = shape[0]
523 | else:
524 | tmp = shape[1] if len(shape) != 3 else shape[2]
525 | scale = scale.repeat(tmp)
526 | zero = zero.repeat(tmp)
527 |
528 | if weight:
529 | shape = [-1] + [1] * (len(shape) - 1)
530 | scale = scale.reshape(shape)
531 | zero = zero.reshape(shape)
532 | w = normal_quantize(w, scale, zero, maxq)
533 |
534 | elif self.method=="prune":
535 | return torch.zeros_like(w)
536 | return w
537 |
--------------------------------------------------------------------------------