├── .gitignore ├── README.md ├── block_influence.py ├── block_influence ├── eval.sh ├── evaluate_block_influence.py ├── llama_model.py └── mistral_model.py ├── block_influence_mmlu_shapley ├── eval.sh └── evaluate_mmlu_block_influence.py ├── dataset.py ├── dist_utils.py ├── evals ├── __init__.py ├── dist_mmlu.py ├── mmlu.py └── mmlu_utils.py ├── layer_influence ├── eval.sh ├── evaluate_layer_influence.py ├── llama_layer_influence.py └── mistral_layer_influence.py └── train_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ 2 | **/*logs*/ 3 | **/*datasets*/ 4 | **/*checkpoints*/ 5 | **/wandb/ 6 | *.log 7 | *.out 8 | *.png 9 | *.csv 10 | machine.* 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A deeper look at depth pruning of LLMs 2 | 3 | The official implementation for the paper: "A deeper look at depth pruning of LLMs" (https://arxiv.org/abs/2407.16286). 4 | 5 | 6 | ## Usage 7 | 8 | The main scripts are divided into three different categories. 9 | 10 | ### Block influence 11 | 12 | The main experiments are based on block influence where we compute the impact of different block influence estimation techniques for block pruning. 13 | The associated scripts are located in the directory: `./block_influence/`. 14 | The evaluate the model based on different block influence techniques, check the script: `block_influence/eval.sh`. 15 | 16 | ### MMLU Shapley 17 | 18 | MMLU shapley focuses on computing the loss shapley directly on the MMLU test set, which serves as an upperbound on the performance that can be achieved by different block pruning techniques on the MMLU benchmark. 19 | The experiments associated with MMLU shapley based block pruning are located in the directory: `./block_influence_mmlu_shapley/`. 20 | The evaluate the model based on MMLU loss shapley based block pruning, check the script: `block_influence_mmlu_shapley/eval.sh`. 21 | 22 | ### Layer influence 23 | 24 | We further disect the transformer block into its corresponding feed-forward and self-attention layers, and evaluate their impact separately. 25 | The experiments associated with layer influence estimation are located in the directory: `./layer_influence/`. 26 | The evaluate the model based on different layer influence techniques, check the script: `layer_influence/eval.sh`. 27 | 28 | 29 | ## Citation 30 | 31 | ``` 32 | @inproceedings{siddiqui2024deeper, 33 | title={A deeper look at depth pruning of LLMs}, 34 | author={Siddiqui, Shoaib Ahmed and Dong, Xin and Heinrich, Greg and Breuel, Thomas and Kautz, Jan and Krueger, David and Molchanov, Pavlo}, 35 | booktitle={ICML 2024 Workshop on Theoretical Foundations of Foundation Models} 36 | } 37 | ``` 38 | 39 | ## License 40 | 41 | MIT 42 | -------------------------------------------------------------------------------- /block_influence.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from typing import List, Dict 4 | 5 | from dist_utils import reduce_tensor 6 | 7 | 8 | class BlockInfluenceEstimator: 9 | """ 10 | Implemented from paper: https://arxiv.org/abs/2403.03853 11 | This influence estimator assumes that the importance of a block is directly related to the size of 12 | the change it induces to the hidden representation. 13 | """ 14 | def __init__(self, num_layers: int, device: torch.device, use_avg: bool = True): 15 | self.num_layers = num_layers 16 | self.device = device 17 | self.use_avg = use_avg 18 | 19 | # Initialize the counters 20 | self.cosine_similarity_dict = {i: 0. for i in range(self.num_layers)} 21 | self.total_dict = {i: 0 for i in range(self.num_layers)} 22 | 23 | @torch.no_grad() 24 | def update_block_stats(self, block_idx: int, prev_rep: torch.Tensor, updated_rep: torch.Tensor): 25 | cosine_sim = torch.nn.functional.cosine_similarity(prev_rep, updated_rep, dim=-1) # BLD format 26 | num_elements = np.prod(cosine_sim.shape) # all others should have the same shape 27 | cosine_sim = cosine_sim.mean() if self.use_avg else cosine_sim.sum() # sum cosine similarity over batch and token position 28 | self.cosine_similarity_dict[block_idx] += float(cosine_sim) 29 | self.total_dict[block_idx] += 1 if self.use_avg else num_elements 30 | 31 | def get_block_influence(self, block_idx: int) -> float: 32 | if self.total_dict[block_idx] == 0: # block not used 33 | return None 34 | avg_cosine_sim = self.cosine_similarity_dict[block_idx] / self.total_dict[block_idx] 35 | avg_cosine_sim = float(reduce_tensor(torch.tensor(avg_cosine_sim).to(self.device), average=True)) # collect from processes 36 | avg_cosine_dist = 1. - avg_cosine_sim 37 | return avg_cosine_dist 38 | 39 | def get_block_influences(self) -> List[float]: 40 | return [self.get_block_influence(i) for i in range(self.num_layers)] 41 | 42 | def __repr__(self) -> str: 43 | return f"[Block influence estimator]" 44 | -------------------------------------------------------------------------------- /block_influence/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export TOKENIZERS_PARALLELISM=false # disable tokenizer warning 4 | pip install wget 5 | 6 | # Get the DDP args 7 | HEAD_NODE_IP=$1 8 | NUM_NODES=$2 9 | NUM_GPUS_PER_NODE=8 10 | echo "Head node IP: ${HEAD_NODE_IP} / # nodes: ${NUM_NODES} / # GPUs per node: ${NUM_GPUS_PER_NODE}" 11 | 12 | # Check if HEAD_NODE_IP is given 13 | if [ -z "${HEAD_NODE_IP}" ]; then 14 | echo "No head node IP found. Using torchrun runner." 15 | RUNNER_CMD="torchrun --standalone --nnodes=1 --nproc_per_node=${NUM_GPUS_PER_NODE}" 16 | else 17 | export WORLD_SIZE=${SLURM_NTASKS} 18 | export RANK=${SLURM_PROCID} 19 | export LOCAL_RANK=${SLURM_LOCALID} 20 | export MASTER_ADDR=${HEAD_NODE_IP} 21 | export MASTER_PORT=29500 22 | echo "python args / world size: ${WORLD_SIZE} / rank: ${RANK} / local rank: ${LOCAL_RANK} / master addr: ${MASTER_ADDR} / master port: ${MASTER_PORT}" 23 | 24 | RUNNER_CMD="python" 25 | fi 26 | 27 | DEFAULT_MODEL="llama-2" 28 | MODEL=${3:-$DEFAULT_MODEL} 29 | echo "Using model: ${MODEL}" 30 | 31 | ${RUNNER_CMD} block_influence/evaluate_block_influence.py \ 32 | --dataset "openwebtext" \ 33 | --model-name ${MODEL} \ 34 | --model-size 7b \ 35 | --batch-size 1 \ 36 | --sequence-length 2048 \ 37 | --subsample-size 250000 \ 38 | --wandb-project 'block_influence' 39 | -------------------------------------------------------------------------------- /block_influence/evaluate_block_influence.py: -------------------------------------------------------------------------------- 1 | #!/bin/python 2 | 3 | import os 4 | import time 5 | import json 6 | import random 7 | import argparse 8 | from tqdm import tqdm 9 | from typing import Tuple 10 | 11 | import wandb 12 | 13 | import torch 14 | import numpy as np 15 | 16 | from transformers import AutoTokenizer, AutoConfig 17 | 18 | import sys 19 | sys.path.append('.') 20 | from llama_model import LlamaForCausalLM 21 | from mistral_model import MistralForCausalLM 22 | from dataset import NLPDataset, get_dataloader 23 | from train_utils import get_num_model_params 24 | from dist_utils import init_distributed_env, is_main_proc, wait_for_other_procs, reduce_tensor 25 | from block_influence import BlockInfluenceEstimator 26 | from evals.dist_mmlu import MMLUDataset, evaluate_mmlu 27 | 28 | 29 | def load_model(args, only_tokenizer=False, pretrained=False): 30 | # assumes huggingface login: `huggingface-cli login`` 31 | if args.model_name == "llama-2": 32 | if args.use_instruct_model: 33 | model_name = f"meta-llama/Llama-2-{args.model_size.lower()}-chat-hf" 34 | else: 35 | model_name = f"meta-llama/Llama-2-{args.model_size.lower()}-hf" 36 | elif args.model_name == "mistral": 37 | if args.use_instruct_model: 38 | model_name = f"mistralai/Mistral-{args.model_size.upper()}-Instruct-v0.2" 39 | else: 40 | model_name = f"mistralai/Mistral-{args.model_size.upper()}-v0.1" 41 | else: 42 | raise RuntimeError(f"Unsupported model: {args.model_name}") 43 | print("!! Loading model:", model_name) 44 | 45 | # Load the tokenizer 46 | tokenizer = AutoTokenizer.from_pretrained(model_name) 47 | if only_tokenizer: 48 | return tokenizer 49 | 50 | # Load the model as well as the tokenizer 51 | config = AutoConfig.from_pretrained(model_name) 52 | print("Config:", config) 53 | kwargs = dict(torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2") 54 | print("Model precision:", kwargs["torch_dtype"]) 55 | if pretrained: 56 | print("Using pretrained model...") 57 | 58 | if args.model_name == "llama-2": 59 | if not pretrained: 60 | model = LlamaForCausalLM(config).to(kwargs["torch_dtype"]) 61 | else: 62 | model = LlamaForCausalLM.from_pretrained(model_name, **kwargs) 63 | elif args.model_name == "mistral": 64 | if not pretrained: 65 | model = MistralForCausalLM(config).to(kwargs["torch_dtype"]) 66 | else: 67 | model = MistralForCausalLM.from_pretrained(model_name, **kwargs) 68 | else: 69 | raise RuntimeError(f"Unsupported model: {args.model_name}") 70 | return model, tokenizer 71 | 72 | 73 | def compute_log_probs(logits: torch.Tensor, target_ids: torch.Tensor) -> Tuple[np.ndarray, np.ndarray]: 74 | # Apply softmax and log to obtain log probabilities from logits (summing original logits would be incorrect) 75 | log_probs = torch.log_softmax(logits.float(), dim=-1) 76 | 77 | log_probs = torch.gather(log_probs, 2, target_ids.unsqueeze(-1)).squeeze(-1) 78 | sequence_log_prob = log_probs.sum(dim=1).cpu().float().numpy() 79 | 80 | # Calculate perplexity 81 | sequence_length = target_ids.size(-1) 82 | assert sequence_length > 0, logits 83 | sequence_perplexity = np.exp(-sequence_log_prob / sequence_length) 84 | 85 | return sequence_perplexity, sequence_log_prob 86 | 87 | 88 | @torch.no_grad() 89 | def evaluate_model(model: torch.nn.Module, eval_loader: torch.utils.data.DataLoader, device: torch.device, split_name: str): 90 | model.eval() 91 | avg_sequence_perplexity = 0. 92 | avg_loss = 0. 93 | num_ex = 0 94 | 95 | for batch in tqdm(eval_loader): 96 | tokenized_input = batch["input_ids"].to(device) 97 | 98 | # Forward prop through the model (will also populate the loss, but one extra logit) 99 | outputs = model(tokenized_input, labels=tokenized_input) 100 | 101 | # Compute metrics on top of LM logits 102 | lm_logits = outputs.logits[:, :-1, :] # BTD format (discard the final logit) 103 | target_ids = tokenized_input[:, 1:] # input ids strided by one 104 | assert len(lm_logits.shape) == 3, lm_logits.shape 105 | assert len(target_ids.shape) == 2, target_ids.shape 106 | assert lm_logits.shape[1] == target_ids.shape[1], f"{lm_logits.shape} != {target_ids.shape}" 107 | perplexity, log_prob = compute_log_probs(lm_logits, target_ids) 108 | 109 | avg_sequence_perplexity += float(perplexity.sum()) 110 | avg_loss += float(outputs.loss) 111 | num_ex += len(tokenized_input) 112 | 113 | # Collect the stats from all processes 114 | avg_sequence_perplexity = float(reduce_tensor(torch.tensor(avg_sequence_perplexity).to(device))) 115 | avg_loss = float(reduce_tensor(torch.tensor(avg_loss).to(device))) 116 | num_ex = int(reduce_tensor(torch.tensor(num_ex).to(device))) 117 | 118 | avg_sequence_perplexity = avg_sequence_perplexity / num_ex 119 | avg_loss = avg_loss / num_ex 120 | output_dict = {"split": split_name, "num_ex": num_ex, "avg_loss": avg_loss, "avg_seq_perplexity": avg_sequence_perplexity} 121 | print(json.dumps(output_dict)) 122 | if split_name is not None and wandb.run is not None: 123 | wandb.log({f"eval_{split_name}": {"num_ex": num_ex, "avg_loss": avg_loss, "avg_seq_perplexity": avg_sequence_perplexity}}) 124 | return avg_loss, avg_sequence_perplexity 125 | 126 | 127 | @torch.no_grad() 128 | def compute_block_shapley(model: torch.nn.Module, eval_loader: torch.utils.data.DataLoader, device: torch.device, 129 | use_random_subnetworks: bool = False, subnetwork_len: float = 0.5, seed: int = 43, 130 | num_subsampled_networks: int = 10, max_samples_per_proc: int = None): 131 | model.eval() 132 | num_model_layers = model.get_num_model_layers() 133 | print(f"!! Computing the logit shapley value for the model with {num_model_layers} layers...") 134 | rng = np.random.default_rng(seed) 135 | if not use_random_subnetworks: 136 | num_subsampled_networks = num_model_layers 137 | 138 | all_statistics = [] 139 | for iterator, batch in enumerate(tqdm(eval_loader)): 140 | tokenized_input = batch["input_ids"].to(device) 141 | base_logits = None 142 | for i in range(1+num_subsampled_networks): # first one is always base model eval 143 | selected_blocks = None # use full network 144 | if i != 0: # use subnetwork 145 | if use_random_subnetworks: 146 | selected_blocks = rng.choice(range(num_model_layers), int(subnetwork_len*num_model_layers), replace=False) 147 | else: 148 | block_to_remove = i - 1 149 | selected_blocks = [x for x in range(num_model_layers) if x != block_to_remove] 150 | model.select_blocks(selected_blocks, verbose=False) 151 | 152 | outputs = model(tokenized_input, labels=tokenized_input) 153 | lm_logits = outputs.logits[:, :-1, :] # BTD format (discard the final logit) 154 | lm_loss = outputs.loss 155 | if base_logits is None: 156 | assert selected_blocks is None 157 | base_logits = lm_logits 158 | else: 159 | assert selected_blocks is not None 160 | diff_norm = torch.norm(base_logits - lm_logits, p=2, dim=-1).mean() # mean over batch and sequence 161 | all_statistics.append((selected_blocks, float(diff_norm), float(lm_loss))) 162 | 163 | # Check if stopping condition is met 164 | if max_samples_per_proc is not None and iterator >= max_samples_per_proc - 1: 165 | print(f"{iterator} samples collected for logit shapley value. Stopping further computations!") 166 | break 167 | 168 | # Compute the block influence based on the computed statistics 169 | logit_dist = {i: {"present": [], "absent": []} for i in range(num_model_layers)} 170 | loss_dist = {i: {"present": [], "absent": []} for i in range(num_model_layers)} 171 | for selected_blocks, diff_norm, loss in all_statistics: 172 | for i in range(num_model_layers): 173 | key = "present" if i in selected_blocks else "absent" 174 | logit_dist[i][key].append(diff_norm) 175 | loss_dist[i][key].append(loss) 176 | 177 | # Compute average distances 178 | print("~~~~~~ Block shapley statistics ~~~~~~") 179 | logit_shapley_list = [] 180 | loss_shapley_list = [] 181 | for key, input_container, output_container in [("dist", logit_dist, logit_shapley_list), 182 | ("loss", loss_dist, loss_shapley_list)]: 183 | for i in range(num_model_layers): 184 | for name in ["present", "absent"]: 185 | mean = np.mean(input_container[i][name]) # convert it to mean 186 | input_container[i][name] = float(reduce_tensor(torch.tensor(mean).to(device), average=True)) 187 | shapley = input_container[i]['present'] - input_container[i]['absent'] 188 | print(f"> block {i} / present mean {key}: {input_container[i]['present']} / absent mean {key}: {input_container[i]['absent']} / shapley: {shapley}") 189 | output_container.append(shapley) 190 | print("-"*50) 191 | return logit_shapley_list, loss_shapley_list 192 | 193 | 194 | def main(args): 195 | init_distributed_env(args) 196 | 197 | generator = None 198 | if args.seed is not None: # Set process seed to reduce stochasticity 199 | torch.manual_seed(args.seed) 200 | torch.cuda.manual_seed(args.seed) 201 | np.random.seed(seed=args.seed) 202 | random.seed(args.seed) 203 | print("Setting process seed:", args.seed) 204 | 205 | # Generator to seed dataloaders 206 | generator = torch.Generator() 207 | generator.manual_seed(args.seed) 208 | 209 | dataset_dir = f"{args.dataset}_model_{args.model_name}_seq_len_{args.sequence_length}_subsample_{args.subsample_size}_comb_docs" 210 | args.dataset_output_dir = os.path.join("datasets", dataset_dir) 211 | 212 | suffix = "block_pruning" 213 | args.wandb_run_name = f"{dataset_dir}_{suffix}" 214 | 215 | if args.wandb_project is not None and is_main_proc(): 216 | print("Initialization w&b...") 217 | wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=args, resume=False) 218 | 219 | if is_main_proc() and not NLPDataset.is_dataset_processed(args.dataset_output_dir): 220 | tokenizer = load_model(args, only_tokenizer=True) 221 | dataset = NLPDataset(args.dataset, tokenizer, max_length=args.sequence_length, 222 | combine_documents=True, subsample_size=args.subsample_size) 223 | dataset.save_datasets(args.dataset_output_dir) 224 | wait_for_other_procs() # wait for the main process to write the dataset 225 | 226 | # Load the dataset 227 | dataset = NLPDataset.load_dataset(args.dataset_output_dir) # returns a dataset dict 228 | train_dataset = dataset["train"] 229 | test_dataset = dataset["test"] 230 | 231 | # Load the model 232 | model, tokenizer = load_model(args, pretrained=True) 233 | num_model_params = get_num_model_params(model) 234 | num_model_layers = model.get_num_model_layers() 235 | print(f"# model params: {num_model_params/1_000_000:.2f}M / # layers: {num_model_layers}") 236 | 237 | # Convert to DDP 238 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 239 | model = model.to(device) # move to device 240 | 241 | # Create the dataloaders 242 | train_loader = get_dataloader(train_dataset, args.batch_size, args.num_workers, drop_last=True, generator=generator) 243 | eval_loader = get_dataloader(test_dataset, args.test_batch_size, args.num_workers, generator=generator) 244 | 245 | # Load MMLU dataset 246 | mmlu_dataset = MMLUDataset(tokenizer, args.model_name) 247 | print("# examples in MMLU dataset:", len(mmlu_dataset)) 248 | mmlu_loader = get_dataloader(mmlu_dataset, 1, args.num_workers, generator=generator) # bs=1 for MMLU 249 | 250 | print(">> Estimating block influences...") 251 | eval_start_time = time.time() 252 | model.select_blocks(None) # use all blocks 253 | block_influence_estimator = BlockInfluenceEstimator(num_model_layers, device) 254 | model.add_block_influence_estimator(block_influence_estimator) 255 | evaluate_model(model, train_loader, device, split_name=None) # use the train set to compute block influences 256 | final_block_influences = block_influence_estimator.get_block_influences() 257 | model.add_block_influence_estimator(None) # remove the block influence computation 258 | print("Final block influences:", final_block_influences) 259 | 260 | cosine_block_influences = [x["cosine_dist"] for x in final_block_influences] 261 | l1_block_influences = [x["l1_update_norm"] for x in final_block_influences] 262 | relative_l1_block_influences = [x["l1_relative_update_norm"] for x in final_block_influences] 263 | l2_block_influences = [x["l2_update_norm"] for x in final_block_influences] 264 | relative_l2_block_influences = [x["l2_relative_update_norm"] for x in final_block_influences] 265 | print("Cosine block influences:", cosine_block_influences) 266 | print("L1 block influences:", l1_block_influences) 267 | print("Relative L1 block influences:", relative_l1_block_influences) 268 | print("L2 block influences:", l2_block_influences) 269 | print("Relative L2 block influences:", relative_l2_block_influences) 270 | 271 | if wandb.run is not None: 272 | wandb.log({f"block_{i}_cosine_influence": block_influence for i, block_influence in enumerate(cosine_block_influences)}) 273 | wandb.log({f"block_{i}_l1_influence": block_influence for i, block_influence in enumerate(l1_block_influences)}) 274 | wandb.log({f"block_{i}_relative_l1_influence": block_influence for i, block_influence in enumerate(relative_l1_block_influences)}) 275 | wandb.log({f"block_{i}_l2_influence": block_influence for i, block_influence in enumerate(l2_block_influences)}) 276 | wandb.log({f"block_{i}_relative_l2_influence": block_influence for i, block_influence in enumerate(relative_l2_block_influences)}) 277 | 278 | # Compute the block logit shapley 279 | max_samples_per_proc = None 280 | if args.limit_shapley_samples is not None: 281 | max_samples_per_proc = args.limit_shapley_samples // args.world_size 282 | print(f"Total samples limit: {args.limit_shapley_samples}. Capping the max_samples_per_proc for logit shapley computation to be: {max_samples_per_proc}") 283 | block_logit_shapley, block_loss_shapley = compute_block_shapley(model, train_loader, device, max_samples_per_proc=max_samples_per_proc) 284 | print("Block logit shapley:", block_logit_shapley) 285 | print("Block loss shapley:", block_loss_shapley) 286 | logit_shapley_block_influence = [-x for x in block_logit_shapley] # negative shapely (lower distance) indicates higher importance 287 | loss_shapley_block_influence = [-x for x in block_loss_shapley] # negative shapely (lower distance) indicates higher importance 288 | if wandb.run is not None: 289 | wandb.log({f"block_{i}_logit_shapley_influence": block_influence for i, block_influence in enumerate(logit_shapley_block_influence)}) 290 | wandb.log({f"block_{i}_loss_shapley_influence": block_influence for i, block_influence in enumerate(loss_shapley_block_influence)}) 291 | 292 | block_influence_list = [("cosine", cosine_block_influences), ("relative_l1", relative_l1_block_influences), 293 | ("relative_l2", relative_l2_block_influences), ("logit_shapley", logit_shapley_block_influence), 294 | ("loss_shapley", loss_shapley_block_influence)] 295 | for influence_name, block_influences in block_influence_list: 296 | print("Using block influence method:", influence_name) 297 | print("Block influence values:", block_influences) 298 | sorted_blocks = np.argsort(block_influences) # ascending order 299 | print("Sorted block list:", sorted_blocks) 300 | 301 | remaining_blocks = list(range(num_model_layers)) 302 | weighted_acc_list = [] 303 | perplexity_list = [] 304 | iterator = -1 305 | for _ in range(len(sorted_blocks)+1): # one additional iteration for no dropping 306 | if iterator > -1: # do nothing for the first block i.e., all blocks are selected 307 | lowest_block = sorted_blocks[iterator] # prune blocks based on the estimated block influence 308 | print(f"Removing block {lowest_block} with lowest influence: {block_influences[lowest_block]}") 309 | remaining_blocks = [i for i in remaining_blocks if i != lowest_block] # remove lowest block 310 | print("Remaining blocks:", remaining_blocks) 311 | model.select_blocks(remaining_blocks) # use all blocks 312 | _, _, weighted_acc = evaluate_mmlu(model, tokenizer, mmlu_loader, device, f"{influence_name}_blocks_pruned_{iterator+1}") 313 | _, avg_perplexity = evaluate_model(model, eval_loader, device, f"{influence_name}_blocks_pruned_{iterator+1}") 314 | weighted_acc_list.append(weighted_acc) 315 | perplexity_list.append(avg_perplexity) 316 | iterator += 1 317 | 318 | print(f">>>>> Block pruning statistics using {influence_name} metric <<<<<") 319 | print(f"{influence_name} weighted ACC list: {weighted_acc_list}") 320 | print(f"{influence_name} perplexity list: {perplexity_list}") 321 | print("="*25) 322 | 323 | eval_time_elapsed_h = (time.time() - eval_start_time) / (60 * 60) # convert seconds into hours 324 | print(f"Block pruning evaluation completed / time elapsed: {eval_time_elapsed_h:.2f}h") 325 | 326 | if wandb.run is not None: 327 | wandb.finish() 328 | 329 | 330 | if __name__ == "__main__": 331 | supported_datasets = ['pg19', 'cc_news', 'wikitext-2', 'bookcorpus', 'c4', 'openwebtext', 'slimpajama'] 332 | 333 | # Create ArgumentParser object 334 | parser = argparse.ArgumentParser(description='Argument parser for LLM block influence evaluator') 335 | 336 | # Add arguments 337 | parser.add_argument('-d', '--dataset', default='wikitext-2', choices=supported_datasets, 338 | help='Dataset name (default: wikitext-2)') 339 | parser.add_argument('-m', '--model-name', default='llama-2', choices=['llama-2', 'mistral'], 340 | help='Model name (default: llama-2)') 341 | parser.add_argument('-s', '--model-size', default='7b', choices=['7b'], 342 | help='Model size (default: 7b)') 343 | parser.add_argument('--use-instruct-model', action='store_true', default=False, 344 | help='Use instruction-tuned model rather than the base model') 345 | parser.add_argument('--batch-size', type=int, default=1, 346 | help='Batch size per process (default: 1)') 347 | parser.add_argument('--test-batch-size', type=int, default=None, 348 | help='Batch size per process for testing (default: equal to --batch-size)') 349 | parser.add_argument('--sequence-length', type=int, default=1024, 350 | help='Sequence length for computing the model perplexity (default: 1024)') 351 | parser.add_argument('--subsample-size', type=int, default=1000000, 352 | help='Dataset subsample size in terms of number of docs (default: 1M)') 353 | parser.add_argument('--num-workers', type=int, default=8, 354 | help='Number of workers for the dataloader (default: 8)') 355 | parser.add_argument('--seed', type=int, default=43, 356 | help='seed value (default: 43)') 357 | parser.add_argument('--wandb-project', type=str, default=None, 358 | help='W&B project name (none indicates no W&B initialization)') 359 | parser.add_argument('--limit-shapley-samples', type=int, default=None, 360 | help='limit the number of samples to the specified value for shapley computation (default: None i.e., no limit)') 361 | 362 | # Parse the arguments 363 | args = parser.parse_args() 364 | 365 | if args.test_batch_size is None: 366 | args.test_batch_size = args.batch_size 367 | print("Setting test batch size to be equal to batch size:", args.test_batch_size) 368 | 369 | main(args) 370 | -------------------------------------------------------------------------------- /block_influence/llama_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base code taken from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py 3 | """ 4 | 5 | # coding=utf-8 6 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 7 | # 8 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 9 | # and OPT implementations in this library. It has been modified from its 10 | # original forms to accommodate minor architectural differences compared 11 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 12 | # 13 | # Licensed under the Apache License, Version 2.0 (the "License"); 14 | # you may not use this file except in compliance with the License. 15 | # You may obtain a copy of the License at 16 | # 17 | # http://www.apache.org/licenses/LICENSE-2.0 18 | # 19 | # Unless required by applicable law or agreed to in writing, software 20 | # distributed under the License is distributed on an "AS IS" BASIS, 21 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 22 | # See the License for the specific language governing permissions and 23 | # limitations under the License. 24 | """ PyTorch LLaMA model.""" 25 | import math 26 | import numpy as np 27 | from typing import List, Optional, Tuple, Union 28 | 29 | import torch 30 | import torch.nn.functional as F 31 | import torch.utils.checkpoint 32 | from torch import nn 33 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 34 | 35 | from transformers.activations import ACT2FN 36 | from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast 37 | from transformers.modeling_utils import PreTrainedModel 38 | from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS 39 | from transformers.utils import ( 40 | add_start_docstrings, 41 | add_start_docstrings_to_model_forward, 42 | is_flash_attn_2_available, 43 | logging, 44 | replace_return_docstrings 45 | ) 46 | from transformers.models.llama.configuration_llama import LlamaConfig 47 | 48 | import sys 49 | sys.path.append('.') 50 | from block_influence import BlockInfluenceEstimator 51 | 52 | 53 | if is_flash_attn_2_available(): 54 | from flash_attn import flash_attn_func, flash_attn_varlen_func 55 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa 56 | 57 | 58 | logger = logging.get_logger(__name__) 59 | 60 | _CONFIG_FOR_DOC = "LlamaConfig" 61 | 62 | 63 | def _get_unpad_data(padding_mask): 64 | seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) 65 | indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() 66 | max_seqlen_in_batch = seqlens_in_batch.max().item() 67 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) 68 | return ( 69 | indices, 70 | cu_seqlens, 71 | max_seqlen_in_batch, 72 | ) 73 | 74 | 75 | # Copied from transformers.models.bart.modeling_bart._make_causal_mask 76 | def _make_causal_mask( 77 | input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 78 | ): 79 | """ 80 | Make causal mask used for bi-directional self-attention. 81 | """ 82 | bsz, tgt_len = input_ids_shape 83 | mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) 84 | mask_cond = torch.arange(mask.size(-1), device=device) 85 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) 86 | mask = mask.to(dtype) 87 | 88 | if past_key_values_length > 0: 89 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) 90 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) 91 | 92 | 93 | # Copied from transformers.models.bart.modeling_bart._expand_mask 94 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 95 | """ 96 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. 97 | """ 98 | bsz, src_len = mask.size() 99 | tgt_len = tgt_len if tgt_len is not None else src_len 100 | 101 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) 102 | 103 | inverted_mask = 1.0 - expanded_mask 104 | 105 | return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) 106 | 107 | 108 | class LlamaRMSNorm(nn.Module): 109 | def __init__(self, hidden_size, eps=1e-6): 110 | """ 111 | LlamaRMSNorm is equivalent to T5LayerNorm 112 | """ 113 | super().__init__() 114 | self.weight = nn.Parameter(torch.ones(hidden_size)) 115 | self.variance_epsilon = eps 116 | 117 | def forward(self, hidden_states): 118 | input_dtype = hidden_states.dtype 119 | hidden_states = hidden_states.to(torch.float32) 120 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 121 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 122 | return self.weight * hidden_states.to(input_dtype) 123 | 124 | 125 | ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) 126 | 127 | 128 | class LlamaRotaryEmbedding(nn.Module): 129 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 130 | super().__init__() 131 | 132 | self.dim = dim 133 | self.max_position_embeddings = max_position_embeddings 134 | self.base = base 135 | inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 136 | self.register_buffer("inv_freq", inv_freq, persistent=False) 137 | 138 | # Build here to make `torch.jit.trace` work. 139 | self._set_cos_sin_cache( 140 | seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() 141 | ) 142 | 143 | def _set_cos_sin_cache(self, seq_len, device, dtype): 144 | self.max_seq_len_cached = seq_len 145 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 146 | 147 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 148 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 149 | emb = torch.cat((freqs, freqs), dim=-1) 150 | self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 151 | self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 152 | 153 | def forward(self, x, seq_len=None): 154 | # x: [bs, num_attention_heads, seq_len, head_size] 155 | if seq_len > self.max_seq_len_cached: 156 | self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) 157 | 158 | return ( 159 | self.cos_cached[:seq_len].to(dtype=x.dtype), 160 | self.sin_cached[:seq_len].to(dtype=x.dtype), 161 | ) 162 | 163 | 164 | class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): 165 | """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" 166 | 167 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): 168 | self.scaling_factor = scaling_factor 169 | super().__init__(dim, max_position_embeddings, base, device) 170 | 171 | def _set_cos_sin_cache(self, seq_len, device, dtype): 172 | self.max_seq_len_cached = seq_len 173 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 174 | t = t / self.scaling_factor 175 | 176 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 177 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 178 | emb = torch.cat((freqs, freqs), dim=-1) 179 | self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 180 | self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 181 | 182 | 183 | class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): 184 | """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" 185 | 186 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): 187 | self.scaling_factor = scaling_factor 188 | super().__init__(dim, max_position_embeddings, base, device) 189 | 190 | def _set_cos_sin_cache(self, seq_len, device, dtype): 191 | self.max_seq_len_cached = seq_len 192 | 193 | if seq_len > self.max_position_embeddings: 194 | base = self.base * ( 195 | (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) 196 | ) ** (self.dim / (self.dim - 2)) 197 | inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 198 | self.register_buffer("inv_freq", inv_freq, persistent=False) 199 | 200 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 201 | 202 | freqs = torch.einsum("i,j->ij", t, self.inv_freq) 203 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 204 | emb = torch.cat((freqs, freqs), dim=-1) 205 | self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 206 | self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 207 | 208 | 209 | def rotate_half(x): 210 | """Rotates half the hidden dims of the input.""" 211 | x1 = x[..., : x.shape[-1] // 2] 212 | x2 = x[..., x.shape[-1] // 2 :] 213 | return torch.cat((-x2, x1), dim=-1) 214 | 215 | 216 | # Copied from transformers.models.gpt_neox.modeling_gpt_neox.apply_rotary_pos_emb 217 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids): 218 | cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim] 219 | sin = sin[position_ids].unsqueeze(1) 220 | q_embed = (q * cos) + (rotate_half(q) * sin) 221 | k_embed = (k * cos) + (rotate_half(k) * sin) 222 | return q_embed, k_embed 223 | 224 | 225 | class LlamaMLP(nn.Module): 226 | def __init__(self, config): 227 | super().__init__() 228 | self.config = config 229 | self.hidden_size = config.hidden_size 230 | self.intermediate_size = config.intermediate_size 231 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 232 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 233 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) 234 | self.act_fn = ACT2FN[config.hidden_act] 235 | 236 | def forward(self, x): 237 | if self.config.pretraining_tp > 1: 238 | slice = self.intermediate_size // self.config.pretraining_tp 239 | gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) 240 | up_proj_slices = self.up_proj.weight.split(slice, dim=0) 241 | down_proj_slices = self.down_proj.weight.split(slice, dim=1) 242 | 243 | gate_proj = torch.cat( 244 | [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 245 | ) 246 | up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) 247 | 248 | intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) 249 | down_proj = [ 250 | F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) 251 | ] 252 | down_proj = sum(down_proj) 253 | else: 254 | down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) 255 | 256 | return down_proj 257 | 258 | 259 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 260 | """ 261 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, 262 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) 263 | """ 264 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape 265 | if n_rep == 1: 266 | return hidden_states 267 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) 268 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 269 | 270 | 271 | class LlamaAttention(nn.Module): 272 | """Multi-headed attention from 'Attention Is All You Need' paper""" 273 | 274 | def __init__(self, config: LlamaConfig): 275 | super().__init__() 276 | self.config = config 277 | self.hidden_size = config.hidden_size 278 | self.num_heads = config.num_attention_heads 279 | self.head_dim = self.hidden_size // self.num_heads 280 | self.num_key_value_heads = config.num_key_value_heads 281 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads 282 | self.max_position_embeddings = config.max_position_embeddings 283 | self.rope_theta = config.rope_theta 284 | 285 | if (self.head_dim * self.num_heads) != self.hidden_size: 286 | raise ValueError( 287 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" 288 | f" and `num_heads`: {self.num_heads})." 289 | ) 290 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) 291 | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) 292 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) 293 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) 294 | self._init_rope() 295 | 296 | def _init_rope(self): 297 | if self.config.rope_scaling is None: 298 | self.rotary_emb = LlamaRotaryEmbedding( 299 | self.head_dim, 300 | max_position_embeddings=self.max_position_embeddings, 301 | base=self.rope_theta, 302 | ) 303 | else: 304 | scaling_type = self.config.rope_scaling["type"] 305 | scaling_factor = self.config.rope_scaling["factor"] 306 | if scaling_type == "linear": 307 | self.rotary_emb = LlamaLinearScalingRotaryEmbedding( 308 | self.head_dim, 309 | max_position_embeddings=self.max_position_embeddings, 310 | scaling_factor=scaling_factor, 311 | base=self.rope_theta, 312 | ) 313 | elif scaling_type == "dynamic": 314 | self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( 315 | self.head_dim, 316 | max_position_embeddings=self.max_position_embeddings, 317 | scaling_factor=scaling_factor, 318 | base=self.rope_theta, 319 | ) 320 | else: 321 | raise ValueError(f"Unknown RoPE scaling type {scaling_type}") 322 | 323 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 324 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 325 | 326 | def forward( 327 | self, 328 | hidden_states: torch.Tensor, 329 | attention_mask: Optional[torch.Tensor] = None, 330 | position_ids: Optional[torch.LongTensor] = None, 331 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 332 | output_attentions: bool = False, 333 | use_cache: bool = False, 334 | padding_mask: Optional[torch.LongTensor] = None, 335 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 336 | bsz, q_len, _ = hidden_states.size() 337 | 338 | if self.config.pretraining_tp > 1: 339 | key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp 340 | query_slices = self.q_proj.weight.split( 341 | (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 342 | ) 343 | key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) 344 | value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) 345 | 346 | query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] 347 | query_states = torch.cat(query_states, dim=-1) 348 | 349 | key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] 350 | key_states = torch.cat(key_states, dim=-1) 351 | 352 | value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] 353 | value_states = torch.cat(value_states, dim=-1) 354 | 355 | else: 356 | query_states = self.q_proj(hidden_states) 357 | key_states = self.k_proj(hidden_states) 358 | value_states = self.v_proj(hidden_states) 359 | 360 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 361 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 362 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 363 | 364 | kv_seq_len = key_states.shape[-2] 365 | if past_key_value is not None: 366 | kv_seq_len += past_key_value[0].shape[-2] 367 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 368 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 369 | 370 | if past_key_value is not None: 371 | # reuse k, v, self_attention 372 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 373 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 374 | 375 | past_key_value = (key_states, value_states) if use_cache else None 376 | 377 | key_states = repeat_kv(key_states, self.num_key_value_groups) 378 | value_states = repeat_kv(value_states, self.num_key_value_groups) 379 | 380 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 381 | 382 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 383 | raise ValueError( 384 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" 385 | f" {attn_weights.size()}" 386 | ) 387 | 388 | if attention_mask is not None: 389 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 390 | raise ValueError( 391 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 392 | ) 393 | attn_weights = attn_weights + attention_mask 394 | 395 | # upcast attention to fp32 396 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 397 | attn_output = torch.matmul(attn_weights, value_states) 398 | 399 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 400 | raise ValueError( 401 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 402 | f" {attn_output.size()}" 403 | ) 404 | 405 | attn_output = attn_output.transpose(1, 2).contiguous() 406 | 407 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 408 | 409 | if self.config.pretraining_tp > 1: 410 | attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) 411 | o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) 412 | attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) 413 | else: 414 | attn_output = self.o_proj(attn_output) 415 | 416 | if not output_attentions: 417 | attn_weights = None 418 | 419 | return attn_output, attn_weights, past_key_value 420 | 421 | 422 | class LlamaFlashAttention2(LlamaAttention): 423 | """ 424 | Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays 425 | untouched. The only required change would be on the forward pass where it needs to correctly call the public API of 426 | flash attention and deal with padding tokens in case the input contains any of them. 427 | """ 428 | 429 | def forward( 430 | self, 431 | hidden_states: torch.Tensor, 432 | attention_mask: Optional[torch.Tensor] = None, 433 | position_ids: Optional[torch.LongTensor] = None, 434 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 435 | output_attentions: bool = False, 436 | use_cache: bool = False, 437 | padding_mask: Optional[torch.LongTensor] = None, 438 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 439 | # LlamaFlashAttention2 attention does not support output_attentions 440 | output_attentions = False 441 | 442 | bsz, q_len, _ = hidden_states.size() 443 | 444 | query_states = self.q_proj(hidden_states) 445 | key_states = self.k_proj(hidden_states) 446 | value_states = self.v_proj(hidden_states) 447 | 448 | # Flash attention requires the input to have the shape 449 | # batch_size x seq_length x head_dime x hidden_dim 450 | # therefore we just need to keep the original shape 451 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 452 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 453 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 454 | 455 | kv_seq_len = key_states.shape[-2] 456 | if past_key_value is not None: 457 | kv_seq_len += past_key_value[0].shape[-2] 458 | 459 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 460 | 461 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 462 | 463 | if past_key_value is not None: 464 | # reuse k, v, self_attention 465 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 466 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 467 | 468 | past_key_value = (key_states, value_states) if use_cache else None 469 | 470 | query_states = query_states.transpose(1, 2) 471 | key_states = key_states.transpose(1, 2) 472 | value_states = value_states.transpose(1, 2) 473 | 474 | # TODO: llama does not have dropout in the config?? 475 | # It is recommended to use dropout with FA according to the docs 476 | # when training. 477 | dropout_rate = 0.0 # if not self.training else self.attn_dropout 478 | 479 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons 480 | # therefore the input hidden states gets silently casted in float32. Hence, we need 481 | # cast them back in float16 just to be sure everything works as expected. 482 | # This might slowdown training & inference so it is recommended to not cast the LayerNorms 483 | # in fp32. (LlamaRMSNorm handles it correctly) 484 | input_dtype = query_states.dtype 485 | if input_dtype == torch.float32: 486 | logger.warning_once( 487 | "The input hidden states seems to be silently casted in float32, this might be related to" 488 | " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 489 | " float16." 490 | ) 491 | 492 | query_states = query_states.to(torch.float16) 493 | key_states = key_states.to(torch.float16) 494 | value_states = value_states.to(torch.float16) 495 | 496 | attn_output = self._flash_attention_forward( 497 | query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate 498 | ) 499 | 500 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() 501 | attn_output = self.o_proj(attn_output) 502 | 503 | if not output_attentions: 504 | attn_weights = None 505 | 506 | return attn_output, attn_weights, past_key_value 507 | 508 | def _flash_attention_forward( 509 | self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None 510 | ): 511 | """ 512 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 513 | first unpad the input, then computes the attention scores and pad the final attention scores. 514 | 515 | Args: 516 | query_states (`torch.Tensor`): 517 | Input query states to be passed to Flash Attention API 518 | key_states (`torch.Tensor`): 519 | Input key states to be passed to Flash Attention API 520 | value_states (`torch.Tensor`): 521 | Input value states to be passed to Flash Attention API 522 | padding_mask (`torch.Tensor`): 523 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the 524 | position of padding tokens and 1 for the position of non-padding tokens. 525 | dropout (`int`, *optional*): 526 | Attention dropout 527 | softmax_scale (`float`, *optional*): 528 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 529 | """ 530 | # Contains at least one padding token in the sequence 531 | if padding_mask is not None: 532 | batch_size = query_states.shape[0] 533 | query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( 534 | query_states, key_states, value_states, padding_mask, query_length 535 | ) 536 | 537 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens 538 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens 539 | 540 | attn_output_unpad = flash_attn_varlen_func( 541 | query_states, 542 | key_states, 543 | value_states, 544 | cu_seqlens_q=cu_seqlens_q, 545 | cu_seqlens_k=cu_seqlens_k, 546 | max_seqlen_q=max_seqlen_in_batch_q, 547 | max_seqlen_k=max_seqlen_in_batch_k, 548 | dropout_p=dropout, 549 | softmax_scale=softmax_scale, 550 | causal=True, 551 | ) 552 | 553 | attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) 554 | else: 555 | attn_output = flash_attn_func( 556 | query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True 557 | ) 558 | 559 | return attn_output 560 | 561 | def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): 562 | indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) 563 | batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape 564 | 565 | key_layer = index_first_axis( 566 | key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k 567 | ) 568 | value_layer = index_first_axis( 569 | value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k 570 | ) 571 | if query_length == kv_seq_len: 572 | query_layer = index_first_axis( 573 | query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k 574 | ) 575 | cu_seqlens_q = cu_seqlens_k 576 | max_seqlen_in_batch_q = max_seqlen_in_batch_k 577 | indices_q = indices_k 578 | elif query_length == 1: 579 | max_seqlen_in_batch_q = 1 580 | cu_seqlens_q = torch.arange( 581 | batch_size + 1, dtype=torch.int32, device=query_layer.device 582 | ) # There is a memcpy here, that is very bad. 583 | indices_q = cu_seqlens_q[:-1] 584 | query_layer = query_layer.squeeze(1) 585 | else: 586 | # The -q_len: slice assumes left padding. 587 | padding_mask = padding_mask[:, -query_length:] 588 | query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask) 589 | 590 | return ( 591 | query_layer, 592 | key_layer, 593 | value_layer, 594 | indices_q, 595 | (cu_seqlens_q, cu_seqlens_k), 596 | (max_seqlen_in_batch_q, max_seqlen_in_batch_k), 597 | ) 598 | 599 | 600 | class LlamaDecoderLayer(nn.Module): 601 | def __init__(self, config: LlamaConfig): 602 | super().__init__() 603 | self.hidden_size = config.hidden_size 604 | self.self_attn = ( 605 | LlamaAttention(config=config) 606 | if not getattr(config, "_flash_attn_2_enabled", False) 607 | else LlamaFlashAttention2(config=config) 608 | ) 609 | self.mlp = LlamaMLP(config) 610 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 611 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 612 | 613 | def forward( 614 | self, 615 | hidden_states: torch.Tensor, 616 | attention_mask: Optional[torch.Tensor] = None, 617 | position_ids: Optional[torch.LongTensor] = None, 618 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 619 | output_attentions: Optional[bool] = False, 620 | use_cache: Optional[bool] = False, 621 | padding_mask: Optional[torch.LongTensor] = None, 622 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 623 | """ 624 | Args: 625 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 626 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size 627 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 628 | output_attentions (`bool`, *optional*): 629 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 630 | returned tensors for more detail. 631 | use_cache (`bool`, *optional*): 632 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 633 | (see `past_key_values`). 634 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states 635 | """ 636 | 637 | residual = hidden_states 638 | 639 | hidden_states = self.input_layernorm(hidden_states) 640 | 641 | # Self Attention 642 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 643 | hidden_states=hidden_states, 644 | attention_mask=attention_mask, 645 | position_ids=position_ids, 646 | past_key_value=past_key_value, 647 | output_attentions=output_attentions, 648 | use_cache=use_cache, 649 | padding_mask=padding_mask, 650 | ) 651 | hidden_states = residual + hidden_states 652 | 653 | # Fully Connected 654 | residual = hidden_states 655 | hidden_states = self.post_attention_layernorm(hidden_states) 656 | hidden_states = self.mlp(hidden_states) 657 | hidden_states = residual + hidden_states 658 | 659 | outputs = (hidden_states,) 660 | 661 | if output_attentions: 662 | outputs += (self_attn_weights,) 663 | 664 | if use_cache: 665 | outputs += (present_key_value,) 666 | 667 | return outputs 668 | 669 | 670 | LLAMA_START_DOCSTRING = r""" 671 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 672 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 673 | etc.) 674 | 675 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 676 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 677 | and behavior. 678 | 679 | Parameters: 680 | config ([`LlamaConfig`]): 681 | Model configuration class with all the parameters of the model. Initializing with a config file does not 682 | load the weights associated with the model, only the configuration. Check out the 683 | [`~PreTrainedModel.from_pretrained`] method to load the model weights. 684 | """ 685 | 686 | 687 | @add_start_docstrings( 688 | "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", 689 | LLAMA_START_DOCSTRING, 690 | ) 691 | class LlamaPreTrainedModel(PreTrainedModel): 692 | config_class = LlamaConfig 693 | base_model_prefix = "model" 694 | supports_gradient_checkpointing = True 695 | _no_split_modules = ["LlamaDecoderLayer"] 696 | _skip_keys_device_placement = "past_key_values" 697 | _supports_flash_attn_2 = True 698 | 699 | def _init_weights(self, module): 700 | std = self.config.initializer_range 701 | if isinstance(module, nn.Linear): 702 | module.weight.data.normal_(mean=0.0, std=std) 703 | if module.bias is not None: 704 | module.bias.data.zero_() 705 | elif isinstance(module, nn.Embedding): 706 | module.weight.data.normal_(mean=0.0, std=std) 707 | if module.padding_idx is not None: 708 | module.weight.data[module.padding_idx].zero_() 709 | 710 | def _set_gradient_checkpointing(self, module, value=False): 711 | if isinstance(module, LlamaModel): 712 | module.gradient_checkpointing = value 713 | 714 | 715 | LLAMA_INPUTS_DOCSTRING = r""" 716 | Args: 717 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 718 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 719 | it. 720 | 721 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 722 | [`PreTrainedTokenizer.__call__`] for details. 723 | 724 | [What are input IDs?](../glossary#input-ids) 725 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 726 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 727 | 728 | - 1 for tokens that are **not masked**, 729 | - 0 for tokens that are **masked**. 730 | 731 | [What are attention masks?](../glossary#attention-mask) 732 | 733 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 734 | [`PreTrainedTokenizer.__call__`] for details. 735 | 736 | If `past_key_values` is used, optionally only the last `input_ids` have to be input (see 737 | `past_key_values`). 738 | 739 | If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] 740 | and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more 741 | information on the default strategy. 742 | 743 | - 1 indicates the head is **not masked**, 744 | - 0 indicates the head is **masked**. 745 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 746 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 747 | config.n_positions - 1]`. 748 | 749 | [What are position IDs?](../glossary#position-ids) 750 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 751 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape 752 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape 753 | `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. 754 | 755 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention 756 | blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. 757 | 758 | If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't 759 | have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` 760 | of shape `(batch_size, sequence_length)`. 761 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 762 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 763 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 764 | model's internal embedding lookup matrix. 765 | use_cache (`bool`, *optional*): 766 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 767 | `past_key_values`). 768 | output_attentions (`bool`, *optional*): 769 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 770 | tensors for more detail. 771 | output_hidden_states (`bool`, *optional*): 772 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 773 | more detail. 774 | return_dict (`bool`, *optional*): 775 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 776 | """ 777 | 778 | 779 | @add_start_docstrings( 780 | "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", 781 | LLAMA_START_DOCSTRING, 782 | ) 783 | class LlamaModel(LlamaPreTrainedModel): 784 | """ 785 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] 786 | 787 | Args: 788 | config: LlamaConfig 789 | """ 790 | 791 | def __init__(self, config: LlamaConfig): 792 | super().__init__(config) 793 | self.padding_idx = config.pad_token_id 794 | self.vocab_size = config.vocab_size 795 | 796 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) 797 | self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) 798 | self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 799 | 800 | self.gradient_checkpointing = False 801 | # Initialize weights and apply final processing 802 | self.post_init() 803 | 804 | # New args for block selection 805 | self.selected_blocks = None 806 | self.block_influence_estimator = None 807 | 808 | def get_input_embeddings(self): 809 | return self.embed_tokens 810 | 811 | def set_input_embeddings(self, value): 812 | self.embed_tokens = value 813 | 814 | def select_blocks(self, block_list: List[int], verbose: bool = True): 815 | if block_list is not None: 816 | assert all([0 <= x < len(self.layers) for x in block_list]), block_list 817 | self.selected_blocks = block_list 818 | if verbose: 819 | print("Selected blocks:", self.selected_blocks) 820 | 821 | def add_block_influence_estimator(self, infl_est: BlockInfluenceEstimator): 822 | self.block_influence_estimator = infl_est 823 | print(f"Added block influence estimator to the model: {self.block_influence_estimator}") 824 | 825 | # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask 826 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 827 | # create causal mask 828 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 829 | combined_attention_mask = None 830 | if input_shape[-1] > 1: 831 | combined_attention_mask = _make_causal_mask( 832 | input_shape, 833 | inputs_embeds.dtype, 834 | device=inputs_embeds.device, 835 | past_key_values_length=past_key_values_length, 836 | ) 837 | 838 | if attention_mask is not None: 839 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 840 | expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( 841 | inputs_embeds.device 842 | ) 843 | combined_attention_mask = ( 844 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask 845 | ) 846 | 847 | return combined_attention_mask 848 | 849 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 850 | def forward( 851 | self, 852 | input_ids: torch.LongTensor = None, 853 | attention_mask: Optional[torch.Tensor] = None, 854 | position_ids: Optional[torch.LongTensor] = None, 855 | past_key_values: Optional[List[torch.FloatTensor]] = None, 856 | inputs_embeds: Optional[torch.FloatTensor] = None, 857 | use_cache: Optional[bool] = None, 858 | output_attentions: Optional[bool] = None, 859 | output_hidden_states: Optional[bool] = None, 860 | return_dict: Optional[bool] = None, 861 | ) -> Union[Tuple, BaseModelOutputWithPast]: 862 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 863 | output_hidden_states = ( 864 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 865 | ) 866 | use_cache = use_cache if use_cache is not None else self.config.use_cache 867 | 868 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 869 | 870 | # retrieve input_ids and inputs_embeds 871 | if input_ids is not None and inputs_embeds is not None: 872 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 873 | elif input_ids is not None: 874 | batch_size, seq_length = input_ids.shape 875 | elif inputs_embeds is not None: 876 | batch_size, seq_length, _ = inputs_embeds.shape 877 | else: 878 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 879 | 880 | seq_length_with_past = seq_length 881 | past_key_values_length = 0 882 | 883 | if past_key_values is not None: # integrated fixes for cache 884 | first_valid_key = [i for i in range(len(past_key_values)) if past_key_values[i] is not None] 885 | assert len(first_valid_key) > 0, "no valid keys" 886 | past_key_values_length = past_key_values[first_valid_key[0]][0].shape[2] 887 | seq_length_with_past = seq_length_with_past + past_key_values_length 888 | 889 | if position_ids is None: 890 | device = input_ids.device if input_ids is not None else inputs_embeds.device 891 | position_ids = torch.arange( 892 | past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device 893 | ) 894 | position_ids = position_ids.unsqueeze(0).view(-1, seq_length) 895 | else: 896 | position_ids = position_ids.view(-1, seq_length).long() 897 | 898 | if inputs_embeds is None: 899 | inputs_embeds = self.embed_tokens(input_ids) 900 | # embed positions 901 | if attention_mask is None: 902 | attention_mask = torch.ones( 903 | (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device 904 | ) 905 | attention_mask = self._prepare_decoder_attention_mask( 906 | attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length 907 | ) 908 | 909 | hidden_states = inputs_embeds 910 | 911 | if self.gradient_checkpointing and self.training: 912 | if use_cache: 913 | logger.warning_once( 914 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 915 | ) 916 | use_cache = False 917 | 918 | # decoder layers 919 | all_hidden_states = () if output_hidden_states else None 920 | all_self_attns = () if output_attentions else None 921 | next_decoder_cache = () if use_cache else None 922 | 923 | for idx, decoder_layer in enumerate(self.layers): 924 | if self.selected_blocks is not None: # skip blocks that are not selected 925 | if idx not in self.selected_blocks: 926 | if use_cache: 927 | next_decoder_cache += (None,) 928 | if output_attentions: 929 | all_self_attns += (None,) 930 | continue 931 | 932 | if output_hidden_states: 933 | all_hidden_states += (hidden_states,) 934 | 935 | past_key_value = past_key_values[idx] if past_key_values is not None else None 936 | 937 | if self.gradient_checkpointing and self.training: 938 | 939 | def create_custom_forward(module): 940 | def custom_forward(*inputs): 941 | # None for past_key_value 942 | return module(*inputs, past_key_value, output_attentions) 943 | 944 | return custom_forward 945 | 946 | layer_outputs = torch.utils.checkpoint.checkpoint( 947 | create_custom_forward(decoder_layer), 948 | hidden_states, 949 | attention_mask, 950 | position_ids, 951 | ) 952 | else: 953 | layer_outputs = decoder_layer( 954 | hidden_states, 955 | attention_mask=attention_mask, 956 | position_ids=position_ids, 957 | past_key_value=past_key_value, 958 | output_attentions=output_attentions, 959 | use_cache=use_cache, 960 | ) 961 | 962 | if self.block_influence_estimator is not None: # update the stats 963 | self.block_influence_estimator.update_block_stats(idx, hidden_states, layer_outputs[0]) 964 | 965 | # default LLaMa update 966 | hidden_states = layer_outputs[0] 967 | 968 | if use_cache: 969 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) 970 | 971 | if output_attentions: 972 | all_self_attns += (layer_outputs[1],) 973 | 974 | hidden_states = self.norm(hidden_states) 975 | 976 | # add hidden states from the last decoder layer 977 | if output_hidden_states: 978 | all_hidden_states += (hidden_states,) 979 | 980 | next_cache = next_decoder_cache if use_cache else None 981 | if not return_dict: 982 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 983 | return BaseModelOutputWithPast( 984 | last_hidden_state=hidden_states, 985 | past_key_values=next_cache, 986 | hidden_states=all_hidden_states, 987 | attentions=all_self_attns, 988 | ) 989 | 990 | 991 | class LlamaForCausalLM(LlamaPreTrainedModel): 992 | _tied_weights_keys = ["lm_head.weight"] 993 | 994 | def __init__(self, config): 995 | super().__init__(config) 996 | self.model = LlamaModel(config) 997 | self.vocab_size = config.vocab_size 998 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 999 | 1000 | # Initialize weights and apply final processing 1001 | self.post_init() 1002 | 1003 | def get_input_embeddings(self): 1004 | return self.model.embed_tokens 1005 | 1006 | def set_input_embeddings(self, value): 1007 | self.model.embed_tokens = value 1008 | 1009 | def get_output_embeddings(self): 1010 | return self.lm_head 1011 | 1012 | def set_output_embeddings(self, new_embeddings): 1013 | self.lm_head = new_embeddings 1014 | 1015 | def set_decoder(self, decoder): 1016 | self.model = decoder 1017 | 1018 | def get_decoder(self): 1019 | return self.model 1020 | 1021 | def get_num_model_layers(self) -> int: 1022 | return len(self.model.layers) 1023 | 1024 | def select_blocks(self, block_list: List[int], verbose: bool = True): 1025 | self.model.select_blocks(block_list, verbose) 1026 | 1027 | def add_block_influence_estimator(self, infl_est: BlockInfluenceEstimator): 1028 | self.model.add_block_influence_estimator(infl_est) 1029 | 1030 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) 1031 | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) 1032 | def forward( 1033 | self, 1034 | input_ids: torch.LongTensor = None, 1035 | attention_mask: Optional[torch.Tensor] = None, 1036 | position_ids: Optional[torch.LongTensor] = None, 1037 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1038 | inputs_embeds: Optional[torch.FloatTensor] = None, 1039 | labels: Optional[torch.LongTensor] = None, 1040 | use_cache: Optional[bool] = None, 1041 | output_attentions: Optional[bool] = None, 1042 | output_hidden_states: Optional[bool] = None, 1043 | return_dict: Optional[bool] = None, 1044 | ) -> Union[Tuple, CausalLMOutputWithPast]: 1045 | r""" 1046 | Args: 1047 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1048 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 1049 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 1050 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 1051 | 1052 | Returns: 1053 | 1054 | Example: 1055 | 1056 | ```python 1057 | >>> from transformers import AutoTokenizer, LlamaForCausalLM 1058 | 1059 | >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) 1060 | >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) 1061 | 1062 | >>> prompt = "Hey, are you conscious? Can you talk to me?" 1063 | >>> inputs = tokenizer(prompt, return_tensors="pt") 1064 | 1065 | >>> # Generate 1066 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 1067 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 1068 | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." 1069 | ```""" 1070 | 1071 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1072 | output_hidden_states = ( 1073 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1074 | ) 1075 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1076 | 1077 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 1078 | outputs = self.model( 1079 | input_ids=input_ids, 1080 | attention_mask=attention_mask, 1081 | position_ids=position_ids, 1082 | past_key_values=past_key_values, 1083 | inputs_embeds=inputs_embeds, 1084 | use_cache=use_cache, 1085 | output_attentions=output_attentions, 1086 | output_hidden_states=output_hidden_states, 1087 | return_dict=return_dict, 1088 | ) 1089 | 1090 | hidden_states = outputs[0] 1091 | if self.config.pretraining_tp > 1: 1092 | lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) 1093 | logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] 1094 | logits = torch.cat(logits, dim=-1) 1095 | else: 1096 | logits = self.lm_head(hidden_states) 1097 | logits = logits.float() 1098 | 1099 | loss = None 1100 | if labels is not None: 1101 | # Shift so that tokens < n predict n 1102 | shift_logits = logits[..., :-1, :].contiguous() 1103 | shift_labels = labels[..., 1:].contiguous() 1104 | # Flatten the tokens 1105 | loss_fct = CrossEntropyLoss() 1106 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 1107 | shift_labels = shift_labels.view(-1) 1108 | # Enable model parallelism 1109 | shift_labels = shift_labels.to(shift_logits.device) 1110 | loss = loss_fct(shift_logits, shift_labels) 1111 | 1112 | if not return_dict: 1113 | output = (logits,) + outputs[1:] 1114 | return (loss,) + output if loss is not None else output 1115 | 1116 | return CausalLMOutputWithPast( 1117 | loss=loss, 1118 | logits=logits, 1119 | past_key_values=outputs.past_key_values, 1120 | hidden_states=outputs.hidden_states, 1121 | attentions=outputs.attentions, 1122 | ) 1123 | 1124 | def prepare_inputs_for_generation( 1125 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 1126 | ): 1127 | if past_key_values: 1128 | input_ids = input_ids[:, -1:] 1129 | 1130 | position_ids = kwargs.get("position_ids", None) 1131 | if attention_mask is not None and position_ids is None: 1132 | # create position_ids on the fly for batch generation 1133 | position_ids = attention_mask.long().cumsum(-1) - 1 1134 | position_ids.masked_fill_(attention_mask == 0, 1) 1135 | if past_key_values: 1136 | position_ids = position_ids[:, -1].unsqueeze(-1) 1137 | 1138 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 1139 | if inputs_embeds is not None and past_key_values is None: 1140 | model_inputs = {"inputs_embeds": inputs_embeds} 1141 | else: 1142 | model_inputs = {"input_ids": input_ids} 1143 | 1144 | model_inputs.update( 1145 | { 1146 | "position_ids": position_ids, 1147 | "past_key_values": past_key_values, 1148 | "use_cache": kwargs.get("use_cache"), 1149 | "attention_mask": attention_mask, 1150 | } 1151 | ) 1152 | return model_inputs 1153 | 1154 | @staticmethod 1155 | def _reorder_cache(past_key_values, beam_idx): 1156 | reordered_past = () 1157 | for layer_past in past_key_values: 1158 | if layer_past is None: 1159 | reordered_past += (None,) 1160 | else: 1161 | reordered_past += ( 1162 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), 1163 | ) 1164 | return reordered_past 1165 | -------------------------------------------------------------------------------- /block_influence_mmlu_shapley/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export TOKENIZERS_PARALLELISM=false # disable tokenizer warning 4 | pip install wget 5 | 6 | # Get the DDP args 7 | # HEAD_NODE_IP=$1 8 | # NUM_NODES=$2 9 | NUM_GPUS_PER_NODE=2 10 | echo "Head node IP: ${HEAD_NODE_IP} / # nodes: ${NUM_NODES} / # GPUs per node: ${NUM_GPUS_PER_NODE}" 11 | 12 | # Check if HEAD_NODE_IP is given 13 | if [ -z "${HEAD_NODE_IP}" ]; then 14 | echo "No head node IP found. Using torchrun runner." 15 | RUNNER_CMD="torchrun --standalone --nnodes=1 --nproc_per_node=${NUM_GPUS_PER_NODE}" 16 | else 17 | export WORLD_SIZE=${SLURM_NTASKS} 18 | export RANK=${SLURM_PROCID} 19 | export LOCAL_RANK=${SLURM_LOCALID} 20 | export MASTER_ADDR=${HEAD_NODE_IP} 21 | export MASTER_PORT=29500 22 | echo "python args / world size: ${WORLD_SIZE} / rank: ${RANK} / local rank: ${LOCAL_RANK} / master addr: ${MASTER_ADDR} / master port: ${MASTER_PORT}" 23 | 24 | RUNNER_CMD="python" 25 | fi 26 | 27 | DEFAULT_MODEL="mistral" 28 | MODEL=${3:-$DEFAULT_MODEL} 29 | echo "Using model: ${MODEL}" 30 | 31 | EXTRA_ARGS="" 32 | DEFAULT_ITERATIVE_PRUNING="false" 33 | ITERATIVE_PRUNING=${4:-$DEFAULT_ITERATIVE_PRUNING} 34 | if [ "${ITERATIVE_PRUNING}" = "true" ]; then 35 | EXTRA_ARGS="--iterative-pruning" 36 | echo "Enabling iterative pruning" 37 | fi 38 | echo "Extra args: ${EXTRA_ARGS}" 39 | 40 | ${RUNNER_CMD} block_influence_mmlu_shapley/evaluate_mmlu_block_influence.py \ 41 | --dataset "openwebtext" \ 42 | --model-name ${MODEL} \ 43 | --model-size 7b \ 44 | --batch-size 1 \ 45 | --sequence-length 2048 \ 46 | --subsample-size 250000 \ 47 | --compute-shapley-on-test-set \ 48 | --wandb-project 'block_influence_mmlu_shapley' \ 49 | ${EXTRA_ARGS} 50 | -------------------------------------------------------------------------------- /block_influence_mmlu_shapley/evaluate_mmlu_block_influence.py: -------------------------------------------------------------------------------- 1 | #!/bin/python 2 | 3 | import os 4 | import time 5 | import json 6 | import random 7 | import argparse 8 | from tqdm import tqdm 9 | from typing import Tuple, List, Dict, Union 10 | 11 | import wandb 12 | 13 | import torch 14 | import numpy as np 15 | import pandas as pd 16 | 17 | from transformers import AutoTokenizer, AutoConfig, PreTrainedTokenizer 18 | 19 | import sys 20 | sys.path.append('block_influence/') 21 | from llama_model import LlamaForCausalLM 22 | from mistral_model import MistralForCausalLM 23 | 24 | from dataset import NLPDataset, get_dataloader 25 | from train_utils import get_num_model_params 26 | from dist_utils import init_distributed_env, is_main_proc, wait_for_other_procs, reduce_tensor 27 | from block_influence import BlockInfluenceEstimator 28 | 29 | from evals.mmlu import format_example, gen_prompt 30 | from evals.dist_mmlu import MMLUDataset, evaluate_mmlu 31 | 32 | 33 | class MMLUTrainDataset(MMLUDataset): 34 | def __init__(self, tokenizer: PreTrainedTokenizer, model_name: str, dataset_path: str = None, n_train: int = 5): 35 | super().__init__(tokenizer, f"{model_name}_train_shapley", dataset_path, n_train) 36 | 37 | def tokenize_mmlu(self, tokenizer: PreTrainedTokenizer, n_train: int): 38 | subjects = sorted( 39 | [ 40 | f.split("_test.csv")[0] 41 | for f in os.listdir(os.path.join(self.dataset_path, "test")) 42 | if "_test.csv" in f 43 | ] 44 | ) 45 | 46 | for subject in subjects: 47 | # Uses the vaildation set instead of the test set used for computing statistics 48 | dev_df = pd.read_csv( 49 | os.path.join(self.dataset_path, "dev", subject + "_dev.csv"), header=None 50 | )[:n_train] 51 | val_df = pd.read_csv( 52 | os.path.join(self.dataset_path, "val", subject + "_val.csv"), header=None 53 | ) 54 | 55 | for i in range(val_df.shape[0]): 56 | # get prompt and make sure it fits 57 | k = n_train 58 | prompt_end = format_example(val_df, i, include_answer=True) # included with the label 59 | train_prompt = gen_prompt(dev_df, subject, k) 60 | prompt = train_prompt + prompt_end 61 | 62 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids 63 | 64 | while input_ids.shape[-1] > 2048: 65 | k -= 1 66 | train_prompt = gen_prompt(dev_df, subject, k) 67 | prompt = train_prompt + prompt_end 68 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids 69 | label = val_df.iloc[i, val_df.shape[1] - 1] 70 | 71 | # Add to the list 72 | assert input_ids.shape[0] == 1, input_ids.shape 73 | self.prompts.append(input_ids[0, :].clone()) 74 | self.labels.append(label) 75 | self.subjects.append(subject) 76 | 77 | def __len__(self): 78 | return len(self.prompts) 79 | 80 | def __getitem__(self, idx): 81 | return {"input_ids": self.prompts[idx], "label": self.label2idx[self.labels[idx]], 82 | "subject": self.subjects[idx]} 83 | 84 | 85 | def load_model(args, only_tokenizer=False, pretrained=False): 86 | # assumes huggingface login: `huggingface-cli login`` 87 | if args.model_name == "llama-2": 88 | if args.use_instruct_model: 89 | model_name = f"meta-llama/Llama-2-{args.model_size.lower()}-chat-hf" 90 | else: 91 | model_name = f"meta-llama/Llama-2-{args.model_size.lower()}-hf" 92 | elif args.model_name == "mistral": 93 | if args.use_instruct_model: 94 | model_name = f"mistralai/Mistral-{args.model_size.upper()}-Instruct-v0.2" 95 | else: 96 | model_name = f"mistralai/Mistral-{args.model_size.upper()}-v0.1" 97 | else: 98 | raise RuntimeError(f"Unsupported model: {args.model_name}") 99 | print("!! Loading model:", model_name) 100 | 101 | # Load the tokenizer 102 | tokenizer = AutoTokenizer.from_pretrained(model_name) 103 | if only_tokenizer: 104 | return tokenizer 105 | 106 | # Load the model as well as the tokenizer 107 | config = AutoConfig.from_pretrained(model_name) 108 | print("Config:", config) 109 | kwargs = dict(torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2") 110 | print("Model precision:", kwargs["torch_dtype"]) 111 | if pretrained: 112 | print("Using pretrained model...") 113 | 114 | if args.model_name == "llama-2": 115 | if not pretrained: 116 | model = LlamaForCausalLM(config).to(kwargs["torch_dtype"]) 117 | else: 118 | model = LlamaForCausalLM.from_pretrained(model_name, **kwargs) 119 | elif args.model_name == "mistral": 120 | if not pretrained: 121 | model = MistralForCausalLM(config).to(kwargs["torch_dtype"]) 122 | else: 123 | model = MistralForCausalLM.from_pretrained(model_name, **kwargs) 124 | else: 125 | raise RuntimeError(f"Unsupported model: {args.model_name}") 126 | return model, tokenizer 127 | 128 | 129 | def compute_log_probs(logits: torch.Tensor, target_ids: torch.Tensor) -> Tuple[np.ndarray, np.ndarray]: 130 | # Apply softmax and log to obtain log probabilities from logits (summing original logits would be incorrect) 131 | log_probs = torch.log_softmax(logits.float(), dim=-1) 132 | 133 | log_probs = torch.gather(log_probs, 2, target_ids.unsqueeze(-1)).squeeze(-1) 134 | sequence_log_prob = log_probs.sum(dim=1).cpu().float().numpy() 135 | 136 | # Calculate perplexity 137 | sequence_length = target_ids.size(-1) 138 | assert sequence_length > 0, logits 139 | sequence_perplexity = np.exp(-sequence_log_prob / sequence_length) 140 | 141 | return sequence_perplexity, sequence_log_prob 142 | 143 | 144 | @torch.no_grad() 145 | def evaluate_model(model: torch.nn.Module, eval_loader: torch.utils.data.DataLoader, device: torch.device, split_name: str): 146 | model.eval() 147 | avg_sequence_perplexity = 0. 148 | avg_loss = 0. 149 | num_ex = 0 150 | 151 | for batch in tqdm(eval_loader): 152 | tokenized_input = batch["input_ids"].to(device) 153 | 154 | # Forward prop through the model (will also populate the loss, but one extra logit) 155 | outputs = model(tokenized_input, labels=tokenized_input) 156 | 157 | # Compute metrics on top of LM logits 158 | lm_logits = outputs.logits[:, :-1, :] # BTD format (discard the final logit) 159 | target_ids = tokenized_input[:, 1:] # input ids strided by one 160 | assert len(lm_logits.shape) == 3, lm_logits.shape 161 | assert len(target_ids.shape) == 2, target_ids.shape 162 | assert lm_logits.shape[1] == target_ids.shape[1], f"{lm_logits.shape} != {target_ids.shape}" 163 | perplexity, log_prob = compute_log_probs(lm_logits, target_ids) 164 | 165 | avg_sequence_perplexity += float(perplexity.sum()) 166 | avg_loss += float(outputs.loss) 167 | num_ex += len(tokenized_input) 168 | 169 | # Collect the stats from all processes 170 | avg_sequence_perplexity = float(reduce_tensor(torch.tensor(avg_sequence_perplexity).to(device))) 171 | avg_loss = float(reduce_tensor(torch.tensor(avg_loss).to(device))) 172 | num_ex = int(reduce_tensor(torch.tensor(num_ex).to(device))) 173 | 174 | avg_sequence_perplexity = avg_sequence_perplexity / num_ex 175 | avg_loss = avg_loss / num_ex 176 | output_dict = {"split": split_name, "num_ex": num_ex, "avg_loss": avg_loss, "avg_seq_perplexity": avg_sequence_perplexity} 177 | print(json.dumps(output_dict)) 178 | if split_name is not None and wandb.run is not None: 179 | wandb.log({f"eval_{split_name}": {"num_ex": num_ex, "avg_loss": avg_loss, "avg_seq_perplexity": avg_sequence_perplexity}}) 180 | return avg_loss, avg_sequence_perplexity 181 | 182 | 183 | def extract_label_logits(logits: torch.Tensor, tokenized_ids: Dict[str, int], return_probs: bool = False) -> torch.Tensor: 184 | assert logits.shape[0] == 1, f"batch size should be 1 / found: {logits.shape}" 185 | logits = logits[:, -1, :].flatten() # BSV format 186 | logit_tensor = torch.tensor( 187 | [ 188 | logits[tokenized_ids["A"]], 189 | logits[tokenized_ids["B"]], 190 | logits[tokenized_ids["C"]], 191 | logits[tokenized_ids["D"]], 192 | ] 193 | ) 194 | if not return_probs: 195 | return logit_tensor 196 | 197 | probs = torch.nn.functional.softmax(logit_tensor, dim=0) 198 | return probs 199 | 200 | 201 | @torch.no_grad() 202 | def compute_mmlu_block_shapley(model: torch.nn.Module, tokenizer: PreTrainedTokenizer, 203 | eval_loader: torch.utils.data.DataLoader, device: torch.device, 204 | use_random_subnetworks: bool = False, subnetwork_len: float = 0.5, 205 | seed: int = 43, num_subsampled_networks: int = 10, 206 | max_samples_per_proc: int = None, only_label_loss: bool = True): 207 | model.eval() 208 | num_model_layers = model.get_num_model_layers() 209 | print(f"!! Computing the logit shapley value for the model with {num_model_layers} layers...") 210 | rng = np.random.default_rng(seed) 211 | if not use_random_subnetworks: 212 | num_subsampled_networks = num_model_layers 213 | 214 | loss_fn = torch.nn.CrossEntropyLoss() 215 | tokenized_ids = {char: tokenizer(char).input_ids[-1] for char in ["A", "B", "C", "D"]} 216 | 217 | all_statistics = [] 218 | for iterator, batch in enumerate(tqdm(eval_loader)): 219 | tokenized_input = batch["input_ids"].to(device) 220 | base_logits = None 221 | for i in range(1+num_subsampled_networks): # first one is always base model eval 222 | selected_blocks = None # use full network 223 | if i != 0: # use subnetwork 224 | if use_random_subnetworks: 225 | selected_blocks = rng.choice(range(num_model_layers), int(subnetwork_len*num_model_layers), replace=False) 226 | else: 227 | block_to_remove = i - 1 228 | selected_blocks = [x for x in range(num_model_layers) if x != block_to_remove] 229 | model.select_blocks(selected_blocks, verbose=False) 230 | 231 | outputs = model(tokenized_input, labels=tokenized_input) 232 | if only_label_loss: 233 | lm_logits = extract_label_logits(outputs.logits, tokenized_ids, return_probs=False) 234 | lm_loss = loss_fn(lm_logits.unsqueeze(dim=0), batch["label"]) 235 | else: 236 | lm_logits = outputs.logits[:, :-1, :] # BTD format (discard the final logit) 237 | lm_loss = outputs.loss 238 | 239 | if base_logits is None: 240 | assert selected_blocks is None 241 | base_logits = lm_logits 242 | else: 243 | assert selected_blocks is not None 244 | if only_label_loss: 245 | diff_norm = torch.norm(base_logits - lm_logits, p=2, dim=-1) # no mean required 246 | assert len(diff_norm.shape) == 0, f"should be a scalar. found shape: {diff_norm.shape}" 247 | else: 248 | diff_norm = torch.norm(base_logits - lm_logits, p=2, dim=-1).mean() # mean over batch and sequence 249 | all_statistics.append((selected_blocks, float(diff_norm), float(lm_loss))) 250 | 251 | # Check if stopping condition is met 252 | if max_samples_per_proc is not None and iterator >= max_samples_per_proc - 1: 253 | print(f"{iterator} samples collected for logit shapley value. Stopping further computations!") 254 | break 255 | 256 | # Compute the block influence based on the computed statistics 257 | logit_dist = {i: {"present": [], "absent": []} for i in range(num_model_layers)} 258 | loss_dist = {i: {"present": [], "absent": []} for i in range(num_model_layers)} 259 | for selected_blocks, diff_norm, loss in all_statistics: 260 | for i in range(num_model_layers): 261 | key = "present" if i in selected_blocks else "absent" 262 | logit_dist[i][key].append(diff_norm) 263 | loss_dist[i][key].append(loss) 264 | 265 | # Compute average distances 266 | print("~~~~~~ Block shapley statistics ~~~~~~") 267 | logit_shapley_list = [] 268 | loss_shapley_list = [] 269 | for key, input_container, output_container in [("dist", logit_dist, logit_shapley_list), 270 | ("loss", loss_dist, loss_shapley_list)]: 271 | for i in range(num_model_layers): 272 | for name in ["present", "absent"]: 273 | mean = np.mean(input_container[i][name]) # convert it to mean 274 | input_container[i][name] = float(reduce_tensor(torch.tensor(mean).to(device), average=True)) 275 | shapley = input_container[i]['present'] - input_container[i]['absent'] 276 | print(f"> block {i} / present mean {key}: {input_container[i]['present']} / absent mean {key}: {input_container[i]['absent']} / shapley: {shapley}") 277 | output_container.append(shapley) 278 | print("-"*50) 279 | return logit_shapley_list, loss_shapley_list 280 | 281 | 282 | @torch.no_grad() 283 | def compute_mmlu_loss_shapley(model: torch.nn.Module, tokenizer: PreTrainedTokenizer, 284 | eval_loader: torch.utils.data.DataLoader, device: torch.device, 285 | subnetwork_len: Union[float, int] = 0.9, seed: int = 43, 286 | num_subsampled_networks: int = 10, max_samples_per_proc: int = None, 287 | track_zero_one_loss: bool = True, removed_blocks: List[int] = None): 288 | model.eval() 289 | num_model_layers = model.get_num_model_layers() 290 | rng = np.random.default_rng(seed) 291 | 292 | loss_fn = torch.nn.CrossEntropyLoss() 293 | tokenized_ids = {char: tokenizer(char).input_ids[-1] for char in ["A", "B", "C", "D"]} 294 | 295 | remaining_blocks = range(num_model_layers) 296 | if removed_blocks is not None: 297 | remaining_blocks = [x for x in remaining_blocks if x not in removed_blocks] 298 | print("Remaining blocks for shapley computation:", remaining_blocks) 299 | 300 | if isinstance(subnetwork_len, float): 301 | assert 0. < subnetwork_len < 1. 302 | subnetwork_len = int(subnetwork_len * len(remaining_blocks)) 303 | print(f"!! Computing the loss shapley value for the model with {len(remaining_blocks)} remaining layers using subnetwork length of {subnetwork_len}") 304 | assert subnetwork_len < len(remaining_blocks), "Shapley value computation requires subnetwork length to be smaller than the " 305 | 306 | all_statistics = [] 307 | correct = 0 308 | total = 0 309 | 310 | for iterator, batch in enumerate(tqdm(eval_loader)): 311 | tokenized_input = batch["input_ids"].to(device) 312 | for _ in range(num_subsampled_networks): # first one is always base model eval 313 | selected_blocks = rng.choice(remaining_blocks, subnetwork_len, replace=False) 314 | model.select_blocks(selected_blocks, verbose=False) 315 | 316 | outputs = model(tokenized_input, labels=tokenized_input) 317 | lm_logits = extract_label_logits(outputs.logits, tokenized_ids, return_probs=False) 318 | is_correct = torch.argmax(lm_logits) == int(batch["label"]) 319 | if track_zero_one_loss: 320 | loss = 1 - int(is_correct) # loss is zero for correct prediction 321 | else: 322 | loss = loss_fn(lm_logits.unsqueeze(dim=0), batch["label"]) 323 | 324 | all_statistics.append((selected_blocks, float(loss))) 325 | correct += int(is_correct) 326 | total += 1 327 | 328 | # Check if stopping condition is met 329 | if max_samples_per_proc is not None and iterator >= max_samples_per_proc - 1: 330 | print(f"{iterator} samples collected for logit shapley value. Stopping further computations!") 331 | break 332 | 333 | correct = int(reduce_tensor(torch.tensor(correct).cuda())) 334 | total = int(reduce_tensor(torch.tensor(total).cuda())) 335 | accuracy = correct / total 336 | print(f"!! [STATS] Correct: {correct} / total: {total} / accuracy: {100.*accuracy:.2f}%") 337 | 338 | # Compute the block influence based on the computed statistics 339 | loss_dist = {i: {"present": [], "absent": []} for i in range(num_model_layers)} 340 | for selected_blocks, loss in all_statistics: 341 | for i in range(num_model_layers): 342 | key = "present" if i in selected_blocks else "absent" 343 | loss_dist[i][key].append(loss) 344 | 345 | # Compute average distances 346 | print("~~~~~~ Block shapley statistics ~~~~~~") 347 | loss_shapley_list = [] 348 | for i in range(num_model_layers): 349 | if removed_blocks is not None and i in removed_blocks: 350 | dummy_max_val = 100000. 351 | loss_shapley_list.append(dummy_max_val) # argsort indices ignored 352 | print(f"> block {i} / ~~~ removed ~~~") 353 | continue 354 | 355 | for name in ["present", "absent"]: 356 | mean = np.mean(loss_dist[i][name]) # convert it to mean 357 | loss_dist[i][name] = float(reduce_tensor(torch.tensor(mean).to(device), average=True)) 358 | shapley = loss_dist[i]['present'] - loss_dist[i]['absent'] 359 | print(f"> block {i} / present mean: {loss_dist[i]['present']} / absent mean: {loss_dist[i]['absent']} / shapley: {shapley}") 360 | loss_shapley_list.append(shapley) 361 | print("-"*50) 362 | return loss_shapley_list 363 | 364 | 365 | def main(args): 366 | init_distributed_env(args) 367 | 368 | generator = None 369 | if args.seed is not None: # Set process seed to reduce stochasticity 370 | torch.manual_seed(args.seed) 371 | torch.cuda.manual_seed(args.seed) 372 | np.random.seed(seed=args.seed) 373 | random.seed(args.seed) 374 | print("Setting process seed:", args.seed) 375 | 376 | # Generator to seed dataloaders 377 | generator = torch.Generator() 378 | generator.manual_seed(args.seed) 379 | 380 | dataset_dir = f"{args.dataset}_model_{args.model_name}_seq_len_{args.sequence_length}_subsample_{args.subsample_size}_comb_docs" 381 | args.dataset_output_dir = os.path.join("datasets", dataset_dir) 382 | 383 | suffix = f"block_{'iterative_' if args.iterative_pruning else ''}pruning_mmlu_shapley{'_test' if args.compute_shapley_on_test_set else ''}_zero_one_loss_num_subnetworks_{args.num_sampled_subnetworks}" 384 | args.wandb_run_name = f"{dataset_dir}_{suffix}" 385 | 386 | if args.wandb_project is not None and is_main_proc(): 387 | print("Initialization w&b...") 388 | wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=args, resume=False) 389 | 390 | if is_main_proc() and not NLPDataset.is_dataset_processed(args.dataset_output_dir): 391 | tokenizer = load_model(args, only_tokenizer=True) 392 | dataset = NLPDataset(args.dataset, tokenizer, max_length=args.sequence_length, 393 | combine_documents=True, subsample_size=args.subsample_size) 394 | dataset.save_datasets(args.dataset_output_dir) 395 | wait_for_other_procs() # wait for the main process to write the dataset 396 | 397 | # Load the dataset 398 | dataset = NLPDataset.load_dataset(args.dataset_output_dir) # returns a dataset dict 399 | train_dataset = dataset["train"] 400 | test_dataset = dataset["test"] 401 | 402 | # Load the model 403 | model, tokenizer = load_model(args, pretrained=True) 404 | num_model_params = get_num_model_params(model) 405 | num_model_layers = model.get_num_model_layers() 406 | print(f"# model params: {num_model_params/1_000_000:.2f}M / # layers: {num_model_layers}") 407 | 408 | # Convert to DDP 409 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 410 | model = model.to(device) # move to device 411 | 412 | # Create the dataloaders 413 | train_loader = get_dataloader(train_dataset, args.batch_size, args.num_workers, drop_last=True, generator=generator) 414 | eval_loader = get_dataloader(test_dataset, args.test_batch_size, args.num_workers, generator=generator) 415 | 416 | # Load MMLU dataset 417 | mmlu_dataset = MMLUDataset(tokenizer, args.model_name) 418 | # mmlu_dataset = torch.utils.data.Subset(mmlu_dataset, range(10)) 419 | mmlu_loader = get_dataloader(mmlu_dataset, 1, args.num_workers, generator=generator) # bs=1 for MMLU 420 | print("# examples in MMLU dataset:", len(mmlu_dataset)) 421 | 422 | if args.compute_shapley_on_test_set: 423 | print("Computing shapley on the test set...") 424 | train_mmlu_dataset, train_mmlu_loader = mmlu_dataset, mmlu_loader 425 | else: 426 | print("Computing shapley on the train set...") 427 | train_mmlu_dataset = MMLUTrainDataset(tokenizer, args.model_name) # uses the validation set instead of the test set 428 | train_mmlu_loader = get_dataloader(train_mmlu_dataset, 1, args.num_workers, generator=generator) # bs=1 for MMLU 429 | print("# examples in MMLU train dataset:", len(train_mmlu_dataset)) 430 | 431 | print(">> Estimating block influences...") 432 | eval_start_time = time.time() 433 | compute_all_stats = False 434 | if compute_all_stats: 435 | model.select_blocks(None) # use all blocks 436 | block_influence_estimator = BlockInfluenceEstimator(num_model_layers, device) 437 | model.add_block_influence_estimator(block_influence_estimator) 438 | evaluate_model(model, train_loader, device, split_name=None) # use the train set to compute block influences 439 | final_block_influences = block_influence_estimator.get_block_influences() 440 | model.add_block_influence_estimator(None) # remove the block influence computation 441 | print("Final block influences:", final_block_influences) 442 | 443 | cosine_block_influences = [x["cosine_dist"] for x in final_block_influences] 444 | l1_block_influences = [x["l1_update_norm"] for x in final_block_influences] 445 | relative_l1_block_influences = [x["l1_relative_update_norm"] for x in final_block_influences] 446 | l2_block_influences = [x["l2_update_norm"] for x in final_block_influences] 447 | relative_l2_block_influences = [x["l2_relative_update_norm"] for x in final_block_influences] 448 | print("Cosine block influences:", cosine_block_influences) 449 | print("L1 block influences:", l1_block_influences) 450 | print("Relative L1 block influences:", relative_l1_block_influences) 451 | print("L2 block influences:", l2_block_influences) 452 | print("Relative L2 block influences:", relative_l2_block_influences) 453 | 454 | if wandb.run is not None: 455 | wandb.log({f"block_{i}_cosine_influence": block_influence for i, block_influence in enumerate(cosine_block_influences)}) 456 | wandb.log({f"block_{i}_l1_influence": block_influence for i, block_influence in enumerate(l1_block_influences)}) 457 | wandb.log({f"block_{i}_relative_l1_influence": block_influence for i, block_influence in enumerate(relative_l1_block_influences)}) 458 | wandb.log({f"block_{i}_l2_influence": block_influence for i, block_influence in enumerate(l2_block_influences)}) 459 | wandb.log({f"block_{i}_relative_l2_influence": block_influence for i, block_influence in enumerate(relative_l2_block_influences)}) 460 | 461 | # Compute the block logit shapley 462 | max_samples_per_proc = None 463 | if args.limit_shapley_samples is not None: 464 | max_samples_per_proc = args.limit_shapley_samples // args.world_size 465 | print(f"Total samples limit: {args.limit_shapley_samples}. Capping the max_samples_per_proc for logit shapley computation to be: {max_samples_per_proc}") 466 | new_shapley_computation = True 467 | if new_shapley_computation: 468 | block_loss_shapley = compute_mmlu_loss_shapley(model, tokenizer, train_mmlu_loader, device, num_subsampled_networks=args.num_sampled_subnetworks, 469 | max_samples_per_proc=max_samples_per_proc) 470 | print("Block loss shapley:", block_loss_shapley) 471 | loss_shapley_block_influence = [-x for x in block_loss_shapley] # negative shapely (lower distance) indicates higher importance 472 | if wandb.run is not None: 473 | wandb.log({f"block_{i}_loss_shapley_influence": block_influence for i, block_influence in enumerate(loss_shapley_block_influence)}) 474 | else: 475 | block_logit_shapley, block_loss_shapley = compute_mmlu_block_shapley(model, tokenizer, train_mmlu_loader, device, num_subsampled_networks=args.num_sampled_subnetworks, 476 | max_samples_per_proc=max_samples_per_proc) 477 | print("Block logit shapley:", block_logit_shapley) 478 | print("Block loss shapley:", block_loss_shapley) 479 | logit_shapley_block_influence = [-x for x in block_logit_shapley] # negative shapely (lower distance) indicates higher importance 480 | loss_shapley_block_influence = [-x for x in block_loss_shapley] # negative shapely (lower distance) indicates higher importance 481 | if wandb.run is not None: 482 | wandb.log({f"block_{i}_logit_shapley_influence": block_influence for i, block_influence in enumerate(logit_shapley_block_influence)}) 483 | wandb.log({f"block_{i}_loss_shapley_influence": block_influence for i, block_influence in enumerate(loss_shapley_block_influence)}) 484 | 485 | if compute_all_stats: 486 | block_influence_list = [("cosine", cosine_block_influences), ("relative_l1", relative_l1_block_influences), 487 | ("relative_l2", relative_l2_block_influences), ("logit_shapley", logit_shapley_block_influence), 488 | ("loss_shapley", loss_shapley_block_influence)] 489 | else: 490 | block_influence_list = [("loss_shapley", loss_shapley_block_influence)] 491 | 492 | compute_perplexity = False 493 | for influence_name, block_influences in block_influence_list: 494 | print("Using block influence method:", influence_name) 495 | print("Block influence values:", block_influences) 496 | sorted_blocks = np.argsort(block_influences) # ascending order 497 | block_influences = sorted_blocks 498 | print("Sorted block list:", sorted_blocks) 499 | 500 | remaining_blocks = list(range(num_model_layers)) 501 | weighted_acc_list = [] 502 | perplexity_list = [] 503 | iterator = -1 504 | removed_blocks_list = [] 505 | for _ in range(len(sorted_blocks)+1): # one additional iteration for no dropping 506 | if iterator > -1: # do nothing for the first block i.e., all blocks are selected 507 | if args.iterative_pruning and iterator == len(block_influences) - 1: 508 | assert len(remaining_blocks) == 1, remaining_blocks 509 | lowest_block = remaining_blocks[0] 510 | print("Using the last remaining block as the block instead of recomputing shapley value...") 511 | elif args.iterative_pruning and iterator > 0: # first block already pruned 512 | if not new_shapley_computation: 513 | raise NotImplementedError 514 | current_block_loss_shapley = compute_mmlu_loss_shapley(model, tokenizer, train_mmlu_loader, device, 515 | num_subsampled_networks=args.num_sampled_subnetworks, 516 | max_samples_per_proc=max_samples_per_proc, 517 | removed_blocks=removed_blocks_list) 518 | print(f"Block loss shapley at pruning step {iterator}: {current_block_loss_shapley}") 519 | block_influences = [-x for x in current_block_loss_shapley] # negative shapely (lower distance) indicates higher importance 520 | if wandb.run is not None: 521 | wandb.log({f"step_{iterator}_block_{i}_loss_shapley_influence": block_influence 522 | for i, block_influence in enumerate(block_influences)}) 523 | 524 | # Compute the new sorted blocks list 525 | sorted_blocks = np.argsort(block_influences) 526 | sorted_blocks = [x for x in sorted_blocks if x not in removed_blocks_list] # sorted block list after discarding removed blocks 527 | lowest_block = sorted_blocks[0] 528 | print("New sorted blocks:", sorted_blocks) 529 | else: # use the precomputed list 530 | lowest_block = sorted_blocks[iterator] # prune blocks based on the estimated block influence 531 | 532 | print(f"Removing block {lowest_block} with lowest influence: {block_influences[lowest_block]}") 533 | assert lowest_block not in removed_blocks_list # only relevant for iterative pruning 534 | removed_blocks_list.append(lowest_block) 535 | remaining_blocks = [i for i in remaining_blocks if i != lowest_block] # remove lowest block 536 | print("Remaining blocks:", remaining_blocks) 537 | model.select_blocks(remaining_blocks) # use all blocks 538 | _, _, weighted_acc = evaluate_mmlu(model, tokenizer, mmlu_loader, device, f"{influence_name}_blocks_pruned_{iterator+1}") 539 | weighted_acc_list.append(weighted_acc) 540 | if compute_perplexity: 541 | _, avg_perplexity = evaluate_model(model, eval_loader, device, f"{influence_name}_blocks_pruned_{iterator+1}") 542 | perplexity_list.append(avg_perplexity) 543 | iterator += 1 544 | 545 | print(f">>>>> Block pruning statistics using {influence_name} metric <<<<<") 546 | print(f"{influence_name} weighted ACC list: {weighted_acc_list}") 547 | if compute_perplexity: 548 | print(f"{influence_name} perplexity list: {perplexity_list}") 549 | print("="*25) 550 | 551 | eval_time_elapsed_h = (time.time() - eval_start_time) / (60 * 60) # convert seconds into hours 552 | print(f"Block pruning evaluation completed / time elapsed: {eval_time_elapsed_h:.2f}h") 553 | 554 | if wandb.run is not None: 555 | wandb.finish() 556 | 557 | 558 | if __name__ == "__main__": 559 | supported_datasets = ['pg19', 'cc_news', 'wikitext-2', 'bookcorpus', 'c4', 'openwebtext', 'slimpajama'] 560 | 561 | # Create ArgumentParser object 562 | parser = argparse.ArgumentParser(description='Argument parser for LLM block influence evaluator') 563 | 564 | # Add arguments 565 | parser.add_argument('-d', '--dataset', default='wikitext-2', choices=supported_datasets, 566 | help='Dataset name (default: wikitext-2)') 567 | parser.add_argument('-m', '--model-name', default='llama-2', choices=['llama-2', 'mistral'], 568 | help='Model name (default: llama-2)') 569 | parser.add_argument('-s', '--model-size', default='7b', choices=['7b'], 570 | help='Model size (default: 7b)') 571 | parser.add_argument('--use-instruct-model', action='store_true', default=False, 572 | help='Use instruction-tuned model rather than the base model') 573 | parser.add_argument('--batch-size', type=int, default=1, 574 | help='Batch size per process (default: 1)') 575 | parser.add_argument('--test-batch-size', type=int, default=None, 576 | help='Batch size per process for testing (default: equal to --batch-size)') 577 | parser.add_argument('--sequence-length', type=int, default=1024, 578 | help='Sequence length for computing the model perplexity (default: 1024)') 579 | parser.add_argument('--subsample-size', type=int, default=1000000, 580 | help='Dataset subsample size in terms of number of docs (default: 1M)') 581 | parser.add_argument('--num-workers', type=int, default=8, 582 | help='Number of workers for the dataloader (default: 8)') 583 | parser.add_argument('--seed', type=int, default=43, 584 | help='seed value (default: 43)') 585 | parser.add_argument('--wandb-project', type=str, default=None, 586 | help='W&B project name (none indicates no W&B initialization)') 587 | parser.add_argument('--limit-shapley-samples', type=int, default=None, 588 | help='limit the number of samples to the specified value for shapley computation (default: None i.e., no limit)') 589 | parser.add_argument('--num-sampled-subnetworks', type=int, default=10, 590 | help='number of subnetworks to sample for each input example for shapley value computation (default: 10)') 591 | parser.add_argument('--compute-shapley-on-test-set', action='store_true', default=False, 592 | help='compute shapley value on the test set rather than the training set') 593 | parser.add_argument('--iterative-pruning', action='store_true', default=False, 594 | help='use iterative pruning instead of one-shot pruning (recomputes the shapley value after each step)') 595 | 596 | # Parse the arguments 597 | args = parser.parse_args() 598 | 599 | if args.test_batch_size is None: 600 | args.test_batch_size = args.batch_size 601 | print("Setting test batch size to be equal to batch size:", args.test_batch_size) 602 | 603 | main(args) 604 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | #!/bin/python 2 | 3 | import os 4 | import psutil 5 | import random 6 | from itertools import chain 7 | 8 | import torch 9 | import numpy as np 10 | from datasets import load_dataset, DatasetDict, load_from_disk 11 | from transformers import PreTrainedTokenizer 12 | 13 | 14 | def seed_worker(worker_id): 15 | worker_seed = torch.initial_seed() % 2**32 16 | np.random.seed(worker_seed) 17 | random.seed(worker_seed) 18 | 19 | 20 | def get_dataloader(dataset: torch.utils.data.Dataset, batch_size: int, num_workers: int = 8, 21 | drop_last: bool = False, pin_loader_memory: bool = False, generator=None): 22 | sampler = None 23 | if torch.distributed.is_initialized(): 24 | print("!! Attaching sampler to the DataLoader for distributed training...") 25 | sampler = torch.utils.data.distributed.DistributedSampler(dataset) 26 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, 27 | sampler=sampler, drop_last=drop_last, pin_memory=pin_loader_memory, 28 | worker_init_fn=seed_worker, generator=generator) 29 | return dataloader 30 | 31 | 32 | class NLPDataset(torch.utils.data.Dataset): 33 | def __init__(self, dataset_name: str, tokenizer: PreTrainedTokenizer, max_length: int, 34 | combine_documents: bool, logging_level: int = 0, subsample_size: int = 1000000, 35 | sampler_seed: int = 43, include_other_cols: bool = False, num_proc: int = None): 36 | # Load the original dataset 37 | assert dataset_name in ["pg19", "cc_news", "wikitext-2", "bookcorpus", "c4", "openwebtext", "slimpajama"] 38 | print("!! Loading dataset:", dataset_name) 39 | 40 | subsample_dataset = True # Subsample in this case due to large dataset size 41 | if dataset_name == "pg19": 42 | dataset = load_dataset("pg19") 43 | subsample_dataset = False # dataset small enough 44 | elif dataset_name == "c4": 45 | # Load the en-noblocklist subset of C4 (https://huggingface.co/datasets/c4) 46 | dataset = load_dataset("c4", "en.noblocklist") 47 | elif dataset_name == "openwebtext": 48 | dataset = load_dataset("openwebtext") 49 | elif dataset_name == "slimpajama": 50 | dataset = load_dataset("cerebras/SlimPajama-627B") 51 | elif dataset_name == "cc_news": 52 | dataset = load_dataset("cc_news") 53 | subsample_dataset = False # dataset small enough 54 | elif dataset_name == "bookcorpus": 55 | dataset = load_dataset("bookcorpus") 56 | else: 57 | assert dataset_name == "wikitext-2" 58 | dataset = load_dataset("wikitext", "wikitext-2-raw-v1") 59 | subsample_dataset = False # dataset small enough 60 | 61 | # Create a test split in case the original test split is not provided 62 | if "test" not in dataset: 63 | # Split the dataset into training (90%) and testing (10%) 64 | print("Creating synthetic test split...") 65 | assert "train" in dataset, dataset 66 | d = dataset["train"].train_test_split(test_size=0.1, seed=sampler_seed, shuffle=True) 67 | else: 68 | print("Using the official test split...") 69 | d = dataset 70 | 71 | if subsample_dataset: 72 | # Define the random number generator based on the random seed 73 | rng = np.random.default_rng(sampler_seed) 74 | 75 | # Subsample the dataset for OpenWebText as a starting point 76 | train_examples = subsample_size # 1M examples from the dataset 77 | print(f"!! Subsampling train dataset to {train_examples} examples...") 78 | possible_idx = list(range(len(d["train"]))) 79 | selected_idx = rng.choice(possible_idx, size=(train_examples), replace=False) 80 | d["train"] = d["train"].select(selected_idx) 81 | eval_examples = int(0.1 * train_examples) 82 | if len(d["test"]) > eval_examples: 83 | print(f"!! Subsampling test dataset to {eval_examples} examples...") 84 | possible_idx = list(range(len(d["test"]))) 85 | selected_idx = rng.choice(possible_idx, size=(eval_examples), replace=False) 86 | d["test"] = d["test"].select(selected_idx) 87 | print("!! Dataset subsampling completed...") 88 | 89 | filter_dataset = True 90 | if filter_dataset: 91 | prev_train_size = len(d["train"]) 92 | d["train"] = d["train"].filter(lambda example: len(example["text"]) > 0) 93 | print(f"Train dataset filtering / old size: {prev_train_size} / new size: {len(d['train'])}") 94 | 95 | prev_test_size = len(d["test"]) 96 | d["test"] = d["test"].filter(lambda example: len(example["text"]) > 0) 97 | print(f"Test dataset filtering / old size: {prev_test_size} / new size: {len(d['test'])}") 98 | 99 | if logging_level > 0: 100 | print("Full dataset:", dataset) 101 | print(f"Splits / train: {d['train']} / test: {d['test']}") 102 | 103 | if logging_level > 1: 104 | for t in d["train"]["text"][:3]: 105 | print(t) 106 | print("="*50) 107 | 108 | self.max_length = max_length 109 | self.tokenizer = tokenizer 110 | 111 | truncate_longer_samples = False 112 | num_proc = psutil.cpu_count() if num_proc is None else num_proc 113 | print(f"# processes for mapping: {num_proc} / combine documents: {combine_documents}") 114 | 115 | # the encode function will depend on the truncate_longer_samples variable 116 | encode = self.encode_with_truncation if truncate_longer_samples else self.encode_without_truncation 117 | 118 | # tokenizing the train/test dataset (essential to remove columns as they can result in wrong model keys) 119 | train_dataset = d["train"].map(encode, remove_columns=['text'], batched=True, num_proc=num_proc, desc="Train encoding") 120 | test_dataset = d["test"].map(encode, remove_columns=['text'], batched=True, num_proc=num_proc, desc="Test encoding") 121 | 122 | if truncate_longer_samples: 123 | columns = ["input_ids", "attention_mask"] if include_other_cols else ["input_ids"] 124 | else: 125 | columns = ["input_ids", "attention_mask", "special_tokens_mask"] if include_other_cols else ["input_ids"] 126 | train_dataset.set_format(type="torch", columns=columns) 127 | test_dataset.set_format(type="torch", columns=columns) 128 | 129 | # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a 130 | # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value 131 | # might be slower to preprocess. 132 | # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: 133 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map 134 | if not truncate_longer_samples: 135 | print(f"!! Grouping {'combined ' if combine_documents else ''}documents to size: {self.max_length}") 136 | grouping_func = self.group_texts if combine_documents else self.group_texts_within_documents 137 | train_dataset = train_dataset.map(grouping_func, batched=combine_documents, num_proc=num_proc, 138 | desc=f"Grouping texts in chunks of {self.max_length}") 139 | test_dataset = test_dataset.map(grouping_func, batched=combine_documents, num_proc=num_proc, 140 | desc=f"Grouping texts in chunks of {self.max_length}") 141 | 142 | # convert them from lists to torch tensors 143 | train_dataset.set_format("torch") 144 | test_dataset.set_format("torch") 145 | 146 | self.datasets = DatasetDict({"train": train_dataset, "test": test_dataset}) 147 | 148 | @staticmethod 149 | def is_dataset_processed(dataset_dir): 150 | return os.path.exists(dataset_dir) 151 | 152 | def save_datasets(self, dataset_dir): 153 | if not NLPDataset.is_dataset_processed(dataset_dir): 154 | print("Saving dataset to disk:", dataset_dir) 155 | self.datasets.save_to_disk(dataset_dir) 156 | 157 | @staticmethod 158 | def load_dataset(dataset_dir): 159 | datasets = None 160 | if NLPDataset.is_dataset_processed(dataset_dir): 161 | print("Loading dataset from disk:", dataset_dir) 162 | datasets = load_from_disk(dataset_dir) 163 | return datasets 164 | 165 | def group_texts(self, examples): 166 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of max_seq_length. 167 | # grabbed from: https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py 168 | # Concatenate all texts. 169 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} 170 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 171 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can 172 | # customize this part to your needs. 173 | if total_length >= self.max_length: 174 | total_length = (total_length // self.max_length) * self.max_length 175 | # Split by chunks of max_len. 176 | result = { 177 | k: [t[i : i + self.max_length] for i in range(0, total_length, self.max_length)] 178 | for k, t in concatenated_examples.items() 179 | } 180 | return result 181 | 182 | def group_texts_within_documents(self, example): 183 | total_length = len(example["input_ids"]) 184 | doc_result = { 185 | "input_ids": [], 186 | "attention_mask": [], 187 | "special_tokens_mask": [] 188 | } 189 | 190 | # If the length of input is less than `self.max_length`, then add the whole input into `doc_result`. 191 | if total_length < self.max_length: 192 | doc_result["input_ids"].append(example["input_ids"]) 193 | doc_result["attention_mask"].append(example["attention_mask"]) 194 | doc_result["special_tokens_mask"].append(example["special_tokens_mask"]) 195 | else: 196 | doc_result["input_ids"] = [example["input_ids"][i : i + self.max_length] 197 | for i in range(0, total_length, self.max_length)] 198 | doc_result["attention_mask"] = [example["attention_mask"][i : i + self.max_length] 199 | for i in range(0, total_length, self.max_length)] 200 | doc_result["special_tokens_mask"] = [example["special_tokens_mask"][i : i + self.max_length] 201 | for i in range(0, total_length, self.max_length)] 202 | assert all([len(x) > 0 for x in doc_result["input_ids"]]) 203 | return doc_result 204 | 205 | def encode_with_truncation(self, examples): 206 | """Mapping function to tokenize the sentences passed with truncation""" 207 | return self.tokenizer(examples["text"], truncation=True, padding="max_length", 208 | max_length=self.max_length, return_special_tokens_mask=True) 209 | 210 | def encode_without_truncation(self, examples): 211 | """Mapping function to tokenize the sentences passed without truncation""" 212 | return self.tokenizer(examples["text"], return_special_tokens_mask=True) 213 | -------------------------------------------------------------------------------- /dist_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import timedelta 3 | 4 | import torch 5 | 6 | 7 | def init_distributed_env(args): 8 | # Initialize the distributed environment 9 | args.world_size = int(os.environ.get('WORLD_SIZE', os.environ.get('SLURM_NTASKS', 1))) 10 | args.distributed = args.world_size > 1 11 | args.rank = int(os.environ.get('RANK', os.environ.get('SLURM_PROCID', 0))) 12 | args.local_rank = int(os.environ.get('LOCAL_RANK', os.environ.get('SLURM_LOCALID', 0))) 13 | args.gpu = args.local_rank 14 | 15 | if args.distributed: 16 | torch.cuda.set_device(args.gpu) 17 | torch.distributed.init_process_group(backend="nccl", init_method="env://", timeout=timedelta(hours=1)) 18 | obtained_world_size = torch.distributed.get_world_size() 19 | assert obtained_world_size == args.world_size, f"{obtained_world_size} != {args.world_size}" 20 | print(f"Initializing the environment with {args.world_size} processes / Process rank: {args.rank} / Local rank: {args.local_rank}") 21 | setup_for_distributed(args.local_rank == 0) # print via one process per node 22 | args.effective_batch_size = args.batch_size * args.world_size 23 | print(f"# processes: {args.world_size} / batch size: {args.batch_size} / effective batch size: {args.effective_batch_size}") 24 | 25 | 26 | def is_main_proc(local_rank=None, shared_fs=True): 27 | assert shared_fs or local_rank is not None 28 | main_proc = not torch.distributed.is_initialized() or (torch.distributed.get_rank() == 0 if shared_fs else local_rank == 0) 29 | return main_proc 30 | 31 | 32 | def setup_for_distributed(is_master): 33 | """ 34 | This function disables printing when not in master process 35 | """ 36 | import builtins as __builtin__ 37 | builtin_print = __builtin__.print 38 | 39 | def print(*args, **kwargs): 40 | force = kwargs.pop('force', False) 41 | if is_master or force: 42 | builtin_print(*args, **kwargs) 43 | 44 | __builtin__.print = print 45 | 46 | 47 | def get_world_size(): 48 | return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 49 | 50 | 51 | def wait_for_other_procs(): 52 | if torch.distributed.is_initialized(): 53 | torch.distributed.barrier() 54 | 55 | 56 | def reduce_tensor(tensor, average=False): 57 | world_size = get_world_size() 58 | if world_size == 1: 59 | return tensor 60 | rt = tensor.clone() 61 | torch.distributed.all_reduce(rt, op=torch.distributed.ReduceOp.SUM) 62 | if average: 63 | rt /= world_size 64 | return rt 65 | -------------------------------------------------------------------------------- /evals/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shoaibahmed/llm_depth_pruning/eea4dbd7cbd55a429ddd0b30ae8180a9f649f2f9/evals/__init__.py -------------------------------------------------------------------------------- /evals/dist_mmlu.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import wandb 5 | import pickle 6 | from tqdm import tqdm 7 | 8 | import torch 9 | import numpy as np 10 | import pandas as pd 11 | from transformers import PreTrainedTokenizer 12 | 13 | from .mmlu import gen_prompt, format_example 14 | from .mmlu_utils import download_mmlu 15 | 16 | sys.path.append("..") # top-level package 17 | from dist_utils import is_main_proc, wait_for_other_procs, reduce_tensor 18 | 19 | 20 | class MMLUDataset(torch.utils.data.Dataset): 21 | def __init__(self, tokenizer: PreTrainedTokenizer, model_name: str, dataset_path: str = None, n_train: int = 5): 22 | if dataset_path is None: 23 | dataset_path = "./datasets/mmlu/" # default dataset directory 24 | if not os.path.exists(dataset_path) and is_main_proc(): 25 | MMLUDataset.download_dataset() 26 | wait_for_other_procs() 27 | assert os.path.exists(dataset_path) 28 | self.dataset_path = dataset_path 29 | self.model_name = model_name 30 | self.n_train = n_train 31 | self.tokenized_dataset_path = os.path.join(dataset_path, f"{model_name}_tokenized_ntrain_{n_train}.pkl") 32 | 33 | # Create the dataset containers 34 | self.prompts = [] 35 | self.labels = [] 36 | self.subjects = [] 37 | self.load_mmlu(tokenizer) 38 | self.label2idx = {char: i for i, char in enumerate(["A", "B", "C", "D"])} 39 | 40 | def load_mmlu(self, tokenizer: PreTrainedTokenizer): 41 | if not self.is_dataset_processed(): 42 | if is_main_proc(): 43 | self.tokenize_mmlu(tokenizer, self.n_train) 44 | self.save_dataset() 45 | wait_for_other_procs() 46 | self.load_dataset() 47 | 48 | def tokenize_mmlu(self, tokenizer: PreTrainedTokenizer, n_train: int): 49 | subjects = sorted( 50 | [ 51 | f.split("_test.csv")[0] 52 | for f in os.listdir(os.path.join(self.dataset_path, "test")) 53 | if "_test.csv" in f 54 | ] 55 | ) 56 | 57 | for subject in subjects: 58 | dev_df = pd.read_csv( 59 | os.path.join(self.dataset_path, "dev", subject + "_dev.csv"), header=None 60 | )[:n_train] 61 | test_df = pd.read_csv( 62 | os.path.join(self.dataset_path, "test", subject + "_test.csv"), header=None 63 | ) 64 | 65 | for i in range(test_df.shape[0]): 66 | # get prompt and make sure it fits 67 | k = n_train 68 | prompt_end = format_example(test_df, i, include_answer=False) 69 | train_prompt = gen_prompt(dev_df, subject, k) 70 | prompt = train_prompt + prompt_end 71 | 72 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids 73 | 74 | while input_ids.shape[-1] > 2048: 75 | k -= 1 76 | train_prompt = gen_prompt(dev_df, subject, k) 77 | prompt = train_prompt + prompt_end 78 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids 79 | label = test_df.iloc[i, test_df.shape[1] - 1] 80 | 81 | # Add to the list 82 | assert input_ids.shape[0] == 1, input_ids.shape 83 | self.prompts.append(input_ids[0, :].clone()) 84 | self.labels.append(label) 85 | self.subjects.append(subject) 86 | 87 | def __len__(self): 88 | return len(self.prompts) 89 | 90 | def __getitem__(self, idx): 91 | return {"input_ids": self.prompts[idx], "label": self.label2idx[self.labels[idx]], 92 | "subject": self.subjects[idx]} 93 | 94 | @staticmethod 95 | def download_dataset(dataset_dir=None): 96 | download_mmlu(dataset_dir) 97 | 98 | def is_dataset_processed(self): 99 | return os.path.exists(self.tokenized_dataset_path) 100 | 101 | def save_dataset(self): 102 | if not self.is_dataset_processed(): 103 | print("Saving MMLU dataset to disk:", self.tokenized_dataset_path) 104 | with open(self.tokenized_dataset_path, "wb") as f: 105 | output_dict = {"prompts": self.prompts, "labels": self.labels, "subjects": self.subjects} 106 | pickle.dump(output_dict, f) 107 | 108 | def load_dataset(self): 109 | assert self.is_dataset_processed() 110 | print("Loading MMLU dataset from disk:", self.tokenized_dataset_path) 111 | with open(self.tokenized_dataset_path, "rb") as f: 112 | output_dict = pickle.load(f) 113 | self.prompts = output_dict["prompts"] 114 | self.labels = output_dict["labels"] 115 | self.subjects = output_dict["subjects"] 116 | 117 | 118 | @torch.no_grad() 119 | def evaluate_mmlu(model: torch.nn.Module, tokenizer: PreTrainedTokenizer, mmlu_loader: torch.utils.data.DataLoader, 120 | device: torch.device, split_name: str, verbose: bool = False): 121 | total = 0 122 | correct = 0 123 | tokenized_ids = {char: tokenizer(char).input_ids[-1] for char in ["A", "B", "C", "D"]} 124 | if verbose: 125 | print("Tokenized IDs:", tokenized_ids) 126 | 127 | for batch in tqdm(mmlu_loader): 128 | input_ids = batch["input_ids"].to(device) 129 | assert input_ids.shape[0] == 1, input_ids.shape 130 | 131 | # Forward prop through the model 132 | logits = model(input_ids=input_ids).logits 133 | assert logits.shape[0] == 1, f"batch size should be 1. Found: {logits.shape}" 134 | logits = logits[:, -1, :].flatten() # BSV format 135 | 136 | probs = ( 137 | torch.nn.functional.softmax( 138 | torch.tensor( 139 | [ 140 | logits[tokenized_ids["A"]], 141 | logits[tokenized_ids["B"]], 142 | logits[tokenized_ids["C"]], 143 | logits[tokenized_ids["D"]], 144 | ] 145 | ), 146 | dim=0, 147 | ) 148 | .detach() 149 | .cpu() 150 | .numpy() 151 | ) 152 | pred = np.argmax(probs) # idx of the correct option 153 | is_correct = int(pred) == int(batch["label"]) 154 | correct += int(is_correct) 155 | if verbose: 156 | print(f"probs: {probs} / pred: {pred} / label: {int(batch['label'])} / correct: {is_correct}") 157 | total += 1 158 | 159 | correct = int(reduce_tensor(torch.tensor(correct).to(device))) 160 | total = int(reduce_tensor(torch.tensor(total).to(device))) 161 | weighted_acc = 100. * correct / total 162 | 163 | output_dict = {"mmlu_split": split_name, "total": total, "correct": correct, "weighted_acc": weighted_acc} 164 | print(json.dumps(output_dict)) 165 | if split_name is not None and wandb.run is not None: 166 | wandb.log({f"mmlu_{split_name}": {"total": total, "correct": correct, "weighted_acc": weighted_acc}}) 167 | return correct, total, weighted_acc 168 | 169 | 170 | if __name__ == "__main__": 171 | from argparse import Namespace 172 | from trainer import load_model 173 | from dataset import get_dataloader 174 | from dist_utils import init_distributed_env 175 | 176 | model_name = 'llama-2' 177 | args = Namespace(model_name=model_name, model_size='7b', use_instruct_model=False, 178 | amp_dtype=None, use_gradient_checkpointing=False, batch_size=1, 179 | num_workers=8, seed=43) 180 | 181 | # Setup the distributed env 182 | init_distributed_env(args) 183 | 184 | # Load the tokenizer 185 | tokenizer = load_model(args, only_tokenizer=True) 186 | 187 | # Wrap the dataset into a dataloader 188 | dataset = MMLUDataset(tokenizer, model_name) 189 | print("# examples in dataset:", len(dataset)) 190 | 191 | # Generator to seed dataloaders 192 | generator = torch.Generator() 193 | generator.manual_seed(args.seed) 194 | dl = get_dataloader(dataset, batch_size=args.batch_size, num_workers=args.num_workers, 195 | generator=generator) # no padding so batch size should be 1 196 | 197 | for i, input_dict in enumerate(dl): 198 | print("Prompt:", input_dict["input_ids"].shape, input_dict["input_ids"][:10]) 199 | print("Label:", input_dict["label"]) 200 | print("Subject:", input_dict["subject"]) 201 | print("="*10) 202 | if i >= 5: 203 | break 204 | 205 | # Test evaluation 206 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 207 | model, _, _ = load_model(args, pretrained=True) 208 | model.eval() # set the model in eval mode 209 | model = model.to(device) # move the model to device 210 | correct, total, weighted_acc = evaluate_mmlu(model, tokenizer, dl, device, split_name=None, verbose=True) 211 | print(f"Final weighted acc: {weighted_acc:.2f}% ({correct}/{total})") 212 | -------------------------------------------------------------------------------- /evals/mmlu.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation adopted from official implementation by Dan Hendrycks: 3 | https://github.com/hendrycks/test/blob/master/evaluate_flan.py 4 | """ 5 | 6 | import os 7 | import sys 8 | import numpy as np 9 | import pandas as pd 10 | 11 | import torch 12 | 13 | from .mmlu_utils import categories, subcategories, download_mmlu 14 | 15 | sys.path.append("..") # top-level package 16 | from dist_utils import is_main_proc, wait_for_other_procs 17 | 18 | choices = ["A", "B", "C", "D"] 19 | 20 | 21 | def format_subject(subject): 22 | l = subject.split("_") 23 | s = "" 24 | for entry in l: 25 | s += " " + entry 26 | return s 27 | 28 | 29 | def format_example(df, idx, include_answer=True): 30 | prompt = df.iloc[idx, 0] 31 | k = df.shape[1] - 2 32 | for j in range(k): 33 | prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j+1]) 34 | prompt += "\nAnswer:" 35 | if include_answer: 36 | prompt += " {}\n\n".format(df.iloc[idx, k + 1]) 37 | return prompt 38 | 39 | 40 | def gen_prompt(train_df, subject, k=-1): 41 | prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(format_subject(subject)) 42 | if k == -1: 43 | k = train_df.shape[0] 44 | for i in range(k): 45 | prompt += format_example(train_df, i) 46 | return prompt 47 | 48 | 49 | @torch.no_grad 50 | def eval(subject, model, tokenizer, dev_df, test_df, n_train): 51 | cors = [] 52 | all_probs = [] 53 | answers = choices[:test_df.shape[1]-2] 54 | 55 | for i in range(test_df.shape[0]): 56 | # get prompt and make sure it fits 57 | k = n_train 58 | prompt_end = format_example(test_df, i, include_answer=False) 59 | train_prompt = gen_prompt(dev_df, subject, k) 60 | prompt = train_prompt + prompt_end 61 | 62 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda() 63 | 64 | while input_ids.shape[-1] > 2048: 65 | k -= 1 66 | train_prompt = gen_prompt(dev_df, subject, k) 67 | prompt = train_prompt + prompt_end 68 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda() 69 | # print("prompt:", prompt) 70 | label = test_df.iloc[i, test_df.shape[1] - 1] 71 | 72 | # Forward prop through the model 73 | logits = model(input_ids=input_ids).logits 74 | assert logits.shape[0] == 1, f"batch size should be 1. Found: {logits.shape}" 75 | logits = logits[:, -1, :].flatten() # BSV format 76 | 77 | probs = ( 78 | torch.nn.functional.softmax( 79 | torch.tensor( 80 | [ 81 | logits[tokenizer("A").input_ids[-1]], 82 | logits[tokenizer("B").input_ids[-1]], 83 | logits[tokenizer("C").input_ids[-1]], 84 | logits[tokenizer("D").input_ids[-1]], 85 | ] 86 | ), 87 | dim=0, 88 | ) 89 | .detach() 90 | .cpu() 91 | .numpy() 92 | ) 93 | pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)] 94 | 95 | cor = pred == label 96 | cors.append(cor) 97 | all_probs.append(probs) 98 | 99 | acc = np.mean(cors) 100 | cors = np.array(cors) 101 | 102 | all_probs = np.array(all_probs) 103 | print("Average accuracy {:.3f} - {}".format(acc, subject)) 104 | 105 | return cors, acc, all_probs 106 | 107 | 108 | def evaluate_mmlu(model, tokenizer, output_log_file, n_train=5, data_dir=None): 109 | if data_dir is None: 110 | data_dir = "./datasets/mmlu/" # default dataset directory 111 | if not os.path.exists(data_dir) and is_main_proc(): 112 | download_mmlu() 113 | wait_for_other_procs() # barrier to wait for model download 114 | print("Using MMLU dataset directory:", data_dir) 115 | assert os.path.exists(data_dir), data_dir 116 | 117 | subjects = sorted( 118 | [ 119 | f.split("_test.csv")[0] 120 | for f in os.listdir(os.path.join(data_dir, "test")) 121 | if "_test.csv" in f 122 | ] 123 | ) 124 | 125 | all_cors = [] 126 | subcat_cors = { 127 | subcat: [] for subcat_lists in subcategories.values() for subcat in subcat_lists 128 | } 129 | cat_cors = {cat: [] for cat in categories} 130 | 131 | for subject in subjects: 132 | dev_df = pd.read_csv( 133 | os.path.join(data_dir, "dev", subject + "_dev.csv"), header=None 134 | )[:n_train] 135 | test_df = pd.read_csv( 136 | os.path.join(data_dir, "test", subject + "_test.csv"), header=None 137 | ) 138 | 139 | cors, acc, probs = eval(subject, model, tokenizer, dev_df, test_df, n_train) 140 | subcats = subcategories[subject] 141 | for subcat in subcats: 142 | subcat_cors[subcat].append(cors) 143 | for key in categories.keys(): 144 | if subcat in categories[key]: 145 | cat_cors[key].append(cors) 146 | all_cors.append(cors) 147 | 148 | test_df["correct"] = cors 149 | for j in range(probs.shape[1]): 150 | choice = choices[j] 151 | test_df[f"choice{choice}_probs"] = probs[:, j] 152 | test_df.to_csv(output_log_file, index=None) 153 | 154 | for subcat in subcat_cors: 155 | subcat_acc = np.mean(np.concatenate(subcat_cors[subcat])) 156 | print("Average accuracy {:.3f} - {}".format(subcat_acc, subcat)) 157 | 158 | for cat in cat_cors: 159 | cat_acc = np.mean(np.concatenate(cat_cors[cat])) 160 | print("Average accuracy {:.3f} - {}".format(cat_acc, cat)) 161 | weighted_acc = np.mean(np.concatenate(all_cors)) 162 | print("Average accuracy: {:.3f}".format(weighted_acc)) 163 | return weighted_acc 164 | 165 | 166 | if __name__ == "__main__": 167 | download_mmlu() 168 | -------------------------------------------------------------------------------- /evals/mmlu_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation adopted from official implementation by Dan Hendrycks: 3 | https://github.com/hendrycks/test/blob/master/categories.py 4 | """ 5 | 6 | import os 7 | import wget 8 | import tarfile 9 | 10 | 11 | subcategories = { 12 | "abstract_algebra": ["math"], 13 | "anatomy": ["health"], 14 | "astronomy": ["physics"], 15 | "business_ethics": ["business"], 16 | "clinical_knowledge": ["health"], 17 | "college_biology": ["biology"], 18 | "college_chemistry": ["chemistry"], 19 | "college_computer_science": ["computer science"], 20 | "college_mathematics": ["math"], 21 | "college_medicine": ["health"], 22 | "college_physics": ["physics"], 23 | "computer_security": ["computer science"], 24 | "conceptual_physics": ["physics"], 25 | "econometrics": ["economics"], 26 | "electrical_engineering": ["engineering"], 27 | "elementary_mathematics": ["math"], 28 | "formal_logic": ["philosophy"], 29 | "global_facts": ["other"], 30 | "high_school_biology": ["biology"], 31 | "high_school_chemistry": ["chemistry"], 32 | "high_school_computer_science": ["computer science"], 33 | "high_school_european_history": ["history"], 34 | "high_school_geography": ["geography"], 35 | "high_school_government_and_politics": ["politics"], 36 | "high_school_macroeconomics": ["economics"], 37 | "high_school_mathematics": ["math"], 38 | "high_school_microeconomics": ["economics"], 39 | "high_school_physics": ["physics"], 40 | "high_school_psychology": ["psychology"], 41 | "high_school_statistics": ["math"], 42 | "high_school_us_history": ["history"], 43 | "high_school_world_history": ["history"], 44 | "human_aging": ["health"], 45 | "human_sexuality": ["culture"], 46 | "international_law": ["law"], 47 | "jurisprudence": ["law"], 48 | "logical_fallacies": ["philosophy"], 49 | "machine_learning": ["computer science"], 50 | "management": ["business"], 51 | "marketing": ["business"], 52 | "medical_genetics": ["health"], 53 | "miscellaneous": ["other"], 54 | "moral_disputes": ["philosophy"], 55 | "moral_scenarios": ["philosophy"], 56 | "nutrition": ["health"], 57 | "philosophy": ["philosophy"], 58 | "prehistory": ["history"], 59 | "professional_accounting": ["other"], 60 | "professional_law": ["law"], 61 | "professional_medicine": ["health"], 62 | "professional_psychology": ["psychology"], 63 | "public_relations": ["politics"], 64 | "security_studies": ["politics"], 65 | "sociology": ["culture"], 66 | "us_foreign_policy": ["politics"], 67 | "virology": ["health"], 68 | "world_religions": ["philosophy"], 69 | } 70 | 71 | categories = { 72 | "STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering"], 73 | "humanities": ["history", "philosophy", "law"], 74 | "social sciences": ["politics", "culture", "economics", "geography", "psychology"], 75 | "other (business, health, misc.)": ["other", "business", "health"], 76 | } 77 | 78 | 79 | def extract_all_files(tar_file_path, extract_to): 80 | with tarfile.open(tar_file_path, 'r') as tar: 81 | tar.extractall(extract_to) 82 | 83 | 84 | def download_mmlu(output_dir=None): 85 | url = "https://people.eecs.berkeley.edu/~hendrycks/data.tar" 86 | if output_dir is None: 87 | output_dir = "./datasets/" 88 | if not os.path.exists(output_dir): 89 | os.mkdir(output_dir) 90 | print("Output directory created:", output_dir) 91 | output_path = os.path.join(output_dir, "mmlu.tar") 92 | 93 | # Download dataset file 94 | if not os.path.exists(output_path): 95 | print("Downloading dataset from URL:", url) 96 | filename = wget.download(url, out=output_path) 97 | print("Dataset file downloaded:", output_path) 98 | assert os.path.exists(output_path), output_path 99 | 100 | # Extract dataset file 101 | dataset_dir = output_path.replace(".tar", "/") 102 | if not os.path.exists(dataset_dir): 103 | print("Exacting files...") 104 | temp_output_dir = os.path.join(output_dir, "temp") 105 | extract_all_files(output_path, temp_output_dir) 106 | os.rename(os.path.join(temp_output_dir, "data"), dataset_dir) # remove data from directory name 107 | os.rmdir(temp_output_dir) # remove temp dir 108 | assert os.path.exists(dataset_dir), dataset_dir 109 | 110 | print("Number of files in dataset directory:", len(os.listdir(dataset_dir)), 111 | os.listdir(dataset_dir)[:5]) 112 | -------------------------------------------------------------------------------- /layer_influence/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export TOKENIZERS_PARALLELISM=false # disable tokenizer warning 4 | pip install wget 5 | 6 | # Get the DDP args 7 | HEAD_NODE_IP=$1 8 | NUM_NODES=$2 9 | NUM_GPUS_PER_NODE=2 10 | echo "Head node IP: ${HEAD_NODE_IP} / # nodes: ${NUM_NODES} / # GPUs per node: ${NUM_GPUS_PER_NODE}" 11 | 12 | # Check if HEAD_NODE_IP is given 13 | if [ -z "${HEAD_NODE_IP}" ]; then 14 | echo "No head node IP found. Using torchrun runner." 15 | RUNNER_CMD="torchrun --standalone --nnodes=1 --nproc_per_node=${NUM_GPUS_PER_NODE}" 16 | else 17 | export WORLD_SIZE=${SLURM_NTASKS} 18 | export RANK=${SLURM_PROCID} 19 | export LOCAL_RANK=${SLURM_LOCALID} 20 | export MASTER_ADDR=${HEAD_NODE_IP} 21 | export MASTER_PORT=29500 22 | echo "python args / world size: ${WORLD_SIZE} / rank: ${RANK} / local rank: ${LOCAL_RANK} / master addr: ${MASTER_ADDR} / master port: ${MASTER_PORT}" 23 | 24 | RUNNER_CMD="python" 25 | fi 26 | 27 | DEFAULT_MODEL="llama-2" 28 | MODEL=${3:-$DEFAULT_MODEL} 29 | echo "Using model: ${MODEL}" 30 | 31 | DEFAULT_PRUNING_SCHEME="mhsa" 32 | PRUNING_SCHEME=${4:-$DEFAULT_PRUNING_SCHEME} 33 | echo "Pruning scheme: ${PRUNING_SCHEME}" 34 | 35 | ${RUNNER_CMD} layer_influence/evaluate_layer_influence.py \ 36 | --dataset "openwebtext" \ 37 | --model-name ${MODEL} \ 38 | --model-size 7b \ 39 | --batch-size 1 \ 40 | --sequence-length 2048 \ 41 | --subsample-size 250000 \ 42 | --pruning-scheme ${PRUNING_SCHEME} \ 43 | --limit-shapley-samples 25000 \ 44 | --wandb-project 'layer_influence' 45 | -------------------------------------------------------------------------------- /layer_influence/evaluate_layer_influence.py: -------------------------------------------------------------------------------- 1 | #!/bin/python 2 | 3 | import os 4 | import time 5 | import json 6 | import random 7 | import argparse 8 | from tqdm import tqdm 9 | from typing import Tuple 10 | 11 | import wandb 12 | 13 | import torch 14 | import numpy as np 15 | 16 | from transformers import AutoTokenizer, AutoConfig 17 | 18 | from llama_layer_influence import LlamaForCausalLM 19 | from mistral_layer_influence import MistralForCausalLM 20 | 21 | import sys 22 | sys.path.append('.') 23 | from dataset import NLPDataset, get_dataloader 24 | from train_utils import get_num_model_params 25 | from dist_utils import init_distributed_env, is_main_proc, wait_for_other_procs, reduce_tensor 26 | from block_influence import BlockInfluenceEstimator 27 | from evals.dist_mmlu import MMLUDataset, evaluate_mmlu 28 | 29 | 30 | def load_model(args, only_tokenizer=False, pretrained=False): 31 | # assumes huggingface login: `huggingface-cli login`` 32 | if args.model_name == "llama-2": 33 | if args.use_instruct_model: 34 | model_name = f"meta-llama/Llama-2-{args.model_size.lower()}-chat-hf" 35 | else: 36 | model_name = f"meta-llama/Llama-2-{args.model_size.lower()}-hf" 37 | elif args.model_name == "mistral": 38 | if args.use_instruct_model: 39 | model_name = f"mistralai/Mistral-{args.model_size.upper()}-Instruct-v0.2" 40 | else: 41 | model_name = f"mistralai/Mistral-{args.model_size.upper()}-v0.1" 42 | else: 43 | raise RuntimeError(f"Unsupported model: {args.model_name}") 44 | print("!! Loading model:", model_name) 45 | 46 | # Load the tokenizer 47 | tokenizer = AutoTokenizer.from_pretrained(model_name) 48 | if only_tokenizer: 49 | return tokenizer 50 | 51 | # Load the model as well as the tokenizer 52 | config = AutoConfig.from_pretrained(model_name) 53 | print("Config:", config) 54 | kwargs = dict(torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2") 55 | print("Model precision:", kwargs["torch_dtype"]) 56 | if pretrained: 57 | print("Using pretrained model...") 58 | 59 | if args.model_name == "llama-2": 60 | if not pretrained: 61 | model = LlamaForCausalLM(config).to(kwargs["torch_dtype"]) 62 | else: 63 | model = LlamaForCausalLM.from_pretrained(model_name, **kwargs) 64 | elif args.model_name == "mistral": 65 | if not pretrained: 66 | model = MistralForCausalLM(config).to(kwargs["torch_dtype"]) 67 | else: 68 | model = MistralForCausalLM.from_pretrained(model_name, **kwargs) 69 | else: 70 | raise RuntimeError(f"Unsupported model: {args.model_name}") 71 | 72 | return model, tokenizer 73 | 74 | 75 | def compute_log_probs(logits: torch.Tensor, target_ids: torch.Tensor) -> Tuple[np.ndarray, np.ndarray]: 76 | # Apply softmax and log to obtain log probabilities from logits (summing original logits would be incorrect) 77 | log_probs = torch.log_softmax(logits.float(), dim=-1) 78 | 79 | log_probs = torch.gather(log_probs, 2, target_ids.unsqueeze(-1)).squeeze(-1) 80 | sequence_log_prob = log_probs.sum(dim=1).cpu().float().numpy() 81 | 82 | # Calculate perplexity 83 | sequence_length = target_ids.size(-1) 84 | assert sequence_length > 0, logits 85 | sequence_perplexity = np.exp(-sequence_log_prob / sequence_length) 86 | 87 | return sequence_perplexity, sequence_log_prob 88 | 89 | 90 | @torch.no_grad() 91 | def evaluate_model(model: torch.nn.Module, eval_loader: torch.utils.data.DataLoader, device: torch.device, split_name: str): 92 | model.eval() 93 | avg_sequence_perplexity = 0. 94 | avg_loss = 0. 95 | num_ex = 0 96 | 97 | for batch in tqdm(eval_loader): 98 | tokenized_input = batch["input_ids"].to(device) 99 | 100 | # Forward prop through the model (will also populate the loss, but one extra logit) 101 | outputs = model(tokenized_input, labels=tokenized_input) 102 | 103 | # Compute metrics on top of LM logits 104 | lm_logits = outputs.logits[:, :-1, :] # BTD format (discard the final logit) 105 | target_ids = tokenized_input[:, 1:] # input ids strided by one 106 | assert len(lm_logits.shape) == 3, lm_logits.shape 107 | assert len(target_ids.shape) == 2, target_ids.shape 108 | assert lm_logits.shape[1] == target_ids.shape[1], f"{lm_logits.shape} != {target_ids.shape}" 109 | perplexity, log_prob = compute_log_probs(lm_logits, target_ids) 110 | 111 | avg_sequence_perplexity += float(perplexity.sum()) 112 | avg_loss += float(outputs.loss) 113 | num_ex += len(tokenized_input) 114 | 115 | # Collect the stats from all processes 116 | avg_sequence_perplexity = float(reduce_tensor(torch.tensor(avg_sequence_perplexity).to(device))) 117 | avg_loss = float(reduce_tensor(torch.tensor(avg_loss).to(device))) 118 | num_ex = int(reduce_tensor(torch.tensor(num_ex).to(device))) 119 | 120 | avg_sequence_perplexity = avg_sequence_perplexity / num_ex 121 | avg_loss = avg_loss / num_ex 122 | output_dict = {"split": split_name, "num_ex": num_ex, "avg_loss": avg_loss, "avg_seq_perplexity": avg_sequence_perplexity} 123 | print(json.dumps(output_dict)) 124 | if split_name is not None and wandb.run is not None: 125 | wandb.log({f"eval_{split_name}": {"num_ex": num_ex, "avg_loss": avg_loss, "avg_seq_perplexity": avg_sequence_perplexity}}) 126 | return avg_loss, avg_sequence_perplexity 127 | 128 | 129 | @torch.no_grad() 130 | def compute_layer_shapley(model: torch.nn.Module, eval_loader: torch.utils.data.DataLoader, device: torch.device, 131 | use_random_subnetworks: bool = False, subnetwork_len: float = 0.5, seed: int = 43, 132 | num_subsampled_networks: int = 10, max_samples_per_proc: int = None): 133 | model.eval() 134 | base_module = model.module if hasattr(model, 'module') else model # redefine the base module 135 | num_model_layers = base_module.get_num_model_layers() 136 | num_model_layers = num_model_layers * 2 # MHSA + MLP in each block 137 | print(f"!! Computing the logit shapley value for the model with {num_model_layers} layers...") 138 | rng = np.random.default_rng(seed) 139 | if not use_random_subnetworks: 140 | num_subsampled_networks = num_model_layers 141 | 142 | all_statistics = [] 143 | for iterator, batch in enumerate(tqdm(eval_loader)): 144 | tokenized_input = batch["input_ids"].to(device) 145 | base_logits = None 146 | for i in range(1+num_subsampled_networks): # first one is always base model eval 147 | selected_layers = None # use full network 148 | if i != 0: # use subnetwork 149 | if use_random_subnetworks: 150 | selected_layers = rng.choice(range(num_model_layers), int(subnetwork_len*num_model_layers), replace=False) 151 | else: 152 | layer_to_remove = i - 1 153 | selected_layers = [x for x in range(num_model_layers) if x != layer_to_remove] 154 | base_module.select_layers(selected_layers, verbose=False) 155 | 156 | outputs = model(tokenized_input, labels=tokenized_input) 157 | lm_logits = outputs.logits[:, :-1, :] # BTD format (discard the final logit) 158 | lm_loss = outputs.loss 159 | if base_logits is None: 160 | assert selected_layers is None 161 | base_logits = lm_logits 162 | else: 163 | assert selected_layers is not None 164 | diff_norm = torch.norm(base_logits - lm_logits, p=2, dim=-1).mean() # mean over batch and sequence 165 | all_statistics.append((selected_layers, float(diff_norm), float(lm_loss))) 166 | 167 | # Check if stopping condition is met 168 | if max_samples_per_proc is not None and iterator >= max_samples_per_proc - 1: 169 | print(f"{iterator} samples collected for logit shapley value. Stopping further computations!") 170 | break 171 | 172 | # Compute the layer influence based on the computed statistics 173 | logit_dist = {i: {"present": [], "absent": []} for i in range(num_model_layers)} 174 | loss_dist = {i: {"present": [], "absent": []} for i in range(num_model_layers)} 175 | for selected_layers, diff_norm, loss in all_statistics: 176 | for i in range(num_model_layers): 177 | key = "present" if i in selected_layers else "absent" 178 | logit_dist[i][key].append(diff_norm) 179 | loss_dist[i][key].append(loss) 180 | 181 | # Compute average distances 182 | print("~~~~~~ Layer shapley statistics ~~~~~~") 183 | logit_shapley_list = [] 184 | loss_shapley_list = [] 185 | for key, input_container, output_container in [("dist", logit_dist, logit_shapley_list), 186 | ("loss", loss_dist, loss_shapley_list)]: 187 | for i in range(num_model_layers): 188 | for name in ["present", "absent"]: 189 | mean = np.mean(input_container[i][name]) # convert it to mean 190 | input_container[i][name] = float(reduce_tensor(torch.tensor(mean).to(device), average=True)) 191 | shapley = input_container[i]['present'] - input_container[i]['absent'] 192 | print(f"> layer {i} / present mean {key}: {input_container[i]['present']} / absent mean {key}: {input_container[i]['absent']} / shapley: {shapley}") 193 | output_container.append(shapley) 194 | print("-"*50) 195 | return logit_shapley_list, loss_shapley_list 196 | 197 | 198 | def main(args): 199 | init_distributed_env(args) 200 | 201 | generator = None 202 | if args.seed is not None: # Set process seed to reduce stochasticity 203 | torch.manual_seed(args.seed) 204 | torch.cuda.manual_seed(args.seed) 205 | np.random.seed(seed=args.seed) 206 | random.seed(args.seed) 207 | print("Setting process seed:", args.seed) 208 | 209 | # Generator to seed dataloaders 210 | generator = torch.Generator() 211 | generator.manual_seed(args.seed) 212 | 213 | dataset_dir = f"{args.dataset}_model_{args.model_name}_seq_len_{args.sequence_length}_subsample_{args.subsample_size}_comb_docs" 214 | args.dataset_output_dir = os.path.join("datasets", dataset_dir) 215 | 216 | suffix = f"layer_pruning_{args.pruning_scheme}" 217 | args.wandb_run_name = f"{dataset_dir}_{suffix}" 218 | 219 | if args.wandb_project is not None and is_main_proc(): 220 | print("Initialization w&b...") 221 | wandb.init(project=args.wandb_project, name=args.wandb_run_name, config=args, resume=False) 222 | 223 | if is_main_proc() and not NLPDataset.is_dataset_processed(args.dataset_output_dir): 224 | tokenizer = load_model(args, only_tokenizer=True) 225 | dataset = NLPDataset(args.dataset, tokenizer, max_length=args.sequence_length, 226 | combine_documents=True, subsample_size=args.subsample_size) 227 | dataset.save_datasets(args.dataset_output_dir) 228 | wait_for_other_procs() # wait for the main process to write the dataset 229 | 230 | # Load the dataset 231 | dataset = NLPDataset.load_dataset(args.dataset_output_dir) # returns a dataset dict 232 | train_dataset = dataset["train"] 233 | test_dataset = dataset["test"] 234 | 235 | # Load the model 236 | model, tokenizer = load_model(args, pretrained=True) 237 | num_model_params = get_num_model_params(model) 238 | num_model_layers = model.get_num_model_layers() 239 | print(f"# model params: {num_model_params/1_000_000:.2f}M / # layers: {num_model_layers}") 240 | 241 | # Create the low-rank adapters 242 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 243 | model = model.to(device) # move to device 244 | 245 | # Create the dataloaders 246 | train_loader = get_dataloader(train_dataset, args.batch_size, args.num_workers, drop_last=True, generator=generator) 247 | eval_loader = get_dataloader(test_dataset, args.test_batch_size, args.num_workers, generator=generator) 248 | 249 | # Load MMLU dataset 250 | mmlu_dataset = MMLUDataset(tokenizer, args.model_name) 251 | print("# examples in MMLU dataset:", len(mmlu_dataset)) 252 | mmlu_loader = get_dataloader(mmlu_dataset, 1, args.num_workers, generator=generator) # bs=1 for MMLU 253 | 254 | eval_start_time = time.time() 255 | model.select_layers(None) # use all layers 256 | total_layers = num_model_layers * 2 # MHSA + MLP in each block 257 | print("Total model layers:", total_layers) 258 | if args.use_block_influence: 259 | print(">> Using block influences for layer pruning...") 260 | if args.model_name == "llama-2": 261 | cosine_block_influences = [0.4065510630607605, 0.24990439414978027, 0.1453840732574463, 0.1462407112121582, 0.1515178084373474, 0.15010249614715576, 0.1436394453048706, 0.14198929071426392, 0.1338726282119751, 0.1287376880645752, 0.12480568885803223, 0.11314243078231812, 0.10922092199325562, 0.10999512672424316, 0.10579818487167358, 0.10590404272079468, 0.10621917247772217, 0.08393210172653198, 0.07559847831726074, 0.06368941068649292, 0.06160825490951538, 0.046613335609436035, 0.04402130842208862, 0.036509573459625244, 0.0347287654876709, 0.030786514282226562, 0.033209800720214844, 0.02987152338027954, 0.03320187330245972, 0.033727407455444336, 0.06028544902801514, 0.3071935772895813] 262 | else: 263 | assert args.model_name == "mistral" 264 | cosine_block_influences = [0.4131488800048828, 0.22294247150421143, 0.18501925468444824, 0.15420764684677124, 0.15791428089141846, 0.16193944215774536, 0.15579771995544434, 0.15514129400253296, 0.14191526174545288, 0.13151496648788452, 0.12524676322937012, 0.1169435977935791, 0.11011826992034912, 0.10614091157913208, 0.11162376403808594, 0.11073744297027588, 0.1094699501991272, 0.10527843236923218, 0.10619938373565674, 0.09696263074874878, 0.07502329349517822, 0.06127279996871948, 0.04937338829040527, 0.04653525352478027, 0.04156315326690674, 0.03952103853225708, 0.03852170705795288, 0.04080432653427124, 0.050363004207611084, 0.06016743183135986, 0.06128782033920288, 0.17413002252578735] 265 | cosine_layer_influences = [] 266 | for inf in cosine_block_influences: 267 | cosine_layer_influences += [inf, inf] # MHSA / MLP 268 | assert len(cosine_layer_influences) == total_layers, f"{len(cosine_layer_influences)} != {total_layers}" 269 | layer_influence_list = [("block_cosine", cosine_layer_influences)] 270 | assert args.pruning_scheme in ["mhsa", "mlp"] 271 | else: 272 | print(">> Estimating layer influences...") 273 | block_influence_estimator = BlockInfluenceEstimator(total_layers, device) 274 | model.add_layer_influence_estimator(block_influence_estimator) 275 | evaluate_model(model, train_loader, device, split_name=None) # use the train set to compute block influences 276 | final_layer_influences = block_influence_estimator.get_block_influences() 277 | model.add_layer_influence_estimator(None) # remove the block influence computation 278 | print("Final layer influences:", final_layer_influences) 279 | 280 | cosine_layer_influences = [x["cosine_dist"] for x in final_layer_influences] 281 | l1_layer_influences = [x["l1_update_norm"] for x in final_layer_influences] 282 | relative_l1_layer_influences = [x["l1_relative_update_norm"] for x in final_layer_influences] 283 | l2_layer_influences = [x["l2_update_norm"] for x in final_layer_influences] 284 | relative_l2_layer_influences = [x["l2_relative_update_norm"] for x in final_layer_influences] 285 | print("Cosine layer influences:", cosine_layer_influences) 286 | print("L1 layer influences:", l1_layer_influences) 287 | print("Relative L1 layer influences:", relative_l1_layer_influences) 288 | print("L2 layer influences:", l2_layer_influences) 289 | print("Relative L2 layer influences:", relative_l2_layer_influences) 290 | 291 | if wandb.run is not None: 292 | wandb.log({f"layer_{i}_cosine_influence": layer_influence for i, layer_influence in enumerate(cosine_layer_influences)}) 293 | wandb.log({f"layer_{i}_l1_influence": layer_influence for i, layer_influence in enumerate(l1_layer_influences)}) 294 | wandb.log({f"layer_{i}_relative_l1_influence": layer_influence for i, layer_influence in enumerate(relative_l1_layer_influences)}) 295 | wandb.log({f"layer_{i}_l2_influence": layer_influence for i, layer_influence in enumerate(l2_layer_influences)}) 296 | wandb.log({f"layer_{i}_relative_l2_influence": layer_influence for i, layer_influence in enumerate(relative_l2_layer_influences)}) 297 | 298 | # Compute the block logit shapley 299 | max_samples_per_proc = None 300 | if args.limit_shapley_samples is not None: 301 | max_samples_per_proc = args.limit_shapley_samples // args.world_size 302 | print(f"Total samples limit: {args.limit_shapley_samples}. Capping the max_samples_per_proc for logit shapley computation to be: {max_samples_per_proc}") 303 | layer_logit_shapley, layer_loss_shapley = compute_layer_shapley(model, train_loader, device, max_samples_per_proc=max_samples_per_proc) 304 | print("Layer logit shapley:", layer_logit_shapley) 305 | print("Layer loss shapley:", layer_loss_shapley) 306 | logit_shapley_layer_influence = [-x for x in layer_logit_shapley] # negative shapely (lower distance) indicates higher importance 307 | loss_shapley_layer_influence = [-x for x in layer_loss_shapley] # negative shapely (lower distance) indicates higher importance 308 | if wandb.run is not None: 309 | wandb.log({f"layer_{i}_logit_shapley_influence": layer_influence for i, layer_influence in enumerate(logit_shapley_layer_influence)}) 310 | wandb.log({f"layer_{i}_loss_shapley_influence": layer_influence for i, layer_influence in enumerate(loss_shapley_layer_influence)}) 311 | 312 | layer_influence_list = [("cosine", cosine_layer_influences), ("relative_l1", relative_l1_layer_influences), 313 | ("relative_l2", relative_l2_layer_influences), ("logit_shapley", logit_shapley_layer_influence), 314 | ("loss_shapley", loss_shapley_layer_influence)] 315 | 316 | for influence_name, layer_influences in layer_influence_list: 317 | print("Using layer influence method:", influence_name) 318 | print("Layer influence values:", layer_influences) 319 | sorted_layers = np.argsort(layer_influences) # ascending order 320 | if args.pruning_scheme != "both": 321 | print("Selected pruning scheme:", args.pruning_scheme) 322 | if args.pruning_scheme == "mhsa": 323 | print("Only keeping even layers for removal...") 324 | sorted_layers = [x for x in sorted_layers if x % 2 == 0] # MHSA layers are even 325 | else: 326 | assert args.pruning_scheme == "mlp", args.pruning_scheme 327 | print("Only keeping odd layers for removal...") 328 | sorted_layers = [x for x in sorted_layers if x % 2 == 1] # MLP layers are odd 329 | print("Sorted layer list:", sorted_layers) 330 | 331 | remaining_layers = list(range(total_layers)) 332 | weighted_acc_list = [] 333 | perplexity_list = [] 334 | iterator = -1 335 | for _ in range(len(sorted_layers)+1): # one additional iteration for no dropping 336 | if iterator > -1: # do nothing for the first layer i.e., all layers are selected 337 | lowest_layer = sorted_layers[iterator] # prune blocks based on the estimated block influence 338 | print(f"Removing layer {lowest_layer} with lowest influence: {layer_influences[lowest_layer]}") 339 | remaining_layers = [i for i in remaining_layers if i != lowest_layer] # remove lowest layer 340 | print("Remaining layers:", remaining_layers) 341 | model.select_layers(remaining_layers) # use the selected layers 342 | _, _, weighted_acc = evaluate_mmlu(model, tokenizer, mmlu_loader, device, f"{influence_name}_layers_pruned_{iterator+1}") 343 | _, avg_perplexity = evaluate_model(model, eval_loader, device, f"{influence_name}_layers_pruned_{iterator+1}") 344 | weighted_acc_list.append(weighted_acc) 345 | perplexity_list.append(avg_perplexity) 346 | iterator += 1 347 | 348 | print(f">>>>> Layer pruning statistics using {influence_name} metric <<<<<") 349 | print(f"{influence_name} weighted ACC list: {weighted_acc_list}") 350 | print(f"{influence_name} perplexity list: {perplexity_list}") 351 | print("="*25) 352 | 353 | eval_time_elapsed_h = (time.time() - eval_start_time) / (60 * 60) # convert seconds into hours 354 | print(f"Layer pruning evaluation completed / time elapsed: {eval_time_elapsed_h:.2f}h") 355 | 356 | if wandb.run is not None: 357 | wandb.finish() 358 | print("Script execution completed!") 359 | 360 | 361 | if __name__ == "__main__": 362 | supported_datasets = ['pg19', 'cc_news', 'wikitext-2', 'bookcorpus', 'c4', 'openwebtext', 'slimpajama'] 363 | 364 | # Create ArgumentParser object 365 | parser = argparse.ArgumentParser(description='Argument parser for LLM layer influence evaluator') 366 | 367 | # Add arguments 368 | parser.add_argument('-d', '--dataset', default='wikitext-2', choices=supported_datasets, 369 | help='Dataset name (default: wikitext-2)') 370 | parser.add_argument('-m', '--model-name', default='llama-2', choices=['llama-2', 'mistral'], 371 | help='Model name (default: llama-2)') 372 | parser.add_argument('-s', '--model-size', default='7b', choices=['7b'], 373 | help='Model size (default: 7b)') 374 | parser.add_argument('--use-instruct-model', action='store_true', default=False, 375 | help='Use instruction-tuned model rather than the base model') 376 | parser.add_argument('--batch-size', type=int, default=1, 377 | help='Batch size per process (default: 1)') 378 | parser.add_argument('--test-batch-size', type=int, default=None, 379 | help='Batch size per process for testing (default: equal to --batch-size)') 380 | parser.add_argument('--sequence-length', type=int, default=1024, 381 | help='Sequence length for computing the model perplexity (default: 1024)') 382 | parser.add_argument('--subsample-size', type=int, default=1000000, 383 | help='Dataset subsample size in terms of number of docs (default: 1M)') 384 | parser.add_argument('--num-workers', type=int, default=8, 385 | help='Number of workers for the dataloader (default: 8)') 386 | parser.add_argument('--seed', type=int, default=43, 387 | help='seed value (default: 43)') 388 | parser.add_argument('--wandb-project', type=str, default=None, 389 | help='W&B project name (none indicates no W&B initialization)') 390 | parser.add_argument('--pruning-scheme', default='both', choices=['mhsa', 'mlp', 'both'], 391 | help='Pruning scheme (default: both)') 392 | parser.add_argument('--limit-shapley-samples', type=int, default=None, 393 | help='limit the number of samples to the specified value for shapley computation (default: None i.e., no limit)') 394 | parser.add_argument('--use-block-influence', action='store_true', default=False, 395 | help='use block influence for deciding the layer importance rather than computing it') 396 | 397 | # Parse the arguments 398 | args = parser.parse_args() 399 | 400 | if args.test_batch_size is None: 401 | args.test_batch_size = args.batch_size 402 | print("Setting test batch size to be equal to batch size:", args.test_batch_size) 403 | 404 | main(args) 405 | -------------------------------------------------------------------------------- /train_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_num_model_params(model): 5 | return sum( 6 | [p.numel() * 2 if p.is_complex() else p.numel() for p in model.parameters() if p.requires_grad] 7 | ) 8 | --------------------------------------------------------------------------------