├── requirements.txt ├── README.md ├── sh ├── Qwen3-30B-A3B.sh ├── DeepSeek-R1-BF16.sh ├── DeepSeek-V2-Lite.sh └── Mixtral-8x7B-v0.1.sh ├── model_utils.py ├── data_utils.py ├── run.py └── eval_utils.py /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.51.3 2 | datasets 3 | accelerate 4 | psutil 5 | safetensors 6 | seaborn 7 | matplotlib 8 | numpy 9 | scipy -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | ```bash 3 | conda create -n Super_Experts python==3.12 -y 4 | conda activate Super_Experts 5 | conda install pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=11.8 -c pytorch -c nvidia 6 | pip install -r requirements.txt 7 | ``` 8 | # Quick start 9 | ## Profiling Down_proj Outliers of Qwen3-30B-A3B 10 | ```shell 11 | CUDA_VISIBLE_DEVICES=0 python3 run.py \ 12 | --model_name Your Model Path \ 13 | --save_path ./output/Qwen3-30B-A3B \ 14 | --profile_outliers \ 15 | --vis_outliers_heatmap 16 | ``` 17 | ## Profiling Super Experts of Qwen3-30B-A3B 18 | ```shell 19 | CUDA_VISIBLE_DEVICES=0 python3 run.py \ 20 | --model_name Your Model Path \ 21 | --save_path ./output/Qwen3-30B-A3B \ 22 | --profile_super_experts \ 23 | --vis_super_experts_line_plot 24 | ``` 25 | ## PPL test of Original Qwen3-30B-A3B 26 | ```shell 27 | CUDA_VISIBLE_DEVICES=0 python3 run.py \ 28 | --model_name Your Model Path \ 29 | --save_path ./output/Qwen3-30B-A3B \ 30 | --dataset_name wikitext2 \ 31 | --eval_ppl 32 | ``` 33 | ## PPL Test of Qwen3-30B-A3B After Prune Super Experts 34 | ```shell 35 | CUDA_VISIBLE_DEVICES=0 python3 run.py \ 36 | --model_name Your Model Path \ 37 | --save_path ./output/Qwen3-30B-A3B \ 38 | --eval_ppl \ 39 | --dataset wikitext2 \ 40 | --prune_super_experts 41 | ``` 42 | 43 | -------------------------------------------------------------------------------- /sh/Qwen3-30B-A3B.sh: -------------------------------------------------------------------------------- 1 | # sh sh/Qwen3-30B-A3B.sh 2>&1 | tee logs/Qwen3-30B-A3B.log 2 | 3 | echo 'Profiling Down_proj Outliers of Qwen3-30B-A3B' 4 | CUDA_VISIBLE_DEVICES=0 python3 run.py \ 5 | --model_name Your Model Path \ 6 | --save_path ./output/Qwen3-30B-A3B \ 7 | --profile_outliers \ 8 | --vis_outliers_heatmap 9 | 10 | echo 'Profiling Massive Experts of Qwen3-30B-A3B' 11 | CUDA_VISIBLE_DEVICES=0 python3 run.py \ 12 | --model_name Your Model Path \ 13 | --save_path ./output/Qwen3-30B-A3B \ 14 | --profile_massive_experts \ 15 | --vis_massive_experts_line_plot 16 | 17 | echo 'Profiling Super Experts of Qwen3-30B-A3B' 18 | CUDA_VISIBLE_DEVICES=0 python3 run.py \ 19 | --model_name Your Model Path \ 20 | --save_path ./output/Qwen3-30B-A3B \ 21 | --profile_super_experts \ 22 | --vis_super_experts_line_plot 23 | 24 | echo 'PPL Test of Original Qwen3-30B-A3B' 25 | CUDA_VISIBLE_DEVICES=0 python3 run.py \ 26 | --model_name Your Model Path \ 27 | --save_path ./output/Qwen3-30B-A3B \ 28 | --dataset_name wikitext2 \ 29 | --eval_ppl 30 | 31 | echo 'PPL Test of Qwen3-30B-A3B After Rrune Experts' 32 | CUDA_VISIBLE_DEVICES=0 python3 run.py \ 33 | --model_name Your Model Path \ 34 | --save_path ./output/Qwen3-30B-A3B \ 35 | --eval_ppl \ 36 | --dataset wikitext2 \ 37 | --prune_experts "3,54;4,38;2,3;5,63" 38 | 39 | echo 'PPL Test of Qwen3-30B-A3B After Rrune Super Experts' 40 | CUDA_VISIBLE_DEVICES=0 python3 run.py \ 41 | --model_name Your Model Path \ 42 | --save_path ./output/Qwen3-30B-A3B \ 43 | --eval_ppl \ 44 | --dataset wikitext2 \ 45 | --prune_super_experts 46 | -------------------------------------------------------------------------------- /sh/DeepSeek-R1-BF16.sh: -------------------------------------------------------------------------------- 1 | # sh sh/DeepSeek-R1-BF16.sh 2>&1 | tee logs/DeepSeek-R1-BF16.log 2 | 3 | echo 'Profiling Down_proj Outliers of DeepSeek-R1-BF16' 4 | CUDA_VISIBLE_DEVICES=0 python3 run.py \ 5 | --model_name Your Model Path \ 6 | --save_path ./output/DeepSeek-R1-BF16 \ 7 | --profile_outliers \ 8 | --vis_outliers_heatmap 9 | 10 | echo 'Profiling Massive Experts of DeepSeek-R1-BF16' 11 | CUDA_VISIBLE_DEVICES=0 python3 run.py \ 12 | --model_name Your Model Path \ 13 | --save_path ./output/DeepSeek-R1-BF16 \ 14 | --profile_massive_experts \ 15 | --vis_massive_experts_line_plot 16 | 17 | echo 'Profiling Super Experts of DeepSeek-R1-BF16' 18 | CUDA_VISIBLE_DEVICES=0 python3 run.py \ 19 | --model_name Your Model Path \ 20 | --save_path ./output/DeepSeek-R1-BF16 \ 21 | --profile_super_experts \ 22 | --vis_super_experts_line_plot 23 | 24 | echo 'PPL Test of Original DeepSeek-R1-BF16' 25 | CUDA_VISIBLE_DEVICES=0 python3 run.py \ 26 | --model_name Your Model Path \ 27 | --save_path ./output/DeepSeek-R1-BF16 \ 28 | --dataset_name wikitext2 \ 29 | --eval_ppl 30 | 31 | echo 'PPL Test of DeepSeek-R1-BF16 After Rrune Experts' 32 | CUDA_VISIBLE_DEVICES=0 python3 run.py \ 33 | --model_name Your Model Path \ 34 | --save_path ./output/DeepSeek-R1-BF16 \ 35 | --eval_ppl \ 36 | --dataset wikitext2 \ 37 | --prune_experts "3,54;4,38;2,3;5,63" 38 | 39 | echo 'PPL Test of DeepSeek-R1-BF16 After Rrune Super Experts' 40 | CUDA_VISIBLE_DEVICES=0 python3 run.py \ 41 | --model_name Your Model Path \ 42 | --save_path ./output/DeepSeek-R1-BF16 \ 43 | --eval_ppl \ 44 | --dataset wikitext2 \ 45 | --prune_super_experts 46 | -------------------------------------------------------------------------------- /sh/DeepSeek-V2-Lite.sh: -------------------------------------------------------------------------------- 1 | # sh sh/DeepSeek-V2-Lite.sh 2>&1 | tee logs/DeepSeek-V2-Lite.log 2 | 3 | echo 'Profiling Down_proj Outliers of DeepSeek-V2-Lite' 4 | CUDA_VISIBLE_DEVICES=6 python3 run.py \ 5 | --model_name Your Model Path \ 6 | --save_path ./output/DeepSeek-V2-Lite \ 7 | --profile_outliers \ 8 | --vis_outliers_heatmap 9 | 10 | echo 'Profiling Massive Experts of DeepSeek-V2-Lite' 11 | CUDA_VISIBLE_DEVICES=0 python3 run.py \ 12 | --model_name Your Model Path \ 13 | --save_path ./output/DeepSeek-V2-Lite \ 14 | --profile_massive_experts \ 15 | --vis_massive_experts_line_plot 16 | 17 | echo 'Profiling Super Experts of DeepSeek-V2-Lite' 18 | CUDA_VISIBLE_DEVICES=0 python3 run.py \ 19 | --model_name Your Model Path \ 20 | --save_path ./output/DeepSeek-V2-Lite \ 21 | --profile_super_experts \ 22 | --vis_super_experts_line_plot 23 | 24 | echo 'PPL Test of Original DeepSeek-V2-Lite' 25 | CUDA_VISIBLE_DEVICES=0 python3 run.py \ 26 | --model_name Your Model Path \ 27 | --save_path ./output/DeepSeek-V2-Lite \ 28 | --dataset_name wikitext2 \ 29 | --eval_ppl 30 | 31 | echo 'PPL Test of DeepSeek-V2-Lite After Rrune Experts' 32 | CUDA_VISIBLE_DEVICES=0 python3 run.py \ 33 | --model_name Your Model Path \ 34 | --save_path ./output/DeepSeek-V2-Lite \ 35 | --eval_ppl \ 36 | --dataset wikitext2 \ 37 | --prune_experts "3,54;4,38;2,3;5,63" 38 | 39 | echo 'PPL Test of DeepSeek-V2-Lite After Rrune Super Experts' 40 | CUDA_VISIBLE_DEVICES=0 python3 run.py \ 41 | --model_name Your Model Path \ 42 | --save_path ./output/DeepSeek-V2-Lite \ 43 | --eval_ppl \ 44 | --dataset wikitext2 \ 45 | --prune_super_experts 46 | -------------------------------------------------------------------------------- /sh/Mixtral-8x7B-v0.1.sh: -------------------------------------------------------------------------------- 1 | # sh sh/Mixtral-8x7B-v0.1.sh 2>&1 | tee logs/Mixtral-8x7B-v0.1.log 2 | 3 | echo 'Profiling Down_proj Outliers of Mixtral-8x7B-v0.1' 4 | CUDA_VISIBLE_DEVICES=0 python3 run.py \ 5 | --model_name Your Model Path \ 6 | --save_path ./output/Mixtral-8x7B-v0.1 \ 7 | --profile_outliers \ 8 | --vis_outliers_heatmap 9 | 10 | echo 'Profiling Massive Experts of Mixtral-8x7B-v0.1' 11 | CUDA_VISIBLE_DEVICES=0 python3 run.py \ 12 | --model_name Your Model Path \ 13 | --save_path ./output/Mixtral-8x7B-v0.1 \ 14 | --profile_massive_experts \ 15 | --vis_massive_experts_line_plot 16 | 17 | echo 'Profiling Super Experts of Mixtral-8x7B-v0.1' 18 | CUDA_VISIBLE_DEVICES=0 python3 run.py \ 19 | --model_name Your Model Path \ 20 | --save_path ./output/Mixtral-8x7B-v0.1 \ 21 | --profile_super_experts \ 22 | --vis_super_experts_line_plot 23 | 24 | echo 'PPL Test of Original Mixtral-8x7B-v0.1' 25 | CUDA_VISIBLE_DEVICES=0 python3 run.py \ 26 | --model_name Your Model Path \ 27 | --save_path ./output/Mixtral-8x7B-v0.1 \ 28 | --dataset_name wikitext2 \ 29 | --eval_ppl 30 | 31 | echo 'PPL Test of Mixtral-8x7B-v0.1 After Rrune Experts' 32 | CUDA_VISIBLE_DEVICES=0 python3 run.py \ 33 | --model_name Your Model Path \ 34 | --save_path ./output/Mixtral-8x7B-v0.1 \ 35 | --eval_ppl \ 36 | --dataset wikitext2 \ 37 | --prune_experts "3,54;4,38;2,3;5,63" 38 | 39 | echo 'PPL Test of Mixtral-8x7B-v0.1 After Rrune Super Experts' 40 | CUDA_VISIBLE_DEVICES=0 python3 run.py \ 41 | --model_name Your Model Path \ 42 | --save_path ./output/Mixtral-8x7B-v0.1 \ 43 | --eval_ppl \ 44 | --dataset wikitext2 \ 45 | --prune_super_experts 46 | -------------------------------------------------------------------------------- /model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import psutil 4 | process = psutil.Process(os.getpid()) 5 | from accelerate import init_empty_weights 6 | from typing import TypeVar, Mapping 7 | import safetensors 8 | import json 9 | import time 10 | 11 | SAFETENSOR_MODEL_INDEX_FILE = "model.safetensors.index.json" 12 | 13 | class SafeTensorsSDAdapter(Mapping[TypeVar("KT"), TypeVar("VT")]): 14 | def __init__( 15 | self, 16 | model_path, 17 | device, 18 | filename=SAFETENSOR_MODEL_INDEX_FILE, 19 | weight_map_key="weight_map", 20 | ): 21 | self.filename = filename 22 | model_index_file = os.path.join(model_path, self.filename) 23 | assert os.path.exists(model_index_file), f"file {model_index_file} not exist!" 24 | with open(model_index_file, "r") as f: 25 | self.weight_map = json.load(f)[weight_map_key] 26 | self.model_path = model_path 27 | self.device = device 28 | self.st_f_cache = {} 29 | self.total_io_time = 0 30 | 31 | def __getitem__(self, key): 32 | st = time.time() 33 | st_file_name = self.weight_map[key] 34 | f = self._get_st(st_file_name) 35 | v = f.get_tensor(key) 36 | self.total_io_time += time.time() - st 37 | return v 38 | 39 | def io_time(self): 40 | return self.total_io_time 41 | 42 | def _get_st(self, st_file_name): 43 | if st_file_name not in self.st_f_cache: 44 | sf_f_path = os.path.join(self.model_path, st_file_name) 45 | 46 | st_f = safetensors.safe_open(sf_f_path, framework="pt", device=self.device) 47 | self.st_f_cache[st_file_name] = st_f 48 | return self.st_f_cache[st_file_name] 49 | 50 | def __contains__(self, key): 51 | return key in self.weight_map 52 | 53 | def __len__(self): 54 | return len(self.weight_map) 55 | 56 | def __iter__(self): 57 | return iter(self.weight_map) 58 | 59 | def keys(self): 60 | return self.weight_map.keys() 61 | 62 | def get_shape(self, key): 63 | st_file_name = self.weight_map[key] 64 | f = self._get_st(st_file_name) 65 | return f.get_slice(key).get_shape() 66 | 67 | def get_dtype(self, key): 68 | st_file_name = self.weight_map[key] 69 | f = self._get_st(st_file_name) 70 | return f.get_slice(key).get_dtype() 71 | 72 | def skip(*args, **kwargs): 73 | # This is a helper function to save time during the initialization! 74 | pass 75 | 76 | def get_model(model_path): 77 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM 78 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True) 79 | torch.nn.init.kaiming_uniform_ = skip 80 | torch.nn.init.uniform_ = skip 81 | torch.nn.init.normal_ = skip 82 | with init_empty_weights(): 83 | model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True) 84 | model.seqlen = 2048 85 | return model, tokenizer 86 | 87 | def get_model_safe_tensors(model_path): 88 | model_st = SafeTensorsSDAdapter(model_path, device='cpu') 89 | return model_st 90 | 91 | def print_memory_usage(): 92 | print(f"Memory usage: {process.memory_info().rss / 1024 ** 2:.2f} MB") 93 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import datasets 3 | import random 4 | 5 | def get_wikitext2(nsamples, tokenizer, seed, seqlen, eval_mode=False): 6 | if eval_mode: 7 | # testdata = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') 8 | # testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') 9 | testdata = datasets.load_from_disk('Your DateSets Path')['test'] 10 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') 11 | return testenc 12 | else: 13 | # traindata = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') 14 | traindata = datasets.load_from_disk('Your DateSets Path')['train'] 15 | trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') 16 | random.seed(seed) 17 | trainloader = [] 18 | for _ in range(nsamples): 19 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 20 | j = i + seqlen 21 | inp = trainenc.input_ids[:, i:j] 22 | tar = inp.clone() 23 | tar[:, :-1] = -100 24 | trainloader.append((inp, tar)) 25 | return trainloader 26 | 27 | def get_c4(nsamples, tokenizer, seed, seqlen, eval_mode=False): 28 | if eval_mode: 29 | # valdata = datasets.load_dataset( 30 | # 'allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation') 31 | # valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') 32 | # valenc = valenc.input_ids[:, :(256 * seqlen)] 33 | # class TokenizerWrapper: 34 | # def __init__(self, input_ids): 35 | # self.input_ids = input_ids 36 | # valenc = TokenizerWrapper(valenc) 37 | valenc = [] 38 | valdata = datasets.load_from_disk('Your DateSets Path') 39 | random.seed(0) 40 | for _ in range(256): 41 | while True: 42 | i = random.randint(0, len(valdata) - 1) 43 | tmp = tokenizer(valdata[i]['text'], return_tensors='pt') 44 | if tmp.input_ids.shape[1] > seqlen: 45 | break 46 | 47 | i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) 48 | j = i + seqlen 49 | valenc.append(tmp.input_ids[:, i:j]) 50 | valenc = torch.hstack(valenc) 51 | class TokenizerWrapper: 52 | def __init__(self, input_ids): 53 | self.input_ids = input_ids 54 | valenc = TokenizerWrapper(valenc) 55 | return valenc 56 | else: 57 | traindata = datasets.load_from_disk('Your DateSets Path') 58 | random.seed(seed) 59 | trainloader = [] 60 | for _ in range(nsamples): 61 | while True: 62 | i = random.randint(0, len(traindata) - 1) 63 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') 64 | if trainenc.input_ids.shape[1] >= seqlen: 65 | break 66 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 67 | j = i + seqlen 68 | inp = trainenc.input_ids[:, i:j] 69 | tar = inp.clone() 70 | tar[:, :-1] = -100 71 | trainloader.append((inp, tar)) 72 | return trainloader 73 | 74 | 75 | def get_loaders(name, tokenizer=None, nsamples=128, seed=0, seqlen=2048, eval_mode=False): 76 | if 'wikitext2' in name: 77 | return get_wikitext2(nsamples, tokenizer, seed, seqlen, eval_mode) 78 | elif 'c4' in name: 79 | return get_c4(nsamples, tokenizer, seed, seqlen, eval_mode) 80 | else: 81 | raise NotImplementedError 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import transformers 5 | from model_utils import get_model, get_model_safe_tensors 6 | from eval_utils import DeepSeekEvaluator, Qwen3MoeEvaluator, MixtralEvaluator 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser() 10 | 11 | # config 12 | parser.add_argument('--model_name', type=str, default='', help='') 13 | parser.add_argument('--dataset_name', type=str, default='c4', help='') 14 | parser.add_argument('--save_path', type=str, default='', help='') 15 | parser.add_argument('--seed', type=float, default=83, help='') 16 | 17 | # profile outliers 18 | parser.add_argument('--profile_outliers', action='store_true', default=False, help='profile max input/output outliers of down_proj') 19 | parser.add_argument('--vis_outliers_heatmap', action='store_true', default=False, help='vis heatmap of max input/output outliers of down_proj across all layers and experts') 20 | 21 | # profile massive experts 22 | parser.add_argument('--profile_massive_experts', action='store_true', default=False, help='profile Massive Experts') 23 | parser.add_argument('--vis_massive_experts_line_plot', action='store_true', default=False, help='vis lineplot of max input/output outliers of down_proj across all layers for all experts') 24 | 25 | # profile super experts 26 | parser.add_argument('--profile_super_experts', action='store_true', default=False, help='profile Super Experts') 27 | parser.add_argument('--vis_super_experts_line_plot', action='store_true', default=False, help='vis lineplot of max input/output outliers of down_proj across all layers for super experts') 28 | parser.add_argument('--include_layers', type=float, default=0.75, help='') 29 | 30 | # eval 31 | parser.add_argument('--eval_ppl', action='store_true', default=False, help='') 32 | parser.add_argument('--prune_experts', type=lambda x: [tuple(map(int, item.split(','))) for item in x.split(';')],default=None, help='Pass a list of tuples, each in the format "layer,expert_index" separated by semicolons, e.g., "1,-1;2,3". -1 indicates shared experts.') 33 | parser.add_argument('--prune_super_experts', action='store_true', default=False, help='') 34 | args = parser.parse_args() 35 | return args 36 | 37 | 38 | def main(args): 39 | transformers.set_seed(args.seed) 40 | if args.save_path: 41 | os.makedirs(args.save_path, exist_ok=True) 42 | 43 | # initialize model 44 | model, tokenizer = get_model(args.model_name) 45 | model_st = get_model_safe_tensors(args.model_name) 46 | model_name = model.__class__.__name__ 47 | print(model.config) 48 | dev = torch.device('cuda:0') 49 | model.eval() 50 | 51 | # initialize evaluator 52 | if model_name in ['DeepseekV3ForCausalLM', 'DeepseekV2ForCausalLM']: 53 | evaluator = DeepSeekEvaluator(model, tokenizer, model_st, dev, args) 54 | elif model_name in ['Qwen3MoeForCausalLM']: 55 | evaluator = Qwen3MoeEvaluator(model, tokenizer, model_st, dev, args) 56 | elif model_name in ['MixtralForCausalLM']: 57 | evaluator = MixtralEvaluator(model, tokenizer, model_st, dev, args) 58 | else: 59 | raise NotImplementedError 60 | 61 | # profilling super experts 62 | if args.profile_outliers or args.vis_outliers_heatmap: 63 | outliers_info_path = evaluator.outliers_profiler(args.vis_outliers_heatmap, require_hook=True) 64 | print(f"Outliers Info is in {outliers_info_path}.") 65 | 66 | if args.profile_massive_experts or args.vis_massive_experts_line_plot: 67 | if not os.path.exists(evaluator.outliers_info_path): 68 | outliers_info_path = evaluator.outliers_profiler(args.vis_outliers_heatmap, require_hook=True) 69 | massive_experts_info_path = evaluator.massive_experts_profiler(args.vis_massive_experts_line_plot) 70 | print(f"Massive Experts Info is in {massive_experts_info_path}") 71 | 72 | if args.profile_super_experts: 73 | if not os.path.exists(evaluator.outliers_info_path): 74 | outliers_info_path = evaluator.outliers_profiler(args.vis_outliers_heatmap, require_hook=True) 75 | if not os.path.exists(evaluator.massive_experts_info_path): 76 | massive_experts_info_path = evaluator.massive_experts_profiler(args.vis_massive_experts_line_plot) 77 | super_experts_info_path = evaluator.super_experts_profiler(args.include_layers, args.vis_super_experts_line_plot) 78 | print(f"Super Experts Info is in {super_experts_info_path}") 79 | 80 | # ppl test 81 | if args.eval_ppl: 82 | if args.prune_super_experts: 83 | if not os.path.exists(evaluator.super_experts_info_path): 84 | print("Please perform profiling of the Super Experts first.") 85 | else: 86 | ppl = evaluator.evaluate_ppl(require_hook=False, prune_experts=args.prune_experts, prune_SE=args.prune_super_experts) 87 | print(f"After pruning Super Experts, Model {model_name} Dataset {args.dataset_name} PPL: {ppl}.") 88 | else: 89 | ppl = evaluator.evaluate_ppl(require_hook=False, prune_experts=args.prune_experts, prune_SE=args.prune_super_experts) 90 | print(f"Model {model_name} Dataset {args.dataset_name} PPL: {ppl}.") 91 | 92 | 93 | if __name__ == '__main__': 94 | args = parse_args() 95 | main(args) 96 | -------------------------------------------------------------------------------- /eval_utils.py: -------------------------------------------------------------------------------- 1 | from data_utils import get_loaders 2 | import torch 3 | import time 4 | import gc 5 | from tqdm import tqdm 6 | from collections import defaultdict 7 | import os 8 | import json 9 | import seaborn as sns 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | from scipy.stats import kurtosis 13 | 14 | class BaseEvaluator: 15 | 16 | def __init__(self, model, tokenizer, model_st, dev, args): 17 | self.model_name = model.__class__.__name__ 18 | self.model = model 19 | self.model.eval() 20 | self.tokenizer = tokenizer 21 | self.model_st = model_st 22 | self.dataset_name = args.dataset_name 23 | self.testloader = get_loaders(args.dataset_name, self.tokenizer, seed=args.seed, seqlen=self.model.seqlen, eval_mode=True) 24 | self.dev = dev 25 | self.args = args 26 | self.dtype = torch.bfloat16 27 | self.require_position_embeddings = None 28 | self.num_dense_layers = getattr(self.model.config, 'first_k_dense_replace', 0) 29 | self.total_layers = getattr(self.model.config, 'num_hidden_layers', None) 30 | self.outliers_info_path = os.path.join(self.args.save_path, "outliers_info") 31 | self.massive_experts_info_path = os.path.join(self.args.save_path, "massive_experts_info") 32 | self.layer_wise_save_path_of_massive = os.path.join(self.massive_experts_info_path, "layer_wise_analysis") 33 | self.expert_wise_save_path_of_massive = os.path.join(self.massive_experts_info_path, "expert_wise_analysis") 34 | self.super_experts_info_path = os.path.join(self.args.save_path, "super_experts_info") 35 | self.layer_wise_save_path_of_super = os.path.join(self.super_experts_info_path, "layer_wise_analysis") 36 | self.expert_wise_save_path_of_super = os.path.join(self.super_experts_info_path, "expert_wise_analysis") 37 | self.super_experts_report_path = os.path.join(self.super_experts_info_path, "super_experts_report") 38 | self.prune_experts=None 39 | self.SE_list = [] 40 | 41 | def _get_module(self, model, submodule_key): 42 | sub_tokens = submodule_key.split('.') 43 | cur_mod = model 44 | for s in sub_tokens: 45 | cur_mod = getattr(cur_mod, s) 46 | return cur_mod 47 | 48 | def _load_layer_weight(self, layer_idx): 49 | raise NotImplementedError("Subclasses should implement the evaluate_ppl method.") 50 | 51 | def _layer_wise_evaluate(self, require_position_embeddings, require_hook): 52 | dev = self.dev 53 | dtype = self.dtype 54 | use_cache = self.model.config.use_cache 55 | self.model.config.use_cache = False 56 | 57 | layers = self.model.model.layers 58 | # load embed_tokens 59 | self.model.model.embed_tokens = self.model.model.embed_tokens.to_empty(device=dev) 60 | self.model.model.embed_tokens.weight.data.copy_(self.model_st['model.embed_tokens.weight'].to(device=dev, dtype=dtype)) 61 | 62 | layers[0] = self._load_layer_weight(0) 63 | 64 | # Convert the whole text of evaluation dataset into batches of sequences. 65 | input_ids = self.testloader.input_ids # (1, text_len) 66 | nsamples = input_ids.numel() // self.model.seqlen # The tail is truncated. 67 | input_ids = input_ids[:, :nsamples * self.model.seqlen].view(nsamples, self.model.seqlen).to(dev) # (nsamples, seqlen) 68 | batch_size = 1 69 | input_ids = [input_ids[i:i + batch_size] for i in range(0, nsamples, batch_size)] 70 | nbatches = len(input_ids) 71 | inps = torch.zeros( 72 | (nbatches, batch_size, self.model.seqlen, self.model.config.hidden_size), dtype=dtype, device=dev 73 | ) 74 | inps = [0] * nbatches 75 | if not require_position_embeddings: 76 | cache = {'i': 0, 'attention_mask': None} 77 | class Catcher(torch.nn.Module): 78 | def __init__(self, module): 79 | super().__init__() 80 | self.module = module 81 | def forward(self, inp, **kwargs): 82 | inps[cache['i']] = inp 83 | cache['i'] += 1 84 | cache['attention_mask'] = kwargs['attention_mask'] 85 | raise ValueError 86 | else: 87 | cache = {'i': 0, 'attention_mask': None} 88 | cache = {'i': 0, 'position_embeddings': None} 89 | class Catcher(torch.nn.Module): 90 | def __init__(self, module): 91 | super().__init__() 92 | self.module = module 93 | def forward(self, inp, **kwargs): 94 | inps[cache['i']] = inp 95 | cache['i'] += 1 96 | cache['attention_mask'] = kwargs['attention_mask'] 97 | cache['position_embeddings'] = kwargs['position_embeddings'] 98 | raise ValueError 99 | 100 | layers[0] = Catcher(layers[0]) 101 | 102 | for i in range(nbatches): 103 | batch = input_ids[i] 104 | try: 105 | self.model(batch) 106 | except ValueError: 107 | pass 108 | layers[0] = layers[0].module 109 | layers[0] = layers[0].cpu() 110 | 111 | self.model.model.embed_tokens = self.model.model.embed_tokens.cpu() 112 | del self.model.model.embed_tokens 113 | gc.collect() 114 | torch.cuda.empty_cache() 115 | time.sleep(0.1) 116 | outs = [0] * nbatches 117 | 118 | if not require_position_embeddings: 119 | attention_mask = cache['attention_mask'] 120 | else: 121 | attention_mask = cache['attention_mask'] 122 | position_embeddings = cache['position_embeddings'] 123 | 124 | # hook 125 | if require_hook: 126 | def experts_down_proj_hook(layer_index,layer_type, name): 127 | def hook(module, input, output): 128 | input_shape = input[0].shape 129 | # output_shape = output[0].shape 130 | # print(f"name={name}") 131 | # print(f"input_shape={output_shape}") 132 | if input_shape[0] == 0: 133 | inps_max = torch.tensor(0).to(input[0].device) 134 | outps_max = torch.tensor(0).to(input[0].device) 135 | inpsmax_index = torch.tensor(-1).to(input[0].device) 136 | outpsmax_index= torch.tensor(-1).to(input[0].device) 137 | else: 138 | inps_max = input[0].abs().max() 139 | inpsmax_index = input[0].abs().argmax() 140 | outps_max = output[0].abs().max() 141 | outpsmax_index = output[0].abs().argmax() 142 | input_max_value[layer_index][layer_type].append(inps_max) 143 | input_max_channel_index[layer_index][layer_type].append(inpsmax_index) 144 | output_max_value[layer_index][layer_type].append(outps_max) 145 | output_max_channel_index[layer_index][layer_type].append(outpsmax_index) 146 | return hook 147 | input_max_value = defaultdict(lambda: defaultdict(list)) 148 | output_max_value = defaultdict(lambda: defaultdict(list)) 149 | input_max_channel_index = defaultdict(lambda: defaultdict(list)) 150 | output_max_channel_index = defaultdict(lambda: defaultdict(list)) 151 | for name, module in self.model.named_modules(): 152 | if "down" in name or "w2" in name: 153 | layer_index = name.split('.')[2] 154 | layer_type = name 155 | module.register_forward_hook(experts_down_proj_hook(layer_index, layer_type, name)) 156 | 157 | # load layer and forward 158 | if not require_position_embeddings: 159 | for i in tqdm(range(len(layers)), desc="(Eval) Layers"): 160 | layer = self._load_layer_weight(i) 161 | for j in range(nbatches): 162 | with torch.no_grad(): 163 | outs[j] = layer(inps[j], attention_mask=attention_mask)[0] 164 | layers[i] = None 165 | del layer 166 | torch.cuda.empty_cache() 167 | time.sleep(0.1) 168 | inps, outs = outs, inps 169 | 170 | else: 171 | for i in tqdm(range(len(layers)), desc="(Eval) Layers"): 172 | layer = self._load_layer_weight(i) 173 | for j in range(nbatches): 174 | with torch.no_grad(): 175 | outs[j] = layer(inps[j], attention_mask=attention_mask, position_embeddings=position_embeddings)[0] 176 | layers[i] = None 177 | del layer 178 | torch.cuda.empty_cache() 179 | time.sleep(0.1) 180 | inps, outs = outs, inps 181 | 182 | if require_hook: 183 | return input_max_value, output_max_value, input_max_channel_index, output_max_channel_index 184 | else: 185 | return nbatches, inps, input_ids, use_cache 186 | 187 | def _vis_outliers_heatmap(self, save_path, num_dense_layers, total_layers): 188 | json_folder_path = save_path 189 | layer_input_max = {} 190 | layer_output_max = {} 191 | for file_name in os.listdir(json_folder_path): 192 | if file_name.endswith(".json"): 193 | with open(os.path.join(json_folder_path, file_name), 'r') as f: 194 | data = json.load(f) 195 | for entry in data: 196 | layer_index = int(entry["layer_index"]) 197 | input_max = entry["input_max"] 198 | if int(layer_index) < num_dense_layers: 199 | continue 200 | output_max = entry["output_max"] 201 | if layer_index not in layer_input_max: 202 | layer_input_max[layer_index] = [] 203 | layer_output_max[layer_index] = [] 204 | layer_input_max[layer_index].append(input_max) 205 | layer_output_max[layer_index].append(output_max) 206 | 207 | layers = sorted(layer_input_max.keys()) 208 | num_experts = max(len(layer_input_max[layer]) for layer in layers) 209 | 210 | heatmap_data_input = [] 211 | for layer in layers: 212 | experts = layer_input_max[layer] 213 | heatmap_data_input.append(experts) 214 | 215 | heatmap_data_output = [] 216 | for layer in layers: 217 | experts = layer_output_max[layer] 218 | heatmap_data_output.append(experts) 219 | 220 | figure_save_path = os.path.join(json_folder_path, "figure") 221 | os.makedirs(figure_save_path, exist_ok=True) 222 | plt.figure(figsize=(16, 6)) 223 | sns.heatmap(heatmap_data_input, cmap="coolwarm", xticklabels=[f"{i}" for i in range(0, num_experts)], yticklabels=[f"{layer}" for layer in layers], vmax=1000) 224 | plt.title("Input Max Values Heatmap", fontsize=24) 225 | plt.xlabel("Expert", fontsize=24) 226 | plt.ylabel("Layer", fontsize=24) 227 | plt.xticks(ticks=np.arange(0, num_experts, 10), labels=[f"{i}" for i in range(0, num_experts, 10)], fontsize=12) 228 | plt.yticks(ticks=np.arange(0, len(layers), 10), labels=[f"{layers[i]}" for i in range(0, len(layers), 10)], fontsize=12) 229 | plt.tight_layout() 230 | plt.savefig(os.path.join(figure_save_path, "outliers_heap_map_input.pdf")) 231 | 232 | plt.figure(figsize=(16, 6)) 233 | sns.heatmap(heatmap_data_output, cmap="coolwarm", xticklabels=[f"{i}" for i in range(0, num_experts)], yticklabels=[f"{layer}" for layer in layers], vmax=1000) 234 | plt.title("Output Max Values Heatmap", fontsize=24) 235 | plt.xlabel("Expert", fontsize=24) 236 | plt.ylabel("Layer", fontsize=24) 237 | plt.xticks(ticks=np.arange(0, num_experts, 10), labels=[f"{i}" for i in range(0, num_experts, 10)], fontsize=12) 238 | plt.yticks(ticks=np.arange(0, len(layers), 10), labels=[f"{layers[i]}" for i in range(0, len(layers), 10)], fontsize=12) 239 | plt.tight_layout() 240 | plt.savefig(os.path.join(figure_save_path, "outliers_heap_map_output.pdf")) 241 | 242 | def _vis_massive_experts_line_plot(self, save_path, num_dense_layers, expert_wise): 243 | for filename in os.listdir(save_path): 244 | if filename.endswith(".json"): 245 | name = filename.split('.')[-2] 246 | file_path = os.path.join(save_path, filename) 247 | with open(file_path, 'r') as file: 248 | data = json.load(file) 249 | if expert_wise or int(data[0]['layer_index']) >= num_dense_layers: 250 | X = [] 251 | Y = [] 252 | if expert_wise: 253 | sorted_data = sorted(data, key=lambda x:int(x['layer_index']), reverse=False) 254 | for layer_info in sorted_data: 255 | X.append(int(layer_info['layer_index'])) 256 | Y.append(layer_info['output_max']) 257 | else: 258 | shared_experts_info = [] 259 | router_experts_info = [] 260 | for expert_info in data: 261 | if int(expert_info['layer_index']) >= num_dense_layers: 262 | if 'shared_experts' in expert_info["layer_type"]: 263 | shared_experts_info.append(expert_info) 264 | else: 265 | router_experts_info.append(expert_info) 266 | sorted_data = sorted(router_experts_info, key=lambda x: int(x["layer_type"].split('.')[-2]), reverse=False) 267 | for expert_info in sorted_data: 268 | X.append(int(expert_info['layer_type'].split('.')[-2])) 269 | Y.append(expert_info['output_max']) 270 | for expert_info in shared_experts_info: 271 | X.append(X[-1] + 1) 272 | Y.append(expert_info['output_max']) 273 | 274 | plt.figure(figsize=(10, 6)) 275 | plt.plot(X, Y, marker='o', linestyle='-', color='b') 276 | plt.ylabel('Output Max', fontsize=24) 277 | if expert_wise: 278 | plt.xlabel('Layer Index', fontsize=24) 279 | plt.title(f'Output Max for {name}', fontsize=20) 280 | else: 281 | plt.xlabel('Expert Index', fontsize=24) 282 | plt.title(f'Output Max for {name}', fontsize=20) 283 | plt.grid(True) 284 | plt.xticks(fontsize=12) 285 | plt.yticks(fontsize=12) 286 | plt.tight_layout() 287 | figure_save_path = os.path.join(save_path, "figure") 288 | os.makedirs(figure_save_path, exist_ok=True) 289 | plot_filename = os.path.join(figure_save_path, f"{name}_plot.pdf") 290 | plt.savefig(plot_filename) 291 | plt.close() 292 | 293 | def _calculate_kurtosis(self, data): 294 | output_max_values = [entry['output_max'] for entry in data if 'output_max' in entry] 295 | return kurtosis(output_max_values, fisher=True) 296 | 297 | def _plot_meta_from_json(self, save_path, object, expert_wise): 298 | if not expert_wise: 299 | layers = [] 300 | meta_values = [] 301 | for filename in os.listdir(save_path): 302 | if filename.endswith(".json"): 303 | file_path = os.path.join(save_path, filename) 304 | with open(file_path, 'r') as file: 305 | data = json.load(file) 306 | meta_data = data.get('meta_data', {}) 307 | meta = meta_data.get(object, None) 308 | layer_index = meta_data.get('layer_index', None) 309 | layers.append(int(layer_index)) 310 | meta_values.append(meta) 311 | sorted_layers, sorted_meta_values = zip(*sorted(zip(layers, meta_values))) 312 | plt.figure(figsize=(10, 6)) 313 | plt.plot(sorted_layers, sorted_meta_values, marker='o', linestyle='-', color='b') 314 | plt.title(f"{object} by Layer", fontsize=24) 315 | plt.xlabel("Layer Index", fontsize=24) 316 | plt.grid(True) 317 | max_layer = max(sorted_layers) 318 | plt.xticks(ticks=np.arange(0, max_layer + 1, 10), labels=[f"{i}" for i in np.arange(0, max_layer + 1, 10)], fontsize=12) 319 | plt.yticks(fontsize=12) 320 | plt.tight_layout() 321 | figure_path = os.path.join(self.super_experts_report_path, "figure") 322 | os.makedirs(figure_path, exist_ok=True) 323 | plt.savefig(os.path.join(figure_path, f"{object} by Layer.pdf")) 324 | plt.show() 325 | else: 326 | shared_experts = [] 327 | meta_values_of_shared_experts = [] 328 | router_experts = [] 329 | meta_values_of_router_experts = [] 330 | for filename in os.listdir(save_path): 331 | if filename.endswith(".json"): 332 | file_path = os.path.join(save_path, filename) 333 | with open(file_path, 'r') as file: 334 | data = json.load(file) 335 | meta_data = data.get('meta_data', {}) 336 | meta = meta_data.get(object, None) 337 | expert_index = meta_data.get('expert_index', None) 338 | if "shared_experts" in expert_index: 339 | shared_experts.append(expert_index) 340 | meta_values_of_shared_experts.append(meta) 341 | else: 342 | router_experts.append(int(expert_index)) 343 | meta_values_of_router_experts.append(meta) 344 | sorted_experts, sorted_meta_values = zip(*sorted(zip(router_experts, meta_values_of_router_experts))) 345 | sorted_experts = list(sorted_experts) 346 | sorted_meta_values = list(sorted_meta_values) 347 | for shared_expert in range(len(shared_experts)): 348 | sorted_experts.append(shared_experts[shared_expert]) 349 | sorted_meta_values.append(meta_values_of_shared_experts[shared_expert]) 350 | 351 | plt.figure(figsize=(10, 6)) 352 | plt.plot(sorted_experts, sorted_meta_values, marker='o', linestyle='-', color='b') 353 | plt.title(f"{object} by Expert", fontsize=24) 354 | plt.xlabel("Expert Index", fontsize=24) 355 | plt.grid(True) 356 | max_expert = len(sorted_experts) 357 | plt.xticks(ticks=np.arange(0, max_expert + 1, 10), labels=[f"{i}" for i in np.arange(0, max_expert + 1, 10)], fontsize=12) 358 | plt.yticks(fontsize=12) 359 | plt.tight_layout() 360 | figure_path = os.path.join(self.super_experts_report_path, "figure") 361 | os.makedirs(figure_path, exist_ok=True) 362 | plt.savefig(os.path.join(figure_path, f"{object} by Expert.pdf")) 363 | plt.show() 364 | 365 | def _generate_SE_list(self): 366 | SE_json_path = os.path.join(self.super_experts_report_path, "Super Experts Report.json") 367 | with open(SE_json_path, 'r') as f: 368 | se_data = json.load(f) 369 | for entry in se_data: 370 | layer_index = entry['layer_index'] 371 | expert_index = entry['expert_index'] 372 | if expert_index == "shared_experts": 373 | expert_index = "-1" 374 | self.SE_list.append((int(layer_index),int(expert_index))) 375 | print(f"Super Experts: {self.SE_list}") 376 | 377 | 378 | 379 | def _outliers_info_to_json(self, save_path, input_max_value, output_max_value, input_max_channel_index, output_max_channel_index): 380 | for layer_index in input_max_value: 381 | outliers_file_path = os.path.join(save_path, f"layer_{layer_index}.json") 382 | if not os.path.exists(outliers_file_path): 383 | with open(outliers_file_path, 'w') as outliers_file: 384 | json.dump([], outliers_file) 385 | for layer_type in input_max_value[layer_index]: 386 | input_max_value_stack = torch.stack(input_max_value[layer_index][layer_type]) 387 | input_max_value[layer_index][layer_type], input_max_index = torch.max(input_max_value_stack, dim=0) 388 | output_max_value_stack = torch.stack(output_max_value[layer_index][layer_type]) 389 | output_max_value[layer_index][layer_type], output_max_index = torch.max(output_max_value_stack, dim=0) 390 | layer_data = { 391 | 'layer_index':layer_index, 392 | 'layer_type':layer_type, 393 | 'input_max': input_max_value[layer_index][layer_type].item(), 394 | # 'input_max_channel': input_max_channel_index[layer_index][layer_type][input_max_index].item(), 395 | 'output_max': output_max_value[layer_index][layer_type].item(), 396 | 'output_max_channel': output_max_channel_index[layer_index][layer_type][output_max_index].item(), 397 | } 398 | with open(outliers_file_path, 'r+') as outliers_file: 399 | outliers_data = json.load(outliers_file) 400 | outliers_data.append(layer_data) 401 | outliers_file.seek(0) 402 | json.dump(outliers_data, outliers_file, indent=4) 403 | 404 | # sort experts 405 | for filename in os.listdir(save_path): 406 | if filename.endswith(".json"): 407 | file_path = os.path.join(save_path, filename) 408 | with open(file_path, 'r') as file: 409 | data = json.load(file) 410 | if "experts" in data[0]['layer_type']: 411 | router_experts = [] 412 | shared_experts = [] 413 | for entry in data: 414 | if 'mlp.experts' in entry['layer_type']: 415 | router_experts.append(entry) 416 | else: 417 | shared_experts.append(entry) 418 | experts_sorted = sorted(router_experts, key=lambda x: (int(x['layer_type'].split('.')[-2]), x['layer_type'])) 419 | experts_sorted.extend(shared_experts) 420 | save_file_path = os.path.join(save_path, filename) 421 | with open(save_file_path, 'w') as file: 422 | json.dump(experts_sorted, file, indent=4) 423 | 424 | def _massive_experts_info_to_json(self, save_path, outliers_info_path): 425 | # layer-wise datas 426 | layer_wise_save_path = self.layer_wise_save_path_of_massive 427 | if not os.path.exists(layer_wise_save_path): 428 | os.makedirs(layer_wise_save_path) 429 | for filename in os.listdir(outliers_info_path): 430 | if filename.endswith(".json"): 431 | file_path = os.path.join(outliers_info_path, filename) 432 | with open(file_path, 'r') as file: 433 | data = json.load(file) 434 | sorted_data = sorted(data, key=lambda x: x['output_max'], reverse=True) 435 | for idx, entry in enumerate(sorted_data, start=1): 436 | entry['rank'] = idx 437 | save_file_path = os.path.join(layer_wise_save_path, filename) 438 | with open(save_file_path, 'w') as file: 439 | json.dump(sorted_data, file, indent=4) 440 | 441 | # expert-wise datas 442 | expert_wise_save_path = self.expert_wise_save_path_of_massive 443 | if not os.path.exists(expert_wise_save_path): 444 | os.makedirs(expert_wise_save_path) 445 | experts_data = {} 446 | for filename in os.listdir(outliers_info_path): 447 | if filename.endswith(".json"): 448 | file_path = os.path.join(outliers_info_path, filename) 449 | with open(file_path, 'r') as file: 450 | data = json.load(file) 451 | for entry in data: 452 | if "experts" in entry['layer_type']: 453 | if 'shared_experts' in entry['layer_type']: 454 | expert_index = "shared_experts" 455 | else: 456 | expert_index = entry["layer_type"].split('.')[-2] 457 | if expert_index not in experts_data: 458 | experts_data[expert_index] = [] 459 | experts_data[expert_index].append(entry) 460 | 461 | for expert_index, expert_entries in experts_data.items(): 462 | output_filename = f"expert_{expert_index}.json" 463 | output_path = os.path.join(expert_wise_save_path, output_filename) 464 | expert_entries = sorted(expert_entries, key=lambda x: x['output_max'], reverse=True) 465 | for idx, entry in enumerate(expert_entries, start=1): 466 | entry['rank'] = idx 467 | with open(output_path, 'w') as output_file: 468 | json.dump(expert_entries, output_file, indent=4) 469 | 470 | def _super_experts_analysis(self, save_path, massive_experts_info_path, total_layers, expert_wise, include_layers=0.75): 471 | include_layers = round(total_layers * include_layers) 472 | for filename in os.listdir(massive_experts_info_path): 473 | if filename.endswith(".json"): 474 | file_path = os.path.join(massive_experts_info_path, filename) 475 | with open(file_path, 'r') as file: 476 | data = json.load(file) 477 | new_data = [] 478 | for each in data: 479 | if int(each['layer_index']) < include_layers: 480 | new_data.append(each) 481 | if new_data == []: 482 | continue 483 | kurt_value = self._calculate_kurtosis(new_data) 484 | if not expert_wise: 485 | meta_data = { 486 | 'include_layer': include_layers, 487 | 'output_max_kurtosis': kurt_value, 488 | "layer_index": new_data[0]['layer_index'], 489 | 'max_output_max': new_data[0]['output_max'], 490 | 'kurtosis*max_output_max': kurt_value * new_data[0]['output_max'] 491 | } 492 | else: 493 | if 'shared_experts' in new_data[0]['layer_type']: 494 | expert_index = "shared_experts" 495 | else: 496 | expert_index = new_data[0]["layer_type"].split('.')[-2] 497 | meta_data = { 498 | 'include_layer': include_layers, 499 | 'output_max_kurtosis': kurt_value, 500 | "expert_index": expert_index, 501 | 'max_output_max': new_data[0]['output_max'], 502 | 'kurtosis*max_output_max': kurt_value * new_data[0]['output_max'] 503 | } 504 | 505 | updated_data = { 506 | 'meta_data': meta_data, 507 | 'data': new_data 508 | } 509 | 510 | with open(os.path.join(save_path, filename), 'w') as file: 511 | json.dump(updated_data, file, indent=4) 512 | 513 | def _identify_super_experts_std(self, save_path, std_multiplier=3): 514 | results = [] 515 | for filename in os.listdir(save_path): 516 | if filename.endswith(".json"): 517 | file_path = os.path.join(save_path, filename) 518 | with open(file_path, 'r') as file: 519 | data = json.load(file) 520 | expert_index = data["meta_data"].get("expert_index", None) 521 | for item in data.get("data", []): 522 | output_max = item.get("output_max", None) 523 | layer_index = item.get("layer_index", None) 524 | if output_max is not None: 525 | results.append({ 526 | "expert_index": expert_index, 527 | "layer_index": layer_index, 528 | "output_max": output_max 529 | }) 530 | 531 | output_max_values = [item['output_max'] for item in results] 532 | mean = np.mean(output_max_values) 533 | std_dev = np.std(output_max_values) 534 | threshold = mean + std_multiplier * std_dev 535 | 536 | Super_Experts = [] 537 | for item in results: 538 | if item['output_max'] > threshold: 539 | Super_Experts.append({ 540 | 'expert_index': item['expert_index'], 541 | 'layer_index': item['layer_index'], 542 | 'output_max': item['output_max'] 543 | }) 544 | Super_Experts.sort(key=lambda x: x['output_max'], reverse=True) 545 | for rank, item in enumerate(Super_Experts, start=1): 546 | item['rank'] = rank 547 | return Super_Experts 548 | 549 | def _identify_super_experts_quantile(self, save_path, quantile=99.5): 550 | results = [] 551 | for filename in os.listdir(save_path): 552 | if filename.endswith(".json"): 553 | file_path = os.path.join(save_path, filename) 554 | with open(file_path, 'r') as file: 555 | data = json.load(file) 556 | expert_index = data["meta_data"].get("expert_index", None) 557 | for item in data.get("data", []): 558 | output_max = item.get("output_max", None) 559 | layer_index = item.get("layer_index", None) 560 | if output_max is not None: 561 | results.append({ 562 | "expert_index": expert_index, 563 | "layer_index": layer_index, 564 | "output_max": output_max 565 | }) 566 | 567 | output_max_values = [item['output_max'] for item in results] 568 | percentile = np.percentile(output_max_values, quantile) 569 | threshold = percentile 570 | 571 | Super_Experts = [] 572 | for item in results: 573 | if item['output_max'] > threshold: 574 | Super_Experts.append({ 575 | 'expert_index': item['expert_index'], 576 | 'layer_index': item['layer_index'], 577 | 'output_max': item['output_max'] 578 | }) 579 | Super_Experts.sort(key=lambda x: x['output_max'], reverse=True) 580 | for rank, item in enumerate(Super_Experts, start=1): 581 | item['rank'] = rank 582 | return Super_Experts 583 | 584 | def _identify_super_experts_aver(self, save_path, times=50): 585 | results = [] 586 | for filename in os.listdir(save_path): 587 | if filename.endswith(".json"): 588 | file_path = os.path.join(save_path, filename) 589 | with open(file_path, 'r') as file: 590 | data = json.load(file) 591 | expert_index = data["meta_data"].get("expert_index", None) 592 | for item in data.get("data", []): 593 | output_max = item.get("output_max", None) 594 | layer_index = item.get("layer_index", None) 595 | if output_max is not None: 596 | results.append({ 597 | "expert_index": expert_index, 598 | "layer_index": layer_index, 599 | "output_max": output_max 600 | }) 601 | 602 | output_max_values = [item['output_max'] for item in results] 603 | average = np.mean(output_max_values) 604 | threshold = average * times 605 | 606 | Super_Experts = [] 607 | for item in results: 608 | if item['output_max'] > threshold: 609 | Super_Experts.append({ 610 | 'expert_index': item['expert_index'], 611 | 'layer_index': item['layer_index'], 612 | 'output_max': item['output_max'] 613 | }) 614 | Super_Experts.sort(key=lambda x: x['output_max'], reverse=True) 615 | for rank, item in enumerate(Super_Experts, start=1): 616 | item['rank'] = rank 617 | return Super_Experts 618 | 619 | def _identify_super_experts(self, save_path, quantile=99.5, times=10): 620 | results = [] 621 | for filename in os.listdir(save_path): 622 | if filename.endswith(".json"): 623 | file_path = os.path.join(save_path, filename) 624 | with open(file_path, 'r') as file: 625 | data = json.load(file) 626 | expert_index = data["meta_data"].get("expert_index", None) 627 | for item in data.get("data", []): 628 | output_max = item.get("output_max", None) 629 | layer_index = item.get("layer_index", None) 630 | if output_max is not None: 631 | results.append({ 632 | "expert_index": expert_index, 633 | "layer_index": layer_index, 634 | "output_max": output_max 635 | }) 636 | 637 | output_max_values = [item['output_max'] for item in results] 638 | percentile = np.percentile(output_max_values, quantile) 639 | 640 | Super_Experts = [] 641 | for item in results: 642 | if item['output_max'] > percentile and item['output_max'] > np.max(output_max_values) // times: 643 | Super_Experts.append({ 644 | 'expert_index': item['expert_index'], 645 | 'layer_index': item['layer_index'], 646 | 'output_max': item['output_max'] 647 | }) 648 | Super_Experts.sort(key=lambda x: x['output_max'], reverse=True) 649 | for rank, item in enumerate(Super_Experts, start=1): 650 | item['rank'] = rank 651 | return Super_Experts 652 | 653 | 654 | def evaluate_ppl(self, require_hook, prune_experts, prune_SE): 655 | self.prune_experts=prune_experts 656 | if prune_SE: 657 | self._generate_SE_list() 658 | dev = self.dev 659 | dtype = self.dtype 660 | nbatches, inps, input_ids, use_cache = self._layer_wise_evaluate(self.require_position_embeddings, require_hook) 661 | self.model.eval() 662 | 663 | # load norm and lm_head 664 | self.model.model.norm = self.model.model.norm.to_empty(device=dev) 665 | self.model.model.norm.weight.data.copy_(self.model_st['model.norm.weight'].to(device=dev, dtype=dtype)) 666 | self.model.lm_head = self.model.lm_head.to_empty(device=dev) 667 | self.model.lm_head.weight.data.copy_(self.model_st['lm_head.weight'].to(device=dev, dtype=dtype)) 668 | nlls = [] 669 | loss_fct = torch.nn.CrossEntropyLoss(reduction = "none") 670 | with torch.no_grad(): 671 | for i in range(nbatches): 672 | hidden_states = inps[i] 673 | hidden_states = self.model.model.norm(hidden_states) 674 | lm_logits = self.model.lm_head(hidden_states) 675 | shift_logits = lm_logits[:, :-1, :] 676 | shift_labels = input_ids[i][:, 1:] 677 | loss = loss_fct(shift_logits.permute(0, 2, 1), shift_labels) 678 | neg_log_likelihood = loss.float().mean(dim=1) 679 | nlls.append(neg_log_likelihood) 680 | nlls_tensor = torch.cat(nlls) 681 | ppl = torch.exp(nlls_tensor.mean()) 682 | self.model.config.use_cache = use_cache 683 | 684 | # release gpu memory 685 | self.model.model.norm.cpu() 686 | self.model.lm_head.cpu() 687 | del self.model.model.norm 688 | del self.model.lm_head 689 | torch.cuda.empty_cache() 690 | 691 | return ppl.item() 692 | 693 | def outliers_profiler(self, vis_outliers_heatmap, require_hook=True): 694 | os.makedirs(self.outliers_info_path, exist_ok=True) 695 | input_max_value, output_max_value, input_max_channel_index, output_max_channel_index = self._layer_wise_evaluate(self.require_position_embeddings, require_hook) 696 | self._outliers_info_to_json(self.outliers_info_path, input_max_value, output_max_value, input_max_channel_index, output_max_channel_index) 697 | if vis_outliers_heatmap: 698 | self._vis_outliers_heatmap(self.outliers_info_path, self.num_dense_layers, self.total_layers) 699 | return self.outliers_info_path 700 | 701 | def massive_experts_profiler(self, vis_massive_experts_line_plot): 702 | os.makedirs(self.massive_experts_info_path, exist_ok=True) 703 | os.makedirs(self.expert_wise_save_path_of_massive, exist_ok=True) 704 | os.makedirs(self.layer_wise_save_path_of_massive, exist_ok=True) 705 | self._massive_experts_info_to_json(self.massive_experts_info_path, self.outliers_info_path) 706 | if vis_massive_experts_line_plot: 707 | self._vis_massive_experts_line_plot(self.expert_wise_save_path_of_massive, self.num_dense_layers, expert_wise=True) 708 | self._vis_massive_experts_line_plot(self.layer_wise_save_path_of_massive, self.num_dense_layers, expert_wise=False) 709 | return self.massive_experts_info_path 710 | 711 | def super_experts_profiler(self, include_layers, vis_super_experts_line_plot): 712 | os.makedirs(self.super_experts_info_path, exist_ok=True) 713 | os.makedirs(self.expert_wise_save_path_of_super, exist_ok=True) 714 | os.makedirs(self.layer_wise_save_path_of_super, exist_ok=True) 715 | os.makedirs(self.super_experts_report_path, exist_ok=True) 716 | self._super_experts_analysis(self.expert_wise_save_path_of_super, self.expert_wise_save_path_of_massive, self.total_layers, expert_wise=True, include_layers=include_layers) 717 | self._super_experts_analysis(self.layer_wise_save_path_of_super, self.layer_wise_save_path_of_massive, self.total_layers, expert_wise=False, include_layers=include_layers) 718 | 719 | Super_Experts = self._identify_super_experts(self.expert_wise_save_path_of_super) 720 | with open(os.path.join(self.super_experts_report_path, 'Super Experts Report.json'), 'w') as f: 721 | json.dump(Super_Experts, f, indent=4) 722 | 723 | if vis_super_experts_line_plot: 724 | # for object in ['output_max_kurtosis', 'max_output_max', 'kurtosis*max_output_max']: 725 | for object in ['max_output_max']: 726 | self._plot_meta_from_json(self.expert_wise_save_path_of_super, object, expert_wise=True) 727 | self._plot_meta_from_json(self.layer_wise_save_path_of_super, object, expert_wise=False) 728 | 729 | 730 | return self.super_experts_info_path 731 | 732 | 733 | class DeepSeekEvaluator(BaseEvaluator): 734 | def __init__(self, model, tokenizer, model_st, dev, args): 735 | super().__init__(model, tokenizer, model_st, dev, args) 736 | self.require_position_embeddings = False 737 | 738 | def _load_layer_weight(self, layer_idx): 739 | if self.prune_experts is not None: 740 | prune_list_router = [f"model.layers.{layer}.mlp.experts.{expert}" for layer, expert in self.prune_experts if int(expert) != -1] 741 | else: 742 | prune_list_router = [] 743 | if self.SE_list: 744 | for layer, expert in self.SE_list: 745 | if int(expert) != -1: 746 | prune_list_router.append(f"model.layers.{layer}.mlp.experts.{expert}") 747 | layer_key = f"model.layers.{layer_idx}" 748 | layer = self._get_module(self.model, layer_key) 749 | dev = self.dev 750 | dtype = self.dtype 751 | 752 | # initialize meta tensor of attention 753 | ## layernorm 754 | W = layer.input_layernorm.to_empty(device=dev).to(dtype=dtype) 755 | W.weight.data.copy_(self.model_st[layer_key + '.input_layernorm.weight'].to(device=dev, dtype=dtype)) 756 | W = layer.post_attention_layernorm.to_empty(device=dev).to(dtype=dtype) 757 | W.weight.data.copy_(self.model_st[layer_key + '.post_attention_layernorm.weight'].to(device=dev, dtype=dtype)) 758 | W = layer.self_attn.kv_a_layernorm.to_empty(device=dev).to(dtype=dtype) 759 | W.weight.data.copy_(self.model_st[layer_key + '.self_attn.kv_a_layernorm.weight'].to(device=dev, dtype=dtype)) 760 | if hasattr(layer.self_attn, 'q_a_layernorm'): 761 | W = layer.self_attn.q_a_layernorm.to_empty(device=dev).to(dtype=dtype) 762 | W.weight.data.copy_(self.model_st[layer_key + '.self_attn.q_a_layernorm.weight'].to(device=dev, dtype=dtype)) 763 | 764 | ## mla 765 | if hasattr(layer.self_attn, 'q_b_proj'): 766 | W = layer.self_attn.q_a_proj.to_empty(device=dev).to(dtype=dtype) 767 | W.weight.data.copy_(self.model_st[layer_key + '.self_attn.q_a_proj.weight'].to(device=dev, dtype=dtype)) 768 | W = layer.self_attn.q_b_proj.to_empty(device=dev).to(dtype=dtype) 769 | W.weight.data.copy_(self.model_st[layer_key + '.self_attn.q_b_proj.weight'].to(device=dev, dtype=dtype)) 770 | else: 771 | W = layer.self_attn.q_proj.to_empty(device=dev).to(dtype=dtype) 772 | W.weight.data.copy_(self.model_st[layer_key + '.self_attn.q_proj.weight'].to(device=dev, dtype=dtype)) 773 | W = layer.self_attn.kv_a_proj_with_mqa.to_empty(device=dev).to(dtype=dtype) 774 | W.weight.data.copy_(self.model_st[layer_key + '.self_attn.kv_a_proj_with_mqa.weight'].to(device=dev, dtype=dtype)) 775 | W = layer.self_attn.kv_b_proj.to_empty(device=dev).to(dtype=dtype) 776 | W.weight.data.copy_(self.model_st[layer_key + '.self_attn.kv_b_proj.weight'].to(device=dev, dtype=dtype)) 777 | W = layer.self_attn.o_proj.to_empty(device=dev).to(dtype=dtype) 778 | W.weight.data.copy_(self.model_st[layer_key + '.self_attn.o_proj.weight'].to(device=dev, dtype=dtype)) 779 | 780 | # initialize meta tensor of mlp 781 | if hasattr(layer.mlp, 'experts'): 782 | expert_num = len(layer.mlp.experts) 783 | ## experts 784 | for expert_idx in range(expert_num): 785 | expert_key = f"model.layers.{layer_idx}.mlp.experts.{expert_idx}" 786 | expert = self._get_module(self.model, expert_key) 787 | W = expert.up_proj.to_empty(device=dev).to(dtype=dtype) 788 | W.weight.data.copy_(self.model_st[expert_key + '.up_proj.weight'].to(device=dev, dtype=dtype)) 789 | W = expert.gate_proj.to_empty(device=dev).to(dtype=dtype) 790 | W.weight.data.copy_(self.model_st[expert_key + '.gate_proj.weight'].to(device=dev, dtype=dtype)) 791 | W = expert.down_proj.to_empty(device=dev).to(dtype=dtype) 792 | 793 | if expert_key in prune_list_router: 794 | print(f"prune {expert_key}") 795 | W.weight.data.copy_(torch.zeros_like(self.model_st[expert_key + '.down_proj.weight'].to(device=dev, dtype=dtype))) 796 | else: 797 | W.weight.data.copy_(self.model_st[expert_key + '.down_proj.weight'].to(device=dev, dtype=dtype)) 798 | 799 | ## router 800 | router_key = f"model.layers.{layer_idx}.mlp.gate" 801 | bias_key = f"model.layers.{layer_idx}.mlp.gate.e_score_correction_bias" 802 | W = layer.mlp.gate.to_empty(device=dev).to(dtype=torch.float32) 803 | W.weight.data.copy_(self.model_st[router_key + '.weight'].to(device=dev, dtype=torch.float32)) 804 | if hasattr(layer.mlp.gate, "e_score_correction_bias"): 805 | W.e_score_correction_bias.data.copy_(self.model_st[bias_key ].to(device=dev, dtype=torch.float32)) 806 | W = layer.mlp.shared_experts.up_proj.to_empty(device=dev).to(dtype=dtype) 807 | W.weight.data.copy_(self.model_st[layer_key + '.mlp.shared_experts.up_proj.weight'].to(device=dev, dtype=dtype)) 808 | W = layer.mlp.shared_experts.gate_proj.to_empty(device=dev).to(dtype=dtype) 809 | W.weight.data.copy_(self.model_st[layer_key + '.mlp.shared_experts.gate_proj.weight'].to(device=dev, dtype=dtype)) 810 | W = layer.mlp.shared_experts.down_proj.to_empty(device=dev).to(dtype=dtype) 811 | if self.prune_experts is not None: 812 | prune_list_shared = [int(layer) for layer, expert in self.prune_experts if int(expert) == -1] 813 | else: 814 | prune_list_shared = [] 815 | for LAYER, EXPERT in self.SE_list: 816 | if int(EXPERT) == -1: 817 | prune_list_shared.append(int(LAYER)) 818 | if layer_idx in prune_list_shared: 819 | print(f"prune model.layers.{layer_idx}.mlp.shared_experts") 820 | W.weight.data.copy_(torch.zeros_like(self.model_st[layer_key + '.mlp.shared_experts.down_proj.weight'].to(device=dev, dtype=dtype))) 821 | else: 822 | W.weight.data.copy_(self.model_st[layer_key + '.mlp.shared_experts.down_proj.weight'].to(device=dev, dtype=dtype)) 823 | 824 | else: 825 | W = layer.mlp.up_proj.to_empty(device=dev).to(dtype=dtype) 826 | W.weight.data.copy_(self.model_st[layer_key + '.mlp.up_proj.weight'].to(device=dev, dtype=dtype)) 827 | W = layer.mlp.gate_proj.to_empty(device=dev).to(dtype=dtype) 828 | W.weight.data.copy_(self.model_st[layer_key + '.mlp.gate_proj.weight'].to(device=dev, dtype=dtype)) 829 | W = layer.mlp.down_proj.to_empty(device=dev).to(dtype=dtype) 830 | W.weight.data.copy_(self.model_st[layer_key + '.mlp.down_proj.weight'].to(device=dev, dtype=dtype)) 831 | 832 | return layer 833 | 834 | 835 | class Qwen3MoeEvaluator(BaseEvaluator): 836 | def __init__(self, model, tokenizer, model_st, dev, args): 837 | super().__init__(model, tokenizer, model_st, dev, args) 838 | self.require_position_embeddings = True 839 | 840 | def _load_layer_weight(self, layer_idx): 841 | if self.prune_experts is not None: 842 | prune_list_router = [f"model.layers.{layer}.mlp.experts.{expert}" for layer, expert in self.prune_experts if int(expert) != -1] 843 | else: 844 | prune_list_router = [] 845 | if self.SE_list: 846 | for layer, expert in self.SE_list: 847 | if int(expert) != -1: 848 | prune_list_router.append(f"model.layers.{layer}.mlp.experts.{expert}") 849 | # print(f"prune_list_router={prune_list_router}") 850 | 851 | layer_key = f"model.layers.{layer_idx}" 852 | layer = self._get_module(self.model, layer_key) 853 | dev = self.dev 854 | dtype = self.dtype 855 | 856 | # initialize meta tensor of attention 857 | ## layernorm 858 | W = layer.input_layernorm.to_empty(device=dev).to(dtype=dtype) 859 | W.weight.data.copy_(self.model_st[layer_key + '.input_layernorm.weight'].to(device=dev, dtype=dtype)) 860 | W = layer.post_attention_layernorm.to_empty(device=dev).to(dtype=dtype) 861 | W.weight.data.copy_(self.model_st[layer_key + '.post_attention_layernorm.weight'].to(device=dev, dtype=dtype)) 862 | W = layer.self_attn.k_norm.to_empty(device=dev).to(dtype=dtype) 863 | W.weight.data.copy_(self.model_st[layer_key + '.self_attn.k_norm.weight'].to(device=dev, dtype=dtype)) 864 | W = layer.self_attn.q_norm.to_empty(device=dev).to(dtype=dtype) 865 | W.weight.data.copy_(self.model_st[layer_key + '.self_attn.q_norm.weight'].to(device=dev, dtype=dtype)) 866 | 867 | 868 | ## mha 869 | W = layer.self_attn.q_proj.to_empty(device=dev).to(dtype=dtype) 870 | W.weight.data.copy_(self.model_st[layer_key + '.self_attn.q_proj.weight'].to(device=dev, dtype=dtype)) 871 | W = layer.self_attn.k_proj.to_empty(device=dev).to(dtype=dtype) 872 | W.weight.data.copy_(self.model_st[layer_key + '.self_attn.k_proj.weight'].to(device=dev, dtype=dtype)) 873 | W = layer.self_attn.v_proj.to_empty(device=dev).to(dtype=dtype) 874 | W.weight.data.copy_(self.model_st[layer_key + '.self_attn.v_proj.weight'].to(device=dev, dtype=dtype)) 875 | W = layer.self_attn.o_proj.to_empty(device=dev).to(dtype=dtype) 876 | W.weight.data.copy_(self.model_st[layer_key + '.self_attn.o_proj.weight'].to(device=dev, dtype=dtype)) 877 | ## rotary 878 | 879 | # initialize meta tensor of mlp 880 | if hasattr(layer.mlp, 'experts'): 881 | expert_num = len(layer.mlp.experts) 882 | ## experts 883 | for expert_idx in range(expert_num): 884 | expert_key = f"model.layers.{layer_idx}.mlp.experts.{expert_idx}" 885 | expert = self._get_module(self.model, expert_key) 886 | W = expert.up_proj.to_empty(device=dev).to(dtype=dtype) 887 | W.weight.data.copy_(self.model_st[expert_key + '.up_proj.weight'].to(device=dev, dtype=dtype)) 888 | W = expert.gate_proj.to_empty(device=dev).to(dtype=dtype) 889 | W.weight.data.copy_(self.model_st[expert_key + '.gate_proj.weight'].to(device=dev, dtype=dtype)) 890 | W = expert.down_proj.to_empty(device=dev).to(dtype=dtype) 891 | if expert_key in prune_list_router: 892 | print(f"prune {expert_key}") 893 | W.weight.data.copy_(torch.zeros_like(self.model_st[expert_key + '.down_proj.weight'].to(device=dev, dtype=dtype))) 894 | else: 895 | W.weight.data.copy_(self.model_st[expert_key + '.down_proj.weight'].to(device=dev, dtype=dtype)) 896 | ## router 897 | router_key = f"model.layers.{layer_idx}.mlp.gate" 898 | 899 | W = layer.mlp.gate.to_empty(device=dev).to(dtype=dtype) 900 | W.weight.data.copy_(self.model_st[router_key + '.weight'].to(device=dev, dtype=dtype)) 901 | else: 902 | W = layer.mlp.up_proj.to_empty(device=dev).to(dtype=dtype) 903 | W.weight.data.copy_(self.model_st[layer_key + '.mlp.up_proj.weight'].to(device=dev, dtype=dtype)) 904 | W = layer.mlp.gate_proj.to_empty(device=dev).to(dtype=dtype) 905 | W.weight.data.copy_(self.model_st[layer_key + '.mlp.gate_proj.weight'].to(device=dev, dtype=dtype)) 906 | W = layer.mlp.down_proj.to_empty(device=dev).to(dtype=dtype) 907 | W.weight.data.copy_(self.model_st[layer_key + '.mlp.down_proj.weight'].to(device=dev, dtype=dtype)) 908 | 909 | return layer 910 | 911 | 912 | class MixtralEvaluator(BaseEvaluator): 913 | def __init__(self, model, tokenizer, model_st, dev, args): 914 | super().__init__(model, tokenizer, model_st, dev, args) 915 | self.require_position_embeddings = True 916 | 917 | def _load_layer_weight(self, layer_idx): 918 | if self.prune_experts is not None: 919 | prune_list_router = [f"model.layers.{layer}.block_sparse_moe.experts.{expert}" for layer, expert in self.prune_experts if int(expert) != -1] 920 | else: 921 | prune_list_router = [] 922 | if self.SE_list: 923 | for layer, expert in self.SE_list: 924 | if int(expert) != -1: 925 | prune_list_router.append(f"model.layers.{layer}.block_sparse_moe.experts.{expert}") 926 | # print(f"prune_list_router={prune_list_router}") 927 | layer_key = f"model.layers.{layer_idx}" 928 | layer = self._get_module(self.model, layer_key) 929 | dev = self.dev 930 | dtype = self.dtype 931 | 932 | # initialize meta tensor of attention 933 | ## layernorm 934 | W = layer.input_layernorm.to_empty(device=dev).to(dtype=dtype) 935 | W.weight.data.copy_(self.model_st[layer_key + '.input_layernorm.weight'].to(device=dev, dtype=dtype)) 936 | W = layer.post_attention_layernorm.to_empty(device=dev).to(dtype=dtype) 937 | W.weight.data.copy_(self.model_st[layer_key + '.post_attention_layernorm.weight'].to(device=dev, dtype=dtype)) 938 | 939 | ## mha 940 | W = layer.self_attn.q_proj.to_empty(device=dev).to(dtype=dtype) 941 | W.weight.data.copy_(self.model_st[layer_key + '.self_attn.q_proj.weight'].to(device=dev, dtype=dtype)) 942 | W = layer.self_attn.k_proj.to_empty(device=dev).to(dtype=dtype) 943 | W.weight.data.copy_(self.model_st[layer_key + '.self_attn.k_proj.weight'].to(device=dev, dtype=dtype)) 944 | W = layer.self_attn.v_proj.to_empty(device=dev).to(dtype=dtype) 945 | W.weight.data.copy_(self.model_st[layer_key + '.self_attn.v_proj.weight'].to(device=dev, dtype=dtype)) 946 | W = layer.self_attn.o_proj.to_empty(device=dev).to(dtype=dtype) 947 | W.weight.data.copy_(self.model_st[layer_key + '.self_attn.o_proj.weight'].to(device=dev, dtype=dtype)) 948 | ## rotary 949 | # W = layer.self_attn.rotary_emb.to_empty(device=dev).to(device=dev, dtype=torch.float32) 950 | # W.inv_freq.data.copy_(model_st[layer_key + '.self_attn.rotary_emb.inv_freq'].to(device=dev, dtype=torch.float32)) 951 | 952 | # initialize meta tensor of mlp 953 | expert_num = len(layer.block_sparse_moe.experts) 954 | ## experts 955 | for expert_idx in range(expert_num): 956 | expert_key = f"model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}" 957 | expert = self._get_module(self.model, expert_key) 958 | W = expert.w1.to_empty(device=dev).to(dtype=dtype) 959 | W.weight.data.copy_(self.model_st[expert_key + '.w1.weight'].to(device=dev, dtype=dtype)) 960 | W = expert.w2.to_empty(device=dev).to(dtype=dtype) 961 | W.weight.data.copy_(self.model_st[expert_key + '.w2.weight'].to(device=dev, dtype=dtype)) 962 | W = expert.w3.to_empty(device=dev).to(dtype=dtype) 963 | 964 | if expert_key in prune_list_router: 965 | print(f"prune {expert_key}") 966 | W.weight.data.copy_(torch.zeros_like(self.model_st[expert_key + '.w3.weight'].to(device=dev, dtype=dtype))) 967 | else: 968 | W.weight.data.copy_(self.model_st[expert_key + '.w3.weight'].to(device=dev, dtype=dtype)) 969 | ## router 970 | router_key = f"model.layers.{layer_idx}.block_sparse_moe.gate" 971 | W = layer.block_sparse_moe.gate.to_empty(device=dev).to(dtype=dtype) 972 | W.weight.data.copy_(self.model_st[router_key + '.weight'].to(device=dev, dtype=dtype)) 973 | 974 | 975 | return layer --------------------------------------------------------------------------------