├── .gitignore
├── .readthedocs.yaml
├── LICENSE
├── README.md
├── docs
├── Makefile
├── make.bat
├── requirements.txt
└── source
│ ├── conf.py
│ ├── img
│ ├── openicl.png
│ └── overview.jpg
│ ├── index.rst
│ ├── modules
│ ├── dataset_reader.rst
│ ├── inferencer
│ │ ├── base_inferencer.rst
│ │ ├── cot_inferencer.rst
│ │ ├── gen_inferencer.rst
│ │ ├── inferencer_link.rst
│ │ └── ppl_inferencer.rst
│ ├── prompt_template.rst
│ └── retriever
│ │ ├── base_retriever.rst
│ │ ├── bm25_retriever.rst
│ │ ├── dpp_retriever.rst
│ │ ├── mdl_retriever.rst
│ │ ├── random_retriever.rst
│ │ ├── retriever_link.rst
│ │ ├── topk_retriever.rst
│ │ ├── votek_retriever.rst
│ │ └── zero_retriever.rst
│ └── notes
│ ├── example.rst
│ ├── installation.rst
│ └── tutorial.rst
├── examples
├── README.md
├── research_projects
│ ├── README.md
│ └── self-adaptive_in-context_learning.ipynb
└── tutorials
│ ├── README.md
│ ├── openicl_tutorial1_getting_started.ipynb
│ ├── openicl_tutorial2_use_different_models.ipynb
│ └── openicl_tutorial3_accelerate.ipynb
├── openicl
├── __init__.py
├── icl_dataset_reader.py
├── icl_evaluator
│ ├── __init__.py
│ ├── icl_acc_evaluator.py
│ ├── icl_api_evaluator.py
│ ├── icl_base_evaluator.py
│ ├── icl_bleu_evaluator.py
│ ├── icl_rouge_evaluator.py
│ └── icl_squad_evaluator.py
├── icl_inferencer
│ ├── __init__.py
│ ├── icl_base_inferencer.py
│ ├── icl_channel_inferencer.py
│ ├── icl_cot_inferencer.py
│ ├── icl_gen_inferencer.py
│ └── icl_ppl_inferencer.py
├── icl_prompt_template.py
├── icl_retriever
│ ├── __init__.py
│ ├── icl_base_retriever.py
│ ├── icl_bm25_retriever.py
│ ├── icl_dpp_retriever.py
│ ├── icl_mdl_retriever.py
│ ├── icl_random_retriever.py
│ ├── icl_topk_retriever.py
│ ├── icl_votek_retriever.py
│ └── icl_zero_retriever.py
└── utils
│ ├── __init__.py
│ ├── api_service.py
│ ├── calculate.py
│ ├── check_type.py
│ ├── collators.py
│ ├── icl_common_utils.py
│ └── logging.py
├── requirements.txt
├── scripts
└── self_consistency.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Compiled python modules.
2 | *.pyc
3 |
4 | # Byte-compiled
5 | __pycache__/
6 | .cache/
7 |
8 | # Test files
9 | test.py
10 | test.sh
11 |
12 | # Output files
13 | icl_inference_output/
14 | batchscript-*
15 | output/
16 |
17 | # Packages
18 | dist/
19 | .eggs/
20 | openicl.egg-info/
21 | build/
22 | docs/build/
23 | *.xml
24 | *.iml
25 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | build:
4 | os: "ubuntu-22.04"
5 | tools:
6 | python: "3.8"
7 |
8 | python:
9 | install:
10 | - requirements: docs/requirements.txt
11 |
12 | formats: all
13 |
14 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 |
5 | ------
6 |
7 |
8 | Overview •
9 | Installation •
10 | Paper •
11 | Examples •
12 | Docs •
13 | Citation
14 |
15 |
16 | 
17 |
18 |
19 | ## Overview
20 | OpenICL provides an easy interface for in-context learning, with many state-of-the-art retrieval and inference methods built in to facilitate systematic comparison of LMs and fast research prototyping. Users can easily incorporate different retrieval and inference methods, as well as different prompt instructions into their workflow.
21 |
22 |

23 |
24 |
25 | ## What's News
26 | + **v0.1.8** Support LLaMA and self-consistency
27 |
28 | ## Installation
29 | Note: OpenICL requires Python 3.8+
30 |
31 | **Using Pip**
32 | ```
33 | pip install openicl
34 | ```
35 |
36 |
37 | **Installation for local development:**
38 | ```
39 | git clone https://github.com/Shark-NLP/OpenICL
40 | cd OpenICL
41 | pip install -e .
42 | ```
43 |
44 | ## Quick Start
45 | Following example shows you how to perform ICL on sentiment classification dataset. More examples and tutorials can be found at [examples](https://github.com/Shark-NLP/OpenICL/tree/main/examples)
46 |
47 | #### Step 1: Load and prepare data
48 | ```python
49 | from datasets import load_dataset
50 | from openicl import DatasetReader
51 |
52 | # Loading dataset from huggingface
53 | dataset = load_dataset('gpt3mix/sst2')
54 |
55 | # Define a DatasetReader, with specified column names where input and output are stored.
56 | data = DatasetReader(dataset, input_columns=['text'], output_column='label')
57 | ```
58 |
59 | #### Step 2: Define the prompt template (Optional)
60 | ```python
61 | from openicl import PromptTemplate
62 | tp_dict = {
63 | 0: "Positive Movie Review: ",
64 | 1: "Negative Movie Review: "
65 | }
66 |
67 | template = PromptTemplate(tp_dict, {'text': ''}, ice_token='')
68 | ```
69 | The placeholder `` and `` will be replaced by in-context examples and testing input, respectively. For more detailed information about `PromptTemplate` (such as string-type template) , please see [tutorial1](https://github.com/Shark-NLP/OpenICL/blob/main/examples/tutorials/openicl_tutorial1_getting_started.ipynb).
70 |
71 | #### Step 3: Initialize the Retriever
72 | ```python
73 | from openicl import TopkRetriever
74 | # Define a retriever using the previous `DataLoader`.
75 | # `ice_num` stands for the number of data in in-context examples.
76 | retriever = TopkRetriever(data, ice_num=8)
77 | ```
78 | Here we use the popular TopK method to build the retriever.
79 |
80 | #### Step 4: Initialize the Inferencer
81 | ```python
82 | from openicl import PPLInferencer
83 | inferencer = PPLInferencer(model_name='distilgpt2')
84 | ```
85 |
86 | #### Step 5: Inference and scoring
87 | ```python
88 | from openicl import AccEvaluator
89 | # the inferencer requires retriever to collect in-context examples, as well as a template to wrap up these examples.
90 | predictions = inferencer.inference(retriever, ice_template=template)
91 | # compute accuracy for the prediction
92 | score = AccEvaluator().score(predictions=predictions, references=data.references)
93 | print(score)
94 | ```
95 |
96 |
97 |
98 | ## Docs
99 | **(updating...)**
100 |
101 | [OpenICL Documentation](https://openicl.readthedocs.io/en/latest/index.html)
102 |
103 | ## Citation
104 | If you find this repository helpful, feel free to cite our paper:
105 | ```bibtex
106 | @article{wu2023openicl,
107 | title={OpenICL: An Open-Source Framework for In-context Learning},
108 | author={Zhenyu Wu, Yaoxiang Wang, Jiacheng Ye, Jiangtao Feng, Jingjing Xu, Yu Qiao, Zhiyong Wu},
109 | journal={arXiv preprint arXiv:2303.02913},
110 | year={2023}
111 | }
112 | ```
113 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx-rtd-theme
2 | accelerate==0.15.0
3 | datasets==2.7.1
4 | evaluate==0.3.0
5 | faiss_gpu==1.7.2
6 | nltk==3.8
7 | numpy==1.23.4
8 | openai==0.27.1
9 | rank_bm25==0.2.2
10 | requests==2.28.1
11 | scikit_learn==1.2.1
12 | sentence_transformers==2.2.2
13 | torch==1.13.1
14 | tqdm==4.64.1
15 | transformers==4.24.0
16 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # -- Path setup --------------------------------------------------------------
2 |
3 | # If extensions (or modules to document with autodoc) are in another directory,
4 | # add these directories to sys.path here. If the directory is relative to the
5 | # documentation root, use os.path.abspath to make it absolute, like shown here.
6 | #
7 | import os
8 | import sys
9 | sys.path.insert(0, os.path.abspath('../../'))
10 | import sphinx_rtd_theme
11 | import datetime
12 |
13 |
14 | # -- Project information -----------------------------------------------------
15 |
16 | project = 'OpenICL'
17 | author = 'Shanghai AI Lab, SharkNLP'
18 | copyright = '{}, {}, Licenced under the Apache License, Version 2.0'.format(datetime.datetime.now().year, author)
19 |
20 | # The full version, including alpha/beta/rc tags
21 | release = 'v0.1.6'
22 | version = 'v0.1.6'
23 |
24 |
25 | # -- General configuration ---------------------------------------------------
26 |
27 | # Add any Sphinx extension module names here, as strings. They can be
28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
29 | # ones.
30 | extensions = [
31 | 'sphinx.ext.autodoc',
32 | 'sphinx.ext.autosummary',
33 | 'sphinx.ext.doctest',
34 | 'sphinx.ext.intersphinx',
35 | 'sphinx.ext.mathjax',
36 | 'sphinx.ext.napoleon',
37 | 'sphinx.ext.viewcode',
38 | 'sphinx.ext.githubpages',
39 | ]
40 |
41 | # Add any paths that contain templates here, relative to this directory.
42 | templates_path = ['_templates']
43 |
44 | # List of patterns, relative to source directory, that match files and
45 | # directories to ignore when looking for source files.
46 | # This pattern also affects html_static_path and html_extra_path.
47 | exclude_patterns = []
48 |
49 |
50 | # -- Options for HTML output -------------------------------------------------
51 |
52 | # The theme to use for HTML and HTML Help pages. See the documentation for
53 | # a list of builtin themes.
54 | #
55 | html_theme = 'sphinx_rtd_theme'
56 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
57 |
58 | # Add any paths that contain custom static files (such as style sheets) here,
59 | # relative to this directory. They are copied after the builtin static files,
60 | # so a file named "default.css" will overwrite the builtin "default.css".
61 | html_static_path = ['_static']
--------------------------------------------------------------------------------
/docs/source/img/openicl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shark-NLP/OpenICL/1613ae10b88ba2dbfed425c4ee078b2a6586152e/docs/source/img/openicl.png
--------------------------------------------------------------------------------
/docs/source/img/overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Shark-NLP/OpenICL/1613ae10b88ba2dbfed425c4ee078b2a6586152e/docs/source/img/overview.jpg
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/Shark-NLP/OpenICL
2 |
3 | OpenICL Documentation
4 | ===================================
5 | .. figure:: img/openicl.png
6 |
7 | **OpenICL** is an open-source framework to facilitate research, development, and prototyping of in-context learning.
8 |
9 | .. figure:: img/overview.jpg
10 | :align: center
11 |
12 | Overview of the architecture in OpenICL.
13 |
14 | It provides an easy interface for in-context learning, with many state-of-the-art retrieval and inference methods built in to facilitate systematic comparison of LMs and fast research prototyping. Users can easily incorporate different retrieval and inference methods, as well as different prompt instructions into their workflow.
15 |
16 | .. note::
17 |
18 | This project is under active development.
19 |
20 |
21 | Citation
22 | --------
23 |
24 | If you find this repository helpful, feel free to cite our paper:
25 |
26 | .. code-block:: bibtex
27 |
28 | @article{wu2023openicl,
29 | title={OpenICL: An Open-Source Framework for In-context Learning},
30 | author={Zhenyu Wu, Yaoxiang Wang, Jiacheng Ye, Jiangtao Feng, Jingjing Xu, Yu Qiao, Zhiyong Wu},
31 | journal={arXiv preprint arXiv:2303.02913},
32 | year={2023}
33 | }
34 |
35 | .. toctree::
36 | :glob:
37 | :maxdepth: 3
38 | :caption: Getting Started
39 |
40 | notes/installation
41 | notes/example
42 |
43 | .. toctree::
44 | :glob:
45 | :maxdepth: 3
46 | :caption: Tutorials
47 |
48 | notes/tutorial
49 |
50 |
51 | .. toctree::
52 | :glob:
53 | :maxdepth: 2
54 | :caption: Modules
55 |
56 | modules/dataset_reader
57 | modules/prompt_template
58 | modules/inferencer/inferencer_link
59 | modules/retriever/retriever_link
60 |
61 | Indices and tables
62 | ==================
63 |
64 | * :ref:`genindex`
65 | * :ref:`search`
66 |
--------------------------------------------------------------------------------
/docs/source/modules/dataset_reader.rst:
--------------------------------------------------------------------------------
1 | DatasetReader
2 | =============
3 |
4 | .. autoclass:: openicl.DatasetReader
5 | :inherited-members:
--------------------------------------------------------------------------------
/docs/source/modules/inferencer/base_inferencer.rst:
--------------------------------------------------------------------------------
1 | BaseInferencer
2 | ==============
3 |
4 | .. autoclass:: openicl.icl_inferencer.icl_base_inferencer.BaseInferencer
5 | :inherited-members:
6 |
--------------------------------------------------------------------------------
/docs/source/modules/inferencer/cot_inferencer.rst:
--------------------------------------------------------------------------------
1 | CoTInferencer
2 | =============
3 |
4 | .. autoclass:: openicl.icl_inferencer.icl_cot_inferencer.CoTInferencer
5 | :inherited-members:
6 |
--------------------------------------------------------------------------------
/docs/source/modules/inferencer/gen_inferencer.rst:
--------------------------------------------------------------------------------
1 | GenInferencer
2 | =============
3 |
4 | .. autoclass:: openicl.icl_inferencer.icl_gen_inferencer.GenInferencer
5 | :inherited-members:
6 |
--------------------------------------------------------------------------------
/docs/source/modules/inferencer/inferencer_link.rst:
--------------------------------------------------------------------------------
1 | Inferencer
2 | ==========
3 |
4 | .. toctree::
5 | :glob:
6 | :maxdepth: 3
7 | :caption: Inferencer Classes
8 |
9 | base_inferencer.rst
10 | ppl_inferencer.rst
11 | gen_inferencer.rst
12 | cot_inferencer.rst
13 |
--------------------------------------------------------------------------------
/docs/source/modules/inferencer/ppl_inferencer.rst:
--------------------------------------------------------------------------------
1 | PPLInferencer
2 | =============
3 |
4 | .. autoclass:: openicl.icl_inferencer.icl_ppl_inferencer.PPLInferencer
5 | :inherited-members:
6 |
--------------------------------------------------------------------------------
/docs/source/modules/prompt_template.rst:
--------------------------------------------------------------------------------
1 | PromptTemplate
2 | ===============
3 |
4 | .. autoclass:: openicl.PromptTemplate
5 | :inherited-members:
--------------------------------------------------------------------------------
/docs/source/modules/retriever/base_retriever.rst:
--------------------------------------------------------------------------------
1 | BaseRetriever
2 | =============
3 |
4 | .. autoclass:: openicl.icl_retriever.icl_base_retriever.BaseRetriever
5 | :inherited-members:
6 |
--------------------------------------------------------------------------------
/docs/source/modules/retriever/bm25_retriever.rst:
--------------------------------------------------------------------------------
1 | BM25Retriever
2 | =============
3 |
4 | .. autoclass:: openicl.icl_retriever.icl_bm25_retriever.BM25Retriever
5 | :inherited-members:
6 |
--------------------------------------------------------------------------------
/docs/source/modules/retriever/dpp_retriever.rst:
--------------------------------------------------------------------------------
1 | DPPRetriever
2 | ============
3 |
4 | .. autoclass:: openicl.icl_retriever.icl_dpp_retriever.DPPRetriever
5 | :inherited-members:
6 |
--------------------------------------------------------------------------------
/docs/source/modules/retriever/mdl_retriever.rst:
--------------------------------------------------------------------------------
1 | MDLRetriever
2 | ============
3 |
4 | .. autoclass:: openicl.icl_retriever.icl_mdl_retriever.MDLRetriever
5 | :inherited-members:
6 |
--------------------------------------------------------------------------------
/docs/source/modules/retriever/random_retriever.rst:
--------------------------------------------------------------------------------
1 | RandomRetriever
2 | ===============
3 |
4 | .. autoclass:: openicl.icl_retriever.icl_random_retriever.RandomRetriever
5 | :inherited-members:
6 |
--------------------------------------------------------------------------------
/docs/source/modules/retriever/retriever_link.rst:
--------------------------------------------------------------------------------
1 | Retriever
2 | ==========
3 |
4 | .. toctree::
5 | :glob:
6 | :maxdepth: 3
7 | :caption: Retriever Classes
8 |
9 | base_retriever
10 | random_retriever
11 | bm25_retriever
12 | topk_retriever
13 | votek_retriever
14 | dpp_retriever
15 | mdl_retriever
16 | zero_retriever
17 |
--------------------------------------------------------------------------------
/docs/source/modules/retriever/topk_retriever.rst:
--------------------------------------------------------------------------------
1 | TopkRetriever
2 | =============
3 |
4 | .. autoclass:: openicl.icl_retriever.icl_topk_retriever.TopkRetriever
5 | :inherited-members:
6 |
--------------------------------------------------------------------------------
/docs/source/modules/retriever/votek_retriever.rst:
--------------------------------------------------------------------------------
1 | VotekRetriever
2 | ==============
3 |
4 | .. autoclass:: openicl.icl_retriever.icl_votek_retriever.VotekRetriever
5 | :inherited-members:
6 |
--------------------------------------------------------------------------------
/docs/source/modules/retriever/zero_retriever.rst:
--------------------------------------------------------------------------------
1 | ZeroRetriever
2 | =============
3 |
4 | .. autoclass:: openicl.icl_retriever.icl_zero_retriever.ZeroRetriever
5 | :inherited-members:
6 |
--------------------------------------------------------------------------------
/docs/source/notes/example.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/Shark-NLP/OpenICL
2 |
3 | A Simple Example
4 | ========================
5 |
6 | Following example shows you how to perform ICL on sentiment classification dataset. More examples and tutorials can be found at our github `repository `_.
7 |
8 | Step 1: Load and prepare data
9 | ----------------------------------
10 |
11 | .. code-block:: python
12 |
13 | from datasets import load_dataset
14 | from openicl import DatasetReader
15 |
16 | # Loading dataset from huggingface
17 | dataset = load_dataset('gpt3mix/sst2')
18 |
19 | # Define a DatasetReader, with specified column names where input and output are stored.
20 | data = DatasetReader(dataset, input_columns=['text'], output_column='label')
21 |
22 |
23 | Step 2: Define the prompt template (Optional)
24 | ---------------------------------------------
25 |
26 | .. code-block:: python
27 |
28 | from openicl import PromptTemplate
29 | tp_dict = {
30 | 0: "Positive Movie Review: ",
31 | 1: "Negative Movie Review: "
32 | }
33 |
34 | template = PromptTemplate(tp_dict, {'text': ''}, ice_token='')
35 |
36 | The placeholder `` and `` will be replaced by in-context examples and testing input, respectively.
37 |
38 | Step 3: Initialize the Retriever
39 | --------------------------------
40 |
41 | .. code-block:: python
42 |
43 | from openicl import TopkRetriever
44 | # Define a retriever using the previous `DataLoader`.
45 | # `ice_num` stands for the number of data in in-context examples.
46 | retriever = TopkRetriever(data, ice_num=8)
47 |
48 | Here we use the popular `TopK `_ method to build the retriever.
49 |
50 | Step 4: Initialize the Inferencer
51 | ---------------------------------
52 |
53 | .. code-block:: python
54 |
55 | from openicl import PPLInferencer
56 | inferencer = PPLInferencer(model_name='distilgpt2')
57 |
58 | Step 5: Inference and scoring
59 | -----------------------------
60 |
61 | .. code-block:: python
62 |
63 | from openicl import AccEvaluator
64 | # the inferencer requires retriever to collect in-context examples, as well as a template to wrap up these examples.
65 | predictions = inferencer.inference(retriever, ice_template=template)
66 | # compute accuracy for the prediction
67 | score = AccEvaluator().score(predictions=predictions, references=data.references)
68 | print(score)
69 |
--------------------------------------------------------------------------------
/docs/source/notes/installation.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/Shark-NLP/OpenICL
2 |
3 | Installation
4 | ========================
5 |
6 | .. note::
7 |
8 | OpenICL requires `Python 3.8+ `_
9 |
10 | Using Pip
11 | ----------------------------------
12 | .. code-block:: bash
13 |
14 | pip install openicl
15 |
16 |
17 | Installation for local development
18 | ----------------------------------
19 |
20 | You can install the latest version of OpenICL from our github `repository `_.
21 |
22 | .. code-block:: bash
23 |
24 | git clone https://github.com/Shark-NLP/OpenICL
25 | cd OpenICL
26 | pip install -e .
27 |
--------------------------------------------------------------------------------
/docs/source/notes/tutorial.rst:
--------------------------------------------------------------------------------
1 | :github_url: https://github.com/Shark-NLP/OpenICL
2 |
3 | Tutorials List
4 | ===============
5 |
6 | tutorials can be found at our github `repository `_.
7 |
8 | .. note::
9 |
10 | Tutorials are still being updated.
11 |
12 | .. list-table::
13 | :header-rows: 1
14 |
15 | * - Notebook
16 | - Description
17 | - Colab Link
18 | * - `Getting Started with OpenICL `_
19 | - Introduction to the main components of OpenICL
20 | - .. image:: https://colab.research.google.com/assets/colab-badge.svg
21 | :target: https://colab.research.google.com/github/Shark-NLP/OpenICL/blob/main/examples/tutorials/openicl_tutorial1_getting_started.ipynb
22 |
23 | * - `Using Different Language Models with OpenICL `_
24 | - Run different language models with OpenICL
25 | - .. image:: https://colab.research.google.com/assets/colab-badge.svg
26 | :target: https://colab.research.google.com/github/Shark-NLP/OpenICL/blob/main/examples/tutorials/openicl_tutorial2_use_different_models.ipynb
27 |
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | # Examples
2 |
3 | We host a wide range of [Tutorials](https://github.com/Shark-NLP/OpenICL/tree/main/examples/tutorials) to elaborate the basic usage of OpenICL.
4 |
5 | We also have some [research projects](https://github.com/Shark-NLP/OpenICL/tree/main/examples/research_projects) that reproduce results in research papers using OpenICL.
6 |
7 | Please discuss in an [issue](https://github.com/Shark-NLP/OpenICL/issues) a feature you would
8 | like to implement in an example before submitting a PR; we welcome bug fixes,
9 | but since we want to keep the examples as simple as possible it's unlikely
10 | that we will merge a pull request adding more functionality at the cost of readability.
--------------------------------------------------------------------------------
/examples/research_projects/README.md:
--------------------------------------------------------------------------------
1 | # Research Projects
2 | Here, you can find the code to reproduce some ICL-related paper experiments using OpenICL(**updating...**)
3 |
4 | | Notebook | Paper | |
5 | |:----------|:-------------|:-------------|
6 | [self-adaptive in-context learning](https://github.com/Shark-NLP/OpenICL/blob/main/examples/research_projects/self-adaptive_in-context_learning.ipynb) | [self-adaptive in-context learning](https://arxiv.org/abs/2212.10375) | [](https://colab.research.google.com/github/Shark-NLP/OpenICL/blob/main/examples/research_projects/self-adaptive_in-context_learning.ipynb) |
--------------------------------------------------------------------------------
/examples/research_projects/self-adaptive_in-context_learning.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "ec91d2ba-7ca4-4bd4-a6f5-c78a5b321179",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "%pip install --upgrade pip\n",
11 | "%pip install openicl"
12 | ]
13 | },
14 | {
15 | "attachments": {},
16 | "cell_type": "markdown",
17 | "id": "c9acf526-7e21-4f95-9a94-fb3b00973e8d",
18 | "metadata": {},
19 | "source": [
20 | "# Self-adaptive In-context Learning\n",
21 | "---\n",
22 | "Code for paper [Self-adaptive In-context Learning](https://arxiv.org/abs/2212.10375)"
23 | ]
24 | },
25 | {
26 | "cell_type": "markdown",
27 | "id": "fd2e9c3a-c8eb-4a03-827f-253ed84b0c7e",
28 | "metadata": {},
29 | "source": [
30 | "## Templates "
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": 1,
36 | "id": "73a948c3-aa16-4b1e-9bc3-c6a6c0fdb7e3",
37 | "metadata": {},
38 | "outputs": [],
39 | "source": [
40 | "from openicl import PromptTemplate\n",
41 | "\n",
42 | "# SST-2\n",
43 | "sst2_tp_dict = {\n",
44 | " 0: 'Positive Movie Review: \\\"\\\"', \n",
45 | " 1: 'Negative Movie Review: \\\"\\\"',\n",
46 | "}\n",
47 | "sst2_template = PromptTemplate(sst2_tp_dict, column_token_map={'text' : ''}, ice_token='')\n",
48 | "\n",
49 | "# SST-5\n",
50 | "sst5_tp_dict = {\n",
51 | " 0: \"Review: \\nSentiment: terrible\",\n",
52 | " 1: \"Review: \\nSentiment: bad\",\n",
53 | " 2: \"Review: \\nSentiment: okay\",\n",
54 | " 3: \"Review: \\nSentiment: good\",\n",
55 | " 4: \"Review: \\nSentiment: great\",\n",
56 | "}\n",
57 | "sst5_template = PromptTemplate(sst5_tp_dict, column_token_map={'text' : ''}, ice_token='')\n",
58 | "\n",
59 | "# AG_NEWS\n",
60 | "ag_news_tp_dict = {\n",
61 | " 0: \"\\\"\\\" It is about world.\",\n",
62 | " 1: \"\\\"\\\" It is about sports.\",\n",
63 | " 2: \"\\\"\\\" It is about business.\",\n",
64 | " 3: \"\\\"\\\" It is about science and technology.\",\n",
65 | "}\n",
66 | "ag_news_template = PromptTemplate(ag_news_tp_dict, column_token_map={'text' : ''}, ice_token='')\n",
67 | "\n",
68 | "# TREC\n",
69 | "trec_tp_dict = {\n",
70 | " 0: \"\\\"\\\" It is about abbreviation.\",\n",
71 | " 1: \"\\\"\\\" It is about entity.\",\n",
72 | " 2: \"\\\"\\\" It is about description and abstract concept.\",\n",
73 | " 3: \"\\\"\\\" It is about human being.\",\n",
74 | " 4: \"\\\"\\\" It is about location.\",\n",
75 | " 5: \"\\\"\\\" It is about numeric value.\"\n",
76 | "}\n",
77 | "trec_template = PromptTemplate(trec_tp_dict, column_token_map={'text' : ''}, ice_token='')\n",
78 | "\n",
79 | "# SNLI & MNLI\n",
80 | "xnli_tp_dict = {\n",
81 | " 0: '? Yes, ',\n",
82 | " 1: '? Maybe, ',\n",
83 | " 2: '? No, '\n",
84 | "}\n",
85 | "xnli_template = PromptTemplate(xnli_tp_dict, column_token_map={'premise' : '', 'hypothesis' : ''}, ice_token='')\n",
86 | "\n",
87 | "# QNLI \n",
88 | "qnli_tp_dict = {\n",
89 | " 0: \" Can we know ? Yes.\",\n",
90 | " 1: \" Can we know ? No.\",\n",
91 | "}\n",
92 | "qnli_template = PromptTemplate(qnli_tp_dict, column_token_map={'sentence' : '', 'question' : ''}, ice_token='')\n",
93 | "\n",
94 | "# Commonsense QA\n",
95 | "cmsqa_template=PromptTemplate(\n",
96 | " {\n",
97 | " 'A': \"Answer the following question:\\n\\nAnswer: \",\n",
98 | " 'B': \"Answer the following question:\\n\\nAnswer: \",\n",
99 | " 'C': \"Answer the following question:\\n\\nAnswer: \",\n",
100 | " 'D': \"Answer the following question:\\n\\nAnswer: \",\n",
101 | " 'E': \"Answer the following question:\\n\\nAnswer: \",\n",
102 | " },\n",
103 | " {'question':'', 'A': '', 'B': '', 'C': '', 'D': '', 'E': ''},\n",
104 | " ice_token='' \n",
105 | ")\n",
106 | "\n",
107 | "templates = {'sst2': sst2_template,\n",
108 | " 'snli': xnli_template,\n",
109 | " 'mnli': xnli_template,\n",
110 | " \"qnli\": qnli_template,\n",
111 | " \"sst5\": sst5_template,\n",
112 | " \"ag_news\": ag_news_template,\n",
113 | " \"trec\": trec_template,\n",
114 | " \"commonsense_qa\": cmsqa_template\n",
115 | " }"
116 | ]
117 | },
118 | {
119 | "cell_type": "code",
120 | "execution_count": null,
121 | "id": "cde382ff",
122 | "metadata": {},
123 | "outputs": [],
124 | "source": [
125 | "## Datasets "
126 | ]
127 | },
128 | {
129 | "cell_type": "code",
130 | "execution_count": null,
131 | "id": "fb05fc8a",
132 | "metadata": {},
133 | "outputs": [],
134 | "source": [
135 | "from datasets import load_dataset\n",
136 | "from openicl import DatasetReader\n",
137 | "\n",
138 | "data_path = {'sst2': [\"gpt3mix/sst2\", None],\n",
139 | " 'snli': ['snli', None],\n",
140 | " 'mnli': ['LysandreJik/glue-mnli-train', None],\n",
141 | " \"qnli\": [\"glue\", \"qnli\"],\n",
142 | " \"sst5\": [\"SetFit/sst5\", None],\n",
143 | " \"ag_news\": [\"ag_news\", None],\n",
144 | " \"trec\": [\"trec\", None],\n",
145 | " \"commonsense_qa\": [\"commonsense_qa\", None]\n",
146 | " }\n",
147 | "\n",
148 | "input_columns={'sst2': [\"text\"],\n",
149 | " 'snli': ['premise', 'hypothesis'],\n",
150 | " 'mnli': ['premise', 'hypothesis'],\n",
151 | " \"qnli\": [\"sentence\", \"question\"],\n",
152 | " \"sst5\": [\"text\"],\n",
153 | " \"ag_news\": [\"text\"],\n",
154 | " \"trec\": [\"text\"],\n",
155 | " \"commonsense_qa\": ['question', 'A', 'B', 'C', 'D', 'E']\n",
156 | " }\n",
157 | "\n",
158 | "output_column={'sst2': 'label',\n",
159 | " 'snli': 'label',\n",
160 | " 'mnli': 'label',\n",
161 | " \"qnli\": 'label',\n",
162 | " \"sst5\": 'label',\n",
163 | " \"ag_news\": 'label',\n",
164 | " \"trec\": 'label-coarse',\n",
165 | " \"commonsense_qa\": \"answerKey\"\n",
166 | " }\n",
167 | "\n",
168 | "# Change it for other tasks\n",
169 | "task_name='snli'\n",
170 | "\n",
171 | "path,name=data_path[task_name]\n",
172 | "dataset = load_dataset(path=path,name=name)\n",
173 | "\n",
174 | "# Preprocess for commonsense_qa\n",
175 | "def pre_process(example):\n",
176 | " for i in range(5):\n",
177 | " example[chr(ord('A') + i)] = example['choices']['text'][i]\n",
178 | " return example\n",
179 | "\n",
180 | "if task_name=='commonsense_qa':\n",
181 | " dataset=dataset.map(pre_process).remove_columns(['question_concept', 'id', 'choices'])\n",
182 | "\n",
183 | "\n",
184 | "data=DatasetReader(dataset, input_columns=input_columns[task_name], output_column=output_column[task_name])\n",
185 | "\n",
186 | "\n",
187 | "test_split={\n",
188 | " 'sst2': 'test',\n",
189 | " 'snli': 'test',\n",
190 | " \"sst5\": 'test',\n",
191 | " \"ag_news\": 'test',\n",
192 | " \"trec\": 'test',\n",
193 | " 'mnli': 'validation', # cannot get gold labels for the test split\n",
194 | " \"qnli\": 'validation',\n",
195 | " \"commonsense_qa\": \"validation\"\n",
196 | "}\n",
197 | "# If you only want to test part of the test set for faster running, you can use the following codes\n",
198 | "# dataset['test'] = dataset['test'].select(list(range(100)))\n",
199 | "# dataset['validation'] = dataset['validation'].select(list(range(100))) # trec,agnews don't have validation\n",
200 | "# dataset['train'] = dataset['train'].select(list(range(100)))"
201 | ]
202 | },
203 | {
204 | "attachments": {},
205 | "cell_type": "markdown",
206 | "id": "5f28ecda-13f8-430c-a0c9-c85e52c806cf",
207 | "metadata": {},
208 | "source": [
209 | "### TopK-MDL Experiments (method in the paper)"
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": null,
215 | "id": "4b3bc424-fe92-4b2a-9d5f-ac942ebc675f",
216 | "metadata": {},
217 | "outputs": [],
218 | "source": [
219 | "from openicl import MDLRetriever, PPLInferencer, AccEvaluator\n",
220 | "\n",
221 | "retriever = MDLRetriever(data, ice_num=8, candidate_num=30, select_time=10, seed=1, batch_size=12, test_split=test_split[task_name])\n",
222 | "\n",
223 | "inferencer = PPLInferencer(model_name='gpt2-xl', batch_size=8)\n",
224 | "\n",
225 | "predictions = inferencer.inference(retriever, ice_template=templates[task_name], output_json_filename=f'mdl_{task_name}')\n",
226 | "\n",
227 | "scores = AccEvaluator().score(predictions=predictions, references=data.references)\n",
228 | "print(scores)"
229 | ]
230 | }
231 | ],
232 | "metadata": {
233 | "kernelspec": {
234 | "display_name": "Python 3",
235 | "language": "python",
236 | "name": "python3"
237 | },
238 | "language_info": {
239 | "codemirror_mode": {
240 | "name": "ipython",
241 | "version": 3
242 | },
243 | "file_extension": ".py",
244 | "mimetype": "text/x-python",
245 | "name": "python",
246 | "nbconvert_exporter": "python",
247 | "pygments_lexer": "ipython3",
248 | "version": "3.7.13"
249 | }
250 | },
251 | "nbformat": 4,
252 | "nbformat_minor": 5
253 | }
--------------------------------------------------------------------------------
/examples/tutorials/README.md:
--------------------------------------------------------------------------------
1 | # OpenICL Notebooks
2 |
3 | (**updating...**)
4 |
5 | | Notebook | Description | |
6 | |:----------|:-------------|:-------------|
7 | [Getting Started with OpenICL](https://github.com/Shark-NLP/OpenICL/blob/main/examples/tutorials/openicl_tutorial1_getting_started.ipynb) | Introduction to the main components of OpenICL | [](https://colab.research.google.com/github/Shark-NLP/OpenICL/blob/main/examples/tutorials/openicl_tutorial1_getting_started.ipynb) |
8 | [Using Different Language Models with OpenICL](https://github.com/Shark-NLP/OpenICL/blob/main/examples/tutorials/openicl_tutorial2_use_different_models.ipynb) | Run different language models with OpenICL | [](https://colab.research.google.com/github/Shark-NLP/OpenICL/blob/main/examples/tutorials/openicl_tutorial2_use_different_models.ipynb) |
9 | [Accelerating OpenICL with 🤗 Accelerate](https://github.com/Shark-NLP/OpenICL/blob/main/examples/tutorials/openicl_tutorial3_accelerate.ipynb) | A Guide to Distributed Data Parallel and Model Parallel | [](https://colab.research.google.com/github/Shark-NLP/OpenICL/blob/main/examples/tutorials/openicl_tutorial3_accelerate.ipynb)
--------------------------------------------------------------------------------
/examples/tutorials/openicl_tutorial2_use_different_models.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "318733a0-eb89-4df2-b278-e80afc144cca",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "%pip install --upgrade pip\n",
11 | "%pip install openicl\n",
12 | "# Restart the kernel after the installation is completed"
13 | ]
14 | },
15 | {
16 | "cell_type": "markdown",
17 | "id": "1be4e6a4-dff9-4f39-a86f-62736ccd82c4",
18 | "metadata": {},
19 | "source": [
20 | "# 2. Using Different Language Models with OpenICL"
21 | ]
22 | },
23 | {
24 | "attachments": {},
25 | "cell_type": "markdown",
26 | "id": "b92d14f4-2e73-4b1c-ba9e-0f9112a764b1",
27 | "metadata": {},
28 | "source": [
29 | "In this chapter, we will show you how to use OpenICL to do in-context learning (ICL) with different language models. Mainly including [GPT-2](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf), [FLAN-T5](https://arxiv.org/abs/2109.01652), [XGLM](https://arxiv.org/abs/2112.10668), OpenAI's [GPT-3](https://arxiv.org/abs/2005.14165) API and [OPT-175B](https://arxiv.org/abs/2205.01068) API."
30 | ]
31 | },
32 | {
33 | "cell_type": "markdown",
34 | "id": "06ba89e7-2790-4097-9e5d-c75e356e152f",
35 | "metadata": {},
36 | "source": [
37 | "## 2-1 Huggingface Library's Models"
38 | ]
39 | },
40 | {
41 | "cell_type": "markdown",
42 | "id": "ca01fccb-6188-4960-85ee-662af04718ce",
43 | "metadata": {},
44 | "source": [
45 | "In this section, we will take GPT2, FLAN-T5, and XGLM as examples to show you how to use the models in the [huggingface library](https://huggingface.co/models) with OpenICL. Generally speaking, you only need to assign the corresponding name to the `model_name` parameter when declaring `Inferencer`, but we will still provide you with some specific examples."
46 | ]
47 | },
48 | {
49 | "cell_type": "markdown",
50 | "id": "2aa253f5-7d3e-4810-a5fe-ddebc9699551",
51 | "metadata": {},
52 | "source": [
53 | "### 2-1-1 GPT-2"
54 | ]
55 | },
56 | {
57 | "attachments": {},
58 | "cell_type": "markdown",
59 | "id": "2a7f087a",
60 | "metadata": {},
61 | "source": [
62 | "This example can be found in [tutorial1](https://github.com/Shark-NLP/OpenICL/blob/main/examples/tutorials/openicl_tutorial1_getting_started.ipynb). But this time, we set `batch_size` for `TopkRetriever` and `PPLInference` to speed up. It can be noticed that the values of the two `batch_size`(s) could be set to be different (`8` and `6`). That is because, at the beginning of retrieval and inference, the corresponding components will receive the complete dataset or the retrieval results for the entire test set, instead of processing the data in batches."
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": null,
68 | "id": "faff9428",
69 | "metadata": {},
70 | "outputs": [],
71 | "source": [
72 | "from openicl import DatasetReader, PromptTemplate, TopkRetriever, PPLInferencer\n",
73 | "\n",
74 | "# Define a DatasetReader, loading dataset from huggingface.\n",
75 | "data = DatasetReader('gpt3mix/sst2', input_columns=['text'], output_column='label')\n",
76 | "\n",
77 | "# SST-2 Template Example\n",
78 | "tp_dict = {\n",
79 | " 0: 'Positive Movie Review: ',\n",
80 | " 1: 'Negative Movie Review: ' \n",
81 | "}\n",
82 | "template = PromptTemplate(tp_dict, {'text' : ''}, ice_token='')\n",
83 | "\n",
84 | "# TopK Retriever\n",
85 | "retriever = TopkRetriever(data, ice_num=2, batch_size=8)\n",
86 | "\n",
87 | "# Define a Inferencer\n",
88 | "inferencer = PPLInferencer(model_name='gpt2', batch_size=6)\n",
89 | "\n",
90 | "# Inference\n",
91 | "predictions = inferencer.inference(retriever, ice_template=template, output_json_filename='gpt2_sst2')\n",
92 | "print(predictions)"
93 | ]
94 | },
95 | {
96 | "cell_type": "markdown",
97 | "id": "f0a7b09a-2c1a-468e-bf1a-26f69ad51b21",
98 | "metadata": {},
99 | "source": [
100 | "### 2-1-2 XGLM"
101 | ]
102 | },
103 | {
104 | "attachments": {},
105 | "cell_type": "markdown",
106 | "id": "b56bb0c2",
107 | "metadata": {},
108 | "source": [
109 | "When it comes to machine translation, it is a good choice to use XGLM. But when using XGLM, we **don't suggest** to set `batch_size` in `GenInferencer`. (When calling the `model.generate` method of [huggingface transformers library](https://huggingface.co/docs/transformers/index), padding is needed if you want to input multiple pieces of data at a time. But we found in the test that if padding exists, the generation of XGLM will be affected). The code for evaluating the ICL performance of XGLM (7.5B) on WMT16 (de-en) dataset\n",
110 | "with direct inference strategy is as follows:"
111 | ]
112 | },
113 | {
114 | "cell_type": "code",
115 | "execution_count": null,
116 | "id": "6951df6a",
117 | "metadata": {},
118 | "outputs": [],
119 | "source": [
120 | "from openicl import DatasetReader, PromptTemplate, BM25Retriever, GenInferencer\n",
121 | "from datasets import load_dataset\n",
122 | "\n",
123 | "# Loading dataset from huggingface \n",
124 | "dataset = load_dataset('wmt16', name='de-en')\n",
125 | "\n",
126 | "# Data Preprocessing\n",
127 | "dataset = dataset.map(lambda example: example['translation']).remove_columns('translation')\n",
128 | "\n",
129 | "# Define a DatasetReader, selecting 5 pieces of data randomly.\n",
130 | "data = DatasetReader(dataset, input_columns='de', output_column='en', ds_size=5)\n",
131 | "\n",
132 | "# WMT16 en->de Template Example\n",
133 | "template = PromptTemplate(' = ', {'en' : '', 'de' : ''}, ice_token='')\n",
134 | "\n",
135 | "# BM25 Retriever\n",
136 | "retriever = BM25Retriever(data, ice_num=1, index_split='validation', test_split='test', batch_size=5)\n",
137 | "\n",
138 | "# Define a Inferencer\n",
139 | "inferencer = GenInferencer(model_name='facebook/xglm-7.5B')\n",
140 | "\n",
141 | "# Inference\n",
142 | "predictions = inferencer.inference(retriever, ice_template=template, output_json_filename='xglm_wmt')\n",
143 | "print(predictions)"
144 | ]
145 | },
146 | {
147 | "cell_type": "markdown",
148 | "id": "2518e5ee-b58e-4378-a30f-ac0535da5559",
149 | "metadata": {},
150 | "source": [
151 | "### 2-1-3 FLAN-T5"
152 | ]
153 | },
154 | {
155 | "attachments": {},
156 | "cell_type": "markdown",
157 | "id": "8778eead",
158 | "metadata": {},
159 | "source": [
160 | "In this section, we will use FLAN-T5 with OpenICL to reproduce the results in the figure below:"
161 | ]
162 | },
163 | {
164 | "attachments": {},
165 | "cell_type": "markdown",
166 | "id": "be985353",
167 | "metadata": {},
168 | "source": [
169 | ""
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": 13,
178 | "id": "1f592cb3-fae4-4295-baa9-1a3998c3ea19",
179 | "metadata": {},
180 | "outputs": [
181 | {
182 | "name": "stderr",
183 | "output_type": "stream",
184 | "text": [
185 | "Found cached dataset snli (/home/zhangyudejia/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)\n",
186 | "100%|██████████| 3/3 [00:00<00:00, 323.35it/s]\n",
187 | "[2023-03-10 15:40:19,712] [openicl.icl_inferencer.icl_gen_inferencer] [INFO] Starting inference process...\n",
188 | "100%|██████████| 10/10 [00:00<00:00, 15.22it/s]"
189 | ]
190 | },
191 | {
192 | "name": "stdout",
193 | "output_type": "stream",
194 | "text": [
195 | "['yes', 'yes', 'yes', 'no', 'yes', 'no', 'it is not possible to tell', 'yes', 'yes', 'it is not possible to tell']\n"
196 | ]
197 | },
198 | {
199 | "name": "stderr",
200 | "output_type": "stream",
201 | "text": [
202 | "\n"
203 | ]
204 | }
205 | ],
206 | "source": [
207 | "from openicl import DatasetReader, PromptTemplate, ZeroRetriever, GenInferencer\n",
208 | "\n",
209 | "# Define a DatasetReader, loading dataset from huggingface and selecting 10 pieces of data randomly.\n",
210 | "data = DatasetReader('snli', input_columns=['premise', 'hypothesis'], output_column='label', ds_size=10)\n",
211 | "\n",
212 | "# SNLI Template\n",
213 | "tp_str = 'Premise:\\nHypothesis:\\nDoes the premise entail the hypothesis?\\nOPTIONS:\\n-yes -It is not possible to tell -no'\n",
214 | "template = PromptTemplate(tp_str, column_token_map={'premise' : '', 'hypothesis' : ''}, ice_token='')\n",
215 | "\n",
216 | "# ZeroShot Retriever (do nothing)\n",
217 | "retriever = ZeroRetriever(data, index_split='train', test_split='test')\n",
218 | "\n",
219 | "# Define a Inferencer\n",
220 | "inferencer = GenInferencer(model_name='google/flan-t5-small', max_model_token_num=1000)\n",
221 | "\n",
222 | "# Inference\n",
223 | "predictions = inferencer.inference(retriever, ice_template=template, output_json_filename='flan-t5-small')\n",
224 | "print(predictions)"
225 | ]
226 | },
227 | {
228 | "attachments": {},
229 | "cell_type": "markdown",
230 | "id": "39e65e67",
231 | "metadata": {},
232 | "source": [
233 | "## 2-2 Using API-based model"
234 | ]
235 | },
236 | {
237 | "attachments": {},
238 | "cell_type": "markdown",
239 | "id": "17a9d790",
240 | "metadata": {},
241 | "source": [
242 | "OpenICL also currently supports OpenAI's GPT-3 API and OPT-175B API. But before using them, users need to do some configuration."
243 | ]
244 | },
245 | {
246 | "attachments": {},
247 | "cell_type": "markdown",
248 | "id": "deab3b33",
249 | "metadata": {},
250 | "source": [
251 | "### 2-2-1 OpenAI's GPT-3 API"
252 | ]
253 | },
254 | {
255 | "attachments": {},
256 | "cell_type": "markdown",
257 | "id": "5b784102",
258 | "metadata": {},
259 | "source": [
260 | "OpenAI provides its own open-source library -- [openai](https://github.com/openai/openai-python), for users to call their API services. To use this library in OpenICL, you need to set environment variable `OPEN_API_KEY` in advance. Here is a simple way (for detailed information, see openai's documentation [here](https://platform.openai.com/docs/api-reference/introduction)):"
261 | ]
262 | },
263 | {
264 | "cell_type": "code",
265 | "execution_count": null,
266 | "id": "84f229ea",
267 | "metadata": {
268 | "vscode": {
269 | "languageId": "powershell"
270 | }
271 | },
272 | "outputs": [],
273 | "source": [
274 | "# Replace 'your_api_key' with your key, and run this command in bash\n",
275 | "export OPENAI_API_KEY=\"your_api_key\""
276 | ]
277 | },
278 | {
279 | "attachments": {},
280 | "cell_type": "markdown",
281 | "id": "f7a7dab9",
282 | "metadata": {},
283 | "source": [
284 | "After the setting is complete, set `api_name='gpt3'` in `Inferencer` to use it normally. Below is a code snippet:"
285 | ]
286 | },
287 | {
288 | "cell_type": "code",
289 | "execution_count": 1,
290 | "id": "9510bd1e",
291 | "metadata": {},
292 | "outputs": [
293 | {
294 | "name": "stderr",
295 | "output_type": "stream",
296 | "text": [
297 | "/home/zhangyudejia/.local/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
298 | " from .autonotebook import tqdm as notebook_tqdm\n",
299 | "Found cached dataset mtop (/home/zhangyudejia/.cache/huggingface/datasets/iohadrubin___mtop/mtop/1.0.0/4ba6d9db9efaebd4f6504db7e36925632e959f456071b9d7f1b86a85cce52448)\n",
300 | "100%|██████████| 3/3 [00:00<00:00, 814.96it/s]\n",
301 | "[2023-03-10 19:03:00,481] [openicl.icl_retriever.icl_bm25_retriever] [INFO] Retrieving data for test set...\n",
302 | "100%|██████████| 1/1 [00:00<00:00, 1504.95it/s]\n",
303 | "[2023-03-10 19:03:00,486] [openicl.icl_inferencer.icl_gen_inferencer] [INFO] Starting inference process...\n",
304 | "100%|██████████| 1/1 [00:06<00:00, 6.38s/it]"
305 | ]
306 | },
307 | {
308 | "name": "stdout",
309 | "output_type": "stream",
310 | "text": [
311 | "['[IN:GET_EVENTS [SL:TYPE music festivals ] [SL:DATE in 2018 ] ]']\n"
312 | ]
313 | },
314 | {
315 | "name": "stderr",
316 | "output_type": "stream",
317 | "text": [
318 | "\n"
319 | ]
320 | }
321 | ],
322 | "source": [
323 | "from openicl import DatasetReader, PromptTemplate, BM25Retriever, GenInferencer\n",
324 | "from datasets import load_dataset\n",
325 | "\n",
326 | "dataset = load_dataset(\"iohadrubin/mtop\")\n",
327 | "dataset['train'] = dataset['train'].select([0, 1, 2])\n",
328 | "dataset['test'] = dataset['test'].select([0])\n",
329 | "\n",
330 | "dr = DatasetReader(dataset, input_columns=['question'], output_column='logical_form') \n",
331 | "\n",
332 | "tp_str = \"\\t\" \n",
333 | "tp = PromptTemplate(tp_str, column_token_map={'question' : '', 'logical_form' : ''}, ice_token='')\n",
334 | "\n",
335 | "rtr = BM25Retriever(dr, ice_num=1)\n",
336 | "\n",
337 | "infr = GenInferencer(api_name='gpt3', engine='text-davinci-003', sleep_time=3)\n",
338 | "\n",
339 | "print(infr.inference(rtr, ice_template=tp))"
340 | ]
341 | },
342 | {
343 | "attachments": {},
344 | "cell_type": "markdown",
345 | "id": "9d974057",
346 | "metadata": {},
347 | "source": [
348 | "Some models of OpenAI are charged and have a rate limit. So we set `sleep_time`(3 seconds) here to control the frequency of data requests. In order to prevent data loss caused by throwing exceptions, we also recommend using this API on a small-scale test set every time. For more information about API parameter configuration in OpenICL, please view [api_service.py](https://github.com/Shark-NLP/OpenICL/blob/main/openicl/utils/api_service.py)."
349 | ]
350 | },
351 | {
352 | "attachments": {},
353 | "cell_type": "markdown",
354 | "id": "48635ed8",
355 | "metadata": {},
356 | "source": [
357 | "### 2-2-2 OPT-175B API "
358 | ]
359 | },
360 | {
361 | "attachments": {},
362 | "cell_type": "markdown",
363 | "id": "f3a240a2",
364 | "metadata": {},
365 | "source": [
366 | "For OPT-175B, you need to deploy the model yourself (or get a URL of a deployed model from your friends :\\) ). \n",
367 | "Visit the [metaseq](https://github.com/facebookresearch/metaseq) repository for more information on deployment."
368 | ]
369 | },
370 | {
371 | "cell_type": "code",
372 | "execution_count": null,
373 | "id": "9bf3c44c",
374 | "metadata": {},
375 | "outputs": [],
376 | "source": [
377 | "from openicl import GenInferencer\n",
378 | "\n",
379 | "URL = \"xxx\"\n",
380 | "inferencer = GenInferencer(api_name='opt-175b', URL=URL)"
381 | ]
382 | }
383 | ],
384 | "metadata": {
385 | "kernelspec": {
386 | "display_name": "Python 3",
387 | "language": "python",
388 | "name": "python3"
389 | },
390 | "language_info": {
391 | "codemirror_mode": {
392 | "name": "ipython",
393 | "version": 3
394 | },
395 | "file_extension": ".py",
396 | "mimetype": "text/x-python",
397 | "name": "python",
398 | "nbconvert_exporter": "python",
399 | "pygments_lexer": "ipython3",
400 | "version": "3.8.16"
401 | }
402 | },
403 | "nbformat": 4,
404 | "nbformat_minor": 5
405 | }
406 |
--------------------------------------------------------------------------------
/examples/tutorials/openicl_tutorial3_accelerate.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%pip install --upgrade pip\n",
10 | "%pip install openicl\n",
11 | "# Restart the kernel after the installation is completed"
12 | ]
13 | },
14 | {
15 | "attachments": {},
16 | "cell_type": "markdown",
17 | "metadata": {},
18 | "source": [
19 | "# 3. Accelerating OpenICL with 🤗 Accelerate: Distributed Data Parallel and Model Parallel"
20 | ]
21 | },
22 | {
23 | "attachments": {},
24 | "cell_type": "markdown",
25 | "metadata": {},
26 | "source": [
27 | "In OpenICL, we use 🤗 [Accelerate](https://github.com/huggingface/accelerate) to implement Distributed Data Parallel (DDP) and Model Parallel. 🤗 [Accelerate](https://github.com/huggingface/accelerate) is a library that enables the same PyTorch code to be run across any distributed configuration by adding just few lines of code, to quickly quickly set up 🤗 Accelerate, on your machine(s) just run:"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": null,
33 | "metadata": {
34 | "vscode": {
35 | "languageId": "shellscript"
36 | }
37 | },
38 | "outputs": [],
39 | "source": [
40 | "accelerate config"
41 | ]
42 | },
43 | {
44 | "attachments": {},
45 | "cell_type": "markdown",
46 | "metadata": {},
47 | "source": [
48 | "For more details on 🤗 Accelearte, you can check the [documentation](https://huggingface.co/docs/accelerate/index) here."
49 | ]
50 | },
51 | {
52 | "attachments": {},
53 | "cell_type": "markdown",
54 | "metadata": {},
55 | "source": [
56 | "## 3-1 Distributed Data Parallel"
57 | ]
58 | },
59 | {
60 | "attachments": {},
61 | "cell_type": "markdown",
62 | "metadata": {},
63 | "source": [
64 | "Distributed Data Parallel (DDP) implements data parallelism at the module level which can run across multiple machines. The recommended way to use DDP is to spawn one process for each model replica, where a model replica can span multiple devices. It is quite easy to use DDP in OpenICL after completing relevant settings through ```accelerate config```, just pass in the `Accelerator` instance in `Retriever` and `Inferencer`. The following are code and script examples:"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": null,
70 | "metadata": {},
71 | "outputs": [],
72 | "source": [
73 | "# test_sst2_ddp.py\n",
74 | "# Example adapted from tutorial 1-4-1\n",
75 | "\n",
76 | "from openicl import DatasetReader, PromptTemplate, TopkRetriever, PPLInferencer\n",
77 | "from accelerate import Accelerator\n",
78 | "\n",
79 | "# Accelerate Prepare\n",
80 | "accelerator = Accelerator()\n",
81 | "\n",
82 | "# Define a DatasetReader, loading dataset from huggingface.\n",
83 | "data = DatasetReader('gpt3mix/sst2', input_columns=['text'], output_column='label')\n",
84 | "\n",
85 | "# SST-2 Template Example\n",
86 | "template = PromptTemplate(template={\n",
87 | " 0: 'Positive Movie Review: ',\n",
88 | " 1: 'Negative Movie Review: ' \n",
89 | " },\n",
90 | " column_token_map={'text' : ''},\n",
91 | " ice_token=''\n",
92 | " )\n",
93 | "\n",
94 | "# TopK Retriever\n",
95 | "retriever = TopkRetriever(data, ice_num=8, index_split='train', test_split='test', accelerator=accelerator)\n",
96 | "\n",
97 | "# Define a Inferencer\n",
98 | "inferencer = PPLInferencer(model_name='distilgpt2', accelerator=accelerator)\n",
99 | "\n",
100 | "# Inference\n",
101 | "predictions = inferencer.inference(retriever, ice_template=template, output_json_filename='ddp_sst2')\n",
102 | "\n",
103 | "# print(predictions)\n",
104 | "# Seeing results at ./icl_inference_output/ddp_sst2.json"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": null,
110 | "metadata": {},
111 | "outputs": [],
112 | "source": [
113 | "# run_sst2.ddp.sh \n",
114 | "# Replace `${your_gpu_num}` and `${your_port_id}` with your gpu number and running port number respectively\n",
115 | "accelerate launch --num_processes ${your_gpu_num} --main_process_port ${your_port_id} test_sst2_ddp.py"
116 | ]
117 | },
118 | {
119 | "attachments": {},
120 | "cell_type": "markdown",
121 | "metadata": {},
122 | "source": [
123 | "## 3-2 Model Parallel"
124 | ]
125 | },
126 | {
127 | "cell_type": "markdown",
128 | "metadata": {},
129 | "source": []
130 | }
131 | ],
132 | "metadata": {
133 | "language_info": {
134 | "name": "python"
135 | },
136 | "orig_nbformat": 4
137 | },
138 | "nbformat": 4,
139 | "nbformat_minor": 2
140 | }
141 |
--------------------------------------------------------------------------------
/openicl/__init__.py:
--------------------------------------------------------------------------------
1 | from .icl_dataset_reader import DatasetReader
2 | from .icl_prompt_template import PromptTemplate
3 | from .icl_retriever import *
4 | from .icl_evaluator import *
5 | from .icl_inferencer import *
6 |
--------------------------------------------------------------------------------
/openicl/icl_dataset_reader.py:
--------------------------------------------------------------------------------
1 | """Simple Dataset Reader"""
2 |
3 | from typing import List, Union, Optional, Dict
4 | from datasets import load_dataset
5 | from datasets import Dataset, DatasetDict
6 | from transformers import AutoTokenizer
7 | from datasets.splits import NamedSplit
8 | from openicl.icl_prompt_template import PromptTemplate
9 | from openicl.utils.check_type import _check_dataset, _check_type_list, _check_str
10 | import random
11 | import torch
12 |
13 |
14 | class DatasetReader:
15 | """In-conext Learning Dataset Reader Class
16 | Generate an DatasetReader instance through 'dataset'.
17 |
18 | Attributes:
19 | dataset (:obj:`Dataset` or :obj:`DatasetDict`): The dataset to be read.
20 | input_columns (:obj:`List[str]` or :obj:`str`): A list of column names (a string of column name) in the dataset that represent(s) the input field.
21 | output_column (:obj:`str`): A column name in the dataset that represents the prediction field.
22 | ds_size (:obj:`int` or :obj:`float`, optional): The number of pieces of data to return. When ds_size is an integer and greater than or equal to 1, `ds_size` pieces of data are randomly returned. When 0 < :obj:`ds_size` < 1, ``int(len(dataset) * ds_size)`` pieces of data are randomly returned. (used for testing)
23 | references(:obj:`list`, optional): The list of references, initialized by ``self.dataset[self.test_split][self.output_column]``.
24 | input_template (:obj:`PromptTemplate`, optional): An instance of the :obj:`PromptTemplate` class, used to format the input field content during the retrieval process. (in some retrieval methods)
25 | output_template (:obj:`PromptTemplate`, optional): An instance of the :obj:`PromptTemplate` class, used to format the output field content during the retrieval process. (in some learnable retrieval methods)
26 | input_output_template (:obj:`PromptTemplate`, optional): An instance of the `PromptTemplate` class, used to format the input-output field content during the retrieval process. (in some retrieval methods)
27 | """
28 | dataset = None
29 | input_template = None
30 | output_template = None
31 | input_output_template = None
32 | references = None
33 |
34 | def __init__(self,
35 | dataset: Union[Dataset, DatasetDict, str],
36 | input_columns: Union[List[str], str],
37 | output_column: str,
38 | name: Optional[str] = None,
39 | data_files: Optional[str] = None,
40 | input_template: Optional[PromptTemplate] = None,
41 | output_template: Optional[PromptTemplate] = None,
42 | input_output_template: Optional[PromptTemplate] = None,
43 | ds_size: Union[None, int, float] = None,
44 | split: Optional[NamedSplit] = None,
45 | test_split: Optional[str] = 'test'
46 | ) -> None:
47 | self.input_columns = _check_type_list(input_columns, [List, str])
48 | if isinstance(self.input_columns, str):
49 | self.input_columns = self.input_columns.split()
50 | self.output_column = _check_str(output_column)
51 | self.ds_size = _check_type_list(ds_size, [None, int, float])
52 | if input_template is not None:
53 | self.input_template = PromptTemplate._check_prompt_template(input_template)
54 | if output_template is not None:
55 | self.output_template = PromptTemplate._check_prompt_template(output_template)
56 | if input_output_template is not None:
57 | self.input_output_template = PromptTemplate._check_prompt_template(input_output_template)
58 | if isinstance(dataset, str):
59 | self.dataset = load_dataset(dataset, name=name, data_files=data_files)
60 | else:
61 | self.dataset = _check_dataset(dataset)
62 | if split is not None and isinstance(self.dataset, DatasetDict):
63 | self.dataset = self.dataset[split]
64 | if self.ds_size is not None:
65 | if isinstance(self.dataset, Dataset):
66 | self.dataset = load_partial_dataset(dataset, size=self.ds_size)
67 | if isinstance(self.dataset, DatasetDict):
68 | for ds_name in self.dataset.keys():
69 | self.dataset[ds_name] = load_partial_dataset(self.dataset[ds_name], size=self.ds_size)
70 | if isinstance(self.dataset, DatasetDict):
71 | if test_split in self.dataset.keys():
72 | self.references = self.dataset[test_split][self.output_column]
73 | elif isinstance(self.dataset, Dataset):
74 | self.references = self.dataset[self.output_column]
75 |
76 | def set_references(self, column: str, split: Optional[str] = None) -> None:
77 | """Set :obj:`self.references` based on :obj:`column` and optional :obj:`split`.
78 |
79 | Args:
80 | column (:obj:`str`): A string of column name.
81 | split (:obj:`str`, optional): A string of dataset split. Defaults to ``None``.
82 | """
83 | if split is not None:
84 | self.references = self.dataset[split][column]
85 | else:
86 | self.references = self.dataset[column]
87 |
88 | def generate_input_field_prompt(self, entry: Dict) -> str:
89 | """Generate a prompt for the input field based on the provided :obj:`entry` data.
90 |
91 | Args:
92 | entry (:obj:`Dict`): A piece of data to be used for generating the prompt.
93 |
94 | Returns:
95 | :obj:`str`: The generated prompt.
96 | """
97 | prompt = None
98 | if self.input_template is None:
99 | prompt = ' '.join([str(entry[ctx]) for ctx in self.input_columns])
100 | else:
101 | prompt = self.input_template.generate_item(entry)
102 | return prompt
103 |
104 | def generate_input_field_corpus(self, dataset: Union[Dataset, DatasetDict], split: Optional[str] = None) -> List[
105 | str]:
106 | """Generate corpus for input field.
107 |
108 | Args:
109 | dataset (:obj:`Dataset` or :obj:`DatasetDict`): A :obj:`datasets.Dataset` or :obj:`datasets.DatasetDict` instance.
110 | split (:obj:`str`, optional): The split of the dataset to use. If :obj:`None`, the entire dataset will be used. Defaults to ``None``.
111 |
112 | Returns:
113 | :obj:`List[str]`: A list of generated input field prompts.
114 | """
115 | if split is not None:
116 | dataset = dataset[split]
117 | corpus = []
118 | for entry in dataset:
119 | corpus.append(self.generate_input_field_prompt(entry))
120 | return corpus
121 |
122 | def generate_ouput_field_prompt(self, entry: Dict) -> str:
123 | """Generate a prompt for the output field based on the provided :obj:`entry` data.
124 |
125 | Args:
126 | entry (:obj:`Dict`): A piece of data to be used for generating the prompt.
127 |
128 | Returns:
129 | :obj:`str`: The generated prompt.
130 | """
131 | prompt = None
132 | if self.output_template is None:
133 | prompt = str(entry[self.output_column])
134 | else:
135 | prompt = self.output_template.generate_item(entry)
136 | return prompt
137 |
138 | def generate_output_field_corpus(self, dataset: Union[Dataset, DatasetDict], split: Optional[str] = None) -> List[
139 | str]:
140 | """Generate corpus for output field.
141 |
142 | Args:
143 | dataset (:obj:`Dataset` or :obj:`DatasetDict`): A :obj:`datasets.Dataset` or :obj:`datasets.DatasetDict` instance.
144 | split (:obj:`str`, optional): The split of the dataset to use. If :obj:`None`, the entire dataset will be used. Defaults to ``None``.
145 |
146 | Returns:
147 | :obj:`List[str]`: A list of generated output field prompts.
148 | """
149 | if split is not None:
150 | dataset = dataset[split]
151 | corpus = []
152 | for entry in dataset:
153 | corpus.append(self.generate_ouput_field_prompt(entry))
154 | return corpus
155 |
156 | def generate_input_output_field_prompt(self, entry: Dict) -> str:
157 | """Generate a prompt for the input-output field based on the provided:obj:`entry` data.
158 |
159 | Args:
160 | entry (:obj:`Dict`): A piece of data to be used for generating the prompt.
161 |
162 | Returns:
163 | :obj:`str`: The generated prompt.
164 | """
165 | prompt = None
166 | if self.input_output_template is None:
167 | prompt = ' '.join([entry[ctx] for ctx in self.input_columns] + [str(entry[self.output_column])])
168 | else:
169 | prompt = self.input_output_template.generate_item(entry)
170 | return prompt
171 |
172 | def generate_input_output_field_corpus(self, dataset: Union[Dataset, DatasetDict], split: Optional[str] = None) -> \
173 | List[str]:
174 | """Generate corpus for input-output field.
175 |
176 | Args:
177 | dataset (:obj:`Dataset` or :obj:`DatasetDict`): A :obj:`datasets.Dataset` or :obj:`datasets.DatasetDict` instance.
178 | split (:obj:`str`, optional): The split of the dataset to use. If :obj:`None`, the entire dataset will be used. Defaults to ``None``.
179 |
180 | Returns:
181 | :obj:`List[str]`: A list of generated input-output field prompts.
182 | """
183 | if split is not None:
184 | dataset = dataset[split]
185 | corpus = []
186 | for entry in dataset:
187 | corpus.append(self.generate_input_output_field_prompt(entry))
188 | return corpus
189 |
190 | def _check_dataset_reader(obj) -> "DatasetReader":
191 | if isinstance(obj, DatasetReader):
192 | return obj
193 | else:
194 | raise TypeError(f"Expected a DatasetReader object, but got {obj}")
195 |
196 | def __len__(self):
197 | return len(self.dataset)
198 |
199 | def __getitem__(self, idx):
200 | return self.dataset[idx]
201 |
202 | def __repr__(self):
203 | return f"DatasetReader({{\n dataset: {self.dataset},\n input_columns: {self.input_columns},\n output_columns: {self.output_column}\n}})"
204 |
205 |
206 | def load_partial_dataset(dataset: Dataset, size: Optional[Union[int, float]] = None) -> Dataset:
207 | total_size = len(dataset)
208 | if size >= total_size or size <= 0:
209 | return dataset
210 | if size > 0 and size < 1:
211 | size = int(size * total_size)
212 | rand = random.Random(x=size)
213 | index_list = list(range(total_size))
214 | rand.shuffle(index_list)
215 | dataset = dataset.select(index_list[:size])
216 | return dataset
217 |
218 |
219 | class DatasetEncoder(torch.utils.data.Dataset):
220 | def __init__(self, datalist: List, model_name=None, tokenizer=None) -> None:
221 | self.datalist = datalist
222 | if model_name is None and tokenizer is None:
223 | raise ValueError("model_name and tokenizer could not both be None")
224 | if tokenizer is not None:
225 | self.tokenizer = tokenizer
226 | else:
227 | self.tokenizer = AutoTokenizer.from_pretrained(model_name)
228 | self.tokenizer.pad_token = self.tokenizer.eos_token
229 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
230 | self.tokenizer.padding_side = "left"
231 | self.encode_dataset = []
232 | self.init_dataset()
233 | self.datalist_length = len(self.encode_dataset)
234 |
235 | def init_dataset(self):
236 | for idx, data in enumerate(self.datalist):
237 | tokenized_data = self.tokenizer.encode_plus(data, truncation=True, return_tensors='pt', verbose=False)
238 | self.encode_dataset.append({
239 | 'input_ids': tokenized_data.input_ids[0],
240 | 'attention_mask': tokenized_data.attention_mask[0],
241 | "metadata": {"id": idx, "len": len(tokenized_data.input_ids[0]),
242 | "text": data}
243 | })
244 |
245 | def __len__(self):
246 | return self.datalist_length
247 |
248 | def __getitem__(self, idx):
249 | return self.encode_dataset[idx]
250 |
--------------------------------------------------------------------------------
/openicl/icl_evaluator/__init__.py:
--------------------------------------------------------------------------------
1 | from .icl_base_evaluator import BaseEvaluator
2 | from .icl_acc_evaluator import AccEvaluator
3 | from .icl_squad_evaluator import SquadEvaluator
4 | from .icl_bleu_evaluator import BleuEvaluator
5 | from .icl_rouge_evaluator import RougeEvaluator
6 | from .icl_api_evaluator import APIEvaluator
--------------------------------------------------------------------------------
/openicl/icl_evaluator/icl_acc_evaluator.py:
--------------------------------------------------------------------------------
1 | """Acc Evaluator"""
2 | from openicl.icl_evaluator import BaseEvaluator
3 | from typing import List
4 | import evaluate
5 |
6 |
7 | class AccEvaluator(BaseEvaluator):
8 | def __init__(self) -> None:
9 | super().__init__()
10 |
11 | def score(self, predictions, references):
12 | assert len(predictions) == len(references)
13 | mapping_to_int_dict = {label: idx for idx, label in enumerate(set(map(str, references)))}
14 | pred_set = set(predictions)
15 | for pred in pred_set:
16 | if str(pred) not in mapping_to_int_dict.keys():
17 | mapping_to_int_dict[str(pred)] = len(mapping_to_int_dict)
18 | golds = [mapping_to_int_dict[str(gold)] for gold in references]
19 | preds = [mapping_to_int_dict[str(pred)] for pred in predictions]
20 | metric = evaluate.load("accuracy")
21 | return metric.compute(references=golds, predictions=preds)
22 |
--------------------------------------------------------------------------------
/openicl/icl_evaluator/icl_api_evaluator.py:
--------------------------------------------------------------------------------
1 | """API evaluator"""
2 | from openicl.icl_evaluator import BaseEvaluator
3 | from typing import List, Dict
4 | import evaluate
5 |
6 |
7 | class APIEvaluator(BaseEvaluator):
8 | def __init__(self, metric) -> None:
9 | super().__init__()
10 | self.metric = metric
11 |
12 | def score(self, predictions, references):
13 | assert len(predictions) == len(references)
14 | metric = evaluate.load(metric)
15 | scores = metric.compute(predictions=predictions, references=references)
16 | return scores
17 |
--------------------------------------------------------------------------------
/openicl/icl_evaluator/icl_base_evaluator.py:
--------------------------------------------------------------------------------
1 | """Base Evaluator"""
2 | from typing import List
3 |
4 |
5 | class BaseEvaluator:
6 | def __init__(self) -> None:
7 | pass
8 |
9 | def score(self):
10 | raise NotImplementedError("Method hasn't been implemented yet")
11 |
--------------------------------------------------------------------------------
/openicl/icl_evaluator/icl_bleu_evaluator.py:
--------------------------------------------------------------------------------
1 | """BLEU evaluator"""
2 | from openicl.icl_evaluator import BaseEvaluator
3 | from typing import List, Dict
4 | import evaluate
5 |
6 |
7 | class BleuEvaluator(BaseEvaluator):
8 | def __init__(self) -> None:
9 | super().__init__()
10 |
11 | def score(self, predictions, references):
12 | assert len(predictions) == len(references)
13 | metric = evaluate.load("sacrebleu")
14 | scores = metric.compute(predictions=predictions, references=references)
15 | return scores
16 |
--------------------------------------------------------------------------------
/openicl/icl_evaluator/icl_rouge_evaluator.py:
--------------------------------------------------------------------------------
1 | """ROUGE evaluator"""
2 | from openicl.icl_evaluator import BaseEvaluator
3 | from typing import List, Dict
4 | import evaluate
5 |
6 |
7 | class RougeEvaluator(BaseEvaluator):
8 | def __init__(self) -> None:
9 | super().__init__()
10 |
11 | def score(self, predictions, references):
12 | assert len(predictions) == len(references)
13 | metric = evaluate.load("rouge")
14 | scores = metric.compute(predictions=predictions, references=references)
15 | return scores
16 |
--------------------------------------------------------------------------------
/openicl/icl_evaluator/icl_squad_evaluator.py:
--------------------------------------------------------------------------------
1 | '''Squad Evaluator'''
2 | from openicl.icl_evaluator import BaseEvaluator
3 | from typing import List
4 | import evaluate
5 |
6 |
7 | class SquadEvaluator(BaseEvaluator):
8 | def __init__(self) -> None:
9 | super().__init__()
10 |
11 | def score(self, predictions, references):
12 | assert len(predictions) == len(references)
13 | p_list = [{'prediction_text': pred.split('\n')[0], 'id': str(i)} for i, pred in
14 | enumerate(predictions)]
15 | r_list = [{'answers': {'answer_start': [0], 'text': [ref]}, 'id': str(i)} for i, ref in
16 | enumerate(references)]
17 | metric = evaluate.load('squad')
18 | scores = metric.compute(predictions=p_list, references=r_list)
19 | return scores["f1"]
20 |
--------------------------------------------------------------------------------
/openicl/icl_inferencer/__init__.py:
--------------------------------------------------------------------------------
1 | from .icl_base_inferencer import BaseInferencer
2 | from .icl_ppl_inferencer import PPLInferencer
3 | from .icl_gen_inferencer import GenInferencer
4 | from .icl_cot_inferencer import CoTInferencer
5 | from .icl_channel_inferencer import ChannelInferencer
6 |
--------------------------------------------------------------------------------
/openicl/icl_inferencer/icl_base_inferencer.py:
--------------------------------------------------------------------------------
1 | """Basic Inferencer"""
2 |
3 | import os
4 | import torch
5 | from openicl import BaseRetriever, PromptTemplate
6 | from openicl.utils.api_service import *
7 | from openicl.icl_evaluator import *
8 | from transformers import AutoTokenizer, AutoModelForCausalLM, PretrainedConfig, GPT2Tokenizer, AutoConfig, \
9 | T5ForConditionalGeneration
10 | from typing import List, Union, Optional, Any
11 | from accelerate import Accelerator
12 | from accelerate import init_empty_weights, infer_auto_device_map
13 |
14 |
15 | class BaseInferencer:
16 | """Basic In-context Learning Inferencer Class
17 | Base class of In-context Learning Inferencer, with no inference method.
18 |
19 | Attributes:
20 | model (:obj:`AutoModelForCausalLM`, optional): Local PLM (loaded from Hugging Face), which can be initialized by name or a config class.
21 | tokenizer (:obj:`AutoTokenizer` or :obj:`GPT2Tokenizer`, optional): Tokenizer for :obj:`model`.
22 | max_model_token_num (:obj:`int`, optional): Maximum number of tokenized words allowed by the LM.
23 | batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`.
24 | accelerator (:obj:`Accelerator`, optional): An instance of the `Accelerator` class, used for multiprocessing.
25 | output_json_filepath (:obj:`str`, optional): File path for output `JSON` file.
26 | output_json_filename (:obj:`str`, optional): File name for output `JSON` file.
27 | api_name (:obj:`str`, optional): Name of API service.
28 | call_api (:obj:`bool`): If ``True``, an API for LM models will be used, determined by :obj:`api_name`.
29 | """
30 | model = None
31 | tokenizer = None
32 | call_api = False
33 |
34 | def __init__(self,
35 | model_name: Optional[Union[str, Any]] = 'gpt2-xl',
36 | tokenizer_name: Optional[Union[str, Any]] = None,
37 | max_model_token_num: Optional[int] = None,
38 | model_config: Optional[PretrainedConfig] = None,
39 | batch_size: Optional[int] = 1,
40 | accelerator: Optional[Accelerator] = None,
41 | output_json_filepath: Optional[str] = "./icl_inference_output",
42 | output_json_filename: Optional[str] = "predictions",
43 | api_name: Optional[str] = None,
44 | model_parallel: Optional[bool] = False,
45 | **kwargs
46 | ) -> None:
47 | self.model_name = model_name
48 | self.tokenizer_name = tokenizer_name if tokenizer_name is not None else model_name
49 | self.accelerator = accelerator
50 | self.is_main_process = True if self.accelerator is None or self.accelerator.is_main_process else False
51 | self.api_name = api_name
52 |
53 | if 'no_split_module_classes' not in kwargs.keys():
54 | kwargs['no_split_module_classes'] = []
55 | if 'device_map' not in kwargs.keys():
56 | kwargs['device_map'] = None
57 |
58 | no_split_module_classes = kwargs['no_split_module_classes']
59 | device_map = kwargs['device_map']
60 |
61 | self.__init_api(**kwargs)
62 | if not self.call_api:
63 | self.__init_model(self.model_name, model_config, model_parallel, device_map, no_split_module_classes)
64 | self.__init_tokenizer(self.tokenizer_name)
65 | else:
66 | if self.api_name == 'opt-175b':
67 | self.__init_tokenizer(self.tokenizer_name)
68 |
69 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
70 | if self.model is not None:
71 | self.model.to(self.device)
72 | self.model.eval()
73 | self.max_model_token_num = max_model_token_num
74 | self.batch_size = batch_size
75 | self.output_json_filepath = output_json_filepath
76 | self.output_json_filename = output_json_filename
77 | if not os.path.exists(self.output_json_filepath):
78 | os.makedirs(self.output_json_filepath)
79 |
80 | def inference(self, retriever: BaseRetriever, ice_template: Optional[PromptTemplate] = None,
81 | prompt_template: Optional[PromptTemplate] = None, output_json_filepath: Optional[str] = None,
82 | output_json_filename: Optional[str] = None) -> List:
83 | """Perform In-Context Inference given a retriever and optional templates.
84 |
85 | Args:
86 | retriever (:obj:`BaseRetriever`): An instance of a Retriever class that will be used to retrieve in-context examples
87 | ice_template (:obj:`PromptTemplate`, optional): A template for generating the in-context examples prompt. Defaults to None.
88 | prompt_template (:obj:`PromptTemplate`, optional): A template for generating the final prompt. Defaults to None.
89 | output_json_filepath (:obj:`str`, optional): The file path to save the results as a `JSON` file. Defaults to None.
90 | output_json_filename (:obj:`str`, optional): The file name to save the results as a `JSON` file. Defaults to None.
91 |
92 | Raises:
93 | NotImplementedError: If the function is not implemented in the subclass.
94 |
95 | Returns:
96 | :obj:`List:` A list of string, each representing the results of one inference.
97 | """
98 | raise NotImplementedError("Method hasn't been implemented yet")
99 |
100 | def __init_model(self, model_name, model_config, model_parallel, device_map, no_split_module_classes):
101 | if not isinstance(model_name, str):
102 | self.model = model_name
103 | self.model_name = '' # set model name to null since we pass the loaded model already
104 | return
105 | if not model_parallel:
106 | if model_config is not None:
107 | self.model = self.__get_hf_model_from_config(model_name, model_config)
108 | else:
109 | self.model = self.__get_hf_model_from_name(model_name)
110 | else:
111 | if model_config is None:
112 | model_config = AutoConfig.from_pretrained(model_name)
113 | with init_empty_weights():
114 | empty_model = AutoModelForCausalLM.from_config(model_config)
115 |
116 | if device_map is None:
117 | device_map = infer_auto_device_map(empty_model, no_split_module_classes=no_split_module_classes,
118 | dtype="float16")
119 |
120 | self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map=device_map,
121 | offload_folder="offload", offload_state_dict=True,
122 | torch_dtype=torch.float16)
123 |
124 | def __get_hf_model_from_name(self, model_name):
125 | if 't5' in model_name:
126 | return T5ForConditionalGeneration.from_pretrained(model_name)
127 | else:
128 | return AutoModelForCausalLM.from_pretrained(model_name)
129 |
130 | def __get_hf_model_from_config(self, model_name, model_config):
131 | if 't5' in model_name:
132 | raise TypeError("T5 model has no 'from_config' method")
133 | else:
134 | return AutoModelForCausalLM.from_config(model_config)
135 |
136 | def __init_tokenizer(self, tokenizer_name):
137 | if self.api_name == 'opt-175b':
138 | self.tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-30b", use_fast=False)
139 | else:
140 | if not isinstance(tokenizer_name, str):
141 | self.tokenizer = tokenizer_name
142 | else:
143 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
144 | self.tokenizer.pad_token = self.tokenizer.eos_token
145 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
146 | self.tokenizer.padding_side = "left"
147 |
148 | def __init_api(self, **kwargs):
149 | if self.api_name == None:
150 | return
151 | self.call_api = is_api_available(self.api_name)
152 | if not self.call_api:
153 | UserWarning(f"api_name '{self.api_name}' is not available, Please check it")
154 | else:
155 | update_openicl_api_request_config(self.api_name, **kwargs)
156 |
157 | def get_input_token_num(self, inputs):
158 | return len(self.tokenizer(inputs, verbose=False)['input_ids'])
159 |
160 |
161 | class GenInferencerOutputHandler:
162 | origin_prompt_dict = {}
163 | output_dict = {}
164 | prediction_dict = {}
165 | results_dict = {}
166 |
167 | def __init__(self,
168 | num: int,
169 | accelerator: Optional[Accelerator] = None
170 | ) -> None:
171 | self.num = num
172 | self.accelerator = accelerator
173 | self.origin_prompt_dict = {}
174 | self.output_dict = {}
175 | self.prediction_dict = {}
176 | self.results_dict = {}
177 |
178 | def subprocess_write_to_json(self, output_json_filepath: str, output_json_filename: str):
179 | self.results_dict = {
180 | str(idx): {
181 | 'origin_prompt': self.origin_prompt_dict[str(idx)],
182 | 'output': self.output_dict[str(idx)],
183 | 'prediction': self.prediction_dict[str(idx)]
184 | } for idx in self.origin_prompt_dict.keys()
185 | }
186 | if self.accelerator is not None:
187 | with open(f'{output_json_filepath}/process{self.accelerator.process_index}_{output_json_filename}.json',
188 | 'w', encoding='utf-8') as json_file:
189 | json.dump(self.results_dict, json_file, indent=4, ensure_ascii=False)
190 | json_file.close()
191 |
192 | def write_to_json(self, output_json_filepath: str, output_json_filename: str):
193 | with open(f'{output_json_filepath}/{output_json_filename}.json', 'w', encoding='utf-8') as json_file:
194 | json.dump(self.results_dict, json_file, indent=4, ensure_ascii=False)
195 | json_file.close()
196 |
197 | def merge_to_main_process(self, output_json_filepath: str, output_json_filename: str):
198 | if self.accelerator is not None and self.accelerator.is_main_process:
199 | for pid in range(self.accelerator.num_processes):
200 | with open(f'{output_json_filepath}/process{pid}_{output_json_filename}.json', 'r',
201 | encoding='utf-8') as json_file:
202 | subprocess_results_dict = json.load(json_file)
203 | self.results_dict.update(subprocess_results_dict)
204 | json_file.close()
205 | self.results_dict = dict(sorted(self.results_dict.items(), key=lambda x: int(x[0])))
206 |
207 | def save_orgin_prompts(self, origin_prompts: List[str]):
208 | for idx, origin_prompt in enumerate(origin_prompts):
209 | if self.accelerator is not None:
210 | idx = idx * self.accelerator.num_processes + self.accelerator.process_index
211 | self.origin_prompt_dict[str(idx)] = origin_prompt
212 |
213 | def save_prediction_and_output(self, prediction, output, idx):
214 | if self.accelerator is not None:
215 | idx = idx * self.accelerator.num_processes + self.accelerator.process_index
216 | self.prediction_dict[str(idx)] = prediction
217 | self.output_dict[str(idx)] = output
218 |
219 |
220 | class PPLInferencerOutputHandler:
221 | results_dict = {}
222 |
223 | def __init__(self,
224 | accelerator: Optional[Accelerator] = None
225 | ) -> None:
226 | self.accelerator = accelerator
227 | self.results_dict = {}
228 |
229 | def subprocess_write_to_json(self, output_json_filepath: str, output_json_filename: str):
230 | if self.accelerator is not None:
231 | with open(f'{output_json_filepath}/process{self.accelerator.process_index}_{output_json_filename}.json',
232 | 'w', encoding='utf-8') as json_file:
233 | json.dump(self.results_dict, json_file, indent=4, ensure_ascii=False)
234 | json_file.close()
235 |
236 | def write_to_json(self, output_json_filepath: str, output_json_filename: str):
237 | with open(f'{output_json_filepath}/{output_json_filename}.json', 'w', encoding='utf-8') as json_file:
238 | json.dump(self.results_dict, json_file, indent=4, ensure_ascii=False)
239 | json_file.close()
240 |
241 | def merge_to_main_process(self, output_json_filepath: str, output_json_filename: str):
242 | if self.accelerator is not None and self.accelerator.is_main_process:
243 | for pid in range(self.accelerator.num_processes):
244 | with open(f'{output_json_filepath}/process{pid}_{output_json_filename}.json', 'r',
245 | encoding='utf-8') as json_file:
246 | subprocess_results_dict = json.load(json_file)
247 | self.results_dict.update(subprocess_results_dict)
248 | json_file.close()
249 | self.results_dict = dict(sorted(self.results_dict.items(), key=lambda x: int(x[0])))
250 |
251 | def save_ice(self, ice):
252 | for idx, example in enumerate(ice):
253 | if self.accelerator is not None:
254 | idx = idx * self.accelerator.num_processes + self.accelerator.process_index
255 | if str(idx) not in self.results_dict.keys():
256 | self.results_dict[str(idx)] = {}
257 | self.results_dict[str(idx)]['in-context examples'] = example
258 |
259 | def save_predictions(self, predictions):
260 | for idx, prediction in enumerate(predictions):
261 | if self.accelerator is not None:
262 | idx = idx * self.accelerator.num_processes + self.accelerator.process_index
263 | if str(idx) not in self.results_dict.keys():
264 | self.results_dict[str(idx)] = {}
265 | self.results_dict[str(idx)]['prediction'] = prediction
266 |
267 | def save_prompt_and_ppl(self, label, input, prompt, ppl, idx):
268 | if self.accelerator is not None:
269 | idx = idx * self.accelerator.num_processes + self.accelerator.process_index
270 | if str(idx) not in self.results_dict.keys():
271 | self.results_dict[str(idx)] = {}
272 | if 'label: ' + str(label) not in self.results_dict[str(idx)].keys():
273 | self.results_dict[str(idx)]['label: ' + str(label)] = {}
274 | self.results_dict[str(idx)]['label: ' + str(label)]['testing input'] = input
275 | self.results_dict[str(idx)]['label: ' + str(label)]['prompt'] = prompt
276 | self.results_dict[str(idx)]['label: ' + str(label)]['PPL'] = ppl
277 |
--------------------------------------------------------------------------------
/openicl/icl_inferencer/icl_channel_inferencer.py:
--------------------------------------------------------------------------------
1 | """PPL Inferencer"""
2 |
3 | import json
4 | import torch
5 | from openicl import PromptTemplate
6 | from openicl.icl_retriever import *
7 | from openicl.icl_evaluator import *
8 | from openicl.icl_inferencer.icl_base_inferencer import BaseInferencer, PPLInferencerOutputHandler
9 | from openicl.icl_inferencer.icl_ppl_inferencer import PPLInferencer
10 | from openicl.utils.logging import get_logger
11 | from openicl.utils.api_service import *
12 | from typing import List, Union, Optional
13 | from tqdm import tqdm
14 | from tqdm import trange
15 | from transformers import PretrainedConfig
16 | from accelerate import Accelerator
17 |
18 | logger = get_logger(__name__)
19 |
20 |
21 | class ChannelInferencer(PPLInferencer):
22 | """PPL In-context Learning Inferencer Class
23 | Channel In-context Learning Inferencer.
24 | We recommend you to use ppl inferencer instead of channel inferencer
25 |
26 | """
27 |
28 | def inference(self, retriever: BaseRetriever, ice_template: Optional[PromptTemplate] = None,
29 | prompt_template: Optional[PromptTemplate] = None, output_json_filepath: Optional[str] = None,
30 | output_json_filename: Optional[str] = None, normalizing_str: Optional[str] = None) -> List:
31 | # 1. Preparation for output logs
32 | output_handler = PPLInferencerOutputHandler(self.accelerator)
33 |
34 | sub_predictions = []
35 | ppl = []
36 | ice = []
37 |
38 | if output_json_filepath is None:
39 | output_json_filepath = self.output_json_filepath
40 | if output_json_filename is None:
41 | output_json_filename = self.output_json_filename
42 |
43 | # 2. Get results of retrieval process
44 | ice_idx_list = retriever.retrieve()
45 |
46 | # 3. Get labels of all the classes
47 | if self.labels is None:
48 | labels = retriever.get_labels(ice_template=ice_template, prompt_template=prompt_template)
49 | else:
50 | labels = self.labels
51 |
52 | # 4. Generate in-context examples for testing inputs
53 | for idx in range(len(ice_idx_list)):
54 | ice.append(retriever.generate_ice(ice_idx_list[idx], ice_template=ice_template))
55 | output_handler.save_ice(ice)
56 |
57 | # 5. Calculating PPL for prompts in each label's class
58 | for label in labels:
59 | index = 0
60 | prompt_list = []
61 | sub_ppl_list = []
62 | context_length_list = []
63 |
64 | # 5.1 Generate prompts of current label and truncate
65 | for idx in range(len(ice_idx_list)):
66 | prompt = retriever.generate_label_prompt(idx, ice[idx], label, ice_template=ice_template,
67 | prompt_template=prompt_template,
68 | remain_sep=True)
69 | if self.max_model_token_num is not None and self.api_name != 'gpt3':
70 | prompt_token_num = self.get_input_token_num(prompt)
71 | while len(ice_idx_list[idx]) > 0 and prompt_token_num > self.max_model_token_num:
72 | ice_idx_list[idx] = ice_idx_list[idx][:-1]
73 | ice[idx] = retriever.generate_ice(ice_idx_list[idx], ice_template=ice_template)
74 | prompt = retriever.generate_label_prompt(idx, ice[idx], label, ice_template=ice_template,
75 | prompt_template=prompt_template)
76 | prompt_token_num = self.get_input_token_num(prompt)
77 |
78 | prompt_sep = prompt
79 | if prompt_template is not None:
80 | sep_token = prompt_template.sep_token
81 | else:
82 | sep_token = ice_template.sep_token
83 | sep_pos = prompt_sep.find(sep_token)
84 | context = prompt_sep[0:sep_pos]
85 | prompt = prompt_sep.replace(sep_token, '')
86 | context_length_list.append(self.get_input_token_num(context))
87 | prompt_list.append(prompt)
88 |
89 | # 5.2 Get PPL
90 | logger.info(f"Calculating PPL for prompts labeled '{label}'")
91 | for idx in trange(0, len(prompt_list), self.batch_size, disable=not self.is_main_process):
92 | sub_prompt_list = prompt_list[idx:idx + self.batch_size]
93 | sub_context_length_list = context_length_list[idx:idx + self.batch_size]
94 |
95 | with torch.no_grad():
96 | sub_res = self.__get_ppl(input_texts=sub_prompt_list, mask_length=sub_context_length_list)
97 | for res, prompt in zip(sub_res, sub_prompt_list):
98 | sub_ppl_list.append(res)
99 | output_handler.save_prompt_and_ppl(label, prompt[len(ice[idx]):], prompt, res, index)
100 | index = index + 1
101 | ppl.append(sub_ppl_list)
102 |
103 | # 6. Get lowest PPL class as predictions
104 | ppl = list(zip(*ppl))
105 | for single_ppl in ppl:
106 | sub_predictions.append(labels[single_ppl.index(min(single_ppl))])
107 | output_handler.save_predictions(sub_predictions)
108 |
109 | # 7. Output
110 | output_handler.subprocess_write_to_json(output_json_filepath, output_json_filename)
111 | if self.accelerator is not None:
112 | self.accelerator.wait_for_everyone()
113 | output_handler.merge_to_main_process(output_json_filepath, output_json_filename)
114 | output_handler.write_to_json(output_json_filepath, output_json_filename)
115 |
116 | return [sample['prediction'] for sample in output_handler.results_dict.values()]
117 |
--------------------------------------------------------------------------------
/openicl/icl_inferencer/icl_cot_inferencer.py:
--------------------------------------------------------------------------------
1 | """chain-of-thought inferencer"""
2 |
3 | import torch
4 | from openicl import PromptTemplate
5 | from openicl.icl_retriever import *
6 | from openicl.icl_evaluator import *
7 | from openicl.icl_inferencer.icl_base_inferencer import BaseInferencer, GenInferencerOutputHandler
8 | from typing import List, Union, Optional
9 | from tqdm import tqdm
10 | from transformers import PretrainedConfig
11 | from openicl.utils.api_service import *
12 | from openicl.utils.icl_common_utils import get_dataloader, get_generation_prompt_list_from_retriever_indices
13 | from openicl.utils.logging import get_logger
14 | from accelerate import Accelerator
15 |
16 | logger = get_logger(__name__)
17 |
18 |
19 | class CoTInferencer(BaseInferencer):
20 | """COT In-context Learning Inferencer Class
21 | Chain-of-Thought In-context Learning Inferencer.
22 |
23 | Attributes:
24 | model (:obj:`AutoModelForCausalLM`, optional): Local PLM (loaded from Hugging Face), which can be initialized by name or a config class.
25 | tokenizer (:obj:`AutoTokenizer` or :obj:`GPT2Tokenizer`, optional): Tokenizer for :obj:`model`.
26 | max_model_token_num (:obj:`int`, optional): Maximum number of tokenized words allowed by the LM.
27 | batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`.
28 | accelerator (:obj:`Accelerator`, optional): An instance of the `Accelerator` class, used for multiprocessing.
29 | output_json_filepath (:obj:`str`, optional): File path for output `JSON` file.
30 | output_json_filename (:obj:`str`, optional): File name for output `JSON` file.
31 | api_name (:obj:`str`, optional): Name of API service.
32 | call_api (:obj:`bool`): If ``True``, an API for LM models will be used, determined by :obj:`api_name`.
33 | gen_field_replace_token (:obj:`str`, optional): Used to replace the generation field token when generating prompts.
34 | generation_kwargs (:obj:`Dict`, optional): Parameters for the :obj:`model.generate()` method.
35 | cot_list (:obj:`list`, optional): A list of sentences used for multiple-step generations.
36 | """
37 |
38 | def __init__(self,
39 | cot_list: Optional[List[str]] = [],
40 | model_name: Optional[str] = 'gpt2-xl',
41 | tokenizer_name: Optional[str] = None,
42 | max_model_token_num: Optional[int] = None,
43 | model_config: Optional[PretrainedConfig] = None,
44 | batch_size: Optional[int] = 1,
45 | gen_field_replace_token: Optional[str] = '',
46 | generation_kwargs={"max_new_tokens": 100},
47 | accelerator: Optional[Accelerator] = None,
48 | output_json_filepath: Optional[str] = "./icl_inference_output",
49 | output_json_filename: Optional[str] = "predictions",
50 | api_name: Optional[str] = None,
51 | model_parallel: Optional[bool] = False,
52 | **kwargs
53 | ) -> None:
54 | super().__init__(model_name, tokenizer_name, max_model_token_num, model_config, batch_size, accelerator,
55 | output_json_filepath, output_json_filename, api_name, model_parallel, **kwargs)
56 | self.cot_list = cot_list
57 | self.gen_field_replace_token = gen_field_replace_token
58 | self.generation_kwargs = generation_kwargs
59 |
60 | def inference(self, retriever: BaseRetriever, ice_template: Optional[PromptTemplate] = None,
61 | prompt_template: Optional[PromptTemplate] = None, output_json_filepath: Optional[str] = None,
62 | output_json_filename: Optional[str] = None) -> List:
63 | # 1. Preparation for output logs
64 | num = len(retriever.test_ds)
65 | output_handler = GenInferencerOutputHandler(num, self.accelerator)
66 | index = 0
67 | if output_json_filepath is None:
68 | output_json_filepath = self.output_json_filepath
69 | if output_json_filename is None:
70 | output_json_filename = self.output_json_filename
71 |
72 | # 2. Get results of retrieval process
73 | ice_idx_list = retriever.retrieve()
74 | cot_list_len = len(self.cot_list)
75 |
76 | # 3. Generate prompts for testing input
77 | prompt_list = get_generation_prompt_list_from_retriever_indices(ice_idx_list, retriever, self.tokenizer,
78 | self.gen_field_replace_token,
79 | max_model_token_num=self.max_model_token_num,
80 | ice_template=ice_template,
81 | prompt_template=prompt_template)
82 | if cot_list_len > 0:
83 | prompt_list = [prompt + self.cot_list[0] for prompt in prompt_list]
84 |
85 | # 4. Inference for `max((len(self.cot_list) + 1), 1)` times
86 | for idx in range(0, max(cot_list_len, 1)):
87 | index = 0
88 | cot_idx = idx + 1
89 | # 4-1. Wrap prompts with Dataloader
90 | dataloader = get_dataloader(prompt_list, self.batch_size)
91 | output_handler.save_orgin_prompts(prompt_list)
92 |
93 | for entry in tqdm(dataloader, disable=not self.is_main_process):
94 | # 4-2-1. Inference with local model
95 | if not self.call_api:
96 | with torch.no_grad():
97 | tokenized_data = self.tokenizer.batch_encode_plus(entry, padding=True, return_tensors='pt').to(
98 | self.device)
99 | prompt_len = int(tokenized_data.attention_mask.shape[1])
100 | if 't5' in self.model_name:
101 | prompt_len = 0
102 | outputs = self.model.generate(input_ids=tokenized_data.input_ids,
103 | attention_mask=tokenized_data.attention_mask,
104 | eos_token_id=self.tokenizer.eos_token_id,
105 | pad_token_id=self.tokenizer.pad_token_id,
106 | **self.generation_kwargs)
107 | outputs = outputs.tolist()
108 | complete_output = self.tokenizer.batch_decode(outputs[:], skip_special_tokens=True)
109 | generated = self.tokenizer.batch_decode([output[prompt_len:] for output in outputs],
110 | skip_special_tokens=True)
111 | # 4-2-2. Inference with remote API
112 | else:
113 | complete_output, generated = api_get_tokens(self.api_name, entry)
114 |
115 | # 4-2-3. Save current output
116 | for prediction, output in zip(generated, complete_output):
117 | if 't5' in self.model_name:
118 | output = prompt_list[index] + output
119 | output_handler.save_prediction_and_output(prediction, output, index)
120 | prompt_list[index] = output
121 | index = index + 1
122 |
123 | # 4-3. Output for current step
124 | if cot_idx < cot_list_len:
125 | filename = output_json_filename + f'_step{idx}'
126 | else:
127 | filename = output_json_filename
128 | output_handler.subprocess_write_to_json(output_json_filepath, filename)
129 | if self.accelerator is not None:
130 | self.accelerator.wait_for_everyone()
131 | output_handler.merge_to_main_process(output_json_filepath, filename)
132 | output_handler.write_to_json(output_json_filepath, filename)
133 |
134 | # 4-4. Check for next string in `self.cot_list`
135 | if cot_idx < cot_list_len:
136 | prompt_list = [(prompt + str(self.cot_list[cot_idx])) for prompt in prompt_list]
137 | else:
138 | break
139 | return [sample['prediction'] for sample in output_handler.results_dict.values()]
140 |
--------------------------------------------------------------------------------
/openicl/icl_inferencer/icl_gen_inferencer.py:
--------------------------------------------------------------------------------
1 | """Direct Generation Inferencer"""
2 |
3 | import json
4 | import torch
5 | from openicl import PromptTemplate
6 | from openicl.icl_retriever import *
7 | from openicl.icl_evaluator import *
8 | from openicl.icl_inferencer.icl_base_inferencer import BaseInferencer, GenInferencerOutputHandler
9 | from openicl.utils.api_service import *
10 | from openicl.utils.icl_common_utils import get_dataloader, get_generation_prompt_list_from_retriever_indices
11 | from openicl.utils.logging import get_logger
12 | from typing import List, Union, Optional
13 | from tqdm import tqdm
14 | from transformers import PretrainedConfig
15 | from accelerate import Accelerator
16 |
17 | logger = get_logger(__name__)
18 |
19 |
20 | class GenInferencer(BaseInferencer):
21 | """Generation In-context Learning Inferencer Class
22 | In-context Learning Inferencer for Directly Generation.
23 |
24 | Attributes:
25 | model (:obj:`AutoModelForCausalLM`, optional): Local PLM (loaded from Hugging Face), which can be initialized by name or a config class.
26 | tokenizer (:obj:`AutoTokenizer` or :obj:`GPT2Tokenizer`, optional): Tokenizer for :obj:`model`.
27 | max_model_token_num (:obj:`int`, optional): Maximum number of tokenized words allowed by the LM.
28 | batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`.
29 | accelerator (:obj:`Accelerator`, optional): An instance of the `Accelerator` class, used for multiprocessing.
30 | output_json_filepath (:obj:`str`, optional): File path for output `JSON` file.
31 | output_json_filename (:obj:`str`, optional): File name for output `JSON` file.
32 | api_name (:obj:`str`, optional): Name of API service.
33 | call_api (:obj:`bool`): If ``True``, an API for LM models will be used, determined by :obj:`api_name`.
34 | gen_field_replace_token (:obj:`str`, optional): Used to replace the generation field token when generating prompts.
35 | generation_kwargs (:obj:`Dict`, optional): Parameters for the :obj:`model.generate()` method.
36 | """
37 |
38 | def __init__(self,
39 | model_name: Optional[str] = 'gpt2-xl',
40 | tokenizer_name: Optional[str] = None,
41 | max_model_token_num: Optional[int] = None,
42 | model_config: Optional[PretrainedConfig] = None,
43 | batch_size: Optional[int] = 1,
44 | gen_field_replace_token: Optional[str] = '',
45 | generation_kwargs={"max_new_tokens": 100},
46 | accelerator: Optional[Accelerator] = None,
47 | output_json_filepath: Optional[str] = "./icl_inference_output",
48 | output_json_filename: Optional[str] = "predictions",
49 | api_name: Optional[str] = None,
50 | model_parallel: Optional[bool] = False,
51 | **kwargs
52 | ) -> None:
53 | super().__init__(model_name, tokenizer_name, max_model_token_num, model_config, batch_size, accelerator,
54 | output_json_filepath, output_json_filename, api_name, model_parallel, **kwargs)
55 | self.gen_field_replace_token = gen_field_replace_token
56 | self.generation_kwargs = generation_kwargs
57 |
58 | def inference(self, retriever: BaseRetriever, ice_template: Optional[PromptTemplate] = None,
59 | prompt_template: Optional[PromptTemplate] = None, output_json_filepath: Optional[str] = None,
60 | output_json_filename: Optional[str] = None, force_words=None) -> List:
61 | # 1. Preparation for output logs
62 | num = len(retriever.test_ds)
63 | output_handler = GenInferencerOutputHandler(num, self.accelerator)
64 | index = 0
65 |
66 | if output_json_filepath is None:
67 | output_json_filepath = self.output_json_filepath
68 | if output_json_filename is None:
69 | output_json_filename = self.output_json_filename
70 |
71 | # 2. Get results of retrieval process
72 | ice_idx_list = retriever.retrieve()
73 |
74 | # 3. Generate prompts for testing input
75 | prompt_list = get_generation_prompt_list_from_retriever_indices(ice_idx_list, retriever, self.tokenizer,
76 | self.gen_field_replace_token,
77 | max_model_token_num=self.max_model_token_num,
78 | ice_template=ice_template,
79 | prompt_template=prompt_template)
80 | output_handler.save_orgin_prompts(prompt_list)
81 |
82 | # 4. Wrap prompts with Dataloader
83 | dataloader = get_dataloader(prompt_list, self.batch_size)
84 |
85 | # 5. Inference for prompts in each batch
86 | logger.info("Starting inference process...")
87 | for entry in tqdm(dataloader, disable=not self.is_main_process):
88 | # 5-1. Inference with local model
89 | if not self.call_api:
90 | with torch.no_grad():
91 | tokenized_data = self.tokenizer.batch_encode_plus(entry, padding=True, return_tensors='pt').to(
92 | self.device)
93 | prompt_len = int(tokenized_data.attention_mask.shape[1])
94 | if 't5' in self.model_name:
95 | prompt_len = 0
96 | if force_words is not None:
97 | force_words_ids = [
98 | self.tokenizer(force_words).input_ids,
99 | ]
100 | outputs = self.model.generate(input_ids=tokenized_data.input_ids,
101 | force_words_ids=force_words_ids,
102 | num_beams=10,
103 | attention_mask=tokenized_data.attention_mask,
104 | eos_token_id=self.tokenizer.eos_token_id,
105 | pad_token_id=self.tokenizer.pad_token_id,
106 | **self.generation_kwargs)
107 | else:
108 | outputs = self.model.generate(input_ids=tokenized_data.input_ids,
109 | attention_mask=tokenized_data.attention_mask,
110 | eos_token_id=self.tokenizer.eos_token_id,
111 | pad_token_id=self.tokenizer.pad_token_id,
112 | **self.generation_kwargs)
113 | outputs = outputs.tolist()
114 | complete_output = self.tokenizer.batch_decode(outputs[:], skip_special_tokens=True)
115 | generated = self.tokenizer.batch_decode([output[prompt_len:] for output in outputs],
116 | skip_special_tokens=True)
117 | # 5-2. Inference with remote API
118 | else:
119 | complete_output, generated = api_get_tokens(self.api_name, entry)
120 |
121 | # 5-3. Save current output
122 | for prediction, output in zip(generated, complete_output):
123 | output_handler.save_prediction_and_output(prediction, output, index)
124 | index = index + 1
125 |
126 | # 6. Output
127 | output_handler.subprocess_write_to_json(output_json_filepath, output_json_filename)
128 | if self.accelerator is not None:
129 | self.accelerator.wait_for_everyone()
130 | output_handler.merge_to_main_process(output_json_filepath, output_json_filename)
131 | output_handler.write_to_json(output_json_filepath, output_json_filename)
132 | return [sample['prediction'] for sample in output_handler.results_dict.values()]
133 |
--------------------------------------------------------------------------------
/openicl/icl_inferencer/icl_ppl_inferencer.py:
--------------------------------------------------------------------------------
1 | """PPL Inferencer"""
2 |
3 | import json
4 | import torch
5 | from openicl import PromptTemplate
6 | from openicl.icl_retriever import *
7 | from openicl.icl_evaluator import *
8 | from openicl.icl_inferencer.icl_base_inferencer import BaseInferencer, PPLInferencerOutputHandler
9 | from openicl.utils.logging import get_logger
10 | from openicl.utils.api_service import *
11 | from typing import List, Union, Optional
12 | from tqdm import tqdm
13 | from tqdm import trange
14 | from transformers import PretrainedConfig
15 | from accelerate import Accelerator
16 |
17 | logger = get_logger(__name__)
18 |
19 |
20 | class PPLInferencer(BaseInferencer):
21 | """PPL In-context Learning Inferencer Class
22 | Perplexity-based In-context Learning Inferencer.
23 |
24 | Attributes:
25 | model (:obj:`AutoModelForCausalLM`, optional): Local PLM (loaded from Hugging Face), which can be initialized by name or a config class.
26 | tokenizer (:obj:`AutoTokenizer` or :obj:`GPT2Tokenizer`, optional): Tokenizer for :obj:`model`.
27 | max_model_token_num (:obj:`int`, optional): Maximum number of tokenized words allowed by the LM.
28 | batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`.
29 | accelerator (:obj:`Accelerator`, optional): An instance of the `Accelerator` class, used for multiprocessing.
30 | output_json_filepath (:obj:`str`, optional): File path for output `JSON` file.
31 | output_json_filename (:obj:`str`, optional): File name for output `JSON` file.
32 | api_name (:obj:`str`, optional): Name of API service.
33 | call_api (:obj:`bool`): If ``True``, an API for LM models will be used, determined by :obj:`api_name`.
34 | labels (:obj:`List`, optional): A list of labels for all classes.
35 | """
36 |
37 | def __init__(self,
38 | model_name: Optional[str] = 'gpt2-xl',
39 | tokenizer_name: Optional[str] = None,
40 | max_model_token_num: Optional[int] = None,
41 | model_config: Optional[PretrainedConfig] = None,
42 | batch_size: Optional[int] = 1,
43 | accelerator: Optional[Accelerator] = None,
44 | output_json_filepath: Optional[str] = "./icl_inference_output",
45 | output_json_filename: Optional[str] = "predictions",
46 | api_name: Optional[str] = None,
47 | labels: Optional[List] = None,
48 | model_parallel: Optional[bool] = False,
49 | **kwargs
50 | ) -> None:
51 | super().__init__(model_name, tokenizer_name, max_model_token_num, model_config, batch_size, accelerator,
52 | output_json_filepath, output_json_filename, api_name, model_parallel, **kwargs)
53 | self.labels = labels
54 |
55 | def inference(self, retriever: BaseRetriever, ice_template: Optional[PromptTemplate] = None,
56 | prompt_template: Optional[PromptTemplate] = None, output_json_filepath: Optional[str] = None,
57 | output_json_filename: Optional[str] = None, normalizing_str: Optional[str] = None) -> List:
58 | # 1. Preparation for output logs
59 | output_handler = PPLInferencerOutputHandler(self.accelerator)
60 |
61 | sub_predictions = []
62 | ppl = []
63 | ice = []
64 |
65 | if output_json_filepath is None:
66 | output_json_filepath = self.output_json_filepath
67 | if output_json_filename is None:
68 | output_json_filename = self.output_json_filename
69 |
70 | # 2. Get results of retrieval process
71 | ice_idx_list = retriever.retrieve()
72 |
73 | # 3. Get labels of all the classes
74 | if self.labels is None:
75 | labels = retriever.get_labels(ice_template=ice_template, prompt_template=prompt_template)
76 | else:
77 | labels = self.labels
78 |
79 | # 4. Generate in-context examples for testing inputs
80 | for idx in range(len(ice_idx_list)):
81 | ice.append(retriever.generate_ice(ice_idx_list[idx], ice_template=ice_template))
82 | output_handler.save_ice(ice)
83 |
84 | # 5. Calculating PPL for prompts in each label's class
85 | for label in labels:
86 | index = 0
87 | prompt_list = []
88 | sub_ppl_list = []
89 | normalizing_prompt_list = []
90 | context_length_list = []
91 |
92 | # 5.1 Generate prompts of current label and truncate
93 | for idx in range(len(ice_idx_list)):
94 | prompt = retriever.generate_label_prompt(idx, ice[idx], label, ice_template=ice_template,
95 | prompt_template=prompt_template,
96 | remain_sep=normalizing_str is not None)
97 | if self.max_model_token_num is not None and self.api_name != 'gpt3':
98 | prompt_token_num = self.get_input_token_num(prompt)
99 | while len(ice_idx_list[idx]) > 0 and prompt_token_num > self.max_model_token_num:
100 | ice_idx_list[idx] = ice_idx_list[idx][:-1]
101 | ice[idx] = retriever.generate_ice(ice_idx_list[idx], ice_template=ice_template)
102 | prompt = retriever.generate_label_prompt(idx, ice[idx], label, ice_template=ice_template,
103 | prompt_template=prompt_template)
104 | prompt_token_num = self.get_input_token_num(prompt)
105 |
106 | if normalizing_str is not None:
107 | prompt_sep = prompt
108 | if prompt_template is not None:
109 | sep_token = prompt_template.sep_token
110 | else:
111 | sep_token = ice_template.sep_token
112 | sep_pos = prompt_sep.find(sep_token)
113 |
114 | context = prompt_sep[0:sep_pos]
115 | answer = prompt_sep[sep_pos:].replace(sep_token, '')
116 | prompt = context + answer
117 | normalizing_prompt = normalizing_str + answer
118 |
119 | context_length_list.append(self.get_input_token_num(context))
120 | normalizing_prompt_list.append(normalizing_prompt)
121 | prompt_list.append(prompt)
122 |
123 | if normalizing_str is not None:
124 | normalizing_str_len = self.get_input_token_num(normalizing_str)
125 |
126 | # 5.2 Get PPL
127 | logger.info(f"Calculating PPL for prompts labeled '{label}'")
128 | for idx in trange(0, len(prompt_list), self.batch_size, disable=not self.is_main_process):
129 | sub_prompt_list = prompt_list[idx:idx + self.batch_size]
130 | if normalizing_str is not None:
131 | sub_context_length_list = context_length_list[idx:idx + self.batch_size]
132 | sub_normalizing_prompt_list = normalizing_prompt_list[idx:idx + self.batch_size]
133 |
134 | with torch.no_grad():
135 | if normalizing_str is not None:
136 | res1 = self.__get_ppl(input_texts=sub_prompt_list, mask_length=sub_context_length_list)
137 | res2 = self.__get_ppl(input_texts=sub_normalizing_prompt_list,
138 | mask_length=[normalizing_str_len for i in range(len(sub_prompt_list))]
139 | )
140 | sub_res = res1 - res2
141 | else:
142 | sub_res = self.__get_ppl(sub_prompt_list).tolist()
143 | for res, prompt in zip(sub_res, sub_prompt_list):
144 | sub_ppl_list.append(res)
145 | output_handler.save_prompt_and_ppl(label, prompt[len(ice[idx]):], prompt, res, index)
146 | index = index + 1
147 | ppl.append(sub_ppl_list)
148 |
149 | # 6. Get lowest PPL class as predictions
150 | ppl = list(zip(*ppl))
151 | for single_ppl in ppl:
152 | sub_predictions.append(labels[single_ppl.index(min(single_ppl))])
153 | output_handler.save_predictions(sub_predictions)
154 |
155 | # 7. Output
156 | output_handler.subprocess_write_to_json(output_json_filepath, output_json_filename)
157 | if self.accelerator is not None:
158 | self.accelerator.wait_for_everyone()
159 | output_handler.merge_to_main_process(output_json_filepath, output_json_filename)
160 | output_handler.write_to_json(output_json_filepath, output_json_filename)
161 |
162 | return [sample['prediction'] for sample in output_handler.results_dict.values()]
163 |
164 | def __get_ppl(self, input_texts: List[str], mask_length=None):
165 | if self.call_api:
166 | return api_get_ppl(self.api_name, input_texts)
167 | self.tokenizer.padding_side = "right"
168 | inputs = self.tokenizer(input_texts, padding=True, return_tensors='pt', truncation=True)
169 | inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
170 | outputs = self.model(**inputs)
171 |
172 | shift_logits = outputs.logits[..., :-1, :].contiguous()
173 | shift_labels = inputs["input_ids"][..., 1:].contiguous()
174 |
175 | loss_fct = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=self.tokenizer.pad_token_id)
176 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(
177 | shift_labels.size())
178 |
179 | if mask_length is not None:
180 | mask = torch.zeros_like(shift_labels) # [batch,seqlen]
181 | for i in range(len(mask)):
182 | for j in range(mask_length[i] - 1, len(mask[i])):
183 | mask[i][j] = 1
184 | loss = loss * mask
185 |
186 | lens = (inputs["input_ids"] != self.tokenizer.pad_token_id).sum(-1).cpu().numpy()
187 | if mask_length is not None:
188 | lens -= np.array(mask_length)
189 | ce_loss = loss.sum(-1).cpu().detach().numpy() / lens
190 | return ce_loss
191 |
--------------------------------------------------------------------------------
/openicl/icl_prompt_template.py:
--------------------------------------------------------------------------------
1 | """Prompt Template"""
2 |
3 | from typing import Dict, Optional, Union, Hashable
4 | from .utils.check_type import _check_type_list, _check_dict
5 |
6 |
7 | class PromptTemplate:
8 | """In-context Learning Prompt Template Class
9 | This class represents a template that guides the generation of prompts in the retrieval or inference process.
10 |
11 | Attributes:
12 | template (:obj:`Dict` or :obj:`str`): A custom template dictionary or string. If a dictionary, the keys of the dictionary represent the values of the output_column, and the values represent the corresponding generated statement. If a string, it represents a string template.
13 | column_token_map (:obj:`Dict`): A dictionary mapping column names to specific tokens. The tokens will be replaced by data in the corresponding column (one piece each time) during the retrieval or inference process.
14 | selected_column_name (:obj:`str`, optional): Used only with string-type templates. A specific column that needs its value to be mapped.
15 | selected_column_map (:obj:`Dict`, optional): Used only with string-type templates. Maps the value of the column :obj:`selected_column_name`.
16 | ice_token(:obj:`str`, optional): A string that represents the specific token mapping from in-context examples. None if you want to use this template only to generate in-context examples, otherwise it can be used to generate the final prompt that is fed into the PLM. The ice_token will be invisible when generating in-context examples.
17 | """
18 |
19 | def __init__(self,
20 | template: Union[Dict, str],
21 | column_token_map: Dict,
22 | selected_column_name: Optional[str] = None,
23 | selected_column_map: Optional[Dict] = None,
24 | ice_token: Optional[str] = None,
25 | sep_token: Optional[str] = None,
26 | ) -> None:
27 | self.template = _check_type_list(template, [Dict, str])
28 | self.column_token_map = _check_dict(column_token_map)
29 | self.selected_column_name = _check_type_list(selected_column_name, [None, str])
30 | self.selected_column_map = _check_type_list(selected_column_map, [None, Dict])
31 | self.ice_token = _check_type_list(ice_token, [None, str])
32 | self.sep_token = _check_type_list(sep_token, [None, str])
33 | if (self.selected_column_name is not None and self.selected_column_map is None) or \
34 | self.selected_column_name is None and self.selected_column_map is not None:
35 | raise ValueError("self.selected_column_name and self.selected_column_map should be set together")
36 | self._check_template_legacy()
37 |
38 | def _check_template_legacy(self):
39 | if isinstance(self.template, Dict):
40 | # Check if token exists in values of tp_dict
41 | for tp_dict_val in self.template.values():
42 | if not isinstance(tp_dict_val, str):
43 | raise TypeError(f"dictionary of template expects a str value, but got '{tp_dict_val}'")
44 | if self.ice_token is not None and self.ice_token not in tp_dict_val:
45 | raise LookupError(f"'{self.ice_token}' not in '{tp_dict_val}'")
46 | if isinstance(self.template, str):
47 | if self.ice_token is not None and self.ice_token not in self.template:
48 | raise LookupError(f"'{self.ice_token}' not in '{self.template}'")
49 |
50 | # Check duplicates
51 | if len(self.column_token_map.values()) != len(set(self.column_token_map.values())):
52 | raise ValueError(f"There are duplicates in self.column_token_map.values()")
53 | if self.ice_token is not None and self.ice_token in self.column_token_map.values():
54 | raise ValueError(f"There are duplicates between self.column_token_map.values() and self.ice_token")
55 |
56 | def generate_ice_item(self, entry: Dict, label: Hashable) -> str:
57 | """Generate in-context example based on the provided :obj:`entry` data.
58 |
59 | Args:
60 | entry (:obj:`Dict`): A piece of data to be used for generating the in-context example.
61 | label (:obj:`Hashable`): The value of the output field.
62 |
63 | Returns:
64 | :obj:`str`: The generated in-context example.
65 | """
66 | # Select the corresponding template
67 | tp = self.template[label] if isinstance(self.template, Dict) else self.template
68 | # Remove sep token
69 | if self.sep_token is not None:
70 | tp.replace(self.sep_token, '')
71 | # Remove ice_token
72 | if self.ice_token is not None:
73 | tp = tp.replace(self.ice_token, '')
74 | # Replace context token
75 | for key, token in self.column_token_map.items():
76 | if self.selected_column_map is not None and key == self.selected_column_name:
77 | tp = tp.replace(token, str(self.selected_column_map[label]))
78 | else:
79 | tp = tp.replace(token, str(entry[key]))
80 | return tp
81 |
82 | def generate_label_prompt_item(self, entry: Dict, ice: str, label: Hashable, remain_sep: Optional[bool] = False) -> str:
83 | """Generate prompt based on :obj:`entry` data, :obj:`ice` in-context example, and the corresponding :obj:`label`.
84 |
85 | Args:
86 |
87 | entry (:obj:`Dict`): A piece of data containing the input field content.
88 | ice (:obj:`str`): The generated in-context example.
89 | label (:obj:`Hashable`): The value of the output field.
90 | remain_sep (:obj:`bool`): If remain sep_token
91 |
92 | Raises:
93 | ValueError: If the :obj:`ice_token` attribute of the :obj:`PromptTemplate` instance is :obj:`None`.
94 |
95 | Returns:
96 | :obj:`str`: The generated prompt.
97 | """
98 | if self.ice_token is None:
99 | raise ValueError("PromptTemplate.ice_token should be not None when generates prompt")
100 | # Select the corresponding template
101 | tp = self.template[label] if isinstance(self.template, Dict) else self.template
102 | # Remove sep token
103 | if not remain_sep and self.sep_token is not None:
104 | tp.replace(self.sep_token, '')
105 | # Insert in-context examples
106 | tp = tp.replace(self.ice_token, ice)
107 | # Replace context token
108 | for key, token in self.column_token_map.items():
109 | if self.selected_column_map is not None and key == self.selected_column_name:
110 | tp = tp.replace(token, str(self.selected_column_map[label]))
111 | else:
112 | tp = tp.replace(token, str(entry[key]))
113 | return tp
114 |
115 |
116 | def generate_item(self, entry: Dict, output_field: Optional[Hashable] = None,
117 | output_field_replace_token: Optional[str] = '',
118 | ice_field_replace_token: Optional[str] = '') -> str:
119 | """Generate an item based on the provided :obj:`entry` data, as well as optional output field and ice field tokens.
120 |
121 | Args:
122 | entry (:obj:`Dict`): A piece of data.
123 | output_field (:obj:`Hashable`, optional): Column name of output field. Defaults to :obj:`None`.
124 | output_field_replace_token (:obj:`str`, optional): Tokens used to replace output field. Defaults to ``''``.
125 | ice_field_replace_token (str, optional): Tokens used to replace the :obj:`ice_token`. Defaults to ``''``.
126 |
127 | Returns:
128 | :obj:`str`: The generated item.
129 | """
130 | tp = None
131 | if isinstance(self.template, str):
132 | tp = self.template
133 | else:
134 | pred_label = None
135 | if self.selected_column_name is not None:
136 | pred_label = entry[self.selected_column_name]
137 | if pred_label in self.template.keys():
138 | tp = self.template[pred_label]
139 | else:
140 | tp = self.template[list(self.template.keys())[0]]
141 | if self.ice_token is not None:
142 | tp = tp.replace(self.ice_token, ice_field_replace_token)
143 | # Remove sep token
144 | if self.sep_token is not None:
145 | tp.replace(self.sep_token, '')
146 | for key, token in self.column_token_map.items():
147 | if output_field is not None and key == output_field:
148 | tp = tp.replace(token, output_field_replace_token)
149 | else:
150 | tp = tp.replace(token, str(entry[key]))
151 | return tp
152 |
153 | def _check_prompt_template(obj) -> "PromptTemplate":
154 | if isinstance(obj, PromptTemplate):
155 | return obj
156 | else:
157 | raise TypeError(f"Expect a PromptTemplate object, but got {obj}")
158 |
159 | def __repr__(self):
160 | return f"PromptTemplate({{\n\ttemplate: {self.template},\n\tcolumn_token_map: {self.column_token_map},\n\tice_token: {self.ice_token}\n}})"
161 |
--------------------------------------------------------------------------------
/openicl/icl_retriever/__init__.py:
--------------------------------------------------------------------------------
1 | from .icl_base_retriever import BaseRetriever
2 | from .icl_random_retriever import RandomRetriever
3 | from .icl_bm25_retriever import BM25Retriever
4 | from .icl_dpp_retriever import DPPRetriever
5 | from .icl_topk_retriever import TopkRetriever
6 | from .icl_mdl_retriever import MDLRetriever
7 | from .icl_votek_retriever import VotekRetriever
8 | from .icl_zero_retriever import ZeroRetriever
9 |
--------------------------------------------------------------------------------
/openicl/icl_retriever/icl_base_retriever.py:
--------------------------------------------------------------------------------
1 | """Basic Retriever"""
2 |
3 | from datasets import Dataset, DatasetDict
4 | from typing import List, Union, Optional, Tuple, Dict
5 | from openicl import DatasetReader, PromptTemplate
6 | from openicl.utils.check_type import _check_str
7 | from accelerate import Accelerator
8 |
9 |
10 | class BaseRetriever:
11 | """Basic In-context Learning Retriever Class
12 | Base class for In-context Learning Retriever, without any retrieval method.
13 |
14 | Attributes:
15 | dataset_reader (:obj:`DatasetReader`): An instance of the :obj:`DatasetReader` class.
16 | ice_separator (:obj:`str`, optional): A string that separates each in-context example.
17 | ice_eos_token (:obj:`str`, optional): A string that is added to the end of in-context examples.
18 | prompt_eos_token (:obj:`str`, optional): A string that is added to the end of the prompt.
19 | ice_num (:obj:`int`, optional): The number of data in the in-context examples.
20 | index_split (:obj:`str`, optional): A string for the index dataset name. The index dataset is used to select data for in-context examples. Defaults to ``train``.
21 | test_split (:obj:`str`, optional): A string for the generation dataset name. The test dataset is used to generate prompts for each data. Defaults to ``test``.
22 | index_ds (:obj:`Dataset`): The index dataset. Used to select data for in-context examples.
23 | test_ds (:obj:`Dataset`): The test dataset. Used to generate prompts for each data.
24 | accelerator (:obj:`Accelerator`, optional): An instance of the :obj:`Accelerator` class, used for multiprocessing.
25 | """
26 | index_ds = None
27 | test_ds = None
28 |
29 | def __init__(self,
30 | dataset_reader: DatasetReader,
31 | ice_separator: Optional[str] = '\n',
32 | ice_eos_token: Optional[str] = '\n',
33 | prompt_eos_token: Optional[str] = '',
34 | ice_num: Optional[int] = 1,
35 | index_split: Optional[str] = 'train',
36 | test_split: Optional[str] = 'test',
37 | accelerator: Optional[Accelerator] = None
38 | ) -> None:
39 | self.dataset_reader = DatasetReader._check_dataset_reader(dataset_reader)
40 | self.ice_separator = ice_separator
41 | self.ice_eos_token = ice_eos_token
42 | self.prompt_eos_token = prompt_eos_token
43 | self.ice_num = ice_num
44 | self.index_split = index_split
45 | self.test_split = test_split
46 | self.accelerator = accelerator
47 | self.is_main_process = True if self.accelerator is None or self.accelerator.is_main_process else False
48 | if isinstance(self.dataset_reader.dataset, Dataset):
49 | self.index_ds = self.dataset_reader.dataset
50 | self.test_ds = self.dataset_reader.dataset
51 | if self.accelerator is not None:
52 | self.test_ds = self.test_ds.shard(
53 | num_shards=self.accelerator.num_processes,
54 | index=self.accelerator.process_index
55 | )
56 | else:
57 | self.index_ds = self.dataset_reader.dataset[self.index_split]
58 | self.test_ds = self.dataset_reader.dataset[self.test_split]
59 |
60 | if self.accelerator is not None:
61 | self.test_ds = self.test_ds.shard(
62 | num_shards=self.accelerator.num_processes,
63 | index=self.accelerator.process_index
64 | )
65 |
66 | def retrieve(self) -> List[List]:
67 | """
68 | Retrieve for each data in generation_ds.
69 |
70 | Returns:
71 | `List[List]`: the index list of in-context example for each data in `test_ds`.
72 | """
73 | raise NotImplementedError("Method hasn't been implemented yet")
74 |
75 | def get_labels(self, ice_template: Optional[PromptTemplate] = None,
76 | prompt_template: Optional[PromptTemplate] = None):
77 | labels = []
78 | if prompt_template is not None and isinstance(prompt_template.template, Dict):
79 | labels = list(prompt_template.template.keys())[:]
80 | elif ice_template is not None and ice_template.ice_token is not None and isinstance(ice_template.template,
81 | Dict):
82 | labels = list(ice_template.template.keys())[:]
83 | else:
84 | labels = list(set(self.test_ds[self.dataset_reader.output_column]))
85 | return labels
86 |
87 | def generate_ice(self, idx_list: List[int], ice_template: Optional[PromptTemplate] = None) -> str:
88 | generated_ice_list = []
89 | dr = self.dataset_reader
90 | for idx in idx_list:
91 | if ice_template is None:
92 | generated_ice_list.append(' '.join(list(map(str,
93 | [self.index_ds[idx][ctx] for ctx in dr.input_columns] + [
94 | self.index_ds[idx][dr.output_column]]))))
95 | else:
96 | generated_ice_list.append(
97 | ice_template.generate_ice_item(self.index_ds[idx], self.index_ds[idx][dr.output_column]))
98 | generated_ice = self.ice_separator.join(generated_ice_list) + self.ice_eos_token
99 | return generated_ice
100 |
101 | def generate_prompt(self, idx: int, ice: str, ice_template: Optional[PromptTemplate] = None,
102 | prompt_template: Optional[PromptTemplate] = None) -> Tuple[List[str], List]:
103 | prompt_list = []
104 | labels = []
105 | if prompt_template is not None and isinstance(prompt_template.template, Dict):
106 | labels = list(prompt_template.template.keys())[:]
107 | elif ice_template is not None and isinstance(ice_template.template,
108 | Dict) and ice_template.ice_token is not None:
109 | labels = list(ice_template.template.keys())[:]
110 | else:
111 | labels = list(set(self.test_ds[self.dataset_reader.output_column]))
112 | for label in labels:
113 | prompt_list.append(self.generate_label_prompt(idx, ice, label))
114 | return prompt_list, labels
115 |
116 | def generate_label_prompt(self, idx: int, ice: str, label, ice_template: Optional[PromptTemplate] = None,
117 | prompt_template: Optional[PromptTemplate] = None, remain_sep: Optional[bool] = False) -> str:
118 | if prompt_template is not None:
119 | return prompt_template.generate_label_prompt_item(self.test_ds[idx], ice, label, remain_sep) + self.prompt_eos_token
120 | elif ice_template is not None and ice_template.ice_token is not None:
121 | return ice_template.generate_label_prompt_item(self.test_ds[idx], ice, label, remain_sep) + self.prompt_eos_token
122 | else:
123 | prefix_prompt = ' '.join(
124 | list(map(str, [self.test_ds[idx][ctx] for ctx in self.dataset_reader.input_columns])))
125 | return ice + prefix_prompt + ' ' + str(label) + self.prompt_eos_token
126 |
127 | def generate_prompt_for_generate_task(self, idx, ice, gen_field_replace_token='',
128 | ice_template: Optional[PromptTemplate] = None,
129 | prompt_template: Optional[PromptTemplate] = None):
130 | if prompt_template is not None:
131 | return prompt_template.generate_item(self.test_ds[idx], output_field=self.dataset_reader.output_column,
132 | output_field_replace_token=gen_field_replace_token,
133 | ice_field_replace_token=ice) + self.prompt_eos_token
134 | elif ice_template is not None and ice_template.ice_token is not None:
135 | return ice_template.generate_item(self.test_ds[idx], output_field=self.dataset_reader.output_column,
136 | output_field_replace_token=gen_field_replace_token,
137 | ice_field_replace_token=ice) + self.prompt_eos_token
138 | else:
139 | prefix_prompt = ' '.join(
140 | list(map(str, [self.test_ds[idx][ctx] for ctx in self.dataset_reader.input_columns])))
141 | return ice + prefix_prompt + gen_field_replace_token + self.prompt_eos_token
142 |
--------------------------------------------------------------------------------
/openicl/icl_retriever/icl_bm25_retriever.py:
--------------------------------------------------------------------------------
1 | """BM25 Retriever"""
2 |
3 | from openicl import DatasetReader
4 | from openicl.icl_retriever import BaseRetriever
5 | from openicl.utils.logging import get_logger
6 | from typing import List, Union, Optional
7 | from rank_bm25 import BM25Okapi
8 | import numpy as np
9 | from tqdm import trange
10 | from accelerate import Accelerator
11 | from nltk.tokenize import word_tokenize
12 |
13 | logger = get_logger(__name__)
14 |
15 |
16 | class BM25Retriever(BaseRetriever):
17 | """BM25 In-context Learning Retriever Class
18 | Class of BM25 Retriever.
19 |
20 | Attributes:
21 | dataset_reader (:obj:`DatasetReader`): An instance of the :obj:`DatasetReader` class.
22 | ice_separator (:obj:`str`, optional): A string that separates each in-context example.
23 | ice_eos_token (:obj:`str`, optional): A string that is added to the end of in-context examples.
24 | prompt_eos_token (:obj:`str`, optional): A string that is added to the end of the prompt.
25 | ice_num (:obj:`int`, optional): The number of data in the in-context examples.
26 | index_split (:obj:`str`, optional): A string for the index dataset name. The index dataset is used to select data for in-context examples. Defaults to ``train``.
27 | test_split (:obj:`str`, optional): A string for the generation dataset name. The test dataset is used to generate prompts for each data. Defaults to ``test``.
28 | index_ds (:obj:`Dataset`): The index dataset. Used to select data for in-context examples.
29 | test_ds (:obj:`Dataset`): The test dataset. Used to generate prompts for each data.
30 | accelerator (:obj:`Accelerator`, optional): An instance of the :obj:`Accelerator` class, used for multiprocessing.
31 | index_corpus (:obj:`List[str]`) : A corpus created from the input field data of :obj:`index_ds`.
32 | test_corpus (:obj:`List[str]`) : A corpus created from the input field data of :obj:`test_ds`.
33 | bm25 (:obj:`BM250kapi`): An instance of :obj:`BM250kapi` class, initialized using :obj:`index_ds`.
34 | """
35 | bm25 = None
36 | index_corpus = None
37 | test_corpus = None
38 |
39 | def __init__(self,
40 | dataset_reader: DatasetReader,
41 | ice_separator: Optional[str] = '\n',
42 | ice_eos_token: Optional[str] = '\n',
43 | prompt_eos_token: Optional[str] = '',
44 | ice_num: Optional[int] = 1,
45 | index_split: Optional[str] = 'train',
46 | test_split: Optional[str] = 'test',
47 | accelerator: Optional[Accelerator] = None
48 | ) -> None:
49 | super().__init__(dataset_reader, ice_separator, ice_eos_token, prompt_eos_token, ice_num, index_split,
50 | test_split, accelerator)
51 | self.index_corpus = [word_tokenize(data) for data in
52 | self.dataset_reader.generate_input_field_corpus(self.index_ds)]
53 | self.bm25 = BM25Okapi(self.index_corpus)
54 | self.test_corpus = [word_tokenize(data) for data in
55 | self.dataset_reader.generate_input_field_corpus(self.test_ds)]
56 |
57 | def retrieve(self) -> List[List]:
58 | rtr_idx_list = []
59 | logger.info("Retrieving data for test set...")
60 | for idx in trange(len(self.test_corpus), disable=not self.is_main_process):
61 | query = self.test_corpus[idx]
62 | scores = self.bm25.get_scores(query)
63 | near_ids = list(np.argsort(scores)[::-1][:self.ice_num])
64 | near_ids = [int(a) for a in near_ids]
65 | rtr_idx_list.append(near_ids)
66 | return rtr_idx_list
67 |
--------------------------------------------------------------------------------
/openicl/icl_retriever/icl_dpp_retriever.py:
--------------------------------------------------------------------------------
1 | """DPP Retriever"""
2 |
3 | from openicl import DatasetReader
4 | from openicl.icl_retriever.icl_topk_retriever import TopkRetriever
5 | from openicl.utils.logging import get_logger
6 | from typing import Optional
7 | import tqdm
8 | import numpy as np
9 | import math
10 | from accelerate import Accelerator
11 |
12 | logger = get_logger(__name__)
13 |
14 |
15 | class DPPRetriever(TopkRetriever):
16 | """DPP In-context Learning Retriever Class
17 | Class of DPP Retriever.
18 | Two-stage DPP is used, where first stage is to get results of TopK to reduce candidate sets
19 | chechout https://arxiv.org/abs/2302.05698 for details.
20 |
21 | Attributes:
22 | dataset_reader (:obj:`DatasetReader`): An instance of the :obj:`DatasetReader` class.
23 | ice_separator (:obj:`str`, optional): A string that separates each in-context example.
24 | ice_eos_token (:obj:`str`, optional): A string that is added to the end of in-context examples.
25 | prompt_eos_token (:obj:`str`, optional): A string that is added to the end of the prompt.
26 | ice_num (:obj:`int`, optional): The number of data in the in-context examples.
27 | index_split (:obj:`str`, optional): A string for the index dataset name. The index dataset is used to select data for in-context examples. Defaults to ``train``.
28 | test_split (:obj:`str`, optional): A string for the generation dataset name. The test dataset is used to generate prompts for each data. Defaults to ``test``.
29 | index_ds (:obj:`Dataset`): The index dataset. Used to select data for in-context examples.
30 | test_ds (:obj:`Dataset`): The test dataset. Used to generate prompts for each data.
31 | accelerator (:obj:`Accelerator`, optional): An instance of the :obj:`Accelerator` class, used for multiprocessing.
32 | batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`.
33 | model (:obj:`SentenceTransformer`): An instance of :obj:`SentenceTransformer` class, used to calculate embeddings.
34 | tokenizer (:obj:`AutoTokenizer`): Tokenizer for :obj:`model`.
35 | index (:obj:`IndexIDMap`): Index generated with FAISS.
36 | seed (:obj:`int`, optional): Seed for the random number generator. (:obj:`random_state` in :obj:`sample_exact_k_dpp` method)
37 | scale_factor (:obj:`float`, optional): A factor when gets the kernel.
38 | """
39 | model = None
40 |
41 | def __init__(self,
42 | dataset_reader: DatasetReader,
43 | ice_separator: Optional[str] = '\n',
44 | ice_eos_token: Optional[str] = '\n',
45 | prompt_eos_token: Optional[str] = '',
46 | sentence_transformers_model_name: Optional[str] = 'all-mpnet-base-v2',
47 | ice_num: Optional[int] = 1,
48 | candidate_num: Optional[int] = 1,
49 | index_split: Optional[str] = 'train',
50 | test_split: Optional[str] = 'test',
51 | tokenizer_name: Optional[str] = 'gpt2-xl',
52 | batch_size: Optional[int] = 1,
53 | accelerator: Optional[Accelerator] = None,
54 | seed: Optional[int] = 1,
55 | scale_factor: Optional[float] = 0.1
56 | ) -> None:
57 | super().__init__(dataset_reader, ice_separator, ice_eos_token, prompt_eos_token,
58 | sentence_transformers_model_name, ice_num, index_split, test_split, tokenizer_name, batch_size,
59 | accelerator)
60 | self.candidate_num = candidate_num
61 | self.seed = seed
62 | self.scale_factor = scale_factor
63 |
64 | def dpp_search(self):
65 | res_list = self.forward(self.dataloader, process_bar=True, information="Embedding test set...")
66 | rtr_idx_list = [[] for _ in range(len(res_list))]
67 | logger.info("Retrieving data for test set...")
68 | for entry in tqdm.tqdm(res_list, disable=not self.is_main_process):
69 | idx = entry['metadata']['id']
70 |
71 | # get TopK results
72 | embed = np.expand_dims(entry['embed'], axis=0)
73 | near_ids = np.array(self.index.search(embed, self.candidate_num)[1][0].tolist())
74 |
75 | # DPP stage
76 | near_reps, rel_scores, kernel_matrix = self.get_kernel(embed, near_ids.tolist())
77 |
78 | # MAP inference
79 | samples_ids = fast_map_dpp(kernel_matrix, self.ice_num)
80 |
81 | # ordered by relevance score
82 | samples_scores = np.array([rel_scores[i] for i in samples_ids])
83 | samples_ids = samples_ids[(-samples_scores).argsort()].tolist()
84 | rtr_sub_list = [int(near_ids[i]) for i in samples_ids]
85 |
86 | rtr_idx_list[idx] = rtr_sub_list
87 |
88 | return rtr_idx_list
89 |
90 | def retrieve(self):
91 | return self.dpp_search()
92 |
93 | def get_kernel(self, embed, candidates):
94 | near_reps = np.stack([self.index.index.reconstruct(i) for i in candidates], axis=0)
95 | # normalize first
96 | embed = embed / np.linalg.norm(embed)
97 | near_reps = near_reps / np.linalg.norm(near_reps, keepdims=True, axis=1)
98 |
99 | # to make kernel-matrix non-negative
100 | rel_scores = np.matmul(embed, near_reps.T)[0]
101 | rel_scores = (rel_scores + 1) / 2
102 |
103 | # to prevent overflow error
104 | rel_scores -= rel_scores.max()
105 |
106 | # to balance relevance and diversity
107 | rel_scores = np.exp(rel_scores / (2 * self.scale_factor))
108 |
109 | # to make kernel-matrix non-negative
110 | sim_matrix = np.matmul(near_reps, near_reps.T)
111 | sim_matrix = (sim_matrix + 1) / 2
112 |
113 | kernel_matrix = rel_scores[None] * sim_matrix * rel_scores[:, None]
114 | return near_reps, rel_scores, kernel_matrix
115 |
116 |
117 | def fast_map_dpp(kernel_matrix, max_length):
118 | """
119 | fast implementation of the greedy algorithm
120 | reference: https://github.com/laming-chen/fast-map-dpp/blob/master/dpp_test.py
121 | paper: Fast Greedy MAP Inference for Determinantal Point Process to Improve Recommendation Diversity
122 | """
123 | item_size = kernel_matrix.shape[0]
124 | cis = np.zeros((max_length, item_size))
125 | di2s = np.copy(np.diag(kernel_matrix))
126 | selected_items = list()
127 | selected_item = np.argmax(di2s)
128 | selected_items.append(int(selected_item))
129 | while len(selected_items) < max_length:
130 | k = len(selected_items) - 1
131 | ci_optimal = cis[:k, selected_item]
132 | di_optimal = math.sqrt(di2s[selected_item])
133 | elements = kernel_matrix[selected_item, :]
134 | eis = (elements - np.dot(ci_optimal, cis[:k, :])) / di_optimal
135 | cis[k, :] = eis
136 | di2s -= np.square(eis)
137 | selected_item = np.argmax(di2s)
138 | selected_items.append(int(selected_item))
139 | return selected_items
140 |
--------------------------------------------------------------------------------
/openicl/icl_retriever/icl_mdl_retriever.py:
--------------------------------------------------------------------------------
1 | """MDL Retriever"""
2 |
3 | from openicl import DatasetReader, PromptTemplate
4 | from openicl.icl_retriever.icl_topk_retriever import TopkRetriever
5 | from openicl.utils.calculate import entropy
6 | from openicl.utils.logging import get_logger
7 | from typing import List, Union, Optional, Tuple
8 | from transformers import AutoModelForCausalLM
9 | import tqdm
10 | import torch
11 | import numpy as np
12 | from accelerate import Accelerator
13 |
14 | logger = get_logger(__name__)
15 |
16 |
17 | class MDLRetriever(TopkRetriever):
18 | """MDL In-context Learning Retriever Class
19 | Class of MDL Retriever.
20 |
21 | Attributes:
22 | dataset_reader (:obj:`DatasetReader`): An instance of the :obj:`DatasetReader` class.
23 | ice_separator (:obj:`str`, optional): A string that separates each in-context example.
24 | ice_eos_token (:obj:`str`, optional): A string that is added to the end of in-context examples.
25 | prompt_eos_token (:obj:`str`, optional): A string that is added to the end of the prompt.
26 | ice_num (:obj:`int`, optional): The number of data in the in-context examples.
27 | candidate_num (:obj:`int`, optional): The number of data selected in TopK stage.
28 | index_split (:obj:`str`, optional): A string for the index dataset name. The index dataset is used to select data for in-context examples. Defaults to ``train``.
29 | test_split (:obj:`str`, optional): A string for the generation dataset name. The test dataset is used to generate prompts for each data. Defaults to ``test``.
30 | index_ds (:obj:`Dataset`): The index dataset. Used to select data for in-context examples.
31 | test_ds (:obj:`Dataset`): The test dataset. Used to generate prompts for each data.
32 | accelerator (:obj:`Accelerator`, optional): An instance of the :obj:`Accelerator` class, used for multiprocessing.
33 | batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`.
34 | model (:obj:`SentenceTransformer`): An instance of :obj:`SentenceTransformer` class, used to calculate embeddings.
35 | tokenizer (:obj:`AutoTokenizer`): Tokenizer for :obj:`model`.
36 | index (:obj:`IndexIDMap`): Index generated with FAISS.
37 | select_time (:obj:`int`, optional): Number of random selections in the MDL stage.
38 | labels (:obj:`List`, optional): A list of labels for all classes used to generate prompts when calculating MDL.
39 | seed (:obj:`int`, optional): Seed for the random number generator.
40 | """
41 | metric_model = None
42 |
43 | def __init__(self,
44 | dataset_reader: DatasetReader,
45 | ice_separator: Optional[str] = '\n',
46 | ice_eos_token: Optional[str] = '\n',
47 | prompt_eos_token: Optional[str] = '',
48 | sentence_transformers_model_name: Optional[str] = 'all-mpnet-base-v2',
49 | ice_num: Optional[int] = 1,
50 | candidate_num: Optional[int] = 1,
51 | index_split: Optional[str] = 'train',
52 | test_split: Optional[str] = 'test',
53 | tokenizer_name: Optional[str] = 'gpt2-xl',
54 | ce_model_name: Optional[str] = 'gpt2-xl',
55 | batch_size: Optional[int] = 1,
56 | select_time: Optional[int] = 5,
57 | accelerator: Optional[Accelerator] = None,
58 | ice_template: Optional[PromptTemplate] = None,
59 | prompt_template: Optional[PromptTemplate] = None,
60 | labels: Optional[List] = None,
61 | seed: Optional[int] = 1
62 | ) -> None:
63 | super().__init__(dataset_reader, ice_separator, ice_eos_token, prompt_eos_token,
64 | sentence_transformers_model_name, ice_num, index_split, test_split, tokenizer_name, batch_size,
65 | accelerator)
66 | self.ce_model_name = ce_model_name
67 | self.candidate_num = candidate_num
68 | self.select_time = select_time
69 | self.ice_template = ice_template
70 | self.prompt_template = prompt_template
71 | self.labels = labels
72 | self.seed = seed
73 |
74 | def topk_search(self):
75 | np.random.seed(self.seed)
76 | res_list = self.forward(self.dataloader)
77 | rtr_idx_list = [[] for _ in range(len(res_list))]
78 |
79 | logger.info("Retrieving data for test set...")
80 | for entry in tqdm.tqdm(res_list, disable=not self.is_main_process):
81 | idx = entry['metadata']['id']
82 |
83 | embed = np.expand_dims(entry['embed'], axis=0)
84 | near_ids = self.index.search(embed, min(self.candidate_num, len(self.index_ds)))[1][0].tolist()
85 | candidates = []
86 | mdl_scores = []
87 | for j in range(self.select_time):
88 | if j == 0:
89 | rand_idx_list = near_ids[:self.ice_num]
90 | else:
91 | rand_idx_list = np.random.choice(near_ids, self.ice_num, replace=False)
92 | rand_idx_list = [int(i) for i in rand_idx_list]
93 | candidates.append(rand_idx_list)
94 |
95 | ice = self.generate_ice(rand_idx_list, ice_template=self.ice_template)
96 | mask_length = len(self.tokenizer(ice + self.ice_eos_token, verbose=False)['input_ids'])
97 | if self.labels is None:
98 | labels = self.get_labels(self.ice_template, self.prompt_template)
99 | else:
100 | labels = self.labels
101 | prompt_list = []
102 | for label in labels:
103 | prompt = self.generate_label_prompt(idx, ice, label, self.ice_template, self.prompt_template)
104 | prompt_list.append(prompt)
105 | loss_list = self.cal_ce(prompt_list, mask_length=mask_length)
106 | probs = np.exp(-np.array(loss_list))
107 | normalized_probs = probs / probs.sum(0, keepdims=True)
108 | neg_entropy = -entropy(normalized_probs, label_dim=0)
109 | mdl_scores.append(neg_entropy)
110 |
111 | rtr_idx_list[idx] = candidates[mdl_scores.index(max(mdl_scores))]
112 | rtr_idx_list[idx] = [int(i) for i in rtr_idx_list[idx]]
113 |
114 | return rtr_idx_list
115 |
116 | def retrieve(self):
117 | return self.topk_search()
118 |
119 | def cal_ce(self, input_texts: List[str], mask_length=None):
120 | if self.metric_model is None:
121 | logger.info(f'Load model {self.metric_model} for calculating MDL...')
122 | self.metric_model = AutoModelForCausalLM.from_pretrained(self.ce_model_name)
123 | self.metric_model.to(self.device)
124 | inputs = self.tokenizer(input_texts, padding=True, return_tensors='pt', truncation=True)
125 | inputs = {k: v.to(self.device) for k, v in inputs.items()}
126 | outputs = self.metric_model(**inputs)
127 |
128 | shift_logits = outputs.logits[..., :-1, :].contiguous()
129 | shift_labels = inputs["input_ids"][..., 1:].contiguous()
130 |
131 | loss_fct = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=self.tokenizer.pad_token_id)
132 | shift_logits = shift_logits.view(-1, shift_logits.size(-1))
133 | loss = loss_fct(shift_logits, shift_labels.view(-1)).view(shift_labels.size())
134 | if mask_length is not None:
135 | mask = torch.cat([torch.zeros([loss.shape[0], mask_length], dtype=torch.float),
136 | torch.ones([loss.shape[0], loss.shape[-1] - mask_length], dtype=torch.float)], -1)
137 | mask = mask.to(self.device)
138 | loss = torch.mul(mask, loss)
139 |
140 | lens = (inputs["input_ids"] != self.tokenizer.pad_token_id).sum(-1).cpu().numpy()
141 | if mask_length is not None:
142 | lens -= mask_length
143 | ce_loss = loss.sum(-1).cpu().detach().numpy() / lens
144 | return ce_loss
145 |
--------------------------------------------------------------------------------
/openicl/icl_retriever/icl_random_retriever.py:
--------------------------------------------------------------------------------
1 | '''Random Retriever'''
2 |
3 | from openicl import DatasetReader
4 | from openicl.icl_retriever import BaseRetriever
5 | from openicl.utils.logging import get_logger
6 | from typing import List, Union, Optional
7 | from tqdm import trange
8 | import numpy as np
9 | from accelerate import Accelerator
10 |
11 | logger = get_logger(__name__)
12 |
13 |
14 | class RandomRetriever(BaseRetriever):
15 | """Random In-context Learning Retriever Class
16 | Class of Random Retriever.
17 |
18 | Attributes:
19 | dataset_reader (:obj:`DatasetReader`): An instance of the :obj:`DatasetReader` class.
20 | ice_separator (:obj:`str`, optional): A string that separates each in-context example.
21 | ice_eos_token (:obj:`str`, optional): A string that is added to the end of in-context examples.
22 | prompt_eos_token (:obj:`str`, optional): A string that is added to the end of the prompt.
23 | ice_num (:obj:`int`, optional): The number of data in the in-context examples.
24 | index_split (:obj:`str`, optional): A string for the index dataset name. The index dataset is used to select data for in-context examples. Defaults to ``train``.
25 | test_split (:obj:`str`, optional): A string for the generation dataset name. The test dataset is used to generate prompts for each data. Defaults to ``test``.
26 | index_ds (:obj:`Dataset`): The index dataset. Used to select data for in-context examples.
27 | test_ds (:obj:`Dataset`): The test dataset. Used to generate prompts for each data.
28 | accelerator (:obj:`Accelerator`, optional): An instance of the :obj:`Accelerator` class, used for multiprocessing.
29 | seed (`int`, optional): Seed for the random number generator.
30 | """
31 |
32 | def __init__(self,
33 | dataset_reader: DatasetReader,
34 | ice_separator: Optional[str] = '\n',
35 | ice_eos_token: Optional[str] = '\n',
36 | prompt_eos_token: Optional[str] = '',
37 | ice_num: Optional[int] = 1,
38 | index_split: Optional[str] = 'train',
39 | test_split: Optional[str] = 'test',
40 | seed: Optional[int] = 43,
41 | accelerator: Optional[Accelerator] = None
42 | ) -> None:
43 | super().__init__(dataset_reader, ice_separator, ice_eos_token, prompt_eos_token, ice_num, index_split,
44 | test_split, accelerator)
45 | self.seed = seed
46 |
47 | def retrieve(self):
48 | np.random.seed(self.seed)
49 | num_idx = len(self.index_ds)
50 | rtr_idx_list = []
51 | logger.info("Retrieving data for test set...")
52 | for _ in trange(len(self.test_ds), disable=not self.is_main_process):
53 | idx_list = np.random.choice(num_idx, self.ice_num, replace=False).tolist()
54 | rtr_idx_list.append(idx_list)
55 | return rtr_idx_list
56 |
--------------------------------------------------------------------------------
/openicl/icl_retriever/icl_topk_retriever.py:
--------------------------------------------------------------------------------
1 | """Topk Retriever"""
2 |
3 | from openicl import DatasetReader
4 | from openicl.icl_dataset_reader import DatasetEncoder
5 | from openicl.icl_retriever import BaseRetriever
6 | from openicl.utils.collators import DataCollatorWithPaddingAndCuda
7 | from openicl.utils.logging import get_logger
8 | import torch
9 | from torch.utils.data import DataLoader
10 | from typing import Optional
11 | from transformers import AutoTokenizer
12 | from sentence_transformers import SentenceTransformer
13 | import tqdm
14 | import faiss
15 | import copy
16 | import numpy as np
17 | from accelerate import Accelerator
18 |
19 | logger = get_logger(__name__)
20 |
21 |
22 | class TopkRetriever(BaseRetriever):
23 | """Topk In-context Learning Retriever Class
24 | Class of Topk Retriever.
25 |
26 | Attributes:
27 | dataset_reader (:obj:`DatasetReader`): An instance of the :obj:`DatasetReader` class.
28 | ice_separator (:obj:`str`, optional): A string that separates each in-context example.
29 | ice_eos_token (:obj:`str`, optional): A string that is added to the end of in-context examples.
30 | prompt_eos_token (:obj:`str`, optional): A string that is added to the end of the prompt.
31 | ice_num (:obj:`int`, optional): The number of data in the in-context examples.
32 | index_split (:obj:`str`, optional): A string for the index dataset name. The index dataset is used to select data for in-context examples. Defaults to ``train``.
33 | test_split (:obj:`str`, optional): A string for the generation dataset name. The test dataset is used to generate prompts for each data. Defaults to ``test``.
34 | index_ds (:obj:`Dataset`): The index dataset. Used to select data for in-context examples.
35 | test_ds (:obj:`Dataset`): The test dataset. Used to generate prompts for each data.
36 | accelerator (:obj:`Accelerator`, optional): An instance of the :obj:`Accelerator` class, used for multiprocessing.
37 | batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`.
38 | model (:obj:`SentenceTransformer`): An instance of :obj:`SentenceTransformer` class, used to calculate embeddings.
39 | tokenizer (:obj:`AutoTokenizer`): Tokenizer for :obj:`model`.
40 | index (:obj:`IndexIDMap`): Index generated with FAISS.
41 | """
42 | model = None
43 |
44 | def __init__(self,
45 | dataset_reader: DatasetReader,
46 | ice_separator: Optional[str] = '\n',
47 | ice_eos_token: Optional[str] = '\n',
48 | prompt_eos_token: Optional[str] = '',
49 | sentence_transformers_model_name: Optional[str] = 'all-mpnet-base-v2',
50 | ice_num: Optional[int] = 1,
51 | index_split: Optional[str] = 'train',
52 | test_split: Optional[str] = 'test',
53 | tokenizer_name: Optional[str] = 'gpt2-xl',
54 | batch_size: Optional[int] = 1,
55 | accelerator: Optional[Accelerator] = None
56 | ) -> None:
57 | super().__init__(dataset_reader, ice_separator, ice_eos_token, prompt_eos_token, ice_num, index_split,
58 | test_split, accelerator)
59 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
60 | self.batch_size = batch_size
61 | self.tokenizer_name = tokenizer_name
62 | gen_datalist = self.dataset_reader.generate_input_field_corpus(self.test_ds)
63 |
64 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
65 | self.tokenizer.pad_token = self.tokenizer.eos_token
66 | self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
67 | self.tokenizer.padding_side = "right"
68 |
69 | self.encode_dataset = DatasetEncoder(gen_datalist, tokenizer=self.tokenizer)
70 | co = DataCollatorWithPaddingAndCuda(tokenizer=self.tokenizer, device=self.device)
71 | self.dataloader = DataLoader(self.encode_dataset, batch_size=self.batch_size, collate_fn=co)
72 |
73 | self.model = SentenceTransformer(sentence_transformers_model_name)
74 |
75 | self.model = self.model.to(self.device)
76 | self.model.eval()
77 |
78 | self.index = self.create_index()
79 |
80 | def create_index(self):
81 | self.select_datalist = self.dataset_reader.generate_input_field_corpus(self.index_ds)
82 | encode_datalist = DatasetEncoder(self.select_datalist, tokenizer=self.tokenizer)
83 | co = DataCollatorWithPaddingAndCuda(tokenizer=self.tokenizer, device=self.device)
84 | dataloader = DataLoader(encode_datalist, batch_size=self.batch_size, collate_fn=co)
85 | index = faiss.IndexIDMap(faiss.IndexFlatIP(self.model.get_sentence_embedding_dimension()))
86 | res_list = self.forward(dataloader, process_bar=True, information="Creating index for index set...")
87 | id_list = np.array([res['metadata']['id'] for res in res_list])
88 | self.embed_list = np.stack([res['embed'] for res in res_list])
89 | index.add_with_ids(self.embed_list, id_list)
90 | return index
91 |
92 | def knn_search(self, ice_num):
93 | res_list = self.forward(self.dataloader, process_bar=True, information="Embedding test set...")
94 | rtr_idx_list = [[] for _ in range(len(res_list))]
95 | logger.info("Retrieving data for test set...")
96 | for entry in tqdm.tqdm(res_list, disable=not self.is_main_process):
97 | idx = entry['metadata']['id']
98 | embed = np.expand_dims(entry['embed'], axis=0)
99 | near_ids = self.index.search(embed, ice_num)[1][0].tolist()
100 | rtr_idx_list[idx] = near_ids
101 | return rtr_idx_list
102 |
103 | def forward(self, dataloader, process_bar=False, information=''):
104 | res_list = []
105 | _dataloader = copy.deepcopy(dataloader)
106 | if process_bar:
107 | logger.info(information)
108 | _dataloader = tqdm.tqdm(_dataloader, disable=not self.is_main_process)
109 | for _, entry in enumerate(_dataloader):
110 | with torch.no_grad():
111 | metadata = entry.pop("metadata")
112 | raw_text = self.tokenizer.batch_decode(entry['input_ids'], skip_special_tokens=True, verbose=False)
113 | res = self.model.encode(raw_text, show_progress_bar=False)
114 | res_list.extend([{"embed": r, "metadata": m} for r, m in zip(res, metadata)])
115 | return res_list
116 |
117 | def retrieve(self):
118 | return self.knn_search(self.ice_num)
119 |
--------------------------------------------------------------------------------
/openicl/icl_retriever/icl_votek_retriever.py:
--------------------------------------------------------------------------------
1 | """Votek Retriever"""
2 |
3 | import os
4 | import json
5 | from openicl import DatasetReader
6 | from openicl.icl_retriever.icl_topk_retriever import TopkRetriever
7 | from typing import List, Union, Optional, Tuple
8 | from sklearn.metrics.pairwise import cosine_similarity
9 | from collections import defaultdict
10 | import numpy as np
11 | import random
12 | from accelerate import Accelerator
13 |
14 |
15 | class VotekRetriever(TopkRetriever):
16 | """Vote-k In-context Learning Retriever Class
17 | Class of Vote-k Retriever.
18 |
19 | Attributes:
20 | dataset_reader (:obj:`DatasetReader`): An instance of the :obj:`DatasetReader` class.
21 | ice_separator (:obj:`str`, optional): A string that separates each in-context example.
22 | ice_eos_token (:obj:`str`, optional): A string that is added to the end of in-context examples.
23 | prompt_eos_token (:obj:`str`, optional): A string that is added to the end of the prompt.
24 | ice_num (:obj:`int`, optional): The number of data in the in-context examples.
25 | index_split (:obj:`str`, optional): A string for the index dataset name. The index dataset is used to select data for in-context examples. Defaults to ``train``.
26 | test_split (:obj:`str`, optional): A string for the generation dataset name. The test dataset is used to generate prompts for each data. Defaults to ``test``.
27 | index_ds (:obj:`Dataset`): The index dataset. Used to select data for in-context examples.
28 | test_ds (:obj:`Dataset`): The test dataset. Used to generate prompts for each data.
29 | accelerator (:obj:`Accelerator`, optional): An instance of the :obj:`Accelerator` class, used for multiprocessing.
30 | batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`.
31 | model (:obj:`SentenceTransformer`): An instance of :obj:`SentenceTransformer` class, used to calculate embeddings.
32 | tokenizer (:obj:`AutoTokenizer`): Tokenizer for :obj:``model``.
33 | index (:obj:`IndexIDMap`): Index generated with FAISS.
34 | votek_k (:obj:`int`, optional): ``k`` value of Voke-k Selective Annotation Algorithm. Defaults to ``3``.
35 | """
36 |
37 | def __init__(self,
38 | dataset_reader: DatasetReader,
39 | ice_separator: Optional[str] = '\n',
40 | ice_eos_token: Optional[str] = '\n',
41 | prompt_eos_token: Optional[str] = '',
42 | sentence_transformers_model_name: Optional[str] = 'all-mpnet-base-v2',
43 | ice_num: Optional[int] = 1,
44 | index_split: Optional[str] = 'train',
45 | test_split: Optional[str] = 'test',
46 | tokenizer_name: Optional[str] = 'gpt2-xl',
47 | batch_size: Optional[int] = 1,
48 | votek_k: Optional[int] = 3,
49 | accelerator: Optional[Accelerator] = None,
50 | ) -> None:
51 | super().__init__(dataset_reader, ice_separator, ice_eos_token, prompt_eos_token,
52 | sentence_transformers_model_name, ice_num, index_split, test_split, tokenizer_name, batch_size,
53 | accelerator)
54 | self.votek_k = votek_k
55 |
56 | def votek_select(self, embeddings=None, select_num=None, k=None, overlap_threshold=None, vote_file=None):
57 | n = len(embeddings)
58 | if vote_file is not None and os.path.isfile(vote_file):
59 | with open(vote_file) as f:
60 | vote_stat = json.load(f)
61 | else:
62 | vote_stat = defaultdict(list)
63 |
64 | for i in range(n):
65 | cur_emb = embeddings[i].reshape(1, -1)
66 | cur_scores = np.sum(cosine_similarity(embeddings, cur_emb), axis=1)
67 | sorted_indices = np.argsort(cur_scores).tolist()[-k - 1:-1]
68 | for idx in sorted_indices:
69 | if idx != i:
70 | vote_stat[idx].append(i)
71 |
72 | if vote_file is not None:
73 | with open(vote_file, 'w') as f:
74 | json.dump(vote_stat, f)
75 | votes = sorted(vote_stat.items(), key=lambda x: len(x[1]), reverse=True)
76 | j = 0
77 | selected_indices = []
78 | while len(selected_indices) < select_num and j < len(votes):
79 | candidate_set = set(votes[j][1])
80 | flag = True
81 | for pre in range(j):
82 | cur_set = set(votes[pre][1])
83 | if len(candidate_set.intersection(cur_set)) >= overlap_threshold * len(candidate_set):
84 | flag = False
85 | break
86 | if not flag:
87 | j += 1
88 | continue
89 | selected_indices.append(int(votes[j][0]))
90 | j += 1
91 | if len(selected_indices) < select_num:
92 | unselected_indices = []
93 | cur_num = len(selected_indices)
94 | for i in range(n):
95 | if not i in selected_indices:
96 | unselected_indices.append(i)
97 | selected_indices += random.sample(unselected_indices, select_num - cur_num)
98 | return selected_indices
99 |
100 | def vote_k_search(self):
101 | vote_k_idxs = self.votek_select(embeddings=self.embed_list, select_num=self.ice_num, k=self.votek_k,
102 | overlap_threshold=1)
103 | return [vote_k_idxs[:] for _ in range(len(self.test_ds))]
104 |
105 | def retrieve(self):
106 | return self.vote_k_search()
107 |
--------------------------------------------------------------------------------
/openicl/icl_retriever/icl_zero_retriever.py:
--------------------------------------------------------------------------------
1 | """Zeroshot Retriever"""
2 |
3 | from datasets import Dataset, DatasetDict
4 | from typing import List, Union, Optional, Tuple, Dict
5 | from openicl import DatasetReader, PromptTemplate
6 | from openicl.icl_retriever import BaseRetriever
7 | from openicl.utils.check_type import _check_str
8 | from accelerate import Accelerator
9 |
10 |
11 | class ZeroRetriever(BaseRetriever):
12 | """Zero In-context Learning Retriever Class
13 | Retriever for Zero-shot.
14 |
15 | Attributes:
16 | dataset_reader (:obj:`DatasetReader`): An instance of the :obj:`DatasetReader` class.
17 | ice_eos_token (:obj:`str`, optional): A string that is added to the end of in-context examples.
18 | prompt_eos_token (:obj:`str`, optional): A string that is added to the end of the prompt.
19 | index_split (:obj:`str`, optional): A string for the index dataset name. The index dataset is used to select data for in-context examples. Defaults to ``train``.
20 | test_split (:obj:`str`, optional): A string for the generation dataset name. The test dataset is used to generate prompts for each data. Defaults to ``test``.
21 | index_ds (:obj:`Dataset`): The index dataset. Used to select data for in-context examples.
22 | test_ds (:obj:`Dataset`): The test dataset. Used to generate prompts for each data.
23 | accelerator (:obj:`Accelerator`, optional): An instance of the :obj:`Accelerator` class, used for multiprocessing.
24 | """
25 |
26 | def __init__(self,
27 | dataset_reader: DatasetReader,
28 | ice_eos_token: Optional[str] = '',
29 | prompt_eos_token: Optional[str] = '',
30 | index_split: Optional[str] = 'train',
31 | test_split: Optional[str] = 'test',
32 | accelerator: Optional[Accelerator] = None
33 | ) -> None:
34 | super().__init__(dataset_reader, '', ice_eos_token, prompt_eos_token, 0, index_split, test_split, accelerator)
35 |
36 | def retrieve(self) -> List[List]:
37 | rtr_idx_list = [[] for _ in range(len(self.test_ds))]
38 | return rtr_idx_list
39 |
--------------------------------------------------------------------------------
/openicl/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .logging import *
2 |
--------------------------------------------------------------------------------
/openicl/utils/api_service.py:
--------------------------------------------------------------------------------
1 | import json
2 | import requests
3 | import os
4 | import openai
5 | import time
6 | import numpy as np
7 |
8 | OPENICL_API_NAME_LIST = ['opt-175b', 'gpt3']
9 | OPENICL_API_PARAMETER_DICT = {
10 | 'opt-175b': ['URL', 'headers'],
11 | 'gpt3': ['engine', 'temperature', 'max_tokens', 'top_p', 'frequency_penalty', 'presence_penalty', 'sleep_time']
12 | }
13 | OPENICL_API_REQUEST_CONFIG = {
14 | 'opt-175b': {
15 | 'URL': "", # http://xxx/completions or http://xxx/generate
16 | 'headers': {
17 | "Content-Type": "application/json; charset=UTF-8"
18 | }
19 | },
20 | 'gpt3': {
21 | 'engine': "text-davinci-003",
22 | 'temperature': 0,
23 | 'max_tokens': 256,
24 | 'top_p': 1.0,
25 | 'frequency_penalty': 0.0,
26 | 'presence_penalty': 0.0,
27 | 'sleep_time': 3
28 | }
29 | }
30 | PROXIES = {"https": "", "http": ""}
31 |
32 |
33 | def is_api_available(api_name):
34 | if api_name is None:
35 | return False
36 | return True if api_name in OPENICL_API_NAME_LIST else False
37 |
38 |
39 | def update_openicl_api_request_config(api_name, **kwargs):
40 | if api_name is None or not is_api_available(api_name):
41 | return
42 |
43 | parameter_list = OPENICL_API_PARAMETER_DICT[api_name]
44 | for parameter in parameter_list:
45 | if parameter in kwargs.keys():
46 | OPENICL_API_REQUEST_CONFIG[api_name][parameter] = kwargs[parameter]
47 |
48 |
49 | def api_get_ppl(api_name, input_texts):
50 | if api_name == 'opt-175b':
51 | pyload = {"prompt": input_texts, "max_tokens": 0, "echo": True}
52 | response = json.loads(
53 | requests.post(OPENICL_API_REQUEST_CONFIG[api_name]['URL'], data=json.dumps(pyload),
54 | headers=OPENICL_API_REQUEST_CONFIG[api_name]['headers'], proxies=PROXIES).text)
55 | lens = np.array([len(r['logprobs']['tokens']) for r in response['choices']])
56 | ce_loss = np.array([-sum(r['logprobs']['token_logprobs']) for r in response['choices']])
57 | return ce_loss / lens
58 |
59 | if api_name == 'gpt3':
60 | raise NotImplementedError("GPT-3 API doesn't support PPL calculation")
61 |
62 |
63 | def api_get_tokens(api_name, input_texts):
64 | length_list = [len(text) for text in input_texts]
65 |
66 | if api_name == 'opt-175b':
67 | pyload = {"prompt": input_texts, "max_tokens": 100, "echo": True}
68 | response = json.loads(
69 | requests.post(OPENICL_API_REQUEST_CONFIG[api_name]['URL'], data=json.dumps(pyload),
70 | headers=OPENICL_API_REQUEST_CONFIG[api_name]['headers'], proxies=PROXIES).text)
71 | return [r['text'] for r in response['choices']], [r['text'][length:] for r, length in
72 | zip(response['choices'], length_list)]
73 |
74 | if api_name == 'gpt3':
75 | openai.api_key = os.getenv("OPENAI_API_KEY")
76 | response = openai.Completion.create(
77 | engine=OPENICL_API_REQUEST_CONFIG['gpt3']['engine'],
78 | prompt=input_texts,
79 | temperature=OPENICL_API_REQUEST_CONFIG['gpt3']['temperature'],
80 | max_tokens=OPENICL_API_REQUEST_CONFIG['gpt3']['max_tokens'],
81 | top_p=OPENICL_API_REQUEST_CONFIG['gpt3']['top_p'],
82 | frequency_penalty=OPENICL_API_REQUEST_CONFIG['gpt3']['frequency_penalty'],
83 | presence_penalty=OPENICL_API_REQUEST_CONFIG['gpt3']['presence_penalty']
84 | )
85 | time.sleep(OPENICL_API_REQUEST_CONFIG['gpt3']['sleep_time'])
86 | return [(input + r['text']) for r, input in zip(response['choices'], input_texts)], [r['text'] for r in
87 | response['choices']]
88 |
--------------------------------------------------------------------------------
/openicl/utils/calculate.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def entropy(probs: np.array, label_dim: int = 0, mask=None):
5 | if mask is None:
6 | return - (probs * np.log(probs)).sum(label_dim)
7 | return - (mask * probs * np.log(probs)).sum(label_dim)
8 |
--------------------------------------------------------------------------------
/openicl/utils/check_type.py:
--------------------------------------------------------------------------------
1 | from datasets import Dataset, DatasetDict
2 | from typing import List, Union, Dict
3 |
4 |
5 | def _check_type_list(obj, typelist: List):
6 | for _type in typelist:
7 | if _type is None:
8 | if obj is None:
9 | return obj
10 | elif isinstance(obj, _type):
11 | return obj
12 | raise TypeError(
13 | f"Expected an object in {[_.__name__ if _ is not None else None for _ in typelist]} type, but got {obj}")
14 |
15 |
16 | def _check_dataset(obj) -> Union[Dataset, DatasetDict]:
17 | if isinstance(obj, Dataset) or isinstance(obj, DatasetDict):
18 | return obj
19 | else:
20 | raise TypeError(f"Expected a datasets.Dataset or a datasets.DatasetDict object, but got {obj}")
21 |
22 |
23 | def _check_list(obj) -> List:
24 | if isinstance(obj, List):
25 | return obj
26 | else:
27 | raise TypeError(f"Expected a List object, but got {obj}")
28 |
29 |
30 | def _check_str(obj) -> str:
31 | if isinstance(obj, str):
32 | return obj
33 | else:
34 | raise TypeError(f"Expected a str object, but got {obj}")
35 |
36 |
37 | def _check_dict(obj) -> Dict:
38 | if isinstance(obj, Dict):
39 | return obj
40 | else:
41 | raise TypeError(f"Expected a Dict object, but got {obj}")
42 |
--------------------------------------------------------------------------------
/openicl/utils/collators.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Any, Dict, List, Optional, Union
3 |
4 | import torch
5 | from transformers import PreTrainedTokenizerBase, BatchEncoding
6 | from transformers.file_utils import PaddingStrategy
7 | import numpy as np
8 |
9 |
10 | class ListWrapper:
11 | def __init__(self, data: List[Any]):
12 | self.data = data
13 |
14 | def to(self, device):
15 | return self.data
16 |
17 |
18 | def ignore_pad_dict(features):
19 | res_dict = {}
20 | if "metadata" in features[0]:
21 | res_dict['metadata'] = ListWrapper([x.pop("metadata") for x in features])
22 | return res_dict
23 |
24 |
25 | @dataclass
26 | class DataCollatorWithPaddingAndCuda:
27 | tokenizer: PreTrainedTokenizerBase
28 | device: object = None
29 | padding: Union[bool, str, PaddingStrategy] = True
30 | max_length: Optional[int] = 3000
31 | pad_to_multiple_of: Optional[int] = None
32 |
33 | def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> BatchEncoding:
34 | res_dict = ignore_pad_dict(features)
35 |
36 | has_labels = "labels" in features[0]
37 | if has_labels:
38 | labels = [{"input_ids": x.pop("labels")} for x in features]
39 | labels = self.tokenizer.pad(
40 | labels,
41 | padding=True,
42 | max_length=self.max_length,
43 | pad_to_multiple_of=self.pad_to_multiple_of,
44 | return_attention_mask=True,
45 | return_tensors="pt",
46 | verbose=False
47 | )
48 |
49 | # print(features)
50 | batch = self.tokenizer.pad(
51 | features,
52 | padding=True,
53 | max_length=self.max_length,
54 | pad_to_multiple_of=self.pad_to_multiple_of,
55 | return_attention_mask=True,
56 | return_tensors="pt",
57 | verbose=False
58 | )
59 |
60 | if has_labels:
61 | batch['labels'] = labels.input_ids
62 | batch.update(res_dict)
63 |
64 | if self.device:
65 | batch = batch.to(self.device)
66 |
67 | return batch
68 |
--------------------------------------------------------------------------------
/openicl/utils/icl_common_utils.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader
2 | from openicl.icl_retriever import BaseRetriever
3 | from typing import List, Union, Optional
4 | from openicl import PromptTemplate
5 | from accelerate import Accelerator
6 |
7 |
8 | def get_dataloader(datalist: List[List], batch_size: int) -> DataLoader:
9 | dataloader = DataLoader(datalist, batch_size=batch_size)
10 | return dataloader
11 |
12 |
13 | def get_generation_prompt_list_from_retriever_indices(ice_idx_list: List[List[int]], retriever: BaseRetriever,
14 | tokenizer, gen_field_replace_token: str,
15 | max_model_token_num: Optional[int] = None,
16 | ice_template: Optional[PromptTemplate] = None,
17 | prompt_template: Optional[PromptTemplate] = None):
18 | prompt_list = []
19 | for idx, ice_idx in enumerate(ice_idx_list):
20 | ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
21 | prompt = retriever.generate_prompt_for_generate_task(idx, ice, gen_field_replace_token=gen_field_replace_token,
22 | ice_template=ice_template, prompt_template=prompt_template)
23 | if max_model_token_num is not None and tokenizer is not None:
24 | prompt_token_num = get_input_token_num(tokenizer, prompt)
25 | while len(ice_idx) > 0 and prompt_token_num > max_model_token_num:
26 | ice_idx = ice_idx[:-1]
27 | ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
28 | prompt = retriever.generate_prompt_for_generate_task(idx, ice,
29 | gen_field_replace_token=gen_field_replace_token,
30 | ice_template=ice_template,
31 | prompt_template=prompt_template)
32 | prompt_token_num = get_input_token_num(tokenizer, prompt)
33 | prompt_list.append(prompt)
34 | return prompt_list
35 |
36 |
37 | def get_input_token_num(tokenizer, input):
38 | return len(tokenizer(input, verbose=False)['input_ids'])
39 |
--------------------------------------------------------------------------------
/openicl/utils/logging.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch.distributed as dist
4 |
5 | LOG_LEVEL = logging.INFO
6 | SUBPROCESS_LOG_LEVEL = logging.ERROR
7 | LOG_FORMATTER = '[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s'
8 |
9 |
10 | def get_logger(name, level=LOG_LEVEL, log_file=None, file_mode='w'):
11 | formatter = logging.Formatter(LOG_FORMATTER)
12 |
13 | logger = logging.getLogger(name)
14 |
15 | for handler in logger.root.handlers:
16 | if type(handler) is logging.StreamHandler:
17 | handler.setLevel(logging.ERROR)
18 |
19 | if dist.is_available() and dist.is_initialized():
20 | rank = dist.get_rank()
21 | else:
22 | rank = 0
23 |
24 | if rank == 0 and log_file is not None:
25 | file_handler = logging.FileHandler(log_file, file_mode)
26 | file_handler.setFormatter(formatter)
27 | file_handler.setLevel(level)
28 | logger.addHandler(file_handler)
29 |
30 | if rank == 0:
31 | logger.setLevel(level)
32 | else:
33 | logger.setLevel(SUBPROCESS_LOG_LEVEL)
34 |
35 | stream_handler = logging.StreamHandler()
36 | stream_handler.setFormatter(formatter)
37 | stream_handler.setLevel(level)
38 | logger.addHandler(stream_handler)
39 |
40 | return logger
41 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers==4.28.1
2 | tokenizers==0.13.3
3 | accelerate
4 | datasets>=2.7.1
5 | evaluate>=0.3.0
6 | faiss_gpu>=1.7.2
7 | nltk>=3.8
8 | numpy>=1.23.4
9 | openai>=0.27.1
10 | rank_bm25>=0.2.2
11 | requests>=2.28.1
12 | scikit_learn>=1.2.1
13 | sentence_transformers>=2.2.2
14 | torch>=1.13.1
15 | tqdm>=4.64.1
16 |
--------------------------------------------------------------------------------
/scripts/self_consistency.py:
--------------------------------------------------------------------------------
1 | from os.path import dirname as d
2 | from os.path import abspath
3 | import sys
4 | root = d(d(abspath(__file__)))
5 | sys.path.append(root)
6 | from collections import Counter
7 |
8 | import json
9 | from openicl import DatasetReader, ZeroRetriever, PromptTemplate, TopkRetriever, GenInferencer, AccEvaluator
10 | # import fire
11 | import re
12 | from transformers import AutoTokenizer, AutoModelForCausalLM
13 | from datasets import load_dataset
14 | #from train_llm.test.proxy import proxy_on
15 |
16 |
17 | def processing_answer(str):
18 | str = str.split(' ')[::-1]
19 | flag = False
20 | ret = ''
21 | for i in range(len(str)):
22 | s = str[i]
23 | for i in range(len(s)):
24 | if s[i].isdigit():
25 | flag = True
26 | ret = s
27 | break
28 | if flag:
29 | break
30 | ret1 = ''
31 | for i in range(len(ret)):
32 | if ret[i].isdigit():
33 | ret1 += ret[i]
34 | return ret1
35 |
36 |
37 | def main(model_path, ice_num=4, batch_size=1, max_seq_len=2048, sc_size=5):
38 |
39 | tokenizer = AutoTokenizer.from_pretrained(model_path)
40 | model = AutoModelForCausalLM.from_pretrained(model_path)
41 |
42 | ds = load_dataset('gsm8k', 'main', split="test[:100]")
43 | print(ds)
44 | # import pdb;pdb.set_trace()
45 | def processing_test(example):
46 | example['answer'] = example['answer'].split("#### ")[1].replace(',', '')
47 | return example
48 |
49 | data = DatasetReader(ds, input_columns=['question'], output_column='answer')
50 |
51 | ref = ds.map(processing_test)
52 |
53 | #template = PromptTemplate("Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?\nA: We start with 15 trees. Later we have 21 trees. The difference must be the number of trees they planted. So, they must have planted 21 - 15 = 6 trees. The answer is 6.\n\nQ: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?\nA: There are 3 cars in the parking lot already. 2 more arrive. Now there are 3 + 2 = 5 cars. The answer is 5.\n\nQ: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?\nA: Leah had 32 chocolates and Leah's sister had 42. That means there were originally 32 + 42 = 74 chocolates. 35 have been eaten. So in total they still have 74 - 35 = 39 chocolates. The answer is 39.\n\nQ: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?\nA: Jason had 20 lollipops. Since he only has 12 now, he must have given the rest to Denny. The number of lollipops he has given to Denny must have been 20 - 12 = 8 lollipops. The answer is 8.\n\nQ: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?\nA: He has 5 toys. He got 2 from mom, so after that he has 5 + 2 = 7 toys. Then he got 2 more from dad, so in total he has 7 + 2 = 9 toys. The answer is 9.\n\nQ: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?\nA: There are 4 days from monday to thursday. 5 computers were added each day. That means in total 4 * 5 = 20 computers were added. There were 9 computers in the beginning, so now there are 9 + 20 = 29 computers. The answer is 29.\n\nQ: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?\nA: Michael initially had 58 balls. He lost 23 on Tuesday, so after that he has 58 - 23 = 35 balls. On Wednesday he lost 2 more so now he has 35 - 2 = 33 balls. The answer is 33.\n\nQ: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?\nA: She bought 5 bagels for $3 each. This means she spent 5 * $3 = $15 on the bagels. She had $23 in beginning, so now she has $23 - $15 = $8. The answer is 8.\n\nQ: \nA: ",
54 | # {'question':'', 'answer':''},
55 | # ice_token='')
56 | #import pdb;pdb.set_trace()
57 | # prompt = open("llm_test/prompt_gsm8k_4shot.txt").readlines()
58 | # for _, line in enumerate(prompt):
59 | # if line == "Let's think step by step\n":
60 | # prompt[_] = "Let's think step by step\nAnswer:\n"
61 | # prompt = ''.join(prompt)
62 | prompt = """
63 | Question: Angelo and Melanie want to plan how many hours over the next week they should study together for their test next week. They have 2 chapters of their textbook to study and 4 worksheets to memorize. They figure out that they should dedicate 3 hours to each chapter of their textbook and 1.5 hours for each worksheet. If they plan to study no more than 4 hours each day, how many days should they plan to study total over the next week if they take a 10-minute break every hour, include 3 10-minute snack breaks each day, and 30 minutes for lunch each day?
64 | Let's think step by step
65 | Angelo and Melanie think they should dedicate 3 hours to each of the 2 chapters, 3 hours x 2 chapters = 6 hours total.
66 | For the worksheets they plan to dedicate 1.5 hours for each worksheet, 1.5 hours x 4 worksheets = 6 hours total.
67 | Angelo and Melanie need to start with planning 12 hours to study, at 4 hours a day, 12 / 4 = 3 days.
68 | However, they need to include time for breaks and lunch. Every hour they want to include a 10-minute break, so 12 total hours x 10 minutes = 120 extra minutes for breaks.
69 | They also want to include 3 10-minute snack breaks, 3 x 10 minutes = 30 minutes.
70 | And they want to include 30 minutes for lunch each day, so 120 minutes for breaks + 30 minutes for snack breaks + 30 minutes for lunch = 180 minutes, or 180 / 60 minutes per hour = 3 extra hours.
71 | So Angelo and Melanie want to plan 12 hours to study + 3 hours of breaks = 15 hours total.
72 | They want to study no more than 4 hours each day, 15 hours / 4 hours each day = 3.75
73 | They will need to plan to study 4 days to allow for all the time they need.
74 | The answer is 4
75 |
76 | Question: Mark's basketball team scores 25 2 pointers, 8 3 pointers and 10 free throws. Their opponents score double the 2 pointers but half the 3 pointers and free throws. What's the total number of points scored by both teams added together?
77 | Let's think step by step
78 | Mark's team scores 25 2 pointers, meaning they scored 25*2= 50 points in 2 pointers.
79 | His team also scores 6 3 pointers, meaning they scored 8*3= 24 points in 3 pointers
80 | They scored 10 free throws, and free throws count as one point so they scored 10*1=10 points in free throws.
81 | All together his team scored 50+24+10= 84 points
82 | Mark's opponents scored double his team's number of 2 pointers, meaning they scored 50*2=100 points in 2 pointers.
83 | His opponents scored half his team's number of 3 pointers, meaning they scored 24/2= 12 points in 3 pointers.
84 | They also scored half Mark's team's points in free throws, meaning they scored 10/2=5 points in free throws.
85 | All together Mark's opponents scored 100+12+5=117 points
86 | The total score for the game is both team's scores added together, so it is 84+117=201 points
87 | The answer is 201
88 |
89 | Question: Bella has two times as many marbles as frisbees. She also has 20 more frisbees than deck cards. If she buys 2/5 times more of each item, what would be the total number of the items she will have if she currently has 60 marbles?
90 | Let's think step by step
91 | When Bella buys 2/5 times more marbles, she'll have increased the number of marbles by 2/5*60 = 24
92 | The total number of marbles she'll have is 60+24 = 84
93 | If Bella currently has 60 marbles, and she has two times as many marbles as frisbees, she has 60/2 = 30 frisbees.
94 | If Bella buys 2/5 times more frisbees, she'll have 2/5*30 = 12 more frisbees.
95 | The total number of frisbees she'll have will increase to 30+12 = 42
96 | Bella also has 20 more frisbees than deck cards, meaning she has 30-20 = 10 deck cards
97 | If she buys 2/5 times more deck cards, she'll have 2/5*10 = 4 more deck cards.
98 | The total number of deck cards she'll have is 10+4 = 14
99 | Together, Bella will have a total of 14+42+84 = 140 items
100 | The answer is 140
101 |
102 | Question: A group of 4 fruit baskets contains 9 apples, 15 oranges, and 14 bananas in the first three baskets and 2 less of each fruit in the fourth basket. How many fruits are there?
103 | Let's think step by step
104 | For the first three baskets, the number of apples and oranges in one basket is 9+15=24
105 | In total, together with bananas, the number of fruits in one basket is 24+14=38 for the first three baskets.
106 | Since there are three baskets each having 38 fruits, there are 3*38=114 fruits in the first three baskets.
107 | The number of apples in the fourth basket is 9-2=7
108 | There are also 15-2=13 oranges in the fourth basket
109 | The combined number of oranges and apples in the fourth basket is 13+7=20
110 | The fourth basket also contains 14-2=12 bananas.
111 | In total, the fourth basket has 20+12=32 fruits.
112 | The four baskets together have 32+114=146 fruits.
113 | The answer is 146
114 |
115 | """
116 |
117 | template = PromptTemplate(f"{prompt}Question: \nLet's think step by step\n",
118 | {'question':'', 'answer':''},
119 | ice_token='')
120 |
121 | retriever = ZeroRetriever(data)
122 | all_predictions = []
123 |
124 | # generation_kwargs = dict(max_new_tokens=512, do_sample=True, temperature=0.7, top_k=40)
125 | generation_kwargs = dict(max_new_tokens=512)
126 | # {"max_gen_len": 512, "do_sample": True, "temperature": 0.8, "top_p": 0.8}
127 | for i in range(sc_size):
128 | print("**"*50)
129 | print("\t\t\tIteration:", str(i))
130 | print("**"*50)
131 | inferencer = GenInferencer(model_name=model, tokenizer_name=tokenizer, generation_kwargs=generation_kwargs,
132 | batch_size=batch_size, output_json_filepath=model_path.split('/')[-2], output_json_filename="gsm8k_"+str(i))
133 | predictions = inferencer.inference(retriever, ice_template=template)
134 | print(predictions[:2])
135 | predictions = [processing_answer(pred.split('\n\n')[0]) for pred in predictions]
136 | # print("**"*50)
137 | # print("\t\t\tProcessed prediction at iteration:", str(i))
138 | # print("**"*50)
139 | # print(predictions[:2])
140 | all_predictions.append(predictions)
141 | #import json
142 | # file = json.load(open("llm_llama/gsm8k.json"))
143 | # predictions = [file[str(i)]['prediction'] for i in range(len(file.keys()))]
144 | assert len(all_predictions) == sc_size
145 | final_prediction = []
146 | for i in range(len(all_predictions[0])):
147 | tmp_preds = []
148 | for j in range(sc_size):
149 | tmp_preds.append(all_predictions[j][i])
150 | counter = Counter(tmp_preds)
151 | if i < 5:
152 | print(counter)
153 | final_prediction.append(counter.most_common(1)[0][0])
154 |
155 | #import pdb;pdb.set_trace()
156 | print(final_prediction[:5], ref['answer'][:5])
157 | score = AccEvaluator().score(predictions=final_prediction, references=ref['answer'])
158 | print(score)
159 |
160 |
161 | if __name__ == "__main__":
162 | # fire.Fire(main)
163 |
164 | # replace with your model_path or huggingface model name here
165 | main(model_path="decapoda-research/llama-7b-hf",
166 | ice_num=4, batch_size=8, sc_size=1)
167 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | from setuptools.command.install import install
3 |
4 |
5 | class DownloadNLTK(install):
6 | def run(self):
7 | self.do_egg_install()
8 | import nltk
9 | nltk.download('punkt')
10 |
11 |
12 | REQUIRES = """
13 | transformers
14 | accelerate
15 | datasets>=2.7.1
16 | evaluate>=0.3.0
17 | faiss_gpu>=1.7.2
18 | nltk>=3.8
19 | numpy>=1.23.4
20 | openai>=0.27.1
21 | rank_bm25>=0.2.2
22 | requests>=2.28.1
23 | scikit_learn>=1.2.1
24 | sentence_transformers>=2.2.2
25 | torch>=1.13.1
26 | tqdm>=4.64.1
27 | """
28 |
29 |
30 | def get_install_requires():
31 | reqs = [req for req in REQUIRES.split("\n") if len(req) > 0]
32 | return reqs
33 |
34 |
35 | with open("README.md") as f:
36 | readme = f.read()
37 |
38 |
39 | def do_setup():
40 | setup(
41 | name="openicl",
42 | version='0.1.7',
43 | description="An open source framework for in-context learning.",
44 | url="https://github.com/Shark-NLP/OpenICL",
45 | author='Zhenyu Wu, Yaoxiang Wang, Zhiyong Wu, Jiacheng Ye',
46 | long_description=readme,
47 | long_description_content_type="text/markdown",
48 | cmdclass={'download_nltk': DownloadNLTK},
49 | install_requires=get_install_requires(),
50 | setup_requires=['nltk==3.8'],
51 | python_requires=">=3.8.0",
52 | packages=find_packages(
53 | exclude=[
54 | "test*",
55 | "paper_test*"
56 | ]
57 | ),
58 | keywords=["AI", "NLP", "in-context learning"],
59 | classifiers=[
60 | "Programming Language :: Python :: 3.8",
61 | "Programming Language :: Python :: 3.9",
62 | "Programming Language :: Python :: 3.10",
63 | "Intended Audience :: Developers",
64 | "Intended Audience :: Education",
65 | "Intended Audience :: Science/Research",
66 | ]
67 | )
68 |
69 |
70 | if __name__ == "__main__":
71 | do_setup()
72 |
--------------------------------------------------------------------------------