├── .gitignore
├── README.md
├── ada_leval
├── api.py
├── dataset.py
├── smp.py
└── util.py
├── assets
├── AdaLEval.png
└── BestAnswer.png
├── fetch_data.sh
├── run.py
├── run.sh
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # Images
156 | images/
157 |
158 | scripts/*ttf
159 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Ada-LEval
2 |
3 | **The official implementation of ["Ada-LEval: Evaluating long-context LLMs with length-adaptable benchmarks"](https://arxiv.org/abs/2404.06480)**
4 |
5 |
6 |
7 |
8 |
9 | **Ada-LEval** is a pioneering benchmark to assess the long-context capabilities with length-adaptable questions. It comprises two challenging tasks: **TSort**, which involves arranging text segments into the correct order, and **BestAnswer**, which requires choosing the best answer of a question among multiple candidates.
10 |
11 | Both tasks feature the following advantages:
12 | 1. **Controllable Test Cases**: The length of each test case can be finely tuned - by adjusting the number and length of text segments in TSort and altering the number of distractor options in BestAnswer.
13 | 2. **Necessity for Full-Text Comprehension**: Successful completion of both tasks mandates complete reading and understanding of the provided text.
14 | 3. **Precise Accuracy Measurement**: The design of these tasks allows for unambiguous accuracy calculation. TSort has a definitive 'correct' order, while in BestAnswer, the annotated responses by the questioner serve as definitive answers.
15 |
16 |
17 |
18 |
19 |
20 | ## 🛠️QuickStart
21 |
22 | In this repo, we implement the evaluation of Ada-LEval on GPT-4-Turbo-0125 (an example for APIs) and internlm2-[7b/20b] (an example for opensource LLMs). You can follow our implementation to evaluate Ada-LEval on your custom LLMs.
23 |
24 | 1. **Preparation**
25 |
26 | 1. Installation and data preparation
27 |
28 | ```bash
29 | cd Ada-LEval
30 | pip install -e .
31 | bash fetch_data.sh
32 | ```
33 |
34 | 2. For evaluating GPT-4, please set the environment variable: `export OPENAI_API_KEY=sk-xxxxx`
35 |
36 | - Cost Estimation for GPT-4-Turbo-0125: `setting (2k, 4k, etc.) * n_samples * $0.01 / 1000`
37 |
38 | 3. For evaluating InternLM2-7B, please follow the [official guide](https://github.com/InternLM/lmdeploy) to install LMDeploy.
39 |
40 | 2. **Evaluate GPT-4-Turbo-0125**: `python run.py --data {dataset_name} --model gpt-4-0125`
41 |
42 | 3. **Evaluate InternLM2-7B**: `bash run.sh --data {dataset_name} --model internlm2-7b`
43 |
44 | \* `dataset_name` can be `stackselect_{setting}` (for **BestAnswer**) or `textsort_{setting}` (for **TSort**). For example, `stackselect_16k`, `textsort_2k`, etc.
45 |
46 | \** `run.sh` detect the number of available GPUs and do the data parallel.
47 |
48 | ## 📊Evaluation Result
49 | Here is the evaluation result of TSort and BestAnswer benchmark under **long-context** & **ultra-long-context** settings. We also provide a 'random guess' baseline for each task.
50 |
51 | **Definition:** long-context -> context window < 32k; ultra-long-context: context-window >= 32k
52 |
53 | **The Number of Evaluation Samples:** 1. API models on long-context: 200; 2. API models on ultra-long-context: 50; 3. Open-source models on long-context: 1000; 4. Open-source models on ultra-long-context: 200.
54 |
55 | #### TL;DR:
56 |
57 | 1. **TSort is an extremely challenging benchmark:** We observe positive results (significantly better than random guess) only when evaluating SOTA API models (GPT-4 series) under short context settings (< 8k).
58 | 2. **BestAnswer is a challenging long-context benchmark with discrimination:** With 32k long-context, GPT-4-Turbo-0125 still obtains a decent 30% accuracy, while other models significantly lag behind. When the context window is 64k or even longer, models failed to solve almost all of the questions.
59 |
60 | #### TSort Evaluation Results
61 |
62 | Blanks indicate the result under the corresponding setting is not evaluated.
63 |
64 | | TSort | 2k | 4k | 8k | 16k | 32k | 64k | 128k |
65 | | -------------------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
66 | | GPT-4-Turbo-0125 | 15.5 | 16.5 | 8.5 | 5.5 | 2.0 | 4.0 | 2.0 |
67 | | GPT-4-Turbo-1106 | 18.5 | 15.5 | 7.5 | 3.5 | 6.0 | 6.0 | 6.0 |
68 | | GPT-3.5-Turbo-1106 | 4.0 | 4.5 | 4.5 | 5.5 | | | |
69 | | Claude-2 | 5.0 | 5.0 | 4.5 | 3.0 | 0.0 | 0.0 | |
70 | | LongChat-7b-v1.5-32k | 5.3 | 5.0 | 3.1 | 2.5 | | | |
71 | | ChatGLM2-6B-32k | 0.9 | 0.7 | 0.2 | 0.9 | | | |
72 | | ChatGLM3-6B-32k | 2.3 | 2.4 | 2.0 | 0.7 | | | |
73 | | Vicuna-7b-v1.5-16k | 5.3 | 2.2 | 2.3 | 1.7 | | | |
74 | | Vicuna-13b-v1.5-16k | 5.4 | 5.0 | 2.4 | 3.1 | | | |
75 | | InternLM2-7b | 5.1 | 3.9 | 5.1 | 4.3 | | | |
76 | | Random Guess | 4.2 | 4.2 | 4.2 | 4.2 | 4.2 | 4.2 | 4.2 |
77 |
78 | #### BestAnswer Evaluation Results
79 |
80 | Blanks indicate the result under the corresponding setting is not evaluated.
81 |
82 | | BestAnswer | 1k | 2k | 4k | 6k | 8k | 12k | 16k | 32k | 64k | 128k |
83 | | -------------------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
84 | | GPT-4-Turbo-0125 | 73.5 | 73.5 | 65.5 | 63.0 | 56.5 | 52.0 | 44.5 | 30.0 | 0.0 | 0.0 |
85 | | GPT-4-Turbo-1106 | 74.0 | 73.5 | 67.5 | 59.5 | 53.5 | 49.5 | 44.0 | 16.0 | 0.0 | 0.0 |
86 | | GPT-3.5-Turbo-1106 | 61.5 | 48.5 | 41.5 | 29.5 | 17.0 | 2.5 | 2.5 | | | |
87 | | Claude-2 | 65.0 | 43.5 | 23.5 | 15.0 | 17.0 | 12.0 | 11.0 | 4.0 | 0.0 | |
88 | | LongChat-7b-v1.5-32k | 32.4 | 10.7 | 5.7 | 3.1 | 1.9 | 1.6 | 0.8 | | | |
89 | | ChatGLM2-6B-32k | 31.2 | 10.9 | 4.5 | 1.6 | 1.6 | 0.0 | 0.3 | | | |
90 | | ChatGLM3-6B-32k | 39.8 | 18.8 | 9.0 | 5.0 | 3.4 | 0.9 | 0.5 | | | |
91 | | Vicuna-7b-v1.5-16k | 37.0 | 11.1 | 5.8 | 3.2 | 1.8 | 1.9 | 1.0 | | | |
92 | | Vicuna-13b-v1.5-16k | 53.4 | 29.2 | 13.1 | 4.3 | 2.2 | 1.4 | 0.9 | | | |
93 | | InternLM2-7b | 58.6 | 49.5 | 33.9 | 12.3 | 13.4 | 2.0 | 0.8 | 0.5 | 0.5 | 0.0 |
94 | | Random Guess | 26.7 | 10.1 | 4.5 | 3.0 | 2.3 | 1.4 | 1.1 | 0.6 | 0.3 | 0.1 |
95 |
96 | ## 🖊️Citation
97 |
98 | ```bib
99 | @inproceedings{wang2024ada,
100 | title={Ada-LEval: Evaluating long-context LLMs with length-adaptable benchmarks},
101 | author={Wang, Chonghua and Duan, Haodong and Zhang, Songyang and Lin, Dahua and Chen, Kai},
102 | booktitle={Proceedings of the 2024 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (Volume 1: Long Papers)},
103 | pages={3712--3724},
104 | year={2024}
105 | }
106 | ```
107 |
--------------------------------------------------------------------------------
/ada_leval/api.py:
--------------------------------------------------------------------------------
1 | import time
2 | import random as rd
3 | from abc import abstractmethod
4 | from .util import get_logger
5 | from .smp import *
6 |
7 |
8 | class BaseAPI:
9 |
10 | def __init__(self,
11 | retry=10,
12 | wait=3,
13 | system_prompt=None,
14 | verbose=True,
15 | fail_msg='Failed to obtain answer via API.',
16 | **kwargs):
17 | self.wait = wait
18 | self.retry = retry
19 | self.system_prompt = system_prompt
20 | self.kwargs = kwargs
21 | self.verbose = verbose
22 | self.fail_msg = fail_msg
23 | self.logger = get_logger('ChatAPI')
24 | if len(kwargs):
25 | self.logger.info(f'BaseAPI received the following kwargs: {kwargs}')
26 | self.logger.info('Will try to use them as kwargs for `generate`. ')
27 |
28 | @abstractmethod
29 | def generate_inner(self, inputs, **kwargs):
30 | self.logger.warning('For APIBase, generate_inner is an abstract method. ')
31 | assert 0, 'generate_inner not defined'
32 | ret_code, answer, log = None, None, None
33 | # if ret_code is 0, means succeed
34 | return ret_code, answer, log
35 |
36 | def working(self):
37 | retry = 3
38 | while retry > 0:
39 | ret = self.generate('hello')
40 | if ret is not None and ret != '' and self.fail_msg not in ret:
41 | return True
42 | retry -= 1
43 | return False
44 |
45 | def generate(self, inputs, **kwargs):
46 | input_type = None
47 | if isinstance(inputs, str):
48 | input_type = 'str'
49 | elif isinstance(inputs, list) and isinstance(inputs[0], str):
50 | input_type = 'strlist'
51 | elif isinstance(inputs, list) and isinstance(inputs[0], dict):
52 | input_type = 'dictlist'
53 | assert input_type is not None, input_type
54 |
55 | answer = None
56 | # a very small random delay [0s - 0.5s]
57 | T = rd.random() * 0.5
58 | time.sleep(T)
59 |
60 | for i in range(self.retry):
61 | try:
62 | ret_code, answer, log = self.generate_inner(inputs, **kwargs)
63 | if ret_code == 0 and self.fail_msg not in answer and answer != '':
64 | if self.verbose:
65 | print(answer)
66 | return answer
67 | elif self.verbose:
68 | if not isinstance(log, str):
69 | try:
70 | log = log.text
71 | except:
72 | self.logger.warning(f'Failed to parse {log} as an http response. ')
73 | self.logger.info(f'RetCode: {ret_code}\nAnswer: {answer}\nLog: {log}')
74 | except Exception as err:
75 | if self.verbose:
76 | self.logger.error(f'An error occured during try {i}:')
77 | self.logger.error(err)
78 | # delay before each retry
79 | T = rd.random() * self.wait * 2
80 | time.sleep(T)
81 |
82 | return self.fail_msg if answer in ['', None] else answer
83 |
84 | APIBASES = {
85 | 'OFFICIAL': 'https://api.openai.com/v1/chat/completions',
86 | }
87 |
88 |
89 | def GPT_context_window(model):
90 | length_map = {
91 | 'gpt-4-0125-preview': 128000,
92 | 'gpt-4-1106-preview': 128000,
93 | 'gpt-4-vision-preview': 128000,
94 | 'gpt-4': 8192,
95 | 'gpt-4-32k': 32768,
96 | 'gpt-4-0613': 8192,
97 | 'gpt-4-32k-0613': 32768,
98 | 'gpt-3.5-turbo-1106': 16385,
99 | 'gpt-3.5-turbo': 4096,
100 | 'gpt-3.5-turbo-16k': 16385,
101 | 'gpt-3.5-turbo-instruct': 4096,
102 | 'gpt-3.5-turbo-0613': 4096,
103 | 'gpt-3.5-turbo-16k-0613': 16385,
104 | }
105 | if model in length_map:
106 | return length_map[model]
107 | else:
108 | return 4096
109 |
110 |
111 | class OpenAIWrapper(BaseAPI):
112 |
113 | is_api: bool = True
114 |
115 | def __init__(self,
116 | model: str = 'gpt-3.5-turbo-0613',
117 | retry: int = 5,
118 | wait: int = 5,
119 | key: str = 'sk-mFPPhGU1FP8i3QoN017078FfC1Cf4e9bA076815e1917767e',
120 | verbose: bool = True,
121 | system_prompt: str = None,
122 | temperature: float = 0,
123 | timeout: int = 60,
124 | api_base: str = 'https://api1.zhtec.xyz/v1/chat/completions',
125 | max_tokens: int = 1024,
126 | img_size: int = 512,
127 | img_detail: str = 'low',
128 | **kwargs):
129 |
130 | self.model = model
131 | self.cur_idx = 0
132 | self.fail_msg = 'Failed to obtain answer via API. '
133 | self.max_tokens = max_tokens
134 | self.temperature = temperature
135 |
136 | env_key = os.environ.get('OPENAI_API_KEY', '')
137 | openai_key = env_key if key is None else key
138 | self.openai_key = openai_key
139 | assert img_size > 0 or img_size == -1
140 | self.img_size = img_size
141 | assert img_detail in ['high', 'low']
142 | self.img_detail = img_detail
143 |
144 | self.vision = False
145 | if model == 'gpt-4-vision-preview':
146 | self.vision = True
147 | self.timeout = timeout
148 |
149 | assert isinstance(openai_key, str) and openai_key.startswith('sk-'), (
150 | f'Illegal openai_key {openai_key}. '
151 | 'Please set the environment variable OPENAI_API_KEY to your openai key. '
152 | )
153 | super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
154 |
155 | if api_base in APIBASES:
156 | self.api_base = APIBASES[api_base]
157 | elif api_base.startswith('http'):
158 | self.api_base = api_base
159 | else:
160 | self.logger.error('Unknown API Base. ')
161 | sys.exit(-1)
162 |
163 | if 'OPENAI_API_BASE' in os.environ and os.environ['OPENAI_API_BASE'] != '':
164 | self.logger.error('Environment variable OPENAI_API_BASE is set. Will override the api_base arg. ')
165 | self.api_base = os.environ['OPENAI_API_BASE']
166 |
167 | # inputs can be a lvl-2 nested list: [content1, content2, content3, ...]
168 | # content can be a string or a list of image & text
169 | def prepare_inputs(self, inputs):
170 | input_msgs = []
171 | if self.system_prompt is not None:
172 | input_msgs.append(dict(role='system', content=self.system_prompt))
173 | if isinstance(inputs, str):
174 | input_msgs.append(dict(role='user', content=inputs))
175 | return input_msgs
176 | assert isinstance(inputs, list)
177 | dict_flag = [isinstance(x, dict) for x in inputs]
178 | if np.all(dict_flag):
179 | input_msgs.extend(inputs)
180 | return input_msgs
181 | str_flag = [isinstance(x, str) for x in inputs]
182 | if np.all(str_flag):
183 | img_flag = [x.startswith('http') or osp.exists(x) for x in inputs]
184 | if np.any(img_flag):
185 | content_list = []
186 | for fl, msg in zip(img_flag, inputs):
187 | if not fl:
188 | content_list.append(dict(type='text', text=msg))
189 | elif msg.startswith('http'):
190 | content_list.append(dict(type='image_url', image_url={'url': msg, 'detail': self.img_detail}))
191 | elif osp.exists(msg):
192 | from PIL import Image
193 | img = Image.open(msg)
194 | b64 = encode_image_to_base64(img, target_size=self.img_size)
195 | img_struct = dict(url=f'data:image/jpeg;base64,{b64}', detail=self.img_detail)
196 | content_list.append(dict(type='image_url', image_url=img_struct))
197 | input_msgs.append(dict(role='user', content=content_list))
198 | return input_msgs
199 | else:
200 | roles = ['user', 'assistant'] if len(inputs) % 2 == 1 else ['assistant', 'user']
201 | roles = roles * len(inputs)
202 | for role, msg in zip(roles, inputs):
203 | input_msgs.append(dict(role=role, content=msg))
204 | return input_msgs
205 | raise NotImplementedError('list of list prompt not implemented now. ')
206 |
207 | def generate_inner(self, inputs, **kwargs) -> str:
208 | input_msgs = self.prepare_inputs(inputs)
209 | temperature = kwargs.pop('temperature', self.temperature)
210 | max_tokens = kwargs.pop('max_tokens', self.max_tokens)
211 |
212 | context_window = GPT_context_window(self.model)
213 | max_tokens = min(max_tokens, context_window - self.get_token_len(inputs))
214 | if 0 < max_tokens <= 100:
215 | self.logger.warning(
216 | 'Less than 100 tokens left, '
217 | 'may exceed the context window with some additional meta symbols. '
218 | )
219 | if max_tokens <= 0:
220 | return 0, self.fail_msg + 'Input string longer than context window. ', 'Length Exceeded. '
221 |
222 | headers = {'Content-Type': 'application/json', 'Authorization': f'Bearer {self.openai_key}'}
223 | payload = dict(
224 | model=self.model,
225 | messages=input_msgs,
226 | max_tokens=max_tokens,
227 | n=1,
228 | temperature=temperature,
229 | **kwargs)
230 | response = requests.post(self.api_base, headers=headers, data=json.dumps(payload), timeout=self.timeout * 1.1)
231 | ret_code = response.status_code
232 | ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code
233 | answer = self.fail_msg
234 | try:
235 | resp_struct = json.loads(response.text)
236 | answer = resp_struct['choices'][0]['message']['content'].strip()
237 | except:
238 | pass
239 | return ret_code, answer, response
240 |
241 | def get_token_len(self, inputs) -> int:
242 | import tiktoken
243 | try:
244 | enc = tiktoken.encoding_for_model(self.model)
245 | except:
246 | enc = tiktoken.encoding_for_model('gpt-4')
247 | if isinstance(inputs, str):
248 | if inputs.startswith('http') or osp.exists(inputs):
249 | return 65 if self.img_detail == 'low' else 130
250 | else:
251 | return len(enc.encode(inputs))
252 | elif isinstance(inputs, dict):
253 | assert 'content' in inputs
254 | return self.get_token_len(inputs['content'])
255 | assert isinstance(inputs, list)
256 | res = 0
257 | for item in inputs:
258 | res += self.get_token_len(item)
259 | return res
260 |
--------------------------------------------------------------------------------
/ada_leval/dataset.py:
--------------------------------------------------------------------------------
1 | from ada_leval.smp import *
2 |
3 | class StackSelect:
4 |
5 | def __init__(self, setting='1k', mode='normal'):
6 | data = load(f'data/stackselect_{setting}.json')
7 | self.setting = setting
8 | assert mode in ['normal', 'less']
9 | if mode == 'normal':
10 | num = 1000 if int(setting[:-1]) < 32 else 200
11 | elif mode == 'less':
12 | num = 200 if int(setting[:-1]) < 32 else 50
13 |
14 | if num > 0:
15 | data = data[:num]
16 | for item in data:
17 | item['index'] = f"{item['question_id']}_{item['answer']}"
18 | self.data = data
19 |
20 | self.meta_prompt = """
21 | You are an AI assistant. Your job is to find out the most helpful answer to a given question.
22 | Each time, you will be provided with a question and n answers to this question.
23 | Each answer begins with an 'A' and a number(e.g. A4), which represents its designation.
24 | You need to determine which answer is the most helpful one to the question.
25 | The case sample is shown below and you should give me the answer in the format exactly the same as the sample. \n
26 | However, you should NOT focus on the content of sample answer. \n
27 | Sample Input (format only): \n
28 | The question is given below.
29 | XXX(The content of question)
30 | Possible answers are given below.
31 | A1:
32 | XXX(The content of answer 1)
33 | A2:
34 | XXX(The content of answer 2)
35 | .
36 | .
37 | .
38 | An:
39 | XXX(The content of answer n)
40 | Now the answers are over, please decide which answer is the most helpful one to the question.
41 | You must give me only the designation of the MOST helpful answer.
42 | Sample Output (format only): \n
43 | Answer: The designation of the most helpful answer.(e.g. A4 means answer 4 is the most helpful answer) \n\n
44 | """
45 |
46 | def __len__(self):
47 | return len(self.data)
48 |
49 | def get_meta(self):
50 | res = {
51 | 'index': [x['index'] for x in self.data],
52 | 'question': [x['question'] for x in self.data],
53 | 'answer': [x['answer'] for x in self.data],
54 | 'tags': [x['tags'] for x in self.data],
55 | 'num_choice': [len(x['all_answers']) for x in self.data]
56 | }
57 | return pd.DataFrame(res)
58 |
59 | def build_prompt(self, line):
60 | if isinstance(line, int):
61 | line = self.data[line]
62 | assert isinstance(line, dict)
63 | prompt = self.meta_prompt
64 | prompt += 'The question is given below.\n'
65 | prompt += line['question'] + '\n\n'
66 | prompt += 'Possible answers are given below.\n'
67 | all_answers = line['all_answers']
68 | for j in range(1, len(all_answers) + 1):
69 | prompt += 'A' + str(j) + ':\n\n' + all_answers[j - 1] + '\n\n'
70 |
71 | prompt += """
72 | Now the answers are over, please decide which answer is the most helpful one to the question.
73 | You must give me only the designation of the MOST helpful answer.
74 | """
75 | return prompt
76 |
77 | def evaluate(self, df):
78 | assert 'prediction' in df and 'answer' in df and 'num_choice' in df
79 |
80 | def extract(line):
81 | nc = line['num_choice']
82 | cands = [f'A{i}' for i in range(1, nc + 1)]
83 | finds = [line['prediction'].find(c) for c in cands]
84 | matched = sum([x >= 0 for x in finds])
85 | if matched >= 1:
86 | for i in range(nc - 1, -1, -1):
87 | if finds[i] >= 0:
88 | return cands[i]
89 | else:
90 | cands = [str(i) for i in range(1, nc + 1)]
91 | finds = [line['prediction'].find(c) for c in cands]
92 | matched = sum([x >= 0 for x in finds])
93 | if matched >= 1:
94 | for i in range(nc - 1, -1, -1):
95 | if finds[i] >= 0:
96 | return 'A' + cands[i]
97 | else:
98 | return '???'
99 |
100 | extracted = [extract(df.iloc[i]) for i in range(len(df))]
101 | df['extracted'] = extracted
102 | acc = np.mean([x == y for x, y in zip(df['extracted'], df['answer'])])
103 | acc = 100 * acc
104 | print(f'StackSelect {self.setting} Accuracy: {acc:.1f}%')
105 | return acc
106 |
107 |
108 | class TextSort:
109 |
110 | def __init__(self, setting='1k', mode='normal'):
111 | data = load(f'data/textsort_{setting}.json')
112 | self.setting = setting
113 | assert mode in ['normal', 'less']
114 | if mode == 'normal':
115 | num = 1000 if int(setting[:-1]) < 32 else 200
116 | elif mode == 'less':
117 | num = 200 if int(setting[:-1]) < 32 else 50
118 |
119 | if num > 0:
120 | data = data[:num]
121 | for item in data:
122 | book_id = item['book_id']
123 | para_offset = item['para_offset']
124 | item['index'] = f"{book_id}_{'_'.join([str(x) for x in para_offset])}"
125 | self.data = data
126 |
127 | def __len__(self):
128 | return len(self.data)
129 |
130 | def get_meta(self):
131 | res = {
132 | 'book_id': [x['book_id'] for x in self.data],
133 | 'para_offset': [x['para_offset'] for x in self.data],
134 | 'answer': [x['answer'] for x in self.data],
135 | 'index': [x['index'] for x in self.data]
136 | }
137 | return pd.DataFrame(res)
138 |
139 | def build_prompt(self, line):
140 | if isinstance(line, int):
141 | line = self.data[line]
142 | assert isinstance(line, dict)
143 | return line['prompt']
144 |
145 | def evaluate(self, df):
146 | assert 'prediction' in df and 'answer' in df
147 |
148 | def is_subseq(needle, haystack):
149 | current_pos = 0
150 | for c in needle:
151 | current_pos = haystack.find(c, current_pos) + 1
152 | if current_pos == 0:
153 | return False
154 | return True
155 |
156 | def extract(line):
157 | pred = line['prediction']
158 | if 'Answer:' in pred:
159 | pred = pred.split('Answer:')[1].strip()
160 | try:
161 | pred = json.loads(pred)
162 | return pred
163 | except:
164 | import itertools
165 | perms = list(itertools.permutations(range(1, 5)))
166 | perms = [''.join([str(x) for x in p]) for p in perms]
167 | subseq = [is_subseq(p, pred) for p in perms]
168 | if sum(subseq) == 1:
169 | for p, s in zip(perms, subseq):
170 | if s:
171 | return [int(pp) for pp in p]
172 | return [0, 0, 0, 0]
173 |
174 | extracted = [extract(df.iloc[i]) for i in range(len(df))]
175 | answers = [json.loads(x) if isinstance(x, str) else x for x in df['answer']]
176 | hit, tot = 0, 0
177 |
178 | for a, e in zip(answers, extracted):
179 | tot += 1
180 | flag = True
181 | for aa, ee in zip(a, e):
182 | if aa != ee:
183 | flag = False
184 | hit += flag
185 |
186 | acc = hit / tot
187 | acc = 100 * acc
188 | print(f'TextSort {self.setting} Accuracy: {acc:.1f}%')
189 | return acc
--------------------------------------------------------------------------------
/ada_leval/smp.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa: F401, F403
2 | import abc
3 | import argparse
4 | import collections
5 | import csv
6 | import json
7 | import multiprocessing as mp
8 | import numpy as np
9 | import os, sys, time, base64, io
10 | import os.path as osp
11 | import copy as cp
12 | import pickle
13 | import random as rd
14 | import requests
15 | import shutil
16 | import string
17 | import subprocess
18 | import warnings
19 | import pandas as pd
20 | from collections import OrderedDict, defaultdict
21 | from multiprocessing import Pool, current_process
22 | from tqdm import tqdm
23 | from PIL import Image
24 | import uuid
25 | from uuid import uuid4
26 | from datetime import datetime
27 | import matplotlib.pyplot as plt
28 | import seaborn as sns
29 | from tabulate import tabulate
30 |
31 | def d2df(D):
32 | return pd.DataFrame({x: [D[x]] for x in D})
33 |
34 | def LMUDataRoot():
35 | if 'LMUData' in os.environ and osp.exists(os.environ['LMUData']):
36 | return os.environ['LMUData']
37 | home = osp.expanduser('~')
38 | root = osp.join(home, 'LMUData')
39 | os.makedirs(root, exist_ok=True)
40 | return root
41 |
42 | def cn_string(s):
43 | import re
44 | if re.search(u'[\u4e00-\u9fff]', s):
45 | return True
46 | return False
47 |
48 | def timestr(second=False, minute=False):
49 | s = datetime.now().strftime('%Y%m%d%H%M%S')[2:]
50 | if second:
51 | return s
52 | elif minute:
53 | return s[:-2]
54 | else:
55 | return s[:-4]
56 |
57 | def num2uuid(num):
58 | rd.seed(num)
59 | return str(uuid.UUID(int=rd.getrandbits(128), version=4))
60 |
61 | def randomuuid():
62 | seed = rd.randint(0, 2 ** 32 - 1)
63 | return num2uuid(seed)
64 |
65 | def mrlines(fname, sp='\n'):
66 | f = open(fname).read().split(sp)
67 | while f != [] and f[-1] == '':
68 | f = f[:-1]
69 | return f
70 |
71 | def mwlines(lines, fname):
72 | with open(fname, 'w') as fout:
73 | fout.write('\n'.join(lines))
74 |
75 | def default_set(self, args, name, default):
76 | if hasattr(args, name):
77 | val = getattr(args, name)
78 | setattr(self, name, val)
79 | else:
80 | setattr(self, name, default)
81 |
82 | def dict_merge(dct, merge_dct):
83 | for k, _ in merge_dct.items():
84 | if (k in dct and isinstance(dct[k], dict) and isinstance(merge_dct[k], dict)): #noqa
85 | dict_merge(dct[k], merge_dct[k])
86 | else:
87 | dct[k] = merge_dct[k]
88 |
89 | def youtube_dl(idx):
90 | cmd = f'youtube-dl -f best -f mp4 "{idx}" -o {idx}.mp4'
91 | os.system(cmd)
92 |
93 | def run_command(cmd):
94 | if isinstance(cmd, str):
95 | cmd = cmd.split()
96 | return subprocess.check_output(cmd)
97 |
98 | def ls(dirname='.', match='', mode='all', level=1):
99 | if dirname == '.':
100 | ans = os.listdir(dirname)
101 | else:
102 | ans = [osp.join(dirname, x) for x in os.listdir(dirname)]
103 | assert mode in ['all', 'dir', 'file']
104 | assert level >= 1 and isinstance(level, int)
105 | if level == 1:
106 | ans = [x for x in ans if match in x]
107 | if mode == 'dir':
108 | ans = [x for x in ans if osp.isdir(x)]
109 | elif mode == 'file':
110 | ans = [x for x in ans if not osp.isdir(x)]
111 | else:
112 | ans = [x for x in ans if osp.isdir(x)]
113 | res = []
114 | for d in ans:
115 | res.extend(ls(d, match=match, mode=mode, level=level-1))
116 | ans = res
117 | return ans
118 |
119 | def intop(pred, label, n):
120 | pred = [np.argsort(x)[-n:] for x in pred]
121 | hit = [(l in p) for l, p in zip(label, pred)]
122 | return hit
123 |
124 | def topk(score, label, k=1):
125 | return np.mean(intop(score, label, k)) if isinstance(k, int) else [topk(score, label, kk) for kk in k]
126 |
127 | def download_file(url, filename=None):
128 | if filename is None:
129 | filename = url.split('/')[-1]
130 | response = requests.get(url)
131 | open(filename, 'wb').write(response.content)
132 |
133 | def fnp(model, input=None):
134 | from fvcore.nn import FlopCountAnalysis, parameter_count
135 | params = parameter_count(model)['']
136 | print('Parameter Size: {:.4f} M'.format(params / 1024 / 1024))
137 | if input is not None:
138 | flops = FlopCountAnalysis(model, input).total()
139 | print('FLOPs: {:.4f} G'.format(flops / 1024 / 1024 / 1024))
140 | return params, flops
141 | return params, None
142 |
143 | def proxy_set(s):
144 | import os
145 | for key in ['http_proxy', 'HTTP_PROXY', 'https_proxy', 'HTTPS_PROXY']:
146 | os.environ[key] = s
147 |
148 | # LOAD & DUMP
149 | def dump(data, f, **kwargs):
150 | def dump_pkl(data, pth, **kwargs):
151 | pickle.dump(data, open(pth, 'wb'))
152 |
153 | def dump_json(data, pth, **kwargs):
154 | json.dump(data, open(pth, 'w'), indent=4, ensure_ascii=False)
155 |
156 | def dump_jsonl(data, f, **kwargs):
157 | lines = [json.dumps(x, ensure_ascii=False) for x in data]
158 | with open(f, 'w', encoding='utf8') as fout:
159 | fout.write('\n'.join(lines))
160 |
161 | def dump_xlsx(data, f, **kwargs):
162 | data.to_excel(f, index=False)
163 |
164 | def dump_csv(data, f, quoting=csv.QUOTE_MINIMAL):
165 | data.to_csv(f, index=False, encoding='utf-8', quoting=quoting)
166 |
167 | def dump_tsv(data, f, quoting=csv.QUOTE_MINIMAL):
168 | data.to_csv(f, sep='\t', index=False, encoding='utf-8', quoting=quoting)
169 |
170 | handlers = dict(pkl=dump_pkl, json=dump_json, jsonl=dump_jsonl, xlsx=dump_xlsx, csv=dump_csv, tsv=dump_tsv)
171 | suffix = f.split('.')[-1]
172 | return handlers[suffix](data, f, **kwargs)
173 |
174 | import portalocker
175 | def safe_dump(data, f, **kwargs):
176 | with portalocker.Lock(f, timeout=5) as fh:
177 | dump(data, f, **kwargs)
178 | fh.flush()
179 | os.fsync(fh.fileno())
180 |
181 | def load(f):
182 | def load_pkl(pth):
183 | return pickle.load(open(pth, 'rb'))
184 |
185 | def load_json(pth):
186 | return json.load(open(pth, 'r', encoding='utf-8'))
187 |
188 | def load_jsonl(f):
189 | lines = open(f, encoding='utf-8').readlines()
190 | lines = [x.strip() for x in lines]
191 | if lines[-1] == '':
192 | lines = lines[:-1]
193 | data = [json.loads(x) for x in lines]
194 | return data
195 |
196 | def load_xlsx(f):
197 | return pd.read_excel(f)
198 |
199 | def load_csv(f):
200 | return pd.read_csv(f)
201 |
202 | def load_tsv(f):
203 | return pd.read_csv(f, sep='\t')
204 |
205 | handlers = dict(pkl=load_pkl, json=load_json, jsonl=load_jsonl, xlsx=load_xlsx, csv=load_csv, tsv=load_tsv)
206 | suffix = f.split('.')[-1]
207 | return handlers[suffix](f)
--------------------------------------------------------------------------------
/ada_leval/util.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | logger_initialized = {}
4 |
5 |
6 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
7 | logger = logging.getLogger(name)
8 | if name in logger_initialized:
9 | return logger
10 |
11 | for logger_name in logger_initialized:
12 | if name.startswith(logger_name):
13 | return logger
14 |
15 | stream_handler = logging.StreamHandler()
16 | handlers = [stream_handler]
17 |
18 | try:
19 | import torch.distributed as dist
20 | if dist.is_available() and dist.is_initialized():
21 | rank = dist.get_rank()
22 | else:
23 | rank = 0
24 | except ImportError:
25 | rank = 0
26 |
27 | if rank == 0 and log_file is not None:
28 | file_handler = logging.FileHandler(log_file, file_mode)
29 | handlers.append(file_handler)
30 |
31 | formatter = logging.Formatter(
32 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
33 | for handler in handlers:
34 | handler.setFormatter(formatter)
35 | handler.setLevel(log_level)
36 | logger.addHandler(handler)
37 |
38 | if rank == 0:
39 | logger.setLevel(log_level)
40 | else:
41 | logger.setLevel(logging.ERROR)
42 |
43 | logger_initialized[name] = True
44 | return logger
45 |
46 |
47 | from multiprocessing import Pool
48 | import os
49 | from typing import Callable, Iterable, Sized
50 |
51 | from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task,
52 | TaskProgressColumn, TextColumn, TimeRemainingColumn)
53 | from rich.text import Text
54 | import os.path as osp
55 | import portalocker
56 | from .smp import load, dump
57 |
58 |
59 | class _Worker:
60 | """Function wrapper for ``track_progress_rich``"""
61 |
62 | def __init__(self, func) -> None:
63 | self.func = func
64 |
65 | def __call__(self, inputs):
66 | inputs, idx = inputs
67 | if not isinstance(inputs, (tuple, list, dict)):
68 | inputs = (inputs, )
69 |
70 | if isinstance(inputs, dict):
71 | return self.func(**inputs), idx
72 | else:
73 | return self.func(*inputs), idx
74 |
75 |
76 | class _SkipFirstTimeRemainingColumn(TimeRemainingColumn):
77 | """Skip calculating remaining time for the first few times.
78 |
79 | Args:
80 | skip_times (int): The number of times to skip. Defaults to 0.
81 | """
82 |
83 | def __init__(self, *args, skip_times=0, **kwargs):
84 | super().__init__(*args, **kwargs)
85 | self.skip_times = skip_times
86 |
87 | def render(self, task: Task) -> Text:
88 | """Show time remaining."""
89 | if task.completed <= self.skip_times:
90 | return Text('-:--:--', style='progress.remaining')
91 | return super().render(task)
92 |
93 |
94 | def _tasks_with_index(tasks):
95 | """Add index to tasks."""
96 | for idx, task in enumerate(tasks):
97 | yield task, idx
98 |
99 |
100 | def track_progress_rich(func: Callable,
101 | tasks: Iterable = tuple(),
102 | task_num: int = None,
103 | nproc: int = 1,
104 | chunksize: int = 1,
105 | description: str = 'Processing',
106 | save=None, keys=None,
107 | color: str = 'blue') -> list:
108 | """Track the progress of parallel task execution with a progress bar. The
109 | built-in :mod:`multiprocessing` module is used for process pools and tasks
110 | are done with :func:`Pool.map` or :func:`Pool.imap_unordered`.
111 |
112 | Args:
113 | func (callable): The function to be applied to each task.
114 | tasks (Iterable or Sized): A tuple of tasks. There are several cases
115 | for different format tasks:
116 | - When ``func`` accepts no arguments: tasks should be an empty
117 | tuple, and ``task_num`` must be specified.
118 | - When ``func`` accepts only one argument: tasks should be a tuple
119 | containing the argument.
120 | - When ``func`` accepts multiple arguments: tasks should be a
121 | tuple, with each element representing a set of arguments.
122 | If an element is a ``dict``, it will be parsed as a set of
123 | keyword-only arguments.
124 | Defaults to an empty tuple.
125 | task_num (int, optional): If ``tasks`` is an iterator which does not
126 | have length, the number of tasks can be provided by ``task_num``.
127 | Defaults to None.
128 | nproc (int): Process (worker) number, if nuproc is 1,
129 | use single process. Defaults to 1.
130 | chunksize (int): Refer to :class:`multiprocessing.Pool` for details.
131 | Defaults to 1.
132 | description (str): The description of progress bar.
133 | Defaults to "Process".
134 | color (str): The color of progress bar. Defaults to "blue".
135 |
136 | Examples:
137 | >>> import time
138 |
139 | >>> def func(x):
140 | ... time.sleep(1)
141 | ... return x**2
142 | >>> track_progress_rich(func, range(10), nproc=2)
143 |
144 | Returns:
145 | list: The task results.
146 | """
147 | if save is not None:
148 | assert osp.exists(osp.dirname(save)) or osp.dirname(save) == ''
149 | if not osp.exists(save):
150 | dump({}, save)
151 | if keys is not None:
152 | assert len(keys) == len(tasks)
153 |
154 | if not callable(func):
155 | raise TypeError('func must be a callable object')
156 | if not isinstance(tasks, Iterable):
157 | raise TypeError(
158 | f'tasks must be an iterable object, but got {type(tasks)}')
159 | if isinstance(tasks, Sized):
160 | if len(tasks) == 0:
161 | if task_num is None:
162 | raise ValueError('If tasks is an empty iterable, '
163 | 'task_num must be set')
164 | else:
165 | tasks = tuple(tuple() for _ in range(task_num))
166 | else:
167 | if task_num is not None and task_num != len(tasks):
168 | raise ValueError('task_num does not match the length of tasks')
169 | task_num = len(tasks)
170 |
171 | if nproc <= 0:
172 | raise ValueError('nproc must be a positive number')
173 |
174 | skip_times = nproc * chunksize if nproc > 1 else 0
175 | prog_bar = Progress(
176 | TextColumn('{task.description}'),
177 | BarColumn(),
178 | _SkipFirstTimeRemainingColumn(skip_times=skip_times),
179 | MofNCompleteColumn(),
180 | TaskProgressColumn(show_speed=True),
181 | )
182 |
183 | worker = _Worker(func)
184 | task_id = prog_bar.add_task(
185 | total=task_num, color=color, description=description)
186 | tasks = _tasks_with_index(tasks)
187 |
188 | # Use single process when nproc is 1, else use multiprocess.
189 | with prog_bar:
190 | if nproc == 1:
191 | results = []
192 | for task in tasks:
193 | result, idx = worker(task)
194 | results.append(worker(task)[0])
195 | if save is not None:
196 | with portalocker.Lock(save, timeout=5) as fh:
197 | ans = load(save)
198 | ans[keys[idx]] = result
199 |
200 | if os.environ.get('VERBOSE', True):
201 | print(keys[idx], result, flush=True)
202 |
203 | dump(ans, save)
204 | fh.flush()
205 | os.fsync(fh.fileno())
206 |
207 | prog_bar.update(task_id, advance=1, refresh=True)
208 | else:
209 | with Pool(nproc) as pool:
210 | results = []
211 | unordered_results = []
212 | gen = pool.imap_unordered(worker, tasks, chunksize)
213 | try:
214 | for result in gen:
215 | result, idx = result
216 | unordered_results.append((result, idx))
217 |
218 | if save is not None:
219 | with portalocker.Lock(save, timeout=5) as fh:
220 | ans = load(save)
221 | ans[keys[idx]] = result
222 |
223 | if os.environ.get('VERBOSE', False):
224 | print(keys[idx], result, flush=True)
225 |
226 | dump(ans, save)
227 | fh.flush()
228 | os.fsync(fh.fileno())
229 |
230 | results.append(None)
231 | prog_bar.update(task_id, advance=1, refresh=True)
232 | except Exception as e:
233 | prog_bar.stop()
234 | raise e
235 | for result, idx in unordered_results:
236 | results[idx] = result
237 | return results
238 |
239 |
240 | def get_rank_and_world_size():
241 | local_rank = int(os.environ.get('LOCAL_RANK', 0))
242 | world_size = int(os.environ.get('WORLD_SIZE', 1))
243 | return local_rank, world_size
--------------------------------------------------------------------------------
/assets/AdaLEval.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-compass/Ada-LEval/2154258d5fa3969ac5429b3132d505570ef8a57a/assets/AdaLEval.png
--------------------------------------------------------------------------------
/assets/BestAnswer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/open-compass/Ada-LEval/2154258d5fa3969ac5429b3132d505570ef8a57a/assets/BestAnswer.png
--------------------------------------------------------------------------------
/fetch_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | mkdir -p data
3 | for i in 1k 2k 4k 6k 8k 12k 16k 32k 64k 128k
4 | do
5 | wget http://opencompass.openxlab.space/utils/AdaLEval/stackselect_$i.json -O data/stackselect_$i.json;\
6 | done
7 | for i in 1k 2k 4k 8k 16k 32k 64k 128k
8 | do
9 | wget http://opencompass.openxlab.space/utils/AdaLEval/textsort_$i.json -O data/textsort_$i.json;\
10 | done
11 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | from ada_leval.smp import *
2 | from ada_leval.util import *
3 | from ada_leval.api import OpenAIWrapper
4 | from ada_leval.dataset import StackSelect, TextSort
5 |
6 | RESULT_FILE = 'result.json'
7 | if not osp.exists(RESULT_FILE):
8 | dump({}, RESULT_FILE)
9 |
10 | settings = ['1k', '2k', '4k', '8k', '16k', '32k', '64k', '128k']
11 | datasets = [f'stackselect_{k}' for k in settings + ['6k', '12k']] + [f'textsort_{k}' for k in settings]
12 |
13 | def parse_args():
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--data', type=str, nargs='+', required=True, choices=datasets)
16 | parser.add_argument('--model', type=str, required=True, choices=['internlm2-7b', 'internlm2-20b', 'gpt-4-0125'])
17 | parser.add_argument('--mode', type=str, default='all', choices=['infer', 'all'])
18 | parser.add_argument('--nproc', type=int, default=4)
19 | args = parser.parse_args()
20 | return args
21 |
22 | def build_model(m, setting=None):
23 | if 'internlm2' in m:
24 | session_len = 160000
25 | from lmdeploy import pipeline, TurbomindEngineConfig
26 | backend_config = TurbomindEngineConfig(rope_scaling_factor=2.0, session_len=session_len)
27 |
28 | if m == 'gpt-4-0125':
29 | model = OpenAIWrapper('gpt-4-0125-preview')
30 | elif m == 'internlm2-7b':
31 | model = pipeline('internlm/internlm2-chat-7b', backend_config=backend_config)
32 | elif m == 'internlm2-20b':
33 | model = pipeline('internlm/internlm2-chat-20b', backend_config=backend_config)
34 | return model
35 |
36 | import tiktoken
37 | ENC = tiktoken.encoding_for_model('gpt-4')
38 |
39 | def get_token_length(prompt):
40 | return len(ENC.encode(prompt))
41 |
42 | def main():
43 | rank, world_size = get_rank_and_world_size()
44 | if world_size > 1:
45 | import torch
46 | import torch.distributed as dist
47 | torch.cuda.set_device(rank)
48 | dist.init_process_group(backend='nccl')
49 |
50 | if rank == 0:
51 | os.makedirs('results', exist_ok=True)
52 |
53 | args = parse_args()
54 | model_name = args.model
55 | model = build_model(args.model)
56 | for dname in args.data:
57 | d, setting = dname.split('_')
58 | dataset_mode = 'less' if getattr(model, 'is_api', False) else 'normal'
59 |
60 | if d == 'stackselect':
61 | dataset = StackSelect(setting=setting, mode=dataset_mode)
62 | elif d == 'textsort':
63 | dataset = TextSort(setting=setting, mode=dataset_mode)
64 |
65 | lt = len(dataset)
66 | prompts = [dataset.build_prompt(i) for i in range(lt)]
67 | meta = dataset.get_meta()
68 | indices = list(meta['index'])
69 |
70 | out_file = f'results/{model_name}_{dname}.pkl'
71 | res = {} if not osp.exists(out_file) else load(out_file)
72 | tups = [(i, p) for i, p in zip(indices, prompts) if i not in res]
73 |
74 | if len(tups):
75 | if getattr(model, 'is_api', False):
76 | res = track_progress_rich(
77 | model.generate,
78 | [x[1] for x in tups],
79 | nproc=args.nproc,
80 | chunksize=args.nproc,
81 | save=out_file,
82 | keys=[x[0] for x in tups])
83 | else:
84 | sub_tups = tups[rank::world_size]
85 | sub_out_file = f'results/{model_name}_{dname}_{rank}.pkl'
86 | sub_res = {}
87 | import torch
88 | with torch.no_grad():
89 | for t in tqdm(sub_tups):
90 | index, prompt = t
91 | sub_res[index] = model(prompt).text
92 | dump(sub_res, sub_out_file)
93 |
94 | if world_size > 1:
95 | dist.barrier()
96 |
97 | if rank == 0:
98 | res = {}
99 | for i in range(world_size):
100 | sub_out_file = f'results/{model_name}_{dname}_{i}.pkl'
101 | if osp.exists(sub_out_file):
102 | res.update(load(sub_out_file))
103 | if osp.exists(out_file):
104 | res.update(load(out_file))
105 | dump(res, out_file)
106 |
107 | res = load(out_file)
108 | meta['prediction'] = [res[k] for k in meta['index']]
109 | dump(meta, f'results/{model_name}_{dname}.xlsx')
110 |
111 | if args.mode == 'all':
112 | results = load(RESULT_FILE)
113 | acc = dataset.evaluate(meta)
114 | results[f'{model_name}_{dname}'] = acc
115 | dump(results, RESULT_FILE)
116 |
117 | if world_size > 1:
118 | dist.barrier()
119 | os.system(f"rm {f'results/{model_name}_{dname}_{rank}.pkl'}")
120 |
121 | if __name__ == '__main__':
122 | main()
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -x
3 | export GPU=$(nvidia-smi --list-gpus | wc -l)
4 | torchrun --nproc-per-node=$GPU run.py ${@:1}
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | REQUIRES = """
4 | numpy>=1.23.4
5 | openai
6 | requests
7 | tqdm
8 | pandas>=1.5.3
9 | tiktoken
10 | rich
11 | portalocker
12 | pillow
13 | matplotlib
14 | seaborn
15 | tabulate
16 | """
17 |
18 |
19 | def get_install_requires():
20 | reqs = [req for req in REQUIRES.split('\n') if len(req) > 0]
21 | return reqs
22 |
23 |
24 | with open('README.md') as f:
25 | readme = f.read()
26 |
27 |
28 | def do_setup():
29 | setup(
30 | name='ada_leval',
31 | version='0.1.0',
32 | description='The official implementation for Ada-LEval',
33 | # url="",
34 | author="Haodong Duan",
35 | long_description=readme,
36 | long_description_content_type='text/markdown',
37 | cmdclass={},
38 | install_requires=get_install_requires(),
39 | setup_requires=[],
40 | python_requires='>=3.7.0',
41 | packages=find_packages(exclude=[
42 | 'test*',
43 | 'paper_test*',
44 | ]),
45 | keywords=['AI', 'NLP', 'in-context learning'],
46 | entry_points={
47 | "console_scripts": []
48 | },
49 | classifiers=[
50 | 'Programming Language :: Python :: 3.7',
51 | 'Programming Language :: Python :: 3.8',
52 | 'Programming Language :: Python :: 3.9',
53 | 'Programming Language :: Python :: 3.10',
54 | 'Intended Audience :: Developers',
55 | 'Intended Audience :: Education',
56 | 'Intended Audience :: Science/Research',
57 | ])
58 |
59 |
60 | if __name__ == '__main__':
61 | do_setup()
62 |
--------------------------------------------------------------------------------