├── .gitignore ├── LICENSE ├── README.md ├── USE_POLICY.md ├── assets └── q_eg.png ├── configs ├── few-shot.yaml ├── few-shot_cot.yaml ├── few-shot_cot_sc.yaml ├── few-shot_sc.yaml ├── few-shot_sys.yaml ├── few-shot_sys_cot.yaml ├── few-shot_tot.yaml ├── zero-shot.yaml ├── zero-shot_cot.yaml ├── zero-shot_cot_sc.yaml ├── zero-shot_sc.yaml ├── zero-shot_sys.yaml ├── zero-shot_sys_cot.yaml └── zero-shot_tot.yaml ├── environment.yml ├── evaluate.py ├── prompts ├── __init__.py ├── qa_prompt.py ├── sys_prompt.py └── tot_prompt.py ├── requirements.txt ├── run.py ├── scripts ├── run.sh ├── run_ctx.sh ├── run_os.sh └── serving.sh ├── solve ├── __init__.py └── solver.py ├── tokenizer.model ├── tokenizer_checklist.chk └── utils ├── __init__.py ├── arg_parser.py ├── tot_solver.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | .cache 3 | .idea 4 | res/ 5 | nohup.out 6 | */__pycache__ 7 | logs/ 8 | download.py -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NLPBench: Evaluating NLP-Related Problem-solving Ability in Large Language Models 2 | 3 | 4 | NLPBench is a novel benchmark for Natural Language Processing problems consisting of 378 questions sourced from the NLP course final exams at Yale University. 5 | 6 | ## Data 7 | Our example questions: 8 | ![Example Questions](assets/q_eg.png) 9 | **Our dataset is licensed under the [CC BY-ND](https://creativecommons.org/licenses/by-nd/4.0/deed.en)**. You can download our dataset through [this link](https://drive.google.com/drive/folders/1haGLwzdZ_fejN7s-nBpDlz8gZPowZSZN?usp=sharing). 10 | 11 | ## Environment Preparation 12 | You can import our environment from the `environment.yml` by 13 | ```bash 14 | conda env create -f environment.yml 15 | ``` 16 | then activate our conda environment by 17 | ```bash 18 | conda activate NLPBench 19 | ``` 20 | 21 | ## Evaluation 22 | Our evaluations are based on both online (GPT-3.5, GPT-4, and PaLM 2) and open-sourced (LLAMA 2, Falcon, Bloom, etc.) LLMs. 23 | 24 | ### For Online LLM 25 | Online LLM often requires an `API-key` before access. If you want to access the OpenAI model, you need to add the `OPENAI_API_KEY` to the system environment as follows: 26 | ```bash 27 | export OPENAI_API_KEY="YOUR OPENAI API KEY" 28 | ``` 29 | and for PaLM 2, you need to add the `PALM_API_KEY` to your system environment as follows: 30 | ```bash 31 | export PALM_API_KEY="YOUR PALM API" 32 | ``` 33 | 34 | ### For Open-sourced LLM 35 | We use [vLLM](https://github.com/vllm-project/vllm) to start an openai-like endpoint for evaluation. All configurations are summarized in `./utils/utils.py:oai_llm_config`. Check [this list](https://vllm.readthedocs.io/en/latest/models/supported_models.html) for information on the supported open-source model. 36 | 37 | Basically, if you want to evaluate other open-sourced models, add your model's configuration in the following format into the `oai_llm_config`: 38 | ```json 39 | "HUGGINGFACE REPO": { 40 | "model": "HUGGINGFACE REPO", 41 | "api_key": "empty", 42 | "api_base": "YOUR ENDPOINT HOST, DEFAULT: http://127.0.0.1:8000/v1", 43 | } 44 | ``` 45 | then start the endpoint with the following script: 46 | ```bash 47 | bash scripts/serving.sh [-m HUGGINGFACE REPO][-n NUMBER OF GPUs][-a HOST ADDRESS, DEFAULT: 127.0.0.1][-p PORT, DEFAULT: 8000] 48 | ``` 49 | 50 | ### Run Evaluation 51 | We have two steps for evaluation: (1) solving the problems and (2) calculating the accuracy. 52 | We adopt [sacred](https://github.com/IDSIA/sacred) to manage our configurations. All configs can be found in `./configs`. You can also add your config by creating a specific `yaml` file. As an example, you can run the following code to let `GPT-3.5` with only `zero-shot` prompting answer the questions without context: 53 | ```bash 54 | python run.py with configs/zero-shot.yaml model_name='gpt-3.5-turbo' ctx=False 55 | ``` 56 | The answer results will be saved in `./res/{SEED}/no_ctx/zero-shot_gpt-3.5-turbo.json`. 57 | You can evaluate the above result by running the following code: 58 | ``` 59 | python evaluate.py 60 | ``` 61 | Then the result will be saved in `./res/{SEED}/` 62 | 63 | ## Prompt 64 | All the prompts in our evaluation can be found in `./prompts`, including prompt for question answering (`qa_prompt.py`), system prompt (`sys_prompt.py`), and prompt for tree-of-thought (`tot_prompt.py`). You can customize your prompt by modifying the above three files. 65 | 66 | ## Citation 67 | If you think our repository and result is useful, please cite our paper by 68 | ``` 69 | @misc{song2023nlpbench, 70 | title={NLPBench: Evaluating Large Language Models on Solving NLP Problems}, 71 | author={Linxin Song and Jieyu Zhang and Lechao Cheng and Pengyuan Zhou and Tianyi Zhou and Irene Li}, 72 | year={2023}, 73 | eprint={2309.15630}, 74 | archivePrefix={arXiv}, 75 | primaryClass={cs.CL} 76 | } 77 | ``` -------------------------------------------------------------------------------- /USE_POLICY.md: -------------------------------------------------------------------------------- 1 | # Llama 2 Acceptable Use Policy 2 | 3 | Meta is committed to promoting safe and fair use of its tools and features, including Llama 2. If you access or use Llama 2, you agree to this Acceptable Use Policy (“Policy”). The most recent copy of this policy can be found at [ai.meta.com/llama/use-policy](http://ai.meta.com/llama/use-policy). 4 | 5 | ## Prohibited Uses 6 | We want everyone to use Llama 2 safely and responsibly. You agree you will not use, or allow others to use, Llama 2 to: 7 | 8 | 1. Violate the law or others’ rights, including to: 9 | 1. Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as: 10 | 1. Violence or terrorism 11 | 2. Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material 12 | 3. Human trafficking, exploitation, and sexual violence 13 | 4. The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials. 14 | 5. Sexual solicitation 15 | 6. Any other criminal activity 16 | 2. Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals 17 | 3. Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services 18 | 4. Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices 19 | 5. Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws 20 | 6. Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any products or services using the Llama 2 Materials 21 | 7. Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system 22 | 23 | 24 | 25 | 2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of Llama 2 related to the following: 26 | 1. Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State 27 | 2. Guns and illegal weapons (including weapon development) 28 | 3. Illegal drugs and regulated/controlled substances 29 | 4. Operation of critical infrastructure, transportation technologies, or heavy machinery 30 | 5. Self-harm or harm to others, including suicide, cutting, and eating disorders 31 | 6. Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual 32 | 33 | 34 | 35 | 3. Intentionally deceive or mislead others, including use of Llama 2 related to the following: 36 | 1. Generating, promoting, or furthering fraud or the creation or promotion of disinformation 37 | 2. Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content 38 | 3. Generating, promoting, or further distributing spam 39 | 4. Impersonating another individual without consent, authorization, or legal right 40 | 5. Representing that the use of Llama 2 or outputs are human-generated 41 | 6. Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement 42 | 4. Fail to appropriately disclose to end users any known dangers of your AI system 43 | 44 | Please report any violation of this Policy, software “bug,” or other problems that could lead to a violation of this Policy through one of the following means: 45 | 46 | * Reporting issues with the model: [github.com/facebookresearch/llama](http://github.com/facebookresearch/llama) 47 | * Reporting risky content generated by the model: [developers.facebook.com/llama_output_feedback](http://developers.facebook.com/llama_output_feedback) 48 | * Reporting bugs and security concerns: [facebook.com/whitehat/info](http://facebook.com/whitehat/info) 49 | * Reporting violations of the Acceptable Use Policy or unlicensed uses of Llama: [LlamaUseReport@meta.com](mailto:LlamaUseReport@meta.com) 50 | 51 | -------------------------------------------------------------------------------- /assets/q_eg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinxinS97/NLPBench/82d6c109db9121ff60627873b8cc5ead922c8cba/assets/q_eg.png -------------------------------------------------------------------------------- /configs/few-shot.yaml: -------------------------------------------------------------------------------- 1 | max_tokens: 945 2 | seed: 41 3 | shot_type: 'few-shot' 4 | prompt_r: null 5 | sys: false 6 | self_consistency: false 7 | self_reflection: false 8 | -------------------------------------------------------------------------------- /configs/few-shot_cot.yaml: -------------------------------------------------------------------------------- 1 | max_tokens: 945 2 | seed: 41 3 | shot_type: 'few-shot' 4 | prompt_r: 'cot' 5 | sys: false 6 | self_consistency: false 7 | self_reflection: false 8 | -------------------------------------------------------------------------------- /configs/few-shot_cot_sc.yaml: -------------------------------------------------------------------------------- 1 | max_tokens: 945 2 | seed: 41 3 | shot_type: 'few-shot' 4 | prompt_r: 'cot' 5 | sys: false 6 | self_consistency: true 7 | self_reflection: false 8 | -------------------------------------------------------------------------------- /configs/few-shot_sc.yaml: -------------------------------------------------------------------------------- 1 | max_tokens: 945 2 | seed: 41 3 | shot_type: 'few-shot' 4 | prompt_r: null 5 | sys: false 6 | self_consistency: true 7 | self_reflection: false 8 | -------------------------------------------------------------------------------- /configs/few-shot_sys.yaml: -------------------------------------------------------------------------------- 1 | max_tokens: 945 2 | seed: 41 3 | shot_type: 'few-shot' 4 | prompt_r: null 5 | sys: true 6 | self_consistency: false 7 | self_reflection: false 8 | -------------------------------------------------------------------------------- /configs/few-shot_sys_cot.yaml: -------------------------------------------------------------------------------- 1 | max_tokens: 945 2 | seed: 41 3 | shot_type: 'few-shot' 4 | prompt_r: 'cot' 5 | sys: true 6 | self_consistency: false 7 | self_reflection: false 8 | -------------------------------------------------------------------------------- /configs/few-shot_tot.yaml: -------------------------------------------------------------------------------- 1 | max_tokens: 945 2 | seed: 41 3 | shot_type: 'few-shot' 4 | prompt_r: 'tot' 5 | sys: false 6 | self_consistency: false 7 | self_reflection: false 8 | -------------------------------------------------------------------------------- /configs/zero-shot.yaml: -------------------------------------------------------------------------------- 1 | max_tokens: 945 2 | seed: 41 3 | shot_type: 'zero-shot' 4 | prompt_r: null 5 | sys: false 6 | self_consistency: false 7 | self_reflection: false 8 | -------------------------------------------------------------------------------- /configs/zero-shot_cot.yaml: -------------------------------------------------------------------------------- 1 | max_tokens: 945 2 | seed: 41 3 | shot_type: 'zero-shot' 4 | prompt_r: 'cot' 5 | sys: false 6 | self_consistency: false 7 | self_reflection: false 8 | -------------------------------------------------------------------------------- /configs/zero-shot_cot_sc.yaml: -------------------------------------------------------------------------------- 1 | max_tokens: 945 2 | seed: 41 3 | shot_type: 'zero-shot' 4 | prompt_r: 'cot' 5 | sys: false 6 | self_consistency: true 7 | self_reflection: false 8 | -------------------------------------------------------------------------------- /configs/zero-shot_sc.yaml: -------------------------------------------------------------------------------- 1 | max_tokens: 945 2 | seed: 41 3 | shot_type: 'zero-shot' 4 | prompt_r: null 5 | sys: false 6 | self_consistency: true 7 | self_reflection: false 8 | -------------------------------------------------------------------------------- /configs/zero-shot_sys.yaml: -------------------------------------------------------------------------------- 1 | max_tokens: 945 2 | seed: 41 3 | shot_type: 'zero-shot' 4 | prompt_r: null 5 | sys: true 6 | self_consistency: false 7 | self_reflection: false 8 | -------------------------------------------------------------------------------- /configs/zero-shot_sys_cot.yaml: -------------------------------------------------------------------------------- 1 | max_tokens: 945 2 | seed: 41 3 | shot_type: 'zero-shot' 4 | prompt_r: 'cot' 5 | sys: true 6 | self_consistency: false 7 | self_reflection: false 8 | -------------------------------------------------------------------------------- /configs/zero-shot_tot.yaml: -------------------------------------------------------------------------------- 1 | max_tokens: 945 2 | seed: 41 3 | shot_type: 'zero-shot' 4 | prompt_r: 'tot' 5 | sys: false 6 | self_consistency: false 7 | self_reflection: false 8 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: NLPBench 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - asttokens=2.2.1=pyhd8ed1ab_0 9 | - backcall=0.2.0=pyh9f0ad1d_0 10 | - backports=1.0=pyhd8ed1ab_3 11 | - backports.functools_lru_cache=1.6.5=pyhd8ed1ab_0 12 | - bzip2=1.0.8=h7b6447c_0 13 | - ca-certificates=2023.7.22=hbcca054_0 14 | - comm=0.1.4=pyhd8ed1ab_0 15 | - debugpy=1.6.7=py310h6a678d5_0 16 | - decorator=5.1.1=pyhd8ed1ab_0 17 | - entrypoints=0.4=pyhd8ed1ab_0 18 | - executing=1.2.0=pyhd8ed1ab_0 19 | - ipykernel=6.25.1=pyh71e2992_0 20 | - ipython=8.14.0=pyh41d4057_0 21 | - jedi=0.19.0=pyhd8ed1ab_0 22 | - jupyter_core=5.3.1=py310hff52083_0 23 | - ld_impl_linux-64=2.38=h1181459_1 24 | - libffi=3.4.4=h6a678d5_0 25 | - libgcc-ng=11.2.0=h1234567_1 26 | - libgomp=11.2.0=h1234567_1 27 | - libsodium=1.0.18=h36c2ea0_1 28 | - libstdcxx-ng=11.2.0=h1234567_1 29 | - libuuid=1.41.5=h5eee18b_0 30 | - matplotlib-inline=0.1.6=pyhd8ed1ab_0 31 | - ncurses=6.4=h6a678d5_0 32 | - nest-asyncio=1.5.6=pyhd8ed1ab_0 33 | - openssl=3.0.10=h7f8727e_0 34 | - packaging=23.1=pyhd8ed1ab_0 35 | - parso=0.8.3=pyhd8ed1ab_0 36 | - pexpect=4.8.0=pyh1a96a4e_2 37 | - pickleshare=0.7.5=py_1003 38 | - pip=23.2.1=py310h06a4308_0 39 | - platformdirs=3.10.0=pyhd8ed1ab_0 40 | - prompt-toolkit=3.0.39=pyha770c72_0 41 | - prompt_toolkit=3.0.39=hd8ed1ab_0 42 | - psutil=5.9.0=py310h5eee18b_0 43 | - ptyprocess=0.7.0=pyhd3deb0d_0 44 | - pure_eval=0.2.2=pyhd8ed1ab_0 45 | - pygments=2.16.1=pyhd8ed1ab_0 46 | - python=3.10.12=h955ad1f_0 47 | - python-dateutil=2.8.2=pyhd8ed1ab_0 48 | - python_abi=3.10=2_cp310 49 | - pyzmq=25.1.0=py310h6a678d5_0 50 | - readline=8.2=h5eee18b_0 51 | - setuptools=68.0.0=py310h06a4308_0 52 | - six=1.16.0=pyh6c4a22f_0 53 | - sqlite=3.41.2=h5eee18b_0 54 | - stack_data=0.6.2=pyhd8ed1ab_0 55 | - tk=8.6.12=h1ccaba5_0 56 | - traitlets=5.9.0=pyhd8ed1ab_0 57 | - typing-extensions=4.7.1=hd8ed1ab_0 58 | - typing_extensions=4.7.1=pyha770c72_0 59 | - wcwidth=0.2.6=pyhd8ed1ab_0 60 | - wheel=0.38.4=py310h06a4308_0 61 | - xz=5.4.2=h5eee18b_0 62 | - zeromq=4.3.4=h9c3ff4c_1 63 | - zlib=1.2.13=h5eee18b_0 64 | - pip: 65 | - aiohttp==3.8.5 66 | - aiosignal==1.3.1 67 | - anyio==3.7.1 68 | - argon2-cffi==23.1.0 69 | - argon2-cffi-bindings==21.2.0 70 | - arrow==1.2.3 71 | - async-lru==2.0.4 72 | - async-timeout==4.0.3 73 | - attrs==23.1.0 74 | - babel==2.12.1 75 | - beautifulsoup4==4.12.2 76 | - bleach==6.0.0 77 | - cachetools==5.3.1 78 | - certifi==2023.7.22 79 | - cffi==1.15.1 80 | - chardet==5.2.0 81 | - charset-normalizer==3.2.0 82 | - click==8.1.7 83 | - cmake==3.27.2 84 | - colorama==0.4.6 85 | - contourpy==1.1.0 86 | - cryptography==41.0.3 87 | - cycler==0.11.0 88 | - dataproperty==1.0.1 89 | - defusedxml==0.7.1 90 | - diskcache==5.6.1 91 | - docopt==0.6.2 92 | - exceptiongroup==1.1.3 93 | - fastapi==0.103.2 94 | - fastjsonschema==2.18.0 95 | - filelock==3.12.2 96 | - flaml==2.0.0 97 | - fonttools==4.42.1 98 | - fqdn==1.5.1 99 | - frozenlist==1.4.0 100 | - fsspec==2023.6.0 101 | - gitdb==4.0.10 102 | - gitpython==3.1.34 103 | - google-ai-generativelanguage==0.2.0 104 | - google-api-core==2.11.1 105 | - google-auth==2.22.0 106 | - google-generativeai==0.1.0 107 | - googleapis-common-protos==1.60.0 108 | - gptcache==0.1.40 109 | - grpcio==1.57.0 110 | - grpcio-status==1.57.0 111 | - guidance==0.0.64 112 | - h11==0.14.0 113 | - httptools==0.6.0 114 | - huggingface-hub==0.16.4 115 | - idna==3.4 116 | - install==1.3.5 117 | - ipython-genutils==0.2.0 118 | - ipywidgets==8.1.0 119 | - isoduration==20.11.0 120 | - jinja2==3.1.2 121 | - joblib==1.3.2 122 | - json5==0.9.14 123 | - jsonpickle==3.0.2 124 | - jsonpointer==2.4 125 | - jsonschema==4.19.0 126 | - jsonschema-specifications==2023.7.1 127 | - jupyter==1.0.0 128 | - jupyter-client==8.3.1 129 | - jupyter-console==6.6.3 130 | - jupyter-events==0.7.0 131 | - jupyter-lsp==2.2.0 132 | - jupyter-server==2.7.3 133 | - jupyter-server-terminals==0.4.4 134 | - jupyterlab==4.0.5 135 | - jupyterlab-pygments==0.2.2 136 | - jupyterlab-server==2.24.0 137 | - jupyterlab-widgets==3.0.8 138 | - kiwisolver==1.4.5 139 | - lit==16.0.6 140 | - markupsafe==2.1.3 141 | - matplotlib==3.7.2 142 | - mbstrdecoder==1.1.3 143 | - mistune==3.0.1 144 | - mpmath==1.3.0 145 | - msal==1.23.0 146 | - msgpack==1.0.7 147 | - multidict==6.0.4 148 | - munch==2.5.0 149 | - nbclient==0.8.0 150 | - nbconvert==7.8.0 151 | - nbformat==5.9.2 152 | - networkx==3.1 153 | - ninja==1.11.1 154 | - nltk==3.8.1 155 | - notebook==7.0.3 156 | - notebook-shim==0.2.3 157 | - numpy==1.26.0b1 158 | - nvidia-cublas-cu11==11.10.3.66 159 | - nvidia-cuda-cupti-cu11==11.7.101 160 | - nvidia-cuda-nvrtc-cu11==11.7.99 161 | - nvidia-cuda-runtime-cu11==11.7.99 162 | - nvidia-cudnn-cu11==8.5.0.96 163 | - nvidia-cufft-cu11==10.9.0.58 164 | - nvidia-curand-cu11==10.2.10.91 165 | - nvidia-cusolver-cu11==11.4.0.1 166 | - nvidia-cusparse-cu11==11.7.4.91 167 | - nvidia-nccl-cu11==2.14.3 168 | - nvidia-nvtx-cu11==11.7.91 169 | - openai==0.27.8 170 | - overrides==7.4.0 171 | - pandas==2.1.0 172 | - pandocfilters==1.5.0 173 | - pathvalidate==3.1.0 174 | - pillow==10.0.0 175 | - prometheus-client==0.17.1 176 | - proto-plus==1.22.3 177 | - protobuf==4.24.2 178 | - py-cpuinfo==9.0.0 179 | - pyarrow==13.0.0 180 | - pyasn1==0.5.0 181 | - pyasn1-modules==0.3.0 182 | - pyautogen==0.1.6 183 | - pycocoevalcap==1.2 184 | - pycocotools==2.0.7 185 | - pycparser==2.21 186 | - pydantic==1.10.13 187 | - pygtrie==2.5.0 188 | - pyjwt==2.8.0 189 | - pyparsing==3.0.9 190 | - pytablewriter==1.0.0 191 | - python-dotenv==1.0.0 192 | - python-json-logger==2.0.7 193 | - pytz==2023.3 194 | - pyyaml==6.0.1 195 | - qtconsole==5.4.4 196 | - qtpy==2.4.0 197 | - ray==2.7.0 198 | - referencing==0.30.2 199 | - regex==2023.8.8 200 | - requests==2.31.0 201 | - rfc3339-validator==0.1.4 202 | - rfc3986-validator==0.1.1 203 | - rpds-py==0.10.2 204 | - rsa==4.9 205 | - sacred==0.8.4 206 | - safetensors==0.3.3 207 | - seaborn==0.12.2 208 | - send2trash==1.8.2 209 | - sentencepiece==0.1.99 210 | - smmap==5.0.0 211 | - sniffio==1.3.0 212 | - soupsieve==2.5 213 | - starlette==0.27.0 214 | - sympy==1.12 215 | - tabledata==1.3.1 216 | - tcolorpy==0.1.3 217 | - termcolor==2.3.0 218 | - terminado==0.17.1 219 | - tiktoken==0.4.0 220 | - tinycss2==1.2.1 221 | - tokenizers==0.13.3 222 | - tomli==2.0.1 223 | - torch==2.0.1 224 | - torchaudio==2.0.2 225 | - torchvision==0.15.2 226 | - tornado==6.3.3 227 | - tqdm==4.66.1 228 | - transformers==4.33.3 229 | - tree-of-thoughts==0.3.6 230 | - triton==2.0.0 231 | - typepy==1.3.1 232 | - tzdata==2023.3 233 | - uri-template==1.3.0 234 | - urllib3==1.26.16 235 | - uvicorn==0.23.2 236 | - uvloop==0.17.0 237 | - vllm==0.2.0 238 | - watchfiles==0.20.0 239 | - webcolors==1.13 240 | - webencodings==0.5.1 241 | - websocket-client==1.6.2 242 | - websockets==11.0.3 243 | - widgetsnbextension==4.0.8 244 | - wrapt==1.15.0 245 | - xformers==0.0.22 246 | - yarl==1.9.2 247 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | from os import listdir 4 | from os.path import isfile, join 5 | 6 | from utils import ApiManager 7 | 8 | cates = [ 9 | "Language Modeling and Syntax and Parsing", 10 | "Pragmatics and Discourse and Dialogue and Applications", 11 | "Semantics and Logic", 12 | "Information Retrieval and Topic Modeling", 13 | "Artificial Intelligence", 14 | "Other Topics" 15 | ] 16 | 17 | 18 | class Evaluation: 19 | def __init__(self, seed, open_sourced): 20 | self.open_sourced = open_sourced 21 | self.seed = seed 22 | osd = '_os' if open_sourced else '' 23 | self.osd = osd 24 | 25 | self.no_ctx_files = [f for f in listdir(f'res/{seed}{osd}/no_ctx/') if 26 | isfile(join(f'res/{seed}{osd}/no_ctx/', f))] 27 | self.ctx_files = [f for f in listdir(f'res/{seed}{osd}/ctx/') if isfile(join(f'res/{seed}{osd}/ctx/', f))] 28 | self.ctx_data = json.load(open('data/w_ctx.json', 'r')) 29 | self.no_ctx_data = json.load(open('data/wo_ctx.json', 'r')) 30 | 31 | def _init_result(self): 32 | return { 33 | 'gpt-3.5-turbo': {}, 34 | 'gpt-4': {}, 35 | 'text-bison-001': {} 36 | } if self.open_sourced is False else { 37 | 'Llama-2-70b-chat-hf': {}, 38 | 'Llama-2-13b-chat-hf': {}, 39 | } 40 | 41 | def _file_format(self, f, ctx_path='no_ctx'): 42 | args = f.split('_') 43 | llm_name = f.split('_')[1].replace('.json', '') 44 | shot_type = f.split('_')[0] 45 | prompt_r = '' 46 | if len(args) == 3: 47 | prompt_r = f.split('_')[2].split('.')[0] 48 | if len(args) == 4: 49 | prompt_r = f.split('_')[2] + '_' + f.split('_')[3].split('.')[0] 50 | llm_res = json.load(open(f'res/{self.seed}{self.osd}/{ctx_path}/{f}', 'r')) 51 | if prompt_r != '': 52 | prompt_r = '_' + prompt_r 53 | 54 | return llm_name, shot_type, prompt_r, llm_res 55 | 56 | def _path_format(self, suffix): 57 | no_ctx_save_path = f'res/{self.seed}{self.osd}/res_no_ctx_{suffix}.json' 58 | ctx_save_path = f'res/{self.seed}{self.osd}/res_ctx_{suffix}.json' 59 | 60 | return no_ctx_save_path, ctx_save_path 61 | 62 | # We use accuracy to evaluate multiple choice questions 63 | def evaluate_mc(self): 64 | ### w/o context 65 | res = self._init_result() 66 | no_ctx_save_path, ctx_save_path = self._path_format('mc') 67 | 68 | for f in self.no_ctx_files: 69 | llm_name, shot_type, prompt_r, llm_res = self._file_format(f) 70 | corr = {c: 0 for c in cates} 71 | cnt = {c: 0 for c in cates} 72 | for i, d in enumerate(llm_res): 73 | if d['type'] == 0: 74 | cate = self.no_ctx_data[i]['category'] 75 | cate_cnt = cnt.get(cate, 0) 76 | cnt[cate] = cate_cnt + 1 77 | if d['retrived_answer'] is not None: 78 | ans = set(d['answer']) 79 | llm_ans = d['retrived_answer'].replace(' ', '').replace("'", '').replace('"', '').split(',') 80 | try: 81 | llm_ans = [int(a) for a in llm_ans] 82 | except Exception: 83 | cate_ans = corr.get(cate, 0) 84 | corr[cate] = cate_ans 85 | llm_ans = set(llm_ans) 86 | if ans == llm_ans: 87 | cate_ans = corr.get(cate, 0) 88 | corr[cate] = cate_ans + 1 89 | else: 90 | cate_ans = corr.get(cate, 0) 91 | corr[cate] = cate_ans 92 | 93 | acc = {k: v / cnt[k] if cnt[k] != 0 else 1 for k, v in corr.items()} 94 | res['count'] = cnt 95 | res['total_count'] = sum(cnt.values()) 96 | overall_acc = np.sum(np.array(list(acc.values())) * np.array(list(cnt.values()))) / sum(cnt.values()) 97 | res[llm_name][f'{shot_type}{prompt_r}'] = { 98 | 'acc': acc, 99 | 'overall_acc': overall_acc 100 | } 101 | json.dump(res, open(no_ctx_save_path, 'w'), indent=4) 102 | 103 | ### w/ context 104 | res = self._init_result() 105 | for f in self.ctx_files: 106 | llm_name, shot_type, prompt_r, llm_res = self._file_format(f, 'ctx') 107 | corr = {c: 0 for c in cates} 108 | cnt = {c: 0 for c in cates} 109 | for i, ds in enumerate(llm_res): 110 | for j, d in enumerate(ds['questions']): 111 | if ds['type'][j] == 0: 112 | cate = self.ctx_data[i]['category'][j] 113 | cate_cnt = cnt.get(cate, 0) 114 | cnt[cate] = cate_cnt + 1 115 | if ds['retrived_answer'][j] is not None: 116 | ans = set(ds['answers'][j]) 117 | llm_ans = ds['retrived_answer'][j].replace(' ', '').replace("'", '').replace('"', '').split( 118 | ',') 119 | try: 120 | llm_ans = [int(a) for a in llm_ans] 121 | except Exception: 122 | cate_ans = corr.get(cate, 0) 123 | corr[cate] = cate_ans 124 | continue 125 | llm_ans = set(llm_ans) 126 | if ans == llm_ans: 127 | cate_ans = corr.get(cate, 0) 128 | corr[cate] = cate_ans + 1 129 | else: 130 | cate_ans = corr.get(cate, 0) 131 | corr[cate] = cate_ans 132 | 133 | acc = {k: v / cnt[k] if cnt[k] != 0 else 1 for k, v in corr.items()} 134 | res['count'] = cnt 135 | res['total_count'] = sum(cnt.values()) 136 | overall_acc = np.sum(np.array(list(acc.values())) * np.array(list(cnt.values()))) / sum(cnt.values()) 137 | res[llm_name][f'{shot_type}{prompt_r}'] = { 138 | 'acc': acc, 139 | 'overall_acc': overall_acc 140 | } 141 | json.dump(res, open(ctx_save_path, 'w'), indent=4) 142 | 143 | def evaluate_sc(self): 144 | ### w/o context 145 | res = self._init_result() 146 | no_ctx_save_path, ctx_save_path = self._path_format('mc-sc') 147 | files = [ 148 | 'few-shot-sc_gpt-3.5-turbo.json', 149 | 'few-shot-sc_gpt-3.5-turbo_cot.json', 150 | 'few-shot-sc_gpt-4.json', 151 | 'few-shot-sc_gpt-4_cot.json', 152 | 'few-shot-sc_text-bison-001.json', 153 | 'few-shot-sc_text-bison-001_cot.json', 154 | 'zero-shot-sc_gpt-3.5-turbo.json', 155 | 'zero-shot-sc_gpt-3.5-turbo_cot.json', 156 | 'zero-shot-sc_gpt-4.json', 157 | 'zero-shot-sc_gpt-4_cot.json', 158 | 'zero-shot-sc_text-bison-001.json', 159 | 'zero-shot-sc_text-bison-001_cot.json', 160 | ] 161 | for f in files: 162 | llm_name, shot_type, prompt_r, llm_res = self._file_format(f) 163 | corr = {c: 0 for c in cates} 164 | cnt = {c: 0 for c in cates} 165 | for i, d in enumerate(llm_res): 166 | if d['type'] == 0: 167 | cate = self.no_ctx_data[i]['category'] 168 | cate_cnt = cnt.get(cate, 0) 169 | cnt[cate] = cate_cnt + 1 170 | if d['llm_answer'] is not None: 171 | ans = set(d['answer']) 172 | llm_ans = d['llm_answer'].replace(' ', '').replace("'", '').replace('"', '').split(',') 173 | try: 174 | llm_ans = [int(a) for a in llm_ans] 175 | except Exception: 176 | print('model:', f, '\nerr: ', llm_ans) 177 | cate_ans = corr.get(cate, 0) 178 | corr[cate] = cate_ans 179 | llm_ans = set(llm_ans) 180 | if ans == llm_ans: 181 | cate_ans = corr.get(cate, 0) 182 | corr[cate] = cate_ans + 1 183 | else: 184 | cate_ans = corr.get(cate, 0) 185 | corr[cate] = cate_ans 186 | 187 | acc = {k: v / cnt[k] if cnt[k] != 0 else 1 for k, v in corr.items()} 188 | res['count'] = cnt 189 | res['total_count'] = sum(cnt.values()) 190 | overall_acc = np.sum(np.array(list(acc.values())) * np.array(list(cnt.values()))) / sum(cnt.values()) 191 | res[llm_name][f'{shot_type}{prompt_r}'] = { 192 | 'acc': acc, 193 | 'overall_acc': overall_acc 194 | } 195 | json.dump(res, open(no_ctx_save_path, 'w'), indent=4) 196 | 197 | ### w/ context 198 | res = self._init_result() 199 | for f in files: 200 | llm_name, shot_type, prompt_r, llm_res = self._file_format(f, 'ctx') 201 | corr = {c: 0 for c in cates} 202 | cnt = {c: 0 for c in cates} 203 | for i, ds in enumerate(llm_res): 204 | for j, d in enumerate(ds['questions']): 205 | if ds['type'][j] == 0: 206 | cate = self.ctx_data[i]['category'][j] 207 | cate_cnt = cnt.get(cate, 0) 208 | cnt[cate] = cate_cnt + 1 209 | if ds['llm_answer'][j] is not None: 210 | ans = set(ds['answers'][j]) 211 | llm_ans = ds['llm_answer'][j].replace(' ', '').replace("'", '').replace('"', '').split( 212 | ',') 213 | try: 214 | llm_ans = [int(a) for a in llm_ans] 215 | except Exception: 216 | print('model:', f, '\nerr: ', llm_ans) 217 | cate_ans = corr.get(cate, 0) 218 | corr[cate] = cate_ans 219 | continue 220 | llm_ans = set(llm_ans) 221 | if ans == llm_ans: 222 | cate_ans = corr.get(cate, 0) 223 | corr[cate] = cate_ans + 1 224 | else: 225 | cate_ans = corr.get(cate, 0) 226 | corr[cate] = cate_ans 227 | 228 | acc = {k: v / cnt[k] if cnt[k] != 0 else 1 for k, v in corr.items()} 229 | res['count'] = cnt 230 | res['total_count'] = sum(cnt.values()) 231 | overall_acc = np.sum(np.array(list(acc.values())) * np.array(list(cnt.values()))) / sum(cnt.values()) 232 | res[llm_name][f'{shot_type}{prompt_r}'] = { 233 | 'acc': acc, 234 | 'overall_acc': overall_acc 235 | } 236 | json.dump(res, open(ctx_save_path, 'w'), indent=4) 237 | 238 | # We use ROUGE-L, CIDEr (for unique answer) to evaluate short answer questions with unique answer 239 | def evaluate_sa_unique(self): 240 | from pycocoevalcap.rouge.rouge import Rouge 241 | from pycocoevalcap.cider.cider import Cider 242 | 243 | res = self._init_result() 244 | no_ctx_save_path, ctx_save_path = self._path_format('sa_unique') 245 | 246 | for f in self.no_ctx_files: 247 | llm_name, shot_type, prompt_r, llm_res = self._file_format(f) 248 | 249 | llm_ans_dict = {} 250 | ans_dict = {} 251 | overall_llm_ans = {} 252 | overall_ans = {} 253 | cnt = {} 254 | for i, d in enumerate(llm_res): 255 | if d['type'] == 1 and self.no_ctx_data[i].get('unique_ans', None) == 1: 256 | cate = self.no_ctx_data[i]['category'] 257 | cate_cnt = cnt.get(cate, 0) 258 | cnt[cate] = cate_cnt + 1 259 | 260 | ans = d['answer'] 261 | llm_ans = d['llm_answer'] 262 | 263 | if llm_ans == "" or llm_ans is None: 264 | llm_ans = "No answer provided." 265 | 266 | tmp1 = llm_ans_dict.get(cate, {}) 267 | tmp1[i] = [llm_ans] 268 | llm_ans_dict[cate] = tmp1 269 | 270 | tmp2 = ans_dict.get(cate, {}) 271 | tmp2[i] = [ans] 272 | ans_dict[cate] = tmp2 273 | 274 | overall_llm_ans[i] = [llm_ans] 275 | overall_ans[i] = [ans] 276 | 277 | scores = {} 278 | for k in cnt.keys(): 279 | rouge = Rouge().compute_score(llm_ans_dict[k], ans_dict[k])[0] 280 | cider = Cider().compute_score(llm_ans_dict[k], ans_dict[k])[0] 281 | scores[k] = { 282 | 'ROUGE-L': rouge, 283 | 'CIDEr': cider, 284 | } 285 | res['count'] = cnt 286 | res['total_count'] = sum(cnt.values()) 287 | res[llm_name][f'{shot_type}{prompt_r}'] = { 288 | 'scores': scores, 289 | 'avg_score': { 290 | 'ROUGE-L': Rouge().compute_score(overall_llm_ans, overall_ans)[0], 291 | 'CIDEr': Cider().compute_score(overall_llm_ans, overall_ans)[0], 292 | } 293 | } 294 | json.dump(res, open(no_ctx_save_path, 'w'), indent=4) 295 | 296 | ### w/ context 297 | res = self._init_result() 298 | for f in self.ctx_files: 299 | llm_name, shot_type, prompt_r, llm_res = self._file_format(f, 'ctx') 300 | llm_ans_dict = {} 301 | ans_dict = {} 302 | overall_llm_ans = {} 303 | overall_ans = {} 304 | cnt = {} 305 | for i, ds in enumerate(llm_res): 306 | for j, d in enumerate(ds['questions']): 307 | if ds['type'][j] == 1 and self.ctx_data[i].get('unique_ans', [0 for _ in range(j + 1)])[j] == 1: 308 | cate = self.ctx_data[i]['category'][j] 309 | cate_cnt = cnt.get(cate, 0) 310 | cnt[cate] = cate_cnt + 1 311 | 312 | ans = ds['answers'][j] 313 | llm_ans = ds['llm_answer'][j] 314 | 315 | if llm_ans == "" or llm_ans is None: 316 | llm_ans = "No answer provided." 317 | 318 | tmp1 = llm_ans_dict.get(cate, {}) 319 | tmp1[f'{i}_{j}'] = [llm_ans] 320 | llm_ans_dict[cate] = tmp1 321 | 322 | tmp2 = ans_dict.get(cate, {}) 323 | tmp2[f'{i}_{j}'] = [ans] 324 | ans_dict[cate] = tmp2 325 | 326 | overall_llm_ans[f'{i}_{j}'] = [llm_ans] 327 | overall_ans[f'{i}_{j}'] = [ans] 328 | 329 | scores = {} 330 | for k in cnt.keys(): 331 | rouge = Rouge().compute_score(llm_ans_dict[k], ans_dict[k])[0] 332 | cider = Cider().compute_score(llm_ans_dict[k], ans_dict[k])[0] 333 | scores[k] = { 334 | 'ROUGE-L': rouge, 335 | 'CIDEr': cider, 336 | } 337 | 338 | res['count'] = cnt 339 | res['total_count'] = sum(cnt.values()) 340 | res[llm_name][f'{shot_type}{prompt_r}'] = { 341 | 'scores': scores, 342 | 'avg_score': { 343 | 'ROUGE-L': Rouge().compute_score(overall_llm_ans, overall_ans)[0], 344 | 'CIDEr': Cider().compute_score(overall_llm_ans, overall_ans)[0], 345 | } 346 | } 347 | json.dump(res, open(ctx_save_path, 'w'), indent=4) 348 | 349 | # We use GPT-4 to evaluate short answer questions 350 | def evaluate_sa(self): 351 | 352 | res = self._init_result() 353 | no_ctx_save_path, ctx_save_path = self._path_format('sa') 354 | api = ApiManager( 355 | model_name='gpt-4', 356 | seed=41, 357 | default_max_tokens=20, 358 | temperature=0, 359 | ) 360 | prompt_SA_EVAL = '''You are a NLP professional assistant, your work is to evaluate whether the student's answer is correct for the given short answer question. 361 | A teacher answer is also provided, your evaluation should based on the teacher answer. 362 | If the student is correct, return 1, else return 0. 363 | Your response should ONLY contain 0 or 1. 364 | 365 | Short answer question: 366 | "{q}" 367 | 368 | Teacher answer: 369 | "{eg_ans}" 370 | 371 | Student answer (evaluate this answer): 372 | "{llm_ans}" 373 | 374 | Your response: 375 | ''' 376 | 377 | for f in self.no_ctx_files: 378 | llm_name, shot_type, prompt_r, llm_res = self._file_format(f) 379 | corr = {c: 0 for c in cates} 380 | cnt = {c: 0 for c in cates} 381 | for i, d in enumerate(llm_res): 382 | if d['type'] == 1: 383 | print(f'Processing Q.{i}') 384 | cate = self.no_ctx_data[i]['category'] 385 | cate_cnt = cnt.get(cate, 0) 386 | cnt[cate] = cate_cnt + 1 387 | 388 | q = d['question'] 389 | ans = d['answer'] 390 | llm_ans = d['llm_answer'] 391 | if llm_ans == "" or llm_ans is None: 392 | llm_ans = "No answer provided." 393 | 394 | score = api( 395 | [{'role': 'user', 'content': prompt_SA_EVAL.format(q=q, eg_ans=ans, llm_ans=llm_ans)}] 396 | ) 397 | try: 398 | score = int(score) 399 | except ValueError: 400 | score = 0 401 | if score == 1: 402 | cate_ans = corr.get(cate, 0) 403 | corr[cate] = cate_ans + 1 404 | else: 405 | cate_ans = corr.get(cate, 0) 406 | corr[cate] = cate_ans 407 | d['retrived_answer'] = score 408 | 409 | acc = {k: v / cnt[k] if cnt[k] != 0 else 1 for k, v in corr.items()} 410 | res['count'] = cnt 411 | res['total_count'] = sum(cnt.values()) 412 | overall_acc = np.sum(np.array(list(acc.values())) * np.array(list(cnt.values()))) / sum(cnt.values()) 413 | res[llm_name][f'{shot_type}{prompt_r}'] = { 414 | 'acc': acc, 415 | 'overall_acc': overall_acc 416 | } 417 | json.dump(llm_res, open(f'res/{self.seed}{self.osd}/no_ctx/{f}', 'w'), indent=4) 418 | json.dump(res, open(no_ctx_save_path, 'w'), indent=4) 419 | 420 | ### w/ context 421 | res = self._init_result() 422 | for f in self.ctx_files: 423 | llm_name, shot_type, prompt_r, llm_res = self._file_format(f, 'ctx') 424 | corr = {c: 0 for c in cates} 425 | cnt = {c: 0 for c in cates} 426 | for i, ds in enumerate(llm_res): 427 | tmp = [] 428 | for j, d in enumerate(ds['questions']): 429 | if ds['type'][j] == 1: 430 | print(f'Processing Q.{i}_{j}') 431 | cate = self.ctx_data[i]['category'][j] 432 | cate_cnt = cnt.get(cate, 0) 433 | cnt[cate] = cate_cnt + 1 434 | 435 | q = ds['questions'][j] 436 | ans = ds['answers'][j] 437 | llm_ans = ds['llm_answer'][j] 438 | if llm_ans == "" or llm_ans is None: 439 | llm_ans = "No answer provided." 440 | 441 | score = api( 442 | [{'role': 'user', 'content': prompt_SA_EVAL.format(q=q, eg_ans=ans, llm_ans=llm_ans)}] 443 | ) 444 | try: 445 | score = int(score) 446 | except ValueError: 447 | print('err: ', score) 448 | score = 0 449 | if score == 1: 450 | cate_ans = corr.get(cate, 0) 451 | corr[cate] = cate_ans + 1 452 | else: 453 | cate_ans = corr.get(cate, 0) 454 | corr[cate] = cate_ans 455 | tmp.append(score) 456 | ds['retrived_answer'] = tmp 457 | 458 | acc = {k: v / cnt[k] if cnt[k] != 0 else 1 for k, v in corr.items()} 459 | res['count'] = cnt 460 | res['total_count'] = sum(cnt.values()) 461 | overall_acc = np.sum(np.array(list(acc.values())) * np.array(list(cnt.values()))) / sum(cnt.values()) 462 | res[llm_name][f'{shot_type}{prompt_r}'] = { 463 | 'acc': acc, 464 | 'overall_acc': overall_acc 465 | } 466 | json.dump(llm_res, open(f'res/{self.seed}{self.osd}/ctx/{f}', 'w'), indent=4) 467 | json.dump(res, open(ctx_save_path, 'w'), indent=4) 468 | 469 | 470 | if __name__ == '__main__': 471 | eval_mg = Evaluation(41, False) 472 | eval_mg_oc = Evaluation(41, True) 473 | 474 | eval_mg.evaluate_mc() 475 | eval_mg.evaluate_sa() 476 | eval_mg.evaluate_sc() 477 | eval_mg.evaluate_sa_unique() 478 | 479 | eval_mg_oc.evaluate_mc() 480 | eval_mg_oc.evaluate_sa() 481 | eval_mg_oc.evaluate_sc() 482 | eval_mg_oc.evaluate_sa_unique() 483 | -------------------------------------------------------------------------------- /prompts/__init__.py: -------------------------------------------------------------------------------- 1 | from .sys_prompt import SYS_PROMPT_MAPPING 2 | from .qa_prompt import PROMPT_MAPPING, SHORTANS_ZS_COT_ST1, CTX 3 | from .tot_prompt import SOLUTION_PROMPT, STATE_PROMPT, VOTE_PROMPT -------------------------------------------------------------------------------- /prompts/qa_prompt.py: -------------------------------------------------------------------------------- 1 | MULTICHOICE_STD_ZS = ''' 2 | Answer the final multiple choice question. 3 | Your output must be only numbers spliting by comma (e.g., 0,1,...) with no descriptions. 4 | 5 | Example Input: 6 | ChatGPT is created by which of the following companies? 7 | 0: Google 8 | 1: Meta 9 | 2: Microsoft 10 | 4: Amazon 11 | 3: OpenAI 12 | 13 | Example Output: 14 | 3 15 | 16 | Example Input: 17 | This is the input question, choose the correct answer. 18 | 0: Correct answer 19 | 1: Option 2 20 | 2: Correct answer 21 | 3: Option 4 22 | 23 | Example Output: 24 | 0,2 25 | 26 | Example Input: 27 | GPT-4 is created by OpenAI. 28 | 0: True 29 | 1: False 30 | 31 | Example Output: 32 | 0 33 | 34 | Input (You need to answer this question): 35 | {input} 36 | 37 | Output: 38 | ''' 39 | 40 | 41 | MULTICHOICE_STD_ZS_COT = ''' 42 | Answer the final multiple choice question. Your output must be only numbers spliting by comma (e.g., 0,1,...) with no descriptions. 43 | 44 | Example Input: 45 | ChatGPT is created by which of the following companies? 46 | 0: Google 47 | 1: Meta 48 | 2: Microsoft 49 | 4: Amazon 50 | 3: OpenAI 51 | 52 | Example Thought: 53 | ChatGPT is a large-scale transformer-based language model, created by OpenAI at 2022. 54 | 55 | Example Output: 56 | 3 57 | 58 | Example Input: 59 | This is the input question, choose the correct answer. 60 | 0: Correct answer 61 | 1: Option 2 62 | 2: Correct answer 63 | 3: Option 4 64 | 65 | Example Thought: 66 | This is a multiple choice question, the "correct answer" appears at index 0 and 2. 67 | 68 | Example Output: 69 | 0,2 70 | 71 | Example Input: 72 | GPT-4 is created by OpenAI. 73 | 0: True 74 | 1: False 75 | 76 | Example Thought: 77 | GPT-4 is created by OpenAI at 2022. 78 | 79 | Example Output: 80 | 0 81 | 82 | Input (You need to answer this question): 83 | {input} 84 | 85 | Output: 86 | ''' 87 | 88 | 89 | MULTICHOICE_STD_FS = ''' 90 | Answer the final multiple choice question. Your output must be only numbers spliting by comma (e.g., 0,1,...) with no descriptions. 91 | 92 | Example Input: 93 | What is the main challenge(s) of NLP? 94 | 0: Handling Ambiguity of Sentences 95 | 1: Handling Tokenization 96 | 2: Handling POS-Tagging 97 | 3: All of the mentioned 98 | 99 | Example Output: 100 | 0 101 | 102 | Example Input: 103 | What is the field of Natural Language Processing (NLP)> 104 | 0: Computer Science 105 | 1: Artificial Intelligence 106 | 2: Linguistics 107 | 3: All of the mentioned 108 | 109 | Example Output: 110 | 3 111 | 112 | Example Input: 113 | Choose from the following areas where NLP can be useful. 114 | 0: Automatic Text Summarization 115 | 1: Automatic Question-Answering Systems 116 | 2: Information Retrieval 117 | 3: All of the mentioned 118 | 119 | Example Output: 120 | 3 121 | 122 | Input (You need to answer this question): 123 | {input} 124 | 125 | Output: 126 | ''' 127 | 128 | 129 | MULTICHOICE_STD_FS_COT = ''' 130 | Answer the final multiple choice question. Your output must be only numbers spliting by comma (e.g., 0,1,...) with no descriptions. 131 | 132 | Example Input: 133 | What is the main challenge(s) of NLP? 134 | 0: Handling Ambiguity of Sentences 135 | 1: Handling Tokenization 136 | 2: Handling POS-Tagging 137 | 3: All of the mentioned 138 | 139 | Example Thought: 140 | There are enormous ambiguity exists when processing natural language. 141 | 142 | Example Output: 143 | 0 144 | 145 | Example Input: 146 | What is Machine Translation? 147 | 0: Converts one human language to another 148 | 1: Converts human language to machine language 149 | 2: Converts any human language to English 150 | 3: Converts Machine language to human language 151 | 152 | Example Thought: 153 | The best known example of machine translation is google translator, which help people to translate one language to another. 154 | 155 | Example Output: 156 | 0 157 | 158 | Example Input: 159 | What is Coreference Resolution? 160 | 0: Anaphora Resolution 161 | 1: Given a sentence or larger chunk of text, determine which words (“mentions”) refer to the same objects (“entities”) 162 | 2: All of the mentioned 163 | 3: None of the mentioned 164 | 165 | Example Thought: 166 | Anaphora resolution is a specific type of coreference resolution. 167 | 168 | Example Output: 169 | 1 170 | 171 | Input (You need to answer this question): 172 | {input} 173 | 174 | Output: 175 | ''' 176 | 177 | 178 | SHORTANS_ZS = ''' 179 | Answer the following short answer question. Your answer should be no more then 150 words. 180 | 181 | Input (You need to answer this question): 182 | {input} 183 | 184 | Output: 185 | ''' 186 | 187 | 188 | SHORTANS_ZS_COT_ST1 = ''' 189 | Answer the following short answer question. Your answer should be no more then 150 words. 190 | 191 | Input (You need to answer this question): 192 | {input} 193 | 194 | Let's think step by step! Output your thought of the question first: 195 | ''' 196 | 197 | 198 | SHORTANS_ZS_COT = ''' 199 | Answer the following short answer question. Your answer should be no more then 150 words. 200 | 201 | Input (You need to answer this question): 202 | {input} 203 | 204 | Your thought: 205 | {thought} 206 | 207 | Output: 208 | ''' 209 | 210 | 211 | SHORTANS_FS = ''' 212 | Answer the following short answer question. Your answer should be no more then 150 words. 213 | 214 | Example Input: 215 | Order the following syntactic features in decreasing order by salience (according to the Lappin/Leass algorithm for anaphora resolution): direct object (accusative), indirect object, subject, recency. 216 | 217 | Example Output: 218 | From strongest to weakest: 219 | 1. Recency 220 | 2. Subject 221 | 3. Direct object 222 | 4. Indirect object 223 | 224 | Example Input: 225 | List any two real-life applications of Natural Language Processing. 226 | 227 | Example Output: 228 | 1. Google Translate. 229 | 2. ChatGPT. 230 | 231 | Example Input: 232 | What are stop words? 233 | 234 | Example Output: 235 | Stop words are said to be useless data for a search engine. Words such as articles, prepositions, etc. are considered stop words. 236 | There are stop words such as was, were, is, am, the, a, an, how, why, and many more. 237 | 238 | Input (You need to answer this question): 239 | {input} 240 | 241 | Output: 242 | ''' 243 | 244 | 245 | SHORTANS_FS_COT = ''' 246 | Answer the following short answer question. Your answer should be no more then 150 words. 247 | 248 | Example Input: 249 | Order the following syntactic features in decreasing order by salience (according to the Lappin/Leass algorithm for anaphora resolution): direct object (accusative), indirect object, subject, recency. 250 | 251 | Example Thought: 252 | The following is the order of the given syntactic features in decreasing order by salience (according to the Lappin/Leass algorithm for anaphora resolution): 253 | Direct object (accusative): The least salient antecedent is the direct object of the sentence. 254 | Indirect object: The third most salient antecedent is the indirect object of the sentence. 255 | Subject: The next most salient antecedent is the subject of the sentence. 256 | Recency: The most salient antecedent is the noun phrase that occurs most recently in the discourse. 257 | 258 | Example Output: 259 | From strongest to weakest: 260 | 1. Recency 261 | 2. Subject 262 | 3. Direct object 263 | 4. Indirect object 264 | 265 | Example Input: 266 | List any two real-life applications of Natural Language Processing. 267 | 268 | Example Thought: 269 | Natural Language Processing (NLP) has a wide range of applications across various industries due to its ability to understand and generate human language. 270 | There are a lot of applications of NLP in deep translation (like google translate), chatbot (like ChatGPT) and many more. 271 | 272 | Example Output: 273 | 1. Google Translate. 274 | 2. ChatGPT. 275 | 276 | Example Input: 277 | What are stop words? 278 | 279 | Example Thought: 280 | Stop words are said to be useless data for a search engine. 281 | There are stop words such as was, were, is, am, the, a, an, how, why, and many more. 282 | 283 | Example Output: 284 | Words such as articles, prepositions, etc. are considered stop words. 285 | 286 | Input (You need to answer this question): 287 | {input} 288 | ''' 289 | 290 | 291 | MATH_ZS = ''' 292 | Answer the following math question. Your answer should be a number, a list of numbers, or a LaTeX expression. 293 | 294 | Example Input: 295 | 1 + 1 296 | 297 | Example Output: 298 | 2 299 | 300 | Example Input: 301 | $\\frac{{1}}{{2}} + \\frac{{1}}{{3}}$ 302 | 303 | Example Output: 304 | \\frac{{5}}{{6}} 305 | 306 | Example Input: 307 | $f(x) = 4x^2 + 3y$ 308 | Solve the $\\frac{{\\partial f(x)}}{{\\partial x}}$ 309 | 310 | Example Output: 311 | 8x 312 | 313 | Input (You need to answer this question): 314 | {input} 315 | 316 | Output: 317 | ''' 318 | 319 | 320 | MATH_ZS_COT = ''' 321 | Answer the following math question. Your answer should be a number, a list of numbers, or a LaTeX expression. 322 | 323 | Example Input: 324 | 1 + 1 325 | 326 | Example Thought: 327 | 1 + 1 = 2 328 | 329 | Example Output: 330 | 2 331 | 332 | Example Input: 333 | $\\frac{{1}}{{2}} + \\frac{{1}}{{3}}$ 334 | 335 | Example Thought: 336 | $\\frac{{1}}{{2}} + \\frac{{1}}{{3}} = \\frac{{5}}{{6}}$ 337 | 338 | Example Output: 339 | \\frac{{5}}{{6}} 340 | 341 | Example Input: 342 | $f(x) = 4x^2 + 3y$ 343 | Solve the $\\frac{{\\partial f(x)}}{{\\partial x}}$ 344 | 345 | Example Thought: 346 | $\\frac{{\\partial f(x)}}{{\\partial x}} = 2\\times 4x + 0$ 347 | 348 | Example Output: 349 | 8x 350 | 351 | Example Input (You need to answer this question): 352 | {input} 353 | 354 | Output: 355 | ''' 356 | 357 | 358 | MATH_FS = ''' 359 | Answer the following math question. Your answer should be a number, a list of numbers, or a LaTeX expression. 360 | 361 | Example Input: 362 | $\\frac{{1}}{{2}} + \\frac{{1}}{{3}}$ 363 | 364 | Example Output: 365 | \\frac{{5}}{{6}} 366 | 367 | Example Input: 368 | $f(x) = 4x^2 + 3y$ 369 | Solve the $\\frac{{\\partial f(x)}}{{\\partial x}}$ 370 | 371 | Example Output: 372 | 8x 373 | 374 | Example Input: 375 | Consider the following bilingual (Spanish-English) corpus. 376 | gato blanco 377 | white cat 378 | 379 | el gato 380 | the cat 381 | 382 | Considering only the following three alignment types: 383 | 1. || {{1,2}} 384 | 2. X {{2,1}} 385 | 386 | Start with a uniform distribution for $t(white|gato)$, $t(white|blanco)$, $t(cat|gato)$, $t(cat|blanco)$, $t(cat|el)$, $t(the|el)$, and $t(the|gato)$. 387 | Show the values of them after two iterations of the EM algorithm. 388 | 389 | Example Output: 390 | \\frac{{1}}{{6}},\\frac{{2}}{{3}},\\frac{{2}}{{3}},\\frac{{1}}{{3}},\\frac{{1}}{{3}},\\frac{{2}}{{3}},\\frac{{1}}{{6}} 391 | 392 | Input: 393 | {input} 394 | 395 | Output (You need to answer this question): 396 | ''' 397 | 398 | 399 | MATH_FS_COT = ''' 400 | Answer the following math question. Your answer should be a number, a list of numbers, or a LaTeX expression. 401 | 402 | Example Input: 403 | $\\frac{{1}}{{2}} + \\frac{{1}}{{3}}$ 404 | 405 | Example Thought: 406 | 1 + 1 = 2 407 | 408 | Example Output: 409 | \\frac{{5}}{{6}} 410 | 411 | Example Input: 412 | $f(x) = 4x^2 + 3y$ 413 | Solve the $\\frac{{\\partial f(x)}}{{\\partial x}}$ 414 | 415 | Example Thought: 416 | $\\frac{{\\partial f(x)}}{{\\partial x}} = 2\\times 4x + 0$ 417 | 418 | Example Output: 419 | 8x 420 | 421 | Example Input: 422 | Consider the following bilingual (Spanish-English) corpus. 423 | gato blanco 424 | white cat 425 | 426 | el gato 427 | the cat 428 | 429 | Considering only the following three alignment types: 430 | 1. || {{1,2}} 431 | 2. X {{2,1}} 432 | 433 | Start with a uniform distribution for $t(white|gato)$, $t(white|blanco)$, $t(cat|gato)$, $t(cat|blanco)$, $t(cat|el)$, $t(the|el)$, and $t(the|gato)$. 434 | Show the values of them after two iterations of the EM algorithm. 435 | 436 | Example Thought: 437 | Initialization: 438 | $t(white|gato)=\\frac{{1}}{{2}}$ 439 | $t(white|blanco)=\\frac{{1}}{{2}}$ 440 | $t(cat|gato)=\\frac{{1}}{{3}}$ 441 | $t(cat|blanco)=\\frac{{1}}{{3}}$ 442 | $t(cat|el)=\\frac{{1}}{{3}}$ 443 | $t(the|el)=\\frac{{1}}{{2}}$ 444 | $t(the|gato)=\\frac{{1}}{{2}}$ 445 | 446 | $\\textbf{{First Iteration:}}$ 447 | 448 | gato blanco 449 | white cat 450 | 451 | For alignment type X: 452 | $p(a,f|e)=\\frac{{1}}{{2}}\\times \\frac{{1}}{{3}}=\\frac{{1}}{{6}}$ 453 | $p(a|e,f)=\\frac{{1}}{{2}}$ 454 | 455 | For alignment type ||: 456 | $p(a, f|e) = \\frac{{1}}{{2}}\\times\\frac{{1}}{{3}} = \\frac{{1}}{{6}}$ 457 | $p(a|e, f) = \\frac{{1}}{{2}}$ 458 | 459 | el gato 460 | the cat 461 | 462 | For alignment type X: 463 | $p(a,f|e)=\\frac{{1}}{{2}}\\times \\frac{{1}}{{3}}=\\frac{{1}}{{6}}$ 464 | $p(a|e,f)=\\frac{{1}}{{2}}$ 465 | 466 | For alignment type ||: 467 | $p(a, f|e) = \\frac{{1}}{{2}}\\times\\frac{{1}}{{3}} = \\frac{{1}}{{6}}$ 468 | $p(a|e, f) = \\frac{{1}}{{2}}$ 469 | 470 | Fractional Counts: 471 | $t(white|gato)=\\frac{{1}}{{2}}$ 472 | $t(white|blanco)=\\frac{{1}}{{2}}$ 473 | $t(cat|gato)=\\frac{{1}}{{2}}+\\frac{{1}}{{2}}=1$ 474 | $t(cat|blanco)=\\frac{{1}}{{2}}$ 475 | $t(cat|el)=\\frac{{1}}{{2}}$ 476 | $t(the|el)=\\frac{{1}}{{2}}$ 477 | $t(the|gato)=\\frac{{1}}{{2}}$ 478 | 479 | Normalize to get updated parameters: 480 | $t(white|gato)=\\frac{{1}}{{4}}$ 481 | $t(white|blanco)=\\frac{{1}}{{2}}$ 482 | $t(cat|gato)=\\frac{{1}}{{2}}$ 483 | $t(cat|blanco)=\\frac{{1}}{{2}}$ 484 | $t(cat|el)=\\frac{{1}}{{2}}$ 485 | $t(the|el)=\\frac{{1}}{{2}}$ 486 | $t(the|gato)=\\frac{{1}}{{4}}$ 487 | 488 | $\\textbf{{Second Iteration:}}$ 489 | gato blanco\nwhite cat 490 | 491 | For alignment type X: 492 | $p(a,f|e)=\\frac{{1}}{{2}}\\times \\frac{{1}}{{2}}=\\frac{{1}}{{4}}$ 493 | $p(a|e,f)=\\frac{{2}}{{3}}$ 494 | 495 | For alignment type ||: 496 | $p(a, f|e) = \\frac{{1}}{{4}}\\times\\frac{{1}}{{2}} = \\frac{{1}}{{8}}$ 497 | $p(a|e, f) = \\frac{{1}}{{3}}$ 498 | 499 | el gato 500 | the cat 501 | 502 | For alignment type X: 503 | $p(a,f|e)=\\frac{{1}}{{4}}\\times \\frac{{1}}{{2}}=\\frac{{1}}{{8}}$ 504 | $p(a|e,f)=\\frac{{1}}{{3}}$ 505 | 506 | For alignment type ||: 507 | $p(a, f|e) = \\frac{{1}}{{2}}\\times\\frac{{1}}{{2}} = \\frac{{1}}{{4}}$ 508 | $p(a|e, f) = \\frac{{2}}{{3}}$ 509 | 510 | Fractional Counts: 511 | $t(white|gato)=\\frac{{1}}{{3}}$ 512 | $t(white|blanco)=\\frac{{2}}{{3}}$ 513 | $t(cat|gato)=\\frac{{2}}{{3}}+\\frac{{2}}{{3}}=\\frac{{4}}{{3}}$ 514 | $t(cat|blanco)=\\frac{{1}}{{3}}$ 515 | $t(cat|el)=\\frac{{1}}{{3}}$ 516 | $t(the|el)=\\frac{{2}}{{3}}$ 517 | $t(the|gato)=\\frac{{1}}{{3}}$ 518 | 519 | Finally, you can normalize to get updated parameters. 520 | 521 | Example Output: 522 | \\frac{{1}}{{6}},\\frac{{2}}{{3}},\\frac{{2}}{{3}},\\frac{{1}}{{3}},\\frac{{1}}{{3}},\\frac{{2}}{{3}},\\frac{{1}}{{6}} 523 | 524 | Input (You need to answer this question): 525 | {input} 526 | 527 | Output: 528 | ''' 529 | 530 | 531 | CTX = '''The context of the following questions is: 532 | {context} 533 | ''' 534 | 535 | PROMPT_MAPPING = { 536 | 0: { 537 | 'zero-shot': MULTICHOICE_STD_ZS, 538 | 'few-shot': MULTICHOICE_STD_FS, 539 | 'zero-shot_cot': MULTICHOICE_STD_ZS_COT, 540 | 'few-shot_cot': MULTICHOICE_STD_FS_COT, 541 | 'zero-shot_tot': MULTICHOICE_STD_ZS_COT, 542 | 'few-shot_tot': MULTICHOICE_STD_FS_COT 543 | }, 544 | 1: { 545 | 'zero-shot': SHORTANS_ZS, 546 | 'few-shot': SHORTANS_FS, 547 | 'zero-shot_cot': SHORTANS_ZS_COT, 548 | 'few-shot_cot': SHORTANS_FS_COT, 549 | 'zero-shot_tot': SHORTANS_ZS_COT, 550 | 'few-shot_tot': SHORTANS_FS_COT 551 | }, 552 | 2: { 553 | 'zero-shot': MATH_ZS, 554 | 'few-shot': MATH_FS, 555 | 'zero-shot_cot': MATH_ZS_COT, 556 | 'few-shot_cot': MATH_FS_COT, 557 | 'zero-shot_tot': MATH_ZS_COT, 558 | 'few-shot_tot': MATH_FS_COT 559 | } 560 | } 561 | -------------------------------------------------------------------------------- /prompts/sys_prompt.py: -------------------------------------------------------------------------------- 1 | SYS_PROMPT_MULTICHOICE = """Answer the multiple choice questions in the field of Natural Language Processing (NLP). 2 | You should select a single (or a couple of) answer(s) from the given options. 3 | """ 4 | 5 | SYS_PROMPT_SHORTANS = """Answer the short answer questions in the field of Natural Language Processing (NLP). 6 | You should provide a concise short answer for the question. 7 | All the math symbols in your answer must be converted to LaTeX format (e.g., \\pi, \\sqrt{{2}}). 8 | """ 9 | 10 | SYS_PROMPT_MATH = """Answer the mathematics questions in the field of Natural Language Processing (NLP). 11 | You should provide a (or a couple of) number(s) or LaTeX expression(s). 12 | When the answer is a fraction, please use \\frac{{}}{{}} to express it (e.g., \\frac{{1}}{{2}}). 13 | When the answer is a vector or matrix, please use \\begin{{bmatrix}} \\end{{bmatrix}} to express it (e.g., \\begin{{bmatrix}} 1 & 2 \\\\ 3 & 4 \\end{{bmatrix}}). 14 | You should not trun \\pi, \\sqrt{{2}}... to a decimals, but use the original format. 15 | """ 16 | 17 | SYS_PROMPT_MAPPING = { 18 | 0: SYS_PROMPT_MULTICHOICE, 19 | 1: SYS_PROMPT_SHORTANS, 20 | 2: SYS_PROMPT_MATH 21 | } 22 | -------------------------------------------------------------------------------- /prompts/tot_prompt.py: -------------------------------------------------------------------------------- 1 | SOLUTION_PROMPT = """You're an Tree-of-Thoughts, an superintelligent AI model devoted to help Human by any means necessary. 2 | Your purpose is to generate a series of solutions to comply with the user's instructions, you must generate solutions on the basis of determining the most reliable solution in the shortest amount of time, while taking rejected solutions into account and learning from them. 3 | Considering the reasoning provided:\n\n 4 | ###'{state_text}'\n\n### 5 | Devise the best possible solution for the task: {initial_prompt}, Here are evaluated solutions that were rejected: 6 | ###{rejected_solutions}###, 7 | complete the {initial_prompt} without making the same mistakes you did with the evaluated rejected solutions. 8 | Be simple. Be direct. Provide intuitive solutions as soon as you think of them.""" 9 | 10 | STATE_PROMPT = """Given a question: 11 | '{initial_prompt}', value the following past solutions as a float number between 0 and 1.\n 12 | If the past solution is not directly and concretely in achieving the goal, give it a low score. 13 | Past solutions:\n\n 14 | '{state_text}' 15 | """ 16 | 17 | VOTE_PROMPT = """Given the following states of reasoning, vote for the best state utilizing an scalar value 1-10:\n{states_text}\n\nVote, on the probability of this state of reasoning achieveing {initial_prompt} and become very pessimistic very NOTHING ELSE""" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp==3.8.5 2 | aiosignal==1.3.1 3 | anyio==3.7.1 4 | argon2-cffi==23.1.0 5 | argon2-cffi-bindings==21.2.0 6 | arrow==1.2.3 7 | asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1670263926556/work 8 | async-lru==2.0.4 9 | async-timeout==4.0.3 10 | attrs==23.1.0 11 | Babel==2.12.1 12 | backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work 13 | backports.functools-lru-cache @ file:///home/conda/feedstock_root/build_artifacts/backports.functools_lru_cache_1687772187254/work 14 | beautifulsoup4==4.12.2 15 | bleach==6.0.0 16 | cachetools==5.3.1 17 | certifi==2023.7.22 18 | cffi==1.15.1 19 | chardet==5.2.0 20 | charset-normalizer==3.2.0 21 | click==8.1.7 22 | cmake==3.27.2 23 | colorama==0.4.6 24 | comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1691044910542/work 25 | contourpy==1.1.0 26 | cryptography==41.0.3 27 | cycler==0.11.0 28 | DataProperty==1.0.1 29 | debugpy @ file:///croot/debugpy_1690905042057/work 30 | decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work 31 | defusedxml==0.7.1 32 | diskcache==5.6.1 33 | docopt==0.6.2 34 | entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work 35 | exceptiongroup==1.1.3 36 | executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1667317341051/work 37 | fastapi==0.103.2 38 | fastjsonschema==2.18.0 39 | filelock==3.12.2 40 | FLAML==2.0.0 41 | fonttools==4.42.1 42 | fqdn==1.5.1 43 | frozenlist==1.4.0 44 | fsspec==2023.6.0 45 | gitdb==4.0.10 46 | GitPython==3.1.34 47 | google-ai-generativelanguage==0.2.0 48 | google-api-core==2.11.1 49 | google-auth==2.22.0 50 | google-generativeai==0.1.0 51 | googleapis-common-protos==1.60.0 52 | gptcache==0.1.40 53 | grpcio==1.57.0 54 | grpcio-status==1.57.0 55 | guidance==0.0.64 56 | h11==0.14.0 57 | httptools==0.6.0 58 | huggingface-hub==0.16.4 59 | idna==3.4 60 | install==1.3.5 61 | ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1691424382338/work 62 | ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1685727741709/work 63 | ipython-genutils==0.2.0 64 | ipywidgets==8.1.0 65 | isoduration==20.11.0 66 | jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1690896916983/work 67 | Jinja2==3.1.2 68 | joblib==1.3.2 69 | json5==0.9.14 70 | jsonpickle==3.0.2 71 | jsonpointer==2.4 72 | jsonschema==4.19.0 73 | jsonschema-specifications==2023.7.1 74 | jupyter==1.0.0 75 | jupyter-console==6.6.3 76 | jupyter-events==0.7.0 77 | jupyter-lsp==2.2.0 78 | jupyter_client==8.3.1 79 | jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1686775611663/work 80 | jupyter_server==2.7.3 81 | jupyter_server_terminals==0.4.4 82 | jupyterlab==4.0.5 83 | jupyterlab-pygments==0.2.2 84 | jupyterlab-widgets==3.0.8 85 | jupyterlab_server==2.24.0 86 | kiwisolver==1.4.5 87 | lit==16.0.6 88 | MarkupSafe==2.1.3 89 | matplotlib==3.7.2 90 | matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1660814786464/work 91 | mbstrdecoder==1.1.3 92 | mistune==3.0.1 93 | mpmath==1.3.0 94 | msal==1.23.0 95 | msgpack==1.0.7 96 | multidict==6.0.4 97 | munch==2.5.0 98 | nbclient==0.8.0 99 | nbconvert==7.8.0 100 | nbformat==5.9.2 101 | nest-asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1664684991461/work 102 | networkx==3.1 103 | ninja==1.11.1 104 | nltk==3.8.1 105 | notebook==7.0.3 106 | notebook_shim==0.2.3 107 | numpy==1.26.0b1 108 | nvidia-cublas-cu11==11.10.3.66 109 | nvidia-cuda-cupti-cu11==11.7.101 110 | nvidia-cuda-nvrtc-cu11==11.7.99 111 | nvidia-cuda-runtime-cu11==11.7.99 112 | nvidia-cudnn-cu11==8.5.0.96 113 | nvidia-cufft-cu11==10.9.0.58 114 | nvidia-curand-cu11==10.2.10.91 115 | nvidia-cusolver-cu11==11.4.0.1 116 | nvidia-cusparse-cu11==11.7.4.91 117 | nvidia-nccl-cu11==2.14.3 118 | nvidia-nvtx-cu11==11.7.91 119 | openai==0.27.8 120 | overrides==7.4.0 121 | packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1681337016113/work 122 | pandas==2.1.0 123 | pandocfilters==1.5.0 124 | parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work 125 | pathvalidate==3.1.0 126 | pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1667297516076/work 127 | pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work 128 | Pillow==10.0.0 129 | platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1690813113769/work 130 | prometheus-client==0.17.1 131 | prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1688565951714/work 132 | proto-plus==1.22.3 133 | protobuf==4.24.2 134 | psutil @ file:///opt/conda/conda-bld/psutil_1656431268089/work 135 | ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl 136 | pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work 137 | py-cpuinfo==9.0.0 138 | pyarrow==13.0.0 139 | pyasn1==0.5.0 140 | pyasn1-modules==0.3.0 141 | pyautogen==0.1.6 142 | pycocoevalcap==1.2 143 | pycocotools==2.0.7 144 | pycparser==2.21 145 | pydantic==1.10.13 146 | Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1691408637400/work 147 | pygtrie==2.5.0 148 | PyJWT==2.8.0 149 | pyparsing==3.0.9 150 | pytablewriter==1.0.0 151 | python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1626286286081/work 152 | python-dotenv==1.0.0 153 | python-json-logger==2.0.7 154 | pytz==2023.3 155 | PyYAML==6.0.1 156 | pyzmq @ file:///croot/pyzmq_1686601365461/work 157 | qtconsole==5.4.4 158 | QtPy==2.4.0 159 | ray==2.7.0 160 | referencing==0.30.2 161 | regex==2023.8.8 162 | requests==2.31.0 163 | rfc3339-validator==0.1.4 164 | rfc3986-validator==0.1.1 165 | rpds-py==0.10.2 166 | rsa==4.9 167 | sacred==0.8.4 168 | safetensors==0.3.3 169 | seaborn==0.12.2 170 | Send2Trash==1.8.2 171 | sentencepiece==0.1.99 172 | six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work 173 | smmap==5.0.0 174 | sniffio==1.3.0 175 | soupsieve==2.5 176 | stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work 177 | starlette==0.27.0 178 | sympy==1.12 179 | tabledata==1.3.1 180 | tcolorpy==0.1.3 181 | termcolor==2.3.0 182 | terminado==0.17.1 183 | tiktoken==0.4.0 184 | tinycss2==1.2.1 185 | tokenizers==0.13.3 186 | tomli==2.0.1 187 | torch==2.0.1 188 | torchaudio==2.0.2 189 | torchvision==0.15.2 190 | tornado==6.3.3 191 | tqdm==4.66.1 192 | traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1675110562325/work 193 | transformers==4.33.3 194 | tree-of-thoughts==0.3.6 195 | triton==2.0.0 196 | typepy==1.3.1 197 | typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1688315532570/work 198 | tzdata==2023.3 199 | uri-template==1.3.0 200 | urllib3==1.26.16 201 | uvicorn==0.23.2 202 | uvloop==0.17.0 203 | vllm==0.2.0 204 | watchfiles==0.20.0 205 | wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1673864653149/work 206 | webcolors==1.13 207 | webencodings==0.5.1 208 | websocket-client==1.6.2 209 | websockets==11.0.3 210 | widgetsnbextension==4.0.8 211 | wrapt==1.15.0 212 | xformers==0.0.22 213 | yarl==1.9.2 214 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | from solve import Solver 2 | from sacred import Experiment 3 | 4 | ex = Experiment("LLM") 5 | 6 | 7 | @ex.config 8 | def config(): 9 | model_name = 'gpt-3.5-turbo' 10 | config = 'zero-shot' 11 | ctx = False 12 | 13 | 14 | @ex.automain 15 | def run( 16 | model_name, 17 | seed, 18 | shot_type, 19 | prompt_r, 20 | sys, 21 | ctx, 22 | max_tokens, 23 | self_consistency, 24 | self_reflection 25 | ): 26 | if ctx: 27 | data_path = "data/w_ctx.json" 28 | else: 29 | data_path = "data/wo_ctx.json" 30 | 31 | solver = Solver( 32 | data_path=data_path, 33 | model_name=model_name, 34 | seed=seed, 35 | shot_type=shot_type, 36 | prompt_r=prompt_r, 37 | sys=sys, 38 | ctx=ctx, 39 | max_tokens=max_tokens, 40 | self_consistency=self_consistency, 41 | self_reflection=self_reflection, 42 | ) 43 | 44 | solver.run() 45 | -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | DIR=/home/elpis_ubuntu/LLM/NLPBench 2 | 3 | 4 | # zero-shot, cot, sys, cot + sys 5 | # python $DIR/run.py with "configs/zero-shot.yaml" "model_name='gpt-3.5-turbo'" "ctx=False" 6 | # wait 7 | # python $DIR/run.py with "configs/zero-shot_sys.yaml" "model_name='gpt-3.5-turbo'" "ctx=False" 8 | # wait 9 | # python $DIR/run.py with "configs/zero-shot_cot.yaml" "model_name='gpt-3.5-turbo'" "ctx=False" 10 | # wait 11 | # python $DIR/run.py with "configs/zero-shot_sys_cot.yaml" "model_name='gpt-3.5-turbo'" "ctx=False" 12 | # wait 13 | 14 | # python $DIR/run.py with "configs/zero-shot.yaml" "model_name='gpt-4'" "ctx=False" 15 | # wait 16 | # python $DIR/run.py with "configs/zero-shot_sys.yaml" "model_name='gpt-4'" "ctx=False" 17 | # wait 18 | # python $DIR/run.py with "configs/zero-shot_cot.yaml" "model_name='gpt-4'" "ctx=False" 19 | # wait 20 | # python $DIR/run.py with "configs/zero-shot_sys_cot.yaml" "model_name='gpt-4'" "ctx=False" 21 | # wait 22 | 23 | # python $DIR/run.py with "configs/zero-shot.yaml" "model_name='models/text-bison-001'" "ctx=False" 24 | # wait 25 | # python $DIR/run.py with "configs/zero-shot_sys.yaml" "model_name='models/text-bison-001'" "ctx=False" 26 | # wait 27 | # python $DIR/run.py with "configs/zero-shot_cot.yaml" "model_name='models/text-bison-001'" "ctx=False" 28 | # wait 29 | # python $DIR/run.py with "configs/zero-shot_sys_cot.yaml" "model_name='models/text-bison-001'" "ctx=False" 30 | # wait 31 | 32 | 33 | # few-shot, cot, sys, cot + sys 34 | #python $DIR/run.py with "configs/few-shot.yaml" "model_name='gpt-3.5-turbo'" "ctx=False" 35 | #wait 36 | #python $DIR/run.py with "configs/few-shot_sys.yaml" "model_name='gpt-3.5-turbo'" "ctx=False" 37 | #wait 38 | #python $DIR/run.py with "configs/few-shot_cot.yaml" "model_name='gpt-3.5-turbo'" "ctx=False" 39 | #wait 40 | #python $DIR/run.py with "configs/few-shot_sys_cot.yaml" "model_name='gpt-3.5-turbo'" "ctx=False" 41 | #wait 42 | 43 | #python $DIR/run.py with "configs/few-shot.yaml" "model_name='gpt-4'" "ctx=False" 44 | #wait 45 | #python $DIR/run.py with "configs/few-shot_sys.yaml" "model_name='gpt-4'" "ctx=False" 46 | #wait 47 | #python $DIR/run.py with "configs/few-shot_cot.yaml" "model_name='gpt-4'" "ctx=False" 48 | #wait 49 | #python $DIR/run.py with "configs/few-shot_sys_cot.yaml" "model_name='gpt-4'" "ctx=False" 50 | #wait 51 | 52 | #python $DIR/run.py with "configs/few-shot.yaml" "model_name='models/text-bison-001'" "ctx=False" 53 | #wait 54 | #python $DIR/run.py with "configs/few-shot_sys.yaml" "model_name='models/text-bison-001'" "ctx=False" 55 | #wait 56 | #python $DIR/run.py with "configs/few-shot_cot.yaml" "model_name='models/text-bison-001'" "ctx=False" 57 | #wait 58 | #python $DIR/run.py with "configs/few-shot_sys_cot.yaml" "model_name='models/text-bison-001'" "ctx=False" 59 | #wait 60 | 61 | # zero-shot, tot 62 | # python $DIR/run.py with "configs/zero-shot_tot.yaml" "model_name='gpt-3.5-turbo'" "ctx=False" 63 | # wait 64 | # python $DIR/run.py with "configs/zero-shot_tot.yaml" "model_name='gpt-4'" "ctx=False" 65 | # wait 66 | # python $DIR/run.py with "configs/zero-shot_tot.yaml" "model_name='models/text-bison-001'" "ctx=False" 67 | # wait 68 | 69 | # few-shot, tot 70 | # python $DIR/run.py with "configs/few-shot_tot.yaml" "model_name='gpt-3.5-turbo'" "ctx=False" 71 | # wait 72 | # python $DIR/run.py with "configs/few-shot_tot.yaml" "model_name='gpt-4'" "ctx=False" 73 | # wait 74 | # python $DIR/run.py with "configs/few-shot_tot.yaml" "model_name='models/text-bison-001'" "ctx=False" 75 | # wait 76 | 77 | # zero-shot, self consistency 78 | #python $DIR/run.py with "configs/zero-shot_sc.yaml" "model_name='gpt-3.5-turbo'" "ctx=False" 79 | #wait 80 | #python $DIR/run.py with "configs/zero-shot_cot_sc.yaml" "model_name='gpt-3.5-turbo'" "ctx=False" 81 | #wait 82 | #python $DIR/run.py with "configs/zero-shot_sc.yaml" "model_name='gpt-4'" "ctx=False" 83 | #wait 84 | #python $DIR/run.py with "configs/zero-shot_cot_sc.yaml" "model_name='gpt-4'" "ctx=False" 85 | #wait 86 | #python $DIR/run.py with "configs/zero-shot_sc.yaml" "model_name='models/text-bison-001'" "ctx=False" 87 | #wait 88 | #python $DIR/run.py with "configs/zero-shot_cot_sc.yaml" "model_name='models/text-bison-001'" "ctx=False" 89 | #wait 90 | 91 | # few-shot, self consistency 92 | #python $DIR/run.py with "configs/few-shot_sc.yaml" "model_name='gpt-3.5-turbo'" "ctx=False" 93 | #wait 94 | #python $DIR/run.py with "configs/few-shot_cot_sc.yaml" "model_name='gpt-3.5-turbo'" "ctx=False" 95 | #wait 96 | #python $DIR/run.py with "configs/few-shot_sc.yaml" "model_name='gpt-4'" "ctx=False" 97 | #wait 98 | #python $DIR/run.py with "configs/few-shot_cot_sc.yaml" "model_name='gpt-4'" "ctx=False" 99 | #wait 100 | #python $DIR/run.py with "configs/few-shot_sc.yaml" "model_name='models/text-bison-001'" "ctx=False" 101 | #wait 102 | #python $DIR/run.py with "configs/few-shot_cot_sc.yaml" "model_name='models/text-bison-001'" "ctx=False" -------------------------------------------------------------------------------- /scripts/run_ctx.sh: -------------------------------------------------------------------------------- 1 | DIR=/home/elpis_ubuntu/LLM/NLPBench 2 | 3 | 4 | # # zero-shot, cot, sys, cot + sys 5 | # python $DIR/run.py with "configs/zero-shot.yaml" "model_name='gpt-3.5-turbo'" "ctx=True" 6 | # wait 7 | # python $DIR/run.py with "configs/zero-shot_sys.yaml" "model_name='gpt-3.5-turbo'" "ctx=True" 8 | # wait 9 | # python $DIR/run.py with "configs/zero-shot_cot.yaml" "model_name='gpt-3.5-turbo'" "ctx=True" 10 | # wait 11 | # python $DIR/run.py with "configs/zero-shot_sys_cot.yaml" "model_name='gpt-3.5-turbo'" "ctx=True" 12 | # wait 13 | 14 | # python $DIR/run.py with "configs/zero-shot.yaml" "model_name='gpt-4'" "ctx=True" 15 | # wait 16 | # python $DIR/run.py with "configs/zero-shot_sys.yaml" "model_name='gpt-4'" "ctx=True" 17 | # wait 18 | # python $DIR/run.py with "configs/zero-shot_cot.yaml" "model_name='gpt-4'" "ctx=True" 19 | # wait 20 | # python $DIR/run.py with "configs/zero-shot_sys_cot.yaml" "model_name='gpt-4'" "ctx=True" 21 | # wait 22 | 23 | # python $DIR/run.py with "configs/zero-shot.yaml" "model_name='models/text-bison-001'" "ctx=True" 24 | # wait 25 | # python $DIR/run.py with "configs/zero-shot_sys.yaml" "model_name='models/text-bison-001'" "ctx=True" 26 | # wait 27 | # python $DIR/run.py with "configs/zero-shot_cot.yaml" "model_name='models/text-bison-001'" "ctx=True" 28 | # wait 29 | # python $DIR/run.py with "configs/zero-shot_sys_cot.yaml" "model_name='models/text-bison-001'" "ctx=True" 30 | # wait 31 | 32 | 33 | # few-shot, cot, sys, cot + sys 34 | #python $DIR/run.py with "configs/few-shot.yaml" "model_name='gpt-3.5-turbo'" "ctx=True" 35 | #wait 36 | #python $DIR/run.py with "configs/few-shot_sys.yaml" "model_name='gpt-3.5-turbo'" "ctx=True" 37 | #wait 38 | #python $DIR/run.py with "configs/few-shot_cot.yaml" "model_name='gpt-3.5-turbo'" "ctx=True" 39 | #wait 40 | #python $DIR/run.py with "configs/few-shot_sys_cot.yaml" "model_name='gpt-3.5-turbo'" "ctx=True" 41 | #wait 42 | # 43 | #python $DIR/run.py with "configs/few-shot.yaml" "model_name='gpt-4'" "ctx=True" 44 | #wait 45 | #python $DIR/run.py with "configs/few-shot_sys.yaml" "model_name='gpt-4'" "ctx=True" 46 | #wait 47 | #python $DIR/run.py with "configs/few-shot_cot.yaml" "model_name='gpt-4'" "ctx=True" 48 | #wait 49 | #python $DIR/run.py with "configs/few-shot_sys_cot.yaml" "model_name='gpt-4'" "ctx=True" 50 | #wait 51 | 52 | # python $DIR/run.py with "configs/few-shot.yaml" "model_name='models/text-bison-001'" "ctx=True" 53 | # wait 54 | # python $DIR/run.py with "configs/few-shot_sys.yaml" "model_name='models/text-bison-001'" "ctx=True" 55 | # wait 56 | # python $DIR/run.py with "configs/few-shot_cot.yaml" "model_name='models/text-bison-001'" "ctx=True" 57 | # wait 58 | # python $DIR/run.py with "configs/few-shot_sys_cot.yaml" "model_name='models/text-bison-001'" "ctx=True" 59 | # wait 60 | 61 | 62 | # zero-shot, tot 63 | # python $DIR/run.py with "configs/zero-shot_cot.yaml" "model_name='gpt-3.5-turbo'" "ctx=True" 64 | # wait 65 | # python $DIR/run.py with "configs/zero-shot_cot.yaml" "model_name='gpt-4'" "ctx=True" 66 | # wait 67 | # python $DIR/run.py with "configs/zero-shot_cot.yaml" "model_name='models/text-bison-001'" "ctx=True" 68 | # wait 69 | 70 | 71 | # few-shot, tot 72 | # python $DIR/run.py with "configs/few-shot_cot.yaml" "model_name='gpt-3.5-turbo'" "ctx=True" 73 | # wait 74 | # python $DIR/run.py with "configs/few-shot_cot.yaml" "model_name='gpt-4'" "ctx=True" 75 | # wait 76 | # python $DIR/run.py with "configs/few-shot_cot.yaml" "model_name='models/text-bison-001'" "ctx=True" 77 | # wait 78 | 79 | # zero-shot, self consistency 80 | python $DIR/run.py with "configs/zero-shot_sc.yaml" "model_name='gpt-3.5-turbo'" "ctx=True" 81 | wait 82 | python $DIR/run.py with "configs/zero-shot_cot_sc.yaml" "model_name='gpt-3.5-turbo'" "ctx=True" 83 | wait 84 | python $DIR/run.py with "configs/zero-shot_sc.yaml" "model_name='gpt-4'" "ctx=True" 85 | wait 86 | python $DIR/run.py with "configs/zero-shot_cot_sc.yaml" "model_name='gpt-4'" "ctx=True" 87 | wait 88 | python $DIR/run.py with "configs/zero-shot_sc.yaml" "model_name='models/text-bison-001'" "ctx=True" 89 | wait 90 | python $DIR/run.py with "configs/zero-shot_cot_sc.yaml" "model_name='models/text-bison-001'" "ctx=True" 91 | wait 92 | 93 | # few-shot, self consistency 94 | python $DIR/run.py with "configs/few-shot_sc.yaml" "model_name='gpt-3.5-turbo'" "ctx=True" 95 | wait 96 | python $DIR/run.py with "configs/few-shot_cot_sc.yaml" "model_name='gpt-3.5-turbo'" "ctx=True" 97 | wait 98 | python $DIR/run.py with "configs/few-shot_sc.yaml" "model_name='gpt-4'" "ctx=True" 99 | wait 100 | python $DIR/run.py with "configs/few-shot_cot_sc.yaml" "model_name='gpt-4'" "ctx=True" 101 | wait 102 | python $DIR/run.py with "configs/few-shot_sc.yaml" "model_name='models/text-bison-001'" "ctx=True" 103 | wait 104 | python $DIR/run.py with "configs/few-shot_cot_sc.yaml" "model_name='models/text-bison-001'" "ctx=True" 105 | -------------------------------------------------------------------------------- /scripts/run_os.sh: -------------------------------------------------------------------------------- 1 | PID_file="logs/pid_file" 2 | DIR=/root/llm/NLPBench 3 | 4 | getpid() { 5 | head -1 < $PID_file 6 | } 7 | 8 | 9 | #nohup bash $DIR/scripts/serving.sh -m "meta-llama/Llama-2-13b-chat-hf" 2>&1 10 | #sleep 90 11 | #PID=$(getpid) 12 | # python $DIR/run.py with "configs/zero-shot.yaml" "model_name='meta-llama/Llama-2-13b-chat-hf'" "ctx=False" 13 | # wait 14 | # python $DIR/run.py with "configs/zero-shot_sys.yaml" "model_name='meta-llama/Llama-2-13b-chat-hf'" "ctx=False" 15 | # wait 16 | # python $DIR/run.py with "configs/zero-shot_cot.yaml" "model_name='meta-llama/Llama-2-13b-chat-hf'" "ctx=False" 17 | # wait 18 | #python $DIR/run.py with "configs/zero-shot_sys_cot.yaml" "model_name='meta-llama/Llama-2-13b-chat-hf'" "ctx=False" 19 | #wait 20 | #python $DIR/run.py with "configs/few-shot.yaml" "model_name='meta-llama/Llama-2-13b-chat-hf'" "ctx=False" 21 | #wait 22 | #python $DIR/run.py with "configs/few-shot_sys.yaml" "model_name='meta-llama/Llama-2-13b-chat-hf'" "ctx=False" 23 | #wait 24 | #python $DIR/run.py with "configs/few-shot_cot.yaml" "model_name='meta-llama/Llama-2-13b-chat-hf'" "ctx=False" 25 | #wait 26 | #python $DIR/run.py with "configs/few-shot_sys_cot.yaml" "model_name='meta-llama/Llama-2-13b-chat-hf'" "ctx=False" 27 | #wait 28 | 29 | #python $DIR/run.py with "configs/zero-shot.yaml" "model_name='meta-llama/Llama-2-13b-chat-hf'" "ctx=true" 30 | #wait 31 | #python $DIR/run.py with "configs/zero-shot_sys.yaml" "model_name='meta-llama/Llama-2-13b-chat-hf'" "ctx=true" 32 | #wait 33 | #python $DIR/run.py with "configs/zero-shot_cot.yaml" "model_name='meta-llama/Llama-2-13b-chat-hf'" "ctx=true" 34 | #wait 35 | #python $DIR/run.py with "configs/zero-shot_sys_cot.yaml" "model_name='meta-llama/Llama-2-13b-chat-hf'" "ctx=true" 36 | #wait 37 | #kill -9 $PID 38 | 39 | nohup bash $DIR/scripts/serving.sh -m "meta-llama/Llama-2-70b-chat-hf" 2>&1 40 | sleep 120 41 | PID=$(getpid) 42 | #python $DIR/run.py with "configs/zero-shot.yaml" "model_name='meta-llama/Llama-2-70b-chat-hf'" "ctx=False" 43 | #wait 44 | #python $DIR/run.py with "configs/zero-shot_sys.yaml" "model_name='meta-llama/Llama-2-70b-chat-hf'" "ctx=False" 45 | #wait 46 | #python $DIR/run.py with "configs/zero-shot_cot.yaml" "model_name='meta-llama/Llama-2-70b-chat-hf'" "ctx=False" 47 | #wait 48 | python $DIR/run.py with "configs/zero-shot_sys_cot.yaml" "model_name='meta-llama/Llama-2-70b-chat-hf'" "ctx=False" 49 | wait 50 | python $DIR/run.py with "configs/few-shot.yaml" "model_name='meta-llama/Llama-2-70b-chat-hf'" "ctx=False" 51 | wait 52 | python $DIR/run.py with "configs/few-shot_sys.yaml" "model_name='meta-llama/Llama-2-70b-chat-hf'" "ctx=False" 53 | wait 54 | python $DIR/run.py with "configs/few-shot_cot.yaml" "model_name='meta-llama/Llama-2-70b-chat-hf'" "ctx=False" 55 | wait 56 | python $DIR/run.py with "configs/few-shot_sys_cot.yaml" "model_name='meta-llama/Llama-2-70b-chat-hf'" "ctx=False" 57 | wait 58 | 59 | python $DIR/run.py with "configs/zero-shot.yaml" "model_name='meta-llama/Llama-2-70b-chat-hf'" "ctx=True" 60 | wait 61 | python $DIR/run.py with "configs/zero-shot_sys.yaml" "model_name='meta-llama/Llama-2-70b-chat-hf'" "ctx=True" 62 | wait 63 | python $DIR/run.py with "configs/zero-shot_cot.yaml" "model_name='meta-llama/Llama-2-70b-chat-hf'" "ctx=True" 64 | wait 65 | python $DIR/run.py with "configs/zero-shot_sys_cot.yaml" "model_name='meta-llama/Llama-2-70b-chat-hf'" "ctx=True" 66 | wait 67 | python $DIR/run.py with "configs/few-shot.yaml" "model_name='meta-llama/Llama-2-70b-chat-hf'" "ctx=True" 68 | wait 69 | python $DIR/run.py with "configs/few-shot_sys.yaml" "model_name='meta-llama/Llama-2-70b-chat-hf'" "ctx=True" 70 | wait 71 | python $DIR/run.py with "configs/few-shot_cot.yaml" "model_name='meta-llama/Llama-2-70b-chat-hf'" "ctx=True" 72 | wait 73 | python $DIR/run.py with "configs/few-shot_sys_cot.yaml" "model_name='meta-llama/Llama-2-70b-chat-hf'" "ctx=True" 74 | wait 75 | kill -9 $PID 76 | 77 | nohup bash $DIR/scripts/serving.sh -m "sambanovasystems/BLOOMChat-176B-v1" 2>&1 78 | sleep 300 79 | PID=$(getpid) 80 | python $DIR/run.py with "configs/zero-shot.yaml" "model_name='sambanovasystems/BLOOMChat-176B-v1'" "ctx=False" 81 | wait 82 | python $DIR/run.py with "configs/zero-shot_sys.yaml" "model_name='sambanovasystems/BLOOMChat-176B-v1'" "ctx=False" 83 | wait 84 | python $DIR/run.py with "configs/zero-shot_cot.yaml" "model_name='sambanovasystems/BLOOMChat-176B-v1'" "ctx=False" 85 | wait 86 | python $DIR/run.py with "configs/zero-shot_sys_cot.yaml" "model_name='sambanovasystems/BLOOMChat-176B-v1'" "ctx=False" 87 | wait 88 | python $DIR/run.py with "configs/few-shot.yaml" "model_name='sambanovasystems/BLOOMChat-176B-v1'" "ctx=False" 89 | wait 90 | python $DIR/run.py with "configs/few-shot_sys.yaml" "model_name='sambanovasystems/BLOOMChat-176B-v1'" "ctx=False" 91 | wait 92 | python $DIR/run.py with "configs/few-shot_cot.yaml" "model_name='sambanovasystems/BLOOMChat-176B-v1'" "ctx=False" 93 | wait 94 | python $DIR/run.py with "configs/few-shot_sys_cot.yaml" "model_name='sambanovasystems/BLOOMChat-176B-v1'" "ctx=False" 95 | wait 96 | 97 | python $DIR/run.py with "configs/zero-shot.yaml" "model_name='sambanovasystems/BLOOMChat-176B-v1'" "ctx=True" 98 | wait 99 | python $DIR/run.py with "configs/zero-shot_sys.yaml" "model_name='sambanovasystems/BLOOMChat-176B-v1'" "ctx=True" 100 | wait 101 | python $DIR/run.py with "configs/zero-shot_cot.yaml" "model_name='sambanovasystems/BLOOMChat-176B-v1'" "ctx=True" 102 | wait 103 | python $DIR/run.py with "configs/zero-shot_sys_cot.yaml" "model_name='sambanovasystems/BLOOMChat-176B-v1'" "ctx=True" 104 | wait 105 | python $DIR/run.py with "configs/few-shot.yaml" "model_name='sambanovasystems/BLOOMChat-176B-v1'" "ctx=True" 106 | wait 107 | python $DIR/run.py with "configs/few-shot_sys.yaml" "model_name='sambanovasystems/BLOOMChat-176B-v1'" "ctx=True" 108 | wait 109 | python $DIR/run.py with "configs/few-shot_cot.yaml" "model_name='sambanovasystems/BLOOMChat-176B-v1'" "ctx=True" 110 | wait 111 | python $DIR/run.py with "configs/few-shot_sys_cot.yaml" "model_name='sambanovasystems/BLOOMChat-176B-v1'" "ctx=True" 112 | wait 113 | kill -9 $PID 114 | 115 | -------------------------------------------------------------------------------- /scripts/serving.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model="tiiuae/falcon-180B-chat" 4 | number_of_gpus="8" 5 | host="127.0.0.1" 6 | port="8000" 7 | PID=$$ 8 | 9 | 10 | while getopts ":m:n:a:p:" opt; do 11 | case $opt in 12 | m) model="$OPTARG" 13 | ;; 14 | n) number_of_gpus="$OPTARG" 15 | ;; 16 | a) host="$OPTARG" 17 | ;; 18 | p) port="$OPTARG" 19 | ;; 20 | \?) echo "Invalid option -$OPTARG" >&2 21 | exit 1 22 | ;; 23 | esac 24 | 25 | case $OPTARG in 26 | -*) echo "Option $opt needs a valid argument" 27 | exit 1 28 | ;; 29 | esac 30 | done 31 | 32 | IFS="/" 33 | read -ra ADDR <<< $model 34 | modelname=${ADDR[1]} 35 | IFS="" # reset IFS 36 | 37 | nohup /root/miniconda3/envs/LLM/bin/python -m vllm.entrypoints.openai.api_server \ 38 | --host ${host} \ 39 | --port ${port} \ 40 | --model ${model} \ 41 | --tensor-parallel-size ${number_of_gpus} > "logs/serving_${modelname}.log" & 42 | 43 | echo $! > "logs/pid_file" -------------------------------------------------------------------------------- /solve/__init__.py: -------------------------------------------------------------------------------- 1 | from .solver import Solver -------------------------------------------------------------------------------- /solve/solver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import traceback 4 | from typing import Optional 5 | from collections import Counter 6 | from utils import ApiManager, PromptGenerator, tot_solver 7 | 8 | 9 | class Solver: 10 | def __init__( 11 | self, 12 | data_path: str, 13 | model_name: str, 14 | seed: int, 15 | shot_type: str, 16 | prompt_r: str, 17 | sys: bool, 18 | max_tokens: int, 19 | ctx: bool, 20 | self_consistency: bool = False, 21 | self_reflection: bool = False, 22 | ): 23 | self.data_path = data_path 24 | self.model_name_save = model_name.split('/')[-1] 25 | ctx_path = 'ctx' if ctx else 'no_ctx' 26 | self.save_path = (f'res/{str(seed)}/{ctx_path}/{shot_type}' 27 | f'{"-sc" if self_consistency else ""}' 28 | f'_{self.model_name_save}' 29 | f'{("_" + prompt_r) if prompt_r is not None else ""}' 30 | f'{"_sys" if sys else ""}.json') 31 | 32 | root = self.save_path.rsplit('/', 1)[0] 33 | if not os.path.exists(root): 34 | os.makedirs(root) 35 | 36 | self.ctx = ctx 37 | 38 | ### Init api manager and prompt generator 39 | self.llm = ApiManager( 40 | model_name=model_name, 41 | seed=seed, 42 | default_max_tokens=max_tokens, 43 | ) 44 | 45 | self.prompt = PromptGenerator( 46 | api=self.llm, 47 | system_prompt=sys, 48 | shot_type=shot_type, 49 | prompt_reinforcement=prompt_r, 50 | ) 51 | 52 | self.self_consistency = self_consistency 53 | self.self_reflection = self_reflection 54 | 55 | @staticmethod 56 | def load_cache(save_path: str): 57 | skip_cnt = 0 58 | res_cache = [] 59 | if os.path.exists(save_path): 60 | res_cache = json.load(open(save_path, 'r')) 61 | skip_cnt = len(res_cache) 62 | 63 | return res_cache, skip_cnt 64 | 65 | @staticmethod 66 | def majority_voting(resp_list: list): 67 | res = [r.replace(' ', '').replace("'", '').replace('"', '') for r in resp_list] 68 | res = Counter(res).most_common(1)[0][0] 69 | return res 70 | 71 | def solve_ctx(self): 72 | ### Load cache (if palm, else init result) 73 | if self.llm.api_type == 'palm': 74 | res, skip_cnt = self.load_cache(self.save_path) 75 | else: 76 | res = [] 77 | skip_cnt = 0 78 | 79 | for data in json.load(open(self.data_path, 'r')): 80 | if skip_cnt > 0: 81 | skip_cnt -= 1 82 | continue 83 | history = None 84 | resps = [] 85 | for idx, q in enumerate(data['questions']): 86 | q_type = data['type'][idx] 87 | q_opt = [] if data['options'][idx] is None else data['options'][idx] 88 | ctx = None if data['context'] == '-1' else data['context'] 89 | messages = self.prompt.generate(q=q, q_type=q_type, q_opt=q_opt, ctx=ctx, history=history) 90 | 91 | if self.self_consistency is True and q_type == 0: 92 | try: 93 | resp = self.llm(messages, choices=3) 94 | if resp is None: 95 | resp = "" 96 | except Exception as e: 97 | print(e) 98 | resp = "" 99 | if resp != "": 100 | resp = self.majority_voting(resp) 101 | else: 102 | try: 103 | resp = self.llm(messages) 104 | if resp is None: 105 | resp = "" 106 | except Exception as e: 107 | print(e) 108 | resp = "" 109 | 110 | if self.llm.api_type == 'palm': 111 | history = messages 112 | history['prompt'] += '\n' + resp 113 | 114 | elif self.llm.api_type == 'oai': 115 | history = messages 116 | history[-1]['content'] += resp 117 | resps.append(resp) 118 | print(f"\n{self.model_name_save}: {resp}") 119 | 120 | res.append({ 121 | **data, 122 | 'prompt': history, 123 | 'llm_answer': resps 124 | }) 125 | if self.llm.api_type == 'palm': 126 | json.dump(res, open(self.save_path, 'w'), indent=4) 127 | 128 | json.dump(res, open(self.save_path, 'w'), indent=4) 129 | 130 | def solve_no_ctx(self): 131 | ### Load cache (if palm, else init result) 132 | if self.llm.api_type == 'palm': 133 | res, skip_cnt = self.load_cache(self.save_path) 134 | else: 135 | res = [] 136 | skip_cnt = 0 137 | 138 | for data in json.load(open(self.data_path, 'r')): 139 | if skip_cnt > 0: 140 | skip_cnt -= 1 141 | continue 142 | q = data['question'] 143 | q_type = data['type'] 144 | q_opt = data.get('options', []) 145 | 146 | messages = self.prompt.generate(q=q, q_type=q_type, q_opt=q_opt) 147 | if q_type == 0 and self.self_consistency is True: 148 | try: 149 | resp = self.llm(messages, choices=3) 150 | except Exception as e: 151 | traceback.print_exc() 152 | messages = [] 153 | resp = "" 154 | if resp != "": 155 | resp = self.majority_voting(resp) 156 | else: 157 | try: 158 | resp = self.llm(messages) 159 | except Exception as e: 160 | traceback.print_exc() 161 | messages = [] 162 | resp = "" 163 | print(f"\n {self.model_name_save}: {resp}") 164 | res.append({ 165 | **data, 166 | 'prompt': messages, 167 | 'llm_answer': resp 168 | }) 169 | if self.llm.api_type == 'palm': 170 | json.dump(res, open(self.save_path, 'w'), indent=4) 171 | 172 | json.dump(res, open(self.save_path, 'w'), indent=4) 173 | 174 | def solve_tot_no_ctx( 175 | self, 176 | evaluation_strategy: Optional[str] = "vote", # value or vote 177 | num_thoughts: Optional[int] = 1, 178 | max_steps: Optional[int] = 3, 179 | max_states: Optional[int] = 4, 180 | pruning_threshold: Optional[float] = 0.5, 181 | ): 182 | ### Load cache (if palm, else init result) 183 | if self.llm.api_type == 'palm': 184 | res, skip_cnt = self.load_cache(self.save_path) 185 | else: 186 | res = [] 187 | skip_cnt = 0 188 | 189 | for data in json.load(open(self.data_path, 'r')): 190 | if skip_cnt > 0: 191 | skip_cnt -= 1 192 | continue 193 | try: 194 | resp, messages = tot_solver( 195 | generator=self.prompt, 196 | api=self.llm, 197 | q=data['question'], 198 | q_type=data['type'], 199 | q_opt=data.get('options', []), 200 | evaluation_strategy=evaluation_strategy, 201 | num_thoughts=num_thoughts, 202 | max_steps=max_steps, 203 | max_states=max_states, 204 | pruning_threshold=pruning_threshold, 205 | ) 206 | except Exception as e: 207 | traceback.print_exc() 208 | messages = [] 209 | resp = "" 210 | 211 | print(f"\n {self.model_name_save}: {resp}") 212 | res.append({ 213 | **data, 214 | 'messages': messages, 215 | 'llm_answer': resp 216 | }) 217 | if self.llm.api_type == 'palm': 218 | out_str = json.dumps(res[-1]) 219 | with open(self.save_path, 'a') as f: 220 | f.write(out_str + '\n') 221 | 222 | json.dump(res, open(self.save_path, 'w'), indent=4) 223 | 224 | def run(self): 225 | if self.ctx: 226 | self.solve_ctx() 227 | else: 228 | if self.prompt.prompt_reinf == 'tot': 229 | self.solve_tot_no_ctx() 230 | else: 231 | self.solve_no_ctx() 232 | -------------------------------------------------------------------------------- /tokenizer.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LinxinS97/NLPBench/82d6c109db9121ff60627873b8cc5ead922c8cba/tokenizer.model -------------------------------------------------------------------------------- /tokenizer_checklist.chk: -------------------------------------------------------------------------------- 1 | eeec4125e9c7560836b4873b6f8e3025 tokenizer.model 2 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .arg_parser import parser 2 | from .utils import PromptGenerator, ApiManager 3 | from .tot_solver import tot_solver -------------------------------------------------------------------------------- /utils/arg_parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument("--config", type=str, default='zero-shot') -------------------------------------------------------------------------------- /utils/tot_solver.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import concurrent.futures 3 | import os 4 | import json 5 | import numpy as np 6 | from typing import List, Dict, Any, Optional, Union 7 | from utils import ApiManager, PromptGenerator 8 | from prompts import SOLUTION_PROMPT, STATE_PROMPT, VOTE_PROMPT 9 | 10 | 11 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | 16 | class ToTModel: 17 | 18 | api: ApiManager 19 | evaluation_strategy: Optional[str] = "value" 20 | 21 | def __init__(self, api: ApiManager, evaluation_strategy: Optional[str] = "value"): 22 | self.api = api 23 | self.evaluation_strategy = evaluation_strategy 24 | 25 | 26 | def message_handler(self, message: str): 27 | if self.api.api_type == 'palm': 28 | return { 29 | 'prompt': message, 30 | } 31 | elif self.api.api_type == 'oai': 32 | return [{ 33 | 'role': 'user', 34 | 'content': message, 35 | }] 36 | 37 | 38 | def generate_thoughts(self, 39 | state: Union[str, list], 40 | k: int, 41 | initial_prompt: Union[str, dict], 42 | rejected_solutions: Optional[str]=None): 43 | 44 | if type(state) == str: 45 | state_text = state 46 | else: 47 | state_text = '\n'.join(state) 48 | 49 | prompt = SOLUTION_PROMPT.format(state_text=state_text, 50 | initial_prompt=initial_prompt, 51 | rejected_solutions=rejected_solutions) 52 | 53 | prompt = self.message_handler(prompt) 54 | thoughts = [self.api(prompt) for _ in range(k)] 55 | return thoughts 56 | 57 | 58 | def generate_solution(self, 59 | initial_prompt: str, 60 | state: Union[str, list], 61 | rejected_solutions: Optional[str]=None): 62 | try: 63 | if type(state) == str: 64 | state_text = state 65 | else: 66 | state_text = '\n'.join(state) 67 | 68 | prompt = SOLUTION_PROMPT.format(state_text=state_text, 69 | initial_prompt=initial_prompt, 70 | rejected_solutions=rejected_solutions) 71 | prompt = self.message_handler(prompt) 72 | answer = self.api(prompt) 73 | return answer 74 | 75 | except Exception as e: 76 | logger.error(f"Error in generate_solutions: {e}") 77 | return None 78 | 79 | 80 | def evaluate_states(self, 81 | states: List, 82 | initial_prompt: str): 83 | if not states: 84 | return {} 85 | 86 | if self.evaluation_strategy == 'value': 87 | state_values = {} 88 | for state in states: 89 | if type(state) == str: 90 | state_text = state 91 | else: 92 | state_text = '\n'.join(state) 93 | 94 | prompt = STATE_PROMPT.format(initial_prompt=initial_prompt, state_text=state_text) 95 | prompt = self.message_handler(prompt) 96 | try: 97 | value_text = self.api(prompt, 10) 98 | value = float(value_text) 99 | print(f"Evaluated Thought Value: {value}") 100 | except ValueError: 101 | value = 0 # Assign a default value if the conversion fails 102 | state_values[state] = value 103 | return state_values 104 | 105 | elif self.evaluation_strategy == 'vote': 106 | states_text = '\n'.join([' '.join(state) for state in states]) 107 | 108 | prompt = VOTE_PROMPT.format(states_text=states_text, initial_prompt=initial_prompt) 109 | prompt = self.message_handler(prompt) 110 | best_state_text = self.api(prompt, 50) 111 | 112 | best_state = tuple(best_state_text.split()) 113 | 114 | return {state: 1 if state == best_state else 0 for state in states} 115 | 116 | else: 117 | raise ValueError("Invalid evaluation strategy. Choose 'value' or 'vote'.") 118 | 119 | 120 | 121 | class TreeofThoughts: 122 | 123 | model: ToTModel 124 | best_state = None 125 | best_value = float("-inf") 126 | history = [] #added line initalize history 127 | tree: Dict[str, Dict[str, Union[float, Dict[str, Any]]]] = { 128 | "nodes": {}, 129 | } 130 | 131 | def __init__(self, model: ToTModel): 132 | self.model = model 133 | 134 | def save_tree_to_json(self, file_name): 135 | os.makedirs(os.path.dirname(file_name), exist_ok=True) 136 | with open(file_name, 'w') as json_file: 137 | json.dump(self.tree, json_file, indent=4) 138 | 139 | def logNewState(self, state, evaluation): 140 | if not (type(state) == str): 141 | state = " | ".join(state) 142 | if state in self.tree['nodes']: 143 | self.tree['nodes'][state]['thoughts'].append(evaluation) 144 | else: 145 | self.tree['nodes'][state] = {'thoughts': [evaluation]} 146 | 147 | def adjust_pruning_threshold_precentile(self, evaluated_thoughts, percentile): 148 | values = np.array(list(evaluated_thoughts.values())) 149 | if values.size == 0: 150 | return 0 151 | return max(np.percentile(values, percentile), 0.1) 152 | 153 | 154 | def adjust_pruning_threshold_moving_average(self, evaluated_thoughts, window_size): 155 | values = list(evaluated_thoughts.values()) 156 | if len(values) < window_size: 157 | return np.mean(values) if values else 0 158 | else: 159 | return max(np.mean(values[-window_size:]), 0.1) 160 | 161 | 162 | 163 | class TreeofThoughtsBFS(TreeofThoughts): 164 | def solve( 165 | self, 166 | initial_prompt, 167 | num_thoughts, 168 | max_steps, 169 | max_states, 170 | pruning_threshold=0.5 171 | ): 172 | current_states = [initial_prompt] 173 | state_values = {} 174 | dynamic_pruning_threshold = pruning_threshold 175 | 176 | try: 177 | with concurrent.futures.ThreadPoolExecutor() as executor: 178 | for _ in range(1, max_steps + 1): 179 | selected_states = [] 180 | for state in current_states: 181 | thoughts = self.model.generate_thoughts(state, num_thoughts, initial_prompt) 182 | futures = [executor.submit(self.model.evaluate_states, {thought: 0}, initial_prompt) for thought in thoughts] 183 | concurrent.futures.wait(futures) 184 | evaluated_thoughts = {thought: list(fut.result().values()) for thought, fut in zip(thoughts, futures)} 185 | 186 | if evaluated_thoughts: # only adjust if you have evaluated thoughts 187 | dynamic_pruning_threshold = self.adjust_pruning_threshold_moving_average(evaluated_thoughts, 5) 188 | 189 | for thought, value in evaluated_thoughts.items(): 190 | flattened_state = (state, thought) if isinstance(state, str) else (*state, thought) 191 | selected_states.append((flattened_state, value)) 192 | 193 | selected_states.sort(key=lambda x: x[1], reverse=True) 194 | selected_states = selected_states[:max_states] # Select only the top states 195 | 196 | for state, value in selected_states: 197 | if value >= dynamic_pruning_threshold: 198 | state_values[state] = value 199 | self.logNewState(state, value) 200 | logger.debug(f"State Values: {state_values}") 201 | 202 | if state_values: 203 | highest_rated_solution = max(state_values.items(), key=lambda x: x[1]) 204 | highest_rated_state = highest_rated_solution[0] 205 | solution = self.model.generate_solution(initial_prompt, highest_rated_state) 206 | 207 | return solution, highest_rated_state 208 | 209 | else: 210 | return None 211 | 212 | except Exception as e: 213 | logger.error(f"Error in tot_bfs: {e}") 214 | return None 215 | 216 | 217 | 218 | def tot_solver( 219 | generator: PromptGenerator, 220 | api: ApiManager, 221 | q: str, 222 | q_type: int, 223 | q_opt: Optional[List[str]] = [], 224 | evaluation_strategy: Optional[str] = "vote", # value or vote 225 | num_thoughts: Optional[int] = 1, 226 | max_steps: Optional[int] = 3, 227 | max_states: Optional[int] = 4, 228 | pruning_threshold: Optional[float] = 0.5, 229 | ): 230 | tot_model = ToTModel(api, evaluation_strategy) 231 | tot = TreeofThoughtsBFS(tot_model) 232 | initial_prompt = generator.generate(q, q_type, q_opt) 233 | 234 | if generator.api_type == 'palm': 235 | input_prompt = initial_prompt['prompt'] 236 | else: 237 | input_prompt = initial_prompt[-1]['content'] 238 | 239 | resp = tot.solve( 240 | initial_prompt=input_prompt, 241 | num_thoughts=num_thoughts, 242 | pruning_threshold=pruning_threshold, 243 | max_steps=max_steps, 244 | max_states=max_states 245 | ) 246 | 247 | return resp 248 | 249 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import google.generativeai as palm 5 | import autogen 6 | from typing import List, Optional, Union 7 | from prompts import SYS_PROMPT_MAPPING, PROMPT_MAPPING, SHORTANS_ZS_COT_ST1 8 | 9 | 10 | palm.configure(api_key=os.environ.get("PALM_API_KEY")) 11 | 12 | 13 | palm_llm_config = { 14 | "models/text-bison-001": { 15 | "model": "models/text-bison-001", 16 | }, 17 | } 18 | 19 | 20 | oai_llm_config = { 21 | "meta-llama/Llama-2-7b-chat-hf": { 22 | "model": "meta-llama/Llama-2-7b-chat-hf", 23 | "api_key": "empty", 24 | "api_base": "http://127.0.0.1:8000/v1", 25 | }, 26 | "meta-llama/Llama-2-13b-chat-hf": { 27 | "model": "meta-llama/Llama-2-13b-chat-hf", 28 | "api_key": "empty", 29 | "api_base": "http://127.0.0.1:8000/v1", 30 | }, 31 | "meta-llama/Llama-2-70b-chat-hf": { 32 | "model": "meta-llama/Llama-2-70b-chat-hf", 33 | "api_key": "empty", 34 | "api_base": "http://127.0.0.1:8000/v1", 35 | }, 36 | "sambanovasystems/BLOOMChat-176B-v1": { 37 | "model": "sambanovasystems/BLOOMChat-176B-v1", 38 | "api_key": "empty", 39 | "api_base": "http://127.0.0.1:8000/v1", 40 | }, 41 | "gpt-3.5-turbo": { 42 | "model": "gpt-3.5-turbo", 43 | "api_key": os.environ.get("OPENAI_API_KEY"), 44 | "api_base": "https://api.openai.com/v1", 45 | "api_version": None, 46 | }, 47 | "gpt-4": { 48 | "model": "gpt-4", 49 | "api_key": os.environ.get("OPENAI_API_KEY"), 50 | "api_base": "https://api.openai.com/v1", 51 | "api_version": None, 52 | }, 53 | } 54 | 55 | 56 | class Completion(autogen.ChatCompletion): 57 | request_timeout = 300 58 | retry_time = 20 59 | 60 | 61 | class ApiManager: 62 | model_name: str 63 | seed: int 64 | default_max_tokens: int 65 | 66 | def __init__( 67 | self, 68 | model_name: str, 69 | seed: int, 70 | temperature: Optional[float] = 1.0, 71 | default_max_tokens: Optional[int] = 945 72 | ): 73 | self.model_name = model_name 74 | self.seed = seed 75 | self.max_tokens = default_max_tokens 76 | self.temperature = temperature 77 | 78 | @property 79 | def api_type(self): 80 | return 'palm' if self.model_name in palm_llm_config.keys() else 'oai' 81 | 82 | def get_api_type(self): 83 | return self.api_type 84 | 85 | def palm_api(self, messages: str, choices: int = 1): 86 | 87 | config = { 88 | "max_output_tokens": self.max_tokens, 89 | "temperature": self.temperature, 90 | **messages, # prompt 91 | **palm_llm_config[self.model_name] 92 | } 93 | while True: 94 | try: 95 | if choices == 1: 96 | result = palm.generate_text(**config).result 97 | else: 98 | result = [palm.generate_text(**config).result for _ in range(choices)] 99 | except Exception: 100 | print('Retrying in 20s...') 101 | time.sleep(20) 102 | continue 103 | break 104 | return result 105 | 106 | def oai_api(self, messages: list, choices: int = 1): 107 | basic_config = { 108 | "api_type": "open_ai", 109 | "max_tokens": self.max_tokens, 110 | "temperature": self.temperature, 111 | "seed": self.seed, # NOTE: the dialog will be cached for the same seed 112 | "n": choices, 113 | **oai_llm_config[self.model_name] 114 | } 115 | 116 | # https://microsoft.github.io/FLAML/docs/Use-Cases/Autogen#basic-concept 117 | completion = Completion.create( 118 | messages=messages, 119 | **basic_config 120 | ) 121 | if choices == 1: 122 | return completion.choices[0].message.content 123 | else: 124 | return [c.message.content for c in completion.choices] 125 | 126 | def __call__(self, 127 | messages: Union[list, str], 128 | temperature: Optional[float] = None, 129 | max_tokens: Optional[int] = None, 130 | choices: Optional[int] = 1): 131 | if max_tokens is not None: 132 | self.max_tokens = max_tokens 133 | if temperature is not None: 134 | self.temperature = temperature 135 | if self.model_name in palm_llm_config.keys(): 136 | resp = self.palm_api(messages, choices) 137 | return resp 138 | 139 | elif self.model_name in oai_llm_config.keys(): 140 | resp = self.oai_api(messages, choices) 141 | return resp 142 | 143 | else: 144 | raise NotImplementedError 145 | 146 | 147 | class PromptGenerator: 148 | """ 149 | Generate the prompt for the given question and example. 150 | 151 | Args: 152 | api: the api manager 153 | sys_prompt: system prompt 154 | shot_type: question type, 'zero-shot' or 'few-shot' 155 | prompt_reinf: whether to use the prompt reinforcement (none, cot or tot) 156 | self_consistency: whether to use the self-consistency 157 | self_reflection: whether to use the self-reflection 158 | """ 159 | 160 | api: ApiManager 161 | sys_prompt: Optional[bool] 162 | shot_type: Optional[str] 163 | prompt_reinf: Optional[str] 164 | return_symbol: Optional[str] = '\n' 165 | 166 | def __init__( 167 | self, 168 | api: ApiManager, 169 | system_prompt: Optional[bool] = False, 170 | shot_type: Optional[str] = 'zero-shot', 171 | prompt_reinforcement: Optional[str] = None, 172 | ): 173 | self.api = api 174 | self.sys_prompt = system_prompt 175 | self.shot_type = shot_type 176 | self.prompt_reinf = prompt_reinforcement 177 | 178 | @property 179 | def api_type(self): 180 | return self.api.get_api_type() 181 | 182 | def oai_prompt_generator( 183 | self, 184 | q: str, 185 | q_type: int, 186 | q_opt: Optional[List[str]] = None, 187 | ctx: Optional[str] = None, 188 | history: Optional[List[dict]] = None, 189 | ) -> List[dict]: 190 | """ 191 | Generate the prompt for the given question and example (Open AI version). 192 | 193 | Args: 194 | args: the user arguments 195 | q: question content 196 | q_type: question type 197 | q_opt: question options (multiple choice only) 198 | 199 | Return: 200 | messages: the generated prompt 201 | """ 202 | messages = [] if history is None else history 203 | prompt_reinf = f'_{self.prompt_reinf}' if self.prompt_reinf is not None else '' 204 | prompt_type = f'{self.shot_type}{prompt_reinf}' 205 | 206 | if self.sys_prompt and history is None: 207 | messages += [{"role": "system", "content": SYS_PROMPT_MAPPING[q_type]}] 208 | 209 | if ctx is not None and history is None: 210 | messages += [{"role": "user", "content": ctx}] 211 | 212 | if q_type == 0: 213 | q = f"{q}\n{self.return_symbol.join([f'{i}: {opt}' for i, opt in enumerate(q_opt)])}" 214 | 215 | if q_type == 1 and prompt_type == 'zero-shot_cot': 216 | cot_st1 = [{"role": "user", "content": SHORTANS_ZS_COT_ST1.format(input=q)}] 217 | if self.sys_prompt: 218 | cot_st1 = [{"role": "system", "content": SYS_PROMPT_MAPPING[q_type]}] + cot_st1 219 | try: 220 | t = self.api(cot_st1, max_tokens=128) 221 | except Exception as e: 222 | print(e) 223 | t = '' 224 | messages += [{"role": "user", "content": PROMPT_MAPPING[q_type][prompt_type].format(input=q, thought=t)}] 225 | 226 | return messages 227 | try: 228 | messages += [{"role": "user", "content": PROMPT_MAPPING[q_type][prompt_type].format(input=q)}] 229 | except Exception as e: 230 | print(e) 231 | 232 | return messages 233 | 234 | def palm_prompt_generator( 235 | self, 236 | q: str, 237 | q_type: int, 238 | q_opt: Optional[List[str]] = [], 239 | ctx: Optional[str] = None, 240 | history: Optional[List[dict]] = None, 241 | ) -> dict: 242 | """ 243 | Generate the prompt for the given question and example (PaLM version). 244 | 245 | Args: 246 | args: the user arguments 247 | q: question content 248 | q_type: question type 249 | q_opt: question options (multiple choice only) 250 | 251 | Return: 252 | messages: the generated prompt 253 | """ 254 | messages = {'prompt': ''} if history is None else history 255 | prompt_reinf = f'_{self.prompt_reinf}' if self.prompt_reinf is not None else '' 256 | prompt_type = f'{self.shot_type}{prompt_reinf}' 257 | sp = SYS_PROMPT_MAPPING[q_type] + self.return_symbol 258 | 259 | if self.sys_prompt and history is None: 260 | messages['prompt'] += sp 261 | 262 | if ctx is not None and history is None: 263 | messages['prompt'] += ctx + self.return_symbol 264 | 265 | if q_type == 0: 266 | q = f"{q}\n{self.return_symbol.join([f'{i}: {opt}' for i, opt in enumerate(q_opt)])}" 267 | 268 | if q_type == 1 and prompt_type == 'zero-shot_cot': 269 | cot_st1 = {'prompt': SHORTANS_ZS_COT_ST1.format(input=q)} 270 | if self.sys_prompt: 271 | cot_st1['prompt'] = sp + cot_st1['prompt'] 272 | try: 273 | t = self.api(cot_st1, max_tokens=128) 274 | except Exception as e: 275 | print(e) 276 | t = '' 277 | messages['prompt'] += PROMPT_MAPPING[q_type][prompt_type].format(input=q, thought=t) 278 | 279 | return messages 280 | try: 281 | messages['prompt'] += PROMPT_MAPPING[q_type][prompt_type].format(input=q) 282 | except Exception as e: 283 | print(e) 284 | return messages 285 | 286 | def generate( 287 | self, 288 | q: str = "", 289 | q_type: int = 0, 290 | q_opt: Optional[List[str]] = None, 291 | ctx: Optional[str] = None, 292 | history: Optional[List[dict]] = None, 293 | ): 294 | if self.api_type == 'oai': 295 | messages = self.oai_prompt_generator(q, q_type, q_opt, ctx, history) 296 | else: 297 | messages = self.palm_prompt_generator(q, q_type, q_opt, ctx, history) 298 | 299 | return messages 300 | --------------------------------------------------------------------------------