├── .github
└── workflows
│ └── unit_tests.yml
├── .gitignore
├── README.md
├── data
├── ab_example.jsonl
└── reddit_r_cooking_sample.jsonl
├── examples
├── ner.yaml
├── textcat_binary.yaml
└── textcat_multilabel.yaml
├── recipes
├── __init__.py
├── openai.py
├── openai_ab.py
├── openai_ner.py
├── openai_terms.py
├── openai_textcat.py
└── style.css
├── requirements.txt
├── requirements_train.txt
├── scripts
└── train_hf_ner.py
├── templates
├── ab
│ ├── input.jinja2
│ ├── prompt1.jinja2
│ └── prompt2.jinja2
├── ner_prompt.jinja2
├── terms_prompt.jinja2
└── textcat_prompt.jinja2
└── tests
├── __init__.py
├── test_ner.py
├── test_openai.py
├── test_terms.py
├── test_textcat.py
└── utils.py
/.github/workflows/unit_tests.yml:
--------------------------------------------------------------------------------
1 | name: Unit Tests
2 |
3 | on:
4 | pull_request:
5 | push:
6 | branches:
7 | - main
8 |
9 | jobs:
10 | setup:
11 | runs-on: ubuntu-latest
12 | steps:
13 | - uses: actions/checkout@v3
14 | - name: Set up Python 3.9
15 | uses: actions/setup-python@v4
16 | with:
17 | python-version: 3.9
18 | cache: "pip" # caching pip dependencies
19 |
20 | - name: Install dependencies
21 | run: |
22 | pip install --upgrade pip
23 | pip install -r requirements.txt
24 | pip install -r requirements_train.txt
25 |
26 | - name: Check out Prodigy
27 | uses: actions/checkout@v3
28 | with:
29 | repository: explosion/prodigy
30 | ref: v1.11.8
31 | path: ./prodigy
32 | ssh-key: ${{ secrets.GHA_PRODIGY_READ }}
33 |
34 | - name: Install prodigy
35 | run: |
36 | ls -la
37 | pip install ./prodigy
38 |
39 | - name: Test with pytest
40 | run: |
41 | python -m pytest --ignore=prodigy
42 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 | .env
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | .python-version
87 |
88 | # pipenv
89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
92 | # install all needed dependencies.
93 | #Pipfile.lock
94 |
95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
96 | __pypackages__/
97 |
98 | # Celery stuff
99 | celerybeat-schedule
100 | celerybeat.pid
101 |
102 | # SageMath parsed files
103 | *.sage.py
104 |
105 | # Environments
106 | .env
107 | .venv
108 | env/
109 | venv/
110 | ENV/
111 | env.bak/
112 | venv.bak/
113 |
114 | # Spyder project settings
115 | .spyderproject
116 | .spyproject
117 |
118 | # Rope project settings
119 | .ropeproject
120 |
121 | # mkdocs documentation
122 | /site
123 |
124 | # mypy
125 | .mypy_cache/
126 | .dmypy.json
127 | dmypy.json
128 |
129 | # Pyre type checker
130 | .pyre/
131 | *.jsonl
132 | *.ipynb
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Archival notice
4 |
5 | The recipes in this repository have since moved to [Prodigy](https://prodi.gy/) and are being maintained there. They will soon even get an upgrade with the advent of [spacy-llm](https://github.com/explosion/spacy-llm) support, which features better prompts and multiple LLM providers. That is why we've opted to archive this repo, so that we may focus on maintaining these recipes as part of spaCy and Prodigy directly.
6 |
7 | You can learn more by checking out the [large language models section](https://prodi.gy/docs/large-language-models) on the docs.
8 |
9 | # Prodigy OpenAI recipes
10 |
11 | This repository contains example code on how to combine **zero- and few-shot learning
12 | with a small annotation effort** to obtain a **high-quality dataset with maximum efficiency**. Specifically, we use large language models available from [OpenAI](https://openai.com) to provide us with an initial set of predictions,
13 | then spin up a [Prodigy](https://prodi.gy) instance on our local machine
14 | to go through these predictions and curate them. This allows us to obtain a
15 | gold-standard dataset pretty quickly, and train a smaller, supervised model that fits
16 | our exact needs and use-case.
17 |
18 | 
19 |
20 | https://user-images.githubusercontent.com/13643239/208504034-0ab6bcbe-6d2b-415d-8257-233f2074ba31.mp4
21 |
22 | ## ⏳ Setup and Install
23 |
24 | Make sure to [install Prodigy](https://prodi.gy/docs/install) as well as a few additional Python dependencies:
25 |
26 | ```bash
27 | python -m pip install prodigy -f https://XXXX-XXXX-XXXX-XXXX@download.prodi.gy
28 | python -m pip install -r requirements.txt
29 | ```
30 |
31 | With `XXXX-XXXX-XXXX-XXXX` being your personal Prodigy license key.
32 |
33 | Then, create a new API key from [openai.com](https://beta.openai.com/account/api-keys) or fetch an existing
34 | one. Record the secret key as well as the [organization key](https://beta.openai.com/account/org-settings)
35 | and make sure these are available as environmental variables. For instance, set them in a `.env` file in the
36 | root directory:
37 |
38 | ```
39 | OPENAI_ORG = "org-..."
40 | OPENAI_KEY = "sk-..."
41 | ```
42 |
43 | ## 📋 Named-entity recognition (NER)
44 |
45 | ### `ner.openai.correct`: NER annotation with zero- or few-shot learning
46 |
47 | This recipe marks entity predictions obtained from a large language model and allows you to flag them as correct, or to
48 | manually curate them. This allows you to quickly gather a gold-standard dataset through zero-shot or few-shot learning.
49 | It's very much like using the standard [`ner.correct`](https://prodi.gy/docs/recipes#ner-correct) recipe in Prodi.gy,
50 | but we're using GPT-3 as a backend model to make predictions.
51 |
52 | ```bash
53 | python -m prodigy ner.openai.correct dataset filepath labels [--options] -F ./recipes/openai_ner.py
54 | ```
55 |
56 | | Argument | Type | Description | Default |
57 | | ----------------------- | ---- | ----------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------- |
58 | | `dataset` | str | Prodigy dataset to save annotations to. | |
59 | | `filepath` | Path | Path to `.jsonl` data to annotate. The data should at least contain a `"text"` field. | |
60 | | `labels` | str | Comma-separated list defining the NER labels the model should predict. | |
61 | | `--lang`, `-l` | str | Language of the input data - will be used to obtain a relevant tokenizer. | `"en"` |
62 | | `--segment`, `-S` | bool | Flag to set when examples should be split into sentences. By default, the full input article is shown. | `False` |
63 | | `--model`, `-m` | str | GPT-3 model to use for initial predictions. | `"text-davinci-003"` |
64 | | `--prompt_path`, `-p` | Path | Path to the `.jinja2` [prompt template](templates). | `./templates/ner_prompt.jinja2` |
65 | | `--examples-path`, `-e` | Path | Path to examples to help define the task. The file can be a .yml, .yaml or .json. If set to `None`, zero-shot learning is applied. | `None` |
66 | | `--max-examples`, `-n` | int | Max number of examples to include in the prompt to OpenAI. If set to 0, zero-shot learning is always applied, even when examples are available. | 2 |
67 | | `--batch-size`, `-b` | int | Batch size of queries to send to the OpenAI API. | 10 |
68 | | `--verbose`, `-v` | bool | Flag to print extra information to the terminal. | `False` |
69 |
70 | #### Example usage
71 |
72 | Let's say we want to recognize dishes, ingredients and cooking equipment from some text we obtained from a cooking subreddit.
73 | We'll send the text to GPT-3, hosted by OpenAI, and provide an annotation prompt to explain
74 | to the language model the type of predictions we want. Something like:
75 |
76 | ```
77 | From the text below, extract the following entities in the following format:
78 | dish:
79 | ingredient:
80 | equipment:
81 |
82 | Text:
83 | ...
84 | ```
85 |
86 | We define the definition of this prompt in a .jinja2 file which also describes how to append examples for few-shot learning.
87 | You can create your own [template](templates) and provide it to the recipe with the `--prompt-path` or `-p` option.
88 | Additionally, with `--examples-path` or `-e` you can set the file path of a .y(a)ml or .json file that contains additional examples:
89 |
90 | ```bash
91 | python -m prodigy ner.openai.correct my_ner_data ./data/reddit_r_cooking_sample.jsonl "dish,ingredient,equipment" -p ./templates/ner_prompt.jinja2 -e ./examples/ner.yaml -n 2 -F ./recipes/openai_ner.py
92 | ```
93 |
94 | After receiving the results from the OpenAI API, the Prodigy recipe converts the predictions into an annotation task
95 | that can be rendered with Prodigy. The task even shows the original prompt as well as the raw answer we obtained
96 | from the language model.
97 |
98 |
99 |
100 | Here, we see that the model is able to correctly recognize dishes, ingredients and cooking equipment right from the start!
101 |
102 | The recipe also offers a `--verbose` or `-v` option that includes the exact prompt and response on the terminal as traffic is received.
103 | Note that because the requests to the API are batched, you might have to scroll back a bit to find the current prompt.
104 |
105 | ### Interactively tune the prompt examples
106 |
107 | At some point, you might notice a mistake in the predictions of the OpenAI language model. For instance, we noticed an error
108 | in the recognition of cooking equipment in this example:
109 |
110 |
111 |
112 | If you see these kind of systematic errors, you can steer the predictions in the right direction by correcting the example and then selecting the small "flag" icon
113 | in the top right of the Prodigy UI:
114 |
115 |
116 |
117 | Once you hit accept on the Prodigy interface, the flagged example will be automatically picked up and added to the examples
118 | that are sent to the OpenAI API as part of the prompt.
119 |
120 | > **Note**
121 | > Because Prodigy batches these requests, the prompt will be updated with a slight
122 | > delay, after the next batch of prompts is sent to OpenAI. You can experiment
123 | > with making the batch size (`--batch-size` or `-b`) smaller to have the change
124 | > come into effect sooner, but this might negatively impact the speed of the
125 | > annotation workflow.
126 |
127 | ### `ner.openai.fetch`: Fetch examples up-front
128 |
129 | The `ner.openai.correct` recipe fetches examples from OpenAI while annotating, but we've also included a recipe that can fetch a large batch of examples upfront.
130 |
131 | ```bash
132 | python -m prodigy ner.openai.fetch input_data.jsonl predictions.jsonl "dish,ingredient,equipment" -F ./recipes/ner.py
133 | ```
134 |
135 | This will create a `predictions.jsonl` file that can be loaded with the [`ner.manual`](https://prodi.gy/docs/recipes#ner-manual) recipe.
136 |
137 | Note that the OpenAI API might return "429 Too Many Request" errors when requesting too much data at once - in this case it's best to ensure you only request
138 | 100 or so examples at a time.
139 |
140 | ### Exporting the annotations and training an NER model
141 |
142 | After you've curated a set of predictions, you can export the results with [`db-out`](https://prodi.gy/docs/recipes#db-out):
143 |
144 | ```bash
145 | python -m prodigy db-out my_ner_data > ner_data.jsonl
146 | ```
147 |
148 | The format of the exported annotations contains all the data you need to train a smaller model downstream. Each example
149 | in the dataset contains the original text, the tokens, span annotations denoting the entities, etc.
150 |
151 | You can also export the data to spaCy's [binary format](https://spacy.io/api/data-formats#training), using [`data-to-spacy`](https://prodi.gy/docs/recipes#data-to-spacy). This format lets you load in the annotations as spaCy `Doc` objects, which can be convenient for further conversion. The `data-to-spacy` command also makes it easy to train an NER model with spaCy. First you export the data, specifying the train data as 20% of the total:
152 |
153 | ```bash
154 | python -m prodigy data-to-spacy ./data/annotations/ --ner my_ner_data -es 0.2
155 | ```
156 |
157 | Then you can train a model with spaCy or [Prodigy](https://prodi.gy/docs/recipes/#training):
158 |
159 | ```bash
160 | python -m spacy train ./data/annotations/config.cfg --paths.train ./data/annotations/train.spacy --paths.dev ./data/annotations/dev.spacy -o ner-model
161 | ```
162 |
163 | This will save a model to the `ner-model/` directory.
164 |
165 | We've also included an experimental script to load in the `.spacy` binary format and train a model with the HuggingFace `transformers` library. You can use the same data you just exported and run the script like this:
166 |
167 | ```bash
168 | # First you need to install the HuggingFace library and requirements
169 | pip install -r requirements_train.txt
170 | python ./scripts/train_hf_ner.py ./data/annotations/train.spacy ./data/annotations/dev.spacy -o hf-ner-model
171 | ```
172 |
173 | The resulting model will be saved to the `hf-ner-model/` directory.
174 |
175 | ## 📋 Text categorization (Textcat)
176 |
177 | ### `textcat.openai.correct`: Textcat annotation with zero- or few-shot learning
178 |
179 | This recipe enables us to classify texts faster with the help of a large
180 | language model. It also provides a "reason" to explain why a particular label
181 | was chosen.
182 |
183 | ```bash
184 | python -m prodigy textcat.openai.correct dataset filepath labels [--options] -F ./recipes/openai_textcat.py
185 | ```
186 |
187 | | Argument | Type | Description | Default |
188 | | --------------------------- | ---- | ----------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------- |
189 | | `dataset` | str | Prodigy dataset to save annotations to. | |
190 | | `filepath` | Path | Path to `.jsonl` data to annotate. The data should at least contain a `"text"` field. | |
191 | | `labels` | str | Comma-separated list defining the text categorization labels the model should predict. | |
192 | | `--lang`, `-l` | str | Language of the input data - will be used to obtain a relevant tokenizer. | `"en"` |
193 | | `--segment`, `-S` | bool | Flag to set when examples should be split into sentences. By default, the full input article is shown. | `False` |
194 | | `--model`, `-m` | str | GPT-3 model to use for initial predictions. | `"text-davinci-003"` |
195 | | `--prompt-path`, `-p` | Path | Path to the `.jinja2` [prompt template](templates). | `./templates/textcat_prompt.jinja2` |
196 | | `--examples-path`, `-e` | Path | Path to examples to help define the task. The file can be a .yml, .yaml or .json. If set to `None`, zero-shot learning is applied. | `None` |
197 | | `--max-examples`, `-n` | int | Max number of examples to include in the prompt to OpenAI. If set to 0, zero-shot learning is always applied, even when examples are available. | 2 |
198 | | `--batch-size`, `-b` | int | Batch size of queries to send to the OpenAI API. | 10 |
199 | | `--exclusive-classes`, `-E` | bool | Flag to make the classification task exclusive. | `False` |
200 | | `--verbose`, `-v` | bool | Flag to print extra information to the terminal. | `False` |
201 |
202 |
203 | #### Example usage
204 |
205 | The `textcat` recipes can be used for binary, multiclass, and multilabel text
206 | categorization. You can set this by passing the appropriate number of labels in
207 | the `--labels` parameter; for example, passing a single label turns it into
208 | binary classification and so on. We will talk about each one in the proceeding
209 | sections.
210 |
211 | ##### Binary text categorization
212 |
213 | Suppose we want to know if a particular Reddit comment talks about a food
214 | recipe. We'll send the text to GPT-3 and provide a prompt that instructs the
215 | predictions we want.
216 |
217 | ```
218 | From the text below, determine wheter or not it contains a recipe. If it is a
219 | recipe, answer "accept." If it is not a recipe, answer "reject."
220 |
221 | Your answer should only be in the following format:
222 | answer:
223 | reason:
224 |
225 | Text:
226 | ```
227 |
228 | For binary classification, we want GPT-3 to return "accept" if a given text is a
229 | food recipe and "reject" otherwise. GPT-3's suggestion is then displayed
230 | prominently in the UI. We can press the ACCEPT (check mark) button to
231 | include the text as a positive example or press the REJECT (cross
232 | mark) button if it is a negative example.
233 |
234 |
235 | ```sh
236 | python -m prodigy textcat.openai.correct my_binary_textcat_data data/reddit_r_cooking_sample.jsonl --labels recipe -F recipes/openai_textcat.py
237 | ```
238 |
239 |
240 |
241 |
242 | ##### Multilabel and multiclass text categorization
243 |
244 | Now, suppose we want to classify Reddit comments as a recipe, a feedback, or a
245 | question. We can write the following prompt:
246 |
247 | ```
248 | Classify the text below to any of the following labels: recipe, feedback, question.
249 | The task is exclusive, so only choose one label from what I provided.
250 |
251 | Your answer should only be in the following format:
252 | answer:
253 | reason:
254 |
255 | Text:
256 | ```
257 |
258 | Then, we can use this recipe to handle multilabel and multiclass cases by
259 | passing the three labels to the `--labels` parameter. We should also set the
260 | `--exclusive-classes` flag to render a single-choice UI:
261 |
262 | ```sh
263 | python -m prodigy textcat.openai.correct my_multi_textcat_data data/reddit_r_cooking_sample.jsonl \
264 | --labels recipe,feedback,question \
265 | --exclusive-classes \
266 | -F recipes/openai_textcat.py
267 | ```
268 |
269 |
270 |
271 | ### Writing templates
272 |
273 | We write these prompts as a .jinja2 template that can also take in examples for
274 | few-shot learning. You can create your own [template](templates) and provide it
275 | to the recipe with the `--prompt-path` or `-p` option. Additionally, with
276 | `--examples-path` or `-e` you can set the file path of a .y(a)ml or .json file
277 | that contains additional examples. You can also add context in these examples as
278 | we observed it to improve the output:
279 |
280 | ```bash
281 | python -m prodigy textcat.openai.correct my_binary_textcat_data \
282 | ./data/reddit_r_cooking_sample.jsonl \
283 | --labels recipe \
284 | --prompt-path ./templates/textcat_prompt.jinja2 \
285 | --examples-path ./examples/textcat_binary.yaml -n 2 \
286 | -F ./recipes/openai_textcat.py
287 | ```
288 |
289 | Similar to the NER recipe, this recipe also converts the predictions into an
290 | annotation task that can be rendered with Prodigy. For binary classification, we
291 | use the [`classification`](https://prodi.gy/docs/api-interfaces#classification)
292 | interface with custom HTML elements, while for multilabel or multiclass text
293 | categorization, we use the
294 | [`choice`](https://prodi.gy/docs/api-interfaces#choice) annotation interface.
295 | Notice that we include the original prompt and the OpenAI response in the UI.
296 |
297 |
298 | Lastly, you can use the `--verbose` or `-v` flag to show the exact prompt and
299 | response on the terminal. Note that because the requests to the API are batched,
300 | you might have to scroll back a bit to find the current prompt.
301 |
302 |
303 | ### Interactively tune the prompt examples
304 |
305 | Similar to the NER recipes, you can also steer the predictions in the right
306 | direction by correcting the example and then selecting the small "flag" icon in
307 | the top right of the Prodigy UI:
308 |
309 |
310 |
311 | Once you hit the accept button on the Prodigy interface, the flagged
312 | example will be picked up and added to the few-shot examples sent to the OpenAI
313 | API as part of the prompt.
314 |
315 | > **Note**
316 | > Because Prodigy batches these requests, the prompt will be updated with a slight
317 | > delay, after the next batch of prompts is sent to OpenAI. You can experiment
318 | > with making the batch size (`--batch-size` or `-b`) smaller to have the change
319 | > come into effect sooner, but this might negatively impact the speed of the
320 | > annotation workflow.
321 |
322 | ### `textcat.openai.fetch`: Fetch text categorization examples up-front
323 |
324 | The `textcat.openai.fetch` recipe allows us to fetch a large batch of examples
325 | upfront. This is helpful when you are with a highly-imbalanced data and interested
326 | only in rare examples.
327 |
328 | ```bash
329 | python -m prodigy textcat.openai.fetch input_data.jsonl predictions.jsonl --labels Recipe -F ./recipes/openai_textcat.py
330 | ```
331 |
332 | This will create a `predictions.jsonl` file that can be loaded with the
333 | [`textcat.manual`](https://prodi.gy/docs/recipes#textcat-manual) recipe.
334 |
335 | Note that the OpenAI API might return "429 Too Many Request" errors when
336 | requesting too much data at once - in this case it's best to ensure you only
337 | request 100 or so examples at a time and take a look at the [API's rate
338 | limits](https://help.openai.com/en/articles/5955598-is-api-usage-subject-to-any-rate-limits).
339 |
340 | #### Working with imbalanced data
341 |
342 | The `textcat.openai.fetch` recipe is suitable for working with datasets where
343 | there is severe class imbalance. Usually, you'd want to find examples of the
344 | rare class rather than annotating a random sample. From there, you want to
345 | upsample them to train a decent model and so on.
346 |
347 | This is where large language models like OpenAI might help.
348 |
349 | Using the [Reddit r/cooking dataset](data), we prompted OpenAI to
350 | look for comments that resemble a food recipe. Instead of annotating 10,000
351 | examples, we ran `textcat.openai.fetch` and obtained 145 positive classes. Out
352 | of those 145 examples, 114 turned out to be true positives (79% precision). We
353 | then checked 1,000 negative examples and found 12 false negative cases (98%
354 | recall).
355 |
356 | Ideally, once we fully annotated the dataset, we can train a supervised model
357 | that is better to use than relying on zero-shot predictions for production. The
358 | running cost is low and it's easier to manage.
359 |
360 | ### Exporting the annotations and training a text categorization model
361 |
362 | After you've curated a set of predictions, you can export the results with
363 | [`db-out`](https://prodi.gy/docs/recipes#db-out):
364 |
365 | ```bash
366 | python -m prodigy db-out my_textcat_data > textcat_data.jsonl
367 | ```
368 |
369 | The format of the exported annotations contains all the data you need to train a
370 | smaller model downstream. Each example in the dataset contains the original
371 | text, the tokens, span annotations denoting the entities, etc.
372 |
373 | You can also export the data to spaCy's [binary
374 | format](https://spacy.io/api/data-formats#training), using
375 | [`data-to-spacy`](https://prodi.gy/docs/recipes#data-to-spacy). This format lets
376 | you load in the annotations as spaCy `Doc` objects, which can be convenient for
377 | further conversion. The `data-to-spacy` command also makes it easy to train a
378 | text categorization model with spaCy. First you export the data, specifying the
379 | train data as 20% of the total:
380 |
381 | ```bash
382 | # For binary textcat
383 | python -m prodigy data-to-spacy ./data/annotations/ --textcat my_textcat_data -es 0.2
384 | # For multilabel textcat
385 | python -m prodigy data-to-spacy ./data/annotations/ --textcat-multilabel my_textcat_data -es 0.2
386 | ```
387 | Then you can train a model with spaCy or [Prodigy](https://prodi.gy/docs/recipes/#training):
388 |
389 | ```bash
390 | python -m spacy train ./data/annotations/config.cfg --paths.train ./data/annotations/train.spacy --paths.dev ./data/annotations/dev.spacy -o textcat-model
391 | ```
392 |
393 | This will save a model to the `textcat-model/` directory.
394 |
395 | ## 📋 Terms
396 |
397 | ### `terms.openai.fetch`: Fetch phrases and terms based on a query
398 |
399 | This recipe generates terms and phrases obtained from a large language model. These
400 | terms can be curated and turned into patterns files, which can help with downstream annotation tasks.
401 |
402 | ```bash
403 | python -m prodigy terms.openai.fetch query filepath [--options] -F ./recipes/openai_terms.py
404 | ```
405 |
406 | | Argument | Type | Description | Default |
407 | |:---------------------:|-------|-----------------------------------------------------|---------------------------------|
408 | | `query` | str | Query to send to OpenAI | |
409 | | `output_path` | Path | Path to save the output | |
410 | | `--seeds`,`-s` | str | One or more comma-separated seed phrases. | `""` |
411 | | `--n`,`-n` | int | Minimum number of items to generate | `100` |
412 | | `--model`, `-m` | str | GPT-3 model to use for completion | `"text-davinci-003"` |
413 | | `--prompt-path`, `-p` | Path | Path to jinja2 prompt template | `templates/terms_prompt.jinja2` |
414 | | `--verbose`,`-v` | bool | Print extra information to terminal | `False` |
415 | | `--resume`, `-r` | bool | Resume by loading in text examples from output file | `False` |
416 | | `--progress`,`-pb` | bool | Print progress of the recipe. | `False` |
417 | | `--temperature`,`-t` | float | OpenAI temperature param | `1.0` |
418 | | `--top-p`, `--tp` | float | OpenAI top_p param | `1.0` |
419 | | `--best-of`, `-bo` | int | OpenAI best_of param" | `10` |
420 | | `--n-batch`,`-nb` | int | OpenAI batch size param | `10` |
421 | | `--max-tokens`, `-mt` | int | Max tokens to generate per call | `100` |
422 |
423 | #### Example usage
424 |
425 | Suppose you're interested in detecting skateboard tricks in text, then you might want to start
426 | with a term list of known tricks. You might want to start with the following query:
427 |
428 | ```bash
429 | # Base behavior, fetch at least 100 terms/phrases
430 | python -m prodigy terms.openai.fetch "skateboard tricks" tricks.jsonl --n 100 --prompt-path templates/terms_prompt.jinja2 -F recipes/openai_terms.py
431 | ```
432 |
433 | This will generate a prompt to OpenAI that asks to try and generate at least 100 examples of "skateboard tricks".
434 | There's an upper limit to the amount of tokens that can be generated by OpenAI, but this recipe will try and keep
435 | collecting terms until it reached the amount specified.
436 |
437 | You can choose to make the query more elaborate if you want to try to be more precise, but you can alternatively
438 | also choose to add some seed terms via `--seeds`. These will act as starting examples to help steer OpenAI
439 | in the right direction.
440 |
441 | ```bash
442 | # Base behavior but with seeds
443 | python -m prodigy terms.openai.fetch "skateboard tricks" tricks.jsonl --n 100 --seeds "kickflip,ollie" --prompt-path templates/terms_prompt.jinja2 -F recipes/openai_terms.py
444 | ```
445 |
446 | Collecting many examples can take a while, so it can be helpful to show the progress, via `--progress`
447 | as requests are sent.
448 |
449 | ```bash
450 | # Adding progress output as we wait for 500 examples
451 | python -m prodigy terms.openai.fetch "skateboard tricks" tricks.jsonl --n 500 --progress --seeds "kickflip,ollie" --prompt-path templates/terms_prompt.jinja2 -F recipes/openai_terms.py
452 | ```
453 |
454 | After collecting a few examples, you might want to generate more. You can choose to continue from a
455 | previous output file. This will effectively re-use those examples as seeds for the prompt to OpenAI.
456 |
457 | ```bash
458 | # Use the `--resume` flag to re-use previous examples
459 | python -m prodigy terms.openai.fetch "skateboard tricks" tricks.jsonl --n 50 --resume --prompt-path templates/terms_prompt.jinja2 -F recipes/openai_terms.py
460 | ```
461 |
462 | When the recipe is done, you'll have a `tricks.jsonl` file that has contents that look like this:
463 |
464 | ```json
465 | {"text":"pop shove it","meta":{"openai_query":"skateboard tricks"}}
466 | {"text":"switch flip","meta":{"openai_query":"skateboard tricks"}}
467 | {"text":"nose slides","meta":{"openai_query":"skateboard tricks"}}
468 | {"text":"lazerflip","meta":{"openai_query":"skateboard tricks"}}
469 | {"text":"lipslide","meta":{"openai_query":"skateboard tricks"}}
470 | ...
471 | ```
472 |
473 | ### Towards Patterns
474 |
475 | You now have a `tricks.jsonl` file on disk that contains skateboard tricks, but you cannot
476 | assume that all of these will be accurate. The next step would be to review the terms and you
477 | can use the [`textcat.manual`](https://prodi.gy/docs/recipes/#textcat-manual) recipe that comes
478 | with Prodigy for that.
479 |
480 | ```bash
481 | # The tricks.jsonl was fetched from OpenAI beforehand
482 | python -m prodigy textcat.manual skateboard-tricks-list tricks.jsonl --label skateboard-tricks
483 | ```
484 |
485 | This generates an interface that looks like this:
486 |
487 |
488 |
489 | You can manually accept or reject each example and when you're done annotating you can export
490 | the annotated text into a patterns file via the [`terms.to-patterns`](https://prodi.gy/docs/recipes/#terms-to-patterns) recipe.
491 |
492 | ```bash
493 | # Generate a `patterns.jsonl` file.
494 | python -m prodigy terms.to-patterns skateboard-tricks-list patterns.jsonl --label skateboard-tricks --spacy-model blank:en
495 | ```
496 |
497 | When the recipe is done, you'll have a `patterns.jsonl` file that has contents that look like this:
498 |
499 | ```json
500 | {"label":"skateboard-tricks","pattern":[{"lower":"pop"},{"lower":"shove"},{"lower":"it"}]}
501 | {"label":"skateboard-tricks","pattern":[{"lower":"switch"},{"lower":"flip"}]}
502 | {"label":"skateboard-tricks","pattern":[{"lower":"nose"},{"lower":"slides"}]}
503 | {"label":"skateboard-tricks","pattern":[{"lower":"lazerflip"}]}
504 | {"label":"skateboard-tricks","pattern":[{"lower":"lipslide"}]}
505 | ...
506 | ```
507 |
508 | ### Known Limitations
509 |
510 | OpenAI has a hard limit on the prompt size. You cannot have a prompt larger than 4079 tokens. Unfortunately
511 | that means that there is a limit to the size of term lists that you can generate. The recipe will report
512 | an error when this happens, but it's good to be aware of this limitation.
513 |
514 | ## 📋 Prompt A/B evaluation
515 |
516 | ### `ab.openai.prompts`: A/B evaluation of prompts
517 |
518 | The goal of this recipe is to quickly allow someone to compare the quality of outputs from two prompts
519 | in a quantifiable and blind way.
520 |
521 | ```bash
522 | python -m prodigy ab.openai.prompts dataset inputs_path display_template_path prompt1_template_path prompt2_template_path [--options] -F ./recipes/openai_ab.py
523 | ```
524 |
525 | | Argument | Type | Description | Default |
526 | |:----------------------:|-------|------------------------------------------------------|----------------------|
527 | | `dataset` | str | Prodigy dataset to save answers into | |
528 | | `inputs_path` | Path | Path to jsonl inputs | |
529 | | `display_template_path`| Path | Template for summarizing the arguments | |
530 | | `prompt1_template_path`| Path | Path to the first jinja2 prompt template | |
531 | | `prompt2_template_path`| Path | Path to the second jinja2 prompt template | |
532 | | `--model`, `-m` | str | GPT-3 model to use for completion | `"text-davinci-003"` |
533 | | `--batch-size`, `-b` | int | Batch size to send to OpenAI API | `10` |
534 | | `--verbose`,`-v` | bool | Print extra information to terminal | `False` |
535 | | `--no-random`,`-NR` | bool | Don't randomize which annotation is shown as correct | `False` |
536 | | `--repeat`, `-r` | int | How often to send the same prompt to OpenAI | `1` |
537 |
538 | #### Example usage
539 |
540 | As an example, let's try to generate humorous haikus. To do that we first need to
541 | construct two jinja files that represent the prompt to send to OpenAI.
542 |
543 | ##### `templates/ab/prompt1.jinja2`
544 |
545 | ```
546 | Write a haiku about {{topic}}.
547 | ```
548 |
549 | ##### `templates/ab/prompt2.jinja2`
550 |
551 | ```
552 | Write an incredibly hilarious haiku about {{topic}}. So funny!
553 | ```
554 |
555 | You can provide variables for these prompts by constructing a .jsonl file with the required
556 | parameters. In this case we need to make sure that `{{topic}}` is accounted for.
557 |
558 | Here's an example `.jsonl` file that could work.
559 |
560 | ##### `data/ab_example.jsonl`
561 |
562 | ```json
563 | {"id": 0, "prompt_args": {"topic": "star wars"}}
564 | {"id": 0, "prompt_args": {"topic": "kittens"}}
565 | {"id": 0, "prompt_args": {"topic": "the python programming language"}}
566 | {"id": 0, "prompt_args": {"topic": "maths"}}
567 | ```
568 |
569 | > **Note**
570 | >
571 | > All the arguments under `prompt_args` will be passed to render the jinja templates.
572 | > The `id` is mandatory and can be used to identify groups in later analysis.
573 |
574 | We're nearly ready to evaluate, but this recipe requires one final jinja2 template.
575 | This one won't be used to generate a prompt, but it will generate a useful title
576 | that reminds the annotator of the current task. Here's an example of such a template.
577 |
578 | ##### `templates/ab/input.jinja2`
579 |
580 | ```
581 | A haiku about {{topic}}.
582 | ```
583 |
584 | When you put all of these templates together you can start annotating. The command below
585 | starts the annotation interface and also uses the `--repeat 4` option. This will ensure
586 | that each topic will be used to generate a prompt at least 4 times.
587 |
588 | ```
589 | python -m prodigy ab.openai.prompts haiku data/ab_example.jsonl templates/ab/input.jinja2 templates/ab/prompt1.jinja2 templates/ab/prompt2.jinja2 --repeat 5 -F recipes/openai_ab.py
590 | ```
591 |
592 | This is what the annotation interface looks like:
593 |
594 | 
595 |
596 | When you look at this interface you'll notice that the title template is rendered and that
597 | you're able to pick from two options. Both options are responses from OpenAI that were
598 | generated by the two prompt templates. You can also see the `prompt_args` rendered in
599 | the lower right corner of the choice menu.
600 |
601 | From here you can annotate your favorite examples and gather data that might help you
602 | decide on which prompt is best.
603 |
604 | #### Results
605 |
606 | Once you're done annotating you'll be presented with an overview of the results.
607 |
608 | ```
609 | =========================== ✨ Evaluation results ===========================
610 | ✔ You preferred prompt1.jinja2
611 |
612 | prompt1.jinja2 11
613 | prompt2.jinja2 5
614 | ```
615 |
616 | But you can also fetch the raw annotations from the database for further analysis.
617 |
618 | ```
619 | python -m prodigy db-out haiku
620 | ```
621 |
622 | ## ❓ What's next?
623 |
624 | There’s lots of interesting follow-up experiments to this, and lots of ways to adapt the basic idea to different tasks or data sets. We’re also interested to try out different prompts. It’s unclear how much the format the annotations are requested in might change the model’s predictions, or whether there’s a shorter prompt that might perform just as well. We also want to run some end-to-end experiments.
625 |
--------------------------------------------------------------------------------
/data/ab_example.jsonl:
--------------------------------------------------------------------------------
1 | {"id": 0, "prompt_args": {"topic": "star wars"}}
2 | {"id": 0, "prompt_args": {"topic": "kittens"}}
3 | {"id": 0, "prompt_args": {"topic": "the python programming language"}}
4 | {"id": 0, "prompt_args": {"topic": "maths"}}
--------------------------------------------------------------------------------
/examples/ner.yaml:
--------------------------------------------------------------------------------
1 | -
2 | text: "You can't get a great chocolate flavor with carob."
3 |
4 | entities:
5 | Ingredient:
6 | - carob
7 |
8 | -
9 | text: "You can probably sand-blast it if it's an anodized aluminum pan, which it probably is."
10 | entities:
11 | Ingredient: []
12 | Equipment: ["anodized aluminum pan"]
13 |
--------------------------------------------------------------------------------
/examples/textcat_binary.yaml:
--------------------------------------------------------------------------------
1 | - text: "This is a recipe for scrambled egg: 2 eggs, 1 butter, batter them, and then fry in a hot pan for 2 minutes"
2 | answer: "accept"
3 | reason: "This is a recipe for making a scrambled egg"
4 | - text: "This is a recipe for fried rice: 1 cup of day old rice, 1 butter, 2 cloves of garlic: put them all in a wok and stir them together."
5 | answer: "accept"
6 | reason: "This is a recipe for making a fried rice"
7 | - text: "I tried it and it's not good"
8 | answer: "reject"
9 | reason: "It doesn't talk about a recipe."
10 |
--------------------------------------------------------------------------------
/examples/textcat_multilabel.yaml:
--------------------------------------------------------------------------------
1 | - text: "Can someone try this recipe?"
2 | answer: "question"
3 | reason: "It is a question about trying a recipe."
4 | - text: "1 cup of rice then egg and then mix them well. Should I add garlic last?"
5 | answer: "question,recipe"
6 | reason: "It is a question about the steps in making a fried rice."
7 |
--------------------------------------------------------------------------------
/recipes/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explosion/prodigy-openai-recipes/a0866f0d05f0a744cab15ad202201c5b691d48b9/recipes/__init__.py
--------------------------------------------------------------------------------
/recipes/openai.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import copy
3 | import os
4 | import sys
5 | import time
6 | from dataclasses import dataclass
7 | from pathlib import Path
8 | from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar
9 |
10 | import httpx
11 | import jinja2
12 | import rich
13 | import srsly
14 | from prodigy.components import preprocess
15 | from prodigy.util import msg, set_hashes
16 | from rich.panel import Panel
17 | from spacy.language import Language
18 |
19 | _ItemT = TypeVar("_ItemT")
20 | _PromptT = TypeVar("_PromptT", bound="PromptExample")
21 |
22 |
23 | @dataclass
24 | class PromptExample(abc.ABC):
25 | """An example to be passed into an OpenAI prompt.
26 | When inheriting this dataclass, you should implement the `from_prodigy`
27 | function that takes in a Prodigy task example and formats it back
28 | into a dataclass that can fill a prompt template.
29 | You can refer to Prodigy's API Interfaces documentation
30 | (https://prodi.gy/docs/api-interfaces) to see how most examples are structured
31 | for each task.
32 | """
33 |
34 | @staticmethod
35 | def is_flagged(example: _ItemT) -> bool:
36 | """Check whether a Prodigy example is flagged for use
37 | in the prompt."""
38 |
39 | return (
40 | example.get("flagged") is True
41 | and example.get("answer") == "accept"
42 | and "text" in example
43 | )
44 |
45 | @classmethod
46 | def from_prodigy(cls, example: _ItemT, labels: Iterable[str]) -> "PromptExample":
47 | """Create a prompt example from Prodigy's format."""
48 | ...
49 |
50 |
51 | def normalize_label(label: str) -> str:
52 | return label.lower()
53 |
54 |
55 | class OpenAISuggester:
56 | """Suggest annotations using OpenAI's GPT-3
57 | prompt_template (jinja2.Template): A Jinja2 template that contains the
58 | prompt to send to OpenAI's GPT-3 API.
59 | model (str): The GPT-3 model ID to use for completion. Check the OpenAI
60 | documentation for more information https://beta.openai.com/docs/models/overview.
61 | labels (List[str]): List of labels for annotation.
62 | max_examples (int): The maximum number of examples to stream in the Prodigy UI.
63 | segment (bool): If set to True, segment the documents into sentences.
64 | verbose (bool): Show verbose output in the command-line, including the prompt and response from OpenAI.
65 | openai_api_org (str): The OpenAI API organization.
66 | openai_api_key (str): The OpenAI API key.
67 | openai_temperature (float): The temperature parameter (from 0 to 1) that controls the
68 | randomness of GPT-3's output.
69 | openai_max_tokens (int): The maximum amout of tokens that GPT-3's
70 | completion API can generate.
71 | openai_n (int): The number of completions to generate for each prompt.
72 | openai_n_retries (int): The number of retries whenever a 429 error occurs.
73 | openai_retry_timeout_s (int): The amount of time before attempting another request whenever we
74 | encounter a 429 error. Increases exponentially for each retry.
75 | openai_read_timeout_s (int): The amount of time to wait a response output during a request.
76 | examples (List[PromptExample]): A list of examples to add to the prompt to guide GPT-3 output.
77 | response_parser (Callable[str] -> Dict): A function that accepts a string that represents
78 | GPT-3's raw response, and parses it into a dictionary that is compatible to Prodigy's
79 | annotation interfaces.
80 | render_vars (Dict[str, Any]): A dictionary containing additional variables to render in the
81 | Jinja2 template. By default, the Jinja2 template will render the text (str), some labels (List[str]),
82 | and examples (PromptExample). If you wish to add other task-specific variables, you should supply
83 | them in this variable.
84 | """
85 |
86 | prompt_template: jinja2.Template
87 | model: str
88 | labels: List[str]
89 | max_examples: int
90 | segment: bool
91 | verbose: bool
92 | openai_api_org: str
93 | openai_api_key: str
94 | openai_temperature: float
95 | openai_max_tokens: int
96 | openai_retry_timeout_s: int
97 | openai_read_timeout_s: int
98 | openai_n_retries: int
99 | openai_n: int
100 | examples: List[PromptExample]
101 | response_parser: Callable
102 | render_vars: Dict[str, Any]
103 | prompt_example_class: PromptExample
104 |
105 | OPENAI_COMPLETIONS_ENDPOINT: str = "https://api.openai.com/v1/completions"
106 | RETRY_ERROR_CODES: List[int] = [429, 503]
107 |
108 | def __init__(
109 | self,
110 | prompt_template: jinja2.Template,
111 | *,
112 | labels: List[str],
113 | max_examples: int,
114 | segment: bool,
115 | openai_api_org: str,
116 | openai_api_key: str,
117 | openai_model: str,
118 | response_parser: Callable,
119 | prompt_example_class: PromptExample,
120 | openai_temperature: int = 0,
121 | openai_max_tokens: int = 500,
122 | openai_retry_timeout_s: int = 1,
123 | openai_read_timeout_s: int = 30,
124 | openai_n_retries: int = 10,
125 | openai_n: int = 1,
126 | render_vars: Optional[Dict[str, Any]] = None,
127 | verbose: bool = False,
128 | ):
129 | self.prompt_template = prompt_template
130 | self.model = openai_model
131 | self.labels = [normalize_label(label) for label in labels]
132 | self.max_examples = max_examples
133 | self.verbose = verbose
134 | self.segment = segment
135 | self.examples = []
136 | self.openai_api_org = openai_api_org
137 | self.openai_api_key = openai_api_key
138 | self.openai_temperature = openai_temperature
139 | self.openai_max_tokens = openai_max_tokens
140 | self.openai_retry_timeout_s = openai_retry_timeout_s
141 | self.openai_read_timeout_s = openai_read_timeout_s
142 | self.openai_n = openai_n
143 | self.openai_n_retries = openai_n_retries
144 | self.response_parser = response_parser
145 | self.prompt_example_class = prompt_example_class
146 | self.render_vars = {} if render_vars is None else render_vars
147 |
148 | def __call__(
149 | self,
150 | stream: Iterable[_ItemT],
151 | *,
152 | nlp: Language,
153 | batch_size: int,
154 | **kwargs,
155 | ) -> Iterable[_ItemT]:
156 | if self.segment:
157 | stream = preprocess.split_sentences(nlp, stream) # type: ignore
158 |
159 | stream = self.pipe(stream, nlp, batch_size, **kwargs)
160 | stream = self.set_hashes(stream)
161 | return stream
162 |
163 | def pipe(
164 | self, stream: Iterable[_ItemT], nlp: Language, batch_size: int, **kwargs
165 | ) -> Iterable[_ItemT]:
166 | """Process the stream and add suggestions from OpenAI."""
167 | stream = self.stream_suggestions(stream, batch_size)
168 | stream = self.format_suggestions(stream, nlp=nlp)
169 | return stream
170 |
171 | def set_hashes(self, stream: Iterable[_ItemT]) -> Iterable[_ItemT]:
172 | for example in stream:
173 | yield set_hashes(example)
174 |
175 | def update(self, examples: Iterable[_ItemT]) -> float:
176 | """Update the examples that will be used in the prompt based on user flags."""
177 | for eg in examples:
178 | if PromptExample.is_flagged(eg):
179 | self.add_example(
180 | self.prompt_example_class.from_prodigy(eg, self.labels)
181 | )
182 | return 0.0
183 |
184 | def add_example(self, example: PromptExample) -> None:
185 | """Add an example for use in the prompts. Examples are pruned to the most recent max_examples."""
186 | if self.max_examples and example:
187 | self.examples.append(example)
188 | if len(self.examples) > self.max_examples:
189 | self.examples = self.examples[-self.max_examples :]
190 |
191 | def stream_suggestions(
192 | self, stream: Iterable[_ItemT], batch_size: int
193 | ) -> Iterable[_ItemT]:
194 | """Get zero-shot or few-shot annotations from OpenAI.
195 | Given a stream of input examples, we define a prompt, get a response from OpenAI,
196 | and yield each example with their predictions to the output stream.
197 | """
198 | for batch in batch_sequence(stream, batch_size):
199 | prompts = [
200 | self._get_prompt(eg["text"], labels=self.labels, examples=self.examples)
201 | for eg in batch
202 | ]
203 | responses = self._get_openai_response(prompts)
204 | for eg, prompt, response in zip(batch, prompts, responses):
205 | if self.verbose:
206 | rich.print(Panel(prompt, title="Prompt to OpenAI"))
207 | eg["openai"] = {"prompt": prompt, "response": response}
208 | if self.verbose:
209 | rich.print(Panel(response, title="Response from OpenAI"))
210 | yield eg
211 |
212 | def format_suggestions(
213 | self, stream: Iterable[_ItemT], *, nlp: Language
214 | ) -> Iterable[_ItemT]:
215 | """Parse the examples in the stream and set up labels
216 | to display in the Prodigy UI.
217 | """
218 | stream = preprocess.add_tokens(nlp, stream, skip=True) # type: ignore
219 | for example in stream:
220 | example = copy.deepcopy(example)
221 | if "meta" not in example:
222 | example["meta"] = {}
223 |
224 | response = example["openai"].get("response", "")
225 | example.update(self.response_parser(response, example))
226 | yield example
227 |
228 | def _get_prompt(
229 | self, text: str, labels: List[str], examples: List[PromptExample]
230 | ) -> str:
231 | """Generate a prompt for GPT-3 OpenAI."""
232 | return self.prompt_template.render(
233 | text=text, labels=labels, examples=examples, **self.render_vars
234 | )
235 |
236 | def _get_openai_response(self, prompts: List[str]) -> List[str]:
237 | headers = {
238 | "Authorization": f"Bearer {self.openai_api_key}",
239 | "OpenAI-Organization": self.openai_api_org,
240 | "Content-Type": "application/json",
241 | }
242 | r = retry(
243 | lambda: httpx.post(
244 | self.OPENAI_COMPLETIONS_ENDPOINT,
245 | headers=headers,
246 | json={
247 | "model": self.model,
248 | "prompt": prompts,
249 | "temperature": self.openai_temperature,
250 | "max_tokens": self.openai_max_tokens,
251 | "n": self.openai_n,
252 | },
253 | timeout=self.openai_read_timeout_s,
254 | ),
255 | n=self.openai_n_retries,
256 | timeout_s=self.openai_retry_timeout_s,
257 | error_codes=self.RETRY_ERROR_CODES,
258 | )
259 | r.raise_for_status()
260 | responses = r.json()
261 | return [responses["choices"][i]["text"] for i in range(len(prompts))]
262 |
263 |
264 | def get_api_credentials(model: Optional[str] = None) -> Tuple[str, str]:
265 | # Fetch and check the key
266 | api_key = os.getenv("OPENAI_KEY")
267 | if api_key is None:
268 | m = (
269 | "Could not find the API key to access the openai API. Ensure you have an API key "
270 | "set up via https://beta.openai.com/account/api-keys, then make it available as "
271 | "an environment variable 'OPENAI_KEY', for instance in a .env file."
272 | )
273 | msg.fail(m)
274 | sys.exit(-1)
275 | # Fetch and check the org
276 | org = os.getenv("OPENAI_ORG")
277 | if org is None:
278 | m = (
279 | "Could not find the organisation to access the openai API. Ensure you have an API key "
280 | "set up via https://beta.openai.com/account/api-keys, obtain its organization ID 'org-XXX' "
281 | "via https://beta.openai.com/account/org-settings, then make it available as "
282 | "an environment variable 'OPENAI_ORG', for instance in a .env file."
283 | )
284 | msg.fail(m)
285 | sys.exit(-1)
286 |
287 | # Check the access and get a list of available models to verify the model argument (if not None)
288 | # Even if the model is None, this call is used as a healthcheck to verify access.
289 | headers = {
290 | "Authorization": f"Bearer {api_key}",
291 | "OpenAI-Organization": org,
292 | }
293 | r = retry(
294 | lambda: httpx.get(
295 | "https://api.openai.com/v1/models",
296 | headers=headers,
297 | ),
298 | n=1,
299 | timeout_s=1,
300 | )
301 | if r.status_code == 422:
302 | m = (
303 | "Could not access api.openai.com -- 422 permission denied."
304 | "Visit https://beta.openai.com/account/api-keys to check your API keys."
305 | )
306 | msg.fail(m)
307 | sys.exit(-1)
308 | elif r.status_code != 200:
309 | m = "Error accessing api.openai.com" f"{r.status_code}: {r.text}"
310 | msg.fail(m)
311 | sys.exit(-1)
312 |
313 | if model is not None:
314 | response = r.json()["data"]
315 | models = [response[i]["id"] for i in range(len(response))]
316 | if model not in models:
317 | e = f"The specified model '{model}' is not available. Choices are: {sorted(set(models))}"
318 | msg.fail(e, exits=1)
319 |
320 | return api_key, org
321 |
322 |
323 | def read_prompt_examples(
324 | path: Optional[Path], *, example_class: Type[PromptExample]
325 | ) -> List[PromptExample]:
326 | if path is None:
327 | return []
328 | elif path.suffix in (".yml", ".yaml"):
329 | return read_yaml_examples(path, example_class=example_class)
330 | elif path.suffix == ".json":
331 | data = srsly.read_json(path)
332 | assert isinstance(data, list)
333 | return [PromptExample(**eg) for eg in data]
334 | else:
335 | msg.fail(
336 | "The --examples-path (-e) parameter expects a .yml, .yaml or .json file."
337 | )
338 | sys.exit(-1)
339 |
340 |
341 | def load_template(path: Path) -> jinja2.Template:
342 | # I know jinja has a lot of complex file loading stuff,
343 | # but we're not using the inheritance etc that makes
344 | # that stuff worthwhile.
345 | if not path.suffix == ".jinja2":
346 | msg.fail(
347 | "The --prompt-path (-p) parameter expects a .jinja2 file.",
348 | exits=1,
349 | )
350 | with path.open("r", encoding="utf8") as file_:
351 | text = file_.read()
352 | return jinja2.Template(text, undefined=jinja2.DebugUndefined)
353 |
354 |
355 | def retry(
356 | call_api: Callable[[], httpx.Response],
357 | n: int,
358 | timeout_s: int,
359 | error_codes: List[int] = [429, 503],
360 | ) -> httpx.Response:
361 | """Retry a call to the OpenAI API if we get a non-ok status code.
362 | This function automatically retries a request if it catches a response
363 | with an error code in `error_codes`. The amount of timeout also increases
364 | exponentially every time we retry.
365 | """
366 | assert n >= 0
367 | assert timeout_s >= 1
368 | r = call_api()
369 | i = -1
370 | # We don't want to retry on every non-ok status code. Some are about
371 | # incorrect inputs, etc. and we want to terminate on those.
372 | while i < n and r.status_code in error_codes:
373 | time.sleep(timeout_s)
374 | i += 1
375 | timeout_s = timeout_s * 2 # Increase timeout everytime you retry
376 | msg.text(
377 | f"Retrying call (retries left: {n-i}, timeout: {timeout_s}s). "
378 | f"Previous call returned: {r.status_code} ({r.reason_phrase})"
379 | )
380 | return r
381 |
382 |
383 | def read_yaml_examples(
384 | path: Path, *, example_class: Type[PromptExample]
385 | ) -> List[PromptExample]:
386 | data = srsly.read_yaml(path)
387 | if not isinstance(data, list):
388 | msg.fail("Cannot interpret prompt examples from yaml", exits=True)
389 | assert isinstance(data, list)
390 | output = [example_class(**eg) for eg in data]
391 | return output
392 |
393 |
394 | def batch_sequence(items: Iterable[_ItemT], batch_size: int) -> Iterable[List[_ItemT]]:
395 | batch = []
396 | for eg in items:
397 | batch.append(eg)
398 | if len(batch) >= batch_size:
399 | yield batch
400 | batch = []
401 | if batch:
402 | yield batch
403 |
--------------------------------------------------------------------------------
/recipes/openai_ab.py:
--------------------------------------------------------------------------------
1 | """A/B evaluation of OpenAI responses, for prompt engineering.
2 |
3 | A/B evaluation is basically a blind "taste test". Results are produced
4 | from two experimental conditions (in this case, responses from two
5 | prompts). The annotator is shown the response pair, without knowing which
6 | condition produced which response. The annotator marks which one is better,
7 | and the response is recorded.
8 |
9 | At the end of annotation, the results are tallied up, so you can see whether
10 | one condition produces better results than the other. This lets you apply
11 | a sound methodology to subjective decisions.
12 | """
13 | from collections import Counter
14 | import os
15 | import random
16 | import sys
17 | import time
18 | from pathlib import Path
19 | from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, cast
20 |
21 | import httpx
22 | import jinja2
23 | import rich
24 | from rich.panel import Panel
25 |
26 | import prodigy
27 | import pydantic
28 | import srsly
29 | from dotenv import load_dotenv
30 | from prodigy.util import msg
31 |
32 | _ItemT = TypeVar("_ItemT")
33 | # Set up openai
34 | load_dotenv() # take environment variables from .env.
35 |
36 |
37 | class PromptInput(pydantic.BaseModel):
38 | id: str
39 | prompt_args: Dict[str, Any]
40 |
41 |
42 | class OpenAIPromptAB:
43 | display: jinja2.Template
44 | prompts: Dict[str, jinja2.Template]
45 | inputs: Iterable[PromptInput]
46 | batch_size: int
47 | verbose: bool
48 | randomize: bool
49 | openai_api_org: str
50 | openai_api_key: str
51 | openai_temperature: float
52 | openai_max_tokens: int
53 | openai_timeout_s: int
54 | openai_n: int
55 |
56 | def __init__(
57 | self,
58 | display: jinja2.Template,
59 | prompts: Dict[str, jinja2.Template],
60 | inputs: Iterable[PromptInput],
61 | *,
62 | openai_api_org: str,
63 | openai_api_key: str,
64 | openai_model: str,
65 | batch_size: int = 10,
66 | verbose: bool = False,
67 | randomize: bool = True,
68 | openai_temperature: float = 0,
69 | openai_max_tokens: int = 500,
70 | openai_timeout_s: int = 50,
71 | openai_n: int = 1,
72 | repeat: int = 3
73 | ):
74 | self.display = display
75 | self.inputs = inputs
76 | self.prompts = prompts
77 | self.model = openai_model
78 | self.batch_size = batch_size
79 | self.verbose = verbose
80 | self.openai_api_org = openai_api_org
81 | self.openai_api_key = openai_api_key
82 | self.openai_temperature = openai_temperature
83 | self.openai_max_tokens = openai_max_tokens
84 | self.openai_timeout_s = openai_timeout_s
85 | self.openai_n = openai_n
86 | self.randomize = randomize
87 | self.repeat = repeat
88 |
89 | def __iter__(self) -> Iterable[Dict]:
90 | for input_batch in _batch_sequence(self.inputs, self.batch_size):
91 | for _ in range(self.repeat):
92 | response_batch = self._get_response_batch(input_batch)
93 | for input_, responses in zip(input_batch, response_batch):
94 | yield self._make_example(
95 | input_.id,
96 | self.display.render(**input_.prompt_args),
97 | responses,
98 | randomize=self.randomize,
99 | prompt_args=input_.prompt_args
100 | )
101 |
102 | def on_exit(self, ctrl):
103 | examples = ctrl.db.get_dataset_examples(ctrl.dataset)
104 | counts = Counter({k: 0 for k in self.prompts.keys()})
105 | # Get last example per ID
106 | for eg in examples:
107 | selected = eg.get("accept", [])
108 | if not selected or len(selected) != 1 or eg["answer"] != "accept":
109 | continue
110 | counts[selected[0]] += 1
111 | print("")
112 | if not counts:
113 | msg.warn("No answers found", exits=0)
114 | msg.divider("Evaluation results", icon="emoji")
115 | # Handle edge case when both are equal:
116 | nr1, nr2 = counts.most_common(2)
117 | if nr1[1] == nr2[1]:
118 | msg.good(f"It's a draw!")
119 | else:
120 | pref, _ = nr1
121 | msg.good(f"You preferred {pref}")
122 | rows = [(name, count) for name, count in counts.most_common()]
123 | msg.table(rows, aligns=("l", "r"))
124 |
125 | def _get_response_batch(self, inputs: List[PromptInput]) -> List[Dict[str, str]]:
126 | name1, name2 = self._choose_rivals()
127 | prompts = []
128 | for input_ in inputs:
129 | prompts.append(self._get_prompt(name1, input_.prompt_args))
130 | prompts.append(self._get_prompt(name2, input_.prompt_args))
131 | if self.verbose:
132 | for prompt in prompts:
133 | rich.print(Panel(prompt, title="Prompt to OpenAI"))
134 | responses = self._get_responses(prompts)
135 | assert len(responses) == len(inputs) * 2
136 | output = []
137 | # Pair out the responses. There's a fancy
138 | # zip way to do this but I think that's less
139 | # readable
140 | for i in range(0, len(responses), 2):
141 | output.append({name1: responses[i], name2: responses[i + 1]})
142 | return output
143 |
144 | def _choose_rivals(self) -> Tuple[str, str]:
145 | assert len(self.prompts) == 2
146 | return tuple(sorted(self.prompts.keys()))
147 |
148 | def _get_prompt(self, name: str, args: Dict) -> str:
149 | return self.prompts[name].render(**args)
150 |
151 | def _get_responses(self, prompts: List[str]) -> List[str]:
152 | headers = {
153 | "Authorization": f"Bearer {self.openai_api_key}",
154 | "OpenAI-Organization": self.openai_api_org,
155 | "Content-Type": "application/json",
156 | }
157 | r = _retry429(
158 | lambda: httpx.post(
159 | "https://api.openai.com/v1/completions",
160 | headers=headers,
161 | json={
162 | "model": self.model,
163 | "prompt": prompts,
164 | "temperature": self.openai_temperature,
165 | "max_tokens": self.openai_max_tokens,
166 | },
167 | timeout=self.openai_timeout_s,
168 | ),
169 | n=self.openai_n,
170 | timeout_s=self.openai_timeout_s,
171 | )
172 | r.raise_for_status()
173 | responses = r.json()
174 | return [responses["choices"][i]["text"].strip() for i in range(len(prompts))]
175 |
176 | def _make_example(
177 | self, id_str: str, input: str, responses: Dict[str, str], randomize: bool, prompt_args: Dict[str, Any]
178 | ) -> Dict:
179 |
180 | question = {
181 | "id": id_str,
182 | "text": input,
183 | "options": [],
184 | }
185 | response_pairs = list(responses.items())
186 | if randomize:
187 | random.shuffle(response_pairs)
188 | else:
189 | response_pairs = list(sorted(response_pairs))
190 | for name, value in response_pairs:
191 | question["options"].append({"id": name, "text": value})
192 | question["meta"] = prompt_args
193 | return question
194 |
195 |
196 | @prodigy.recipe(
197 | # fmt: off
198 | "ab.openai.prompts",
199 | dataset=("Dataset to save answers to", "positional", None, str),
200 | inputs_path=("Path to jsonl inputs", "positional", None, Path),
201 | display_template_path=("Template for summarizing the arguments","positional", None, Path),
202 | prompt1_template_path=("Path to the first jinja2 prompt template","positional", None, Path),
203 | prompt2_template_path=("Path to the second jinja2 prompt template","positional", None, Path),
204 | model=("GPT-3 model to use for responses", "option", "m", str),
205 | batch_size=("Batch size to send to OpenAI API", "option", "b", int),
206 | verbose=("Print extra information to terminal", "flag", "v", bool),
207 | no_random=("Don't randomize which annotation is shown as correct","flag","NR",bool,),
208 | repeat=("How often to send the same prompt to OpenAI", "option", "r", int)
209 | # fmt: on
210 | )
211 | def ab_openai_prompts(
212 | dataset: str,
213 | inputs_path: Path,
214 | display_template_path: Path,
215 | prompt1_template_path: Path,
216 | prompt2_template_path: Path,
217 | model: str = "text-davinci-003",
218 | batch_size: int = 10,
219 | verbose: bool = False,
220 | no_random: bool = False,
221 | repeat: int = 1,
222 | ):
223 | api_key, api_org = _get_api_credentials(model)
224 | inputs = [PromptInput(**x) for x in cast(List[Dict], srsly.read_jsonl(inputs_path))]
225 |
226 | display = _load_template(display_template_path)
227 | prompt1 = _load_template(prompt1_template_path)
228 | prompt2 = _load_template(prompt2_template_path)
229 | stream = OpenAIPromptAB(
230 | display=display,
231 | prompts={
232 | prompt1_template_path.name: prompt1,
233 | prompt2_template_path.name: prompt2,
234 | },
235 | inputs=inputs,
236 | openai_api_org=api_org,
237 | openai_api_key=api_key,
238 | openai_model=model,
239 | batch_size=batch_size,
240 | verbose=verbose,
241 | randomize=not no_random,
242 | openai_temperature=0.9,
243 | repeat=repeat
244 | )
245 | return {
246 | "dataset": dataset,
247 | "view_id": "choice",
248 | "stream": stream,
249 | "on_exit": stream.on_exit,
250 | "config": {
251 | "batch_size": batch_size,
252 | "choice_auto_accept": True,
253 | "exclude_by": "input",
254 | "global_css": ".prodigy-content{line-height: 1.2;};"
255 | },
256 | }
257 |
258 | def _get_api_credentials(model: str) -> Tuple[str, str]:
259 | # Fetch and check the key
260 | api_key = os.getenv("OPENAI_KEY")
261 | if api_key is None:
262 | m = (
263 | "Could not find the API key to access the openai API. Ensure you have an API key "
264 | "set up via https://beta.openai.com/account/api-keys, then make it available as "
265 | "an environment variable 'OPENAI_KEY', for instance in a .env file."
266 | )
267 | msg.fail(m)
268 | sys.exit(-1)
269 | # Fetch and check the org
270 | org = os.getenv("OPENAI_ORG")
271 | if org is None:
272 | m = (
273 | "Could not find the organisation to access the openai API. Ensure you have an API key "
274 | "set up via https://beta.openai.com/account/api-keys, obtain its organization ID 'org-XXX' "
275 | "via https://beta.openai.com/account/org-settings, then make it available as "
276 | "an environment variable 'OPENAI_ORG', for instance in a .env file."
277 | )
278 | msg.fail(m)
279 | sys.exit(-1)
280 |
281 | # Check the access and get a list of available models to verify the model argument (if not None)
282 | # Even if the model is None, this call is used as a healthcheck to verify access.
283 | headers = {
284 | "Authorization": f"Bearer {api_key}",
285 | "OpenAI-Organization": org,
286 | }
287 | r = _retry429(
288 | lambda: httpx.get(
289 | "https://api.openai.com/v1/models",
290 | headers=headers,
291 | ),
292 | n=1,
293 | timeout_s=1,
294 | )
295 | if r.status_code == 422:
296 | m = (
297 | "Could not access api.openai.com -- 422 permission denied."
298 | "Visit https://beta.openai.com/account/api-keys to check your API keys."
299 | )
300 | msg.fail(m)
301 | sys.exit(-1)
302 | elif r.status_code != 200:
303 | m = "Error accessing api.openai.com" f"{r.status_code}: {r.text}"
304 | msg.fail(m)
305 | sys.exit(-1)
306 |
307 | if model is not None:
308 | response = r.json()["data"]
309 | models = [response[i]["id"] for i in range(len(response))]
310 | if model not in models:
311 | e = f"The specified model '{model}' is not available. Choices are: {sorted(set(models))}"
312 | msg.fail(e, exits=1)
313 |
314 | return api_key, org
315 |
316 |
317 | def _load_template(path: Path) -> jinja2.Template:
318 | # I know jinja has a lot of complex file loading stuff,
319 | # but we're not using the inheritance etc that makes
320 | # that stuff worthwhile.
321 | if not path.suffix == ".jinja2":
322 | msg.fail(
323 | f"The parameter expects a .jinja2 file. Received path='{path}'",
324 | exits=1,
325 | )
326 | with path.open("r", encoding="utf8") as file_:
327 | text = file_.read()
328 | return jinja2.Template(text)
329 |
330 |
331 | def _retry429(
332 | call_api: Callable[[], httpx.Response], n: int, timeout_s: int
333 | ) -> httpx.Response:
334 | """Retry a call to the OpenAI API if we get a 429: Too many requests
335 | error.
336 | """
337 | assert n >= 0
338 | assert timeout_s >= 1
339 | r = call_api()
340 | i = -1
341 | while i < n and r.status_code == 429:
342 | time.sleep(timeout_s)
343 | i += 1
344 | return r
345 |
346 |
347 | def _batch_sequence(items: Iterable[_ItemT], batch_size: int) -> Iterable[List[_ItemT]]:
348 | batch = []
349 | for eg in items:
350 | batch.append(eg)
351 | if len(batch) >= batch_size:
352 | yield batch
353 | batch = []
354 | if batch:
355 | yield batch
356 |
--------------------------------------------------------------------------------
/recipes/openai_ner.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from pathlib import Path
3 | from typing import Dict, Iterable, List, Optional, Tuple, cast, Callable
4 | from dataclasses import dataclass
5 |
6 | import prodigy
7 | import prodigy.components.db
8 | import prodigy.components.preprocess
9 | import prodigy.util
10 | import spacy
11 | import srsly
12 | import tqdm
13 | from dotenv import load_dotenv
14 | from spacy.language import Language
15 | from spacy.util import filter_spans
16 |
17 | from recipes.openai import OpenAISuggester, PromptExample, get_api_credentials
18 | from recipes.openai import load_template, normalize_label
19 | from recipes.openai import read_prompt_examples
20 |
21 | CSS_FILE_PATH = Path(__file__).parent / "style.css"
22 | DEFAULT_PROMPT_PATH = Path(__file__).parent.parent / "templates" / "ner_prompt.jinja2"
23 | HTML_TEMPLATE = """
24 |
25 |
26 | Show the prompt for OpenAI
27 | {{openai.prompt}}
28 |
29 |
30 | Show the response from OpenAI
31 | {{openai.response}}
32 |
33 |
34 | """
35 |
36 | # Set up openai
37 | load_dotenv() # take environment variables from .env.
38 |
39 |
40 | @dataclass
41 | class NERPromptExample(PromptExample):
42 | """An example to be passed into an OpenAI NER prompt"""
43 |
44 | text: str
45 | entities: Dict[str, List[str]]
46 |
47 | @classmethod
48 | def from_prodigy(cls, example: Dict, labels: Iterable[str]) -> "PromptExample":
49 | """Create a prompt example from Prodigy's format.
50 | Only entities with a label from the given set will be retained.
51 | The given set of labels is assumed to be already normalized.
52 | """
53 | if "text" not in example:
54 | raise ValueError("Cannot make PromptExample without text")
55 | entities_by_label = defaultdict(list)
56 | full_text = example["text"]
57 | for span in example.get("spans", []):
58 | label = normalize_label(span["label"])
59 | if label in labels:
60 | mention = full_text[int(span["start"]) : int(span["end"])]
61 | entities_by_label[label].append(mention)
62 |
63 | return cls(text=full_text, entities=entities_by_label)
64 |
65 |
66 | def make_ner_response_parser(labels: List[str], lang: str) -> Callable:
67 |
68 | nlp = spacy.blank(lang)
69 |
70 | def _parse_response(text: str, example: Optional[Dict] = None) -> Dict:
71 | """Interpret OpenAI's NER response. It's supposed to be
72 | a list of lines, with each line having the form:
73 | Label: phrase1, phrase2, ...
74 | However, there's no guarantee that the model will give
75 | us well-formed output. It could say anything, it's an LM.
76 | So we need to be robust.
77 | """
78 | output = []
79 | for line in text.strip().split("\n"):
80 | if line and ":" in line:
81 | label, phrases = line.split(":", 1)
82 | label = normalize_label(label)
83 | if label in labels:
84 | if phrases.strip():
85 | phrases = [
86 | phrase.strip() for phrase in phrases.strip().split(",")
87 | ]
88 | output.append((label, phrases))
89 |
90 | example = _fmt_response(output, example)
91 | return example
92 |
93 | def _fmt_response(response: List[Tuple[str, List[str]]], example: Dict):
94 | doc = nlp.make_doc(example["text"])
95 | spacy_spans = []
96 | for label, phrases in response:
97 | label = normalize_label(label)
98 | if label in labels:
99 | offsets = _find_substrings(doc.text, phrases)
100 | for start, end in offsets:
101 | span = doc.char_span(
102 | start, end, alignment_mode="contract", label=label
103 | )
104 | if span is not None:
105 | spacy_spans.append(span)
106 | # This step prevents the same token from being used in multiple spans.
107 | # If there's a conflict, the longer span is preserved.
108 | spacy_spans = filter_spans(spacy_spans)
109 | spans = [
110 | {
111 | "label": span.label_,
112 | "start": span.start_char,
113 | "end": span.end_char,
114 | "token_start": span.start,
115 | "token_end": span.end - 1,
116 | }
117 | for span in spacy_spans
118 | ]
119 | return {"spans": spans}
120 |
121 | return _parse_response
122 |
123 |
124 | @prodigy.recipe(
125 | "ner.openai.correct",
126 | dataset=("Dataset to save answers to", "positional", None, str),
127 | filepath=("Path to jsonl data to annotate", "positional", None, Path),
128 | labels=("Labels (comma delimited)", "positional", None, lambda s: s.split(",")),
129 | model=("GPT-3 model to use for initial predictions", "option", "m", str),
130 | examples_path=("Path to examples to help define the task", "option", "e", Path),
131 | lang=("Language to use for tokenizer", "option", "l", str),
132 | max_examples=("Max examples to include in prompt", "option", "n", int),
133 | prompt_path=("Path to jinja2 prompt template", "option", "p", Path),
134 | batch_size=("Batch size to send to OpenAI API", "option", "b", int),
135 | segment=("Split articles into sentences", "flag", "S", bool),
136 | verbose=("Print extra information to terminal", "flag", "v", bool),
137 | )
138 | def ner_openai_correct(
139 | dataset: str,
140 | filepath: Path,
141 | labels: List[str],
142 | lang: str = "en",
143 | model: str = "text-davinci-003",
144 | batch_size: int = 10,
145 | segment: bool = False,
146 | examples_path: Optional[Path] = None,
147 | prompt_path: Path = DEFAULT_PROMPT_PATH,
148 | max_examples: int = 2,
149 | verbose: bool = False,
150 | ):
151 | examples = read_prompt_examples(examples_path, example_class=NERPromptExample)
152 | nlp = spacy.blank(lang)
153 | if segment:
154 | nlp.add_pipe("sentencizer")
155 | api_key, api_org = get_api_credentials(model)
156 | openai = OpenAISuggester(
157 | response_parser=make_ner_response_parser(labels=labels, lang=lang),
158 | openai_model=model,
159 | labels=labels,
160 | max_examples=max_examples,
161 | prompt_template=load_template(prompt_path),
162 | segment=segment,
163 | verbose=verbose,
164 | openai_api_org=api_org,
165 | openai_api_key=api_key,
166 | openai_n=1,
167 | openai_retry_timeout_s=10,
168 | openai_read_timeout_s=20,
169 | openai_n_retries=10,
170 | prompt_example_class=NERPromptExample,
171 | )
172 | for eg in examples:
173 | openai.add_example(eg)
174 | if max_examples >= 1:
175 | db = prodigy.components.db.connect()
176 | db_examples = db.get_dataset(dataset)
177 | if db_examples:
178 | for eg in db_examples:
179 | if PromptExample.is_flagged(eg):
180 | openai.add_example(PromptExample.from_prodigy(eg, openai.labels))
181 | stream = cast(Iterable[Dict], srsly.read_jsonl(filepath))
182 | return {
183 | "dataset": dataset,
184 | "view_id": "blocks",
185 | "stream": openai(stream, batch_size=batch_size, nlp=nlp),
186 | "update": openai.update,
187 | "config": {
188 | "labels": openai.labels,
189 | "batch_size": batch_size,
190 | "exclude_by": "input",
191 | "blocks": [
192 | {"view_id": "ner_manual"},
193 | {"view_id": "html", "html_template": HTML_TEMPLATE},
194 | ],
195 | "show_flag": True,
196 | "global_css": CSS_FILE_PATH.read_text(),
197 | },
198 | }
199 |
200 |
201 | @prodigy.recipe(
202 | "ner.openai.fetch",
203 | input_path=("Path to jsonl data to annotate", "positional", None, Path),
204 | output_path=("Path to save the output", "positional", None, Path),
205 | labels=("Labels (comma delimited)", "positional", None, lambda s: s.split(",")),
206 | lang=("Language to use for tokenizer.", "option", "l", str),
207 | model=("GPT-3 model to use for completion", "option", "m", str),
208 | examples_path=("Examples file to help define the task", "option", "e", Path),
209 | max_examples=("Max examples to include in prompt", "option", "n", int),
210 | prompt_path=("Path to jinja2 prompt template", "option", "p", Path),
211 | batch_size=("Batch size to send to OpenAI API", "option", "b", int),
212 | segment=("Split sentences", "flag", "S", bool),
213 | verbose=("Print extra information to terminal", "option", "flag", bool),
214 | )
215 | def ner_openai_fetch(
216 | input_path: Path,
217 | output_path: Path,
218 | labels: List[str],
219 | lang: str = "en",
220 | model: str = "text-davinci-003",
221 | batch_size: int = 10,
222 | segment: bool = False,
223 | examples_path: Optional[Path] = None,
224 | prompt_path: Path = DEFAULT_PROMPT_PATH,
225 | max_examples: int = 2,
226 | verbose: bool = False,
227 | ):
228 | """Get bulk NER suggestions from an OpenAI API, using zero-shot or few-shot learning.
229 | The results can then be corrected using the `ner.manual` recipe.
230 | This approach lets you get the openai queries out of the way upfront, which can help
231 | if you want to use multiple annotators of if you want to make sure you don't have to
232 | wait on the OpenAI queries. The downside is that you can't flag examples to be integrated
233 | into the prompt during the annotation, unlike the ner.openai.correct recipe.
234 | """
235 | api_key, api_org = get_api_credentials(model)
236 | examples = read_prompt_examples(examples_path, example_class=NERPromptExample)
237 | nlp = spacy.blank(lang)
238 | if segment:
239 | nlp.add_pipe("sentencizer")
240 | openai = OpenAISuggester(
241 | response_parser=make_ner_response_parser(labels=labels, lang=lang),
242 | openai_model=model,
243 | labels=labels,
244 | max_examples=max_examples,
245 | prompt_template=load_template(prompt_path),
246 | segment=segment,
247 | verbose=verbose,
248 | openai_api_org=api_org,
249 | openai_api_key=api_key,
250 | openai_n=1,
251 | openai_retry_timeout_s=10,
252 | openai_read_timeout_s=20,
253 | openai_n_retries=10,
254 | prompt_example_class=NERPromptExample,
255 | )
256 | for eg in examples:
257 | openai.add_example(eg)
258 | stream = list(srsly.read_jsonl(input_path))
259 | stream = openai(tqdm.tqdm(stream), batch_size=batch_size, nlp=nlp)
260 | srsly.write_jsonl(output_path, stream)
261 |
262 |
263 | def _find_substrings(
264 | text: str,
265 | substrings: List[str],
266 | *,
267 | case_sensitive: bool = False,
268 | single_match: bool = False,
269 | ) -> List[Tuple[int, int]]:
270 | """Given a list of substrings, find their character start and end positions in a text. The substrings are assumed to be sorted by the order of their occurrence in the text.
271 | text: The text to search over.
272 | substrings: The strings to find.
273 | case_sensitive: Whether to search without case sensitivity.
274 | single_match: If False, allow one substring to match multiple times in the text. If True, returns the first hit.
275 | """
276 | # remove empty and duplicate strings, and lowercase everything if need be
277 | substrings = [s for s in substrings if s and len(s) > 0]
278 | if not case_sensitive:
279 | text = text.lower()
280 | substrings = [s.lower() for s in substrings]
281 | substrings = _unique(substrings)
282 | offsets = []
283 | for substring in substrings:
284 | search_from = 0
285 | # Search until one hit is found. Continue only if single_match is False.
286 | while True:
287 | start = text.find(substring, search_from)
288 | if start == -1:
289 | break
290 | end = start + len(substring)
291 | offsets.append((start, end))
292 | if single_match:
293 | break
294 | search_from = end
295 | return offsets
296 |
297 |
298 | def _unique(items: List[str]) -> List[str]:
299 | """Remove duplicates without changing order"""
300 | seen = set()
301 | output = []
302 | for item in items:
303 | if item not in seen:
304 | output.append(item)
305 | seen.add(item)
306 | return output
307 |
--------------------------------------------------------------------------------
/recipes/openai_terms.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from functools import reduce
4 | from pathlib import Path
5 | from typing import Callable, Dict, Iterable, List, TypeVar, Union
6 |
7 | import httpx
8 | import jinja2
9 | import prodigy
10 | import prodigy.components.preprocess
11 | import prodigy.util
12 | import rich
13 | import srsly
14 | from dotenv import load_dotenv
15 | from prodigy import set_hashes
16 | from prodigy.components.db import connect
17 | from prodigy.util import msg
18 | from rich.panel import Panel
19 | from rich.pretty import Pretty
20 |
21 | TERMS_DEFAULT_PROMPT_PATH = (
22 | Path(__file__).parent.parent / "templates" / "terms_prompt.jinja2"
23 | )
24 |
25 | # Set up openai
26 | load_dotenv() # take environment variables from .env.
27 |
28 | _ItemT = TypeVar("_ItemT")
29 |
30 |
31 | def _load_template(path: Path) -> jinja2.Template:
32 | # I know jinja has a lot of complex file loading stuff,
33 | # but we're not using the inheritance etc that makes
34 | # that stuff worthwhile.
35 | if not path.suffix == ".jinja2":
36 | msg.fail(
37 | "The --prompt-path (-p) parameter expects a .jinja2 file.",
38 | exits=1,
39 | )
40 | with path.open("r", encoding="utf8") as file_:
41 | text = file_.read()
42 | return jinja2.Template(text)
43 |
44 |
45 | def _parse_terms(completion: str) -> List[str]:
46 | if "\n" not in completion:
47 | # Sometimes it only returns a single item. For example, when there are
48 | # many, many seeds around.
49 | lines = [completion]
50 | else:
51 | # Other times we cannot assume the final item will have had sufficient
52 | # tokens available to complete the term, so we have to discard it.
53 | lines = [item for item in completion.split("\n") if len(item)]
54 | lines = lines[:-1]
55 | return [item.replace("-", "").strip() for item in lines]
56 |
57 |
58 | def _retry429(
59 | call_api: Callable[[], httpx.Response], n: int, timeout_s: int
60 | ) -> httpx.Response:
61 | """Retry a call to the OpenAI API if we get a 429: Too many requests
62 | error.
63 | """
64 | assert n >= 0
65 | assert timeout_s >= 1
66 | r = call_api()
67 | i = -1
68 | while i < n and r.status_code == 429:
69 | time.sleep(timeout_s)
70 | i += 1
71 | return r
72 |
73 |
74 | def _generate_headers() -> Dict[str, str]:
75 | if not os.getenv("OPENAI_KEY"):
76 | msg.fail("The `OPENAI_KEY` is missing from your `.env` file.", exits=1)
77 |
78 | if not os.getenv("OPENAI_ORG"):
79 | msg.fail("The `OPENAI_ORG` is missing from your `.env` file.", exits=1)
80 |
81 | return {
82 | "Authorization": f"Bearer {os.getenv('OPENAI_KEY')}",
83 | "OpenAI-Organization": os.getenv("OPENAI_ORG"),
84 | "Content-Type": "application/json",
85 | }
86 |
87 |
88 | @prodigy.recipe(
89 | # fmt: off
90 | "terms.openai.fetch",
91 | query=("Query to send to OpenAI", "positional", None, str),
92 | output_path=("Path to save the output", "positional", None, Path),
93 | seeds=("One or more comma-seperated seed phrases.","option","s",lambda d: d.split(",")),
94 | n=("The minimum number of items to generate", "option", "n", int),
95 | model=("GPT-3 model to use for completion", "option", "m", str),
96 | prompt_path=("Path to jinja2 prompt template", "option", "p", Path),
97 | verbose=("Print extra information to terminal", "flag", "v", bool),
98 | resume=("Resume by loading in text examples from output file.", "flag", "r", bool),
99 | progress=("Print progress of the recipe.", "flag", "pb", bool),
100 | temperature=("OpenAI temperature param", "option", "t", float),
101 | top_p=("OpenAI top_p param", "option", "tp", float),
102 | best_of=("OpenAI best_of param", "option", "bo", int),
103 | n_batch=("OpenAI batch size param", "option", "nb", int),
104 | max_tokens=("Max tokens to generate per call", "option", "mt", int),
105 | # fmt: on
106 | )
107 | def terms_openai_fetch(
108 | query: str,
109 | output_path: Path,
110 | seeds: List[str] = [],
111 | n: int = 100,
112 | model: str = "text-davinci-003",
113 | prompt_path: Path = TERMS_DEFAULT_PROMPT_PATH,
114 | verbose: bool = False,
115 | resume: bool = False,
116 | progress: bool = False,
117 | temperature=1.0,
118 | top_p=1.0,
119 | best_of=10,
120 | n_batch=10,
121 | max_tokens=100,
122 | ):
123 | """Get bulk term suggestions from the OpenAI API, using zero-shot learning.
124 |
125 | The results can then be corrected using the `prodigy textcat.manual` recipe and
126 | turned into patterns via `prodigy terms.to-patterns`.
127 | """
128 | tic = time.time()
129 | template = _load_template(prompt_path)
130 | # The `best_of` param cannot be less than the amount we batch.
131 | if best_of < n_batch:
132 | best_of = n_batch
133 |
134 | # Start collection of terms. If we resume we also fill seed terms with file contents.
135 | terms = []
136 | if resume:
137 | if output_path.exists():
138 | examples = srsly.read_jsonl(output_path)
139 | terms.extend([e["text"] for e in examples])
140 |
141 | # Mimic behavior from Prodigy terms recipe to ensure that seed terms also appear in output
142 | for seed in seeds:
143 | if seed not in terms:
144 | srsly.write_jsonl(
145 | output_path,
146 | [{"text": seed, "meta": {"openai_query": query}}],
147 | append=True,
148 | append_new_line=False,
149 | )
150 |
151 | # Ensure we have access to correct environment variables and construct headers
152 | if not os.getenv("OPENAI_KEY"):
153 | msg.fail("The `OPENAI_KEY` is missing from your `.env` file.", exits=1)
154 |
155 | if not os.getenv("OPENAI_ORG"):
156 | msg.fail("The `OPENAI_ORG` is missing from your `.env` file.", exits=1)
157 |
158 | headers = _generate_headers()
159 |
160 | # This recipe may overshoot the target, but we keep going until we have at least `n`
161 | while len(terms) < n:
162 | prompt = template.render(n=n, examples=seeds + terms, description=query)
163 | if verbose:
164 | rich.print(Panel(prompt, title="Prompt to OpenAI"))
165 |
166 | make_request = lambda: httpx.post(
167 | "https://api.openai.com/v1/completions",
168 | headers=headers,
169 | json={
170 | "model": model,
171 | "prompt": [prompt],
172 | "temperature": temperature,
173 | "max_tokens": max_tokens,
174 | "top_p": top_p,
175 | "n": min(n_batch, best_of),
176 | "best_of": best_of,
177 | },
178 | timeout=45,
179 | )
180 |
181 | # Catch 429: too many request errors
182 | resp = _retry429(make_request, n=1, timeout_s=30)
183 |
184 | # Report on any other error that might happen, the most typical use-case is
185 | # catching the maximum context length of 4097 tokens when the prompt gets big.
186 | if resp.status_code != 200:
187 | msg.fail(f"Received status code {resp.status_code} from OpenAI. Details:")
188 | rich.print(resp.json())
189 | exit(code=1)
190 |
191 | # Cast to a set to make sure we remove duplicates
192 | choices = resp.json()["choices"]
193 | sets_of_terms = [set(_parse_terms(c["text"])) for c in choices]
194 | parsed_terms = list(reduce(lambda a, b: a.union(b), sets_of_terms))
195 |
196 | # Save intermediate results into file, in-case of a hiccup
197 | srsly.write_jsonl(
198 | output_path,
199 | [{"text": t, "meta": {"openai_query": query}} for t in parsed_terms],
200 | append=True,
201 | append_new_line=False,
202 | )
203 |
204 | # Make the terms list bigger and re-use terms in next prompt.
205 | terms.extend(parsed_terms)
206 | if verbose:
207 | rich.print(Panel(Pretty(terms), title="Terms collected sofar."))
208 | if progress:
209 | rich.print(
210 | f"Received {len(parsed_terms)} items, totalling {len(terms)} terms. "
211 | f"Progress at {round(len(terms)/n*100)}% after {round(time.time() - tic)}s."
212 | )
213 |
--------------------------------------------------------------------------------
/recipes/openai_textcat.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from pathlib import Path
3 | from typing import Callable, Dict, List, Optional
4 |
5 | import prodigy
6 | import spacy
7 | import srsly
8 | import tqdm
9 | from dotenv import load_dotenv
10 | from prodigy.util import msg
11 |
12 | from recipes.openai import _ItemT, OpenAISuggester, PromptExample, get_api_credentials
13 | from recipes.openai import load_template, normalize_label
14 | from recipes.openai import read_prompt_examples
15 |
16 | CSS_FILE_PATH = Path(__file__).parent / "style.css"
17 | DEFAULT_PROMPT_PATH = (
18 | Path(__file__).parent.parent / "templates" / "textcat_prompt.jinja2"
19 | )
20 | HTML_TEMPLATE = """
21 |
22 | {{ #label }}
23 |
24 | OpenAI GPT-3 says: {{ meta.answer }}
25 |
26 | {{ /label }}
27 |
28 | Show the prompt for OpenAI
29 | {{openai.prompt}}
30 |
31 |
32 | Show the response from OpenAI
33 | {{openai.response}}
34 |
35 |
36 | """
37 |
38 | # Set up openai
39 | load_dotenv() # take environment variables from .env.
40 |
41 |
42 | @dataclass
43 | class TextCatPromptExample(PromptExample):
44 | """An example to be passed into an OpenAI TextCat prompt."""
45 |
46 | text: str
47 | answer: str
48 | reason: str
49 |
50 | @classmethod
51 | def from_prodigy(cls, example: _ItemT, labels: List[str]) -> "PromptExample":
52 | """Create a prompt example from Prodigy's format."""
53 | if "text" not in example:
54 | raise ValueError("Cannot make PromptExample without text")
55 |
56 | full_text = example["text"]
57 | reason = example["meta"].get("reason")
58 | if len(labels) == 1:
59 | answer = example.get("answer", "reject")
60 | else:
61 | answer = ",".join(example.get("accept", []))
62 | return cls(text=full_text, answer=answer, reason=reason)
63 |
64 |
65 | def make_textcat_response_parser(labels: List[str]) -> Callable:
66 | def _parse_response(text: str, example: Optional[Dict] = None) -> Dict:
67 | response: Dict[str, str] = {}
68 | if text and any(k in text.lower() for k in ("answer", "reason")):
69 | for line in text.strip().split("\n"):
70 | if line and ":" in line:
71 | key, value = line.split(":", 1)
72 | response[key.strip().lower()] = value.strip()
73 | else:
74 | response = {"answer": "", "reason": ""}
75 |
76 | example = _fmt_binary(response) if len(labels) == 1 else _fmt_multi(response)
77 | return example
78 |
79 | def _fmt_binary(response: Dict[str, str]) -> Dict:
80 | """Parse binary TextCat where the 'answer' key means it's a positive class."""
81 | return {
82 | "answer": response["answer"].lower(),
83 | "label": labels[0],
84 | "meta": {
85 | "answer": response["answer"].upper(),
86 | "reason": response["reason"],
87 | },
88 | }
89 |
90 | def _fmt_multi(response: Dict[str, str]) -> Dict:
91 | """Parse multilabel TextCat where the 'accept' key is a list of positive labels."""
92 | return {
93 | "options": [{"id": label, "text": label} for label in labels],
94 | "answer": "accept",
95 | "meta": {
96 | "reason": response.get("reason", ""),
97 | "gpt_answer": response.get("answer", ""),
98 | },
99 | "accept": list(
100 | filter(
101 | None,
102 | [normalize_label(s.strip()) for s in response["answer"].split(",")],
103 | )
104 | ),
105 | }
106 |
107 | return _parse_response
108 |
109 |
110 | @prodigy.recipe(
111 | # fmt: off
112 | "textcat.openai.correct",
113 | dataset=("Dataset to save answers to", "positional", None, str),
114 | filepath=("Path to jsonl data to annotate", "positional", None, Path),
115 | labels=("Labels (comma delimited)", "option", "L", lambda s: s.split(",")),
116 | lang=("Language to initialize spaCy model", "option", "l", str),
117 | model=("GPT-3 model to use for completion", "option", "m", str),
118 | batch_size=("Batch size to send to OpenAI API", "option", "b", int),
119 | segment=("Split sentences", "flag", "S", bool),
120 | prompt_path=("Path to the .jinja2 prompt template", "option", "p", Path),
121 | examples_path=("Examples file to help define the task", "option", "e", Path),
122 | max_examples=("Max examples to include in prompt", "option", "n", int),
123 | exclusive_classes=("Make the classification task exclusive", "flag", "E", bool),
124 | verbose=("Print extra information to terminal", "flag", "v", bool),
125 | # fmt: on
126 | )
127 | def textcat_openai_correct(
128 | dataset: str,
129 | filepath: Path,
130 | labels: List[str],
131 | lang: str = "en",
132 | model: str = "text-davinci-003",
133 | batch_size: int = 10,
134 | segment: bool = False,
135 | prompt_path: Path = DEFAULT_PROMPT_PATH,
136 | examples_path: Optional[Path] = None,
137 | max_examples: int = 2,
138 | exclusive_classes: bool = False,
139 | verbose: bool = False,
140 | ):
141 | api_key, api_org = get_api_credentials(model)
142 | examples = read_prompt_examples(examples_path, example_class=TextCatPromptExample)
143 | nlp = spacy.blank(lang)
144 |
145 | if segment:
146 | nlp.add_pipe("sentencizer")
147 |
148 | if not exclusive_classes and len(labels) == 1:
149 | msg.warn(
150 | "Binary classification should always be exclusive. Setting "
151 | "`exclusive_classes` parameter to True"
152 | )
153 | exclusive_classes = True
154 |
155 | # Create OpenAISuggester with GPT-3 parameters
156 | openai = OpenAISuggester(
157 | response_parser=make_textcat_response_parser(labels=labels),
158 | prompt_template=load_template(prompt_path),
159 | labels=labels,
160 | max_examples=max_examples,
161 | segment=segment,
162 | openai_api_org=api_org,
163 | openai_api_key=api_key,
164 | openai_n=1,
165 | openai_model=model,
166 | openai_retry_timeout_s=10,
167 | openai_read_timeout_s=20,
168 | openai_n_retries=10,
169 | render_vars={"exclusive_classes": exclusive_classes},
170 | prompt_example_class=TextCatPromptExample,
171 | verbose=verbose,
172 | )
173 | for eg in examples:
174 | openai.add_example(eg)
175 | if max_examples >= 1:
176 | db = prodigy.components.db.connect()
177 | db_examples = db.get_dataset(dataset)
178 | if db_examples:
179 | for eg in db_examples:
180 | if PromptExample.is_flagged(eg):
181 | openai.add_example(PromptExample.from_prodigy(eg, openai.labels))
182 |
183 | # Set up the stream
184 | stream = prodigy.get_stream(filepath)
185 | stream = openai(tqdm.tqdm(stream), batch_size=batch_size, nlp=nlp)
186 |
187 | # Set up the Prodigy UI
188 | return {
189 | "dataset": dataset,
190 | "view_id": "blocks",
191 | "stream": stream,
192 | "update": openai.update,
193 | "config": {
194 | "labels": openai.labels,
195 | "batch_size": batch_size,
196 | "exclude_by": "input",
197 | "choice_style": "single" if exclusive_classes else "multiple",
198 | "blocks": [
199 | {"view_id": "classification" if len(labels) == 1 else "choice"},
200 | {"view_id": "html", "html_template": HTML_TEMPLATE},
201 | ],
202 | "show_flag": True,
203 | "global_css": CSS_FILE_PATH.read_text(),
204 | },
205 | }
206 |
207 |
208 | @prodigy.recipe(
209 | # fmt: off
210 | "textcat.openai.fetch",
211 | filepath=("Path to jsonl data to annotate", "positional", None, Path),
212 | output_path=("Path to save the output", "positional", None, Path),
213 | labels=("Labels (comma delimited)", "option", "L", lambda s: s.split(",")),
214 | lang=("Language to use for tokenizer.", "option", "l", str),
215 | model=("GPT-3 model to use for completion", "option", "m", str),
216 | prompt_path=("Path to jinja2 prompt template", "option", "p", Path),
217 | examples_path=("Examples file to help define the task", "option", "e", Path),
218 | max_examples=("Max examples to include in prompt", "option", "n", int),
219 | batch_size=("Batch size to send to OpenAI API", "option", "b", int),
220 | segment=("Split sentences", "flag", "S", bool),
221 | exclusive_classes=("Make the classification task exclusive", "flag", "E", bool),
222 | verbose=("Print extra information to terminal", "flag", "v", bool),
223 | # fmt: on
224 | )
225 | def textcat_openai_fetch(
226 | filepath: Path,
227 | output_path: Path,
228 | labels: List[str],
229 | lang: str = "en",
230 | model: str = "text-davinci-003",
231 | batch_size: int = 10,
232 | segment: bool = False,
233 | prompt_path: Path = DEFAULT_PROMPT_PATH,
234 | examples_path: Optional[Path] = None,
235 | max_examples: int = 2,
236 | exclusive_classes: bool = False,
237 | verbose: bool = False,
238 | ):
239 | api_key, api_org = get_api_credentials(model)
240 | examples = read_prompt_examples(examples_path, example_class=TextCatPromptExample)
241 | nlp = spacy.blank(lang)
242 |
243 | if segment:
244 | nlp.add_pipe("sentencizer")
245 |
246 | if not exclusive_classes and len(labels) == 1:
247 | msg.warn(
248 | "Binary classification should always be exclusive. Setting "
249 | "`exclusive_classes` parameter to True"
250 | )
251 | exclusive_classes = True
252 |
253 | # Create OpenAISuggester with GPT-3 parameters
254 | openai = OpenAISuggester(
255 | response_parser=make_textcat_response_parser(labels=labels),
256 | prompt_template=load_template(prompt_path),
257 | labels=labels,
258 | max_examples=max_examples,
259 | segment=segment,
260 | openai_api_org=api_org,
261 | openai_api_key=api_key,
262 | openai_n=1,
263 | openai_model=model,
264 | openai_retry_timeout_s=10,
265 | openai_read_timeout_s=20,
266 | openai_n_retries=10,
267 | render_vars={"exclusive_classes": exclusive_classes},
268 | prompt_example_class=TextCatPromptExample,
269 | verbose=verbose,
270 | )
271 | for eg in examples:
272 | openai.add_example(eg)
273 |
274 | # Set up the stream
275 | stream = prodigy.get_stream(filepath)
276 | stream = openai(tqdm.tqdm(stream), batch_size=batch_size, nlp=nlp)
277 | srsly.write_jsonl(output_path, stream)
278 |
--------------------------------------------------------------------------------
/recipes/style.css:
--------------------------------------------------------------------------------
1 | .prodigy-container > .prodigy-content {
2 | white-space: normal;
3 | border-top: 1px solid #ddd;
4 | }
5 |
6 | .cleaned {
7 | text-align: left;
8 | font-size: 14px;
9 | }
10 |
11 | .cleaned pre {
12 | background-color: #eee;
13 | font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;
14 | padding: 15px 20px;
15 | border-radius: 15px;
16 | white-space: pre-wrap;
17 | }
18 |
19 | summary {
20 | font-weight: bold;
21 | cursor: pointer;
22 | font-size: 1.2em;
23 | }
24 |
25 | details {
26 | margin-bottom: 1rem;
27 | }
28 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | httpx
2 | rich==12.6.0
3 | python-dotenv==0.21.0
4 | openai==0.25.0
5 | pytest>=5.2.0,!=7.1.0
6 | pytest-httpx==0.21.3
--------------------------------------------------------------------------------
/requirements_train.txt:
--------------------------------------------------------------------------------
1 | evaluate
2 | seqeval
3 | transformers
4 | torch
--------------------------------------------------------------------------------
/scripts/train_hf_ner.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Dict, List, Optional
3 |
4 | import evaluate
5 | import numpy as np
6 | import spacy.vocab
7 | import typer
8 | from spacy.tokens import DocBin
9 | from transformers import (
10 | AutoModelForTokenClassification,
11 | AutoTokenizer,
12 | DataCollatorForTokenClassification,
13 | Trainer,
14 | TrainingArguments,
15 | )
16 | from transformers.tokenization_utils_base import BatchEncoding
17 |
18 | app = typer.Typer()
19 |
20 | # This can't be imported like a normal library
21 | seqeval = evaluate.load("seqeval")
22 |
23 |
24 | def spacy2hf(
25 | binary_file: Path,
26 | label2id: Dict[str, int],
27 | tokenizer: AutoTokenizer,
28 | ) -> List[BatchEncoding]:
29 | """Given a path to a .spacy file, a label mapping, and an HF tokenizer,
30 | return HF tokens with NER labels.
31 | """
32 |
33 | db = DocBin().from_disk(binary_file)
34 |
35 | hfdocs = []
36 | # Normally you would use the vocab from an nlp object, but we just need
37 | # this for deserialization before conversion.
38 | vocab = spacy.vocab.Vocab()
39 | # first, make ids for all labels
40 | for doc in db.get_docs(vocab):
41 | labels = []
42 | toks = []
43 | for tok in doc:
44 | toks.append(tok.text)
45 | if tok.ent_type == 0:
46 | labels.append("O")
47 | continue
48 | ent_label = f"{tok.ent_iob_}-{tok.ent_type_}"
49 | if ent_label not in label2id:
50 | label2id[ent_label] = len(label2id)
51 | labels.append(ent_label)
52 |
53 | # now do the hf tokenization
54 | tokens_hf = tokenizer(toks, truncation=True, is_split_into_words=True)
55 | labels_hf = []
56 |
57 | for word_id in tokens_hf.word_ids():
58 | if word_id is None:
59 | # for things like [CLS]
60 | labels_hf.append(-100)
61 | else:
62 | # The docs note it's common to assign -100 to subwords after the
63 | # first inside an entity, but this does the simpler thing.
64 | label = label2id[labels[word_id]]
65 | labels_hf.append(label)
66 | tokens_hf["labels"] = labels_hf
67 |
68 | hfdocs.append(tokens_hf)
69 |
70 | return hfdocs
71 |
72 |
73 | def build_compute_metrics(label_list: Dict[int, str]):
74 | def compute_metrics(p):
75 | prediction_matrix, label_matrix = p
76 | prediction_matrix = np.argmax(prediction_matrix, axis=2)
77 |
78 | all_predictions = [
79 | [label_list[pred] for (pred, label) in zip(predictions, labels) if label != -100]
80 | for predictions, labels in zip(prediction_matrix, label_matrix)
81 | ]
82 | gold_labels = [
83 | [label_list[label] for (pred, label) in zip(predictions, labels) if label != -100]
84 | for predictions, labels in zip(prediction_matrix, label_matrix)
85 | ]
86 |
87 | results = seqeval.compute(predictions=all_predictions, references=gold_labels, zero_division=0)
88 | return {
89 | "precision": _round_float(results["overall_precision"]),
90 | "recall": _round_float(results["overall_recall"]),
91 | "f1": _round_float(results["overall_f1"]),
92 | "accuracy": _round_float(results["overall_accuracy"]),
93 | }
94 |
95 | return compute_metrics
96 |
97 |
98 | def _round_float(number: float) -> float:
99 | return round(number, 3)
100 |
101 |
102 | def train_ner(
103 | base_model: str,
104 | tokenizer: AutoTokenizer,
105 | id2label: Dict[int, str],
106 | train_data: List[BatchEncoding],
107 | test_data: List[BatchEncoding],
108 | ) -> Trainer:
109 | """Fine-tune an existing HF model."""
110 | model = AutoModelForTokenClassification.from_pretrained(
111 | base_model, num_labels=len(id2label)
112 | )
113 |
114 | batch_size = 16
115 |
116 | args = TrainingArguments(
117 | f"test-ner",
118 | evaluation_strategy="epoch",
119 | learning_rate=1e-4,
120 | per_device_train_batch_size=batch_size,
121 | per_device_eval_batch_size=batch_size,
122 | num_train_epochs=20,
123 | weight_decay=1e-5,
124 | disable_tqdm=True,
125 | # specify the optimizer to avoid a deprecation warning
126 | optim="adamw_torch",
127 | )
128 |
129 | data_collator = DataCollatorForTokenClassification(tokenizer)
130 | compute_metrics = build_compute_metrics(id2label)
131 |
132 | trainer = Trainer(
133 | model,
134 | args,
135 | train_dataset=train_data,
136 | eval_dataset=test_data,
137 | data_collator=data_collator,
138 | tokenizer=tokenizer,
139 | compute_metrics=compute_metrics,
140 | )
141 |
142 | trainer.train()
143 | trainer.evaluate()
144 | return trainer
145 |
146 |
147 | @app.command("train_hf_ner", context_settings={"allow_extra_args": False})
148 | def train_hf_ner(
149 | # fmt: off
150 | train_file: Path = typer.Argument(..., help="Binary .spacy file containing training data", exists=True, allow_dash=False),
151 | dev_file: Path = typer.Argument(..., help="Binary .spacy file containing dev evaluation data", exists=True, allow_dash=False),
152 | output_path: Optional[Path] = typer.Option(None, "--output", "-o", help="Output directory to store trained pipeline in"),
153 | base_model: str = typer.Option("distilbert-base-uncased", "--base", "-b", help="Base transformer model to start from"),
154 | # fmt: on
155 | ):
156 | """Fine-tune a HuggingFace NER model using a .spacy file as input."""
157 | # prep the data
158 | tokenizer = AutoTokenizer.from_pretrained(base_model)
159 | label2id = {"O": 0}
160 | train = spacy2hf(train_file, label2id, tokenizer)
161 | test = spacy2hf(dev_file, label2id, tokenizer)
162 | # handle the mapping
163 | id2label = {v: k for k, v in label2id.items()}
164 | # actually train
165 | trainer = train_ner(base_model, tokenizer, id2label, train, test)
166 | if output_path:
167 | trainer.save_model(str(output_path))
168 |
169 |
170 | if __name__ == "__main__":
171 | app()
172 |
--------------------------------------------------------------------------------
/templates/ab/input.jinja2:
--------------------------------------------------------------------------------
1 | A haiku about {{topic}}.
--------------------------------------------------------------------------------
/templates/ab/prompt1.jinja2:
--------------------------------------------------------------------------------
1 | Write a haiku about {{topic}}.
2 |
--------------------------------------------------------------------------------
/templates/ab/prompt2.jinja2:
--------------------------------------------------------------------------------
1 | Write an incredibly hilarious haiku about {{topic}}.
--------------------------------------------------------------------------------
/templates/ner_prompt.jinja2:
--------------------------------------------------------------------------------
1 | From the text below, extract the following entities in the following format:
2 | {# whitespace #}
3 | {%- for label in labels -%}
4 | {{label}}:
5 | {# whitespace #}
6 | {%- endfor -%}
7 | {# whitespace #}
8 | Text:
9 | """
10 | {{text}}
11 | """
12 | {# whitespace #}
13 | {%- if examples -%}
14 | {# whitespace #}
15 | For example:
16 | {# whitespace #}
17 | {# whitespace #}
18 | {%- for example in examples -%}
19 | Text:
20 | """
21 | {{ example.text }}
22 | """
23 | {# whitespace #}
24 | {%- for label, substrings in example.entities.items() -%}
25 | {{ label }}: {{ ", ".join(substrings) }}
26 | {# whitespace #}
27 | {%- endfor -%}
28 | {# whitespace #}
29 | {% endfor -%}
30 | {%- endif -%}
31 |
--------------------------------------------------------------------------------
/templates/terms_prompt.jinja2:
--------------------------------------------------------------------------------
1 | Generate me {{n}} examples of {{description}}.
2 |
3 | Here are the examples:
4 | {%- if examples -%}
5 | {%- for example in examples -%}
6 | {# whitespace #}
7 | - {{example}}
8 | {%- endfor -%}
9 | {%- endif -%}
10 | {# whitespace #}
11 | -
12 |
--------------------------------------------------------------------------------
/templates/textcat_prompt.jinja2:
--------------------------------------------------------------------------------
1 | {% if labels|length == 1 %}
2 | {% set label = labels[0] %}
3 | From the text below, determine whether or not it contains a {{ label }}. If it is
4 | a {{ label }}, answer "accept." If it is not a {{ label }}, answer "reject."
5 | {% else %}
6 | Classify the text below to any of the following labels: {{ labels|join(", ") }}
7 | {% if exclusive_classes %}
8 | The task is exclusive, so only choose one label from what I provided.
9 | {% else %}
10 | The task is non-exclusive, so you can provide more than one label as long as
11 | they're comma-delimited. For example: Label1, Label2, Label3.
12 | {% endif %}
13 | {% endif %}
14 | {# whitespace #}
15 | Your answer should only be in the following format:
16 | {# whitespace #}
17 | answer:
18 | reason:
19 | {# whitespace #}
20 | {% if examples %}
21 | Below are some examples (only use these as a guide):
22 | {# whitespace #}
23 | {# whitespace #}
24 | {% for example in examples %}
25 | Text:
26 | """
27 | {{ example.text }}
28 | """
29 | {# whitespace #}
30 | answer: {{ example.answer }}
31 | reason: {{ example.reason }}
32 | {% endfor %}
33 | {% endif %}
34 | {# whitespace #}
35 | Here is the text that needs classification
36 | {# whitespace #}
37 | Text:
38 | """
39 | {{text}}
40 | """
41 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/explosion/prodigy-openai-recipes/a0866f0d05f0a744cab15ad202201c5b691d48b9/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_ner.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from pathlib import Path
3 | from typing import Dict, List, Tuple
4 |
5 | import pytest
6 | import spacy
7 |
8 | from recipes.openai import load_template
9 | from recipes.openai_ner import DEFAULT_PROMPT_PATH, NERPromptExample
10 | from recipes.openai_ner import _find_substrings
11 | from tests.utils import make_suggester
12 |
13 | from recipes.openai_ner import make_ner_response_parser
14 |
15 |
16 | def test_multiple_substrings():
17 | text = "The Blargs is the debut album by rock band The Blargs."
18 | substrings = ["The Blargs", "rock"]
19 | res = _find_substrings(text, substrings, single_match=False)
20 | assert res == [(0, 10), (43, 53), (33, 37)]
21 | res = _find_substrings(text, substrings, single_match=True)
22 | assert res == [(0, 10), (33, 37)]
23 |
24 |
25 | def test_substrings_case():
26 | text = "A, a, B, b, a,b,c,d"
27 | substrings = ["a,"]
28 | res = _find_substrings(text, substrings, single_match=False, case_sensitive=True)
29 | assert res == [(3, 5), (12, 14)]
30 | res = _find_substrings(text, substrings, single_match=False, case_sensitive=False)
31 | assert res == [(0, 2), (3, 5), (12, 14)]
32 | res = _find_substrings(text, substrings, single_match=True, case_sensitive=True)
33 | assert res == [(3, 5)]
34 | res = _find_substrings(text, substrings, single_match=True, case_sensitive=False)
35 | assert res == [(0, 2)]
36 |
37 |
38 | def test_template_no_examples():
39 | text = "David Bowie lived in Berlin in the 1970s."
40 | labels = ["PERSON", "PLACE", "PERIOD"]
41 | examples = []
42 | path = Path(__file__).parent.parent / "templates" / "ner_prompt.jinja2"
43 | template = load_template(path)
44 | prompt = template.render(text=text, labels=labels, examples=examples)
45 | assert (
46 | prompt
47 | == f"""
48 | From the text below, extract the following entities in the following format:
49 | PERSON:
50 | PLACE:
51 | PERIOD:
52 |
53 | Text:
54 | \"\"\"
55 | David Bowie lived in Berlin in the 1970s.
56 | \"\"\"
57 | """.lstrip()
58 | )
59 |
60 |
61 | def test_template_two_examples():
62 | text = "David Bowie lived in Berlin in the 1970s."
63 | labels = ["PERSON", "PLACE", "PERIOD"]
64 | examples = [
65 | NERPromptExample(
66 | text="New York is a large city.", entities={"PLACE": ["New York"]}
67 | ),
68 | NERPromptExample(
69 | text="David Hasslehoff and Helena Fischer are big in Germany.",
70 | entities={
71 | "PERSON": ["David Hasslehoff", "Helena Fischer"],
72 | "PLACE": ["Germany"],
73 | },
74 | ),
75 | ]
76 | path = Path(__file__).parent.parent / "templates" / "ner_prompt.jinja2"
77 | template = load_template(path)
78 | prompt = template.render(text=text, labels=labels, examples=examples)
79 | assert (
80 | prompt
81 | == f"""
82 | From the text below, extract the following entities in the following format:
83 | PERSON:
84 | PLACE:
85 | PERIOD:
86 |
87 | Text:
88 | \"\"\"
89 | David Bowie lived in Berlin in the 1970s.
90 | \"\"\"
91 |
92 | For example:
93 |
94 | Text:
95 | \"\"\"
96 | New York is a large city.
97 | \"\"\"
98 | PLACE: New York
99 |
100 | Text:
101 | \"\"\"
102 | David Hasslehoff and Helena Fischer are big in Germany.
103 | \"\"\"
104 | PERSON: David Hasslehoff, Helena Fischer
105 | PLACE: Germany
106 |
107 | """.lstrip()
108 | )
109 |
110 |
111 | @pytest.mark.parametrize(
112 | "comment,text,raw_spans,filtered_spans",
113 | [
114 | (
115 | "Does it take longest when nested and shorter is first?",
116 | "one two three four",
117 | [("a", 0, 1), ("b", 0, 2)],
118 | [("b", 0, 2)],
119 | ),
120 | (
121 | "Does it take longest when nested and shorter is last?",
122 | "one two three four",
123 | [("b", 0, 2), ("a", 0, 1)],
124 | [("b", 0, 2)],
125 | ),
126 | (
127 | "Does it take first when overlapping and shorter is first?",
128 | "one two three four",
129 | [("b", 0, 2), ("a", 1, 3)],
130 | [("b", 0, 2)],
131 | ),
132 | (
133 | "Does it take first when overlapping and shorter is last?",
134 | "one two three four",
135 | [("a", 0, 2), ("b", 1, 3)],
136 | [("a", 0, 2)],
137 | ),
138 | ],
139 | )
140 | def test_one_token_per_span(
141 | comment: str,
142 | text: str,
143 | raw_spans: List[Tuple[str, int, int]],
144 | filtered_spans: List[Tuple[int, int]],
145 | ):
146 | labels = list(sorted(set(label for label, _, _1 in raw_spans)))
147 | suggester = make_suggester(
148 | response_parser=make_ner_response_parser(labels=labels, lang="en"),
149 | prompt_path=DEFAULT_PROMPT_PATH,
150 | labels=labels,
151 | openai_api_key="fake api key",
152 | openai_api_org="fake api org",
153 | prompt_example_class=NERPromptExample,
154 | )
155 | prompt = suggester._get_prompt(text, labels=labels, examples=[])
156 | response = _get_response(text, labels, raw_spans)
157 | stream = [{"text": text, "openai": {"prompt": prompt, "response": response}}]
158 | stream = list(suggester.format_suggestions(stream, nlp=spacy.blank("en")))
159 | output_spans = [
160 | (s["label"], s["token_start"], s["token_end"] + 1) for s in stream[0]["spans"]
161 | ]
162 | assert output_spans == filtered_spans
163 | # Also check there's no overlaps
164 | seen_tokens = set()
165 | for _, start, end in output_spans:
166 | for i in range(start, end):
167 | assert i not in seen_tokens, "Overlapping or nested spans found"
168 | seen_tokens.add(i)
169 |
170 |
171 | def _get_response(text: str, labels, spans: List[Tuple[str, int, int]]) -> str:
172 | # Get table of start and end character offsets for the test spans.
173 | tokens = text.split()
174 | start_chars, end_chars = _get_token_char_maps(tokens, [True for _ in tokens])
175 | spans_by_label = defaultdict(list)
176 | for label, start, end in spans:
177 | spans_by_label[label].append(text[start_chars[start] : end_chars[end - 1]])
178 | response_lines = []
179 | for label in labels:
180 | response_lines.append(f"{label}: {', '.join(spans_by_label[label])}")
181 | return "\n".join(response_lines)
182 |
183 |
184 | def _get_token_char_maps(
185 | tokens: List[str], whitespace: List[bool]
186 | ) -> Tuple[Dict[int, int], Dict[int, int]]:
187 | idx = 0
188 | start_chars = {}
189 | end_chars = {}
190 | for i, token in enumerate(tokens):
191 | start_chars[i] = idx
192 | idx += len(token) + int(whitespace[i])
193 | end_chars[i] = idx
194 | return start_chars, end_chars
195 |
--------------------------------------------------------------------------------
/tests/test_openai.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import pytest
3 | import jinja2
4 | import httpx
5 | from pytest_httpx import HTTPXMock
6 |
7 | from recipes.openai import OpenAISuggester, get_api_credentials
8 |
9 |
10 | # Setup the template and the suggester
11 | environment = jinja2.Environment()
12 | template = environment.from_string("Prompt: {{ text }}")
13 |
14 | openai = OpenAISuggester(
15 | prompt_template=template,
16 | labels=["label1", "label2"],
17 | max_examples=1,
18 | segment=False,
19 | openai_model="text-davinci-003",
20 | openai_api_key="Fake api key",
21 | openai_api_org="Fake api org",
22 | response_parser=lambda x: {"key": "value"},
23 | openai_n_retries=1,
24 | openai_read_timeout_s=1,
25 | openai_retry_timeout_s=1,
26 | prompt_example_class=None,
27 | )
28 |
29 |
30 | @pytest.mark.parametrize(
31 | "prompts,response_text",
32 | [
33 | (["A single prompt"], ["A single response"]),
34 | (["A batch", "of prompts"], ["A batch", "of responses"]),
35 | ],
36 | )
37 | def test_openai_response_follows_contract(
38 | httpx_mock: HTTPXMock, prompts: List[str], response_text
39 | ):
40 | """Test happy path where OpenAI follows the contract and we can parse it
41 | https://beta.openai.com/docs/api-reference/completions
42 | """
43 |
44 | httpx_mock.add_response(
45 | method="POST",
46 | json={
47 | "choices": [
48 | {
49 | "text": text,
50 | "index": index,
51 | "logprobs": 0.1,
52 | "finish_reason": "length",
53 | }
54 | for index, text in enumerate(response_text)
55 | ]
56 | },
57 | )
58 |
59 | chatgpt_response = openai._get_openai_response(prompts=prompts)
60 | assert len(chatgpt_response) == len(prompts)
61 | assert set(chatgpt_response) == set(response_text)
62 |
63 |
64 | @pytest.mark.parametrize("error_code", openai.RETRY_ERROR_CODES)
65 | def test_retry_function_when_calls_fail(httpx_mock, error_code):
66 | """Test if status error shows up after all failed retries."""
67 | httpx_mock.add_response(status_code=error_code)
68 | with pytest.raises(httpx.HTTPStatusError):
69 | openai._get_openai_response(prompts=["Some prompt"])
70 |
71 |
72 | @pytest.mark.parametrize(
73 | "key,org", [(None, "fake api org"), ("fake api key", None), (None, None)]
74 | )
75 | def test_get_api_credentials_error_handling_envvars(monkeypatch, key, org):
76 | """Ensure that auth fails whenever key or org is none."""
77 | monkeypatch.setenv("OPENAI_KEY", key)
78 | monkeypatch.setenv("OPENAI_ORG", org)
79 | with pytest.raises(SystemExit):
80 | get_api_credentials(model="text-davinci-003")
81 |
82 |
83 | @pytest.mark.parametrize("error_code", [422, 500, 501])
84 | def test_get_api_credentials_calls_fail(httpx_mock, monkeypatch, error_code):
85 | """Ensure that auth fails when we encounter an error code."""
86 | httpx_mock.add_response(status_code=error_code)
87 | monkeypatch.setenv("OPENAI_KEY", "fake api key")
88 | monkeypatch.setenv("OPENAI_ORG", "fake api org")
89 |
90 | with pytest.raises(SystemExit):
91 | get_api_credentials(model="text-davinci-003")
92 |
93 |
94 | def test_get_api_credentials_model_does_not_exist(httpx_mock, monkeypatch):
95 | """Ensure that auth fails when model passed does not exist.
96 | https://beta.openai.com/docs/api-reference/models/list
97 | """
98 | httpx_mock.add_response(
99 | method="GET",
100 | json={
101 | "data": [{"id": "model-id-0"}, {"id": "model-id-1"}, {"id": "model-id-2"}]
102 | },
103 | )
104 | monkeypatch.setenv("OPENAI_KEY", "fake api key")
105 | monkeypatch.setenv("OPENAI_ORG", "fake api org")
106 |
107 | with pytest.raises(SystemExit):
108 | get_api_credentials(model="a-model-that-does-not-exist")
109 |
--------------------------------------------------------------------------------
/tests/test_terms.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from recipes.openai_terms import _parse_terms
4 |
5 | # We cannot assume that `risk` is a complete game because
6 | # OpenAI may have exhausted the tokens
7 | base_completion = """monopoly
8 | - Scrabble
9 | - risk
10 | """
11 |
12 | # Added some trailing spaces in this one
13 | base_completion_with_trailing_spaces = """monopoly
14 | - scrabble
15 | - risk
16 | """
17 |
18 | # An example where the tokens may have been exhausted
19 | # note the capitalisation
20 | trailing_token_completion = """monopoly
21 | - scrabble
22 | - Risk
23 | - carcas
24 | """
25 |
26 | # This can also happen
27 | single_line_completion = "monopoly"
28 |
29 |
30 | @pytest.mark.parametrize(
31 | "comment,completion,expectation",
32 | [
33 | (
34 | "Base OpenAI completion with capitalisation",
35 | base_completion,
36 | ["monopoly", "Scrabble"],
37 | ),
38 | (
39 | "Check trailing spaces",
40 | base_completion_with_trailing_spaces,
41 | ["monopoly", "scrabble"],
42 | ),
43 | (
44 | "Completion with bad final item",
45 | trailing_token_completion,
46 | ["monopoly", "scrabble", "Risk"],
47 | ),
48 | (
49 | "Example of a single-line OpenAI completion",
50 | single_line_completion,
51 | ["monopoly"],
52 | ),
53 | ],
54 | )
55 | def test_parse_terms(comment, completion, expectation):
56 | assert _parse_terms(completion=completion) == expectation
57 |
--------------------------------------------------------------------------------
/tests/test_textcat.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from recipes.openai_textcat import make_textcat_response_parser
4 |
5 |
6 | @pytest.mark.parametrize(
7 | # fmt: off
8 | "response,answer",
9 | [
10 | ("Answer: accept\nReason: The text is a recipe.", "accept"),
11 | ("Answer: Accept\nReason: The text is a recipe.", "accept"),
12 | ("Answer: reject\nReason: The text is not a recipe.", "reject"),
13 | ("Answer: Reject\nReason: The text is not a recipe.", "reject"),
14 | ("answer: reject\nreason: The text is not a recipe.", "reject"),
15 | ("answer: Reject\nreason: The text is not a recipe.\nI don't know what it's about.", "reject"),
16 | ],
17 | # fmt: on
18 | )
19 | def test_parse_response_binary(response, answer):
20 | """Test if parse response works for common examples"""
21 | labels = ["recipe"] # binary
22 | parser = make_textcat_response_parser(labels=labels)
23 | example = parser(response)
24 | assert example.get("answer") == answer
25 |
26 |
27 | @pytest.mark.parametrize(
28 | # fmt: off
29 | "response,answer",
30 | [
31 | ("Answer: recipe,feedback,question\nReason: It is all three!", ["recipe", "feedback", "question"]),
32 | ("Answer: recipe\nReason: The text is a recipe.", ["recipe"]),
33 | ("Answer: recipe,\nReason: The text is a recipe.", ["recipe"]),
34 | ("Answer: recipe,feedback,\nReason: The text is a feedback about a recipe.\nThat's what I think...", ["recipe", "feedback"]),
35 | ("answer: recipe,feedback,\nreason: The text is a feedback about a recipe.\nThat's what I think...", ["recipe", "feedback"]),
36 | ("answer: \nreason: It's none of the above.", []),
37 | ],
38 | # fmt: on
39 | )
40 | def test_parser_response_multi(response, answer):
41 | """Test if parse response works for common examples"""
42 | labels = ["recipe", "feedback", "question"] # multiclass
43 | parser = make_textcat_response_parser(labels=labels)
44 | example = parser(response)
45 | assert set(example.get("accept")) == set(answer)
46 |
47 |
48 | @pytest.mark.parametrize("labels", [["binary"], ["multi1", "multi2"]])
49 | def test_parser_no_answer(labels):
50 | """Test if parse response works for common examples"""
51 | empty_response = ""
52 | parser = make_textcat_response_parser(labels=labels)
53 | example = parser(empty_response)
54 | assert not example.get("accept")
55 | assert not example.get("reason")
56 |
57 |
58 | @pytest.mark.parametrize("labels", [["binary"], ["multi1", "multi2"]])
59 | @pytest.mark.parametrize(
60 | "response", ["asdfghjklmnop", "I am now a sentient robot. Bow before me."]
61 | )
62 | def test_parser_openai_returns_arbitrary_text(labels, response):
63 | """Test if parser response works for any arbitrary text"""
64 | parser = make_textcat_response_parser(labels=labels)
65 | example = parser(response)
66 | assert not example.get("accept")
67 | assert not example.get("reason")
68 |
--------------------------------------------------------------------------------
/tests/utils.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Callable, List, Optional
3 |
4 | from recipes.openai import OpenAISuggester, PromptExample, get_api_credentials
5 | from recipes.openai import load_template
6 |
7 |
8 | def make_suggester(
9 | response_parser: Callable,
10 | labels: List[str],
11 | prompt_path: Path,
12 | prompt_example_class: Optional[PromptExample] = None,
13 | model: str = "text-davinci-003",
14 | **kwargs
15 | ) -> OpenAISuggester:
16 | if "openai_api_key" not in kwargs or "openai_api_org" not in kwargs:
17 | api_key, api_org = get_api_credentials(model)
18 | if "openai_api_key" not in kwargs:
19 | kwargs["openai_api_key"] = api_key
20 | if "openai_api_org" not in kwargs:
21 | kwargs["openai_api_org"] = api_org
22 | if "max_examples" not in kwargs:
23 | kwargs["max_examples"] = 0
24 | if "prompt_template" not in kwargs:
25 | kwargs["prompt_template"] = load_template(prompt_path)
26 | if "segment" not in kwargs:
27 | kwargs["segment"] = False
28 | if "openai_model" not in kwargs:
29 | kwargs["openai_model"] = "text-davinci-003"
30 |
31 | openai = OpenAISuggester(
32 | response_parser=response_parser,
33 | labels=labels,
34 | prompt_example_class=prompt_example_class,
35 | **kwargs
36 | )
37 | return openai
38 |
--------------------------------------------------------------------------------