├── .gitignore ├── .readthedocs.yaml ├── LICENSE ├── README.md ├── docs ├── Makefile ├── make.bat ├── requirements.txt └── source │ ├── conf.py │ ├── img │ ├── openicl.png │ └── overview.jpg │ ├── index.rst │ ├── modules │ ├── dataset_reader.rst │ ├── inferencer │ │ ├── base_inferencer.rst │ │ ├── cot_inferencer.rst │ │ ├── gen_inferencer.rst │ │ ├── inferencer_link.rst │ │ └── ppl_inferencer.rst │ ├── prompt_template.rst │ └── retriever │ │ ├── base_retriever.rst │ │ ├── bm25_retriever.rst │ │ ├── dpp_retriever.rst │ │ ├── mdl_retriever.rst │ │ ├── random_retriever.rst │ │ ├── retriever_link.rst │ │ ├── topk_retriever.rst │ │ ├── votek_retriever.rst │ │ └── zero_retriever.rst │ └── notes │ ├── example.rst │ ├── installation.rst │ └── tutorial.rst ├── examples ├── README.md ├── research_projects │ ├── README.md │ └── self-adaptive_in-context_learning.ipynb └── tutorials │ ├── README.md │ ├── openicl_tutorial1_getting_started.ipynb │ ├── openicl_tutorial2_use_different_models.ipynb │ └── openicl_tutorial3_accelerate.ipynb ├── openicl ├── __init__.py ├── icl_dataset_reader.py ├── icl_evaluator │ ├── __init__.py │ ├── icl_acc_evaluator.py │ ├── icl_api_evaluator.py │ ├── icl_base_evaluator.py │ ├── icl_bleu_evaluator.py │ ├── icl_rouge_evaluator.py │ └── icl_squad_evaluator.py ├── icl_inferencer │ ├── __init__.py │ ├── icl_base_inferencer.py │ ├── icl_channel_inferencer.py │ ├── icl_cot_inferencer.py │ ├── icl_gen_inferencer.py │ └── icl_ppl_inferencer.py ├── icl_prompt_template.py ├── icl_retriever │ ├── __init__.py │ ├── icl_base_retriever.py │ ├── icl_bm25_retriever.py │ ├── icl_dpp_retriever.py │ ├── icl_mdl_retriever.py │ ├── icl_random_retriever.py │ ├── icl_topk_retriever.py │ ├── icl_votek_retriever.py │ └── icl_zero_retriever.py └── utils │ ├── __init__.py │ ├── api_service.py │ ├── calculate.py │ ├── check_type.py │ ├── collators.py │ ├── icl_common_utils.py │ └── logging.py ├── requirements.txt ├── scripts └── self_consistency.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python modules. 2 | *.pyc 3 | 4 | # Byte-compiled 5 | __pycache__/ 6 | .cache/ 7 | 8 | # Test files 9 | test.py 10 | test.sh 11 | 12 | # Output files 13 | icl_inference_output/ 14 | batchscript-* 15 | output/ 16 | 17 | # Packages 18 | dist/ 19 | .eggs/ 20 | openicl.egg-info/ 21 | build/ 22 | docs/build/ 23 | *.xml 24 | *.iml 25 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: "ubuntu-22.04" 5 | tools: 6 | python: "3.8" 7 | 8 | python: 9 | install: 10 | - requirements: docs/requirements.txt 11 | 12 | formats: all 13 | 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 | 5 | ------ 6 | 7 |

8 | Overview • 9 | Installation • 10 | Paper • 11 | Examples • 12 | Docs • 13 | Citation 14 |

15 | 16 | ![version](https://img.shields.io/badge/version-0.1.8-blue) 17 | 18 | 19 | ## Overview 20 | OpenICL provides an easy interface for in-context learning, with many state-of-the-art retrieval and inference methods built in to facilitate systematic comparison of LMs and fast research prototyping. Users can easily incorporate different retrieval and inference methods, as well as different prompt instructions into their workflow. 21 |
22 | 23 |
24 | 25 | ## What's News 26 | + **v0.1.8** Support LLaMA and self-consistency 27 | 28 | ## Installation 29 | Note: OpenICL requires Python 3.8+ 30 | 31 | **Using Pip** 32 | ``` 33 | pip install openicl 34 | ``` 35 | 36 | 37 | **Installation for local development:** 38 | ``` 39 | git clone https://github.com/Shark-NLP/OpenICL 40 | cd OpenICL 41 | pip install -e . 42 | ``` 43 | 44 | ## Quick Start 45 | Following example shows you how to perform ICL on sentiment classification dataset. More examples and tutorials can be found at [examples](https://github.com/Shark-NLP/OpenICL/tree/main/examples) 46 | 47 | #### Step 1: Load and prepare data 48 | ```python 49 | from datasets import load_dataset 50 | from openicl import DatasetReader 51 | 52 | # Loading dataset from huggingface 53 | dataset = load_dataset('gpt3mix/sst2') 54 | 55 | # Define a DatasetReader, with specified column names where input and output are stored. 56 | data = DatasetReader(dataset, input_columns=['text'], output_column='label') 57 | ``` 58 | 59 | #### Step 2: Define the prompt template (Optional) 60 | ```python 61 | from openicl import PromptTemplate 62 | tp_dict = { 63 | 0: "Positive Movie Review: ", 64 | 1: "Negative Movie Review: " 65 | } 66 | 67 | template = PromptTemplate(tp_dict, {'text': ''}, ice_token='') 68 | ``` 69 | The placeholder `` and `` will be replaced by in-context examples and testing input, respectively. For more detailed information about `PromptTemplate` (such as string-type template) , please see [tutorial1](https://github.com/Shark-NLP/OpenICL/blob/main/examples/tutorials/openicl_tutorial1_getting_started.ipynb). 70 | 71 | #### Step 3: Initialize the Retriever 72 | ```python 73 | from openicl import TopkRetriever 74 | # Define a retriever using the previous `DataLoader`. 75 | # `ice_num` stands for the number of data in in-context examples. 76 | retriever = TopkRetriever(data, ice_num=8) 77 | ``` 78 | Here we use the popular TopK method to build the retriever. 79 | 80 | #### Step 4: Initialize the Inferencer 81 | ```python 82 | from openicl import PPLInferencer 83 | inferencer = PPLInferencer(model_name='distilgpt2') 84 | ``` 85 | 86 | #### Step 5: Inference and scoring 87 | ```python 88 | from openicl import AccEvaluator 89 | # the inferencer requires retriever to collect in-context examples, as well as a template to wrap up these examples. 90 | predictions = inferencer.inference(retriever, ice_template=template) 91 | # compute accuracy for the prediction 92 | score = AccEvaluator().score(predictions=predictions, references=data.references) 93 | print(score) 94 | ``` 95 | 96 | 97 | 98 | ## Docs 99 | **(updating...)** 100 | 101 | [OpenICL Documentation](https://openicl.readthedocs.io/en/latest/index.html) 102 | 103 | ## Citation 104 | If you find this repository helpful, feel free to cite our paper: 105 | ```bibtex 106 | @article{wu2023openicl, 107 | title={OpenICL: An Open-Source Framework for In-context Learning}, 108 | author={Zhenyu Wu, Yaoxiang Wang, Jiacheng Ye, Jiangtao Feng, Jingjing Xu, Yu Qiao, Zhiyong Wu}, 109 | journal={arXiv preprint arXiv:2303.02913}, 110 | year={2023} 111 | } 112 | ``` 113 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx-rtd-theme 2 | accelerate==0.15.0 3 | datasets==2.7.1 4 | evaluate==0.3.0 5 | faiss_gpu==1.7.2 6 | nltk==3.8 7 | numpy==1.23.4 8 | openai==0.27.1 9 | rank_bm25==0.2.2 10 | requests==2.28.1 11 | scikit_learn==1.2.1 12 | sentence_transformers==2.2.2 13 | torch==1.13.1 14 | tqdm==4.64.1 15 | transformers==4.24.0 16 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # -- Path setup -------------------------------------------------------------- 2 | 3 | # If extensions (or modules to document with autodoc) are in another directory, 4 | # add these directories to sys.path here. If the directory is relative to the 5 | # documentation root, use os.path.abspath to make it absolute, like shown here. 6 | # 7 | import os 8 | import sys 9 | sys.path.insert(0, os.path.abspath('../../')) 10 | import sphinx_rtd_theme 11 | import datetime 12 | 13 | 14 | # -- Project information ----------------------------------------------------- 15 | 16 | project = 'OpenICL' 17 | author = 'Shanghai AI Lab, SharkNLP' 18 | copyright = '{}, {}, Licenced under the Apache License, Version 2.0'.format(datetime.datetime.now().year, author) 19 | 20 | # The full version, including alpha/beta/rc tags 21 | release = 'v0.1.6' 22 | version = 'v0.1.6' 23 | 24 | 25 | # -- General configuration --------------------------------------------------- 26 | 27 | # Add any Sphinx extension module names here, as strings. They can be 28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 29 | # ones. 30 | extensions = [ 31 | 'sphinx.ext.autodoc', 32 | 'sphinx.ext.autosummary', 33 | 'sphinx.ext.doctest', 34 | 'sphinx.ext.intersphinx', 35 | 'sphinx.ext.mathjax', 36 | 'sphinx.ext.napoleon', 37 | 'sphinx.ext.viewcode', 38 | 'sphinx.ext.githubpages', 39 | ] 40 | 41 | # Add any paths that contain templates here, relative to this directory. 42 | templates_path = ['_templates'] 43 | 44 | # List of patterns, relative to source directory, that match files and 45 | # directories to ignore when looking for source files. 46 | # This pattern also affects html_static_path and html_extra_path. 47 | exclude_patterns = [] 48 | 49 | 50 | # -- Options for HTML output ------------------------------------------------- 51 | 52 | # The theme to use for HTML and HTML Help pages. See the documentation for 53 | # a list of builtin themes. 54 | # 55 | html_theme = 'sphinx_rtd_theme' 56 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 57 | 58 | # Add any paths that contain custom static files (such as style sheets) here, 59 | # relative to this directory. They are copied after the builtin static files, 60 | # so a file named "default.css" will overwrite the builtin "default.css". 61 | html_static_path = ['_static'] -------------------------------------------------------------------------------- /docs/source/img/openicl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shark-NLP/OpenICL/1613ae10b88ba2dbfed425c4ee078b2a6586152e/docs/source/img/openicl.png -------------------------------------------------------------------------------- /docs/source/img/overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shark-NLP/OpenICL/1613ae10b88ba2dbfed425c4ee078b2a6586152e/docs/source/img/overview.jpg -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/Shark-NLP/OpenICL 2 | 3 | OpenICL Documentation 4 | =================================== 5 | .. figure:: img/openicl.png 6 | 7 | **OpenICL** is an open-source framework to facilitate research, development, and prototyping of in-context learning. 8 | 9 | .. figure:: img/overview.jpg 10 | :align: center 11 | 12 | Overview of the architecture in OpenICL. 13 | 14 | It provides an easy interface for in-context learning, with many state-of-the-art retrieval and inference methods built in to facilitate systematic comparison of LMs and fast research prototyping. Users can easily incorporate different retrieval and inference methods, as well as different prompt instructions into their workflow. 15 | 16 | .. note:: 17 | 18 | This project is under active development. 19 | 20 | 21 | Citation 22 | -------- 23 | 24 | If you find this repository helpful, feel free to cite our paper: 25 | 26 | .. code-block:: bibtex 27 | 28 | @article{wu2023openicl, 29 | title={OpenICL: An Open-Source Framework for In-context Learning}, 30 | author={Zhenyu Wu, Yaoxiang Wang, Jiacheng Ye, Jiangtao Feng, Jingjing Xu, Yu Qiao, Zhiyong Wu}, 31 | journal={arXiv preprint arXiv:2303.02913}, 32 | year={2023} 33 | } 34 | 35 | .. toctree:: 36 | :glob: 37 | :maxdepth: 3 38 | :caption: Getting Started 39 | 40 | notes/installation 41 | notes/example 42 | 43 | .. toctree:: 44 | :glob: 45 | :maxdepth: 3 46 | :caption: Tutorials 47 | 48 | notes/tutorial 49 | 50 | 51 | .. toctree:: 52 | :glob: 53 | :maxdepth: 2 54 | :caption: Modules 55 | 56 | modules/dataset_reader 57 | modules/prompt_template 58 | modules/inferencer/inferencer_link 59 | modules/retriever/retriever_link 60 | 61 | Indices and tables 62 | ================== 63 | 64 | * :ref:`genindex` 65 | * :ref:`search` 66 | -------------------------------------------------------------------------------- /docs/source/modules/dataset_reader.rst: -------------------------------------------------------------------------------- 1 | DatasetReader 2 | ============= 3 | 4 | .. autoclass:: openicl.DatasetReader 5 | :inherited-members: -------------------------------------------------------------------------------- /docs/source/modules/inferencer/base_inferencer.rst: -------------------------------------------------------------------------------- 1 | BaseInferencer 2 | ============== 3 | 4 | .. autoclass:: openicl.icl_inferencer.icl_base_inferencer.BaseInferencer 5 | :inherited-members: 6 | -------------------------------------------------------------------------------- /docs/source/modules/inferencer/cot_inferencer.rst: -------------------------------------------------------------------------------- 1 | CoTInferencer 2 | ============= 3 | 4 | .. autoclass:: openicl.icl_inferencer.icl_cot_inferencer.CoTInferencer 5 | :inherited-members: 6 | -------------------------------------------------------------------------------- /docs/source/modules/inferencer/gen_inferencer.rst: -------------------------------------------------------------------------------- 1 | GenInferencer 2 | ============= 3 | 4 | .. autoclass:: openicl.icl_inferencer.icl_gen_inferencer.GenInferencer 5 | :inherited-members: 6 | -------------------------------------------------------------------------------- /docs/source/modules/inferencer/inferencer_link.rst: -------------------------------------------------------------------------------- 1 | Inferencer 2 | ========== 3 | 4 | .. toctree:: 5 | :glob: 6 | :maxdepth: 3 7 | :caption: Inferencer Classes 8 | 9 | base_inferencer.rst 10 | ppl_inferencer.rst 11 | gen_inferencer.rst 12 | cot_inferencer.rst 13 | -------------------------------------------------------------------------------- /docs/source/modules/inferencer/ppl_inferencer.rst: -------------------------------------------------------------------------------- 1 | PPLInferencer 2 | ============= 3 | 4 | .. autoclass:: openicl.icl_inferencer.icl_ppl_inferencer.PPLInferencer 5 | :inherited-members: 6 | -------------------------------------------------------------------------------- /docs/source/modules/prompt_template.rst: -------------------------------------------------------------------------------- 1 | PromptTemplate 2 | =============== 3 | 4 | .. autoclass:: openicl.PromptTemplate 5 | :inherited-members: -------------------------------------------------------------------------------- /docs/source/modules/retriever/base_retriever.rst: -------------------------------------------------------------------------------- 1 | BaseRetriever 2 | ============= 3 | 4 | .. autoclass:: openicl.icl_retriever.icl_base_retriever.BaseRetriever 5 | :inherited-members: 6 | -------------------------------------------------------------------------------- /docs/source/modules/retriever/bm25_retriever.rst: -------------------------------------------------------------------------------- 1 | BM25Retriever 2 | ============= 3 | 4 | .. autoclass:: openicl.icl_retriever.icl_bm25_retriever.BM25Retriever 5 | :inherited-members: 6 | -------------------------------------------------------------------------------- /docs/source/modules/retriever/dpp_retriever.rst: -------------------------------------------------------------------------------- 1 | DPPRetriever 2 | ============ 3 | 4 | .. autoclass:: openicl.icl_retriever.icl_dpp_retriever.DPPRetriever 5 | :inherited-members: 6 | -------------------------------------------------------------------------------- /docs/source/modules/retriever/mdl_retriever.rst: -------------------------------------------------------------------------------- 1 | MDLRetriever 2 | ============ 3 | 4 | .. autoclass:: openicl.icl_retriever.icl_mdl_retriever.MDLRetriever 5 | :inherited-members: 6 | -------------------------------------------------------------------------------- /docs/source/modules/retriever/random_retriever.rst: -------------------------------------------------------------------------------- 1 | RandomRetriever 2 | =============== 3 | 4 | .. autoclass:: openicl.icl_retriever.icl_random_retriever.RandomRetriever 5 | :inherited-members: 6 | -------------------------------------------------------------------------------- /docs/source/modules/retriever/retriever_link.rst: -------------------------------------------------------------------------------- 1 | Retriever 2 | ========== 3 | 4 | .. toctree:: 5 | :glob: 6 | :maxdepth: 3 7 | :caption: Retriever Classes 8 | 9 | base_retriever 10 | random_retriever 11 | bm25_retriever 12 | topk_retriever 13 | votek_retriever 14 | dpp_retriever 15 | mdl_retriever 16 | zero_retriever 17 | -------------------------------------------------------------------------------- /docs/source/modules/retriever/topk_retriever.rst: -------------------------------------------------------------------------------- 1 | TopkRetriever 2 | ============= 3 | 4 | .. autoclass:: openicl.icl_retriever.icl_topk_retriever.TopkRetriever 5 | :inherited-members: 6 | -------------------------------------------------------------------------------- /docs/source/modules/retriever/votek_retriever.rst: -------------------------------------------------------------------------------- 1 | VotekRetriever 2 | ============== 3 | 4 | .. autoclass:: openicl.icl_retriever.icl_votek_retriever.VotekRetriever 5 | :inherited-members: 6 | -------------------------------------------------------------------------------- /docs/source/modules/retriever/zero_retriever.rst: -------------------------------------------------------------------------------- 1 | ZeroRetriever 2 | ============= 3 | 4 | .. autoclass:: openicl.icl_retriever.icl_zero_retriever.ZeroRetriever 5 | :inherited-members: 6 | -------------------------------------------------------------------------------- /docs/source/notes/example.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/Shark-NLP/OpenICL 2 | 3 | A Simple Example 4 | ======================== 5 | 6 | Following example shows you how to perform ICL on sentiment classification dataset. More examples and tutorials can be found at our github `repository `_. 7 | 8 | Step 1: Load and prepare data 9 | ---------------------------------- 10 | 11 | .. code-block:: python 12 | 13 | from datasets import load_dataset 14 | from openicl import DatasetReader 15 | 16 | # Loading dataset from huggingface 17 | dataset = load_dataset('gpt3mix/sst2') 18 | 19 | # Define a DatasetReader, with specified column names where input and output are stored. 20 | data = DatasetReader(dataset, input_columns=['text'], output_column='label') 21 | 22 | 23 | Step 2: Define the prompt template (Optional) 24 | --------------------------------------------- 25 | 26 | .. code-block:: python 27 | 28 | from openicl import PromptTemplate 29 | tp_dict = { 30 | 0: "Positive Movie Review: ", 31 | 1: "Negative Movie Review: " 32 | } 33 | 34 | template = PromptTemplate(tp_dict, {'text': ''}, ice_token='') 35 | 36 | The placeholder `` and `` will be replaced by in-context examples and testing input, respectively. 37 | 38 | Step 3: Initialize the Retriever 39 | -------------------------------- 40 | 41 | .. code-block:: python 42 | 43 | from openicl import TopkRetriever 44 | # Define a retriever using the previous `DataLoader`. 45 | # `ice_num` stands for the number of data in in-context examples. 46 | retriever = TopkRetriever(data, ice_num=8) 47 | 48 | Here we use the popular `TopK `_ method to build the retriever. 49 | 50 | Step 4: Initialize the Inferencer 51 | --------------------------------- 52 | 53 | .. code-block:: python 54 | 55 | from openicl import PPLInferencer 56 | inferencer = PPLInferencer(model_name='distilgpt2') 57 | 58 | Step 5: Inference and scoring 59 | ----------------------------- 60 | 61 | .. code-block:: python 62 | 63 | from openicl import AccEvaluator 64 | # the inferencer requires retriever to collect in-context examples, as well as a template to wrap up these examples. 65 | predictions = inferencer.inference(retriever, ice_template=template) 66 | # compute accuracy for the prediction 67 | score = AccEvaluator().score(predictions=predictions, references=data.references) 68 | print(score) 69 | -------------------------------------------------------------------------------- /docs/source/notes/installation.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/Shark-NLP/OpenICL 2 | 3 | Installation 4 | ======================== 5 | 6 | .. note:: 7 | 8 | OpenICL requires `Python 3.8+ `_ 9 | 10 | Using Pip 11 | ---------------------------------- 12 | .. code-block:: bash 13 | 14 | pip install openicl 15 | 16 | 17 | Installation for local development 18 | ---------------------------------- 19 | 20 | You can install the latest version of OpenICL from our github `repository `_. 21 | 22 | .. code-block:: bash 23 | 24 | git clone https://github.com/Shark-NLP/OpenICL 25 | cd OpenICL 26 | pip install -e . 27 | -------------------------------------------------------------------------------- /docs/source/notes/tutorial.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/Shark-NLP/OpenICL 2 | 3 | Tutorials List 4 | =============== 5 | 6 | tutorials can be found at our github `repository `_. 7 | 8 | .. note:: 9 | 10 | Tutorials are still being updated. 11 | 12 | .. list-table:: 13 | :header-rows: 1 14 | 15 | * - Notebook 16 | - Description 17 | - Colab Link 18 | * - `Getting Started with OpenICL `_ 19 | - Introduction to the main components of OpenICL 20 | - .. image:: https://colab.research.google.com/assets/colab-badge.svg 21 | :target: https://colab.research.google.com/github/Shark-NLP/OpenICL/blob/main/examples/tutorials/openicl_tutorial1_getting_started.ipynb 22 | 23 | * - `Using Different Language Models with OpenICL `_ 24 | - Run different language models with OpenICL 25 | - .. image:: https://colab.research.google.com/assets/colab-badge.svg 26 | :target: https://colab.research.google.com/github/Shark-NLP/OpenICL/blob/main/examples/tutorials/openicl_tutorial2_use_different_models.ipynb 27 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | We host a wide range of [Tutorials](https://github.com/Shark-NLP/OpenICL/tree/main/examples/tutorials) to elaborate the basic usage of OpenICL. 4 | 5 | We also have some [research projects](https://github.com/Shark-NLP/OpenICL/tree/main/examples/research_projects) that reproduce results in research papers using OpenICL. 6 | 7 | Please discuss in an [issue](https://github.com/Shark-NLP/OpenICL/issues) a feature you would 8 | like to implement in an example before submitting a PR; we welcome bug fixes, 9 | but since we want to keep the examples as simple as possible it's unlikely 10 | that we will merge a pull request adding more functionality at the cost of readability. -------------------------------------------------------------------------------- /examples/research_projects/README.md: -------------------------------------------------------------------------------- 1 | # Research Projects 2 | Here, you can find the code to reproduce some ICL-related paper experiments using OpenICL(**updating...**) 3 | 4 | | Notebook | Paper | | 5 | |:----------|:-------------|:-------------| 6 | [self-adaptive in-context learning](https://github.com/Shark-NLP/OpenICL/blob/main/examples/research_projects/self-adaptive_in-context_learning.ipynb) | [self-adaptive in-context learning](https://arxiv.org/abs/2212.10375) | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Shark-NLP/OpenICL/blob/main/examples/research_projects/self-adaptive_in-context_learning.ipynb) | -------------------------------------------------------------------------------- /examples/research_projects/self-adaptive_in-context_learning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "ec91d2ba-7ca4-4bd4-a6f5-c78a5b321179", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%pip install --upgrade pip\n", 11 | "%pip install openicl" 12 | ] 13 | }, 14 | { 15 | "attachments": {}, 16 | "cell_type": "markdown", 17 | "id": "c9acf526-7e21-4f95-9a94-fb3b00973e8d", 18 | "metadata": {}, 19 | "source": [ 20 | "# Self-adaptive In-context Learning\n", 21 | "---\n", 22 | "Code for paper [Self-adaptive In-context Learning](https://arxiv.org/abs/2212.10375)" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "id": "fd2e9c3a-c8eb-4a03-827f-253ed84b0c7e", 28 | "metadata": {}, 29 | "source": [ 30 | "## Templates " 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 1, 36 | "id": "73a948c3-aa16-4b1e-9bc3-c6a6c0fdb7e3", 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "from openicl import PromptTemplate\n", 41 | "\n", 42 | "# SST-2\n", 43 | "sst2_tp_dict = {\n", 44 | " 0: 'Positive Movie Review: \\\"\\\"', \n", 45 | " 1: 'Negative Movie Review: \\\"\\\"',\n", 46 | "}\n", 47 | "sst2_template = PromptTemplate(sst2_tp_dict, column_token_map={'text' : ''}, ice_token='')\n", 48 | "\n", 49 | "# SST-5\n", 50 | "sst5_tp_dict = {\n", 51 | " 0: \"Review: \\nSentiment: terrible\",\n", 52 | " 1: \"Review: \\nSentiment: bad\",\n", 53 | " 2: \"Review: \\nSentiment: okay\",\n", 54 | " 3: \"Review: \\nSentiment: good\",\n", 55 | " 4: \"Review: \\nSentiment: great\",\n", 56 | "}\n", 57 | "sst5_template = PromptTemplate(sst5_tp_dict, column_token_map={'text' : ''}, ice_token='')\n", 58 | "\n", 59 | "# AG_NEWS\n", 60 | "ag_news_tp_dict = {\n", 61 | " 0: \"\\\"\\\" It is about world.\",\n", 62 | " 1: \"\\\"\\\" It is about sports.\",\n", 63 | " 2: \"\\\"\\\" It is about business.\",\n", 64 | " 3: \"\\\"\\\" It is about science and technology.\",\n", 65 | "}\n", 66 | "ag_news_template = PromptTemplate(ag_news_tp_dict, column_token_map={'text' : ''}, ice_token='')\n", 67 | "\n", 68 | "# TREC\n", 69 | "trec_tp_dict = {\n", 70 | " 0: \"\\\"\\\" It is about abbreviation.\",\n", 71 | " 1: \"\\\"\\\" It is about entity.\",\n", 72 | " 2: \"\\\"\\\" It is about description and abstract concept.\",\n", 73 | " 3: \"\\\"\\\" It is about human being.\",\n", 74 | " 4: \"\\\"\\\" It is about location.\",\n", 75 | " 5: \"\\\"\\\" It is about numeric value.\"\n", 76 | "}\n", 77 | "trec_template = PromptTemplate(trec_tp_dict, column_token_map={'text' : ''}, ice_token='')\n", 78 | "\n", 79 | "# SNLI & MNLI\n", 80 | "xnli_tp_dict = {\n", 81 | " 0: '? Yes, ',\n", 82 | " 1: '? Maybe, ',\n", 83 | " 2: '? No, '\n", 84 | "}\n", 85 | "xnli_template = PromptTemplate(xnli_tp_dict, column_token_map={'premise' : '', 'hypothesis' : ''}, ice_token='')\n", 86 | "\n", 87 | "# QNLI \n", 88 | "qnli_tp_dict = {\n", 89 | " 0: \" Can we know ? Yes.\",\n", 90 | " 1: \" Can we know ? No.\",\n", 91 | "}\n", 92 | "qnli_template = PromptTemplate(qnli_tp_dict, column_token_map={'sentence' : '', 'question' : ''}, ice_token='')\n", 93 | "\n", 94 | "# Commonsense QA\n", 95 | "cmsqa_template=PromptTemplate(\n", 96 | " {\n", 97 | " 'A': \"Answer the following question:\\n\\nAnswer: \",\n", 98 | " 'B': \"Answer the following question:\\n\\nAnswer: \",\n", 99 | " 'C': \"Answer the following question:\\n\\nAnswer: \",\n", 100 | " 'D': \"Answer the following question:\\n\\nAnswer: \",\n", 101 | " 'E': \"Answer the following question:\\n\\nAnswer: \",\n", 102 | " },\n", 103 | " {'question':'', 'A': '', 'B': '', 'C': '', 'D': '', 'E': ''},\n", 104 | " ice_token='' \n", 105 | ")\n", 106 | "\n", 107 | "templates = {'sst2': sst2_template,\n", 108 | " 'snli': xnli_template,\n", 109 | " 'mnli': xnli_template,\n", 110 | " \"qnli\": qnli_template,\n", 111 | " \"sst5\": sst5_template,\n", 112 | " \"ag_news\": ag_news_template,\n", 113 | " \"trec\": trec_template,\n", 114 | " \"commonsense_qa\": cmsqa_template\n", 115 | " }" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "id": "cde382ff", 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "## Datasets " 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "id": "fb05fc8a", 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "from datasets import load_dataset\n", 136 | "from openicl import DatasetReader\n", 137 | "\n", 138 | "data_path = {'sst2': [\"gpt3mix/sst2\", None],\n", 139 | " 'snli': ['snli', None],\n", 140 | " 'mnli': ['LysandreJik/glue-mnli-train', None],\n", 141 | " \"qnli\": [\"glue\", \"qnli\"],\n", 142 | " \"sst5\": [\"SetFit/sst5\", None],\n", 143 | " \"ag_news\": [\"ag_news\", None],\n", 144 | " \"trec\": [\"trec\", None],\n", 145 | " \"commonsense_qa\": [\"commonsense_qa\", None]\n", 146 | " }\n", 147 | "\n", 148 | "input_columns={'sst2': [\"text\"],\n", 149 | " 'snli': ['premise', 'hypothesis'],\n", 150 | " 'mnli': ['premise', 'hypothesis'],\n", 151 | " \"qnli\": [\"sentence\", \"question\"],\n", 152 | " \"sst5\": [\"text\"],\n", 153 | " \"ag_news\": [\"text\"],\n", 154 | " \"trec\": [\"text\"],\n", 155 | " \"commonsense_qa\": ['question', 'A', 'B', 'C', 'D', 'E']\n", 156 | " }\n", 157 | "\n", 158 | "output_column={'sst2': 'label',\n", 159 | " 'snli': 'label',\n", 160 | " 'mnli': 'label',\n", 161 | " \"qnli\": 'label',\n", 162 | " \"sst5\": 'label',\n", 163 | " \"ag_news\": 'label',\n", 164 | " \"trec\": 'label-coarse',\n", 165 | " \"commonsense_qa\": \"answerKey\"\n", 166 | " }\n", 167 | "\n", 168 | "# Change it for other tasks\n", 169 | "task_name='snli'\n", 170 | "\n", 171 | "path,name=data_path[task_name]\n", 172 | "dataset = load_dataset(path=path,name=name)\n", 173 | "\n", 174 | "# Preprocess for commonsense_qa\n", 175 | "def pre_process(example):\n", 176 | " for i in range(5):\n", 177 | " example[chr(ord('A') + i)] = example['choices']['text'][i]\n", 178 | " return example\n", 179 | "\n", 180 | "if task_name=='commonsense_qa':\n", 181 | " dataset=dataset.map(pre_process).remove_columns(['question_concept', 'id', 'choices'])\n", 182 | "\n", 183 | "\n", 184 | "data=DatasetReader(dataset, input_columns=input_columns[task_name], output_column=output_column[task_name])\n", 185 | "\n", 186 | "\n", 187 | "test_split={\n", 188 | " 'sst2': 'test',\n", 189 | " 'snli': 'test',\n", 190 | " \"sst5\": 'test',\n", 191 | " \"ag_news\": 'test',\n", 192 | " \"trec\": 'test',\n", 193 | " 'mnli': 'validation', # cannot get gold labels for the test split\n", 194 | " \"qnli\": 'validation',\n", 195 | " \"commonsense_qa\": \"validation\"\n", 196 | "}\n", 197 | "# If you only want to test part of the test set for faster running, you can use the following codes\n", 198 | "# dataset['test'] = dataset['test'].select(list(range(100)))\n", 199 | "# dataset['validation'] = dataset['validation'].select(list(range(100))) # trec,agnews don't have validation\n", 200 | "# dataset['train'] = dataset['train'].select(list(range(100)))" 201 | ] 202 | }, 203 | { 204 | "attachments": {}, 205 | "cell_type": "markdown", 206 | "id": "5f28ecda-13f8-430c-a0c9-c85e52c806cf", 207 | "metadata": {}, 208 | "source": [ 209 | "### TopK-MDL Experiments (method in the paper)" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "id": "4b3bc424-fe92-4b2a-9d5f-ac942ebc675f", 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "from openicl import MDLRetriever, PPLInferencer, AccEvaluator\n", 220 | "\n", 221 | "retriever = MDLRetriever(data, ice_num=8, candidate_num=30, select_time=10, seed=1, batch_size=12, test_split=test_split[task_name])\n", 222 | "\n", 223 | "inferencer = PPLInferencer(model_name='gpt2-xl', batch_size=8)\n", 224 | "\n", 225 | "predictions = inferencer.inference(retriever, ice_template=templates[task_name], output_json_filename=f'mdl_{task_name}')\n", 226 | "\n", 227 | "scores = AccEvaluator().score(predictions=predictions, references=data.references)\n", 228 | "print(scores)" 229 | ] 230 | } 231 | ], 232 | "metadata": { 233 | "kernelspec": { 234 | "display_name": "Python 3", 235 | "language": "python", 236 | "name": "python3" 237 | }, 238 | "language_info": { 239 | "codemirror_mode": { 240 | "name": "ipython", 241 | "version": 3 242 | }, 243 | "file_extension": ".py", 244 | "mimetype": "text/x-python", 245 | "name": "python", 246 | "nbconvert_exporter": "python", 247 | "pygments_lexer": "ipython3", 248 | "version": "3.7.13" 249 | } 250 | }, 251 | "nbformat": 4, 252 | "nbformat_minor": 5 253 | } -------------------------------------------------------------------------------- /examples/tutorials/README.md: -------------------------------------------------------------------------------- 1 | # OpenICL Notebooks 2 | 3 | (**updating...**) 4 | 5 | | Notebook | Description | | 6 | |:----------|:-------------|:-------------| 7 | [Getting Started with OpenICL](https://github.com/Shark-NLP/OpenICL/blob/main/examples/tutorials/openicl_tutorial1_getting_started.ipynb) | Introduction to the main components of OpenICL | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Shark-NLP/OpenICL/blob/main/examples/tutorials/openicl_tutorial1_getting_started.ipynb) | 8 | [Using Different Language Models with OpenICL](https://github.com/Shark-NLP/OpenICL/blob/main/examples/tutorials/openicl_tutorial2_use_different_models.ipynb) | Run different language models with OpenICL | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Shark-NLP/OpenICL/blob/main/examples/tutorials/openicl_tutorial2_use_different_models.ipynb) | 9 | [Accelerating OpenICL with 🤗 Accelerate](https://github.com/Shark-NLP/OpenICL/blob/main/examples/tutorials/openicl_tutorial3_accelerate.ipynb) | A Guide to Distributed Data Parallel and Model Parallel | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Shark-NLP/OpenICL/blob/main/examples/tutorials/openicl_tutorial3_accelerate.ipynb) -------------------------------------------------------------------------------- /examples/tutorials/openicl_tutorial2_use_different_models.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "318733a0-eb89-4df2-b278-e80afc144cca", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%pip install --upgrade pip\n", 11 | "%pip install openicl\n", 12 | "# Restart the kernel after the installation is completed" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "id": "1be4e6a4-dff9-4f39-a86f-62736ccd82c4", 18 | "metadata": {}, 19 | "source": [ 20 | "# 2. Using Different Language Models with OpenICL" 21 | ] 22 | }, 23 | { 24 | "attachments": {}, 25 | "cell_type": "markdown", 26 | "id": "b92d14f4-2e73-4b1c-ba9e-0f9112a764b1", 27 | "metadata": {}, 28 | "source": [ 29 | "In this chapter, we will show you how to use OpenICL to do in-context learning (ICL) with different language models. Mainly including [GPT-2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf), [FLAN-T5](https://arxiv.org/abs/2109.01652), [XGLM](https://arxiv.org/abs/2112.10668), OpenAI's [GPT-3](https://arxiv.org/abs/2005.14165) API and [OPT-175B](https://arxiv.org/abs/2205.01068) API." 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "id": "06ba89e7-2790-4097-9e5d-c75e356e152f", 35 | "metadata": {}, 36 | "source": [ 37 | "## 2-1 Huggingface Library's Models" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "id": "ca01fccb-6188-4960-85ee-662af04718ce", 43 | "metadata": {}, 44 | "source": [ 45 | "In this section, we will take GPT2, FLAN-T5, and XGLM as examples to show you how to use the models in the [huggingface library](https://huggingface.co/models) with OpenICL. Generally speaking, you only need to assign the corresponding name to the `model_name` parameter when declaring `Inferencer`, but we will still provide you with some specific examples." 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "id": "2aa253f5-7d3e-4810-a5fe-ddebc9699551", 51 | "metadata": {}, 52 | "source": [ 53 | "### 2-1-1 GPT-2" 54 | ] 55 | }, 56 | { 57 | "attachments": {}, 58 | "cell_type": "markdown", 59 | "id": "2a7f087a", 60 | "metadata": {}, 61 | "source": [ 62 | "This example can be found in [tutorial1](https://github.com/Shark-NLP/OpenICL/blob/main/examples/tutorials/openicl_tutorial1_getting_started.ipynb). But this time, we set `batch_size` for `TopkRetriever` and `PPLInference` to speed up. It can be noticed that the values ​​of the two `batch_size`(s) could be set to be different (`8` and `6`). That is because, at the beginning of retrieval and inference, the corresponding components will receive the complete dataset or the retrieval results for the entire test set, instead of processing the data in batches." 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "id": "faff9428", 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "from openicl import DatasetReader, PromptTemplate, TopkRetriever, PPLInferencer\n", 73 | "\n", 74 | "# Define a DatasetReader, loading dataset from huggingface.\n", 75 | "data = DatasetReader('gpt3mix/sst2', input_columns=['text'], output_column='label')\n", 76 | "\n", 77 | "# SST-2 Template Example\n", 78 | "tp_dict = {\n", 79 | " 0: 'Positive Movie Review: ',\n", 80 | " 1: 'Negative Movie Review: ' \n", 81 | "}\n", 82 | "template = PromptTemplate(tp_dict, {'text' : ''}, ice_token='')\n", 83 | "\n", 84 | "# TopK Retriever\n", 85 | "retriever = TopkRetriever(data, ice_num=2, batch_size=8)\n", 86 | "\n", 87 | "# Define a Inferencer\n", 88 | "inferencer = PPLInferencer(model_name='gpt2', batch_size=6)\n", 89 | "\n", 90 | "# Inference\n", 91 | "predictions = inferencer.inference(retriever, ice_template=template, output_json_filename='gpt2_sst2')\n", 92 | "print(predictions)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "id": "f0a7b09a-2c1a-468e-bf1a-26f69ad51b21", 98 | "metadata": {}, 99 | "source": [ 100 | "### 2-1-2 XGLM" 101 | ] 102 | }, 103 | { 104 | "attachments": {}, 105 | "cell_type": "markdown", 106 | "id": "b56bb0c2", 107 | "metadata": {}, 108 | "source": [ 109 | "When it comes to machine translation, it is a good choice to use XGLM. But when using XGLM, we **don't suggest** to set `batch_size` in `GenInferencer`. (When calling the `model.generate` method of [huggingface transformers library](https://huggingface.co/docs/transformers/index), padding is needed if you want to input multiple pieces of data at a time. But we found in the test that if padding exists, the generation of XGLM will be affected). The code for evaluating the ICL performance of XGLM (7.5B) on WMT16 (de-en) dataset\n", 110 | "with direct inference strategy is as follows:" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "id": "6951df6a", 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "from openicl import DatasetReader, PromptTemplate, BM25Retriever, GenInferencer\n", 121 | "from datasets import load_dataset\n", 122 | "\n", 123 | "# Loading dataset from huggingface \n", 124 | "dataset = load_dataset('wmt16', name='de-en')\n", 125 | "\n", 126 | "# Data Preprocessing\n", 127 | "dataset = dataset.map(lambda example: example['translation']).remove_columns('translation')\n", 128 | "\n", 129 | "# Define a DatasetReader, selecting 5 pieces of data randomly.\n", 130 | "data = DatasetReader(dataset, input_columns='de', output_column='en', ds_size=5)\n", 131 | "\n", 132 | "# WMT16 en->de Template Example\n", 133 | "template = PromptTemplate(' = ', {'en' : '', 'de' : ''}, ice_token='')\n", 134 | "\n", 135 | "# BM25 Retriever\n", 136 | "retriever = BM25Retriever(data, ice_num=1, index_split='validation', test_split='test', batch_size=5)\n", 137 | "\n", 138 | "# Define a Inferencer\n", 139 | "inferencer = GenInferencer(model_name='facebook/xglm-7.5B')\n", 140 | "\n", 141 | "# Inference\n", 142 | "predictions = inferencer.inference(retriever, ice_template=template, output_json_filename='xglm_wmt')\n", 143 | "print(predictions)" 144 | ] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "id": "2518e5ee-b58e-4378-a30f-ac0535da5559", 149 | "metadata": {}, 150 | "source": [ 151 | "### 2-1-3 FLAN-T5" 152 | ] 153 | }, 154 | { 155 | "attachments": {}, 156 | "cell_type": "markdown", 157 | "id": "8778eead", 158 | "metadata": {}, 159 | "source": [ 160 | "In this section, we will use FLAN-T5 with OpenICL to reproduce the results in the figure below:" 161 | ] 162 | }, 163 | { 164 | "attachments": {}, 165 | "cell_type": "markdown", 166 | "id": "be985353", 167 | "metadata": {}, 168 | "source": [ 169 | "
\n", 170 | "\n", 171 | "

(figure in Finetuned Language Models Are Zero-Shot Learners)

\n", 172 | "
" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 13, 178 | "id": "1f592cb3-fae4-4295-baa9-1a3998c3ea19", 179 | "metadata": {}, 180 | "outputs": [ 181 | { 182 | "name": "stderr", 183 | "output_type": "stream", 184 | "text": [ 185 | "Found cached dataset snli (/home/zhangyudejia/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)\n", 186 | "100%|██████████| 3/3 [00:00<00:00, 323.35it/s]\n", 187 | "[2023-03-10 15:40:19,712] [openicl.icl_inferencer.icl_gen_inferencer] [INFO] Starting inference process...\n", 188 | "100%|██████████| 10/10 [00:00<00:00, 15.22it/s]" 189 | ] 190 | }, 191 | { 192 | "name": "stdout", 193 | "output_type": "stream", 194 | "text": [ 195 | "['yes', 'yes', 'yes', 'no', 'yes', 'no', 'it is not possible to tell', 'yes', 'yes', 'it is not possible to tell']\n" 196 | ] 197 | }, 198 | { 199 | "name": "stderr", 200 | "output_type": "stream", 201 | "text": [ 202 | "\n" 203 | ] 204 | } 205 | ], 206 | "source": [ 207 | "from openicl import DatasetReader, PromptTemplate, ZeroRetriever, GenInferencer\n", 208 | "\n", 209 | "# Define a DatasetReader, loading dataset from huggingface and selecting 10 pieces of data randomly.\n", 210 | "data = DatasetReader('snli', input_columns=['premise', 'hypothesis'], output_column='label', ds_size=10)\n", 211 | "\n", 212 | "# SNLI Template\n", 213 | "tp_str = 'Premise:\\nHypothesis:\\nDoes the premise entail the hypothesis?\\nOPTIONS:\\n-yes -It is not possible to tell -no'\n", 214 | "template = PromptTemplate(tp_str, column_token_map={'premise' : '', 'hypothesis' : ''}, ice_token='')\n", 215 | "\n", 216 | "# ZeroShot Retriever (do nothing)\n", 217 | "retriever = ZeroRetriever(data, index_split='train', test_split='test')\n", 218 | "\n", 219 | "# Define a Inferencer\n", 220 | "inferencer = GenInferencer(model_name='google/flan-t5-small', max_model_token_num=1000)\n", 221 | "\n", 222 | "# Inference\n", 223 | "predictions = inferencer.inference(retriever, ice_template=template, output_json_filename='flan-t5-small')\n", 224 | "print(predictions)" 225 | ] 226 | }, 227 | { 228 | "attachments": {}, 229 | "cell_type": "markdown", 230 | "id": "39e65e67", 231 | "metadata": {}, 232 | "source": [ 233 | "## 2-2 Using API-based model" 234 | ] 235 | }, 236 | { 237 | "attachments": {}, 238 | "cell_type": "markdown", 239 | "id": "17a9d790", 240 | "metadata": {}, 241 | "source": [ 242 | "OpenICL also currently supports OpenAI's GPT-3 API and OPT-175B API. But before using them, users need to do some configuration." 243 | ] 244 | }, 245 | { 246 | "attachments": {}, 247 | "cell_type": "markdown", 248 | "id": "deab3b33", 249 | "metadata": {}, 250 | "source": [ 251 | "### 2-2-1 OpenAI's GPT-3 API" 252 | ] 253 | }, 254 | { 255 | "attachments": {}, 256 | "cell_type": "markdown", 257 | "id": "5b784102", 258 | "metadata": {}, 259 | "source": [ 260 | "OpenAI provides its own open-source library -- [openai](https://github.com/openai/openai-python), for users to call their API services. To use this library in OpenICL, you need to set environment variable `OPEN_API_KEY` in advance. Here is a simple way (for detailed information, see openai's documentation [here](https://platform.openai.com/docs/api-reference/introduction)):" 261 | ] 262 | }, 263 | { 264 | "cell_type": "code", 265 | "execution_count": null, 266 | "id": "84f229ea", 267 | "metadata": { 268 | "vscode": { 269 | "languageId": "powershell" 270 | } 271 | }, 272 | "outputs": [], 273 | "source": [ 274 | "# Replace 'your_api_key' with your key, and run this command in bash\n", 275 | "export OPENAI_API_KEY=\"your_api_key\"" 276 | ] 277 | }, 278 | { 279 | "attachments": {}, 280 | "cell_type": "markdown", 281 | "id": "f7a7dab9", 282 | "metadata": {}, 283 | "source": [ 284 | "After the setting is complete, set `api_name='gpt3'` in `Inferencer` to use it normally. Below is a code snippet:" 285 | ] 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 1, 290 | "id": "9510bd1e", 291 | "metadata": {}, 292 | "outputs": [ 293 | { 294 | "name": "stderr", 295 | "output_type": "stream", 296 | "text": [ 297 | "/home/zhangyudejia/.local/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 298 | " from .autonotebook import tqdm as notebook_tqdm\n", 299 | "Found cached dataset mtop (/home/zhangyudejia/.cache/huggingface/datasets/iohadrubin___mtop/mtop/1.0.0/4ba6d9db9efaebd4f6504db7e36925632e959f456071b9d7f1b86a85cce52448)\n", 300 | "100%|██████████| 3/3 [00:00<00:00, 814.96it/s]\n", 301 | "[2023-03-10 19:03:00,481] [openicl.icl_retriever.icl_bm25_retriever] [INFO] Retrieving data for test set...\n", 302 | "100%|██████████| 1/1 [00:00<00:00, 1504.95it/s]\n", 303 | "[2023-03-10 19:03:00,486] [openicl.icl_inferencer.icl_gen_inferencer] [INFO] Starting inference process...\n", 304 | "100%|██████████| 1/1 [00:06<00:00, 6.38s/it]" 305 | ] 306 | }, 307 | { 308 | "name": "stdout", 309 | "output_type": "stream", 310 | "text": [ 311 | "['[IN:GET_EVENTS [SL:TYPE music festivals ] [SL:DATE in 2018 ] ]']\n" 312 | ] 313 | }, 314 | { 315 | "name": "stderr", 316 | "output_type": "stream", 317 | "text": [ 318 | "\n" 319 | ] 320 | } 321 | ], 322 | "source": [ 323 | "from openicl import DatasetReader, PromptTemplate, BM25Retriever, GenInferencer\n", 324 | "from datasets import load_dataset\n", 325 | "\n", 326 | "dataset = load_dataset(\"iohadrubin/mtop\")\n", 327 | "dataset['train'] = dataset['train'].select([0, 1, 2])\n", 328 | "dataset['test'] = dataset['test'].select([0])\n", 329 | "\n", 330 | "dr = DatasetReader(dataset, input_columns=['question'], output_column='logical_form') \n", 331 | "\n", 332 | "tp_str = \"\\t\" \n", 333 | "tp = PromptTemplate(tp_str, column_token_map={'question' : '', 'logical_form' : ''}, ice_token='')\n", 334 | "\n", 335 | "rtr = BM25Retriever(dr, ice_num=1)\n", 336 | "\n", 337 | "infr = GenInferencer(api_name='gpt3', engine='text-davinci-003', sleep_time=3)\n", 338 | "\n", 339 | "print(infr.inference(rtr, ice_template=tp))" 340 | ] 341 | }, 342 | { 343 | "attachments": {}, 344 | "cell_type": "markdown", 345 | "id": "9d974057", 346 | "metadata": {}, 347 | "source": [ 348 | "Some models of OpenAI are charged and have a rate limit. So we set `sleep_time`(3 seconds) here to control the frequency of data requests. In order to prevent data loss caused by throwing exceptions, we also recommend using this API on a small-scale test set every time. For more information about API parameter configuration in OpenICL, please view [api_service.py](https://github.com/Shark-NLP/OpenICL/blob/main/openicl/utils/api_service.py)." 349 | ] 350 | }, 351 | { 352 | "attachments": {}, 353 | "cell_type": "markdown", 354 | "id": "48635ed8", 355 | "metadata": {}, 356 | "source": [ 357 | "### 2-2-2 OPT-175B API " 358 | ] 359 | }, 360 | { 361 | "attachments": {}, 362 | "cell_type": "markdown", 363 | "id": "f3a240a2", 364 | "metadata": {}, 365 | "source": [ 366 | "For OPT-175B, you need to deploy the model yourself (or get a URL of a deployed model from your friends :\\) ). \n", 367 | "Visit the [metaseq](https://github.com/facebookresearch/metaseq) repository for more information on deployment." 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": null, 373 | "id": "9bf3c44c", 374 | "metadata": {}, 375 | "outputs": [], 376 | "source": [ 377 | "from openicl import GenInferencer\n", 378 | "\n", 379 | "URL = \"xxx\"\n", 380 | "inferencer = GenInferencer(api_name='opt-175b', URL=URL)" 381 | ] 382 | } 383 | ], 384 | "metadata": { 385 | "kernelspec": { 386 | "display_name": "Python 3", 387 | "language": "python", 388 | "name": "python3" 389 | }, 390 | "language_info": { 391 | "codemirror_mode": { 392 | "name": "ipython", 393 | "version": 3 394 | }, 395 | "file_extension": ".py", 396 | "mimetype": "text/x-python", 397 | "name": "python", 398 | "nbconvert_exporter": "python", 399 | "pygments_lexer": "ipython3", 400 | "version": "3.8.16" 401 | } 402 | }, 403 | "nbformat": 4, 404 | "nbformat_minor": 5 405 | } 406 | -------------------------------------------------------------------------------- /examples/tutorials/openicl_tutorial3_accelerate.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%pip install --upgrade pip\n", 10 | "%pip install openicl\n", 11 | "# Restart the kernel after the installation is completed" 12 | ] 13 | }, 14 | { 15 | "attachments": {}, 16 | "cell_type": "markdown", 17 | "metadata": {}, 18 | "source": [ 19 | "# 3. Accelerating OpenICL with 🤗 Accelerate: Distributed Data Parallel and Model Parallel" 20 | ] 21 | }, 22 | { 23 | "attachments": {}, 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "In OpenICL, we use 🤗 [Accelerate](https://github.com/huggingface/accelerate) to implement Distributed Data Parallel (DDP) and Model Parallel. 🤗 [Accelerate](https://github.com/huggingface/accelerate) is a library that enables the same PyTorch code to be run across any distributed configuration by adding just few lines of code, to quickly quickly set up 🤗 Accelerate, on your machine(s) just run:" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": { 34 | "vscode": { 35 | "languageId": "shellscript" 36 | } 37 | }, 38 | "outputs": [], 39 | "source": [ 40 | "accelerate config" 41 | ] 42 | }, 43 | { 44 | "attachments": {}, 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "For more details on 🤗 Accelearte, you can check the [documentation](https://huggingface.co/docs/accelerate/index) here." 49 | ] 50 | }, 51 | { 52 | "attachments": {}, 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "## 3-1 Distributed Data Parallel" 57 | ] 58 | }, 59 | { 60 | "attachments": {}, 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "Distributed Data Parallel (DDP) implements data parallelism at the module level which can run across multiple machines. The recommended way to use DDP is to spawn one process for each model replica, where a model replica can span multiple devices. It is quite easy to use DDP in OpenICL after completing relevant settings through ```accelerate config```, just pass in the `Accelerator` instance in `Retriever` and `Inferencer`. The following are code and script examples:" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "# test_sst2_ddp.py\n", 74 | "# Example adapted from tutorial 1-4-1\n", 75 | "\n", 76 | "from openicl import DatasetReader, PromptTemplate, TopkRetriever, PPLInferencer\n", 77 | "from accelerate import Accelerator\n", 78 | "\n", 79 | "# Accelerate Prepare\n", 80 | "accelerator = Accelerator()\n", 81 | "\n", 82 | "# Define a DatasetReader, loading dataset from huggingface.\n", 83 | "data = DatasetReader('gpt3mix/sst2', input_columns=['text'], output_column='label')\n", 84 | "\n", 85 | "# SST-2 Template Example\n", 86 | "template = PromptTemplate(template={\n", 87 | " 0: 'Positive Movie Review: ',\n", 88 | " 1: 'Negative Movie Review: ' \n", 89 | " },\n", 90 | " column_token_map={'text' : ''},\n", 91 | " ice_token=''\n", 92 | " )\n", 93 | "\n", 94 | "# TopK Retriever\n", 95 | "retriever = TopkRetriever(data, ice_num=8, index_split='train', test_split='test', accelerator=accelerator)\n", 96 | "\n", 97 | "# Define a Inferencer\n", 98 | "inferencer = PPLInferencer(model_name='distilgpt2', accelerator=accelerator)\n", 99 | "\n", 100 | "# Inference\n", 101 | "predictions = inferencer.inference(retriever, ice_template=template, output_json_filename='ddp_sst2')\n", 102 | "\n", 103 | "# print(predictions)\n", 104 | "# Seeing results at ./icl_inference_output/ddp_sst2.json" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "# run_sst2.ddp.sh \n", 114 | "# Replace `${your_gpu_num}` and `${your_port_id}` with your gpu number and running port number respectively\n", 115 | "accelerate launch --num_processes ${your_gpu_num} --main_process_port ${your_port_id} test_sst2_ddp.py" 116 | ] 117 | }, 118 | { 119 | "attachments": {}, 120 | "cell_type": "markdown", 121 | "metadata": {}, 122 | "source": [ 123 | "## 3-2 Model Parallel" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": {}, 129 | "source": [] 130 | } 131 | ], 132 | "metadata": { 133 | "language_info": { 134 | "name": "python" 135 | }, 136 | "orig_nbformat": 4 137 | }, 138 | "nbformat": 4, 139 | "nbformat_minor": 2 140 | } 141 | -------------------------------------------------------------------------------- /openicl/__init__.py: -------------------------------------------------------------------------------- 1 | from .icl_dataset_reader import DatasetReader 2 | from .icl_prompt_template import PromptTemplate 3 | from .icl_retriever import * 4 | from .icl_evaluator import * 5 | from .icl_inferencer import * 6 | -------------------------------------------------------------------------------- /openicl/icl_dataset_reader.py: -------------------------------------------------------------------------------- 1 | """Simple Dataset Reader""" 2 | 3 | from typing import List, Union, Optional, Dict 4 | from datasets import load_dataset 5 | from datasets import Dataset, DatasetDict 6 | from transformers import AutoTokenizer 7 | from datasets.splits import NamedSplit 8 | from openicl.icl_prompt_template import PromptTemplate 9 | from openicl.utils.check_type import _check_dataset, _check_type_list, _check_str 10 | import random 11 | import torch 12 | 13 | 14 | class DatasetReader: 15 | """In-conext Learning Dataset Reader Class 16 | Generate an DatasetReader instance through 'dataset'. 17 | 18 | Attributes: 19 | dataset (:obj:`Dataset` or :obj:`DatasetDict`): The dataset to be read. 20 | input_columns (:obj:`List[str]` or :obj:`str`): A list of column names (a string of column name) in the dataset that represent(s) the input field. 21 | output_column (:obj:`str`): A column name in the dataset that represents the prediction field. 22 | ds_size (:obj:`int` or :obj:`float`, optional): The number of pieces of data to return. When ds_size is an integer and greater than or equal to 1, `ds_size` pieces of data are randomly returned. When 0 < :obj:`ds_size` < 1, ``int(len(dataset) * ds_size)`` pieces of data are randomly returned. (used for testing) 23 | references(:obj:`list`, optional): The list of references, initialized by ``self.dataset[self.test_split][self.output_column]``. 24 | input_template (:obj:`PromptTemplate`, optional): An instance of the :obj:`PromptTemplate` class, used to format the input field content during the retrieval process. (in some retrieval methods) 25 | output_template (:obj:`PromptTemplate`, optional): An instance of the :obj:`PromptTemplate` class, used to format the output field content during the retrieval process. (in some learnable retrieval methods) 26 | input_output_template (:obj:`PromptTemplate`, optional): An instance of the `PromptTemplate` class, used to format the input-output field content during the retrieval process. (in some retrieval methods) 27 | """ 28 | dataset = None 29 | input_template = None 30 | output_template = None 31 | input_output_template = None 32 | references = None 33 | 34 | def __init__(self, 35 | dataset: Union[Dataset, DatasetDict, str], 36 | input_columns: Union[List[str], str], 37 | output_column: str, 38 | name: Optional[str] = None, 39 | data_files: Optional[str] = None, 40 | input_template: Optional[PromptTemplate] = None, 41 | output_template: Optional[PromptTemplate] = None, 42 | input_output_template: Optional[PromptTemplate] = None, 43 | ds_size: Union[None, int, float] = None, 44 | split: Optional[NamedSplit] = None, 45 | test_split: Optional[str] = 'test' 46 | ) -> None: 47 | self.input_columns = _check_type_list(input_columns, [List, str]) 48 | if isinstance(self.input_columns, str): 49 | self.input_columns = self.input_columns.split() 50 | self.output_column = _check_str(output_column) 51 | self.ds_size = _check_type_list(ds_size, [None, int, float]) 52 | if input_template is not None: 53 | self.input_template = PromptTemplate._check_prompt_template(input_template) 54 | if output_template is not None: 55 | self.output_template = PromptTemplate._check_prompt_template(output_template) 56 | if input_output_template is not None: 57 | self.input_output_template = PromptTemplate._check_prompt_template(input_output_template) 58 | if isinstance(dataset, str): 59 | self.dataset = load_dataset(dataset, name=name, data_files=data_files) 60 | else: 61 | self.dataset = _check_dataset(dataset) 62 | if split is not None and isinstance(self.dataset, DatasetDict): 63 | self.dataset = self.dataset[split] 64 | if self.ds_size is not None: 65 | if isinstance(self.dataset, Dataset): 66 | self.dataset = load_partial_dataset(dataset, size=self.ds_size) 67 | if isinstance(self.dataset, DatasetDict): 68 | for ds_name in self.dataset.keys(): 69 | self.dataset[ds_name] = load_partial_dataset(self.dataset[ds_name], size=self.ds_size) 70 | if isinstance(self.dataset, DatasetDict): 71 | if test_split in self.dataset.keys(): 72 | self.references = self.dataset[test_split][self.output_column] 73 | elif isinstance(self.dataset, Dataset): 74 | self.references = self.dataset[self.output_column] 75 | 76 | def set_references(self, column: str, split: Optional[str] = None) -> None: 77 | """Set :obj:`self.references` based on :obj:`column` and optional :obj:`split`. 78 | 79 | Args: 80 | column (:obj:`str`): A string of column name. 81 | split (:obj:`str`, optional): A string of dataset split. Defaults to ``None``. 82 | """ 83 | if split is not None: 84 | self.references = self.dataset[split][column] 85 | else: 86 | self.references = self.dataset[column] 87 | 88 | def generate_input_field_prompt(self, entry: Dict) -> str: 89 | """Generate a prompt for the input field based on the provided :obj:`entry` data. 90 | 91 | Args: 92 | entry (:obj:`Dict`): A piece of data to be used for generating the prompt. 93 | 94 | Returns: 95 | :obj:`str`: The generated prompt. 96 | """ 97 | prompt = None 98 | if self.input_template is None: 99 | prompt = ' '.join([str(entry[ctx]) for ctx in self.input_columns]) 100 | else: 101 | prompt = self.input_template.generate_item(entry) 102 | return prompt 103 | 104 | def generate_input_field_corpus(self, dataset: Union[Dataset, DatasetDict], split: Optional[str] = None) -> List[ 105 | str]: 106 | """Generate corpus for input field. 107 | 108 | Args: 109 | dataset (:obj:`Dataset` or :obj:`DatasetDict`): A :obj:`datasets.Dataset` or :obj:`datasets.DatasetDict` instance. 110 | split (:obj:`str`, optional): The split of the dataset to use. If :obj:`None`, the entire dataset will be used. Defaults to ``None``. 111 | 112 | Returns: 113 | :obj:`List[str]`: A list of generated input field prompts. 114 | """ 115 | if split is not None: 116 | dataset = dataset[split] 117 | corpus = [] 118 | for entry in dataset: 119 | corpus.append(self.generate_input_field_prompt(entry)) 120 | return corpus 121 | 122 | def generate_ouput_field_prompt(self, entry: Dict) -> str: 123 | """Generate a prompt for the output field based on the provided :obj:`entry` data. 124 | 125 | Args: 126 | entry (:obj:`Dict`): A piece of data to be used for generating the prompt. 127 | 128 | Returns: 129 | :obj:`str`: The generated prompt. 130 | """ 131 | prompt = None 132 | if self.output_template is None: 133 | prompt = str(entry[self.output_column]) 134 | else: 135 | prompt = self.output_template.generate_item(entry) 136 | return prompt 137 | 138 | def generate_output_field_corpus(self, dataset: Union[Dataset, DatasetDict], split: Optional[str] = None) -> List[ 139 | str]: 140 | """Generate corpus for output field. 141 | 142 | Args: 143 | dataset (:obj:`Dataset` or :obj:`DatasetDict`): A :obj:`datasets.Dataset` or :obj:`datasets.DatasetDict` instance. 144 | split (:obj:`str`, optional): The split of the dataset to use. If :obj:`None`, the entire dataset will be used. Defaults to ``None``. 145 | 146 | Returns: 147 | :obj:`List[str]`: A list of generated output field prompts. 148 | """ 149 | if split is not None: 150 | dataset = dataset[split] 151 | corpus = [] 152 | for entry in dataset: 153 | corpus.append(self.generate_ouput_field_prompt(entry)) 154 | return corpus 155 | 156 | def generate_input_output_field_prompt(self, entry: Dict) -> str: 157 | """Generate a prompt for the input-output field based on the provided:obj:`entry` data. 158 | 159 | Args: 160 | entry (:obj:`Dict`): A piece of data to be used for generating the prompt. 161 | 162 | Returns: 163 | :obj:`str`: The generated prompt. 164 | """ 165 | prompt = None 166 | if self.input_output_template is None: 167 | prompt = ' '.join([entry[ctx] for ctx in self.input_columns] + [str(entry[self.output_column])]) 168 | else: 169 | prompt = self.input_output_template.generate_item(entry) 170 | return prompt 171 | 172 | def generate_input_output_field_corpus(self, dataset: Union[Dataset, DatasetDict], split: Optional[str] = None) -> \ 173 | List[str]: 174 | """Generate corpus for input-output field. 175 | 176 | Args: 177 | dataset (:obj:`Dataset` or :obj:`DatasetDict`): A :obj:`datasets.Dataset` or :obj:`datasets.DatasetDict` instance. 178 | split (:obj:`str`, optional): The split of the dataset to use. If :obj:`None`, the entire dataset will be used. Defaults to ``None``. 179 | 180 | Returns: 181 | :obj:`List[str]`: A list of generated input-output field prompts. 182 | """ 183 | if split is not None: 184 | dataset = dataset[split] 185 | corpus = [] 186 | for entry in dataset: 187 | corpus.append(self.generate_input_output_field_prompt(entry)) 188 | return corpus 189 | 190 | def _check_dataset_reader(obj) -> "DatasetReader": 191 | if isinstance(obj, DatasetReader): 192 | return obj 193 | else: 194 | raise TypeError(f"Expected a DatasetReader object, but got {obj}") 195 | 196 | def __len__(self): 197 | return len(self.dataset) 198 | 199 | def __getitem__(self, idx): 200 | return self.dataset[idx] 201 | 202 | def __repr__(self): 203 | return f"DatasetReader({{\n dataset: {self.dataset},\n input_columns: {self.input_columns},\n output_columns: {self.output_column}\n}})" 204 | 205 | 206 | def load_partial_dataset(dataset: Dataset, size: Optional[Union[int, float]] = None) -> Dataset: 207 | total_size = len(dataset) 208 | if size >= total_size or size <= 0: 209 | return dataset 210 | if size > 0 and size < 1: 211 | size = int(size * total_size) 212 | rand = random.Random(x=size) 213 | index_list = list(range(total_size)) 214 | rand.shuffle(index_list) 215 | dataset = dataset.select(index_list[:size]) 216 | return dataset 217 | 218 | 219 | class DatasetEncoder(torch.utils.data.Dataset): 220 | def __init__(self, datalist: List, model_name=None, tokenizer=None) -> None: 221 | self.datalist = datalist 222 | if model_name is None and tokenizer is None: 223 | raise ValueError("model_name and tokenizer could not both be None") 224 | if tokenizer is not None: 225 | self.tokenizer = tokenizer 226 | else: 227 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 228 | self.tokenizer.pad_token = self.tokenizer.eos_token 229 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 230 | self.tokenizer.padding_side = "left" 231 | self.encode_dataset = [] 232 | self.init_dataset() 233 | self.datalist_length = len(self.encode_dataset) 234 | 235 | def init_dataset(self): 236 | for idx, data in enumerate(self.datalist): 237 | tokenized_data = self.tokenizer.encode_plus(data, truncation=True, return_tensors='pt', verbose=False) 238 | self.encode_dataset.append({ 239 | 'input_ids': tokenized_data.input_ids[0], 240 | 'attention_mask': tokenized_data.attention_mask[0], 241 | "metadata": {"id": idx, "len": len(tokenized_data.input_ids[0]), 242 | "text": data} 243 | }) 244 | 245 | def __len__(self): 246 | return self.datalist_length 247 | 248 | def __getitem__(self, idx): 249 | return self.encode_dataset[idx] 250 | -------------------------------------------------------------------------------- /openicl/icl_evaluator/__init__.py: -------------------------------------------------------------------------------- 1 | from .icl_base_evaluator import BaseEvaluator 2 | from .icl_acc_evaluator import AccEvaluator 3 | from .icl_squad_evaluator import SquadEvaluator 4 | from .icl_bleu_evaluator import BleuEvaluator 5 | from .icl_rouge_evaluator import RougeEvaluator 6 | from .icl_api_evaluator import APIEvaluator -------------------------------------------------------------------------------- /openicl/icl_evaluator/icl_acc_evaluator.py: -------------------------------------------------------------------------------- 1 | """Acc Evaluator""" 2 | from openicl.icl_evaluator import BaseEvaluator 3 | from typing import List 4 | import evaluate 5 | 6 | 7 | class AccEvaluator(BaseEvaluator): 8 | def __init__(self) -> None: 9 | super().__init__() 10 | 11 | def score(self, predictions, references): 12 | assert len(predictions) == len(references) 13 | mapping_to_int_dict = {label: idx for idx, label in enumerate(set(map(str, references)))} 14 | pred_set = set(predictions) 15 | for pred in pred_set: 16 | if str(pred) not in mapping_to_int_dict.keys(): 17 | mapping_to_int_dict[str(pred)] = len(mapping_to_int_dict) 18 | golds = [mapping_to_int_dict[str(gold)] for gold in references] 19 | preds = [mapping_to_int_dict[str(pred)] for pred in predictions] 20 | metric = evaluate.load("accuracy") 21 | return metric.compute(references=golds, predictions=preds) 22 | -------------------------------------------------------------------------------- /openicl/icl_evaluator/icl_api_evaluator.py: -------------------------------------------------------------------------------- 1 | """API evaluator""" 2 | from openicl.icl_evaluator import BaseEvaluator 3 | from typing import List, Dict 4 | import evaluate 5 | 6 | 7 | class APIEvaluator(BaseEvaluator): 8 | def __init__(self, metric) -> None: 9 | super().__init__() 10 | self.metric = metric 11 | 12 | def score(self, predictions, references): 13 | assert len(predictions) == len(references) 14 | metric = evaluate.load(metric) 15 | scores = metric.compute(predictions=predictions, references=references) 16 | return scores 17 | -------------------------------------------------------------------------------- /openicl/icl_evaluator/icl_base_evaluator.py: -------------------------------------------------------------------------------- 1 | """Base Evaluator""" 2 | from typing import List 3 | 4 | 5 | class BaseEvaluator: 6 | def __init__(self) -> None: 7 | pass 8 | 9 | def score(self): 10 | raise NotImplementedError("Method hasn't been implemented yet") 11 | -------------------------------------------------------------------------------- /openicl/icl_evaluator/icl_bleu_evaluator.py: -------------------------------------------------------------------------------- 1 | """BLEU evaluator""" 2 | from openicl.icl_evaluator import BaseEvaluator 3 | from typing import List, Dict 4 | import evaluate 5 | 6 | 7 | class BleuEvaluator(BaseEvaluator): 8 | def __init__(self) -> None: 9 | super().__init__() 10 | 11 | def score(self, predictions, references): 12 | assert len(predictions) == len(references) 13 | metric = evaluate.load("sacrebleu") 14 | scores = metric.compute(predictions=predictions, references=references) 15 | return scores 16 | -------------------------------------------------------------------------------- /openicl/icl_evaluator/icl_rouge_evaluator.py: -------------------------------------------------------------------------------- 1 | """ROUGE evaluator""" 2 | from openicl.icl_evaluator import BaseEvaluator 3 | from typing import List, Dict 4 | import evaluate 5 | 6 | 7 | class RougeEvaluator(BaseEvaluator): 8 | def __init__(self) -> None: 9 | super().__init__() 10 | 11 | def score(self, predictions, references): 12 | assert len(predictions) == len(references) 13 | metric = evaluate.load("rouge") 14 | scores = metric.compute(predictions=predictions, references=references) 15 | return scores 16 | -------------------------------------------------------------------------------- /openicl/icl_evaluator/icl_squad_evaluator.py: -------------------------------------------------------------------------------- 1 | '''Squad Evaluator''' 2 | from openicl.icl_evaluator import BaseEvaluator 3 | from typing import List 4 | import evaluate 5 | 6 | 7 | class SquadEvaluator(BaseEvaluator): 8 | def __init__(self) -> None: 9 | super().__init__() 10 | 11 | def score(self, predictions, references): 12 | assert len(predictions) == len(references) 13 | p_list = [{'prediction_text': pred.split('\n')[0], 'id': str(i)} for i, pred in 14 | enumerate(predictions)] 15 | r_list = [{'answers': {'answer_start': [0], 'text': [ref]}, 'id': str(i)} for i, ref in 16 | enumerate(references)] 17 | metric = evaluate.load('squad') 18 | scores = metric.compute(predictions=p_list, references=r_list) 19 | return scores["f1"] 20 | -------------------------------------------------------------------------------- /openicl/icl_inferencer/__init__.py: -------------------------------------------------------------------------------- 1 | from .icl_base_inferencer import BaseInferencer 2 | from .icl_ppl_inferencer import PPLInferencer 3 | from .icl_gen_inferencer import GenInferencer 4 | from .icl_cot_inferencer import CoTInferencer 5 | from .icl_channel_inferencer import ChannelInferencer 6 | -------------------------------------------------------------------------------- /openicl/icl_inferencer/icl_base_inferencer.py: -------------------------------------------------------------------------------- 1 | """Basic Inferencer""" 2 | 3 | import os 4 | import torch 5 | from openicl import BaseRetriever, PromptTemplate 6 | from openicl.utils.api_service import * 7 | from openicl.icl_evaluator import * 8 | from transformers import AutoTokenizer, AutoModelForCausalLM, PretrainedConfig, GPT2Tokenizer, AutoConfig, \ 9 | T5ForConditionalGeneration 10 | from typing import List, Union, Optional, Any 11 | from accelerate import Accelerator 12 | from accelerate import init_empty_weights, infer_auto_device_map 13 | 14 | 15 | class BaseInferencer: 16 | """Basic In-context Learning Inferencer Class 17 | Base class of In-context Learning Inferencer, with no inference method. 18 | 19 | Attributes: 20 | model (:obj:`AutoModelForCausalLM`, optional): Local PLM (loaded from Hugging Face), which can be initialized by name or a config class. 21 | tokenizer (:obj:`AutoTokenizer` or :obj:`GPT2Tokenizer`, optional): Tokenizer for :obj:`model`. 22 | max_model_token_num (:obj:`int`, optional): Maximum number of tokenized words allowed by the LM. 23 | batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`. 24 | accelerator (:obj:`Accelerator`, optional): An instance of the `Accelerator` class, used for multiprocessing. 25 | output_json_filepath (:obj:`str`, optional): File path for output `JSON` file. 26 | output_json_filename (:obj:`str`, optional): File name for output `JSON` file. 27 | api_name (:obj:`str`, optional): Name of API service. 28 | call_api (:obj:`bool`): If ``True``, an API for LM models will be used, determined by :obj:`api_name`. 29 | """ 30 | model = None 31 | tokenizer = None 32 | call_api = False 33 | 34 | def __init__(self, 35 | model_name: Optional[Union[str, Any]] = 'gpt2-xl', 36 | tokenizer_name: Optional[Union[str, Any]] = None, 37 | max_model_token_num: Optional[int] = None, 38 | model_config: Optional[PretrainedConfig] = None, 39 | batch_size: Optional[int] = 1, 40 | accelerator: Optional[Accelerator] = None, 41 | output_json_filepath: Optional[str] = "./icl_inference_output", 42 | output_json_filename: Optional[str] = "predictions", 43 | api_name: Optional[str] = None, 44 | model_parallel: Optional[bool] = False, 45 | **kwargs 46 | ) -> None: 47 | self.model_name = model_name 48 | self.tokenizer_name = tokenizer_name if tokenizer_name is not None else model_name 49 | self.accelerator = accelerator 50 | self.is_main_process = True if self.accelerator is None or self.accelerator.is_main_process else False 51 | self.api_name = api_name 52 | 53 | if 'no_split_module_classes' not in kwargs.keys(): 54 | kwargs['no_split_module_classes'] = [] 55 | if 'device_map' not in kwargs.keys(): 56 | kwargs['device_map'] = None 57 | 58 | no_split_module_classes = kwargs['no_split_module_classes'] 59 | device_map = kwargs['device_map'] 60 | 61 | self.__init_api(**kwargs) 62 | if not self.call_api: 63 | self.__init_model(self.model_name, model_config, model_parallel, device_map, no_split_module_classes) 64 | self.__init_tokenizer(self.tokenizer_name) 65 | else: 66 | if self.api_name == 'opt-175b': 67 | self.__init_tokenizer(self.tokenizer_name) 68 | 69 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 70 | if self.model is not None: 71 | self.model.to(self.device) 72 | self.model.eval() 73 | self.max_model_token_num = max_model_token_num 74 | self.batch_size = batch_size 75 | self.output_json_filepath = output_json_filepath 76 | self.output_json_filename = output_json_filename 77 | if not os.path.exists(self.output_json_filepath): 78 | os.makedirs(self.output_json_filepath) 79 | 80 | def inference(self, retriever: BaseRetriever, ice_template: Optional[PromptTemplate] = None, 81 | prompt_template: Optional[PromptTemplate] = None, output_json_filepath: Optional[str] = None, 82 | output_json_filename: Optional[str] = None) -> List: 83 | """Perform In-Context Inference given a retriever and optional templates. 84 | 85 | Args: 86 | retriever (:obj:`BaseRetriever`): An instance of a Retriever class that will be used to retrieve in-context examples 87 | ice_template (:obj:`PromptTemplate`, optional): A template for generating the in-context examples prompt. Defaults to None. 88 | prompt_template (:obj:`PromptTemplate`, optional): A template for generating the final prompt. Defaults to None. 89 | output_json_filepath (:obj:`str`, optional): The file path to save the results as a `JSON` file. Defaults to None. 90 | output_json_filename (:obj:`str`, optional): The file name to save the results as a `JSON` file. Defaults to None. 91 | 92 | Raises: 93 | NotImplementedError: If the function is not implemented in the subclass. 94 | 95 | Returns: 96 | :obj:`List:` A list of string, each representing the results of one inference. 97 | """ 98 | raise NotImplementedError("Method hasn't been implemented yet") 99 | 100 | def __init_model(self, model_name, model_config, model_parallel, device_map, no_split_module_classes): 101 | if not isinstance(model_name, str): 102 | self.model = model_name 103 | self.model_name = '' # set model name to null since we pass the loaded model already 104 | return 105 | if not model_parallel: 106 | if model_config is not None: 107 | self.model = self.__get_hf_model_from_config(model_name, model_config) 108 | else: 109 | self.model = self.__get_hf_model_from_name(model_name) 110 | else: 111 | if model_config is None: 112 | model_config = AutoConfig.from_pretrained(model_name) 113 | with init_empty_weights(): 114 | empty_model = AutoModelForCausalLM.from_config(model_config) 115 | 116 | if device_map is None: 117 | device_map = infer_auto_device_map(empty_model, no_split_module_classes=no_split_module_classes, 118 | dtype="float16") 119 | 120 | self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map, 121 | offload_folder="offload", offload_state_dict=True, 122 | torch_dtype=torch.float16) 123 | 124 | def __get_hf_model_from_name(self, model_name): 125 | if 't5' in model_name: 126 | return T5ForConditionalGeneration.from_pretrained(model_name) 127 | else: 128 | return AutoModelForCausalLM.from_pretrained(model_name) 129 | 130 | def __get_hf_model_from_config(self, model_name, model_config): 131 | if 't5' in model_name: 132 | raise TypeError("T5 model has no 'from_config' method") 133 | else: 134 | return AutoModelForCausalLM.from_config(model_config) 135 | 136 | def __init_tokenizer(self, tokenizer_name): 137 | if self.api_name == 'opt-175b': 138 | self.tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-30b", use_fast=False) 139 | else: 140 | if not isinstance(tokenizer_name, str): 141 | self.tokenizer = tokenizer_name 142 | else: 143 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 144 | self.tokenizer.pad_token = self.tokenizer.eos_token 145 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 146 | self.tokenizer.padding_side = "left" 147 | 148 | def __init_api(self, **kwargs): 149 | if self.api_name == None: 150 | return 151 | self.call_api = is_api_available(self.api_name) 152 | if not self.call_api: 153 | UserWarning(f"api_name '{self.api_name}' is not available, Please check it") 154 | else: 155 | update_openicl_api_request_config(self.api_name, **kwargs) 156 | 157 | def get_input_token_num(self, inputs): 158 | return len(self.tokenizer(inputs, verbose=False)['input_ids']) 159 | 160 | 161 | class GenInferencerOutputHandler: 162 | origin_prompt_dict = {} 163 | output_dict = {} 164 | prediction_dict = {} 165 | results_dict = {} 166 | 167 | def __init__(self, 168 | num: int, 169 | accelerator: Optional[Accelerator] = None 170 | ) -> None: 171 | self.num = num 172 | self.accelerator = accelerator 173 | self.origin_prompt_dict = {} 174 | self.output_dict = {} 175 | self.prediction_dict = {} 176 | self.results_dict = {} 177 | 178 | def subprocess_write_to_json(self, output_json_filepath: str, output_json_filename: str): 179 | self.results_dict = { 180 | str(idx): { 181 | 'origin_prompt': self.origin_prompt_dict[str(idx)], 182 | 'output': self.output_dict[str(idx)], 183 | 'prediction': self.prediction_dict[str(idx)] 184 | } for idx in self.origin_prompt_dict.keys() 185 | } 186 | if self.accelerator is not None: 187 | with open(f'{output_json_filepath}/process{self.accelerator.process_index}_{output_json_filename}.json', 188 | 'w', encoding='utf-8') as json_file: 189 | json.dump(self.results_dict, json_file, indent=4, ensure_ascii=False) 190 | json_file.close() 191 | 192 | def write_to_json(self, output_json_filepath: str, output_json_filename: str): 193 | with open(f'{output_json_filepath}/{output_json_filename}.json', 'w', encoding='utf-8') as json_file: 194 | json.dump(self.results_dict, json_file, indent=4, ensure_ascii=False) 195 | json_file.close() 196 | 197 | def merge_to_main_process(self, output_json_filepath: str, output_json_filename: str): 198 | if self.accelerator is not None and self.accelerator.is_main_process: 199 | for pid in range(self.accelerator.num_processes): 200 | with open(f'{output_json_filepath}/process{pid}_{output_json_filename}.json', 'r', 201 | encoding='utf-8') as json_file: 202 | subprocess_results_dict = json.load(json_file) 203 | self.results_dict.update(subprocess_results_dict) 204 | json_file.close() 205 | self.results_dict = dict(sorted(self.results_dict.items(), key=lambda x: int(x[0]))) 206 | 207 | def save_orgin_prompts(self, origin_prompts: List[str]): 208 | for idx, origin_prompt in enumerate(origin_prompts): 209 | if self.accelerator is not None: 210 | idx = idx * self.accelerator.num_processes + self.accelerator.process_index 211 | self.origin_prompt_dict[str(idx)] = origin_prompt 212 | 213 | def save_prediction_and_output(self, prediction, output, idx): 214 | if self.accelerator is not None: 215 | idx = idx * self.accelerator.num_processes + self.accelerator.process_index 216 | self.prediction_dict[str(idx)] = prediction 217 | self.output_dict[str(idx)] = output 218 | 219 | 220 | class PPLInferencerOutputHandler: 221 | results_dict = {} 222 | 223 | def __init__(self, 224 | accelerator: Optional[Accelerator] = None 225 | ) -> None: 226 | self.accelerator = accelerator 227 | self.results_dict = {} 228 | 229 | def subprocess_write_to_json(self, output_json_filepath: str, output_json_filename: str): 230 | if self.accelerator is not None: 231 | with open(f'{output_json_filepath}/process{self.accelerator.process_index}_{output_json_filename}.json', 232 | 'w', encoding='utf-8') as json_file: 233 | json.dump(self.results_dict, json_file, indent=4, ensure_ascii=False) 234 | json_file.close() 235 | 236 | def write_to_json(self, output_json_filepath: str, output_json_filename: str): 237 | with open(f'{output_json_filepath}/{output_json_filename}.json', 'w', encoding='utf-8') as json_file: 238 | json.dump(self.results_dict, json_file, indent=4, ensure_ascii=False) 239 | json_file.close() 240 | 241 | def merge_to_main_process(self, output_json_filepath: str, output_json_filename: str): 242 | if self.accelerator is not None and self.accelerator.is_main_process: 243 | for pid in range(self.accelerator.num_processes): 244 | with open(f'{output_json_filepath}/process{pid}_{output_json_filename}.json', 'r', 245 | encoding='utf-8') as json_file: 246 | subprocess_results_dict = json.load(json_file) 247 | self.results_dict.update(subprocess_results_dict) 248 | json_file.close() 249 | self.results_dict = dict(sorted(self.results_dict.items(), key=lambda x: int(x[0]))) 250 | 251 | def save_ice(self, ice): 252 | for idx, example in enumerate(ice): 253 | if self.accelerator is not None: 254 | idx = idx * self.accelerator.num_processes + self.accelerator.process_index 255 | if str(idx) not in self.results_dict.keys(): 256 | self.results_dict[str(idx)] = {} 257 | self.results_dict[str(idx)]['in-context examples'] = example 258 | 259 | def save_predictions(self, predictions): 260 | for idx, prediction in enumerate(predictions): 261 | if self.accelerator is not None: 262 | idx = idx * self.accelerator.num_processes + self.accelerator.process_index 263 | if str(idx) not in self.results_dict.keys(): 264 | self.results_dict[str(idx)] = {} 265 | self.results_dict[str(idx)]['prediction'] = prediction 266 | 267 | def save_prompt_and_ppl(self, label, input, prompt, ppl, idx): 268 | if self.accelerator is not None: 269 | idx = idx * self.accelerator.num_processes + self.accelerator.process_index 270 | if str(idx) not in self.results_dict.keys(): 271 | self.results_dict[str(idx)] = {} 272 | if 'label: ' + str(label) not in self.results_dict[str(idx)].keys(): 273 | self.results_dict[str(idx)]['label: ' + str(label)] = {} 274 | self.results_dict[str(idx)]['label: ' + str(label)]['testing input'] = input 275 | self.results_dict[str(idx)]['label: ' + str(label)]['prompt'] = prompt 276 | self.results_dict[str(idx)]['label: ' + str(label)]['PPL'] = ppl 277 | -------------------------------------------------------------------------------- /openicl/icl_inferencer/icl_channel_inferencer.py: -------------------------------------------------------------------------------- 1 | """PPL Inferencer""" 2 | 3 | import json 4 | import torch 5 | from openicl import PromptTemplate 6 | from openicl.icl_retriever import * 7 | from openicl.icl_evaluator import * 8 | from openicl.icl_inferencer.icl_base_inferencer import BaseInferencer, PPLInferencerOutputHandler 9 | from openicl.icl_inferencer.icl_ppl_inferencer import PPLInferencer 10 | from openicl.utils.logging import get_logger 11 | from openicl.utils.api_service import * 12 | from typing import List, Union, Optional 13 | from tqdm import tqdm 14 | from tqdm import trange 15 | from transformers import PretrainedConfig 16 | from accelerate import Accelerator 17 | 18 | logger = get_logger(__name__) 19 | 20 | 21 | class ChannelInferencer(PPLInferencer): 22 | """PPL In-context Learning Inferencer Class 23 | Channel In-context Learning Inferencer. 24 | We recommend you to use ppl inferencer instead of channel inferencer 25 | 26 | """ 27 | 28 | def inference(self, retriever: BaseRetriever, ice_template: Optional[PromptTemplate] = None, 29 | prompt_template: Optional[PromptTemplate] = None, output_json_filepath: Optional[str] = None, 30 | output_json_filename: Optional[str] = None, normalizing_str: Optional[str] = None) -> List: 31 | # 1. Preparation for output logs 32 | output_handler = PPLInferencerOutputHandler(self.accelerator) 33 | 34 | sub_predictions = [] 35 | ppl = [] 36 | ice = [] 37 | 38 | if output_json_filepath is None: 39 | output_json_filepath = self.output_json_filepath 40 | if output_json_filename is None: 41 | output_json_filename = self.output_json_filename 42 | 43 | # 2. Get results of retrieval process 44 | ice_idx_list = retriever.retrieve() 45 | 46 | # 3. Get labels of all the classes 47 | if self.labels is None: 48 | labels = retriever.get_labels(ice_template=ice_template, prompt_template=prompt_template) 49 | else: 50 | labels = self.labels 51 | 52 | # 4. Generate in-context examples for testing inputs 53 | for idx in range(len(ice_idx_list)): 54 | ice.append(retriever.generate_ice(ice_idx_list[idx], ice_template=ice_template)) 55 | output_handler.save_ice(ice) 56 | 57 | # 5. Calculating PPL for prompts in each label's class 58 | for label in labels: 59 | index = 0 60 | prompt_list = [] 61 | sub_ppl_list = [] 62 | context_length_list = [] 63 | 64 | # 5.1 Generate prompts of current label and truncate 65 | for idx in range(len(ice_idx_list)): 66 | prompt = retriever.generate_label_prompt(idx, ice[idx], label, ice_template=ice_template, 67 | prompt_template=prompt_template, 68 | remain_sep=True) 69 | if self.max_model_token_num is not None and self.api_name != 'gpt3': 70 | prompt_token_num = self.get_input_token_num(prompt) 71 | while len(ice_idx_list[idx]) > 0 and prompt_token_num > self.max_model_token_num: 72 | ice_idx_list[idx] = ice_idx_list[idx][:-1] 73 | ice[idx] = retriever.generate_ice(ice_idx_list[idx], ice_template=ice_template) 74 | prompt = retriever.generate_label_prompt(idx, ice[idx], label, ice_template=ice_template, 75 | prompt_template=prompt_template) 76 | prompt_token_num = self.get_input_token_num(prompt) 77 | 78 | prompt_sep = prompt 79 | if prompt_template is not None: 80 | sep_token = prompt_template.sep_token 81 | else: 82 | sep_token = ice_template.sep_token 83 | sep_pos = prompt_sep.find(sep_token) 84 | context = prompt_sep[0:sep_pos] 85 | prompt = prompt_sep.replace(sep_token, '') 86 | context_length_list.append(self.get_input_token_num(context)) 87 | prompt_list.append(prompt) 88 | 89 | # 5.2 Get PPL 90 | logger.info(f"Calculating PPL for prompts labeled '{label}'") 91 | for idx in trange(0, len(prompt_list), self.batch_size, disable=not self.is_main_process): 92 | sub_prompt_list = prompt_list[idx:idx + self.batch_size] 93 | sub_context_length_list = context_length_list[idx:idx + self.batch_size] 94 | 95 | with torch.no_grad(): 96 | sub_res = self.__get_ppl(input_texts=sub_prompt_list, mask_length=sub_context_length_list) 97 | for res, prompt in zip(sub_res, sub_prompt_list): 98 | sub_ppl_list.append(res) 99 | output_handler.save_prompt_and_ppl(label, prompt[len(ice[idx]):], prompt, res, index) 100 | index = index + 1 101 | ppl.append(sub_ppl_list) 102 | 103 | # 6. Get lowest PPL class as predictions 104 | ppl = list(zip(*ppl)) 105 | for single_ppl in ppl: 106 | sub_predictions.append(labels[single_ppl.index(min(single_ppl))]) 107 | output_handler.save_predictions(sub_predictions) 108 | 109 | # 7. Output 110 | output_handler.subprocess_write_to_json(output_json_filepath, output_json_filename) 111 | if self.accelerator is not None: 112 | self.accelerator.wait_for_everyone() 113 | output_handler.merge_to_main_process(output_json_filepath, output_json_filename) 114 | output_handler.write_to_json(output_json_filepath, output_json_filename) 115 | 116 | return [sample['prediction'] for sample in output_handler.results_dict.values()] 117 | -------------------------------------------------------------------------------- /openicl/icl_inferencer/icl_cot_inferencer.py: -------------------------------------------------------------------------------- 1 | """chain-of-thought inferencer""" 2 | 3 | import torch 4 | from openicl import PromptTemplate 5 | from openicl.icl_retriever import * 6 | from openicl.icl_evaluator import * 7 | from openicl.icl_inferencer.icl_base_inferencer import BaseInferencer, GenInferencerOutputHandler 8 | from typing import List, Union, Optional 9 | from tqdm import tqdm 10 | from transformers import PretrainedConfig 11 | from openicl.utils.api_service import * 12 | from openicl.utils.icl_common_utils import get_dataloader, get_generation_prompt_list_from_retriever_indices 13 | from openicl.utils.logging import get_logger 14 | from accelerate import Accelerator 15 | 16 | logger = get_logger(__name__) 17 | 18 | 19 | class CoTInferencer(BaseInferencer): 20 | """COT In-context Learning Inferencer Class 21 | Chain-of-Thought In-context Learning Inferencer. 22 | 23 | Attributes: 24 | model (:obj:`AutoModelForCausalLM`, optional): Local PLM (loaded from Hugging Face), which can be initialized by name or a config class. 25 | tokenizer (:obj:`AutoTokenizer` or :obj:`GPT2Tokenizer`, optional): Tokenizer for :obj:`model`. 26 | max_model_token_num (:obj:`int`, optional): Maximum number of tokenized words allowed by the LM. 27 | batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`. 28 | accelerator (:obj:`Accelerator`, optional): An instance of the `Accelerator` class, used for multiprocessing. 29 | output_json_filepath (:obj:`str`, optional): File path for output `JSON` file. 30 | output_json_filename (:obj:`str`, optional): File name for output `JSON` file. 31 | api_name (:obj:`str`, optional): Name of API service. 32 | call_api (:obj:`bool`): If ``True``, an API for LM models will be used, determined by :obj:`api_name`. 33 | gen_field_replace_token (:obj:`str`, optional): Used to replace the generation field token when generating prompts. 34 | generation_kwargs (:obj:`Dict`, optional): Parameters for the :obj:`model.generate()` method. 35 | cot_list (:obj:`list`, optional): A list of sentences used for multiple-step generations. 36 | """ 37 | 38 | def __init__(self, 39 | cot_list: Optional[List[str]] = [], 40 | model_name: Optional[str] = 'gpt2-xl', 41 | tokenizer_name: Optional[str] = None, 42 | max_model_token_num: Optional[int] = None, 43 | model_config: Optional[PretrainedConfig] = None, 44 | batch_size: Optional[int] = 1, 45 | gen_field_replace_token: Optional[str] = '', 46 | generation_kwargs={"max_new_tokens": 100}, 47 | accelerator: Optional[Accelerator] = None, 48 | output_json_filepath: Optional[str] = "./icl_inference_output", 49 | output_json_filename: Optional[str] = "predictions", 50 | api_name: Optional[str] = None, 51 | model_parallel: Optional[bool] = False, 52 | **kwargs 53 | ) -> None: 54 | super().__init__(model_name, tokenizer_name, max_model_token_num, model_config, batch_size, accelerator, 55 | output_json_filepath, output_json_filename, api_name, model_parallel, **kwargs) 56 | self.cot_list = cot_list 57 | self.gen_field_replace_token = gen_field_replace_token 58 | self.generation_kwargs = generation_kwargs 59 | 60 | def inference(self, retriever: BaseRetriever, ice_template: Optional[PromptTemplate] = None, 61 | prompt_template: Optional[PromptTemplate] = None, output_json_filepath: Optional[str] = None, 62 | output_json_filename: Optional[str] = None) -> List: 63 | # 1. Preparation for output logs 64 | num = len(retriever.test_ds) 65 | output_handler = GenInferencerOutputHandler(num, self.accelerator) 66 | index = 0 67 | if output_json_filepath is None: 68 | output_json_filepath = self.output_json_filepath 69 | if output_json_filename is None: 70 | output_json_filename = self.output_json_filename 71 | 72 | # 2. Get results of retrieval process 73 | ice_idx_list = retriever.retrieve() 74 | cot_list_len = len(self.cot_list) 75 | 76 | # 3. Generate prompts for testing input 77 | prompt_list = get_generation_prompt_list_from_retriever_indices(ice_idx_list, retriever, self.tokenizer, 78 | self.gen_field_replace_token, 79 | max_model_token_num=self.max_model_token_num, 80 | ice_template=ice_template, 81 | prompt_template=prompt_template) 82 | if cot_list_len > 0: 83 | prompt_list = [prompt + self.cot_list[0] for prompt in prompt_list] 84 | 85 | # 4. Inference for `max((len(self.cot_list) + 1), 1)` times 86 | for idx in range(0, max(cot_list_len, 1)): 87 | index = 0 88 | cot_idx = idx + 1 89 | # 4-1. Wrap prompts with Dataloader 90 | dataloader = get_dataloader(prompt_list, self.batch_size) 91 | output_handler.save_orgin_prompts(prompt_list) 92 | 93 | for entry in tqdm(dataloader, disable=not self.is_main_process): 94 | # 4-2-1. Inference with local model 95 | if not self.call_api: 96 | with torch.no_grad(): 97 | tokenized_data = self.tokenizer.batch_encode_plus(entry, padding=True, return_tensors='pt').to( 98 | self.device) 99 | prompt_len = int(tokenized_data.attention_mask.shape[1]) 100 | if 't5' in self.model_name: 101 | prompt_len = 0 102 | outputs = self.model.generate(input_ids=tokenized_data.input_ids, 103 | attention_mask=tokenized_data.attention_mask, 104 | eos_token_id=self.tokenizer.eos_token_id, 105 | pad_token_id=self.tokenizer.pad_token_id, 106 | **self.generation_kwargs) 107 | outputs = outputs.tolist() 108 | complete_output = self.tokenizer.batch_decode(outputs[:], skip_special_tokens=True) 109 | generated = self.tokenizer.batch_decode([output[prompt_len:] for output in outputs], 110 | skip_special_tokens=True) 111 | # 4-2-2. Inference with remote API 112 | else: 113 | complete_output, generated = api_get_tokens(self.api_name, entry) 114 | 115 | # 4-2-3. Save current output 116 | for prediction, output in zip(generated, complete_output): 117 | if 't5' in self.model_name: 118 | output = prompt_list[index] + output 119 | output_handler.save_prediction_and_output(prediction, output, index) 120 | prompt_list[index] = output 121 | index = index + 1 122 | 123 | # 4-3. Output for current step 124 | if cot_idx < cot_list_len: 125 | filename = output_json_filename + f'_step{idx}' 126 | else: 127 | filename = output_json_filename 128 | output_handler.subprocess_write_to_json(output_json_filepath, filename) 129 | if self.accelerator is not None: 130 | self.accelerator.wait_for_everyone() 131 | output_handler.merge_to_main_process(output_json_filepath, filename) 132 | output_handler.write_to_json(output_json_filepath, filename) 133 | 134 | # 4-4. Check for next string in `self.cot_list` 135 | if cot_idx < cot_list_len: 136 | prompt_list = [(prompt + str(self.cot_list[cot_idx])) for prompt in prompt_list] 137 | else: 138 | break 139 | return [sample['prediction'] for sample in output_handler.results_dict.values()] 140 | -------------------------------------------------------------------------------- /openicl/icl_inferencer/icl_gen_inferencer.py: -------------------------------------------------------------------------------- 1 | """Direct Generation Inferencer""" 2 | 3 | import json 4 | import torch 5 | from openicl import PromptTemplate 6 | from openicl.icl_retriever import * 7 | from openicl.icl_evaluator import * 8 | from openicl.icl_inferencer.icl_base_inferencer import BaseInferencer, GenInferencerOutputHandler 9 | from openicl.utils.api_service import * 10 | from openicl.utils.icl_common_utils import get_dataloader, get_generation_prompt_list_from_retriever_indices 11 | from openicl.utils.logging import get_logger 12 | from typing import List, Union, Optional 13 | from tqdm import tqdm 14 | from transformers import PretrainedConfig 15 | from accelerate import Accelerator 16 | 17 | logger = get_logger(__name__) 18 | 19 | 20 | class GenInferencer(BaseInferencer): 21 | """Generation In-context Learning Inferencer Class 22 | In-context Learning Inferencer for Directly Generation. 23 | 24 | Attributes: 25 | model (:obj:`AutoModelForCausalLM`, optional): Local PLM (loaded from Hugging Face), which can be initialized by name or a config class. 26 | tokenizer (:obj:`AutoTokenizer` or :obj:`GPT2Tokenizer`, optional): Tokenizer for :obj:`model`. 27 | max_model_token_num (:obj:`int`, optional): Maximum number of tokenized words allowed by the LM. 28 | batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`. 29 | accelerator (:obj:`Accelerator`, optional): An instance of the `Accelerator` class, used for multiprocessing. 30 | output_json_filepath (:obj:`str`, optional): File path for output `JSON` file. 31 | output_json_filename (:obj:`str`, optional): File name for output `JSON` file. 32 | api_name (:obj:`str`, optional): Name of API service. 33 | call_api (:obj:`bool`): If ``True``, an API for LM models will be used, determined by :obj:`api_name`. 34 | gen_field_replace_token (:obj:`str`, optional): Used to replace the generation field token when generating prompts. 35 | generation_kwargs (:obj:`Dict`, optional): Parameters for the :obj:`model.generate()` method. 36 | """ 37 | 38 | def __init__(self, 39 | model_name: Optional[str] = 'gpt2-xl', 40 | tokenizer_name: Optional[str] = None, 41 | max_model_token_num: Optional[int] = None, 42 | model_config: Optional[PretrainedConfig] = None, 43 | batch_size: Optional[int] = 1, 44 | gen_field_replace_token: Optional[str] = '', 45 | generation_kwargs={"max_new_tokens": 100}, 46 | accelerator: Optional[Accelerator] = None, 47 | output_json_filepath: Optional[str] = "./icl_inference_output", 48 | output_json_filename: Optional[str] = "predictions", 49 | api_name: Optional[str] = None, 50 | model_parallel: Optional[bool] = False, 51 | **kwargs 52 | ) -> None: 53 | super().__init__(model_name, tokenizer_name, max_model_token_num, model_config, batch_size, accelerator, 54 | output_json_filepath, output_json_filename, api_name, model_parallel, **kwargs) 55 | self.gen_field_replace_token = gen_field_replace_token 56 | self.generation_kwargs = generation_kwargs 57 | 58 | def inference(self, retriever: BaseRetriever, ice_template: Optional[PromptTemplate] = None, 59 | prompt_template: Optional[PromptTemplate] = None, output_json_filepath: Optional[str] = None, 60 | output_json_filename: Optional[str] = None, force_words=None) -> List: 61 | # 1. Preparation for output logs 62 | num = len(retriever.test_ds) 63 | output_handler = GenInferencerOutputHandler(num, self.accelerator) 64 | index = 0 65 | 66 | if output_json_filepath is None: 67 | output_json_filepath = self.output_json_filepath 68 | if output_json_filename is None: 69 | output_json_filename = self.output_json_filename 70 | 71 | # 2. Get results of retrieval process 72 | ice_idx_list = retriever.retrieve() 73 | 74 | # 3. Generate prompts for testing input 75 | prompt_list = get_generation_prompt_list_from_retriever_indices(ice_idx_list, retriever, self.tokenizer, 76 | self.gen_field_replace_token, 77 | max_model_token_num=self.max_model_token_num, 78 | ice_template=ice_template, 79 | prompt_template=prompt_template) 80 | output_handler.save_orgin_prompts(prompt_list) 81 | 82 | # 4. Wrap prompts with Dataloader 83 | dataloader = get_dataloader(prompt_list, self.batch_size) 84 | 85 | # 5. Inference for prompts in each batch 86 | logger.info("Starting inference process...") 87 | for entry in tqdm(dataloader, disable=not self.is_main_process): 88 | # 5-1. Inference with local model 89 | if not self.call_api: 90 | with torch.no_grad(): 91 | tokenized_data = self.tokenizer.batch_encode_plus(entry, padding=True, return_tensors='pt').to( 92 | self.device) 93 | prompt_len = int(tokenized_data.attention_mask.shape[1]) 94 | if 't5' in self.model_name: 95 | prompt_len = 0 96 | if force_words is not None: 97 | force_words_ids = [ 98 | self.tokenizer(force_words).input_ids, 99 | ] 100 | outputs = self.model.generate(input_ids=tokenized_data.input_ids, 101 | force_words_ids=force_words_ids, 102 | num_beams=10, 103 | attention_mask=tokenized_data.attention_mask, 104 | eos_token_id=self.tokenizer.eos_token_id, 105 | pad_token_id=self.tokenizer.pad_token_id, 106 | **self.generation_kwargs) 107 | else: 108 | outputs = self.model.generate(input_ids=tokenized_data.input_ids, 109 | attention_mask=tokenized_data.attention_mask, 110 | eos_token_id=self.tokenizer.eos_token_id, 111 | pad_token_id=self.tokenizer.pad_token_id, 112 | **self.generation_kwargs) 113 | outputs = outputs.tolist() 114 | complete_output = self.tokenizer.batch_decode(outputs[:], skip_special_tokens=True) 115 | generated = self.tokenizer.batch_decode([output[prompt_len:] for output in outputs], 116 | skip_special_tokens=True) 117 | # 5-2. Inference with remote API 118 | else: 119 | complete_output, generated = api_get_tokens(self.api_name, entry) 120 | 121 | # 5-3. Save current output 122 | for prediction, output in zip(generated, complete_output): 123 | output_handler.save_prediction_and_output(prediction, output, index) 124 | index = index + 1 125 | 126 | # 6. Output 127 | output_handler.subprocess_write_to_json(output_json_filepath, output_json_filename) 128 | if self.accelerator is not None: 129 | self.accelerator.wait_for_everyone() 130 | output_handler.merge_to_main_process(output_json_filepath, output_json_filename) 131 | output_handler.write_to_json(output_json_filepath, output_json_filename) 132 | return [sample['prediction'] for sample in output_handler.results_dict.values()] 133 | -------------------------------------------------------------------------------- /openicl/icl_inferencer/icl_ppl_inferencer.py: -------------------------------------------------------------------------------- 1 | """PPL Inferencer""" 2 | 3 | import json 4 | import torch 5 | from openicl import PromptTemplate 6 | from openicl.icl_retriever import * 7 | from openicl.icl_evaluator import * 8 | from openicl.icl_inferencer.icl_base_inferencer import BaseInferencer, PPLInferencerOutputHandler 9 | from openicl.utils.logging import get_logger 10 | from openicl.utils.api_service import * 11 | from typing import List, Union, Optional 12 | from tqdm import tqdm 13 | from tqdm import trange 14 | from transformers import PretrainedConfig 15 | from accelerate import Accelerator 16 | 17 | logger = get_logger(__name__) 18 | 19 | 20 | class PPLInferencer(BaseInferencer): 21 | """PPL In-context Learning Inferencer Class 22 | Perplexity-based In-context Learning Inferencer. 23 | 24 | Attributes: 25 | model (:obj:`AutoModelForCausalLM`, optional): Local PLM (loaded from Hugging Face), which can be initialized by name or a config class. 26 | tokenizer (:obj:`AutoTokenizer` or :obj:`GPT2Tokenizer`, optional): Tokenizer for :obj:`model`. 27 | max_model_token_num (:obj:`int`, optional): Maximum number of tokenized words allowed by the LM. 28 | batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`. 29 | accelerator (:obj:`Accelerator`, optional): An instance of the `Accelerator` class, used for multiprocessing. 30 | output_json_filepath (:obj:`str`, optional): File path for output `JSON` file. 31 | output_json_filename (:obj:`str`, optional): File name for output `JSON` file. 32 | api_name (:obj:`str`, optional): Name of API service. 33 | call_api (:obj:`bool`): If ``True``, an API for LM models will be used, determined by :obj:`api_name`. 34 | labels (:obj:`List`, optional): A list of labels for all classes. 35 | """ 36 | 37 | def __init__(self, 38 | model_name: Optional[str] = 'gpt2-xl', 39 | tokenizer_name: Optional[str] = None, 40 | max_model_token_num: Optional[int] = None, 41 | model_config: Optional[PretrainedConfig] = None, 42 | batch_size: Optional[int] = 1, 43 | accelerator: Optional[Accelerator] = None, 44 | output_json_filepath: Optional[str] = "./icl_inference_output", 45 | output_json_filename: Optional[str] = "predictions", 46 | api_name: Optional[str] = None, 47 | labels: Optional[List] = None, 48 | model_parallel: Optional[bool] = False, 49 | **kwargs 50 | ) -> None: 51 | super().__init__(model_name, tokenizer_name, max_model_token_num, model_config, batch_size, accelerator, 52 | output_json_filepath, output_json_filename, api_name, model_parallel, **kwargs) 53 | self.labels = labels 54 | 55 | def inference(self, retriever: BaseRetriever, ice_template: Optional[PromptTemplate] = None, 56 | prompt_template: Optional[PromptTemplate] = None, output_json_filepath: Optional[str] = None, 57 | output_json_filename: Optional[str] = None, normalizing_str: Optional[str] = None) -> List: 58 | # 1. Preparation for output logs 59 | output_handler = PPLInferencerOutputHandler(self.accelerator) 60 | 61 | sub_predictions = [] 62 | ppl = [] 63 | ice = [] 64 | 65 | if output_json_filepath is None: 66 | output_json_filepath = self.output_json_filepath 67 | if output_json_filename is None: 68 | output_json_filename = self.output_json_filename 69 | 70 | # 2. Get results of retrieval process 71 | ice_idx_list = retriever.retrieve() 72 | 73 | # 3. Get labels of all the classes 74 | if self.labels is None: 75 | labels = retriever.get_labels(ice_template=ice_template, prompt_template=prompt_template) 76 | else: 77 | labels = self.labels 78 | 79 | # 4. Generate in-context examples for testing inputs 80 | for idx in range(len(ice_idx_list)): 81 | ice.append(retriever.generate_ice(ice_idx_list[idx], ice_template=ice_template)) 82 | output_handler.save_ice(ice) 83 | 84 | # 5. Calculating PPL for prompts in each label's class 85 | for label in labels: 86 | index = 0 87 | prompt_list = [] 88 | sub_ppl_list = [] 89 | normalizing_prompt_list = [] 90 | context_length_list = [] 91 | 92 | # 5.1 Generate prompts of current label and truncate 93 | for idx in range(len(ice_idx_list)): 94 | prompt = retriever.generate_label_prompt(idx, ice[idx], label, ice_template=ice_template, 95 | prompt_template=prompt_template, 96 | remain_sep=normalizing_str is not None) 97 | if self.max_model_token_num is not None and self.api_name != 'gpt3': 98 | prompt_token_num = self.get_input_token_num(prompt) 99 | while len(ice_idx_list[idx]) > 0 and prompt_token_num > self.max_model_token_num: 100 | ice_idx_list[idx] = ice_idx_list[idx][:-1] 101 | ice[idx] = retriever.generate_ice(ice_idx_list[idx], ice_template=ice_template) 102 | prompt = retriever.generate_label_prompt(idx, ice[idx], label, ice_template=ice_template, 103 | prompt_template=prompt_template) 104 | prompt_token_num = self.get_input_token_num(prompt) 105 | 106 | if normalizing_str is not None: 107 | prompt_sep = prompt 108 | if prompt_template is not None: 109 | sep_token = prompt_template.sep_token 110 | else: 111 | sep_token = ice_template.sep_token 112 | sep_pos = prompt_sep.find(sep_token) 113 | 114 | context = prompt_sep[0:sep_pos] 115 | answer = prompt_sep[sep_pos:].replace(sep_token, '') 116 | prompt = context + answer 117 | normalizing_prompt = normalizing_str + answer 118 | 119 | context_length_list.append(self.get_input_token_num(context)) 120 | normalizing_prompt_list.append(normalizing_prompt) 121 | prompt_list.append(prompt) 122 | 123 | if normalizing_str is not None: 124 | normalizing_str_len = self.get_input_token_num(normalizing_str) 125 | 126 | # 5.2 Get PPL 127 | logger.info(f"Calculating PPL for prompts labeled '{label}'") 128 | for idx in trange(0, len(prompt_list), self.batch_size, disable=not self.is_main_process): 129 | sub_prompt_list = prompt_list[idx:idx + self.batch_size] 130 | if normalizing_str is not None: 131 | sub_context_length_list = context_length_list[idx:idx + self.batch_size] 132 | sub_normalizing_prompt_list = normalizing_prompt_list[idx:idx + self.batch_size] 133 | 134 | with torch.no_grad(): 135 | if normalizing_str is not None: 136 | res1 = self.__get_ppl(input_texts=sub_prompt_list, mask_length=sub_context_length_list) 137 | res2 = self.__get_ppl(input_texts=sub_normalizing_prompt_list, 138 | mask_length=[normalizing_str_len for i in range(len(sub_prompt_list))] 139 | ) 140 | sub_res = res1 - res2 141 | else: 142 | sub_res = self.__get_ppl(sub_prompt_list).tolist() 143 | for res, prompt in zip(sub_res, sub_prompt_list): 144 | sub_ppl_list.append(res) 145 | output_handler.save_prompt_and_ppl(label, prompt[len(ice[idx]):], prompt, res, index) 146 | index = index + 1 147 | ppl.append(sub_ppl_list) 148 | 149 | # 6. Get lowest PPL class as predictions 150 | ppl = list(zip(*ppl)) 151 | for single_ppl in ppl: 152 | sub_predictions.append(labels[single_ppl.index(min(single_ppl))]) 153 | output_handler.save_predictions(sub_predictions) 154 | 155 | # 7. Output 156 | output_handler.subprocess_write_to_json(output_json_filepath, output_json_filename) 157 | if self.accelerator is not None: 158 | self.accelerator.wait_for_everyone() 159 | output_handler.merge_to_main_process(output_json_filepath, output_json_filename) 160 | output_handler.write_to_json(output_json_filepath, output_json_filename) 161 | 162 | return [sample['prediction'] for sample in output_handler.results_dict.values()] 163 | 164 | def __get_ppl(self, input_texts: List[str], mask_length=None): 165 | if self.call_api: 166 | return api_get_ppl(self.api_name, input_texts) 167 | self.tokenizer.padding_side = "right" 168 | inputs = self.tokenizer(input_texts, padding=True, return_tensors='pt', truncation=True) 169 | inputs = {k: v.to(self.model.device) for k, v in inputs.items()} 170 | outputs = self.model(**inputs) 171 | 172 | shift_logits = outputs.logits[..., :-1, :].contiguous() 173 | shift_labels = inputs["input_ids"][..., 1:].contiguous() 174 | 175 | loss_fct = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=self.tokenizer.pad_token_id) 176 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view( 177 | shift_labels.size()) 178 | 179 | if mask_length is not None: 180 | mask = torch.zeros_like(shift_labels) # [batch,seqlen] 181 | for i in range(len(mask)): 182 | for j in range(mask_length[i] - 1, len(mask[i])): 183 | mask[i][j] = 1 184 | loss = loss * mask 185 | 186 | lens = (inputs["input_ids"] != self.tokenizer.pad_token_id).sum(-1).cpu().numpy() 187 | if mask_length is not None: 188 | lens -= np.array(mask_length) 189 | ce_loss = loss.sum(-1).cpu().detach().numpy() / lens 190 | return ce_loss 191 | -------------------------------------------------------------------------------- /openicl/icl_prompt_template.py: -------------------------------------------------------------------------------- 1 | """Prompt Template""" 2 | 3 | from typing import Dict, Optional, Union, Hashable 4 | from .utils.check_type import _check_type_list, _check_dict 5 | 6 | 7 | class PromptTemplate: 8 | """In-context Learning Prompt Template Class 9 | This class represents a template that guides the generation of prompts in the retrieval or inference process. 10 | 11 | Attributes: 12 | template (:obj:`Dict` or :obj:`str`): A custom template dictionary or string. If a dictionary, the keys of the dictionary represent the values of the output_column, and the values represent the corresponding generated statement. If a string, it represents a string template. 13 | column_token_map (:obj:`Dict`): A dictionary mapping column names to specific tokens. The tokens will be replaced by data in the corresponding column (one piece each time) during the retrieval or inference process. 14 | selected_column_name (:obj:`str`, optional): Used only with string-type templates. A specific column that needs its value to be mapped. 15 | selected_column_map (:obj:`Dict`, optional): Used only with string-type templates. Maps the value of the column :obj:`selected_column_name`. 16 | ice_token(:obj:`str`, optional): A string that represents the specific token mapping from in-context examples. None if you want to use this template only to generate in-context examples, otherwise it can be used to generate the final prompt that is fed into the PLM. The ice_token will be invisible when generating in-context examples. 17 | """ 18 | 19 | def __init__(self, 20 | template: Union[Dict, str], 21 | column_token_map: Dict, 22 | selected_column_name: Optional[str] = None, 23 | selected_column_map: Optional[Dict] = None, 24 | ice_token: Optional[str] = None, 25 | sep_token: Optional[str] = None, 26 | ) -> None: 27 | self.template = _check_type_list(template, [Dict, str]) 28 | self.column_token_map = _check_dict(column_token_map) 29 | self.selected_column_name = _check_type_list(selected_column_name, [None, str]) 30 | self.selected_column_map = _check_type_list(selected_column_map, [None, Dict]) 31 | self.ice_token = _check_type_list(ice_token, [None, str]) 32 | self.sep_token = _check_type_list(sep_token, [None, str]) 33 | if (self.selected_column_name is not None and self.selected_column_map is None) or \ 34 | self.selected_column_name is None and self.selected_column_map is not None: 35 | raise ValueError("self.selected_column_name and self.selected_column_map should be set together") 36 | self._check_template_legacy() 37 | 38 | def _check_template_legacy(self): 39 | if isinstance(self.template, Dict): 40 | # Check if token exists in values of tp_dict 41 | for tp_dict_val in self.template.values(): 42 | if not isinstance(tp_dict_val, str): 43 | raise TypeError(f"dictionary of template expects a str value, but got '{tp_dict_val}'") 44 | if self.ice_token is not None and self.ice_token not in tp_dict_val: 45 | raise LookupError(f"'{self.ice_token}' not in '{tp_dict_val}'") 46 | if isinstance(self.template, str): 47 | if self.ice_token is not None and self.ice_token not in self.template: 48 | raise LookupError(f"'{self.ice_token}' not in '{self.template}'") 49 | 50 | # Check duplicates 51 | if len(self.column_token_map.values()) != len(set(self.column_token_map.values())): 52 | raise ValueError(f"There are duplicates in self.column_token_map.values()") 53 | if self.ice_token is not None and self.ice_token in self.column_token_map.values(): 54 | raise ValueError(f"There are duplicates between self.column_token_map.values() and self.ice_token") 55 | 56 | def generate_ice_item(self, entry: Dict, label: Hashable) -> str: 57 | """Generate in-context example based on the provided :obj:`entry` data. 58 | 59 | Args: 60 | entry (:obj:`Dict`): A piece of data to be used for generating the in-context example. 61 | label (:obj:`Hashable`): The value of the output field. 62 | 63 | Returns: 64 | :obj:`str`: The generated in-context example. 65 | """ 66 | # Select the corresponding template 67 | tp = self.template[label] if isinstance(self.template, Dict) else self.template 68 | # Remove sep token 69 | if self.sep_token is not None: 70 | tp.replace(self.sep_token, '') 71 | # Remove ice_token 72 | if self.ice_token is not None: 73 | tp = tp.replace(self.ice_token, '') 74 | # Replace context token 75 | for key, token in self.column_token_map.items(): 76 | if self.selected_column_map is not None and key == self.selected_column_name: 77 | tp = tp.replace(token, str(self.selected_column_map[label])) 78 | else: 79 | tp = tp.replace(token, str(entry[key])) 80 | return tp 81 | 82 | def generate_label_prompt_item(self, entry: Dict, ice: str, label: Hashable, remain_sep: Optional[bool] = False) -> str: 83 | """Generate prompt based on :obj:`entry` data, :obj:`ice` in-context example, and the corresponding :obj:`label`. 84 | 85 | Args: 86 | 87 | entry (:obj:`Dict`): A piece of data containing the input field content. 88 | ice (:obj:`str`): The generated in-context example. 89 | label (:obj:`Hashable`): The value of the output field. 90 | remain_sep (:obj:`bool`): If remain sep_token 91 | 92 | Raises: 93 | ValueError: If the :obj:`ice_token` attribute of the :obj:`PromptTemplate` instance is :obj:`None`. 94 | 95 | Returns: 96 | :obj:`str`: The generated prompt. 97 | """ 98 | if self.ice_token is None: 99 | raise ValueError("PromptTemplate.ice_token should be not None when generates prompt") 100 | # Select the corresponding template 101 | tp = self.template[label] if isinstance(self.template, Dict) else self.template 102 | # Remove sep token 103 | if not remain_sep and self.sep_token is not None: 104 | tp.replace(self.sep_token, '') 105 | # Insert in-context examples 106 | tp = tp.replace(self.ice_token, ice) 107 | # Replace context token 108 | for key, token in self.column_token_map.items(): 109 | if self.selected_column_map is not None and key == self.selected_column_name: 110 | tp = tp.replace(token, str(self.selected_column_map[label])) 111 | else: 112 | tp = tp.replace(token, str(entry[key])) 113 | return tp 114 | 115 | 116 | def generate_item(self, entry: Dict, output_field: Optional[Hashable] = None, 117 | output_field_replace_token: Optional[str] = '', 118 | ice_field_replace_token: Optional[str] = '') -> str: 119 | """Generate an item based on the provided :obj:`entry` data, as well as optional output field and ice field tokens. 120 | 121 | Args: 122 | entry (:obj:`Dict`): A piece of data. 123 | output_field (:obj:`Hashable`, optional): Column name of output field. Defaults to :obj:`None`. 124 | output_field_replace_token (:obj:`str`, optional): Tokens used to replace output field. Defaults to ``''``. 125 | ice_field_replace_token (str, optional): Tokens used to replace the :obj:`ice_token`. Defaults to ``''``. 126 | 127 | Returns: 128 | :obj:`str`: The generated item. 129 | """ 130 | tp = None 131 | if isinstance(self.template, str): 132 | tp = self.template 133 | else: 134 | pred_label = None 135 | if self.selected_column_name is not None: 136 | pred_label = entry[self.selected_column_name] 137 | if pred_label in self.template.keys(): 138 | tp = self.template[pred_label] 139 | else: 140 | tp = self.template[list(self.template.keys())[0]] 141 | if self.ice_token is not None: 142 | tp = tp.replace(self.ice_token, ice_field_replace_token) 143 | # Remove sep token 144 | if self.sep_token is not None: 145 | tp.replace(self.sep_token, '') 146 | for key, token in self.column_token_map.items(): 147 | if output_field is not None and key == output_field: 148 | tp = tp.replace(token, output_field_replace_token) 149 | else: 150 | tp = tp.replace(token, str(entry[key])) 151 | return tp 152 | 153 | def _check_prompt_template(obj) -> "PromptTemplate": 154 | if isinstance(obj, PromptTemplate): 155 | return obj 156 | else: 157 | raise TypeError(f"Expect a PromptTemplate object, but got {obj}") 158 | 159 | def __repr__(self): 160 | return f"PromptTemplate({{\n\ttemplate: {self.template},\n\tcolumn_token_map: {self.column_token_map},\n\tice_token: {self.ice_token}\n}})" 161 | -------------------------------------------------------------------------------- /openicl/icl_retriever/__init__.py: -------------------------------------------------------------------------------- 1 | from .icl_base_retriever import BaseRetriever 2 | from .icl_random_retriever import RandomRetriever 3 | from .icl_bm25_retriever import BM25Retriever 4 | from .icl_dpp_retriever import DPPRetriever 5 | from .icl_topk_retriever import TopkRetriever 6 | from .icl_mdl_retriever import MDLRetriever 7 | from .icl_votek_retriever import VotekRetriever 8 | from .icl_zero_retriever import ZeroRetriever 9 | -------------------------------------------------------------------------------- /openicl/icl_retriever/icl_base_retriever.py: -------------------------------------------------------------------------------- 1 | """Basic Retriever""" 2 | 3 | from datasets import Dataset, DatasetDict 4 | from typing import List, Union, Optional, Tuple, Dict 5 | from openicl import DatasetReader, PromptTemplate 6 | from openicl.utils.check_type import _check_str 7 | from accelerate import Accelerator 8 | 9 | 10 | class BaseRetriever: 11 | """Basic In-context Learning Retriever Class 12 | Base class for In-context Learning Retriever, without any retrieval method. 13 | 14 | Attributes: 15 | dataset_reader (:obj:`DatasetReader`): An instance of the :obj:`DatasetReader` class. 16 | ice_separator (:obj:`str`, optional): A string that separates each in-context example. 17 | ice_eos_token (:obj:`str`, optional): A string that is added to the end of in-context examples. 18 | prompt_eos_token (:obj:`str`, optional): A string that is added to the end of the prompt. 19 | ice_num (:obj:`int`, optional): The number of data in the in-context examples. 20 | index_split (:obj:`str`, optional): A string for the index dataset name. The index dataset is used to select data for in-context examples. Defaults to ``train``. 21 | test_split (:obj:`str`, optional): A string for the generation dataset name. The test dataset is used to generate prompts for each data. Defaults to ``test``. 22 | index_ds (:obj:`Dataset`): The index dataset. Used to select data for in-context examples. 23 | test_ds (:obj:`Dataset`): The test dataset. Used to generate prompts for each data. 24 | accelerator (:obj:`Accelerator`, optional): An instance of the :obj:`Accelerator` class, used for multiprocessing. 25 | """ 26 | index_ds = None 27 | test_ds = None 28 | 29 | def __init__(self, 30 | dataset_reader: DatasetReader, 31 | ice_separator: Optional[str] = '\n', 32 | ice_eos_token: Optional[str] = '\n', 33 | prompt_eos_token: Optional[str] = '', 34 | ice_num: Optional[int] = 1, 35 | index_split: Optional[str] = 'train', 36 | test_split: Optional[str] = 'test', 37 | accelerator: Optional[Accelerator] = None 38 | ) -> None: 39 | self.dataset_reader = DatasetReader._check_dataset_reader(dataset_reader) 40 | self.ice_separator = ice_separator 41 | self.ice_eos_token = ice_eos_token 42 | self.prompt_eos_token = prompt_eos_token 43 | self.ice_num = ice_num 44 | self.index_split = index_split 45 | self.test_split = test_split 46 | self.accelerator = accelerator 47 | self.is_main_process = True if self.accelerator is None or self.accelerator.is_main_process else False 48 | if isinstance(self.dataset_reader.dataset, Dataset): 49 | self.index_ds = self.dataset_reader.dataset 50 | self.test_ds = self.dataset_reader.dataset 51 | if self.accelerator is not None: 52 | self.test_ds = self.test_ds.shard( 53 | num_shards=self.accelerator.num_processes, 54 | index=self.accelerator.process_index 55 | ) 56 | else: 57 | self.index_ds = self.dataset_reader.dataset[self.index_split] 58 | self.test_ds = self.dataset_reader.dataset[self.test_split] 59 | 60 | if self.accelerator is not None: 61 | self.test_ds = self.test_ds.shard( 62 | num_shards=self.accelerator.num_processes, 63 | index=self.accelerator.process_index 64 | ) 65 | 66 | def retrieve(self) -> List[List]: 67 | """ 68 | Retrieve for each data in generation_ds. 69 | 70 | Returns: 71 | `List[List]`: the index list of in-context example for each data in `test_ds`. 72 | """ 73 | raise NotImplementedError("Method hasn't been implemented yet") 74 | 75 | def get_labels(self, ice_template: Optional[PromptTemplate] = None, 76 | prompt_template: Optional[PromptTemplate] = None): 77 | labels = [] 78 | if prompt_template is not None and isinstance(prompt_template.template, Dict): 79 | labels = list(prompt_template.template.keys())[:] 80 | elif ice_template is not None and ice_template.ice_token is not None and isinstance(ice_template.template, 81 | Dict): 82 | labels = list(ice_template.template.keys())[:] 83 | else: 84 | labels = list(set(self.test_ds[self.dataset_reader.output_column])) 85 | return labels 86 | 87 | def generate_ice(self, idx_list: List[int], ice_template: Optional[PromptTemplate] = None) -> str: 88 | generated_ice_list = [] 89 | dr = self.dataset_reader 90 | for idx in idx_list: 91 | if ice_template is None: 92 | generated_ice_list.append(' '.join(list(map(str, 93 | [self.index_ds[idx][ctx] for ctx in dr.input_columns] + [ 94 | self.index_ds[idx][dr.output_column]])))) 95 | else: 96 | generated_ice_list.append( 97 | ice_template.generate_ice_item(self.index_ds[idx], self.index_ds[idx][dr.output_column])) 98 | generated_ice = self.ice_separator.join(generated_ice_list) + self.ice_eos_token 99 | return generated_ice 100 | 101 | def generate_prompt(self, idx: int, ice: str, ice_template: Optional[PromptTemplate] = None, 102 | prompt_template: Optional[PromptTemplate] = None) -> Tuple[List[str], List]: 103 | prompt_list = [] 104 | labels = [] 105 | if prompt_template is not None and isinstance(prompt_template.template, Dict): 106 | labels = list(prompt_template.template.keys())[:] 107 | elif ice_template is not None and isinstance(ice_template.template, 108 | Dict) and ice_template.ice_token is not None: 109 | labels = list(ice_template.template.keys())[:] 110 | else: 111 | labels = list(set(self.test_ds[self.dataset_reader.output_column])) 112 | for label in labels: 113 | prompt_list.append(self.generate_label_prompt(idx, ice, label)) 114 | return prompt_list, labels 115 | 116 | def generate_label_prompt(self, idx: int, ice: str, label, ice_template: Optional[PromptTemplate] = None, 117 | prompt_template: Optional[PromptTemplate] = None, remain_sep: Optional[bool] = False) -> str: 118 | if prompt_template is not None: 119 | return prompt_template.generate_label_prompt_item(self.test_ds[idx], ice, label, remain_sep) + self.prompt_eos_token 120 | elif ice_template is not None and ice_template.ice_token is not None: 121 | return ice_template.generate_label_prompt_item(self.test_ds[idx], ice, label, remain_sep) + self.prompt_eos_token 122 | else: 123 | prefix_prompt = ' '.join( 124 | list(map(str, [self.test_ds[idx][ctx] for ctx in self.dataset_reader.input_columns]))) 125 | return ice + prefix_prompt + ' ' + str(label) + self.prompt_eos_token 126 | 127 | def generate_prompt_for_generate_task(self, idx, ice, gen_field_replace_token='', 128 | ice_template: Optional[PromptTemplate] = None, 129 | prompt_template: Optional[PromptTemplate] = None): 130 | if prompt_template is not None: 131 | return prompt_template.generate_item(self.test_ds[idx], output_field=self.dataset_reader.output_column, 132 | output_field_replace_token=gen_field_replace_token, 133 | ice_field_replace_token=ice) + self.prompt_eos_token 134 | elif ice_template is not None and ice_template.ice_token is not None: 135 | return ice_template.generate_item(self.test_ds[idx], output_field=self.dataset_reader.output_column, 136 | output_field_replace_token=gen_field_replace_token, 137 | ice_field_replace_token=ice) + self.prompt_eos_token 138 | else: 139 | prefix_prompt = ' '.join( 140 | list(map(str, [self.test_ds[idx][ctx] for ctx in self.dataset_reader.input_columns]))) 141 | return ice + prefix_prompt + gen_field_replace_token + self.prompt_eos_token 142 | -------------------------------------------------------------------------------- /openicl/icl_retriever/icl_bm25_retriever.py: -------------------------------------------------------------------------------- 1 | """BM25 Retriever""" 2 | 3 | from openicl import DatasetReader 4 | from openicl.icl_retriever import BaseRetriever 5 | from openicl.utils.logging import get_logger 6 | from typing import List, Union, Optional 7 | from rank_bm25 import BM25Okapi 8 | import numpy as np 9 | from tqdm import trange 10 | from accelerate import Accelerator 11 | from nltk.tokenize import word_tokenize 12 | 13 | logger = get_logger(__name__) 14 | 15 | 16 | class BM25Retriever(BaseRetriever): 17 | """BM25 In-context Learning Retriever Class 18 | Class of BM25 Retriever. 19 | 20 | Attributes: 21 | dataset_reader (:obj:`DatasetReader`): An instance of the :obj:`DatasetReader` class. 22 | ice_separator (:obj:`str`, optional): A string that separates each in-context example. 23 | ice_eos_token (:obj:`str`, optional): A string that is added to the end of in-context examples. 24 | prompt_eos_token (:obj:`str`, optional): A string that is added to the end of the prompt. 25 | ice_num (:obj:`int`, optional): The number of data in the in-context examples. 26 | index_split (:obj:`str`, optional): A string for the index dataset name. The index dataset is used to select data for in-context examples. Defaults to ``train``. 27 | test_split (:obj:`str`, optional): A string for the generation dataset name. The test dataset is used to generate prompts for each data. Defaults to ``test``. 28 | index_ds (:obj:`Dataset`): The index dataset. Used to select data for in-context examples. 29 | test_ds (:obj:`Dataset`): The test dataset. Used to generate prompts for each data. 30 | accelerator (:obj:`Accelerator`, optional): An instance of the :obj:`Accelerator` class, used for multiprocessing. 31 | index_corpus (:obj:`List[str]`) : A corpus created from the input field data of :obj:`index_ds`. 32 | test_corpus (:obj:`List[str]`) : A corpus created from the input field data of :obj:`test_ds`. 33 | bm25 (:obj:`BM250kapi`): An instance of :obj:`BM250kapi` class, initialized using :obj:`index_ds`. 34 | """ 35 | bm25 = None 36 | index_corpus = None 37 | test_corpus = None 38 | 39 | def __init__(self, 40 | dataset_reader: DatasetReader, 41 | ice_separator: Optional[str] = '\n', 42 | ice_eos_token: Optional[str] = '\n', 43 | prompt_eos_token: Optional[str] = '', 44 | ice_num: Optional[int] = 1, 45 | index_split: Optional[str] = 'train', 46 | test_split: Optional[str] = 'test', 47 | accelerator: Optional[Accelerator] = None 48 | ) -> None: 49 | super().__init__(dataset_reader, ice_separator, ice_eos_token, prompt_eos_token, ice_num, index_split, 50 | test_split, accelerator) 51 | self.index_corpus = [word_tokenize(data) for data in 52 | self.dataset_reader.generate_input_field_corpus(self.index_ds)] 53 | self.bm25 = BM25Okapi(self.index_corpus) 54 | self.test_corpus = [word_tokenize(data) for data in 55 | self.dataset_reader.generate_input_field_corpus(self.test_ds)] 56 | 57 | def retrieve(self) -> List[List]: 58 | rtr_idx_list = [] 59 | logger.info("Retrieving data for test set...") 60 | for idx in trange(len(self.test_corpus), disable=not self.is_main_process): 61 | query = self.test_corpus[idx] 62 | scores = self.bm25.get_scores(query) 63 | near_ids = list(np.argsort(scores)[::-1][:self.ice_num]) 64 | near_ids = [int(a) for a in near_ids] 65 | rtr_idx_list.append(near_ids) 66 | return rtr_idx_list 67 | -------------------------------------------------------------------------------- /openicl/icl_retriever/icl_dpp_retriever.py: -------------------------------------------------------------------------------- 1 | """DPP Retriever""" 2 | 3 | from openicl import DatasetReader 4 | from openicl.icl_retriever.icl_topk_retriever import TopkRetriever 5 | from openicl.utils.logging import get_logger 6 | from typing import Optional 7 | import tqdm 8 | import numpy as np 9 | import math 10 | from accelerate import Accelerator 11 | 12 | logger = get_logger(__name__) 13 | 14 | 15 | class DPPRetriever(TopkRetriever): 16 | """DPP In-context Learning Retriever Class 17 | Class of DPP Retriever. 18 | Two-stage DPP is used, where first stage is to get results of TopK to reduce candidate sets 19 | chechout https://arxiv.org/abs/2302.05698 for details. 20 | 21 | Attributes: 22 | dataset_reader (:obj:`DatasetReader`): An instance of the :obj:`DatasetReader` class. 23 | ice_separator (:obj:`str`, optional): A string that separates each in-context example. 24 | ice_eos_token (:obj:`str`, optional): A string that is added to the end of in-context examples. 25 | prompt_eos_token (:obj:`str`, optional): A string that is added to the end of the prompt. 26 | ice_num (:obj:`int`, optional): The number of data in the in-context examples. 27 | index_split (:obj:`str`, optional): A string for the index dataset name. The index dataset is used to select data for in-context examples. Defaults to ``train``. 28 | test_split (:obj:`str`, optional): A string for the generation dataset name. The test dataset is used to generate prompts for each data. Defaults to ``test``. 29 | index_ds (:obj:`Dataset`): The index dataset. Used to select data for in-context examples. 30 | test_ds (:obj:`Dataset`): The test dataset. Used to generate prompts for each data. 31 | accelerator (:obj:`Accelerator`, optional): An instance of the :obj:`Accelerator` class, used for multiprocessing. 32 | batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`. 33 | model (:obj:`SentenceTransformer`): An instance of :obj:`SentenceTransformer` class, used to calculate embeddings. 34 | tokenizer (:obj:`AutoTokenizer`): Tokenizer for :obj:`model`. 35 | index (:obj:`IndexIDMap`): Index generated with FAISS. 36 | seed (:obj:`int`, optional): Seed for the random number generator. (:obj:`random_state` in :obj:`sample_exact_k_dpp` method) 37 | scale_factor (:obj:`float`, optional): A factor when gets the kernel. 38 | """ 39 | model = None 40 | 41 | def __init__(self, 42 | dataset_reader: DatasetReader, 43 | ice_separator: Optional[str] = '\n', 44 | ice_eos_token: Optional[str] = '\n', 45 | prompt_eos_token: Optional[str] = '', 46 | sentence_transformers_model_name: Optional[str] = 'all-mpnet-base-v2', 47 | ice_num: Optional[int] = 1, 48 | candidate_num: Optional[int] = 1, 49 | index_split: Optional[str] = 'train', 50 | test_split: Optional[str] = 'test', 51 | tokenizer_name: Optional[str] = 'gpt2-xl', 52 | batch_size: Optional[int] = 1, 53 | accelerator: Optional[Accelerator] = None, 54 | seed: Optional[int] = 1, 55 | scale_factor: Optional[float] = 0.1 56 | ) -> None: 57 | super().__init__(dataset_reader, ice_separator, ice_eos_token, prompt_eos_token, 58 | sentence_transformers_model_name, ice_num, index_split, test_split, tokenizer_name, batch_size, 59 | accelerator) 60 | self.candidate_num = candidate_num 61 | self.seed = seed 62 | self.scale_factor = scale_factor 63 | 64 | def dpp_search(self): 65 | res_list = self.forward(self.dataloader, process_bar=True, information="Embedding test set...") 66 | rtr_idx_list = [[] for _ in range(len(res_list))] 67 | logger.info("Retrieving data for test set...") 68 | for entry in tqdm.tqdm(res_list, disable=not self.is_main_process): 69 | idx = entry['metadata']['id'] 70 | 71 | # get TopK results 72 | embed = np.expand_dims(entry['embed'], axis=0) 73 | near_ids = np.array(self.index.search(embed, self.candidate_num)[1][0].tolist()) 74 | 75 | # DPP stage 76 | near_reps, rel_scores, kernel_matrix = self.get_kernel(embed, near_ids.tolist()) 77 | 78 | # MAP inference 79 | samples_ids = fast_map_dpp(kernel_matrix, self.ice_num) 80 | 81 | # ordered by relevance score 82 | samples_scores = np.array([rel_scores[i] for i in samples_ids]) 83 | samples_ids = samples_ids[(-samples_scores).argsort()].tolist() 84 | rtr_sub_list = [int(near_ids[i]) for i in samples_ids] 85 | 86 | rtr_idx_list[idx] = rtr_sub_list 87 | 88 | return rtr_idx_list 89 | 90 | def retrieve(self): 91 | return self.dpp_search() 92 | 93 | def get_kernel(self, embed, candidates): 94 | near_reps = np.stack([self.index.index.reconstruct(i) for i in candidates], axis=0) 95 | # normalize first 96 | embed = embed / np.linalg.norm(embed) 97 | near_reps = near_reps / np.linalg.norm(near_reps, keepdims=True, axis=1) 98 | 99 | # to make kernel-matrix non-negative 100 | rel_scores = np.matmul(embed, near_reps.T)[0] 101 | rel_scores = (rel_scores + 1) / 2 102 | 103 | # to prevent overflow error 104 | rel_scores -= rel_scores.max() 105 | 106 | # to balance relevance and diversity 107 | rel_scores = np.exp(rel_scores / (2 * self.scale_factor)) 108 | 109 | # to make kernel-matrix non-negative 110 | sim_matrix = np.matmul(near_reps, near_reps.T) 111 | sim_matrix = (sim_matrix + 1) / 2 112 | 113 | kernel_matrix = rel_scores[None] * sim_matrix * rel_scores[:, None] 114 | return near_reps, rel_scores, kernel_matrix 115 | 116 | 117 | def fast_map_dpp(kernel_matrix, max_length): 118 | """ 119 | fast implementation of the greedy algorithm 120 | reference: https://github.com/laming-chen/fast-map-dpp/blob/master/dpp_test.py 121 | paper: Fast Greedy MAP Inference for Determinantal Point Process to Improve Recommendation Diversity 122 | """ 123 | item_size = kernel_matrix.shape[0] 124 | cis = np.zeros((max_length, item_size)) 125 | di2s = np.copy(np.diag(kernel_matrix)) 126 | selected_items = list() 127 | selected_item = np.argmax(di2s) 128 | selected_items.append(int(selected_item)) 129 | while len(selected_items) < max_length: 130 | k = len(selected_items) - 1 131 | ci_optimal = cis[:k, selected_item] 132 | di_optimal = math.sqrt(di2s[selected_item]) 133 | elements = kernel_matrix[selected_item, :] 134 | eis = (elements - np.dot(ci_optimal, cis[:k, :])) / di_optimal 135 | cis[k, :] = eis 136 | di2s -= np.square(eis) 137 | selected_item = np.argmax(di2s) 138 | selected_items.append(int(selected_item)) 139 | return selected_items 140 | -------------------------------------------------------------------------------- /openicl/icl_retriever/icl_mdl_retriever.py: -------------------------------------------------------------------------------- 1 | """MDL Retriever""" 2 | 3 | from openicl import DatasetReader, PromptTemplate 4 | from openicl.icl_retriever.icl_topk_retriever import TopkRetriever 5 | from openicl.utils.calculate import entropy 6 | from openicl.utils.logging import get_logger 7 | from typing import List, Union, Optional, Tuple 8 | from transformers import AutoModelForCausalLM 9 | import tqdm 10 | import torch 11 | import numpy as np 12 | from accelerate import Accelerator 13 | 14 | logger = get_logger(__name__) 15 | 16 | 17 | class MDLRetriever(TopkRetriever): 18 | """MDL In-context Learning Retriever Class 19 | Class of MDL Retriever. 20 | 21 | Attributes: 22 | dataset_reader (:obj:`DatasetReader`): An instance of the :obj:`DatasetReader` class. 23 | ice_separator (:obj:`str`, optional): A string that separates each in-context example. 24 | ice_eos_token (:obj:`str`, optional): A string that is added to the end of in-context examples. 25 | prompt_eos_token (:obj:`str`, optional): A string that is added to the end of the prompt. 26 | ice_num (:obj:`int`, optional): The number of data in the in-context examples. 27 | candidate_num (:obj:`int`, optional): The number of data selected in TopK stage. 28 | index_split (:obj:`str`, optional): A string for the index dataset name. The index dataset is used to select data for in-context examples. Defaults to ``train``. 29 | test_split (:obj:`str`, optional): A string for the generation dataset name. The test dataset is used to generate prompts for each data. Defaults to ``test``. 30 | index_ds (:obj:`Dataset`): The index dataset. Used to select data for in-context examples. 31 | test_ds (:obj:`Dataset`): The test dataset. Used to generate prompts for each data. 32 | accelerator (:obj:`Accelerator`, optional): An instance of the :obj:`Accelerator` class, used for multiprocessing. 33 | batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`. 34 | model (:obj:`SentenceTransformer`): An instance of :obj:`SentenceTransformer` class, used to calculate embeddings. 35 | tokenizer (:obj:`AutoTokenizer`): Tokenizer for :obj:`model`. 36 | index (:obj:`IndexIDMap`): Index generated with FAISS. 37 | select_time (:obj:`int`, optional): Number of random selections in the MDL stage. 38 | labels (:obj:`List`, optional): A list of labels for all classes used to generate prompts when calculating MDL. 39 | seed (:obj:`int`, optional): Seed for the random number generator. 40 | """ 41 | metric_model = None 42 | 43 | def __init__(self, 44 | dataset_reader: DatasetReader, 45 | ice_separator: Optional[str] = '\n', 46 | ice_eos_token: Optional[str] = '\n', 47 | prompt_eos_token: Optional[str] = '', 48 | sentence_transformers_model_name: Optional[str] = 'all-mpnet-base-v2', 49 | ice_num: Optional[int] = 1, 50 | candidate_num: Optional[int] = 1, 51 | index_split: Optional[str] = 'train', 52 | test_split: Optional[str] = 'test', 53 | tokenizer_name: Optional[str] = 'gpt2-xl', 54 | ce_model_name: Optional[str] = 'gpt2-xl', 55 | batch_size: Optional[int] = 1, 56 | select_time: Optional[int] = 5, 57 | accelerator: Optional[Accelerator] = None, 58 | ice_template: Optional[PromptTemplate] = None, 59 | prompt_template: Optional[PromptTemplate] = None, 60 | labels: Optional[List] = None, 61 | seed: Optional[int] = 1 62 | ) -> None: 63 | super().__init__(dataset_reader, ice_separator, ice_eos_token, prompt_eos_token, 64 | sentence_transformers_model_name, ice_num, index_split, test_split, tokenizer_name, batch_size, 65 | accelerator) 66 | self.ce_model_name = ce_model_name 67 | self.candidate_num = candidate_num 68 | self.select_time = select_time 69 | self.ice_template = ice_template 70 | self.prompt_template = prompt_template 71 | self.labels = labels 72 | self.seed = seed 73 | 74 | def topk_search(self): 75 | np.random.seed(self.seed) 76 | res_list = self.forward(self.dataloader) 77 | rtr_idx_list = [[] for _ in range(len(res_list))] 78 | 79 | logger.info("Retrieving data for test set...") 80 | for entry in tqdm.tqdm(res_list, disable=not self.is_main_process): 81 | idx = entry['metadata']['id'] 82 | 83 | embed = np.expand_dims(entry['embed'], axis=0) 84 | near_ids = self.index.search(embed, min(self.candidate_num, len(self.index_ds)))[1][0].tolist() 85 | candidates = [] 86 | mdl_scores = [] 87 | for j in range(self.select_time): 88 | if j == 0: 89 | rand_idx_list = near_ids[:self.ice_num] 90 | else: 91 | rand_idx_list = np.random.choice(near_ids, self.ice_num, replace=False) 92 | rand_idx_list = [int(i) for i in rand_idx_list] 93 | candidates.append(rand_idx_list) 94 | 95 | ice = self.generate_ice(rand_idx_list, ice_template=self.ice_template) 96 | mask_length = len(self.tokenizer(ice + self.ice_eos_token, verbose=False)['input_ids']) 97 | if self.labels is None: 98 | labels = self.get_labels(self.ice_template, self.prompt_template) 99 | else: 100 | labels = self.labels 101 | prompt_list = [] 102 | for label in labels: 103 | prompt = self.generate_label_prompt(idx, ice, label, self.ice_template, self.prompt_template) 104 | prompt_list.append(prompt) 105 | loss_list = self.cal_ce(prompt_list, mask_length=mask_length) 106 | probs = np.exp(-np.array(loss_list)) 107 | normalized_probs = probs / probs.sum(0, keepdims=True) 108 | neg_entropy = -entropy(normalized_probs, label_dim=0) 109 | mdl_scores.append(neg_entropy) 110 | 111 | rtr_idx_list[idx] = candidates[mdl_scores.index(max(mdl_scores))] 112 | rtr_idx_list[idx] = [int(i) for i in rtr_idx_list[idx]] 113 | 114 | return rtr_idx_list 115 | 116 | def retrieve(self): 117 | return self.topk_search() 118 | 119 | def cal_ce(self, input_texts: List[str], mask_length=None): 120 | if self.metric_model is None: 121 | logger.info(f'Load model {self.metric_model} for calculating MDL...') 122 | self.metric_model = AutoModelForCausalLM.from_pretrained(self.ce_model_name) 123 | self.metric_model.to(self.device) 124 | inputs = self.tokenizer(input_texts, padding=True, return_tensors='pt', truncation=True) 125 | inputs = {k: v.to(self.device) for k, v in inputs.items()} 126 | outputs = self.metric_model(**inputs) 127 | 128 | shift_logits = outputs.logits[..., :-1, :].contiguous() 129 | shift_labels = inputs["input_ids"][..., 1:].contiguous() 130 | 131 | loss_fct = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=self.tokenizer.pad_token_id) 132 | shift_logits = shift_logits.view(-1, shift_logits.size(-1)) 133 | loss = loss_fct(shift_logits, shift_labels.view(-1)).view(shift_labels.size()) 134 | if mask_length is not None: 135 | mask = torch.cat([torch.zeros([loss.shape[0], mask_length], dtype=torch.float), 136 | torch.ones([loss.shape[0], loss.shape[-1] - mask_length], dtype=torch.float)], -1) 137 | mask = mask.to(self.device) 138 | loss = torch.mul(mask, loss) 139 | 140 | lens = (inputs["input_ids"] != self.tokenizer.pad_token_id).sum(-1).cpu().numpy() 141 | if mask_length is not None: 142 | lens -= mask_length 143 | ce_loss = loss.sum(-1).cpu().detach().numpy() / lens 144 | return ce_loss 145 | -------------------------------------------------------------------------------- /openicl/icl_retriever/icl_random_retriever.py: -------------------------------------------------------------------------------- 1 | '''Random Retriever''' 2 | 3 | from openicl import DatasetReader 4 | from openicl.icl_retriever import BaseRetriever 5 | from openicl.utils.logging import get_logger 6 | from typing import List, Union, Optional 7 | from tqdm import trange 8 | import numpy as np 9 | from accelerate import Accelerator 10 | 11 | logger = get_logger(__name__) 12 | 13 | 14 | class RandomRetriever(BaseRetriever): 15 | """Random In-context Learning Retriever Class 16 | Class of Random Retriever. 17 | 18 | Attributes: 19 | dataset_reader (:obj:`DatasetReader`): An instance of the :obj:`DatasetReader` class. 20 | ice_separator (:obj:`str`, optional): A string that separates each in-context example. 21 | ice_eos_token (:obj:`str`, optional): A string that is added to the end of in-context examples. 22 | prompt_eos_token (:obj:`str`, optional): A string that is added to the end of the prompt. 23 | ice_num (:obj:`int`, optional): The number of data in the in-context examples. 24 | index_split (:obj:`str`, optional): A string for the index dataset name. The index dataset is used to select data for in-context examples. Defaults to ``train``. 25 | test_split (:obj:`str`, optional): A string for the generation dataset name. The test dataset is used to generate prompts for each data. Defaults to ``test``. 26 | index_ds (:obj:`Dataset`): The index dataset. Used to select data for in-context examples. 27 | test_ds (:obj:`Dataset`): The test dataset. Used to generate prompts for each data. 28 | accelerator (:obj:`Accelerator`, optional): An instance of the :obj:`Accelerator` class, used for multiprocessing. 29 | seed (`int`, optional): Seed for the random number generator. 30 | """ 31 | 32 | def __init__(self, 33 | dataset_reader: DatasetReader, 34 | ice_separator: Optional[str] = '\n', 35 | ice_eos_token: Optional[str] = '\n', 36 | prompt_eos_token: Optional[str] = '', 37 | ice_num: Optional[int] = 1, 38 | index_split: Optional[str] = 'train', 39 | test_split: Optional[str] = 'test', 40 | seed: Optional[int] = 43, 41 | accelerator: Optional[Accelerator] = None 42 | ) -> None: 43 | super().__init__(dataset_reader, ice_separator, ice_eos_token, prompt_eos_token, ice_num, index_split, 44 | test_split, accelerator) 45 | self.seed = seed 46 | 47 | def retrieve(self): 48 | np.random.seed(self.seed) 49 | num_idx = len(self.index_ds) 50 | rtr_idx_list = [] 51 | logger.info("Retrieving data for test set...") 52 | for _ in trange(len(self.test_ds), disable=not self.is_main_process): 53 | idx_list = np.random.choice(num_idx, self.ice_num, replace=False).tolist() 54 | rtr_idx_list.append(idx_list) 55 | return rtr_idx_list 56 | -------------------------------------------------------------------------------- /openicl/icl_retriever/icl_topk_retriever.py: -------------------------------------------------------------------------------- 1 | """Topk Retriever""" 2 | 3 | from openicl import DatasetReader 4 | from openicl.icl_dataset_reader import DatasetEncoder 5 | from openicl.icl_retriever import BaseRetriever 6 | from openicl.utils.collators import DataCollatorWithPaddingAndCuda 7 | from openicl.utils.logging import get_logger 8 | import torch 9 | from torch.utils.data import DataLoader 10 | from typing import Optional 11 | from transformers import AutoTokenizer 12 | from sentence_transformers import SentenceTransformer 13 | import tqdm 14 | import faiss 15 | import copy 16 | import numpy as np 17 | from accelerate import Accelerator 18 | 19 | logger = get_logger(__name__) 20 | 21 | 22 | class TopkRetriever(BaseRetriever): 23 | """Topk In-context Learning Retriever Class 24 | Class of Topk Retriever. 25 | 26 | Attributes: 27 | dataset_reader (:obj:`DatasetReader`): An instance of the :obj:`DatasetReader` class. 28 | ice_separator (:obj:`str`, optional): A string that separates each in-context example. 29 | ice_eos_token (:obj:`str`, optional): A string that is added to the end of in-context examples. 30 | prompt_eos_token (:obj:`str`, optional): A string that is added to the end of the prompt. 31 | ice_num (:obj:`int`, optional): The number of data in the in-context examples. 32 | index_split (:obj:`str`, optional): A string for the index dataset name. The index dataset is used to select data for in-context examples. Defaults to ``train``. 33 | test_split (:obj:`str`, optional): A string for the generation dataset name. The test dataset is used to generate prompts for each data. Defaults to ``test``. 34 | index_ds (:obj:`Dataset`): The index dataset. Used to select data for in-context examples. 35 | test_ds (:obj:`Dataset`): The test dataset. Used to generate prompts for each data. 36 | accelerator (:obj:`Accelerator`, optional): An instance of the :obj:`Accelerator` class, used for multiprocessing. 37 | batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`. 38 | model (:obj:`SentenceTransformer`): An instance of :obj:`SentenceTransformer` class, used to calculate embeddings. 39 | tokenizer (:obj:`AutoTokenizer`): Tokenizer for :obj:`model`. 40 | index (:obj:`IndexIDMap`): Index generated with FAISS. 41 | """ 42 | model = None 43 | 44 | def __init__(self, 45 | dataset_reader: DatasetReader, 46 | ice_separator: Optional[str] = '\n', 47 | ice_eos_token: Optional[str] = '\n', 48 | prompt_eos_token: Optional[str] = '', 49 | sentence_transformers_model_name: Optional[str] = 'all-mpnet-base-v2', 50 | ice_num: Optional[int] = 1, 51 | index_split: Optional[str] = 'train', 52 | test_split: Optional[str] = 'test', 53 | tokenizer_name: Optional[str] = 'gpt2-xl', 54 | batch_size: Optional[int] = 1, 55 | accelerator: Optional[Accelerator] = None 56 | ) -> None: 57 | super().__init__(dataset_reader, ice_separator, ice_eos_token, prompt_eos_token, ice_num, index_split, 58 | test_split, accelerator) 59 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 60 | self.batch_size = batch_size 61 | self.tokenizer_name = tokenizer_name 62 | gen_datalist = self.dataset_reader.generate_input_field_corpus(self.test_ds) 63 | 64 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 65 | self.tokenizer.pad_token = self.tokenizer.eos_token 66 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id 67 | self.tokenizer.padding_side = "right" 68 | 69 | self.encode_dataset = DatasetEncoder(gen_datalist, tokenizer=self.tokenizer) 70 | co = DataCollatorWithPaddingAndCuda(tokenizer=self.tokenizer, device=self.device) 71 | self.dataloader = DataLoader(self.encode_dataset, batch_size=self.batch_size, collate_fn=co) 72 | 73 | self.model = SentenceTransformer(sentence_transformers_model_name) 74 | 75 | self.model = self.model.to(self.device) 76 | self.model.eval() 77 | 78 | self.index = self.create_index() 79 | 80 | def create_index(self): 81 | self.select_datalist = self.dataset_reader.generate_input_field_corpus(self.index_ds) 82 | encode_datalist = DatasetEncoder(self.select_datalist, tokenizer=self.tokenizer) 83 | co = DataCollatorWithPaddingAndCuda(tokenizer=self.tokenizer, device=self.device) 84 | dataloader = DataLoader(encode_datalist, batch_size=self.batch_size, collate_fn=co) 85 | index = faiss.IndexIDMap(faiss.IndexFlatIP(self.model.get_sentence_embedding_dimension())) 86 | res_list = self.forward(dataloader, process_bar=True, information="Creating index for index set...") 87 | id_list = np.array([res['metadata']['id'] for res in res_list]) 88 | self.embed_list = np.stack([res['embed'] for res in res_list]) 89 | index.add_with_ids(self.embed_list, id_list) 90 | return index 91 | 92 | def knn_search(self, ice_num): 93 | res_list = self.forward(self.dataloader, process_bar=True, information="Embedding test set...") 94 | rtr_idx_list = [[] for _ in range(len(res_list))] 95 | logger.info("Retrieving data for test set...") 96 | for entry in tqdm.tqdm(res_list, disable=not self.is_main_process): 97 | idx = entry['metadata']['id'] 98 | embed = np.expand_dims(entry['embed'], axis=0) 99 | near_ids = self.index.search(embed, ice_num)[1][0].tolist() 100 | rtr_idx_list[idx] = near_ids 101 | return rtr_idx_list 102 | 103 | def forward(self, dataloader, process_bar=False, information=''): 104 | res_list = [] 105 | _dataloader = copy.deepcopy(dataloader) 106 | if process_bar: 107 | logger.info(information) 108 | _dataloader = tqdm.tqdm(_dataloader, disable=not self.is_main_process) 109 | for _, entry in enumerate(_dataloader): 110 | with torch.no_grad(): 111 | metadata = entry.pop("metadata") 112 | raw_text = self.tokenizer.batch_decode(entry['input_ids'], skip_special_tokens=True, verbose=False) 113 | res = self.model.encode(raw_text, show_progress_bar=False) 114 | res_list.extend([{"embed": r, "metadata": m} for r, m in zip(res, metadata)]) 115 | return res_list 116 | 117 | def retrieve(self): 118 | return self.knn_search(self.ice_num) 119 | -------------------------------------------------------------------------------- /openicl/icl_retriever/icl_votek_retriever.py: -------------------------------------------------------------------------------- 1 | """Votek Retriever""" 2 | 3 | import os 4 | import json 5 | from openicl import DatasetReader 6 | from openicl.icl_retriever.icl_topk_retriever import TopkRetriever 7 | from typing import List, Union, Optional, Tuple 8 | from sklearn.metrics.pairwise import cosine_similarity 9 | from collections import defaultdict 10 | import numpy as np 11 | import random 12 | from accelerate import Accelerator 13 | 14 | 15 | class VotekRetriever(TopkRetriever): 16 | """Vote-k In-context Learning Retriever Class 17 | Class of Vote-k Retriever. 18 | 19 | Attributes: 20 | dataset_reader (:obj:`DatasetReader`): An instance of the :obj:`DatasetReader` class. 21 | ice_separator (:obj:`str`, optional): A string that separates each in-context example. 22 | ice_eos_token (:obj:`str`, optional): A string that is added to the end of in-context examples. 23 | prompt_eos_token (:obj:`str`, optional): A string that is added to the end of the prompt. 24 | ice_num (:obj:`int`, optional): The number of data in the in-context examples. 25 | index_split (:obj:`str`, optional): A string for the index dataset name. The index dataset is used to select data for in-context examples. Defaults to ``train``. 26 | test_split (:obj:`str`, optional): A string for the generation dataset name. The test dataset is used to generate prompts for each data. Defaults to ``test``. 27 | index_ds (:obj:`Dataset`): The index dataset. Used to select data for in-context examples. 28 | test_ds (:obj:`Dataset`): The test dataset. Used to generate prompts for each data. 29 | accelerator (:obj:`Accelerator`, optional): An instance of the :obj:`Accelerator` class, used for multiprocessing. 30 | batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`. 31 | model (:obj:`SentenceTransformer`): An instance of :obj:`SentenceTransformer` class, used to calculate embeddings. 32 | tokenizer (:obj:`AutoTokenizer`): Tokenizer for :obj:``model``. 33 | index (:obj:`IndexIDMap`): Index generated with FAISS. 34 | votek_k (:obj:`int`, optional): ``k`` value of Voke-k Selective Annotation Algorithm. Defaults to ``3``. 35 | """ 36 | 37 | def __init__(self, 38 | dataset_reader: DatasetReader, 39 | ice_separator: Optional[str] = '\n', 40 | ice_eos_token: Optional[str] = '\n', 41 | prompt_eos_token: Optional[str] = '', 42 | sentence_transformers_model_name: Optional[str] = 'all-mpnet-base-v2', 43 | ice_num: Optional[int] = 1, 44 | index_split: Optional[str] = 'train', 45 | test_split: Optional[str] = 'test', 46 | tokenizer_name: Optional[str] = 'gpt2-xl', 47 | batch_size: Optional[int] = 1, 48 | votek_k: Optional[int] = 3, 49 | accelerator: Optional[Accelerator] = None, 50 | ) -> None: 51 | super().__init__(dataset_reader, ice_separator, ice_eos_token, prompt_eos_token, 52 | sentence_transformers_model_name, ice_num, index_split, test_split, tokenizer_name, batch_size, 53 | accelerator) 54 | self.votek_k = votek_k 55 | 56 | def votek_select(self, embeddings=None, select_num=None, k=None, overlap_threshold=None, vote_file=None): 57 | n = len(embeddings) 58 | if vote_file is not None and os.path.isfile(vote_file): 59 | with open(vote_file) as f: 60 | vote_stat = json.load(f) 61 | else: 62 | vote_stat = defaultdict(list) 63 | 64 | for i in range(n): 65 | cur_emb = embeddings[i].reshape(1, -1) 66 | cur_scores = np.sum(cosine_similarity(embeddings, cur_emb), axis=1) 67 | sorted_indices = np.argsort(cur_scores).tolist()[-k - 1:-1] 68 | for idx in sorted_indices: 69 | if idx != i: 70 | vote_stat[idx].append(i) 71 | 72 | if vote_file is not None: 73 | with open(vote_file, 'w') as f: 74 | json.dump(vote_stat, f) 75 | votes = sorted(vote_stat.items(), key=lambda x: len(x[1]), reverse=True) 76 | j = 0 77 | selected_indices = [] 78 | while len(selected_indices) < select_num and j < len(votes): 79 | candidate_set = set(votes[j][1]) 80 | flag = True 81 | for pre in range(j): 82 | cur_set = set(votes[pre][1]) 83 | if len(candidate_set.intersection(cur_set)) >= overlap_threshold * len(candidate_set): 84 | flag = False 85 | break 86 | if not flag: 87 | j += 1 88 | continue 89 | selected_indices.append(int(votes[j][0])) 90 | j += 1 91 | if len(selected_indices) < select_num: 92 | unselected_indices = [] 93 | cur_num = len(selected_indices) 94 | for i in range(n): 95 | if not i in selected_indices: 96 | unselected_indices.append(i) 97 | selected_indices += random.sample(unselected_indices, select_num - cur_num) 98 | return selected_indices 99 | 100 | def vote_k_search(self): 101 | vote_k_idxs = self.votek_select(embeddings=self.embed_list, select_num=self.ice_num, k=self.votek_k, 102 | overlap_threshold=1) 103 | return [vote_k_idxs[:] for _ in range(len(self.test_ds))] 104 | 105 | def retrieve(self): 106 | return self.vote_k_search() 107 | -------------------------------------------------------------------------------- /openicl/icl_retriever/icl_zero_retriever.py: -------------------------------------------------------------------------------- 1 | """Zeroshot Retriever""" 2 | 3 | from datasets import Dataset, DatasetDict 4 | from typing import List, Union, Optional, Tuple, Dict 5 | from openicl import DatasetReader, PromptTemplate 6 | from openicl.icl_retriever import BaseRetriever 7 | from openicl.utils.check_type import _check_str 8 | from accelerate import Accelerator 9 | 10 | 11 | class ZeroRetriever(BaseRetriever): 12 | """Zero In-context Learning Retriever Class 13 | Retriever for Zero-shot. 14 | 15 | Attributes: 16 | dataset_reader (:obj:`DatasetReader`): An instance of the :obj:`DatasetReader` class. 17 | ice_eos_token (:obj:`str`, optional): A string that is added to the end of in-context examples. 18 | prompt_eos_token (:obj:`str`, optional): A string that is added to the end of the prompt. 19 | index_split (:obj:`str`, optional): A string for the index dataset name. The index dataset is used to select data for in-context examples. Defaults to ``train``. 20 | test_split (:obj:`str`, optional): A string for the generation dataset name. The test dataset is used to generate prompts for each data. Defaults to ``test``. 21 | index_ds (:obj:`Dataset`): The index dataset. Used to select data for in-context examples. 22 | test_ds (:obj:`Dataset`): The test dataset. Used to generate prompts for each data. 23 | accelerator (:obj:`Accelerator`, optional): An instance of the :obj:`Accelerator` class, used for multiprocessing. 24 | """ 25 | 26 | def __init__(self, 27 | dataset_reader: DatasetReader, 28 | ice_eos_token: Optional[str] = '', 29 | prompt_eos_token: Optional[str] = '', 30 | index_split: Optional[str] = 'train', 31 | test_split: Optional[str] = 'test', 32 | accelerator: Optional[Accelerator] = None 33 | ) -> None: 34 | super().__init__(dataset_reader, '', ice_eos_token, prompt_eos_token, 0, index_split, test_split, accelerator) 35 | 36 | def retrieve(self) -> List[List]: 37 | rtr_idx_list = [[] for _ in range(len(self.test_ds))] 38 | return rtr_idx_list 39 | -------------------------------------------------------------------------------- /openicl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .logging import * 2 | -------------------------------------------------------------------------------- /openicl/utils/api_service.py: -------------------------------------------------------------------------------- 1 | import json 2 | import requests 3 | import os 4 | import openai 5 | import time 6 | import numpy as np 7 | 8 | OPENICL_API_NAME_LIST = ['opt-175b', 'gpt3'] 9 | OPENICL_API_PARAMETER_DICT = { 10 | 'opt-175b': ['URL', 'headers'], 11 | 'gpt3': ['engine', 'temperature', 'max_tokens', 'top_p', 'frequency_penalty', 'presence_penalty', 'sleep_time'] 12 | } 13 | OPENICL_API_REQUEST_CONFIG = { 14 | 'opt-175b': { 15 | 'URL': "", # http://xxx/completions or http://xxx/generate 16 | 'headers': { 17 | "Content-Type": "application/json; charset=UTF-8" 18 | } 19 | }, 20 | 'gpt3': { 21 | 'engine': "text-davinci-003", 22 | 'temperature': 0, 23 | 'max_tokens': 256, 24 | 'top_p': 1.0, 25 | 'frequency_penalty': 0.0, 26 | 'presence_penalty': 0.0, 27 | 'sleep_time': 3 28 | } 29 | } 30 | PROXIES = {"https": "", "http": ""} 31 | 32 | 33 | def is_api_available(api_name): 34 | if api_name is None: 35 | return False 36 | return True if api_name in OPENICL_API_NAME_LIST else False 37 | 38 | 39 | def update_openicl_api_request_config(api_name, **kwargs): 40 | if api_name is None or not is_api_available(api_name): 41 | return 42 | 43 | parameter_list = OPENICL_API_PARAMETER_DICT[api_name] 44 | for parameter in parameter_list: 45 | if parameter in kwargs.keys(): 46 | OPENICL_API_REQUEST_CONFIG[api_name][parameter] = kwargs[parameter] 47 | 48 | 49 | def api_get_ppl(api_name, input_texts): 50 | if api_name == 'opt-175b': 51 | pyload = {"prompt": input_texts, "max_tokens": 0, "echo": True} 52 | response = json.loads( 53 | requests.post(OPENICL_API_REQUEST_CONFIG[api_name]['URL'], data=json.dumps(pyload), 54 | headers=OPENICL_API_REQUEST_CONFIG[api_name]['headers'], proxies=PROXIES).text) 55 | lens = np.array([len(r['logprobs']['tokens']) for r in response['choices']]) 56 | ce_loss = np.array([-sum(r['logprobs']['token_logprobs']) for r in response['choices']]) 57 | return ce_loss / lens 58 | 59 | if api_name == 'gpt3': 60 | raise NotImplementedError("GPT-3 API doesn't support PPL calculation") 61 | 62 | 63 | def api_get_tokens(api_name, input_texts): 64 | length_list = [len(text) for text in input_texts] 65 | 66 | if api_name == 'opt-175b': 67 | pyload = {"prompt": input_texts, "max_tokens": 100, "echo": True} 68 | response = json.loads( 69 | requests.post(OPENICL_API_REQUEST_CONFIG[api_name]['URL'], data=json.dumps(pyload), 70 | headers=OPENICL_API_REQUEST_CONFIG[api_name]['headers'], proxies=PROXIES).text) 71 | return [r['text'] for r in response['choices']], [r['text'][length:] for r, length in 72 | zip(response['choices'], length_list)] 73 | 74 | if api_name == 'gpt3': 75 | openai.api_key = os.getenv("OPENAI_API_KEY") 76 | response = openai.Completion.create( 77 | engine=OPENICL_API_REQUEST_CONFIG['gpt3']['engine'], 78 | prompt=input_texts, 79 | temperature=OPENICL_API_REQUEST_CONFIG['gpt3']['temperature'], 80 | max_tokens=OPENICL_API_REQUEST_CONFIG['gpt3']['max_tokens'], 81 | top_p=OPENICL_API_REQUEST_CONFIG['gpt3']['top_p'], 82 | frequency_penalty=OPENICL_API_REQUEST_CONFIG['gpt3']['frequency_penalty'], 83 | presence_penalty=OPENICL_API_REQUEST_CONFIG['gpt3']['presence_penalty'] 84 | ) 85 | time.sleep(OPENICL_API_REQUEST_CONFIG['gpt3']['sleep_time']) 86 | return [(input + r['text']) for r, input in zip(response['choices'], input_texts)], [r['text'] for r in 87 | response['choices']] 88 | -------------------------------------------------------------------------------- /openicl/utils/calculate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def entropy(probs: np.array, label_dim: int = 0, mask=None): 5 | if mask is None: 6 | return - (probs * np.log(probs)).sum(label_dim) 7 | return - (mask * probs * np.log(probs)).sum(label_dim) 8 | -------------------------------------------------------------------------------- /openicl/utils/check_type.py: -------------------------------------------------------------------------------- 1 | from datasets import Dataset, DatasetDict 2 | from typing import List, Union, Dict 3 | 4 | 5 | def _check_type_list(obj, typelist: List): 6 | for _type in typelist: 7 | if _type is None: 8 | if obj is None: 9 | return obj 10 | elif isinstance(obj, _type): 11 | return obj 12 | raise TypeError( 13 | f"Expected an object in {[_.__name__ if _ is not None else None for _ in typelist]} type, but got {obj}") 14 | 15 | 16 | def _check_dataset(obj) -> Union[Dataset, DatasetDict]: 17 | if isinstance(obj, Dataset) or isinstance(obj, DatasetDict): 18 | return obj 19 | else: 20 | raise TypeError(f"Expected a datasets.Dataset or a datasets.DatasetDict object, but got {obj}") 21 | 22 | 23 | def _check_list(obj) -> List: 24 | if isinstance(obj, List): 25 | return obj 26 | else: 27 | raise TypeError(f"Expected a List object, but got {obj}") 28 | 29 | 30 | def _check_str(obj) -> str: 31 | if isinstance(obj, str): 32 | return obj 33 | else: 34 | raise TypeError(f"Expected a str object, but got {obj}") 35 | 36 | 37 | def _check_dict(obj) -> Dict: 38 | if isinstance(obj, Dict): 39 | return obj 40 | else: 41 | raise TypeError(f"Expected a Dict object, but got {obj}") 42 | -------------------------------------------------------------------------------- /openicl/utils/collators.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Dict, List, Optional, Union 3 | 4 | import torch 5 | from transformers import PreTrainedTokenizerBase, BatchEncoding 6 | from transformers.file_utils import PaddingStrategy 7 | import numpy as np 8 | 9 | 10 | class ListWrapper: 11 | def __init__(self, data: List[Any]): 12 | self.data = data 13 | 14 | def to(self, device): 15 | return self.data 16 | 17 | 18 | def ignore_pad_dict(features): 19 | res_dict = {} 20 | if "metadata" in features[0]: 21 | res_dict['metadata'] = ListWrapper([x.pop("metadata") for x in features]) 22 | return res_dict 23 | 24 | 25 | @dataclass 26 | class DataCollatorWithPaddingAndCuda: 27 | tokenizer: PreTrainedTokenizerBase 28 | device: object = None 29 | padding: Union[bool, str, PaddingStrategy] = True 30 | max_length: Optional[int] = 3000 31 | pad_to_multiple_of: Optional[int] = None 32 | 33 | def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> BatchEncoding: 34 | res_dict = ignore_pad_dict(features) 35 | 36 | has_labels = "labels" in features[0] 37 | if has_labels: 38 | labels = [{"input_ids": x.pop("labels")} for x in features] 39 | labels = self.tokenizer.pad( 40 | labels, 41 | padding=True, 42 | max_length=self.max_length, 43 | pad_to_multiple_of=self.pad_to_multiple_of, 44 | return_attention_mask=True, 45 | return_tensors="pt", 46 | verbose=False 47 | ) 48 | 49 | # print(features) 50 | batch = self.tokenizer.pad( 51 | features, 52 | padding=True, 53 | max_length=self.max_length, 54 | pad_to_multiple_of=self.pad_to_multiple_of, 55 | return_attention_mask=True, 56 | return_tensors="pt", 57 | verbose=False 58 | ) 59 | 60 | if has_labels: 61 | batch['labels'] = labels.input_ids 62 | batch.update(res_dict) 63 | 64 | if self.device: 65 | batch = batch.to(self.device) 66 | 67 | return batch 68 | -------------------------------------------------------------------------------- /openicl/utils/icl_common_utils.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from openicl.icl_retriever import BaseRetriever 3 | from typing import List, Union, Optional 4 | from openicl import PromptTemplate 5 | from accelerate import Accelerator 6 | 7 | 8 | def get_dataloader(datalist: List[List], batch_size: int) -> DataLoader: 9 | dataloader = DataLoader(datalist, batch_size=batch_size) 10 | return dataloader 11 | 12 | 13 | def get_generation_prompt_list_from_retriever_indices(ice_idx_list: List[List[int]], retriever: BaseRetriever, 14 | tokenizer, gen_field_replace_token: str, 15 | max_model_token_num: Optional[int] = None, 16 | ice_template: Optional[PromptTemplate] = None, 17 | prompt_template: Optional[PromptTemplate] = None): 18 | prompt_list = [] 19 | for idx, ice_idx in enumerate(ice_idx_list): 20 | ice = retriever.generate_ice(ice_idx, ice_template=ice_template) 21 | prompt = retriever.generate_prompt_for_generate_task(idx, ice, gen_field_replace_token=gen_field_replace_token, 22 | ice_template=ice_template, prompt_template=prompt_template) 23 | if max_model_token_num is not None and tokenizer is not None: 24 | prompt_token_num = get_input_token_num(tokenizer, prompt) 25 | while len(ice_idx) > 0 and prompt_token_num > max_model_token_num: 26 | ice_idx = ice_idx[:-1] 27 | ice = retriever.generate_ice(ice_idx, ice_template=ice_template) 28 | prompt = retriever.generate_prompt_for_generate_task(idx, ice, 29 | gen_field_replace_token=gen_field_replace_token, 30 | ice_template=ice_template, 31 | prompt_template=prompt_template) 32 | prompt_token_num = get_input_token_num(tokenizer, prompt) 33 | prompt_list.append(prompt) 34 | return prompt_list 35 | 36 | 37 | def get_input_token_num(tokenizer, input): 38 | return len(tokenizer(input, verbose=False)['input_ids']) 39 | -------------------------------------------------------------------------------- /openicl/utils/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch.distributed as dist 4 | 5 | LOG_LEVEL = logging.INFO 6 | SUBPROCESS_LOG_LEVEL = logging.ERROR 7 | LOG_FORMATTER = '[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s' 8 | 9 | 10 | def get_logger(name, level=LOG_LEVEL, log_file=None, file_mode='w'): 11 | formatter = logging.Formatter(LOG_FORMATTER) 12 | 13 | logger = logging.getLogger(name) 14 | 15 | for handler in logger.root.handlers: 16 | if type(handler) is logging.StreamHandler: 17 | handler.setLevel(logging.ERROR) 18 | 19 | if dist.is_available() and dist.is_initialized(): 20 | rank = dist.get_rank() 21 | else: 22 | rank = 0 23 | 24 | if rank == 0 and log_file is not None: 25 | file_handler = logging.FileHandler(log_file, file_mode) 26 | file_handler.setFormatter(formatter) 27 | file_handler.setLevel(level) 28 | logger.addHandler(file_handler) 29 | 30 | if rank == 0: 31 | logger.setLevel(level) 32 | else: 33 | logger.setLevel(SUBPROCESS_LOG_LEVEL) 34 | 35 | stream_handler = logging.StreamHandler() 36 | stream_handler.setFormatter(formatter) 37 | stream_handler.setLevel(level) 38 | logger.addHandler(stream_handler) 39 | 40 | return logger 41 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.28.1 2 | tokenizers==0.13.3 3 | accelerate 4 | datasets>=2.7.1 5 | evaluate>=0.3.0 6 | faiss_gpu>=1.7.2 7 | nltk>=3.8 8 | numpy>=1.23.4 9 | openai>=0.27.1 10 | rank_bm25>=0.2.2 11 | requests>=2.28.1 12 | scikit_learn>=1.2.1 13 | sentence_transformers>=2.2.2 14 | torch>=1.13.1 15 | tqdm>=4.64.1 16 | -------------------------------------------------------------------------------- /scripts/self_consistency.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname as d 2 | from os.path import abspath 3 | import sys 4 | root = d(d(abspath(__file__))) 5 | sys.path.append(root) 6 | from collections import Counter 7 | 8 | import json 9 | from openicl import DatasetReader, ZeroRetriever, PromptTemplate, TopkRetriever, GenInferencer, AccEvaluator 10 | # import fire 11 | import re 12 | from transformers import AutoTokenizer, AutoModelForCausalLM 13 | from datasets import load_dataset 14 | #from train_llm.test.proxy import proxy_on 15 | 16 | 17 | def processing_answer(str): 18 | str = str.split(' ')[::-1] 19 | flag = False 20 | ret = '' 21 | for i in range(len(str)): 22 | s = str[i] 23 | for i in range(len(s)): 24 | if s[i].isdigit(): 25 | flag = True 26 | ret = s 27 | break 28 | if flag: 29 | break 30 | ret1 = '' 31 | for i in range(len(ret)): 32 | if ret[i].isdigit(): 33 | ret1 += ret[i] 34 | return ret1 35 | 36 | 37 | def main(model_path, ice_num=4, batch_size=1, max_seq_len=2048, sc_size=5): 38 | 39 | tokenizer = AutoTokenizer.from_pretrained(model_path) 40 | model = AutoModelForCausalLM.from_pretrained(model_path) 41 | 42 | ds = load_dataset('gsm8k', 'main', split="test[:100]") 43 | print(ds) 44 | # import pdb;pdb.set_trace() 45 | def processing_test(example): 46 | example['answer'] = example['answer'].split("#### ")[1].replace(',', '') 47 | return example 48 | 49 | data = DatasetReader(ds, input_columns=['question'], output_column='answer') 50 | 51 | ref = ds.map(processing_test) 52 | 53 | #template = PromptTemplate("Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?\nA: We start with 15 trees. Later we have 21 trees. The difference must be the number of trees they planted. So, they must have planted 21 - 15 = 6 trees. The answer is 6.\n\nQ: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?\nA: There are 3 cars in the parking lot already. 2 more arrive. Now there are 3 + 2 = 5 cars. The answer is 5.\n\nQ: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?\nA: Leah had 32 chocolates and Leah's sister had 42. That means there were originally 32 + 42 = 74 chocolates. 35 have been eaten. So in total they still have 74 - 35 = 39 chocolates. The answer is 39.\n\nQ: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?\nA: Jason had 20 lollipops. Since he only has 12 now, he must have given the rest to Denny. The number of lollipops he has given to Denny must have been 20 - 12 = 8 lollipops. The answer is 8.\n\nQ: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?\nA: He has 5 toys. He got 2 from mom, so after that he has 5 + 2 = 7 toys. Then he got 2 more from dad, so in total he has 7 + 2 = 9 toys. The answer is 9.\n\nQ: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?\nA: There are 4 days from monday to thursday. 5 computers were added each day. That means in total 4 * 5 = 20 computers were added. There were 9 computers in the beginning, so now there are 9 + 20 = 29 computers. The answer is 29.\n\nQ: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?\nA: Michael initially had 58 balls. He lost 23 on Tuesday, so after that he has 58 - 23 = 35 balls. On Wednesday he lost 2 more so now he has 35 - 2 = 33 balls. The answer is 33.\n\nQ: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?\nA: She bought 5 bagels for $3 each. This means she spent 5 * $3 = $15 on the bagels. She had $23 in beginning, so now she has $23 - $15 = $8. The answer is 8.\n\nQ: \nA: ", 54 | # {'question':'', 'answer':''}, 55 | # ice_token='') 56 | #import pdb;pdb.set_trace() 57 | # prompt = open("llm_test/prompt_gsm8k_4shot.txt").readlines() 58 | # for _, line in enumerate(prompt): 59 | # if line == "Let's think step by step\n": 60 | # prompt[_] = "Let's think step by step\nAnswer:\n" 61 | # prompt = ''.join(prompt) 62 | prompt = """ 63 | Question: Angelo and Melanie want to plan how many hours over the next week they should study together for their test next week. They have 2 chapters of their textbook to study and 4 worksheets to memorize. They figure out that they should dedicate 3 hours to each chapter of their textbook and 1.5 hours for each worksheet. If they plan to study no more than 4 hours each day, how many days should they plan to study total over the next week if they take a 10-minute break every hour, include 3 10-minute snack breaks each day, and 30 minutes for lunch each day? 64 | Let's think step by step 65 | Angelo and Melanie think they should dedicate 3 hours to each of the 2 chapters, 3 hours x 2 chapters = 6 hours total. 66 | For the worksheets they plan to dedicate 1.5 hours for each worksheet, 1.5 hours x 4 worksheets = 6 hours total. 67 | Angelo and Melanie need to start with planning 12 hours to study, at 4 hours a day, 12 / 4 = 3 days. 68 | However, they need to include time for breaks and lunch. Every hour they want to include a 10-minute break, so 12 total hours x 10 minutes = 120 extra minutes for breaks. 69 | They also want to include 3 10-minute snack breaks, 3 x 10 minutes = 30 minutes. 70 | And they want to include 30 minutes for lunch each day, so 120 minutes for breaks + 30 minutes for snack breaks + 30 minutes for lunch = 180 minutes, or 180 / 60 minutes per hour = 3 extra hours. 71 | So Angelo and Melanie want to plan 12 hours to study + 3 hours of breaks = 15 hours total. 72 | They want to study no more than 4 hours each day, 15 hours / 4 hours each day = 3.75 73 | They will need to plan to study 4 days to allow for all the time they need. 74 | The answer is 4 75 | 76 | Question: Mark's basketball team scores 25 2 pointers, 8 3 pointers and 10 free throws. Their opponents score double the 2 pointers but half the 3 pointers and free throws. What's the total number of points scored by both teams added together? 77 | Let's think step by step 78 | Mark's team scores 25 2 pointers, meaning they scored 25*2= 50 points in 2 pointers. 79 | His team also scores 6 3 pointers, meaning they scored 8*3= 24 points in 3 pointers 80 | They scored 10 free throws, and free throws count as one point so they scored 10*1=10 points in free throws. 81 | All together his team scored 50+24+10= 84 points 82 | Mark's opponents scored double his team's number of 2 pointers, meaning they scored 50*2=100 points in 2 pointers. 83 | His opponents scored half his team's number of 3 pointers, meaning they scored 24/2= 12 points in 3 pointers. 84 | They also scored half Mark's team's points in free throws, meaning they scored 10/2=5 points in free throws. 85 | All together Mark's opponents scored 100+12+5=117 points 86 | The total score for the game is both team's scores added together, so it is 84+117=201 points 87 | The answer is 201 88 | 89 | Question: Bella has two times as many marbles as frisbees. She also has 20 more frisbees than deck cards. If she buys 2/5 times more of each item, what would be the total number of the items she will have if she currently has 60 marbles? 90 | Let's think step by step 91 | When Bella buys 2/5 times more marbles, she'll have increased the number of marbles by 2/5*60 = 24 92 | The total number of marbles she'll have is 60+24 = 84 93 | If Bella currently has 60 marbles, and she has two times as many marbles as frisbees, she has 60/2 = 30 frisbees. 94 | If Bella buys 2/5 times more frisbees, she'll have 2/5*30 = 12 more frisbees. 95 | The total number of frisbees she'll have will increase to 30+12 = 42 96 | Bella also has 20 more frisbees than deck cards, meaning she has 30-20 = 10 deck cards 97 | If she buys 2/5 times more deck cards, she'll have 2/5*10 = 4 more deck cards. 98 | The total number of deck cards she'll have is 10+4 = 14 99 | Together, Bella will have a total of 14+42+84 = 140 items 100 | The answer is 140 101 | 102 | Question: A group of 4 fruit baskets contains 9 apples, 15 oranges, and 14 bananas in the first three baskets and 2 less of each fruit in the fourth basket. How many fruits are there? 103 | Let's think step by step 104 | For the first three baskets, the number of apples and oranges in one basket is 9+15=24 105 | In total, together with bananas, the number of fruits in one basket is 24+14=38 for the first three baskets. 106 | Since there are three baskets each having 38 fruits, there are 3*38=114 fruits in the first three baskets. 107 | The number of apples in the fourth basket is 9-2=7 108 | There are also 15-2=13 oranges in the fourth basket 109 | The combined number of oranges and apples in the fourth basket is 13+7=20 110 | The fourth basket also contains 14-2=12 bananas. 111 | In total, the fourth basket has 20+12=32 fruits. 112 | The four baskets together have 32+114=146 fruits. 113 | The answer is 146 114 | 115 | """ 116 | 117 | template = PromptTemplate(f"{prompt}Question: \nLet's think step by step\n", 118 | {'question':'', 'answer':''}, 119 | ice_token='') 120 | 121 | retriever = ZeroRetriever(data) 122 | all_predictions = [] 123 | 124 | # generation_kwargs = dict(max_new_tokens=512, do_sample=True, temperature=0.7, top_k=40) 125 | generation_kwargs = dict(max_new_tokens=512) 126 | # {"max_gen_len": 512, "do_sample": True, "temperature": 0.8, "top_p": 0.8} 127 | for i in range(sc_size): 128 | print("**"*50) 129 | print("\t\t\tIteration:", str(i)) 130 | print("**"*50) 131 | inferencer = GenInferencer(model_name=model, tokenizer_name=tokenizer, generation_kwargs=generation_kwargs, 132 | batch_size=batch_size, output_json_filepath=model_path.split('/')[-2], output_json_filename="gsm8k_"+str(i)) 133 | predictions = inferencer.inference(retriever, ice_template=template) 134 | print(predictions[:2]) 135 | predictions = [processing_answer(pred.split('\n\n')[0]) for pred in predictions] 136 | # print("**"*50) 137 | # print("\t\t\tProcessed prediction at iteration:", str(i)) 138 | # print("**"*50) 139 | # print(predictions[:2]) 140 | all_predictions.append(predictions) 141 | #import json 142 | # file = json.load(open("llm_llama/gsm8k.json")) 143 | # predictions = [file[str(i)]['prediction'] for i in range(len(file.keys()))] 144 | assert len(all_predictions) == sc_size 145 | final_prediction = [] 146 | for i in range(len(all_predictions[0])): 147 | tmp_preds = [] 148 | for j in range(sc_size): 149 | tmp_preds.append(all_predictions[j][i]) 150 | counter = Counter(tmp_preds) 151 | if i < 5: 152 | print(counter) 153 | final_prediction.append(counter.most_common(1)[0][0]) 154 | 155 | #import pdb;pdb.set_trace() 156 | print(final_prediction[:5], ref['answer'][:5]) 157 | score = AccEvaluator().score(predictions=final_prediction, references=ref['answer']) 158 | print(score) 159 | 160 | 161 | if __name__ == "__main__": 162 | # fire.Fire(main) 163 | 164 | # replace with your model_path or huggingface model name here 165 | main(model_path="decapoda-research/llama-7b-hf", 166 | ice_num=4, batch_size=8, sc_size=1) 167 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from setuptools.command.install import install 3 | 4 | 5 | class DownloadNLTK(install): 6 | def run(self): 7 | self.do_egg_install() 8 | import nltk 9 | nltk.download('punkt') 10 | 11 | 12 | REQUIRES = """ 13 | transformers 14 | accelerate 15 | datasets>=2.7.1 16 | evaluate>=0.3.0 17 | faiss_gpu>=1.7.2 18 | nltk>=3.8 19 | numpy>=1.23.4 20 | openai>=0.27.1 21 | rank_bm25>=0.2.2 22 | requests>=2.28.1 23 | scikit_learn>=1.2.1 24 | sentence_transformers>=2.2.2 25 | torch>=1.13.1 26 | tqdm>=4.64.1 27 | """ 28 | 29 | 30 | def get_install_requires(): 31 | reqs = [req for req in REQUIRES.split("\n") if len(req) > 0] 32 | return reqs 33 | 34 | 35 | with open("README.md") as f: 36 | readme = f.read() 37 | 38 | 39 | def do_setup(): 40 | setup( 41 | name="openicl", 42 | version='0.1.7', 43 | description="An open source framework for in-context learning.", 44 | url="https://github.com/Shark-NLP/OpenICL", 45 | author='Zhenyu Wu, Yaoxiang Wang, Zhiyong Wu, Jiacheng Ye', 46 | long_description=readme, 47 | long_description_content_type="text/markdown", 48 | cmdclass={'download_nltk': DownloadNLTK}, 49 | install_requires=get_install_requires(), 50 | setup_requires=['nltk==3.8'], 51 | python_requires=">=3.8.0", 52 | packages=find_packages( 53 | exclude=[ 54 | "test*", 55 | "paper_test*" 56 | ] 57 | ), 58 | keywords=["AI", "NLP", "in-context learning"], 59 | classifiers=[ 60 | "Programming Language :: Python :: 3.8", 61 | "Programming Language :: Python :: 3.9", 62 | "Programming Language :: Python :: 3.10", 63 | "Intended Audience :: Developers", 64 | "Intended Audience :: Education", 65 | "Intended Audience :: Science/Research", 66 | ] 67 | ) 68 | 69 | 70 | if __name__ == "__main__": 71 | do_setup() 72 | --------------------------------------------------------------------------------