├── .github └── workflows │ └── unit_tests.yml ├── .gitignore ├── README.md ├── data ├── ab_example.jsonl └── reddit_r_cooking_sample.jsonl ├── examples ├── ner.yaml ├── textcat_binary.yaml └── textcat_multilabel.yaml ├── recipes ├── __init__.py ├── openai.py ├── openai_ab.py ├── openai_ner.py ├── openai_terms.py ├── openai_textcat.py └── style.css ├── requirements.txt ├── requirements_train.txt ├── scripts └── train_hf_ner.py ├── templates ├── ab │ ├── input.jinja2 │ ├── prompt1.jinja2 │ └── prompt2.jinja2 ├── ner_prompt.jinja2 ├── terms_prompt.jinja2 └── textcat_prompt.jinja2 └── tests ├── __init__.py ├── test_ner.py ├── test_openai.py ├── test_terms.py ├── test_textcat.py └── utils.py /.github/workflows/unit_tests.yml: -------------------------------------------------------------------------------- 1 | name: Unit Tests 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - main 8 | 9 | jobs: 10 | setup: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v3 14 | - name: Set up Python 3.9 15 | uses: actions/setup-python@v4 16 | with: 17 | python-version: 3.9 18 | cache: "pip" # caching pip dependencies 19 | 20 | - name: Install dependencies 21 | run: | 22 | pip install --upgrade pip 23 | pip install -r requirements.txt 24 | pip install -r requirements_train.txt 25 | 26 | - name: Check out Prodigy 27 | uses: actions/checkout@v3 28 | with: 29 | repository: explosion/prodigy 30 | ref: v1.11.8 31 | path: ./prodigy 32 | ssh-key: ${{ secrets.GHA_PRODIGY_READ }} 33 | 34 | - name: Install prodigy 35 | run: | 36 | ls -la 37 | pip install ./prodigy 38 | 39 | - name: Test with pytest 40 | run: | 41 | python -m pytest --ignore=prodigy 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | .env 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | *.jsonl 132 | *.ipynb -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Archival notice 4 | 5 | The recipes in this repository have since moved to [Prodigy](https://prodi.gy/) and are being maintained there. They will soon even get an upgrade with the advent of [spacy-llm](https://github.com/explosion/spacy-llm) support, which features better prompts and multiple LLM providers. That is why we've opted to archive this repo, so that we may focus on maintaining these recipes as part of spaCy and Prodigy directly. 6 | 7 | You can learn more by checking out the [large language models section](https://prodi.gy/docs/large-language-models) on the docs. 8 | 9 | # Prodigy OpenAI recipes 10 | 11 | This repository contains example code on how to combine **zero- and few-shot learning 12 | with a small annotation effort** to obtain a **high-quality dataset with maximum efficiency**. Specifically, we use large language models available from [OpenAI](https://openai.com) to provide us with an initial set of predictions, 13 | then spin up a [Prodigy](https://prodi.gy) instance on our local machine 14 | to go through these predictions and curate them. This allows us to obtain a 15 | gold-standard dataset pretty quickly, and train a smaller, supervised model that fits 16 | our exact needs and use-case. 17 | 18 | ![](https://user-images.githubusercontent.com/13643239/208497043-178beb47-f7c6-4b3e-a253-9e12e2f0c807.png) 19 | 20 | https://user-images.githubusercontent.com/13643239/208504034-0ab6bcbe-6d2b-415d-8257-233f2074ba31.mp4 21 | 22 | ## ⏳ Setup and Install 23 | 24 | Make sure to [install Prodigy](https://prodi.gy/docs/install) as well as a few additional Python dependencies: 25 | 26 | ```bash 27 | python -m pip install prodigy -f https://XXXX-XXXX-XXXX-XXXX@download.prodi.gy 28 | python -m pip install -r requirements.txt 29 | ``` 30 | 31 | With `XXXX-XXXX-XXXX-XXXX` being your personal Prodigy license key. 32 | 33 | Then, create a new API key from [openai.com](https://beta.openai.com/account/api-keys) or fetch an existing 34 | one. Record the secret key as well as the [organization key](https://beta.openai.com/account/org-settings) 35 | and make sure these are available as environmental variables. For instance, set them in a `.env` file in the 36 | root directory: 37 | 38 | ``` 39 | OPENAI_ORG = "org-..." 40 | OPENAI_KEY = "sk-..." 41 | ``` 42 | 43 | ## 📋 Named-entity recognition (NER) 44 | 45 | ### `ner.openai.correct`: NER annotation with zero- or few-shot learning 46 | 47 | This recipe marks entity predictions obtained from a large language model and allows you to flag them as correct, or to 48 | manually curate them. This allows you to quickly gather a gold-standard dataset through zero-shot or few-shot learning. 49 | It's very much like using the standard [`ner.correct`](https://prodi.gy/docs/recipes#ner-correct) recipe in Prodi.gy, 50 | but we're using GPT-3 as a backend model to make predictions. 51 | 52 | ```bash 53 | python -m prodigy ner.openai.correct dataset filepath labels [--options] -F ./recipes/openai_ner.py 54 | ``` 55 | 56 | | Argument | Type | Description | Default | 57 | | ----------------------- | ---- | ----------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------- | 58 | | `dataset` | str | Prodigy dataset to save annotations to. | | 59 | | `filepath` | Path | Path to `.jsonl` data to annotate. The data should at least contain a `"text"` field. | | 60 | | `labels` | str | Comma-separated list defining the NER labels the model should predict. | | 61 | | `--lang`, `-l` | str | Language of the input data - will be used to obtain a relevant tokenizer. | `"en"` | 62 | | `--segment`, `-S` | bool | Flag to set when examples should be split into sentences. By default, the full input article is shown. | `False` | 63 | | `--model`, `-m` | str | GPT-3 model to use for initial predictions. | `"text-davinci-003"` | 64 | | `--prompt_path`, `-p` | Path | Path to the `.jinja2` [prompt template](templates). | `./templates/ner_prompt.jinja2` | 65 | | `--examples-path`, `-e` | Path | Path to examples to help define the task. The file can be a .yml, .yaml or .json. If set to `None`, zero-shot learning is applied. | `None` | 66 | | `--max-examples`, `-n` | int | Max number of examples to include in the prompt to OpenAI. If set to 0, zero-shot learning is always applied, even when examples are available. | 2 | 67 | | `--batch-size`, `-b` | int | Batch size of queries to send to the OpenAI API. | 10 | 68 | | `--verbose`, `-v` | bool | Flag to print extra information to the terminal. | `False` | 69 | 70 | #### Example usage 71 | 72 | Let's say we want to recognize dishes, ingredients and cooking equipment from some text we obtained from a cooking subreddit. 73 | We'll send the text to GPT-3, hosted by OpenAI, and provide an annotation prompt to explain 74 | to the language model the type of predictions we want. Something like: 75 | 76 | ``` 77 | From the text below, extract the following entities in the following format: 78 | dish: 79 | ingredient: 80 | equipment: 81 | 82 | Text: 83 | ... 84 | ``` 85 | 86 | We define the definition of this prompt in a .jinja2 file which also describes how to append examples for few-shot learning. 87 | You can create your own [template](templates) and provide it to the recipe with the `--prompt-path` or `-p` option. 88 | Additionally, with `--examples-path` or `-e` you can set the file path of a .y(a)ml or .json file that contains additional examples: 89 | 90 | ```bash 91 | python -m prodigy ner.openai.correct my_ner_data ./data/reddit_r_cooking_sample.jsonl "dish,ingredient,equipment" -p ./templates/ner_prompt.jinja2 -e ./examples/ner.yaml -n 2 -F ./recipes/openai_ner.py 92 | ``` 93 | 94 | After receiving the results from the OpenAI API, the Prodigy recipe converts the predictions into an annotation task 95 | that can be rendered with Prodigy. The task even shows the original prompt as well as the raw answer we obtained 96 | from the language model. 97 | 98 | 99 | 100 | Here, we see that the model is able to correctly recognize dishes, ingredients and cooking equipment right from the start! 101 | 102 | The recipe also offers a `--verbose` or `-v` option that includes the exact prompt and response on the terminal as traffic is received. 103 | Note that because the requests to the API are batched, you might have to scroll back a bit to find the current prompt. 104 | 105 | ### Interactively tune the prompt examples 106 | 107 | At some point, you might notice a mistake in the predictions of the OpenAI language model. For instance, we noticed an error 108 | in the recognition of cooking equipment in this example: 109 | 110 | 111 | 112 | If you see these kind of systematic errors, you can steer the predictions in the right direction by correcting the example and then selecting the small "flag" icon 113 | in the top right of the Prodigy UI: 114 | 115 | 116 | 117 | Once you hit accept on the Prodigy interface, the flagged example will be automatically picked up and added to the examples 118 | that are sent to the OpenAI API as part of the prompt. 119 | 120 | > **Note** 121 | > Because Prodigy batches these requests, the prompt will be updated with a slight 122 | > delay, after the next batch of prompts is sent to OpenAI. You can experiment 123 | > with making the batch size (`--batch-size` or `-b`) smaller to have the change 124 | > come into effect sooner, but this might negatively impact the speed of the 125 | > annotation workflow. 126 | 127 | ### `ner.openai.fetch`: Fetch examples up-front 128 | 129 | The `ner.openai.correct` recipe fetches examples from OpenAI while annotating, but we've also included a recipe that can fetch a large batch of examples upfront. 130 | 131 | ```bash 132 | python -m prodigy ner.openai.fetch input_data.jsonl predictions.jsonl "dish,ingredient,equipment" -F ./recipes/ner.py 133 | ``` 134 | 135 | This will create a `predictions.jsonl` file that can be loaded with the [`ner.manual`](https://prodi.gy/docs/recipes#ner-manual) recipe. 136 | 137 | Note that the OpenAI API might return "429 Too Many Request" errors when requesting too much data at once - in this case it's best to ensure you only request 138 | 100 or so examples at a time. 139 | 140 | ### Exporting the annotations and training an NER model 141 | 142 | After you've curated a set of predictions, you can export the results with [`db-out`](https://prodi.gy/docs/recipes#db-out): 143 | 144 | ```bash 145 | python -m prodigy db-out my_ner_data > ner_data.jsonl 146 | ``` 147 | 148 | The format of the exported annotations contains all the data you need to train a smaller model downstream. Each example 149 | in the dataset contains the original text, the tokens, span annotations denoting the entities, etc. 150 | 151 | You can also export the data to spaCy's [binary format](https://spacy.io/api/data-formats#training), using [`data-to-spacy`](https://prodi.gy/docs/recipes#data-to-spacy). This format lets you load in the annotations as spaCy `Doc` objects, which can be convenient for further conversion. The `data-to-spacy` command also makes it easy to train an NER model with spaCy. First you export the data, specifying the train data as 20% of the total: 152 | 153 | ```bash 154 | python -m prodigy data-to-spacy ./data/annotations/ --ner my_ner_data -es 0.2 155 | ``` 156 | 157 | Then you can train a model with spaCy or [Prodigy](https://prodi.gy/docs/recipes/#training): 158 | 159 | ```bash 160 | python -m spacy train ./data/annotations/config.cfg --paths.train ./data/annotations/train.spacy --paths.dev ./data/annotations/dev.spacy -o ner-model 161 | ``` 162 | 163 | This will save a model to the `ner-model/` directory. 164 | 165 | We've also included an experimental script to load in the `.spacy` binary format and train a model with the HuggingFace `transformers` library. You can use the same data you just exported and run the script like this: 166 | 167 | ```bash 168 | # First you need to install the HuggingFace library and requirements 169 | pip install -r requirements_train.txt 170 | python ./scripts/train_hf_ner.py ./data/annotations/train.spacy ./data/annotations/dev.spacy -o hf-ner-model 171 | ``` 172 | 173 | The resulting model will be saved to the `hf-ner-model/` directory. 174 | 175 | ## 📋 Text categorization (Textcat) 176 | 177 | ### `textcat.openai.correct`: Textcat annotation with zero- or few-shot learning 178 | 179 | This recipe enables us to classify texts faster with the help of a large 180 | language model. It also provides a "reason" to explain why a particular label 181 | was chosen. 182 | 183 | ```bash 184 | python -m prodigy textcat.openai.correct dataset filepath labels [--options] -F ./recipes/openai_textcat.py 185 | ``` 186 | 187 | | Argument | Type | Description | Default | 188 | | --------------------------- | ---- | ----------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------- | 189 | | `dataset` | str | Prodigy dataset to save annotations to. | | 190 | | `filepath` | Path | Path to `.jsonl` data to annotate. The data should at least contain a `"text"` field. | | 191 | | `labels` | str | Comma-separated list defining the text categorization labels the model should predict. | | 192 | | `--lang`, `-l` | str | Language of the input data - will be used to obtain a relevant tokenizer. | `"en"` | 193 | | `--segment`, `-S` | bool | Flag to set when examples should be split into sentences. By default, the full input article is shown. | `False` | 194 | | `--model`, `-m` | str | GPT-3 model to use for initial predictions. | `"text-davinci-003"` | 195 | | `--prompt-path`, `-p` | Path | Path to the `.jinja2` [prompt template](templates). | `./templates/textcat_prompt.jinja2` | 196 | | `--examples-path`, `-e` | Path | Path to examples to help define the task. The file can be a .yml, .yaml or .json. If set to `None`, zero-shot learning is applied. | `None` | 197 | | `--max-examples`, `-n` | int | Max number of examples to include in the prompt to OpenAI. If set to 0, zero-shot learning is always applied, even when examples are available. | 2 | 198 | | `--batch-size`, `-b` | int | Batch size of queries to send to the OpenAI API. | 10 | 199 | | `--exclusive-classes`, `-E` | bool | Flag to make the classification task exclusive. | `False` | 200 | | `--verbose`, `-v` | bool | Flag to print extra information to the terminal. | `False` | 201 | 202 | 203 | #### Example usage 204 | 205 | The `textcat` recipes can be used for binary, multiclass, and multilabel text 206 | categorization. You can set this by passing the appropriate number of labels in 207 | the `--labels` parameter; for example, passing a single label turns it into 208 | binary classification and so on. We will talk about each one in the proceeding 209 | sections. 210 | 211 | ##### Binary text categorization 212 | 213 | Suppose we want to know if a particular Reddit comment talks about a food 214 | recipe. We'll send the text to GPT-3 and provide a prompt that instructs the 215 | predictions we want. 216 | 217 | ``` 218 | From the text below, determine wheter or not it contains a recipe. If it is a 219 | recipe, answer "accept." If it is not a recipe, answer "reject." 220 | 221 | Your answer should only be in the following format: 222 | answer: 223 | reason: 224 | 225 | Text: 226 | ``` 227 | 228 | For binary classification, we want GPT-3 to return "accept" if a given text is a 229 | food recipe and "reject" otherwise. GPT-3's suggestion is then displayed 230 | prominently in the UI. We can press the ACCEPT (check mark) button to 231 | include the text as a positive example or press the REJECT (cross 232 | mark) button if it is a negative example. 233 | 234 | 235 | ```sh 236 | python -m prodigy textcat.openai.correct my_binary_textcat_data data/reddit_r_cooking_sample.jsonl --labels recipe -F recipes/openai_textcat.py 237 | ``` 238 | 239 | 240 | 241 | 242 | ##### Multilabel and multiclass text categorization 243 | 244 | Now, suppose we want to classify Reddit comments as a recipe, a feedback, or a 245 | question. We can write the following prompt: 246 | 247 | ``` 248 | Classify the text below to any of the following labels: recipe, feedback, question. 249 | The task is exclusive, so only choose one label from what I provided. 250 | 251 | Your answer should only be in the following format: 252 | answer: 253 | reason: 254 | 255 | Text: 256 | ``` 257 | 258 | Then, we can use this recipe to handle multilabel and multiclass cases by 259 | passing the three labels to the `--labels` parameter. We should also set the 260 | `--exclusive-classes` flag to render a single-choice UI: 261 | 262 | ```sh 263 | python -m prodigy textcat.openai.correct my_multi_textcat_data data/reddit_r_cooking_sample.jsonl \ 264 | --labels recipe,feedback,question \ 265 | --exclusive-classes \ 266 | -F recipes/openai_textcat.py 267 | ``` 268 | 269 | 270 | 271 | ### Writing templates 272 | 273 | We write these prompts as a .jinja2 template that can also take in examples for 274 | few-shot learning. You can create your own [template](templates) and provide it 275 | to the recipe with the `--prompt-path` or `-p` option. Additionally, with 276 | `--examples-path` or `-e` you can set the file path of a .y(a)ml or .json file 277 | that contains additional examples. You can also add context in these examples as 278 | we observed it to improve the output: 279 | 280 | ```bash 281 | python -m prodigy textcat.openai.correct my_binary_textcat_data \ 282 | ./data/reddit_r_cooking_sample.jsonl \ 283 | --labels recipe \ 284 | --prompt-path ./templates/textcat_prompt.jinja2 \ 285 | --examples-path ./examples/textcat_binary.yaml -n 2 \ 286 | -F ./recipes/openai_textcat.py 287 | ``` 288 | 289 | Similar to the NER recipe, this recipe also converts the predictions into an 290 | annotation task that can be rendered with Prodigy. For binary classification, we 291 | use the [`classification`](https://prodi.gy/docs/api-interfaces#classification) 292 | interface with custom HTML elements, while for multilabel or multiclass text 293 | categorization, we use the 294 | [`choice`](https://prodi.gy/docs/api-interfaces#choice) annotation interface. 295 | Notice that we include the original prompt and the OpenAI response in the UI. 296 | 297 | 298 | Lastly, you can use the `--verbose` or `-v` flag to show the exact prompt and 299 | response on the terminal. Note that because the requests to the API are batched, 300 | you might have to scroll back a bit to find the current prompt. 301 | 302 | 303 | ### Interactively tune the prompt examples 304 | 305 | Similar to the NER recipes, you can also steer the predictions in the right 306 | direction by correcting the example and then selecting the small "flag" icon in 307 | the top right of the Prodigy UI: 308 | 309 | 310 | 311 | Once you hit the accept button on the Prodigy interface, the flagged 312 | example will be picked up and added to the few-shot examples sent to the OpenAI 313 | API as part of the prompt. 314 | 315 | > **Note** 316 | > Because Prodigy batches these requests, the prompt will be updated with a slight 317 | > delay, after the next batch of prompts is sent to OpenAI. You can experiment 318 | > with making the batch size (`--batch-size` or `-b`) smaller to have the change 319 | > come into effect sooner, but this might negatively impact the speed of the 320 | > annotation workflow. 321 | 322 | ### `textcat.openai.fetch`: Fetch text categorization examples up-front 323 | 324 | The `textcat.openai.fetch` recipe allows us to fetch a large batch of examples 325 | upfront. This is helpful when you are with a highly-imbalanced data and interested 326 | only in rare examples. 327 | 328 | ```bash 329 | python -m prodigy textcat.openai.fetch input_data.jsonl predictions.jsonl --labels Recipe -F ./recipes/openai_textcat.py 330 | ``` 331 | 332 | This will create a `predictions.jsonl` file that can be loaded with the 333 | [`textcat.manual`](https://prodi.gy/docs/recipes#textcat-manual) recipe. 334 | 335 | Note that the OpenAI API might return "429 Too Many Request" errors when 336 | requesting too much data at once - in this case it's best to ensure you only 337 | request 100 or so examples at a time and take a look at the [API's rate 338 | limits](https://help.openai.com/en/articles/5955598-is-api-usage-subject-to-any-rate-limits). 339 | 340 | #### Working with imbalanced data 341 | 342 | The `textcat.openai.fetch` recipe is suitable for working with datasets where 343 | there is severe class imbalance. Usually, you'd want to find examples of the 344 | rare class rather than annotating a random sample. From there, you want to 345 | upsample them to train a decent model and so on. 346 | 347 | This is where large language models like OpenAI might help. 348 | 349 | Using the [Reddit r/cooking dataset](data), we prompted OpenAI to 350 | look for comments that resemble a food recipe. Instead of annotating 10,000 351 | examples, we ran `textcat.openai.fetch` and obtained 145 positive classes. Out 352 | of those 145 examples, 114 turned out to be true positives (79% precision). We 353 | then checked 1,000 negative examples and found 12 false negative cases (98% 354 | recall). 355 | 356 | Ideally, once we fully annotated the dataset, we can train a supervised model 357 | that is better to use than relying on zero-shot predictions for production. The 358 | running cost is low and it's easier to manage. 359 | 360 | ### Exporting the annotations and training a text categorization model 361 | 362 | After you've curated a set of predictions, you can export the results with 363 | [`db-out`](https://prodi.gy/docs/recipes#db-out): 364 | 365 | ```bash 366 | python -m prodigy db-out my_textcat_data > textcat_data.jsonl 367 | ``` 368 | 369 | The format of the exported annotations contains all the data you need to train a 370 | smaller model downstream. Each example in the dataset contains the original 371 | text, the tokens, span annotations denoting the entities, etc. 372 | 373 | You can also export the data to spaCy's [binary 374 | format](https://spacy.io/api/data-formats#training), using 375 | [`data-to-spacy`](https://prodi.gy/docs/recipes#data-to-spacy). This format lets 376 | you load in the annotations as spaCy `Doc` objects, which can be convenient for 377 | further conversion. The `data-to-spacy` command also makes it easy to train a 378 | text categorization model with spaCy. First you export the data, specifying the 379 | train data as 20% of the total: 380 | 381 | ```bash 382 | # For binary textcat 383 | python -m prodigy data-to-spacy ./data/annotations/ --textcat my_textcat_data -es 0.2 384 | # For multilabel textcat 385 | python -m prodigy data-to-spacy ./data/annotations/ --textcat-multilabel my_textcat_data -es 0.2 386 | ``` 387 | Then you can train a model with spaCy or [Prodigy](https://prodi.gy/docs/recipes/#training): 388 | 389 | ```bash 390 | python -m spacy train ./data/annotations/config.cfg --paths.train ./data/annotations/train.spacy --paths.dev ./data/annotations/dev.spacy -o textcat-model 391 | ``` 392 | 393 | This will save a model to the `textcat-model/` directory. 394 | 395 | ## 📋 Terms 396 | 397 | ### `terms.openai.fetch`: Fetch phrases and terms based on a query 398 | 399 | This recipe generates terms and phrases obtained from a large language model. These 400 | terms can be curated and turned into patterns files, which can help with downstream annotation tasks. 401 | 402 | ```bash 403 | python -m prodigy terms.openai.fetch query filepath [--options] -F ./recipes/openai_terms.py 404 | ``` 405 | 406 | | Argument | Type | Description | Default | 407 | |:---------------------:|-------|-----------------------------------------------------|---------------------------------| 408 | | `query` | str | Query to send to OpenAI | | 409 | | `output_path` | Path | Path to save the output | | 410 | | `--seeds`,`-s` | str | One or more comma-separated seed phrases. | `""` | 411 | | `--n`,`-n` | int | Minimum number of items to generate | `100` | 412 | | `--model`, `-m` | str | GPT-3 model to use for completion | `"text-davinci-003"` | 413 | | `--prompt-path`, `-p` | Path | Path to jinja2 prompt template | `templates/terms_prompt.jinja2` | 414 | | `--verbose`,`-v` | bool | Print extra information to terminal | `False` | 415 | | `--resume`, `-r` | bool | Resume by loading in text examples from output file | `False` | 416 | | `--progress`,`-pb` | bool | Print progress of the recipe. | `False` | 417 | | `--temperature`,`-t` | float | OpenAI temperature param | `1.0` | 418 | | `--top-p`, `--tp` | float | OpenAI top_p param | `1.0` | 419 | | `--best-of`, `-bo` | int | OpenAI best_of param" | `10` | 420 | | `--n-batch`,`-nb` | int | OpenAI batch size param | `10` | 421 | | `--max-tokens`, `-mt` | int | Max tokens to generate per call | `100` | 422 | 423 | #### Example usage 424 | 425 | Suppose you're interested in detecting skateboard tricks in text, then you might want to start 426 | with a term list of known tricks. You might want to start with the following query: 427 | 428 | ```bash 429 | # Base behavior, fetch at least 100 terms/phrases 430 | python -m prodigy terms.openai.fetch "skateboard tricks" tricks.jsonl --n 100 --prompt-path templates/terms_prompt.jinja2 -F recipes/openai_terms.py 431 | ``` 432 | 433 | This will generate a prompt to OpenAI that asks to try and generate at least 100 examples of "skateboard tricks". 434 | There's an upper limit to the amount of tokens that can be generated by OpenAI, but this recipe will try and keep 435 | collecting terms until it reached the amount specified. 436 | 437 | You can choose to make the query more elaborate if you want to try to be more precise, but you can alternatively 438 | also choose to add some seed terms via `--seeds`. These will act as starting examples to help steer OpenAI 439 | in the right direction. 440 | 441 | ```bash 442 | # Base behavior but with seeds 443 | python -m prodigy terms.openai.fetch "skateboard tricks" tricks.jsonl --n 100 --seeds "kickflip,ollie" --prompt-path templates/terms_prompt.jinja2 -F recipes/openai_terms.py 444 | ``` 445 | 446 | Collecting many examples can take a while, so it can be helpful to show the progress, via `--progress` 447 | as requests are sent. 448 | 449 | ```bash 450 | # Adding progress output as we wait for 500 examples 451 | python -m prodigy terms.openai.fetch "skateboard tricks" tricks.jsonl --n 500 --progress --seeds "kickflip,ollie" --prompt-path templates/terms_prompt.jinja2 -F recipes/openai_terms.py 452 | ``` 453 | 454 | After collecting a few examples, you might want to generate more. You can choose to continue from a 455 | previous output file. This will effectively re-use those examples as seeds for the prompt to OpenAI. 456 | 457 | ```bash 458 | # Use the `--resume` flag to re-use previous examples 459 | python -m prodigy terms.openai.fetch "skateboard tricks" tricks.jsonl --n 50 --resume --prompt-path templates/terms_prompt.jinja2 -F recipes/openai_terms.py 460 | ``` 461 | 462 | When the recipe is done, you'll have a `tricks.jsonl` file that has contents that look like this: 463 | 464 | ```json 465 | {"text":"pop shove it","meta":{"openai_query":"skateboard tricks"}} 466 | {"text":"switch flip","meta":{"openai_query":"skateboard tricks"}} 467 | {"text":"nose slides","meta":{"openai_query":"skateboard tricks"}} 468 | {"text":"lazerflip","meta":{"openai_query":"skateboard tricks"}} 469 | {"text":"lipslide","meta":{"openai_query":"skateboard tricks"}} 470 | ... 471 | ``` 472 | 473 | ### Towards Patterns 474 | 475 | You now have a `tricks.jsonl` file on disk that contains skateboard tricks, but you cannot 476 | assume that all of these will be accurate. The next step would be to review the terms and you 477 | can use the [`textcat.manual`](https://prodi.gy/docs/recipes/#textcat-manual) recipe that comes 478 | with Prodigy for that. 479 | 480 | ```bash 481 | # The tricks.jsonl was fetched from OpenAI beforehand 482 | python -m prodigy textcat.manual skateboard-tricks-list tricks.jsonl --label skateboard-tricks 483 | ``` 484 | 485 | This generates an interface that looks like this: 486 | 487 | 488 | 489 | You can manually accept or reject each example and when you're done annotating you can export 490 | the annotated text into a patterns file via the [`terms.to-patterns`](https://prodi.gy/docs/recipes/#terms-to-patterns) recipe. 491 | 492 | ```bash 493 | # Generate a `patterns.jsonl` file. 494 | python -m prodigy terms.to-patterns skateboard-tricks-list patterns.jsonl --label skateboard-tricks --spacy-model blank:en 495 | ``` 496 | 497 | When the recipe is done, you'll have a `patterns.jsonl` file that has contents that look like this: 498 | 499 | ```json 500 | {"label":"skateboard-tricks","pattern":[{"lower":"pop"},{"lower":"shove"},{"lower":"it"}]} 501 | {"label":"skateboard-tricks","pattern":[{"lower":"switch"},{"lower":"flip"}]} 502 | {"label":"skateboard-tricks","pattern":[{"lower":"nose"},{"lower":"slides"}]} 503 | {"label":"skateboard-tricks","pattern":[{"lower":"lazerflip"}]} 504 | {"label":"skateboard-tricks","pattern":[{"lower":"lipslide"}]} 505 | ... 506 | ``` 507 | 508 | ### Known Limitations 509 | 510 | OpenAI has a hard limit on the prompt size. You cannot have a prompt larger than 4079 tokens. Unfortunately 511 | that means that there is a limit to the size of term lists that you can generate. The recipe will report 512 | an error when this happens, but it's good to be aware of this limitation. 513 | 514 | ## 📋 Prompt A/B evaluation 515 | 516 | ### `ab.openai.prompts`: A/B evaluation of prompts 517 | 518 | The goal of this recipe is to quickly allow someone to compare the quality of outputs from two prompts 519 | in a quantifiable and blind way. 520 | 521 | ```bash 522 | python -m prodigy ab.openai.prompts dataset inputs_path display_template_path prompt1_template_path prompt2_template_path [--options] -F ./recipes/openai_ab.py 523 | ``` 524 | 525 | | Argument | Type | Description | Default | 526 | |:----------------------:|-------|------------------------------------------------------|----------------------| 527 | | `dataset` | str | Prodigy dataset to save answers into | | 528 | | `inputs_path` | Path | Path to jsonl inputs | | 529 | | `display_template_path`| Path | Template for summarizing the arguments | | 530 | | `prompt1_template_path`| Path | Path to the first jinja2 prompt template | | 531 | | `prompt2_template_path`| Path | Path to the second jinja2 prompt template | | 532 | | `--model`, `-m` | str | GPT-3 model to use for completion | `"text-davinci-003"` | 533 | | `--batch-size`, `-b` | int | Batch size to send to OpenAI API | `10` | 534 | | `--verbose`,`-v` | bool | Print extra information to terminal | `False` | 535 | | `--no-random`,`-NR` | bool | Don't randomize which annotation is shown as correct | `False` | 536 | | `--repeat`, `-r` | int | How often to send the same prompt to OpenAI | `1` | 537 | 538 | #### Example usage 539 | 540 | As an example, let's try to generate humorous haikus. To do that we first need to 541 | construct two jinja files that represent the prompt to send to OpenAI. 542 | 543 | ##### `templates/ab/prompt1.jinja2` 544 | 545 | ``` 546 | Write a haiku about {{topic}}. 547 | ``` 548 | 549 | ##### `templates/ab/prompt2.jinja2` 550 | 551 | ``` 552 | Write an incredibly hilarious haiku about {{topic}}. So funny! 553 | ``` 554 | 555 | You can provide variables for these prompts by constructing a .jsonl file with the required 556 | parameters. In this case we need to make sure that `{{topic}}` is accounted for. 557 | 558 | Here's an example `.jsonl` file that could work. 559 | 560 | ##### `data/ab_example.jsonl` 561 | 562 | ```json 563 | {"id": 0, "prompt_args": {"topic": "star wars"}} 564 | {"id": 0, "prompt_args": {"topic": "kittens"}} 565 | {"id": 0, "prompt_args": {"topic": "the python programming language"}} 566 | {"id": 0, "prompt_args": {"topic": "maths"}} 567 | ``` 568 | 569 | > **Note** 570 | > 571 | > All the arguments under `prompt_args` will be passed to render the jinja templates. 572 | > The `id` is mandatory and can be used to identify groups in later analysis. 573 | 574 | We're nearly ready to evaluate, but this recipe requires one final jinja2 template. 575 | This one won't be used to generate a prompt, but it will generate a useful title 576 | that reminds the annotator of the current task. Here's an example of such a template. 577 | 578 | ##### `templates/ab/input.jinja2` 579 | 580 | ``` 581 | A haiku about {{topic}}. 582 | ``` 583 | 584 | When you put all of these templates together you can start annotating. The command below 585 | starts the annotation interface and also uses the `--repeat 4` option. This will ensure 586 | that each topic will be used to generate a prompt at least 4 times. 587 | 588 | ``` 589 | python -m prodigy ab.openai.prompts haiku data/ab_example.jsonl templates/ab/input.jinja2 templates/ab/prompt1.jinja2 templates/ab/prompt2.jinja2 --repeat 5 -F recipes/openai_ab.py 590 | ``` 591 | 592 | This is what the annotation interface looks like: 593 | 594 | ![](https://user-images.githubusercontent.com/1019791/216607308-97a0b82d-03ea-4d09-ab79-0ec6b26cc033.png) 595 | 596 | When you look at this interface you'll notice that the title template is rendered and that 597 | you're able to pick from two options. Both options are responses from OpenAI that were 598 | generated by the two prompt templates. You can also see the `prompt_args` rendered in 599 | the lower right corner of the choice menu. 600 | 601 | From here you can annotate your favorite examples and gather data that might help you 602 | decide on which prompt is best. 603 | 604 | #### Results 605 | 606 | Once you're done annotating you'll be presented with an overview of the results. 607 | 608 | ``` 609 | =========================== ✨ Evaluation results =========================== 610 | ✔ You preferred prompt1.jinja2 611 | 612 | prompt1.jinja2 11 613 | prompt2.jinja2 5 614 | ``` 615 | 616 | But you can also fetch the raw annotations from the database for further analysis. 617 | 618 | ``` 619 | python -m prodigy db-out haiku 620 | ``` 621 | 622 | ## ❓ What's next? 623 | 624 | There’s lots of interesting follow-up experiments to this, and lots of ways to adapt the basic idea to different tasks or data sets. We’re also interested to try out different prompts. It’s unclear how much the format the annotations are requested in might change the model’s predictions, or whether there’s a shorter prompt that might perform just as well. We also want to run some end-to-end experiments. 625 | -------------------------------------------------------------------------------- /data/ab_example.jsonl: -------------------------------------------------------------------------------- 1 | {"id": 0, "prompt_args": {"topic": "star wars"}} 2 | {"id": 0, "prompt_args": {"topic": "kittens"}} 3 | {"id": 0, "prompt_args": {"topic": "the python programming language"}} 4 | {"id": 0, "prompt_args": {"topic": "maths"}} -------------------------------------------------------------------------------- /examples/ner.yaml: -------------------------------------------------------------------------------- 1 | - 2 | text: "You can't get a great chocolate flavor with carob." 3 | 4 | entities: 5 | Ingredient: 6 | - carob 7 | 8 | - 9 | text: "You can probably sand-blast it if it's an anodized aluminum pan, which it probably is." 10 | entities: 11 | Ingredient: [] 12 | Equipment: ["anodized aluminum pan"] 13 | -------------------------------------------------------------------------------- /examples/textcat_binary.yaml: -------------------------------------------------------------------------------- 1 | - text: "This is a recipe for scrambled egg: 2 eggs, 1 butter, batter them, and then fry in a hot pan for 2 minutes" 2 | answer: "accept" 3 | reason: "This is a recipe for making a scrambled egg" 4 | - text: "This is a recipe for fried rice: 1 cup of day old rice, 1 butter, 2 cloves of garlic: put them all in a wok and stir them together." 5 | answer: "accept" 6 | reason: "This is a recipe for making a fried rice" 7 | - text: "I tried it and it's not good" 8 | answer: "reject" 9 | reason: "It doesn't talk about a recipe." 10 | -------------------------------------------------------------------------------- /examples/textcat_multilabel.yaml: -------------------------------------------------------------------------------- 1 | - text: "Can someone try this recipe?" 2 | answer: "question" 3 | reason: "It is a question about trying a recipe." 4 | - text: "1 cup of rice then egg and then mix them well. Should I add garlic last?" 5 | answer: "question,recipe" 6 | reason: "It is a question about the steps in making a fried rice." 7 | -------------------------------------------------------------------------------- /recipes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explosion/prodigy-openai-recipes/a0866f0d05f0a744cab15ad202201c5b691d48b9/recipes/__init__.py -------------------------------------------------------------------------------- /recipes/openai.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import copy 3 | import os 4 | import sys 5 | import time 6 | from dataclasses import dataclass 7 | from pathlib import Path 8 | from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, TypeVar 9 | 10 | import httpx 11 | import jinja2 12 | import rich 13 | import srsly 14 | from prodigy.components import preprocess 15 | from prodigy.util import msg, set_hashes 16 | from rich.panel import Panel 17 | from spacy.language import Language 18 | 19 | _ItemT = TypeVar("_ItemT") 20 | _PromptT = TypeVar("_PromptT", bound="PromptExample") 21 | 22 | 23 | @dataclass 24 | class PromptExample(abc.ABC): 25 | """An example to be passed into an OpenAI prompt. 26 | When inheriting this dataclass, you should implement the `from_prodigy` 27 | function that takes in a Prodigy task example and formats it back 28 | into a dataclass that can fill a prompt template. 29 | You can refer to Prodigy's API Interfaces documentation 30 | (https://prodi.gy/docs/api-interfaces) to see how most examples are structured 31 | for each task. 32 | """ 33 | 34 | @staticmethod 35 | def is_flagged(example: _ItemT) -> bool: 36 | """Check whether a Prodigy example is flagged for use 37 | in the prompt.""" 38 | 39 | return ( 40 | example.get("flagged") is True 41 | and example.get("answer") == "accept" 42 | and "text" in example 43 | ) 44 | 45 | @classmethod 46 | def from_prodigy(cls, example: _ItemT, labels: Iterable[str]) -> "PromptExample": 47 | """Create a prompt example from Prodigy's format.""" 48 | ... 49 | 50 | 51 | def normalize_label(label: str) -> str: 52 | return label.lower() 53 | 54 | 55 | class OpenAISuggester: 56 | """Suggest annotations using OpenAI's GPT-3 57 | prompt_template (jinja2.Template): A Jinja2 template that contains the 58 | prompt to send to OpenAI's GPT-3 API. 59 | model (str): The GPT-3 model ID to use for completion. Check the OpenAI 60 | documentation for more information https://beta.openai.com/docs/models/overview. 61 | labels (List[str]): List of labels for annotation. 62 | max_examples (int): The maximum number of examples to stream in the Prodigy UI. 63 | segment (bool): If set to True, segment the documents into sentences. 64 | verbose (bool): Show verbose output in the command-line, including the prompt and response from OpenAI. 65 | openai_api_org (str): The OpenAI API organization. 66 | openai_api_key (str): The OpenAI API key. 67 | openai_temperature (float): The temperature parameter (from 0 to 1) that controls the 68 | randomness of GPT-3's output. 69 | openai_max_tokens (int): The maximum amout of tokens that GPT-3's 70 | completion API can generate. 71 | openai_n (int): The number of completions to generate for each prompt. 72 | openai_n_retries (int): The number of retries whenever a 429 error occurs. 73 | openai_retry_timeout_s (int): The amount of time before attempting another request whenever we 74 | encounter a 429 error. Increases exponentially for each retry. 75 | openai_read_timeout_s (int): The amount of time to wait a response output during a request. 76 | examples (List[PromptExample]): A list of examples to add to the prompt to guide GPT-3 output. 77 | response_parser (Callable[str] -> Dict): A function that accepts a string that represents 78 | GPT-3's raw response, and parses it into a dictionary that is compatible to Prodigy's 79 | annotation interfaces. 80 | render_vars (Dict[str, Any]): A dictionary containing additional variables to render in the 81 | Jinja2 template. By default, the Jinja2 template will render the text (str), some labels (List[str]), 82 | and examples (PromptExample). If you wish to add other task-specific variables, you should supply 83 | them in this variable. 84 | """ 85 | 86 | prompt_template: jinja2.Template 87 | model: str 88 | labels: List[str] 89 | max_examples: int 90 | segment: bool 91 | verbose: bool 92 | openai_api_org: str 93 | openai_api_key: str 94 | openai_temperature: float 95 | openai_max_tokens: int 96 | openai_retry_timeout_s: int 97 | openai_read_timeout_s: int 98 | openai_n_retries: int 99 | openai_n: int 100 | examples: List[PromptExample] 101 | response_parser: Callable 102 | render_vars: Dict[str, Any] 103 | prompt_example_class: PromptExample 104 | 105 | OPENAI_COMPLETIONS_ENDPOINT: str = "https://api.openai.com/v1/completions" 106 | RETRY_ERROR_CODES: List[int] = [429, 503] 107 | 108 | def __init__( 109 | self, 110 | prompt_template: jinja2.Template, 111 | *, 112 | labels: List[str], 113 | max_examples: int, 114 | segment: bool, 115 | openai_api_org: str, 116 | openai_api_key: str, 117 | openai_model: str, 118 | response_parser: Callable, 119 | prompt_example_class: PromptExample, 120 | openai_temperature: int = 0, 121 | openai_max_tokens: int = 500, 122 | openai_retry_timeout_s: int = 1, 123 | openai_read_timeout_s: int = 30, 124 | openai_n_retries: int = 10, 125 | openai_n: int = 1, 126 | render_vars: Optional[Dict[str, Any]] = None, 127 | verbose: bool = False, 128 | ): 129 | self.prompt_template = prompt_template 130 | self.model = openai_model 131 | self.labels = [normalize_label(label) for label in labels] 132 | self.max_examples = max_examples 133 | self.verbose = verbose 134 | self.segment = segment 135 | self.examples = [] 136 | self.openai_api_org = openai_api_org 137 | self.openai_api_key = openai_api_key 138 | self.openai_temperature = openai_temperature 139 | self.openai_max_tokens = openai_max_tokens 140 | self.openai_retry_timeout_s = openai_retry_timeout_s 141 | self.openai_read_timeout_s = openai_read_timeout_s 142 | self.openai_n = openai_n 143 | self.openai_n_retries = openai_n_retries 144 | self.response_parser = response_parser 145 | self.prompt_example_class = prompt_example_class 146 | self.render_vars = {} if render_vars is None else render_vars 147 | 148 | def __call__( 149 | self, 150 | stream: Iterable[_ItemT], 151 | *, 152 | nlp: Language, 153 | batch_size: int, 154 | **kwargs, 155 | ) -> Iterable[_ItemT]: 156 | if self.segment: 157 | stream = preprocess.split_sentences(nlp, stream) # type: ignore 158 | 159 | stream = self.pipe(stream, nlp, batch_size, **kwargs) 160 | stream = self.set_hashes(stream) 161 | return stream 162 | 163 | def pipe( 164 | self, stream: Iterable[_ItemT], nlp: Language, batch_size: int, **kwargs 165 | ) -> Iterable[_ItemT]: 166 | """Process the stream and add suggestions from OpenAI.""" 167 | stream = self.stream_suggestions(stream, batch_size) 168 | stream = self.format_suggestions(stream, nlp=nlp) 169 | return stream 170 | 171 | def set_hashes(self, stream: Iterable[_ItemT]) -> Iterable[_ItemT]: 172 | for example in stream: 173 | yield set_hashes(example) 174 | 175 | def update(self, examples: Iterable[_ItemT]) -> float: 176 | """Update the examples that will be used in the prompt based on user flags.""" 177 | for eg in examples: 178 | if PromptExample.is_flagged(eg): 179 | self.add_example( 180 | self.prompt_example_class.from_prodigy(eg, self.labels) 181 | ) 182 | return 0.0 183 | 184 | def add_example(self, example: PromptExample) -> None: 185 | """Add an example for use in the prompts. Examples are pruned to the most recent max_examples.""" 186 | if self.max_examples and example: 187 | self.examples.append(example) 188 | if len(self.examples) > self.max_examples: 189 | self.examples = self.examples[-self.max_examples :] 190 | 191 | def stream_suggestions( 192 | self, stream: Iterable[_ItemT], batch_size: int 193 | ) -> Iterable[_ItemT]: 194 | """Get zero-shot or few-shot annotations from OpenAI. 195 | Given a stream of input examples, we define a prompt, get a response from OpenAI, 196 | and yield each example with their predictions to the output stream. 197 | """ 198 | for batch in batch_sequence(stream, batch_size): 199 | prompts = [ 200 | self._get_prompt(eg["text"], labels=self.labels, examples=self.examples) 201 | for eg in batch 202 | ] 203 | responses = self._get_openai_response(prompts) 204 | for eg, prompt, response in zip(batch, prompts, responses): 205 | if self.verbose: 206 | rich.print(Panel(prompt, title="Prompt to OpenAI")) 207 | eg["openai"] = {"prompt": prompt, "response": response} 208 | if self.verbose: 209 | rich.print(Panel(response, title="Response from OpenAI")) 210 | yield eg 211 | 212 | def format_suggestions( 213 | self, stream: Iterable[_ItemT], *, nlp: Language 214 | ) -> Iterable[_ItemT]: 215 | """Parse the examples in the stream and set up labels 216 | to display in the Prodigy UI. 217 | """ 218 | stream = preprocess.add_tokens(nlp, stream, skip=True) # type: ignore 219 | for example in stream: 220 | example = copy.deepcopy(example) 221 | if "meta" not in example: 222 | example["meta"] = {} 223 | 224 | response = example["openai"].get("response", "") 225 | example.update(self.response_parser(response, example)) 226 | yield example 227 | 228 | def _get_prompt( 229 | self, text: str, labels: List[str], examples: List[PromptExample] 230 | ) -> str: 231 | """Generate a prompt for GPT-3 OpenAI.""" 232 | return self.prompt_template.render( 233 | text=text, labels=labels, examples=examples, **self.render_vars 234 | ) 235 | 236 | def _get_openai_response(self, prompts: List[str]) -> List[str]: 237 | headers = { 238 | "Authorization": f"Bearer {self.openai_api_key}", 239 | "OpenAI-Organization": self.openai_api_org, 240 | "Content-Type": "application/json", 241 | } 242 | r = retry( 243 | lambda: httpx.post( 244 | self.OPENAI_COMPLETIONS_ENDPOINT, 245 | headers=headers, 246 | json={ 247 | "model": self.model, 248 | "prompt": prompts, 249 | "temperature": self.openai_temperature, 250 | "max_tokens": self.openai_max_tokens, 251 | "n": self.openai_n, 252 | }, 253 | timeout=self.openai_read_timeout_s, 254 | ), 255 | n=self.openai_n_retries, 256 | timeout_s=self.openai_retry_timeout_s, 257 | error_codes=self.RETRY_ERROR_CODES, 258 | ) 259 | r.raise_for_status() 260 | responses = r.json() 261 | return [responses["choices"][i]["text"] for i in range(len(prompts))] 262 | 263 | 264 | def get_api_credentials(model: Optional[str] = None) -> Tuple[str, str]: 265 | # Fetch and check the key 266 | api_key = os.getenv("OPENAI_KEY") 267 | if api_key is None: 268 | m = ( 269 | "Could not find the API key to access the openai API. Ensure you have an API key " 270 | "set up via https://beta.openai.com/account/api-keys, then make it available as " 271 | "an environment variable 'OPENAI_KEY', for instance in a .env file." 272 | ) 273 | msg.fail(m) 274 | sys.exit(-1) 275 | # Fetch and check the org 276 | org = os.getenv("OPENAI_ORG") 277 | if org is None: 278 | m = ( 279 | "Could not find the organisation to access the openai API. Ensure you have an API key " 280 | "set up via https://beta.openai.com/account/api-keys, obtain its organization ID 'org-XXX' " 281 | "via https://beta.openai.com/account/org-settings, then make it available as " 282 | "an environment variable 'OPENAI_ORG', for instance in a .env file." 283 | ) 284 | msg.fail(m) 285 | sys.exit(-1) 286 | 287 | # Check the access and get a list of available models to verify the model argument (if not None) 288 | # Even if the model is None, this call is used as a healthcheck to verify access. 289 | headers = { 290 | "Authorization": f"Bearer {api_key}", 291 | "OpenAI-Organization": org, 292 | } 293 | r = retry( 294 | lambda: httpx.get( 295 | "https://api.openai.com/v1/models", 296 | headers=headers, 297 | ), 298 | n=1, 299 | timeout_s=1, 300 | ) 301 | if r.status_code == 422: 302 | m = ( 303 | "Could not access api.openai.com -- 422 permission denied." 304 | "Visit https://beta.openai.com/account/api-keys to check your API keys." 305 | ) 306 | msg.fail(m) 307 | sys.exit(-1) 308 | elif r.status_code != 200: 309 | m = "Error accessing api.openai.com" f"{r.status_code}: {r.text}" 310 | msg.fail(m) 311 | sys.exit(-1) 312 | 313 | if model is not None: 314 | response = r.json()["data"] 315 | models = [response[i]["id"] for i in range(len(response))] 316 | if model not in models: 317 | e = f"The specified model '{model}' is not available. Choices are: {sorted(set(models))}" 318 | msg.fail(e, exits=1) 319 | 320 | return api_key, org 321 | 322 | 323 | def read_prompt_examples( 324 | path: Optional[Path], *, example_class: Type[PromptExample] 325 | ) -> List[PromptExample]: 326 | if path is None: 327 | return [] 328 | elif path.suffix in (".yml", ".yaml"): 329 | return read_yaml_examples(path, example_class=example_class) 330 | elif path.suffix == ".json": 331 | data = srsly.read_json(path) 332 | assert isinstance(data, list) 333 | return [PromptExample(**eg) for eg in data] 334 | else: 335 | msg.fail( 336 | "The --examples-path (-e) parameter expects a .yml, .yaml or .json file." 337 | ) 338 | sys.exit(-1) 339 | 340 | 341 | def load_template(path: Path) -> jinja2.Template: 342 | # I know jinja has a lot of complex file loading stuff, 343 | # but we're not using the inheritance etc that makes 344 | # that stuff worthwhile. 345 | if not path.suffix == ".jinja2": 346 | msg.fail( 347 | "The --prompt-path (-p) parameter expects a .jinja2 file.", 348 | exits=1, 349 | ) 350 | with path.open("r", encoding="utf8") as file_: 351 | text = file_.read() 352 | return jinja2.Template(text, undefined=jinja2.DebugUndefined) 353 | 354 | 355 | def retry( 356 | call_api: Callable[[], httpx.Response], 357 | n: int, 358 | timeout_s: int, 359 | error_codes: List[int] = [429, 503], 360 | ) -> httpx.Response: 361 | """Retry a call to the OpenAI API if we get a non-ok status code. 362 | This function automatically retries a request if it catches a response 363 | with an error code in `error_codes`. The amount of timeout also increases 364 | exponentially every time we retry. 365 | """ 366 | assert n >= 0 367 | assert timeout_s >= 1 368 | r = call_api() 369 | i = -1 370 | # We don't want to retry on every non-ok status code. Some are about 371 | # incorrect inputs, etc. and we want to terminate on those. 372 | while i < n and r.status_code in error_codes: 373 | time.sleep(timeout_s) 374 | i += 1 375 | timeout_s = timeout_s * 2 # Increase timeout everytime you retry 376 | msg.text( 377 | f"Retrying call (retries left: {n-i}, timeout: {timeout_s}s). " 378 | f"Previous call returned: {r.status_code} ({r.reason_phrase})" 379 | ) 380 | return r 381 | 382 | 383 | def read_yaml_examples( 384 | path: Path, *, example_class: Type[PromptExample] 385 | ) -> List[PromptExample]: 386 | data = srsly.read_yaml(path) 387 | if not isinstance(data, list): 388 | msg.fail("Cannot interpret prompt examples from yaml", exits=True) 389 | assert isinstance(data, list) 390 | output = [example_class(**eg) for eg in data] 391 | return output 392 | 393 | 394 | def batch_sequence(items: Iterable[_ItemT], batch_size: int) -> Iterable[List[_ItemT]]: 395 | batch = [] 396 | for eg in items: 397 | batch.append(eg) 398 | if len(batch) >= batch_size: 399 | yield batch 400 | batch = [] 401 | if batch: 402 | yield batch 403 | -------------------------------------------------------------------------------- /recipes/openai_ab.py: -------------------------------------------------------------------------------- 1 | """A/B evaluation of OpenAI responses, for prompt engineering. 2 | 3 | A/B evaluation is basically a blind "taste test". Results are produced 4 | from two experimental conditions (in this case, responses from two 5 | prompts). The annotator is shown the response pair, without knowing which 6 | condition produced which response. The annotator marks which one is better, 7 | and the response is recorded. 8 | 9 | At the end of annotation, the results are tallied up, so you can see whether 10 | one condition produces better results than the other. This lets you apply 11 | a sound methodology to subjective decisions. 12 | """ 13 | from collections import Counter 14 | import os 15 | import random 16 | import sys 17 | import time 18 | from pathlib import Path 19 | from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, cast 20 | 21 | import httpx 22 | import jinja2 23 | import rich 24 | from rich.panel import Panel 25 | 26 | import prodigy 27 | import pydantic 28 | import srsly 29 | from dotenv import load_dotenv 30 | from prodigy.util import msg 31 | 32 | _ItemT = TypeVar("_ItemT") 33 | # Set up openai 34 | load_dotenv() # take environment variables from .env. 35 | 36 | 37 | class PromptInput(pydantic.BaseModel): 38 | id: str 39 | prompt_args: Dict[str, Any] 40 | 41 | 42 | class OpenAIPromptAB: 43 | display: jinja2.Template 44 | prompts: Dict[str, jinja2.Template] 45 | inputs: Iterable[PromptInput] 46 | batch_size: int 47 | verbose: bool 48 | randomize: bool 49 | openai_api_org: str 50 | openai_api_key: str 51 | openai_temperature: float 52 | openai_max_tokens: int 53 | openai_timeout_s: int 54 | openai_n: int 55 | 56 | def __init__( 57 | self, 58 | display: jinja2.Template, 59 | prompts: Dict[str, jinja2.Template], 60 | inputs: Iterable[PromptInput], 61 | *, 62 | openai_api_org: str, 63 | openai_api_key: str, 64 | openai_model: str, 65 | batch_size: int = 10, 66 | verbose: bool = False, 67 | randomize: bool = True, 68 | openai_temperature: float = 0, 69 | openai_max_tokens: int = 500, 70 | openai_timeout_s: int = 50, 71 | openai_n: int = 1, 72 | repeat: int = 3 73 | ): 74 | self.display = display 75 | self.inputs = inputs 76 | self.prompts = prompts 77 | self.model = openai_model 78 | self.batch_size = batch_size 79 | self.verbose = verbose 80 | self.openai_api_org = openai_api_org 81 | self.openai_api_key = openai_api_key 82 | self.openai_temperature = openai_temperature 83 | self.openai_max_tokens = openai_max_tokens 84 | self.openai_timeout_s = openai_timeout_s 85 | self.openai_n = openai_n 86 | self.randomize = randomize 87 | self.repeat = repeat 88 | 89 | def __iter__(self) -> Iterable[Dict]: 90 | for input_batch in _batch_sequence(self.inputs, self.batch_size): 91 | for _ in range(self.repeat): 92 | response_batch = self._get_response_batch(input_batch) 93 | for input_, responses in zip(input_batch, response_batch): 94 | yield self._make_example( 95 | input_.id, 96 | self.display.render(**input_.prompt_args), 97 | responses, 98 | randomize=self.randomize, 99 | prompt_args=input_.prompt_args 100 | ) 101 | 102 | def on_exit(self, ctrl): 103 | examples = ctrl.db.get_dataset_examples(ctrl.dataset) 104 | counts = Counter({k: 0 for k in self.prompts.keys()}) 105 | # Get last example per ID 106 | for eg in examples: 107 | selected = eg.get("accept", []) 108 | if not selected or len(selected) != 1 or eg["answer"] != "accept": 109 | continue 110 | counts[selected[0]] += 1 111 | print("") 112 | if not counts: 113 | msg.warn("No answers found", exits=0) 114 | msg.divider("Evaluation results", icon="emoji") 115 | # Handle edge case when both are equal: 116 | nr1, nr2 = counts.most_common(2) 117 | if nr1[1] == nr2[1]: 118 | msg.good(f"It's a draw!") 119 | else: 120 | pref, _ = nr1 121 | msg.good(f"You preferred {pref}") 122 | rows = [(name, count) for name, count in counts.most_common()] 123 | msg.table(rows, aligns=("l", "r")) 124 | 125 | def _get_response_batch(self, inputs: List[PromptInput]) -> List[Dict[str, str]]: 126 | name1, name2 = self._choose_rivals() 127 | prompts = [] 128 | for input_ in inputs: 129 | prompts.append(self._get_prompt(name1, input_.prompt_args)) 130 | prompts.append(self._get_prompt(name2, input_.prompt_args)) 131 | if self.verbose: 132 | for prompt in prompts: 133 | rich.print(Panel(prompt, title="Prompt to OpenAI")) 134 | responses = self._get_responses(prompts) 135 | assert len(responses) == len(inputs) * 2 136 | output = [] 137 | # Pair out the responses. There's a fancy 138 | # zip way to do this but I think that's less 139 | # readable 140 | for i in range(0, len(responses), 2): 141 | output.append({name1: responses[i], name2: responses[i + 1]}) 142 | return output 143 | 144 | def _choose_rivals(self) -> Tuple[str, str]: 145 | assert len(self.prompts) == 2 146 | return tuple(sorted(self.prompts.keys())) 147 | 148 | def _get_prompt(self, name: str, args: Dict) -> str: 149 | return self.prompts[name].render(**args) 150 | 151 | def _get_responses(self, prompts: List[str]) -> List[str]: 152 | headers = { 153 | "Authorization": f"Bearer {self.openai_api_key}", 154 | "OpenAI-Organization": self.openai_api_org, 155 | "Content-Type": "application/json", 156 | } 157 | r = _retry429( 158 | lambda: httpx.post( 159 | "https://api.openai.com/v1/completions", 160 | headers=headers, 161 | json={ 162 | "model": self.model, 163 | "prompt": prompts, 164 | "temperature": self.openai_temperature, 165 | "max_tokens": self.openai_max_tokens, 166 | }, 167 | timeout=self.openai_timeout_s, 168 | ), 169 | n=self.openai_n, 170 | timeout_s=self.openai_timeout_s, 171 | ) 172 | r.raise_for_status() 173 | responses = r.json() 174 | return [responses["choices"][i]["text"].strip() for i in range(len(prompts))] 175 | 176 | def _make_example( 177 | self, id_str: str, input: str, responses: Dict[str, str], randomize: bool, prompt_args: Dict[str, Any] 178 | ) -> Dict: 179 | 180 | question = { 181 | "id": id_str, 182 | "text": input, 183 | "options": [], 184 | } 185 | response_pairs = list(responses.items()) 186 | if randomize: 187 | random.shuffle(response_pairs) 188 | else: 189 | response_pairs = list(sorted(response_pairs)) 190 | for name, value in response_pairs: 191 | question["options"].append({"id": name, "text": value}) 192 | question["meta"] = prompt_args 193 | return question 194 | 195 | 196 | @prodigy.recipe( 197 | # fmt: off 198 | "ab.openai.prompts", 199 | dataset=("Dataset to save answers to", "positional", None, str), 200 | inputs_path=("Path to jsonl inputs", "positional", None, Path), 201 | display_template_path=("Template for summarizing the arguments","positional", None, Path), 202 | prompt1_template_path=("Path to the first jinja2 prompt template","positional", None, Path), 203 | prompt2_template_path=("Path to the second jinja2 prompt template","positional", None, Path), 204 | model=("GPT-3 model to use for responses", "option", "m", str), 205 | batch_size=("Batch size to send to OpenAI API", "option", "b", int), 206 | verbose=("Print extra information to terminal", "flag", "v", bool), 207 | no_random=("Don't randomize which annotation is shown as correct","flag","NR",bool,), 208 | repeat=("How often to send the same prompt to OpenAI", "option", "r", int) 209 | # fmt: on 210 | ) 211 | def ab_openai_prompts( 212 | dataset: str, 213 | inputs_path: Path, 214 | display_template_path: Path, 215 | prompt1_template_path: Path, 216 | prompt2_template_path: Path, 217 | model: str = "text-davinci-003", 218 | batch_size: int = 10, 219 | verbose: bool = False, 220 | no_random: bool = False, 221 | repeat: int = 1, 222 | ): 223 | api_key, api_org = _get_api_credentials(model) 224 | inputs = [PromptInput(**x) for x in cast(List[Dict], srsly.read_jsonl(inputs_path))] 225 | 226 | display = _load_template(display_template_path) 227 | prompt1 = _load_template(prompt1_template_path) 228 | prompt2 = _load_template(prompt2_template_path) 229 | stream = OpenAIPromptAB( 230 | display=display, 231 | prompts={ 232 | prompt1_template_path.name: prompt1, 233 | prompt2_template_path.name: prompt2, 234 | }, 235 | inputs=inputs, 236 | openai_api_org=api_org, 237 | openai_api_key=api_key, 238 | openai_model=model, 239 | batch_size=batch_size, 240 | verbose=verbose, 241 | randomize=not no_random, 242 | openai_temperature=0.9, 243 | repeat=repeat 244 | ) 245 | return { 246 | "dataset": dataset, 247 | "view_id": "choice", 248 | "stream": stream, 249 | "on_exit": stream.on_exit, 250 | "config": { 251 | "batch_size": batch_size, 252 | "choice_auto_accept": True, 253 | "exclude_by": "input", 254 | "global_css": ".prodigy-content{line-height: 1.2;};" 255 | }, 256 | } 257 | 258 | def _get_api_credentials(model: str) -> Tuple[str, str]: 259 | # Fetch and check the key 260 | api_key = os.getenv("OPENAI_KEY") 261 | if api_key is None: 262 | m = ( 263 | "Could not find the API key to access the openai API. Ensure you have an API key " 264 | "set up via https://beta.openai.com/account/api-keys, then make it available as " 265 | "an environment variable 'OPENAI_KEY', for instance in a .env file." 266 | ) 267 | msg.fail(m) 268 | sys.exit(-1) 269 | # Fetch and check the org 270 | org = os.getenv("OPENAI_ORG") 271 | if org is None: 272 | m = ( 273 | "Could not find the organisation to access the openai API. Ensure you have an API key " 274 | "set up via https://beta.openai.com/account/api-keys, obtain its organization ID 'org-XXX' " 275 | "via https://beta.openai.com/account/org-settings, then make it available as " 276 | "an environment variable 'OPENAI_ORG', for instance in a .env file." 277 | ) 278 | msg.fail(m) 279 | sys.exit(-1) 280 | 281 | # Check the access and get a list of available models to verify the model argument (if not None) 282 | # Even if the model is None, this call is used as a healthcheck to verify access. 283 | headers = { 284 | "Authorization": f"Bearer {api_key}", 285 | "OpenAI-Organization": org, 286 | } 287 | r = _retry429( 288 | lambda: httpx.get( 289 | "https://api.openai.com/v1/models", 290 | headers=headers, 291 | ), 292 | n=1, 293 | timeout_s=1, 294 | ) 295 | if r.status_code == 422: 296 | m = ( 297 | "Could not access api.openai.com -- 422 permission denied." 298 | "Visit https://beta.openai.com/account/api-keys to check your API keys." 299 | ) 300 | msg.fail(m) 301 | sys.exit(-1) 302 | elif r.status_code != 200: 303 | m = "Error accessing api.openai.com" f"{r.status_code}: {r.text}" 304 | msg.fail(m) 305 | sys.exit(-1) 306 | 307 | if model is not None: 308 | response = r.json()["data"] 309 | models = [response[i]["id"] for i in range(len(response))] 310 | if model not in models: 311 | e = f"The specified model '{model}' is not available. Choices are: {sorted(set(models))}" 312 | msg.fail(e, exits=1) 313 | 314 | return api_key, org 315 | 316 | 317 | def _load_template(path: Path) -> jinja2.Template: 318 | # I know jinja has a lot of complex file loading stuff, 319 | # but we're not using the inheritance etc that makes 320 | # that stuff worthwhile. 321 | if not path.suffix == ".jinja2": 322 | msg.fail( 323 | f"The parameter expects a .jinja2 file. Received path='{path}'", 324 | exits=1, 325 | ) 326 | with path.open("r", encoding="utf8") as file_: 327 | text = file_.read() 328 | return jinja2.Template(text) 329 | 330 | 331 | def _retry429( 332 | call_api: Callable[[], httpx.Response], n: int, timeout_s: int 333 | ) -> httpx.Response: 334 | """Retry a call to the OpenAI API if we get a 429: Too many requests 335 | error. 336 | """ 337 | assert n >= 0 338 | assert timeout_s >= 1 339 | r = call_api() 340 | i = -1 341 | while i < n and r.status_code == 429: 342 | time.sleep(timeout_s) 343 | i += 1 344 | return r 345 | 346 | 347 | def _batch_sequence(items: Iterable[_ItemT], batch_size: int) -> Iterable[List[_ItemT]]: 348 | batch = [] 349 | for eg in items: 350 | batch.append(eg) 351 | if len(batch) >= batch_size: 352 | yield batch 353 | batch = [] 354 | if batch: 355 | yield batch 356 | -------------------------------------------------------------------------------- /recipes/openai_ner.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from pathlib import Path 3 | from typing import Dict, Iterable, List, Optional, Tuple, cast, Callable 4 | from dataclasses import dataclass 5 | 6 | import prodigy 7 | import prodigy.components.db 8 | import prodigy.components.preprocess 9 | import prodigy.util 10 | import spacy 11 | import srsly 12 | import tqdm 13 | from dotenv import load_dotenv 14 | from spacy.language import Language 15 | from spacy.util import filter_spans 16 | 17 | from recipes.openai import OpenAISuggester, PromptExample, get_api_credentials 18 | from recipes.openai import load_template, normalize_label 19 | from recipes.openai import read_prompt_examples 20 | 21 | CSS_FILE_PATH = Path(__file__).parent / "style.css" 22 | DEFAULT_PROMPT_PATH = Path(__file__).parent.parent / "templates" / "ner_prompt.jinja2" 23 | HTML_TEMPLATE = """ 24 |
25 |
26 | Show the prompt for OpenAI 27 |
{{openai.prompt}}
28 |
29 |
30 | Show the response from OpenAI 31 |
{{openai.response}}
32 |
33 |
34 | """ 35 | 36 | # Set up openai 37 | load_dotenv() # take environment variables from .env. 38 | 39 | 40 | @dataclass 41 | class NERPromptExample(PromptExample): 42 | """An example to be passed into an OpenAI NER prompt""" 43 | 44 | text: str 45 | entities: Dict[str, List[str]] 46 | 47 | @classmethod 48 | def from_prodigy(cls, example: Dict, labels: Iterable[str]) -> "PromptExample": 49 | """Create a prompt example from Prodigy's format. 50 | Only entities with a label from the given set will be retained. 51 | The given set of labels is assumed to be already normalized. 52 | """ 53 | if "text" not in example: 54 | raise ValueError("Cannot make PromptExample without text") 55 | entities_by_label = defaultdict(list) 56 | full_text = example["text"] 57 | for span in example.get("spans", []): 58 | label = normalize_label(span["label"]) 59 | if label in labels: 60 | mention = full_text[int(span["start"]) : int(span["end"])] 61 | entities_by_label[label].append(mention) 62 | 63 | return cls(text=full_text, entities=entities_by_label) 64 | 65 | 66 | def make_ner_response_parser(labels: List[str], lang: str) -> Callable: 67 | 68 | nlp = spacy.blank(lang) 69 | 70 | def _parse_response(text: str, example: Optional[Dict] = None) -> Dict: 71 | """Interpret OpenAI's NER response. It's supposed to be 72 | a list of lines, with each line having the form: 73 | Label: phrase1, phrase2, ... 74 | However, there's no guarantee that the model will give 75 | us well-formed output. It could say anything, it's an LM. 76 | So we need to be robust. 77 | """ 78 | output = [] 79 | for line in text.strip().split("\n"): 80 | if line and ":" in line: 81 | label, phrases = line.split(":", 1) 82 | label = normalize_label(label) 83 | if label in labels: 84 | if phrases.strip(): 85 | phrases = [ 86 | phrase.strip() for phrase in phrases.strip().split(",") 87 | ] 88 | output.append((label, phrases)) 89 | 90 | example = _fmt_response(output, example) 91 | return example 92 | 93 | def _fmt_response(response: List[Tuple[str, List[str]]], example: Dict): 94 | doc = nlp.make_doc(example["text"]) 95 | spacy_spans = [] 96 | for label, phrases in response: 97 | label = normalize_label(label) 98 | if label in labels: 99 | offsets = _find_substrings(doc.text, phrases) 100 | for start, end in offsets: 101 | span = doc.char_span( 102 | start, end, alignment_mode="contract", label=label 103 | ) 104 | if span is not None: 105 | spacy_spans.append(span) 106 | # This step prevents the same token from being used in multiple spans. 107 | # If there's a conflict, the longer span is preserved. 108 | spacy_spans = filter_spans(spacy_spans) 109 | spans = [ 110 | { 111 | "label": span.label_, 112 | "start": span.start_char, 113 | "end": span.end_char, 114 | "token_start": span.start, 115 | "token_end": span.end - 1, 116 | } 117 | for span in spacy_spans 118 | ] 119 | return {"spans": spans} 120 | 121 | return _parse_response 122 | 123 | 124 | @prodigy.recipe( 125 | "ner.openai.correct", 126 | dataset=("Dataset to save answers to", "positional", None, str), 127 | filepath=("Path to jsonl data to annotate", "positional", None, Path), 128 | labels=("Labels (comma delimited)", "positional", None, lambda s: s.split(",")), 129 | model=("GPT-3 model to use for initial predictions", "option", "m", str), 130 | examples_path=("Path to examples to help define the task", "option", "e", Path), 131 | lang=("Language to use for tokenizer", "option", "l", str), 132 | max_examples=("Max examples to include in prompt", "option", "n", int), 133 | prompt_path=("Path to jinja2 prompt template", "option", "p", Path), 134 | batch_size=("Batch size to send to OpenAI API", "option", "b", int), 135 | segment=("Split articles into sentences", "flag", "S", bool), 136 | verbose=("Print extra information to terminal", "flag", "v", bool), 137 | ) 138 | def ner_openai_correct( 139 | dataset: str, 140 | filepath: Path, 141 | labels: List[str], 142 | lang: str = "en", 143 | model: str = "text-davinci-003", 144 | batch_size: int = 10, 145 | segment: bool = False, 146 | examples_path: Optional[Path] = None, 147 | prompt_path: Path = DEFAULT_PROMPT_PATH, 148 | max_examples: int = 2, 149 | verbose: bool = False, 150 | ): 151 | examples = read_prompt_examples(examples_path, example_class=NERPromptExample) 152 | nlp = spacy.blank(lang) 153 | if segment: 154 | nlp.add_pipe("sentencizer") 155 | api_key, api_org = get_api_credentials(model) 156 | openai = OpenAISuggester( 157 | response_parser=make_ner_response_parser(labels=labels, lang=lang), 158 | openai_model=model, 159 | labels=labels, 160 | max_examples=max_examples, 161 | prompt_template=load_template(prompt_path), 162 | segment=segment, 163 | verbose=verbose, 164 | openai_api_org=api_org, 165 | openai_api_key=api_key, 166 | openai_n=1, 167 | openai_retry_timeout_s=10, 168 | openai_read_timeout_s=20, 169 | openai_n_retries=10, 170 | prompt_example_class=NERPromptExample, 171 | ) 172 | for eg in examples: 173 | openai.add_example(eg) 174 | if max_examples >= 1: 175 | db = prodigy.components.db.connect() 176 | db_examples = db.get_dataset(dataset) 177 | if db_examples: 178 | for eg in db_examples: 179 | if PromptExample.is_flagged(eg): 180 | openai.add_example(PromptExample.from_prodigy(eg, openai.labels)) 181 | stream = cast(Iterable[Dict], srsly.read_jsonl(filepath)) 182 | return { 183 | "dataset": dataset, 184 | "view_id": "blocks", 185 | "stream": openai(stream, batch_size=batch_size, nlp=nlp), 186 | "update": openai.update, 187 | "config": { 188 | "labels": openai.labels, 189 | "batch_size": batch_size, 190 | "exclude_by": "input", 191 | "blocks": [ 192 | {"view_id": "ner_manual"}, 193 | {"view_id": "html", "html_template": HTML_TEMPLATE}, 194 | ], 195 | "show_flag": True, 196 | "global_css": CSS_FILE_PATH.read_text(), 197 | }, 198 | } 199 | 200 | 201 | @prodigy.recipe( 202 | "ner.openai.fetch", 203 | input_path=("Path to jsonl data to annotate", "positional", None, Path), 204 | output_path=("Path to save the output", "positional", None, Path), 205 | labels=("Labels (comma delimited)", "positional", None, lambda s: s.split(",")), 206 | lang=("Language to use for tokenizer.", "option", "l", str), 207 | model=("GPT-3 model to use for completion", "option", "m", str), 208 | examples_path=("Examples file to help define the task", "option", "e", Path), 209 | max_examples=("Max examples to include in prompt", "option", "n", int), 210 | prompt_path=("Path to jinja2 prompt template", "option", "p", Path), 211 | batch_size=("Batch size to send to OpenAI API", "option", "b", int), 212 | segment=("Split sentences", "flag", "S", bool), 213 | verbose=("Print extra information to terminal", "option", "flag", bool), 214 | ) 215 | def ner_openai_fetch( 216 | input_path: Path, 217 | output_path: Path, 218 | labels: List[str], 219 | lang: str = "en", 220 | model: str = "text-davinci-003", 221 | batch_size: int = 10, 222 | segment: bool = False, 223 | examples_path: Optional[Path] = None, 224 | prompt_path: Path = DEFAULT_PROMPT_PATH, 225 | max_examples: int = 2, 226 | verbose: bool = False, 227 | ): 228 | """Get bulk NER suggestions from an OpenAI API, using zero-shot or few-shot learning. 229 | The results can then be corrected using the `ner.manual` recipe. 230 | This approach lets you get the openai queries out of the way upfront, which can help 231 | if you want to use multiple annotators of if you want to make sure you don't have to 232 | wait on the OpenAI queries. The downside is that you can't flag examples to be integrated 233 | into the prompt during the annotation, unlike the ner.openai.correct recipe. 234 | """ 235 | api_key, api_org = get_api_credentials(model) 236 | examples = read_prompt_examples(examples_path, example_class=NERPromptExample) 237 | nlp = spacy.blank(lang) 238 | if segment: 239 | nlp.add_pipe("sentencizer") 240 | openai = OpenAISuggester( 241 | response_parser=make_ner_response_parser(labels=labels, lang=lang), 242 | openai_model=model, 243 | labels=labels, 244 | max_examples=max_examples, 245 | prompt_template=load_template(prompt_path), 246 | segment=segment, 247 | verbose=verbose, 248 | openai_api_org=api_org, 249 | openai_api_key=api_key, 250 | openai_n=1, 251 | openai_retry_timeout_s=10, 252 | openai_read_timeout_s=20, 253 | openai_n_retries=10, 254 | prompt_example_class=NERPromptExample, 255 | ) 256 | for eg in examples: 257 | openai.add_example(eg) 258 | stream = list(srsly.read_jsonl(input_path)) 259 | stream = openai(tqdm.tqdm(stream), batch_size=batch_size, nlp=nlp) 260 | srsly.write_jsonl(output_path, stream) 261 | 262 | 263 | def _find_substrings( 264 | text: str, 265 | substrings: List[str], 266 | *, 267 | case_sensitive: bool = False, 268 | single_match: bool = False, 269 | ) -> List[Tuple[int, int]]: 270 | """Given a list of substrings, find their character start and end positions in a text. The substrings are assumed to be sorted by the order of their occurrence in the text. 271 | text: The text to search over. 272 | substrings: The strings to find. 273 | case_sensitive: Whether to search without case sensitivity. 274 | single_match: If False, allow one substring to match multiple times in the text. If True, returns the first hit. 275 | """ 276 | # remove empty and duplicate strings, and lowercase everything if need be 277 | substrings = [s for s in substrings if s and len(s) > 0] 278 | if not case_sensitive: 279 | text = text.lower() 280 | substrings = [s.lower() for s in substrings] 281 | substrings = _unique(substrings) 282 | offsets = [] 283 | for substring in substrings: 284 | search_from = 0 285 | # Search until one hit is found. Continue only if single_match is False. 286 | while True: 287 | start = text.find(substring, search_from) 288 | if start == -1: 289 | break 290 | end = start + len(substring) 291 | offsets.append((start, end)) 292 | if single_match: 293 | break 294 | search_from = end 295 | return offsets 296 | 297 | 298 | def _unique(items: List[str]) -> List[str]: 299 | """Remove duplicates without changing order""" 300 | seen = set() 301 | output = [] 302 | for item in items: 303 | if item not in seen: 304 | output.append(item) 305 | seen.add(item) 306 | return output 307 | -------------------------------------------------------------------------------- /recipes/openai_terms.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from functools import reduce 4 | from pathlib import Path 5 | from typing import Callable, Dict, Iterable, List, TypeVar, Union 6 | 7 | import httpx 8 | import jinja2 9 | import prodigy 10 | import prodigy.components.preprocess 11 | import prodigy.util 12 | import rich 13 | import srsly 14 | from dotenv import load_dotenv 15 | from prodigy import set_hashes 16 | from prodigy.components.db import connect 17 | from prodigy.util import msg 18 | from rich.panel import Panel 19 | from rich.pretty import Pretty 20 | 21 | TERMS_DEFAULT_PROMPT_PATH = ( 22 | Path(__file__).parent.parent / "templates" / "terms_prompt.jinja2" 23 | ) 24 | 25 | # Set up openai 26 | load_dotenv() # take environment variables from .env. 27 | 28 | _ItemT = TypeVar("_ItemT") 29 | 30 | 31 | def _load_template(path: Path) -> jinja2.Template: 32 | # I know jinja has a lot of complex file loading stuff, 33 | # but we're not using the inheritance etc that makes 34 | # that stuff worthwhile. 35 | if not path.suffix == ".jinja2": 36 | msg.fail( 37 | "The --prompt-path (-p) parameter expects a .jinja2 file.", 38 | exits=1, 39 | ) 40 | with path.open("r", encoding="utf8") as file_: 41 | text = file_.read() 42 | return jinja2.Template(text) 43 | 44 | 45 | def _parse_terms(completion: str) -> List[str]: 46 | if "\n" not in completion: 47 | # Sometimes it only returns a single item. For example, when there are 48 | # many, many seeds around. 49 | lines = [completion] 50 | else: 51 | # Other times we cannot assume the final item will have had sufficient 52 | # tokens available to complete the term, so we have to discard it. 53 | lines = [item for item in completion.split("\n") if len(item)] 54 | lines = lines[:-1] 55 | return [item.replace("-", "").strip() for item in lines] 56 | 57 | 58 | def _retry429( 59 | call_api: Callable[[], httpx.Response], n: int, timeout_s: int 60 | ) -> httpx.Response: 61 | """Retry a call to the OpenAI API if we get a 429: Too many requests 62 | error. 63 | """ 64 | assert n >= 0 65 | assert timeout_s >= 1 66 | r = call_api() 67 | i = -1 68 | while i < n and r.status_code == 429: 69 | time.sleep(timeout_s) 70 | i += 1 71 | return r 72 | 73 | 74 | def _generate_headers() -> Dict[str, str]: 75 | if not os.getenv("OPENAI_KEY"): 76 | msg.fail("The `OPENAI_KEY` is missing from your `.env` file.", exits=1) 77 | 78 | if not os.getenv("OPENAI_ORG"): 79 | msg.fail("The `OPENAI_ORG` is missing from your `.env` file.", exits=1) 80 | 81 | return { 82 | "Authorization": f"Bearer {os.getenv('OPENAI_KEY')}", 83 | "OpenAI-Organization": os.getenv("OPENAI_ORG"), 84 | "Content-Type": "application/json", 85 | } 86 | 87 | 88 | @prodigy.recipe( 89 | # fmt: off 90 | "terms.openai.fetch", 91 | query=("Query to send to OpenAI", "positional", None, str), 92 | output_path=("Path to save the output", "positional", None, Path), 93 | seeds=("One or more comma-seperated seed phrases.","option","s",lambda d: d.split(",")), 94 | n=("The minimum number of items to generate", "option", "n", int), 95 | model=("GPT-3 model to use for completion", "option", "m", str), 96 | prompt_path=("Path to jinja2 prompt template", "option", "p", Path), 97 | verbose=("Print extra information to terminal", "flag", "v", bool), 98 | resume=("Resume by loading in text examples from output file.", "flag", "r", bool), 99 | progress=("Print progress of the recipe.", "flag", "pb", bool), 100 | temperature=("OpenAI temperature param", "option", "t", float), 101 | top_p=("OpenAI top_p param", "option", "tp", float), 102 | best_of=("OpenAI best_of param", "option", "bo", int), 103 | n_batch=("OpenAI batch size param", "option", "nb", int), 104 | max_tokens=("Max tokens to generate per call", "option", "mt", int), 105 | # fmt: on 106 | ) 107 | def terms_openai_fetch( 108 | query: str, 109 | output_path: Path, 110 | seeds: List[str] = [], 111 | n: int = 100, 112 | model: str = "text-davinci-003", 113 | prompt_path: Path = TERMS_DEFAULT_PROMPT_PATH, 114 | verbose: bool = False, 115 | resume: bool = False, 116 | progress: bool = False, 117 | temperature=1.0, 118 | top_p=1.0, 119 | best_of=10, 120 | n_batch=10, 121 | max_tokens=100, 122 | ): 123 | """Get bulk term suggestions from the OpenAI API, using zero-shot learning. 124 | 125 | The results can then be corrected using the `prodigy textcat.manual` recipe and 126 | turned into patterns via `prodigy terms.to-patterns`. 127 | """ 128 | tic = time.time() 129 | template = _load_template(prompt_path) 130 | # The `best_of` param cannot be less than the amount we batch. 131 | if best_of < n_batch: 132 | best_of = n_batch 133 | 134 | # Start collection of terms. If we resume we also fill seed terms with file contents. 135 | terms = [] 136 | if resume: 137 | if output_path.exists(): 138 | examples = srsly.read_jsonl(output_path) 139 | terms.extend([e["text"] for e in examples]) 140 | 141 | # Mimic behavior from Prodigy terms recipe to ensure that seed terms also appear in output 142 | for seed in seeds: 143 | if seed not in terms: 144 | srsly.write_jsonl( 145 | output_path, 146 | [{"text": seed, "meta": {"openai_query": query}}], 147 | append=True, 148 | append_new_line=False, 149 | ) 150 | 151 | # Ensure we have access to correct environment variables and construct headers 152 | if not os.getenv("OPENAI_KEY"): 153 | msg.fail("The `OPENAI_KEY` is missing from your `.env` file.", exits=1) 154 | 155 | if not os.getenv("OPENAI_ORG"): 156 | msg.fail("The `OPENAI_ORG` is missing from your `.env` file.", exits=1) 157 | 158 | headers = _generate_headers() 159 | 160 | # This recipe may overshoot the target, but we keep going until we have at least `n` 161 | while len(terms) < n: 162 | prompt = template.render(n=n, examples=seeds + terms, description=query) 163 | if verbose: 164 | rich.print(Panel(prompt, title="Prompt to OpenAI")) 165 | 166 | make_request = lambda: httpx.post( 167 | "https://api.openai.com/v1/completions", 168 | headers=headers, 169 | json={ 170 | "model": model, 171 | "prompt": [prompt], 172 | "temperature": temperature, 173 | "max_tokens": max_tokens, 174 | "top_p": top_p, 175 | "n": min(n_batch, best_of), 176 | "best_of": best_of, 177 | }, 178 | timeout=45, 179 | ) 180 | 181 | # Catch 429: too many request errors 182 | resp = _retry429(make_request, n=1, timeout_s=30) 183 | 184 | # Report on any other error that might happen, the most typical use-case is 185 | # catching the maximum context length of 4097 tokens when the prompt gets big. 186 | if resp.status_code != 200: 187 | msg.fail(f"Received status code {resp.status_code} from OpenAI. Details:") 188 | rich.print(resp.json()) 189 | exit(code=1) 190 | 191 | # Cast to a set to make sure we remove duplicates 192 | choices = resp.json()["choices"] 193 | sets_of_terms = [set(_parse_terms(c["text"])) for c in choices] 194 | parsed_terms = list(reduce(lambda a, b: a.union(b), sets_of_terms)) 195 | 196 | # Save intermediate results into file, in-case of a hiccup 197 | srsly.write_jsonl( 198 | output_path, 199 | [{"text": t, "meta": {"openai_query": query}} for t in parsed_terms], 200 | append=True, 201 | append_new_line=False, 202 | ) 203 | 204 | # Make the terms list bigger and re-use terms in next prompt. 205 | terms.extend(parsed_terms) 206 | if verbose: 207 | rich.print(Panel(Pretty(terms), title="Terms collected sofar.")) 208 | if progress: 209 | rich.print( 210 | f"Received {len(parsed_terms)} items, totalling {len(terms)} terms. " 211 | f"Progress at {round(len(terms)/n*100)}% after {round(time.time() - tic)}s." 212 | ) 213 | -------------------------------------------------------------------------------- /recipes/openai_textcat.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pathlib import Path 3 | from typing import Callable, Dict, List, Optional 4 | 5 | import prodigy 6 | import spacy 7 | import srsly 8 | import tqdm 9 | from dotenv import load_dotenv 10 | from prodigy.util import msg 11 | 12 | from recipes.openai import _ItemT, OpenAISuggester, PromptExample, get_api_credentials 13 | from recipes.openai import load_template, normalize_label 14 | from recipes.openai import read_prompt_examples 15 | 16 | CSS_FILE_PATH = Path(__file__).parent / "style.css" 17 | DEFAULT_PROMPT_PATH = ( 18 | Path(__file__).parent.parent / "templates" / "textcat_prompt.jinja2" 19 | ) 20 | HTML_TEMPLATE = """ 21 |
22 | {{ #label }} 23 | 24 |

OpenAI GPT-3 says: {{ meta.answer }}

25 |
26 | {{ /label }} 27 |
28 | Show the prompt for OpenAI 29 |
{{openai.prompt}}
30 |
31 |
32 | Show the response from OpenAI 33 |
{{openai.response}}
34 |
35 |
36 | """ 37 | 38 | # Set up openai 39 | load_dotenv() # take environment variables from .env. 40 | 41 | 42 | @dataclass 43 | class TextCatPromptExample(PromptExample): 44 | """An example to be passed into an OpenAI TextCat prompt.""" 45 | 46 | text: str 47 | answer: str 48 | reason: str 49 | 50 | @classmethod 51 | def from_prodigy(cls, example: _ItemT, labels: List[str]) -> "PromptExample": 52 | """Create a prompt example from Prodigy's format.""" 53 | if "text" not in example: 54 | raise ValueError("Cannot make PromptExample without text") 55 | 56 | full_text = example["text"] 57 | reason = example["meta"].get("reason") 58 | if len(labels) == 1: 59 | answer = example.get("answer", "reject") 60 | else: 61 | answer = ",".join(example.get("accept", [])) 62 | return cls(text=full_text, answer=answer, reason=reason) 63 | 64 | 65 | def make_textcat_response_parser(labels: List[str]) -> Callable: 66 | def _parse_response(text: str, example: Optional[Dict] = None) -> Dict: 67 | response: Dict[str, str] = {} 68 | if text and any(k in text.lower() for k in ("answer", "reason")): 69 | for line in text.strip().split("\n"): 70 | if line and ":" in line: 71 | key, value = line.split(":", 1) 72 | response[key.strip().lower()] = value.strip() 73 | else: 74 | response = {"answer": "", "reason": ""} 75 | 76 | example = _fmt_binary(response) if len(labels) == 1 else _fmt_multi(response) 77 | return example 78 | 79 | def _fmt_binary(response: Dict[str, str]) -> Dict: 80 | """Parse binary TextCat where the 'answer' key means it's a positive class.""" 81 | return { 82 | "answer": response["answer"].lower(), 83 | "label": labels[0], 84 | "meta": { 85 | "answer": response["answer"].upper(), 86 | "reason": response["reason"], 87 | }, 88 | } 89 | 90 | def _fmt_multi(response: Dict[str, str]) -> Dict: 91 | """Parse multilabel TextCat where the 'accept' key is a list of positive labels.""" 92 | return { 93 | "options": [{"id": label, "text": label} for label in labels], 94 | "answer": "accept", 95 | "meta": { 96 | "reason": response.get("reason", ""), 97 | "gpt_answer": response.get("answer", ""), 98 | }, 99 | "accept": list( 100 | filter( 101 | None, 102 | [normalize_label(s.strip()) for s in response["answer"].split(",")], 103 | ) 104 | ), 105 | } 106 | 107 | return _parse_response 108 | 109 | 110 | @prodigy.recipe( 111 | # fmt: off 112 | "textcat.openai.correct", 113 | dataset=("Dataset to save answers to", "positional", None, str), 114 | filepath=("Path to jsonl data to annotate", "positional", None, Path), 115 | labels=("Labels (comma delimited)", "option", "L", lambda s: s.split(",")), 116 | lang=("Language to initialize spaCy model", "option", "l", str), 117 | model=("GPT-3 model to use for completion", "option", "m", str), 118 | batch_size=("Batch size to send to OpenAI API", "option", "b", int), 119 | segment=("Split sentences", "flag", "S", bool), 120 | prompt_path=("Path to the .jinja2 prompt template", "option", "p", Path), 121 | examples_path=("Examples file to help define the task", "option", "e", Path), 122 | max_examples=("Max examples to include in prompt", "option", "n", int), 123 | exclusive_classes=("Make the classification task exclusive", "flag", "E", bool), 124 | verbose=("Print extra information to terminal", "flag", "v", bool), 125 | # fmt: on 126 | ) 127 | def textcat_openai_correct( 128 | dataset: str, 129 | filepath: Path, 130 | labels: List[str], 131 | lang: str = "en", 132 | model: str = "text-davinci-003", 133 | batch_size: int = 10, 134 | segment: bool = False, 135 | prompt_path: Path = DEFAULT_PROMPT_PATH, 136 | examples_path: Optional[Path] = None, 137 | max_examples: int = 2, 138 | exclusive_classes: bool = False, 139 | verbose: bool = False, 140 | ): 141 | api_key, api_org = get_api_credentials(model) 142 | examples = read_prompt_examples(examples_path, example_class=TextCatPromptExample) 143 | nlp = spacy.blank(lang) 144 | 145 | if segment: 146 | nlp.add_pipe("sentencizer") 147 | 148 | if not exclusive_classes and len(labels) == 1: 149 | msg.warn( 150 | "Binary classification should always be exclusive. Setting " 151 | "`exclusive_classes` parameter to True" 152 | ) 153 | exclusive_classes = True 154 | 155 | # Create OpenAISuggester with GPT-3 parameters 156 | openai = OpenAISuggester( 157 | response_parser=make_textcat_response_parser(labels=labels), 158 | prompt_template=load_template(prompt_path), 159 | labels=labels, 160 | max_examples=max_examples, 161 | segment=segment, 162 | openai_api_org=api_org, 163 | openai_api_key=api_key, 164 | openai_n=1, 165 | openai_model=model, 166 | openai_retry_timeout_s=10, 167 | openai_read_timeout_s=20, 168 | openai_n_retries=10, 169 | render_vars={"exclusive_classes": exclusive_classes}, 170 | prompt_example_class=TextCatPromptExample, 171 | verbose=verbose, 172 | ) 173 | for eg in examples: 174 | openai.add_example(eg) 175 | if max_examples >= 1: 176 | db = prodigy.components.db.connect() 177 | db_examples = db.get_dataset(dataset) 178 | if db_examples: 179 | for eg in db_examples: 180 | if PromptExample.is_flagged(eg): 181 | openai.add_example(PromptExample.from_prodigy(eg, openai.labels)) 182 | 183 | # Set up the stream 184 | stream = prodigy.get_stream(filepath) 185 | stream = openai(tqdm.tqdm(stream), batch_size=batch_size, nlp=nlp) 186 | 187 | # Set up the Prodigy UI 188 | return { 189 | "dataset": dataset, 190 | "view_id": "blocks", 191 | "stream": stream, 192 | "update": openai.update, 193 | "config": { 194 | "labels": openai.labels, 195 | "batch_size": batch_size, 196 | "exclude_by": "input", 197 | "choice_style": "single" if exclusive_classes else "multiple", 198 | "blocks": [ 199 | {"view_id": "classification" if len(labels) == 1 else "choice"}, 200 | {"view_id": "html", "html_template": HTML_TEMPLATE}, 201 | ], 202 | "show_flag": True, 203 | "global_css": CSS_FILE_PATH.read_text(), 204 | }, 205 | } 206 | 207 | 208 | @prodigy.recipe( 209 | # fmt: off 210 | "textcat.openai.fetch", 211 | filepath=("Path to jsonl data to annotate", "positional", None, Path), 212 | output_path=("Path to save the output", "positional", None, Path), 213 | labels=("Labels (comma delimited)", "option", "L", lambda s: s.split(",")), 214 | lang=("Language to use for tokenizer.", "option", "l", str), 215 | model=("GPT-3 model to use for completion", "option", "m", str), 216 | prompt_path=("Path to jinja2 prompt template", "option", "p", Path), 217 | examples_path=("Examples file to help define the task", "option", "e", Path), 218 | max_examples=("Max examples to include in prompt", "option", "n", int), 219 | batch_size=("Batch size to send to OpenAI API", "option", "b", int), 220 | segment=("Split sentences", "flag", "S", bool), 221 | exclusive_classes=("Make the classification task exclusive", "flag", "E", bool), 222 | verbose=("Print extra information to terminal", "flag", "v", bool), 223 | # fmt: on 224 | ) 225 | def textcat_openai_fetch( 226 | filepath: Path, 227 | output_path: Path, 228 | labels: List[str], 229 | lang: str = "en", 230 | model: str = "text-davinci-003", 231 | batch_size: int = 10, 232 | segment: bool = False, 233 | prompt_path: Path = DEFAULT_PROMPT_PATH, 234 | examples_path: Optional[Path] = None, 235 | max_examples: int = 2, 236 | exclusive_classes: bool = False, 237 | verbose: bool = False, 238 | ): 239 | api_key, api_org = get_api_credentials(model) 240 | examples = read_prompt_examples(examples_path, example_class=TextCatPromptExample) 241 | nlp = spacy.blank(lang) 242 | 243 | if segment: 244 | nlp.add_pipe("sentencizer") 245 | 246 | if not exclusive_classes and len(labels) == 1: 247 | msg.warn( 248 | "Binary classification should always be exclusive. Setting " 249 | "`exclusive_classes` parameter to True" 250 | ) 251 | exclusive_classes = True 252 | 253 | # Create OpenAISuggester with GPT-3 parameters 254 | openai = OpenAISuggester( 255 | response_parser=make_textcat_response_parser(labels=labels), 256 | prompt_template=load_template(prompt_path), 257 | labels=labels, 258 | max_examples=max_examples, 259 | segment=segment, 260 | openai_api_org=api_org, 261 | openai_api_key=api_key, 262 | openai_n=1, 263 | openai_model=model, 264 | openai_retry_timeout_s=10, 265 | openai_read_timeout_s=20, 266 | openai_n_retries=10, 267 | render_vars={"exclusive_classes": exclusive_classes}, 268 | prompt_example_class=TextCatPromptExample, 269 | verbose=verbose, 270 | ) 271 | for eg in examples: 272 | openai.add_example(eg) 273 | 274 | # Set up the stream 275 | stream = prodigy.get_stream(filepath) 276 | stream = openai(tqdm.tqdm(stream), batch_size=batch_size, nlp=nlp) 277 | srsly.write_jsonl(output_path, stream) 278 | -------------------------------------------------------------------------------- /recipes/style.css: -------------------------------------------------------------------------------- 1 | .prodigy-container > .prodigy-content { 2 | white-space: normal; 3 | border-top: 1px solid #ddd; 4 | } 5 | 6 | .cleaned { 7 | text-align: left; 8 | font-size: 14px; 9 | } 10 | 11 | .cleaned pre { 12 | background-color: #eee; 13 | font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; 14 | padding: 15px 20px; 15 | border-radius: 15px; 16 | white-space: pre-wrap; 17 | } 18 | 19 | summary { 20 | font-weight: bold; 21 | cursor: pointer; 22 | font-size: 1.2em; 23 | } 24 | 25 | details { 26 | margin-bottom: 1rem; 27 | } 28 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | httpx 2 | rich==12.6.0 3 | python-dotenv==0.21.0 4 | openai==0.25.0 5 | pytest>=5.2.0,!=7.1.0 6 | pytest-httpx==0.21.3 -------------------------------------------------------------------------------- /requirements_train.txt: -------------------------------------------------------------------------------- 1 | evaluate 2 | seqeval 3 | transformers 4 | torch -------------------------------------------------------------------------------- /scripts/train_hf_ner.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Dict, List, Optional 3 | 4 | import evaluate 5 | import numpy as np 6 | import spacy.vocab 7 | import typer 8 | from spacy.tokens import DocBin 9 | from transformers import ( 10 | AutoModelForTokenClassification, 11 | AutoTokenizer, 12 | DataCollatorForTokenClassification, 13 | Trainer, 14 | TrainingArguments, 15 | ) 16 | from transformers.tokenization_utils_base import BatchEncoding 17 | 18 | app = typer.Typer() 19 | 20 | # This can't be imported like a normal library 21 | seqeval = evaluate.load("seqeval") 22 | 23 | 24 | def spacy2hf( 25 | binary_file: Path, 26 | label2id: Dict[str, int], 27 | tokenizer: AutoTokenizer, 28 | ) -> List[BatchEncoding]: 29 | """Given a path to a .spacy file, a label mapping, and an HF tokenizer, 30 | return HF tokens with NER labels. 31 | """ 32 | 33 | db = DocBin().from_disk(binary_file) 34 | 35 | hfdocs = [] 36 | # Normally you would use the vocab from an nlp object, but we just need 37 | # this for deserialization before conversion. 38 | vocab = spacy.vocab.Vocab() 39 | # first, make ids for all labels 40 | for doc in db.get_docs(vocab): 41 | labels = [] 42 | toks = [] 43 | for tok in doc: 44 | toks.append(tok.text) 45 | if tok.ent_type == 0: 46 | labels.append("O") 47 | continue 48 | ent_label = f"{tok.ent_iob_}-{tok.ent_type_}" 49 | if ent_label not in label2id: 50 | label2id[ent_label] = len(label2id) 51 | labels.append(ent_label) 52 | 53 | # now do the hf tokenization 54 | tokens_hf = tokenizer(toks, truncation=True, is_split_into_words=True) 55 | labels_hf = [] 56 | 57 | for word_id in tokens_hf.word_ids(): 58 | if word_id is None: 59 | # for things like [CLS] 60 | labels_hf.append(-100) 61 | else: 62 | # The docs note it's common to assign -100 to subwords after the 63 | # first inside an entity, but this does the simpler thing. 64 | label = label2id[labels[word_id]] 65 | labels_hf.append(label) 66 | tokens_hf["labels"] = labels_hf 67 | 68 | hfdocs.append(tokens_hf) 69 | 70 | return hfdocs 71 | 72 | 73 | def build_compute_metrics(label_list: Dict[int, str]): 74 | def compute_metrics(p): 75 | prediction_matrix, label_matrix = p 76 | prediction_matrix = np.argmax(prediction_matrix, axis=2) 77 | 78 | all_predictions = [ 79 | [label_list[pred] for (pred, label) in zip(predictions, labels) if label != -100] 80 | for predictions, labels in zip(prediction_matrix, label_matrix) 81 | ] 82 | gold_labels = [ 83 | [label_list[label] for (pred, label) in zip(predictions, labels) if label != -100] 84 | for predictions, labels in zip(prediction_matrix, label_matrix) 85 | ] 86 | 87 | results = seqeval.compute(predictions=all_predictions, references=gold_labels, zero_division=0) 88 | return { 89 | "precision": _round_float(results["overall_precision"]), 90 | "recall": _round_float(results["overall_recall"]), 91 | "f1": _round_float(results["overall_f1"]), 92 | "accuracy": _round_float(results["overall_accuracy"]), 93 | } 94 | 95 | return compute_metrics 96 | 97 | 98 | def _round_float(number: float) -> float: 99 | return round(number, 3) 100 | 101 | 102 | def train_ner( 103 | base_model: str, 104 | tokenizer: AutoTokenizer, 105 | id2label: Dict[int, str], 106 | train_data: List[BatchEncoding], 107 | test_data: List[BatchEncoding], 108 | ) -> Trainer: 109 | """Fine-tune an existing HF model.""" 110 | model = AutoModelForTokenClassification.from_pretrained( 111 | base_model, num_labels=len(id2label) 112 | ) 113 | 114 | batch_size = 16 115 | 116 | args = TrainingArguments( 117 | f"test-ner", 118 | evaluation_strategy="epoch", 119 | learning_rate=1e-4, 120 | per_device_train_batch_size=batch_size, 121 | per_device_eval_batch_size=batch_size, 122 | num_train_epochs=20, 123 | weight_decay=1e-5, 124 | disable_tqdm=True, 125 | # specify the optimizer to avoid a deprecation warning 126 | optim="adamw_torch", 127 | ) 128 | 129 | data_collator = DataCollatorForTokenClassification(tokenizer) 130 | compute_metrics = build_compute_metrics(id2label) 131 | 132 | trainer = Trainer( 133 | model, 134 | args, 135 | train_dataset=train_data, 136 | eval_dataset=test_data, 137 | data_collator=data_collator, 138 | tokenizer=tokenizer, 139 | compute_metrics=compute_metrics, 140 | ) 141 | 142 | trainer.train() 143 | trainer.evaluate() 144 | return trainer 145 | 146 | 147 | @app.command("train_hf_ner", context_settings={"allow_extra_args": False}) 148 | def train_hf_ner( 149 | # fmt: off 150 | train_file: Path = typer.Argument(..., help="Binary .spacy file containing training data", exists=True, allow_dash=False), 151 | dev_file: Path = typer.Argument(..., help="Binary .spacy file containing dev evaluation data", exists=True, allow_dash=False), 152 | output_path: Optional[Path] = typer.Option(None, "--output", "-o", help="Output directory to store trained pipeline in"), 153 | base_model: str = typer.Option("distilbert-base-uncased", "--base", "-b", help="Base transformer model to start from"), 154 | # fmt: on 155 | ): 156 | """Fine-tune a HuggingFace NER model using a .spacy file as input.""" 157 | # prep the data 158 | tokenizer = AutoTokenizer.from_pretrained(base_model) 159 | label2id = {"O": 0} 160 | train = spacy2hf(train_file, label2id, tokenizer) 161 | test = spacy2hf(dev_file, label2id, tokenizer) 162 | # handle the mapping 163 | id2label = {v: k for k, v in label2id.items()} 164 | # actually train 165 | trainer = train_ner(base_model, tokenizer, id2label, train, test) 166 | if output_path: 167 | trainer.save_model(str(output_path)) 168 | 169 | 170 | if __name__ == "__main__": 171 | app() 172 | -------------------------------------------------------------------------------- /templates/ab/input.jinja2: -------------------------------------------------------------------------------- 1 | A haiku about {{topic}}. -------------------------------------------------------------------------------- /templates/ab/prompt1.jinja2: -------------------------------------------------------------------------------- 1 | Write a haiku about {{topic}}. 2 | -------------------------------------------------------------------------------- /templates/ab/prompt2.jinja2: -------------------------------------------------------------------------------- 1 | Write an incredibly hilarious haiku about {{topic}}. -------------------------------------------------------------------------------- /templates/ner_prompt.jinja2: -------------------------------------------------------------------------------- 1 | From the text below, extract the following entities in the following format: 2 | {# whitespace #} 3 | {%- for label in labels -%} 4 | {{label}}: 5 | {# whitespace #} 6 | {%- endfor -%} 7 | {# whitespace #} 8 | Text: 9 | """ 10 | {{text}} 11 | """ 12 | {# whitespace #} 13 | {%- if examples -%} 14 | {# whitespace #} 15 | For example: 16 | {# whitespace #} 17 | {# whitespace #} 18 | {%- for example in examples -%} 19 | Text: 20 | """ 21 | {{ example.text }} 22 | """ 23 | {# whitespace #} 24 | {%- for label, substrings in example.entities.items() -%} 25 | {{ label }}: {{ ", ".join(substrings) }} 26 | {# whitespace #} 27 | {%- endfor -%} 28 | {# whitespace #} 29 | {% endfor -%} 30 | {%- endif -%} 31 | -------------------------------------------------------------------------------- /templates/terms_prompt.jinja2: -------------------------------------------------------------------------------- 1 | Generate me {{n}} examples of {{description}}. 2 | 3 | Here are the examples: 4 | {%- if examples -%} 5 | {%- for example in examples -%} 6 | {# whitespace #} 7 | - {{example}} 8 | {%- endfor -%} 9 | {%- endif -%} 10 | {# whitespace #} 11 | - 12 | -------------------------------------------------------------------------------- /templates/textcat_prompt.jinja2: -------------------------------------------------------------------------------- 1 | {% if labels|length == 1 %} 2 | {% set label = labels[0] %} 3 | From the text below, determine whether or not it contains a {{ label }}. If it is 4 | a {{ label }}, answer "accept." If it is not a {{ label }}, answer "reject." 5 | {% else %} 6 | Classify the text below to any of the following labels: {{ labels|join(", ") }} 7 | {% if exclusive_classes %} 8 | The task is exclusive, so only choose one label from what I provided. 9 | {% else %} 10 | The task is non-exclusive, so you can provide more than one label as long as 11 | they're comma-delimited. For example: Label1, Label2, Label3. 12 | {% endif %} 13 | {% endif %} 14 | {# whitespace #} 15 | Your answer should only be in the following format: 16 | {# whitespace #} 17 | answer: 18 | reason: 19 | {# whitespace #} 20 | {% if examples %} 21 | Below are some examples (only use these as a guide): 22 | {# whitespace #} 23 | {# whitespace #} 24 | {% for example in examples %} 25 | Text: 26 | """ 27 | {{ example.text }} 28 | """ 29 | {# whitespace #} 30 | answer: {{ example.answer }} 31 | reason: {{ example.reason }} 32 | {% endfor %} 33 | {% endif %} 34 | {# whitespace #} 35 | Here is the text that needs classification 36 | {# whitespace #} 37 | Text: 38 | """ 39 | {{text}} 40 | """ 41 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/explosion/prodigy-openai-recipes/a0866f0d05f0a744cab15ad202201c5b691d48b9/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_ner.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from pathlib import Path 3 | from typing import Dict, List, Tuple 4 | 5 | import pytest 6 | import spacy 7 | 8 | from recipes.openai import load_template 9 | from recipes.openai_ner import DEFAULT_PROMPT_PATH, NERPromptExample 10 | from recipes.openai_ner import _find_substrings 11 | from tests.utils import make_suggester 12 | 13 | from recipes.openai_ner import make_ner_response_parser 14 | 15 | 16 | def test_multiple_substrings(): 17 | text = "The Blargs is the debut album by rock band The Blargs." 18 | substrings = ["The Blargs", "rock"] 19 | res = _find_substrings(text, substrings, single_match=False) 20 | assert res == [(0, 10), (43, 53), (33, 37)] 21 | res = _find_substrings(text, substrings, single_match=True) 22 | assert res == [(0, 10), (33, 37)] 23 | 24 | 25 | def test_substrings_case(): 26 | text = "A, a, B, b, a,b,c,d" 27 | substrings = ["a,"] 28 | res = _find_substrings(text, substrings, single_match=False, case_sensitive=True) 29 | assert res == [(3, 5), (12, 14)] 30 | res = _find_substrings(text, substrings, single_match=False, case_sensitive=False) 31 | assert res == [(0, 2), (3, 5), (12, 14)] 32 | res = _find_substrings(text, substrings, single_match=True, case_sensitive=True) 33 | assert res == [(3, 5)] 34 | res = _find_substrings(text, substrings, single_match=True, case_sensitive=False) 35 | assert res == [(0, 2)] 36 | 37 | 38 | def test_template_no_examples(): 39 | text = "David Bowie lived in Berlin in the 1970s." 40 | labels = ["PERSON", "PLACE", "PERIOD"] 41 | examples = [] 42 | path = Path(__file__).parent.parent / "templates" / "ner_prompt.jinja2" 43 | template = load_template(path) 44 | prompt = template.render(text=text, labels=labels, examples=examples) 45 | assert ( 46 | prompt 47 | == f""" 48 | From the text below, extract the following entities in the following format: 49 | PERSON: 50 | PLACE: 51 | PERIOD: 52 | 53 | Text: 54 | \"\"\" 55 | David Bowie lived in Berlin in the 1970s. 56 | \"\"\" 57 | """.lstrip() 58 | ) 59 | 60 | 61 | def test_template_two_examples(): 62 | text = "David Bowie lived in Berlin in the 1970s." 63 | labels = ["PERSON", "PLACE", "PERIOD"] 64 | examples = [ 65 | NERPromptExample( 66 | text="New York is a large city.", entities={"PLACE": ["New York"]} 67 | ), 68 | NERPromptExample( 69 | text="David Hasslehoff and Helena Fischer are big in Germany.", 70 | entities={ 71 | "PERSON": ["David Hasslehoff", "Helena Fischer"], 72 | "PLACE": ["Germany"], 73 | }, 74 | ), 75 | ] 76 | path = Path(__file__).parent.parent / "templates" / "ner_prompt.jinja2" 77 | template = load_template(path) 78 | prompt = template.render(text=text, labels=labels, examples=examples) 79 | assert ( 80 | prompt 81 | == f""" 82 | From the text below, extract the following entities in the following format: 83 | PERSON: 84 | PLACE: 85 | PERIOD: 86 | 87 | Text: 88 | \"\"\" 89 | David Bowie lived in Berlin in the 1970s. 90 | \"\"\" 91 | 92 | For example: 93 | 94 | Text: 95 | \"\"\" 96 | New York is a large city. 97 | \"\"\" 98 | PLACE: New York 99 | 100 | Text: 101 | \"\"\" 102 | David Hasslehoff and Helena Fischer are big in Germany. 103 | \"\"\" 104 | PERSON: David Hasslehoff, Helena Fischer 105 | PLACE: Germany 106 | 107 | """.lstrip() 108 | ) 109 | 110 | 111 | @pytest.mark.parametrize( 112 | "comment,text,raw_spans,filtered_spans", 113 | [ 114 | ( 115 | "Does it take longest when nested and shorter is first?", 116 | "one two three four", 117 | [("a", 0, 1), ("b", 0, 2)], 118 | [("b", 0, 2)], 119 | ), 120 | ( 121 | "Does it take longest when nested and shorter is last?", 122 | "one two three four", 123 | [("b", 0, 2), ("a", 0, 1)], 124 | [("b", 0, 2)], 125 | ), 126 | ( 127 | "Does it take first when overlapping and shorter is first?", 128 | "one two three four", 129 | [("b", 0, 2), ("a", 1, 3)], 130 | [("b", 0, 2)], 131 | ), 132 | ( 133 | "Does it take first when overlapping and shorter is last?", 134 | "one two three four", 135 | [("a", 0, 2), ("b", 1, 3)], 136 | [("a", 0, 2)], 137 | ), 138 | ], 139 | ) 140 | def test_one_token_per_span( 141 | comment: str, 142 | text: str, 143 | raw_spans: List[Tuple[str, int, int]], 144 | filtered_spans: List[Tuple[int, int]], 145 | ): 146 | labels = list(sorted(set(label for label, _, _1 in raw_spans))) 147 | suggester = make_suggester( 148 | response_parser=make_ner_response_parser(labels=labels, lang="en"), 149 | prompt_path=DEFAULT_PROMPT_PATH, 150 | labels=labels, 151 | openai_api_key="fake api key", 152 | openai_api_org="fake api org", 153 | prompt_example_class=NERPromptExample, 154 | ) 155 | prompt = suggester._get_prompt(text, labels=labels, examples=[]) 156 | response = _get_response(text, labels, raw_spans) 157 | stream = [{"text": text, "openai": {"prompt": prompt, "response": response}}] 158 | stream = list(suggester.format_suggestions(stream, nlp=spacy.blank("en"))) 159 | output_spans = [ 160 | (s["label"], s["token_start"], s["token_end"] + 1) for s in stream[0]["spans"] 161 | ] 162 | assert output_spans == filtered_spans 163 | # Also check there's no overlaps 164 | seen_tokens = set() 165 | for _, start, end in output_spans: 166 | for i in range(start, end): 167 | assert i not in seen_tokens, "Overlapping or nested spans found" 168 | seen_tokens.add(i) 169 | 170 | 171 | def _get_response(text: str, labels, spans: List[Tuple[str, int, int]]) -> str: 172 | # Get table of start and end character offsets for the test spans. 173 | tokens = text.split() 174 | start_chars, end_chars = _get_token_char_maps(tokens, [True for _ in tokens]) 175 | spans_by_label = defaultdict(list) 176 | for label, start, end in spans: 177 | spans_by_label[label].append(text[start_chars[start] : end_chars[end - 1]]) 178 | response_lines = [] 179 | for label in labels: 180 | response_lines.append(f"{label}: {', '.join(spans_by_label[label])}") 181 | return "\n".join(response_lines) 182 | 183 | 184 | def _get_token_char_maps( 185 | tokens: List[str], whitespace: List[bool] 186 | ) -> Tuple[Dict[int, int], Dict[int, int]]: 187 | idx = 0 188 | start_chars = {} 189 | end_chars = {} 190 | for i, token in enumerate(tokens): 191 | start_chars[i] = idx 192 | idx += len(token) + int(whitespace[i]) 193 | end_chars[i] = idx 194 | return start_chars, end_chars 195 | -------------------------------------------------------------------------------- /tests/test_openai.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import pytest 3 | import jinja2 4 | import httpx 5 | from pytest_httpx import HTTPXMock 6 | 7 | from recipes.openai import OpenAISuggester, get_api_credentials 8 | 9 | 10 | # Setup the template and the suggester 11 | environment = jinja2.Environment() 12 | template = environment.from_string("Prompt: {{ text }}") 13 | 14 | openai = OpenAISuggester( 15 | prompt_template=template, 16 | labels=["label1", "label2"], 17 | max_examples=1, 18 | segment=False, 19 | openai_model="text-davinci-003", 20 | openai_api_key="Fake api key", 21 | openai_api_org="Fake api org", 22 | response_parser=lambda x: {"key": "value"}, 23 | openai_n_retries=1, 24 | openai_read_timeout_s=1, 25 | openai_retry_timeout_s=1, 26 | prompt_example_class=None, 27 | ) 28 | 29 | 30 | @pytest.mark.parametrize( 31 | "prompts,response_text", 32 | [ 33 | (["A single prompt"], ["A single response"]), 34 | (["A batch", "of prompts"], ["A batch", "of responses"]), 35 | ], 36 | ) 37 | def test_openai_response_follows_contract( 38 | httpx_mock: HTTPXMock, prompts: List[str], response_text 39 | ): 40 | """Test happy path where OpenAI follows the contract and we can parse it 41 | https://beta.openai.com/docs/api-reference/completions 42 | """ 43 | 44 | httpx_mock.add_response( 45 | method="POST", 46 | json={ 47 | "choices": [ 48 | { 49 | "text": text, 50 | "index": index, 51 | "logprobs": 0.1, 52 | "finish_reason": "length", 53 | } 54 | for index, text in enumerate(response_text) 55 | ] 56 | }, 57 | ) 58 | 59 | chatgpt_response = openai._get_openai_response(prompts=prompts) 60 | assert len(chatgpt_response) == len(prompts) 61 | assert set(chatgpt_response) == set(response_text) 62 | 63 | 64 | @pytest.mark.parametrize("error_code", openai.RETRY_ERROR_CODES) 65 | def test_retry_function_when_calls_fail(httpx_mock, error_code): 66 | """Test if status error shows up after all failed retries.""" 67 | httpx_mock.add_response(status_code=error_code) 68 | with pytest.raises(httpx.HTTPStatusError): 69 | openai._get_openai_response(prompts=["Some prompt"]) 70 | 71 | 72 | @pytest.mark.parametrize( 73 | "key,org", [(None, "fake api org"), ("fake api key", None), (None, None)] 74 | ) 75 | def test_get_api_credentials_error_handling_envvars(monkeypatch, key, org): 76 | """Ensure that auth fails whenever key or org is none.""" 77 | monkeypatch.setenv("OPENAI_KEY", key) 78 | monkeypatch.setenv("OPENAI_ORG", org) 79 | with pytest.raises(SystemExit): 80 | get_api_credentials(model="text-davinci-003") 81 | 82 | 83 | @pytest.mark.parametrize("error_code", [422, 500, 501]) 84 | def test_get_api_credentials_calls_fail(httpx_mock, monkeypatch, error_code): 85 | """Ensure that auth fails when we encounter an error code.""" 86 | httpx_mock.add_response(status_code=error_code) 87 | monkeypatch.setenv("OPENAI_KEY", "fake api key") 88 | monkeypatch.setenv("OPENAI_ORG", "fake api org") 89 | 90 | with pytest.raises(SystemExit): 91 | get_api_credentials(model="text-davinci-003") 92 | 93 | 94 | def test_get_api_credentials_model_does_not_exist(httpx_mock, monkeypatch): 95 | """Ensure that auth fails when model passed does not exist. 96 | https://beta.openai.com/docs/api-reference/models/list 97 | """ 98 | httpx_mock.add_response( 99 | method="GET", 100 | json={ 101 | "data": [{"id": "model-id-0"}, {"id": "model-id-1"}, {"id": "model-id-2"}] 102 | }, 103 | ) 104 | monkeypatch.setenv("OPENAI_KEY", "fake api key") 105 | monkeypatch.setenv("OPENAI_ORG", "fake api org") 106 | 107 | with pytest.raises(SystemExit): 108 | get_api_credentials(model="a-model-that-does-not-exist") 109 | -------------------------------------------------------------------------------- /tests/test_terms.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from recipes.openai_terms import _parse_terms 4 | 5 | # We cannot assume that `risk` is a complete game because 6 | # OpenAI may have exhausted the tokens 7 | base_completion = """monopoly 8 | - Scrabble 9 | - risk 10 | """ 11 | 12 | # Added some trailing spaces in this one 13 | base_completion_with_trailing_spaces = """monopoly 14 | - scrabble 15 | - risk 16 | """ 17 | 18 | # An example where the tokens may have been exhausted 19 | # note the capitalisation 20 | trailing_token_completion = """monopoly 21 | - scrabble 22 | - Risk 23 | - carcas 24 | """ 25 | 26 | # This can also happen 27 | single_line_completion = "monopoly" 28 | 29 | 30 | @pytest.mark.parametrize( 31 | "comment,completion,expectation", 32 | [ 33 | ( 34 | "Base OpenAI completion with capitalisation", 35 | base_completion, 36 | ["monopoly", "Scrabble"], 37 | ), 38 | ( 39 | "Check trailing spaces", 40 | base_completion_with_trailing_spaces, 41 | ["monopoly", "scrabble"], 42 | ), 43 | ( 44 | "Completion with bad final item", 45 | trailing_token_completion, 46 | ["monopoly", "scrabble", "Risk"], 47 | ), 48 | ( 49 | "Example of a single-line OpenAI completion", 50 | single_line_completion, 51 | ["monopoly"], 52 | ), 53 | ], 54 | ) 55 | def test_parse_terms(comment, completion, expectation): 56 | assert _parse_terms(completion=completion) == expectation 57 | -------------------------------------------------------------------------------- /tests/test_textcat.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from recipes.openai_textcat import make_textcat_response_parser 4 | 5 | 6 | @pytest.mark.parametrize( 7 | # fmt: off 8 | "response,answer", 9 | [ 10 | ("Answer: accept\nReason: The text is a recipe.", "accept"), 11 | ("Answer: Accept\nReason: The text is a recipe.", "accept"), 12 | ("Answer: reject\nReason: The text is not a recipe.", "reject"), 13 | ("Answer: Reject\nReason: The text is not a recipe.", "reject"), 14 | ("answer: reject\nreason: The text is not a recipe.", "reject"), 15 | ("answer: Reject\nreason: The text is not a recipe.\nI don't know what it's about.", "reject"), 16 | ], 17 | # fmt: on 18 | ) 19 | def test_parse_response_binary(response, answer): 20 | """Test if parse response works for common examples""" 21 | labels = ["recipe"] # binary 22 | parser = make_textcat_response_parser(labels=labels) 23 | example = parser(response) 24 | assert example.get("answer") == answer 25 | 26 | 27 | @pytest.mark.parametrize( 28 | # fmt: off 29 | "response,answer", 30 | [ 31 | ("Answer: recipe,feedback,question\nReason: It is all three!", ["recipe", "feedback", "question"]), 32 | ("Answer: recipe\nReason: The text is a recipe.", ["recipe"]), 33 | ("Answer: recipe,\nReason: The text is a recipe.", ["recipe"]), 34 | ("Answer: recipe,feedback,\nReason: The text is a feedback about a recipe.\nThat's what I think...", ["recipe", "feedback"]), 35 | ("answer: recipe,feedback,\nreason: The text is a feedback about a recipe.\nThat's what I think...", ["recipe", "feedback"]), 36 | ("answer: \nreason: It's none of the above.", []), 37 | ], 38 | # fmt: on 39 | ) 40 | def test_parser_response_multi(response, answer): 41 | """Test if parse response works for common examples""" 42 | labels = ["recipe", "feedback", "question"] # multiclass 43 | parser = make_textcat_response_parser(labels=labels) 44 | example = parser(response) 45 | assert set(example.get("accept")) == set(answer) 46 | 47 | 48 | @pytest.mark.parametrize("labels", [["binary"], ["multi1", "multi2"]]) 49 | def test_parser_no_answer(labels): 50 | """Test if parse response works for common examples""" 51 | empty_response = "" 52 | parser = make_textcat_response_parser(labels=labels) 53 | example = parser(empty_response) 54 | assert not example.get("accept") 55 | assert not example.get("reason") 56 | 57 | 58 | @pytest.mark.parametrize("labels", [["binary"], ["multi1", "multi2"]]) 59 | @pytest.mark.parametrize( 60 | "response", ["asdfghjklmnop", "I am now a sentient robot. Bow before me."] 61 | ) 62 | def test_parser_openai_returns_arbitrary_text(labels, response): 63 | """Test if parser response works for any arbitrary text""" 64 | parser = make_textcat_response_parser(labels=labels) 65 | example = parser(response) 66 | assert not example.get("accept") 67 | assert not example.get("reason") 68 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Callable, List, Optional 3 | 4 | from recipes.openai import OpenAISuggester, PromptExample, get_api_credentials 5 | from recipes.openai import load_template 6 | 7 | 8 | def make_suggester( 9 | response_parser: Callable, 10 | labels: List[str], 11 | prompt_path: Path, 12 | prompt_example_class: Optional[PromptExample] = None, 13 | model: str = "text-davinci-003", 14 | **kwargs 15 | ) -> OpenAISuggester: 16 | if "openai_api_key" not in kwargs or "openai_api_org" not in kwargs: 17 | api_key, api_org = get_api_credentials(model) 18 | if "openai_api_key" not in kwargs: 19 | kwargs["openai_api_key"] = api_key 20 | if "openai_api_org" not in kwargs: 21 | kwargs["openai_api_org"] = api_org 22 | if "max_examples" not in kwargs: 23 | kwargs["max_examples"] = 0 24 | if "prompt_template" not in kwargs: 25 | kwargs["prompt_template"] = load_template(prompt_path) 26 | if "segment" not in kwargs: 27 | kwargs["segment"] = False 28 | if "openai_model" not in kwargs: 29 | kwargs["openai_model"] = "text-davinci-003" 30 | 31 | openai = OpenAISuggester( 32 | response_parser=response_parser, 33 | labels=labels, 34 | prompt_example_class=prompt_example_class, 35 | **kwargs 36 | ) 37 | return openai 38 | --------------------------------------------------------------------------------