├── img ├── intro.png ├── ruler.png ├── ntk_extra.png └── direct_extra.png ├── ppl ├── diffusion_niah.png ├── diffusion_ppl_gov.png ├── get_ppl_llama.py ├── get_ppl_llada.py └── get_ppl_plot.py ├── needlebench ├── needlebench │ ├── needlebench.py │ ├── needlebench_multi_retrieval.py │ └── needlebench_single.py ├── origin.py ├── needlebench.py └── needlebench_summarizer.py ├── eval ├── eval_llada_ruler.py ├── eval_llada_niah.py └── eval_llada_long.py ├── llada ├── llada_generate.py └── llada_wrapper.py └── README.md /img/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMOSS/LongLLaDA/HEAD/img/intro.png -------------------------------------------------------------------------------- /img/ruler.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMOSS/LongLLaDA/HEAD/img/ruler.png -------------------------------------------------------------------------------- /img/ntk_extra.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMOSS/LongLLaDA/HEAD/img/ntk_extra.png -------------------------------------------------------------------------------- /img/direct_extra.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMOSS/LongLLaDA/HEAD/img/direct_extra.png -------------------------------------------------------------------------------- /ppl/diffusion_niah.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMOSS/LongLLaDA/HEAD/ppl/diffusion_niah.png -------------------------------------------------------------------------------- /ppl/diffusion_ppl_gov.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenMOSS/LongLLaDA/HEAD/ppl/diffusion_ppl_gov.png -------------------------------------------------------------------------------- /needlebench/needlebench/needlebench.py: -------------------------------------------------------------------------------- 1 | from mmengine.config import read_base 2 | 3 | with read_base(): 4 | 5 | from .needlebench_single import needlebench_en_datasets as needlebench_origin_en_datasets 6 | from .needlebench_single import needlebench_zh_datasets as needlebench_origin_zh_datasets 7 | from .needlebench_multi_retrieval import needlebench_en_datasets as needlebench_parallel_en_datasets 8 | from .needlebench_multi_retrieval import needlebench_zh_datasets as needlebench_parallel_zh_datasets 9 | 10 | needlebench_datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), []) 11 | -------------------------------------------------------------------------------- /ppl/get_ppl_llama.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from transformers import AutoModelForCausalLM, AutoTokenizer 7 | 8 | 9 | def main(path): 10 | 11 | device = 'cuda:0' 12 | 13 | print(path) 14 | 15 | model = AutoModelForCausalLM.from_pretrained(path, trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval() 16 | tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) 17 | 18 | file = '/'.join(os.path.realpath(__file__).split('/')[:-1]) 19 | file = f'{file}/gov_report_001.txt' 20 | 21 | text = open(file, mode='r').readline() 22 | 23 | input_ids = tokenizer(text)['input_ids'] 24 | 25 | prompt_len = 16385 if len(input_ids) > 16384 else ((len(input_ids) // 64) * 64 + 1) 26 | 27 | prompt = torch.tensor([input_ids[:prompt_len]]).to(device) 28 | 29 | with torch.no_grad(): 30 | outputs = model(prompt, labels=prompt) 31 | loss = F.cross_entropy(outputs.logits[0, :-1], prompt[0, 1:], reduction='none') # / p_mask[mask_index] 32 | loss = torch.cumsum(loss, dim=-1) / (torch.arange(prompt_len-1).to(loss) + 1) 33 | perplexity = torch.exp(loss) 34 | 35 | print(f"Perplexity: {perplexity.float().detach().cpu().numpy().tolist()[64::64]}", flush=True) 36 | 37 | 38 | if __name__ == '__main__': 39 | 40 | path = 'meta-llama/Meta-Llama-3-8B-Instruct' 41 | 42 | main(path) 43 | 44 | path = 'meta-llama/Meta-Llama-3-8B' 45 | 46 | main(path) 47 | -------------------------------------------------------------------------------- /needlebench/needlebench/needlebench_multi_retrieval.py: -------------------------------------------------------------------------------- 1 | from opencompass.openicl.icl_prompt_template import PromptTemplate 2 | from opencompass.openicl.icl_retriever import ZeroRetriever 3 | from opencompass.openicl.icl_inferencer import GenInferencer 4 | from opencompass.datasets.needlebench.parallel import NeedleBenchParallelDataset 5 | from opencompass.datasets.needlebench.parallel import NeedleBenchParallelEvaluator 6 | from opencompass.datasets.needlebench.origin import needlebench_postprocess 7 | from opencompass.datasets.needlebench.origin import needlebench_dataset_postprocess 8 | import math 9 | 10 | 11 | def logistic(x, L=100, x0=50, k=0.1): 12 | return round(L / (1 + math.exp(-k * (x - x0))), 3) 13 | 14 | 15 | def generate_linear_space(start, end, num): 16 | if num == 1: 17 | return [start] 18 | elif num < 1: 19 | raise ValueError('num must be at least 1.') 20 | step = (end - start) / (num - 1) 21 | return [start + step * i for i in range(num)] 22 | 23 | 24 | def generate_depth_percents(intervals, interval_type): 25 | if interval_type == 'linear': 26 | return generate_linear_space(0, 100, intervals) 27 | elif interval_type == 'sigmoid': 28 | linear_space = generate_linear_space(0, 100, intervals) 29 | return [logistic(x) for x in linear_space] 30 | else: 31 | raise ValueError('Unsupported interval type') 32 | 33 | 34 | needlebench_reader_cfg = dict(input_columns=['prompt'], output_column='answer') 35 | 36 | needlebench_infer_cfg = dict( 37 | prompt_template=dict( 38 | type=PromptTemplate, 39 | template=dict( 40 | round=[ 41 | dict(role='HUMAN', prompt='{prompt}'), 42 | # dict(role='BOT', prompt='{answer}\n'), 43 | ] 44 | ), 45 | ), 46 | retriever=dict(type=ZeroRetriever), 47 | inferencer=dict(type=GenInferencer), 48 | ) 49 | 50 | needlebench_eval_cfg = dict( 51 | evaluator=dict(type=NeedleBenchParallelEvaluator), 52 | pred_postprocessor=dict(type=needlebench_postprocess), 53 | dataset_postprocessor=dict(type=needlebench_dataset_postprocess), 54 | pred_role='BOT', 55 | ) 56 | 57 | context_lengths = list([2000, 4000, 8000, 16000, 24000, 32000]) # , 64000, 128000 58 | document_depth_percent_intervals = 25 59 | document_depth_percent_interval_type = 'linear' 60 | 61 | base_path = 'opencompass/needlebench' 62 | file_list = ['en_un_asr.jsonl'] # PaulGrahamEssays 63 | needlebench_en_datasets = [] 64 | needle_file_name = 'needles.jsonl' 65 | depths = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100] 66 | 67 | for original_context_length in context_lengths: 68 | dataset_dict = { 69 | 'abbr': f'Length{original_context_length}' f'_parallel_en', 70 | 'type': NeedleBenchParallelDataset, 71 | 'path': base_path, 72 | 'needle_file_name': needle_file_name, 73 | 'length': original_context_length, 74 | 'depths': depths, 75 | 'tokenizer_model': 'gpt-4', 76 | 'file_list': file_list, 77 | 'num_repeats_per_file': 25, 78 | 'length_buffer': 3000, 79 | 'guide': True, 80 | 'language': 'English', 81 | 'reader_cfg': needlebench_reader_cfg, 82 | 'infer_cfg': needlebench_infer_cfg, 83 | 'eval_cfg': needlebench_eval_cfg, 84 | } 85 | needlebench_en_datasets.append(dataset_dict) 86 | 87 | file_list = ['zh_all.jsonl'] # zh_finance 88 | needlebench_zh_datasets = [] 89 | 90 | for original_context_length in context_lengths: 91 | dataset_dict = { 92 | 'abbr': f'Length{original_context_length}' f'_parallel_zh', 93 | 'type': NeedleBenchParallelDataset, 94 | 'path': base_path, 95 | 'needle_file_name': needle_file_name, 96 | 'length': original_context_length, 97 | 'depths': depths, 98 | 'tokenizer_model': 'gpt-4', 99 | 'file_list': file_list, 100 | 'num_repeats_per_file': 25, 101 | 'length_buffer': 200, 102 | 'guide': True, 103 | 'language': 'Chinese', 104 | 'reader_cfg': needlebench_reader_cfg, 105 | 'infer_cfg': needlebench_infer_cfg, 106 | 'eval_cfg': needlebench_eval_cfg, 107 | } 108 | needlebench_zh_datasets.append(dataset_dict) 109 | -------------------------------------------------------------------------------- /needlebench/needlebench/needlebench_single.py: -------------------------------------------------------------------------------- 1 | from opencompass.openicl.icl_prompt_template import PromptTemplate 2 | from opencompass.openicl.icl_retriever import ZeroRetriever 3 | from opencompass.openicl.icl_inferencer import GenInferencer 4 | from opencompass.datasets.needlebench.origin import NeedleBenchOriginDataset 5 | from opencompass.datasets.needlebench.origin import NeedleBenchOriginEvaluator 6 | from opencompass.datasets.needlebench.origin import needlebench_postprocess 7 | from opencompass.datasets.needlebench.origin import needlebench_dataset_postprocess 8 | import math 9 | 10 | 11 | def logistic(x, L=100, x0=50, k=0.1): 12 | return round(L / (1 + math.exp(-k * (x - x0))), 3) 13 | 14 | 15 | def generate_linear_space(start, end, num): 16 | if num == 1: 17 | return [start] 18 | elif num < 1: 19 | raise ValueError('num must be at least 1.') 20 | step = (end - start) / (num - 1) 21 | return [start + step * i for i in range(num)] 22 | 23 | 24 | def generate_depth_percents(intervals, interval_type): 25 | if interval_type == 'linear': 26 | return generate_linear_space(0, 100, intervals) 27 | elif interval_type == 'sigmoid': 28 | linear_space = generate_linear_space(0, 100, intervals) 29 | return [logistic(x) for x in linear_space] 30 | else: 31 | raise ValueError('Unsupported interval type') 32 | 33 | 34 | needlebench_reader_cfg = dict(input_columns=['prompt'], output_column='answer') 35 | 36 | needlebench_infer_cfg = dict( 37 | prompt_template=dict( 38 | type=PromptTemplate, 39 | template=dict( 40 | round=[ 41 | dict(role='HUMAN', prompt='{prompt}'), 42 | # dict(role='BOT', prompt='{answer}\n'), 43 | ] 44 | ), 45 | ), 46 | retriever=dict(type=ZeroRetriever), 47 | inferencer=dict(type=GenInferencer), 48 | ) 49 | 50 | needlebench_eval_cfg = dict( 51 | evaluator=dict(type=NeedleBenchOriginEvaluator), 52 | pred_postprocessor=dict(type=needlebench_postprocess), 53 | dataset_postprocessor=dict(type=needlebench_dataset_postprocess), 54 | pred_role='BOT', 55 | ) 56 | 57 | context_lengths = list([2000, 4000, 8000, 16000, 24000, 32000, ]) # 64000, 128000 58 | 59 | base_path = 'opencompass/needlebench' 60 | file_list = ['en_un_asr.jsonl'] 61 | needlebench_en_datasets = [] 62 | needle_file_name = 'needles.jsonl' 63 | depths_list = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100] 64 | 65 | for original_context_length in context_lengths: 66 | for depth_percent in depths_list: 67 | dataset_dict = { 68 | 'abbr': f'Length{original_context_length}' 69 | f'Depth{int(depth_percent)}_origin_en', 70 | 'type': NeedleBenchOriginDataset, 71 | 'path': base_path, 72 | 'length': original_context_length, 73 | 'depth': int(depth_percent), 74 | 'tokenizer_model': 'gpt-4', 75 | 'file_list': file_list, 76 | 'num_repeats_per_file': 10, 77 | 'length_buffer': 600, 78 | 'guide': True, 79 | 'language': 'English', 80 | 'needle_file_name': needle_file_name, 81 | 'reader_cfg': needlebench_reader_cfg, 82 | 'infer_cfg': needlebench_infer_cfg, 83 | 'eval_cfg': needlebench_eval_cfg, 84 | } 85 | needlebench_en_datasets.append(dataset_dict) 86 | 87 | file_list = ['zh_all.jsonl'] # zh_finance 88 | needlebench_zh_datasets = [] 89 | needle_file_name = 'needles.jsonl' 90 | 91 | for original_context_length in context_lengths: 92 | for depth_percent in depths_list: 93 | dataset_dict = { 94 | 'abbr': f'Length{original_context_length}' 95 | f'Depth{int(depth_percent)}_origin_zh', 96 | 'type': NeedleBenchOriginDataset, 97 | 'path': base_path, 98 | 'length': original_context_length, 99 | 'depth': int(depth_percent), 100 | 'tokenizer_model': 'gpt-4', 101 | 'file_list': file_list, 102 | 'num_repeats_per_file': 10, 103 | 'length_buffer': 200, 104 | 'guide': True, 105 | 'language': 'Chinese', 106 | 'needle_file_name': needle_file_name, 107 | 'reader_cfg': needlebench_reader_cfg, 108 | 'infer_cfg': needlebench_infer_cfg, 109 | 'eval_cfg': needlebench_eval_cfg, 110 | } 111 | needlebench_zh_datasets.append(dataset_dict) 112 | -------------------------------------------------------------------------------- /eval/eval_llada_ruler.py: -------------------------------------------------------------------------------- 1 | from mmengine.config import read_base 2 | from opencompass.runners import LocalRunner 3 | from opencompass.partitioners import NaivePartitioner, NumWorkerPartitioner 4 | from opencompass.tasks import OpenICLInferTask, OpenICLEvalTask 5 | from opencompass.models import LLaDACausalLM 6 | 7 | with read_base(): 8 | from opencompass.configs.datasets.ruler.ruler_4k_gen import ruler_datasets as ruler_datasets_4k 9 | from opencompass.configs.datasets.ruler.ruler_8k_gen import ruler_datasets as ruler_datasets_8k 10 | from opencompass.configs.datasets.ruler.ruler_16k_gen import ruler_datasets as ruler_datasets_16k 11 | 12 | datasets = [] 13 | datasets += ruler_datasets_4k 14 | datasets += ruler_datasets_8k 15 | datasets += ruler_datasets_16k 16 | 17 | num_gpus = { 18 | 'llama_3_8b_base': 1, 'llama_3_8b_chat': 1, 19 | 20 | 'llada_8b_base': 1, 'llada_8b_chat': 1, 'llada_1_5_8b': 1, 21 | 22 | 'dream_v0_7b_base': 1, 'dream_v0_7b_chat': 1, 23 | } 24 | 25 | path_dict = { 26 | 'llama_3_8b_base': 'meta-llama/Meta-Llama-3-8B', 27 | 'llama_3_8b_chat': 'meta-llama/Meta-Llama-3-8B-Instruct', 28 | 29 | 'llada_8b_base': 'GSAI-ML/LLaDA-8B-Base', 30 | 'llada_8b_chat': 'GSAI-ML/LLaDA-8B-Instruct', 31 | 32 | 'llada_1_5_8b': 'GSAI-ML/LLaDA-1.5', 33 | 34 | 'dream_v0_7b_base': 'Dream-org/Dream-v0-Base-7B', 35 | 'dream_v0_7b_chat': 'Dream-org/Dream-v0-Instruct-7B', 36 | } 37 | 38 | models = [ 39 | 40 | ## llama series 41 | 42 | ('llama_3_8b_base-o64', {}, {}, 64), 43 | ('llama_3_8b_base-o64-ntk4', {'scaling_factor': 4}, {}, 64), 44 | ('llama_3_8b_base-o64-ntk13', {'scaling_factor': 13}, {}, 64), 45 | 46 | ('llama_3_8b_chat-o64', {}, {}, 64), 47 | ('llama_3_8b_chat-o64-ntk4', {'scaling_factor': 4}, {}, 64), 48 | ('llama_3_8b_chat-o64-ntk13', {'scaling_factor': 13}, {}, 64), 49 | 50 | ## llada series 51 | 52 | ('llada_8b_base-o64_b64_s64', {}, {'steps': 64, 'block_length': 64, }, 64), 53 | ('llada_8b_base-o64_b64_s64-ntk4', {'scaling_factor': 4}, {'steps': 64, 'block_length': 64, }, 64), 54 | ('llada_8b_base-o64_b64_s64-ntk14', {'scaling_factor': 14}, {'steps': 64, 'block_length': 64, }, 64), 55 | ('llada_8b_base-o64_b64_s64-ntk31', {'scaling_factor': 31}, {'steps': 64, 'block_length': 64, }, 64), 56 | 57 | ('llada_8b_chat-o64_b64_s64', {}, {'steps': 64, 'block_length': 64, }, 64), 58 | ('llada_8b_chat-o64_b64_s64-ntk4', {'scaling_factor': 4}, {'steps': 64, 'block_length': 64, }, 64), 59 | ('llada_8b_chat-o64_b64_s64-ntk14', {'scaling_factor': 14}, {'steps': 64, 'block_length': 64, }, 64), 60 | ('llada_8b_chat-o64_b64_s64-ntk31', {'scaling_factor': 31}, {'steps': 64, 'block_length': 64, }, 64), 61 | 62 | ('llada_1_5_8b-o64_b64_s64', {}, {'steps': 64, 'block_length': 64, }, 64), 63 | ('llada_1_5_8b-o64_b64_s64-ntk4', {'scaling_factor': 4}, {'steps': 64, 'block_length': 64, }, 64), 64 | ('llada_1_5_8b-o64_b64_s64-ntk14', {'scaling_factor': 14}, {'steps': 64, 'block_length': 64, }, 64), 65 | ('llada_1_5_8b-o64_b64_s64-ntk31', {'scaling_factor': 31}, {'steps': 64, 'block_length': 64, }, 64), 66 | 67 | ## dream series 68 | 69 | ('dream_v0_7b_base-o64_s64', {}, {'steps': 64, }, 64), 70 | ('dream_v0_7b_base-o64_s64-ntk5', {'scaling_factor': 5}, {'steps': 64, }, 64), 71 | ('dream_v0_7b_chat-o64_s64', {}, {'steps': 64, }, 64), 72 | ('dream_v0_7b_chat-o64_s64-ntk5', {'scaling_factor': 5}, {'steps': 64, }, 64), 73 | 74 | ] 75 | 76 | models = [ 77 | dict( 78 | type=LLaDACausalLM, abbr=abbr, path=path_dict[abbr.split('-')[0]], 79 | scaling_config=scaling_config, diffusion_config=diffusion_config, seed=2025, model_type=abbr.split('_')[0], 80 | model_kwargs={'flash_attention': True}, max_out_len=max_out_len, batch_size=1, 81 | run_cfg=dict(num_gpus=num_gpus[abbr.split('-')[0]], num_procs=num_gpus[abbr.split('-')[0]]), 82 | ) for abbr, scaling_config, diffusion_config, max_out_len in models 83 | ] 84 | 85 | work_dir = './outputs/llada_ruler/' 86 | 87 | infer = dict( 88 | partitioner=dict(type=NaivePartitioner), 89 | runner=dict( 90 | type=LocalRunner, 91 | task=dict(type=OpenICLInferTask), 92 | ), 93 | ) 94 | 95 | eval = dict( 96 | partitioner=dict(type=NaivePartitioner), 97 | runner=dict( 98 | type=LocalRunner, 99 | max_num_workers=32, 100 | task=dict(type=OpenICLEvalTask, dump_details=True), 101 | ), 102 | ) 103 | 104 | # python run.py eval/eval_llada_ruler.py --dump-eval-details -r 105 | # python run.py eval/eval_llada_ruler.py --dump-eval-details -r --debug 106 | -------------------------------------------------------------------------------- /ppl/get_ppl_llada.py: -------------------------------------------------------------------------------- 1 | # copy from https://github.com/ML-GSAI/LLaDA/blob/main/get_log_likelihood.py 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from transformers import AutoTokenizer, AutoModel 7 | 8 | import os 9 | import random 10 | 11 | from tqdm import tqdm 12 | 13 | import numpy as np 14 | 15 | seed = 2025 16 | os.environ['PYTHONHASHSEED'] = str(seed) 17 | random.seed(seed) 18 | np.random.seed(seed) 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | torch.backends.cudnn.benchmark = False # if benchmark=True, deterministic will be False 23 | torch.backends.cudnn.deterministic = True # choose a deterministic algorithm 24 | 25 | 26 | def forward_process(batch, prompt_index, mask_id): 27 | b, l = batch.shape 28 | 29 | target_len = (l - prompt_index.sum()).item() 30 | k = torch.randint(1, target_len + 1, (), device=batch.device) 31 | 32 | x = torch.round(torch.linspace(float(k), k + (b - 1) * (target_len / b), steps=b, device=batch.device)).long() 33 | x = ((x - 1) % target_len) + 1 34 | assert x.min() >= 1 and x.max() <= target_len 35 | 36 | indices = torch.arange(target_len, device=batch.device).repeat(b, 1) 37 | is_mask = indices < x.unsqueeze(1) 38 | for i in range(b): 39 | is_mask[i] = is_mask[i][torch.randperm(target_len)] 40 | 41 | is_mask = torch.cat((torch.zeros(b, prompt_index.sum(), dtype=torch.bool, device=batch.device), is_mask), dim=1) 42 | noisy_batch = torch.where(is_mask, mask_id, batch) 43 | 44 | # Return the masked batch and the mask ratio 45 | return noisy_batch, (x / target_len).unsqueeze(1).repeat(1, l) 46 | 47 | 48 | def get_logits(model, batch, prompt_index, cfg_scale, mask_id): 49 | if cfg_scale > 0.: 50 | assert len(prompt_index) == batch.shape[1] 51 | prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1) 52 | un_batch = batch.clone() 53 | un_batch[prompt_index] = mask_id 54 | batch = torch.cat([batch, un_batch]) 55 | 56 | input = batch 57 | logits = model(input).logits 58 | 59 | if cfg_scale > 0.: 60 | logits, un_logits = torch.chunk(logits, 2, dim=0) 61 | logits = un_logits + (cfg_scale + 1) * (logits - un_logits) 62 | return logits 63 | 64 | 65 | @ torch.no_grad() 66 | def get_ppl(model, prompt, answer, mc_num=8, batch_size=1, cfg_scale=0., mask_id=126336): 67 | ''' 68 | Args: 69 | model: Mask predictor. 70 | prompt: A tensor of shape (l1). 71 | answer: A tensor of shape (l2). 72 | mc_num: Monte Carlo estimation times. 73 | As detailed in Appendix B.5. Since MMLU, CMMLU, and C-EVAL only require the likelihood of a single token, a 74 | single Monte Carlo estimate is sufficient for these benchmarks. For all other benchmarks, we find that 128 75 | Monte Carlo samples are adequate to produce stable results. 76 | batch_size: Mini batch size. 77 | cfg_scale: Unsupervised classifier-free guidance scale. 78 | mask_id: The toke id of [MASK] is 126336. 79 | ''' 80 | seq = torch.concatenate([prompt, answer])[None, :] 81 | # seq = seq.repeat((batch_size, 1)).to(model.device) 82 | prompt_index = torch.arange(seq.shape[1], device=model.device) < len(prompt) 83 | 84 | loss_ = [] 85 | for _ in range(mc_num): 86 | perturbed_seq, p_mask = forward_process(seq, prompt_index, mask_id) 87 | mask_index = perturbed_seq == mask_id 88 | 89 | logits = get_logits(model, perturbed_seq, prompt_index, cfg_scale, mask_id) 90 | 91 | loss = F.cross_entropy(logits[mask_index], seq[mask_index], reduction='none') 92 | loss = loss.mean() 93 | 94 | loss_.append(loss.item()) 95 | 96 | return np.exp(sum(loss_) / len(loss_)) 97 | 98 | 99 | def main(path): 100 | device = 'cuda:0' 101 | 102 | print(path) 103 | 104 | model = AutoModel.from_pretrained(path, trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval() 105 | tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) 106 | 107 | file = '/'.join(os.path.realpath(__file__).split('/')[:-1]) 108 | file = f'{file}/gov_report_001.txt' 109 | 110 | text = open(file, mode='r').readline() 111 | 112 | input_ids = tokenizer(text)['input_ids'] 113 | print(f'{len(input_ids)}') 114 | 115 | context_len, chunk_size = 16384, 64 116 | 117 | perplexity = [] 118 | 119 | for i in tqdm(list(range(int(context_len // chunk_size) - 1))): 120 | prompt_len = (i+1) * chunk_size 121 | prompt = torch.tensor(input_ids[:prompt_len]).to(device) 122 | answer = torch.tensor(input_ids[prompt_len:prompt_len+chunk_size]).to(device) 123 | ppl = get_ppl(model, prompt, answer, mc_num=8) 124 | print(prompt_len, flush=True) 125 | perplexity.append(ppl) 126 | 127 | num_sample = len(perplexity) 128 | perplexity = np.log(np.array(perplexity)) 129 | perplexity = np.exp(np.cumsum(perplexity) / (np.arange(num_sample) + 1)) 130 | 131 | print(f"Perplexity: {perplexity.tolist()}", flush=True) 132 | 133 | 134 | if __name__ == '__main__': 135 | 136 | path = 'GSAI-ML/LLaDA-8B-Instruct' 137 | 138 | main(path) 139 | 140 | path = 'GSAI-ML/LLaDA-8B-Base' 141 | 142 | main(path) 143 | -------------------------------------------------------------------------------- /llada/llada_generate.py: -------------------------------------------------------------------------------- 1 | ## copy from https://github.com/ML-GSAI/LLaDA/blob/main/generate.py 2 | 3 | import torch 4 | import numpy as np 5 | import torch.nn.functional as F 6 | 7 | from transformers import AutoTokenizer, AutoModel 8 | 9 | 10 | def add_gumbel_noise(logits, temperature): 11 | ''' 12 | The Gumbel max is a method for sampling categorical distributions. 13 | According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality. 14 | Thus, we use float64. 15 | ''' 16 | if temperature == 0: 17 | return logits 18 | logits = logits.to(torch.float64) 19 | noise = torch.rand_like(logits, dtype=torch.float64) 20 | gumbel_noise = (- torch.log(noise)) ** temperature 21 | return logits.exp() / gumbel_noise 22 | 23 | 24 | def get_num_transfer_tokens(mask_index, steps): 25 | ''' 26 | In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals. 27 | Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)), 28 | the expected number of tokens transitioned at each step should be consistent. 29 | 30 | This function is designed to precompute the number of tokens that need to be transitioned at each step. 31 | ''' 32 | mask_num = mask_index.sum(dim=1, keepdim=True) 33 | 34 | base = mask_num // steps 35 | remainder = mask_num % steps 36 | 37 | num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base 38 | 39 | for i in range(mask_num.size(0)): 40 | num_transfer_tokens[i, :remainder[i]] += 1 41 | 42 | return num_transfer_tokens 43 | 44 | 45 | @ torch.no_grad() 46 | def generate(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0., 47 | cfg_scale=0., remasking='low_confidence', mask_id=126336): 48 | ''' 49 | Args: 50 | model: Mask predictor. 51 | prompt: A tensor of shape (1, L). 52 | steps: Sampling steps, less than or equal to gen_length. 53 | gen_length: Generated answer length. 54 | block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking. 55 | temperature: Categorical distribution sampling temperature. 56 | cfg_scale: Unsupervised classifier-free guidance scale. 57 | remasking: Remasking strategy. 'low_confidence' or 'random'. 58 | mask_id: The toke id of [MASK] is 126336. 59 | ''' 60 | x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device) 61 | x[:, :prompt.shape[1]] = prompt.clone() 62 | 63 | prompt_index = (x != mask_id) 64 | 65 | assert gen_length % block_length == 0 66 | num_blocks = gen_length // block_length 67 | 68 | assert steps % num_blocks == 0 69 | steps = steps // num_blocks 70 | 71 | for num_block in range(num_blocks): 72 | block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id) 73 | num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps) 74 | for i in range(steps): 75 | mask_index = (x == mask_id) 76 | if cfg_scale > 0.: 77 | un_x = x.clone() 78 | un_x[prompt_index] = mask_id 79 | x_ = torch.cat([x, un_x], dim=0) 80 | logits = model(x_).logits 81 | logits, un_logits = torch.chunk(logits, 2, dim=0) 82 | logits = un_logits + (cfg_scale + 1) * (logits - un_logits) 83 | else: 84 | logits = model(x).logits 85 | 86 | logits_with_noise = add_gumbel_noise(logits, temperature=temperature) 87 | x0 = torch.argmax(logits_with_noise, dim=-1) # b, l 88 | 89 | if remasking == 'low_confidence': 90 | p = F.softmax(logits, dim=-1) 91 | x0_p = torch.squeeze( 92 | torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l 93 | elif remasking == 'random': 94 | x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) 95 | else: 96 | raise NotImplementedError(remasking) 97 | 98 | x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf 99 | 100 | x0 = torch.where(mask_index, x0, x) 101 | confidence = torch.where(mask_index, x0_p, -np.inf) 102 | 103 | transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device) 104 | for j in range(confidence.shape[0]): 105 | _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i]) 106 | transfer_index[j, select_index] = True 107 | x[transfer_index] = x0[transfer_index] 108 | 109 | return x 110 | 111 | 112 | def main(): 113 | device = 'cuda' 114 | 115 | path = 'GSAI-ML/LLaDA-8B-Instruct' 116 | 117 | model = AutoModel.from_pretrained(path, trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval() 118 | tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) 119 | 120 | prompt = "Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 8 hours?" 121 | 122 | # Add special tokens for the Instruct model. The Base model does not require the following two lines. 123 | m = [{"role": "user", "content": prompt}, ] 124 | prompt = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False) 125 | 126 | input_ids = tokenizer(prompt)['input_ids'] 127 | input_ids = torch.tensor(input_ids).to(device).unsqueeze(0) 128 | 129 | out = generate(model, input_ids, steps=128, gen_length=128, block_length=32, temperature=0., cfg_scale=0., remasking='low_confidence') 130 | print(tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)[0]) 131 | 132 | 133 | if __name__ == '__main__': 134 | main() 135 | -------------------------------------------------------------------------------- /eval/eval_llada_niah.py: -------------------------------------------------------------------------------- 1 | from mmengine.config import read_base 2 | from opencompass.runners import LocalRunner 3 | from opencompass.partitioners import NaivePartitioner, NumWorkerPartitioner 4 | from opencompass.tasks import OpenICLInferTask, OpenICLEvalTask 5 | from opencompass.models import LLaDACausalLM 6 | 7 | with read_base(): 8 | from opencompass.configs.datasets.needlebench.needlebench.needlebench import needlebench_origin_en_datasets 9 | from opencompass.configs.summarizers.needlebench import needlebench_summarizer as summarizer 10 | 11 | datasets = [] 12 | datasets += needlebench_origin_en_datasets 13 | 14 | num_gpus = { 15 | 'llama_3_8b_base': 1, 'llama_3_8b_chat': 1, 16 | 17 | 'llada_8b_base': 1, 'llada_8b_chat': 1, 'llada_1_5_8b': 1, 18 | 19 | 'dream_v0_7b_base': 1, 'dream_v0_7b_chat': 1, 20 | } 21 | 22 | path_dict = { 23 | 'llama_3_8b_base': 'meta-llama/Meta-Llama-3-8B', 24 | 'llama_3_8b_chat': 'meta-llama/Meta-Llama-3-8B-Instruct', 25 | 26 | 'llada_8b_base': 'GSAI-ML/LLaDA-8B-Base', 27 | 'llada_8b_chat': 'GSAI-ML/LLaDA-8B-Instruct', 28 | 29 | 'llada_1_5_8b': 'GSAI-ML/LLaDA-1.5', 30 | 31 | 'dream_v0_7b_base': 'Dream-org/Dream-v0-Base-7B', 32 | 'dream_v0_7b_chat': 'Dream-org/Dream-v0-Instruct-7B', 33 | } 34 | 35 | models = [ 36 | 37 | ## llama series 38 | 39 | ('llama_3_8b_base-o32', {}, {}, 32), 40 | ('llama_3_8b_base-o32-ntk4', {'scaling_factor': 4}, {}, 32), 41 | ('llama_3_8b_base-o32-ntk13', {'scaling_factor': 13}, {}, 32), 42 | 43 | ('llama_3_8b_chat-o32', {}, {}, 32), 44 | ('llama_3_8b_chat-o32-ntk4', {'scaling_factor': 4}, {}, 32), 45 | ('llama_3_8b_chat-o32-ntk13', {'scaling_factor': 13}, {}, 32), 46 | 47 | ## llada series 48 | 49 | ### comparison on different sample steps 50 | 51 | ('llada_8b_base-o32_b32_s1', {}, {'steps': 1, 'block_length': 32, }, 32), 52 | ('llada_8b_base-o32_b32_s2', {}, {'steps': 2, 'block_length': 32, }, 32), 53 | ('llada_8b_base-o32_b32_s4', {}, {'steps': 4, 'block_length': 32, }, 32), 54 | ('llada_8b_base-o32_b32_s8', {}, {'steps': 8, 'block_length': 32, }, 32), 55 | ('llada_8b_base-o32_b32_s16', {}, {'steps': 16, 'block_length': 32, }, 32), 56 | ('llada_8b_base-o32_b32_s32', {}, {'steps': 32, 'block_length': 32, }, 32), # default for llada_8b_base 57 | 58 | ('llada_8b_chat-o32_b32_s1', {}, {'steps': 1, 'block_length': 32, }, 32), 59 | ('llada_8b_chat-o32_b32_s2', {}, {'steps': 2, 'block_length': 32, }, 32), 60 | ('llada_8b_chat-o32_b32_s4', {}, {'steps': 4, 'block_length': 32, }, 32), 61 | ('llada_8b_chat-o32_b32_s8', {}, {'steps': 8, 'block_length': 32, }, 32), 62 | ('llada_8b_chat-o32_b32_s16', {}, {'steps': 16, 'block_length': 32, }, 32), 63 | ('llada_8b_chat-o32_b32_s32', {}, {'steps': 32, 'block_length': 32, }, 32), # default for llada_8b_chat 64 | 65 | ('llada_1_5_8b-o32_b32_s1', {}, {'steps': 1, 'block_length': 32, }, 32), 66 | ('llada_1_5_8b-o32_b32_s2', {}, {'steps': 2, 'block_length': 32, }, 32), 67 | ('llada_1_5_8b-o32_b32_s4', {}, {'steps': 4, 'block_length': 32, }, 32), 68 | ('llada_1_5_8b-o32_b32_s8', {}, {'steps': 8, 'block_length': 32, }, 32), 69 | ('llada_1_5_8b-o32_b32_s16', {}, {'steps': 16, 'block_length': 32, }, 32), 70 | ('llada_1_5_8b-o32_b32_s32', {}, {'steps': 32, 'block_length': 32, }, 32), # default for llada_1_5_8b 71 | 72 | ### comparison on different scaling factors 73 | 74 | ('llada_8b_base-o32_b32_s32-ntk4', {'scaling_factor': 4}, {'steps': 32, 'block_length': 32, }, 32), 75 | ('llada_8b_base-o32_b32_s32-ntk14', {'scaling_factor': 14}, {'steps': 32, 'block_length': 32, }, 32), 76 | ('llada_8b_base-o32_b32_s32-ntk31', {'scaling_factor': 31}, {'steps': 32, 'block_length': 32, }, 32), 77 | ('llada_8b_base-o32_b32_s32-ntk55', {'scaling_factor': 55}, {'steps': 32, 'block_length': 32, }, 32), 78 | 79 | ('llada_8b_chat-o32_b32_s32-ntk4', {'scaling_factor': 4}, {'steps': 32, 'block_length': 32, }, 32), 80 | ('llada_8b_chat-o32_b32_s32-ntk14', {'scaling_factor': 14}, {'steps': 32, 'block_length': 32, }, 32), 81 | ('llada_8b_chat-o32_b32_s32-ntk31', {'scaling_factor': 31}, {'steps': 32, 'block_length': 32, }, 32), 82 | ('llada_8b_chat-o32_b32_s32-ntk55', {'scaling_factor': 55}, {'steps': 32, 'block_length': 32, }, 32), 83 | 84 | ('llada_1_5_8b-o32_b32_s32-ntk4', {'scaling_factor': 4}, {'steps': 32, 'block_length': 32, }, 32), 85 | ('llada_1_5_8b-o32_b32_s32-ntk14', {'scaling_factor': 14}, {'steps': 32, 'block_length': 32, }, 32), 86 | ('llada_1_5_8b-o32_b32_s32-ntk31', {'scaling_factor': 31}, {'steps': 32, 'block_length': 32, }, 32), 87 | ('llada_1_5_8b-o32_b32_s32-ntk55', {'scaling_factor': 55}, {'steps': 32, 'block_length': 32, }, 32), 88 | 89 | ## dream series 90 | 91 | ### comparison on different sample steps 92 | 93 | ('dream_v0_7b_base-o32_s1', {}, {'steps': 1, }, 32), 94 | ('dream_v0_7b_base-o32_s8', {}, {'steps': 8, }, 32), 95 | ('dream_v0_7b_base-o32_s16', {}, {'steps': 16, }, 32), 96 | ('dream_v0_7b_base-o32_s32', {}, {'steps': 32, }, 32), # default for dream_v0_7b_base 97 | 98 | ('dream_v0_7b_chat-o32_s1', {}, {'steps': 1, }, 32), 99 | ('dream_v0_7b_chat-o32_s8', {}, {'steps': 8, }, 32), 100 | ('dream_v0_7b_chat-o32_s16', {}, {'steps': 16, }, 32), 101 | ('dream_v0_7b_chat-o32_s32', {}, {'steps': 32, }, 32), # default for dream_v0_7b_chat 102 | 103 | ### comparison on different scaling factors 104 | 105 | ('dream_v0_7b_base-o32_s32-ntk5', {'scaling_factor': 5}, {'steps': 32, }, 32), 106 | 107 | ('dream_v0_7b_chat-o32_s32-ntk5', {'scaling_factor': 5}, {'steps': 32, }, 32), 108 | 109 | ] 110 | 111 | models = [ 112 | dict( 113 | type=LLaDACausalLM, abbr=abbr, path=path_dict[abbr.split('-')[0]], 114 | scaling_config=scaling_config, diffusion_config=diffusion_config, seed=2025, model_type=abbr.split('_')[0], 115 | model_kwargs={'flash_attention': True}, max_out_len=max_out_len, batch_size=1, 116 | run_cfg=dict(num_gpus=num_gpus[abbr.split('-')[0]], num_procs=num_gpus[abbr.split('-')[0]]), 117 | ) for abbr, scaling_config, diffusion_config, max_out_len in models 118 | ] 119 | 120 | 121 | work_dir = './outputs/llada_niah/' 122 | 123 | infer = dict( 124 | partitioner=dict(type=NaivePartitioner), 125 | runner=dict( 126 | type=LocalRunner, 127 | task=dict(type=OpenICLInferTask), 128 | ), 129 | ) 130 | 131 | eval = dict( 132 | partitioner=dict(type=NaivePartitioner), 133 | runner=dict( 134 | type=LocalRunner, 135 | max_num_workers=32, 136 | task=dict(type=OpenICLEvalTask, dump_details=True), 137 | ), 138 | ) 139 | 140 | # python run.py eval/eval_llada_niah.py --dump-eval-details -r 141 | # python run.py eval/eval_llada_niah.py --dump-eval-details -r --debug 142 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

LongLLaDA: Unlocking Long Context Capabilities in Diffusion LLMs

3 | Xiaoran Liu1,2, Yuerong Song1,2, Zhigeng Liu1,2, Zengfeng Huang1,2, Qipeng Guo2,3, Ziwei He2,†, Xipeng Qiu1,2,† 4 | 5 | 1 Fudan Univerisity, 2Shanghai Innovation Institute, 3Shanghai AI Laboratory 6 | 7 | [📝 Paper] | [🤗 HF] | [🚀 Code] 8 |
9 | 10 | ## Introduction 11 | 12 | In this work, we present the first systematic investigation comparing the long-context performance of diffusion LLMs and traditional auto-regressive LLMs. We first identify a unique characteristic of diffusion LLMs, unlike auto-regressive LLMs, they maintain remarkably ***stable perplexity*** during direct context extrapolation. 13 | 14 | Moreover, where auto-regressive models fail outright during the Needle-In-A-Haystack task with context exceeding their pretrained length, we discover diffusion LLMs exhibit a distinct ***local perception*** phenomenon, enabling successful retrieval from recent context segments. We explain both phenomena through the lens of Rotary Position Embedding (RoPE) scaling theory. 15 | 16 | Building on these observations, we propose ***LongLLaDA***, a training-free method that integrates LLaDA with the NTK-based RoPE extrapolation. Our results validate that established extrapolation scaling laws remain effective for extending the context windows of diffusion LLMs. 17 | 18 | Furthermore, we identify long-context tasks where diffusion LLMs outperform auto-regressive LLMs and others where they fall short. Consequently, this study establishes ***the first length extrapolation method for diffusion LLMs*** while providing essential theoretical insights and empirical benchmarks critical for advancing future research on long-context diffusion LLMs. This is the official implementation of LongLLaDA. 19 | 20 |

21 | 22 |

23 | 24 | ## Installation 25 | 26 | ### Prepare Your OpenCompass 27 | 28 | We run our downstream evaluation based on [OpenCompass](https://github.com/open-compass/opencompass). 29 | 30 | ```bash 31 | git clone https://github.com/open-compass/opencompass 32 | cd opencompass 33 | pip install -e . 34 | ``` 35 | 36 | The necessary Python packages we use and their corresponding versions. 37 | 38 | ``` 39 | flash-attn==2.7.4.post1 40 | torch==2.6.0 41 | transformers==4.46.3 42 | opencompass==0.4.2 43 | ``` 44 | 45 | ### Prepare Your Model 46 | 47 | Copy the folder `LongLLaDA/llada/` to `opencompass/models/` and add the following line to the end of `opencompass/models/__init__.py`. 48 | 49 | ```python 50 | from .llada.llada_wrapper import LLaDACausalLM 51 | ``` 52 | 53 | ## Evaluation 54 | 55 | Copy the folder `LongLLaDA/eval/` to your OpenCompass directory and then you can try the following evaluations. 56 | 57 | ### Needle-In-A-Haystack (NIAH) evaluation 58 | 59 | 1. Add a NIAH evaluation script with customizable context length and depth. Copy `LongLLaDA/needlebench/needlebench` to `opencompass/configs/datasets/needlebench` and replace `opencompass/configs/summarizers/needlebench.py` with `LongLLaDA/needlebench/needlebench.py`. 60 | 61 | 2. Edit the prompt format of the RULER benchmark to enable the base model to respond more effectively by replacing `opencompass/datasets/needlebench/origin.py` with `LongLLaDA/needlebench/origin.py`. 62 | 63 | 3. You can also modify the plotting code in `opencompass/summarizers/needlebench.py` as shown in `LongLLaDA/needlebench/needlebench_summarizer.py`, which is optional. 64 | 65 | 4. Execute the following command. 66 | 67 | ```bash 68 | python run.py eval/eval_llada_niah.py --dump-eval-details -r 69 | ``` 70 | 71 | ### LongBench evaluation 72 | 73 | 1. Execute the following command. 74 | 75 | ```bash 76 | python run.py eval/eval_llada_long.py --dump-eval-details -r 77 | ``` 78 | 79 | ### RULER evaluation 80 | 81 | 1. Edit the prompt format of the RULER benchmark to enable the base model to respond more effectively. In `ruler_cwe_gen.py`, `ruler_fwe_gen.py`, `ruler_niah_gen.py`, `ruler_qa_gen.py`, `ruler_vt_gen.py` under the path `opencompass/configs/datasets/ruler/`, comment out the '\n' at the end of the prompt. The following is an example in `opencompass/configs/datasets/ruler/ruler_vt_gen.py`. 82 | 83 | ```python 84 | vt_datasets = [ 85 | { 86 | 'abbr': 'ruler_vt', 87 | 'type': RulerVtDataset, 88 | 'num_chains': 1, 89 | 'num_hops': 4, 90 | 'reader_cfg': dict(input_columns=['prompt'], output_column='answer'), 91 | 'infer_cfg': dict( 92 | prompt_template=dict( 93 | type=PromptTemplate, 94 | template=dict( 95 | round=[ 96 | dict(role='HUMAN', prompt='{prompt}'), 97 | # dict(role='BOT', prompt='{answer}\n'), # comment out this line 98 | ] 99 | ), 100 | ), 101 | retriever=dict(type=ZeroRetriever), 102 | inferencer=dict(type=GenInferencer), 103 | ), 104 | 'eval_cfg': dict( 105 | evaluator=dict(type=RulerVtEvaluator), 106 | ), 107 | } 108 | ] 109 | ``` 110 | 111 | 2. Execute the following command. 112 | 113 | ```bash 114 | python run.py eval/eval_llada_ruler.py --dump-eval-details -r 115 | ``` 116 | 117 | ### Perplexity (PPL) Evaluation 118 | 119 | > We calculate the perplexity in LongLLaDA directory instead of OpenCompass as follows. 120 | 121 | 1. Execute the following command to get the perplexity curve of LLaMA3. 122 | 123 | ```bash 124 | python ppl/get_ppl_llama.py 125 | ``` 126 | 127 | 2. Execute the following command to get the perplexity curve of LLaDA with block_size=64 for efficiency. 128 | 129 | ```bash 130 | python ppl/get_ppl_llada.py 131 | ``` 132 | 133 | 3. Organize the related results and execute the following command to get Figure 1 in our paper. 134 | 135 | ```bash 136 | python ppl/get_ppl_plot.py 137 | ``` 138 | 139 | ## Results 140 | 141 |

142 | 143 |

144 | 145 |

146 | 147 |

148 | 149 |

150 | 151 |

152 | 153 | ## Citation 154 | 155 | ``` 156 | @article{liu2025longllada, 157 | title={LongLLaDA: Unlocking Long Context Capabilities in Diffusion LLMs}, 158 | author={Liu, Xiaoran and Song, Yuerong and Liu, Zhigeng and Huang, Zengfeng and Guo, Qipeng and He, Ziwei and Qiu, Xipeng}, 159 | journal={arXiv preprint arXiv:2506.14429}, 160 | year={2025} 161 | } 162 | ``` 163 | -------------------------------------------------------------------------------- /eval/eval_llada_long.py: -------------------------------------------------------------------------------- 1 | from mmengine.config import read_base 2 | from opencompass.partitioners import NaivePartitioner, SizePartitioner 3 | from opencompass.runners import LocalRunner 4 | from opencompass.tasks import OpenICLInferTask, OpenICLEvalTask 5 | from opencompass.models import LLaDACausalLM 6 | 7 | import torch 8 | 9 | with read_base(): 10 | 11 | # longbench 12 | 13 | from opencompass.configs.datasets.longbench.longbenchnarrativeqa.longbench_narrativeqa_gen import LongBench_narrativeqa_datasets 14 | from opencompass.configs.datasets.longbench.longbenchqasper.longbench_qasper_gen import LongBench_qasper_datasets 15 | from opencompass.configs.datasets.longbench.longbenchmultifieldqa_en.longbench_multifieldqa_en_gen import LongBench_multifieldqa_en_datasets 16 | from opencompass.configs.datasets.longbench.longbenchmultifieldqa_zh.longbench_multifieldqa_zh_gen import LongBench_multifieldqa_zh_datasets 17 | 18 | from opencompass.configs.datasets.longbench.longbenchhotpotqa.longbench_hotpotqa_gen import LongBench_hotpotqa_datasets 19 | from opencompass.configs.datasets.longbench.longbench2wikimqa.longbench_2wikimqa_gen import LongBench_2wikimqa_datasets 20 | from opencompass.configs.datasets.longbench.longbenchmusique.longbench_musique_gen import LongBench_musique_datasets 21 | from opencompass.configs.datasets.longbench.longbenchdureader.longbench_dureader_gen import LongBench_dureader_datasets 22 | 23 | from opencompass.configs.datasets.longbench.longbenchgov_report.longbench_gov_report_gen import LongBench_gov_report_datasets 24 | from opencompass.configs.datasets.longbench.longbenchqmsum.longbench_qmsum_gen import LongBench_qmsum_datasets 25 | from opencompass.configs.datasets.longbench.longbenchmulti_news.longbench_multi_news_gen import LongBench_multi_news_datasets 26 | from opencompass.configs.datasets.longbench.longbenchvcsum.longbench_vcsum_gen import LongBench_vcsum_datasets 27 | 28 | from opencompass.configs.datasets.longbench.longbenchtrec.longbench_trec_gen import LongBench_trec_datasets 29 | from opencompass.configs.datasets.longbench.longbenchtriviaqa.longbench_triviaqa_gen import LongBench_triviaqa_datasets 30 | from opencompass.configs.datasets.longbench.longbenchsamsum.longbench_samsum_gen import LongBench_samsum_datasets 31 | from opencompass.configs.datasets.longbench.longbenchlsht.longbench_lsht_gen import LongBench_lsht_datasets 32 | 33 | from opencompass.configs.datasets.longbench.longbenchpassage_count.longbench_passage_count_gen import LongBench_passage_count_datasets 34 | from opencompass.configs.datasets.longbench.longbenchpassage_retrieval_en.longbench_passage_retrieval_en_gen import LongBench_passage_retrieval_en_datasets 35 | from opencompass.configs.datasets.longbench.longbenchpassage_retrieval_zh.longbench_passage_retrieval_zh_gen import LongBench_passage_retrieval_zh_datasets 36 | 37 | from opencompass.configs.datasets.longbench.longbenchlcc.longbench_lcc_gen import LongBench_lcc_datasets 38 | from opencompass.configs.datasets.longbench.longbenchrepobench.longbench_repobench_gen import LongBench_repobench_datasets 39 | 40 | datasets = sum((v for k, v in locals().items() if k.endswith('_datasets')), []) 41 | 42 | num_gpus = { 43 | 'llama_3_8b_base': 1, 'llama_3_8b_chat': 1, 44 | 45 | 'llada_8b_base': 1, 'llada_8b_chat': 1, 'llada_1_5_8b': 1, 46 | 47 | 'dream_v0_7b_base': 1, 'dream_v0_7b_chat': 1, 48 | } 49 | 50 | path_dict = { 51 | 'llama_3_8b_base': 'meta-llama/Meta-Llama-3-8B', 52 | 'llama_3_8b_chat': 'meta-llama/Meta-Llama-3-8B-Instruct', 53 | 54 | 'llada_8b_base': 'GSAI-ML/LLaDA-8B-Base', 55 | 'llada_8b_chat': 'GSAI-ML/LLaDA-8B-Instruct', 56 | 57 | 'llada_1_5_8b': 'GSAI-ML/LLaDA-1.5', 58 | 59 | 'dream_v0_7b_base': 'Dream-org/Dream-v0-Base-7B', 60 | 'dream_v0_7b_chat': 'Dream-org/Dream-v0-Instruct-7B', 61 | } 62 | 63 | models = [ 64 | # 8k 65 | ## llama series 66 | 67 | ('llama_3_8b_base-o512-8k', {}, {}, 7500, 512), 68 | ('llama_3_8b_chat-o512-8k', {}, {}, 7500, 512), 69 | 70 | ## llada series 71 | 72 | ('llada_8b_base-o512_b64_s512-8k', {}, {'steps': 512, 'block_length': 64, }, 7500, 512), 73 | ('llada_8b_base-o512_b64_s512-ntk4-8k', {'scaling_factor': 4}, {'steps': 512, 'block_length': 64, }, 7500, 512), 74 | ('llada_8b_chat-o512_b64_s512-8k', {}, {'steps': 512, 'block_length': 64, }, 7500, 512), 75 | ('llada_8b_chat-o512_b64_s512-ntk4-8k', {'scaling_factor': 4}, {'steps': 512, 'block_length': 64, }, 7500, 512), 76 | ('llada_1_5_8b-o512_b64_s512-8k', {}, {'steps': 512, 'block_length': 64, }, 7500, 512), 77 | ('llada_1_5_8b-o512_b64_s512-ntk4-8k', {'scaling_factor': 4}, {'steps': 512, 'block_length': 64, }, 7500, 512), 78 | 79 | ## dream series 80 | 81 | ('dream_v0_7b_base-o512_s512-8k', {}, {'steps': 512, }, 7500, 512), 82 | ('dream_v0_7b_base-o512_s512-ntk5-8k', {'scaling_factor': 5}, {'steps': 512, }, 7500, 512), 83 | ('dream_v0_7b_chat-o512_s512-8k', {}, {'steps': 512, }, 7500, 512), 84 | ('dream_v0_7b_chat-o512_s512-ntk5-8k', {'scaling_factor': 5}, {'steps': 512, }, 7500, 512), 85 | 86 | # 4k 87 | ## llama series 88 | 89 | ('llama_3_8b_base-o512-4k', {}, {}, 3500, 512), 90 | ('llama_3_8b_chat-o512-4k', {}, {}, 3500, 512), 91 | 92 | ## llada series 93 | 94 | ('llada_8b_base-o512_b64_s512-4k', {}, {'steps': 512, 'block_length': 64, }, 3500, 512), 95 | ('llada_8b_base-o512_b64_s512-ntk4-4k', {'scaling_factor': 4}, {'steps': 512, 'block_length': 64, }, 3500, 512), 96 | ('llada_8b_chat-o512_b64_s512-4k', {}, {'steps': 512, 'block_length': 64, }, 3500, 512), 97 | ('llada_8b_chat-o512_b64_s512-ntk4-4k', {'scaling_factor': 4}, {'steps': 512, 'block_length': 64, }, 3500, 512), 98 | ('llada_1_5_8b-o512_b64_s512-4k', {}, {'steps': 512, 'block_length': 64, }, 3500, 512), 99 | ('llada_1_5_8b-o512_b64_s512-ntk4-4k', {'scaling_factor': 4}, {'steps': 512, 'block_length': 64, }, 3500, 512), 100 | 101 | ## dream series 102 | 103 | ('dream_v0_7b_base-o512_s512-4k', {}, {'steps': 512, }, 3500, 512), 104 | ('dream_v0_7b_base-o512_s512-ntk5-4k', {'scaling_factor': 5}, {'steps': 512, }, 3500, 512), 105 | ('dream_v0_7b_chat-o512_s512-4k', {}, {'steps': 512, }, 3500, 512), 106 | ('dream_v0_7b_chat-o512_s512-ntk5-4k', {'scaling_factor': 5}, {'steps': 512, }, 3500, 512), 107 | 108 | ] 109 | 110 | models = [ 111 | dict( 112 | type=LLaDACausalLM, abbr=abbr, path=path_dict[abbr.split('-')[0]], drop_middle=True, 113 | scaling_config=scaling_config, diffusion_config=diffusion_config, seed=2025, model_type=abbr.split('_')[0], 114 | model_kwargs={'flash_attention': True}, max_out_len=max_out_len, batch_size=1, max_seq_len=max_seq_len, 115 | run_cfg=dict(num_gpus=num_gpus[abbr.split('-')[0]], num_procs=num_gpus[abbr.split('-')[0]]), 116 | ) for abbr, scaling_config, diffusion_config, max_seq_len, max_out_len in models 117 | ] 118 | 119 | work_dir = './outputs/llada_long/' 120 | 121 | infer = dict( 122 | partitioner=dict(type=SizePartitioner, max_task_size=1000, gen_task_coef=15), 123 | runner=dict( 124 | type=LocalRunner, 125 | task=dict(type=OpenICLInferTask), 126 | ), 127 | ) 128 | 129 | eval = dict( 130 | partitioner=dict(type=NaivePartitioner), 131 | runner=dict( 132 | type=LocalRunner, 133 | max_num_workers=32, retry=2, 134 | task=dict(type=OpenICLEvalTask, dump_details=True), 135 | ), 136 | ) 137 | 138 | # python run.py eval/eval_llada_long.py --dump-eval-details -r 139 | # python run.py eval/eval_llada_long.py --dump-eval-details -r --debug 140 | -------------------------------------------------------------------------------- /llada/llada_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Union 2 | 3 | import torch 4 | from mmengine.device import is_npu_available 5 | 6 | from opencompass.models.base import BaseModel, LMTemplateParser 7 | from opencompass.models.base_api import APITemplateParser 8 | from opencompass.registry import MODELS 9 | from opencompass.utils.logging import get_logger 10 | from opencompass.utils.prompt import PromptList 11 | 12 | from ..huggingface_above_v4_33 import HuggingFaceBaseModel 13 | 14 | from transformers import AutoConfig 15 | from transformers.generation.utils import GenerationConfig 16 | 17 | from transformers import AutoModelForCausalLM 18 | from .llada_generate import generate 19 | 20 | import os 21 | import random 22 | 23 | import numpy as np 24 | 25 | 26 | def _get_stopping_criteria(stop_words, tokenizer, batch_size): 27 | from transformers import StoppingCriteria, StoppingCriteriaList 28 | 29 | class MultiTokenEOSCriteria(StoppingCriteria): 30 | """Criteria to stop on the specified multi-token sequence.""" 31 | 32 | def __init__(self, stop_words: List[str], tokenizer, batch_size: int): 33 | self.done_tracker = [False] * batch_size 34 | self.stop_words, self.max_sequence_id_len = [], 0 35 | for s in stop_words: 36 | self.stop_words.append(s) 37 | sequence_ids = tokenizer.encode(s, add_special_tokens=False) 38 | self.max_sequence_id_len = max(self.max_sequence_id_len, len(sequence_ids)) 39 | self.tokenizer = tokenizer 40 | 41 | def __call__(self, input_ids, scores, **kwargs) -> bool: 42 | # compare the last len(stop) tokens 43 | lookback_ids_batch = input_ids[:, -self.max_sequence_id_len:] 44 | lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch) 45 | for i, done in enumerate(self.done_tracker): 46 | if done: 47 | continue 48 | self.done_tracker[i] = any(s in lookback_tokens_batch[i] for s in self.stop_words) 49 | return False not in self.done_tracker 50 | 51 | c = MultiTokenEOSCriteria(stop_words, tokenizer, batch_size) 52 | return StoppingCriteriaList([c]) 53 | 54 | 55 | def _get_possible_max_seq_len(max_seq_len, path): 56 | if max_seq_len is not None: 57 | return max_seq_len 58 | 59 | from transformers import AutoConfig 60 | config = AutoConfig.from_pretrained(path, trust_remote_code=True) 61 | possible_keys = [ 62 | 'max_position_embeddings', 63 | 'seq_length', 64 | 'model_max_length', 65 | 'max_sequence_length', 66 | ] 67 | for k in possible_keys: 68 | if hasattr(config, k): 69 | return getattr(config, k) 70 | raise ValueError('max_seq_len is not provided and cannot be inferred from the model config.') 71 | 72 | 73 | def _convert_base_messages(inputs): 74 | outputs = [] 75 | for _input in inputs: 76 | if isinstance(_input, str): 77 | outputs.append(_input) 78 | else: 79 | messages = [] 80 | for item in _input: 81 | messages.append(item['prompt']) 82 | outputs.append(''.join(messages)) 83 | return outputs 84 | 85 | 86 | def _set_model_kwargs_torch_dtype(model_kwargs): 87 | import torch 88 | if 'torch_dtype' not in model_kwargs: 89 | torch_dtype = torch.float16 90 | else: 91 | torch_dtype = { 92 | 'torch.float16': torch.float16, 93 | 'torch.bfloat16': torch.bfloat16, 94 | 'torch.float': torch.float, 95 | 'auto': 'auto', 96 | 'None': None, 97 | }.get(model_kwargs['torch_dtype']) 98 | if torch_dtype is not None: 99 | model_kwargs['torch_dtype'] = torch_dtype 100 | return model_kwargs 101 | 102 | 103 | @MODELS.register_module() 104 | class LLaDACausalLM(HuggingFaceBaseModel): 105 | 106 | def __init__(self, 107 | path: str, 108 | model_kwargs: dict = dict(), 109 | tokenizer_path: Optional[str] = None, 110 | tokenizer_kwargs: dict = dict(), 111 | peft_path: Optional[str] = None, 112 | peft_kwargs: dict = dict(), 113 | tokenizer_only: bool = False, 114 | generation_kwargs: dict = dict(), 115 | max_seq_len: Optional[int] = None, 116 | pad_token_id: Optional[int] = None, 117 | stop_words: Optional[str] = [], 118 | drop_middle: bool = False, 119 | 120 | scaling_config: dict = None, 121 | diffusion_config: dict = None, 122 | model_type: str = None, 123 | seed: int = None, 124 | 125 | **other_kwargs): 126 | 127 | if seed is not None: 128 | os.environ['PYTHONHASHSEED'] = str(seed) 129 | random.seed(seed) 130 | np.random.seed(seed) 131 | torch.manual_seed(seed) 132 | torch.cuda.manual_seed(seed) 133 | torch.cuda.manual_seed_all(seed) 134 | torch.backends.cudnn.benchmark = False # if benchmark=True, deterministic will be False 135 | torch.backends.cudnn.deterministic = True # choose a deterministic algorithm 136 | 137 | self.logger = get_logger() 138 | self.path = path 139 | self.tokenizer_only = tokenizer_only 140 | self.template_parser = LMTemplateParser() 141 | self.max_seq_len = max_seq_len # _get_possible_max_seq_len(max_seq_len, path) 142 | self.drop_middle = drop_middle 143 | self._load_tokenizer(tokenizer_path or path, tokenizer_kwargs, pad_token_id) 144 | 145 | self.scaling_config = scaling_config 146 | 147 | if model_type == 'dream': 148 | self.diffusion_config = {'steps': 32, 'alg': 'entropy', 'return_dict_in_generate': True, 'temperature': 0.2, 'top_p': 0.95} 149 | else: 150 | self.diffusion_config = {'steps': 128, 'block_length': 32, 'temperature': 0., 'cfg_scale': 0., 'remasking': 'low_confidence', } 151 | if diffusion_config is not None: 152 | self.diffusion_config.update(diffusion_config) 153 | 154 | print(self.diffusion_config, flush=True) 155 | 156 | self.model_type = model_type 157 | 158 | if not tokenizer_only: 159 | self._load_model(path=path, kwargs=model_kwargs, peft_path=peft_path, peft_kwargs=peft_kwargs) 160 | self.generation_kwargs = generation_kwargs 161 | self.stop_words = stop_words 162 | 163 | for k, v in other_kwargs.items(): 164 | if v is not None: 165 | self.logger.warning(f'Unused argument {k}={v}') 166 | 167 | 168 | def _load_model(self, path: str, kwargs: dict, peft_path: Optional[str] = None, peft_kwargs: dict = dict()): 169 | from transformers import AutoModel, AutoModelForCausalLM 170 | 171 | DEFAULT_MODEL_KWARGS = dict(device_map='auto', trust_remote_code=True) 172 | model_kwargs = DEFAULT_MODEL_KWARGS 173 | model_kwargs.update(kwargs) 174 | model_kwargs = _set_model_kwargs_torch_dtype(model_kwargs) 175 | self.logger.debug(f'using model_kwargs: {model_kwargs}') 176 | if is_npu_available(): 177 | model_kwargs['device_map'] = 'npu' 178 | 179 | config = AutoConfig.from_pretrained(path, trust_remote_code=True) 180 | config.flash_attention = True 181 | 182 | if self.scaling_config is not None: 183 | scaling_factor = self.scaling_config['scaling_factor'] if 'scaling_factor' in self.scaling_config else 1 184 | config.rope_theta = config.rope_theta * scaling_factor 185 | print(f'{config.rope_theta=}', flush=True) 186 | 187 | if self.model_type == 'llama': 188 | self.model = AutoModelForCausalLM.from_pretrained(path, config=config, device_map='auto', 189 | torch_dtype=torch.float16, trust_remote_code=True) 190 | elif self.model_type == 'dream': 191 | self.model = AutoModel.from_pretrained(path, config=config, device_map='auto', 192 | torch_dtype=torch.bfloat16, trust_remote_code=True) 193 | else: 194 | self.model = AutoModelForCausalLM.from_pretrained(path, config=config, device_map='auto', 195 | torch_dtype=torch.float16, trust_remote_code=True) 196 | 197 | if peft_path is not None: 198 | from peft import PeftModel 199 | peft_kwargs['is_trainable'] = False 200 | self.model = PeftModel.from_pretrained(self.model, peft_path, **peft_kwargs) 201 | 202 | self.model.eval() 203 | self.model.generation_config.do_sample = False 204 | 205 | 206 | def generate(self, 207 | inputs: List[str], 208 | max_out_len: int, 209 | min_out_len: Optional[int] = None, 210 | stopping_criteria: List[str] = [], 211 | **kwargs) -> List[str]: 212 | messages = _convert_base_messages(inputs) 213 | batch_size = len(messages) 214 | 215 | tokenize_kwargs = dict( 216 | return_tensors='pt', 217 | padding=True, 218 | truncation=True, 219 | add_special_tokens=True, 220 | max_length=self.max_seq_len 221 | ) 222 | 223 | if self.drop_middle: 224 | assert len(inputs) == 1 225 | input_ids = self.tokenizer(inputs, padding=False, truncation=False)['input_ids'] 226 | input_ids = torch.tensor(input_ids) 227 | if input_ids.shape[-1] > self.max_seq_len: 228 | input_ids = torch.cat([input_ids[:, : self.max_seq_len // 2], input_ids[:, - self.max_seq_len // 2:]], dim=-1) 229 | tokens = {'input_ids': input_ids, } 230 | else: 231 | if self.model_type == 'dream': 232 | messages = [self.tokenizer.bos_token + p for p in messages] 233 | tokens = self.tokenizer.batch_encode_plus(messages, **tokenize_kwargs) 234 | 235 | tokens = {k: v.to(self.model.device) for k, v in tokens.items()} 236 | 237 | generation_kwargs = self.generation_kwargs.copy() 238 | generation_kwargs.update(kwargs) 239 | stopping_criteria = list(set(stopping_criteria + self.stop_words)) 240 | if stopping_criteria: 241 | generation_kwargs['stopping_criteria'] = _get_stopping_criteria(stopping_criteria, self.tokenizer, batch_size) 242 | if max_out_len is not None: 243 | generation_kwargs['max_new_tokens'] = max_out_len 244 | if min_out_len is not None: 245 | generation_kwargs['min_new_tokens'] = min_out_len 246 | generation_kwargs['pad_token_id'] = self.tokenizer.pad_token_id 247 | 248 | # step-2: conduct model forward to generate output 249 | print(tokens['input_ids'].shape, flush=True) 250 | 251 | if self.model_type == 'llama': 252 | outputs = self.model.generate(**tokens, **generation_kwargs) 253 | elif self.model_type == 'dream': 254 | diffusion_config = self.diffusion_config 255 | if diffusion_config['steps'] > max_out_len: 256 | diffusion_config['steps'] = max_out_len 257 | print(diffusion_config, flush=True) 258 | 259 | outputs = self.model.diffusion_generate(tokens['input_ids'], 260 | max_new_tokens=max_out_len, **diffusion_config).sequences 261 | else: 262 | diffusion_config = self.diffusion_config 263 | if max_out_len % diffusion_config['block_length'] != 0: 264 | max_out_len = int((max_out_len // diffusion_config['block_length'] + 1) * diffusion_config['block_length']) 265 | 266 | if diffusion_config['steps'] > max_out_len: 267 | diffusion_config['steps'] = max_out_len 268 | print(diffusion_config, flush=True) 269 | 270 | outputs = generate(self.model, tokens['input_ids'], 271 | gen_length=max_out_len, **diffusion_config) 272 | 273 | outputs = outputs[:, tokens['input_ids'].shape[1]:] 274 | 275 | # step-3: decode the output 276 | decodeds = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) 277 | 278 | if self.model_type == 'dream': 279 | decodeds = [ 280 | token.split(self.tokenizer.eos_token)[0] # 取切分后的第一部分 281 | for token in decodeds # 遍历所有生成的文本 282 | ] 283 | else: 284 | for stop in stopping_criteria: 285 | decodeds = [token.split(stop)[0] for token in decodeds] 286 | 287 | return decodeds 288 | 289 | def get_token_len(self, prompt: str, add_special_tokens: bool=True) -> int: 290 | m = _convert_base_messages([prompt])[0] 291 | t = self.tokenizer(m, add_special_tokens=add_special_tokens) 292 | return len(t['input_ids']) 293 | -------------------------------------------------------------------------------- /needlebench/origin.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import re 5 | from pathlib import Path 6 | 7 | import tiktoken 8 | from datasets import Dataset 9 | 10 | from opencompass.datasets.base import BaseDataset 11 | from opencompass.openicl import BaseEvaluator 12 | from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS 13 | from opencompass.utils import get_data_path 14 | 15 | 16 | def get_random_line_by_language(counter, file_path, language): 17 | with open(file_path, 'r', encoding='utf-8') as file: 18 | lines = [ 19 | json.loads(line.strip()) for line in file 20 | if json.loads(line.strip())['language'] == language 21 | ] 22 | 23 | if lines: 24 | random.seed(counter) 25 | random_line = random.choice(lines) 26 | return { 27 | 'needle': random_line['needle'], 28 | 'retrieval_question': random_line['retrieval_question'], 29 | 'keyword': random_line['arg2'] 30 | } 31 | else: 32 | return None 33 | 34 | 35 | @LOAD_DATASET.register_module() 36 | class NeedleBenchOriginDataset(BaseDataset): 37 | 38 | @staticmethod 39 | def load( 40 | path: str, 41 | length: int, 42 | depth: int, 43 | tokenizer_model: str, 44 | file_list: list[str], 45 | num_repeats_per_file: int, 46 | length_buffer: int, 47 | guide: bool, 48 | language: str, 49 | needle_file_name: str, 50 | position: str = 'End', 51 | ): 52 | data = {'prompt': [], 'answer': []} 53 | tokenizer = tiktoken.encoding_for_model(tokenizer_model) 54 | 55 | def _generate_context(tokens_context, depth_percent, needle): 56 | tokens_needle = _get_tokens_from_context(needle) 57 | insertion_point = int(len(tokens_context) * (depth_percent / 100)) 58 | tokens_context = (tokens_context[:insertion_point] + 59 | tokens_needle + tokens_context[insertion_point:]) 60 | new_context = _decode_tokens(tokens_context) 61 | return new_context 62 | 63 | def _get_tokens_from_context(context): 64 | return tokenizer.encode(context) 65 | 66 | def _decode_tokens(tokens): 67 | return tokenizer.decode(tokens) 68 | 69 | def _modify_retrieval_question(retrieval_question): 70 | if language == 'Chinese': 71 | parts = retrieval_question.split('请按照') 72 | guide_retrieval_question = (parts[0] + '在回答之前,请思考文档中与此问题' 73 | '最相关的内容是什么。请按照' + parts[1]) 74 | return guide_retrieval_question 75 | elif language == 'English': 76 | parts = retrieval_question.split('Please answer in the format') 77 | guide_retrieval_question = ( 78 | parts[0] + 'Before answering, please consider' 79 | ' what in the document is most relevant to this question.' 80 | ' Please answer in the format' + parts[1]) 81 | return guide_retrieval_question 82 | else: 83 | raise ValueError(f"Language '{language}' is not supported.") 84 | 85 | def _modify_retrieval_question_for_base(retrieval_question): 86 | if language == 'Chinese': 87 | parts = retrieval_question.split('请按照') 88 | retrieval_question = (parts[0] + '在回答之前,请思考文档中与此问题' 89 | '最相关的内容是什么。请按照' + parts[1]) 90 | return retrieval_question.replace("请按照'", '')[:-16] 91 | elif language == 'English': 92 | parts = retrieval_question.split('Please answer in the format') 93 | retrieval_question = ( 94 | parts[0] + 'Before answering, please consider' 95 | ' what in the document is most relevant to this question.' 96 | ' Please answer in the format' + parts[1]) 97 | return retrieval_question.replace( 98 | "Please answer in the format '", '')[:-10] 99 | else: 100 | raise ValueError(f"Language '{language}' is not supported.") 101 | 102 | def _generate_prompt(context, retrieval_question): 103 | if guide: 104 | retrieval_question = _modify_retrieval_question( 105 | retrieval_question) 106 | else: 107 | retrieval_question = _modify_retrieval_question_for_base( 108 | retrieval_question) 109 | 110 | if language == 'Chinese': 111 | if position == 'End': 112 | retrieval_question = retrieval_question.replace("请按照'", '')[:-16] 113 | prompt = ('你是一个善于回答用户问题的智能AI助手\n' 114 | '请保持你的回答简洁清楚。不要说和下面文档中的无关的话' 115 | ',或重复你的回答\n' 116 | f'用户现在给你的文档是{context}\n\n' 117 | f'现在请问:{retrieval_question}') 118 | elif position == 'Start': 119 | prompt = ('你是一个善于回答用户问题的智能AI助手\n' 120 | '请保持你的回答简洁清楚。不要说和下面文档中的无关的话' 121 | ',或重复你的回答\n' 122 | f'现在请问:{retrieval_question}', 123 | f'用户现在给你的文档是{context}\n\n') 124 | else: 125 | raise ValueError('Unsupported position. ' 126 | 'Position must be "End" or "Start".') 127 | elif language == 'English': 128 | if position == 'End': 129 | retrieval_question = retrieval_question.replace("Please answer in the format '", '')[:-10] 130 | prompt = ('You are an intelligent AI assistant skilled in ' 131 | 'answering user questions.\n' 132 | 'Please keep your answers concise and clear. Do ' 133 | 'not talk about irrelevant topics or repeat ' 134 | 'your answers.\nThe document ' 135 | f'given to you by the user is {context}\n\n' 136 | f'Now, the question is: {retrieval_question}') 137 | elif position == 'Start': 138 | prompt = ('You are an intelligent AI assistant skilled in ' 139 | 'answering user questions.\n' 140 | 'Please keep your answers concise and clear. Do ' 141 | 'not talk about irrelevant topics or repeat ' 142 | 'your answers.\n' 143 | f'Now, the question is: {retrieval_question}' 144 | 'The document given to you by the user' 145 | f' is {context}\n\n') 146 | else: 147 | raise ValueError(f'Unsupported position {position}. ' 148 | 'Position must be "End" or "Start".') 149 | else: 150 | raise ValueError(f"Language '{language}' is not supported.") 151 | 152 | return prompt 153 | 154 | file_names = [ 155 | 'en_un_asr.jsonl', 'zh_all.jsonl', 'PaulGrahamEssays.jsonl', 156 | 'multi_needle_reasoning_en.json', 'multi_needle_reasoning_zh.json', 157 | 'zh_finance.jsonl', 'zh_game.jsonl', 'zh_general.jsonl', 158 | 'zh_government.jsonl', 'zh_movie.jsonl', 'zh_tech.jsonl' 159 | ] 160 | path = get_data_path(path) 161 | if os.environ.get('DATASET_SOURCE') == 'HF': 162 | from huggingface_hub import snapshot_download 163 | path = snapshot_download(repo_id=path, repo_type='dataset') 164 | needle_file_path = os.path.join(path, needle_file_name) 165 | 166 | for file_name in file_names: 167 | file_path = os.path.join(path, file_name) 168 | if file_name not in file_list: 169 | continue 170 | 171 | with open(file_path, 'r', encoding='utf-8') as f: 172 | lines_bak = [json.loads(line.strip()) for line in f] 173 | lines = lines_bak.copy() 174 | for counter in range(num_repeats_per_file): 175 | random.seed(counter) 176 | random.shuffle(lines) 177 | needle_file_path = os.path.join(path, needle_file_name) 178 | random_needle = get_random_line_by_language(counter, file_path=needle_file_path, language=language) 179 | needle = '\n' + random_needle['needle'] + '\n' 180 | retrieval_question = random_needle['retrieval_question'] 181 | keyword = random_needle['keyword'] 182 | 183 | context_length = length - length_buffer 184 | target_length_per_record = context_length - len( 185 | _get_tokens_from_context(needle)) 186 | target_length_per_record = max(target_length_per_record, 0) 187 | accumulated_tokens = [] 188 | for line in lines: 189 | tokens_current_line = _get_tokens_from_context( 190 | line['text']) 191 | accumulated_tokens.extend(tokens_current_line) 192 | 193 | if len(accumulated_tokens) >= target_length_per_record: 194 | break 195 | 196 | processed_text = _generate_context( 197 | accumulated_tokens[:target_length_per_record], depth, 198 | needle) 199 | 200 | processed_prompt = _generate_prompt(processed_text, 201 | retrieval_question) 202 | 203 | data['prompt'].append(processed_prompt) 204 | data['answer'].append(needle + '*' + keyword) 205 | 206 | dataset = Dataset.from_dict({ 207 | 'prompt': data['prompt'], 208 | 'answer': data['answer'], 209 | }) 210 | return dataset # Dataset.from_dict({'test': dataset}) 211 | 212 | 213 | class NeedleBenchOriginEvaluator(BaseEvaluator): 214 | 215 | def __init__(self, use_trim=False): 216 | self.use_trim = use_trim 217 | 218 | @staticmethod 219 | def _trim_prediction(prediction, reference): 220 | """Trims the prediction string based on the length of the reference 221 | string. 222 | 223 | Args: 224 | prediction (str): The prediction string. 225 | reference (str): The reference string. 226 | 227 | Returns: 228 | str: The trimmed prediction string. 229 | """ 230 | l08 = int(0.8 * len(reference)) 231 | l12 = int(1.2 * len(reference)) 232 | trimmed_prediction = prediction[:l12] 233 | 234 | if len(trimmed_prediction) > l08 and \ 235 | reference[-1] in trimmed_prediction[l08:]: 236 | end_pos = l08 + trimmed_prediction[l08:].index(reference[-1]) + 1 237 | trimmed_prediction = trimmed_prediction[:end_pos] 238 | 239 | return trimmed_prediction 240 | 241 | def levenshtein_distance(self, s1, s2): 242 | if len(s1) < len(s2): 243 | return self.levenshtein_distance(s2, s1) 244 | 245 | if len(s2) == 0: 246 | return len(s1) 247 | 248 | previous_row = range(len(s2) + 1) 249 | for i, c1 in enumerate(s1): 250 | current_row = [i + 1] 251 | for j, c2 in enumerate(s2): 252 | insertions = previous_row[j + 1] + 1 253 | deletions = current_row[j] + 1 254 | substitutions = previous_row[j] + (c1 != c2) 255 | current_row.append(min(insertions, deletions, substitutions)) 256 | previous_row = current_row 257 | 258 | return previous_row[-1] 259 | 260 | def score(self, predictions, gold): 261 | 262 | if len(predictions) != len(gold): 263 | return {'error': 'predictions and gold have different lengths'} 264 | 265 | total_score = 0 266 | details = [] 267 | for prediction, reference in zip(predictions, gold): 268 | keyword = reference.split('*')[1] 269 | reference = reference.split('*')[0] 270 | raw_prediction = prediction 271 | prediction = re.sub(r'\s+', '', prediction) 272 | reference = re.sub(r'\s+', '', reference) 273 | 274 | if self.use_trim: 275 | prediction = NeedleBenchOriginEvaluator._trim_prediction( 276 | prediction, reference) 277 | 278 | edit_distance = self.levenshtein_distance(prediction, reference) 279 | max_len = max(len(prediction), len(reference)) 280 | score = 100 * (1 - 281 | edit_distance / max_len) if max_len != 0 else 100 282 | 283 | if keyword in raw_prediction: 284 | print(f'{keyword} is in {prediction}') 285 | score = 100 286 | else: 287 | print(f'{keyword} is not in {prediction}') 288 | score = 0.2 * score 289 | 290 | detail = { 291 | 'pred': prediction, 292 | 'answer': reference, 293 | 'edit_distance': edit_distance, 294 | 'score': score 295 | } 296 | total_score += score 297 | details.append(detail) 298 | 299 | average_score = total_score / len(predictions) if predictions else 0 300 | result = {'score': average_score, 'details': details} 301 | return result 302 | 303 | 304 | @TEXT_POSTPROCESSORS.register_module('needlebench') 305 | def needlebench_postprocess(text: str) -> str: 306 | return text 307 | 308 | 309 | @TEXT_POSTPROCESSORS.register_module('needlebench_dataset') 310 | def needlebench_dataset_postprocess(text: str) -> str: 311 | return text 312 | -------------------------------------------------------------------------------- /needlebench/needlebench.py: -------------------------------------------------------------------------------- 1 | from opencompass.summarizers.needlebench import NeedleBenchSummarizer 2 | 3 | 4 | def create_m_rs_names_list(context_lengths, depths, needle_counts, 5 | languages, dataset_size): 6 | names_dict = {} 7 | multi_needle_list = [] 8 | multi_needle_en_list = [] 9 | multi_needle_zh_list = [] 10 | 11 | for needle_count in needle_counts: 12 | for language in languages: 13 | key = f'{needle_count}-Needle-{language.upper()}-{dataset_size.upper()}' 14 | names_list = [ 15 | f'Length{length}Depth{int(depth)}_{needle_count}needle_{language}_{dataset_size}' 16 | for length in context_lengths 17 | for depth in depths 18 | ] 19 | names_dict[key] = names_list 20 | 21 | multi_needle_list.extend(names_list) 22 | if language == 'en': 23 | multi_needle_en_list.extend(names_list) 24 | elif language == 'zh': 25 | multi_needle_zh_list.extend(names_list) 26 | names_dict[f'Multi-Needle-Reasoning(M-RS)-{dataset_size.upper()}'] = multi_needle_list 27 | names_dict[f'Multi-Needle-Reasoning-EN-{dataset_size.upper()}'] = multi_needle_en_list 28 | names_dict[f'Multi-Needle-Reasoning-ZH-{dataset_size.upper()}'] = multi_needle_zh_list 29 | 30 | return names_dict 31 | 32 | def create_summarizer(context_lengths, depths, dataset_size, 33 | sparse_depths=None): 34 | needle_counts = ['2', '3', '4', '5'] 35 | languages = ['en', 'zh'] 36 | if sparse_depths: 37 | depths = sparse_depths 38 | names_dict = {} 39 | multi_reasoning_names = create_m_rs_names_list( 40 | context_lengths, depths, needle_counts, languages, dataset_size) 41 | 42 | names_dict.update(multi_reasoning_names) 43 | 44 | single_needle_list = [] 45 | single_needle_en_list = [] 46 | single_needle_zh_list = [] 47 | 48 | for language in languages: 49 | names_list = [ 50 | f'Length{length}Depth{int(depth)}_origin_{language}_{dataset_size}' 51 | for length in context_lengths 52 | for depth in depths 53 | ] 54 | single_needle_list.extend(names_list) 55 | if language == 'en': 56 | single_needle_en_list.extend(names_list) 57 | elif language == 'zh': 58 | single_needle_zh_list.extend(names_list) 59 | names_dict[f'Single-Needle-Retrieval(S-RT)-{dataset_size.upper()}'] = single_needle_list 60 | names_dict[f'Single-Needle-Retrieval-EN-{dataset_size.upper()}'] = single_needle_en_list 61 | names_dict[f'Single-Needle-Retrieval-ZH-{dataset_size.upper()}'] = single_needle_zh_list 62 | 63 | parallel_list = [] 64 | parallel_en_list = [] 65 | parallel_zh_list = [] 66 | 67 | for language in languages: 68 | names_list = [ 69 | f'Length{length}_parallel_{language}_{dataset_size}' 70 | for length in context_lengths 71 | ] 72 | parallel_list.extend(names_list) 73 | if language == 'en': 74 | parallel_en_list.extend(names_list) 75 | elif language == 'zh': 76 | parallel_zh_list.extend(names_list) 77 | names_dict[f'Multi-Needle-Retrieval(M-RT)-{dataset_size.upper()}'] = parallel_list 78 | names_dict[f'Multi-Needle-Retrieval-EN-{dataset_size.upper()}'] = parallel_en_list 79 | names_dict[f'Multi-Needle-Retrieval-ZH-{dataset_size.upper()}'] = parallel_zh_list 80 | 81 | summary_groups = [ 82 | {'name': key, 'subsets': value} for key, value in names_dict.items() 83 | ] 84 | 85 | summary_groups.append({ 86 | 'name': f'NeedleBench-Overall-Score-{dataset_size.upper()}', 87 | 'subsets': [[f'Single-Needle-Retrieval(S-RT)-{dataset_size.upper()}', 'naive_average'], 88 | [f'Multi-Needle-Reasoning(M-RS)-{dataset_size.upper()}', 'naive_average'], 89 | [f'Multi-Needle-Retrieval(M-RT)-{dataset_size.upper()}', 'average_score']], 90 | 'weights': {f'Single-Needle-Retrieval(S-RT)-{dataset_size.upper()}': 0.4, 91 | f'Multi-Needle-Reasoning(M-RS)-{dataset_size.upper()}': 0.3, 92 | f'Multi-Needle-Retrieval(M-RT)-{dataset_size.upper()}': 0.3}}) 93 | summarizer_config = { 94 | 'type': NeedleBenchSummarizer, 95 | 'summary_groups': summary_groups, 96 | 'dataset_abbrs': [ 97 | f'NeedleBench-Overall-Score-{dataset_size.upper()}', 98 | f'--------- NeedleBench-{dataset_size.upper()}-Single-Needle-Retrieval ---------', 99 | f'Single-Needle-Retrieval(S-RT)-{dataset_size.upper()}', 100 | f'Single-Needle-Retrieval-EN-{dataset_size.upper()}', 101 | f'Single-Needle-Retrieval-ZH-{dataset_size.upper()}', 102 | f'--------- NeedleBench-{dataset_size.upper()}-Multi-Needle-Retrieval ---------', 103 | f'Multi-Needle-Retrieval(M-RT)-{dataset_size.upper()}', 104 | f'Multi-Needle-Retrieval-EN-{dataset_size.upper()}', 105 | f'Multi-Needle-Retrieval-ZH-{dataset_size.upper()}', 106 | f'--------- NeedleBench-{dataset_size.upper()}-Multi-Needle-Reasoning ---------', 107 | f'Multi-Needle-Reasoning(M-RS)-{dataset_size.upper()}', 108 | f'Multi-Needle-Reasoning-EN-{dataset_size.upper()}', 109 | f'Multi-Needle-Reasoning-ZH-{dataset_size.upper()}', 110 | f'2-Needle-EN-{dataset_size.upper()}', 111 | f'2-Needle-ZH-{dataset_size.upper()}', 112 | f'3-Needle-EN-{dataset_size.upper()}', 113 | f'3-Needle-ZH-{dataset_size.upper()}', 114 | f'4-Needle-EN-{dataset_size.upper()}', 115 | f'4-Needle-ZH-{dataset_size.upper()}', 116 | f'5-Needle-EN-{dataset_size.upper()}', 117 | f'5-Needle-ZH-{dataset_size.upper()}', 118 | ] 119 | } 120 | return summarizer_config 121 | 122 | 123 | context_lengths = list([4000, 8000, 16000, 24000, 32000]) 124 | depths = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100] 125 | needlebench_summarizer = create_summarizer(context_lengths, depths, '') 126 | 127 | depths = [0, 5, 10, 15, 21, 26, 31, 36, 42, 47, 52, 57, 63, 68, 73, 78, 84, 89, 94, 100] 128 | depths_list_sparse = [0, 10, 21, 31, 42, 52, 63, 73, 84, 94, 100] 129 | 130 | context_lengths_4k = list(range(1000, 5000, 1000)) 131 | needlebench_4k_summarizer = create_summarizer(context_lengths_4k, depths, '4k') 132 | context_lengths_8k = list(range(5000, 9000, 1000)) 133 | needlebench_8k_summarizer = create_summarizer(context_lengths_8k, depths, '8k') 134 | context_lengths_32k = [9000, 13000, 17000, 21000, 25000, 29000, 31000, 32000] 135 | needlebench_32k_summarizer = create_summarizer(context_lengths_32k, depths_list_sparse, '32k') 136 | context_lengths_128k = list([16000, 32000, 48000, 64000, 80000, 96000, 112000, 128000]) 137 | needlebench_128k_summarizer = create_summarizer(context_lengths_128k, depths_list_sparse, '128k') 138 | context_lengths_200k = list([16000, 48000, 80000, 112000, 128000, 144000, 176000, 200000]) 139 | needlebench_200k_summarizer = create_summarizer(context_lengths_200k, depths_list_sparse, '200k') 140 | context_lengths_256k = list([32000, 128000, 256000]) 141 | needlebench_256k_summarizer = create_summarizer(context_lengths_256k, depths_list_sparse, '256k') 142 | context_lengths_1000k = list([20000, 160000, 300000, 440000, 580000, 720000, 860000, 1000000]) 143 | needlebench_1000k_summarizer = create_summarizer(context_lengths_1000k, depths_list_sparse, '1000k') 144 | 145 | depths_list_internal = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, ] 146 | needlebench_internal_32k_summarizer = create_summarizer([32000], depths_list_internal, '32000') 147 | needlebench_internal_100k_summarizer = create_summarizer([100000], depths_list_internal, '100000') 148 | needlebench_internal_200k_summarizer = create_summarizer([200000], depths_list_internal, '200000') 149 | 150 | _needlebench_8k_parallel_en_batch1 = [] 151 | _needlebench_8k_parallel_en_batch5 = [] 152 | _needlebench_8k_parallel_en_batch10 = [] 153 | _needlebench_8k_parallel_en_batch15 = [] 154 | _needlebench_8k_parallel_en_batch20 = [] 155 | _needlebench_8k_parallel_zh_batch1 = [] 156 | _needlebench_8k_parallel_zh_batch5 = [] 157 | _needlebench_8k_parallel_zh_batch10 = [] 158 | _needlebench_8k_parallel_zh_batch15 = [] 159 | _needlebench_8k_parallel_zh_batch20 = [] 160 | for original_context_length in context_lengths_8k: 161 | _needlebench_8k_parallel_en_batch1.append(f'Length{original_context_length}_parallel_en_8k_batch1') 162 | _needlebench_8k_parallel_en_batch5.append(f'Length{original_context_length}_parallel_en_8k_batch5') 163 | _needlebench_8k_parallel_en_batch10.append(f'Length{original_context_length}_parallel_en_8k_batch10') 164 | _needlebench_8k_parallel_en_batch15.append(f'Length{original_context_length}_parallel_en_8k_batch15') 165 | _needlebench_8k_parallel_en_batch20.append(f'Length{original_context_length}_parallel_en_8k_batch20') 166 | _needlebench_8k_parallel_zh_batch1.append(f'Length{original_context_length}_parallel_zh_8k_batch1') 167 | _needlebench_8k_parallel_zh_batch5.append(f'Length{original_context_length}_parallel_zh_8k_batch5') 168 | _needlebench_8k_parallel_zh_batch10.append(f'Length{original_context_length}_parallel_zh_8k_batch10') 169 | _needlebench_8k_parallel_zh_batch15.append(f'Length{original_context_length}_parallel_zh_8k_batch15') 170 | _needlebench_8k_parallel_zh_batch20.append(f'Length{original_context_length}_parallel_zh_8k_batch20') 171 | 172 | 173 | _needlebench_8k_parallel_batch1 = _needlebench_8k_parallel_en_batch1 + _needlebench_8k_parallel_zh_batch1 174 | _needlebench_8k_parallel_batch5 = _needlebench_8k_parallel_en_batch5 + _needlebench_8k_parallel_zh_batch5 175 | _needlebench_8k_parallel_batch10 = _needlebench_8k_parallel_en_batch10 + _needlebench_8k_parallel_zh_batch10 176 | _needlebench_8k_parallel_batch15 = _needlebench_8k_parallel_en_batch15 + _needlebench_8k_parallel_zh_batch15 177 | _needlebench_8k_parallel_batch20 = _needlebench_8k_parallel_en_batch20 + _needlebench_8k_parallel_zh_batch20 178 | 179 | needlebench_summary_groups = [ 180 | {'name': 'parallel_version_batch1', 'subsets': [[_dataset, 'average_score'] for _dataset in _needlebench_8k_parallel_batch1]}, 181 | {'name': 'parallel_version_zh_batch1', 'subsets': [[_dataset, 'average_score'] for _dataset in _needlebench_8k_parallel_zh_batch1]}, 182 | {'name': 'parallel_version_en_batch1', 'subsets': [[_dataset, 'average_score'] for _dataset in _needlebench_8k_parallel_en_batch1]}, 183 | {'name': 'parallel_version_batch5', 'subsets': [[_dataset, 'average_score'] for _dataset in _needlebench_8k_parallel_batch5]}, 184 | {'name': 'parallel_version_zh_batch5', 'subsets': [[_dataset, 'average_score'] for _dataset in _needlebench_8k_parallel_zh_batch5]}, 185 | {'name': 'parallel_version_en_batch5', 'subsets': [[_dataset, 'average_score'] for _dataset in _needlebench_8k_parallel_en_batch5]}, 186 | {'name': 'parallel_version_batch10', 'subsets': [[_dataset, 'average_score'] for _dataset in _needlebench_8k_parallel_batch10]}, 187 | {'name': 'parallel_version_zh_batch10', 'subsets': [[_dataset, 'average_score'] for _dataset in _needlebench_8k_parallel_zh_batch10]}, 188 | {'name': 'parallel_version_en_batch10', 'subsets': [[_dataset, 'average_score'] for _dataset in _needlebench_8k_parallel_en_batch10]}, 189 | {'name': 'parallel_version_batch15', 'subsets': [[_dataset, 'average_score'] for _dataset in _needlebench_8k_parallel_batch15]}, 190 | {'name': 'parallel_version_zh_batch15', 'subsets': [[_dataset, 'average_score'] for _dataset in _needlebench_8k_parallel_zh_batch15]}, 191 | {'name': 'parallel_version_en_batch15', 'subsets': [[_dataset, 'average_score'] for _dataset in _needlebench_8k_parallel_en_batch15]}, 192 | {'name': 'parallel_version_batch20', 'subsets': [[_dataset, 'average_score'] for _dataset in _needlebench_8k_parallel_batch20]}, 193 | {'name': 'parallel_version_zh_batch20', 'subsets': [[_dataset, 'average_score'] for _dataset in _needlebench_8k_parallel_zh_batch20]}, 194 | {'name': 'parallel_version_en_batch20', 'subsets': [[_dataset, 'average_score'] for _dataset in _needlebench_8k_parallel_en_batch20]}, 195 | ] 196 | 197 | needlebench_8k_batch_overall_summarizer = dict( 198 | dataset_abbrs=[ 199 | '--------- NeedleBench-8k Parallel-Needles ---------', # category 200 | 'parallel_version_batch1', 201 | 'parallel_version_batch5', 202 | 'parallel_version_batch10', 203 | 'parallel_version_batch15', 204 | 'parallel_version_batch20', 205 | 'parallel_version_zh_batch1', 206 | 'parallel_version_en_batch1', 207 | 'parallel_version_zh_batch5', 208 | 'parallel_version_en_batch5', 209 | 'parallel_version_zh_batch10', 210 | 'parallel_version_en_batch10', 211 | 'parallel_version_zh_batch15', 212 | 'parallel_version_en_batch15', 213 | 'parallel_version_zh_batch20', 214 | 'parallel_version_en_batch20', 215 | ], 216 | summary_groups=needlebench_summary_groups, 217 | ) 218 | 219 | needlebench_summary_groups = [ 220 | {'name': 'parallel_version_batch1', 'subsets': [[_dataset, 'Depth0'] for _dataset in _needlebench_8k_parallel_batch1]}, 221 | {'name': 'parallel_version_zh_batch1', 'subsets': [[_dataset, 'Depth0'] for _dataset in _needlebench_8k_parallel_zh_batch1]}, 222 | {'name': 'parallel_version_en_batch1', 'subsets': [[_dataset, 'Depth0'] for _dataset in _needlebench_8k_parallel_en_batch1]}, 223 | {'name': 'parallel_version_batch5', 'subsets': [[_dataset, 'Depth0'] for _dataset in _needlebench_8k_parallel_batch5]}, 224 | {'name': 'parallel_version_zh_batch5', 'subsets': [[_dataset, 'Depth0'] for _dataset in _needlebench_8k_parallel_zh_batch5]}, 225 | {'name': 'parallel_version_en_batch5', 'subsets': [[_dataset, 'Depth0'] for _dataset in _needlebench_8k_parallel_en_batch5]}, 226 | {'name': 'parallel_version_batch10', 'subsets': [[_dataset, 'Depth0'] for _dataset in _needlebench_8k_parallel_batch10]}, 227 | {'name': 'parallel_version_zh_batch10', 'subsets': [[_dataset, 'Depth0'] for _dataset in _needlebench_8k_parallel_zh_batch10]}, 228 | {'name': 'parallel_version_en_batch10', 'subsets': [[_dataset, 'Depth0'] for _dataset in _needlebench_8k_parallel_en_batch10]}, 229 | {'name': 'parallel_version_batch15', 'subsets': [[_dataset, 'Depth0'] for _dataset in _needlebench_8k_parallel_batch15]}, 230 | {'name': 'parallel_version_zh_batch15', 'subsets': [[_dataset, 'Depth0'] for _dataset in _needlebench_8k_parallel_zh_batch15]}, 231 | {'name': 'parallel_version_en_batch15', 'subsets': [[_dataset, 'Depth0'] for _dataset in _needlebench_8k_parallel_en_batch15]}, 232 | {'name': 'parallel_version_batch20', 'subsets': [[_dataset, 'Depth0'] for _dataset in _needlebench_8k_parallel_batch20]}, 233 | {'name': 'parallel_version_zh_batch20', 'subsets': [[_dataset, 'Depth0'] for _dataset in _needlebench_8k_parallel_zh_batch20]}, 234 | {'name': 'parallel_version_en_batch20', 'subsets': [[_dataset, 'Depth0'] for _dataset in _needlebench_8k_parallel_en_batch20]}, 235 | ] 236 | 237 | needlebench_8k_batch_depth0_summarizer = dict( 238 | dataset_abbrs=[ 239 | '--------- NeedleBench-8k Parallel-Needles ---------', # category 240 | 'parallel_version_batch1', 241 | 'parallel_version_batch5', 242 | 'parallel_version_batch10', 243 | 'parallel_version_batch15', 244 | 'parallel_version_batch20', 245 | 'parallel_version_zh_batch1', 246 | 'parallel_version_en_batch1', 247 | 'parallel_version_zh_batch5', 248 | 'parallel_version_en_batch5', 249 | 'parallel_version_zh_batch10', 250 | 'parallel_version_en_batch10', 251 | 'parallel_version_zh_batch15', 252 | 'parallel_version_en_batch15', 253 | 'parallel_version_zh_batch20', 254 | 'parallel_version_en_batch20', 255 | ], 256 | summary_groups=needlebench_summary_groups, 257 | ) 258 | 259 | def gen_atc_summarizer(needle_num_list): 260 | categories = [ 261 | 'ZH-Direct-CE', 'EN-Direct-CE', 262 | 'ZH-Reasoning-CE', 'EN-Reasoning-CE' 263 | ] 264 | needlebench_atc_summary_groups = [] 265 | 266 | # 根据分类生成summary groups 267 | for category in categories: 268 | # 对于CircularEval相关的评分,使用perf_4指标,否则使用acc_1指标 269 | metric = 'perf_4' if 'CE' in category else 'acc_1' 270 | # 生成subsets时,不需要在数据集名称中包含CircularEval信息 271 | cleaned_category = category.replace('-CE', '').replace('-Direct', '') 272 | needlebench_atc_summary_groups.append({ 273 | 'name': category, 274 | 'subsets': [ 275 | [f'NeedleBenchATCDataset-{num_needles}Needle-{cleaned_category}', metric] 276 | for num_needles in needle_num_list 277 | ], 278 | 'weights': {f'NeedleBenchATCDataset-{num_needles}Needle-{cleaned_category}': num_needles for num_needles in needle_num_list}, 279 | }) 280 | 281 | needlebench_atc_summary_groups.append({ 282 | 'name': 'ATC-CE-Overall', 283 | 'subsets': [ 284 | [f'{category}', 'weighted_average'] for category in categories 285 | ], 286 | }) 287 | atc_dataset_abbrs = [] 288 | atc_dataset_abbrs.append(['ATC-CE-Overall', 'naive_average']) 289 | 290 | for category in categories: 291 | weighted_average_score_entry = [f'{category}', 'weighted_average'] 292 | atc_dataset_abbrs.append(weighted_average_score_entry) 293 | 294 | needlebench_atc_summarizer = dict( 295 | dataset_abbrs=[ 296 | *atc_dataset_abbrs, 297 | '######## Needlebench-ATC Accuracy ########', # category 298 | *[[f'NeedleBenchATCDataset-{num_needles}Needle-ZH', 'acc_1'] for num_needles in needle_num_list], 299 | '------------------------------------------', 300 | *[[f'NeedleBenchATCDataset-{num_needles}Needle-EN', 'acc_1'] for num_needles in needle_num_list], 301 | '------------------------------------------', 302 | *[[f'NeedleBenchATCDataset-{num_needles}Needle-ZH-Reasoning', 'acc_1'] for num_needles in needle_num_list], 303 | '------------------------------------------', 304 | *[[f'NeedleBenchATCDataset-{num_needles}Needle-EN-Reasoning', 'acc_1'] for num_needles in needle_num_list], 305 | '------------------------------------------', 306 | '######## Needlebench-ATC CircularEval ########', # category 307 | *[[f'NeedleBenchATCDataset-{num_needles}Needle-ZH', 'perf_4'] for num_needles in needle_num_list], 308 | '------------------------------------------', 309 | *[[f'NeedleBenchATCDataset-{num_needles}Needle-EN', 'perf_4'] for num_needles in needle_num_list], 310 | '------------------------------------------', 311 | *[[f'NeedleBenchATCDataset-{num_needles}Needle-ZH-Reasoning', 'perf_4'] for num_needles in needle_num_list], 312 | '------------------------------------------', 313 | *[[f'NeedleBenchATCDataset-{num_needles}Needle-EN-Reasoning', 'perf_4'] for num_needles in needle_num_list], 314 | '------------------------------------------', 315 | ], 316 | summary_groups=needlebench_atc_summary_groups 317 | ) 318 | return needlebench_atc_summarizer 319 | 320 | 321 | atc_summarizer_20 = gen_atc_summarizer(list(range(2, 20, 1))) 322 | atc_summarizer_50 = gen_atc_summarizer(list(range(2, 50, 1))) 323 | atc_summarizer_80 = gen_atc_summarizer(list(range(2, 80, 1))) 324 | -------------------------------------------------------------------------------- /ppl/get_ppl_plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pylab as mpl 3 | import matplotlib.pyplot as plt 4 | import seaborn 5 | 6 | # mpl.rcParams['font.sans-serif'] = ['Times New Roman'] 7 | # mpl.rcParams['axes.unicode_minus'] = False 8 | 9 | ppl_dict = { 10 | 'LLaMA3-8B-Base': [ 11 | 14.9375, 12.9375, 11.0625, 9.5, 8.75, 7.5, 6.375, 6.375, 5.9375, 5.75, 5.5625, 5.4375, 5.71875, 5.53125, 5.84375, 5.84375, 5.75, 5.8125, 5.53125, 5.40625, 5.53125, 5.71875, 5.9375, 6.125, 6.125, 5.96875, 5.9375, 5.84375, 5.9375, 5.9375, 5.9375, 5.96875, 6.0625, 6.0625, 5.96875, 5.9375, 5.8125, 5.84375, 5.71875, 5.75, 5.71875, 5.53125, 5.4375, 5.5, 5.65625, 5.71875, 5.65625, 5.5625, 5.4375, 5.375, 5.28125, 5.28125, 5.15625, 5.125, 5.03125, 5.03125, 5.03125, 4.96875, 4.96875, 5.0, 5.03125, 5.0, 5.0, 5.09375, 5.03125, 5.03125, 4.96875, 5.09375, 5.125, 5.15625, 5.25, 5.28125, 5.25, 5.28125, 5.3125, 5.28125, 5.25, 5.25, 5.1875, 5.1875, 5.15625, 5.125, 5.15625, 5.125, 5.125, 5.09375, 5.125, 5.125, 5.15625, 5.15625, 5.15625, 5.125, 5.09375, 5.03125, 5.03125, 5.03125, 5.03125, 5.125, 5.125, 5.15625, 5.15625, 5.1875, 5.1875, 5.25, 5.375, 5.40625, 5.4375, 5.4375, 5.4375, 5.40625, 5.4375, 5.4375, 5.53125, 5.53125, 5.5, 5.5, 5.53125, 5.4375, 5.4375, 5.5, 5.5, 5.4375, 5.40625, 5.40625, 5.40625, 5.375, 5.3125, 5.28125, 5.25, 5.25, 5.25, 5.25, 5.25, 5.25, 5.28125, 5.3125, 5.375, 5.4375, 5.53125, 5.625, 5.71875, 5.84375, 6.0625, 6.28125, 6.40625, 6.625, 6.84375, 7.0, 7.09375, 7.5, 7.75, 8.0, 8.25, 8.375, 8.625, 8.75, 9.1875, 9.5, 9.8125, 9.9375, 10.25, 10.4375, 10.75, 11.0625, 11.625, 11.8125, 12.0, 12.5625, 12.75, 12.9375, 13.1875, 13.5625, 14.0, 14.4375, 14.9375, 15.1875, 15.375, 15.875, 16.125, 16.875, 17.125, 17.5, 18.0, 18.25, 18.625, 19.125, 19.5, 20.125, 20.75, 21.0, 21.375, 21.75, 22.375, 22.375, 23.5, 23.5, 24.25, 24.25, 25.0, 25.375, 25.75, 25.75, 27.0, 27.5, 27.5, 28.375, 28.75, 28.75, 29.25, 30.125, 30.125, 31.125, 31.625, 32.5, 32.5, 33.0, 33.0, 35.25, 35.75, 36.25, 37.0, 37.5, 37.5, 38.0, 40.0, 40.5, 41.25, 41.75, 41.75, 42.5, 43.75, 44.5, 45.25, 46.0, 46.75, 47.5, 48.25, 49.75, 50.5, 52.0, 53.0, 53.75, 54.5, 56.25, 56.25, 58.0, 58.0, 60.0, 61.75, 61.75 12 | ], 13 | 'LLaMA3-8B-Instruct' : [ 14 | 20.125, 19.125, 16.625, 15.1875, 13.375, 10.5625, 9.1875, 9.1875, 8.25, 8.125, 8.375, 8.0, 8.25, 8.0, 8.25, 8.25, 8.125, 8.125, 7.75, 7.5, 7.625, 7.75, 8.125, 8.5, 8.375, 8.25, 8.125, 8.0, 8.25, 8.125, 8.125, 8.125, 8.25, 8.375, 8.0, 8.0, 7.75, 7.875, 7.75, 7.875, 7.75, 7.625, 7.375, 7.375, 7.625, 7.75, 7.625, 7.5, 7.34375, 7.15625, 7.09375, 7.0625, 6.9375, 6.78125, 6.6875, 6.625, 6.6875, 6.625, 6.53125, 6.625, 6.625, 6.6875, 6.625, 6.6875, 6.625, 6.625, 6.625, 6.78125, 6.78125, 6.875, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.9375, 6.84375, 6.84375, 6.78125, 6.78125, 6.71875, 6.71875, 6.71875, 6.71875, 6.71875, 6.71875, 6.71875, 6.71875, 6.78125, 6.78125, 6.6875, 6.625, 6.625, 6.5625, 6.53125, 6.53125, 6.6875, 6.84375, 6.84375, 6.78125, 6.84375, 6.84375, 6.84375, 7.0, 7.0625, 7.09375, 7.09375, 7.09375, 7.09375, 7.09375, 7.09375, 7.21875, 7.28125, 7.28125, 7.21875, 7.15625, 7.15625, 7.15625, 7.15625, 7.15625, 7.15625, 7.09375, 7.09375, 7.09375, 7.0625, 6.9375, 6.9375, 6.875, 6.84375, 6.84375, 6.84375, 6.84375, 6.84375, 6.9375, 6.9375, 7.0, 7.0625, 7.15625, 7.15625, 7.375, 7.5, 7.75, 7.875, 8.0, 8.25, 8.375, 8.5, 8.5, 8.9375, 9.0625, 9.1875, 9.3125, 9.625, 9.625, 9.9375, 10.125, 10.4375, 10.5625, 10.75, 10.9375, 11.0625, 11.25, 11.625, 12.1875, 12.1875, 12.5625, 12.9375, 13.1875, 13.375, 13.8125, 14.25, 14.6875, 14.9375, 15.1875, 15.625, 15.625, 16.125, 16.375, 17.125, 17.75, 18.25, 18.25, 18.875, 18.875, 19.5, 20.125, 20.375, 20.75, 21.0, 21.75, 22.375, 22.375, 23.125, 23.875, 24.25, 24.625, 25.375, 25.75, 26.25, 26.625, 26.625, 27.875, 28.75, 28.75, 29.25, 30.125, 30.625, 31.125, 32.0, 32.5, 33.75, 34.25, 34.75, 34.75, 35.25, 35.75, 38.0, 38.0, 38.0, 38.75, 39.25, 40.0, 40.5, 41.75, 42.5, 42.5, 42.5, 43.25, 44.5, 45.25, 46.0, 46.75, 47.5, 47.5, 48.25, 49.75, 50.5, 51.25, 52.0, 53.0, 53.75, 54.5, 56.25, 56.25, 56.25, 58.0, 60.0, 60.0, 60.0 15 | ], 16 | 'LLaDA-8B-Base': [ 17 | 31.215599965441594, 15.635950058313327, 15.031746350814204, 10.017377041294852, 8.256935111477564, 7.525485208461095, 7.985191489693327, 8.532730828979332, 7.608915188273775, 7.410531926465482, 7.471988117223264, 7.912616865786797, 8.217873724794813, 8.334566157399937, 8.965991706183758, 8.374734243537858, 7.865086852719313, 7.505556775666931, 7.6387474437014244, 7.716806199663392, 7.906625408063113, 7.826547443851141, 7.820434681185339, 7.773013353043445, 7.6823053487499395, 7.519993049682306, 7.5851127144894175, 7.614007368574815, 7.401337181284581, 7.082590073325164, 7.281449902763494, 7.250632708257208, 7.385190204863756, 7.240351640006626, 7.150823174571431, 7.117459750490899, 7.041011901524979, 7.14131396688932, 7.040747377935942, 7.174369274642432, 6.908463415532484, 6.6952907704209155, 6.617282542212192, 6.6266281413013015, 6.943368413066175, 7.163349689432342, 7.266828983660283, 7.1442882283546005, 7.0175985626424975, 6.809815358106866, 6.605365632966719, 6.504920388842127, 6.512745776581436, 6.303983319282476, 6.254125948376688, 6.175782817733016, 6.16037204764814, 6.098271597670605, 6.08832243719332, 6.018471158799754, 6.101231189198328, 6.1330231739688115, 6.19231266864724, 6.246279638681408, 6.308564195947122, 6.3513573411961834, 6.266276853364549, 6.197741407914564, 6.238450866156455, 6.368893046282213, 6.430820322892907, 6.498839604652489, 6.624221961698937, 6.659843342561861, 6.637067015646085, 6.535000578403961, 6.623806779474312, 6.590760187075854, 6.500062219327069, 6.422939775853884, 6.387339314856706, 6.366051173598784, 6.355988026544827, 6.419292628558981, 6.49201950356815, 6.50392465801976, 6.524574630235263, 6.493391020919108, 6.454417258179848, 6.4990548647396205, 6.568781642350071, 6.541922574705067, 6.525572527643752, 6.488880644922927, 6.429301036653022, 6.491336377003697, 6.410051686437492, 6.386877670991678, 6.451130022294027, 6.4900230342622365, 6.583575285091979, 6.613350518860776, 6.695928100544172, 6.693058399170227, 6.7635071922341, 6.756801560547409, 6.9233881138503035, 7.028897296986769, 7.034633061911196, 7.072288218832691, 7.036653736792537, 6.995942832381456, 7.039851956817035, 7.068983947474231, 7.017336989503497, 7.039576334002366, 7.071885295068534, 7.101866047258819, 7.030640402353992, 7.018517242289501, 6.961387279603931, 6.904072337450267, 6.860666587916588, 6.911942467951614, 6.978101509958631, 6.943476221682916, 6.901647048844326, 6.89616662432893, 6.868015395583732, 6.875920316123156, 6.879673504602549, 6.818817637311625, 6.848408645479323, 6.863628624124874, 6.919489332876996, 6.909812716763575, 6.893682397376075, 6.957150237244219, 6.927788067430783, 6.947660038078617, 6.9345063789830315, 6.925904337515425, 6.945530162533639, 7.025508326656496, 7.045796621342354, 7.046749159370119, 7.1033184840050385, 7.137764112577194, 7.122176226751386, 7.164645096076784, 7.137155050179349, 7.137029786684471, 7.127323659438644, 7.147098782772781, 7.255903835697371, 7.192219844780579, 7.154724802701353, 7.156549204381604, 7.123155292757565, 7.170813111993173, 7.189614670141452, 7.143749797574801, 7.058235130031492, 7.022120813363837, 6.99003666539505, 6.979449534321139, 6.992211297002591, 6.9917247147550485, 7.073082269118717, 7.078803528705634, 7.01132011250281, 7.012936876118699, 7.0171936746287225, 7.0500918072587595, 7.076757317244631, 7.137000461550471, 7.0766110478211965, 7.026706405516082, 7.003560904943472, 6.995866267849627, 6.948686117218292, 6.95367650272566, 6.936544033785841, 6.898163716157902, 6.835324811841622, 6.861384567558508, 6.8325742679645085, 6.8027088037255385, 6.773611591671094, 6.803478568906453, 6.7482584732214, 6.718813213295617, 6.717270645383044, 6.677946236554242, 6.624944882055482, 6.596728060286893, 6.610178586778914, 6.565547870559994, 6.5610749200197676, 6.568782550774787, 6.507595755006346, 6.523162366986279, 6.476274589730972, 6.50234724514566, 6.482812761022524, 6.457086715183274, 6.50687245388926, 6.52666936435223, 6.468406217880967, 6.423556185376323, 6.440320069166522, 6.407369703127251, 6.386369798099357, 6.375067331802796, 6.324300721558993, 6.297648775115306, 6.304841900595934, 6.301055498884993, 6.27911521066224, 6.2665366911393106, 6.264916457374869, 6.303506970690386, 6.346278254500108, 6.36275580289926, 6.438236538821236, 6.422511028852881, 6.387037414003175, 6.381899200761289, 6.40995421066046, 6.409289066040262, 6.424362829674744, 6.447125550662103, 6.466607426514347, 6.50258140617411, 6.5003353981988345, 6.536285671659909, 6.549376669798613, 6.543884036796224, 6.521718893405075, 6.531438217193836, 6.5149661724684345, 6.498981074689645, 6.502924385606996, 6.480615003930579, 6.482151077655006, 6.475344266090059, 6.4701534912729235, 6.467185679725477, 6.470641347923814, 6.46835200950965, 6.47221676283681, 6.458665941557523, 6.46473014796024, 6.457447861777459, 6.420963570909659 18 | ], 19 | 'LLaDA-8B-Instruct': [ 20 | 21.488220103390674, 21.236578538844455, 20.7813846490413, 17.245732080909402, 12.443621856012372, 9.534163835109135, 8.963033185736979, 8.643642779333355, 8.169917255088496, 8.43842529967077, 9.426302474313154, 9.822586700194892, 9.743389224789144, 9.617303977331211, 9.490795831112784, 9.642376862067236, 9.87497774429867, 9.801472073639422, 9.250886630715607, 9.921821076540564, 10.206679613451383, 10.57621070106594, 10.922901903440252, 11.10293222021639, 10.912459861216538, 10.30596682268114, 10.172234310941882, 10.481844454246044, 10.209915855738293, 10.298246603467328, 10.478661932964982, 10.520886176407842, 10.650407543332788, 10.368751458302249, 10.162382392937458, 10.09867066747591, 10.168638425244206, 10.59139609373048, 10.487894581265193, 10.716540813845993, 10.25023152044307, 9.758503267321059, 9.47549887615146, 9.39723744594226, 9.7044450746539, 9.90654173478695, 9.807576983475562, 9.589998569811197, 9.360991885947877, 9.17281353389875, 9.1052019757222, 9.036065608957399, 8.830348904389545, 8.59552108786613, 8.42794105730301, 8.348287267322183, 8.28376063320556, 8.246212367149157, 8.073288770965886, 7.992107907171555, 7.97704512101963, 8.04085721922584, 8.060075133995781, 8.01039497394845, 8.192883764556996, 8.13500782012844, 8.037685374572467, 7.9657596973477665, 7.924664549574043, 7.926139559932092, 8.18079662876541, 8.162218869863649, 8.30397454793508, 8.388505069546042, 8.48557312008981, 8.551226529418708, 8.519010763004573, 8.332837016844348, 8.256752743258572, 8.24147737097778, 8.263983725381863, 8.271876450596881, 8.2209434408942, 8.151825036103698, 8.10366710150276, 7.9929904073911215, 7.986716898397476, 7.990571372257731, 7.962638879653429, 7.972046523036588, 7.9399001256374975, 7.950937894311953, 7.9066642406486505, 7.863837457680564, 7.8068080743363595, 7.8231707448524075, 7.784040493700461, 7.761867477407588, 7.784828892578422, 7.837670975086094, 7.893160290800521, 7.961270679308158, 8.012051993664873, 8.17200218397134, 8.149582707028571, 8.155285751103595, 8.298172528325164, 8.395091438218962, 8.482968674263697, 8.572616120943557, 8.651279326158564, 8.621459890401164, 8.657583254502072, 8.689068153911045, 8.725672530432314, 8.749854663040013, 8.776545217398594, 8.803227014693142, 8.743580669100163, 8.76546272515898, 8.67841098228123, 8.639802520139495, 8.569346923227283, 8.515124057181485, 8.500649249207923, 8.465261404425284, 8.45532525806631, 8.43495488767312, 8.464503477062092, 8.448231191749667, 8.407353186221805, 8.34315813619307, 8.319599737809348, 8.34978676493025, 8.315058161681225, 8.384463124360378, 8.388324095024888, 8.485828646855282, 8.547655227381936, 8.539608160846065, 8.601449073698689, 8.588765368671085, 8.619403878270631, 8.66710688535702, 8.696965161139874, 8.705396320138346, 8.712172418299321, 8.726200340428724, 8.72543845804308, 8.729317337382744, 8.686141020359388, 8.692145712086324, 8.726454440780621, 8.790305212217296, 8.911938354622558, 8.876001546871361, 8.765226130637576, 8.777110937199263, 8.710209565291155, 8.718390190753857, 8.838115718233825, 8.818126144948158, 8.713789981418765, 8.665623301855977, 8.649628151422556, 8.605886639691093, 8.736447321892433, 8.714167768521587, 8.772044915201432, 8.74804449961673, 8.647654660339878, 8.616798351041934, 8.615402780461329, 8.633506555247033, 8.639813706938352, 8.671072174015425, 8.609185332617349, 8.508998456667502, 8.466883002795548, 8.45880536119881, 8.397833271830553, 8.417319944983538, 8.414786107418616, 8.37538808539126, 8.287549374739427, 8.26331545765366, 8.257806754166841, 8.255090794493297, 8.204036916789324, 8.18090416152759, 8.099321966331493, 8.05301942306742, 8.032382330991044, 8.009734201834618, 7.928030186587907, 7.909285629435085, 7.892657543658793, 7.85541348987676, 7.880059149864683, 7.854024395068605, 7.77392837852095, 7.780696444866186, 7.718985695581614, 7.738797436322837, 7.686715082903719, 7.67828607673687, 7.67087490256606, 7.689302245192047, 7.615612674257093, 7.566433041741517, 7.587727045121962, 7.531777554186748, 7.543120158726112, 7.546076717251367, 7.505624754574766, 7.471454602073076, 7.462646829558187, 7.469966640546512, 7.446424646306798, 7.3943501801583125, 7.395960112305919, 7.457478079235893, 7.493777501526946, 7.567006956625985, 7.643902806263908, 7.677245565938762, 7.739117507406339, 7.723614891240555, 7.734253634788368, 7.710221021850232, 7.722653291119107, 7.7611980153186275, 7.783167825829458, 7.792976677098177, 7.813486730485384, 7.827982750600113, 7.865485593647512, 7.89346865215495, 7.851089650003551, 7.8964846500955606, 7.907210734806524, 7.923583966523582, 7.951287389555553, 7.920522783769807, 7.907975408662956, 7.874987747992427, 7.883236060805292, 7.895871292057997, 7.869573943964885, 7.845479102744145, 7.86478269719948, 7.860474522820284, 7.871003597030237, 7.833492865319142, 7.788963710713078 21 | ], 22 | } 23 | 24 | niah_dict = { 25 | 'LLaMA3-8B-Base': { 26 | 'Length2000Depth0_origin_en': {'score': 100.0}, 27 | 'Length2000Depth10_origin_en': {'score': 100.0}, 28 | 'Length2000Depth20_origin_en': {'score': 100.0}, 29 | 'Length2000Depth30_origin_en': {'score': 100.0}, 30 | 'Length2000Depth40_origin_en': {'score': 100.0}, 31 | 'Length2000Depth50_origin_en': {'score': 100.0}, 32 | 'Length2000Depth60_origin_en': {'score': 100.0}, 33 | 'Length2000Depth70_origin_en': {'score': 100.0}, 34 | 'Length2000Depth80_origin_en': {'score': 100.0}, 35 | 'Length2000Depth90_origin_en': {'score': 100.0}, 36 | 'Length2000Depth100_origin_en': {'score': 100.0}, 37 | 'Length4000Depth0_origin_en': {'score': 100.0}, 38 | 'Length4000Depth10_origin_en': {'score': 100.0}, 39 | 'Length4000Depth20_origin_en': {'score': 100.0}, 40 | 'Length4000Depth30_origin_en': {'score': 100.0}, 41 | 'Length4000Depth40_origin_en': {'score': 100.0}, 42 | 'Length4000Depth50_origin_en': {'score': 100.0}, 43 | 'Length4000Depth60_origin_en': {'score': 100.0}, 44 | 'Length4000Depth70_origin_en': {'score': 100.0}, 45 | 'Length4000Depth80_origin_en': {'score': 100.0}, 46 | 'Length4000Depth90_origin_en': {'score': 100.0}, 47 | 'Length4000Depth100_origin_en': {'score': 100.0}, 48 | 'Length8000Depth0_origin_en': {'score': 100.0}, 49 | 'Length8000Depth10_origin_en': {'score': 100.0}, 50 | 'Length8000Depth20_origin_en': {'score': 100.0}, 51 | 'Length8000Depth30_origin_en': {'score': 100.0}, 52 | 'Length8000Depth40_origin_en': {'score': 100.0}, 53 | 'Length8000Depth50_origin_en': {'score': 100.0}, 54 | 'Length8000Depth60_origin_en': {'score': 100.0}, 55 | 'Length8000Depth70_origin_en': {'score': 100.0}, 56 | 'Length8000Depth80_origin_en': {'score': 100.0}, 57 | 'Length8000Depth90_origin_en': {'score': 100.0}, 58 | 'Length8000Depth100_origin_en': {'score': 100.0}, 59 | 'Length16000Depth0_origin_en': {'score': 3.955578202023016}, 60 | 'Length16000Depth10_origin_en': {'score': 3.814358520895154}, 61 | 'Length16000Depth20_origin_en': {'score': 3.389939684805654}, 62 | 'Length16000Depth30_origin_en': {'score': 3.5647477858276977}, 63 | 'Length16000Depth40_origin_en': {'score': 3.41768610989437}, 64 | 'Length16000Depth50_origin_en': {'score': 3.4158382534870824}, 65 | 'Length16000Depth60_origin_en': {'score': 3.2570884072896567}, 66 | 'Length16000Depth70_origin_en': {'score': 3.3172103889231854}, 67 | 'Length16000Depth80_origin_en': {'score': 3.346333696112411}, 68 | 'Length16000Depth90_origin_en': {'score': 3.308409397536356}, 69 | 'Length16000Depth100_origin_en': {'score': 3.530094127304376}, 70 | 'Length24000Depth0_origin_en': {'score': 3.816289812600785}, 71 | 'Length24000Depth10_origin_en': {'score': 3.8801882233454608}, 72 | 'Length24000Depth20_origin_en': {'score': 3.712333058962397}, 73 | 'Length24000Depth30_origin_en': {'score': 3.57047673840194}, 74 | 'Length24000Depth40_origin_en': {'score': 3.753657209745467}, 75 | 'Length24000Depth50_origin_en': {'score': 3.777985348144155}, 76 | 'Length24000Depth60_origin_en': {'score': 3.676910191650002}, 77 | 'Length24000Depth70_origin_en': {'score': 3.6506695250024364}, 78 | 'Length24000Depth80_origin_en': {'score': 3.628475087255798}, 79 | 'Length24000Depth90_origin_en': {'score': 3.6317158379227665}, 80 | 'Length24000Depth100_origin_en': {'score': 3.6741107459202}, 81 | 'Length32000Depth0_origin_en': {'score': 2.9442774159743488}, 82 | 'Length32000Depth10_origin_en': {'score': 3.0268215785369144}, 83 | 'Length32000Depth20_origin_en': {'score': 2.978469304854226}, 84 | 'Length32000Depth30_origin_en': {'score': 3.2533993732580013}, 85 | 'Length32000Depth40_origin_en': {'score': 2.8608639769098962}, 86 | 'Length32000Depth50_origin_en': {'score': 3.171727716002395}, 87 | 'Length32000Depth60_origin_en': {'score': 3.0616042268478187}, 88 | 'Length32000Depth70_origin_en': {'score': 3.075137387226193}, 89 | 'Length32000Depth80_origin_en': {'score': 3.1226106428354994}, 90 | 'Length32000Depth90_origin_en': {'score': 3.0329053802681587}, 91 | 'Length32000Depth100_origin_en': {'score': 3.110238195812485} 92 | }, 93 | 'LLaMA3-8B-Instruct': { 94 | 'Length2000Depth0_origin_en': {'score': 90.37837837837837}, 95 | 'Length2000Depth10_origin_en': {'score': 100.0}, 96 | 'Length2000Depth20_origin_en': {'score': 100.0}, 97 | 'Length2000Depth30_origin_en': {'score': 81.28487365870544}, 98 | 'Length2000Depth40_origin_en': {'score': 90.33333333333334}, 99 | 'Length2000Depth50_origin_en': {'score': 90.33333333333334}, 100 | 'Length2000Depth60_origin_en': {'score': 90.33333333333334}, 101 | 'Length2000Depth70_origin_en': {'score': 81.23071458172993}, 102 | 'Length2000Depth80_origin_en': {'score': 80.71171171171171}, 103 | 'Length2000Depth90_origin_en': {'score': 80.71171171171171}, 104 | 'Length2000Depth100_origin_en': {'score': 90.37837837837837}, 105 | 'Length4000Depth0_origin_en': {'score': 90.3731343283582}, 106 | 'Length4000Depth10_origin_en': {'score': 100.0}, 107 | 'Length4000Depth20_origin_en': {'score': 100.0}, 108 | 'Length4000Depth30_origin_en': {'score': 90.66165413533835}, 109 | 'Length4000Depth40_origin_en': {'score': 81.22881831444282}, 110 | 'Length4000Depth50_origin_en': {'score': 100.0}, 111 | 'Length4000Depth60_origin_en': {'score': 90.66165413533835}, 112 | 'Length4000Depth70_origin_en': {'score': 81.29580047680176}, 113 | 'Length4000Depth80_origin_en': {'score': 81.19788601939631}, 114 | 'Length4000Depth90_origin_en': {'score': 81.01370838525904}, 115 | 'Length4000Depth100_origin_en': {'score': 80.75197583689906}, 116 | 'Length8000Depth0_origin_en': {'score': 90.42735042735043}, 117 | 'Length8000Depth10_origin_en': {'score': 90.38571428571429}, 118 | 'Length8000Depth20_origin_en': {'score': 90.55223880597015}, 119 | 'Length8000Depth30_origin_en': {'score': 90.36496350364965}, 120 | 'Length8000Depth40_origin_en': {'score': 90.36496350364965}, 121 | 'Length8000Depth50_origin_en': {'score': 100.0}, 122 | 'Length8000Depth60_origin_en': {'score': 90.35555555555555}, 123 | 'Length8000Depth70_origin_en': {'score': 90.35555555555555}, 124 | 'Length8000Depth80_origin_en': {'score': 100.0}, 125 | 'Length8000Depth90_origin_en': {'score': 100.0}, 126 | 'Length8000Depth100_origin_en': {'score': 90.38571428571429}, 127 | 'Length16000Depth0_origin_en': {'score': 3.5705017639222305}, 128 | 'Length16000Depth10_origin_en': {'score': 3.7718933009296256}, 129 | 'Length16000Depth20_origin_en': {'score': 3.3968781977208073}, 130 | 'Length16000Depth30_origin_en': {'score': 3.4606092023781527}, 131 | 'Length16000Depth40_origin_en': {'score': 3.558577740759143}, 132 | 'Length16000Depth50_origin_en': {'score': 3.5311288307711415}, 133 | 'Length16000Depth60_origin_en': {'score': 3.4329180031799007}, 134 | 'Length16000Depth70_origin_en': {'score': 3.5914161088478176}, 135 | 'Length16000Depth80_origin_en': {'score': 3.547603317865215}, 136 | 'Length16000Depth90_origin_en': {'score': 3.3606714996833973}, 137 | 'Length16000Depth100_origin_en': {'score': 3.5553519245884972}, 138 | 'Length24000Depth0_origin_en': {'score': 3.3145166403207673}, 139 | 'Length24000Depth10_origin_en': {'score': 3.3208293172464662}, 140 | 'Length24000Depth20_origin_en': {'score': 3.7010730844738133}, 141 | 'Length24000Depth30_origin_en': {'score': 3.689621886548649}, 142 | 'Length24000Depth40_origin_en': {'score': 3.573080835316987}, 143 | 'Length24000Depth50_origin_en': {'score': 3.3362615900063135}, 144 | 'Length24000Depth60_origin_en': {'score': 3.288194557760524}, 145 | 'Length24000Depth70_origin_en': {'score': 3.288194557760524}, 146 | 'Length24000Depth80_origin_en': {'score': 3.288194557760524}, 147 | 'Length24000Depth90_origin_en': {'score': 3.3981945577605237}, 148 | 'Length24000Depth100_origin_en': {'score': 3.235488391347652}, 149 | 'Length32000Depth0_origin_en': {'score': 2.4731366996818798}, 150 | 'Length32000Depth10_origin_en': {'score': 1.9208024499246477}, 151 | 'Length32000Depth20_origin_en': {'score': 2.1175531667194476}, 152 | 'Length32000Depth30_origin_en': {'score': 2.1788650422778058}, 153 | 'Length32000Depth40_origin_en': {'score': 1.9651122448567935}, 154 | 'Length32000Depth50_origin_en': {'score': 1.9957320790193223}, 155 | 'Length32000Depth60_origin_en': {'score': 2.261448785836582}, 156 | 'Length32000Depth70_origin_en': {'score': 2.165767789683888}, 157 | 'Length32000Depth80_origin_en': {'score': 2.128031940627284}, 158 | 'Length32000Depth90_origin_en': {'score': 2.128031940627284}, 159 | 'Length32000Depth100_origin_en': {'score': 2.126351313530453} 160 | }, 161 | 'LLaDA-8B-Base': { 162 | 'Length2000Depth0_origin_en': {'score': 100.0}, 163 | 'Length2000Depth10_origin_en': {'score': 100.0}, 164 | 'Length2000Depth20_origin_en': {'score': 100.0}, 165 | 'Length2000Depth30_origin_en': {'score': 100.0}, 166 | 'Length2000Depth40_origin_en': {'score': 100.0}, 167 | 'Length2000Depth50_origin_en': {'score': 100.0}, 168 | 'Length2000Depth60_origin_en': {'score': 100.0}, 169 | 'Length2000Depth70_origin_en': {'score': 100.0}, 170 | 'Length2000Depth80_origin_en': {'score': 100.0}, 171 | 'Length2000Depth90_origin_en': {'score': 100.0}, 172 | 'Length2000Depth100_origin_en': {'score': 100.0}, 173 | 'Length4000Depth0_origin_en': {'score': 100.0}, 174 | 'Length4000Depth10_origin_en': {'score': 100.0}, 175 | 'Length4000Depth20_origin_en': {'score': 100.0}, 176 | 'Length4000Depth30_origin_en': {'score': 100.0}, 177 | 'Length4000Depth40_origin_en': {'score': 100.0}, 178 | 'Length4000Depth50_origin_en': {'score': 100.0}, 179 | 'Length4000Depth60_origin_en': {'score': 100.0}, 180 | 'Length4000Depth70_origin_en': {'score': 100.0}, 181 | 'Length4000Depth80_origin_en': {'score': 100.0}, 182 | 'Length4000Depth90_origin_en': {'score': 100.0}, 183 | 'Length4000Depth100_origin_en': {'score': 100.0}, 184 | 'Length8000Depth0_origin_en': {'score': 4.700679394220694}, 185 | 'Length8000Depth10_origin_en': {'score': 4.29118535359655}, 186 | 'Length8000Depth20_origin_en': {'score': 13.872474716319505}, 187 | 'Length8000Depth30_origin_en': {'score': 4.644095850571923}, 188 | 'Length8000Depth40_origin_en': {'score': 61.48321967260274}, 189 | 'Length8000Depth50_origin_en': {'score': 100.0}, 190 | 'Length8000Depth60_origin_en': {'score': 100.0}, 191 | 'Length8000Depth70_origin_en': {'score': 100.0}, 192 | 'Length8000Depth80_origin_en': {'score': 100.0}, 193 | 'Length8000Depth90_origin_en': {'score': 100.0}, 194 | 'Length8000Depth100_origin_en': {'score': 100.0}, 195 | 'Length16000Depth0_origin_en': {'score': 5.661801557895577}, 196 | 'Length16000Depth10_origin_en': {'score': 5.622672736186065}, 197 | 'Length16000Depth20_origin_en': {'score': 5.843090061328022}, 198 | 'Length16000Depth30_origin_en': {'score': 5.617648734318405}, 199 | 'Length16000Depth40_origin_en': {'score': 5.931713500948121}, 200 | 'Length16000Depth50_origin_en': {'score': 6.001014262081467}, 201 | 'Length16000Depth60_origin_en': {'score': 5.497573678292468}, 202 | 'Length16000Depth70_origin_en': {'score': 5.667443147821325}, 203 | 'Length16000Depth80_origin_en': {'score': 100.0}, 204 | 'Length16000Depth90_origin_en': {'score': 100.0}, 205 | 'Length16000Depth100_origin_en': {'score': 100.0}, 206 | 'Length24000Depth0_origin_en': {'score': 4.963310490191196}, 207 | 'Length24000Depth10_origin_en': {'score': 5.054142095701048}, 208 | 'Length24000Depth20_origin_en': {'score': 4.537515059657501}, 209 | 'Length24000Depth30_origin_en': {'score': 4.801166503733159}, 210 | 'Length24000Depth40_origin_en': {'score': 5.018685960822426}, 211 | 'Length24000Depth50_origin_en': {'score': 4.935725731307151}, 212 | 'Length24000Depth60_origin_en': {'score': 4.910711291643793}, 213 | 'Length24000Depth70_origin_en': {'score': 4.997351987613704}, 214 | 'Length24000Depth80_origin_en': {'score': 4.6526419696785215}, 215 | 'Length24000Depth90_origin_en': {'score': 90.29457364341086}, 216 | 'Length24000Depth100_origin_en': {'score': 100.0}, 217 | 'Length32000Depth0_origin_en': {'score': 3.870724509538114}, 218 | 'Length32000Depth10_origin_en': {'score': 4.176837927059076}, 219 | 'Length32000Depth20_origin_en': {'score': 3.979273276540834}, 220 | 'Length32000Depth30_origin_en': {'score': 3.7557044844658583}, 221 | 'Length32000Depth40_origin_en': {'score': 3.9167490293613247}, 222 | 'Length32000Depth50_origin_en': {'score': 3.924985543776618}, 223 | 'Length32000Depth60_origin_en': {'score': 4.101317546440129}, 224 | 'Length32000Depth70_origin_en': {'score': 4.122466367547554}, 225 | 'Length32000Depth80_origin_en': {'score': 4.107608129796786}, 226 | 'Length32000Depth90_origin_en': {'score': 3.892856439917338}, 227 | 'Length32000Depth100_origin_en': {'score': 71.57156166814552} 228 | }, 229 | 'LLaDA-8B-Instruct': { 230 | 'Length2000Depth0_origin_en': {'score': 100.0}, 231 | 'Length2000Depth10_origin_en': {'score': 100.0}, 232 | 'Length2000Depth20_origin_en': {'score': 100.0}, 233 | 'Length2000Depth30_origin_en': {'score': 100.0}, 234 | 'Length2000Depth40_origin_en': {'score': 100.0}, 235 | 'Length2000Depth50_origin_en': {'score': 100.0}, 236 | 'Length2000Depth60_origin_en': {'score': 90.35374149659864}, 237 | 'Length2000Depth70_origin_en': {'score': 100.0}, 238 | 'Length2000Depth80_origin_en': {'score': 100.0}, 239 | 'Length2000Depth90_origin_en': {'score': 90.42424242424242}, 240 | 'Length2000Depth100_origin_en': {'score': 100.0}, 241 | 'Length4000Depth0_origin_en': {'score': 100.0}, 242 | 'Length4000Depth10_origin_en': {'score': 100.0}, 243 | 'Length4000Depth20_origin_en': {'score': 100.0}, 244 | 'Length4000Depth30_origin_en': {'score': 100.0}, 245 | 'Length4000Depth40_origin_en': {'score': 100.0}, 246 | 'Length4000Depth50_origin_en': {'score': 100.0}, 247 | 'Length4000Depth60_origin_en': {'score': 100.0}, 248 | 'Length4000Depth70_origin_en': {'score': 100.0}, 249 | 'Length4000Depth80_origin_en': {'score': 100.0}, 250 | 'Length4000Depth90_origin_en': {'score': 90.4360902255639}, 251 | 'Length4000Depth100_origin_en': {'score': 100.0}, 252 | 'Length8000Depth0_origin_en': {'score': 13.759822654411257}, 253 | 'Length8000Depth10_origin_en': {'score': 13.639063689557233}, 254 | 'Length8000Depth20_origin_en': {'score': 23.115124407066492}, 255 | 'Length8000Depth30_origin_en': {'score': 14.140713069136655}, 256 | 'Length8000Depth40_origin_en': {'score': 80.77278208441}, 257 | 'Length8000Depth50_origin_en': {'score': 100.0}, 258 | 'Length8000Depth60_origin_en': {'score': 100.0}, 259 | 'Length8000Depth70_origin_en': {'score': 100.0}, 260 | 'Length8000Depth80_origin_en': {'score': 100.0}, 261 | 'Length8000Depth90_origin_en': {'score': 100.0}, 262 | 'Length8000Depth100_origin_en': {'score': 100.0}, 263 | 'Length16000Depth0_origin_en': {'score': 4.221726547852791}, 264 | 'Length16000Depth10_origin_en': {'score': 4.14880322926385}, 265 | 'Length16000Depth20_origin_en': {'score': 3.8003142535463454}, 266 | 'Length16000Depth30_origin_en': {'score': 4.107636678658605}, 267 | 'Length16000Depth40_origin_en': {'score': 4.173052537518885}, 268 | 'Length16000Depth50_origin_en': {'score': 4.130942440403301}, 269 | 'Length16000Depth60_origin_en': {'score': 4.116522447748491}, 270 | 'Length16000Depth70_origin_en': {'score': 3.739108642917869}, 271 | 'Length16000Depth80_origin_en': {'score': 80.76521739130435}, 272 | 'Length16000Depth90_origin_en': {'score': 90.40298507462687}, 273 | 'Length16000Depth100_origin_en': {'score': 100.0}, 274 | 'Length24000Depth0_origin_en': {'score': 4.351837221215318}, 275 | 'Length24000Depth10_origin_en': {'score': 4.087575138851376}, 276 | 'Length24000Depth20_origin_en': {'score': 4.033899595277184}, 277 | 'Length24000Depth30_origin_en': {'score': 4.13096601682714}, 278 | 'Length24000Depth40_origin_en': {'score': 4.342704335725307}, 279 | 'Length24000Depth50_origin_en': {'score': 4.311520333815464}, 280 | 'Length24000Depth60_origin_en': {'score': 4.111736986686991}, 281 | 'Length24000Depth70_origin_en': {'score': 4.129820183863647}, 282 | 'Length24000Depth80_origin_en': {'score': 4.127310925565743}, 283 | 'Length24000Depth90_origin_en': {'score': 62.041159012808485}, 284 | 'Length24000Depth100_origin_en': {'score': 90.37837837837837}, 285 | 'Length32000Depth0_origin_en': {'score': 3.7647783839273528}, 286 | 'Length32000Depth10_origin_en': {'score': 3.3173461589394244}, 287 | 'Length32000Depth20_origin_en': {'score': 3.5726719604550916}, 288 | 'Length32000Depth30_origin_en': {'score': 3.5103745714846353}, 289 | 'Length32000Depth40_origin_en': {'score': 3.578554782205976}, 290 | 'Length32000Depth50_origin_en': {'score': 3.6529694251862743}, 291 | 'Length32000Depth60_origin_en': {'score': 3.521140319687622}, 292 | 'Length32000Depth70_origin_en': {'score': 3.5572595597516354}, 293 | 'Length32000Depth80_origin_en': {'score': 3.4710161880112507}, 294 | 'Length32000Depth90_origin_en': {'score': 3.524176120032192}, 295 | 'Length32000Depth100_origin_en': {'score': 13.479082253627729} 296 | }, 297 | } 298 | 299 | color_dict = { 300 | 'LLaMA3-8B-Base': '#1F77B4', 301 | 'LLaMA3-8B-Instruct': '#2BA02B', 302 | 'LLaDA-8B-Base': '#D62727', 303 | 'LLaDA-8B-Instruct': '#FF7F0F', 304 | } 305 | 306 | fig = plt.figure(figsize=(6, 3.5), dpi=200) 307 | ax = fig.add_subplot(1, 1, 1) 308 | 309 | for model in ppl_dict: 310 | 311 | ppl_line = ppl_dict[model] 312 | num_sample = len(ppl_line) 313 | 314 | ppl_line = np.array(ppl_line[::2]) 315 | num_sample = (np.arange(len(ppl_line))*2 + 1) * 64 316 | l, = ax.plot(num_sample, ppl_line, lw=2, ls='-', label=model, color=color_dict[model]) 317 | 318 | ax.set_ylim((ax.get_ylim()[0], ax.get_ylim()[1] - 30)) 319 | 320 | ax.set_xticks([0, 4000, 8000, 12000, 16000], [0, 4000, 8000, 12000, 16000]) 321 | ax.set_xticklabels(ax.get_xticklabels(), fontsize=12) 322 | ax.set_yticklabels(ax.get_yticklabels(), fontsize=12) 323 | ax.set_xlabel('Context Length', fontsize=14, labelpad=6) 324 | ax.set_ylabel('Perplexity', fontsize=14) 325 | 326 | plt.legend(fontsize=12, loc='upper left') 327 | plt.subplots_adjust(left=0.12, bottom=0.15, right=0.92, top=0.96, hspace=0.25, wspace=0.2) 328 | plt.savefig('diffusion_ppl_gov.png') 329 | 330 | 331 | fig = plt.figure(figsize=(6, 3.5), dpi=200) 332 | ax = fig.add_subplot(1, 1, 1) 333 | 334 | depths = (np.arange(11) * 10).tolist() 335 | lengths = [2000, 4000, 8000, 16000, 24000, 32000] 336 | 337 | for model in niah_dict: 338 | 339 | niah_model_dict = niah_dict[model] 340 | niah_line = [np.mean([niah_model_dict[f'Length{length}Depth{depth}_origin_en']['score'] for depth in depths]) for length in lengths] 341 | l, = ax.plot(niah_line, lw=2, ls='-', label=model, color=color_dict[model], # 342 | marker='o', markerfacecolor='white', markeredgewidth=2, markeredgecolor=color_dict[model]) 343 | 344 | ax.set_xticks(range(len(lengths)), lengths) 345 | 346 | ax.set_xticklabels(ax.get_xticklabels(), fontsize=12) 347 | ax.set_yticklabels(ax.get_yticklabels(), fontsize=12) 348 | 349 | ax.set_xlabel('Context Length', fontsize=14, labelpad=6) 350 | ax.set_ylabel('Retrieval Accuracy', fontsize=14) 351 | 352 | plt.legend(fontsize=12, loc='lower left') 353 | plt.subplots_adjust(left=0.12, bottom=0.15, right=0.92, top=0.96, hspace=0.25, wspace=0.2) 354 | plt.savefig('diffusion_niah.png') -------------------------------------------------------------------------------- /needlebench/needlebench_summarizer.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import getpass 3 | import math 4 | import os 5 | import os.path as osp 6 | import shutil 7 | from datetime import datetime 8 | from typing import Any, Dict, List, Optional 9 | 10 | import matplotlib.pyplot as plt 11 | # plt.rcParams["font.family"] = "Times New Roman" 12 | 13 | import mmengine 14 | import numpy as np 15 | import pandas as pd 16 | import seaborn as sns 17 | import tabulate 18 | from matplotlib.colors import LinearSegmentedColormap 19 | from mmengine import ConfigDict 20 | from tqdm import tqdm 21 | 22 | from opencompass.summarizers.default import ( 23 | METRIC_BLACKLIST, METRIC_WHITELIST, DefaultSummarizer, 24 | model_abbr_from_cfg_used_in_summarizer) 25 | from opencompass.utils import (LarkReporter, dataset_abbr_from_cfg, 26 | get_infer_output_path, get_logger, 27 | model_abbr_from_cfg) 28 | from opencompass.utils.prompt import get_prompt_hash 29 | 30 | model_name_mapping = { 31 | 'llama-2-7b-chat-hf': 'LLaMA-2-7B', 32 | 'llama-2-13b-chat-hf': 'LLaMA-2-13B', 33 | 'llama-2-70b-chat-hf': 'LLaMA-2-70B', 34 | 'baichuan2-7b-chat-hf': 'Baichuan2-7B', 35 | 'baichuan2-13b-chat-hf': 'Baichuan2-13B', 36 | 'yi-6b-chat-hf': 'Yi-6B', 37 | 'yi-34b-chat-hf': 'Yi-34B', 38 | 'deepseek-67b-chat-hf': 'DeepSeek-67B', 39 | 'wizardlm-70b-v1.0-vllm': 'WizardLM-70B', 40 | 'qwen-14b-chat-hf': 'Qwen-14B', 41 | 'qwen-72b-chat-hf': 'Qwen-72B', 42 | 'qwen-72b-chat-vllm': 'Qwen-72B-vLLM', 43 | 'internlm2-chat-7b-turbomind': 'InternLM2-7B-200K', 44 | 'internlm2-chat-20b-turbomind': 'InternLM2-20B-200K', 45 | 'internlm2-chat-7b-hf': 'InternLM2-7B', 46 | 'internlm2-chat-20b-hf': 'InternLM2-20B', 47 | 'qwen-7b-chat-hf': 'Qwen-7B', 48 | 'chatglm3-6b-hf': 'ChatGLM3-6B', 49 | 'chatglm3-6b-32k-hf': 'ChatGLM3-6B-32K', 50 | 'zephyr-7b-beta-vllm': 'Zephyr-7B Beta', 51 | 'mistral-7b-instruct-v0.2-vllm': 'Mistral-7B Inst. v0.2', 52 | 'mistral-7b-instruct-v0.1-vllm': 'Mistral-7B Inst. v0.1', 53 | 'mixtral-8x7b-instruct-v0.1-vllm': 'Mixtral-8x7B Inst. v0.1', 54 | 'orionstar-yi-34b-chat-hf': 'OrionStar-Yi-34B', 55 | 'orionstar-14b-long-chat-vllm': 'Orion-14B-LongChat', 56 | 'internlm-chat-7b-hf': 'InternLM-7B', 57 | 'gemma-2b-it-hf': 'Gemma-2B', 58 | 'gemma-7b-it-hf': 'Gemma-7B', 59 | 'qwen1.5-0.5b-chat-hf': 'Qwen-1.5-0.5B', 60 | 'qwen1.5-1.8b-chat-hf': 'Qwen-1.5-1.8B', 61 | 'qwen1.5-4b-chat-hf': 'Qwen-1.5-4B', 62 | 'qwen1.5-14b-chat-hf': 'Qwen-1.5-14B', 63 | 'qwen1.5-72b-chat-hf': 'Qwen-1.5-72B', 64 | 'qwen1.5-14b-chat-vllm': 'Qwen-1.5-14B-vLLM', 65 | 'qwen1.5-72b-chat-vllm': 'Qwen-1.5-72B-vLLM', 66 | 'glm4_notools': 'GLM-4', 67 | 'claude-3-opus': 'Claude-3-Opus', 68 | 'glm-4-9b-chat-1m-vllm': 'GLM4-9B-Chat-1M', 69 | 'internlm2_5-7b-chat-1m-turbomind': 'InternLM2.5-7B-Chat-1M', 70 | # Add more mappings as necessary 71 | } 72 | 73 | dataset_mapping_dict = {} 74 | 75 | # needle_counts = ['2', '3', '4', '5'] 76 | languages = ['en', 'zh'] 77 | # sizes = ['4k', '8k', '32k', '128k', '200k', '256k', '1000k', '1M'] 78 | types = ['origin', 'parallel'] 79 | 80 | # for needle_count in needle_counts: 81 | # for language in languages: 82 | # for size in sizes: 83 | # key = f'{needle_count}needle_{language}_{size}' 84 | # value = f'{needle_count}-Needle-Reasoning-{language.upper()}-{size.upper()}' 85 | # dataset_mapping_dict[key] = value 86 | for t in types: 87 | for language in languages: 88 | # for size in sizes: 89 | if t == 'origin': 90 | key = f'{t}_{language}' # _{size} 91 | value = f'Single-Needle-Retrieval-{language.upper()}' # -{size.upper()}' 92 | elif t == 'parallel': 93 | key = f'{t}_{language}' # _{size} 94 | value = f'Multi-Needle-Retrieval-{language.upper()}' # -{size.upper()}' 95 | dataset_mapping_dict[key] = value 96 | 97 | 98 | def calculate_elementwise_average(model_name, merged_df): 99 | score_columns = [col for col in merged_df.columns if col != 'dataset'] 100 | 101 | origin_columns = [col for col in score_columns if 'origin' in col] 102 | parallel_columns = [col for col in score_columns if 'parallel' in col] 103 | multi_columns = [col for col in score_columns if 'needle' in col] 104 | 105 | if origin_columns and parallel_columns and multi_columns: 106 | origin_avg = merged_df[origin_columns].mean(axis=1) * 0.4 107 | parallel_avg = merged_df[parallel_columns].mean(axis=1) * 0.3 108 | multi_avg = merged_df[multi_columns].mean(axis=1) * 0.3 109 | merged_df[model_name] = origin_avg + parallel_avg + multi_avg 110 | else: 111 | relevant_columns = origin_columns or parallel_columns or multi_columns 112 | if relevant_columns: 113 | merged_df[model_name] = merged_df[relevant_columns].mean(axis=1) 114 | else: 115 | merged_df[model_name] = pd.Series([0] * len(merged_df)) 116 | 117 | return merged_df.iloc[:, [0, -1]] 118 | 119 | def read_after_specific_line_except_last(file_name, keyword, offset): 120 | with open(file_name, 'r', encoding='utf-8') as file: 121 | lines = file.readlines() 122 | 123 | for index, line in enumerate(lines): 124 | if keyword in line: 125 | start_index = index + offset + 1 126 | break 127 | else: 128 | return '' 129 | 130 | return ''.join(lines[start_index:-1]) 131 | 132 | def create_model_dataframe(nested_dict, model_name, dataset_abbr, parallel=False): 133 | if model_name not in nested_dict: 134 | print(f'Model {model_name} not found in the provided data.') 135 | return pd.DataFrame() 136 | 137 | model_data = nested_dict[model_name] 138 | data = [] 139 | 140 | for key, value in model_data.items(): 141 | if parallel: 142 | if dataset_abbr in key: 143 | new_key_base = key.replace(dataset_abbr, '').strip('_') 144 | for depth_key, score in value.items(): 145 | new_key = f'{new_key_base}{depth_key}' 146 | if 'average_score' not in new_key: 147 | data.append([new_key, score]) 148 | else: 149 | if dataset_abbr in key: 150 | score = value.get('score', None) 151 | new_key = key.replace(dataset_abbr, '').strip('_') 152 | data.append([new_key, score]) 153 | 154 | df = pd.DataFrame(data, columns=['dataset', model_name]) 155 | return df 156 | 157 | def convert_to_k(value): 158 | try: 159 | return f'{int(value) // 1000}k' 160 | except ValueError: 161 | return value 162 | 163 | def parse_model_scores(text): 164 | lines = text.split('\n') 165 | 166 | result_dict = {} 167 | current_model = None 168 | 169 | for line in lines: 170 | if line.startswith('Model:'): 171 | current_model = line.split('Model:')[1].strip() 172 | result_dict[current_model] = {} 173 | elif current_model and ':' in line: 174 | dataset, score_str = line.split(':', 1) 175 | score_dict = eval(score_str.strip()) 176 | result_dict[current_model][dataset] = score_dict 177 | 178 | return result_dict 179 | 180 | def remove_empty_subfolders(plot_path): 181 | for folder_name in tqdm(os.listdir(plot_path), 182 | desc='Deleting Empty folders'): 183 | folder_path = os.path.join(plot_path, folder_name) 184 | if os.path.isdir(folder_path): 185 | if not os.listdir(folder_path): 186 | shutil.rmtree(folder_path) 187 | 188 | def save_results_to_plots(txt_results_save_path): 189 | content = read_after_specific_line_except_last(txt_results_save_path, 'raw format', 2) 190 | parsed_data = parse_model_scores(content) 191 | model_names = get_dict_model_names(parsed_data) 192 | numbers = [2, 3, 4, 5] 193 | languages = ['en', 'zh'] 194 | size_exists = [] 195 | sizes_origin = [''] # , '_4k', '_8k', '_32k', '_128k', '_200k', '_256k', '_1000k', '_1M'] 196 | 197 | for size in sizes_origin: 198 | if size in content: 199 | size_exists.append(size) 200 | 201 | multi_dataset_abbrs = [] # [f'{num}needle_{lang}{size}' for num in numbers for lang in languages for size in size_exists] 202 | origin_dataset_abbrs = [f'origin_{lang}{size}' for lang in languages for size in size_exists] 203 | parallel_dataset_abbrs = [f'parallel_{lang}{size}' for lang in languages for size in size_exists] 204 | 205 | dataset_abbrs = multi_dataset_abbrs + origin_dataset_abbrs + \ 206 | parallel_dataset_abbrs 207 | base_path = os.path.dirname(txt_results_save_path) 208 | plot_path = os.path.join(base_path, 'plots') 209 | 210 | model_scores = {} 211 | 212 | for model_name in tqdm(model_names): 213 | model_datasets_scores = {} # Dictionary to store scores for each dataset for the current model 214 | for dataset_abbr in dataset_abbrs: 215 | parallel_flag = 'parallel' in dataset_abbr 216 | 217 | folder_path = os.path.join(plot_path, dataset_mapping_dict[dataset_abbr]) 218 | ensure_directory(folder_path) 219 | 220 | save_path = os.path.join(folder_path, f'{model_name}.png') 221 | 222 | df = create_model_dataframe(parsed_data, model_name, dataset_abbr, parallel=parallel_flag) 223 | 224 | score = visualize(df, save_path, model_name, dataset_abbr) 225 | 226 | model_datasets_scores[dataset_abbr] = '{:.02f}'.format(score) 227 | 228 | # overall_dataset_abbrs = multi_dataset_abbrs + origin_dataset_abbrs + parallel_dataset_abbrs 229 | # overall_score_pic_path = os.path.join(plot_path, f'{model_name}_overall.png') 230 | # merged_df = merge_dataframes(model_name, overall_dataset_abbrs, parsed_data) 231 | # averaged_df = calculate_elementwise_average(model_name, merged_df) 232 | # overall_score = visualize(averaged_df, overall_score_pic_path, model_name, 'Overall Score') 233 | 234 | # Single-Retrieval 235 | single_retrieval_score_pic_path = os.path.join(plot_path, f'{model_name}_single_retrieval_overall.png') 236 | single_retrieval_merged_df = merge_dataframes(model_name, origin_dataset_abbrs, parsed_data) 237 | single_retrieval_averaged_df = calculate_elementwise_average(model_name, single_retrieval_merged_df) 238 | single_retrieval_overall_score = visualize(single_retrieval_averaged_df, single_retrieval_score_pic_path, model_name, 'Single-Retrieval Overall Score') 239 | 240 | # Multi-Retrieval 241 | multi_retrieval_score_pic_path = os.path.join(plot_path, f'{model_name}_multi_retrieval_overall.png') 242 | multi_retrieval_merged_df = merge_dataframes(model_name, parallel_dataset_abbrs, parsed_data) 243 | multi_retrieval_averaged_df = calculate_elementwise_average(model_name, multi_retrieval_merged_df) 244 | multi_retrieval_overall_score = visualize(multi_retrieval_averaged_df, multi_retrieval_score_pic_path, model_name, 'Multi-Retrieval Overall Score') 245 | 246 | # # Multi-Reasoning 247 | # multi_reasoning_score_pic_path = os.path.join(plot_path, f'{model_name}_multi_reasoning_overall.png') 248 | # multi_reasoning_merged_df = merge_dataframes(model_name, multi_dataset_abbrs, parsed_data) 249 | # multi_reasoning_averaged_df = calculate_elementwise_average(model_name, multi_reasoning_merged_df) 250 | # multi_reasoning_overall_score = visualize(multi_reasoning_averaged_df, multi_reasoning_score_pic_path, model_name, 'Multi-Reasoning Overall Score') 251 | 252 | # model_scores[model_name] = averaged_df 253 | remove_empty_subfolders(plot_path) 254 | return model_scores 255 | 256 | def visualize(df_raw, save_path: str,model_name: str ,dataset_type:str): 257 | df = df_raw.copy() 258 | if df.empty: 259 | return -1 260 | df['Context Length'] = df['dataset'].apply( 261 | lambda x: int(x.split('Length')[1].split('Depth')[0])) 262 | df['Document Depth'] = df['dataset'].apply( 263 | lambda x: float(x.split('Depth')[1].split('_')[0])) 264 | 265 | model_columns = [ 266 | col for col in df.columns 267 | if col not in ['Context Length', 'Document Depth'] 268 | ] 269 | 270 | for model_name in model_columns[1:]: 271 | model_df = df[['Document Depth', 'Context Length', 272 | model_name]].copy() 273 | model_df.rename(columns={model_name: 'Score'}, inplace=True) 274 | 275 | # Create pivot table 276 | pivot_table = pd.pivot_table(model_df, 277 | values='Score', 278 | index=['Document Depth'], 279 | columns=['Context Length'], 280 | aggfunc='mean') 281 | 282 | # Calculate mean scores 283 | mean_scores = pivot_table.mean().values 284 | 285 | # Calculate overall score 286 | overall_score = mean_scores.mean() 287 | 288 | # Create heatmap and line plot 289 | plt.figure(figsize=(15.5, 8), dpi=300) 290 | ax = plt.gca() 291 | cmap = LinearSegmentedColormap.from_list( 292 | 'custom_cmap', ['#F0496E', '#EBB839', '#0CD79F']) 293 | 294 | # Draw heatmap 295 | tmp_ax = sns.heatmap(pivot_table, 296 | cmap=cmap, 297 | ax=ax, 298 | cbar_kws={'label': 'Score'}, 299 | vmin=0, 300 | vmax=100) 301 | tmp_ax.figure.axes[-1].yaxis.label.set_size(26) 302 | # print(dir(tmp_ax.figure.axes[-1])) 303 | # print(dir(tmp_ax.figure.axes[-1].yaxis)) 304 | # tmp_ax.figure.axes[-1].yticks.set_size(16) 305 | tmp_ax.figure.axes[-1].set_yticklabels(tmp_ax.figure.axes[-1].get_yticklabels(), 306 | rotation=0, fontsize=24) 307 | 308 | # Set line plot data 309 | x_data = [i + 0.5 for i in range(len(mean_scores))] 310 | y_data = mean_scores 311 | 312 | # Create twin axis for line plot 313 | ax2 = ax.twinx() 314 | # Draw line plot 315 | ax2.plot(x_data, 316 | y_data, 317 | color='white', 318 | marker='o', 319 | linestyle='-', 320 | linewidth=2, 321 | markersize=8, 322 | label='Average Depth Score') # 323 | # Set y-axis range 324 | ax2.set_ylim(0, 100) 325 | 326 | # Hide original y-axis ticks and labels 327 | ax2.set_yticklabels([]) 328 | ax2.set_yticks([]) 329 | 330 | # Add legend 331 | ax2.legend(fontsize=24) # loc='lower right', upper left 332 | 333 | # Set chart title and labels 334 | # ax.set_title(f'{model_name} {dataset_type} Context ' 335 | # 'Performance\nFact Retrieval Across ' 336 | # 'Context Lengths ("Needle In A Haystack")', fontsize=18) 337 | ax.set_xlabel('Token Limit', fontsize=26) 338 | ax.set_ylabel('Depth Percent', fontsize=26) 339 | ax.set_xticklabels(pivot_table.columns.values, rotation=0, fontsize=24) 340 | ax.set_yticklabels(pivot_table.index.values, rotation=0, fontsize=24) 341 | # Add overall score as a subtitle 342 | # plt.text(0.5, 343 | # -0.13, f'Overall Score for {model_name}: ' 344 | # f'{overall_score:.2f}', 345 | # ha='center', 346 | # va='center', 347 | # transform=ax.transAxes, 348 | # fontsize=18) 349 | ax.set_title(f'Overall Score: {overall_score:.2f}', fontsize=28, pad=10) 350 | 351 | plt.tight_layout() 352 | plt.subplots_adjust(right=1.04) 353 | plt.draw() 354 | save_path = save_path.split('.') 355 | save_path[-1] = 'pdf' 356 | save_path = '.'.join(save_path) 357 | plt.savefig(save_path) 358 | print(f'Saved :{save_path}') 359 | plt.close() # Close figure to prevent memory leaks 360 | return overall_score 361 | 362 | # for model_name in model_columns[1:]: 363 | # model_df = df[['Document Depth', 'Context Length', 364 | # model_name]].copy() 365 | # model_df.rename(columns={model_name: 'Score'}, inplace=True) 366 | # pivot_table = pd.pivot_table(model_df, 367 | # values='Score', 368 | # index=['Document Depth'], 369 | # columns=['Context Length'], 370 | # aggfunc='mean') 371 | 372 | # mean_scores = pivot_table.mean().values 373 | # overall_score = mean_scores.mean() 374 | # plt.figure(figsize=(10, 6)) 375 | # ax = plt.gca() 376 | # cmap = LinearSegmentedColormap.from_list( 377 | # 'custom_cmap', ['#F0496E', '#EBB839', '#0CD79F']) 378 | 379 | # sns.heatmap(pivot_table, 380 | # cmap=cmap, 381 | # ax=ax, 382 | # vmin=0, 383 | # vmax=100) 384 | # cbar = ax.collections[0].colorbar 385 | # x_data = [i + 0.5 for i in range(len(mean_scores))] 386 | # y_data = mean_scores 387 | 388 | # ax2 = ax.twinx() 389 | # ax2.plot(x_data, 390 | # y_data, 391 | # color='white', 392 | # marker='o', 393 | # linestyle='-', 394 | # linewidth=2, 395 | # markersize=8, 396 | # label='Average Depth Score' 397 | # ) 398 | # for x_value, y_value in zip(x_data, y_data): 399 | # ax2.text(x_value, y_value, f'{y_value:.2f}', ha='center', va='top') 400 | 401 | # ax2.set_ylim(0, 100) 402 | 403 | # ax2.set_yticklabels([]) 404 | # ax2.set_yticks([]) 405 | 406 | # ax2.legend(loc='lower left') 407 | 408 | # if model_name in model_name_mapping: 409 | # title_name = model_name_mapping[model_name] 410 | # else: 411 | # title_name = model_name 412 | 413 | # ax.set_title(title_name, fontsize=12, fontweight='bold', pad=15) 414 | 415 | # if dataset_type in dataset_mapping_dict: 416 | # dataset_name = dataset_mapping_dict[dataset_type] 417 | # else: 418 | # dataset_name = dataset_type 419 | 420 | # ax.text(0.5, 1.005, f'{dataset_name}:{overall_score:.2f}', 421 | # transform=ax.transAxes, 422 | # ha='center', 423 | # fontsize=12, 424 | # fontweight='normal') 425 | # ax.set_xlabel('Token Length', fontsize=13, fontweight='normal', labelpad=1) 426 | # ax.set_ylabel('Depth Percent(%)', fontsize=13, fontweight='normal', labelpad=1) 427 | # converted_labels = [convert_to_k(value) for value in pivot_table.columns.values] 428 | 429 | # ax.tick_params(axis='both', which='major', length=1, pad=1) 430 | # ax.tick_params(axis='both', which='minor', length=1, pad=1) 431 | # ax.set_xticklabels(converted_labels, rotation=45) 432 | # index_length = len(pivot_table.index) 433 | 434 | # selected_indices = pivot_table.index.values[::2] 435 | # labels = [str(int(index)) for index in selected_indices] 436 | # ax.set_yticks(np.arange(0, len(pivot_table.index), 2)) 437 | # ax.set_yticklabels(labels, rotation=0) 438 | # for spine in ax.spines.values(): 439 | # spine.set_visible(False) 440 | # for spine in ax2.spines.values(): 441 | # spine.set_visible(False) 442 | 443 | # plt.tight_layout() 444 | # plt.draw() 445 | # directory_path, original_filename = os.path.split(save_path) 446 | 447 | # filename_suffix = (title_name+'_'+dataset_name).replace(' ', '_') 448 | # new_filename = f'{filename_suffix}.png' 449 | 450 | # new_save_path = os.path.join(directory_path, new_filename) 451 | 452 | # plt.savefig(new_save_path, format='png', bbox_inches='tight', pad_inches=0) 453 | # print(f'Saved: {new_save_path}') 454 | 455 | # plt.close() 456 | 457 | # return overall_score 458 | 459 | 460 | def ensure_directory(path): 461 | if not os.path.exists(path): 462 | os.makedirs(path) 463 | 464 | def get_dict_model_names(nested_dict): 465 | model_names = [] 466 | for first_level_key in nested_dict: 467 | model_names.append(first_level_key) 468 | return model_names 469 | 470 | def merge_dataframes(model_name, dataset_abbrs, parsed_data): 471 | dfs = [] 472 | for dataset_abbr in dataset_abbrs: 473 | parallel_flag = 'parallel' in dataset_abbr 474 | df = create_model_dataframe(parsed_data, model_name, dataset_abbr, parallel=parallel_flag) 475 | 476 | if not df.empty and len(df.columns) > 1: 477 | score_column = df.columns[-1] 478 | df.rename(columns={score_column: dataset_abbr}, inplace=True) 479 | 480 | dfs.append(df) 481 | 482 | # print('dfs', dfs) 483 | # print('dataset_abbrs', dataset_abbrs) 484 | # print('parsed_data', parsed_data) 485 | 486 | from functools import reduce 487 | merged_df = reduce(lambda left, right: pd.merge(left, right, on='dataset', how='outer'), dfs) 488 | 489 | if merged_df.isnull().any().any(): 490 | print('Warning: Some rows were filtered out due to NaN values. ' 491 | 'This is often due to mismatched row counts among DataFrames.') 492 | merged_df = merged_df.dropna() 493 | return merged_df 494 | 495 | class NeedleBenchSummarizer(DefaultSummarizer): 496 | """NeedleBench summarizer in OpenCompass. 497 | 498 | Args: 499 | config (ConfigDict): The configuration object of the evaluation task. It's expected to be filled out at runtime. 500 | dataset_abbrs (list[str], optional): Dataset abbreviations to be listed in the summary. 501 | summary_groups (list): The dataset groups whose results need to be averaged out. For example, mmlu. Each item it a dict with 502 | 'name' (str) and 'subsets' (list of dataset abbrs), and optionally 503 | 'weights' if weighted average is needed. 504 | prompt_db: A deprecated field. 505 | """ 506 | def _format_table(self, parsed_results, dataset_metrics, dataset_eval_mode): 507 | dataset_abbrs = [dataset_abbr_from_cfg(dataset) for dataset in self.dataset_cfgs] 508 | prompt_version = {dataset_abbr_from_cfg(d): get_prompt_hash(d)[:6] for d in self.dataset_cfgs} 509 | 510 | summarizer_dataset_abbrs = [] 511 | if self.dataset_abbrs is None: 512 | for dataset_abbr in dataset_abbrs: 513 | if dataset_abbr in dataset_metrics: 514 | for metric in dataset_metrics[dataset_abbr]: 515 | summarizer_dataset_abbrs.append((dataset_abbr, metric)) 516 | else: 517 | summarizer_dataset_abbrs.append((dataset_abbr, None)) 518 | for dataset_abbr in dataset_metrics: 519 | for metric in dataset_metrics[dataset_abbr]: 520 | if (dataset_abbr, metric) not in summarizer_dataset_abbrs: 521 | summarizer_dataset_abbrs.append((dataset_abbr, metric)) 522 | else: 523 | for item in self.dataset_abbrs: 524 | if isinstance(item, str): 525 | summarizer_dataset_abbrs.append((item, None)) 526 | elif isinstance(item, (list, tuple)): 527 | summarizer_dataset_abbrs.append((item[0], item[1])) 528 | 529 | table = [] 530 | header = ['dataset', 'version', 'metric', 'mode'] + self.model_abbrs 531 | table.append(header) 532 | 533 | for key in dataset_metrics: 534 | dataset_metrics[key] = list(set(dataset_metrics[key])) 535 | 536 | for dataset_abbr, metric in summarizer_dataset_abbrs: 537 | if dataset_abbr not in dataset_metrics: 538 | 539 | table.append([dataset_abbr, '-', '-', '-'] + ['-'] * len(self.model_abbrs)) 540 | table.append(header) 541 | continue 542 | if len(dataset_metrics[dataset_abbr]) >= 10: 543 | metric = 'average_score' 544 | 545 | if metric is None: 546 | metric = dataset_metrics[dataset_abbr][0] 547 | elif metric in dataset_metrics[dataset_abbr]: 548 | pass 549 | else: 550 | table.append([dataset_abbr, '-', '-', '-'] + ['-'] * len(self.model_abbrs)) 551 | continue 552 | 553 | row = [dataset_abbr, prompt_version.get(dataset_abbr, '-'), metric, dataset_eval_mode.get(dataset_abbr, '-')] 554 | for model_abbr in self.model_abbrs: 555 | if dataset_abbr in parsed_results[model_abbr]: 556 | row.append('{:.02f}'.format(parsed_results[model_abbr][dataset_abbr][metric])) 557 | else: 558 | row.append('-') 559 | 560 | table.append(row) 561 | for i in range(len(table)): 562 | if i == 0 or table[i][0].startswith('---------'): 563 | table[i] = [table[i][0]] + table[i][4:] 564 | else: 565 | table[i] = [table[i][0]] + table[i][4:] 566 | 567 | return table 568 | 569 | def _format_raw_txt(self, raw_results): 570 | raw_dataset_abbrs = [] 571 | for model_abbr in self.model_abbrs: 572 | for dataset_abbr in raw_results[model_abbr]: 573 | if dataset_abbr not in raw_dataset_abbrs: 574 | raw_dataset_abbrs.append(dataset_abbr) 575 | raw_txts = [] 576 | for model_abbr in self.model_abbrs: 577 | raw_txts.append('-------------------------------') 578 | raw_txts.append(f'Model: {model_abbr}') 579 | for dataset_abbr in raw_dataset_abbrs: 580 | result = raw_results[model_abbr].get(dataset_abbr, '{}') 581 | raw_txts.append(f'{dataset_abbr}: {result}') 582 | raw_txts = '\n'.join(raw_txts) 583 | return raw_txts 584 | 585 | def _output_to_file(self, output_path, time_str, table, raw_txts): 586 | if output_path is None: 587 | output_path = osp.join(self.work_dir, 'summary', f'summary_{time_str}.txt') 588 | output_csv_path = osp.join(self.work_dir, 'summary', f'summary_{time_str}.csv') 589 | else: 590 | output_csv_path = output_path.replace('.txt', '.csv') 591 | 592 | output_dir = osp.split(output_path)[0] 593 | mmengine.mkdir_or_exist(output_dir) 594 | with open(output_path, 'w', encoding='utf-8') as f: 595 | text = f'{time_str}\n' + \ 596 | 'tabulate format\n' + \ 597 | '^' * 128 + '\n' + \ 598 | tabulate.tabulate(table, headers='firstrow') + '\n' + \ 599 | '$' * 128 + '\n\n' + \ 600 | '-' * 128 + ' THIS IS A DIVIDER ' + '-' * 128 + '\n\n' + \ 601 | 'csv format\n' + \ 602 | '^' * 128 + '\n' + \ 603 | '\n'.join([','.join(row) for row in table]) + '\n' + \ 604 | '$' * 128 + '\n\n' + \ 605 | '-' * 128 + ' THIS IS A DIVIDER ' + '-' * 128 + '\n\n' + \ 606 | 'raw format\n' + \ 607 | '^' * 128 + '\n' + \ 608 | raw_txts + '\n' + \ 609 | '$' * 128 + '\n' 610 | f.write(text) 611 | self.logger.info(f'write summary to {osp.abspath(output_path)}') 612 | 613 | with open(output_csv_path, 'w', encoding='utf-8') as f: 614 | f.write('\n'.join([','.join(row) for row in table]) + '\n') 615 | self.logger.info(f'write csv to {osp.abspath(output_csv_path)}') 616 | 617 | 618 | def summarize( 619 | self, 620 | output_path: str = None, 621 | time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S')): # noqa 622 | 623 | raw_results, parsed_results, dataset_metrics, dataset_eval_mode = self._pick_up_results() 624 | raw_results, parsed_results, dataset_metrics, dataset_eval_mode = \ 625 | self._calculate_group_metrics(raw_results, parsed_results, dataset_metrics, dataset_eval_mode) 626 | table = self._format_table(parsed_results, dataset_metrics, dataset_eval_mode) 627 | raw_txts = self._format_raw_txt(raw_results) 628 | print(tabulate.tabulate(table, headers='firstrow')) 629 | self._output_to_file(output_path, time_str, table, raw_txts) 630 | if self.lark_reporter: 631 | content = f'{getpass.getuser()} 的' 632 | content += f'详细评测汇总已输出至 {osp.abspath(output_path)}' 633 | self.lark_reporter.post(content) 634 | 635 | if output_path is None: 636 | output_path = osp.join(self.work_dir, 'summary', f'summary_{time_str}.txt') 637 | # plot to show visualize results 638 | save_results_to_plots(output_path) 639 | 640 | class NeedleBenchATCSummarizer(DefaultSummarizer): 641 | """NeedleBench-ATC summarizer in OpenCompass. 642 | 643 | Args: 644 | config (ConfigDict): The configuration object of the evaluation task. It's expected to be filled out at runtime. 645 | dataset_abbrs (list[str], optional): Dataset abbreviations to be listed in the summary. 646 | summary_groups (list): The dataset groups whose results need to be averaged out. For example, mmlu. Each item it a dict with 647 | 'name' (str) and 'subsets' (list of dataset abbrs), and optionally 648 | 'weights' if weighted average is needed. 649 | prompt_db: A deprecated field. 650 | """ 651 | 652 | def _format_table(self, parsed_results, dataset_metrics, dataset_eval_mode): 653 | dataset_abbrs = [dataset_abbr_from_cfg(dataset) for dataset in self.dataset_cfgs] 654 | prompt_version = {dataset_abbr_from_cfg(d): get_prompt_hash(d)[:6] for d in self.dataset_cfgs} 655 | 656 | summarizer_dataset_abbrs = [] 657 | if self.dataset_abbrs is None: 658 | # display all dataset metrics included in the config 659 | for dataset_abbr in dataset_abbrs: 660 | if dataset_abbr in dataset_metrics: 661 | for metric in dataset_metrics[dataset_abbr]: 662 | summarizer_dataset_abbrs.append((dataset_abbr, metric)) 663 | else: 664 | summarizer_dataset_abbrs.append((dataset_abbr, None)) 665 | # along with all possible group metrics 666 | for dataset_abbr in dataset_metrics: 667 | for metric in dataset_metrics[dataset_abbr]: 668 | if (dataset_abbr, metric) not in summarizer_dataset_abbrs: 669 | summarizer_dataset_abbrs.append((dataset_abbr, metric)) 670 | else: 671 | # follow the required order 672 | for item in self.dataset_abbrs: 673 | if isinstance(item, str): 674 | summarizer_dataset_abbrs.append((item, None)) 675 | elif isinstance(item, (list, tuple)): 676 | summarizer_dataset_abbrs.append((item[0], item[1])) 677 | 678 | table = [] 679 | header = ['dataset', 'version', 'metric', 'mode'] + self.model_abbrs 680 | table.append(header) 681 | 682 | for key in dataset_metrics: 683 | dataset_metrics[key] = list(set(dataset_metrics[key])) 684 | 685 | for dataset_abbr, metric in summarizer_dataset_abbrs: 686 | if dataset_abbr not in dataset_metrics: 687 | table.append([dataset_abbr, '-', '-', '-'] + ['-'] * len(self.model_abbrs)) 688 | table.append(header) 689 | continue 690 | if len(dataset_metrics[dataset_abbr]) >= 10: 691 | metric = 'average_score' 692 | 693 | if metric is None: 694 | metric = dataset_metrics[dataset_abbr][0] 695 | elif metric in dataset_metrics[dataset_abbr]: 696 | pass 697 | else: 698 | table.append([dataset_abbr, '-', '-', '-'] + ['-'] * len(self.model_abbrs)) 699 | continue 700 | 701 | row = [dataset_abbr, prompt_version.get(dataset_abbr, '-'), metric, dataset_eval_mode.get(dataset_abbr, '-')] 702 | for model_abbr in self.model_abbrs: 703 | if dataset_abbr in parsed_results[model_abbr]: 704 | row.append('{:.02f}'.format(parsed_results[model_abbr][dataset_abbr][metric])) 705 | else: 706 | row.append('-') 707 | 708 | table.append(row) 709 | for i in range(len(table)): 710 | if i == 0 or table[i][0].startswith('---------'): 711 | table[i] = [table[i][0]] + table[i][4:] 712 | else: 713 | table[i] = [table[i][0]] + table[i][4:] 714 | 715 | return table 716 | 717 | def _read_and_sort_dataframe(self, file_path): 718 | # Read the file without treating the first row as a header 719 | data = pd.read_csv(file_path) 720 | # print(data) 721 | # Correct the extraction of needle counts for all settings 722 | data['needle_count'] = data['dataset'].str.extract(r'needle_(\d+)_').astype(float) 723 | data['needle_count'] = data['needle_count'].astype(int) 724 | 725 | # Define experimental settings groups 726 | experimental_settings = { 727 | 'en': '_en$', 728 | 'zh': '_zh$', 729 | 'en_ordered': '_en_ordered', 730 | 'zh_ordered': '_zh_ordered', 731 | } 732 | 733 | # Function to calculate maximum needles 734 | def calculate_max_needles(dataset): 735 | max_needles = {model: None for model in dataset.columns if 'b' in model} 736 | for model in max_needles.keys(): 737 | consecutive_low_scores = 0 738 | previous_needle_count = 0 739 | for index, row in dataset.sort_values(by='needle_count').iterrows(): 740 | try: 741 | score = float(row[model]) 742 | except ValueError as e: 743 | score = -1 744 | if score < 60: 745 | consecutive_low_scores += 1 746 | if consecutive_low_scores == 1: 747 | max_needles[model] = previous_needle_count 748 | else: 749 | consecutive_low_scores = 0 750 | previous_needle_count = row['needle_count'] 751 | max_needle_count_seen = dataset['needle_count'].max() 752 | max_needles[model] = max_needle_count_seen if max_needles[model] is None else max_needles[model] 753 | return max_needles 754 | 755 | # Calculate max needles for each group and organize results in a DataFrame 756 | results = {} 757 | for setting, regex in experimental_settings.items(): 758 | filtered_data = data[data['dataset'].str.contains(regex)] 759 | results[setting] = calculate_max_needles(filtered_data) 760 | 761 | # Convert results to DataFrame and transpose it 762 | results_df = pd.DataFrame(results).transpose() 763 | 764 | # Return the sorted DataFrame 765 | results_df.index.name = 'ATC Experiment Type' 766 | return results_df 767 | 768 | def _output_to_file(self, output_path, time_str, table, raw_txts): 769 | # output to file 770 | if output_path is None: 771 | output_path = osp.join(self.work_dir, 'summary', f'summary_{time_str}.txt') 772 | output_csv_path = osp.join(self.work_dir, 'summary', f'summary_{time_str}.csv') 773 | else: 774 | output_csv_path = output_path.replace('.txt', '.csv') 775 | 776 | output_dir = osp.split(output_path)[0] 777 | mmengine.mkdir_or_exist(output_dir) 778 | with open(output_path, 'w', encoding='utf-8') as f: 779 | text = f'{time_str}\n' + \ 780 | 'tabulate format\n' + \ 781 | '^' * 128 + '\n' + \ 782 | tabulate.tabulate(table, headers='firstrow') + '\n' + \ 783 | '$' * 128 + '\n\n' + \ 784 | '-' * 128 + ' THIS IS A DIVIDER ' + '-' * 128 + '\n\n' + \ 785 | 'csv format\n' + \ 786 | '^' * 128 + '\n' + \ 787 | '\n'.join([','.join(row) for row in table]) + '\n' + \ 788 | '$' * 128 + '\n\n' + \ 789 | '-' * 128 + ' THIS IS A DIVIDER ' + '-' * 128 + '\n\n' + \ 790 | 'raw format\n' + \ 791 | '^' * 128 + '\n' + \ 792 | raw_txts + '\n' + \ 793 | '$' * 128 + '\n' 794 | f.write(text) 795 | self.logger.info(f'write summary to {osp.abspath(output_path)}') 796 | 797 | with open(output_csv_path, 'w', encoding='utf-8') as f: 798 | f.write('\n'.join([','.join(row) for row in table]) + '\n') 799 | # self.logger.info(f'write csv to {osp.abspath(output_csv_path)}') 800 | 801 | df_sorted = self._read_and_sort_dataframe(output_csv_path) 802 | 803 | df_sorted.to_csv(output_csv_path) 804 | 805 | self.logger.info(f'write sorted csv to {output_csv_path}') 806 | 807 | 808 | def summarize( 809 | self, 810 | output_path: str = None, 811 | time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S')): # noqa 812 | 813 | # pick up results 814 | raw_results, parsed_results, dataset_metrics, dataset_eval_mode = self._pick_up_results() 815 | 816 | # calculate group metrics 817 | raw_results, parsed_results, dataset_metrics, dataset_eval_mode = \ 818 | self._calculate_group_metrics(raw_results, parsed_results, dataset_metrics, dataset_eval_mode) 819 | 820 | # format table 821 | table = self._format_table(parsed_results, dataset_metrics, dataset_eval_mode) 822 | 823 | # format raw txt 824 | raw_txts = self._format_raw_txt(raw_results) 825 | 826 | # output to .text / .csv files 827 | self._output_to_file(output_path, time_str, table, raw_txts) 828 | 829 | if self.lark_reporter: 830 | content = f'{getpass.getuser()} 的' 831 | content += f'详细评测汇总已输出至 {osp.abspath(output_path)}' 832 | self.lark_reporter.post(content) 833 | 834 | if output_path is None: 835 | output_path = osp.join(self.work_dir, 'summary', f'summary_{time_str}.txt') 836 | --------------------------------------------------------------------------------