├── .gitignore ├── README.md ├── ada_leval ├── api.py ├── dataset.py ├── smp.py └── util.py ├── assets ├── AdaLEval.png └── BestAnswer.png ├── fetch_data.sh ├── run.py ├── run.sh └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # Images 156 | images/ 157 | 158 | scripts/*ttf 159 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Ada-LEval 2 | 3 | **The official implementation of ["Ada-LEval: Evaluating long-context LLMs with length-adaptable benchmarks"](https://arxiv.org/abs/2404.06480)** 4 | 5 |

6 | 7 |

8 | 9 | **Ada-LEval** is a pioneering benchmark to assess the long-context capabilities with length-adaptable questions. It comprises two challenging tasks: **TSort**, which involves arranging text segments into the correct order, and **BestAnswer**, which requires choosing the best answer of a question among multiple candidates. 10 | 11 | Both tasks feature the following advantages: 12 | 1. **Controllable Test Cases**: The length of each test case can be finely tuned - by adjusting the number and length of text segments in TSort and altering the number of distractor options in BestAnswer. 13 | 2. **Necessity for Full-Text Comprehension**: Successful completion of both tasks mandates complete reading and understanding of the provided text. 14 | 3. **Precise Accuracy Measurement**: The design of these tasks allows for unambiguous accuracy calculation. TSort has a definitive 'correct' order, while in BestAnswer, the annotated responses by the questioner serve as definitive answers. 15 | 16 |

17 | 18 |

19 | 20 | ## 🛠️QuickStart 21 | 22 | In this repo, we implement the evaluation of Ada-LEval on GPT-4-Turbo-0125 (an example for APIs) and internlm2-[7b/20b] (an example for opensource LLMs). You can follow our implementation to evaluate Ada-LEval on your custom LLMs. 23 | 24 | 1. **Preparation** 25 | 26 | 1. Installation and data preparation 27 | 28 | ```bash 29 | cd Ada-LEval 30 | pip install -e . 31 | bash fetch_data.sh 32 | ``` 33 | 34 | 2. For evaluating GPT-4, please set the environment variable: `export OPENAI_API_KEY=sk-xxxxx` 35 | 36 | - Cost Estimation for GPT-4-Turbo-0125: `setting (2k, 4k, etc.) * n_samples * $0.01 / 1000` 37 | 38 | 3. For evaluating InternLM2-7B, please follow the [official guide](https://github.com/InternLM/lmdeploy) to install LMDeploy. 39 | 40 | 2. **Evaluate GPT-4-Turbo-0125**: `python run.py --data {dataset_name} --model gpt-4-0125` 41 | 42 | 3. **Evaluate InternLM2-7B**: `bash run.sh --data {dataset_name} --model internlm2-7b` 43 | 44 | \* `dataset_name` can be `stackselect_{setting}` (for **BestAnswer**) or `textsort_{setting}` (for **TSort**). For example, `stackselect_16k`, `textsort_2k`, etc. 45 | 46 | \** `run.sh` detect the number of available GPUs and do the data parallel. 47 | 48 | ## 📊Evaluation Result 49 | Here is the evaluation result of TSort and BestAnswer benchmark under **long-context** & **ultra-long-context** settings. We also provide a 'random guess' baseline for each task. 50 | 51 | **Definition:** long-context -> context window < 32k; ultra-long-context: context-window >= 32k 52 | 53 | **The Number of Evaluation Samples:** 1. API models on long-context: 200; 2. API models on ultra-long-context: 50; 3. Open-source models on long-context: 1000; 4. Open-source models on ultra-long-context: 200. 54 | 55 | #### TL;DR: 56 | 57 | 1. **TSort is an extremely challenging benchmark:** We observe positive results (significantly better than random guess) only when evaluating SOTA API models (GPT-4 series) under short context settings (< 8k). 58 | 2. **BestAnswer is a challenging long-context benchmark with discrimination:** With 32k long-context, GPT-4-Turbo-0125 still obtains a decent 30% accuracy, while other models significantly lag behind. When the context window is 64k or even longer, models failed to solve almost all of the questions. 59 | 60 | #### TSort Evaluation Results 61 | 62 | Blanks indicate the result under the corresponding setting is not evaluated. 63 | 64 | | TSort | 2k | 4k | 8k | 16k | 32k | 64k | 128k | 65 | | -------------------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | 66 | | GPT-4-Turbo-0125 | 15.5 | 16.5 | 8.5 | 5.5 | 2.0 | 4.0 | 2.0 | 67 | | GPT-4-Turbo-1106 | 18.5 | 15.5 | 7.5 | 3.5 | 6.0 | 6.0 | 6.0 | 68 | | GPT-3.5-Turbo-1106 | 4.0 | 4.5 | 4.5 | 5.5 | | | | 69 | | Claude-2 | 5.0 | 5.0 | 4.5 | 3.0 | 0.0 | 0.0 | | 70 | | LongChat-7b-v1.5-32k | 5.3 | 5.0 | 3.1 | 2.5 | | | | 71 | | ChatGLM2-6B-32k | 0.9 | 0.7 | 0.2 | 0.9 | | | | 72 | | ChatGLM3-6B-32k | 2.3 | 2.4 | 2.0 | 0.7 | | | | 73 | | Vicuna-7b-v1.5-16k | 5.3 | 2.2 | 2.3 | 1.7 | | | | 74 | | Vicuna-13b-v1.5-16k | 5.4 | 5.0 | 2.4 | 3.1 | | | | 75 | | InternLM2-7b | 5.1 | 3.9 | 5.1 | 4.3 | | | | 76 | | Random Guess | 4.2 | 4.2 | 4.2 | 4.2 | 4.2 | 4.2 | 4.2 | 77 | 78 | #### BestAnswer Evaluation Results 79 | 80 | Blanks indicate the result under the corresponding setting is not evaluated. 81 | 82 | | BestAnswer | 1k | 2k | 4k | 6k | 8k | 12k | 16k | 32k | 64k | 128k | 83 | | -------------------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | 84 | | GPT-4-Turbo-0125 | 73.5 | 73.5 | 65.5 | 63.0 | 56.5 | 52.0 | 44.5 | 30.0 | 0.0 | 0.0 | 85 | | GPT-4-Turbo-1106 | 74.0 | 73.5 | 67.5 | 59.5 | 53.5 | 49.5 | 44.0 | 16.0 | 0.0 | 0.0 | 86 | | GPT-3.5-Turbo-1106 | 61.5 | 48.5 | 41.5 | 29.5 | 17.0 | 2.5 | 2.5 | | | | 87 | | Claude-2 | 65.0 | 43.5 | 23.5 | 15.0 | 17.0 | 12.0 | 11.0 | 4.0 | 0.0 | | 88 | | LongChat-7b-v1.5-32k | 32.4 | 10.7 | 5.7 | 3.1 | 1.9 | 1.6 | 0.8 | | | | 89 | | ChatGLM2-6B-32k | 31.2 | 10.9 | 4.5 | 1.6 | 1.6 | 0.0 | 0.3 | | | | 90 | | ChatGLM3-6B-32k | 39.8 | 18.8 | 9.0 | 5.0 | 3.4 | 0.9 | 0.5 | | | | 91 | | Vicuna-7b-v1.5-16k | 37.0 | 11.1 | 5.8 | 3.2 | 1.8 | 1.9 | 1.0 | | | | 92 | | Vicuna-13b-v1.5-16k | 53.4 | 29.2 | 13.1 | 4.3 | 2.2 | 1.4 | 0.9 | | | | 93 | | InternLM2-7b | 58.6 | 49.5 | 33.9 | 12.3 | 13.4 | 2.0 | 0.8 | 0.5 | 0.5 | 0.0 | 94 | | Random Guess | 26.7 | 10.1 | 4.5 | 3.0 | 2.3 | 1.4 | 1.1 | 0.6 | 0.3 | 0.1 | 95 | 96 | ## 🖊️Citation 97 | 98 | ```bib 99 | @inproceedings{wang2024ada, 100 | title={Ada-LEval: Evaluating long-context LLMs with length-adaptable benchmarks}, 101 | author={Wang, Chonghua and Duan, Haodong and Zhang, Songyang and Lin, Dahua and Chen, Kai}, 102 | booktitle={Proceedings of the 2024 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (Volume 1: Long Papers)}, 103 | pages={3712--3724}, 104 | year={2024} 105 | } 106 | ``` 107 | -------------------------------------------------------------------------------- /ada_leval/api.py: -------------------------------------------------------------------------------- 1 | import time 2 | import random as rd 3 | from abc import abstractmethod 4 | from .util import get_logger 5 | from .smp import * 6 | 7 | 8 | class BaseAPI: 9 | 10 | def __init__(self, 11 | retry=10, 12 | wait=3, 13 | system_prompt=None, 14 | verbose=True, 15 | fail_msg='Failed to obtain answer via API.', 16 | **kwargs): 17 | self.wait = wait 18 | self.retry = retry 19 | self.system_prompt = system_prompt 20 | self.kwargs = kwargs 21 | self.verbose = verbose 22 | self.fail_msg = fail_msg 23 | self.logger = get_logger('ChatAPI') 24 | if len(kwargs): 25 | self.logger.info(f'BaseAPI received the following kwargs: {kwargs}') 26 | self.logger.info('Will try to use them as kwargs for `generate`. ') 27 | 28 | @abstractmethod 29 | def generate_inner(self, inputs, **kwargs): 30 | self.logger.warning('For APIBase, generate_inner is an abstract method. ') 31 | assert 0, 'generate_inner not defined' 32 | ret_code, answer, log = None, None, None 33 | # if ret_code is 0, means succeed 34 | return ret_code, answer, log 35 | 36 | def working(self): 37 | retry = 3 38 | while retry > 0: 39 | ret = self.generate('hello') 40 | if ret is not None and ret != '' and self.fail_msg not in ret: 41 | return True 42 | retry -= 1 43 | return False 44 | 45 | def generate(self, inputs, **kwargs): 46 | input_type = None 47 | if isinstance(inputs, str): 48 | input_type = 'str' 49 | elif isinstance(inputs, list) and isinstance(inputs[0], str): 50 | input_type = 'strlist' 51 | elif isinstance(inputs, list) and isinstance(inputs[0], dict): 52 | input_type = 'dictlist' 53 | assert input_type is not None, input_type 54 | 55 | answer = None 56 | # a very small random delay [0s - 0.5s] 57 | T = rd.random() * 0.5 58 | time.sleep(T) 59 | 60 | for i in range(self.retry): 61 | try: 62 | ret_code, answer, log = self.generate_inner(inputs, **kwargs) 63 | if ret_code == 0 and self.fail_msg not in answer and answer != '': 64 | if self.verbose: 65 | print(answer) 66 | return answer 67 | elif self.verbose: 68 | if not isinstance(log, str): 69 | try: 70 | log = log.text 71 | except: 72 | self.logger.warning(f'Failed to parse {log} as an http response. ') 73 | self.logger.info(f'RetCode: {ret_code}\nAnswer: {answer}\nLog: {log}') 74 | except Exception as err: 75 | if self.verbose: 76 | self.logger.error(f'An error occured during try {i}:') 77 | self.logger.error(err) 78 | # delay before each retry 79 | T = rd.random() * self.wait * 2 80 | time.sleep(T) 81 | 82 | return self.fail_msg if answer in ['', None] else answer 83 | 84 | APIBASES = { 85 | 'OFFICIAL': 'https://api.openai.com/v1/chat/completions', 86 | } 87 | 88 | 89 | def GPT_context_window(model): 90 | length_map = { 91 | 'gpt-4-0125-preview': 128000, 92 | 'gpt-4-1106-preview': 128000, 93 | 'gpt-4-vision-preview': 128000, 94 | 'gpt-4': 8192, 95 | 'gpt-4-32k': 32768, 96 | 'gpt-4-0613': 8192, 97 | 'gpt-4-32k-0613': 32768, 98 | 'gpt-3.5-turbo-1106': 16385, 99 | 'gpt-3.5-turbo': 4096, 100 | 'gpt-3.5-turbo-16k': 16385, 101 | 'gpt-3.5-turbo-instruct': 4096, 102 | 'gpt-3.5-turbo-0613': 4096, 103 | 'gpt-3.5-turbo-16k-0613': 16385, 104 | } 105 | if model in length_map: 106 | return length_map[model] 107 | else: 108 | return 4096 109 | 110 | 111 | class OpenAIWrapper(BaseAPI): 112 | 113 | is_api: bool = True 114 | 115 | def __init__(self, 116 | model: str = 'gpt-3.5-turbo-0613', 117 | retry: int = 5, 118 | wait: int = 5, 119 | key: str = 'sk-mFPPhGU1FP8i3QoN017078FfC1Cf4e9bA076815e1917767e', 120 | verbose: bool = True, 121 | system_prompt: str = None, 122 | temperature: float = 0, 123 | timeout: int = 60, 124 | api_base: str = 'https://api1.zhtec.xyz/v1/chat/completions', 125 | max_tokens: int = 1024, 126 | img_size: int = 512, 127 | img_detail: str = 'low', 128 | **kwargs): 129 | 130 | self.model = model 131 | self.cur_idx = 0 132 | self.fail_msg = 'Failed to obtain answer via API. ' 133 | self.max_tokens = max_tokens 134 | self.temperature = temperature 135 | 136 | env_key = os.environ.get('OPENAI_API_KEY', '') 137 | openai_key = env_key if key is None else key 138 | self.openai_key = openai_key 139 | assert img_size > 0 or img_size == -1 140 | self.img_size = img_size 141 | assert img_detail in ['high', 'low'] 142 | self.img_detail = img_detail 143 | 144 | self.vision = False 145 | if model == 'gpt-4-vision-preview': 146 | self.vision = True 147 | self.timeout = timeout 148 | 149 | assert isinstance(openai_key, str) and openai_key.startswith('sk-'), ( 150 | f'Illegal openai_key {openai_key}. ' 151 | 'Please set the environment variable OPENAI_API_KEY to your openai key. ' 152 | ) 153 | super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs) 154 | 155 | if api_base in APIBASES: 156 | self.api_base = APIBASES[api_base] 157 | elif api_base.startswith('http'): 158 | self.api_base = api_base 159 | else: 160 | self.logger.error('Unknown API Base. ') 161 | sys.exit(-1) 162 | 163 | if 'OPENAI_API_BASE' in os.environ and os.environ['OPENAI_API_BASE'] != '': 164 | self.logger.error('Environment variable OPENAI_API_BASE is set. Will override the api_base arg. ') 165 | self.api_base = os.environ['OPENAI_API_BASE'] 166 | 167 | # inputs can be a lvl-2 nested list: [content1, content2, content3, ...] 168 | # content can be a string or a list of image & text 169 | def prepare_inputs(self, inputs): 170 | input_msgs = [] 171 | if self.system_prompt is not None: 172 | input_msgs.append(dict(role='system', content=self.system_prompt)) 173 | if isinstance(inputs, str): 174 | input_msgs.append(dict(role='user', content=inputs)) 175 | return input_msgs 176 | assert isinstance(inputs, list) 177 | dict_flag = [isinstance(x, dict) for x in inputs] 178 | if np.all(dict_flag): 179 | input_msgs.extend(inputs) 180 | return input_msgs 181 | str_flag = [isinstance(x, str) for x in inputs] 182 | if np.all(str_flag): 183 | img_flag = [x.startswith('http') or osp.exists(x) for x in inputs] 184 | if np.any(img_flag): 185 | content_list = [] 186 | for fl, msg in zip(img_flag, inputs): 187 | if not fl: 188 | content_list.append(dict(type='text', text=msg)) 189 | elif msg.startswith('http'): 190 | content_list.append(dict(type='image_url', image_url={'url': msg, 'detail': self.img_detail})) 191 | elif osp.exists(msg): 192 | from PIL import Image 193 | img = Image.open(msg) 194 | b64 = encode_image_to_base64(img, target_size=self.img_size) 195 | img_struct = dict(url=f'data:image/jpeg;base64,{b64}', detail=self.img_detail) 196 | content_list.append(dict(type='image_url', image_url=img_struct)) 197 | input_msgs.append(dict(role='user', content=content_list)) 198 | return input_msgs 199 | else: 200 | roles = ['user', 'assistant'] if len(inputs) % 2 == 1 else ['assistant', 'user'] 201 | roles = roles * len(inputs) 202 | for role, msg in zip(roles, inputs): 203 | input_msgs.append(dict(role=role, content=msg)) 204 | return input_msgs 205 | raise NotImplementedError('list of list prompt not implemented now. ') 206 | 207 | def generate_inner(self, inputs, **kwargs) -> str: 208 | input_msgs = self.prepare_inputs(inputs) 209 | temperature = kwargs.pop('temperature', self.temperature) 210 | max_tokens = kwargs.pop('max_tokens', self.max_tokens) 211 | 212 | context_window = GPT_context_window(self.model) 213 | max_tokens = min(max_tokens, context_window - self.get_token_len(inputs)) 214 | if 0 < max_tokens <= 100: 215 | self.logger.warning( 216 | 'Less than 100 tokens left, ' 217 | 'may exceed the context window with some additional meta symbols. ' 218 | ) 219 | if max_tokens <= 0: 220 | return 0, self.fail_msg + 'Input string longer than context window. ', 'Length Exceeded. ' 221 | 222 | headers = {'Content-Type': 'application/json', 'Authorization': f'Bearer {self.openai_key}'} 223 | payload = dict( 224 | model=self.model, 225 | messages=input_msgs, 226 | max_tokens=max_tokens, 227 | n=1, 228 | temperature=temperature, 229 | **kwargs) 230 | response = requests.post(self.api_base, headers=headers, data=json.dumps(payload), timeout=self.timeout * 1.1) 231 | ret_code = response.status_code 232 | ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code 233 | answer = self.fail_msg 234 | try: 235 | resp_struct = json.loads(response.text) 236 | answer = resp_struct['choices'][0]['message']['content'].strip() 237 | except: 238 | pass 239 | return ret_code, answer, response 240 | 241 | def get_token_len(self, inputs) -> int: 242 | import tiktoken 243 | try: 244 | enc = tiktoken.encoding_for_model(self.model) 245 | except: 246 | enc = tiktoken.encoding_for_model('gpt-4') 247 | if isinstance(inputs, str): 248 | if inputs.startswith('http') or osp.exists(inputs): 249 | return 65 if self.img_detail == 'low' else 130 250 | else: 251 | return len(enc.encode(inputs)) 252 | elif isinstance(inputs, dict): 253 | assert 'content' in inputs 254 | return self.get_token_len(inputs['content']) 255 | assert isinstance(inputs, list) 256 | res = 0 257 | for item in inputs: 258 | res += self.get_token_len(item) 259 | return res 260 | -------------------------------------------------------------------------------- /ada_leval/dataset.py: -------------------------------------------------------------------------------- 1 | from ada_leval.smp import * 2 | 3 | class StackSelect: 4 | 5 | def __init__(self, setting='1k', mode='normal'): 6 | data = load(f'data/stackselect_{setting}.json') 7 | self.setting = setting 8 | assert mode in ['normal', 'less'] 9 | if mode == 'normal': 10 | num = 1000 if int(setting[:-1]) < 32 else 200 11 | elif mode == 'less': 12 | num = 200 if int(setting[:-1]) < 32 else 50 13 | 14 | if num > 0: 15 | data = data[:num] 16 | for item in data: 17 | item['index'] = f"{item['question_id']}_{item['answer']}" 18 | self.data = data 19 | 20 | self.meta_prompt = """ 21 | You are an AI assistant. Your job is to find out the most helpful answer to a given question. 22 | Each time, you will be provided with a question and n answers to this question. 23 | Each answer begins with an 'A' and a number(e.g. A4), which represents its designation. 24 | You need to determine which answer is the most helpful one to the question. 25 | The case sample is shown below and you should give me the answer in the format exactly the same as the sample. \n 26 | However, you should NOT focus on the content of sample answer. \n 27 | Sample Input (format only): \n 28 | The question is given below. 29 | XXX(The content of question) 30 | Possible answers are given below. 31 | A1: 32 | XXX(The content of answer 1) 33 | A2: 34 | XXX(The content of answer 2) 35 | . 36 | . 37 | . 38 | An: 39 | XXX(The content of answer n) 40 | Now the answers are over, please decide which answer is the most helpful one to the question. 41 | You must give me only the designation of the MOST helpful answer. 42 | Sample Output (format only): \n 43 | Answer: The designation of the most helpful answer.(e.g. A4 means answer 4 is the most helpful answer) \n\n 44 | """ 45 | 46 | def __len__(self): 47 | return len(self.data) 48 | 49 | def get_meta(self): 50 | res = { 51 | 'index': [x['index'] for x in self.data], 52 | 'question': [x['question'] for x in self.data], 53 | 'answer': [x['answer'] for x in self.data], 54 | 'tags': [x['tags'] for x in self.data], 55 | 'num_choice': [len(x['all_answers']) for x in self.data] 56 | } 57 | return pd.DataFrame(res) 58 | 59 | def build_prompt(self, line): 60 | if isinstance(line, int): 61 | line = self.data[line] 62 | assert isinstance(line, dict) 63 | prompt = self.meta_prompt 64 | prompt += 'The question is given below.\n' 65 | prompt += line['question'] + '\n\n' 66 | prompt += 'Possible answers are given below.\n' 67 | all_answers = line['all_answers'] 68 | for j in range(1, len(all_answers) + 1): 69 | prompt += 'A' + str(j) + ':\n\n' + all_answers[j - 1] + '\n\n' 70 | 71 | prompt += """ 72 | Now the answers are over, please decide which answer is the most helpful one to the question. 73 | You must give me only the designation of the MOST helpful answer. 74 | """ 75 | return prompt 76 | 77 | def evaluate(self, df): 78 | assert 'prediction' in df and 'answer' in df and 'num_choice' in df 79 | 80 | def extract(line): 81 | nc = line['num_choice'] 82 | cands = [f'A{i}' for i in range(1, nc + 1)] 83 | finds = [line['prediction'].find(c) for c in cands] 84 | matched = sum([x >= 0 for x in finds]) 85 | if matched >= 1: 86 | for i in range(nc - 1, -1, -1): 87 | if finds[i] >= 0: 88 | return cands[i] 89 | else: 90 | cands = [str(i) for i in range(1, nc + 1)] 91 | finds = [line['prediction'].find(c) for c in cands] 92 | matched = sum([x >= 0 for x in finds]) 93 | if matched >= 1: 94 | for i in range(nc - 1, -1, -1): 95 | if finds[i] >= 0: 96 | return 'A' + cands[i] 97 | else: 98 | return '???' 99 | 100 | extracted = [extract(df.iloc[i]) for i in range(len(df))] 101 | df['extracted'] = extracted 102 | acc = np.mean([x == y for x, y in zip(df['extracted'], df['answer'])]) 103 | acc = 100 * acc 104 | print(f'StackSelect {self.setting} Accuracy: {acc:.1f}%') 105 | return acc 106 | 107 | 108 | class TextSort: 109 | 110 | def __init__(self, setting='1k', mode='normal'): 111 | data = load(f'data/textsort_{setting}.json') 112 | self.setting = setting 113 | assert mode in ['normal', 'less'] 114 | if mode == 'normal': 115 | num = 1000 if int(setting[:-1]) < 32 else 200 116 | elif mode == 'less': 117 | num = 200 if int(setting[:-1]) < 32 else 50 118 | 119 | if num > 0: 120 | data = data[:num] 121 | for item in data: 122 | book_id = item['book_id'] 123 | para_offset = item['para_offset'] 124 | item['index'] = f"{book_id}_{'_'.join([str(x) for x in para_offset])}" 125 | self.data = data 126 | 127 | def __len__(self): 128 | return len(self.data) 129 | 130 | def get_meta(self): 131 | res = { 132 | 'book_id': [x['book_id'] for x in self.data], 133 | 'para_offset': [x['para_offset'] for x in self.data], 134 | 'answer': [x['answer'] for x in self.data], 135 | 'index': [x['index'] for x in self.data] 136 | } 137 | return pd.DataFrame(res) 138 | 139 | def build_prompt(self, line): 140 | if isinstance(line, int): 141 | line = self.data[line] 142 | assert isinstance(line, dict) 143 | return line['prompt'] 144 | 145 | def evaluate(self, df): 146 | assert 'prediction' in df and 'answer' in df 147 | 148 | def is_subseq(needle, haystack): 149 | current_pos = 0 150 | for c in needle: 151 | current_pos = haystack.find(c, current_pos) + 1 152 | if current_pos == 0: 153 | return False 154 | return True 155 | 156 | def extract(line): 157 | pred = line['prediction'] 158 | if 'Answer:' in pred: 159 | pred = pred.split('Answer:')[1].strip() 160 | try: 161 | pred = json.loads(pred) 162 | return pred 163 | except: 164 | import itertools 165 | perms = list(itertools.permutations(range(1, 5))) 166 | perms = [''.join([str(x) for x in p]) for p in perms] 167 | subseq = [is_subseq(p, pred) for p in perms] 168 | if sum(subseq) == 1: 169 | for p, s in zip(perms, subseq): 170 | if s: 171 | return [int(pp) for pp in p] 172 | return [0, 0, 0, 0] 173 | 174 | extracted = [extract(df.iloc[i]) for i in range(len(df))] 175 | answers = [json.loads(x) if isinstance(x, str) else x for x in df['answer']] 176 | hit, tot = 0, 0 177 | 178 | for a, e in zip(answers, extracted): 179 | tot += 1 180 | flag = True 181 | for aa, ee in zip(a, e): 182 | if aa != ee: 183 | flag = False 184 | hit += flag 185 | 186 | acc = hit / tot 187 | acc = 100 * acc 188 | print(f'TextSort {self.setting} Accuracy: {acc:.1f}%') 189 | return acc -------------------------------------------------------------------------------- /ada_leval/smp.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa: F401, F403 2 | import abc 3 | import argparse 4 | import collections 5 | import csv 6 | import json 7 | import multiprocessing as mp 8 | import numpy as np 9 | import os, sys, time, base64, io 10 | import os.path as osp 11 | import copy as cp 12 | import pickle 13 | import random as rd 14 | import requests 15 | import shutil 16 | import string 17 | import subprocess 18 | import warnings 19 | import pandas as pd 20 | from collections import OrderedDict, defaultdict 21 | from multiprocessing import Pool, current_process 22 | from tqdm import tqdm 23 | from PIL import Image 24 | import uuid 25 | from uuid import uuid4 26 | from datetime import datetime 27 | import matplotlib.pyplot as plt 28 | import seaborn as sns 29 | from tabulate import tabulate 30 | 31 | def d2df(D): 32 | return pd.DataFrame({x: [D[x]] for x in D}) 33 | 34 | def LMUDataRoot(): 35 | if 'LMUData' in os.environ and osp.exists(os.environ['LMUData']): 36 | return os.environ['LMUData'] 37 | home = osp.expanduser('~') 38 | root = osp.join(home, 'LMUData') 39 | os.makedirs(root, exist_ok=True) 40 | return root 41 | 42 | def cn_string(s): 43 | import re 44 | if re.search(u'[\u4e00-\u9fff]', s): 45 | return True 46 | return False 47 | 48 | def timestr(second=False, minute=False): 49 | s = datetime.now().strftime('%Y%m%d%H%M%S')[2:] 50 | if second: 51 | return s 52 | elif minute: 53 | return s[:-2] 54 | else: 55 | return s[:-4] 56 | 57 | def num2uuid(num): 58 | rd.seed(num) 59 | return str(uuid.UUID(int=rd.getrandbits(128), version=4)) 60 | 61 | def randomuuid(): 62 | seed = rd.randint(0, 2 ** 32 - 1) 63 | return num2uuid(seed) 64 | 65 | def mrlines(fname, sp='\n'): 66 | f = open(fname).read().split(sp) 67 | while f != [] and f[-1] == '': 68 | f = f[:-1] 69 | return f 70 | 71 | def mwlines(lines, fname): 72 | with open(fname, 'w') as fout: 73 | fout.write('\n'.join(lines)) 74 | 75 | def default_set(self, args, name, default): 76 | if hasattr(args, name): 77 | val = getattr(args, name) 78 | setattr(self, name, val) 79 | else: 80 | setattr(self, name, default) 81 | 82 | def dict_merge(dct, merge_dct): 83 | for k, _ in merge_dct.items(): 84 | if (k in dct and isinstance(dct[k], dict) and isinstance(merge_dct[k], dict)): #noqa 85 | dict_merge(dct[k], merge_dct[k]) 86 | else: 87 | dct[k] = merge_dct[k] 88 | 89 | def youtube_dl(idx): 90 | cmd = f'youtube-dl -f best -f mp4 "{idx}" -o {idx}.mp4' 91 | os.system(cmd) 92 | 93 | def run_command(cmd): 94 | if isinstance(cmd, str): 95 | cmd = cmd.split() 96 | return subprocess.check_output(cmd) 97 | 98 | def ls(dirname='.', match='', mode='all', level=1): 99 | if dirname == '.': 100 | ans = os.listdir(dirname) 101 | else: 102 | ans = [osp.join(dirname, x) for x in os.listdir(dirname)] 103 | assert mode in ['all', 'dir', 'file'] 104 | assert level >= 1 and isinstance(level, int) 105 | if level == 1: 106 | ans = [x for x in ans if match in x] 107 | if mode == 'dir': 108 | ans = [x for x in ans if osp.isdir(x)] 109 | elif mode == 'file': 110 | ans = [x for x in ans if not osp.isdir(x)] 111 | else: 112 | ans = [x for x in ans if osp.isdir(x)] 113 | res = [] 114 | for d in ans: 115 | res.extend(ls(d, match=match, mode=mode, level=level-1)) 116 | ans = res 117 | return ans 118 | 119 | def intop(pred, label, n): 120 | pred = [np.argsort(x)[-n:] for x in pred] 121 | hit = [(l in p) for l, p in zip(label, pred)] 122 | return hit 123 | 124 | def topk(score, label, k=1): 125 | return np.mean(intop(score, label, k)) if isinstance(k, int) else [topk(score, label, kk) for kk in k] 126 | 127 | def download_file(url, filename=None): 128 | if filename is None: 129 | filename = url.split('/')[-1] 130 | response = requests.get(url) 131 | open(filename, 'wb').write(response.content) 132 | 133 | def fnp(model, input=None): 134 | from fvcore.nn import FlopCountAnalysis, parameter_count 135 | params = parameter_count(model)[''] 136 | print('Parameter Size: {:.4f} M'.format(params / 1024 / 1024)) 137 | if input is not None: 138 | flops = FlopCountAnalysis(model, input).total() 139 | print('FLOPs: {:.4f} G'.format(flops / 1024 / 1024 / 1024)) 140 | return params, flops 141 | return params, None 142 | 143 | def proxy_set(s): 144 | import os 145 | for key in ['http_proxy', 'HTTP_PROXY', 'https_proxy', 'HTTPS_PROXY']: 146 | os.environ[key] = s 147 | 148 | # LOAD & DUMP 149 | def dump(data, f, **kwargs): 150 | def dump_pkl(data, pth, **kwargs): 151 | pickle.dump(data, open(pth, 'wb')) 152 | 153 | def dump_json(data, pth, **kwargs): 154 | json.dump(data, open(pth, 'w'), indent=4, ensure_ascii=False) 155 | 156 | def dump_jsonl(data, f, **kwargs): 157 | lines = [json.dumps(x, ensure_ascii=False) for x in data] 158 | with open(f, 'w', encoding='utf8') as fout: 159 | fout.write('\n'.join(lines)) 160 | 161 | def dump_xlsx(data, f, **kwargs): 162 | data.to_excel(f, index=False) 163 | 164 | def dump_csv(data, f, quoting=csv.QUOTE_MINIMAL): 165 | data.to_csv(f, index=False, encoding='utf-8', quoting=quoting) 166 | 167 | def dump_tsv(data, f, quoting=csv.QUOTE_MINIMAL): 168 | data.to_csv(f, sep='\t', index=False, encoding='utf-8', quoting=quoting) 169 | 170 | handlers = dict(pkl=dump_pkl, json=dump_json, jsonl=dump_jsonl, xlsx=dump_xlsx, csv=dump_csv, tsv=dump_tsv) 171 | suffix = f.split('.')[-1] 172 | return handlers[suffix](data, f, **kwargs) 173 | 174 | import portalocker 175 | def safe_dump(data, f, **kwargs): 176 | with portalocker.Lock(f, timeout=5) as fh: 177 | dump(data, f, **kwargs) 178 | fh.flush() 179 | os.fsync(fh.fileno()) 180 | 181 | def load(f): 182 | def load_pkl(pth): 183 | return pickle.load(open(pth, 'rb')) 184 | 185 | def load_json(pth): 186 | return json.load(open(pth, 'r', encoding='utf-8')) 187 | 188 | def load_jsonl(f): 189 | lines = open(f, encoding='utf-8').readlines() 190 | lines = [x.strip() for x in lines] 191 | if lines[-1] == '': 192 | lines = lines[:-1] 193 | data = [json.loads(x) for x in lines] 194 | return data 195 | 196 | def load_xlsx(f): 197 | return pd.read_excel(f) 198 | 199 | def load_csv(f): 200 | return pd.read_csv(f) 201 | 202 | def load_tsv(f): 203 | return pd.read_csv(f, sep='\t') 204 | 205 | handlers = dict(pkl=load_pkl, json=load_json, jsonl=load_jsonl, xlsx=load_xlsx, csv=load_csv, tsv=load_tsv) 206 | suffix = f.split('.')[-1] 207 | return handlers[suffix](f) -------------------------------------------------------------------------------- /ada_leval/util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logger_initialized = {} 4 | 5 | 6 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'): 7 | logger = logging.getLogger(name) 8 | if name in logger_initialized: 9 | return logger 10 | 11 | for logger_name in logger_initialized: 12 | if name.startswith(logger_name): 13 | return logger 14 | 15 | stream_handler = logging.StreamHandler() 16 | handlers = [stream_handler] 17 | 18 | try: 19 | import torch.distributed as dist 20 | if dist.is_available() and dist.is_initialized(): 21 | rank = dist.get_rank() 22 | else: 23 | rank = 0 24 | except ImportError: 25 | rank = 0 26 | 27 | if rank == 0 and log_file is not None: 28 | file_handler = logging.FileHandler(log_file, file_mode) 29 | handlers.append(file_handler) 30 | 31 | formatter = logging.Formatter( 32 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s') 33 | for handler in handlers: 34 | handler.setFormatter(formatter) 35 | handler.setLevel(log_level) 36 | logger.addHandler(handler) 37 | 38 | if rank == 0: 39 | logger.setLevel(log_level) 40 | else: 41 | logger.setLevel(logging.ERROR) 42 | 43 | logger_initialized[name] = True 44 | return logger 45 | 46 | 47 | from multiprocessing import Pool 48 | import os 49 | from typing import Callable, Iterable, Sized 50 | 51 | from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task, 52 | TaskProgressColumn, TextColumn, TimeRemainingColumn) 53 | from rich.text import Text 54 | import os.path as osp 55 | import portalocker 56 | from .smp import load, dump 57 | 58 | 59 | class _Worker: 60 | """Function wrapper for ``track_progress_rich``""" 61 | 62 | def __init__(self, func) -> None: 63 | self.func = func 64 | 65 | def __call__(self, inputs): 66 | inputs, idx = inputs 67 | if not isinstance(inputs, (tuple, list, dict)): 68 | inputs = (inputs, ) 69 | 70 | if isinstance(inputs, dict): 71 | return self.func(**inputs), idx 72 | else: 73 | return self.func(*inputs), idx 74 | 75 | 76 | class _SkipFirstTimeRemainingColumn(TimeRemainingColumn): 77 | """Skip calculating remaining time for the first few times. 78 | 79 | Args: 80 | skip_times (int): The number of times to skip. Defaults to 0. 81 | """ 82 | 83 | def __init__(self, *args, skip_times=0, **kwargs): 84 | super().__init__(*args, **kwargs) 85 | self.skip_times = skip_times 86 | 87 | def render(self, task: Task) -> Text: 88 | """Show time remaining.""" 89 | if task.completed <= self.skip_times: 90 | return Text('-:--:--', style='progress.remaining') 91 | return super().render(task) 92 | 93 | 94 | def _tasks_with_index(tasks): 95 | """Add index to tasks.""" 96 | for idx, task in enumerate(tasks): 97 | yield task, idx 98 | 99 | 100 | def track_progress_rich(func: Callable, 101 | tasks: Iterable = tuple(), 102 | task_num: int = None, 103 | nproc: int = 1, 104 | chunksize: int = 1, 105 | description: str = 'Processing', 106 | save=None, keys=None, 107 | color: str = 'blue') -> list: 108 | """Track the progress of parallel task execution with a progress bar. The 109 | built-in :mod:`multiprocessing` module is used for process pools and tasks 110 | are done with :func:`Pool.map` or :func:`Pool.imap_unordered`. 111 | 112 | Args: 113 | func (callable): The function to be applied to each task. 114 | tasks (Iterable or Sized): A tuple of tasks. There are several cases 115 | for different format tasks: 116 | - When ``func`` accepts no arguments: tasks should be an empty 117 | tuple, and ``task_num`` must be specified. 118 | - When ``func`` accepts only one argument: tasks should be a tuple 119 | containing the argument. 120 | - When ``func`` accepts multiple arguments: tasks should be a 121 | tuple, with each element representing a set of arguments. 122 | If an element is a ``dict``, it will be parsed as a set of 123 | keyword-only arguments. 124 | Defaults to an empty tuple. 125 | task_num (int, optional): If ``tasks`` is an iterator which does not 126 | have length, the number of tasks can be provided by ``task_num``. 127 | Defaults to None. 128 | nproc (int): Process (worker) number, if nuproc is 1, 129 | use single process. Defaults to 1. 130 | chunksize (int): Refer to :class:`multiprocessing.Pool` for details. 131 | Defaults to 1. 132 | description (str): The description of progress bar. 133 | Defaults to "Process". 134 | color (str): The color of progress bar. Defaults to "blue". 135 | 136 | Examples: 137 | >>> import time 138 | 139 | >>> def func(x): 140 | ... time.sleep(1) 141 | ... return x**2 142 | >>> track_progress_rich(func, range(10), nproc=2) 143 | 144 | Returns: 145 | list: The task results. 146 | """ 147 | if save is not None: 148 | assert osp.exists(osp.dirname(save)) or osp.dirname(save) == '' 149 | if not osp.exists(save): 150 | dump({}, save) 151 | if keys is not None: 152 | assert len(keys) == len(tasks) 153 | 154 | if not callable(func): 155 | raise TypeError('func must be a callable object') 156 | if not isinstance(tasks, Iterable): 157 | raise TypeError( 158 | f'tasks must be an iterable object, but got {type(tasks)}') 159 | if isinstance(tasks, Sized): 160 | if len(tasks) == 0: 161 | if task_num is None: 162 | raise ValueError('If tasks is an empty iterable, ' 163 | 'task_num must be set') 164 | else: 165 | tasks = tuple(tuple() for _ in range(task_num)) 166 | else: 167 | if task_num is not None and task_num != len(tasks): 168 | raise ValueError('task_num does not match the length of tasks') 169 | task_num = len(tasks) 170 | 171 | if nproc <= 0: 172 | raise ValueError('nproc must be a positive number') 173 | 174 | skip_times = nproc * chunksize if nproc > 1 else 0 175 | prog_bar = Progress( 176 | TextColumn('{task.description}'), 177 | BarColumn(), 178 | _SkipFirstTimeRemainingColumn(skip_times=skip_times), 179 | MofNCompleteColumn(), 180 | TaskProgressColumn(show_speed=True), 181 | ) 182 | 183 | worker = _Worker(func) 184 | task_id = prog_bar.add_task( 185 | total=task_num, color=color, description=description) 186 | tasks = _tasks_with_index(tasks) 187 | 188 | # Use single process when nproc is 1, else use multiprocess. 189 | with prog_bar: 190 | if nproc == 1: 191 | results = [] 192 | for task in tasks: 193 | result, idx = worker(task) 194 | results.append(worker(task)[0]) 195 | if save is not None: 196 | with portalocker.Lock(save, timeout=5) as fh: 197 | ans = load(save) 198 | ans[keys[idx]] = result 199 | 200 | if os.environ.get('VERBOSE', True): 201 | print(keys[idx], result, flush=True) 202 | 203 | dump(ans, save) 204 | fh.flush() 205 | os.fsync(fh.fileno()) 206 | 207 | prog_bar.update(task_id, advance=1, refresh=True) 208 | else: 209 | with Pool(nproc) as pool: 210 | results = [] 211 | unordered_results = [] 212 | gen = pool.imap_unordered(worker, tasks, chunksize) 213 | try: 214 | for result in gen: 215 | result, idx = result 216 | unordered_results.append((result, idx)) 217 | 218 | if save is not None: 219 | with portalocker.Lock(save, timeout=5) as fh: 220 | ans = load(save) 221 | ans[keys[idx]] = result 222 | 223 | if os.environ.get('VERBOSE', False): 224 | print(keys[idx], result, flush=True) 225 | 226 | dump(ans, save) 227 | fh.flush() 228 | os.fsync(fh.fileno()) 229 | 230 | results.append(None) 231 | prog_bar.update(task_id, advance=1, refresh=True) 232 | except Exception as e: 233 | prog_bar.stop() 234 | raise e 235 | for result, idx in unordered_results: 236 | results[idx] = result 237 | return results 238 | 239 | 240 | def get_rank_and_world_size(): 241 | local_rank = int(os.environ.get('LOCAL_RANK', 0)) 242 | world_size = int(os.environ.get('WORLD_SIZE', 1)) 243 | return local_rank, world_size -------------------------------------------------------------------------------- /assets/AdaLEval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-compass/Ada-LEval/2154258d5fa3969ac5429b3132d505570ef8a57a/assets/AdaLEval.png -------------------------------------------------------------------------------- /assets/BestAnswer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-compass/Ada-LEval/2154258d5fa3969ac5429b3132d505570ef8a57a/assets/BestAnswer.png -------------------------------------------------------------------------------- /fetch_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | mkdir -p data 3 | for i in 1k 2k 4k 6k 8k 12k 16k 32k 64k 128k 4 | do 5 | wget http://opencompass.openxlab.space/utils/AdaLEval/stackselect_$i.json -O data/stackselect_$i.json;\ 6 | done 7 | for i in 1k 2k 4k 8k 16k 32k 64k 128k 8 | do 9 | wget http://opencompass.openxlab.space/utils/AdaLEval/textsort_$i.json -O data/textsort_$i.json;\ 10 | done 11 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | from ada_leval.smp import * 2 | from ada_leval.util import * 3 | from ada_leval.api import OpenAIWrapper 4 | from ada_leval.dataset import StackSelect, TextSort 5 | 6 | RESULT_FILE = 'result.json' 7 | if not osp.exists(RESULT_FILE): 8 | dump({}, RESULT_FILE) 9 | 10 | settings = ['1k', '2k', '4k', '8k', '16k', '32k', '64k', '128k'] 11 | datasets = [f'stackselect_{k}' for k in settings + ['6k', '12k']] + [f'textsort_{k}' for k in settings] 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--data', type=str, nargs='+', required=True, choices=datasets) 16 | parser.add_argument('--model', type=str, required=True, choices=['internlm2-7b', 'internlm2-20b', 'gpt-4-0125']) 17 | parser.add_argument('--mode', type=str, default='all', choices=['infer', 'all']) 18 | parser.add_argument('--nproc', type=int, default=4) 19 | args = parser.parse_args() 20 | return args 21 | 22 | def build_model(m, setting=None): 23 | if 'internlm2' in m: 24 | session_len = 160000 25 | from lmdeploy import pipeline, TurbomindEngineConfig 26 | backend_config = TurbomindEngineConfig(rope_scaling_factor=2.0, session_len=session_len) 27 | 28 | if m == 'gpt-4-0125': 29 | model = OpenAIWrapper('gpt-4-0125-preview') 30 | elif m == 'internlm2-7b': 31 | model = pipeline('internlm/internlm2-chat-7b', backend_config=backend_config) 32 | elif m == 'internlm2-20b': 33 | model = pipeline('internlm/internlm2-chat-20b', backend_config=backend_config) 34 | return model 35 | 36 | import tiktoken 37 | ENC = tiktoken.encoding_for_model('gpt-4') 38 | 39 | def get_token_length(prompt): 40 | return len(ENC.encode(prompt)) 41 | 42 | def main(): 43 | rank, world_size = get_rank_and_world_size() 44 | if world_size > 1: 45 | import torch 46 | import torch.distributed as dist 47 | torch.cuda.set_device(rank) 48 | dist.init_process_group(backend='nccl') 49 | 50 | if rank == 0: 51 | os.makedirs('results', exist_ok=True) 52 | 53 | args = parse_args() 54 | model_name = args.model 55 | model = build_model(args.model) 56 | for dname in args.data: 57 | d, setting = dname.split('_') 58 | dataset_mode = 'less' if getattr(model, 'is_api', False) else 'normal' 59 | 60 | if d == 'stackselect': 61 | dataset = StackSelect(setting=setting, mode=dataset_mode) 62 | elif d == 'textsort': 63 | dataset = TextSort(setting=setting, mode=dataset_mode) 64 | 65 | lt = len(dataset) 66 | prompts = [dataset.build_prompt(i) for i in range(lt)] 67 | meta = dataset.get_meta() 68 | indices = list(meta['index']) 69 | 70 | out_file = f'results/{model_name}_{dname}.pkl' 71 | res = {} if not osp.exists(out_file) else load(out_file) 72 | tups = [(i, p) for i, p in zip(indices, prompts) if i not in res] 73 | 74 | if len(tups): 75 | if getattr(model, 'is_api', False): 76 | res = track_progress_rich( 77 | model.generate, 78 | [x[1] for x in tups], 79 | nproc=args.nproc, 80 | chunksize=args.nproc, 81 | save=out_file, 82 | keys=[x[0] for x in tups]) 83 | else: 84 | sub_tups = tups[rank::world_size] 85 | sub_out_file = f'results/{model_name}_{dname}_{rank}.pkl' 86 | sub_res = {} 87 | import torch 88 | with torch.no_grad(): 89 | for t in tqdm(sub_tups): 90 | index, prompt = t 91 | sub_res[index] = model(prompt).text 92 | dump(sub_res, sub_out_file) 93 | 94 | if world_size > 1: 95 | dist.barrier() 96 | 97 | if rank == 0: 98 | res = {} 99 | for i in range(world_size): 100 | sub_out_file = f'results/{model_name}_{dname}_{i}.pkl' 101 | if osp.exists(sub_out_file): 102 | res.update(load(sub_out_file)) 103 | if osp.exists(out_file): 104 | res.update(load(out_file)) 105 | dump(res, out_file) 106 | 107 | res = load(out_file) 108 | meta['prediction'] = [res[k] for k in meta['index']] 109 | dump(meta, f'results/{model_name}_{dname}.xlsx') 110 | 111 | if args.mode == 'all': 112 | results = load(RESULT_FILE) 113 | acc = dataset.evaluate(meta) 114 | results[f'{model_name}_{dname}'] = acc 115 | dump(results, RESULT_FILE) 116 | 117 | if world_size > 1: 118 | dist.barrier() 119 | os.system(f"rm {f'results/{model_name}_{dname}_{rank}.pkl'}") 120 | 121 | if __name__ == '__main__': 122 | main() -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -x 3 | export GPU=$(nvidia-smi --list-gpus | wc -l) 4 | torchrun --nproc-per-node=$GPU run.py ${@:1} -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | REQUIRES = """ 4 | numpy>=1.23.4 5 | openai 6 | requests 7 | tqdm 8 | pandas>=1.5.3 9 | tiktoken 10 | rich 11 | portalocker 12 | pillow 13 | matplotlib 14 | seaborn 15 | tabulate 16 | """ 17 | 18 | 19 | def get_install_requires(): 20 | reqs = [req for req in REQUIRES.split('\n') if len(req) > 0] 21 | return reqs 22 | 23 | 24 | with open('README.md') as f: 25 | readme = f.read() 26 | 27 | 28 | def do_setup(): 29 | setup( 30 | name='ada_leval', 31 | version='0.1.0', 32 | description='The official implementation for Ada-LEval', 33 | # url="", 34 | author="Haodong Duan", 35 | long_description=readme, 36 | long_description_content_type='text/markdown', 37 | cmdclass={}, 38 | install_requires=get_install_requires(), 39 | setup_requires=[], 40 | python_requires='>=3.7.0', 41 | packages=find_packages(exclude=[ 42 | 'test*', 43 | 'paper_test*', 44 | ]), 45 | keywords=['AI', 'NLP', 'in-context learning'], 46 | entry_points={ 47 | "console_scripts": [] 48 | }, 49 | classifiers=[ 50 | 'Programming Language :: Python :: 3.7', 51 | 'Programming Language :: Python :: 3.8', 52 | 'Programming Language :: Python :: 3.9', 53 | 'Programming Language :: Python :: 3.10', 54 | 'Intended Audience :: Developers', 55 | 'Intended Audience :: Education', 56 | 'Intended Audience :: Science/Research', 57 | ]) 58 | 59 | 60 | if __name__ == '__main__': 61 | do_setup() 62 | --------------------------------------------------------------------------------