├── README.md ├── attenscore ├── configuration_chatglm.py ├── datastore.py ├── main.py ├── modeling_chatglm.py ├── modeling_llama.py ├── modeling_qwen2.py └── run_all.sh ├── datas └── system_benchmark_eval_datas.json ├── eval_system_bench.py ├── eval_system_bench_with_gt.py ├── models ├── __init__.py ├── claude35_opus.py ├── deepseek.py ├── ernie4.py ├── glm4.py ├── glm_9b_client.py ├── gpt35.py ├── gpt4_turbo_0409.py ├── gpt4o.py ├── llama3_1_70b.py ├── llama3_1_8b.py ├── moonshot.py ├── qwen2_72b.py ├── qwen2_7b.py ├── template.py └── yi_large.py ├── output ├── .gitignore ├── claude35_opus │ └── claude35_opus_analysis.xlsx ├── deepseek │ └── deepseek_analysis.xlsx ├── ernie4 │ └── ernie4_analysis.xlsx ├── glm4 │ └── glm4_analysis.xlsx ├── glm_9b_client │ └── glm_9b_client_analysis.xlsx ├── gpt35 │ └── gpt35_analysis.xlsx ├── gpt4_turbo_0409 │ └── gpt4_turbo_0409_analysis.xlsx ├── gpt4o │ └── gpt4o_analysis.xlsx ├── llama3_70b │ └── llama3_70b_analysis.xlsx ├── llama3_8b │ └── llama3_8b_analysis.xlsx ├── moonshot │ └── moonshot_analysis.xlsx ├── qwen2_72b │ └── qwen2_72b_analysis.xlsx ├── qwen2_7b │ └── qwen2_7b_analysis.xlsx ├── with_gt_history_output │ ├── claude35_opus │ │ └── claude35_opus_analysis.xlsx │ ├── ernie4 │ │ └── ernie4_analysis.xlsx │ ├── gpt35 │ │ └── gpt35_analysis.xlsx │ ├── gpt4o │ │ └── gpt4o_analysis.xlsx │ ├── llama3_70b │ │ └── llama3_70b_analysis.xlsx │ ├── llama3_8b │ │ └── llama3_8b_analysis.xlsx │ └── qwen2_72b │ │ └── qwen2_72b_analysis.xlsx └── yi_large │ └── yi_large_analysis.xlsx ├── plot ├── .gitignore ├── analyze_history_gt.py ├── eval_output.py ├── fig3_stat.py ├── fig4_radar.py ├── fig5_hgt_histo.py ├── fig6_atscore.py ├── fig_atscore_curve.py ├── fig_atscore_replace.py ├── fig_constraint.py ├── fig_domain.py ├── tab1_categoty.py ├── tab2_overall.py ├── tab3_align.py ├── tab4_turn.py ├── tab6_csr_full.py ├── tab7_align_full.py └── utils │ ├── change_color.py │ ├── generate_n_color.py │ ├── get_rank.py │ ├── parse_xls.py │ └── smooth.py ├── requirements.txt ├── run.sh ├── run_metric.sh ├── servers └── run_vllm_serve.sh └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # SysBench 2 | Code for [SysBench: Can Large Language Models Follow System Messages?](https://arxiv.org/abs/2408.10943) 3 | 4 | ## Introduction 5 | 6 | In this section, we introduce the usage of our attached codes, including: 7 | 8 | - SysBench's dataset. 9 | - Running customized model on SysBench 10 | - Reproducing figures/tables using provided data. 11 | - Reproducing figures/tables from scratch. 12 | 13 | 14 | ## The Dataset 15 | 16 | `datas/system_benchmark_eval_datas.json` is the dataset file, a JSON array containing 500 dialogues with system messages. For each entry, the meanings of its important JSON fields are listed in the table below. 17 | |Field Name|Meaning| 18 | |-|-| 19 | |`system_id`|The ID of the system message (or dialogue)| 20 | |`system_prompt`|The content of the system message.| 21 | |`messages`|A JSON array containing the roles and contents of system message and the whole 5-turn conversation. The contents of the role "assistant" is the ground truth.| 22 | |`prompt_infos`|Containing 5 JSON entries corresponding to each user instruction. For each entry, the field `alignment` denotes its alignment with the system message, and the field `criteria` is the checklists with constraint types labeled.| 23 | 24 | 25 | ## Evaluating Customized Model 26 | 27 | This section presents the steps to evaluate a customized model on SysBench. 28 | 29 | ### Software Dependencies 30 | 31 | Only Python (>= 3.10) and the openai (>= 1.0) package are required for the code base itself. 32 | 33 | ### Implement Interface 34 | 35 | We provide a template model file for easily adding a new model. Suppose the customized model's name is `myModel`, copy the template by: 36 | 37 | ```sh 38 | cd models 39 | cp template.py myModel.py 40 | cd .. 41 | ``` 42 | 43 | Then, rename the class name to `myModel` and implement its `__call__` method. It receives a list called `messages`, where each element is a Python dictionary containing `role` and `content` keys, representing the whole historical dialog contents for generating the next model outputs. This method should return model response contents in string format. 44 | 45 | Some of the existing code in this directory can be used for reference. For example, `gpt4o.py` is for the OpenAI style API, `glm_9b_client.py` is for the vLLM server, while `qwen2_7b.py` is for offline inference. 46 | 47 | ### Prepare GPT-4o as Verifier 48 | 49 | GPT-4o is used as the model-based verifier, please fill in the OpenAI Key and base URL in `gpt4o.py` to configure GPT-4o inference. Run the following command to test its usability: 50 | 51 | ```sh 52 | python models/gpt4o.py 53 | ``` 54 | 55 | ### Run Evaluation 56 | 57 | Run the following command for evaluation: 58 | 59 | ```sh 60 | python -m eval_system_bench \ 61 | --infer_model_name myModel \ 62 | --output_dir output \ 63 | --max_threads 20 64 | ``` 65 | 66 | **Note:** It is highly recommended to use online inference and keep your `__call__` method re-entrant. Setting max_threads to 1 is required in the absence of such a guarantee. 67 | 68 | ### Calculate Metrics 69 | 70 | After finishing the evaluation step, the detailed model and verifier outputs are both automatically stored in the `output/myModel` directory by default. To calculate the metric scores, run: 71 | 72 | ```sh 73 | python -m eval_output \ 74 | --infer_model_name myModel \ 75 | --output_dir output 76 | ``` 77 | 78 | This command will output the metric scores with detailed information. 79 | 80 | 81 | ## Reproducing Results from Provided Data 82 | 83 | Since all API keys are removed from our provided data due to privacy and anonymity requests, reproducing all results in the paper from scratch is more complicated, and we place the instructions in the next subsection. In this section, we elaborate on the steps to reproduce results with our provided raw data, which is much easier to follow. 84 | 85 | ### Software Dependencies 86 | 87 | Python (>= 3.10), matplotlib (>= 3.9), pandas (>= 2.2), and openpyxl(>= 3.1) are required. 88 | 89 | ### Plot Figures 90 | 91 | Run the following commands to plot figures and generate tables (in LaTeX code). We recommend installing the missing fonts for better display: 92 | 93 | ```sh 94 | mkdir figures # create the output directory 95 | 96 | # Plot each figure 97 | python plot/fig3_stat.py 98 | python plot/fig4_radar.py 99 | python plot/fig5_hgt_histo.py 100 | python plot/fig6_atscore.py 101 | 102 | # Generate each table in LaTeX code 103 | python plot/tab1_category.py 104 | python plot/tab2_overall.py 105 | python plot/tab3_align.py 106 | python plot/tab4_turn.py 107 | python plot/tab6_csr_full.py 108 | python plot/tab7_align_full.py 109 | ``` 110 | 111 | These commands will parse the raw data in `output/` and generate figures and tables presented in the paper. 112 | 113 | ### Expected Results 114 | 115 | All results should be **strictly consistent** with those presented in the paper. 116 | 117 | 118 | ## Reproducing Results from Scratch 119 | 120 | To reproduce from scratch, obtaining the API keys (for all closed models) or preparing the checkpoints (for all open models) are required. Here lists the detailed steps. 121 | 122 | ### Hardware Dependencies 123 | 124 | GPU instances are required when running open-sourced models. For the largest Qwen-72B model, we use 4× NVIDIA H100 80GB GPUs. 125 | 126 | ### Software Dependencies 127 | 128 | transformers (>= 4.44.0) and vLLM (>= 0.5.0). 129 | 130 | ### Configure Models 131 | 132 | Please modify **all** the model files listed in `./models` directories. 133 | For models with public API, please fill in your public keys and the base URLs. 134 | For open-sourced models running inference locally (i.e., Qwen family, Llama family, and GLM-4 9B), we recommend deploying a vLLM server for online serving, please check `glm_9b_client.py` for reference and modify others. 135 | 136 | We also provide a sample script to start the vLLM server, at `servers/run_vllm_serve.sh` 137 | 138 | ### Backup Our Data (Optional) 139 | 140 | The `output/` directory will be overwritten later. 141 | 142 | ```sh 143 | mv output output-backup && mkdir output 144 | ``` 145 | 146 | ### Exp. 1: Evaluate Models 147 | 148 | For each model, run the following command for evaluation, please set max_threads to 1 for those without re-entrant guarantee. 149 | 150 | ```sh 151 | python -m eval_system_bench \ 152 | --infer_model_name \ 153 | --output_dir output \ 154 | --max_threads 20 155 | ``` 156 | 157 | Then, the detailed evaluation results are available in the directory `output/`. 158 | 159 | ### Exp. 2: Ground-truth History 160 | 161 | We also replace the historical model response with the ground truth: 162 | 163 | ```sh 164 | OUTDIR=output/with_gt_history_output 165 | python -m eval_system_bench_with_gt \ 166 | --infer_model_name \ 167 | --output_dir $OUTDIR \ 168 | --max_threads 20 169 | ``` 170 | 171 | To reproduce Figure in the paper, following models should be run with the command above: `qwen2_72b`, `claude35_opus`, `ernie4` and `llama3_8b`. All results will be stored in `output/with_gt_history_output` standby. 172 | 173 | ### Exp. 3: Attention Score 174 | 175 | To explore the distribution of attention scores, please first specify the Huggingface checkpoint paths of `glm4-9b`, `llama31-8b`, and `qwen-72b` models in Line 20-22 of `attenscore/main.py`. 176 | 177 | Then, change the working directory to `./attenscore` and run our provided script by commands: 178 | 179 | ```sh 180 | cd attenscore 181 | bash run_all.sh 182 | ``` 183 | 184 | You can change the value of the `--id` flag if you want to explore another system message not presented in Figure. And set `--id` to -1 will run all 500 system messages on the current model, but very time-consuming. All results will be stored in `output/attenscore` for plotting the figure later. 185 | 186 | ### Reproduce Figures and Tables 187 | 188 | Finally, when all experimental data are ready in the `output/`, follow the instructions for reproduction. Note that there are more available command flags for the attention score figure, run the following command for model details: 189 | 190 | ```sh 191 | python plot/fig6_atscore -h 192 | ``` 193 | 194 | ### Expected Results 195 | 196 | Even though there exists unavoidable randomness and fluctuation, especially for closed models, all the figures and tables should statistically match the patterns shown in the paper. 197 | 198 | ## Citation 199 | 200 | ```bibtex 201 | @article{qin2024sysbench, 202 | title={SysBench: Can Large Language Models Follow System Messages?}, 203 | author={Qin, Yanzhao and Zhang, Tao and Shen, Yanjun and Luo, Wenjing and Sun, Haoze and Zhang, Yan and Qiao, Yujing and Chen, Weipeng and Zhou, Zenan and Zhang, Wentao and others}, 204 | journal={arXiv preprint arXiv:2408.10943}, 205 | year={2024} 206 | } 207 | ``` 208 | -------------------------------------------------------------------------------- /attenscore/configuration_chatglm.py: -------------------------------------------------------------------------------- 1 | from transformers import PretrainedConfig 2 | 3 | 4 | class ChatGLMConfig(PretrainedConfig): 5 | model_type = "chatglm" 6 | 7 | def __init__( 8 | self, 9 | num_layers=28, 10 | padded_vocab_size=65024, 11 | hidden_size=4096, 12 | ffn_hidden_size=13696, 13 | kv_channels=128, 14 | num_attention_heads=32, 15 | seq_length=2048, 16 | hidden_dropout=0.0, 17 | classifier_dropout=None, 18 | attention_dropout=0.0, 19 | layernorm_epsilon=1e-5, 20 | rmsnorm=True, 21 | apply_residual_connection_post_layernorm=False, 22 | post_layer_norm=True, 23 | add_bias_linear=False, 24 | add_qkv_bias=False, 25 | bias_dropout_fusion=True, 26 | multi_query_attention=False, 27 | multi_query_group_num=1, 28 | rope_ratio=1, 29 | apply_query_key_layer_scaling=True, 30 | attention_softmax_in_fp32=True, 31 | fp32_residual_connection=False, 32 | **kwargs 33 | ): 34 | self.num_layers = num_layers 35 | self.vocab_size = padded_vocab_size 36 | self.padded_vocab_size = padded_vocab_size 37 | self.hidden_size = hidden_size 38 | self.ffn_hidden_size = ffn_hidden_size 39 | self.kv_channels = kv_channels 40 | self.num_attention_heads = num_attention_heads 41 | self.seq_length = seq_length 42 | self.hidden_dropout = hidden_dropout 43 | self.classifier_dropout = classifier_dropout 44 | self.attention_dropout = attention_dropout 45 | self.layernorm_epsilon = layernorm_epsilon 46 | self.rmsnorm = rmsnorm 47 | self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm 48 | self.post_layer_norm = post_layer_norm 49 | self.add_bias_linear = add_bias_linear 50 | self.add_qkv_bias = add_qkv_bias 51 | self.bias_dropout_fusion = bias_dropout_fusion 52 | self.multi_query_attention = multi_query_attention 53 | self.multi_query_group_num = multi_query_group_num 54 | self.rope_ratio = rope_ratio 55 | self.apply_query_key_layer_scaling = apply_query_key_layer_scaling 56 | self.attention_softmax_in_fp32 = attention_softmax_in_fp32 57 | self.fp32_residual_connection = fp32_residual_connection 58 | super().__init__(**kwargs) 59 | -------------------------------------------------------------------------------- /attenscore/datastore.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | 5 | class DataStore: 6 | def __init__(self) -> None: 7 | self._data_store = {} 8 | self._split_indices = [] 9 | self.has_extra = False 10 | 11 | def add_split_index(self, split_index : int, extra : bool = False) -> None: 12 | if extra: 13 | assert len(self._split_indices) == 0, "Extra split index can only be added at the beginning" 14 | self.has_extra = True 15 | self._split_indices.append(split_index) 16 | 17 | def append(self, layer_num : int, data : torch.Tensor) -> None: 18 | if data.dim() == 4: 19 | data = data.squeeze(0) 20 | # [np, sq, sk] -> [sq, sk] 21 | # print(f"Data shape: {data.shape}", ", indices:", self._split_indices) 22 | data = data.sum(dim=0, dtype=torch.float32) 23 | # print(f"Data shape: {data.shape}", data[0]) 24 | data = data.cumsum(dim=-1).cpu().numpy() 25 | 26 | to_save_data = np.zeros((data.shape[0], 12 if self.has_extra else 11), dtype=np.float32) 27 | sliced_data = data[:, self._split_indices] 28 | to_save_data[:, :sliced_data.shape[1]] = sliced_data 29 | to_save_data[:, sliced_data.shape[1]] = data[:, -1] 30 | 31 | if layer_num not in self._data_store: 32 | self._data_store[layer_num] = [] 33 | self._data_store[layer_num].append(to_save_data) 34 | # print(f"Data appended to layer {layer_num}, shape is {to_save_data.shape}") 35 | 36 | def _collect(self, layer_num : int) -> np.ndarray: 37 | if layer_num not in self._data_store: 38 | return None 39 | return np.vstack(self._data_store[layer_num]) 40 | 41 | def get_keys(self) -> list: 42 | return list(self._data_store.keys()) 43 | 44 | def save_data(self, save_path : str, file_name : str = '') -> None: 45 | split_indices_np = np.array(self._split_indices) 46 | os.makedirs(save_path, exist_ok=True) 47 | 48 | for layer_num in self._data_store: 49 | data = self._collect(layer_num) 50 | to_save = { 51 | "data": data, 52 | "split_indices": split_indices_np 53 | } 54 | fn = f"layer_{layer_num}_{file_name}.npy" if file_name else f"layer_{layer_num}.npy" 55 | np.save(os.path.join(save_path, fn), to_save) 56 | print(f"File saved: {fn}, with {data.shape[0]} samples") 57 | 58 | def load_data(self, load_path : str) -> None: 59 | assert len(self._data_store) == 0, "Data store is not empty" 60 | for file in os.listdir(load_path): 61 | if file.endswith(".npy"): 62 | data = np.load(os.path.join(load_path, file), allow_pickle=True).item() 63 | layer_num = int(file.split("_")[1].split(".")[0]) 64 | self._data_store[layer_num] = data["data"] 65 | self._split_indices = data["split_indices"] 66 | 67 | def clear(self) -> None: 68 | self._data_store = {} 69 | self._split_indices = [] 70 | self.has_extra = False 71 | 72 | def get_split_indices(self) -> list: 73 | return self._split_indices 74 | 75 | data_store = DataStore() 76 | 77 | def get_data_store() -> DataStore: 78 | global data_store 79 | return data_store -------------------------------------------------------------------------------- /attenscore/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import datetime 5 | import re 6 | 7 | import torch 8 | from transformers import AutoTokenizer 9 | from transformers.generation.utils import GenerationConfig 10 | 11 | from datastore import get_data_store 12 | from modeling_chatglm import ChatGLMForConditionalGeneration 13 | from modeling_llama import LlamaForCausalLM 14 | from modeling_qwen2 import Qwen2ForCausalLM 15 | 16 | data_file_path = '../datas/system_benchmark_eval_datas.json' 17 | 18 | # THUDM/glm-4-9b-chat, meta-llama/Meta-Llama-3.1-8B-Instruct, Qwen/Qwen2-72B-Instruct 19 | checkpoint_paths = { 20 | 'glm-9b': '/path/to/.cache/huggingface/hub/models--THUDM--glm-4-9b-chat/snapshots/aae8bd74af5c6dff63a49d7fbdcc89349ebf87aa/', 21 | 'llama31-8b': '/path/to/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3.1-8B-Instruct/snapshots/8c22764a7e3675c50d4c7c9a4edb474456022b16/', 22 | 'qwen-72b': '/path/to/.cache/huggingface/hub/models--Qwen--Qwen2-72B-Instruct/snapshots/1af63c698f59c4235668ec9c1395468cb7cd7e79/' 23 | } 24 | 25 | extra_tokens = { 26 | 'glm-9b': 2, 27 | 'llama31-8b': 25, 28 | 'qwen-72b': 3 29 | } 30 | 31 | cached_sid = set() 32 | 33 | def load_examples(dataset_filepath): 34 | data = json.load(open(dataset_filepath, encoding="utf-8")) 35 | return data 36 | 37 | def converation_generator(sysmeg_id): 38 | for entry in load_examples(data_file_path): 39 | if entry['system_id'] == sysmeg_id: 40 | print("System message ID:", sysmeg_id) 41 | for message in entry['messages']: 42 | if message['role'] == 'assistant': 43 | continue # ignore ground truth 44 | yield message 45 | break 46 | else: 47 | raise ValueError(f"System message with id {sysmeg_id} not found") 48 | 49 | def get_model_type(model_name): 50 | if model_name == 'glm-9b': 51 | return ChatGLMForConditionalGeneration 52 | elif model_name == 'llama31-8b': 53 | return LlamaForCausalLM 54 | elif model_name == 'qwen-72b': 55 | return Qwen2ForCausalLM 56 | else: 57 | raise ValueError(f"Model name {model_name} not found") 58 | 59 | def workflow(arg, model, tokenizer, generation_config, datastore): 60 | if arg.id in cached_sid and not arg.ignore_cache: 61 | print(f"System message {arg.id} already cached, ignore") 62 | return 63 | 64 | datastore.clear() 65 | datastore.add_split_index(extra_tokens[arg.model] - 1, extra=True) # special tokens at the beginning 66 | 67 | generation_length = generation_config.max_length 68 | 69 | messages = [] 70 | for message in converation_generator(arg.id): 71 | if len(messages) > 0 and messages[-1]["role"] == message["role"]: 72 | # concat the content 73 | print(messages[-1]["role"]) 74 | assert len(messages) == 1 and messages[-1]["role"] == "user" and arg.replace 75 | messages[-1]["content"] += message["content"] 76 | 77 | messages.append(message) 78 | 79 | tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt") 80 | input_length = tokenized_chat.shape[-1] 81 | datastore.add_split_index(input_length - 2) 82 | 83 | if message['role'] == 'system': 84 | print("System message:", message) 85 | if arg.replace: 86 | messages[-1]["role"] = "user" 87 | continue # concat next message 88 | 89 | generation_config.max_length = input_length + generation_length 90 | kwargs = { 91 | 'inputs': tokenized_chat.to('cuda'), 92 | 'generation_config' : generation_config 93 | } 94 | # if arg.seed is not None: 95 | # kwargs['seed'] = arg.seed 96 | 97 | outputs = model.generate(**kwargs) 98 | output_text = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True) 99 | datastore.add_split_index(outputs[0].shape[-1] - 2) 100 | print(datetime.datetime.now(), "Output text:", output_text, flush=True) 101 | 102 | messages.append({"role": "assistant", "content": output_text}) 103 | 104 | split_indices = datastore.get_split_indices() 105 | split_indices = [v + 1 for v in split_indices] 106 | 107 | print("Split indices:", split_indices) 108 | for i, idx in enumerate(split_indices): 109 | if i == 0: 110 | if datastore.has_extra: 111 | continue 112 | print(f'===== Split {i} =====', tokenizer.decode(outputs[0, :idx], skip_special_tokens=False), sep='\n') 113 | else: 114 | print(f'===== Split {i} =====', tokenizer.decode(outputs[0, split_indices[i-1]:idx], skip_special_tokens=False), sep='\n') 115 | 116 | datastore.save_data(arg.save_path, file_name=f'sid{arg.id}') 117 | 118 | def main(): 119 | parser = argparse.ArgumentParser() 120 | parser.add_argument("--id", type=int, default=287, help="System message ID, -1 means all") 121 | parser.add_argument("--save_path", type=str, default='glm', help="Path to save the data") 122 | parser.add_argument("--model", type=str, default='glm-9b', choices=['glm-9b', 'llama31-8b', 'qwen-72b'], help="Model name") 123 | parser.add_argument("--ignore_cache", action='store_true', help="Ignore cache") 124 | parser.add_argument("--replace", action='store_true', help="Replace system message as user message") 125 | # parser.add_argument("--seed", type=int, default=None, help="Random seed") 126 | arg = parser.parse_args() 127 | 128 | 129 | model_cls = get_model_type(arg.model) 130 | checkpoint_path = checkpoint_paths[arg.model] 131 | 132 | model = model_cls.from_pretrained(checkpoint_path, device_map="auto") 133 | tokenizer = AutoTokenizer.from_pretrained(checkpoint_path, trust_remote_code=True) 134 | 135 | generation_config = GenerationConfig.from_pretrained(checkpoint_path) 136 | generation_config.max_length = 8192 137 | print(generation_config) 138 | 139 | datastore = get_data_store() 140 | 141 | 142 | if arg.id == -1: # -1 means all 143 | from tqdm import tqdm 144 | # id_list = [63, 175, 187, 287, 401, 424] 145 | id_list = list(range(1, 501)) 146 | 147 | arg.replace = False 148 | for id in tqdm(id_list): 149 | arg.id = id 150 | workflow(arg, model, tokenizer, generation_config, datastore) 151 | 152 | arg.replace = True 153 | arg.save_path += '_replace' 154 | for id in tqdm(id_list): 155 | arg.id = id 156 | workflow(arg, model, tokenizer, generation_config, datastore) 157 | 158 | else: 159 | if arg.replace: 160 | arg.save_path += '_replace' 161 | 162 | pattern = re.compile(r"layer_\d+_sid(\d+).npy") 163 | if os.path.exists(arg.save_path): 164 | for file in os.listdir(arg.save_path): 165 | match = pattern.match(file) 166 | if match: 167 | cached_sid.add(int(match.group(1))) 168 | 169 | workflow(arg, model, tokenizer, generation_config, datastore) 170 | 171 | if __name__ == "__main__": 172 | main() -------------------------------------------------------------------------------- /attenscore/modeling_chatglm.py: -------------------------------------------------------------------------------- 1 | """ PyTorch ChatGLM model. """ 2 | 3 | import math 4 | import sys 5 | import torch 6 | import torch.utils.checkpoint 7 | import torch.nn.functional as F 8 | from torch import nn 9 | from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss 10 | from torch.nn.utils import skip_init 11 | from typing import Optional, Tuple, Union, List, Dict, Any 12 | 13 | from transformers.modeling_outputs import ( 14 | BaseModelOutputWithPast, 15 | CausalLMOutputWithPast, 16 | SequenceClassifierOutputWithPast, 17 | ) 18 | from transformers.modeling_utils import PreTrainedModel 19 | from transformers.utils import logging, is_torch_npu_available 20 | from transformers.generation.logits_process import LogitsProcessor 21 | from transformers.generation.utils import ModelOutput 22 | 23 | from configuration_chatglm import ChatGLMConfig 24 | from datastore import get_data_store 25 | 26 | try: 27 | from transformers.utils import is_flash_attn_greater_or_equal_2_10, is_flash_attn_2_available 28 | 29 | if is_flash_attn_2_available(): 30 | from flash_attn import flash_attn_func, flash_attn_varlen_func 31 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa 32 | except: 33 | pass 34 | 35 | # flags required to enable jit fusion kernels 36 | 37 | if sys.platform != 'darwin' and not is_torch_npu_available(): 38 | torch._C._jit_set_profiling_mode(False) 39 | torch._C._jit_set_profiling_executor(False) 40 | torch._C._jit_override_can_fuse_on_cpu(True) 41 | torch._C._jit_override_can_fuse_on_gpu(True) 42 | 43 | logger = logging.get_logger(__name__) 44 | 45 | _CHECKPOINT_FOR_DOC = "THUDM/ChatGLM" 46 | _CONFIG_FOR_DOC = "ChatGLMConfig" 47 | 48 | 49 | def default_init(cls, *args, **kwargs): 50 | return cls(*args, **kwargs) 51 | 52 | 53 | class InvalidScoreLogitsProcessor(LogitsProcessor): 54 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 55 | if torch.isnan(scores).any() or torch.isinf(scores).any(): 56 | scores.zero_() 57 | scores[..., 198] = 5e4 58 | return scores 59 | 60 | 61 | def split_tensor_along_last_dim( 62 | tensor: torch.Tensor, 63 | num_partitions: int, 64 | contiguous_split_chunks: bool = False, 65 | ) -> List[torch.Tensor]: 66 | """Split a tensor along its last dimension. 67 | 68 | Arguments: 69 | tensor: input tensor. 70 | num_partitions: number of partitions to split the tensor 71 | contiguous_split_chunks: If True, make each chunk contiguous 72 | in memory. 73 | 74 | Returns: 75 | A list of Tensors 76 | """ 77 | # Get the size and dimension. 78 | last_dim = tensor.dim() - 1 79 | last_dim_size = tensor.size()[last_dim] // num_partitions 80 | # Split. 81 | tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) 82 | # Note: torch.split does not create contiguous tensors by default. 83 | if contiguous_split_chunks: 84 | return tuple(chunk.contiguous() for chunk in tensor_list) 85 | 86 | return tensor_list 87 | 88 | 89 | class RotaryEmbedding(nn.Module): 90 | def __init__(self, dim, rope_ratio=1, original_impl=False, device=None, dtype=None): 91 | super().__init__() 92 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) 93 | self.register_buffer("inv_freq", inv_freq) 94 | self.dim = dim 95 | self.original_impl = original_impl 96 | self.rope_ratio = rope_ratio 97 | 98 | def forward_impl( 99 | self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000 100 | ): 101 | """Enhanced Transformer with Rotary Position Embedding. 102 | 103 | Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ 104 | transformers/rope/__init__.py. MIT License: 105 | https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. 106 | """ 107 | # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ 108 | base = base * self.rope_ratio 109 | theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem)) 110 | 111 | # Create position indexes `[0, 1, ..., seq_len - 1]` 112 | seq_idx = torch.arange(seq_len, dtype=torch.float, device=device) 113 | 114 | # Calculate the product of position index and $\theta_i$ 115 | idx_theta = torch.outer(seq_idx, theta).float() 116 | 117 | cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) 118 | 119 | # this is to mimic the behaviour of complex32, else we will get different results 120 | if dtype in (torch.float16, torch.bfloat16, torch.int8): 121 | cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() 122 | return cache 123 | 124 | def forward(self, max_seq_len, offset=0): 125 | return self.forward_impl( 126 | max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device 127 | ) 128 | 129 | 130 | @torch.jit.script 131 | def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: 132 | # x: [b, np, sq, hn] 133 | b, np, sq, hn = x.size(0), x.size(1), x.size(2), x.size(3) 134 | rot_dim = rope_cache.shape[-2] * 2 135 | x, x_pass = x[..., :rot_dim], x[..., rot_dim:] 136 | # truncate to support variable sizes 137 | rope_cache = rope_cache[:, :sq] 138 | xshaped = x.reshape(b, np, sq, rot_dim // 2, 2) 139 | rope_cache = rope_cache.view(-1, 1, sq, xshaped.size(3), 2) 140 | x_out2 = torch.stack( 141 | [ 142 | xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], 143 | xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], 144 | ], 145 | -1, 146 | ) 147 | x_out2 = x_out2.flatten(3) 148 | return torch.cat((x_out2, x_pass), dim=-1) 149 | 150 | 151 | class RMSNorm(torch.nn.Module): 152 | def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): 153 | super().__init__() 154 | self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) 155 | self.eps = eps 156 | 157 | def forward(self, hidden_states: torch.Tensor): 158 | input_dtype = hidden_states.dtype 159 | variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) 160 | hidden_states = hidden_states * torch.rsqrt(variance + self.eps) 161 | 162 | return (self.weight * hidden_states).to(input_dtype) 163 | 164 | 165 | class CoreAttention(torch.nn.Module): 166 | def __init__(self, config: ChatGLMConfig, layer_number): 167 | super(CoreAttention, self).__init__() 168 | self.config = config 169 | self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling 170 | self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 171 | if self.apply_query_key_layer_scaling: 172 | self.attention_softmax_in_fp32 = True 173 | self.layer_number = max(1, layer_number) 174 | self.is_causal = True 175 | 176 | projection_size = config.kv_channels * config.num_attention_heads 177 | 178 | # Per attention head and per partition values. 179 | self.hidden_size_per_partition = projection_size 180 | self.hidden_size_per_attention_head = projection_size // config.num_attention_heads 181 | self.num_attention_heads_per_partition = config.num_attention_heads 182 | 183 | coeff = None 184 | self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) 185 | if self.apply_query_key_layer_scaling: 186 | coeff = self.layer_number 187 | self.norm_factor *= coeff 188 | self.coeff = coeff 189 | 190 | self.attention_dropout = torch.nn.Dropout(config.attention_dropout) 191 | 192 | def forward(self, query_layer, key_layer, value_layer, attention_mask): 193 | # [b, np, sq, sk] 194 | output_size = (query_layer.size(0), query_layer.size(1), query_layer.size(2), key_layer.size(2)) 195 | 196 | # [b, np, sq, hn] -> [b * np, sq, hn] 197 | query_layer = query_layer.view(output_size[0] * output_size[1], output_size[2], -1) 198 | # [b, np, sk, hn] -> [b * np, sk, hn] 199 | key_layer = key_layer.view(output_size[0] * output_size[1], output_size[3], -1) 200 | 201 | # preallocting input tensor: [b * np, sq, sk] 202 | matmul_input_buffer = torch.empty( 203 | output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, 204 | device=query_layer.device 205 | ) 206 | 207 | # Raw attention scores. [b * np, sq, sk] 208 | matmul_result = torch.baddbmm( 209 | matmul_input_buffer, 210 | query_layer, # [b * np, sq, hn] 211 | key_layer.transpose(1, 2), # [b * np, hn, sk] 212 | beta=0.0, 213 | alpha=(1.0 / self.norm_factor), 214 | ) 215 | 216 | # change view to [b, np, sq, sk] 217 | attention_scores = matmul_result.view(*output_size) 218 | 219 | # =========================== 220 | # Attention probs and dropout 221 | # =========================== 222 | 223 | # attention scores and attention mask [b, np, sq, sk] 224 | if self.attention_softmax_in_fp32: 225 | attention_scores = attention_scores.float() 226 | if self.coeff is not None: 227 | attention_scores = attention_scores * self.coeff 228 | if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: 229 | attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3], 230 | device=attention_scores.device, dtype=torch.bool) 231 | attention_mask.tril_() 232 | attention_mask = ~attention_mask 233 | if attention_mask is not None: 234 | attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) 235 | attention_probs = F.softmax(attention_scores, dim=-1) 236 | attention_probs = attention_probs.type_as(value_layer) 237 | 238 | # This is actually dropping out entire tokens to attend to, which might 239 | # seem a bit unusual, but is taken from the original Transformer paper. 240 | attention_probs = self.attention_dropout(attention_probs) 241 | 242 | # query layer shape: [b * np, sq, hn] 243 | # value layer shape: [b, np, sk, hn] 244 | # attention shape: [b, np, sq, sk] 245 | # context layer shape: [b, np, sq, hn] 246 | output_size = (value_layer.size(0), value_layer.size(1), query_layer.size(1), value_layer.size(3)) 247 | # change view [b * np, sk, hn] 248 | value_layer = value_layer.view(output_size[0] * output_size[1], value_layer.size(2), -1) 249 | # change view [b * np, sq, sk] 250 | attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) 251 | # print('value_layer', value_layer.size()) 252 | # print('attention_probs', attention_probs.size()) 253 | if self.layer_number in (self.config.num_layers - 1, (self.config.num_layers -1 ) // 2): 254 | get_data_store().append(self.layer_number, attention_probs) 255 | # matmul: [b * np, sq, hn] 256 | context_layer = torch.bmm(attention_probs, value_layer) 257 | # change view [b, np, sq, hn] 258 | context_layer = context_layer.view(*output_size) 259 | # [b, np, sq, hn] --> [b, sq, np, hn] 260 | context_layer = context_layer.transpose(1, 2).contiguous() 261 | # [b, sq, np, hn] --> [b, sq, hp] 262 | new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) 263 | context_layer = context_layer.reshape(*new_context_layer_shape) 264 | 265 | return context_layer 266 | 267 | 268 | class SdpaAttention(CoreAttention): 269 | def forward(self, query_layer, key_layer, value_layer, attention_mask): 270 | if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: 271 | context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, 272 | is_causal=True, 273 | dropout_p=self.config.attention_dropout if self.training else 0.0) 274 | else: 275 | if attention_mask is not None: 276 | attention_mask = ~attention_mask 277 | context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, 278 | attention_mask, 279 | dropout_p=self.config.attention_dropout if self.training else 0.0) 280 | context_layer = context_layer.transpose(1, 2).contiguous() 281 | new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) 282 | context_layer = context_layer.reshape(*new_context_layer_shape) 283 | return context_layer 284 | 285 | 286 | def _get_unpad_data(attention_mask): 287 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) 288 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 289 | max_seqlen_in_batch = seqlens_in_batch.max().item() 290 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) 291 | return ( 292 | indices, 293 | cu_seqlens, 294 | max_seqlen_in_batch, 295 | ) 296 | 297 | 298 | # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 299 | class FlashAttention2(CoreAttention): 300 | def __init__(self, *args, **kwargs): 301 | super().__init__(*args, **kwargs) 302 | self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() 303 | 304 | def forward(self, query_states, key_states, value_states, attention_mask): 305 | query_states = query_states.transpose(1, 2) 306 | key_states = key_states.transpose(1, 2) 307 | value_states = value_states.transpose(1, 2) 308 | batch_size, query_length = query_states.shape[:2] 309 | if not self._flash_attn_uses_top_left_mask: 310 | causal = self.is_causal 311 | else: 312 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. 313 | causal = self.is_causal and query_length != 1 314 | dropout = self.config.attention_dropout if self.training else 0.0 315 | # Contains at least one padding token in the sequence 316 | if attention_mask is not None: 317 | query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( 318 | query_states, key_states, value_states, attention_mask, query_length 319 | ) 320 | 321 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens 322 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens 323 | 324 | attn_output_unpad = flash_attn_varlen_func( 325 | query_states, 326 | key_states, 327 | value_states, 328 | cu_seqlens_q=cu_seqlens_q, 329 | cu_seqlens_k=cu_seqlens_k, 330 | max_seqlen_q=max_seqlen_in_batch_q, 331 | max_seqlen_k=max_seqlen_in_batch_k, 332 | dropout_p=dropout, 333 | softmax_scale=None, 334 | causal=causal, 335 | ) 336 | 337 | attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) 338 | else: 339 | attn_output = flash_attn_func( 340 | query_states, key_states, value_states, dropout, softmax_scale=None, causal=causal 341 | ) 342 | attn_output = attn_output.reshape(batch_size, query_length, self.hidden_size_per_partition).contiguous() 343 | return attn_output 344 | 345 | def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): 346 | indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) 347 | batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape 348 | 349 | key_layer = index_first_axis( 350 | key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k 351 | ) 352 | value_layer = index_first_axis( 353 | value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k 354 | ) 355 | if query_length == kv_seq_len: 356 | query_layer = index_first_axis( 357 | query_layer.reshape(batch_size * kv_seq_len, self.num_attention_heads_per_partition, head_dim), 358 | indices_k 359 | ) 360 | cu_seqlens_q = cu_seqlens_k 361 | max_seqlen_in_batch_q = max_seqlen_in_batch_k 362 | indices_q = indices_k 363 | elif query_length == 1: 364 | max_seqlen_in_batch_q = 1 365 | cu_seqlens_q = torch.arange( 366 | batch_size + 1, dtype=torch.int32, device=query_layer.device 367 | ) # There is a memcpy here, that is very bad. 368 | indices_q = cu_seqlens_q[:-1] 369 | query_layer = query_layer.squeeze(1) 370 | else: 371 | # The -q_len: slice assumes left padding. 372 | attention_mask = attention_mask[:, -query_length:] 373 | query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) 374 | 375 | return ( 376 | query_layer, 377 | key_layer, 378 | value_layer, 379 | indices_q, 380 | (cu_seqlens_q, cu_seqlens_k), 381 | (max_seqlen_in_batch_q, max_seqlen_in_batch_k), 382 | ) 383 | 384 | 385 | CORE_ATTENTION_CLASSES = { 386 | "eager": CoreAttention, 387 | "sdpa": SdpaAttention, 388 | "flash_attention_2": FlashAttention2 389 | } 390 | 391 | 392 | class SelfAttention(torch.nn.Module): 393 | """Parallel self-attention layer abstract class. 394 | 395 | Self-attention layer takes input with size [s, b, h] 396 | and returns output of the same size. 397 | """ 398 | 399 | def __init__(self, config: ChatGLMConfig, layer_number, device=None): 400 | super(SelfAttention, self).__init__() 401 | self.layer_number = max(1, layer_number) 402 | 403 | self.projection_size = config.kv_channels * config.num_attention_heads 404 | 405 | # Per attention head and per partition values. 406 | self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads 407 | self.num_attention_heads_per_partition = config.num_attention_heads 408 | 409 | self.multi_query_attention = config.multi_query_attention 410 | self.qkv_hidden_size = 3 * self.projection_size 411 | if self.multi_query_attention: 412 | self.num_multi_query_groups_per_partition = config.multi_query_group_num 413 | self.qkv_hidden_size = ( 414 | self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num 415 | ) 416 | self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size, 417 | bias=config.add_bias_linear or config.add_qkv_bias, 418 | device=device, **_config_to_kwargs(config) 419 | ) 420 | 421 | self.core_attention = CORE_ATTENTION_CLASSES[config._attn_implementation](config, self.layer_number) 422 | 423 | # Output. 424 | self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear, 425 | device=device, **_config_to_kwargs(config) 426 | ) 427 | 428 | def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): 429 | if self.multi_query_attention: 430 | num_attention_heads = self.num_multi_query_groups_per_partition 431 | else: 432 | num_attention_heads = self.num_attention_heads_per_partition 433 | return torch.empty( 434 | inference_max_sequence_len, 435 | batch_size, 436 | num_attention_heads, 437 | self.hidden_size_per_attention_head, 438 | dtype=dtype, 439 | device=device, 440 | ) 441 | 442 | def forward( 443 | self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True 444 | ): 445 | # hidden_states: [b, sq, h] 446 | 447 | # ================================================= 448 | # Pre-allocate memory for key-values for inference. 449 | # ================================================= 450 | # ===================== 451 | # Query, Key, and Value 452 | # ===================== 453 | 454 | # Attention heads [b, sq, h] --> [b, sq, (np * 3 * hn)] 455 | mixed_x_layer = self.query_key_value(hidden_states) 456 | 457 | if self.multi_query_attention: 458 | (query_layer, key_layer, value_layer) = mixed_x_layer.split( 459 | [ 460 | self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, 461 | self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, 462 | self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, 463 | ], 464 | dim=-1, 465 | ) 466 | query_layer = query_layer.view( 467 | query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) 468 | ) 469 | key_layer = key_layer.view( 470 | key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) 471 | ) 472 | value_layer = value_layer.view( 473 | value_layer.size()[:-1] 474 | + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) 475 | ) 476 | else: 477 | new_tensor_shape = mixed_x_layer.size()[:-1] + \ 478 | (self.num_attention_heads_per_partition, 479 | 3 * self.hidden_size_per_attention_head) 480 | mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) 481 | 482 | # [b, sq, np, 3 * hn] --> 3 [b, sq, np, hn] 483 | (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) 484 | 485 | # [b, sq, np, hn] -> [b, np, sq, hn] 486 | query_layer, key_layer, value_layer = [k.transpose(1, 2) for k in [query_layer, key_layer, value_layer]] 487 | 488 | # apply relative positional encoding (rotary embedding) 489 | if rotary_pos_emb is not None: 490 | query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) 491 | key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) 492 | 493 | # adjust key and value for inference 494 | if kv_cache is not None: 495 | cache_k, cache_v = kv_cache 496 | key_layer = torch.cat((cache_k, key_layer), dim=2) 497 | value_layer = torch.cat((cache_v, value_layer), dim=2) 498 | if use_cache: 499 | if kv_cache is None: 500 | kv_cache = torch.cat((key_layer.unsqueeze(0).unsqueeze(0), value_layer.unsqueeze(0).unsqueeze(0)), 501 | dim=1) 502 | else: 503 | kv_cache = (key_layer, value_layer) 504 | else: 505 | kv_cache = None 506 | 507 | if self.multi_query_attention: 508 | key_layer = key_layer.unsqueeze(2) 509 | key_layer = key_layer.expand( 510 | -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, -1 511 | ) 512 | key_layer = key_layer.contiguous().view( 513 | key_layer.size()[:1] + (self.num_attention_heads_per_partition,) + key_layer.size()[3:] 514 | ) 515 | value_layer = value_layer.unsqueeze(2) 516 | value_layer = value_layer.expand( 517 | -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1, -1 518 | ) 519 | value_layer = value_layer.contiguous().view( 520 | value_layer.size()[:1] + (self.num_attention_heads_per_partition,) + value_layer.size()[3:] 521 | ) 522 | 523 | # ================================== 524 | # core attention computation 525 | # ================================== 526 | 527 | context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) 528 | 529 | # ================= 530 | # Output. [sq, b, h] 531 | # ================= 532 | 533 | output = self.dense(context_layer) 534 | 535 | return output, kv_cache 536 | 537 | 538 | def _config_to_kwargs(args): 539 | common_kwargs = { 540 | "dtype": args.torch_dtype, 541 | } 542 | return common_kwargs 543 | 544 | 545 | class MLP(torch.nn.Module): 546 | """MLP. 547 | 548 | MLP will take the input with h hidden state, project it to 4*h 549 | hidden dimension, perform nonlinear transformation, and project the 550 | state back into h hidden dimension. 551 | """ 552 | 553 | def __init__(self, config: ChatGLMConfig, device=None): 554 | super(MLP, self).__init__() 555 | 556 | self.add_bias = config.add_bias_linear 557 | 558 | # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf 559 | self.dense_h_to_4h = nn.Linear( 560 | config.hidden_size, 561 | config.ffn_hidden_size * 2, 562 | bias=self.add_bias, 563 | device=device, 564 | **_config_to_kwargs(config) 565 | ) 566 | 567 | def swiglu(x): 568 | x = torch.chunk(x, 2, dim=-1) 569 | return F.silu(x[0]) * x[1] 570 | 571 | self.activation_func = swiglu 572 | 573 | # Project back to h. 574 | self.dense_4h_to_h = nn.Linear( 575 | config.ffn_hidden_size, 576 | config.hidden_size, 577 | bias=self.add_bias, 578 | device=device, 579 | **_config_to_kwargs(config) 580 | ) 581 | 582 | def forward(self, hidden_states): 583 | # [s, b, 4hp] 584 | intermediate_parallel = self.dense_h_to_4h(hidden_states) 585 | intermediate_parallel = self.activation_func(intermediate_parallel) 586 | # [s, b, h] 587 | output = self.dense_4h_to_h(intermediate_parallel) 588 | return output 589 | 590 | 591 | class GLMBlock(torch.nn.Module): 592 | """A single transformer layer. 593 | 594 | Transformer layer takes input with size [s, b, h] and returns an 595 | output of the same size. 596 | """ 597 | 598 | def __init__(self, config: ChatGLMConfig, layer_number, device=None): 599 | super(GLMBlock, self).__init__() 600 | self.layer_number = layer_number 601 | 602 | self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm 603 | 604 | self.fp32_residual_connection = config.fp32_residual_connection 605 | 606 | LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm 607 | # Layernorm on the input data. 608 | self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, 609 | dtype=config.torch_dtype) 610 | 611 | # Self attention. 612 | self.self_attention = SelfAttention(config, layer_number, device=device) 613 | self.hidden_dropout = config.hidden_dropout 614 | 615 | # Layernorm on the attention output 616 | self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, 617 | dtype=config.torch_dtype) 618 | 619 | # MLP 620 | self.mlp = MLP(config, device=device) 621 | 622 | def forward( 623 | self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True, 624 | ): 625 | # hidden_states: [s, b, h] 626 | 627 | # Layer norm at the beginning of the transformer layer. 628 | layernorm_output = self.input_layernorm(hidden_states) 629 | # Self attention. 630 | attention_output, kv_cache = self.self_attention( 631 | layernorm_output, 632 | attention_mask, 633 | rotary_pos_emb, 634 | kv_cache=kv_cache, 635 | use_cache=use_cache 636 | ) 637 | 638 | # Residual connection. 639 | if self.apply_residual_connection_post_layernorm: 640 | residual = layernorm_output 641 | else: 642 | residual = hidden_states 643 | 644 | layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) 645 | layernorm_input = residual + layernorm_input 646 | 647 | # Layer norm post the self attention. 648 | layernorm_output = self.post_attention_layernorm(layernorm_input) 649 | 650 | # MLP. 651 | mlp_output = self.mlp(layernorm_output) 652 | 653 | # Second residual connection. 654 | if self.apply_residual_connection_post_layernorm: 655 | residual = layernorm_output 656 | else: 657 | residual = layernorm_input 658 | 659 | output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) 660 | output = residual + output 661 | 662 | return output, kv_cache 663 | 664 | 665 | class GLMTransformer(torch.nn.Module): 666 | """Transformer class.""" 667 | 668 | def __init__(self, config: ChatGLMConfig, device=None): 669 | super(GLMTransformer, self).__init__() 670 | 671 | self.fp32_residual_connection = config.fp32_residual_connection 672 | self.post_layer_norm = config.post_layer_norm 673 | 674 | # Number of layers. 675 | self.num_layers = config.num_layers 676 | 677 | # Transformer layers. 678 | def build_layer(layer_number): 679 | return GLMBlock(config, layer_number, device=device) 680 | 681 | self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)]) 682 | 683 | if self.post_layer_norm: 684 | LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm 685 | # Final layer norm before output. 686 | self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, 687 | dtype=config.torch_dtype) 688 | 689 | self.gradient_checkpointing = False 690 | 691 | def _get_layer(self, layer_number): 692 | return self.layers[layer_number] 693 | 694 | def forward( 695 | self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None, 696 | use_cache: Optional[bool] = True, 697 | output_hidden_states: Optional[bool] = False, 698 | ): 699 | if not kv_caches: 700 | kv_caches = [None for _ in range(self.num_layers)] 701 | presents = () if use_cache else None 702 | if self.gradient_checkpointing and self.training: 703 | if use_cache: 704 | logger.warning_once( 705 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 706 | ) 707 | use_cache = False 708 | 709 | all_self_attentions = None 710 | all_hidden_states = () if output_hidden_states else None 711 | for index in range(self.num_layers): 712 | if output_hidden_states: 713 | all_hidden_states = all_hidden_states + (hidden_states,) 714 | 715 | layer = self._get_layer(index) 716 | if self.gradient_checkpointing and self.training: 717 | layer_ret = torch.utils.checkpoint.checkpoint( 718 | layer, 719 | hidden_states, 720 | attention_mask, 721 | rotary_pos_emb, 722 | kv_caches[index], 723 | use_cache, 724 | use_reentrant=False 725 | ) 726 | else: 727 | layer_ret = layer( 728 | hidden_states, 729 | attention_mask, 730 | rotary_pos_emb, 731 | kv_cache=kv_caches[index], 732 | use_cache=use_cache 733 | ) 734 | hidden_states, kv_cache = layer_ret 735 | if use_cache: 736 | # token by token decoding, use tuple format 737 | if kv_caches[0] is not None: 738 | presents = presents + (kv_cache,) 739 | # prefilling in decoding, use tensor format to save cuda memory 740 | else: 741 | if len(presents) == 0: 742 | presents = kv_cache 743 | else: 744 | presents = torch.cat((presents, kv_cache.to(presents.device)), dim=0) 745 | 746 | if output_hidden_states: 747 | all_hidden_states = all_hidden_states + (hidden_states,) 748 | 749 | # Final layer norm. 750 | if self.post_layer_norm: 751 | hidden_states = self.final_layernorm(hidden_states) 752 | 753 | return hidden_states, presents, all_hidden_states, all_self_attentions 754 | 755 | 756 | class ChatGLMPreTrainedModel(PreTrainedModel): 757 | """ 758 | An abstract class to handle weights initialization and 759 | a simple interface for downloading and loading pretrained models. 760 | """ 761 | 762 | is_parallelizable = False 763 | supports_gradient_checkpointing = True 764 | config_class = ChatGLMConfig 765 | base_model_prefix = "transformer" 766 | _no_split_modules = ["GLMBlock"] 767 | _supports_flash_attn_2 = True 768 | _supports_sdpa = True 769 | 770 | def _init_weights(self, module: nn.Module): 771 | """Initialize the weights.""" 772 | return 773 | 774 | def get_masks(self, input_ids, past_key_values, padding_mask=None): 775 | if self.config._attn_implementation == "flash_attention_2": 776 | if padding_mask is not None and not padding_mask.all(): 777 | return padding_mask 778 | return None 779 | batch_size, seq_length = input_ids.shape 780 | full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) 781 | full_attention_mask.tril_() 782 | past_length = 0 783 | if past_key_values: 784 | past_length = past_key_values[0][0].shape[2] 785 | if past_length: 786 | full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length, 787 | device=input_ids.device), full_attention_mask), dim=-1) 788 | if padding_mask is not None: 789 | full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) 790 | if not past_length and padding_mask is not None: 791 | full_attention_mask -= padding_mask.unsqueeze(-1) - 1 792 | full_attention_mask = (full_attention_mask < 0.5).bool() 793 | full_attention_mask.unsqueeze_(1) 794 | return full_attention_mask 795 | 796 | def get_position_ids(self, input_ids, device): 797 | batch_size, seq_length = input_ids.shape 798 | position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) 799 | return position_ids 800 | 801 | class Embedding(torch.nn.Module): 802 | """Language model embeddings.""" 803 | 804 | def __init__(self, config: ChatGLMConfig, device=None): 805 | super(Embedding, self).__init__() 806 | 807 | self.hidden_size = config.hidden_size 808 | # Word embeddings (parallel). 809 | self.word_embeddings = nn.Embedding( 810 | config.padded_vocab_size, 811 | self.hidden_size, 812 | dtype=config.torch_dtype, 813 | device=device 814 | ) 815 | self.fp32_residual_connection = config.fp32_residual_connection 816 | 817 | def forward(self, input_ids): 818 | # Embeddings. 819 | words_embeddings = self.word_embeddings(input_ids) 820 | embeddings = words_embeddings 821 | # If the input flag for fp32 residual connection is set, convert for float. 822 | if self.fp32_residual_connection: 823 | embeddings = embeddings.float() 824 | return embeddings 825 | 826 | 827 | class ChatGLMModel(ChatGLMPreTrainedModel): 828 | def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): 829 | super().__init__(config) 830 | if empty_init: 831 | init_method = skip_init 832 | else: 833 | init_method = default_init 834 | init_kwargs = {} 835 | if device is not None: 836 | init_kwargs["device"] = device 837 | self.embedding = init_method(Embedding, config, **init_kwargs) 838 | self.num_layers = config.num_layers 839 | self.multi_query_group_num = config.multi_query_group_num 840 | self.kv_channels = config.kv_channels 841 | 842 | # Rotary positional embeddings 843 | self.seq_length = config.seq_length 844 | rotary_dim = ( 845 | config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels 846 | ) 847 | 848 | self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, rope_ratio=config.rope_ratio, 849 | original_impl=config.original_rope, 850 | device=device, dtype=config.torch_dtype) 851 | self.encoder = init_method(GLMTransformer, config, **init_kwargs) 852 | self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False, 853 | dtype=config.torch_dtype, **init_kwargs) 854 | 855 | def get_input_embeddings(self): 856 | return self.embedding.word_embeddings 857 | 858 | def set_input_embeddings(self, value): 859 | self.embedding.word_embeddings = value 860 | 861 | def forward( 862 | self, 863 | input_ids, 864 | position_ids: Optional[torch.Tensor] = None, 865 | attention_mask: Optional[torch.BoolTensor] = None, 866 | full_attention_mask: Optional[torch.BoolTensor] = None, 867 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, 868 | inputs_embeds: Optional[torch.Tensor] = None, 869 | use_cache: Optional[bool] = None, 870 | output_attentions: Optional[bool] = None, 871 | output_hidden_states: Optional[bool] = None, 872 | return_dict: Optional[bool] = None, 873 | ): 874 | output_hidden_states = ( 875 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 876 | ) 877 | use_cache = use_cache if use_cache is not None else self.config.use_cache 878 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 879 | 880 | batch_size, seq_length = input_ids.shape 881 | 882 | if inputs_embeds is None: 883 | inputs_embeds = self.embedding(input_ids) 884 | 885 | if full_attention_mask is None: 886 | if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): 887 | full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) 888 | 889 | # Rotary positional embeddings 890 | rotary_pos_emb = self.rotary_pos_emb(self.seq_length) 891 | if position_ids is not None: 892 | rotary_pos_emb = rotary_pos_emb[position_ids] 893 | else: 894 | rotary_pos_emb = rotary_pos_emb[None, :seq_length] 895 | 896 | # Run encoder. 897 | hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( 898 | inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb, 899 | kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states 900 | ) 901 | if presents is not None and type(presents) is torch.Tensor: 902 | presents = presents.split(1, dim=0) 903 | presents = list(presents) 904 | presents = [list(x.squeeze(0).split(1, dim=0)) for x in presents] 905 | presents = [tuple([x.squeeze(0) for x in y]) for y in presents] 906 | presents = tuple(presents) 907 | 908 | if not return_dict: 909 | return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) 910 | 911 | return BaseModelOutputWithPast( 912 | last_hidden_state=hidden_states, 913 | past_key_values=presents, 914 | hidden_states=all_hidden_states, 915 | attentions=all_self_attentions, 916 | ) 917 | 918 | 919 | class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): 920 | def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): 921 | config._attn_implementation = "eager" 922 | super().__init__(config) 923 | 924 | self.max_sequence_length = config.max_length 925 | self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) 926 | self.config = config 927 | 928 | def _update_model_kwargs_for_generation( 929 | self, 930 | outputs: ModelOutput, 931 | model_kwargs: Dict[str, Any], 932 | is_encoder_decoder: bool = False, 933 | ) -> Dict[str, Any]: 934 | # update past_key_values 935 | cache_name, cache = self._extract_past_from_model_output(outputs) 936 | model_kwargs[cache_name] = cache 937 | 938 | # update attention mask 939 | if "attention_mask" in model_kwargs: 940 | attention_mask = model_kwargs["attention_mask"] 941 | model_kwargs["attention_mask"] = torch.cat( 942 | [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 943 | ) 944 | 945 | # update position ids 946 | if "position_ids" in model_kwargs: 947 | position_ids = model_kwargs["position_ids"] 948 | new_position_id = position_ids[..., -1:].clone() 949 | new_position_id += 1 950 | model_kwargs["position_ids"] = torch.cat( 951 | [position_ids, new_position_id], dim=-1 952 | ) 953 | 954 | model_kwargs["is_first_forward"] = False 955 | return model_kwargs 956 | 957 | def prepare_inputs_for_generation( 958 | self, 959 | input_ids: torch.LongTensor, 960 | past_key_values: Optional[torch.Tensor] = None, 961 | attention_mask: Optional[torch.Tensor] = None, 962 | position_ids: Optional[torch.Tensor] = None, 963 | use_cache: Optional[bool] = None, 964 | is_first_forward: bool = True, 965 | **kwargs 966 | ) -> dict: 967 | # only last token for input_ids if past is not None 968 | if position_ids is None: 969 | position_ids = self.get_position_ids(input_ids, device=input_ids.device) 970 | if not is_first_forward: 971 | if past_key_values is not None: 972 | position_ids = position_ids[..., -1:] 973 | input_ids = input_ids[:, -1:] 974 | return { 975 | "input_ids": input_ids, 976 | "past_key_values": past_key_values, 977 | "position_ids": position_ids, 978 | "attention_mask": attention_mask, 979 | "return_last_logit": True, 980 | "use_cache": use_cache 981 | } 982 | 983 | def forward( 984 | self, 985 | input_ids: Optional[torch.Tensor] = None, 986 | position_ids: Optional[torch.Tensor] = None, 987 | attention_mask: Optional[torch.Tensor] = None, 988 | past_key_values: Optional[Tuple[torch.FloatTensor]] = None, 989 | inputs_embeds: Optional[torch.Tensor] = None, 990 | labels: Optional[torch.Tensor] = None, 991 | use_cache: Optional[bool] = None, 992 | output_attentions: Optional[bool] = None, 993 | output_hidden_states: Optional[bool] = None, 994 | return_dict: Optional[bool] = None, 995 | return_last_logit: Optional[bool] = False, 996 | ): 997 | use_cache = use_cache if use_cache is not None else self.config.use_cache 998 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 999 | 1000 | transformer_outputs = self.transformer( 1001 | input_ids=input_ids, 1002 | position_ids=position_ids, 1003 | attention_mask=attention_mask, 1004 | past_key_values=past_key_values, 1005 | inputs_embeds=inputs_embeds, 1006 | use_cache=use_cache, 1007 | output_hidden_states=output_hidden_states, 1008 | return_dict=return_dict, 1009 | ) 1010 | 1011 | hidden_states = transformer_outputs[0] 1012 | if return_last_logit: 1013 | hidden_states = hidden_states[:, -1:] 1014 | lm_logits = self.transformer.output_layer(hidden_states) 1015 | 1016 | loss = None 1017 | if labels is not None: 1018 | lm_logits = lm_logits.to(torch.float32) 1019 | 1020 | # Shift so that tokens < n predict n 1021 | shift_logits = lm_logits[..., :-1, :].contiguous() 1022 | shift_labels = labels[..., 1:].contiguous() 1023 | # Flatten the tokens 1024 | loss_fct = CrossEntropyLoss(ignore_index=-100) 1025 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 1026 | 1027 | lm_logits = lm_logits.to(hidden_states.dtype) 1028 | loss = loss.to(hidden_states.dtype) 1029 | 1030 | if not return_dict: 1031 | output = (lm_logits,) + transformer_outputs[1:] 1032 | return ((loss,) + output) if loss is not None else output 1033 | 1034 | return CausalLMOutputWithPast( 1035 | loss=loss, 1036 | logits=lm_logits, 1037 | past_key_values=transformer_outputs.past_key_values, 1038 | hidden_states=transformer_outputs.hidden_states, 1039 | attentions=transformer_outputs.attentions, 1040 | ) 1041 | 1042 | @staticmethod 1043 | def _reorder_cache( 1044 | past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor 1045 | ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: 1046 | """ 1047 | This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or 1048 | [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct 1049 | beam_idx at every generation step. 1050 | 1051 | Output shares the same memory storage as `past`. 1052 | """ 1053 | return tuple( 1054 | ( 1055 | layer_past[0].index_select(0, beam_idx.to(layer_past[0].device)), 1056 | layer_past[1].index_select(0, beam_idx.to(layer_past[1].device)), 1057 | ) 1058 | for layer_past in past 1059 | ) 1060 | 1061 | 1062 | class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel): 1063 | def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): 1064 | super().__init__(config) 1065 | 1066 | self.num_labels = config.num_labels 1067 | self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) 1068 | 1069 | self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=config.torch_dtype) 1070 | if config.classifier_dropout is not None: 1071 | self.dropout = nn.Dropout(config.classifier_dropout) 1072 | else: 1073 | self.dropout = None 1074 | self.config = config 1075 | 1076 | def forward( 1077 | self, 1078 | input_ids: Optional[torch.LongTensor] = None, 1079 | position_ids: Optional[torch.LongTensor] = None, 1080 | attention_mask: Optional[torch.Tensor] = None, 1081 | full_attention_mask: Optional[torch.Tensor] = None, 1082 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, 1083 | inputs_embeds: Optional[torch.LongTensor] = None, 1084 | labels: Optional[torch.LongTensor] = None, 1085 | use_cache: Optional[bool] = None, 1086 | output_attentions: Optional[bool] = None, 1087 | output_hidden_states: Optional[bool] = None, 1088 | return_dict: Optional[bool] = None, 1089 | ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]: 1090 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1091 | 1092 | transformer_outputs = self.transformer( 1093 | input_ids=input_ids, 1094 | position_ids=position_ids, 1095 | attention_mask=attention_mask, 1096 | full_attention_mask=full_attention_mask, 1097 | past_key_values=past_key_values, 1098 | inputs_embeds=inputs_embeds, 1099 | use_cache=use_cache, 1100 | output_attentions=output_attentions, 1101 | output_hidden_states=output_hidden_states, 1102 | return_dict=return_dict, 1103 | ) 1104 | 1105 | hidden_states = transformer_outputs[0] 1106 | pooled_hidden_states = hidden_states[:, -1] 1107 | if self.dropout is not None: 1108 | pooled_hidden_states = self.dropout(pooled_hidden_states) 1109 | logits = self.classifier_head(pooled_hidden_states) 1110 | 1111 | loss = None 1112 | if labels is not None: 1113 | if self.config.problem_type is None: 1114 | if self.num_labels == 1: 1115 | self.config.problem_type = "regression" 1116 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 1117 | self.config.problem_type = "single_label_classification" 1118 | else: 1119 | self.config.problem_type = "multi_label_classification" 1120 | 1121 | if self.config.problem_type == "regression": 1122 | loss_fct = MSELoss() 1123 | if self.num_labels == 1: 1124 | loss = loss_fct(logits.squeeze().float(), labels.squeeze()) 1125 | else: 1126 | loss = loss_fct(logits.float(), labels) 1127 | elif self.config.problem_type == "single_label_classification": 1128 | loss_fct = CrossEntropyLoss() 1129 | loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1)) 1130 | elif self.config.problem_type == "multi_label_classification": 1131 | loss_fct = BCEWithLogitsLoss() 1132 | loss = loss_fct(logits.float(), labels.view(-1, self.num_labels)) 1133 | 1134 | if not return_dict: 1135 | output = (logits,) + transformer_outputs[1:] 1136 | return ((loss,) + output) if loss is not None else output 1137 | 1138 | return SequenceClassifierOutputWithPast( 1139 | loss=loss, 1140 | logits=logits, 1141 | past_key_values=transformer_outputs.past_key_values, 1142 | hidden_states=transformer_outputs.hidden_states, 1143 | attentions=transformer_outputs.attentions, 1144 | ) -------------------------------------------------------------------------------- /attenscore/run_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 3 | python main.py \ 4 | --id 287 \ 5 | --save_path ../output/attenscore/qwen \ 6 | --model qwen-72b 7 | 8 | # export CUDA_VISIBLE_DEVICES=0,1 9 | python main.py \ 10 | --id 287 \ 11 | --save_path ../output/attenscore/llama31 \ 12 | --model llama31-8b 13 | 14 | # export CUDA_VISIBLE_DEVICES=0,1 15 | python main.py \ 16 | --id 287 \ 17 | --save_path ../output/attenscore/glm \ 18 | --model glm-9b 19 | -------------------------------------------------------------------------------- /eval_system_bench.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import concurrent.futures 4 | import os 5 | import random 6 | import traceback 7 | import sys 8 | import threading 9 | import importlib 10 | from models import * 11 | from utils import * 12 | 13 | class SystemBenchEval(): 14 | def __init__(self, infer_model_name, infer_with_gt_history, eval_model_name, eval_dataset_path, output_dir): 15 | self.infer_with_gt_history = infer_with_gt_history 16 | if self.infer_with_gt_history: 17 | self.output_dir = os.path.join(output_dir, f"{infer_model_name}_with_gt_history") 18 | else: 19 | self.output_dir = os.path.join(output_dir, infer_model_name) 20 | os.makedirs(self.output_dir, exist_ok=True) 21 | self.infer_model_name = infer_model_name 22 | self.infer_model = self.get_model_class(infer_model_name) 23 | self.eval_model_name = eval_model_name 24 | self.eval_model = self.get_model_class(eval_model_name) 25 | self.eval_dataset_path = eval_dataset_path 26 | 27 | self.infer_output_path = os.path.join(self.output_dir, f"{infer_model_name}_infer.json") 28 | self.eval_output_path = os.path.join(self.output_dir, f"{infer_model_name}_eval.json") 29 | self.analysis_result_output_path = os.path.join(self.output_dir, f"{infer_model_name}_analysis.xlsx") 30 | 31 | def get_model_class(self, model_type): 32 | module_name = f"models.{model_type}" 33 | class_name = f"{model_type}" 34 | print(module_name, class_name) 35 | try: 36 | module = importlib.import_module(module_name) 37 | model_class = getattr(module, class_name) 38 | return model_class() 39 | except (ImportError, AttributeError) as e: 40 | raise ValueError(f"Model type '{model_type}' is not defined: {e}") 41 | 42 | def load_examples(self, dataset_filepath): 43 | try: 44 | datas = json.load(open(dataset_filepath, encoding="utf-8")) 45 | except: 46 | return list() 47 | return datas 48 | 49 | def do_infer(self, data, retry_time=10): 50 | messages = list() 51 | for mess in data["messages"]: 52 | if mess["role"] in {"system", "user"}: 53 | messages.append(mess) 54 | else: 55 | retry_i = 0 56 | while retry_i < retry_time: 57 | try: 58 | response = self.infer_model(messages) 59 | assert response is not None and isinstance(response, str) 60 | messages.append({"role": "assistant", "content": response}) 61 | break 62 | except Exception as e: 63 | traceback.print_exc() 64 | retry_i += 1 65 | else: 66 | raise 67 | 68 | assert len(messages) == len(data["messages"]) 69 | data["infer_model"] = self.infer_model_name 70 | data["infer_results"] = messages 71 | 72 | return data 73 | 74 | def do_eval(self, data, retry_time=10): 75 | eval_results = dict() 76 | messages = list() 77 | for message in data["infer_results"]: 78 | if message["role"] in {"system", "user"}: 79 | messages.append(message) 80 | else: 81 | messages.append(message) 82 | criteria = data["prompt_infos"][messages[-2]["content"]]["criteria"] 83 | eval_pattern = get_eval_pattern(messages=messages, criteria=criteria) 84 | 85 | retry_i = 0 86 | while retry_i < retry_time: 87 | try: 88 | eval_response = self.eval_model([{"role": "user", "content": eval_pattern}], temperature=0 if retry_i < retry_time // 2 else 0.5).strip() 89 | eval_response_js = eval(eval_response[7:-3]) 90 | assert "评判理由" in eval_response_js and "评判结果" in eval_response_js and isinstance(eval_response_js["评判结果"], dict), eval_response 91 | 92 | assert set([int(n) for n in eval_response_js["评判结果"].keys()]) == set([int(ck) for ck in criteria]), "-" * 50 + eval_pattern + "\n" + "-" * 50 + json.dumps(eval_response_js, ensure_ascii=False, indent=2) + "\n" + "-" * 50 + json.dumps(criteria, ensure_ascii=False, indent=2) + "-" * 50 93 | all(value in {"是", "否"} for value in eval_response_js["评判结果"].values()) 94 | eval_response_js["eval_pattern"] = eval_pattern 95 | eval_response_js["response"] = message["content"] 96 | eval_response_js["criteria"] = criteria 97 | eval_response_js["retry_time"] = retry_i 98 | 99 | for cr in criteria: 100 | check_res = character_count(criteria[cr]["criteria_content"], message["content"]) 101 | if check_res == -1: 102 | continue 103 | else: 104 | if check_res is True: 105 | eval_response_js["评判结果"][cr] = "是" 106 | else: 107 | eval_response_js["评判结果"][cr] = "否" 108 | 109 | eval_results[messages[-2]["content"]] = eval_response_js 110 | break 111 | except Exception as e: 112 | print(eval_response) 113 | print(json.dumps(criteria, ensure_ascii=False, indent=2)) 114 | print(eval_pattern) 115 | traceback.print_exc() 116 | retry_i += 1 117 | else: 118 | data["eval_results"] = None 119 | return data 120 | 121 | data["eval_results"] = eval_results 122 | return data 123 | 124 | def execute(self, do_infer=False, do_eval=False, max_threads=20): 125 | def worker(task_type, input_path, output_path): 126 | datas = self.load_examples(input_path) 127 | total_count = len(datas) 128 | 129 | cache_filepath = output_path + "_cache.json" 130 | if os.path.exists(cache_filepath): 131 | cache_datas = [data for data in self.load_examples(cache_filepath) if data[f"{task_type}_results"] is not None] 132 | else: 133 | cache_datas = list() 134 | completed = len(cache_datas) 135 | 136 | print(f"{task_type}: {completed} / {total_count}") 137 | cache_data_ids = [data["system_id"] for data in cache_datas] 138 | rest_datas = [data for data in datas if data["system_id"] not in cache_data_ids] 139 | random.shuffle(rest_datas) 140 | 141 | lock = threading.Lock() 142 | with concurrent.futures.ThreadPoolExecutor(max_workers=max_threads) as executor: 143 | tasks = {executor.submit(getattr(self, f"do_{task_type}"), rest_data) for rest_data in rest_datas} 144 | for future in concurrent.futures.as_completed(tasks): 145 | result = future.result() 146 | try: 147 | result = future.result() 148 | except Exception as e: 149 | traceback.print_exc() 150 | else: 151 | with lock: 152 | cache_datas.append(result) 153 | json.dump(cache_datas, open(cache_filepath, "w", encoding="utf-8"), ensure_ascii=False, indent=2) 154 | if result[f"{task_type}_results"] is not None: 155 | completed += 1 156 | 157 | print(f"{task_type}: {completed} / {total_count}") 158 | 159 | json.dump(cache_datas, open(output_path, "w", encoding="utf-8"), ensure_ascii=False, indent=2) 160 | assert total_count - completed == 0, f"失败数量:{total_count - completed}" 161 | 162 | if do_infer: 163 | print("do infer") 164 | worker("infer", self.eval_dataset_path, self.infer_output_path) 165 | print("-" * 50) 166 | if do_eval: 167 | print("do eval") 168 | worker("eval", self.infer_output_path, self.eval_output_path) 169 | analysis_eval_results(eval_result_filepath=self.eval_output_path, analysis_eval_output_path=self.analysis_result_output_path) 170 | 171 | 172 | if __name__ == "__main__": 173 | parser = argparse.ArgumentParser() 174 | parser.add_argument("--infer_model_name", type=str) 175 | parser.add_argument("--infer_with_gt_history", type=str2bool, default=False, help="true or false") 176 | parser.add_argument("--output_dir", type=str, default="output") 177 | parser.add_argument("--max_threads", type=int, default=100) 178 | args = parser.parse_args() 179 | 180 | eval_model_name = "gpt4o" 181 | eval_dataset_path = "datas/system_benchmark_eval_datas.json" 182 | output_dir = args.output_dir 183 | 184 | system_bench_eval = SystemBenchEval(infer_model_name=args.infer_model_name, infer_with_gt_history=args.infer_with_gt_history, eval_model_name=eval_model_name, eval_dataset_path=eval_dataset_path, output_dir=output_dir) 185 | system_bench_eval.execute(do_infer=True, do_eval=True, max_threads=args.max_threads) 186 | -------------------------------------------------------------------------------- /eval_system_bench_with_gt.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import concurrent.futures 4 | import os 5 | import random 6 | import traceback 7 | import sys 8 | import threading 9 | import importlib 10 | from models import * 11 | from utils import * 12 | 13 | class SystemBenchEval(): 14 | def __init__(self, infer_model_name, infer_with_gt_history, eval_model_name, eval_dataset_path, output_dir): 15 | self.infer_with_gt_history = infer_with_gt_history 16 | if self.infer_with_gt_history: 17 | self.output_dir = os.path.join(output_dir, f"{infer_model_name}_with_gt_history") 18 | else: 19 | self.output_dir = os.path.join(output_dir, infer_model_name) 20 | os.makedirs(self.output_dir, exist_ok=True) 21 | self.infer_model_name = infer_model_name 22 | self.infer_model = self.get_model_class(infer_model_name) 23 | self.eval_model_name = eval_model_name 24 | self.eval_model = self.get_model_class(eval_model_name) 25 | self.eval_dataset_path = eval_dataset_path 26 | 27 | self.infer_output_path = os.path.join(self.output_dir, f"{infer_model_name}_infer.json") 28 | self.eval_output_path = os.path.join(self.output_dir, f"{infer_model_name}_eval.json") 29 | self.analysis_result_output_path = os.path.join(self.output_dir, f"{infer_model_name}_analysis.xlsx") 30 | 31 | def get_model_class(self, model_type): 32 | module_name = f"models.{model_type}" 33 | class_name = f"{model_type}" 34 | print(module_name, class_name) 35 | try: 36 | module = importlib.import_module(module_name) 37 | model_class = getattr(module, class_name) 38 | return model_class() 39 | except (ImportError, AttributeError) as e: 40 | raise ValueError(f"Model type '{model_type}' is not defined: {e}") 41 | 42 | def load_examples(self, dataset_filepath): 43 | try: 44 | datas = json.load(open(dataset_filepath, encoding="utf-8")) 45 | except: 46 | return list() 47 | return datas 48 | 49 | def do_infer(self, data, retry_time=10): 50 | all_messages = list() 51 | for index, mess in enumerate(data["messages"]): 52 | if mess["role"] in {"system", "user"}: 53 | pass 54 | else: 55 | messages = data["messages"][0:index] 56 | retry_i = 0 57 | while retry_i < retry_time: 58 | try: 59 | response = self.infer_model(messages) 60 | assert response is not None and isinstance(response, str) 61 | messages.append({"role": "assistant", "content": response}) 62 | all_messages.append(messages) 63 | break 64 | except Exception as e: 65 | traceback.print_exc() 66 | retry_i += 1 67 | else: 68 | raise 69 | 70 | assert len(all_messages) == len(data["messages"]) // 2 71 | data["infer_model"] = self.infer_model_name 72 | data["infer_results"] = all_messages 73 | 74 | return data 75 | 76 | def do_eval(self, data, retry_time=10): 77 | eval_results = dict() 78 | for messages in data["infer_results"]: 79 | prompt = messages[-2]["content"] 80 | answer = messages[-1]["content"] 81 | criteria = data["prompt_infos"][messages[-2]["content"]]["criteria"] 82 | eval_pattern = get_eval_pattern(messages=messages, criteria=criteria) 83 | 84 | retry_i = 0 85 | while retry_i < retry_time: 86 | try: 87 | eval_response = self.eval_model([{"role": "user", "content": eval_pattern}], temperature=0).strip() 88 | eval_response_js = eval(eval_response[7:-3]) 89 | assert "评判理由" in eval_response_js and "评判结果" in eval_response_js and isinstance(eval_response_js["评判结果"], dict), eval_response 90 | 91 | assert set([int(n) for n in eval_response_js["评判结果"].keys()]) == set([int(ck) for ck in criteria]), "-" * 50 + eval_pattern + "\n" + "-" * 50 + json.dumps(eval_response_js, ensure_ascii=False, indent=2) + "\n" + "-" * 50 + json.dumps(criteria, ensure_ascii=False, indent=2) + "-" * 50 92 | all(value in {"是", "否"} for value in eval_response_js["评判结果"].values()) 93 | eval_response_js["eval_pattern"] = eval_pattern 94 | eval_response_js["response"] = answer 95 | eval_response_js["criteria"] = criteria 96 | eval_response_js["retry_time"] = retry_i 97 | 98 | for cr in criteria: 99 | check_res = character_count(criteria[cr]["criteria_content"], answer) 100 | if check_res == -1: 101 | continue 102 | else: 103 | if check_res is True: 104 | eval_response_js["评判结果"][cr] = "是" 105 | else: 106 | eval_response_js["评判结果"][cr] = "否" 107 | 108 | eval_results[messages[-2]["content"]] = eval_response_js 109 | break 110 | except Exception as e: 111 | print(eval_response) 112 | print(json.dumps(criteria, ensure_ascii=False, indent=2)) 113 | print(eval_pattern) 114 | traceback.print_exc() 115 | retry_i += 1 116 | else: 117 | data["eval_results"] = None 118 | return data 119 | 120 | data["eval_results"] = eval_results 121 | return data 122 | 123 | def execute(self, do_infer=False, do_eval=False, max_threads=20): 124 | def worker(task_type, input_path, output_path): 125 | datas = self.load_examples(input_path) 126 | total_count = len(datas) 127 | 128 | cache_filepath = output_path + "_cache.json" 129 | if os.path.exists(cache_filepath): 130 | cache_datas = [data for data in self.load_examples(cache_filepath) if data[f"{task_type}_results"] is not None] 131 | else: 132 | cache_datas = list() 133 | completed = len(cache_datas) 134 | 135 | print(f"{task_type}: {completed} / {total_count}") 136 | cache_data_ids = [data["system_id"] for data in cache_datas] 137 | rest_datas = [data for data in datas if data["system_id"] not in cache_data_ids] 138 | random.shuffle(rest_datas) 139 | 140 | lock = threading.Lock() 141 | with concurrent.futures.ThreadPoolExecutor(max_workers=max_threads) as executor: 142 | tasks = {executor.submit(getattr(self, f"do_{task_type}"), rest_data) for rest_data in rest_datas} 143 | for future in concurrent.futures.as_completed(tasks): 144 | result = future.result() 145 | try: 146 | result = future.result() 147 | except Exception as e: 148 | traceback.print_exc() 149 | else: 150 | with lock: 151 | cache_datas.append(result) 152 | json.dump(cache_datas, open(cache_filepath, "w", encoding="utf-8"), ensure_ascii=False, indent=2) 153 | if result[f"{task_type}_results"] is not None: 154 | completed += 1 155 | 156 | print(f"{task_type}: {completed} / {total_count}") 157 | 158 | json.dump(cache_datas, open(output_path, "w", encoding="utf-8"), ensure_ascii=False, indent=2) 159 | assert total_count - completed == 0, f"失败数量:{total_count - completed}" 160 | 161 | if do_infer: 162 | print("do infer") 163 | worker("infer", self.eval_dataset_path, self.infer_output_path) 164 | print("-" * 50) 165 | if do_eval: 166 | print("do eval") 167 | worker("eval", self.infer_output_path, self.eval_output_path) 168 | analysis_eval_results(eval_result_filepath=self.eval_output_path, analysis_eval_output_path=self.analysis_result_output_path) 169 | 170 | 171 | if __name__ == "__main__": 172 | parser = argparse.ArgumentParser() 173 | parser.add_argument("--infer_model_name", type=str) 174 | parser.add_argument("--infer_with_gt_history", type=str2bool, default=False, help="true or false") 175 | parser.add_argument("--output_dir", type=str, default="output") 176 | parser.add_argument("--max_threads", type=int, default=100) 177 | args = parser.parse_args() 178 | 179 | eval_model_name = "gpt4o" 180 | eval_dataset_path = "datas/system_benchmark_eval_datas.json" 181 | output_dir = args.output_dir 182 | 183 | system_bench_eval = SystemBenchEval(infer_model_name=args.infer_model_name, infer_with_gt_history=args.infer_with_gt_history, eval_model_name=eval_model_name, eval_dataset_path=eval_dataset_path, output_dir=output_dir) 184 | system_bench_eval.execute(do_infer=True, do_eval=True, max_threads=args.max_threads) 185 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Baichuan-MLSystemLab/SysBench/627ffa8010d00e270426975b33b1fb7a0a635602/models/__init__.py -------------------------------------------------------------------------------- /models/claude35_opus.py: -------------------------------------------------------------------------------- 1 | import openai 2 | from openai import OpenAI 3 | 4 | 5 | class claude35_opus(object): 6 | def __init__(self) -> None: 7 | API_BASE = "http:///v1" 8 | API_KEY = "" 9 | 10 | self.model = OpenAI( 11 | api_key=API_KEY, 12 | base_url=API_BASE 13 | ) 14 | 15 | def __call__(self, messages): 16 | for i in range(100): 17 | try: 18 | response = self.model.chat.completions.create( 19 | model="claude-3-opus-20240229", 20 | messages=messages 21 | ) 22 | print(response.json()) 23 | return response.choices[0].message.content 24 | except Exception as e: 25 | print(e) 26 | continue 27 | else: 28 | return None 29 | 30 | if __name__ == "__main__": 31 | model = claude35_opus() 32 | messages = [{"role": "system", "content": "你的名字叫百灵鸟,你擅长给人看病"}, {"role": "user", "content": "你叫什么名字"}] 33 | print(model(messages)) 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /models/deepseek.py: -------------------------------------------------------------------------------- 1 | import openai 2 | from openai import OpenAI 3 | 4 | 5 | class deepseek(object): 6 | def __init__(self) -> None: 7 | API_BASE = "https://api.deepseek.com" 8 | API_KEY = "" 9 | 10 | self.model = OpenAI( 11 | api_key=API_KEY, 12 | base_url=API_BASE 13 | ) 14 | 15 | def __call__(self, messages): 16 | response = self.model.chat.completions.create( 17 | model="deepseek-chat", 18 | messages=messages 19 | ) 20 | return response.choices[0].message.content 21 | 22 | 23 | if __name__ == "__main__": 24 | model = deepseek() 25 | messages = [{"role": "system", "content": "你是百灵鸟,你是一个给人看病的医生"}, {"role": "user", "content": "你叫什么名字"}] 26 | print(model(messages)) 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /models/ernie4.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import random 4 | import requests 5 | import json 6 | 7 | random.seed(1234) 8 | 9 | 10 | class ernie4(object): 11 | def __init__(self) -> None: 12 | pass 13 | 14 | def __call__(self, messages, system=None, temperature=0.95, finish_try=2, keys=None): 15 | if isinstance(messages, str): 16 | messages = [{'role': 'user', 'content': messages}] 17 | if messages[0]["role"] in {"system", "user_system"}: 18 | system = messages[0]["content"] 19 | messages = messages[1:] 20 | 21 | assert isinstance(messages, list) 22 | url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-8k-0613?access_token=" + self.get_access_token(keys) 23 | payload = { 24 | "messages": messages, 25 | "top_p":0.8, 26 | "penalty_score": 1, 27 | "temperature": 0.95, 28 | "disable_search": False, 29 | "enable_citation": False} 30 | if system: 31 | payload["system"] = system 32 | payload = json.dumps(payload) 33 | headers = {'Content-Type': 'application/json'} 34 | while True: 35 | try: 36 | response = requests.request("POST", url, headers=headers, data=payload) 37 | assert response.status_code == 200 38 | response = json.loads(response.text) 39 | print(f"response:{response}") 40 | 41 | response = response["result"] 42 | finish_try -= 1 43 | 44 | return response 45 | except Exception as e: 46 | print(f"【response】:{response}\t【Error】:{e}", flush=True) 47 | try: 48 | error_code = json.loads(e.http_body)['error']['code'] 49 | if error_code in ('billing_not_active', 'context_length_exceeded'): 50 | return '', error_code 51 | except: 52 | pass 53 | return None 54 | 55 | 56 | def get_access_token(self, keys=None): 57 | """ 58 | 使用 API Key,Secret Key 获取access_token,替换下列示例中的应用API Key、应用Secret Key 59 | """ 60 | API_SECRET_KEYs = [["client_id", "client_secret"]] 61 | url = f"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={API_SECRET_KEYs[0][0]}&client_secret={API_SECRET_KEYs[0][1]}" 62 | payload = json.dumps("") 63 | headers = { 64 | 'Content-Type': 'application/json', 65 | 'Accept': 'application/json' 66 | } 67 | response = requests.request("POST", url, headers=headers, data=payload) 68 | return response.json().get("access_token") 69 | 70 | 71 | if __name__ == "__main__": 72 | model = ernie4() 73 | print(model([{"role": "system", "content": "你是百灵鸟,你是一个给人看病的医生"}, {"role": "user", "content": "你叫什么名字"}])) 74 | -------------------------------------------------------------------------------- /models/glm4.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import random 4 | from zhipuai import ZhipuAI 5 | 6 | class glm4(object): 7 | def __init__(self, model="glm-4-0520") -> None: 8 | self.model = model 9 | 10 | def __call__(self, messages): 11 | client = ZhipuAI(api_key="") 12 | response = client.chat.completions.create( 13 | model=self.model, # 填写需要调用的模型编码 14 | messages=messages, 15 | stream=False, 16 | ) 17 | return response.choices[0].message.content 18 | 19 | if __name__ == "__main__": 20 | model = glm4() 21 | messages = [{"role": "system", "content": "你是百灵鸟,你是一个给人看病的医生"}, {"role": "user", "content": "你叫什么名字"}] 22 | print(model(messages)) 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /models/glm_9b_client.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM, AutoTokenizer 3 | from transformers.generation.utils import GenerationConfig 4 | from openai import OpenAI 5 | import datetime 6 | device = "cuda" # the device to load the model onto 7 | 8 | class glm_9b_client(): 9 | def __init__(self): 10 | API_BASE = "http://localhost:33618/v1" 11 | API_KEY = "custom-key" 12 | 13 | self.model = OpenAI( 14 | api_key=API_KEY, 15 | base_url=API_BASE 16 | ) 17 | 18 | # https://huggingface.co/THUDM/glm-4-9b-chat/blob/main/generation_config.json 19 | self.kwargs = { 20 | "temperature": 0.8, 21 | "top_p": 0.8, 22 | "max_tokens": 8192, 23 | } 24 | 25 | def __call__(self, messages): 26 | for i in range(100): 27 | try: 28 | response = self.model.chat.completions.create( 29 | model="THUDM/glm-4-9b-chat", 30 | messages=messages, 31 | **self.kwargs, 32 | ) 33 | return response.choices[0].message.content 34 | except Exception as e: 35 | print(e) 36 | continue 37 | else: 38 | return None 39 | 40 | 41 | if __name__ == "__main__": 42 | messages = [{'role': 'system', 'content': '你是一个计算器,只允许回答计算问题,其他问题需要拒绝回答。'}, 43 | {'role': 'user', 'content': '太阳系内有几颗大行星?'}] 44 | 45 | glm_9b_model = glm_9b_client() 46 | print(glm_9b_model(messages)) 47 | 48 | -------------------------------------------------------------------------------- /models/gpt35.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import random 4 | # import openai 5 | from openai import OpenAI 6 | 7 | random.seed(1234) 8 | 9 | class gpt35(object): 10 | def __init__(self, model_name="gpt-3.5-turbo-1106", key="") -> None: 11 | self.client = OpenAI(api_key=key) 12 | self.model_name = model_name 13 | print(f"model_name: {self.model_name}") 14 | 15 | def __call__(self, query, retry=10, temperature=None): 16 | if isinstance(query, str): 17 | messages = [{"role":"user","content": query}] 18 | elif isinstance(query, list): 19 | messages = query 20 | else: 21 | raise ValueError("query must be str or list") 22 | i = 0 23 | while i < retry: 24 | try: 25 | if temperature is None: 26 | response = self.client.chat.completions.create( 27 | model = self.model_name, 28 | messages=messages, 29 | ) 30 | else: 31 | response = self.client.chat.completions.create( 32 | model = self.model_name, 33 | messages=messages, 34 | temperature=temperature 35 | ) 36 | 37 | result = response.choices[0].message.content 38 | 39 | assert isinstance(result, str) and response.choices[0].finish_reason == "stop" 40 | return result 41 | except Exception as e: 42 | print(e) 43 | else: 44 | raise 45 | 46 | 47 | if __name__ == "__main__": 48 | gpt = gpt35() 49 | messages = [{"role": "system", "content": "你是百灵鸟,你是一个给人看病的医生"}, {"role": "user", "content": "你叫什么名字"}] 50 | print(gpt(messages)) 51 | 52 | 53 | -------------------------------------------------------------------------------- /models/gpt4_turbo_0409.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import random 4 | # import openai 5 | from openai import OpenAI 6 | 7 | random.seed(1234) 8 | 9 | class gpt4_turbo_0409(object): 10 | def __init__(self, model_name="gpt-4-turbo-2024-04-09", key="") -> None: 11 | self.client = OpenAI(api_key=key) 12 | self.model_name = model_name 13 | print(f"model_name: {self.model_name}") 14 | 15 | def __call__(self, query, retry=10, temperature=None): 16 | if isinstance(query, str): 17 | messages = [{"role":"user","content": query}] 18 | elif isinstance(query, list): 19 | messages = query 20 | else: 21 | raise ValueError("query must be str or list") 22 | i = 0 23 | while i < retry: 24 | try: 25 | if temperature is None: 26 | response = self.client.chat.completions.create( 27 | model = self.model_name, 28 | messages=messages, 29 | ) 30 | else: 31 | response = self.client.chat.completions.create( 32 | model = self.model_name, 33 | messages=messages, 34 | temperature=temperature 35 | ) 36 | 37 | result = response.choices[0].message.content 38 | 39 | assert isinstance(result, str) and response.choices[0].finish_reason == "stop" 40 | return result 41 | except Exception as e: 42 | print(e) 43 | else: 44 | raise 45 | 46 | if __name__ == "__main__": 47 | gpt4 = gpt4_turbo_0409() 48 | messages = [{"role": "system", "content": "你是百灵鸟,你是一个给人看病的医生"}, {"role": "user", "content": "你叫什么名字"}] 49 | print(gpt4(messages)) 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /models/gpt4o.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import random 4 | # import openai 5 | from openai import OpenAI 6 | 7 | random.seed(1234) 8 | 9 | class gpt4o(object): 10 | def __init__(self, model_name="gpt-4o-2024-05-13", key="") -> None: 11 | self.client = OpenAI(api_key=key) 12 | self.model_name = model_name 13 | print(f"model_name: {self.model_name}") 14 | 15 | def __call__(self, query, retry=10, temperature=None): 16 | if isinstance(query, str): 17 | messages = [{"role":"user","content": query}] 18 | elif isinstance(query, list): 19 | messages = query 20 | else: 21 | raise ValueError("query must be str or list") 22 | i = 0 23 | while i < retry: 24 | try: 25 | if temperature is None: 26 | response = self.client.chat.completions.create( 27 | model = self.model_name, 28 | messages=messages, 29 | ) 30 | else: 31 | response = self.client.chat.completions.create( 32 | model = self.model_name, 33 | messages=messages, 34 | temperature=temperature 35 | ) 36 | 37 | result = response.choices[0].message.content 38 | 39 | assert isinstance(result, str) and response.choices[0].finish_reason == "stop" 40 | return result 41 | except Exception as e: 42 | print(e) 43 | else: 44 | raise 45 | 46 | if __name__ == "__main__": 47 | gpt4 = gpt4o() 48 | messages = [{"role": "system", "content": "你是百灵鸟,你是一个给人看病的医生"}, {"role": "user", "content": "你叫什么名字"}] 49 | print(gpt4(messages)) 50 | 51 | 52 | -------------------------------------------------------------------------------- /models/llama3_1_70b.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import random 3 | import json 4 | import os 5 | 6 | 7 | IP_with_NAMES = [ 8 | ["enter your server host name", "llama3-1-70b-Instruct"], 9 | ] 10 | 11 | model_name_config = {name: ip for name, ip in IP_with_NAMES} 12 | 13 | class llama3_1_70b(): 14 | def __init__(self, ip_with_name=None, model_name=None): 15 | pass 16 | def __call__(self, messages, max_new_tokens=1000, temperature=0.9, ip_with_name=None, model_name=None): 17 | if isinstance(messages, str): 18 | messages = [{'role': 'user', 'content': messages}] 19 | 20 | if not ip_with_name and not model_name: 21 | ip_with_name = random.choice(IP_with_NAMES) 22 | elif model_name: 23 | ip_with_name = model_name_config[model_name], model_name 24 | elif ip_with_name: 25 | if isinstance(ip_with_name[0], str): 26 | ip_with_name = ip_with_name 27 | elif isinstance(ip_with_name[0], list): ## ip是个列表,随机选一个 28 | ip_with_name = random.choice(ip_with_name[0]), ip_with_name[1] 29 | else: 30 | raise ValueError('Please specify ip_with_name or model_name.') 31 | 32 | else: 33 | raise ValueError('Please specify ip_with_name or model_name.') 34 | server, model_name = ip_with_name 35 | 36 | parameters = { 37 | 'max_tokens': max_new_tokens, 38 | 'temperature': temperature, 39 | 'top_k': 5, 40 | 'top_p': 0.85, 41 | 'repetition_penalty': 1.05, 42 | # 'use_beam_search': True 43 | } 44 | 45 | response = requests.post( 46 | url = f'http://{server}/v1/chat/completions', 47 | json = { 48 | 'model': model_name, 49 | 'messages': messages, 50 | **parameters 51 | }, 52 | ) 53 | # print(11111, response.json()) 54 | answer = response.json()['choices'][0]['message']['content'] 55 | 56 | return answer 57 | 58 | 59 | if __name__ == '__main__': 60 | llama3_70b_model = llama3_1_70b() 61 | messages = [{"role": "system", "content": "你是百灵鸟,你是一个给人看病的医生"}, {"role": "user", "content": "你叫什么名字"}] 62 | print(llama3_70b_model(messages)) 63 | -------------------------------------------------------------------------------- /models/llama3_1_8b.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import random 3 | import json 4 | import os 5 | 6 | 7 | IP_with_NAMES = [ 8 | ["", "Llama-3.1-8B-Instruct"], 9 | ] 10 | 11 | 12 | model_name_config = {name: ip for name, ip in IP_with_NAMES} 13 | 14 | class llama3_1_8b(): 15 | def __init__(self, ip_with_name=None, model_name=None): 16 | pass 17 | def __call__(self, messages, max_new_tokens=1000, temperature=0.9, ip_with_name=None, model_name=None): 18 | if isinstance(messages, str): 19 | messages = [{'role': 'user', 'content': messages}] 20 | 21 | if not ip_with_name and not model_name: 22 | ip_with_name = random.choice(IP_with_NAMES) 23 | elif model_name: 24 | ip_with_name = model_name_config[model_name], model_name 25 | elif ip_with_name: 26 | if isinstance(ip_with_name[0], str): 27 | ip_with_name = ip_with_name 28 | elif isinstance(ip_with_name[0], list): ## ip是个列表,随机选一个 29 | ip_with_name = random.choice(ip_with_name[0]), ip_with_name[1] 30 | else: 31 | raise ValueError('Please specify ip_with_name or model_name.') 32 | 33 | else: 34 | raise ValueError('Please specify ip_with_name or model_name.') 35 | server, model_name = ip_with_name 36 | 37 | parameters = { 38 | 'max_tokens': max_new_tokens, 39 | 'temperature': temperature, 40 | 'top_k': 5, 41 | 'top_p': 0.85, 42 | 'repetition_penalty': 1.05, 43 | # 'use_beam_search': True 44 | } 45 | 46 | response = requests.post( 47 | url = f'http://{server}/v1/chat/completions', 48 | json = { 49 | 'model': model_name, 50 | 'messages': messages, 51 | **parameters 52 | }, 53 | ) 54 | # print(11111, response.json()) 55 | answer = response.json()['choices'][0]['message']['content'] 56 | 57 | return answer 58 | 59 | 60 | if __name__ == '__main__': 61 | model = llama3_1_8b() 62 | messages = [{"role": "system", "content": "你是百灵鸟,你是一个给人看病的医生"}, {"role": "user", "content": "你叫什么名字"}] 63 | print(model(messages)) 64 | -------------------------------------------------------------------------------- /models/moonshot.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | from tqdm import tqdm 4 | from concurrent.futures import ThreadPoolExecutor 5 | from openai import OpenAI 6 | 7 | 8 | class moonshot(): 9 | def __call__(self, messages): 10 | 11 | client = OpenAI( 12 | api_key = "", 13 | base_url = "https://api.moonshot.cn/v1", 14 | ) 15 | Max_Try = 10 16 | i = 0 17 | response = "" 18 | while i < Max_Try: 19 | try: 20 | completion = client.chat.completions.create( 21 | model = "moonshot-v1-8k", 22 | messages = messages) 23 | response = completion.choices[0].message.content 24 | return response 25 | except Exception as e: 26 | print(f"Try {i}/{Max_Try}【response】:{response}\t message:【Error】:{e}", flush=True) 27 | i += 1 28 | continue 29 | return response 30 | 31 | 32 | if __name__ == "__main__": 33 | model = moonshot() 34 | print(model([{"role": "system", "content": "你是百灵鸟,你是一个给人看病的医生"}, {"role": "user", "content": "你叫什么名字"}])) -------------------------------------------------------------------------------- /models/qwen2_72b.py: -------------------------------------------------------------------------------- 1 | from modelscope import AutoModelForCausalLM, AutoTokenizer 2 | device = "cuda" 3 | 4 | class qwen2_72b(): 5 | def __init__(self, model_path="/path/to/Qwen2-72B-Instruct"): 6 | self.model = AutoModelForCausalLM.from_pretrained( 7 | model_path, 8 | torch_dtype="auto", 9 | device_map="auto" 10 | ) 11 | self.tokenizer = AutoTokenizer.from_pretrained(model_path) 12 | 13 | def __call__(self, messages): 14 | text = self.tokenizer.apply_chat_template( 15 | messages, 16 | tokenize=False, 17 | add_generation_prompt=True 18 | ) 19 | model_inputs = self.tokenizer([text], return_tensors="pt").to(device) 20 | 21 | generated_ids = self.model.generate( 22 | model_inputs.input_ids, 23 | max_new_tokens=2048 24 | ) 25 | generated_ids = [ 26 | output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) 27 | ] 28 | 29 | response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 30 | return response 31 | 32 | 33 | if __name__ == "__main__": 34 | messages = [{"role": "system", "content": "你是百灵鸟,你是一个给人看病的医生"}, {"role": "user", "content": "你叫什么名字"}] 35 | 36 | qwen2_72b_model = qwen2_72b() 37 | print(qwen2_72b_model(messages)) 38 | 39 | -------------------------------------------------------------------------------- /models/qwen2_7b.py: -------------------------------------------------------------------------------- 1 | from modelscope import AutoModelForCausalLM, AutoTokenizer 2 | device = "cuda" 3 | 4 | 5 | class qwen2_7b(): 6 | def __init__(self, model_path="/path/to/Qwen2-7B-Instruct"): 7 | self.model = AutoModelForCausalLM.from_pretrained( 8 | model_path, 9 | torch_dtype="auto", 10 | device_map="auto" 11 | ) 12 | self.tokenizer = AutoTokenizer.from_pretrained(model_path) 13 | 14 | def __call__(self, messages): 15 | text = self.tokenizer.apply_chat_template( 16 | messages, 17 | tokenize=False, 18 | add_generation_prompt=True 19 | ) 20 | model_inputs = self.tokenizer([text], return_tensors="pt").to(device) 21 | 22 | generated_ids = self.model.generate( 23 | model_inputs.input_ids, 24 | max_new_tokens=2048 25 | ) 26 | generated_ids = [ 27 | output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) 28 | ] 29 | 30 | response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 31 | return response 32 | 33 | 34 | if __name__ == "__main__": 35 | messages = [{"role": "system", "content": "你是百灵鸟,你是一个给人看病的医生"}, {"role": "user", "content": "你叫什么名字"}] 36 | 37 | qwen2_7b_model = qwen2_7b() 38 | print(qwen2_7b_model(messages)) 39 | 40 | -------------------------------------------------------------------------------- /models/template.py: -------------------------------------------------------------------------------- 1 | 2 | class template_model(): 3 | def __init__(self): 4 | pass 5 | 6 | def __call__(self, messages): 7 | raise NotImplementedError("please implement the __call__ method") 8 | 9 | if __name__ == "__main__": 10 | messages = [{"role": "system", "content": "You are a doctor named Jack"}, {"role": "user", "content": "What's your name"}] 11 | model = template_model() 12 | print(model(messages)) 13 | -------------------------------------------------------------------------------- /models/yi_large.py: -------------------------------------------------------------------------------- 1 | import openai 2 | from openai import OpenAI 3 | 4 | 5 | class yi_large(object): 6 | def __init__(self) -> None: 7 | API_BASE = "https://api.lingyiwanwu.com/v1" 8 | API_KEY = "" 9 | 10 | self.model = OpenAI( 11 | api_key=API_KEY, 12 | base_url=API_BASE 13 | ) 14 | 15 | def __call__(self, messages): 16 | response = self.model.chat.completions.create( 17 | model="yi-large", 18 | messages=messages 19 | ) 20 | return response.choices[0].message.content 21 | 22 | 23 | if __name__ == "__main__": 24 | model = yi_large() 25 | messages = [{"role": "system", "content": "你是百灵鸟,你是一个给人看病的医生"}, {"role": "user", "content": "你叫什么名字"}] 26 | print(model(messages)) 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /output/.gitignore: -------------------------------------------------------------------------------- 1 | *.json 2 | *.npy -------------------------------------------------------------------------------- /output/claude35_opus/claude35_opus_analysis.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Baichuan-MLSystemLab/SysBench/627ffa8010d00e270426975b33b1fb7a0a635602/output/claude35_opus/claude35_opus_analysis.xlsx -------------------------------------------------------------------------------- /output/deepseek/deepseek_analysis.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Baichuan-MLSystemLab/SysBench/627ffa8010d00e270426975b33b1fb7a0a635602/output/deepseek/deepseek_analysis.xlsx -------------------------------------------------------------------------------- /output/ernie4/ernie4_analysis.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Baichuan-MLSystemLab/SysBench/627ffa8010d00e270426975b33b1fb7a0a635602/output/ernie4/ernie4_analysis.xlsx -------------------------------------------------------------------------------- /output/glm4/glm4_analysis.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Baichuan-MLSystemLab/SysBench/627ffa8010d00e270426975b33b1fb7a0a635602/output/glm4/glm4_analysis.xlsx -------------------------------------------------------------------------------- /output/glm_9b_client/glm_9b_client_analysis.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Baichuan-MLSystemLab/SysBench/627ffa8010d00e270426975b33b1fb7a0a635602/output/glm_9b_client/glm_9b_client_analysis.xlsx -------------------------------------------------------------------------------- /output/gpt35/gpt35_analysis.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Baichuan-MLSystemLab/SysBench/627ffa8010d00e270426975b33b1fb7a0a635602/output/gpt35/gpt35_analysis.xlsx -------------------------------------------------------------------------------- /output/gpt4_turbo_0409/gpt4_turbo_0409_analysis.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Baichuan-MLSystemLab/SysBench/627ffa8010d00e270426975b33b1fb7a0a635602/output/gpt4_turbo_0409/gpt4_turbo_0409_analysis.xlsx -------------------------------------------------------------------------------- /output/gpt4o/gpt4o_analysis.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Baichuan-MLSystemLab/SysBench/627ffa8010d00e270426975b33b1fb7a0a635602/output/gpt4o/gpt4o_analysis.xlsx -------------------------------------------------------------------------------- /output/llama3_70b/llama3_70b_analysis.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Baichuan-MLSystemLab/SysBench/627ffa8010d00e270426975b33b1fb7a0a635602/output/llama3_70b/llama3_70b_analysis.xlsx -------------------------------------------------------------------------------- /output/llama3_8b/llama3_8b_analysis.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Baichuan-MLSystemLab/SysBench/627ffa8010d00e270426975b33b1fb7a0a635602/output/llama3_8b/llama3_8b_analysis.xlsx -------------------------------------------------------------------------------- /output/moonshot/moonshot_analysis.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Baichuan-MLSystemLab/SysBench/627ffa8010d00e270426975b33b1fb7a0a635602/output/moonshot/moonshot_analysis.xlsx -------------------------------------------------------------------------------- /output/qwen2_72b/qwen2_72b_analysis.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Baichuan-MLSystemLab/SysBench/627ffa8010d00e270426975b33b1fb7a0a635602/output/qwen2_72b/qwen2_72b_analysis.xlsx -------------------------------------------------------------------------------- /output/qwen2_7b/qwen2_7b_analysis.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Baichuan-MLSystemLab/SysBench/627ffa8010d00e270426975b33b1fb7a0a635602/output/qwen2_7b/qwen2_7b_analysis.xlsx -------------------------------------------------------------------------------- /output/with_gt_history_output/claude35_opus/claude35_opus_analysis.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Baichuan-MLSystemLab/SysBench/627ffa8010d00e270426975b33b1fb7a0a635602/output/with_gt_history_output/claude35_opus/claude35_opus_analysis.xlsx -------------------------------------------------------------------------------- /output/with_gt_history_output/ernie4/ernie4_analysis.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Baichuan-MLSystemLab/SysBench/627ffa8010d00e270426975b33b1fb7a0a635602/output/with_gt_history_output/ernie4/ernie4_analysis.xlsx -------------------------------------------------------------------------------- /output/with_gt_history_output/gpt35/gpt35_analysis.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Baichuan-MLSystemLab/SysBench/627ffa8010d00e270426975b33b1fb7a0a635602/output/with_gt_history_output/gpt35/gpt35_analysis.xlsx -------------------------------------------------------------------------------- /output/with_gt_history_output/gpt4o/gpt4o_analysis.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Baichuan-MLSystemLab/SysBench/627ffa8010d00e270426975b33b1fb7a0a635602/output/with_gt_history_output/gpt4o/gpt4o_analysis.xlsx -------------------------------------------------------------------------------- /output/with_gt_history_output/llama3_70b/llama3_70b_analysis.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Baichuan-MLSystemLab/SysBench/627ffa8010d00e270426975b33b1fb7a0a635602/output/with_gt_history_output/llama3_70b/llama3_70b_analysis.xlsx -------------------------------------------------------------------------------- /output/with_gt_history_output/llama3_8b/llama3_8b_analysis.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Baichuan-MLSystemLab/SysBench/627ffa8010d00e270426975b33b1fb7a0a635602/output/with_gt_history_output/llama3_8b/llama3_8b_analysis.xlsx -------------------------------------------------------------------------------- /output/with_gt_history_output/qwen2_72b/qwen2_72b_analysis.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Baichuan-MLSystemLab/SysBench/627ffa8010d00e270426975b33b1fb7a0a635602/output/with_gt_history_output/qwen2_72b/qwen2_72b_analysis.xlsx -------------------------------------------------------------------------------- /output/yi_large/yi_large_analysis.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Baichuan-MLSystemLab/SysBench/627ffa8010d00e270426975b33b1fb7a0a635602/output/yi_large/yi_large_analysis.xlsx -------------------------------------------------------------------------------- /plot/.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | 3 | .DS_Store 4 | __MACOSX 5 | 6 | __pycache__ 7 | *.pyc 8 | *.pyo 9 | *.pyd 10 | 11 | *.svg 12 | *.png 13 | *.pdf 14 | 15 | ~*.xlsx 16 | ~*.pptx -------------------------------------------------------------------------------- /plot/analyze_history_gt.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | from utils.parse_xls import parse_xls, TURN_NUMBER 5 | 6 | KEY_LIST = 'GPT-4o', 'Claude-3.5', 'Llama3.1-70B', 'Llama3.1-8B', \ 7 | 'Qwen2-72B', 'ERNIE-4', 'GPT-3.5' 8 | OUTPUT_FILE = 'output/with_gt_history_output/resultsV2.csv' 9 | 10 | def get_data(key, root_dir='output'): 11 | res = np.zeros(13) 12 | cnt = np.zeros(13) 13 | 14 | def update_res(base, val): 15 | res[base] += val 16 | cnt[base] += 1 17 | 18 | try: 19 | df = parse_xls(key, root_dir=root_dir) 20 | except Exception as e: 21 | print(f'Error: {e}, when reading {key}') 22 | return res 23 | 24 | 25 | for index, row in df.iterrows(): 26 | val = row['是否可用'] 27 | turn = index % TURN_NUMBER 28 | 29 | assert isinstance(row['multi_rounds_related'], bool), f'Unknown multi_rounds_related value: {row["multi_rounds_related"]}, at index {index} of {key}' 30 | base = 0 if row['multi_rounds_related'] else 6 31 | 32 | update_res(base + turn, val) 33 | if turn > 0: 34 | update_res(base + 5, val) 35 | else: 36 | update_res(12, val) 37 | 38 | return np.vstack((res, cnt, res / cnt)) 39 | 40 | def parse_data(key): 41 | data_table = np.zeros((7, 13)) 42 | data_table[:3, :] = get_data(key) 43 | data_table[3:6, :] = get_data(key, root_dir=os.path.dirname(OUTPUT_FILE)) 44 | data_table[6] = data_table[5] - data_table[2] 45 | return data_table 46 | 47 | if __name__ == '__main__': 48 | with open(OUTPUT_FILE, 'w') as f: 49 | caps = 'Corrent', 'Total', 'Ratio', 'Corrent', 'Total', 'Ratio', 'Gain' 50 | 51 | for key in KEY_LIST: 52 | data_table = parse_data(key) 53 | header = key + ',' 54 | header += ','.join([f'R{i}' for i in range(1, 6)]) + ',AVG,' 55 | header += ','.join([f'R{i}' for i in range(1, 6)]) + ',AVG,AVG-R1,' 56 | f.write(header + '\n') 57 | for i in range(7): 58 | row = f'{caps[i]},' + ','.join([str(x) for x in data_table[i].flatten()]) 59 | f.write(row + '\n') 60 | f.write('\n') 61 | 62 | print(f'Output to {OUTPUT_FILE}') -------------------------------------------------------------------------------- /plot/eval_output.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | from tab4_turn import get_data as get_data_ssr 4 | from tab3_align import get_data as get_data_isr 5 | from tab6_csr_full import get_data as get_data_csr 6 | 7 | C_LIST = [ 8 | 'Action' , 9 | 'Content', 10 | 'Background', 11 | 'Role', 12 | 'Format', 13 | 'Style', 14 | 'Total' 15 | ] 16 | 17 | def read_metrics(infer_model_name, output_dir): 18 | ssr = get_data_ssr(infer_model_name, output_dir) 19 | isr = get_data_isr(infer_model_name, output_dir) 20 | csr = get_data_csr(infer_model_name, output_dir) 21 | placeholder = 30 22 | placeholder_2 = 20 23 | print("="*placeholder+"Total Metrics"+"="*placeholder) 24 | print("CSR:\t", f"{csr[-1]:.3f}") 25 | print("ISR:\t", f"{isr[-1]:.3f}") 26 | print("SSR:\t", f"{ssr[-1]:.3f}") 27 | 28 | print("-"*placeholder_2+"Constraints-categorized Results"+"-"*placeholder_2) 29 | for i, item in enumerate(csr): 30 | #print(C_LIST[i], f"{item:.3f}", sep=":\t") 31 | print("{0:15} {1:.3f}".format(C_LIST[i]+":", item)) 32 | 33 | print("-"*placeholder_2+"Instructions-categoried Results"+"-"*placeholder_2) 34 | print("{0:15} {1:.3f}".format("Aligned:", isr[0])) 35 | print("{0:15} {1:.3f}".format("Misaligned:", isr[1])) 36 | print("{0:15} {1:.3f}".format("Total:", isr[-1])) 37 | 38 | print("-"*placeholder_2+"Sessions-categoried Results"+"-"*placeholder_2) 39 | #print(ssr) 40 | ssr_dep = ssr[0:6] 41 | ssr_para = ssr[6:12] 42 | #print(ssr_dep) 43 | #print(ssr_para) 44 | print("Multi-turn Dependent") 45 | for i in range(0, 5): 46 | print("\tR{0}:\t{1:.3f}".format(i+1, ssr_dep[i])) 47 | print("\tAvg:\t{:.3f}".format(ssr_dep[-1])) 48 | print("Multi-turn Parallel") 49 | for i in range(0, 5): 50 | print("\tR{0}:\t{1:.3f}".format(i, ssr_para[i])) 51 | print("\tAvg:\t{:.3f}".format(ssr_para[-1])) 52 | print("{0:15} {1:.3f}".format("Total:", ssr[-1])) 53 | return 54 | 55 | 56 | if __name__ == "__main__": 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument("--infer_model_name", type=str) 59 | parser.add_argument("--output_dir", type=str, default='./output') 60 | 61 | args = parser.parse_args() 62 | 63 | infer_model_name = args.infer_model_name 64 | output_dir = args.output_dir 65 | read_metrics(infer_model_name, output_dir) 66 | -------------------------------------------------------------------------------- /plot/fig3_stat.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | 4 | import matplotlib as mpl 5 | import matplotlib.pyplot as plt 6 | from matplotlib.gridspec import GridSpec 7 | from matplotlib.transforms import Bbox 8 | 9 | from fig_domain import plot_histogram 10 | from fig_constraint import plot_pie 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser(description='Generate figure for statistics') 14 | parser.add_argument('--png', action='store_true', help='Save to PDF') 15 | parser.add_argument('--stat_type', '-s', choices=['session', 'base'], default='session', help='Type of statistics to plot') 16 | args = parser.parse_args() 17 | 18 | to_pdf = not args.png 19 | 20 | plt.rcParams['font.family'] = 'Calibri' 21 | mpl.rcParams.update({'font.size': 10}) 22 | 23 | # magic implementation 24 | fig = plt.figure(figsize=(10, 4), dpi=300, tight_layout=True) 25 | gs = GridSpec(2, 2, width_ratios=[6.5, 4], height_ratios=[10, 0.1], 26 | wspace=0) 27 | 28 | ax1 = fig.add_subplot(gs[0, 0]) 29 | ax2 = fig.add_subplot(gs[:, 1]) 30 | 31 | plot_histogram(ax1) 32 | plot_pie(ax2, stat_type=args.stat_type) 33 | 34 | fig.canvas.draw() 35 | 36 | bbox = fig.get_tightbbox(fig.canvas.get_renderer()) 37 | left, bottom, right, top = bbox.extents 38 | print(left, bottom, right, top) 39 | new_bbox = Bbox.from_extents(left - 0.1, 40 | bottom + 0.1, 41 | right - 0.1, 42 | top + 0.1) 43 | 44 | file_name = 'figures/fig_stat' + ('.pdf' if to_pdf else '.png') 45 | plt.savefig(file_name, bbox_inches=new_bbox) 46 | print(f'File saved to {file_name}') -------------------------------------------------------------------------------- /plot/fig4_radar.py: -------------------------------------------------------------------------------- 1 | """ 2 | ====================================== 3 | Radar chart (aka spider or star chart) 4 | ====================================== 5 | 6 | This example creates a radar chart, also known as a spider or star chart [1]_. 7 | 8 | Although this example allows a frame of either 'circle' or 'polygon', polygon 9 | frames don't have proper gridlines (the lines are circles instead of polygons). 10 | It's possible to get a polygon grid by setting GRIDLINE_INTERPOLATION_STEPS in 11 | matplotlib.axis to the desired number of vertices, but the orientation of the 12 | polygon is not aligned with the radial axes. 13 | 14 | .. [1] http://en.wikipedia.org/wiki/Radar_chart 15 | """ 16 | import numpy as np 17 | 18 | import matplotlib as mpl 19 | import matplotlib.pyplot as plt 20 | from matplotlib.path import Path 21 | from matplotlib.spines import Spine 22 | from matplotlib.projections.polar import PolarAxes 23 | from matplotlib.projections import register_projection 24 | from matplotlib.transforms import Affine2D, Bbox 25 | 26 | 27 | from utils.parse_xls import parse_xls 28 | from utils.generate_n_color import generate_n_colors 29 | 30 | # KEY_LIST = 'GPT-4o', 'GPT-4-Turbo', 'GPT-3.5', 'Claude-3.5', 'ERNIE-4', \ 31 | # 'Yi-Large', 'Moonshot', 'DeepSeek-V2', 'GLM-4', 'Llama3.1-70B', \ 32 | # 'Qwen2-72B', 'Llama3.1-8B', 'Mistral-7B', 'Qwen2-7B' 33 | KEY_LIST = 'GPT-4o', 'Claude-3.5', 'Qwen2-72B', 'GLM-4', 'Moonshot', \ 34 | 'GPT-3.5', 'ERNIE-4', 'Qwen2-7B' 35 | LABEL_MAP = { 36 | # 'Total': 'total', 37 | 'Action' : '动作约束', 38 | 'Content': '内容约束', 39 | 'Background': '背景约束', 40 | 'Role': '角色约束', 41 | 'Format': '格式约束', 42 | 'Style': '风格约束' 43 | } 44 | to_pdf = True 45 | ignore_first = False 46 | 47 | N = len(LABEL_MAP) 48 | M = len(KEY_LIST) 49 | 50 | # The seed in Yanzhao's lunar calendar birthday 51 | color_palette = generate_n_colors(M, seed=808, hue_offset_ratio=0.05, brightness_bias=0.1, 52 | saturation_mean=0.5, saturation_bias=0, 53 | shuffle='interleave' if ignore_first else 'interleave') 54 | 55 | # https://colorkit.co/palettes/9-colors/ 56 | # color_palette = ["#538fff","#b431e6","#ff5b58","#f7ed65","#28d2ab","#fca207","#f6ccf9","#268189","#2d1a77"] 57 | 58 | def get_data(key): 59 | res = np.zeros(N) 60 | 61 | try: 62 | df = parse_xls(key, sheet_name='不同约束类型遵循') 63 | except Exception as e: 64 | print(f'Error: {e}, when reading {key}') 65 | return res 66 | 67 | for i, (_, col) in enumerate(LABEL_MAP.items()): 68 | column = df[col] 69 | res[i] = column[1] / column[0] 70 | 71 | return res 72 | 73 | data_table = np.zeros((len(KEY_LIST), N)) 74 | for i, key in enumerate(KEY_LIST): 75 | data_table[i] = get_data(key) 76 | 77 | if ignore_first: 78 | data_table = data_table[:, 1:] 79 | data_firstcol = data_table[:, 0] 80 | del LABEL_MAP['Total'] 81 | 82 | def radar_factory(num_vars, frame='circle'): 83 | """Create a radar chart with `num_vars` axes. 84 | 85 | This function creates a RadarAxes projection and registers it. 86 | 87 | Parameters 88 | ---------- 89 | num_vars : int 90 | Number of variables for radar chart. 91 | frame : {'circle' | 'polygon'} 92 | Shape of frame surrounding axes. 93 | 94 | """ 95 | # calculate evenly-spaced axis angles 96 | theta = np.linspace(0, 2*np.pi, num_vars, endpoint=False) 97 | 98 | def draw_poly_patch(self): 99 | # rotate theta such that the first axis is at the top 100 | verts = unit_poly_verts(theta + np.pi / 2) 101 | return plt.Polygon(verts, closed=True, edgecolor='k') 102 | 103 | def draw_circle_patch(self): 104 | # unit circle centered on (0.5, 0.5) 105 | return plt.Circle((0.5, 0.5), 0.5) 106 | 107 | patch_dict = {'polygon': draw_poly_patch, 'circle': draw_circle_patch} 108 | if frame not in patch_dict: 109 | raise ValueError('unknown value for `frame`: %s' % frame) 110 | 111 | class RadarAxes(PolarAxes): 112 | 113 | name = 'radar' 114 | # use 1 line segment to connect specified points 115 | RESOLUTION = 1 116 | # define draw_frame method 117 | draw_patch = patch_dict[frame] 118 | 119 | def __init__(self, *args, **kwargs): 120 | super(RadarAxes, self).__init__(*args, **kwargs) 121 | # rotate plot such that the first axis is at the top 122 | self.set_theta_zero_location('N') 123 | 124 | def fill(self, *args, **kwargs): 125 | """Override fill so that line is closed by default""" 126 | closed = kwargs.pop('closed', True) 127 | return super(RadarAxes, self).fill(closed=closed, *args, **kwargs) 128 | 129 | def plot(self, *args, **kwargs): 130 | """Override plot so that line is closed by default""" 131 | lines = super(RadarAxes, self).plot(*args, **kwargs) 132 | for line in lines: 133 | self._close_line(line) 134 | 135 | def _close_line(self, line): 136 | x, y = line.get_data() 137 | # FIXME: markers at x[0], y[0] get doubled-up 138 | if x[0] != x[-1]: 139 | x = np.concatenate((x, [x[0]])) 140 | y = np.concatenate((y, [y[0]])) 141 | line.set_data(x, y) 142 | 143 | def set_varlabels(self, labels): 144 | deg_theta = np.degrees(theta) 145 | _, xticklabels = self.set_thetagrids(deg_theta, labels) 146 | for i, (label, angle) in enumerate(zip(xticklabels, deg_theta)): 147 | x,y = label.get_position() 148 | trans = Affine2D().scale(0.92, 0.92) # MAGIC NUMBER 149 | trans = trans + Affine2D().translate(7, 6) # MAGIC NUMBER 150 | lab = ax.text(x,y, label.get_text(), transform=(label.get_transform() + trans), 151 | ha=label.get_ha(), va=label.get_va()) 152 | if angle >= 90 and angle <= 270: 153 | angle += 180 154 | lab.set_rotation(angle % 360) 155 | self.set_thetagrids(deg_theta, []) 156 | 157 | def _gen_axes_patch(self): 158 | return self.draw_patch() 159 | 160 | def _gen_axes_spines(self): 161 | if frame == 'circle': 162 | return PolarAxes._gen_axes_spines(self) 163 | # The following is a hack to get the spines (i.e. the axes frame) 164 | # to draw correctly for a polygon frame. 165 | 166 | # spine_type must be 'left', 'right', 'top', 'bottom', or `circle`. 167 | spine_type = 'circle' 168 | verts = unit_poly_verts(theta + np.pi / 2) 169 | # close off polygon by repeating first vertex 170 | verts.append(verts[0]) 171 | path = Path(verts) 172 | 173 | spine = Spine(self, spine_type, path) 174 | spine.set_transform(self.transAxes) 175 | return {'polar': spine} 176 | 177 | register_projection(RadarAxes) 178 | return theta 179 | 180 | 181 | def unit_poly_verts(theta): 182 | """Return vertices of polygon for subplot axes. 183 | 184 | This polygon is circumscribed by a unit circle centered at (0.5, 0.5) 185 | """ 186 | x0, y0, r = [0.5] * 3 187 | verts = [(r*np.cos(t) + x0, r*np.sin(t) + y0) for t in theta] 188 | return verts 189 | 190 | def white_space_align(labels, data, space=2): 191 | max_len = max(map(len, labels)) 192 | return [label + ' ' * (max_len - len(label) + space) + f'{data[i]*100:.1f}%' for i, label in enumerate(labels)] 193 | 194 | if __name__ == '__main__': 195 | plt.rcParams['font.family'] = 'Calibri' 196 | mpl.rcParams.update({'font.size': 12}) 197 | 198 | theta = radar_factory(N - (1 if ignore_first else 0), frame='circle') 199 | 200 | if ignore_first: 201 | figsize = (5, 4.5) 202 | else: 203 | figsize = (4, 4) 204 | 205 | fig, ax = plt.subplots(1, 1, figsize=figsize, dpi=300, tight_layout=True, 206 | subplot_kw=dict(projection='radar')) 207 | 208 | aligned_labels = white_space_align(KEY_LIST, data_firstcol) if ignore_first else KEY_LIST 209 | indexes = list(range(M)) 210 | if ignore_first: 211 | indexes = sorted(indexes, key=lambda i: data_firstcol[i], reverse=True) 212 | 213 | ax.set_rlim(0, 1) 214 | for color_id, i in enumerate(indexes): 215 | print(KEY_LIST[i], data_table[i]) 216 | ax.plot(theta, data_table[i], color=color_palette[color_id], linewidth=1.2, linestyle='solid', label=aligned_labels[i]) 217 | ax.fill(theta, data_table[i], facecolor=color_palette[color_id], alpha=0.05) 218 | ax.set_varlabels(LABEL_MAP.keys()) 219 | ax.set_rgrids([0.2, 0.4, 0.6, 0.8], size=8) 220 | 221 | # add legend relative to top-left plot, top-aligned with it 222 | legend_kwargs = { 223 | 'loc': (1.15, 0.17), 224 | 'fontsize': 11, 225 | 'labelspacing': 0.2, 226 | # 'frameon': False, 227 | } 228 | if ignore_first: 229 | legend_kwargs['loc'] = (1.15, 0.00) 230 | legend_kwargs['prop'] = { 231 | 'family': 'Consolas', 232 | 'size': 10, 233 | } 234 | # 5.2359877559829887307710723054658 235 | ax.text(5.105, 2.42, 'The CSR Scores', weight='bold', size=12, ha='center') 236 | 237 | legend = ax.legend(**legend_kwargs) 238 | 239 | fig.canvas.draw() 240 | 241 | bbox = fig.get_tightbbox(fig.canvas.get_renderer()) 242 | left, bottom, right, top = bbox.extents 243 | print(left, bottom, right, top) 244 | 245 | if ignore_first: 246 | new_bbox = Bbox.from_extents(left + 0.2, 247 | bottom + 0.03, 248 | right + 0.25, 249 | top + 0.06) 250 | else: 251 | new_bbox = Bbox.from_extents(left + 0.08, 252 | bottom - 0.08, 253 | right + 0.28, 254 | top + 0.1) 255 | file_name = 'figures/fig_radar' + ('.pdf' if to_pdf else '.png') 256 | plt.savefig(file_name, bbox_inches=new_bbox) 257 | print(f'Saved to {file_name}.') 258 | -------------------------------------------------------------------------------- /plot/fig5_hgt_histo.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import colorsys 5 | 6 | import matplotlib.lines as mlines 7 | import matplotlib.patches as mpatches 8 | 9 | from analyze_history_gt import parse_data 10 | 11 | color_palette = ["#05b9e2", "#e88290"] 12 | KEY_LIST = 'Qwen2-72B', 'Claude-3.5', 'Llama3.1-8B', 'ERNIE-4' 13 | 14 | to_pdf = True 15 | xtick_labels = ['T2', 'T3', 'T4', 'T5', 'AVG'] 16 | LABEL_LSIT = ['Multi-turn Dependent', 'Multi-turn Parallel'] 17 | 18 | def add_hsv_for_color(color, h=0, s=0, v=0.1): 19 | old_h, old_s, old_v = colorsys.rgb_to_hsv(*mpl.colors.to_rgb(color)) 20 | print(old_h, old_s, old_v) 21 | new_h = (old_h + h) % 1 22 | new_s = max(0, min(1, old_s + s)) 23 | new_v = max(0, min(1, old_v + v)) 24 | color = colorsys.hsv_to_rgb(new_h, new_s, new_v) 25 | return color 26 | 27 | color_palette2 = [ 28 | add_hsv_for_color(color_palette[0], h=0.0, s=-0.4, v=0), 29 | add_hsv_for_color(color_palette[1], h=0.0, s=-0.1, v=0), 30 | ] 31 | 32 | def plot_bar(ax, key, width=0.25): 33 | data_table = parse_data(key) * 100 34 | 35 | for i in range(5): 36 | pad = width/2 + width / 6 37 | ax.bar(i - pad, data_table[6, i+1], 38 | color=color_palette2[0], width=width) 39 | ax.bar(i + pad, data_table[6, i+7], color=color_palette2[1], width=width) 40 | 41 | ax.text(i - pad, max(0, data_table[6, i+1]) + 0.5, '%+.1f%%'%(data_table[6, i+1]), 42 | ha='center', rotation=90, fontsize=12, color=color_palette[0]) 43 | ax.text(i + pad, max(0, data_table[6, i+7]) + 0.5, '%+.1f%%'%(data_table[6, i+7]), 44 | ha='center', rotation=90, fontsize=12, color=color_palette[1]) 45 | 46 | ax.text(i, -3, xtick_labels[i], rotation=45, ha='center', fontsize=14) 47 | 48 | tx, ty = 2, 12.5 49 | if key == 'Qwen2-72B': 50 | tx = 1.2 51 | ax.text(tx, ty, key, fontsize=14, ha='center', va='center', weight='bold') 52 | 53 | ax.spines['bottom'].set_visible(False) 54 | ax.spines['top'].set_visible(False) 55 | ax.spines['right'].set_visible(False) 56 | 57 | # ax.xaxis.set_ticks_position('bottom') 58 | ax.axhline(0, color='black', linewidth=1) 59 | 60 | # for i in range(2): 61 | # offset = abs(data_table[6, 6 * i]) 62 | # ax.axhline(offset, color=color_palette2[i], linewidth=0.5, linestyle='dotted') 63 | 64 | ax.axhline(abs(data_table[6, -1]), color='black', linewidth=0.5, linestyle='dotted') 65 | 66 | ax.tick_params(axis='x', length=0) 67 | 68 | ax.set_xticks(range(5)) 69 | ax.set_xticklabels([]) 70 | 71 | ax.set_ylim(-4, 15) 72 | 73 | if __name__ == '__main__': 74 | plt.rcParams["font.family"] = "Calibri" 75 | mpl.rcParams.update({'font.size': 10}) 76 | 77 | fig, axs = plt.subplots(2, 2, dpi=300, tight_layout=True, 78 | figsize=(7.2, 4.8), 79 | sharex=True, sharey=True) 80 | # print figure size 81 | # print(fig.get_size_inches()) 82 | 83 | plt.subplots_adjust(wspace=0.08, hspace=0.28) 84 | 85 | 86 | for idx, key in enumerate(KEY_LIST): 87 | plot_bar(axs[idx//2, idx%2], key) 88 | 89 | patches = [mpatches.Patch(color=color_palette2[i], label=LABEL_LSIT[i]) for i in range(2)] 90 | patches.append( 91 | mlines.Line2D([], [], color='black', label="Uncertainty", linewidth=1, linestyle='dotted') 92 | ) 93 | legend = fig.legend(handles=patches, loc='upper center', bbox_to_anchor=(0.5, 0.56), ncol=3, 94 | fontsize=13, borderpad=0.5, handleheight=0.7, columnspacing=1) 95 | fig.text(-0.01, 0.48, 'ISR Improvment(%) with Ground-truth History', 96 | va='center', rotation='vertical', fontsize=15) 97 | 98 | file_name = 'figures/fig_hgt' + ('.pdf' if to_pdf else '.png') 99 | plt.savefig(file_name, bbox_inches='tight', pad_inches=0.1) 100 | print(f'Figure saved to {file_name}.') 101 | -------------------------------------------------------------------------------- /plot/fig6_atscore.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import matplotlib as mpl 5 | import matplotlib.pyplot as plt 6 | import matplotlib.lines as mlines 7 | import matplotlib.patches as mpatches 8 | 9 | import numpy as np 10 | import pandas as pd 11 | 12 | from fig_atscore_curve import do_plot as do_plot_l 13 | from fig_atscore_replace import do_plot as do_plot_r 14 | 15 | from fig_atscore_curve import color_palette, model_list 16 | 17 | is_pdf = True 18 | 19 | map_str = { 20 | 'glm': 'GLM4-9B', 21 | 'llama31': 'Llama3.1-8B', 22 | 'qwen': 'Qwen2-72B', 23 | } 24 | 25 | if __name__ == '__main__': 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--id', '-i', type=int, default=287, help='System Message ID for the plot') 28 | parser.add_argument('--layer', '-l', type=int, default=0, help='Middle(0) / final(1) Layer for the plot') 29 | parser.add_argument('--window_size', '-w', type=int, default=21, help='Window size for moving average') 30 | args = parser.parse_args() 31 | 32 | fig, axs = plt.subplots(1, 2, figsize=(6.9, 2.9), 33 | dpi=300, tight_layout=True, 34 | gridspec_kw={'width_ratios': [1.4, 0.9]}) 35 | plt.subplots_adjust(left=0.0, right=1.0, top=0.9, bottom=0.1) 36 | 37 | kwargs = { 38 | 'plot_sid': args.id, 39 | 'layer_idx': args.layer 40 | } 41 | print(f'Plotting for System Message ID {args.id}, Layer idx {args.layer}.') 42 | 43 | do_plot_l(axs[0], args.window_size, **kwargs) 44 | do_plot_r(axs[1], **kwargs) 45 | 46 | axs[0].set_ylim(0.0, 0.75) 47 | axs[0].set_xlim(-2, 102) 48 | 49 | # fig.text(-0.01, 0.52, "System Message's Share of Total AS", va='center', rotation='vertical', fontsize=11) 50 | patches = [mpatches.Patch(color=color_palette[i], label=map_str[model_list[i]]) for i in range(len(model_list))] 51 | lines = [ 52 | mlines.Line2D([], [], color='black', linestyle='-.', label='Average'), 53 | mlines.Line2D([], [], color='black', linestyle='-', label='As System'), 54 | mlines.Line2D([], [], color='black', linestyle='--', label='As User'), 55 | ] 56 | 57 | fig.text(0.55, 0.8, '(a)', ha='center', 58 | fontdict={'fontsize': 12, 'font': 'Times New Roman'}) 59 | fig.text(0.95, 0.8, '(b)', ha='center', 60 | fontdict={'fontsize': 12, 'font': 'Times New Roman'}) 61 | legend = fig.legend(handles=patches + lines, loc='upper center', ncol=6, fontsize=10, 62 | bbox_to_anchor=(0.50, 1.08), columnspacing=0.5, labelspacing=0.2, 63 | frameon=False,handletextpad=0.3) 64 | # legend.get_frame().set_alpha(None) 65 | # legend.get_frame().set_facecolor((0.95, 0.95, 0.95, 0.95)) 66 | 67 | file_name = 'figures/atscore' + ('.pdf' if is_pdf else '.png') 68 | plt.savefig(file_name, bbox_inches='tight', pad_inches=0.0) 69 | print(f'Figure saved to {file_name}.') -------------------------------------------------------------------------------- /plot/fig_atscore_curve.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import matplotlib as mpl 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import pandas as pd 8 | 9 | from utils.smooth import weighted_moving_average 10 | 11 | model_list = ['glm', 'llama31', 'qwen'] 12 | # model_list = ['qwen'] 13 | 14 | color_palette = "#F27970", "#BB9727", "#32B897" 15 | layer_ids = { 16 | 'glm': [19, 39], 17 | 'llama31': [16, 31], 18 | 'qwen': [40, 79], 19 | } 20 | 21 | PLOT_SID, LAYER_IDX, IS_REPLACE = 287, 0, False 22 | BASE_DIR = './output/attenscore' 23 | is_pdf = True 24 | 25 | X_AXIS_SCALE = 100. 26 | seg_indices = [1, 3, 5, 7, 9, 11] 27 | 28 | def value_mapper(row, splits, last=None): 29 | if last is None: 30 | last = get_last_nonzero_col(row) 31 | return (row[1] - row[0]) / (row[last] - row[0]) 32 | 33 | def get_last_nonzero_col(row): 34 | for i in range(row.shape[0] - 1, -1, -1): 35 | if row[i] != 0: 36 | return i 37 | return 0 38 | 39 | def load_data(file_name): 40 | data = np.load(file_name, allow_pickle=True).item() 41 | start_row = -data['split_indices'][-1] 42 | start_row += data['split_indices'][1] # filter out system message 43 | data['data'] = data['data'][start_row:] 44 | return data['data'], data['split_indices'] 45 | 46 | def plot_curve(ax, model_idx, data, split_indices, seg_len, color=None, window_size=5): 47 | model_name = model_list[model_idx] 48 | 49 | y = [value_mapper(row, split_indices) for row in data] 50 | x = np.arange(len(y)).astype(np.float32) 51 | 52 | start_idx = 0 53 | x_start = 0. 54 | for i in range(seg_len.shape[1]): 55 | seg_length = seg_len[model_idx, i] 56 | target_length = seg_len[-1, i] 57 | 58 | assert seg_length.is_integer() 59 | end_idx = start_idx + int(seg_length) 60 | assert end_idx <= len(x), f'{end_idx} > {len(x)} for model {model_name}' 61 | 62 | x_seg = x[start_idx:end_idx] 63 | print(model_name, i, np.mean(y[start_idx:end_idx])) 64 | x[start_idx:end_idx] = (x_seg - x_seg[0]) * target_length / seg_length + x_start 65 | 66 | x_start += target_length 67 | start_idx = end_idx 68 | 69 | avg_y = np.mean(y) 70 | print(f'{model_name}: {avg_y:.4f}', 'split:', split_indices) 71 | 72 | y_smooth = weighted_moving_average(x, y, window_size=window_size) 73 | ax.axhline(avg_y, color=color, linewidth=0.5, linestyle='-.') 74 | ax.plot(x, y_smooth, label=model_name, color=color, linewidth=1, linestyle='-') 75 | return avg_y 76 | 77 | def read_all_data(plot_sid=PLOT_SID, layer_idx=LAYER_IDX, is_replace=IS_REPLACE): 78 | data_full = {} 79 | split_indices_full = {} 80 | 81 | for model_name in model_list: 82 | fn = model_name + ('_replace' if is_replace else '') 83 | file_path = os.path.join(BASE_DIR, fn, 84 | f'layer_{layer_ids[model_name][layer_idx]}_sid{plot_sid}.npy') 85 | data, split_indices = load_data(file_path) 86 | data_full[fn] = data 87 | split_indices_full[fn] = split_indices 88 | 89 | return data_full, split_indices_full 90 | 91 | def do_plot(ax, window_size=5, **kwargs): 92 | data_full, split_indices_full = read_all_data(**kwargs) 93 | 94 | # cal average length for each segment 95 | data_seg_len = np.zeros((len(model_list) + 1, 5)) 96 | for i, split_indices in enumerate(split_indices_full.values()): 97 | print(split_indices) 98 | data_seg_len[i] = np.diff(split_indices[seg_indices]) 99 | data_seg_len[-1] = np.mean(data_seg_len[:-1], axis=0) 100 | data_seg_len[-1] /= np.sum(data_seg_len[-1]) 101 | data_seg_len[-1] *= X_AXIS_SCALE 102 | print(data_seg_len) 103 | 104 | for i, (data, split_indices) in enumerate(zip(data_full.values(), split_indices_full.values())): 105 | plot_curve(ax, i, data, split_indices, data_seg_len, color=color_palette[i], window_size=window_size) 106 | 107 | x_splits = np.cumsum(data_seg_len[-1]) 108 | for i, x_val in enumerate(x_splits): 109 | ax.axvline(x_val, color='lightgray', linewidth=0.5, linestyle='dotted') 110 | ax.text((x_val + (x_splits[i-1] if i else 0)) / 2, -0.05, f'T{i+1}', ha='center', fontsize=10) 111 | 112 | ax.set_xticks(x_splits) 113 | ax.set_xticklabels(['' for _ in range(5)]) 114 | # ax.set_xticklabels([f'T{i+1}' for i in range(5)]) 115 | 116 | ax.text(70, 0.4, 'Decoding Step', ha='center', fontsize=11) 117 | ax.arrow(50, 0.37, 40, 0, head_width=0.02, head_length=2, fc='k', ec='k') 118 | 119 | ax.text(51.5, 0.68, "System Message's Share of Total Attention Score", ha='center', 120 | fontsize=9.2) 121 | 122 | if __name__ == '__main__': 123 | plt.rcParams["font.family"] = "Calibri" 124 | mpl.rcParams.update({'font.size': 14}) 125 | 126 | fig, ax = plt.subplots(1, 1, figsize=(4, 2.5), dpi=300, tight_layout=True) 127 | 128 | do_plot(ax) 129 | 130 | hadles, labels = ax.get_legend_handles_labels() 131 | fig.legend(hadles, labels, loc='upper center', ncol=3, fontsize=12) 132 | 133 | file_name = f'figures/attenscore' + ('.pdf' if is_pdf else '.png') 134 | plt.savefig(file_name, bbox_inches='tight', pad_inches=0.1) 135 | print(f'Figure saved to {file_name}.') 136 | -------------------------------------------------------------------------------- /plot/fig_atscore_replace.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import matplotlib as mpl 5 | import matplotlib.patches as patches 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import pandas as pd 9 | 10 | from fig_atscore_curve import read_all_data, \ 11 | color_palette, model_list, value_mapper, seg_indices 12 | 13 | is_pdf = True 14 | 15 | def calc_turn(model_idx, data, split_indices, seg_len): 16 | model_name = model_list[model_idx] 17 | 18 | y = [value_mapper(row, split_indices) for row in data] 19 | x = np.arange(len(y)).astype(np.float32) 20 | 21 | start_idx = 0 22 | ans_y = [] 23 | for i in range(seg_len.shape[1]): 24 | seg_length = seg_len[model_idx, i] 25 | 26 | assert seg_length.is_integer() 27 | end_idx = start_idx + int(seg_length) 28 | assert end_idx <= len(x), f'{end_idx} > {len(x)} for model {model_name}' 29 | 30 | ans_y.append(np.mean(y[start_idx:end_idx])) 31 | 32 | start_idx = end_idx 33 | 34 | print(model_name, ans_y) 35 | return ans_y 36 | 37 | def process_data(is_replace=False, **kwargs): 38 | data_full, split_indices_full = read_all_data(is_replace=is_replace, **kwargs) 39 | 40 | # cal average length for each segment 41 | data_seg_len = np.zeros((len(model_list), 5)) 42 | for i, split_indices in enumerate(split_indices_full.values()): 43 | print(split_indices) 44 | data_seg_len[i] = np.diff(split_indices[seg_indices]) 45 | print(data_seg_len) 46 | 47 | res = [] 48 | for i, (data, split_indices) in enumerate(zip(data_full.values(), split_indices_full.values())): 49 | res.append(calc_turn(i, data, split_indices, data_seg_len)) 50 | 51 | return res 52 | 53 | def do_plot(ax, **kwargs): 54 | if 'is_replace' in kwargs: 55 | del kwargs['is_replace'] 56 | res_org = process_data(is_replace=False, **kwargs) 57 | res_rep = process_data(is_replace=True, **kwargs) 58 | 59 | x = np.arange(5) + 1 60 | for i, res in enumerate([res_org, res_rep]): 61 | for j in range(len(res)): 62 | ax.plot(x, res[j], label=model_list[j], 63 | color=color_palette[j], 64 | linestyle='-' if i == 0 else '--', 65 | linewidth=1.2 if i == 0 else 0.5) 66 | 67 | dy = 0.035 68 | base_x, basey = 3.8, 0.22 69 | head_length_ratio = 1 / 8 70 | for i in range(len(model_list)): 71 | get_avg = lambda x: np.mean(x[i]) 72 | avg_org = get_avg(res_org[i]) 73 | avg_rep = get_avg(res_rep[i]) 74 | 75 | if np.sign(avg_rep - avg_org) > 0: 76 | ax.arrow(base_x, basey + dy * i, 0, dy * 0.7, color=color_palette[i], 77 | head_width=0.05, head_length=dy * head_length_ratio, 78 | length_includes_head=True, 79 | fc=color_palette[i], ec=color_palette[i]) 80 | else: 81 | ax.arrow(base_x, basey + dy * (i + 0.7), 0, -dy * 0.7, color=color_palette[i], 82 | head_width=0.05, head_length=dy * head_length_ratio, 83 | length_includes_head=True, 84 | fc=color_palette[i], ec=color_palette[i]) 85 | ax.text(base_x + 0.1, basey + 0.003 + dy * i, f'{(avg_rep - avg_org)*100:+.2f}%', 86 | color=color_palette[i], fontdict={'fontsize': 9, 'font': 'Consolas'}) 87 | 88 | rectangle = patches.Rectangle((base_x - 0.15, basey - 0.008), 1.15, dy * 3+0.04, 89 | edgecolor='black', facecolor='none', 90 | linewidth=0.8) 91 | ax.add_patch(rectangle) 92 | ax.text(base_x - 0.1, basey + dy * 3 + 0.004, 'Avg Diff.', fontsize=8, weight='bold') 93 | 94 | ax.text(3, 0.45, 'Treat System Message', ha='center', fontsize=10) 95 | ax.text(3, 0.415, 'as User Instruction', ha='center', fontsize=10) 96 | 97 | ax.set_xticks(x) 98 | ax.set_xticklabels([f'T{i+1}' for i in range(5)]) 99 | 100 | 101 | if __name__ == '__main__': 102 | plt.rcParams["font.family"] = "Calibri" 103 | mpl.rcParams.update({'font.size': 14}) 104 | 105 | fig, ax = plt.subplots(1, 1, figsize=(4, 3), dpi=300, tight_layout=True) 106 | 107 | do_plot(ax) 108 | 109 | hadles, labels = ax.get_legend_handles_labels() 110 | fig.legend(hadles[:3], labels[:3], loc='upper center', ncol=3, fontsize=12) 111 | 112 | file_name = f'figures/attenscore_r' + ('.pdf' if is_pdf else '.png') 113 | plt.savefig(file_name, bbox_inches='tight', pad_inches=0.1) 114 | print(f'Figure saved to {file_name}.') 115 | 116 | -------------------------------------------------------------------------------- /plot/fig_constraint.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | 4 | import matplotlib as mpl 5 | import matplotlib.pyplot as plt 6 | 7 | from utils.parse_xls import parse_xls, TURN_NUMBER 8 | 9 | LABEL_MAP = { 10 | 'Action' : '动作约束', 11 | 'Content': '内容约束', 12 | 'Background': '背景约束', 13 | 'Role': '角色约束', 14 | 'Format': '格式约束', 15 | 'Style': '风格约束', 16 | } 17 | 18 | sector_labels = list(LABEL_MAP.keys()) 19 | # sector_colors = ['#ff9999','#66b3ff','#99ff99','#ffcc99','#c2c2f0','#ffb3e6'] 20 | sector_colors = ['#b7cef2', '#b9ecea', '#f2f2ca', '#f2ddb6', '#eec1c1', '#d2b6e2'] 21 | 22 | pattern = r'\d+\.\s(..约束)' 23 | 24 | def get_data_old(key): 25 | res = np.zeros(len(LABEL_MAP), dtype=int) 26 | 27 | try: 28 | df = parse_xls(key, sheet_name='不同约束类型遵循') 29 | except Exception as e: 30 | print(f'Error: {e}, when reading {key}') 31 | return res 32 | 33 | for i, (_, col) in enumerate(LABEL_MAP.items()): 34 | res[i] = df[col][0] 35 | 36 | return res 37 | 38 | def get_data(key): 39 | res = np.zeros(len(LABEL_MAP), dtype=int) 40 | 41 | try: 42 | df = parse_xls(key) 43 | except Exception as e: 44 | print(f'Error: {e}, when reading {key}') 45 | return res 46 | 47 | for index, row in df.iterrows(): 48 | text = row['评判结果'] 49 | turn = index % TURN_NUMBER 50 | 51 | if turn == 0: 52 | constraints = set() 53 | 54 | # get all constraints 55 | constraints.update(re.findall(pattern, text)) 56 | 57 | if turn == TURN_NUMBER - 1: 58 | # print('session:', index//TURN_NUMBER, ', constraints:', constraints) 59 | for i, col in enumerate(LABEL_MAP.values()): 60 | if col in constraints: 61 | res[i] += 1 62 | # print('res:', res) 63 | return res 64 | 65 | 66 | to_pdf = True 67 | 68 | def plot_pie(ax, fontsize=14, radius=0.9, stat_type='session'): 69 | if stat_type == 'base': 70 | data = get_data_old('GPT-4o') 71 | elif stat_type == 'session': 72 | data = get_data('GPT-4o') 73 | else: 74 | raise ValueError(f'Invalid stat_type: {stat_type}') 75 | sector_sizes = data 76 | 77 | # Properties for the wedges 78 | wedge_properties = {'edgecolor': 'white', 'linewidth': 1.5} # Adjust linewidth as needed 79 | 80 | total = sum(sector_sizes) 81 | 82 | # Ring 83 | wedges, texts, autotexts = ax.pie(sector_sizes, colors=sector_colors, autopct='%1.1f%%', 84 | startangle=90, radius=radius, wedgeprops=wedge_properties, pctdistance=0.65) 85 | 86 | # Rotate labels to align with wedges 87 | for i, (wedge, label) in enumerate(zip(wedges, autotexts)): 88 | if wedge.theta2 - wedge.theta1 < 40: # MAGIC NUMBER 89 | angle = (wedge.theta1 + wedge.theta2) / 2 90 | angle = angle % 360 91 | if angle > 90 and angle < 270: 92 | angle += 180 93 | label.set_rotation(angle) 94 | if wedge.theta2 - wedge.theta1 < 20: # MAGIC NUMBER 95 | label.set_size(fontsize - 2) 96 | else: 97 | label.set_size(fontsize) 98 | label.set_text(sector_labels[i] + f' ({round(sector_sizes[i]*100 / total)}%)') 99 | label.set_horizontalalignment('center') 100 | label.set_verticalalignment('center') 101 | label.set_color('black') 102 | 103 | # Draw circle in the center to make it look like a donut chart 104 | # centre_circle = plt.Circle((0,0),0.70,fc='white') 105 | # fig = plt.gcf() 106 | # fig.gca().add_artist(centre_circle) 107 | 108 | # Equal aspect ratio ensures that pie is drawn as a circle. 109 | ax.axis('equal') 110 | length = 2.5 111 | y_start, x_start = -1.5, -1.2 112 | ax.set_ylim(y_start, y_start + length) 113 | ax.set_xlim(x_start, x_start + length) 114 | 115 | ax.text(x_start + length/2, y_start + 0.25, 'Constraint Distribution', fontsize=16, ha='center', weight='bold') 116 | 117 | if __name__ == '__main__': 118 | plt.rcParams['font.family'] = 'Calibri' 119 | mpl.rcParams.update({'font.size': 13}) 120 | 121 | fig, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=300, tight_layout=True) 122 | plot_pie(ax) 123 | 124 | plt.savefig('figures/fig_constraint' + ('.pdf' if to_pdf else '.png'), bbox_inches='tight', pad_inches=-0.1) -------------------------------------------------------------------------------- /plot/fig_domain.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | 5 | from utils.parse_xls import parse_xls 6 | 7 | to_pdf = True 8 | color = '#B8ECBE' 9 | 10 | TRANSLATE_TABLE = { 11 | "NLP" : "NLP", 12 | "互联网" : "Internet", 13 | "传媒" : "Media", 14 | "医疗" : "Healthcare", 15 | "宗教玄学" : "Religion", 16 | "心理情感" : "Psychology", 17 | "房地产" : "Real Estate", 18 | "招聘" : "Recruitment", 19 | "政务" : "Govern. Affairs", 20 | "教育" : "Education", 21 | "文化娱乐" : "Culture", 22 | "旅游" : "Travel", 23 | "汽车" : "Automobile", 24 | "法律" : "Law", 25 | "游戏" : "Gaming", 26 | "科技数码" : "Technology", 27 | "美食" : "Cuisine", 28 | "运动健身" : "Sports", 29 | "通用工作" : "General Work", 30 | "通用生活" : "General Life", 31 | "金融" : "Finance", 32 | "其他" : "Other", 33 | } 34 | 35 | def get_data(key): 36 | df = parse_xls(key) 37 | 38 | output = df[['领域', 'answer']].groupby('领域').count() 39 | output['answer'] //= 5 40 | return output.sort_values(by='answer', ascending=True) 41 | 42 | data = get_data('GPT-4o') 43 | 44 | def plot_histogram(ax, fontsize=13): 45 | y = data['answer'] 46 | x = np.arange(len(y)) 47 | 48 | ax.bar(x, y, color=color, width=0.7) 49 | ax.set_xticks(x) 50 | ax.set_xticklabels([TRANSLATE_TABLE[domain] for domain in data.index], rotation=45, ha='right', fontsize=fontsize) 51 | ax.set_ylabel('# System Messages', fontsize=fontsize+2) 52 | 53 | ax.text(7, 35, 'Domain Distribution', fontsize=16, ha='center', va='center', weight='bold') 54 | ax.spines['top'].set_visible(False) 55 | ax.spines['right'].set_visible(False) 56 | 57 | 58 | if __name__ == '__main__': 59 | plt.rcParams['font.family'] = 'Calibri' 60 | 61 | fig, ax = plt.subplots(1, 1, figsize=(6, 4), dpi=300, tight_layout=True) 62 | 63 | plot_histogram(ax) 64 | 65 | file_name = 'figures/fig_domain' + ('.pdf' if to_pdf else '.png') 66 | plt.savefig(file_name, bbox_inches='tight', pad_inches=0.1) 67 | print(f'Figure saved to {file_name}.') -------------------------------------------------------------------------------- /plot/tab1_categoty.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | from utils.parse_xls import parse_xls, TURN_NUMBER 5 | 6 | row_margins = 2.0, 1.5 7 | def number_mapper(x, col_id): 8 | if col_id == 2: 9 | return f'{x:.2f}' 10 | else: 11 | return str(int(x)) 12 | 13 | def get_data(key): 14 | res = np.zeros((3, 4)) 15 | 16 | try: 17 | df = parse_xls(key) 18 | except Exception as e: 19 | print(f'Error: {e}, when reading {key}') 20 | return res 21 | 22 | for index, row in df.iterrows(): 23 | text = row['评判结果'] 24 | base = 1 if row['multi_rounds_related'] else 0 25 | res[base, (0 if row['alignment'] == 'align' else 1)] += 1 26 | res[base, 2] += text.count('约束') 27 | res[base, 3] += 1 28 | 29 | res[2] = res[0] + res[1] 30 | 31 | res[:, 2] /= res[:, 3] 32 | res[:, 3] /= TURN_NUMBER 33 | 34 | return res 35 | 36 | BEFORE_TEX = r'''\begin{table}[t] 37 | \centering 38 | \small 39 | \begin{tabular}{c|cccc} 40 | \toprule 41 | & Aligned & Misaligned & C. per I. & \# Session \\ 42 | \midrule 43 | ''' 44 | ATRER_TEX = r'''\bottomrule 45 | \end{tabular} 46 | \caption{Table 1} 47 | \end{table} 48 | ''' 49 | 50 | if __name__ == '__main__': 51 | data_table = get_data('GPT-4o') 52 | 53 | print(BEFORE_TEX) 54 | print("% ====<<<<==== Auto-generated LaTeX code begin ====>>>>==== %\n") 55 | print(f"% --- generated by {os.path.basename(__file__)} --- %\n") 56 | 57 | for i, key in enumerate(['Parallel', 'Dependent', 'Total']): 58 | print(r'\rule{0pt}{' + str(row_margins[1 if i == 1 else 0]) + r'ex}') 59 | print(f'{key} & ' 60 | + ' & '.join([number_mapper(x, j) for j, x in enumerate(data_table[i])]) 61 | + ' \\\\') 62 | if i == 1: 63 | print('\\hline') 64 | 65 | print("\n% ====<<<<==== Auto-generated LaTeX code end ====>>>>==== %") 66 | print(ATRER_TEX) 67 | 68 | # print(data_table) -------------------------------------------------------------------------------- /plot/tab2_overall.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | from tab4_turn import get_data as get_data_ssr 5 | from tab3_align import get_data as get_data_isr 6 | from tab6_csr_full import get_data as get_data_csr 7 | 8 | from utils.parse_xls import get_full_name 9 | from utils.get_rank import rank_columns_desc 10 | 11 | KEY_LIST = 'GPT-4o', 'GPT-4-Turbo', 'GPT-3.5', 'Claude-3.5', 'Llama3.1-70B', 'Llama3.1-8B', \ 12 | 'Qwen2-72B', 'GLM-4', 'DeepSeek-V2', 'Moonshot', 'Yi-Large' ,'GLM-4-9B', \ 13 | 'ERNIE-4', 'Qwen2-7B' 14 | 15 | number_mapper = lambda x, col_id: f'{x * 100:.1f}' + r'\%' 16 | def hilight_mapper(x, col_id, rank): 17 | if rank == 0: 18 | return f'\\textbf{{{number_mapper(x, col_id)}}}' 19 | elif rank == 1: 20 | return f'\\underline{{{number_mapper(x, col_id)}}}' 21 | else: 22 | return number_mapper(x, col_id) 23 | 24 | row_margins = 2.0, 1.5 # the first, and the rest, in ex 25 | 26 | def generate_table(): 27 | data_entries = [] 28 | 29 | for key in KEY_LIST: 30 | data_ssr = get_data_ssr(key) 31 | data_isr = get_data_isr(key) 32 | data_csr = get_data_csr(key) 33 | 34 | entry = np.array([data_ssr[-1], data_isr[-1], data_csr[-1]])[::-1] 35 | 36 | data_entries.append(entry) 37 | # print('Written:', key) 38 | 39 | data_table = np.array(data_entries) 40 | return data_table 41 | 42 | BEFORE_TEX=r'''\begin{table}[t] 43 | \centering 44 | \small 45 | \begin{tabular}{c|ccc} 46 | \toprule 47 | Full Model Name & \textbf{CSR} & \textbf{ISR} & \textbf{SSR} \\ 48 | \midrule 49 | ''' 50 | AFTER_TEX=r'''\bottomrule 51 | \end{tabular} 52 | \caption{Table 2} 53 | \end{table} 54 | ''' 55 | 56 | if __name__ == '__main__': 57 | data_table = generate_table() 58 | 59 | # find desending order for each column 60 | data_ranked = rank_columns_desc(data_table) 61 | 62 | print(BEFORE_TEX) 63 | print("% ====<<<<==== Auto-generated LaTeX code begin ====>>>>==== %\n") 64 | print(f"% --- generated by {os.path.basename(__file__)} --- %\n") 65 | 66 | for i, key in enumerate(KEY_LIST): 67 | print(r'\rule{0pt}{' + str(row_margins[0 if i == 6 else 1]) + r'ex}') 68 | 69 | print(f'{get_full_name(key)} & ' 70 | + ' & '.join([hilight_mapper(x, j, data_ranked[i, j]) for j, x in enumerate(data_table[i])]) 71 | + ' \\\\') 72 | if i == 5: 73 | print(r'\hline') 74 | 75 | print("\n% ====<<<<==== Auto-generated LaTeX code end ====>>>>==== %") 76 | print(AFTER_TEX) -------------------------------------------------------------------------------- /plot/tab3_align.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | from utils.parse_xls import parse_xls 5 | from utils.get_rank import rank_columns_desc 6 | 7 | KEY_LIST = 'GPT-4o', 'GPT-4-Turbo', 'GPT-3.5', 'Claude-3.5', 'Llama3.1-70B', 'Llama3.1-8B', \ 8 | 'Qwen2-72B', 'GLM-4', 'DeepSeek-V2', 'Moonshot', 'Yi-Large' ,'GLM-4-9B', \ 9 | 'ERNIE-4', 'Qwen2-7B' 10 | 11 | number_mapper = lambda x, col_id: f'{x * 100:.1f}' + r'\%' 12 | def hilight_mapper(x, col_id, rank): 13 | if rank == 0: 14 | return f'\\textbf{{{number_mapper(x, col_id)}}}' 15 | elif rank == 1: 16 | return f'\\underline{{{number_mapper(x, col_id)}}}' 17 | else: 18 | return number_mapper(x, col_id) 19 | 20 | 21 | row_margins = 2.0, 1.5 # the first, and the rest, in ex 22 | 23 | def get_data(key, root_dir='output'): 24 | res = np.zeros(3) 25 | cnt = np.zeros(3) 26 | 27 | def update_res(base, val): 28 | res[base] += val 29 | cnt[base] += 1 30 | 31 | try: 32 | df = parse_xls(key, root_dir=root_dir) 33 | except Exception as e: 34 | print(f'Error: {e}, when reading {key}') 35 | return res 36 | 37 | for index, row in df.iterrows(): 38 | val = row['是否可用'] 39 | 40 | assert row['alignment'] in ('align', 'misalign', 'unknown'), f'Unknown alignment value: {row["alignment"]}, at index {index} of {key}' 41 | 42 | base = 0 43 | update_res(base + (0 if row['alignment'] == 'align' else 1), val) 44 | update_res(base + 2, val) 45 | # print(key, res, cnt) 46 | return res / cnt 47 | 48 | BEFORE_TEX = r'''\begin{table}[t] 49 | \centering 50 | \small 51 | \begin{tabular}{c|cc|c} 52 | \toprule 53 | % \rule{0pt}{2.0ex} 54 | Model & Aligned & Misaligned & Total\\ 55 | \midrule''' 56 | AFTER_TEX = r'''\bottomrule 57 | \end{tabular} 58 | \caption{Table 3} 59 | \end{table}''' 60 | 61 | if __name__ == '__main__': 62 | data_table = np.zeros((len(KEY_LIST), 3)) 63 | for i, key in enumerate(KEY_LIST): 64 | data_table[i] = get_data(key) 65 | 66 | # sort by the last column 67 | index = np.argsort(data_table[:, -1])[::-1] 68 | data_table = data_table[index] 69 | label_list = [KEY_LIST[i] for i in index] 70 | 71 | # find desending order for each column 72 | data_ranked = rank_columns_desc(data_table) 73 | print(BEFORE_TEX) 74 | print("% ====<<<<==== Auto-generated LaTeX code begin ====>>>>==== %\n") 75 | print(f"% --- generated by {os.path.basename(__file__)} --- %\n") 76 | 77 | for i, key in enumerate(label_list): 78 | print(r'\rule{0pt}{' + str(row_margins[1 if i else 0]) + r'ex}') 79 | print(f'{key} & ' 80 | + ' & '.join([hilight_mapper(x, j, data_ranked[i, j]) for j, x in enumerate(data_table[i])]) 81 | + ' \\\\') 82 | # print('\\hline') 83 | 84 | print("\n% ====<<<<==== Auto-generated LaTeX code end ====>>>>==== %") 85 | print(AFTER_TEX) -------------------------------------------------------------------------------- /plot/tab4_turn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | from utils.parse_xls import parse_xls, TURN_NUMBER 5 | from utils.get_rank import rank_columns_desc 6 | 7 | KEY_LIST = 'GPT-4o', 'GPT-4-Turbo', 'GPT-3.5', 'Claude-3.5', 'Llama3.1-70B', 'Llama3.1-8B', \ 8 | 'Qwen2-72B', 'GLM-4', 'DeepSeek-V2', 'Moonshot', 'Yi-Large' ,'GLM-4-9B', \ 9 | 'ERNIE-4', 'Qwen2-7B' 10 | 11 | number_mapper = lambda x, col_id: (f'{x * 100:.1f}' + r'\%') # if col_id % 6 != 5 else f'{x:.2f}' 12 | 13 | def hilight_mapper(x, col_id, rank): 14 | if rank == 0: 15 | return f'\\textbf{{{number_mapper(x, col_id)}}}' 16 | elif rank == 1: 17 | # underline 18 | return f'\\underline{{{number_mapper(x, col_id)}}}' 19 | else: 20 | return number_mapper(x, col_id) 21 | 22 | row_margins = 2.0, 1.5 # the first, and the rest, in ex 23 | 24 | def get_data(key, root_dir='output'): 25 | res = np.zeros(13) 26 | cnt = np.zeros(13) 27 | 28 | def update_res(base, val): 29 | res[base] += val 30 | cnt[base] += 1 31 | 32 | try: 33 | df = parse_xls(key, root_dir=root_dir) 34 | except Exception as e: 35 | print(f'Error: {e}, when reading {key}') 36 | return res 37 | 38 | 39 | for index, row in df.iterrows(): 40 | val = row['是否可用'] 41 | turn = index % TURN_NUMBER 42 | 43 | if turn == 0: 44 | accmulated_turn = 0 45 | 46 | assert isinstance(row['multi_rounds_related'], bool), f'Unknown multi_rounds_related value: {row["multi_rounds_related"]}, at index {index} of {key}' 47 | base = 0 if row['multi_rounds_related'] else 6 48 | 49 | if val > 0 and accmulated_turn == turn: 50 | accmulated_turn += 1 51 | 52 | update_res(base + turn, 1 if accmulated_turn == turn + 1 else 0) 53 | 54 | if turn == TURN_NUMBER - 1: 55 | update_res(base + 5, accmulated_turn) 56 | update_res(12, accmulated_turn) 57 | 58 | res[[5, 11, 12]] /= TURN_NUMBER 59 | # print(key, res, cnt) 60 | 61 | return res / cnt 62 | 63 | BEFORE_TEX = r''' 64 | % !!!!!!!!! set this at beginning of the document !!!!!!!!! 65 | \newcolumntype{M}[1]{>{\centering\arraybackslash}m{#1}} 66 | % !!!!!!!!! set this at beginning of the document !!!!!!!!! 67 | 68 | \begin{table*}[t] 69 | \centering 70 | \small 71 | \begin{tabular}{c|M{0.75cm}M{0.75cm}M{0.75cm}M{0.75cm}M{0.75cm}|M{0.75cm}|M{0.75cm}M{0.75cm}M{0.75cm}M{0.75cm}M{0.75cm}|M{0.75cm}|M{0.75cm}} 72 | \hline 73 | \multirow{2}{*}{Model} & \multicolumn{6}{c|}{Multi-turn Dependent} & \multicolumn{6}{c|}{Multi-turn Parallel} & Total\\ 74 | & R1 & R2 & R3 & R4 & R5 & \textbf{SSR} & R1 & R2 & R3 & R4 & R5 & \textbf{SSR} & \textbf{SSR}\\ 75 | \hline\hline''' 76 | AFTER_TEX = r''' 77 | \end{tabular} 78 | \caption{Tabel 4} 79 | \end{table*} 80 | ''' 81 | 82 | if __name__ == '__main__': 83 | data_table = np.zeros((len(KEY_LIST), 13)) 84 | for i, key in enumerate(KEY_LIST): 85 | data_table[i] = get_data(key) 86 | 87 | # sort by the last column 88 | index = np.argsort(data_table[:, -1])[::-1] 89 | data_table = data_table[index] 90 | label_list = [KEY_LIST[i] for i in index] 91 | 92 | # find desending order for each column 93 | data_ranked = rank_columns_desc(data_table) 94 | 95 | print(BEFORE_TEX) 96 | print("% ====<<<<==== Auto-generated LaTeX code begin ====>>>>==== %\n") 97 | print(f"% --- generated by {os.path.basename(__file__)} --- %\n") 98 | 99 | for i, key in enumerate(label_list): 100 | print(r'\rule{0pt}{' + str(row_margins[1 if i else 0]) + r'ex}') 101 | print(f'{key} & ' 102 | + ' & '.join([hilight_mapper(x, j, data_ranked[i, j]) for j, x in enumerate(data_table[i])]) 103 | + ' \\\\') 104 | print('\\hline') 105 | print(AFTER_TEX) 106 | 107 | print("\n% ====<<<<==== Auto-generated LaTeX code end ====>>>>==== %") -------------------------------------------------------------------------------- /plot/tab6_csr_full.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | from utils.parse_xls import parse_xls 5 | from utils.get_rank import rank_columns_desc 6 | 7 | KEY_LIST = 'GPT-4o', 'GPT-4-Turbo', 'GPT-3.5', 'Claude-3.5', 'Llama3.1-70B', 'Llama3.1-8B', \ 8 | 'Qwen2-72B', 'GLM-4', 'DeepSeek-V2', 'Moonshot', 'Yi-Large' ,'GLM-4-9B', \ 9 | 'ERNIE-4', 'Qwen2-7B' 10 | LABEL_MAP = { 11 | 'Action' : '动作约束', 12 | 'Content': '内容约束', 13 | 'Background': '背景约束', 14 | 'Role': '角色约束', 15 | 'Format': '格式约束', 16 | 'Style': '风格约束', 17 | 'Total': 'total' 18 | } 19 | 20 | number_mapper = lambda x, col_id: f'{x * 100:.1f}' + r'\%' 21 | def hilight_mapper(x, col_id, rank): 22 | if rank == 0: 23 | return f'\\textbf{{{number_mapper(x, col_id)}}}' 24 | elif rank == 1: 25 | return f'\\underline{{{number_mapper(x, col_id)}}}' 26 | else: 27 | return number_mapper(x, col_id) 28 | 29 | row_margins = 2.0, 1.5 # the first, and the rest, in ex 30 | 31 | N = len(LABEL_MAP) 32 | def get_data(key, root_dir='output'): 33 | res = np.zeros(N) 34 | 35 | try: 36 | df = parse_xls(key, sheet_name='不同约束类型遵循', root_dir=root_dir) 37 | except Exception as e: 38 | print(f'Error: {e}, when reading {key}') 39 | return res 40 | 41 | for i, (_, col) in enumerate(LABEL_MAP.items()): 42 | column = df[col] 43 | res[i] = column[1] / column[0] 44 | 45 | return res 46 | 47 | BEFORE_TEX = r'''\begin{table*}[htp] 48 | \centering 49 | % \small 50 | \begin{tabular}{|c|cccccc|c|} 51 | \hline 52 | \rule{0pt}{2.0ex} 53 | \multirow{2}{*}{Model} & \multicolumn{7}{c|}{\textbf{CSR}} \\ 54 | & Action & Content & Background & Role & Format & Style & Total \\\hline''' 55 | AFTER_TEX = r'''\hline 56 | \end{tabular} 57 | \caption{Table 6} 58 | \end{table*}''' 59 | 60 | if __name__ == '__main__': 61 | data_table = np.zeros((len(KEY_LIST), N)) 62 | for i, key in enumerate(KEY_LIST): 63 | data_table[i] = get_data(key) 64 | 65 | # sort by the last column 66 | index = np.argsort(data_table[:, -1])[::-1] 67 | data_table = data_table[index] 68 | label_list = [KEY_LIST[i] for i in index] 69 | 70 | # find desending order for each column 71 | data_ranked = rank_columns_desc(data_table) 72 | print(BEFORE_TEX) 73 | print("% ====<<<<==== Auto-generated LaTeX code begin ====>>>>==== %\n") 74 | print(f"% --- generated by {os.path.basename(__file__)} --- %\n") 75 | 76 | for i, key in enumerate(label_list): 77 | print(r'\rule{0pt}{' + str(row_margins[1 if i else 0]) + r'ex}') 78 | print(f'{key} & ' 79 | + ' & '.join([hilight_mapper(x, j, data_ranked[i, j]) for j, x in enumerate(data_table[i])]) 80 | + ' \\\\') 81 | print('\\hline') 82 | 83 | print("\n% ====<<<<==== Auto-generated LaTeX code end ====>>>>==== %") 84 | print(AFTER_TEX) -------------------------------------------------------------------------------- /plot/tab7_align_full.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | from utils.parse_xls import parse_xls 5 | from utils.get_rank import rank_columns_desc 6 | 7 | KEY_LIST = 'GPT-4o', 'GPT-4-Turbo', 'GPT-3.5', 'Claude-3.5', 'Llama3.1-70B', 'Llama3.1-8B', \ 8 | 'Qwen2-72B', 'GLM-4', 'DeepSeek-V2', 'Moonshot', 'Yi-Large' ,'GLM-4-9B', \ 9 | 'ERNIE-4', 'Qwen2-7B' 10 | 11 | number_mapper = lambda x, col_id: f'{x * 100:.1f}' + r'\%' 12 | def hilight_mapper(x, col_id, rank): 13 | if rank == 0: 14 | return f'\\textbf{{{number_mapper(x, col_id)}}}' 15 | elif rank == 1: 16 | return f'\\underline{{{number_mapper(x, col_id)}}}' 17 | else: 18 | return number_mapper(x, col_id) 19 | 20 | 21 | row_margins = 2.0, 1.5 # the first, and the rest, in ex 22 | 23 | def get_data(key): 24 | res = np.zeros(9) 25 | cnt = np.zeros(9) 26 | 27 | def update_res(base, val): 28 | res[base] += val 29 | cnt[base] += 1 30 | 31 | try: 32 | df = parse_xls(key) 33 | except Exception as e: 34 | print(f'Error: {e}, when reading {key}') 35 | return res 36 | 37 | for index, row in df.iterrows(): 38 | val = row['是否可用'] 39 | 40 | assert isinstance(row['multi_rounds_related'], bool), f'Unknown multi_rounds_related value: {row["multi_rounds_related"]}, at index {index} of {key}' 41 | assert row['alignment'] in ('align', 'misalign', 'unknown'), f'Unknown alignment value: {row["alignment"]}, at index {index} of {key}' 42 | base = 0 if row['multi_rounds_related'] else 3 43 | 44 | update_res(base + (0 if row['alignment'] == 'align' else 1), val) 45 | update_res(base + 2, val) 46 | update_res(6 + (0 if row['alignment'] == 'align' else 1), val) 47 | update_res(8, val) 48 | 49 | return res / cnt 50 | 51 | BEFORE_TEX = r'''\begin{table*}[htp] 52 | \centering 53 | \small 54 | \begin{tabular}{|c|ccc|ccc|ccc|} 55 | \hline 56 | \multirow{2}{*}{Model} & \multicolumn{3}{c|}{Multi-turn Dependent} & \multicolumn{3}{c|}{Multi-turn Parallel} & \multicolumn{3}{c|}{Overall} \\ 57 | & Aligned & Misaligned & Average & Aligned & Misaligned & Average & Aligned & Misaligned & Average \\\hline 58 | \hline''' 59 | AFTER_TEX = r'''\end{tabular} 60 | \caption{Table 7} 61 | \end{table*}''' 62 | 63 | if __name__ == '__main__': 64 | data_table = np.zeros((len(KEY_LIST), 9)) 65 | for i, key in enumerate(KEY_LIST): 66 | data_table[i] = get_data(key) 67 | 68 | # sort by the last column 69 | index = np.argsort(data_table[:, -1])[::-1] 70 | data_table = data_table[index] 71 | label_list = [KEY_LIST[i] for i in index] 72 | 73 | # find desending order for each column 74 | data_ranked = rank_columns_desc(data_table) 75 | 76 | print(BEFORE_TEX) 77 | print("% ====<<<<==== Auto-generated LaTeX code begin ====>>>>==== %\n") 78 | print(f"% --- generated by {os.path.basename(__file__)} --- %\n") 79 | 80 | for i, key in enumerate(label_list): 81 | print(r'\rule{0pt}{' + str(row_margins[1 if i else 0]) + r'ex}') 82 | print(f'{key} & ' 83 | + ' & '.join([hilight_mapper(x, j, data_ranked[i, j]) for j, x in enumerate(data_table[i])]) 84 | + ' \\\\') 85 | print('\\hline') 86 | 87 | print("\n% ====<<<<==== Auto-generated LaTeX code end ====>>>>==== %") 88 | print(AFTER_TEX) -------------------------------------------------------------------------------- /plot/utils/change_color.py: -------------------------------------------------------------------------------- 1 | import colorsys 2 | 3 | def hex_to_rgb(hex_color): 4 | """Convert hex color to RGB.""" 5 | hex_color = hex_color.lstrip('#') 6 | return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4)) 7 | 8 | def rgb_to_hex(rgb_color): 9 | """Convert RGB color back to hex.""" 10 | return '#' + ''.join(f'{int(x):02x}' for x in rgb_color) 11 | 12 | def adjust_saturation(colors): 13 | adjusted_colors = [] 14 | for color in colors: 15 | # Convert hex to RGB 16 | rgb = hex_to_rgb(color) 17 | # Convert RGB to HSV 18 | hsv = colorsys.rgb_to_hsv(rgb[0]/255.0, rgb[1]/255.0, rgb[2]/255.0) 19 | # Increase saturation by 0.1, ensuring it does not exceed 1 20 | new_saturation = min(hsv[1] + 0.08, 1) 21 | new_value = min(hsv[2] - 0.05, 1) 22 | print(new_saturation, new_value) 23 | # Convert back to RGB 24 | new_rgb = colorsys.hsv_to_rgb(hsv[0], new_saturation, new_value) 25 | # Scale RGB back to 0-255 range and convert to integer 26 | new_rgb_scaled = tuple(int(x * 255) for x in new_rgb) 27 | # Convert RGB back to hex 28 | new_hex = rgb_to_hex(new_rgb_scaled) 29 | adjusted_colors.append(new_hex) 30 | return adjusted_colors 31 | 32 | # Sample input 33 | # https://colorkit.co/palettes/9-colors/ 34 | input_colors = ["#d6e6ff","#d7f9f8","#ffffea","#fff0d4","#fbe0e0","#e5d4ef"] 35 | adjusted_colors = adjust_saturation(input_colors) 36 | 37 | print(adjusted_colors) -------------------------------------------------------------------------------- /plot/utils/generate_n_color.py: -------------------------------------------------------------------------------- 1 | import colorsys 2 | import random 3 | import argparse 4 | 5 | def generate_n_colors(n, 6 | hue_offset_ratio=0.2, 7 | saturation_mean=0.6, 8 | saturation_bias=0.1, 9 | brightness_mean=0.75, 10 | brightness_bias=0.15, 11 | seed=None, 12 | shuffle='none'): 13 | """Generate n colors with random hue, saturation, and brightness.""" 14 | 15 | assert shuffle in ['none', 'interleave', 'random', 'max_adjacent'], \ 16 | 'Invalid value for shuffle' 17 | 18 | if seed is not None: 19 | random.seed(seed) 20 | 21 | # random starting hue in [0, 1/n) 22 | start_hue = random.random() / n 23 | 24 | colors = [] 25 | for i in range(n): 26 | hue = (start_hue + i / n + (2*random.random() - 1) * hue_offset_ratio / n) % 1 27 | saturation = max(0, min(1, saturation_mean + (2*random.random() - 1) * saturation_bias)) 28 | brightness = max(0, min(1, brightness_mean + (2*random.random() - 1) * brightness_bias)) 29 | 30 | rgb = colorsys.hsv_to_rgb(hue, saturation, brightness) 31 | rgb_scaled = tuple(int(x * 255) for x in rgb) 32 | hex_color = '#' + ''.join(f'{int(x):02x}' for x in rgb_scaled) 33 | colors.append(hex_color) 34 | 35 | if shuffle == 'interleave': 36 | colors = colors[::2] + colors[1::2] 37 | elif shuffle == 'random': 38 | random.shuffle(colors) 39 | elif shuffle == 'max_adjacent': 40 | half_n = n // 2 41 | colors = [colors[i // 2] if i % 2 == 0 else colors[half_n + i // 2] for i in range(n)] 42 | 43 | return colors 44 | 45 | if __name__ == '__main__': 46 | arg_parser = argparse.ArgumentParser() 47 | arg_parser.add_argument('-n', type=int, help='Number of colors to generate', default=9) 48 | args = arg_parser.parse_args() 49 | 50 | n = args.n 51 | colors = generate_n_colors(n) 52 | print(colors) 53 | 54 | import matplotlib.pyplot as plt 55 | 56 | # Plot the colors in patches 57 | fig, ax = plt.subplots(1, 1, figsize=(n / 4, 1), dpi=300, tight_layout=True) 58 | for i, color in enumerate(colors): 59 | ax.add_patch(plt.Rectangle((i / n, 0), 0.95/n, 0.95, color=color)) 60 | ax.axis('off') 61 | plt.show() -------------------------------------------------------------------------------- /plot/utils/get_rank.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def rank_columns_desc(data): 4 | n = data.shape[0] 5 | 6 | sorted_indices = np.argsort(-data, axis=0) 7 | 8 | ranks = np.empty_like(sorted_indices, dtype=int) 9 | 10 | for i in range(data.shape[1]): 11 | column = data[:, i] 12 | sorted_column = -column[sorted_indices[:, i]] 13 | 14 | _, first_index_positions = np.unique(sorted_column, return_index=True) 15 | 16 | first_index_positions = first_index_positions.tolist() + [n] 17 | # print(first_index_positions) 18 | 19 | ranks_for_uniques = np.zeros(n, dtype=float) 20 | last_idx = 0 21 | for j in range(len(first_index_positions) - 1): 22 | ranks_for_uniques[last_idx:first_index_positions[j+1]] = first_index_positions[j] 23 | last_idx = first_index_positions[j+1] 24 | 25 | # print(ranks_for_uniques) 26 | ranks[:, i] = ranks_for_uniques[np.argsort(sorted_indices[:, i])] 27 | 28 | return ranks 29 | 30 | if __name__ == '__main__': 31 | data = np.array([ 32 | [20, 2, 20], 33 | [10, 15, 5], 34 | [50, 25, 5], 35 | [30, 10, 20] 36 | ]) 37 | 38 | # Compute ranks 39 | ranked_data = rank_columns_desc(data) 40 | print("Original Data:\n", data) 41 | print("Ranked Data:\n", ranked_data) -------------------------------------------------------------------------------- /plot/utils/parse_xls.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | 4 | TOTAL_SYSTEM_ID = 500 5 | TURN_NUMBER = 5 6 | 7 | ENTRY_NUMBER = TOTAL_SYSTEM_ID * TURN_NUMBER 8 | 9 | KEY_MAP = { 10 | 'GPT-4-Turbo' : 'gpt4_turbo_0409', 11 | 'GPT-4o' : 'gpt4o', 12 | 'GPT-3.5' : 'gpt35', 13 | 'ERNIE-4' : 'ernie4', 14 | 'Moonshot' : 'moonshot', 15 | 'Qwen2-7B' : 'qwen2_7b', 16 | 'Qwen2-72B' : 'qwen2_72b', 17 | 'GLM-4' : 'glm4', 18 | 'Yi-Large' : 'yi_large', 19 | 'DeepSeek-V2' : 'deepseek', 20 | 'Claude-3.5' : 'claude35_opus', 21 | 'Llama3.1-70B' : 'llama3_70b', 22 | 'Llama3.1-8B' : 'llama3_8b', 23 | 'GLM-4-9B' : 'glm_9b_client', 24 | } 25 | 26 | FULL_MAP = { 27 | 'GPT-4-Turbo' : 'GPT4-Turbo-20240409$^\dag$', 28 | 'GPT-4o' : 'GPT4o$^\dag$', 29 | 'GPT-3.5' : 'GPT3.5-Turbo-20231106$^\dag$', 30 | 'ERNIE-4' : 'ERNIE-4-8K-0613$^\dag$', 31 | 'Moonshot' : 'Moonshot-V1-8K$^\dag$', 32 | 'Qwen2-7B' : 'Qwen2-7B-Instruct', 33 | 'Qwen2-72B' : 'Qwen2-72B-Instruct', 34 | 'GLM-4' : 'GLM-4-0520$^\dag$', 35 | 'DeepSeek-V2' : 'DeepSeek-V2-0628$^\dag$', 36 | 'Claude-3.5' : 'Claude-3.5-Opus$^\dag$', 37 | 'Llama3.1-70B' : 'Llama3.1-70B-Instruct', 38 | 'Llama3.1-8B' : 'Llama3.1-8B-Instruct', 39 | 'Baichuan2-13B' : 'Baichuan2-13B-Chat', 40 | 'GLM-4-9B' : 'GLM-4-9B-Chat', 41 | } 42 | 43 | def get_full_name(key): 44 | return FULL_MAP.get(key, key) 45 | 46 | def parse_xls(key, sheet_name='详情', root_dir='output'): 47 | file_path = os.path.join(root_dir, KEY_MAP.get(key, key), f'{KEY_MAP.get(key, key)}_analysis.xlsx') 48 | df = pd.read_excel(file_path, sheet_name) 49 | if sheet_name == '详情': 50 | assert len(df) == ENTRY_NUMBER, f'Reading error: {len(df)} entries found, expected {ENTRY_NUMBER} entries ({file_path})' 51 | return df 52 | 53 | if __name__ == '__main__': 54 | file_path = 'output/ernie4/ernie4_analysis.xlsx' 55 | data = parse_xls(file_path) 56 | print(data) -------------------------------------------------------------------------------- /plot/utils/smooth.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def weighted_moving_average(x, y, window_size=5): 4 | y_smooth = np.zeros_like(y) 5 | n = len(x) 6 | 7 | for i in range(n): 8 | start = max(0, i - window_size) 9 | end = min(n, i + window_size + 1) 10 | 11 | x_neigh = x[start:end] 12 | y_neigh = y[start:end] 13 | 14 | distances = np.abs(x_neigh - x[i]) 15 | weights = 1 / (distances + 0.1) 16 | weights /= weights.sum() 17 | 18 | y_smooth[i] = np.dot(y_neigh, weights) 19 | 20 | return y_smooth -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib>=3.9 2 | pandas>=2.2 3 | openpyxl>=3.1 4 | transformers>=4.44 5 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | infer_model="gpt4_turbo_0409" # change this to the model you want to evaluate 2 | max_threads=20 3 | python -m eval_system_bench \ 4 | --infer_model_name ${infer_model} \ 5 | --output_dir output \ 6 | --max_threads ${max_threads} 7 | -------------------------------------------------------------------------------- /run_metric.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | infer_model="gpt4_turbo_0409" 4 | output="./output" 5 | python plot/eval_output.py \ 6 | --infer_model_name ${infer_model} \ 7 | --output_dir ${output} \ 8 | 9 | -------------------------------------------------------------------------------- /servers/run_vllm_serve.sh: -------------------------------------------------------------------------------- 1 | # export HF_ENDPOINT=https://hf-mirror.com 2 | # export CUDA_VISIBLE_DEVICES=4,5 3 | 4 | model="THUDM/glm-4-9b-chat" 5 | vllm serve $model \ 6 | --dtype auto \ 7 | --port 33618 \ 8 | --tensor-parallel-size 2 \ 9 | --api-key custom-key \ 10 | --trust-remote-code 11 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import pandas as pd 4 | from collections import defaultdict 5 | import re 6 | import numpy as np 7 | 8 | def str2bool(value): 9 | if isinstance(value, bool): 10 | return value 11 | if value.lower() == "true": 12 | return True 13 | elif value.lower() == "false": 14 | return False 15 | else: 16 | raise argparse.ArgumentTypeError('Boolean value expected.') 17 | 18 | # 定义计算加权平均的自定义函数 19 | def weighted_mean(series): 20 | count = series.notna().sum() # 非缺失值计数 21 | return series.sum() / count if count != 0 else np.nan 22 | 23 | def analysis_eval_results(eval_result_filepath, analysis_eval_output_path): 24 | datas = json.load(open(eval_result_filepath, encoding="utf-8")) 25 | 26 | count_every_round = {i + 1: 0 for i in range(5)} 27 | count_continuous_round = {i + 1: 0 for i in range(5)} 28 | count_continuous_round_relate = {i + 1: 0 for i in range(5)} 29 | count_continuous_round_parallel = {i + 1: 0 for i in range(5)} 30 | 31 | relate_align = {"align": [], "misalign": []} 32 | parallel_align = {"align": [], "misalign": []} 33 | total_align = {"align": [], "misalign": []} 34 | 35 | count_relate = 0 36 | count_parallel = 0 37 | 38 | categorys_all = defaultdict(int) 39 | categorys_valid = defaultdict(int) 40 | 41 | all_infos = list() 42 | datas = sorted(datas, key=lambda key:key["system_id"]) 43 | 44 | for data in datas: 45 | if data["rounds_related"]: 46 | count_relate += 1 47 | else: 48 | count_parallel += 1 49 | session_flag = True 50 | for index, message in enumerate([m for m in data["messages"] if m["role"] == "user"]): 51 | prompt = message["content"] 52 | eval_res = data["eval_results"][message["content"]] 53 | 54 | round_info = dict() 55 | round_info["system_id"] = data["system_id"] 56 | round_info["领域"] = data["领域"] 57 | round_info["场景"] = data["场景"] 58 | round_info["multi_rounds_related"] = data["rounds_related"] 59 | round_info["alignment"] = data["prompt_infos"][prompt]["alignment"] 60 | round_info["round_index"] = index + 1 61 | round_info["system_prompt"] = data["system_prompt"] 62 | round_info["prompt"] = prompt 63 | round_info["referrence"] = data["messages"][index*2 + 2]["content"] 64 | round_info["answer"] = eval_res["response"] 65 | round_info["infer_model"] = data["infer_model"] 66 | round_info["评判细则"] = "\n".join([str(cri["criteria_id"]) + ". " + cri["criteria_content"] + " | " + cri["criteria_type"] for cri_index, cri in eval_res["criteria"].items()]) 67 | round_info["评判理由"] = eval_res["评判理由"] 68 | round_info["评判结果"] = "\n".join([k + ". " + eval_res["criteria"][k]["criteria_type"] + " | " + v for k,v in eval_res["评判结果"].items()]) 69 | round_info["是否可用"] = 1 if "否" not in eval_res["评判结果"].values() else 0 70 | 71 | # 统计多轮相关/平行-prompt对齐/冲突的可用率结果 72 | if round_info["multi_rounds_related"]: 73 | relate_align[round_info["alignment"]].append(round_info["是否可用"]) 74 | else: 75 | parallel_align[round_info["alignment"]].append(round_info["是否可用"]) 76 | total_align[round_info["alignment"]].append(round_info["是否可用"]) 77 | 78 | # 统计多轮相关/平行-prompt对齐/冲突下的多轮连续遵循情况 79 | round_flag = True 80 | for cri in eval_res["评判结果"]: 81 | categorys_all[eval_res["criteria"][cri]["criteria_type"]] += 1 82 | if eval_res["评判结果"][cri] == "是": 83 | categorys_valid[eval_res["criteria"][cri]["criteria_type"]] += 1 84 | else: 85 | round_flag = False 86 | if round_flag: 87 | count_every_round[index + 1] += 1 88 | if session_flag: 89 | count_continuous_round[index + 1] += 1 90 | if round_info["multi_rounds_related"]: 91 | count_continuous_round_relate[index + 1] += 1 92 | else: 93 | count_continuous_round_parallel[index + 1] += 1 94 | else: 95 | session_flag = False 96 | 97 | all_infos.append(round_info) 98 | 99 | # 结果总表 100 | all_infos = pd.DataFrame(all_infos) 101 | # 每一轮的遵循率 102 | count_every_round_rate = {k: round(v/len(datas) * 100, 2) for k, v in count_every_round.items()} 103 | count_every_round["total"] = sum([v for _, v in count_every_round.items()]) 104 | count_every_round_rate["total"] = round(count_every_round["total"] / (len(datas) * 5) * 100, 2) 105 | 106 | # 多轮连续遵循率 107 | count_continuous_round_relate_rate = {k: round(v/count_relate * 100, 2) for k, v in count_continuous_round_relate.items()} 108 | count_continuous_round_parallel_rate = {k: round(v/count_parallel * 100, 2) for k, v in count_continuous_round_parallel.items()} 109 | count_continuous_round_rate = {k: round(v/len(datas) * 100, 2) for k, v in count_continuous_round.items()} 110 | # 不同约束类型的可用率 111 | category_valid_rate = {type_key: round(categorys_valid[type_key] / categorys_all[type_key] * 100, 2) for type_key in categorys_all} 112 | categorys_all["total"] = sum(categorys_all.values()) 113 | categorys_valid["total"] = sum(categorys_valid.values()) 114 | category_valid_rate["total"] = round(categorys_valid["total"] / categorys_all["total"] * 100, 2) 115 | 116 | relate_align_total = {k: len(v) for k, v in relate_align.items()} 117 | parallel_align_total = {k: len(v) for k, v in parallel_align.items()} 118 | total_align_total = {k: len(v) for k, v in total_align.items()} 119 | 120 | relate_align_valid = {k: sum(v) for k, v in relate_align.items()} 121 | parallel_align_valid = {k: sum(v) for k, v in parallel_align.items()} 122 | total_align_valid = {k: sum(v) for k, v in total_align.items()} 123 | 124 | for align_item in [relate_align_total, parallel_align_total, total_align_total, relate_align_valid, parallel_align_valid, total_align_valid]: 125 | align_item["total"] = sum(align_item.values()) 126 | 127 | relate_align_rate = {k: round(v/relate_align_total[k] * 100, 2) for k, v in relate_align_valid.items()} 128 | parallel_align_rate = {k: round(v/parallel_align_total[k] * 100, 2) for k, v in parallel_align_valid.items()} 129 | total_align_rate = {k: round(v/total_align_total[k] * 100, 2) for k, v in total_align_valid.items()} 130 | 131 | relate_align_rate["total"] = round(relate_align_valid["total"] / relate_align_total["total"] * 100, 2) 132 | parallel_align_rate["total"] = round(parallel_align_valid["total"] / parallel_align_total["total"] * 100, 2) 133 | total_align_rate["total"] = round(total_align_valid["total"] / total_align_total["total"] * 100, 2) 134 | 135 | print("=" * 50) 136 | print(f"多轮相关:{count_relate}, 多轮平行:{count_parallel}") 137 | print(json.dumps(count_continuous_round_relate, ensure_ascii=False, indent=2)) 138 | print(json.dumps(count_continuous_round_parallel, ensure_ascii=False, indent=2)) 139 | print(json.dumps(count_continuous_round, ensure_ascii=False, indent=2)) 140 | print("-" * 50) 141 | print(json.dumps(count_continuous_round_relate_rate, ensure_ascii=False, indent=2)) 142 | print(json.dumps(count_continuous_round_parallel_rate, ensure_ascii=False, indent=2)) 143 | print(json.dumps(count_continuous_round_rate, ensure_ascii=False, indent=2)) 144 | print("=" * 50) 145 | print(json.dumps(categorys_all, ensure_ascii=False, indent=2)) 146 | print(json.dumps(categorys_valid, ensure_ascii=False, indent=2)) 147 | print(json.dumps(category_valid_rate, ensure_ascii=False, indent=2)) 148 | print("=" * 50) 149 | 150 | with pd.ExcelWriter(analysis_eval_output_path) as writer: 151 | # sheet1:详情 152 | all_infos.to_excel(writer, sheet_name='详情', index=False) 153 | # sheet2:不同约束类型遵循 154 | round_evals = pd.DataFrame([categorys_all, categorys_valid, category_valid_rate], index=["约束总量", "遵循数量", "遵循率"]) 155 | cols = list(round_evals.columns) 156 | cols.remove("total") 157 | sorted(cols) 158 | cols.append("total") 159 | round_evals = round_evals[cols] 160 | round_evals.to_excel(writer, sheet_name='不同约束类型遵循') 161 | # sheet3:不同轮次遵循 162 | round_index = pd.DataFrame([count_every_round, count_every_round_rate], index=["当前轮次遵循数量", "遵循率"]) 163 | round_index.to_excel(writer, sheet_name='不同轮次遵循') 164 | # sheet4:最大连续遵循轮次 165 | # count_continuous_round_relate["total"] = count_relate 166 | # count_continuous_round_parallel["total"] = count_parallel 167 | # count_continuous_round["total"] = count_relate + count_parallel 168 | 169 | # count_continuous_round_relate_rate["total"] = round(count_continuous_round_relate["total"] / count_relate * 100, 2) 170 | # count_continuous_round_parallel_rate["total"] = round(count_continuous_round_parallel["total"] / count_parallel * 100, 2) 171 | # count_continuous_round_rate["total"] = round(count_continuous_round["total"] / (count_relate + count_parallel) * 100, 2) 172 | def cal_continuous_avg(item, total): 173 | continuous_count = list(item.values()) 174 | continuous_count.append(0) 175 | continuous_count.insert(0, total) 176 | return round(sum([(continuous_count[i] - continuous_count[i + 1]) * i for i in range(len(continuous_count) - 1)]) / total, 2) 177 | 178 | count_continuous_round_relate_rate["total"] = cal_continuous_avg(count_continuous_round_relate, count_relate) 179 | count_continuous_round_parallel_rate["total"] = cal_continuous_avg(count_continuous_round_parallel, count_parallel) 180 | count_continuous_round_rate["total"] = cal_continuous_avg(count_continuous_round, count_relate + count_parallel) 181 | 182 | for item in [count_continuous_round_relate, count_continuous_round_parallel, count_continuous_round]: 183 | item["total"] = round(sum(k * v for k, v in item.items()) / len(item), 2) 184 | 185 | merge_continuous_round = pd.concat([pd.DataFrame([count_continuous_round_relate, count_continuous_round_relate_rate]), 186 | pd.DataFrame([count_continuous_round_parallel, count_continuous_round_parallel_rate]), 187 | pd.DataFrame([count_continuous_round, count_continuous_round_rate])], axis=1) 188 | merge_continuous_round.to_excel(writer, sheet_name='最大连续遵循轮次') 189 | 190 | # sheet5:统计多轮相关/平行-prompt对齐/冲突的可用率结果 191 | merge_sys_prompt_align = pd.concat([pd.DataFrame([relate_align_total, relate_align_valid, relate_align_rate]), 192 | pd.DataFrame([parallel_align_total, parallel_align_valid, parallel_align_rate]), 193 | pd.DataFrame([total_align_total, total_align_valid, total_align_rate])], axis=1) 194 | merge_sys_prompt_align.to_excel(writer, sheet_name='统计多轮相关|平行-prompt对齐|冲突的可用率结果') 195 | 196 | 197 | def get_eval_pattern(messages, criteria): 198 | system_prompt = messages[0]["content"] 199 | historys = messages[1:-2] 200 | current_dialogs = messages[-2::] 201 | criteria_pattern = "\n".join([str(criteria[cri_id]["criteria_id"]) + ". " + criteria[cri_id]["criteria_content"] + " | " + criteria[cri_id]["criteria_type"] for cri_id in criteria]) 202 | 203 | pattern = f""" 204 | # 背景与目标 205 | 你现在是一个大模型结果评估专家。以下,你将会面对一个大模型system prompt遵循能力评估任务。 206 | 我将会给出你对应的system prompt、历史对话、当前轮次的问题、和当前轮次的回答 207 | 需要你准确判断当前轮次的回答是否合格。 208 | 为了让你的评判足够准确,我还将会给出评判细则,需要你准确判断并告知评判细则当中每一条约束的遵循情况。 209 | 210 | # 对话信息 211 | ## system prompt 212 | 213 | {system_prompt} 214 | 215 | 216 | ## 历史对话轮次 217 | 218 | """ 219 | 220 | for index in range(len(historys) // 2): 221 | user_message = historys[index * 2] 222 | assistant_message = historys[index * 2 + 1] 223 | 224 | pattern += f""" 225 | 226 | 227 | {user_message["role"]} 228 | 229 | 230 | {user_message["content"]} 231 | 232 | {assistant_message["role"]} 233 | 234 | {assistant_message["content"]} 235 | 236 | 237 | 238 | """ 239 | 240 | pattern += f""" 241 | ## 当前待评估的对话轮次 242 | 243 | {current_dialogs[0]["role"]} 244 | 245 | 246 | {current_dialogs[0]["content"]} 247 | 248 | 249 | 250 | {current_dialogs[1]["role"]} 251 | 252 | 253 | {current_dialogs[1]["content"]} 254 | 255 | 256 | # 评判细则 257 | <评判细则> 258 | {criteria_pattern} 259 | 260 | """ 261 | 262 | pattern += """ 263 | 请你认真阅读上述system prompt设定与历史对话轮次,并严格以评判细则为评判标准,针对评判细则当中的逐条要求,判断当前对话轮次的回答是否遵循。 264 | 请以json格式回答,包含两个字段:评判理由、评判结果(评判结果为一个dict,dict的key评判细则的序号,value为对应的评判结果(是/否)) 265 | 输出格式如下: 266 | '''json 267 | { 268 | "评判理由": "……", 269 | "评判结果": { 270 | 1: "……", 271 | …… 272 | } 273 | } 274 | ''' 275 | 276 | """ 277 | return pattern 278 | 279 | 280 | def character_count(answer, criteria_content): 281 | chinese_character_count = re.findall(r"汉字字数数量(|大于|小于|等于)(\d{1,5})", criteria_content) 282 | character_count = re.findall(r"字数数量(|大于|小于|等于)(\d{1,5})", criteria_content) 283 | assert len(chinese_character_count) <= 0 284 | assert len(character_count) <= 0 285 | 286 | if len(chinese_character_count) > 0: 287 | comparison, number = chinese_character_count[0] 288 | number = int(number) 289 | chinese_character = re.findall("[\u4e00-\u9fa5]", answer) 290 | if comparison == "大于": 291 | return len(chinese_character) > number 292 | elif comparison == "等于": 293 | return len(chinese_character) == number 294 | else: 295 | return len(chinese_character) < number 296 | elif len(character_count) > 0: 297 | comparison, number = character_count[0] 298 | number = int(number) 299 | if comparison == "大于": 300 | return len(answer) > number 301 | elif comparison == "等于": 302 | return len(answer) == number 303 | else: 304 | return len(answer) < number 305 | else: 306 | return -1 307 | --------------------------------------------------------------------------------