├── README.md ├── hw_sim ├── PE_array.py ├── SRAM.py ├── ViTCoD_comp.py ├── __init__.py ├── check_funcs.py ├── reorder.py └── utils.py ├── main.py ├── main_exps.sh ├── main_hw.py ├── models ├── llama.py ├── llama_seq.py ├── ops │ ├── mask_gen │ │ ├── functions │ │ │ └── mask_gen.py │ │ └── src │ │ │ ├── mask_gen.cpp │ │ │ ├── mask_gen.h │ │ │ └── mask_gen_kernal.cu │ ├── setup.py │ └── test_mask_gen.py └── sparse_layers.py ├── optimizers ├── __init__.py └── prodigy.py └── utils ├── data.py └── tools.py /README.md: -------------------------------------------------------------------------------- 1 | # BESA 2 | 3 | This repository contains code to reproduce the key results of the paper "BESA: Pruning Large Language Models with Blockwise Parameter-Efficient Sparsity Allocation", accepted at International Conference on Learning Representations (ICLR), 2024. 4 | 5 | ## Dependencies 6 | 7 | * `torch`: tested on v2.0.1+cu118 8 | * `transformers`: tested on v4.31.0 9 | * `accelerate`: tested on v0.21.0 10 | * `datasets`: tested on v2.14.4 11 | * `timm`: tested on v0.9.5 12 | 13 | **lm-evaluation-harness** 14 | ``` 15 | git clone https://github.com/EleutherAI/lm-evaluation-harness 16 | cd lm-evaluation-harness 17 | pip install -e . 18 | ``` 19 | 20 | **Customized Cuda Operator** 21 | ``` 22 | cd models/ops 23 | python setup.py install 24 | ``` 25 | 26 | ## Usage 27 | 28 | Here is the command to run baseline experiments followed by perplexity evaluations on WikiText2, PTB, C4 and zero-shot tasks. 29 | See also the CMD-argument documentation. 30 | 31 | ``` 32 | bash main_exps.sh 33 | ``` 34 | 35 | ## Hardware Simulation 36 | 37 | We utilize the ViTCoD accelerator to achieve the speed-up that can be obtained through the sparse accelerator. Here is the command to simulate the average runtime of each module in the pruned model. 38 | 39 | ``` 40 | python main_hw.py \ 41 | --model-name MODEL_DIR or HF_MODEL_NAME 42 | --func SIMULATE_CHOICE (q, k, v, o, gate, up, and down are available) 43 | ``` 44 | 45 | ## Others 46 | 47 | In the experiment section of our paper, we present the results of row-wise sparsity, which customize sparsity for each row of target layer's weight within in the block. Additionally, we provide an extension presenting the outcomes of layer-wise sparsity, where each row of the target layer is assigned uniform sparsity. You can find the commands to execute the layer-wise sparsity experiments in the **main_exps.sh** script. Below, we present the perplexity results for the Wikitext2 dataset. 48 | 49 | | | 1-7B | 1-13B | 1-30B | 1-65B | 2-7B | 2-13B | 2-70B | 50 | |------------------:|:-----|:------|:------|:------|:-----|:------|:------| 51 | | Dense | 5.68 | 5.09 | 4.10 | 3.53 | 5.47 | 4.88 | 3.31 | 52 | | SparseGPT | 7.22 | 6.21 | 5.33 | 4.60 | 6.99 | 6.02 | 4.25 | 53 | | Wanda | 7.26 | 6.15 | 5.25 | 4.60 | 6.92 | 5.97 | 4.22 | 54 | | BESA (layer-wise) | 7.04 | 6.07 | 5.16 | 4.51 | 6.77 | 5.85 | 4.14 | 55 | | BESA (row-wise) | 6.86 | 5.92 | 5.00 | 4.33 | 6.60 | 5.75 | 4.09 | 56 | 57 | ## Acknowledgement 58 | 59 | This repo benefits from [SparseGPT](https://github.com/IST-DASLab/sparsegpt), [Prodigy](https://github.com/konstmish/prodigy), and [ViTCoD](https://github.com/GATECH-EIC/ViTCoD). Thanks for their wonderful works. -------------------------------------------------------------------------------- /hw_sim/PE_array.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class PE_array: 5 | def __init__(self, width=64, height=8): 6 | # 64 x 64 PE array 7 | self.width = width 8 | self.height = height 9 | self.res = [[0 for i in range(width)] for j in range(height)] 10 | 11 | def reg_Q(self, temp_Q): 12 | self.Q = temp_Q 13 | 14 | def reg_K(self, temp_K): 15 | self.K = temp_K 16 | 17 | def reg_V(self, temp_V): 18 | self.V = temp_V 19 | 20 | def reg_attn(self, temp_attn): 21 | self.attn = temp_attn 22 | 23 | def reg_index(self, temp_index): 24 | self.index = temp_index 25 | 26 | def cal_attn_map(self): 27 | assert self.Q is not None 28 | assert self.K is not None 29 | assert len(self.Q) == len(self.K) 30 | 31 | # print('Shape of Q: ', self.Q.shape) 32 | # print('Shape of K: ', self.K.shape) 33 | 34 | cycle = 0 35 | for k in range(len(self.Q)): 36 | self.res[0][0] += self.Q[k] * self.K[k] 37 | cycle += 1 38 | 39 | return self.res[0][0], cycle 40 | 41 | def cal_V_update(self): 42 | assert self.attn is not None 43 | assert self.V is not None 44 | # print('Shape of attn_map: ', self.attn.shape) 45 | # print('Shape of V: ', self.V.shape) 46 | 47 | assert len(self.attn[0]) == len(self.V[0]) 48 | 49 | cycle = 0 50 | for i in range(len(self.attn)): 51 | for j in range(len(self.V)): 52 | for k in range(len(self.V[0])): 53 | self.res[i][j] += self.attn[i][k] * self.V[j][k] 54 | cycle += 1 55 | 56 | return np.array(self.res), cycle 57 | 58 | def reset_res(self): 59 | self.res = [[0 for i in range(64)] for j in range(64)] 60 | 61 | def store_res(self): 62 | cycle = 1 63 | return cycle 64 | 65 | def store_res_V(self): 66 | cycle = len(self.attn) * len(self.V) // 64 # assume 64 number paralleism for one cycle 67 | return cycle -------------------------------------------------------------------------------- /hw_sim/SRAM.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | class SRAM: 5 | def __init__(self): 6 | self.max_Q = 53 * 1024 * 8 # 53KB 7 | self.max_K = 53 * 1024 * 8 # 53KB 8 | self.max_V = 53 * 1024 * 8 # 53KB 9 | self.max_index = 20 * 1024 * 8 # 20KB 10 | self.max_output = 108 * 1024 * 8 # 108KB 11 | 12 | # HBM to SRAM 13 | self.bandwidth = 76.8 * 1024 * 1024 * 1024 * 8 # 76.8GB/s 14 | self.clock_frequency = 500 * 1e6 # 500MHz 15 | 16 | def preload_decoder(self, nums=0, bits=32, bandwidth_ratio=1): 17 | if nums * bits > self.max_Q: 18 | print('Error: loading Q from DRAM to SRAM') 19 | else: 20 | latency = nums * bits / (self.bandwidth * bandwidth_ratio) 21 | cycle = math.ceil(latency * self.clock_frequency) 22 | 23 | return cycle 24 | 25 | def preload_encoder(self, nums=0, bits=32, bandwidth_ratio=1): 26 | if nums * bits > self.max_Q: 27 | print('Error: loading Q from DRAM to SRAM') 28 | else: 29 | latency = nums * bits / (self.bandwidth * bandwidth_ratio) 30 | cycle = math.ceil(latency * self.clock_frequency) 31 | 32 | return cycle 33 | 34 | def preload_Q(self, nums=0, bits=32, bandwidth_ratio=1): 35 | if nums * bits > self.max_Q: 36 | print('Error: loading Q from DRAM to SRAM') 37 | else: 38 | latency = nums * bits / (self.bandwidth * bandwidth_ratio) 39 | cycle = math.ceil(latency * self.clock_frequency) 40 | 41 | return cycle 42 | 43 | def data_cycle(self, num, bit=8, bandwidth_ratio=1): 44 | latency = (num*bit) / (self.bandwidth * bandwidth_ratio) 45 | cycle = math.ceil(latency * self.clock_frequency) 46 | return cycle 47 | 48 | def preload_K(self, nums=0, bits=32, bandwidth_ratio=1): 49 | if nums * bits > self.max_K: 50 | print('Error: loading K from DRAM to SRAM') 51 | else: 52 | latency = nums * bits / (self.bandwidth * bandwidth_ratio) 53 | cycle = math.ceil(latency * self.clock_frequency) 54 | 55 | return cycle 56 | 57 | def preload_V(self, nums=0, bits=32, bandwidth_ratio=1): 58 | if nums * bits > self.max_V: 59 | print('Error: loading V from DRAM to SRAM') 60 | else: 61 | latency = nums * bits / (self.bandwidth * bandwidth_ratio) 62 | cycle = math.ceil(latency * self.clock_frequency) 63 | 64 | return cycle 65 | 66 | def preload_index(self, nums=0, bits=32, bandwidth_ratio=1): 67 | if nums * bits > self.max_index: 68 | print('Error: loading index from DRAM to SRAM') 69 | else: 70 | latency = nums * bits / (self.bandwidth * bandwidth_ratio) 71 | cycle = math.ceil(latency * self.clock_frequency) 72 | 73 | return cycle 74 | 75 | def store_out(self, nums=0, bits=32, bandwidth_ratio=1): 76 | if nums * bits > self.max_output: 77 | print('Error: storing back intermediate results from PE to SRAM') 78 | else: 79 | latency = nums * bits / (self.bandwidth * bandwidth_ratio) 80 | cycle = math.ceil(latency * self.clock_frequency) 81 | 82 | return cycle 83 | 84 | def preload_weight(self, nums=0, bits=32, bandwidth_ratio=1): 85 | if nums * bits > self.max_output: 86 | print('Error: storing back intermediate results from PE to SRAM') 87 | else: 88 | latency = nums * bits / (self.bandwidth * bandwidth_ratio) 89 | cycle = math.ceil(latency * self.clock_frequency) 90 | 91 | return cycle -------------------------------------------------------------------------------- /hw_sim/ViTCoD_comp.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | from scipy.sparse import coo_matrix 4 | 5 | from .PE_array import PE_array 6 | 7 | 8 | def sparse_linear_simulate(sparse_weight_mask, num_global_cols, PE_width=64, PE_height=64): 9 | my_PE = PE_array(PE_width, PE_height) 10 | inp = np.random.random((sparse_weight_mask.shape[1], 1)) 11 | 12 | mask = sparse_weight_mask 13 | global_tokens = int(num_global_cols) 14 | sparser = coo_matrix(1 - mask[:, global_tokens:]) 15 | sparser = np.column_stack((sparser.row, sparser.col)) 16 | dense_ratio = global_tokens * sparse_weight_mask.shape[0] / (len(sparser) + global_tokens * sparse_weight_mask.shape[0]) 17 | dense_PE_width = int(my_PE.width * dense_ratio) 18 | sparse_PE_width = my_PE.width - dense_PE_width 19 | 20 | # ############## dense pattern weight * inp ############## 21 | dense_SpMM_PE_cycles = 0 22 | if dense_PE_width > 0: 23 | for _ in range(math.ceil((inp.shape[0] * inp.shape[1] * global_tokens) / (dense_PE_width * my_PE.height))): 24 | dense_SpMM_PE_cycles += 1 25 | print('Dense SpMM PE caclulation | cycles: {}'.format(dense_SpMM_PE_cycles)) 26 | 27 | # ############## sparse pattern weight * inp ############## 28 | # acumulation 29 | num_list = [] 30 | accumulator = 0 31 | prev_cout_index = 0 32 | for _cout_index, _cin_index in sparser: 33 | if _cout_index == prev_cout_index: 34 | accumulator += 1 35 | else: 36 | num_list.append(accumulator) 37 | accumulator = 1 38 | prev_cout_index = _cout_index 39 | num_list.append(accumulator) 40 | 41 | # ############## sparse pattern weight * inp ############## 42 | sparse_SpMM_PE_cycles = 0 43 | for row_num in num_list: 44 | sparse_SpMM_PE_cycles += row_num * inp.shape[1] 45 | if sparse_PE_width > 0: 46 | sparse_SpMM_PE_cycles = math.ceil(sparse_SpMM_PE_cycles / (sparse_PE_width * my_PE.height)) 47 | print('Sparse SpMM PE caclulation | cycles: {}'.format(sparse_SpMM_PE_cycles)) 48 | 49 | SpMM_PE_cycles = max(sparse_SpMM_PE_cycles, dense_SpMM_PE_cycles) 50 | print(f"Computation Cycles: {SpMM_PE_cycles}") 51 | return SpMM_PE_cycles 52 | 53 | 54 | def sparse_linear_flops(sparse_weight_mask, num_global_cols): 55 | inp = np.random.random((sparse_weight_mask.shape[1], 1)) 56 | 57 | mask = sparse_weight_mask 58 | global_tokens = int(num_global_cols) 59 | sparser = coo_matrix(1 - mask[:, global_tokens:]) 60 | sparser = np.column_stack((sparser.row, sparser.col)) 61 | dense_flops = global_tokens * sparse_weight_mask.shape[0] 62 | 63 | # ############## sparse pattern weight * inp ############## 64 | num_list = [] 65 | accumulator = 0 66 | prev_cout_index = 0 67 | for _cout_index, _cin_index in sparser: 68 | if _cout_index == prev_cout_index: 69 | accumulator += 1 70 | else: 71 | num_list.append(accumulator) 72 | accumulator = 1 73 | prev_cout_index = _cout_index 74 | num_list.append(accumulator) 75 | 76 | sparse_flops = 0 77 | for row_num in num_list: 78 | sparse_flops += row_num * inp.shape[1] 79 | 80 | total_flops = dense_flops + sparse_flops 81 | print(f"Computation FLOPs: {total_flops}") 82 | return total_flops 83 | -------------------------------------------------------------------------------- /hw_sim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/LLMPrune-BESA/4158d7749cc883135bd1872a6e91b10c67008417/hw_sim/__init__.py -------------------------------------------------------------------------------- /hw_sim/check_funcs.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Process 2 | 3 | from .utils import get_model, calc_sparsity 4 | from .utils import process_layer_loop, process_patch_layer_loop 5 | from .utils import process_dense_layer, process_patch_dense_layer 6 | 7 | 8 | def check_sparsity(model_name, model=None): 9 | model = get_model(model_name) if model is None else model 10 | layers = model.model.layers 11 | q_sparsity, k_sparsity, v_sparsity, o_sparsity, gate_sparsity, up_sparsity, down_sparsity = [], [], [], [], [], [], [] 12 | for layer_index in range(len(layers)): 13 | layer = layers[layer_index] 14 | q_sparsity.append(calc_sparsity(layer.self_attn.q_proj.weight.data)) 15 | k_sparsity.append(calc_sparsity(layer.self_attn.k_proj.weight.data)) 16 | v_sparsity.append(calc_sparsity(layer.self_attn.v_proj.weight.data)) 17 | o_sparsity.append(calc_sparsity(layer.self_attn.o_proj.weight.data)) 18 | gate_sparsity.append(calc_sparsity(layer.mlp.gate_proj.weight.data)) 19 | up_sparsity.append(calc_sparsity(layer.mlp.up_proj.weight.data)) 20 | down_sparsity.append(calc_sparsity(layer.mlp.down_proj.weight.data)) 21 | 22 | sparse_func = { 23 | "Min": lambda x : 100 * min(x), 24 | "Max": lambda x : 100 * max(x), 25 | "Average": lambda x : 100 * sum(x) / len(x) 26 | } 27 | for func_name in sparse_func: 28 | for layer_name in ['q', 'k', 'v', 'o', 'gate', 'up', 'down']: 29 | print(f"{func_name} sparsity for {layer_name}_proj: {sparse_func[func_name](eval(f'{layer_name}_sparsity'))}") 30 | 31 | 32 | def check_proj(model_name, layer_name, threshold_ratio=0.5, model=None): 33 | model = get_model(model_name) if model is None else model 34 | layers = model.model.layers 35 | if layer_name in ['q', 'k', 'v', 'o']: 36 | block_name = 'self_attn' 37 | elif layer_name in ['gate', 'up', 'down']: 38 | block_name = 'mlp' 39 | else: 40 | raise ValueError(f"Invalid layer_name: {layer_name}") 41 | proc_func = process_patch_layer_loop if layer_name == 'down' else process_layer_loop 42 | 43 | total_cycles, total_flops = 0, 0 44 | for layer_index in range(len(layers)): 45 | layer = layers[layer_index] 46 | pe_cycles, flops = proc_func(model_name, eval(f"layer.{block_name}.{layer_name}_proj.weight.data"), layer_index, f'{layer_name}_proj', threshold_ratio) 47 | total_cycles += pe_cycles 48 | total_flops += flops 49 | print(f"Avg {layer_name}_proj cycles: {total_cycles / len(layers)}") 50 | print(f"Avg {layer_name}_proj flops: {total_flops / len(layers)}") 51 | 52 | 53 | def check_attn(model_name, threshold_ratio=0.5, model=None): 54 | model = get_model(model_name) if model is None else model 55 | for layer_name in ['q', 'k', 'v', 'o']: 56 | check_proj(model_name, layer_name, threshold_ratio, model) 57 | 58 | 59 | def check_mlp(model_name, threshold_ratio=0.5, model=None): 60 | model = get_model(model_name) if model is None else model 61 | for layer_name in ['gate', 'up', 'down']: 62 | check_proj(model_name, layer_name, threshold_ratio, model) 63 | 64 | 65 | def check_dense(model_name, model=None): 66 | model = get_model(model_name) if model is None else model 67 | layers = model.model.layers 68 | layer = layers[0] 69 | process_dense_layer(layer.self_attn.q_proj.weight.data, 'q_proj') 70 | process_dense_layer(layer.self_attn.k_proj.weight.data, 'k_proj') 71 | process_dense_layer(layer.self_attn.v_proj.weight.data, 'v_proj') 72 | process_dense_layer(layer.self_attn.o_proj.weight.data, 'o_proj') 73 | process_dense_layer(layer.mlp.gate_proj.weight.data, 'gate_proj') 74 | process_dense_layer(layer.mlp.up_proj.weight.data, 'up_proj') 75 | process_dense_layer(layer.mlp.down_proj.weight.data, 'down_proj') 76 | process_patch_dense_layer(layer.mlp.down_proj.weight.data, 'down_proj') 77 | 78 | 79 | def check_model(model_name, threshold_ratio=0.5, model=None): 80 | model = get_model(model_name) if model is None else model 81 | layers = model.model.layers 82 | p_list = [] 83 | 84 | for layer_index in range(len(layers)): 85 | layer = layers[layer_index] 86 | p_list.append(Process(target=process_layer_loop, args=(model_name, layer.self_attn.q_proj.weight.data, layer_index, 'q_proj', threshold_ratio))) 87 | p_list.append(Process(target=process_layer_loop, args=(model_name, layer.self_attn.k_proj.weight.data, layer_index, 'k_proj', threshold_ratio))) 88 | p_list.append(Process(target=process_layer_loop, args=(model_name, layer.self_attn.v_proj.weight.data, layer_index, 'v_proj', threshold_ratio))) 89 | p_list.append(Process(target=process_layer_loop, args=(model_name, layer.self_attn.o_proj.weight.data, layer_index, 'o_proj', threshold_ratio))) 90 | p_list.append(Process(target=process_layer_loop, args=(model_name, layer.mlp.gate_proj.weight.data, layer_index, 'gate_proj', threshold_ratio))) 91 | p_list.append(Process(target=process_layer_loop, args=(model_name, layer.mlp.up_proj.weight.data, layer_index, 'up_proj', threshold_ratio))) 92 | p_list.append(Process(target=process_patch_layer_loop, args=(model_name, layer.mlp.down_proj.weight.data, layer_index, 'down_proj', threshold_ratio))) 93 | 94 | for p in p_list: 95 | p.start() 96 | for p in p_list: 97 | p.join() 98 | -------------------------------------------------------------------------------- /hw_sim/reorder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import dgl 4 | import torch 5 | 6 | 7 | def calc(graph, threshold=90): 8 | a0 = graph 9 | u_list = [] 10 | v_list = [] 11 | for i in range(graph.shape[0]): 12 | for j in range(graph.shape[1]): 13 | if not a0[i][j]: 14 | u_list.append(j) 15 | v_list.append(i) 16 | g = dgl.graph((u_list, v_list)) 17 | g.ndata['in_deg'] = g.in_degrees() 18 | 19 | n_node = g.num_nodes() 20 | n_edge = g.num_edges() 21 | out_deg = g.out_degrees() 22 | high_density = out_deg[out_deg > threshold] 23 | high_density_idx = np.where(out_deg > threshold)[0] 24 | 25 | total = len(high_density_idx) 26 | tmp1 = 200 27 | orig_a, orig_b = g.edges() 28 | 29 | total_dense = 0 30 | for i in high_density_idx: 31 | total_dense += torch.sum(orig_a == i) 32 | 33 | for i in range(total): 34 | orig_a[orig_a == i] = tmp1 35 | orig_b[orig_b == i] = tmp1 36 | orig_a[orig_a == high_density_idx[i]] = i 37 | orig_b[orig_b == high_density_idx[i]] = i 38 | orig_a[orig_a == tmp1] = torch.tensor(high_density_idx[i]) 39 | orig_b[orig_b == tmp1] = torch.tensor(high_density_idx[i]) 40 | dense_cnt = total_dense 41 | 42 | new_graph = torch.ones(graph.shape[0], graph.shape[1]) 43 | for i in range(len(orig_a)): 44 | try: 45 | new_graph[orig_b[i], orig_a[i]] = 0 46 | except: 47 | pass 48 | new_graph = new_graph.numpy() 49 | total_cnt = n_edge 50 | 51 | return dense_cnt, total_cnt, new_graph, total 52 | -------------------------------------------------------------------------------- /hw_sim/utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import torch 4 | import torch.nn as nn 5 | from transformers import AutoModelForCausalLM 6 | 7 | from .reorder import calc 8 | from .ViTCoD_comp import sparse_linear_simulate, sparse_linear_flops 9 | 10 | 11 | def get_model(model_name): 12 | def skip(*args, **kwargs): 13 | pass 14 | nn.init.kaiming_uniform_ = skip 15 | nn.init.uniform_ = skip 16 | nn.init.normal_ = skip 17 | model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True) 18 | return model 19 | 20 | 21 | def calc_sparsity(layer_weight): 22 | layer_params = layer_weight.numel() 23 | layer_pruned = layer_params - torch.count_nonzero(layer_weight) 24 | return layer_pruned / layer_params 25 | 26 | 27 | def process_sparse_layer(model_name, layer_weight, layer_index, layer_name, threshold_ratio=0.5): 28 | print(f"Process layer-{layer_index}-{layer_name}") 29 | save_file_name = f"{model_name}_masks/{threshold_ratio}_layer_{layer_index}_{layer_name}.pkl" 30 | 31 | try: 32 | layer_info = pickle.load(open(save_file_name, 'rb')) 33 | cnt_d = layer_info['dense_ratio'] 34 | cnt_e = layer_info['sparse_ratio'] 35 | SpMM_PE_cycles = layer_info['SpMM_PE_cycles'] 36 | print(f"Computation Cycles: {SpMM_PE_cycles}") 37 | 38 | if 'SpMM_FLOPs' not in layer_info: 39 | SpMM_FLOPs = sparse_linear_flops(layer_info['sparse_weight_mask'], layer_info['num_global_cols']) 40 | layer_info['SpMM_FLOPs'] = SpMM_FLOPs 41 | pickle.dump(layer_info, open(save_file_name, 'wb')) 42 | else: 43 | SpMM_FLOPs = layer_info['SpMM_FLOPs'] 44 | print(f"Computation FLOPs: {SpMM_FLOPs}") 45 | except: 46 | bool_weight = layer_weight == 0 47 | bool_weight = bool_weight.numpy() 48 | threshold = threshold_ratio * bool_weight.shape[0] 49 | cnt_d, cnt_e, sparse_weight_mask, num_global_cols = calc(bool_weight, threshold) 50 | SpMM_PE_cycles = sparse_linear_simulate(sparse_weight_mask, num_global_cols) 51 | SpMM_FLOPs = sparse_linear_flops(sparse_weight_mask, num_global_cols) 52 | pickle.dump({ 53 | 'sparse_weight_mask': sparse_weight_mask, 54 | 'num_global_cols': num_global_cols, 55 | 'dense_ratio': cnt_d, 56 | 'sparse_ratio': cnt_e, 57 | 'SpMM_PE_cycles': SpMM_PE_cycles, 58 | 'SpMM_FLOPs': SpMM_FLOPs, 59 | 'layer_shape': [layer_weight.shape[0], layer_weight.shape[1]] 60 | }, open(save_file_name, 'wb')) 61 | 62 | return cnt_d, cnt_e, SpMM_PE_cycles, SpMM_FLOPs 63 | 64 | 65 | def process_layer_loop(model_name, layer_weight, layer_index, layer_name, threshold_ratio=0.5): 66 | proj_ratio = threshold_ratio 67 | try_times = 0 68 | min_pe_cycles = float('inf') 69 | min_cycle_flops = float('inf') 70 | 71 | while True: 72 | D, E, pe_cycles, flops = process_sparse_layer(model_name, layer_weight, layer_index, layer_name, proj_ratio) 73 | if pe_cycles < min_pe_cycles: 74 | min_pe_cycles = pe_cycles 75 | min_cycle_flops = flops 76 | if D/E >= 0.6: 77 | proj_ratio += 0.1 78 | try_times += 1 79 | elif D/E <= 0.4: 80 | proj_ratio -= 0.1 81 | try_times += 1 82 | else: 83 | break 84 | 85 | if try_times == 5: 86 | break 87 | else: 88 | continue 89 | print(f'layer_{layer_index}_{layer_name}, Dense: {D} ({D/E * 100:.2f}%), Sparse: {E-D} ({(E-D)/E * 100:.2f}%), Total: {E}') 90 | 91 | return min_pe_cycles, min_cycle_flops 92 | 93 | 94 | def process_patch_layer_loop(model_name, layer_weight, layer_index, layer_name, threshold_ratio=0.5): 95 | rows, cols = layer_weight.shape[0], layer_weight.shape[1] 96 | assert cols > rows 97 | patch_idx, total_pe_cycles, total_flops = 0, 0, 0 98 | for begin_idx in range(0, cols, rows): 99 | end_idx = min(cols, begin_idx + rows) 100 | patch_weight = layer_weight[:, begin_idx:end_idx] 101 | pe_cycles, flops = process_layer_loop(model_name, patch_weight, layer_index, f"{layer_name}_Patch-{patch_idx}", threshold_ratio) 102 | patch_idx += 1 103 | total_pe_cycles += pe_cycles 104 | total_flops += flops 105 | 106 | return total_pe_cycles, total_flops 107 | 108 | 109 | def process_patch_dense_layer(layer_weight, layer_name): 110 | rows, cols = layer_weight.shape[0], layer_weight.shape[1] 111 | assert cols > rows 112 | patch_idx, total_pe_cycles, total_flops = 0, 0, 0 113 | for begin_idx in range(0, cols, rows): 114 | end_idx = min(cols, begin_idx + rows) 115 | patch_weight = layer_weight[:, begin_idx:end_idx] 116 | pe_cycles = sparse_linear_simulate(patch_weight, patch_weight.shape[1]) 117 | flops = sparse_linear_flops(patch_weight, patch_weight.shape[1]) 118 | patch_idx += 1 119 | total_pe_cycles += pe_cycles 120 | total_flops += flops 121 | print(f"{layer_name} dense cycles: {total_pe_cycles}") 122 | print(f"{layer_name} flops: {flops}") 123 | 124 | 125 | def process_dense_layer(layer_weight, layer_name): 126 | SpMM_PE_cycles = sparse_linear_simulate(layer_weight, layer_weight.shape[1]) 127 | SpMM_FLOPs = sparse_linear_flops(layer_weight, layer_weight.shape[1]) 128 | print(f"{layer_name} dense cycles: {SpMM_PE_cycles}") 129 | print(f"{layer_name} flops: {SpMM_FLOPs}") -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import time 4 | import uuid 5 | import pickle 6 | import argparse 7 | import contextlib 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from timm.utils import NativeScaler 13 | 14 | from optimizers import Prodigy 15 | from models.llama import LLaMA 16 | from models.llama_seq import LLaMA_Seq 17 | from models.sparse_layers import SparseLinear 18 | from utils.data import get_loaders 19 | from utils.tools import slurm_dist_init, is_master_process, lm_eval_model, find_layers, init_learn_sparsity, finish_learn_sparsity, get_sparsity, get_sparsity_params, eval_ppl, FakeScaler, use_old_forward, use_new_forward, get_lora_params, auto_map_model 20 | 21 | USE_WANDB = False 22 | if USE_WANDB: 23 | import wandb 24 | else: 25 | class Wandb: 26 | def __init__(self): pass 27 | 28 | def login(self): pass 29 | 30 | def init(self, **kwargs): pass 31 | 32 | def finish(self): pass 33 | 34 | def log(self, log_info): 35 | print(f"wandb: {log_info}") 36 | 37 | wandb = Wandb() 38 | 39 | 40 | def get_model(model_name, batch_size=1): 41 | def skip(*args, **kwargs): 42 | pass 43 | nn.init.kaiming_uniform_ = skip 44 | nn.init.uniform_ = skip 45 | nn.init.normal_ = skip 46 | 47 | if 'llama' in model_name.lower(): 48 | if USE_LLaMA_SEQ: 49 | model = LLaMA_Seq(model_name, batch_size=batch_size) 50 | else: 51 | model = LLaMA(model_name, batch_size=batch_size) 52 | else: 53 | raise NotImplementedError(f"Invalid model name: {model_name}") 54 | return model 55 | 56 | 57 | def loss_func(l2_loss, sparsity): 58 | loss = args.l2_alpha * l2_loss + args.sparsity_beta * ((sparsity - args.sparsity) / args.sparsity) ** 2 59 | return loss 60 | 61 | 62 | def wrapped_forward(layer, inps, attention_mask=None, position_ids=None, repeat_size=1): 63 | attention_mask = attention_mask.repeat(repeat_size, 1, 1, 1) if attention_mask is not None else attention_mask 64 | position_ids = position_ids.repeat(repeat_size, 1) if position_ids is not None else position_ids 65 | return layer(inps, attention_mask=attention_mask, position_ids=position_ids) 66 | 67 | 68 | def val_epoch(layer, sparse_layers, attention_mask, position_ids, inps, outs, pruned_outs, dense_outs, refer_dense=False): 69 | refer_outs = dense_outs if refer_dense else outs 70 | with torch.no_grad(): 71 | loss_list, l2_loss_list, dense_l2_loss_list = [], [], [] 72 | sparsity = float(get_sparsity(sparse_layers)) 73 | if args.norm_all: 74 | l2_scaler = torch.norm(refer_outs.type(torch.float32).reshape((-1, refer_outs.shape[-1])).t(), p=2, dim=1) 75 | 76 | for begin_idx in range(0, args.nsamples, args.prune_batch_size): 77 | end_idx = min(args.nsamples, begin_idx + args.prune_batch_size) 78 | with inference_context: 79 | pruned_outs[begin_idx: end_idx,] = wrapped_forward(layer, inps[begin_idx: end_idx,], attention_mask, position_ids, end_idx - begin_idx)[0] 80 | if not args.norm_all: 81 | l2_scaler = torch.norm(refer_outs[begin_idx: end_idx,].type(torch.float32).reshape((-1, refer_outs[begin_idx: end_idx,].shape[-1])).t(), p=2, dim=1).detach() 82 | l2_loss = (((refer_outs[begin_idx: end_idx,] - pruned_outs[begin_idx: end_idx,]) / l2_scaler) ** 2).sum() / pruned_outs[begin_idx: end_idx,].shape[-1] 83 | loss = loss_func(l2_loss, sparsity) 84 | if not args.no_dense_loss: 85 | dense_l2_loss = ((dense_outs[begin_idx: end_idx,] - pruned_outs[begin_idx: end_idx,]) ** 2).sum() / pruned_outs[begin_idx: end_idx,].numel() 86 | dense_l2_loss_list.append(dense_l2_loss.item()) 87 | loss_list.append(float(loss)) 88 | l2_loss_list.append(l2_loss.item()) 89 | val_loss = sum(loss_list) / len(loss_list) 90 | val_l2_loss = sum(l2_loss_list) / len(l2_loss_list) 91 | return sparsity, val_loss, val_l2_loss, dense_l2_loss_list 92 | 93 | 94 | def train_epoch(layer, sparse_layers, attention_mask, position_ids, inps, refer_outs, optimizer, loss_scaler, train_params): 95 | l2_loss_list, loss_list = [], [] 96 | if args.norm_all: 97 | l2_scaler = torch.norm(refer_outs.type(torch.float32).reshape((-1, refer_outs.shape[-1])).t(), p=2, dim=1).detach() 98 | 99 | for begin_idx in range(0, args.nsamples, args.prune_batch_size): 100 | end_idx = min(args.nsamples, begin_idx + args.prune_batch_size) 101 | with inference_context: 102 | pruned_out = wrapped_forward(layer, inps[begin_idx: end_idx,], attention_mask, position_ids, end_idx - begin_idx)[0] 103 | sparsity = get_sparsity(sparse_layers) 104 | if not args.norm_all: 105 | l2_scaler = torch.norm(refer_outs[begin_idx: end_idx,].type(torch.float32).reshape((-1, refer_outs[begin_idx: end_idx,].shape[-1])).t(), p=2, dim=1).detach() 106 | l2_loss = (((refer_outs[begin_idx: end_idx,] - pruned_out) / l2_scaler) ** 2).sum() / refer_outs[begin_idx: end_idx,].shape[-1] 107 | loss = loss_func(l2_loss, sparsity) 108 | loss_list.append(loss.item()) 109 | l2_loss_list.append(l2_loss.item()) 110 | optimizer.zero_grad() 111 | loss_scaler(loss, optimizer, parameters=train_params, clip_grad=args.clip_grad, clip_mode=args.clip_mode) 112 | torch.cuda.empty_cache() 113 | train_loss = sum(loss_list) / len(loss_list) 114 | train_l2_loss = sum(l2_loss_list) / len(l2_loss_list) 115 | 116 | return train_loss, train_l2_loss 117 | 118 | 119 | def grad_prune(layer_index, layer, sparse_layers, attention_mask, position_ids, inps, outs, pruned_outs, dense_outs): 120 | print(f"Grad prune layer {layer_index}") 121 | sparsity_params = get_sparsity_params(sparse_layers) 122 | lora_params = get_lora_params(sparse_layers) 123 | if len(lora_params) > 0: 124 | param_lr = args.prodigy_lr if not args.normal_opt else 1e-3 if args.normal_default else args.normal_opt_lr 125 | compress_params = [ 126 | {'params': sparsity_params, 'lr': param_lr}, 127 | {'params': lora_params, 'lr': param_lr}, 128 | ] 129 | train_params = sparsity_params + lora_params 130 | else: 131 | compress_params = train_params = sparsity_params 132 | loss_scaler = FakeScaler() if args.no_scaler else NativeScaler() 133 | 134 | if args.normal_opt: 135 | if args.normal_default: 136 | optimizer = torch.optim.AdamW(compress_params) 137 | else: 138 | optimizer = torch.optim.AdamW(compress_params, lr=args.normal_opt_lr, weight_decay=0) 139 | else: 140 | optimizer = Prodigy(compress_params, args.prodigy_lr, 141 | weight_decay=args.weight_decay, 142 | decouple=not args.no_decouple, 143 | use_bias_correction=args.use_bias_correction, 144 | safeguard_warmup=args.safeguard_warmup, 145 | d_coef=args.d_coef 146 | ) 147 | 148 | if args.use_cos_sche: 149 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) 150 | 151 | # learn sparsity epochs 152 | refer_outs = dense_outs if args.prune_dense else outs 153 | for epoch in range(args.epochs): 154 | # train epoch 155 | train_loss, train_l2_loss = train_epoch(layer, sparse_layers, attention_mask, position_ids, inps, refer_outs, optimizer, loss_scaler, train_params) 156 | if args.use_cos_sche: 157 | lr_scheduler.step(epoch) 158 | torch.cuda.empty_cache() 159 | 160 | # val epoch 161 | sparsity, val_loss, val_l2_loss, dense_l2_loss_list = val_epoch(layer, sparse_layers, attention_mask, position_ids, inps, outs, pruned_outs, dense_outs, args.prune_dense) 162 | 163 | wandb_log = { 164 | f'layer_{layer_index}-train_loss': train_loss, 165 | f'layer_{layer_index}-train_l2_loss': train_l2_loss, 166 | f'layer_{layer_index}-sparsity': sparsity, 167 | f'layer_{layer_index}-val_loss': val_loss, 168 | f'layer_{layer_index}-val_l2_loss': val_l2_loss, 169 | } 170 | if not args.no_dense_loss: 171 | dense_val_l2_loss = sum(dense_l2_loss_list) / len(dense_l2_loss_list) 172 | wandb_log[f'layer_{layer_index}-dense_val_l2_loss'] = dense_val_l2_loss 173 | for layer_name in sparse_layers: 174 | sparse_layer = sparse_layers[layer_name] 175 | wandb_log[f"layer_{layer_index}-{layer_name}_sparsity"] = float(sparse_layer.sparsity) 176 | wandb.log(wandb_log) 177 | 178 | return wandb_log, sparsity 179 | 180 | 181 | def fixed_prune(layer_index, layer, sparse_layers, attention_mask, position_ids, inps, outs, pruned_outs, dense_outs): 182 | print(f"Fixed prune layer {layer_index}") 183 | sparsity, val_loss, val_l2_loss, dense_l2_loss_list = val_epoch(layer, sparse_layers, attention_mask, position_ids, inps, outs, pruned_outs, dense_outs, args.prune_dense) 184 | wandb_log = { 185 | f'layer_{layer_index}-val_loss': val_loss, 186 | f'layer_{layer_index}-val_l2_loss': val_l2_loss, 187 | f'layer_{layer_index}-sparsity': args.sparsity, 188 | } 189 | if not args.no_dense_loss: 190 | dense_val_l2_loss = val_l2_loss if args.prune_dense else sum(dense_l2_loss_list) / len(dense_l2_loss_list) 191 | wandb_log[f'layer_{layer_index}-dense_val_l2_loss'] = dense_val_l2_loss 192 | wandb.log(wandb_log) 193 | 194 | return wandb_log, sparsity 195 | 196 | 197 | def compress_model(model, dataloader): 198 | class Catcher(nn.Module): 199 | def __init__(self, module): 200 | super().__init__() 201 | self.module = module 202 | def forward(self, inp, **kwargs): 203 | inps[cache['i']] = inp 204 | cache['i'] += 1 205 | cache['attention_mask'] = kwargs['attention_mask'] 206 | if 'position_ids' in kwargs: 207 | cache['position_ids'] = kwargs['position_ids'] 208 | raise ValueError 209 | 210 | def add_batch(layer_name): 211 | def tmp(_, inp, out): 212 | sparse_layers[layer_name].add_batch(inp[0].data, out.data) 213 | return tmp 214 | 215 | print('Starting ...') 216 | prune_start = time.time() 217 | dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 218 | use_cache = model.config.use_cache 219 | model.config.use_cache = False 220 | layers = model.model.model.layers 221 | model.model.model.embed_tokens = model.model.model.embed_tokens.to(dev) 222 | 223 | dtype = next(iter(model.model.parameters())).dtype 224 | inps = torch.zeros( 225 | (args.nsamples, model.seqlen, model.config.hidden_size), device=dev, dtype=dtype 226 | ) 227 | cache = {'i': 0, 'attention_mask': None, "position_ids": None} 228 | 229 | layers[0] = layers[0].to(dev) 230 | layers[0] = Catcher(layers[0]) 231 | for i in range(args.nsamples): 232 | try: 233 | batch = dataloader[i] 234 | model.model(batch[0].to(dev)) 235 | except ValueError: 236 | pass 237 | layers[0] = layers[0].module 238 | layers = layers.cpu() 239 | model.model.model.embed_tokens = model.model.model.embed_tokens.cpu() 240 | 241 | position_ids = cache['position_ids'] 242 | attention_mask = cache['attention_mask'] 243 | if args.use_fp32: 244 | inps = inps.float() 245 | attention_mask = attention_mask.float() 246 | dtype = torch.float32 247 | torch.cuda.empty_cache() 248 | 249 | pruned_outs = torch.zeros_like(inps) 250 | if args.prune_dense or (not args.no_dense_loss): 251 | dense_inps = inps.clone() 252 | dense_outs = torch.zeros_like(inps) 253 | else: 254 | dense_outs = None 255 | outs = None if args.prune_dense else torch.zeros_like(inps) 256 | 257 | print('Ready.') 258 | model_prune_log, model_sparsity = [], [] 259 | for i in range(len(layers)): 260 | layer = layers[i].to(dev) 261 | use_old_forward(layer, recurse=True) 262 | if args.use_fp32: 263 | layer = layer.float() 264 | 265 | layer.self_attn.q_proj = SparseLinear(layer.self_attn.q_proj, args.metric_type, args.wise_dim) 266 | layer.self_attn.k_proj = SparseLinear(layer.self_attn.k_proj, args.metric_type, args.wise_dim) 267 | layer.self_attn.v_proj = SparseLinear(layer.self_attn.v_proj, args.metric_type, args.wise_dim) 268 | if 'llama' in args.model.lower(): 269 | layer.self_attn.o_proj = SparseLinear(layer.self_attn.o_proj, args.metric_type, args.wise_dim) 270 | layer.mlp.gate_proj = SparseLinear(layer.mlp.gate_proj, args.metric_type, args.wise_dim) 271 | layer.mlp.up_proj = SparseLinear(layer.mlp.up_proj, args.metric_type, args.wise_dim) 272 | layer.mlp.down_proj = SparseLinear(layer.mlp.down_proj, args.metric_type, args.wise_dim) 273 | elif 'opt' in args.model.lower(): 274 | layer.self_attn.out_proj = SparseLinear(layer.self_attn.out_proj, args.metric_type, args.wise_dim) 275 | layer.fc1 = SparseLinear(layer.fc1, args.metric_type, args.wise_dim) 276 | layer.fc2 = SparseLinear(layer.fc2, args.metric_type, args.wise_dim) 277 | 278 | handles = [] 279 | sparse_layers = find_layers(layer, layers=[SparseLinear]) 280 | for layer_name in sparse_layers: 281 | sparse_layer = sparse_layers[layer_name] 282 | handles.append(sparse_layer.register_forward_hook(add_batch(layer_name))) 283 | with inference_context: 284 | refer_outs = pruned_outs if outs is None else outs 285 | for begin_idx in range(0, args.nsamples, args.prune_batch_size): 286 | end_idx = min(args.nsamples, begin_idx + args.prune_batch_size) 287 | refer_outs[begin_idx: end_idx,] = wrapped_forward(layer, inps[begin_idx: end_idx,], attention_mask, position_ids, end_idx - begin_idx)[0] 288 | torch.cuda.empty_cache() 289 | for h in handles: 290 | h.remove() 291 | 292 | if args.prune_dense or (not args.no_dense_loss): 293 | with inference_context: 294 | for begin_idx in range(0, args.nsamples, args.prune_batch_size): 295 | end_idx = min(args.nsamples, begin_idx + args.prune_batch_size) 296 | dense_outs[begin_idx: end_idx,] = wrapped_forward(layer, dense_inps[begin_idx: end_idx,], attention_mask, position_ids, end_idx - begin_idx)[0] 297 | torch.cuda.empty_cache() 298 | 299 | prune_func = grad_prune 300 | if args.fix_layers: 301 | fix_layers = list(sparse_layers.keys()) if args.fix_layers == 'all' else args.fix_layers.split(',') 302 | prune_func = fixed_prune if args.fix_layers == 'all' else grad_prune 303 | for layer_name in fix_layers: 304 | sparse_layers[layer_name].sparsity = args.sparsity 305 | 306 | torch.set_grad_enabled(True) 307 | init_learn_sparsity(sparse_layers, args.sparsity_step, blocksize=args.blocksize, sigmoid_smooth=not args.no_sigmoid_smooth, lora_rank=args.lora_rank) 308 | layer_prune_log, layer_sparsity = prune_func(i, layer, sparse_layers, attention_mask, position_ids, inps, outs, pruned_outs, dense_outs) 309 | torch.set_grad_enabled(False) 310 | finish_learn_sparsity(sparse_layers) 311 | model_prune_log.append(layer_prune_log) 312 | model_sparsity.append(layer_sparsity) 313 | 314 | layer.self_attn.q_proj = layer.self_attn.q_proj.layer 315 | layer.self_attn.k_proj = layer.self_attn.k_proj.layer 316 | layer.self_attn.v_proj = layer.self_attn.v_proj.layer 317 | layer.self_attn.o_proj = layer.self_attn.o_proj.layer 318 | layer.mlp.gate_proj = layer.mlp.gate_proj.layer 319 | layer.mlp.up_proj = layer.mlp.up_proj.layer 320 | layer.mlp.down_proj = layer.mlp.down_proj.layer 321 | 322 | layer = layer.cpu().to(dtype=dtype) 323 | use_new_forward(layer, recurse=True) 324 | layers[i] = layer 325 | del layer 326 | del sparse_layers 327 | gc.collect() 328 | torch.cuda.empty_cache() 329 | inps, pruned_outs = pruned_outs, inps 330 | if args.prune_dense or (not args.no_dense_loss): 331 | dense_inps, dense_outs = dense_outs, dense_inps 332 | 333 | model.config.use_cache = use_cache 334 | prune_time_cost = time.time() - prune_start 335 | print(f'Prune time cost: {prune_time_cost:.3f} seconds') 336 | model_sparsity = sum(model_sparsity) / len(model_sparsity) 337 | print(f"Model sparsity: {model_sparsity:.2f}") 338 | 339 | return model_prune_log 340 | 341 | 342 | def get_args(): 343 | parser = argparse.ArgumentParser() 344 | parser.add_argument( 345 | '--model', type=str, default='/mnt/lustre/share_data/xupeng/llama-7b-hf', 346 | help='model to load.' 347 | ) 348 | parser.add_argument( 349 | '--test-datasets', type=str, default='piqa,boolq,hellaswag,winogrande,arc_easy,arc_challenge', 350 | help='Evaluate model on test datasets' 351 | ) 352 | parser.add_argument( 353 | '--eval-dense', action='store_true', 354 | help='Whether to evaluate the dense model' 355 | ) 356 | parser.add_argument( 357 | '--batch-size', type=int, default=1, 358 | help='batch size of model evaluation' 359 | ) 360 | parser.add_argument( 361 | '--seed', type=int, default=0, 362 | help='Seed for sampling the calibration data.' 363 | ) 364 | parser.add_argument( 365 | '--port', type=int, default=1999, 366 | help='Port to init torch distributed.' 367 | ) 368 | parser.add_argument( 369 | '--nsamples', type=int, default=128, 370 | help='Number of calibration data samples.' 371 | ) 372 | parser.add_argument( 373 | '--sparsity', type=float, default=0.5, 374 | help='Target sparsity' 375 | ) 376 | 377 | parser.add_argument('--save-path', type=str, default=None) 378 | parser.add_argument('--exp-name', type=str, default='exp_0') 379 | parser.add_argument('--fix-layers', type=str, default=None) 380 | parser.add_argument('--no-dense-loss', action='store_true') 381 | parser.add_argument('--epochs', type=int, default=1) 382 | parser.add_argument('--prune-batch-size', type=int, default=1) 383 | parser.add_argument('--use-fp32', action='store_true') 384 | parser.add_argument('--metric-type', type=str, default='Wanda') 385 | parser.add_argument('--wise-dim', type=str, default='row') 386 | # Learning parameter settings 387 | parser.add_argument('--blocksize', type=int, default=-1) 388 | parser.add_argument('--sparsity-step', type=float, default=0.01) 389 | parser.add_argument('--lora-rank', type=int, default=-1) 390 | # Loss settings 391 | parser.add_argument('--norm-all', action='store_true') 392 | parser.add_argument('--prune-dense', action='store_true') 393 | parser.add_argument('--l2-alpha', type=float, default=1) 394 | parser.add_argument('--sparsity-beta', type=float, default=1) 395 | parser.add_argument('--no-sigmoid-smooth', action='store_true') 396 | # Scaler (norm, value) and Scheduler 397 | parser.add_argument('--clip-grad', type=float) 398 | parser.add_argument('--clip-mode', type=str, default='norm') 399 | parser.add_argument('--no-scaler', action='store_true') 400 | parser.add_argument('--use-cos-sche', action='store_true') 401 | # Normal Opt settings (AdamW) 402 | parser.add_argument('--normal-opt', action='store_true') 403 | parser.add_argument('--normal-opt-lr', type=float, default=1e-2) 404 | parser.add_argument('--normal-default', action='store_true') 405 | # Prodigy settings 406 | parser.add_argument('--prodigy-lr', type=float, default=1) 407 | parser.add_argument('--no-decouple', action='store_true') 408 | parser.add_argument('--use-bias-correction', action='store_true') 409 | parser.add_argument('--safeguard-warmup', action='store_true') 410 | parser.add_argument('--weight-decay', type=float, default=0) 411 | parser.add_argument('--d-coef', type=float, default=1) 412 | 413 | args = parser.parse_args() 414 | 415 | return args 416 | 417 | 418 | def main(args): 419 | print('Getting model ...') 420 | model = get_model(args.model, args.batch_size) 421 | 422 | if args.sparsity: 423 | print('Loading dataset ...') 424 | dataloader, c4_testenc = get_loaders( 425 | "c4", nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen 426 | ) 427 | _, wikitext_testenc = get_loaders( 428 | "wikitext2", nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen 429 | ) 430 | _, ptb_testenc = get_loaders( 431 | "ptb", nsamples=args.nsamples, seed=args.seed, model=args.model, seqlen=model.seqlen 432 | ) 433 | 434 | ppl_test_sets = ['c4', 'wikitext', 'ptb'] 435 | gc.collect() 436 | torch.cuda.empty_cache() 437 | 438 | if args.eval_dense: 439 | result = lm_eval_model(model, args) 440 | print(f"Dense model zero-shot evaluation result: {result}") 441 | for set_name in ppl_test_sets: 442 | ppl = eval_ppl(model, eval(f"{set_name}_testenc"), args.batch_size) 443 | print(f"Dense model {set_name} ppl: {ppl}") 444 | 445 | wandb.login() 446 | wandb.init( 447 | project="LLaMA", 448 | name=args.exp_name, 449 | config={ 450 | "model": args.model, 451 | "sparsity-step": args.sparsity_step, 452 | "epochs": args.epochs, 453 | "prune-batch-size": args.prune_batch_size, 454 | 'l2-alpha': args.l2_alpha, 455 | 'sparsity-beta': args.sparsity_beta, 456 | 'fix-layers': args.fix_layers, 457 | 'prune-dense': args.prune_dense, 458 | 'dense-loss': not args.no_dense_loss 459 | }) 460 | model_prune_log = compress_model(model, dataloader) 461 | wandb.finish() 462 | 463 | del dataloader 464 | torch.cuda.empty_cache() 465 | 466 | if not USE_LLaMA_SEQ: 467 | auto_map_model(model) 468 | 469 | if args.save_path: 470 | model.model.save_pretrained(args.save_path) 471 | model.tokenizer.save_pretrained(args.save_path) 472 | 473 | eval_result = lm_eval_model(model, args) 474 | print(f"Evaluation result: {eval_result}") 475 | c4_ppl = eval_ppl(model, c4_testenc, args.batch_size) 476 | ptb_ppl = eval_ppl(model, ptb_testenc, args.batch_size) 477 | wikitext_ppl = eval_ppl(model, wikitext_testenc, args.batch_size) 478 | for set_name in ppl_test_sets: 479 | print(f"{set_name} ppl: {eval(f'{set_name}_ppl')}") 480 | 481 | exp_log = os.path.join('exp_logs', f"{args.model.split('/')[-1]}-{args.exp_name}-{str(uuid.uuid4())}.pkl") 482 | while os.path.exists(exp_log): 483 | exp_log = os.path.join('exp_logs', f"{args.model.split('/')[-1]}-{args.exp_name}-{str(uuid.uuid4())}.pkl") 484 | with open(exp_log, 'wb') as f: 485 | pickle.dump({ 486 | 'args': args, 487 | 'c4_ppl': c4_ppl, 488 | 'ptb_ppl': ptb_ppl, 489 | 'wikitext_ppl': wikitext_ppl, 490 | 'eval_result': eval_result, 491 | 'model_prune_log': model_prune_log, 492 | }, f) 493 | 494 | 495 | if __name__ == "__main__": 496 | args = get_args() 497 | if torch.cuda.device_count() > 1: 498 | slurm_dist_init(args.seed, args.port) 499 | USE_LLaMA_SEQ = torch.cuda.device_count() == 1 500 | inference_context = contextlib.nullcontext() if args.use_fp32 else torch.cuda.amp.autocast() 501 | if is_master_process(): 502 | print(args) 503 | main(args) 504 | -------------------------------------------------------------------------------- /main_exps.sh: -------------------------------------------------------------------------------- 1 | # row-wise sparsity 2 | python main.py --exp-name LLaMA-7B-r1-e1-df5e0-beta5e0 --epochs 1 --d-coef 5e0 --sparsity-beta 5e0 --blocksize 1 --batch-size 16 --model /mnt/lustre/share_data/xupeng/llama-7b-hf 3 | python main.py --exp-name LLaMA-13B-r1-e1-df5e0-beta5e0 --epochs 1 --d-coef 5e0 --sparsity-beta 5e0 --blocksize 1 --batch-size 16 --model /mnt/lustre/share_data/xupeng/llama-13b-hf 4 | python main.py --exp-name LLaMA-30B-r1-e1-df1e1-beta5e0 --epochs 1 --d-coef 1e1 --sparsity-beta 5e0 --blocksize 1 --batch-size 16 --model /mnt/lustre/share_data/xupeng/llama-30b-hf 5 | python main.py --exp-name LLaMA-65B-r1-e1-df5e1-beta5e0 --epochs 1 --d-coef 5e1 --sparsity-beta 5e0 --blocksize 1 --batch-size 16 --model /mnt/lustre/share_data/xupeng/llama-65b-hf 6 | python main.py --exp-name LLaMA2-7B-r1-e1-df5e0-beta5e0 --epochs 1 --d-coef 5e0 --sparsity-beta 5e0 --blocksize 1 --batch-size 16 --model /mnt/lustre/share_data/xupeng/llama2-7b-hf 7 | python main.py --exp-name LLaMA2-13B-r1-e1-df5e0-beta5e0 --epochs 1 --d-coef 5e0 --sparsity-beta 5e0 --blocksize 1 --batch-size 16 --model /mnt/lustre/share_data/xupeng/llama2-13b-hf 8 | python main.py --exp-name LLaMA2-70B-r1-e1-df5e1-beta5e0 --epochs 1 --d-coef 5e1 --sparsity-beta 5e0 --blocksize 1 --batch-size 16 --model /mnt/lustre/share_data/xupeng/llama2-70b-hf 9 | 10 | # layer-wise sparsity 11 | python main.py --exp-name Dense-LLaMA-7B-e1-df5e1-beta5e0 --prune-dense --epochs 1 --d-coef 5e1 --sparsity-beta 5e0 --batch-size 16 --model /mnt/lustre/share_data/xupeng/llama-7b-hf 12 | python main.py --exp-name Dense-LLaMA-13B-e1-df5e-1-beta5e0 --prune-dense --epochs 1 --d-coef 5e-1 --sparsity-beta 5e0 --batch-size 16 --model /mnt/lustre/share_data/xupeng/llama-13b-hf 13 | python main.py --exp-name Dense-LLaMA-30B-e1-df5e-2-beta5e0 --prune-dense --epochs 1 --d-coef 5e-2 --sparsity-beta 5e0 --batch-size 16 --model /mnt/lustre/share_data/xupeng/llama-30b-hf 14 | python main.py --exp-name Dense-LLaMA-65B-e1-df5e-1-beta5e0 --prune-dense --epochs 1 --d-coef 5e-1 --sparsity-beta 5e0 --batch-size 16 --model /mnt/lustre/share_data/xupeng/llama-65b-hf 15 | python main.py --exp-name Dense-LLaMA2-7B-e1-df5e-2-beta5e0 --prune-dense --epochs 1 --d-coef 5e-2 --sparsity-beta 5e0 --batch-size 16 --model /mnt/lustre/share_data/xupeng/llama2-7b-hf 16 | python main.py --exp-name Dense-LLaMA2-13B-e1-df5e-2-beta5e0 --prune-dense --epochs 1 --d-coef 5e-2 --sparsity-beta 5e0 --batch-size 16 --model /mnt/lustre/share_data/xupeng/llama2-13b-hf 17 | python main.py --exp-name Dense-LLaMA2-70B-e1-df5e-1-beta5e0 --prune-dense --epochs 1 --d-coef 5e-1 --sparsity-beta 5e0 --batch-size 16 --model /mnt/lustre/share_data/xupeng/llama2-70b-hf -------------------------------------------------------------------------------- /main_hw.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from hw_sim.check_funcs import check_sparsity, check_dense, check_proj, check_attn, check_mlp, check_model 3 | 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('-m', '--model-name', type=str, default='llama-7b-0.5') 8 | parser.add_argument('-f', '--func', type=str, default='model') 9 | parser.add_argument('-t', '--threshold', type=float, default=0.5) 10 | args = parser.parse_args() 11 | 12 | model_name = args.model_name 13 | func_name = args.func 14 | if func_name in ['q', 'k', 'v', 'o', 'gate', 'up', 'down']: 15 | check_proj(model_name, func_name, args.threshold) 16 | elif func_name in ['attn', 'mlp', 'model']: 17 | eval(f"check_{func_name}")(model_name, args.threshold) 18 | elif func_name in ['dense', 'sparsity']: 19 | eval(f"check_{func_name}")(model_name) 20 | else: 21 | raise ValueError(f"Invalid func name: {func_name}") 22 | -------------------------------------------------------------------------------- /models/llama.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Optional, Union 3 | from lm_eval.base import BaseLM, CachingLM 4 | from transformers import LlamaForCausalLM, LlamaTokenizer, BatchEncoding 5 | 6 | TokenSequence = Union[List[int], torch.LongTensor, torch.Tensor, BatchEncoding] 7 | 8 | 9 | class LLaMA(BaseLM): 10 | def __init__(self, model_name, batch_size=1, device='cuda') -> None: 11 | self._batch_size = self.max_batch_size = batch_size 12 | self.seqlen = self._max_length = self._max_gen_toks = 2048 13 | self.add_special_tokens = False 14 | 15 | self.model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto") 16 | self.tokenizer = LlamaTokenizer.from_pretrained(model_name, use_fast=False) 17 | self.model.eval() 18 | 19 | self._device = device 20 | self.config = self.model.config 21 | CachingLM(self, '__lmcache__') 22 | torch.set_grad_enabled(False) 23 | 24 | @property 25 | def eot_token_id(self) -> int: 26 | return self.tokenizer.eos_token_id 27 | 28 | @property 29 | def max_length(self) -> int: 30 | return self._max_length 31 | 32 | @property 33 | def max_gen_toks(self): 34 | return self._max_gen_toks 35 | 36 | @property 37 | def batch_size(self) -> int: 38 | return self._batch_size 39 | 40 | @property 41 | def device(self) -> Union[int, str, torch.device]: 42 | return self._device 43 | 44 | def tok_encode(self, string: str) -> TokenSequence: 45 | return self.tokenizer.encode(string, add_special_tokens=self.add_special_tokens) 46 | 47 | def tok_decode(self, tokens: torch.LongTensor) -> List[str]: 48 | return self.tokenizer.batch_decode(tokens, skip_special_tokens=True) 49 | 50 | def _model_generate(self, context, max_length, eos_token_id): 51 | return None 52 | 53 | def _model_call( 54 | self, inputs: TokenSequence, labels: Optional[TokenSequence] = None 55 | ) -> TokenSequence: 56 | return self.model(inputs)["logits"] 57 | -------------------------------------------------------------------------------- /models/llama_seq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Optional, Union 3 | from lm_eval.base import BaseLM, CachingLM 4 | from transformers import LlamaForCausalLM, LlamaTokenizer, BatchEncoding 5 | 6 | TokenSequence = Union[List[int], torch.LongTensor, torch.Tensor, BatchEncoding] 7 | 8 | 9 | class LLaMA_Seq(BaseLM): 10 | def __init__(self, model_name, batch_size=1, device='cuda') -> None: 11 | self._batch_size = self.max_batch_size = batch_size 12 | self.seqlen = self._max_length = self._max_gen_toks = 2048 13 | self.add_special_tokens = False 14 | 15 | self.model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="cpu") 16 | self.tokenizer = LlamaTokenizer.from_pretrained(model_name, use_fast=False) 17 | self.model.eval() 18 | 19 | self._device = device 20 | self.config = self.model.config 21 | CachingLM(self, '__lmcache__') 22 | torch.set_grad_enabled(False) 23 | 24 | @property 25 | def eot_token_id(self) -> int: 26 | return self.tokenizer.eos_token_id 27 | 28 | @property 29 | def max_length(self) -> int: 30 | return self._max_length 31 | 32 | @property 33 | def max_gen_toks(self): 34 | return self._max_gen_toks 35 | 36 | @property 37 | def batch_size(self) -> int: 38 | return self._batch_size 39 | 40 | @property 41 | def device(self) -> Union[int, str, torch.device]: 42 | return self._device 43 | 44 | def tok_encode(self, string: str) -> TokenSequence: 45 | return self.tokenizer.encode(string, add_special_tokens=self.add_special_tokens) 46 | 47 | def tok_decode(self, tokens: torch.LongTensor) -> List[str]: 48 | return self.tokenizer.batch_decode(tokens, skip_special_tokens=True) 49 | 50 | def _model_generate(self, context, max_length, eos_token_id): 51 | return None 52 | 53 | def _model_call( 54 | self, inputs: TokenSequence, labels: Optional[TokenSequence] = None 55 | ) -> TokenSequence: 56 | self.model.model.embed_tokens = self.model.model.embed_tokens.to(self.device) 57 | layers = self.model.model.layers 58 | layers[0] = layers[0].to(self.device) 59 | 60 | cache = {'hidden_states': None, 'attention_mask': None, 'position_ids': None} 61 | class Catcher(torch.nn.Module): 62 | def __init__(self, module): 63 | super().__init__() 64 | self.module = module 65 | 66 | def forward(self, inp, **kwargs): 67 | cache['hidden_states'] = inp 68 | cache['attention_mask'] = kwargs['attention_mask'] 69 | cache['position_ids'] = kwargs['position_ids'] 70 | raise ValueError 71 | 72 | layers[0] = Catcher(layers[0]) 73 | try: 74 | self.model(inputs) 75 | except ValueError: 76 | pass 77 | layers[0] = layers[0].module 78 | layers[0] = layers[0].cpu() 79 | self.model.model.embed_tokens = self.model.model.embed_tokens.cpu() 80 | torch.cuda.empty_cache() 81 | 82 | inps = cache.pop('hidden_states') 83 | outs = torch.zeros_like(inps) 84 | for i in range(len(layers)): 85 | layer = layers[i].to(self.device) 86 | outs = layer(inps, **cache)[0] 87 | layers[i] = layer.cpu() 88 | del layer 89 | torch.cuda.empty_cache() 90 | inps, outs = outs, inps 91 | 92 | self.model.model.norm = self.model.model.norm.to(self.device) 93 | self.model.lm_head = self.model.lm_head.to(self.device) 94 | inps = self.model.model.norm(inps) 95 | lm_logits = self.model.lm_head(inps) 96 | self.model.model.norm = self.model.model.norm.cpu() 97 | self.model.lm_head = self.model.lm_head.cpu() 98 | torch.cuda.empty_cache() 99 | 100 | return lm_logits 101 | -------------------------------------------------------------------------------- /models/ops/mask_gen/functions/mask_gen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | import mask_gen_cuda 5 | 6 | 7 | class MaskGen(Function): 8 | @staticmethod 9 | def forward(ctx, sort_index, mask_shape, top_k): 10 | mask=torch.ones(mask_shape, device=sort_index.device) 11 | x = mask_gen_cuda.mask_gen_forward(mask, sort_index, top_k) 12 | return x[0] 13 | 14 | @staticmethod 15 | def backward(ctx, grad_output): 16 | return None 17 | 18 | 19 | mask_gen_op = MaskGen.apply 20 | 21 | 22 | class MaskGen(nn.Module): 23 | def __init__(self): 24 | super().__init__() 25 | 26 | def forward(self, sort_index, mask_shape, top_k): 27 | mask=mask_gen_op(sort_index, mask_shape, top_k) 28 | return mask 29 | 30 | 31 | mask_gen = MaskGen() -------------------------------------------------------------------------------- /models/ops/mask_gen/src/mask_gen.cpp: -------------------------------------------------------------------------------- 1 | #include "mask_gen.h" 2 | #include 3 | #include 4 | #include 5 | #include 6 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 7 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 8 | #define CHECK_INPUT(x) \ 9 | CHECK_CUDA(x); \ 10 | CHECK_CONTIGUOUS(x) 11 | 12 | // == Forward 13 | std::vector mask_cuda_forward(torch::Tensor mask, // parameter: K*group_num, C 14 | torch::Tensor sorted_index, // tensor : B, N, C 15 | torch::Tensor top_k) // tensor: B, N, K 16 | { 17 | CHECK_INPUT(mask); 18 | CHECK_INPUT(sorted_index); 19 | CHECK_INPUT(top_k); 20 | 21 | return MaskData_ongpu(mask, sorted_index, top_k); 22 | } 23 | 24 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 25 | { 26 | m.def("mask_gen_forward", &mask_cuda_forward, "score forward (CUDA)"); 27 | } -------------------------------------------------------------------------------- /models/ops/mask_gen/src/mask_gen.h: -------------------------------------------------------------------------------- 1 | #ifndef _Score_CUDA 2 | #define _Score_CUDA 3 | #include 4 | #include 5 | 6 | std::vector mask_cuda_forward(torch::Tensor mask, // query t: N, H, W, C1 7 | torch::Tensor sorted_index, // scene : N, H, W, C1 8 | torch::Tensor top_k); // scene : N, H, W, C1 9 | 10 | std::vector MaskData_ongpu(at::Tensor mask, // query t: N, H, W, C1 11 | at::Tensor sorted_index, // scene : N, H, W, C1 12 | at::Tensor top_k); // scene : N, H, W, C1 13 | 14 | #endif -------------------------------------------------------------------------------- /models/ops/mask_gen/src/mask_gen_kernal.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "mask_gen.h" 7 | #include 8 | 9 | #define ROUND_OFF 50000 10 | 11 | #define CUDA_NUM_THREADS 1024 12 | #define WARPS_PER_BLOCK 1 13 | #define THREADS_PER_WARP 32 14 | #define MAX_H 8 15 | 16 | #define CUDA_KERNEL_LOOP(i, n) for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) 17 | 18 | #define GET_BLOCKS(n, t) (n+t-1) / t 19 | 20 | 21 | template 22 | __global__ void MaskData_kernal( 23 | torch::PackedTensorAccessor32 mask, // B, N1, 4, H, dim 24 | torch::PackedTensorAccessor32 index, //B, N1, K*4, H 25 | torch::PackedTensorAccessor32 top_k //B, N1, 4, K*4, H 26 | ){ 27 | 28 | int n = blockIdx.x; 29 | int g = blockIdx.y; 30 | 31 | if(g MaskData_ongpu(torch::Tensor mask, // B, N1, 4, H, dim 39 | torch::Tensor sorted_index, // B, N2, H, dim 40 | torch::Tensor top_k) // B, N1, K, 4, H 41 | { 42 | 43 | const auto N = mask.size(0); 44 | const auto G = mask.size(1); 45 | 46 | 47 | //auto mask = torch::zeros({B, N1, 4, K, H},torch::device(torch::kCUDA)); 48 | 49 | int shared_memory_per_block = 0; 50 | 51 | dim3 totalBlocks(N, G, 1); 52 | dim3 threadsPerBlock(THREADS_PER_WARP); 53 | AT_DISPATCH_FLOATING_TYPES(mask.type(), "MaskData_kernal", ([&] { 54 | MaskData_kernal<<>>( 55 | mask.packed_accessor32(), 56 | sorted_index.packed_accessor32(), 57 | top_k.packed_accessor32()); 58 | })); 59 | return {mask}; 60 | 61 | } 62 | -------------------------------------------------------------------------------- /models/ops/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | this_file = os.path.dirname(__file__) 6 | 7 | setup( 8 | name="sparse_mask", 9 | ext_modules=[ 10 | CUDAExtension( 11 | "mask_gen_cuda", 12 | [ 13 | "mask_gen/src/mask_gen.cpp", 14 | "mask_gen/src/mask_gen_kernal.cu", 15 | ], 16 | extra_compile_args={"cxx": ["-g"], "nvcc": ["-O2"]}, 17 | ) 18 | ], 19 | cmdclass={"build_ext": BuildExtension}, 20 | ) 21 | -------------------------------------------------------------------------------- /models/ops/test_mask_gen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from mask_gen.functions.mask_gen import mask_gen 5 | 6 | 7 | class MaskGenPytorch(nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | def forward(self, sort_index, mask_shape, top_k): 11 | mask=torch.ones(mask_shape, device=sort_index.device) 12 | 13 | for i in range(mask_shape[0]): 14 | k=top_k[i] 15 | mask[i,sort_index[i,:k]]=0 16 | return mask 17 | 18 | 19 | if __name__ == "__main__": 20 | import time 21 | rows, columns = 4096, 4096 22 | shape = (rows, columns) 23 | mask_gen_pytorch=MaskGenPytorch() 24 | 25 | row_blocksize = 8 26 | row_number = rows // row_blocksize 27 | row_sparsities = torch.rand(row_number).cuda() 28 | row_block_prune_num = (row_sparsities * columns).to(dtype=torch.long) 29 | row_prune_num = row_block_prune_num.reshape(-1, 1).repeat(1, row_blocksize).reshape(-1) 30 | 31 | mask_prob=torch.rand(shape).cuda() 32 | sort_index=torch.argsort(mask_prob, dim=1, descending=True) 33 | t=time.time() 34 | mask=mask_gen(sort_index, shape, row_prune_num) 35 | t2=time.time() 36 | mask2=mask_gen_pytorch(sort_index, shape, row_prune_num) 37 | t3=time.time() 38 | assert (mask==mask2).all() 39 | print('time for cuda op: ', t2-t) 40 | print('time for pytorch op: ', t3-t2) 41 | -------------------------------------------------------------------------------- /models/sparse_layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import mask_gen_cuda 6 | 7 | 8 | class SparseLinear(nn.Module): 9 | def __init__(self, layer, metric_type='Wanda', wise_dim='row') -> None: 10 | super().__init__() 11 | self.layer = layer 12 | self.linear_func = nn.functional.linear 13 | self.register_buffer('weight', layer.weight) 14 | if layer.bias is not None: 15 | self.register_buffer('bias', layer.bias) 16 | else: 17 | self.bias = None 18 | self.param_num = self.weight.numel() 19 | 20 | self.nsamples = 0 21 | self.use_lora = False 22 | self.learn_sparsity = False 23 | self.rows = self.weight.data.shape[0] 24 | self.columns = self.weight.data.shape[1] 25 | self.device = self.weight.device 26 | 27 | self.wise_dim = wise_dim 28 | assert self.wise_dim in ['row', 'column'], f"Invalid wise dim: {wise_dim}" 29 | 30 | self.metric_type = metric_type 31 | if metric_type == 'Wanda': 32 | self.scaler_row = torch.zeros((self.columns), device=self.device) 33 | elif metric_type == 'SparseGPT' or metric_type == 'SparseGPT-Git': 34 | self.Hessian = torch.zeros((self.columns, self.columns), device=self.device) 35 | elif metric_type == 'Weight': 36 | pass 37 | else: 38 | raise NotImplementedError(f"Invalid metric type: {metric_type}") 39 | 40 | def add_batch(self, inp, out): 41 | if self.metric_type == 'Weight': 42 | return 43 | 44 | if len(inp.shape) == 2: 45 | inp = inp.unsqueeze(0) 46 | tmp = inp.shape[0] 47 | if len(inp.shape) == 3: 48 | inp = inp.reshape((-1, inp.shape[-1])) 49 | inp = inp.t() 50 | 51 | if self.metric_type == 'Wanda': 52 | self.scaler_row *= self.nsamples / (self.nsamples+tmp) 53 | elif self.metric_type == 'SparseGPT' or self.metric_type == 'SparseGPT-Git': 54 | self.Hessian *= self.nsamples / (self.nsamples + tmp) 55 | else: 56 | raise NotImplementedError(f"Invalid metric type: {self.metric_type}") 57 | 58 | self.nsamples += tmp 59 | 60 | if self.metric_type == 'Wanda': 61 | self.scaler_row += torch.norm(inp.float(), p=2, dim=1) ** 2 / self.nsamples 62 | elif self.metric_type == 'SparseGPT' or self.metric_type == 'SparseGPT-Git': 63 | inp = math.sqrt(2 / self.nsamples) * inp.float() 64 | self.Hessian += inp.matmul(inp.t()) 65 | else: 66 | raise NotImplementedError(f"Invalid metric type: {self.metric_type}") 67 | 68 | def get_w_metric(self): 69 | # NOTE: in lora pruning, the importance metric will be: lora_B @ grad(lora_A) + grad(lora_B) @ lora_A - grad(lora_B) @ grad(lora_A) 70 | if self.metric_type == 'Weight': 71 | self.W_metric = torch.abs(self.weight) 72 | elif self.metric_type == 'Wanda': 73 | self.W_metric = torch.abs(self.weight) * torch.sqrt(self.scaler_row.reshape((1,-1))) 74 | self.scaler_row = None 75 | elif self.metric_type == 'SparseGPT' or self.metric_type == 'SparseGPT-Git': 76 | percdamp = 0.01 # Percent of the average Hessian diagonal used for dampening 77 | dead = torch.diag(self.Hessian) == 0 78 | self.Hessian[dead, dead] = 1 79 | # NOTE: SparseGPT updates the weight in cols of zero diag 80 | self.weight.data[:, dead] = 0 81 | damp = percdamp * torch.mean(torch.diag(self.Hessian)) 82 | diag = torch.arange(self.columns, device=self.device) 83 | self.Hessian[diag, diag] += damp 84 | Hinv = torch.linalg.cholesky(self.Hessian) 85 | Hinv = torch.cholesky_inverse(Hinv) 86 | # NOTE: use cholesky_ex as Hinv many not be complex Hermitian matrix 87 | Hinv, _ = torch.linalg.cholesky_ex(Hinv, upper=True) 88 | Hinv = torch.diag(Hinv).reshape((1, -1)) 89 | if self.metric_type == 'SparseGPT-Git': 90 | self.W_metric = self.weight ** 2 / Hinv ** 2 91 | else: 92 | self.W_metric = self.weight ** 2 / Hinv 93 | 94 | if self.wise_dim == 'row': 95 | self.sort_indices = torch.sort(self.W_metric, dim=-1, stable=True)[1] 96 | elif self.wise_dim == 'column': 97 | self.sort_indices = torch.sort(self.W_metric, dim=0, stable=True)[1] 98 | else: 99 | raise NotImplementedError(f"Invalid wise dim: {self.wise_dim}") 100 | 101 | def init_learn_sparsity(self, sparsity_step=0.01, prune_n=0, prune_m=0, blocksize=-1, sigmoid_smooth=False, lora_rank=-1, lora_alpha=1): 102 | self.prune_n, self.prune_m = prune_n, prune_m 103 | self.get_w_metric() 104 | torch.cuda.empty_cache() 105 | 106 | if hasattr(self, 'sparsity'): 107 | self.block_wise = False 108 | self.learn_sparsity = False 109 | W_mask = self.get_weight_mask().detach() 110 | self.weight.data *= W_mask.to(dtype=self.weight.dtype) 111 | self.finish_learn_sparsity() 112 | return 113 | 114 | self.learn_sparsity = True 115 | self.block_wise = blocksize != -1 116 | self.sigmoid_smooth = sigmoid_smooth 117 | self.sparsity_candidates = torch.arange(1.0, -1 * sparsity_step, -1 * sparsity_step, device=self.device) 118 | self.sparsity_candidates[-1] = 0.0 119 | if self.block_wise: 120 | self.blocksize = blocksize 121 | if self.wise_dim == 'row': 122 | assert self.rows % blocksize == 0, "Row blocksize should be fully divided by the number of rows" 123 | self.blocknum = self.rows // blocksize 124 | elif self.wise_dim == 'column': 125 | assert self.columns % blocksize == 0, "Column blocksize should be fully divided by the number of rows" 126 | self.blocknum = self.columns // blocksize 127 | else: 128 | raise NotImplementedError(f"Invalid wise dim: {self.wise_dim}") 129 | self.sparsity_probabilities = nn.Parameter(torch.zeros((self.blocknum, self.sparsity_candidates.shape[0]), device=self.device)) 130 | else: 131 | self.sparsity_probabilities = nn.Parameter(torch.zeros_like(self.sparsity_candidates, device=self.device)) 132 | self.update_sparsity() 133 | 134 | map_dim_size = self.columns if self.wise_dim == 'row' else self.rows if self.wise_dim == 'column' else -1 135 | self.prob_map_matrix = torch.zeros((len(self.sparsity_candidates), map_dim_size), device=self.device) 136 | for i in range(len(self.sparsity_candidates)): 137 | self.prob_map_matrix[i, :int(map_dim_size * self.sparsity_candidates[i].item())] = 1 138 | 139 | self.use_lora = lora_rank != -1 140 | if self.use_lora: 141 | assert type(lora_rank) is int and 0 < lora_rank < min(self.rows, self.columns), f"Invalid Lora rank: {lora_rank}" 142 | self.lora_A = nn.Parameter(torch.zeros((lora_rank, self.columns), device=self.device)) 143 | self.lora_B = nn.Parameter(torch.zeros((self.rows, lora_rank), device=self.device)) 144 | self.lora_scaling = lora_alpha / lora_rank 145 | 146 | def finish_learn_sparsity(self): 147 | if self.learn_sparsity: 148 | if self.use_lora: 149 | lora_weight = (self.lora_B.data @ self.lora_A.data).detach() * self.lora_scaling 150 | self.weight.data += lora_weight.to(self.weight.dtype) 151 | self.lora_A = None 152 | self.lora_B = None 153 | self.lora_scaling = None 154 | self.update_sparsity() 155 | prune_mask = self.get_prune_mask().detach() 156 | self.weight.data *= prune_mask 157 | self.learn_sparsity = False 158 | 159 | self.W_metric = None 160 | self.scaler_row = None 161 | self.sort_indices = None 162 | self.sparsities = None 163 | self.prob_map_matrix = None 164 | self.sparsity_candidates = None 165 | self.sparsity_probabilities = None 166 | self.sparsity_probabilities_softmax = None 167 | torch.cuda.empty_cache() 168 | 169 | def update_sparsity(self): 170 | if self.sigmoid_smooth: 171 | self.sparsity_probabilities_softmax = self.sparsity_probabilities.sigmoid().softmax(dim=-1) 172 | else: 173 | self.sparsity_probabilities_softmax = self.sparsity_probabilities.softmax(dim=-1) 174 | 175 | if self.block_wise: 176 | self.sparsities = self.sparsity_probabilities_softmax @ self.sparsity_candidates 177 | self.sparsity = self.sparsities.mean() 178 | else: 179 | self.sparsity = torch.matmul(self.sparsity_candidates, self.sparsity_probabilities_softmax) 180 | return self.sparsity 181 | 182 | def get_weight_mask(self): 183 | W_mask = torch.ones((self.rows, self.columns), device=self.device) 184 | if self.prune_n != 0: 185 | # structured n:m sparsity 186 | for ii in range(self.columns): 187 | if ii % self.prune_m == 0: 188 | tmp = self.W_metric[:, ii:(ii + self.prune_m)].float() 189 | W_mask.scatter_(1, ii + torch.topk(tmp, self.prune_n, dim=1, largest=False)[1], 0) 190 | elif self.block_wise: 191 | # block wise unstructured pruning 192 | if self.wise_dim == 'row': 193 | row_block_prune_num = (self.sparsities * self.columns).to(dtype=torch.long) 194 | row_prune_num = row_block_prune_num.reshape((-1, 1)).repeat(1, self.blocksize).reshape(-1) 195 | W_mask = mask_gen_cuda.mask_gen_forward(W_mask, self.sort_indices, row_prune_num)[0] 196 | elif self.wise_dim == 'column': 197 | column_block_prune_num = (self.sparsities * self.rows).to(dtype=torch.long) 198 | column_prune_num = column_block_prune_num.reshape((-1, 1)).repeat(1, self.blocksize).reshape(-1) 199 | W_mask = mask_gen_cuda.mask_gen_forward(W_mask.t().contiguous(), self.sort_indices.t().contiguous(), column_prune_num)[0] 200 | W_mask = W_mask.t().contiguous() 201 | else: 202 | raise NotImplementedError(f"Invalid wise dim: {self.wise_dim}") 203 | else: 204 | # unstructured pruning 205 | if self.wise_dim == 'row': 206 | indices = self.sort_indices[:, :int(self.columns * self.sparsity)] 207 | W_mask.scatter_(1, indices, 0) 208 | elif self.wise_dim == 'column': 209 | indices = self.sort_indices[:int(self.rows * self.sparsity), :] 210 | W_mask.scatter_(0, indices, 0) 211 | else: 212 | raise NotImplementedError(f"Invalid wise dim: {self.wise_dim}") 213 | return W_mask 214 | 215 | def get_prob_mask(self): 216 | P_mask = torch.zeros((self.rows, self.columns), device=self.device) 217 | probabilities = 1 - (self.sparsity_probabilities_softmax @ self.prob_map_matrix) 218 | if not self.block_wise: 219 | if self.wise_dim == 'row': 220 | probabilities = probabilities.repeat(self.rows, 1) 221 | elif self.wise_dim == 'column': 222 | probabilities = probabilities.reshape((-1, 1)).repeat(1, self.columns) 223 | else: 224 | raise NotImplementedError(f"Invalid wise dim: {self.wise_dim}") 225 | else: 226 | if self.wise_dim == 'row': 227 | probabilities = probabilities.reshape((self.blocknum, 1, self.columns)) 228 | probabilities = probabilities.repeat(1, self.blocksize, 1) 229 | elif self.wise_dim == 'column': 230 | probabilities = probabilities.reshape((self.rows, self.blocknum, 1)) 231 | probabilities = probabilities.repeat(1, 1, self.blocksize) 232 | else: 233 | raise NotImplementedError(f"Invalid wise dim: {self.wise_dim}") 234 | probabilities = probabilities.reshape((self.rows, self.columns)) 235 | probabilities = probabilities.to(dtype=P_mask.dtype) 236 | scatter_dim = 1 if self.wise_dim == 'row' else 0 if self.wise_dim == 'column' else -1 237 | P_mask.scatter_(scatter_dim, self.sort_indices, probabilities) 238 | return P_mask 239 | 240 | def get_prune_mask(self): 241 | W_mask = self.get_weight_mask() 242 | P_mask = self.get_prob_mask() 243 | prune_mask = W_mask.detach() - P_mask.detach() + P_mask 244 | prune_mask = prune_mask.to(dtype=self.weight.dtype) 245 | 246 | return prune_mask 247 | 248 | def forward(self, input: torch.Tensor): 249 | weight = self.weight.detach() 250 | if self.learn_sparsity: 251 | self.update_sparsity() 252 | prune_mask = self.get_prune_mask() 253 | if self.use_lora: 254 | lora_weight = (self.lora_B @ self.lora_A) * self.lora_scaling 255 | weight += lora_weight.to(dtype=self.weight.dtype) 256 | weight = torch.mul(weight, prune_mask) 257 | out = self.linear_func(input, weight, self.bias) 258 | 259 | return out 260 | -------------------------------------------------------------------------------- /optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .prodigy import Prodigy -------------------------------------------------------------------------------- /optimizers/prodigy.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.optim 5 | from torch.optim.optimizer import Optimizer, _use_grad_for_differentiable 6 | import torch.distributed as dist 7 | 8 | 9 | class Prodigy(Optimizer): 10 | r""" 11 | Implements Adam with Prodigy step-sizes. 12 | Leave LR set to 1 unless you encounter instability. 13 | 14 | Arguments: 15 | params (iterable): 16 | Iterable of parameters to optimize or dicts defining parameter groups. 17 | lr (float): 18 | Learning rate adjustment parameter. Increases or decreases the Prodigy learning rate. 19 | betas (Tuple[float, float], optional): coefficients used for computing 20 | running averages of gradient and its square (default: (0.9, 0.999)) 21 | beta3 (float): 22 | coefficients for computing the Prodidy stepsize using running averages. 23 | If set to None, uses the value of square root of beta2 (default: None). 24 | eps (float): 25 | Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-8). 26 | weight_decay (float): 27 | Weight decay, i.e. a L2 penalty (default: 0). 28 | decouple (boolean): 29 | Use AdamW style decoupled weight decay 30 | use_bias_correction (boolean): 31 | Turn on Adam's bias correction. Off by default. 32 | safeguard_warmup (boolean): 33 | Remove lr from the denominator of D estimate to avoid issues during warm-up stage. Off by default. 34 | d0 (float): 35 | Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing. 36 | d_coef (float): 37 | Coefficient in the expression for the estimate of d (default 1.0). 38 | Values such as 0.5 and 2.0 typically work as well. 39 | Changing this parameter is the preferred way to tune the method. 40 | growth_rate (float): 41 | prevent the D estimate from growing faster than this multiplicative rate. 42 | Default is inf, for unrestricted. Values like 1.02 give a kind of learning 43 | rate warmup effect. 44 | fsdp_in_use (bool): 45 | If you're using sharded parameters, this should be set to True. The optimizer 46 | will attempt to auto-detect this, but if you're using an implementation other 47 | than PyTorch's builtin version, the auto-detection won't work. 48 | """ 49 | def __init__(self, params, lr=1.0, 50 | betas=(0.9, 0.999), beta3=None, 51 | eps=1e-8, weight_decay=0, decouple=True, 52 | use_bias_correction=False, safeguard_warmup=False, 53 | d0=1e-6, d_coef=1.0, growth_rate=float('inf'), 54 | fsdp_in_use=False): 55 | if not 0.0 < d0: 56 | raise ValueError("Invalid d0 value: {}".format(d0)) 57 | if not 0.0 < lr: 58 | raise ValueError("Invalid learning rate: {}".format(lr)) 59 | if not 0.0 < eps: 60 | raise ValueError("Invalid epsilon value: {}".format(eps)) 61 | if not 0.0 <= betas[0] < 1.0: 62 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 63 | if not 0.0 <= betas[1] < 1.0: 64 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 65 | 66 | if decouple and weight_decay > 0: 67 | print(f"Using decoupled weight decay") 68 | 69 | 70 | defaults = dict(lr=lr, betas=betas, beta3=beta3, 71 | eps=eps, weight_decay=weight_decay, 72 | d=d0, d0=d0, d_max=d0, 73 | d_numerator=0.0, d_coef=d_coef, 74 | k=0, growth_rate=growth_rate, 75 | use_bias_correction=use_bias_correction, 76 | decouple=decouple, safeguard_warmup=safeguard_warmup, 77 | fsdp_in_use=fsdp_in_use, differentiable=False) 78 | self.d0 = d0 79 | super().__init__(params, defaults) 80 | 81 | @property 82 | def supports_memory_efficient_fp16(self): 83 | return False 84 | 85 | @property 86 | def supports_flat_params(self): 87 | return True 88 | 89 | @_use_grad_for_differentiable 90 | def step(self, closure=None): 91 | """Performs a single optimization step. 92 | 93 | Arguments: 94 | closure (callable, optional): A closure that reevaluates the model 95 | and returns the loss. 96 | """ 97 | loss = None 98 | if closure is not None: 99 | loss = closure() 100 | 101 | d_denom = 0.0 102 | 103 | group = self.param_groups[0] 104 | use_bias_correction = group['use_bias_correction'] 105 | beta1, beta2 = group['betas'] 106 | beta3 = group['beta3'] 107 | if beta3 is None: 108 | beta3 = math.sqrt(beta2) 109 | k = group['k'] 110 | 111 | d = group['d'] 112 | d_max = group['d_max'] 113 | d_coef = group['d_coef'] 114 | lr = max(group['lr'] for group in self.param_groups) 115 | 116 | if use_bias_correction: 117 | bias_correction = ((1 - beta2**(k+1))**0.5) / (1 - beta1**(k+1)) 118 | else: 119 | bias_correction = 1 120 | 121 | dlr = d*lr*bias_correction 122 | 123 | growth_rate = group['growth_rate'] 124 | decouple = group['decouple'] 125 | fsdp_in_use = group['fsdp_in_use'] 126 | 127 | d_numerator = group['d_numerator'] 128 | d_numerator *= beta3 129 | 130 | for group in self.param_groups: 131 | decay = group['weight_decay'] 132 | k = group['k'] 133 | eps = group['eps'] 134 | group_lr = group['lr'] 135 | d0 = group['d0'] 136 | safeguard_warmup = group['safeguard_warmup'] 137 | 138 | if group_lr not in [lr, 0.0]: 139 | raise RuntimeError(f"Setting different lr values in different parameter groups is only supported for values of 0") 140 | 141 | for p in group['params']: 142 | if p.grad is None: 143 | continue 144 | if hasattr(p, "_fsdp_flattened"): 145 | fsdp_in_use = True 146 | 147 | grad = p.grad.data 148 | 149 | # Apply weight decay (coupled variant) 150 | if decay != 0 and not decouple: 151 | grad.add_(p.data, alpha=decay) 152 | 153 | state = self.state[p] 154 | 155 | # State initialization 156 | if 'step' not in state: 157 | state['step'] = 0 158 | state['s'] = torch.zeros_like(p.data).detach() 159 | state['p0'] = p.detach().clone() 160 | # Exponential moving average of gradient values 161 | state['exp_avg'] = torch.zeros_like(p.data).detach() 162 | # Exponential moving average of squared gradient values 163 | state['exp_avg_sq'] = torch.zeros_like(p.data).detach() 164 | 165 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 166 | 167 | s = state['s'] 168 | p0 = state['p0'] 169 | 170 | if group_lr > 0.0: 171 | # we use d / d0 instead of just d to avoid getting values that are too small 172 | d_numerator += (d / d0) * dlr * torch.dot(grad.flatten(), (p0.data - p.data).flatten()).item() 173 | 174 | # Adam EMA updates 175 | exp_avg.mul_(beta1).add_(grad, alpha=d * (1-beta1)) 176 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=d * d * (1-beta2)) 177 | 178 | if safeguard_warmup: 179 | s.mul_(beta3).add_(grad, alpha=((d / d0) * d)) 180 | else: 181 | s.mul_(beta3).add_(grad, alpha=((d / d0) * dlr)) 182 | d_denom += s.abs().sum().item() 183 | 184 | ###### 185 | 186 | d_hat = d 187 | 188 | # if we have not done any progres, return 189 | # if we have any gradients available, will have d_denom > 0 (unless \|g\|=0) 190 | if d_denom == 0: 191 | return loss 192 | 193 | if lr > 0.0: 194 | if fsdp_in_use: 195 | dist_tensor = torch.zeros(2).cuda() 196 | dist_tensor[0] = d_numerator 197 | dist_tensor[1] = d_denom 198 | dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM) 199 | global_d_numerator = dist_tensor[0] 200 | global_d_denom = dist_tensor[1] 201 | else: 202 | global_d_numerator = d_numerator 203 | global_d_denom = d_denom 204 | 205 | d_hat = d_coef * global_d_numerator / global_d_denom 206 | if d == group['d0']: 207 | d = max(d, d_hat) 208 | d_max = max(d_max, d_hat) 209 | d = min(d_max, d * growth_rate) 210 | 211 | for group in self.param_groups: 212 | group['d_numerator'] = global_d_numerator 213 | group['d_denom'] = global_d_denom 214 | group['d'] = d 215 | group['d_max'] = d_max 216 | group['d_hat'] = d_hat 217 | 218 | decay = group['weight_decay'] 219 | k = group['k'] 220 | eps = group['eps'] 221 | 222 | for p in group['params']: 223 | if p.grad is None: 224 | continue 225 | grad = p.grad.data 226 | 227 | state = self.state[p] 228 | 229 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 230 | 231 | state['step'] += 1 232 | 233 | denom = exp_avg_sq.sqrt().add_(d * eps) 234 | 235 | # Apply weight decay (decoupled variant) 236 | if decay != 0 and decouple: 237 | p.data.add_(p.data, alpha=-decay * dlr) 238 | 239 | 240 | ### Take step 241 | p.data.addcdiv_(exp_avg, denom, value=-dlr) 242 | 243 | group['k'] = k + 1 244 | 245 | return loss 246 | -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | import random 2 | from datasets import load_dataset 3 | from transformers import AutoTokenizer 4 | 5 | 6 | class TokenizerWrapper: 7 | def __init__(self, input_ids): 8 | self.input_ids = input_ids 9 | 10 | 11 | def get_wikitext2(nsamples, seed, seqlen, model): 12 | traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') 13 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') 14 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 15 | 16 | trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt') 17 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') 18 | 19 | random.seed(seed) 20 | trainloader = [] 21 | for _ in range(nsamples): 22 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 23 | j = i + seqlen 24 | inp = trainenc.input_ids[:, i:j] 25 | tar = inp.clone() 26 | tar[:, :-1] = -100 27 | trainloader.append((inp, tar)) 28 | return trainloader, testenc 29 | 30 | 31 | def get_ptb(nsamples, seed, seqlen, model): 32 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') 33 | testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test') 34 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 35 | 36 | trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt') 37 | testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt') 38 | 39 | random.seed(seed) 40 | trainloader = [] 41 | for _ in range(nsamples): 42 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 43 | j = i + seqlen 44 | inp = trainenc.input_ids[:, i:j] 45 | tar = inp.clone() 46 | tar[:, :-1] = -100 47 | trainloader.append((inp, tar)) 48 | return trainloader, testenc 49 | 50 | 51 | def get_c4(nsamples, seed, seqlen, model): 52 | traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train') 53 | valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation') 54 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) 55 | 56 | random.seed(seed) 57 | trainloader = [] 58 | for _ in range(nsamples): 59 | while True: 60 | i = random.randint(0, len(traindata) - 1) 61 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') 62 | if trainenc.input_ids.shape[1] > seqlen: 63 | break 64 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 65 | j = i + seqlen 66 | inp = trainenc.input_ids[:, i:j] 67 | tar = inp.clone() 68 | tar[:, :-1] = -100 69 | trainloader.append((inp, tar)) 70 | 71 | valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') 72 | valenc = valenc.input_ids[:, :(256 * seqlen)] 73 | valenc = TokenizerWrapper(valenc) 74 | return trainloader, valenc 75 | 76 | 77 | def get_loaders( 78 | name, nsamples=128, seed=0, seqlen=2048, model='' 79 | ): 80 | if name == 'wikitext2': 81 | return get_wikitext2(nsamples, seed, seqlen, model) 82 | if name == 'ptb': 83 | return get_ptb(nsamples, seed, seqlen, model) 84 | if name == 'c4': 85 | return get_c4(nsamples, seed, seqlen, model) 86 | else: 87 | raise NotImplementedError(f"Invalid dataset name: {name}") -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import datetime 4 | import builtins 5 | import numpy as np 6 | import multiprocessing as mp 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.distributed as dist 11 | import torch.backends.cudnn as cudnn 12 | 13 | import fnmatch 14 | from lm_eval import tasks, evaluator 15 | from accelerate import dispatch_model 16 | from accelerate.utils import get_balanced_memory, infer_auto_device_map 17 | 18 | 19 | def slurm_dist_init(seed=0, port=1999): 20 | mp.set_start_method('spawn', force=True) 21 | 22 | rank = int(os.environ['SLURM_PROCID']) 23 | world_size = os.environ['SLURM_NTASKS'] 24 | node_list = os.environ['SLURM_NODELIST'] 25 | num_gpus = torch.cuda.device_count() 26 | gpu_id = rank % num_gpus 27 | torch.cuda.set_device(gpu_id) 28 | 29 | if '[' in node_list: 30 | beg = node_list.find('[') 31 | pos1 = node_list.find('-', beg) 32 | if pos1 < 0: 33 | pos1 = 1000 34 | pos2 = node_list.find(',', beg) 35 | if pos2 < 0: 36 | pos2 = 1000 37 | node_list = node_list[:min(pos1, pos2)].replace('[', '') 38 | addr = node_list[8:].replace('-', '.') 39 | 40 | os.environ['MASTER_ADDR'] = addr 41 | # os.environ['MASTER_PORT'] = str(port) 42 | os.environ['WORLD_SIZE'] = world_size 43 | os.environ['RANK'] = str(rank) 44 | 45 | for _ in range(10): 46 | try: 47 | os.environ['MASTER_PORT'] = str(port) 48 | dist.init_process_group(backend='nccl') 49 | break 50 | except: 51 | port += 99 52 | continue 53 | 54 | # dist.init_process_group(backend='nccl') 55 | torch.cuda.set_device(gpu_id) 56 | dist.barrier() 57 | 58 | random.seed(seed) 59 | np.random.seed(seed) 60 | torch.manual_seed(seed) 61 | cudnn.benchmark = True 62 | 63 | 64 | def get_rank(): 65 | if not dist.is_available(): 66 | return 0 67 | if not dist.is_initialized(): 68 | return 0 69 | return dist.get_rank() 70 | 71 | 72 | def is_master_process(): 73 | return get_rank() == 0 74 | 75 | 76 | def setup_distributed_print(is_master_process): 77 | builtin_print = builtins.print 78 | 79 | def print(*args, **kwargs): 80 | if is_master_process: 81 | now = datetime.datetime.now().time() 82 | builtin_print('[{}] '.format(now), end='') 83 | builtin_print(*args, **kwargs) 84 | 85 | builtins.print = print 86 | 87 | 88 | def use_old_forward(module: nn.Module, recurse=False): 89 | if hasattr(module, '_old_forward'): 90 | module._new_forward = module.forward 91 | module.forward = module._old_forward 92 | 93 | if recurse: 94 | for child in module.children(): 95 | use_old_forward(child, recurse) 96 | 97 | 98 | def use_new_forward(module: nn.Module, recurse=False): 99 | if hasattr(module, '_new_forward'): 100 | module.forward = module._new_forward 101 | delattr(module, "_new_forward") 102 | 103 | if recurse: 104 | for child in module.children(): 105 | use_new_forward(child, recurse) 106 | 107 | 108 | def auto_map_model(model): 109 | print(f"Check no split modules: {model.model._no_split_modules}") 110 | max_memory = get_balanced_memory(model.model, dtype=torch.float16, no_split_module_classes=model.model._no_split_modules) 111 | print(f"Check max memory: {max_memory}") 112 | model.model.tie_weights() 113 | print("Model weights tied") 114 | device_map = infer_auto_device_map(model.model, dtype=torch.float16, max_memory=max_memory, no_split_module_classes=model.model._no_split_modules) 115 | print(f"Check device map: {device_map}") 116 | dispatch_model(model.model, device_map) 117 | 118 | 119 | class FakeScaler: 120 | def __call__(self, loss, optimizer, parameters=None, clip_grad=None, clip_mode=None): 121 | loss.backward() 122 | optimizer.step() 123 | 124 | 125 | def pattern_match(patterns, source_list): 126 | task_names = set() 127 | for pattern in patterns: 128 | for matching in fnmatch.filter(source_list, pattern): 129 | task_names.add(matching) 130 | return list(task_names) 131 | 132 | 133 | def lm_eval_model(model, args): 134 | if args.test_datasets is None: 135 | test_datasets = args.dataset 136 | else: 137 | test_datasets = pattern_match(args.test_datasets.split(","), tasks.ALL_TASKS) 138 | if test_datasets == []: 139 | return "No test dataset specified" 140 | 141 | return evaluator.simple_evaluate( 142 | model=model, 143 | tasks=test_datasets, 144 | batch_size=args.batch_size, 145 | device=model.device, 146 | no_cache=True 147 | ) 148 | 149 | 150 | def find_layers(module, layers=[nn.Linear], name=''): 151 | if type(module) in layers: 152 | return {name: module} 153 | res = {} 154 | for name1, child in module.named_children(): 155 | res.update(find_layers( 156 | child, layers=layers, name=name + '.' + name1 if name != '' else name1 157 | )) 158 | return res 159 | 160 | 161 | def init_learn_sparsity(sparse_layers, sparsity_step=0.01, prune_n=0, prune_m=0, blocksize=-1, sigmoid_smooth=False, lora_rank=-1, lora_alpha=1): 162 | for layer_name in sparse_layers: 163 | sparse_layer = sparse_layers[layer_name] 164 | sparse_layer.init_learn_sparsity(sparsity_step, prune_n, prune_m, blocksize, sigmoid_smooth, lora_rank, lora_alpha) 165 | 166 | 167 | def finish_learn_sparsity(sparse_layers): 168 | for layer_name in sparse_layers: 169 | sparse_layer = sparse_layers[layer_name] 170 | sparse_layer.finish_learn_sparsity() 171 | 172 | 173 | def get_sparsity(sparse_layers): 174 | total_param = sum([sparse_layers[layer_name].param_num for layer_name in sparse_layers]) 175 | sparsity = 0 176 | for layer_name in sparse_layers: 177 | sparse_layer = sparse_layers[layer_name] 178 | sparsity += sparse_layer.sparsity * (sparse_layer.param_num / total_param) 179 | return sparsity 180 | 181 | 182 | def get_sparsity_params(sparse_layers): 183 | params = [] 184 | for layer_name in sparse_layers: 185 | sparse_layer = sparse_layers[layer_name] 186 | if sparse_layer.sparsity_probabilities is not None: 187 | layer_sparsity_params = sparse_layer.sparsity_probabilities 188 | if type(layer_sparsity_params) is list: 189 | params.extend(layer_sparsity_params) 190 | else: 191 | params.append(layer_sparsity_params) 192 | return params 193 | 194 | 195 | def get_lora_params(sparse_layers): 196 | params = [] 197 | for layer_name in sparse_layers: 198 | sparse_layer = sparse_layers[layer_name] 199 | if sparse_layer.use_lora: 200 | params.append(sparse_layer.lora_A) 201 | params.append(sparse_layer.lora_B) 202 | return params 203 | 204 | 205 | def eval_ppl(model, testenc, batch_size=1, device='cuda'): 206 | testenc = testenc.input_ids 207 | nsamples = testenc.numel() // model.seqlen 208 | neg_log_likelihoods = [] 209 | for i in range(0, nsamples, batch_size): 210 | j = min(i + batch_size, nsamples) 211 | inputs = testenc[:,(i * model.seqlen):(j * model.seqlen)].to(device) 212 | inputs = inputs.reshape(j-i, model.seqlen) 213 | 214 | lm_logits = model._model_call(inputs) 215 | shift_logits = lm_logits[:, :-1, :].contiguous() 216 | shift_labels = inputs[:, 1:] 217 | 218 | loss_fct = nn.CrossEntropyLoss() 219 | loss = loss_fct(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1)) 220 | 221 | neg_log_likelihood = loss.float() * model.seqlen * (j - i) 222 | neg_log_likelihoods.append(neg_log_likelihood) 223 | 224 | ppl = torch.exp(torch.stack(neg_log_likelihoods).sum() / (nsamples * model.seqlen)) 225 | torch.cuda.empty_cache() 226 | 227 | return ppl.item() 228 | --------------------------------------------------------------------------------