├── openai ├── py.typed ├── version.py ├── api_resources │ ├── experimental │ │ ├── __init__.py │ │ └── completion_config.py │ ├── model.py │ ├── answer.py │ ├── classification.py │ ├── abstract │ │ ├── updateable_api_resource.py │ │ ├── __init__.py │ │ ├── deletable_api_resource.py │ │ ├── createable_api_resource.py │ │ ├── listable_api_resource.py │ │ ├── nested_resource_class_methods.py │ │ ├── api_resource.py │ │ └── engine_api_resource.py │ ├── customer.py │ ├── moderation.py │ ├── search.py │ ├── error_object.py │ ├── __init__.py │ ├── completion.py │ ├── edit.py │ ├── engine.py │ ├── embedding.py │ ├── image.py │ ├── deployment.py │ ├── fine_tune.py │ └── file.py ├── object_classes.py ├── openai_response.py ├── tests │ ├── test_util.py │ ├── test_exceptions.py │ ├── test_file_cli.py │ ├── test_endpoints.py │ ├── test_long_examples_validator.py │ ├── test_api_requestor.py │ └── test_url_composition.py ├── upload_progress.py ├── __init__.py ├── _openai_scripts.py ├── error.py ├── util.py ├── embeddings_utils.py ├── openai_object.py ├── wandb_logger.py ├── api_requestor.py └── validators.py ├── public ├── Makefile └── setup.py ├── examples ├── codex │ └── backtranslation.py ├── finetuning │ ├── answers_with_ft.py │ ├── olympics-2-create-qa.ipynb │ ├── olympics-3-train-qa.ipynb │ ├── olympics-1-collect-data.ipynb │ └── finetuning-classification.ipynb ├── README.md ├── embeddings │ ├── Clustering.ipynb │ ├── Code_search.ipynb │ ├── Get_embeddings.ipynb │ ├── Obtain_dataset.ipynb │ ├── Recommendation.ipynb │ ├── Regression.ipynb │ ├── Visualize_in_3d.ipynb │ ├── Visualize_in_2d.ipynb │ ├── Classification.ipynb │ ├── User_and_product_embeddings.ipynb │ ├── Zero-shot_classification.ipynb │ └── Semantic_text_search_using_embeddings.ipynb └── azure │ ├── embeddings.ipynb │ └── finetuning.ipynb ├── pytest.ini ├── Makefile ├── pyproject.toml ├── .gitignore ├── LICENSE ├── setup.py └── README.md /openai/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /openai/version.py: -------------------------------------------------------------------------------- 1 | VERSION = "0.25.0" 2 | -------------------------------------------------------------------------------- /public/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: build upload 2 | 3 | build: 4 | OPENAI_UPLOAD=y python setup.py sdist 5 | 6 | upload: 7 | OPENAI_UPLOAD=y twine upload dist/* 8 | -------------------------------------------------------------------------------- /examples/codex/backtranslation.py: -------------------------------------------------------------------------------- 1 | # this example has moved to https://github.com/openai/openai-cookbook/blob/main/examples/Backtranslation_of_SQL_queries.py 2 | -------------------------------------------------------------------------------- /openai/api_resources/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | from openai.api_resources.experimental.completion_config import ( # noqa: F401 2 | CompletionConfig, 3 | ) 4 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | url: mark a test as part of the url composition tests. 4 | requestor: mark test as part of the api_requestor tests. 5 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: build upload 2 | 3 | build: 4 | python setup.py sdist 5 | 6 | upload: 7 | twine upload dist/openai-*.tar.gz 8 | rm dist/openai-*.tar.gz 9 | 10 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | target-version = ['py36'] 3 | exclude = '.*\.ipynb' 4 | 5 | [tool.isort] 6 | py_version = 36 7 | include_trailing_comma = "true" 8 | line_length = 88 9 | multi_line_output = 3 -------------------------------------------------------------------------------- /openai/api_resources/model.py: -------------------------------------------------------------------------------- 1 | from openai.api_resources.abstract import DeletableAPIResource, ListableAPIResource 2 | 3 | 4 | class Model(ListableAPIResource, DeletableAPIResource): 5 | OBJECT_NAME = "models" 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | .idea 3 | .python-version 4 | /public/dist 5 | __pycache__ 6 | build 7 | *.egg 8 | .vscode/settings.json 9 | .ipynb_checkpoints 10 | .vscode/launch.json 11 | examples/azure/training.jsonl 12 | examples/azure/validation.jsonl 13 | -------------------------------------------------------------------------------- /examples/finetuning/answers_with_ft.py: -------------------------------------------------------------------------------- 1 | # This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) 2 | # at [examples/fine-tuned_qa](https://github.com/openai/openai-cookbook/tree/main/examples/fine-tuned_qa) 3 | -------------------------------------------------------------------------------- /openai/api_resources/answer.py: -------------------------------------------------------------------------------- 1 | from openai.openai_object import OpenAIObject 2 | 3 | 4 | class Answer(OpenAIObject): 5 | @classmethod 6 | def get_url(self): 7 | return "/answers" 8 | 9 | @classmethod 10 | def create(cls, **params): 11 | instance = cls() 12 | return instance.request("post", cls.get_url(), params) 13 | -------------------------------------------------------------------------------- /openai/api_resources/experimental/completion_config.py: -------------------------------------------------------------------------------- 1 | from openai.api_resources.abstract import ( 2 | CreateableAPIResource, 3 | DeletableAPIResource, 4 | ListableAPIResource, 5 | ) 6 | 7 | 8 | class CompletionConfig( 9 | CreateableAPIResource, ListableAPIResource, DeletableAPIResource 10 | ): 11 | OBJECT_NAME = "experimental.completion_configs" 12 | -------------------------------------------------------------------------------- /openai/api_resources/classification.py: -------------------------------------------------------------------------------- 1 | from openai.openai_object import OpenAIObject 2 | 3 | 4 | class Classification(OpenAIObject): 5 | @classmethod 6 | def get_url(self): 7 | return "/classifications" 8 | 9 | @classmethod 10 | def create(cls, **params): 11 | instance = cls() 12 | return instance.request("post", cls.get_url(), params) 13 | -------------------------------------------------------------------------------- /openai/object_classes.py: -------------------------------------------------------------------------------- 1 | from openai import api_resources 2 | from openai.api_resources.experimental.completion_config import CompletionConfig 3 | 4 | OBJECT_CLASSES = { 5 | "engine": api_resources.Engine, 6 | "experimental.completion_config": CompletionConfig, 7 | "file": api_resources.File, 8 | "fine-tune": api_resources.FineTune, 9 | "model": api_resources.Model, 10 | } 11 | -------------------------------------------------------------------------------- /openai/api_resources/abstract/updateable_api_resource.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import quote_plus 2 | 3 | from openai.api_resources.abstract.api_resource import APIResource 4 | 5 | 6 | class UpdateableAPIResource(APIResource): 7 | @classmethod 8 | def modify(cls, sid, **params): 9 | url = "%s/%s" % (cls.class_url(), quote_plus(sid)) 10 | return cls._static_request("post", url, **params) 11 | -------------------------------------------------------------------------------- /openai/api_resources/customer.py: -------------------------------------------------------------------------------- 1 | from openai.openai_object import OpenAIObject 2 | 3 | 4 | class Customer(OpenAIObject): 5 | @classmethod 6 | def get_url(self, customer, endpoint): 7 | return f"/customer/{customer}/{endpoint}" 8 | 9 | @classmethod 10 | def create(cls, customer, endpoint, **params): 11 | instance = cls() 12 | return instance.request("post", cls.get_url(customer, endpoint), params) 13 | -------------------------------------------------------------------------------- /public/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from setuptools import setup 4 | 5 | if os.getenv("OPENAI_UPLOAD") != "y": 6 | raise RuntimeError( 7 | "This package is a placeholder package on the public PyPI instance, and is not the correct version to install. If you are having trouble figuring out the correct package to install, please contact us." 8 | ) 9 | 10 | setup(name="openai", description="Placeholder package", version="0.0.1") 11 | -------------------------------------------------------------------------------- /openai/api_resources/abstract/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | from openai.api_resources.abstract.api_resource import APIResource 4 | from openai.api_resources.abstract.createable_api_resource import CreateableAPIResource 5 | from openai.api_resources.abstract.deletable_api_resource import DeletableAPIResource 6 | from openai.api_resources.abstract.listable_api_resource import ListableAPIResource 7 | from openai.api_resources.abstract.nested_resource_class_methods import ( 8 | nested_resource_class_methods, 9 | ) 10 | from openai.api_resources.abstract.updateable_api_resource import UpdateableAPIResource 11 | -------------------------------------------------------------------------------- /openai/openai_response.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | 4 | class OpenAIResponse: 5 | def __init__(self, data, headers): 6 | self._headers = headers 7 | self.data = data 8 | 9 | @property 10 | def request_id(self) -> Optional[str]: 11 | return self._headers.get("request-id") 12 | 13 | @property 14 | def organization(self) -> Optional[str]: 15 | return self._headers.get("OpenAI-Organization") 16 | 17 | @property 18 | def response_ms(self) -> Optional[int]: 19 | h = self._headers.get("Openai-Processing-Ms") 20 | return None if h is None else round(float(h)) 21 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples have moved to the [OpenAI Cookbook](https://github.com/openai/openai-cookbook/) 2 | 3 | Looking for code examples? Visit the [OpenAI Cookbook](https://github.com/openai/openai-cookbook/), which shares examples of how to use the OpenAI Python library to accomplish common tasks. 4 | 5 | Prior to July 2022, code examples were hosted in this examples folder; going forward, code examples will be hosted in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook/). 6 | 7 | This separation will help keep the [OpenAI Python library](https://github.com/openai/openai-python) simple and small, without extra files or dependencies. 8 | -------------------------------------------------------------------------------- /openai/tests/test_util.py: -------------------------------------------------------------------------------- 1 | from tempfile import NamedTemporaryFile 2 | 3 | import pytest 4 | 5 | import openai 6 | from openai import util 7 | 8 | 9 | @pytest.fixture(scope="function") 10 | def api_key_file(): 11 | saved_path = openai.api_key_path 12 | try: 13 | with NamedTemporaryFile(prefix="openai-api-key", mode="wt") as tmp: 14 | openai.api_key_path = tmp.name 15 | yield tmp 16 | finally: 17 | openai.api_key_path = saved_path 18 | 19 | 20 | def test_openai_api_key_path(api_key_file) -> None: 21 | print("sk-foo", file=api_key_file) 22 | api_key_file.flush() 23 | assert util.default_api_key() == "sk-foo" 24 | 25 | 26 | def test_openai_api_key_path_with_malformed_key(api_key_file) -> None: 27 | print("malformed-api-key", file=api_key_file) 28 | api_key_file.flush() 29 | with pytest.raises(ValueError, match="Malformed API key"): 30 | util.default_api_key() 31 | -------------------------------------------------------------------------------- /openai/api_resources/moderation.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | from openai.openai_object import OpenAIObject 4 | 5 | 6 | class Moderation(OpenAIObject): 7 | VALID_MODEL_NAMES: List[str] = ["text-moderation-stable", "text-moderation-latest"] 8 | 9 | @classmethod 10 | def get_url(self): 11 | return "/moderations" 12 | 13 | @classmethod 14 | def create(cls, input: Union[str, List[str]], model: Optional[str] = None, api_key: Optional[str] = None): 15 | if model is not None and model not in cls.VALID_MODEL_NAMES: 16 | raise ValueError( 17 | f"The parameter model should be chosen from {cls.VALID_MODEL_NAMES} " 18 | f"and it is default to be None." 19 | ) 20 | 21 | instance = cls(api_key=api_key) 22 | params = {"input": input} 23 | if model is not None: 24 | params["model"] = model 25 | return instance.request("post", cls.get_url(), params) 26 | -------------------------------------------------------------------------------- /openai/api_resources/search.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from openai import util 4 | from openai.api_resources.abstract.engine_api_resource import EngineAPIResource 5 | from openai.error import TryAgain 6 | 7 | 8 | class Search(EngineAPIResource): 9 | OBJECT_NAME = "search" 10 | 11 | @classmethod 12 | def create(cls, *args, **kwargs): 13 | """ 14 | Creates a new search for the provided input and parameters. 15 | 16 | See https://beta.openai.com/docs/api-reference/search for a list 17 | of valid parameters. 18 | """ 19 | 20 | start = time.time() 21 | timeout = kwargs.pop("timeout", None) 22 | 23 | while True: 24 | try: 25 | return super().create(*args, **kwargs) 26 | except TryAgain as e: 27 | if timeout is not None and time.time() > start + timeout: 28 | raise 29 | 30 | util.log_info("Waiting for model to warm up", error=e) 31 | -------------------------------------------------------------------------------- /openai/api_resources/error_object.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from openai.openai_object import OpenAIObject 4 | from openai.util import merge_dicts 5 | 6 | 7 | class ErrorObject(OpenAIObject): 8 | def refresh_from( 9 | self, 10 | values, 11 | api_key=None, 12 | api_version=None, 13 | api_type=None, 14 | organization=None, 15 | response_ms: Optional[int] = None, 16 | ): 17 | # Unlike most other API resources, the API will omit attributes in 18 | # error objects when they have a null value. We manually set default 19 | # values here to facilitate generic error handling. 20 | values = merge_dicts({"message": None, "type": None}, values) 21 | return super(ErrorObject, self).refresh_from( 22 | values=values, 23 | api_key=api_key, 24 | api_version=api_version, 25 | api_type=api_type, 26 | organization=organization, 27 | response_ms=response_ms, 28 | ) 29 | -------------------------------------------------------------------------------- /openai/api_resources/__init__.py: -------------------------------------------------------------------------------- 1 | from openai.api_resources.answer import Answer # noqa: F401 2 | from openai.api_resources.classification import Classification # noqa: F401 3 | from openai.api_resources.completion import Completion # noqa: F401 4 | from openai.api_resources.customer import Customer # noqa: F401 5 | from openai.api_resources.deployment import Deployment # noqa: F401 6 | from openai.api_resources.edit import Edit # noqa: F401 7 | from openai.api_resources.embedding import Embedding # noqa: F401 8 | from openai.api_resources.engine import Engine # noqa: F401 9 | from openai.api_resources.error_object import ErrorObject # noqa: F401 10 | from openai.api_resources.file import File # noqa: F401 11 | from openai.api_resources.fine_tune import FineTune # noqa: F401 12 | from openai.api_resources.image import Image # noqa: F401 13 | from openai.api_resources.model import Model # noqa: F401 14 | from openai.api_resources.moderation import Moderation # noqa: F401 15 | from openai.api_resources.search import Search # noqa: F401 16 | -------------------------------------------------------------------------------- /openai/api_resources/completion.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from openai import util 4 | from openai.api_resources.abstract import DeletableAPIResource, ListableAPIResource 5 | from openai.api_resources.abstract.engine_api_resource import EngineAPIResource 6 | from openai.error import TryAgain 7 | 8 | 9 | class Completion(EngineAPIResource): 10 | OBJECT_NAME = "completions" 11 | 12 | @classmethod 13 | def create(cls, *args, **kwargs): 14 | """ 15 | Creates a new completion for the provided prompt and parameters. 16 | 17 | See https://beta.openai.com/docs/api-reference/completions/create for a list 18 | of valid parameters. 19 | """ 20 | start = time.time() 21 | timeout = kwargs.pop("timeout", None) 22 | 23 | while True: 24 | try: 25 | return super().create(*args, **kwargs) 26 | except TryAgain as e: 27 | if timeout is not None and time.time() > start + timeout: 28 | raise 29 | 30 | util.log_info("Waiting for model to warm up", error=e) 31 | -------------------------------------------------------------------------------- /examples/embeddings/Clustering.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/Clustering.ipynb](https://github.com/openai/openai-cookbook/blob/main/examples/Clustering.ipynb)." 8 | ] 9 | } 10 | ], 11 | "metadata": { 12 | "kernelspec": { 13 | "display_name": "Python 3.9.9 ('openai')", 14 | "language": "python", 15 | "name": "python3" 16 | }, 17 | "language_info": { 18 | "codemirror_mode": { 19 | "name": "ipython", 20 | "version": 3 21 | }, 22 | "file_extension": ".py", 23 | "mimetype": "text/x-python", 24 | "name": "python", 25 | "nbconvert_exporter": "python", 26 | "pygments_lexer": "ipython3", 27 | "version": "3.9.9" 28 | }, 29 | "orig_nbformat": 4, 30 | "vscode": { 31 | "interpreter": { 32 | "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97" 33 | } 34 | } 35 | }, 36 | "nbformat": 4, 37 | "nbformat_minor": 2 38 | } 39 | -------------------------------------------------------------------------------- /examples/embeddings/Code_search.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/Code_search.ipynb](https://github.com/openai/openai-cookbook/blob/main/examples/Code_search.ipynb)." 8 | ] 9 | } 10 | ], 11 | "metadata": { 12 | "kernelspec": { 13 | "display_name": "Python 3.9.9 ('openai')", 14 | "language": "python", 15 | "name": "python3" 16 | }, 17 | "language_info": { 18 | "codemirror_mode": { 19 | "name": "ipython", 20 | "version": 3 21 | }, 22 | "file_extension": ".py", 23 | "mimetype": "text/x-python", 24 | "name": "python", 25 | "nbconvert_exporter": "python", 26 | "pygments_lexer": "ipython3", 27 | "version": "3.9.9" 28 | }, 29 | "orig_nbformat": 4, 30 | "vscode": { 31 | "interpreter": { 32 | "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97" 33 | } 34 | } 35 | }, 36 | "nbformat": 4, 37 | "nbformat_minor": 2 38 | } 39 | -------------------------------------------------------------------------------- /examples/finetuning/olympics-2-create-qa.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/fine-tuned_qa/](https://github.com/openai/openai-cookbook/tree/main/examples/fine-tuned_qa)." 8 | ] 9 | } 10 | ], 11 | "metadata": { 12 | "kernelspec": { 13 | "display_name": "Python 3.9.9 ('openai')", 14 | "language": "python", 15 | "name": "python3" 16 | }, 17 | "language_info": { 18 | "codemirror_mode": { 19 | "name": "ipython", 20 | "version": 3 21 | }, 22 | "file_extension": ".py", 23 | "mimetype": "text/x-python", 24 | "name": "python", 25 | "nbconvert_exporter": "python", 26 | "pygments_lexer": "ipython3", 27 | "version": "3.9.9" 28 | }, 29 | "orig_nbformat": 4, 30 | "vscode": { 31 | "interpreter": { 32 | "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97" 33 | } 34 | } 35 | }, 36 | "nbformat": 4, 37 | "nbformat_minor": 2 38 | } 39 | -------------------------------------------------------------------------------- /examples/finetuning/olympics-3-train-qa.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/fine-tuned_qa/](https://github.com/openai/openai-cookbook/tree/main/examples/fine-tuned_qa)." 8 | ] 9 | } 10 | ], 11 | "metadata": { 12 | "kernelspec": { 13 | "display_name": "Python 3.9.9 ('openai')", 14 | "language": "python", 15 | "name": "python3" 16 | }, 17 | "language_info": { 18 | "codemirror_mode": { 19 | "name": "ipython", 20 | "version": 3 21 | }, 22 | "file_extension": ".py", 23 | "mimetype": "text/x-python", 24 | "name": "python", 25 | "nbconvert_exporter": "python", 26 | "pygments_lexer": "ipython3", 27 | "version": "3.9.9" 28 | }, 29 | "orig_nbformat": 4, 30 | "vscode": { 31 | "interpreter": { 32 | "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97" 33 | } 34 | } 35 | }, 36 | "nbformat": 4, 37 | "nbformat_minor": 2 38 | } 39 | -------------------------------------------------------------------------------- /examples/azure/embeddings.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/azure/embeddings.ipynb](https://github.com/openai/openai-cookbook/tree/main/examples/azure/embeddings.ipynb)." 8 | ] 9 | } 10 | ], 11 | "metadata": { 12 | "kernelspec": { 13 | "display_name": "Python 3.9.9 ('openai')", 14 | "language": "python", 15 | "name": "python3" 16 | }, 17 | "language_info": { 18 | "codemirror_mode": { 19 | "name": "ipython", 20 | "version": 3 21 | }, 22 | "file_extension": ".py", 23 | "mimetype": "text/x-python", 24 | "name": "python", 25 | "nbconvert_exporter": "python", 26 | "pygments_lexer": "ipython3", 27 | "version": "3.9.9" 28 | }, 29 | "orig_nbformat": 4, 30 | "vscode": { 31 | "interpreter": { 32 | "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97" 33 | } 34 | } 35 | }, 36 | "nbformat": 4, 37 | "nbformat_minor": 2 38 | } 39 | -------------------------------------------------------------------------------- /examples/azure/finetuning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/azure/finetuning.ipynb](https://github.com/openai/openai-cookbook/tree/main/examples/azure/finetuning.ipynb)." 8 | ] 9 | } 10 | ], 11 | "metadata": { 12 | "kernelspec": { 13 | "display_name": "Python 3.9.9 ('openai')", 14 | "language": "python", 15 | "name": "python3" 16 | }, 17 | "language_info": { 18 | "codemirror_mode": { 19 | "name": "ipython", 20 | "version": 3 21 | }, 22 | "file_extension": ".py", 23 | "mimetype": "text/x-python", 24 | "name": "python", 25 | "nbconvert_exporter": "python", 26 | "pygments_lexer": "ipython3", 27 | "version": "3.9.9" 28 | }, 29 | "orig_nbformat": 4, 30 | "vscode": { 31 | "interpreter": { 32 | "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97" 33 | } 34 | } 35 | }, 36 | "nbformat": 4, 37 | "nbformat_minor": 2 38 | } 39 | -------------------------------------------------------------------------------- /examples/finetuning/olympics-1-collect-data.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/fine-tuned_qa/](https://github.com/openai/openai-cookbook/tree/main/examples/fine-tuned_qa)." 8 | ] 9 | } 10 | ], 11 | "metadata": { 12 | "kernelspec": { 13 | "display_name": "Python 3.9.9 ('openai')", 14 | "language": "python", 15 | "name": "python3" 16 | }, 17 | "language_info": { 18 | "codemirror_mode": { 19 | "name": "ipython", 20 | "version": 3 21 | }, 22 | "file_extension": ".py", 23 | "mimetype": "text/x-python", 24 | "name": "python", 25 | "nbconvert_exporter": "python", 26 | "pygments_lexer": "ipython3", 27 | "version": "3.9.9" 28 | }, 29 | "orig_nbformat": 4, 30 | "vscode": { 31 | "interpreter": { 32 | "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97" 33 | } 34 | } 35 | }, 36 | "nbformat": 4, 37 | "nbformat_minor": 2 38 | } 39 | -------------------------------------------------------------------------------- /examples/embeddings/Get_embeddings.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/Get_embeddings.ipynb](https://github.com/openai/openai-cookbook/blob/main/examples/Get_embeddings.ipynb)." 8 | ] 9 | } 10 | ], 11 | "metadata": { 12 | "kernelspec": { 13 | "display_name": "Python 3.9.9 ('openai')", 14 | "language": "python", 15 | "name": "python3" 16 | }, 17 | "language_info": { 18 | "codemirror_mode": { 19 | "name": "ipython", 20 | "version": 3 21 | }, 22 | "file_extension": ".py", 23 | "mimetype": "text/x-python", 24 | "name": "python", 25 | "nbconvert_exporter": "python", 26 | "pygments_lexer": "ipython3", 27 | "version": "3.9.9" 28 | }, 29 | "orig_nbformat": 4, 30 | "vscode": { 31 | "interpreter": { 32 | "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97" 33 | } 34 | } 35 | }, 36 | "nbformat": 4, 37 | "nbformat_minor": 2 38 | } 39 | -------------------------------------------------------------------------------- /examples/embeddings/Obtain_dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/Obtain_dataset.ipynb](https://github.com/openai/openai-cookbook/blob/main/examples/Obtain_dataset.ipynb)." 8 | ] 9 | } 10 | ], 11 | "metadata": { 12 | "kernelspec": { 13 | "display_name": "Python 3.9.9 ('openai')", 14 | "language": "python", 15 | "name": "python3" 16 | }, 17 | "language_info": { 18 | "codemirror_mode": { 19 | "name": "ipython", 20 | "version": 3 21 | }, 22 | "file_extension": ".py", 23 | "mimetype": "text/x-python", 24 | "name": "python", 25 | "nbconvert_exporter": "python", 26 | "pygments_lexer": "ipython3", 27 | "version": "3.9.9" 28 | }, 29 | "orig_nbformat": 4, 30 | "vscode": { 31 | "interpreter": { 32 | "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97" 33 | } 34 | } 35 | }, 36 | "nbformat": 4, 37 | "nbformat_minor": 2 38 | } 39 | -------------------------------------------------------------------------------- /examples/embeddings/Recommendation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/Recommendation_using_embeddings.ipynb](https://github.com/openai/openai-cookbook/blob/main/examples/Recommendation_using_embeddings.ipynb)." 8 | ] 9 | } 10 | ], 11 | "metadata": { 12 | "interpreter": { 13 | "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97" 14 | }, 15 | "kernelspec": { 16 | "display_name": "Python 3.9.9 64-bit ('openai': virtualenv)", 17 | "language": "python", 18 | "name": "python3" 19 | }, 20 | "language_info": { 21 | "codemirror_mode": { 22 | "name": "ipython", 23 | "version": 3 24 | }, 25 | "file_extension": ".py", 26 | "mimetype": "text/x-python", 27 | "name": "python", 28 | "nbconvert_exporter": "python", 29 | "pygments_lexer": "ipython3", 30 | "version": "3.9.9" 31 | }, 32 | "orig_nbformat": 4 33 | }, 34 | "nbformat": 4, 35 | "nbformat_minor": 2 36 | } 37 | -------------------------------------------------------------------------------- /examples/embeddings/Regression.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/Regression_using_embeddings.ipynb](https://github.com/openai/openai-cookbook/blob/main/examples/Regression_using_embeddings.ipynb)." 8 | ] 9 | } 10 | ], 11 | "metadata": { 12 | "kernelspec": { 13 | "display_name": "Python 3.9.9 ('openai')", 14 | "language": "python", 15 | "name": "python3" 16 | }, 17 | "language_info": { 18 | "codemirror_mode": { 19 | "name": "ipython", 20 | "version": 3 21 | }, 22 | "file_extension": ".py", 23 | "mimetype": "text/x-python", 24 | "name": "python", 25 | "nbconvert_exporter": "python", 26 | "pygments_lexer": "ipython3", 27 | "version": "3.9.9" 28 | }, 29 | "orig_nbformat": 4, 30 | "vscode": { 31 | "interpreter": { 32 | "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97" 33 | } 34 | } 35 | }, 36 | "nbformat": 4, 37 | "nbformat_minor": 2 38 | } 39 | -------------------------------------------------------------------------------- /examples/embeddings/Visualize_in_3d.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b87d69b2", 6 | "metadata": {}, 7 | "source": [ 8 | "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/Visualizing_embeddings_in_3D.ipynb](https://github.com/openai/openai-cookbook/blob/main/examples/Visualizing_embeddings_in_3D.ipynb)." 9 | ] 10 | } 11 | ], 12 | "metadata": { 13 | "kernelspec": { 14 | "display_name": "Python 3.9.9 ('openai')", 15 | "language": "python", 16 | "name": "python3" 17 | }, 18 | "language_info": { 19 | "codemirror_mode": { 20 | "name": "ipython", 21 | "version": 3 22 | }, 23 | "file_extension": ".py", 24 | "mimetype": "text/x-python", 25 | "name": "python", 26 | "nbconvert_exporter": "python", 27 | "pygments_lexer": "ipython3", 28 | "version": "3.9.9" 29 | }, 30 | "vscode": { 31 | "interpreter": { 32 | "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97" 33 | } 34 | } 35 | }, 36 | "nbformat": 4, 37 | "nbformat_minor": 5 38 | } 39 | -------------------------------------------------------------------------------- /examples/embeddings/Visualize_in_2d.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/Visualizing_embeddings_in_2D.ipynb](https://github.com/openai/openai-cookbook/blob/main/examples/Visualizing_embeddings_in_2D.ipynb)." 8 | ] 9 | } 10 | ], 11 | "metadata": { 12 | "kernelspec": { 13 | "display_name": "Python 3.9.9 ('openai')", 14 | "language": "python", 15 | "name": "python3" 16 | }, 17 | "language_info": { 18 | "codemirror_mode": { 19 | "name": "ipython", 20 | "version": 3 21 | }, 22 | "file_extension": ".py", 23 | "mimetype": "text/x-python", 24 | "name": "python", 25 | "nbconvert_exporter": "python", 26 | "pygments_lexer": "ipython3", 27 | "version": "3.9.9" 28 | }, 29 | "orig_nbformat": 4, 30 | "vscode": { 31 | "interpreter": { 32 | "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97" 33 | } 34 | } 35 | }, 36 | "nbformat": 4, 37 | "nbformat_minor": 2 38 | } 39 | -------------------------------------------------------------------------------- /examples/finetuning/finetuning-classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/Fine-tuned_classification.ipynb](https://github.com/openai/openai-cookbook/blob/main/examples/Fine-tuned_classification.ipynb)." 8 | ] 9 | } 10 | ], 11 | "metadata": { 12 | "kernelspec": { 13 | "display_name": "Python 3.9.9 ('openai')", 14 | "language": "python", 15 | "name": "python3" 16 | }, 17 | "language_info": { 18 | "codemirror_mode": { 19 | "name": "ipython", 20 | "version": 3 21 | }, 22 | "file_extension": ".py", 23 | "mimetype": "text/x-python", 24 | "name": "python", 25 | "nbconvert_exporter": "python", 26 | "pygments_lexer": "ipython3", 27 | "version": "3.9.9" 28 | }, 29 | "orig_nbformat": 4, 30 | "vscode": { 31 | "interpreter": { 32 | "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97" 33 | } 34 | } 35 | }, 36 | "nbformat": 4, 37 | "nbformat_minor": 2 38 | } 39 | -------------------------------------------------------------------------------- /examples/embeddings/Classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/Classification_using_embeddings.ipynb](https://github.com/openai/openai-cookbook/blob/main/examples/Classification_using_embeddings.ipynb)." 8 | ] 9 | } 10 | ], 11 | "metadata": { 12 | "kernelspec": { 13 | "display_name": "Python 3.9.9 ('openai')", 14 | "language": "python", 15 | "name": "python3" 16 | }, 17 | "language_info": { 18 | "codemirror_mode": { 19 | "name": "ipython", 20 | "version": 3 21 | }, 22 | "file_extension": ".py", 23 | "mimetype": "text/x-python", 24 | "name": "python", 25 | "nbconvert_exporter": "python", 26 | "pygments_lexer": "ipython3", 27 | "version": "3.9.9" 28 | }, 29 | "orig_nbformat": 4, 30 | "vscode": { 31 | "interpreter": { 32 | "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97" 33 | } 34 | } 35 | }, 36 | "nbformat": 4, 37 | "nbformat_minor": 2 38 | } 39 | -------------------------------------------------------------------------------- /examples/embeddings/User_and_product_embeddings.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/User_and_product_embeddings.ipynb](https://github.com/openai/openai-cookbook/blob/main/examples/User_and_product_embeddings.ipynb)." 8 | ] 9 | } 10 | ], 11 | "metadata": { 12 | "kernelspec": { 13 | "display_name": "Python 3.9.9 ('openai')", 14 | "language": "python", 15 | "name": "python3" 16 | }, 17 | "language_info": { 18 | "codemirror_mode": { 19 | "name": "ipython", 20 | "version": 3 21 | }, 22 | "file_extension": ".py", 23 | "mimetype": "text/x-python", 24 | "name": "python", 25 | "nbconvert_exporter": "python", 26 | "pygments_lexer": "ipython3", 27 | "version": "3.9.9" 28 | }, 29 | "orig_nbformat": 4, 30 | "vscode": { 31 | "interpreter": { 32 | "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97" 33 | } 34 | } 35 | }, 36 | "nbformat": 4, 37 | "nbformat_minor": 2 38 | } 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) OpenAI (https://openai.com) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/embeddings/Zero-shot_classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/Zero-shot_classification_with_embeddings.ipynb](https://github.com/openai/openai-cookbook/blob/main/examples/Zero-shot_classification_with_embeddings.ipynb)." 8 | ] 9 | } 10 | ], 11 | "metadata": { 12 | "kernelspec": { 13 | "display_name": "Python 3.9.9 ('openai')", 14 | "language": "python", 15 | "name": "python3" 16 | }, 17 | "language_info": { 18 | "codemirror_mode": { 19 | "name": "ipython", 20 | "version": 3 21 | }, 22 | "file_extension": ".py", 23 | "mimetype": "text/x-python", 24 | "name": "python", 25 | "nbconvert_exporter": "python", 26 | "pygments_lexer": "ipython3", 27 | "version": "3.9.9" 28 | }, 29 | "orig_nbformat": 4, 30 | "vscode": { 31 | "interpreter": { 32 | "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97" 33 | } 34 | } 35 | }, 36 | "nbformat": 4, 37 | "nbformat_minor": 2 38 | } 39 | -------------------------------------------------------------------------------- /examples/embeddings/Semantic_text_search_using_embeddings.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "This code example has moved. You can now find it in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook) at [examples/Semantic_text_search_using_embeddings.ipynb](https://github.com/openai/openai-cookbook/blob/main/examples/Semantic_text_search_using_embeddings.ipynb)." 8 | ] 9 | } 10 | ], 11 | "metadata": { 12 | "kernelspec": { 13 | "display_name": "Python 3.9.9 ('openai')", 14 | "language": "python", 15 | "name": "python3" 16 | }, 17 | "language_info": { 18 | "codemirror_mode": { 19 | "name": "ipython", 20 | "version": 3 21 | }, 22 | "file_extension": ".py", 23 | "mimetype": "text/x-python", 24 | "name": "python", 25 | "nbconvert_exporter": "python", 26 | "pygments_lexer": "ipython3", 27 | "version": "3.9.9" 28 | }, 29 | "orig_nbformat": 4, 30 | "vscode": { 31 | "interpreter": { 32 | "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97" 33 | } 34 | } 35 | }, 36 | "nbformat": 4, 37 | "nbformat_minor": 2 38 | } 39 | -------------------------------------------------------------------------------- /openai/api_resources/edit.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from openai import util, error 4 | from openai.api_resources.abstract.engine_api_resource import EngineAPIResource 5 | from openai.error import TryAgain 6 | 7 | 8 | class Edit(EngineAPIResource): 9 | OBJECT_NAME = "edits" 10 | 11 | @classmethod 12 | def create(cls, *args, **kwargs): 13 | """ 14 | Creates a new edit for the provided input, instruction, and parameters. 15 | """ 16 | start = time.time() 17 | timeout = kwargs.pop("timeout", None) 18 | 19 | api_type = kwargs.pop("api_type", None) 20 | typed_api_type = cls._get_api_type_and_version(api_type=api_type)[0] 21 | if typed_api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD): 22 | raise error.InvalidAPIType( 23 | "This operation is not supported by the Azure OpenAI API yet." 24 | ) 25 | 26 | while True: 27 | try: 28 | return super().create(*args, **kwargs) 29 | except TryAgain as e: 30 | if timeout is not None and time.time() > start + timeout: 31 | raise 32 | 33 | util.log_info("Waiting for model to warm up", error=e) 34 | -------------------------------------------------------------------------------- /openai/api_resources/abstract/deletable_api_resource.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import quote_plus 2 | 3 | from openai import error 4 | from openai.api_resources.abstract.api_resource import APIResource 5 | from openai.util import ApiType 6 | 7 | 8 | class DeletableAPIResource(APIResource): 9 | @classmethod 10 | def delete(cls, sid, api_type=None, api_version=None, **params): 11 | if isinstance(cls, APIResource): 12 | raise ValueError(".delete may only be called as a class method now.") 13 | 14 | base = cls.class_url() 15 | extn = quote_plus(sid) 16 | 17 | typed_api_type, api_version = cls._get_api_type_and_version( 18 | api_type, api_version 19 | ) 20 | if typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD): 21 | url = "/%s%s/%s?api-version=%s" % ( 22 | cls.azure_api_prefix, 23 | base, 24 | extn, 25 | api_version, 26 | ) 27 | elif typed_api_type == ApiType.OPEN_AI: 28 | url = "%s/%s" % (base, extn) 29 | else: 30 | raise error.InvalidAPIType("Unsupported API type %s" % api_type) 31 | 32 | return cls._static_request( 33 | "delete", url, api_type=api_type, api_version=api_version, **params 34 | ) 35 | -------------------------------------------------------------------------------- /openai/tests/test_exceptions.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import pytest 4 | 5 | import openai 6 | 7 | EXCEPTION_TEST_CASES = [ 8 | openai.InvalidRequestError( 9 | "message", 10 | "param", 11 | code=400, 12 | http_body={"test": "test1"}, 13 | http_status="fail", 14 | json_body={"text": "iono some text"}, 15 | headers={"request-id": "asasd"}, 16 | ), 17 | openai.error.AuthenticationError(), 18 | openai.error.PermissionError(), 19 | openai.error.RateLimitError(), 20 | openai.error.ServiceUnavailableError(), 21 | openai.error.SignatureVerificationError("message", "sig_header?"), 22 | openai.error.APIConnectionError("message!", should_retry=True), 23 | openai.error.TryAgain(), 24 | openai.error.Timeout(), 25 | openai.error.APIError( 26 | message="message", 27 | code=400, 28 | http_body={"test": "test1"}, 29 | http_status="fail", 30 | json_body={"text": "iono some text"}, 31 | headers={"request-id": "asasd"}, 32 | ), 33 | openai.error.OpenAIError(), 34 | ] 35 | 36 | 37 | class TestExceptions: 38 | @pytest.mark.parametrize("error", EXCEPTION_TEST_CASES) 39 | def test_exceptions_are_pickleable(self, error) -> None: 40 | assert error.__repr__() == pickle.loads(pickle.dumps(error)).__repr__() 41 | -------------------------------------------------------------------------------- /openai/upload_progress.py: -------------------------------------------------------------------------------- 1 | import io 2 | 3 | 4 | class CancelledError(Exception): 5 | def __init__(self, msg): 6 | self.msg = msg 7 | Exception.__init__(self, msg) 8 | 9 | def __str__(self): 10 | return self.msg 11 | 12 | __repr__ = __str__ 13 | 14 | 15 | class BufferReader(io.BytesIO): 16 | def __init__(self, buf=b"", desc=None): 17 | self._len = len(buf) 18 | io.BytesIO.__init__(self, buf) 19 | self._progress = 0 20 | self._callback = progress(len(buf), desc=desc) 21 | 22 | def __len__(self): 23 | return self._len 24 | 25 | def read(self, n=-1): 26 | chunk = io.BytesIO.read(self, n) 27 | self._progress += len(chunk) 28 | if self._callback: 29 | try: 30 | self._callback(self._progress) 31 | except Exception as e: # catches exception from the callback 32 | raise CancelledError("The upload was cancelled: {}".format(e)) 33 | return chunk 34 | 35 | 36 | def progress(total, desc): 37 | import tqdm # type: ignore 38 | 39 | meter = tqdm.tqdm(total=total, unit_scale=True, desc=desc) 40 | 41 | def incr(progress): 42 | meter.n = progress 43 | if progress == total: 44 | meter.close() 45 | else: 46 | meter.refresh() 47 | 48 | return incr 49 | 50 | 51 | def MB(i): 52 | return int(i // 1024**2) 53 | -------------------------------------------------------------------------------- /openai/tests/test_file_cli.py: -------------------------------------------------------------------------------- 1 | import json 2 | import subprocess 3 | import time 4 | from tempfile import NamedTemporaryFile 5 | 6 | STILL_PROCESSING = "File is still processing. Check back later." 7 | 8 | 9 | def test_file_cli() -> None: 10 | contents = json.dumps({"prompt": "1 + 3 =", "completion": "4"}) + "\n" 11 | with NamedTemporaryFile(suffix=".jsonl", mode="wb") as train_file: 12 | train_file.write(contents.encode("utf-8")) 13 | train_file.flush() 14 | create_output = subprocess.check_output( 15 | ["openai", "api", "files.create", "-f", train_file.name, "-p", "fine-tune"] 16 | ) 17 | file_obj = json.loads(create_output) 18 | assert file_obj["bytes"] == len(contents) 19 | file_id: str = file_obj["id"] 20 | assert file_id.startswith("file-") 21 | start_time = time.time() 22 | while True: 23 | delete_result = subprocess.run( 24 | ["openai", "api", "files.delete", "-i", file_id], 25 | stdout=subprocess.PIPE, 26 | stderr=subprocess.PIPE, 27 | encoding="utf-8", 28 | ) 29 | if delete_result.returncode == 0: 30 | break 31 | elif STILL_PROCESSING in delete_result.stderr: 32 | time.sleep(0.5) 33 | if start_time + 60 < time.time(): 34 | raise RuntimeError("timed out waiting for file to become available") 35 | continue 36 | else: 37 | raise RuntimeError( 38 | f"delete failed: stdout={delete_result.stdout} stderr={delete_result.stderr}" 39 | ) 40 | -------------------------------------------------------------------------------- /openai/api_resources/abstract/createable_api_resource.py: -------------------------------------------------------------------------------- 1 | from openai import api_requestor, util, error 2 | from openai.api_resources.abstract.api_resource import APIResource 3 | from openai.util import ApiType 4 | 5 | 6 | class CreateableAPIResource(APIResource): 7 | plain_old_data = False 8 | 9 | @classmethod 10 | def create( 11 | cls, 12 | api_key=None, 13 | api_base=None, 14 | api_type=None, 15 | request_id=None, 16 | api_version=None, 17 | organization=None, 18 | **params, 19 | ): 20 | requestor = api_requestor.APIRequestor( 21 | api_key, 22 | api_base=api_base, 23 | api_type=api_type, 24 | api_version=api_version, 25 | organization=organization, 26 | ) 27 | typed_api_type, api_version = cls._get_api_type_and_version( 28 | api_type, api_version 29 | ) 30 | 31 | if typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD): 32 | base = cls.class_url() 33 | url = "/%s%s?api-version=%s" % (cls.azure_api_prefix, base, api_version) 34 | elif typed_api_type == ApiType.OPEN_AI: 35 | url = cls.class_url() 36 | else: 37 | raise error.InvalidAPIType("Unsupported API type %s" % api_type) 38 | 39 | response, _, api_key = requestor.request( 40 | "post", url, params, request_id=request_id 41 | ) 42 | 43 | return util.convert_to_openai_object( 44 | response, 45 | api_key, 46 | api_version, 47 | organization, 48 | plain_old_data=cls.plain_old_data, 49 | ) 50 | -------------------------------------------------------------------------------- /openai/api_resources/engine.py: -------------------------------------------------------------------------------- 1 | import time 2 | import warnings 3 | 4 | from openai import util 5 | from openai.api_resources.abstract import ListableAPIResource, UpdateableAPIResource 6 | from openai.error import InvalidAPIType, TryAgain 7 | from openai.util import ApiType 8 | 9 | 10 | class Engine(ListableAPIResource, UpdateableAPIResource): 11 | OBJECT_NAME = "engines" 12 | 13 | def generate(self, timeout=None, **params): 14 | start = time.time() 15 | while True: 16 | try: 17 | return self.request( 18 | "post", 19 | self.instance_url() + "/generate", 20 | params, 21 | stream=params.get("stream"), 22 | plain_old_data=True, 23 | ) 24 | except TryAgain as e: 25 | if timeout is not None and time.time() > start + timeout: 26 | raise 27 | 28 | util.log_info("Waiting for model to warm up", error=e) 29 | 30 | def search(self, **params): 31 | if self.typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD): 32 | return self.request("post", self.instance_url("search"), params) 33 | elif self.typed_api_type == ApiType.OPEN_AI: 34 | return self.request("post", self.instance_url() + "/search", params) 35 | else: 36 | raise InvalidAPIType("Unsupported API type %s" % self.api_type) 37 | 38 | def embeddings(self, **params): 39 | warnings.warn( 40 | "Engine.embeddings is deprecated, use Embedding.create", DeprecationWarning 41 | ) 42 | return self.request("post", self.instance_url() + "/embeddings", params) 43 | -------------------------------------------------------------------------------- /openai/tests/test_endpoints.py: -------------------------------------------------------------------------------- 1 | import io 2 | import json 3 | 4 | import pytest 5 | 6 | import openai 7 | from openai import error 8 | 9 | 10 | # FILE TESTS 11 | def test_file_upload(): 12 | result = openai.File.create( 13 | file=io.StringIO(json.dumps({"text": "test file data"})), 14 | purpose="search", 15 | ) 16 | assert result.purpose == "search" 17 | assert "id" in result 18 | 19 | result = openai.File.retrieve(id=result.id) 20 | assert result.status == "uploaded" 21 | 22 | 23 | # COMPLETION TESTS 24 | def test_completions(): 25 | result = openai.Completion.create(prompt="This was a test", n=5, engine="ada") 26 | assert len(result.choices) == 5 27 | 28 | 29 | def test_completions_multiple_prompts(): 30 | result = openai.Completion.create( 31 | prompt=["This was a test", "This was another test"], n=5, engine="ada" 32 | ) 33 | assert len(result.choices) == 10 34 | 35 | 36 | def test_completions_model(): 37 | result = openai.Completion.create(prompt="This was a test", n=5, model="ada") 38 | assert len(result.choices) == 5 39 | assert result.model.startswith("ada") 40 | 41 | 42 | def test_timeout_raises_error(): 43 | # A query that should take awhile to return 44 | with pytest.raises(error.Timeout): 45 | openai.Completion.create( 46 | prompt="test" * 1000, 47 | n=10, 48 | model="ada", 49 | max_tokens=100, 50 | request_timeout=0.01, 51 | ) 52 | 53 | 54 | def test_timeout_does_not_error(): 55 | # A query that should be fast 56 | openai.Completion.create( 57 | prompt="test", 58 | model="ada", 59 | request_timeout=10, 60 | ) 61 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from setuptools import find_packages, setup 4 | 5 | version_contents = {} 6 | version_path = os.path.join( 7 | os.path.abspath(os.path.dirname(__file__)), "openai/version.py" 8 | ) 9 | with open(version_path, "rt") as f: 10 | exec(f.read(), version_contents) 11 | 12 | setup( 13 | name="openai", 14 | description="Python client library for the OpenAI API", 15 | version=version_contents["VERSION"], 16 | install_requires=[ 17 | "requests>=2.20", # to get the patch for CVE-2018-18074 18 | "tqdm", # Needed for progress bars 19 | "pandas>=1.2.3", # Needed for CLI fine-tuning data preparation tool 20 | "pandas-stubs>=1.1.0.11", # Needed for type hints for mypy 21 | "openpyxl>=3.0.7", # Needed for CLI fine-tuning data preparation tool xlsx format 22 | "numpy", 23 | "typing_extensions", # Needed for type hints for mypy 24 | ], 25 | extras_require={ 26 | "dev": ["black~=21.6b0", "pytest==6.*"], 27 | "wandb": ["wandb"], 28 | "embeddings": [ 29 | "scikit-learn>=1.0.2", # Needed for embedding utils, versions >= 1.1 require python 3.8 30 | "tenacity>=8.0.1", 31 | "matplotlib", 32 | "sklearn", 33 | "plotly", 34 | ], 35 | }, 36 | python_requires=">=3.7.1", 37 | entry_points={ 38 | "console_scripts": [ 39 | "openai=openai._openai_scripts:main", 40 | ], 41 | }, 42 | packages=find_packages(exclude=["tests", "tests.*"]), 43 | package_data={ 44 | "openai": [ 45 | "py.typed", 46 | ] 47 | }, 48 | author="OpenAI", 49 | author_email="support@openai.com", 50 | url="https://github.com/openai/openai-python", 51 | ) 52 | -------------------------------------------------------------------------------- /openai/api_resources/abstract/listable_api_resource.py: -------------------------------------------------------------------------------- 1 | from openai import api_requestor, util, error 2 | from openai.api_resources.abstract.api_resource import APIResource 3 | from openai.util import ApiType 4 | 5 | 6 | class ListableAPIResource(APIResource): 7 | @classmethod 8 | def auto_paging_iter(cls, *args, **params): 9 | return cls.list(*args, **params).auto_paging_iter() 10 | 11 | @classmethod 12 | def list( 13 | cls, 14 | api_key=None, 15 | request_id=None, 16 | api_version=None, 17 | organization=None, 18 | api_base=None, 19 | api_type=None, 20 | **params, 21 | ): 22 | requestor = api_requestor.APIRequestor( 23 | api_key, 24 | api_base=api_base or cls.api_base(), 25 | api_version=api_version, 26 | api_type=api_type, 27 | organization=organization, 28 | ) 29 | 30 | typed_api_type, api_version = cls._get_api_type_and_version( 31 | api_type, api_version 32 | ) 33 | 34 | if typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD): 35 | base = cls.class_url() 36 | url = "/%s%s?api-version=%s" % (cls.azure_api_prefix, base, api_version) 37 | elif typed_api_type == ApiType.OPEN_AI: 38 | url = cls.class_url() 39 | else: 40 | raise error.InvalidAPIType("Unsupported API type %s" % api_type) 41 | 42 | response, _, api_key = requestor.request( 43 | "get", url, params, request_id=request_id 44 | ) 45 | openai_object = util.convert_to_openai_object( 46 | response, api_key, api_version, organization 47 | ) 48 | openai_object._retrieve_params = params 49 | return openai_object 50 | -------------------------------------------------------------------------------- /openai/tests/test_long_examples_validator.py: -------------------------------------------------------------------------------- 1 | import json 2 | import subprocess 3 | from tempfile import NamedTemporaryFile 4 | 5 | 6 | def test_long_examples_validator() -> None: 7 | 8 | """ 9 | Ensures that long_examples_validator() handles previously applied recommendations, 10 | namely dropped duplicates, without resulting in a KeyError. 11 | """ 12 | 13 | # data 14 | short_prompt = "a prompt " 15 | long_prompt = short_prompt * 500 16 | 17 | short_completion = "a completion " 18 | long_completion = short_completion * 500 19 | 20 | # the order of these matters 21 | unprepared_training_data = [ 22 | {"prompt": long_prompt, "completion": long_completion}, # 1 of 2 duplicates 23 | {"prompt": short_prompt, "completion": short_completion}, 24 | {"prompt": long_prompt, "completion": long_completion}, # 2 of 2 duplicates 25 | 26 | ] 27 | 28 | with NamedTemporaryFile(suffix="jsonl", mode="w") as training_data: 29 | for prompt_completion_row in unprepared_training_data: 30 | training_data.write(json.dumps(prompt_completion_row) + "\n") 31 | training_data.flush() 32 | 33 | prepared_data_cmd_output = subprocess.run( 34 | [f"openai tools fine_tunes.prepare_data -f {training_data.name}"], 35 | stdout=subprocess.PIPE, 36 | text=True, 37 | input="y\ny\ny\ny\ny", # apply all recommendations, one at a time 38 | stderr=subprocess.PIPE, 39 | encoding="utf-8", 40 | shell=True 41 | ) 42 | 43 | # validate data was prepared successfully 44 | assert prepared_data_cmd_output.stderr == "" 45 | # validate get_long_indexes() applied during optional_fn() call in long_examples_validator() 46 | assert "indices of the long examples has changed" in prepared_data_cmd_output.stdout 47 | 48 | return prepared_data_cmd_output.stdout -------------------------------------------------------------------------------- /openai/api_resources/embedding.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import time 3 | 4 | import numpy as np 5 | 6 | from openai import util 7 | from openai.api_resources.abstract import DeletableAPIResource, ListableAPIResource 8 | from openai.api_resources.abstract.engine_api_resource import EngineAPIResource 9 | from openai.error import TryAgain 10 | 11 | 12 | class Embedding(EngineAPIResource): 13 | OBJECT_NAME = "embeddings" 14 | 15 | @classmethod 16 | def create(cls, *args, **kwargs): 17 | """ 18 | Creates a new embedding for the provided input and parameters. 19 | 20 | See https://beta.openai.com/docs/api-reference/embeddings for a list 21 | of valid parameters. 22 | """ 23 | start = time.time() 24 | timeout = kwargs.pop("timeout", None) 25 | 26 | user_provided_encoding_format = kwargs.get("encoding_format", None) 27 | 28 | # If encoding format was not explicitly specified, we opaquely use base64 for performance 29 | if not user_provided_encoding_format: 30 | kwargs["encoding_format"] = "base64" 31 | 32 | while True: 33 | try: 34 | response = super().create(*args, **kwargs) 35 | 36 | # If a user specifies base64, we'll just return the encoded string. 37 | # This is only for the default case. 38 | if not user_provided_encoding_format: 39 | for data in response.data: 40 | 41 | # If an engine isn't using this optimization, don't do anything 42 | if type(data["embedding"]) == str: 43 | data["embedding"] = np.frombuffer( 44 | base64.b64decode(data["embedding"]), dtype="float32" 45 | ).tolist() 46 | 47 | return response 48 | except TryAgain as e: 49 | if timeout is not None and time.time() > start + timeout: 50 | raise 51 | 52 | util.log_info("Waiting for model to warm up", error=e) 53 | -------------------------------------------------------------------------------- /openai/__init__.py: -------------------------------------------------------------------------------- 1 | # OpenAI Python bindings. 2 | # 3 | # Originally forked from the MIT-licensed Stripe Python bindings. 4 | 5 | import os 6 | from typing import Optional 7 | 8 | from openai.api_resources import ( 9 | Answer, 10 | Classification, 11 | Completion, 12 | Customer, 13 | Edit, 14 | Deployment, 15 | Embedding, 16 | Engine, 17 | ErrorObject, 18 | File, 19 | FineTune, 20 | Image, 21 | Model, 22 | Moderation, 23 | Search, 24 | ) 25 | from openai.error import APIError, InvalidRequestError, OpenAIError 26 | 27 | api_key = os.environ.get("OPENAI_API_KEY") 28 | # Path of a file with an API key, whose contents can change. Supercedes 29 | # `api_key` if set. The main use case is volume-mounted Kubernetes secrets, 30 | # which are updated automatically. 31 | api_key_path: Optional[str] = os.environ.get("OPENAI_API_KEY_PATH") 32 | 33 | organization = os.environ.get("OPENAI_ORGANIZATION") 34 | api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1") 35 | api_type = os.environ.get("OPENAI_API_TYPE", "open_ai") 36 | api_version = ( 37 | "2022-03-01-preview" if api_type in ("azure", "azure_ad", "azuread") else None 38 | ) 39 | verify_ssl_certs = True # No effect. Certificates are always verified. 40 | proxy = None 41 | app_info = None 42 | enable_telemetry = False # Ignored; the telemetry feature was removed. 43 | ca_bundle_path = None # No longer used, feature was removed 44 | debug = False 45 | log = None # Set to either 'debug' or 'info', controls console logging 46 | 47 | __all__ = [ 48 | "APIError", 49 | "Answer", 50 | "Classification", 51 | "Completion", 52 | "Customer", 53 | "Edit", 54 | "Image", 55 | "Deployment", 56 | "Embedding", 57 | "Engine", 58 | "ErrorObject", 59 | "File", 60 | "FineTune", 61 | "InvalidRequestError", 62 | "Model", 63 | "Moderation", 64 | "OpenAIError", 65 | "Search", 66 | "api_base", 67 | "api_key", 68 | "api_type", 69 | "api_key_path", 70 | "api_version", 71 | "app_info", 72 | "ca_bundle_path", 73 | "debug", 74 | "enable_elemetry", 75 | "log", 76 | "organization", 77 | "proxy", 78 | "verify_ssl_certs", 79 | ] 80 | -------------------------------------------------------------------------------- /openai/_openai_scripts.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import logging 4 | import sys 5 | 6 | import openai 7 | from openai.cli import api_register, display_error, tools_register, wandb_register 8 | 9 | logger = logging.getLogger() 10 | formatter = logging.Formatter("[%(asctime)s] %(message)s") 11 | handler = logging.StreamHandler(sys.stderr) 12 | handler.setFormatter(formatter) 13 | logger.addHandler(handler) 14 | 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser(description=None) 18 | parser.add_argument( 19 | "-v", 20 | "--verbose", 21 | action="count", 22 | dest="verbosity", 23 | default=0, 24 | help="Set verbosity.", 25 | ) 26 | parser.add_argument("-b", "--api-base", help="What API base url to use.") 27 | parser.add_argument("-k", "--api-key", help="What API key to use.") 28 | parser.add_argument( 29 | "-o", 30 | "--organization", 31 | help="Which organization to run as (will use your default organization if not specified)", 32 | ) 33 | 34 | def help(args): 35 | parser.print_help() 36 | 37 | parser.set_defaults(func=help) 38 | 39 | subparsers = parser.add_subparsers() 40 | sub_api = subparsers.add_parser("api", help="Direct API calls") 41 | sub_tools = subparsers.add_parser("tools", help="Client side tools for convenience") 42 | sub_wandb = subparsers.add_parser("wandb", help="Logging with Weights & Biases") 43 | 44 | api_register(sub_api) 45 | tools_register(sub_tools) 46 | wandb_register(sub_wandb) 47 | 48 | args = parser.parse_args() 49 | if args.verbosity == 1: 50 | logger.setLevel(logging.INFO) 51 | elif args.verbosity >= 2: 52 | logger.setLevel(logging.DEBUG) 53 | 54 | openai.debug = True 55 | if args.api_key is not None: 56 | openai.api_key = args.api_key 57 | if args.api_base is not None: 58 | openai.api_base = args.api_base 59 | if args.organization is not None: 60 | openai.organization = args.organization 61 | 62 | try: 63 | args.func(args) 64 | except openai.error.OpenAIError as e: 65 | display_error(e) 66 | return 1 67 | except KeyboardInterrupt: 68 | sys.stderr.write("\n") 69 | return 1 70 | return 0 71 | 72 | 73 | if __name__ == "__main__": 74 | sys.exit(main()) 75 | -------------------------------------------------------------------------------- /openai/tests/test_api_requestor.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import pytest 4 | import requests 5 | from pytest_mock import MockerFixture 6 | 7 | from openai import Model 8 | from openai.api_requestor import APIRequestor 9 | 10 | 11 | @pytest.mark.requestor 12 | def test_requestor_sets_request_id(mocker: MockerFixture) -> None: 13 | # Fake out 'requests' and confirm that the X-Request-Id header is set. 14 | 15 | got_headers = {} 16 | 17 | def fake_request(self, *args, **kwargs): 18 | nonlocal got_headers 19 | got_headers = kwargs["headers"] 20 | r = requests.Response() 21 | r.status_code = 200 22 | r.headers["content-type"] = "application/json" 23 | r._content = json.dumps({}).encode("utf-8") 24 | return r 25 | 26 | mocker.patch("requests.sessions.Session.request", fake_request) 27 | fake_request_id = "1234" 28 | Model.retrieve("xxx", request_id=fake_request_id) # arbitrary API resource 29 | got_request_id = got_headers.get("X-Request-Id") 30 | assert got_request_id == fake_request_id 31 | 32 | 33 | @pytest.mark.requestor 34 | def test_requestor_open_ai_headers() -> None: 35 | api_requestor = APIRequestor(key="test_key", api_type="open_ai") 36 | headers = {"Test_Header": "Unit_Test_Header"} 37 | headers = api_requestor.request_headers( 38 | method="get", extra=headers, request_id="test_id" 39 | ) 40 | assert "Test_Header" in headers 41 | assert headers["Test_Header"] == "Unit_Test_Header" 42 | assert "Authorization" in headers 43 | assert headers["Authorization"] == "Bearer test_key" 44 | 45 | 46 | @pytest.mark.requestor 47 | def test_requestor_azure_headers() -> None: 48 | api_requestor = APIRequestor(key="test_key", api_type="azure") 49 | headers = {"Test_Header": "Unit_Test_Header"} 50 | headers = api_requestor.request_headers( 51 | method="get", extra=headers, request_id="test_id" 52 | ) 53 | assert "Test_Header" in headers 54 | assert headers["Test_Header"] == "Unit_Test_Header" 55 | assert "api-key" in headers 56 | assert headers["api-key"] == "test_key" 57 | 58 | 59 | @pytest.mark.requestor 60 | def test_requestor_azure_ad_headers() -> None: 61 | api_requestor = APIRequestor(key="test_key", api_type="azure_ad") 62 | headers = {"Test_Header": "Unit_Test_Header"} 63 | headers = api_requestor.request_headers( 64 | method="get", extra=headers, request_id="test_id" 65 | ) 66 | assert "Test_Header" in headers 67 | assert headers["Test_Header"] == "Unit_Test_Header" 68 | assert "Authorization" in headers 69 | assert headers["Authorization"] == "Bearer test_key" 70 | -------------------------------------------------------------------------------- /openai/api_resources/image.py: -------------------------------------------------------------------------------- 1 | # WARNING: This interface is considered experimental and may changed in the future without warning. 2 | from typing import Any, List 3 | 4 | import openai 5 | from openai import api_requestor, util 6 | from openai.api_resources.abstract import APIResource 7 | 8 | 9 | class Image(APIResource): 10 | OBJECT_NAME = "images" 11 | 12 | @classmethod 13 | def _get_url(cls, action): 14 | return cls.class_url() + f"/{action}" 15 | 16 | @classmethod 17 | def create( 18 | cls, 19 | **params, 20 | ): 21 | instance = cls() 22 | return instance.request("post", cls._get_url("generations"), params) 23 | 24 | @classmethod 25 | def create_variation( 26 | cls, 27 | image, 28 | api_key=None, 29 | api_base=None, 30 | api_type=None, 31 | api_version=None, 32 | organization=None, 33 | **params, 34 | ): 35 | requestor = api_requestor.APIRequestor( 36 | api_key, 37 | api_base=api_base or openai.api_base, 38 | api_type=api_type, 39 | api_version=api_version, 40 | organization=organization, 41 | ) 42 | _, api_version = cls._get_api_type_and_version(api_type, api_version) 43 | 44 | url = cls._get_url("variations") 45 | 46 | files: List[Any] = [] 47 | for key, value in params.items(): 48 | files.append((key, (None, value))) 49 | files.append(("image", ("image", image, "application/octet-stream"))) 50 | 51 | response, _, api_key = requestor.request("post", url, files=files) 52 | 53 | return util.convert_to_openai_object( 54 | response, api_key, api_version, organization 55 | ) 56 | 57 | @classmethod 58 | def create_edit( 59 | cls, 60 | image, 61 | mask, 62 | api_key=None, 63 | api_base=None, 64 | api_type=None, 65 | api_version=None, 66 | organization=None, 67 | **params, 68 | ): 69 | requestor = api_requestor.APIRequestor( 70 | api_key, 71 | api_base=api_base or openai.api_base, 72 | api_type=api_type, 73 | api_version=api_version, 74 | organization=organization, 75 | ) 76 | _, api_version = cls._get_api_type_and_version(api_type, api_version) 77 | 78 | url = cls._get_url("edits") 79 | 80 | files: List[Any] = [] 81 | for key, value in params.items(): 82 | files.append((key, (None, value))) 83 | files.append(("image", ("image", image, "application/octet-stream"))) 84 | files.append(("mask", ("mask", mask, "application/octet-stream"))) 85 | 86 | response, _, api_key = requestor.request("post", url, files=files) 87 | 88 | return util.convert_to_openai_object( 89 | response, api_key, api_version, organization 90 | ) 91 | -------------------------------------------------------------------------------- /openai/api_resources/deployment.py: -------------------------------------------------------------------------------- 1 | from openai import util 2 | from openai.api_resources.abstract import ( 3 | DeletableAPIResource, 4 | ListableAPIResource, 5 | CreateableAPIResource, 6 | ) 7 | from openai.error import InvalidRequestError, APIError 8 | 9 | 10 | class Deployment(CreateableAPIResource, ListableAPIResource, DeletableAPIResource): 11 | OBJECT_NAME = "deployments" 12 | 13 | @classmethod 14 | def create(cls, *args, **kwargs): 15 | """ 16 | Creates a new deployment for the provided prompt and parameters. 17 | """ 18 | typed_api_type, _ = cls._get_api_type_and_version( 19 | kwargs.get("api_type", None), None 20 | ) 21 | if typed_api_type not in (util.ApiType.AZURE, util.ApiType.AZURE_AD): 22 | raise APIError( 23 | "Deployment operations are only available for the Azure API type." 24 | ) 25 | 26 | if kwargs.get("model", None) is None: 27 | raise InvalidRequestError( 28 | "Must provide a 'model' parameter to create a Deployment.", 29 | param="model", 30 | ) 31 | 32 | scale_settings = kwargs.get("scale_settings", None) 33 | if scale_settings is None: 34 | raise InvalidRequestError( 35 | "Must provide a 'scale_settings' parameter to create a Deployment.", 36 | param="scale_settings", 37 | ) 38 | 39 | if "scale_type" not in scale_settings or ( 40 | scale_settings["scale_type"].lower() == "manual" 41 | and "capacity" not in scale_settings 42 | ): 43 | raise InvalidRequestError( 44 | "The 'scale_settings' parameter contains invalid or incomplete values.", 45 | param="scale_settings", 46 | ) 47 | 48 | return super().create(*args, **kwargs) 49 | 50 | @classmethod 51 | def list(cls, *args, **kwargs): 52 | typed_api_type, _ = cls._get_api_type_and_version( 53 | kwargs.get("api_type", None), None 54 | ) 55 | if typed_api_type not in (util.ApiType.AZURE, util.ApiType.AZURE_AD): 56 | raise APIError( 57 | "Deployment operations are only available for the Azure API type." 58 | ) 59 | 60 | return super().list(*args, **kwargs) 61 | 62 | @classmethod 63 | def delete(cls, *args, **kwargs): 64 | typed_api_type, _ = cls._get_api_type_and_version( 65 | kwargs.get("api_type", None), None 66 | ) 67 | if typed_api_type not in (util.ApiType.AZURE, util.ApiType.AZURE_AD): 68 | raise APIError( 69 | "Deployment operations are only available for the Azure API type." 70 | ) 71 | 72 | return super().delete(*args, **kwargs) 73 | 74 | @classmethod 75 | def retrieve(cls, *args, **kwargs): 76 | typed_api_type, _ = cls._get_api_type_and_version( 77 | kwargs.get("api_type", None), None 78 | ) 79 | if typed_api_type not in (util.ApiType.AZURE, util.ApiType.AZURE_AD): 80 | raise APIError( 81 | "Deployment operations are only available for the Azure API type." 82 | ) 83 | 84 | return super().retrieve(*args, **kwargs) 85 | -------------------------------------------------------------------------------- /openai/api_resources/fine_tune.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import quote_plus 2 | 3 | from openai import api_requestor, util, error 4 | from openai.api_resources.abstract import ( 5 | CreateableAPIResource, 6 | ListableAPIResource, 7 | nested_resource_class_methods, 8 | ) 9 | from openai.api_resources.abstract.deletable_api_resource import DeletableAPIResource 10 | from openai.openai_response import OpenAIResponse 11 | from openai.util import ApiType 12 | 13 | 14 | @nested_resource_class_methods("event", operations=["list"]) 15 | class FineTune(ListableAPIResource, CreateableAPIResource, DeletableAPIResource): 16 | OBJECT_NAME = "fine-tunes" 17 | 18 | @classmethod 19 | def cancel( 20 | cls, 21 | id, 22 | api_key=None, 23 | api_type=None, 24 | request_id=None, 25 | api_version=None, 26 | **params, 27 | ): 28 | base = cls.class_url() 29 | extn = quote_plus(id) 30 | 31 | typed_api_type, api_version = cls._get_api_type_and_version( 32 | api_type, api_version 33 | ) 34 | if typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD): 35 | url = "/%s%s/%s/cancel?api-version=%s" % ( 36 | cls.azure_api_prefix, 37 | base, 38 | extn, 39 | api_version, 40 | ) 41 | elif typed_api_type == ApiType.OPEN_AI: 42 | url = "%s/%s/cancel" % (base, extn) 43 | else: 44 | raise error.InvalidAPIType("Unsupported API type %s" % api_type) 45 | 46 | instance = cls(id, api_key, **params) 47 | return instance.request("post", url, request_id=request_id) 48 | 49 | @classmethod 50 | def stream_events( 51 | cls, 52 | id, 53 | api_key=None, 54 | api_base=None, 55 | api_type=None, 56 | request_id=None, 57 | api_version=None, 58 | organization=None, 59 | **params, 60 | ): 61 | base = cls.class_url() 62 | extn = quote_plus(id) 63 | 64 | requestor = api_requestor.APIRequestor( 65 | api_key, 66 | api_base=api_base, 67 | api_type=api_type, 68 | api_version=api_version, 69 | organization=organization, 70 | ) 71 | 72 | typed_api_type, api_version = cls._get_api_type_and_version( 73 | api_type, api_version 74 | ) 75 | 76 | if typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD): 77 | url = "/%s%s/%s/events?stream=true&api-version=%s" % ( 78 | cls.azure_api_prefix, 79 | base, 80 | extn, 81 | api_version, 82 | ) 83 | elif typed_api_type == ApiType.OPEN_AI: 84 | url = "%s/%s/events?stream=true" % (base, extn) 85 | else: 86 | raise error.InvalidAPIType("Unsupported API type %s" % api_type) 87 | 88 | response, _, api_key = requestor.request( 89 | "get", url, params, stream=True, request_id=request_id 90 | ) 91 | 92 | assert not isinstance(response, OpenAIResponse) # must be an iterator 93 | return ( 94 | util.convert_to_openai_object( 95 | line, 96 | api_key, 97 | api_version, 98 | organization, 99 | ) 100 | for line in response 101 | ) 102 | -------------------------------------------------------------------------------- /openai/api_resources/abstract/nested_resource_class_methods.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import quote_plus 2 | 3 | from openai import api_requestor, util 4 | 5 | 6 | def nested_resource_class_methods( 7 | resource, path=None, operations=None, resource_plural=None 8 | ): 9 | if resource_plural is None: 10 | resource_plural = "%ss" % resource 11 | if path is None: 12 | path = resource_plural 13 | if operations is None: 14 | raise ValueError("operations list required") 15 | 16 | def wrapper(cls): 17 | def nested_resource_url(cls, id, nested_id=None): 18 | url = "%s/%s/%s" % (cls.class_url(), quote_plus(id), quote_plus(path)) 19 | if nested_id is not None: 20 | url += "/%s" % quote_plus(nested_id) 21 | return url 22 | 23 | resource_url_method = "%ss_url" % resource 24 | setattr(cls, resource_url_method, classmethod(nested_resource_url)) 25 | 26 | def nested_resource_request( 27 | cls, 28 | method, 29 | url, 30 | api_key=None, 31 | request_id=None, 32 | api_version=None, 33 | organization=None, 34 | **params, 35 | ): 36 | requestor = api_requestor.APIRequestor( 37 | api_key, api_version=api_version, organization=organization 38 | ) 39 | response, _, api_key = requestor.request( 40 | method, url, params, request_id=request_id 41 | ) 42 | return util.convert_to_openai_object( 43 | response, api_key, api_version, organization 44 | ) 45 | 46 | resource_request_method = "%ss_request" % resource 47 | setattr(cls, resource_request_method, classmethod(nested_resource_request)) 48 | 49 | for operation in operations: 50 | if operation == "create": 51 | 52 | def create_nested_resource(cls, id, **params): 53 | url = getattr(cls, resource_url_method)(id) 54 | return getattr(cls, resource_request_method)("post", url, **params) 55 | 56 | create_method = "create_%s" % resource 57 | setattr(cls, create_method, classmethod(create_nested_resource)) 58 | 59 | elif operation == "retrieve": 60 | 61 | def retrieve_nested_resource(cls, id, nested_id, **params): 62 | url = getattr(cls, resource_url_method)(id, nested_id) 63 | return getattr(cls, resource_request_method)("get", url, **params) 64 | 65 | retrieve_method = "retrieve_%s" % resource 66 | setattr(cls, retrieve_method, classmethod(retrieve_nested_resource)) 67 | 68 | elif operation == "update": 69 | 70 | def modify_nested_resource(cls, id, nested_id, **params): 71 | url = getattr(cls, resource_url_method)(id, nested_id) 72 | return getattr(cls, resource_request_method)("post", url, **params) 73 | 74 | modify_method = "modify_%s" % resource 75 | setattr(cls, modify_method, classmethod(modify_nested_resource)) 76 | 77 | elif operation == "delete": 78 | 79 | def delete_nested_resource(cls, id, nested_id, **params): 80 | url = getattr(cls, resource_url_method)(id, nested_id) 81 | return getattr(cls, resource_request_method)( 82 | "delete", url, **params 83 | ) 84 | 85 | delete_method = "delete_%s" % resource 86 | setattr(cls, delete_method, classmethod(delete_nested_resource)) 87 | 88 | elif operation == "list": 89 | 90 | def list_nested_resources(cls, id, **params): 91 | url = getattr(cls, resource_url_method)(id) 92 | return getattr(cls, resource_request_method)("get", url, **params) 93 | 94 | list_method = "list_%s" % resource_plural 95 | setattr(cls, list_method, classmethod(list_nested_resources)) 96 | 97 | else: 98 | raise ValueError("Unknown operation: %s" % operation) 99 | 100 | return cls 101 | 102 | return wrapper 103 | -------------------------------------------------------------------------------- /openai/api_resources/abstract/api_resource.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import quote_plus 2 | 3 | import openai 4 | from openai import api_requestor, error, util 5 | from openai.openai_object import OpenAIObject 6 | from openai.util import ApiType 7 | from typing import Optional 8 | 9 | 10 | class APIResource(OpenAIObject): 11 | api_prefix = "" 12 | azure_api_prefix = "openai" 13 | azure_deployments_prefix = "deployments" 14 | 15 | @classmethod 16 | def retrieve( 17 | cls, id, api_key=None, request_id=None, request_timeout=None, **params 18 | ): 19 | instance = cls(id, api_key, **params) 20 | instance.refresh(request_id=request_id, request_timeout=request_timeout) 21 | return instance 22 | 23 | def refresh(self, request_id=None, request_timeout=None): 24 | self.refresh_from( 25 | self.request( 26 | "get", 27 | self.instance_url(), 28 | request_id=request_id, 29 | request_timeout=request_timeout, 30 | ) 31 | ) 32 | return self 33 | 34 | @classmethod 35 | def class_url(cls): 36 | if cls == APIResource: 37 | raise NotImplementedError( 38 | "APIResource is an abstract class. You should perform actions on its subclasses." 39 | ) 40 | # Namespaces are separated in object names with periods (.) and in URLs 41 | # with forward slashes (/), so replace the former with the latter. 42 | base = cls.OBJECT_NAME.replace(".", "/") # type: ignore 43 | if cls.api_prefix: 44 | return "/%s/%s" % (cls.api_prefix, base) 45 | return "/%s" % (base) 46 | 47 | def instance_url(self, operation=None): 48 | id = self.get("id") 49 | 50 | if not isinstance(id, str): 51 | raise error.InvalidRequestError( 52 | "Could not determine which URL to request: %s instance " 53 | "has invalid ID: %r, %s. ID should be of type `str` (or" 54 | " `unicode`)" % (type(self).__name__, id, type(id)), 55 | "id", 56 | ) 57 | api_version = self.api_version or openai.api_version 58 | extn = quote_plus(id) 59 | 60 | if self.typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD): 61 | if not api_version: 62 | raise error.InvalidRequestError( 63 | "An API version is required for the Azure API type." 64 | ) 65 | 66 | if not operation: 67 | base = self.class_url() 68 | return "/%s%s/%s?api-version=%s" % ( 69 | self.azure_api_prefix, 70 | base, 71 | extn, 72 | api_version, 73 | ) 74 | 75 | return "/%s/%s/%s/%s?api-version=%s" % ( 76 | self.azure_api_prefix, 77 | self.azure_deployments_prefix, 78 | extn, 79 | operation, 80 | api_version, 81 | ) 82 | 83 | elif self.typed_api_type == ApiType.OPEN_AI: 84 | base = self.class_url() 85 | return "%s/%s" % (base, extn) 86 | 87 | else: 88 | raise error.InvalidAPIType("Unsupported API type %s" % self.api_type) 89 | 90 | # The `method_` and `url_` arguments are suffixed with an underscore to 91 | # avoid conflicting with actual request parameters in `params`. 92 | @classmethod 93 | def _static_request( 94 | cls, 95 | method_, 96 | url_, 97 | api_key=None, 98 | api_base=None, 99 | api_type=None, 100 | request_id=None, 101 | api_version=None, 102 | organization=None, 103 | **params, 104 | ): 105 | requestor = api_requestor.APIRequestor( 106 | api_key, 107 | api_version=api_version, 108 | organization=organization, 109 | api_base=api_base, 110 | api_type=api_type, 111 | ) 112 | response, _, api_key = requestor.request( 113 | method_, url_, params, request_id=request_id 114 | ) 115 | return util.convert_to_openai_object( 116 | response, api_key, api_version, organization 117 | ) 118 | 119 | @classmethod 120 | def _get_api_type_and_version( 121 | cls, api_type: Optional[str] = None, api_version: Optional[str] = None 122 | ): 123 | typed_api_type = ( 124 | ApiType.from_str(api_type) 125 | if api_type 126 | else ApiType.from_str(openai.api_type) 127 | ) 128 | typed_api_version = api_version or openai.api_version 129 | return (typed_api_type, typed_api_version) 130 | -------------------------------------------------------------------------------- /openai/api_resources/file.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import cast 4 | 5 | import openai 6 | from openai import api_requestor, util, error 7 | from openai.api_resources.abstract import DeletableAPIResource, ListableAPIResource 8 | from openai.util import ApiType 9 | 10 | 11 | class File(ListableAPIResource, DeletableAPIResource): 12 | OBJECT_NAME = "files" 13 | 14 | @classmethod 15 | def create( 16 | cls, 17 | file, 18 | purpose, 19 | model=None, 20 | api_key=None, 21 | api_base=None, 22 | api_type=None, 23 | api_version=None, 24 | organization=None, 25 | user_provided_filename=None, 26 | ): 27 | if purpose != "search" and model is not None: 28 | raise ValueError("'model' is only meaningful if 'purpose' is 'search'") 29 | requestor = api_requestor.APIRequestor( 30 | api_key, 31 | api_base=api_base or openai.api_base, 32 | api_type=api_type, 33 | api_version=api_version, 34 | organization=organization, 35 | ) 36 | typed_api_type, api_version = cls._get_api_type_and_version( 37 | api_type, api_version 38 | ) 39 | 40 | if typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD): 41 | base = cls.class_url() 42 | url = "/%s%s?api-version=%s" % (cls.azure_api_prefix, base, api_version) 43 | elif typed_api_type == ApiType.OPEN_AI: 44 | url = cls.class_url() 45 | else: 46 | raise error.InvalidAPIType("Unsupported API type %s" % api_type) 47 | 48 | # Set the filename on 'purpose' and 'model' to None so they are 49 | # interpreted as form data. 50 | files = [("purpose", (None, purpose))] 51 | if model is not None: 52 | files.append(("model", (None, model))) 53 | if user_provided_filename is not None: 54 | files.append( 55 | ("file", (user_provided_filename, file, "application/octet-stream")) 56 | ) 57 | else: 58 | files.append(("file", ("file", file, "application/octet-stream"))) 59 | response, _, api_key = requestor.request("post", url, files=files) 60 | return util.convert_to_openai_object( 61 | response, api_key, api_version, organization 62 | ) 63 | 64 | @classmethod 65 | def download( 66 | cls, 67 | id, 68 | api_key=None, 69 | api_base=None, 70 | api_type=None, 71 | api_version=None, 72 | organization=None, 73 | ): 74 | requestor = api_requestor.APIRequestor( 75 | api_key, 76 | api_base=api_base or openai.api_base, 77 | api_type=api_type, 78 | api_version=api_version, 79 | organization=organization, 80 | ) 81 | typed_api_type, api_version = cls._get_api_type_and_version( 82 | api_type, api_version 83 | ) 84 | 85 | if typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD): 86 | base = cls.class_url() 87 | url = "/%s%s/%s/content?api-version=%s" % ( 88 | cls.azure_api_prefix, 89 | base, 90 | id, 91 | api_version, 92 | ) 93 | elif typed_api_type == ApiType.OPEN_AI: 94 | url = f"{cls.class_url()}/{id}/content" 95 | else: 96 | raise error.InvalidAPIType("Unsupported API type %s" % api_type) 97 | 98 | result = requestor.request_raw("get", url) 99 | if not 200 <= result.status_code < 300: 100 | raise requestor.handle_error_response( 101 | result.content, 102 | result.status_code, 103 | json.loads(cast(bytes, result.content)), 104 | result.headers, 105 | stream_error=False, 106 | ) 107 | return result.content 108 | 109 | @classmethod 110 | def find_matching_files( 111 | cls, 112 | name, 113 | bytes, 114 | purpose, 115 | api_key=None, 116 | api_base=None, 117 | api_type=None, 118 | api_version=None, 119 | organization=None, 120 | ): 121 | """Find already uploaded files with the same name, size, and purpose.""" 122 | all_files = cls.list( 123 | api_key=api_key, 124 | api_base=api_base or openai.api_base, 125 | api_type=api_type, 126 | api_version=api_version, 127 | organization=organization, 128 | ).get("data", []) 129 | matching_files = [] 130 | basename = os.path.basename(name) 131 | for f in all_files: 132 | if f["purpose"] != purpose: 133 | continue 134 | file_basename = os.path.basename(f["filename"]) 135 | if file_basename != basename: 136 | continue 137 | if "bytes" in f and f["bytes"] != bytes: 138 | continue 139 | if "size" in f and int(f["size"]) != bytes: 140 | continue 141 | matching_files.append(f) 142 | return matching_files 143 | -------------------------------------------------------------------------------- /openai/error.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import os 3 | 4 | display_cause = os.environ["OPENAI_ERROR_DISPLAY_CAUSE"] 5 | 6 | 7 | class OpenAIError(Exception): 8 | def __init__( 9 | self, 10 | message=None, 11 | http_body=None, 12 | http_status=None, 13 | json_body=None, 14 | headers=None, 15 | code=None, 16 | ): 17 | super(OpenAIError, self).__init__(message) 18 | 19 | if http_body and hasattr(http_body, "decode"): 20 | try: 21 | http_body = http_body.decode("utf-8") 22 | except BaseException: 23 | http_body = ( 24 | "" 26 | ) 27 | 28 | self._message = message 29 | self.http_body = http_body 30 | self.http_status = http_status 31 | self.json_body = json_body 32 | self.headers = headers or {} 33 | self.code = code 34 | self.request_id = self.headers.get("request-id", None) 35 | self.error = self.construct_error_object() 36 | self.organization = self.headers.get("openai-organization", None) 37 | 38 | def __str__(self): 39 | msg = self._message or "" 40 | if display_cause is not None and hasattr(self, "__cause__") and self.__cause__ is not None: 41 | msg += " (Cause: {0})".format(self.__cause__) 42 | if self.request_id is not None: 43 | return "Request {0}: {1}".format(self.request_id, msg) 44 | else: 45 | return msg 46 | 47 | # Returns the underlying `Exception` (base class) message, which is usually 48 | # the raw message returned by OpenAI's API. This was previously available 49 | # in python2 via `error.message`. Unlike `str(error)`, it omits "Request 50 | # req_..." from the beginning of the string. 51 | @property 52 | def user_message(self): 53 | return self._message 54 | 55 | def __repr__(self): 56 | return "%s(message=%r, http_status=%r, request_id=%r)" % ( 57 | self.__class__.__name__, 58 | self._message, 59 | self.http_status, 60 | self.request_id, 61 | ) 62 | 63 | def construct_error_object(self): 64 | if ( 65 | self.json_body is None 66 | or "error" not in self.json_body 67 | or not isinstance(self.json_body["error"], dict) 68 | ): 69 | return None 70 | 71 | return openai.api_resources.error_object.ErrorObject.construct_from( 72 | self.json_body["error"] 73 | ) 74 | 75 | 76 | class APIError(OpenAIError): 77 | pass 78 | 79 | 80 | class TryAgain(OpenAIError): 81 | pass 82 | 83 | 84 | class Timeout(OpenAIError): 85 | pass 86 | 87 | 88 | class APIConnectionError(OpenAIError): 89 | def __init__( 90 | self, 91 | message, 92 | http_body=None, 93 | http_status=None, 94 | json_body=None, 95 | headers=None, 96 | code=None, 97 | should_retry=False, 98 | ): 99 | super(APIConnectionError, self).__init__( 100 | message, http_body, http_status, json_body, headers, code 101 | ) 102 | self.should_retry = should_retry 103 | 104 | 105 | class InvalidRequestError(OpenAIError): 106 | def __init__( 107 | self, 108 | message, 109 | param, 110 | code=None, 111 | http_body=None, 112 | http_status=None, 113 | json_body=None, 114 | headers=None, 115 | ): 116 | super(InvalidRequestError, self).__init__( 117 | message, http_body, http_status, json_body, headers, code 118 | ) 119 | self.param = param 120 | 121 | def __repr__(self): 122 | return "%s(message=%r, param=%r, code=%r, http_status=%r, " "request_id=%r)" % ( 123 | self.__class__.__name__, 124 | self._message, 125 | self.param, 126 | self.code, 127 | self.http_status, 128 | self.request_id, 129 | ) 130 | 131 | def __reduce__(self): 132 | return type(self), ( 133 | self._message, 134 | self.param, 135 | self.code, 136 | self.http_body, 137 | self.http_status, 138 | self.json_body, 139 | self.headers, 140 | ) 141 | 142 | 143 | class AuthenticationError(OpenAIError): 144 | pass 145 | 146 | 147 | class PermissionError(OpenAIError): 148 | pass 149 | 150 | 151 | class RateLimitError(OpenAIError): 152 | pass 153 | 154 | 155 | class ServiceUnavailableError(OpenAIError): 156 | pass 157 | 158 | 159 | class InvalidAPIType(OpenAIError): 160 | pass 161 | 162 | 163 | class SignatureVerificationError(OpenAIError): 164 | def __init__(self, message, sig_header, http_body=None): 165 | super(SignatureVerificationError, self).__init__(message, http_body) 166 | self.sig_header = sig_header 167 | 168 | def __reduce__(self): 169 | return type(self), ( 170 | self._message, 171 | self.sig_header, 172 | self.http_body, 173 | ) 174 | -------------------------------------------------------------------------------- /openai/util.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import re 4 | import sys 5 | from enum import Enum 6 | from typing import Optional 7 | 8 | import openai 9 | 10 | OPENAI_LOG = os.environ.get("OPENAI_LOG") 11 | 12 | logger = logging.getLogger("openai") 13 | 14 | __all__ = [ 15 | "log_info", 16 | "log_debug", 17 | "log_warn", 18 | "logfmt", 19 | ] 20 | 21 | api_key_to_header = ( 22 | lambda api, key: {"Authorization": f"Bearer {key}"} 23 | if api in (ApiType.OPEN_AI, ApiType.AZURE_AD) 24 | else {"api-key": f"{key}"} 25 | ) 26 | 27 | 28 | class ApiType(Enum): 29 | AZURE = 1 30 | OPEN_AI = 2 31 | AZURE_AD = 3 32 | 33 | @staticmethod 34 | def from_str(label): 35 | if label.lower() == "azure": 36 | return ApiType.AZURE 37 | elif label.lower() in ("azure_ad", "azuread"): 38 | return ApiType.AZURE_AD 39 | elif label.lower() in ("open_ai", "openai"): 40 | return ApiType.OPEN_AI 41 | else: 42 | raise openai.error.InvalidAPIType( 43 | "The API type provided in invalid. Please select one of the supported API types: 'azure', 'azure_ad', 'open_ai'" 44 | ) 45 | 46 | 47 | def _console_log_level(): 48 | if openai.log in ["debug", "info"]: 49 | return openai.log 50 | elif OPENAI_LOG in ["debug", "info"]: 51 | return OPENAI_LOG 52 | else: 53 | return None 54 | 55 | 56 | def log_debug(message, **params): 57 | msg = logfmt(dict(message=message, **params)) 58 | if _console_log_level() == "debug": 59 | print(msg, file=sys.stderr) 60 | logger.debug(msg) 61 | 62 | 63 | def log_info(message, **params): 64 | msg = logfmt(dict(message=message, **params)) 65 | if _console_log_level() in ["debug", "info"]: 66 | print(msg, file=sys.stderr) 67 | logger.info(msg) 68 | 69 | 70 | def log_warn(message, **params): 71 | msg = logfmt(dict(message=message, **params)) 72 | print(msg, file=sys.stderr) 73 | logger.warn(msg) 74 | 75 | 76 | def logfmt(props): 77 | def fmt(key, val): 78 | # Handle case where val is a bytes or bytesarray 79 | if hasattr(val, "decode"): 80 | val = val.decode("utf-8") 81 | # Check if val is already a string to avoid re-encoding into ascii. 82 | if not isinstance(val, str): 83 | val = str(val) 84 | if re.search(r"\s", val): 85 | val = repr(val) 86 | # key should already be a string 87 | if re.search(r"\s", key): 88 | key = repr(key) 89 | return "{key}={val}".format(key=key, val=val) 90 | 91 | return " ".join([fmt(key, val) for key, val in sorted(props.items())]) 92 | 93 | 94 | def get_object_classes(): 95 | # This is here to avoid a circular dependency 96 | from openai.object_classes import OBJECT_CLASSES 97 | 98 | return OBJECT_CLASSES 99 | 100 | 101 | def convert_to_openai_object( 102 | resp, 103 | api_key=None, 104 | api_version=None, 105 | organization=None, 106 | engine=None, 107 | plain_old_data=False, 108 | ): 109 | # If we get a OpenAIResponse, we'll want to return a OpenAIObject. 110 | 111 | response_ms: Optional[int] = None 112 | if isinstance(resp, openai.openai_response.OpenAIResponse): 113 | organization = resp.organization 114 | response_ms = resp.response_ms 115 | resp = resp.data 116 | 117 | if plain_old_data: 118 | return resp 119 | elif isinstance(resp, list): 120 | return [ 121 | convert_to_openai_object( 122 | i, api_key, api_version, organization, engine=engine 123 | ) 124 | for i in resp 125 | ] 126 | elif isinstance(resp, dict) and not isinstance( 127 | resp, openai.openai_object.OpenAIObject 128 | ): 129 | resp = resp.copy() 130 | klass_name = resp.get("object") 131 | if isinstance(klass_name, str): 132 | klass = get_object_classes().get( 133 | klass_name, openai.openai_object.OpenAIObject 134 | ) 135 | else: 136 | klass = openai.openai_object.OpenAIObject 137 | 138 | return klass.construct_from( 139 | resp, 140 | api_key=api_key, 141 | api_version=api_version, 142 | organization=organization, 143 | response_ms=response_ms, 144 | engine=engine, 145 | ) 146 | else: 147 | return resp 148 | 149 | 150 | def convert_to_dict(obj): 151 | """Converts a OpenAIObject back to a regular dict. 152 | 153 | Nested OpenAIObjects are also converted back to regular dicts. 154 | 155 | :param obj: The OpenAIObject to convert. 156 | 157 | :returns: The OpenAIObject as a dict. 158 | """ 159 | if isinstance(obj, list): 160 | return [convert_to_dict(i) for i in obj] 161 | # This works by virtue of the fact that OpenAIObjects _are_ dicts. The dict 162 | # comprehension returns a regular dict and recursively applies the 163 | # conversion to each value. 164 | elif isinstance(obj, dict): 165 | return {k: convert_to_dict(v) for k, v in obj.items()} 166 | else: 167 | return obj 168 | 169 | 170 | def merge_dicts(x, y): 171 | z = x.copy() 172 | z.update(y) 173 | return z 174 | 175 | 176 | def default_api_key() -> str: 177 | if openai.api_key_path: 178 | with open(openai.api_key_path, "rt") as k: 179 | api_key = k.read().strip() 180 | if not api_key.startswith("sk-"): 181 | raise ValueError(f"Malformed API key in {openai.api_key_path}.") 182 | return api_key 183 | elif openai.api_key is not None: 184 | return openai.api_key 185 | else: 186 | raise openai.error.AuthenticationError( 187 | "No API key provided. You can set your API key in code using 'openai.api_key = ', or you can set the environment variable OPENAI_API_KEY=). If your API key is stored in a file, you can point the openai module at it with 'openai.api_key_path = '. You can generate API keys in the OpenAI web interface. See https://onboard.openai.com for details, or email support@openai.com if you have any questions." 188 | ) 189 | -------------------------------------------------------------------------------- /openai/tests/test_url_composition.py: -------------------------------------------------------------------------------- 1 | from sys import api_version 2 | 3 | import pytest 4 | 5 | from openai import Completion, Engine 6 | from openai.util import ApiType 7 | 8 | 9 | @pytest.mark.url 10 | def test_completions_url_composition_azure() -> None: 11 | url = Completion.class_url("test_engine", "azure", "2021-11-01-preview") 12 | assert ( 13 | url 14 | == "/openai/deployments/test_engine/completions?api-version=2021-11-01-preview" 15 | ) 16 | 17 | 18 | @pytest.mark.url 19 | def test_completions_url_composition_azure_ad() -> None: 20 | url = Completion.class_url("test_engine", "azure_ad", "2021-11-01-preview") 21 | assert ( 22 | url 23 | == "/openai/deployments/test_engine/completions?api-version=2021-11-01-preview" 24 | ) 25 | 26 | 27 | @pytest.mark.url 28 | def test_completions_url_composition_default() -> None: 29 | url = Completion.class_url("test_engine") 30 | assert url == "/engines/test_engine/completions" 31 | 32 | 33 | @pytest.mark.url 34 | def test_completions_url_composition_open_ai() -> None: 35 | url = Completion.class_url("test_engine", "open_ai") 36 | assert url == "/engines/test_engine/completions" 37 | 38 | 39 | @pytest.mark.url 40 | def test_completions_url_composition_invalid_type() -> None: 41 | with pytest.raises(Exception): 42 | url = Completion.class_url("test_engine", "invalid") 43 | 44 | 45 | @pytest.mark.url 46 | def test_completions_url_composition_instance_url_azure() -> None: 47 | completion = Completion( 48 | id="test_id", 49 | engine="test_engine", 50 | api_type="azure", 51 | api_version="2021-11-01-preview", 52 | ) 53 | url = completion.instance_url() 54 | assert ( 55 | url 56 | == "/openai/deployments/test_engine/completions/test_id?api-version=2021-11-01-preview" 57 | ) 58 | 59 | 60 | @pytest.mark.url 61 | def test_completions_url_composition_instance_url_azure_ad() -> None: 62 | completion = Completion( 63 | id="test_id", 64 | engine="test_engine", 65 | api_type="azure_ad", 66 | api_version="2021-11-01-preview", 67 | ) 68 | url = completion.instance_url() 69 | assert ( 70 | url 71 | == "/openai/deployments/test_engine/completions/test_id?api-version=2021-11-01-preview" 72 | ) 73 | 74 | 75 | @pytest.mark.url 76 | def test_completions_url_composition_instance_url_azure_no_version() -> None: 77 | completion = Completion( 78 | id="test_id", engine="test_engine", api_type="azure", api_version=None 79 | ) 80 | with pytest.raises(Exception): 81 | completion.instance_url() 82 | 83 | 84 | @pytest.mark.url 85 | def test_completions_url_composition_instance_url_default() -> None: 86 | completion = Completion(id="test_id", engine="test_engine") 87 | url = completion.instance_url() 88 | assert url == "/engines/test_engine/completions/test_id" 89 | 90 | 91 | @pytest.mark.url 92 | def test_completions_url_composition_instance_url_open_ai() -> None: 93 | completion = Completion( 94 | id="test_id", 95 | engine="test_engine", 96 | api_type="open_ai", 97 | api_version="2021-11-01-preview", 98 | ) 99 | url = completion.instance_url() 100 | assert url == "/engines/test_engine/completions/test_id" 101 | 102 | 103 | @pytest.mark.url 104 | def test_completions_url_composition_instance_url_invalid() -> None: 105 | completion = Completion(id="test_id", engine="test_engine", api_type="invalid") 106 | with pytest.raises(Exception): 107 | url = completion.instance_url() 108 | 109 | 110 | @pytest.mark.url 111 | def test_completions_url_composition_instance_url_timeout_azure() -> None: 112 | completion = Completion( 113 | id="test_id", 114 | engine="test_engine", 115 | api_type="azure", 116 | api_version="2021-11-01-preview", 117 | ) 118 | completion["timeout"] = 12 119 | url = completion.instance_url() 120 | assert ( 121 | url 122 | == "/openai/deployments/test_engine/completions/test_id?api-version=2021-11-01-preview&timeout=12" 123 | ) 124 | 125 | 126 | @pytest.mark.url 127 | def test_completions_url_composition_instance_url_timeout_openai() -> None: 128 | completion = Completion(id="test_id", engine="test_engine", api_type="open_ai") 129 | completion["timeout"] = 12 130 | url = completion.instance_url() 131 | assert url == "/engines/test_engine/completions/test_id?timeout=12" 132 | 133 | 134 | @pytest.mark.url 135 | def test_engine_search_url_composition_azure() -> None: 136 | engine = Engine(id="test_id", api_type="azure", api_version="2021-11-01-preview") 137 | assert engine.api_type == "azure" 138 | assert engine.typed_api_type == ApiType.AZURE 139 | url = engine.instance_url("test_operation") 140 | assert ( 141 | url 142 | == "/openai/deployments/test_id/test_operation?api-version=2021-11-01-preview" 143 | ) 144 | 145 | 146 | @pytest.mark.url 147 | def test_engine_search_url_composition_azure_ad() -> None: 148 | engine = Engine(id="test_id", api_type="azure_ad", api_version="2021-11-01-preview") 149 | assert engine.api_type == "azure_ad" 150 | assert engine.typed_api_type == ApiType.AZURE_AD 151 | url = engine.instance_url("test_operation") 152 | assert ( 153 | url 154 | == "/openai/deployments/test_id/test_operation?api-version=2021-11-01-preview" 155 | ) 156 | 157 | 158 | @pytest.mark.url 159 | def test_engine_search_url_composition_azure_no_version() -> None: 160 | engine = Engine(id="test_id", api_type="azure", api_version=None) 161 | assert engine.api_type == "azure" 162 | assert engine.typed_api_type == ApiType.AZURE 163 | with pytest.raises(Exception): 164 | engine.instance_url("test_operation") 165 | 166 | 167 | @pytest.mark.url 168 | def test_engine_search_url_composition_azure_no_operation() -> None: 169 | engine = Engine(id="test_id", api_type="azure", api_version="2021-11-01-preview") 170 | assert engine.api_type == "azure" 171 | assert engine.typed_api_type == ApiType.AZURE 172 | assert ( 173 | engine.instance_url() 174 | == "/openai/engines/test_id?api-version=2021-11-01-preview" 175 | ) 176 | 177 | 178 | @pytest.mark.url 179 | def test_engine_search_url_composition_default() -> None: 180 | engine = Engine(id="test_id") 181 | assert engine.api_type == None 182 | assert engine.typed_api_type == ApiType.OPEN_AI 183 | url = engine.instance_url() 184 | assert url == "/engines/test_id" 185 | 186 | 187 | @pytest.mark.url 188 | def test_engine_search_url_composition_open_ai() -> None: 189 | engine = Engine(id="test_id", api_type="open_ai") 190 | assert engine.api_type == "open_ai" 191 | assert engine.typed_api_type == ApiType.OPEN_AI 192 | url = engine.instance_url() 193 | assert url == "/engines/test_id" 194 | 195 | 196 | @pytest.mark.url 197 | def test_engine_search_url_composition_invalid_type() -> None: 198 | engine = Engine(id="test_id", api_type="invalid") 199 | assert engine.api_type == "invalid" 200 | with pytest.raises(Exception): 201 | assert engine.typed_api_type == ApiType.OPEN_AI 202 | 203 | 204 | @pytest.mark.url 205 | def test_engine_search_url_composition_invalid_search() -> None: 206 | engine = Engine(id="test_id", api_type="invalid") 207 | assert engine.api_type == "invalid" 208 | with pytest.raises(Exception): 209 | engine.search() 210 | -------------------------------------------------------------------------------- /openai/api_resources/abstract/engine_api_resource.py: -------------------------------------------------------------------------------- 1 | import time 2 | from pydoc import apropos 3 | from typing import Optional 4 | from urllib.parse import quote_plus 5 | 6 | import openai 7 | from openai import api_requestor, error, util 8 | from openai.api_resources.abstract.api_resource import APIResource 9 | from openai.openai_response import OpenAIResponse 10 | from openai.util import ApiType 11 | 12 | MAX_TIMEOUT = 20 13 | 14 | 15 | class EngineAPIResource(APIResource): 16 | plain_old_data = False 17 | 18 | def __init__(self, engine: Optional[str] = None, **kwargs): 19 | super().__init__(engine=engine, **kwargs) 20 | 21 | @classmethod 22 | def class_url( 23 | cls, 24 | engine: Optional[str] = None, 25 | api_type: Optional[str] = None, 26 | api_version: Optional[str] = None, 27 | ): 28 | # Namespaces are separated in object names with periods (.) and in URLs 29 | # with forward slashes (/), so replace the former with the latter. 30 | base = cls.OBJECT_NAME.replace(".", "/") # type: ignore 31 | typed_api_type, api_version = cls._get_api_type_and_version( 32 | api_type, api_version 33 | ) 34 | 35 | if typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD): 36 | if not api_version: 37 | raise error.InvalidRequestError( 38 | "An API version is required for the Azure API type." 39 | ) 40 | if engine is None: 41 | raise error.InvalidRequestError( 42 | "You must provide the deployment name in the 'engine' parameter to access the Azure OpenAI service" 43 | ) 44 | extn = quote_plus(engine) 45 | return "/%s/%s/%s/%s?api-version=%s" % ( 46 | cls.azure_api_prefix, 47 | cls.azure_deployments_prefix, 48 | extn, 49 | base, 50 | api_version, 51 | ) 52 | 53 | elif typed_api_type == ApiType.OPEN_AI: 54 | if engine is None: 55 | return "/%s" % (base) 56 | 57 | extn = quote_plus(engine) 58 | return "/engines/%s/%s" % (extn, base) 59 | 60 | else: 61 | raise error.InvalidAPIType("Unsupported API type %s" % api_type) 62 | 63 | @classmethod 64 | def create( 65 | cls, 66 | api_key=None, 67 | api_base=None, 68 | api_type=None, 69 | request_id=None, 70 | api_version=None, 71 | organization=None, 72 | **params, 73 | ): 74 | deployment_id = params.pop("deployment_id", None) 75 | engine = params.pop("engine", deployment_id) 76 | model = params.get("model", None) 77 | timeout = params.pop("timeout", None) 78 | stream = params.get("stream", False) 79 | headers = params.pop("headers", None) 80 | request_timeout = params.pop("request_timeout", None) 81 | typed_api_type = cls._get_api_type_and_version(api_type=api_type)[0] 82 | if typed_api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD): 83 | if deployment_id is None and engine is None: 84 | raise error.InvalidRequestError( 85 | "Must provide an 'engine' or 'deployment_id' parameter to create a %s" 86 | % cls, 87 | "engine", 88 | ) 89 | else: 90 | if model is None and engine is None: 91 | raise error.InvalidRequestError( 92 | "Must provide an 'engine' or 'model' parameter to create a %s" 93 | % cls, 94 | "engine", 95 | ) 96 | 97 | if timeout is None: 98 | # No special timeout handling 99 | pass 100 | elif timeout > 0: 101 | # API only supports timeouts up to MAX_TIMEOUT 102 | params["timeout"] = min(timeout, MAX_TIMEOUT) 103 | timeout = (timeout - params["timeout"]) or None 104 | elif timeout == 0: 105 | params["timeout"] = MAX_TIMEOUT 106 | 107 | requestor = api_requestor.APIRequestor( 108 | api_key, 109 | api_base=api_base, 110 | api_type=api_type, 111 | api_version=api_version, 112 | organization=organization, 113 | ) 114 | url = cls.class_url(engine, api_type, api_version) 115 | response, _, api_key = requestor.request( 116 | "post", 117 | url, 118 | params=params, 119 | headers=headers, 120 | stream=stream, 121 | request_id=request_id, 122 | request_timeout=request_timeout, 123 | ) 124 | 125 | if stream: 126 | # must be an iterator 127 | assert not isinstance(response, OpenAIResponse) 128 | return ( 129 | util.convert_to_openai_object( 130 | line, 131 | api_key, 132 | api_version, 133 | organization, 134 | engine=engine, 135 | plain_old_data=cls.plain_old_data, 136 | ) 137 | for line in response 138 | ) 139 | else: 140 | obj = util.convert_to_openai_object( 141 | response, 142 | api_key, 143 | api_version, 144 | organization, 145 | engine=engine, 146 | plain_old_data=cls.plain_old_data, 147 | ) 148 | 149 | if timeout is not None: 150 | obj.wait(timeout=timeout or None) 151 | 152 | return obj 153 | 154 | def instance_url(self): 155 | id = self.get("id") 156 | 157 | if not isinstance(id, str): 158 | raise error.InvalidRequestError( 159 | f"Could not determine which URL to request: {type(self).__name__} instance has invalid ID: {id}, {type(id)}. ID should be of type str.", 160 | "id", 161 | ) 162 | 163 | extn = quote_plus(id) 164 | params_connector = "?" 165 | 166 | if self.typed_api_type in (ApiType.AZURE, ApiType.AZURE_AD): 167 | api_version = self.api_version or openai.api_version 168 | if not api_version: 169 | raise error.InvalidRequestError( 170 | "An API version is required for the Azure API type." 171 | ) 172 | base = self.OBJECT_NAME.replace(".", "/") 173 | url = "/%s/%s/%s/%s/%s?api-version=%s" % ( 174 | self.azure_api_prefix, 175 | self.azure_deployments_prefix, 176 | self.engine, 177 | base, 178 | extn, 179 | api_version, 180 | ) 181 | params_connector = "&" 182 | 183 | elif self.typed_api_type == ApiType.OPEN_AI: 184 | base = self.class_url(self.engine, self.api_type, self.api_version) 185 | url = "%s/%s" % (base, extn) 186 | 187 | else: 188 | raise error.InvalidAPIType("Unsupported API type %s" % self.api_type) 189 | 190 | timeout = self.get("timeout") 191 | if timeout is not None: 192 | timeout = quote_plus(str(timeout)) 193 | url += params_connector + "timeout={}".format(timeout) 194 | return url 195 | 196 | def wait(self, timeout=None): 197 | start = time.time() 198 | while self.status != "complete": 199 | self.timeout = ( 200 | min(timeout + start - time.time(), MAX_TIMEOUT) 201 | if timeout is not None 202 | else MAX_TIMEOUT 203 | ) 204 | if self.timeout < 0: 205 | del self.timeout 206 | break 207 | self.refresh() 208 | return self 209 | -------------------------------------------------------------------------------- /openai/embeddings_utils.py: -------------------------------------------------------------------------------- 1 | import textwrap as tr 2 | from typing import List, Optional 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import pandas as pd 7 | import plotly.express as px 8 | from scipy import spatial 9 | from sklearn.decomposition import PCA 10 | from sklearn.manifold import TSNE 11 | from sklearn.metrics import average_precision_score, precision_recall_curve 12 | from tenacity import retry, stop_after_attempt, wait_random_exponential 13 | 14 | import openai 15 | 16 | 17 | @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) 18 | def get_embedding(text: str, engine="text-similarity-davinci-001") -> List[float]: 19 | 20 | # replace newlines, which can negatively affect performance. 21 | text = text.replace("\n", " ") 22 | 23 | return openai.Embedding.create(input=[text], engine=engine)["data"][0]["embedding"] 24 | 25 | 26 | @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) 27 | def get_embeddings( 28 | list_of_text: List[str], engine="text-similarity-babbage-001" 29 | ) -> List[List[float]]: 30 | assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048." 31 | 32 | # replace newlines, which can negatively affect performance. 33 | list_of_text = [text.replace("\n", " ") for text in list_of_text] 34 | 35 | data = openai.Embedding.create(input=list_of_text, engine=engine).data 36 | data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. 37 | return [d["embedding"] for d in data] 38 | 39 | 40 | def cosine_similarity(a, b): 41 | return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) 42 | 43 | 44 | def plot_multiclass_precision_recall( 45 | y_score, y_true_untransformed, class_list, classifier_name 46 | ): 47 | """ 48 | Precision-Recall plotting for a multiclass problem. It plots average precision-recall, per class precision recall and reference f1 contours. 49 | 50 | Code slightly modified, but heavily based on https://scikit-learn.org/stable/auto_examples/model_selection/plot_precision_recall.html 51 | """ 52 | n_classes = len(class_list) 53 | y_true = pd.concat( 54 | [(y_true_untransformed == class_list[i]) for i in range(n_classes)], axis=1 55 | ).values 56 | 57 | # For each class 58 | precision = dict() 59 | recall = dict() 60 | average_precision = dict() 61 | for i in range(n_classes): 62 | precision[i], recall[i], _ = precision_recall_curve(y_true[:, i], y_score[:, i]) 63 | average_precision[i] = average_precision_score(y_true[:, i], y_score[:, i]) 64 | 65 | # A "micro-average": quantifying score on all classes jointly 66 | precision_micro, recall_micro, _ = precision_recall_curve( 67 | y_true.ravel(), y_score.ravel() 68 | ) 69 | average_precision_micro = average_precision_score(y_true, y_score, average="micro") 70 | print( 71 | str(classifier_name) 72 | + " - Average precision score over all classes: {0:0.2f}".format( 73 | average_precision_micro 74 | ) 75 | ) 76 | 77 | # setup plot details 78 | plt.figure(figsize=(9, 10)) 79 | f_scores = np.linspace(0.2, 0.8, num=4) 80 | lines = [] 81 | labels = [] 82 | for f_score in f_scores: 83 | x = np.linspace(0.01, 1) 84 | y = f_score * x / (2 * x - f_score) 85 | (l,) = plt.plot(x[y >= 0], y[y >= 0], color="gray", alpha=0.2) 86 | plt.annotate("f1={0:0.1f}".format(f_score), xy=(0.9, y[45] + 0.02)) 87 | 88 | lines.append(l) 89 | labels.append("iso-f1 curves") 90 | (l,) = plt.plot(recall_micro, precision_micro, color="gold", lw=2) 91 | lines.append(l) 92 | labels.append( 93 | "average Precision-recall (auprc = {0:0.2f})" "".format(average_precision_micro) 94 | ) 95 | 96 | for i in range(n_classes): 97 | (l,) = plt.plot(recall[i], precision[i], lw=2) 98 | lines.append(l) 99 | labels.append( 100 | "Precision-recall for class `{0}` (auprc = {1:0.2f})" 101 | "".format(class_list[i], average_precision[i]) 102 | ) 103 | 104 | fig = plt.gcf() 105 | fig.subplots_adjust(bottom=0.25) 106 | plt.xlim([0.0, 1.0]) 107 | plt.ylim([0.0, 1.05]) 108 | plt.xlabel("Recall") 109 | plt.ylabel("Precision") 110 | plt.title(f"{classifier_name}: Precision-Recall curve for each class") 111 | plt.legend(lines, labels) 112 | 113 | 114 | def distances_from_embeddings( 115 | query_embedding: List[float], 116 | embeddings: List[List[float]], 117 | distance_metric="cosine", 118 | ) -> List[List]: 119 | """Return the distances between a query embedding and a list of embeddings.""" 120 | distance_metrics = { 121 | "cosine": spatial.distance.cosine, 122 | "L1": spatial.distance.cityblock, 123 | "L2": spatial.distance.euclidean, 124 | "Linf": spatial.distance.chebyshev, 125 | } 126 | distances = [ 127 | distance_metrics[distance_metric](query_embedding, embedding) 128 | for embedding in embeddings 129 | ] 130 | return distances 131 | 132 | 133 | def indices_of_nearest_neighbors_from_distances(distances) -> np.ndarray: 134 | """Return a list of indices of nearest neighbors from a list of distances.""" 135 | return np.argsort(distances) 136 | 137 | 138 | def pca_components_from_embeddings( 139 | embeddings: List[List[float]], n_components=2 140 | ) -> np.ndarray: 141 | """Return the PCA components of a list of embeddings.""" 142 | pca = PCA(n_components=n_components) 143 | array_of_embeddings = np.array(embeddings) 144 | return pca.fit_transform(array_of_embeddings) 145 | 146 | 147 | def tsne_components_from_embeddings( 148 | embeddings: List[List[float]], n_components=2, **kwargs 149 | ) -> np.ndarray: 150 | """Returns t-SNE components of a list of embeddings.""" 151 | # use better defaults if not specified 152 | if "init" not in kwargs.keys(): 153 | kwargs["init"] = "pca" 154 | if "learning_rate" not in kwargs.keys(): 155 | kwargs["learning_rate"] = "auto" 156 | tsne = TSNE(n_components=n_components, **kwargs) 157 | array_of_embeddings = np.array(embeddings) 158 | return tsne.fit_transform(array_of_embeddings) 159 | 160 | 161 | def chart_from_components( 162 | components: np.ndarray, 163 | labels: Optional[List[str]] = None, 164 | strings: Optional[List[str]] = None, 165 | x_title="Component 0", 166 | y_title="Component 1", 167 | mark_size=5, 168 | **kwargs, 169 | ): 170 | """Return an interactive 2D chart of embedding components.""" 171 | empty_list = ["" for _ in components] 172 | data = pd.DataFrame( 173 | { 174 | x_title: components[:, 0], 175 | y_title: components[:, 1], 176 | "label": labels if labels else empty_list, 177 | "string": ["
".join(tr.wrap(string, width=30)) for string in strings] 178 | if strings 179 | else empty_list, 180 | } 181 | ) 182 | chart = px.scatter( 183 | data, 184 | x=x_title, 185 | y=y_title, 186 | color="label" if labels else None, 187 | symbol="label" if labels else None, 188 | hover_data=["string"] if strings else None, 189 | **kwargs, 190 | ).update_traces(marker=dict(size=mark_size)) 191 | return chart 192 | 193 | 194 | def chart_from_components_3D( 195 | components: np.ndarray, 196 | labels: Optional[List[str]] = None, 197 | strings: Optional[List[str]] = None, 198 | x_title: str = "Component 0", 199 | y_title: str = "Component 1", 200 | z_title: str = "Compontent 2", 201 | mark_size: int = 5, 202 | **kwargs, 203 | ): 204 | """Return an interactive 3D chart of embedding components.""" 205 | empty_list = ["" for _ in components] 206 | data = pd.DataFrame( 207 | { 208 | x_title: components[:, 0], 209 | y_title: components[:, 1], 210 | z_title: components[:, 2], 211 | "label": labels if labels else empty_list, 212 | "string": ["
".join(tr.wrap(string, width=30)) for string in strings] 213 | if strings 214 | else empty_list, 215 | } 216 | ) 217 | chart = px.scatter_3d( 218 | data, 219 | x=x_title, 220 | y=y_title, 221 | z=z_title, 222 | color="label" if labels else None, 223 | symbol="label" if labels else None, 224 | hover_data=["string"] if strings else None, 225 | **kwargs, 226 | ).update_traces(marker=dict(size=mark_size)) 227 | return chart 228 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OpenAI Python Library 2 | 3 | The OpenAI Python library provides convenient access to the OpenAI API 4 | from applications written in the Python language. It includes a 5 | pre-defined set of classes for API resources that initialize 6 | themselves dynamically from API responses which makes it compatible 7 | with a wide range of versions of the OpenAI API. 8 | 9 | ## Documentation 10 | 11 | See the [OpenAI API docs](https://beta.openai.com/docs/api-reference?lang=python). 12 | 13 | ## Installation 14 | 15 | You don't need this source code unless you want to modify the package. If you just 16 | want to use the package, just run: 17 | 18 | ```sh 19 | pip install --upgrade openai 20 | ``` 21 | 22 | Install from source with: 23 | 24 | ```sh 25 | python setup.py install 26 | ``` 27 | 28 | ## Usage 29 | 30 | The library needs to be configured with your account's secret key which is available on the [website](https://beta.openai.com/account/api-keys). Either set it as the `OPENAI_API_KEY` environment variable before using the library: 31 | 32 | ```bash 33 | export OPENAI_API_KEY='sk-...' 34 | ``` 35 | 36 | Or set `openai.api_key` to its value: 37 | 38 | ```python 39 | import openai 40 | openai.api_key = "sk-..." 41 | 42 | # list engines 43 | engines = openai.Engine.list() 44 | 45 | # print the first engine's id 46 | print(engines.data[0].id) 47 | 48 | # create a completion 49 | completion = openai.Completion.create(engine="ada", prompt="Hello world") 50 | 51 | # print the completion 52 | print(completion.choices[0].text) 53 | ``` 54 | 55 | 56 | ### Params 57 | All endpoints have a `.create` method that support a `request_timeout` param. This param takes a `Union[float, Tuple[float, float]]` and will raise a `openai.error.TimeoutError` error if the request exceeds that time in seconds (See: https://requests.readthedocs.io/en/latest/user/quickstart/#timeouts). 58 | 59 | ### Microsoft Azure Endpoints 60 | 61 | In order to use the library with Microsoft Azure endpoints, you need to set the api_type, api_base and api_version in addition to the api_key. The api_type must be set to 'azure' and the others correspond to the properties of your endpoint. 62 | In addition, the deployment name must be passed as the engine parameter. 63 | 64 | ```python 65 | import openai 66 | openai.api_type = "azure" 67 | openai.api_key = "..." 68 | openai.api_base = "https://example-endpoint.openai.azure.com" 69 | openai.api_version = "2021-11-01-preview" 70 | 71 | # create a completion 72 | completion = openai.Completion.create(engine="deployment-name", prompt="Hello world") 73 | 74 | # print the completion 75 | print(completion.choices[0].text) 76 | 77 | # create a search and pass the deployment-name as the engine Id. 78 | search = openai.Engine(id="deployment-name").search(documents=["White House", "hospital", "school"], query ="the president") 79 | 80 | # print the search 81 | print(search) 82 | ``` 83 | 84 | Please note that for the moment, the Microsoft Azure endpoints can only be used for completion, search and fine-tuning operations. 85 | For a detailed example on how to use fine-tuning and other operations using Azure endpoints, please check out the following Jupyter notebooks: 86 | * [Using Azure fine-tuning](https://github.com/openai/openai-cookbook/tree/main/examples/azure/finetuning.ipynb) 87 | * [Using Azure embeddings](https://github.com/openai/openai-cookbook/blob/main/examples/azure/embeddings.ipynb) 88 | 89 | ### Microsoft Azure Active Directory Authentication 90 | 91 | In order to use Microsoft Active Directory to authenticate to your Azure endpoint, you need to set the api_type to "azure_ad" and pass the acquired credential token to api_key. The rest of the parameters need to be set as specified in the previous section. 92 | 93 | 94 | ```python 95 | from azure.identity import DefaultAzureCredential 96 | import openai 97 | 98 | # Request credential 99 | default_credential = DefaultAzureCredential() 100 | token = default_credential.get_token("https://cognitiveservices.azure.com") 101 | 102 | # Setup parameters 103 | openai.api_type = "azure_ad" 104 | openai.api_key = token.token 105 | openai.api_base = "https://example-endpoint.openai.azure.com/" 106 | openai.api_version = "2022-03-01-preview" 107 | 108 | # ... 109 | ``` 110 | ### Command-line interface 111 | 112 | This library additionally provides an `openai` command-line utility 113 | which makes it easy to interact with the API from your terminal. Run 114 | `openai api -h` for usage. 115 | 116 | ```sh 117 | # list engines 118 | openai api engines.list 119 | 120 | # create a completion 121 | openai api completions.create -e ada -p "Hello world" 122 | 123 | # generate images via DALL·E API 124 | openai api image.create -p "two dogs playing chess, cartoon" -n 1 125 | ``` 126 | 127 | ## Example code 128 | 129 | Examples of how to use this Python library to accomplish various tasks can be found in the [OpenAI Cookbook](https://github.com/openai/openai-cookbook/). It contains code examples for: 130 | 131 | * Classification using fine-tuning 132 | * Clustering 133 | * Code search 134 | * Customizing embeddings 135 | * Question answering from a corpus of documents 136 | * Recommendations 137 | * Visualization of embeddings 138 | * And more 139 | 140 | Prior to July 2022, this OpenAI Python library hosted code examples in its examples folder, but since then all examples have been migrated to the [OpenAI Cookbook](https://github.com/openai/openai-cookbook/). 141 | 142 | ### Embeddings 143 | 144 | In the OpenAI Python library, an embedding represents a text string as a fixed-length vector of floating point numbers. Embeddings are designed to measure the similarity or relevance between text strings. 145 | 146 | To get an embedding for a text string, you can use the embeddings method as follows in Python: 147 | 148 | ```python 149 | import openai 150 | openai.api_key = "sk-..." # supply your API key however you choose 151 | 152 | # choose text to embed 153 | text_string = "sample text" 154 | 155 | # choose an embedding 156 | model_id = "text-similarity-davinci-001" 157 | 158 | # compute the embedding of the text 159 | embedding = openai.Embedding.create(input=text_string, engine=model_id)['data'][0]['embedding'] 160 | ``` 161 | 162 | An example of how to call the embeddings method is shown in this [get embeddings notebook](https://github.com/openai/openai-cookbook/blob/main/examples/Get_embeddings.ipynb). 163 | 164 | Examples of how to use embeddings are shared in the following Jupyter notebooks: 165 | 166 | - [Classification using embeddings](https://github.com/openai/openai-cookbook/blob/main/examples/Classification_using_embeddings.ipynb) 167 | - [Clustering using embeddings](https://github.com/openai/openai-cookbook/blob/main/examples/Clustering.ipynb) 168 | - [Code search using embeddings](https://github.com/openai/openai-cookbook/blob/main/examples/Code_search.ipynb) 169 | - [Semantic text search using embeddings](https://github.com/openai/openai-cookbook/blob/main/examples/Semantic_text_search_using_embeddings.ipynb) 170 | - [User and product embeddings](https://github.com/openai/openai-cookbook/blob/main/examples/User_and_product_embeddings.ipynb) 171 | - [Zero-shot classification using embeddings](https://github.com/openai/openai-cookbook/blob/main/examples/Zero-shot_classification_with_embeddings.ipynb) 172 | - [Recommendation using embeddings](https://github.com/openai/openai-cookbook/blob/main/examples/Recommendation_using_embeddings.ipynb) 173 | 174 | For more information on embeddings and the types of embeddings OpenAI offers, read the [embeddings guide](https://beta.openai.com/docs/guides/embeddings) in the OpenAI documentation. 175 | 176 | ### Fine tuning 177 | 178 | Fine tuning a model on training data can both improve the results (by giving the model more examples to learn from) and reduce the cost/latency of API calls (chiefly through reducing the need to include training examples in prompts). 179 | 180 | Examples of fine tuning are shared in the following Jupyter notebooks: 181 | 182 | - [Classification with fine tuning](https://github.com/openai/openai-cookbook/blob/main/examples/Fine-tuned_classification.ipynb) (a simple notebook that shows the steps required for fine tuning) 183 | - Fine tuning a model that answers questions about the 2020 Olympics 184 | - [Step 1: Collecting data](https://github.com/openai/openai-cookbook/blob/main/examples/fine-tuned_qa/olympics-1-collect-data.ipynb) 185 | - [Step 2: Creating a synthetic Q&A dataset](https://github.com/openai/openai-cookbook/blob/main/examples/fine-tuned_qa/olympics-2-create-qa.ipynb) 186 | - [Step 3: Train a fine-tuning model specialized for Q&A](https://github.com/openai/openai-cookbook/blob/main/examples/fine-tuned_qa/olympics-3-train-qa.ipynb) 187 | 188 | Sync your fine-tunes to [Weights & Biases](https://wandb.me/openai-docs) to track experiments, models, and datasets in your central dashboard with: 189 | 190 | ```bash 191 | openai wandb sync 192 | ``` 193 | 194 | For more information on fine tuning, read the [fine-tuning guide](https://beta.openai.com/docs/guides/fine-tuning) in the OpenAI documentation. 195 | 196 | ## Image generation (DALL·E) 197 | 198 | ```python 199 | import openai 200 | openai.api_key = "sk-..." # supply your API key however you choose 201 | 202 | image_resp = openai.Image.create(prompt="two dogs playing chess, oil painting", n=4, size="512x512") 203 | 204 | ``` 205 | 206 | See the [usage guide](https://beta.openai.com/docs/guides/images) for more details. 207 | 208 | ## Requirements 209 | 210 | - Python 3.7.1+ 211 | 212 | In general, we want to support the versions of Python that our 213 | customers are using. If you run into problems with any version 214 | issues, please let us know at support@openai.com. 215 | 216 | ## Credit 217 | 218 | This library is forked from the [Stripe Python Library](https://github.com/stripe/stripe-python). 219 | -------------------------------------------------------------------------------- /openai/openai_object.py: -------------------------------------------------------------------------------- 1 | import json 2 | from copy import deepcopy 3 | from typing import Optional, Tuple, Union 4 | 5 | import openai 6 | from openai import api_requestor, util 7 | from openai.openai_response import OpenAIResponse 8 | from openai.util import ApiType 9 | 10 | 11 | class OpenAIObject(dict): 12 | api_base_override = None 13 | 14 | def __init__( 15 | self, 16 | id=None, 17 | api_key=None, 18 | api_version=None, 19 | api_type=None, 20 | organization=None, 21 | response_ms: Optional[int] = None, 22 | api_base=None, 23 | engine=None, 24 | **params, 25 | ): 26 | super(OpenAIObject, self).__init__() 27 | 28 | if response_ms is not None and not isinstance(response_ms, int): 29 | raise TypeError(f"response_ms is a {type(response_ms).__name__}.") 30 | self._response_ms = response_ms 31 | 32 | self._retrieve_params = params 33 | 34 | object.__setattr__(self, "api_key", api_key) 35 | object.__setattr__(self, "api_version", api_version) 36 | object.__setattr__(self, "api_type", api_type) 37 | object.__setattr__(self, "organization", organization) 38 | object.__setattr__(self, "api_base_override", api_base) 39 | object.__setattr__(self, "engine", engine) 40 | 41 | if id: 42 | self["id"] = id 43 | 44 | @property 45 | def response_ms(self) -> Optional[int]: 46 | return self._response_ms 47 | 48 | def __setattr__(self, k, v): 49 | if k[0] == "_" or k in self.__dict__: 50 | return super(OpenAIObject, self).__setattr__(k, v) 51 | 52 | self[k] = v 53 | return None 54 | 55 | def __getattr__(self, k): 56 | if k[0] == "_": 57 | raise AttributeError(k) 58 | try: 59 | return self[k] 60 | except KeyError as err: 61 | raise AttributeError(*err.args) 62 | 63 | def __delattr__(self, k): 64 | if k[0] == "_" or k in self.__dict__: 65 | return super(OpenAIObject, self).__delattr__(k) 66 | else: 67 | del self[k] 68 | 69 | def __setitem__(self, k, v): 70 | if v == "": 71 | raise ValueError( 72 | "You cannot set %s to an empty string. " 73 | "We interpret empty strings as None in requests." 74 | "You may set %s.%s = None to delete the property" % (k, str(self), k) 75 | ) 76 | super(OpenAIObject, self).__setitem__(k, v) 77 | 78 | def __delitem__(self, k): 79 | raise NotImplementedError("del is not supported") 80 | 81 | # Custom unpickling method that uses `update` to update the dictionary 82 | # without calling __setitem__, which would fail if any value is an empty 83 | # string 84 | def __setstate__(self, state): 85 | self.update(state) 86 | 87 | # Custom pickling method to ensure the instance is pickled as a custom 88 | # class and not as a dict, otherwise __setstate__ would not be called when 89 | # unpickling. 90 | def __reduce__(self): 91 | reduce_value = ( 92 | type(self), # callable 93 | ( # args 94 | self.get("id", None), 95 | self.api_key, 96 | self.api_version, 97 | self.api_type, 98 | self.organization, 99 | ), 100 | dict(self), # state 101 | ) 102 | return reduce_value 103 | 104 | @classmethod 105 | def construct_from( 106 | cls, 107 | values, 108 | api_key: Optional[str] = None, 109 | api_version=None, 110 | organization=None, 111 | engine=None, 112 | response_ms: Optional[int] = None, 113 | ): 114 | instance = cls( 115 | values.get("id"), 116 | api_key=api_key, 117 | api_version=api_version, 118 | organization=organization, 119 | engine=engine, 120 | response_ms=response_ms, 121 | ) 122 | instance.refresh_from( 123 | values, 124 | api_key=api_key, 125 | api_version=api_version, 126 | organization=organization, 127 | response_ms=response_ms, 128 | ) 129 | return instance 130 | 131 | def refresh_from( 132 | self, 133 | values, 134 | api_key=None, 135 | api_version=None, 136 | api_type=None, 137 | organization=None, 138 | response_ms: Optional[int] = None, 139 | ): 140 | self.api_key = api_key or getattr(values, "api_key", None) 141 | self.api_version = api_version or getattr(values, "api_version", None) 142 | self.api_type = api_type or getattr(values, "api_type", None) 143 | self.organization = organization or getattr(values, "organization", None) 144 | self._response_ms = response_ms or getattr(values, "_response_ms", None) 145 | 146 | # Wipe old state before setting new. 147 | self.clear() 148 | for k, v in values.items(): 149 | super(OpenAIObject, self).__setitem__( 150 | k, util.convert_to_openai_object(v, api_key, api_version, organization) 151 | ) 152 | 153 | self._previous = values 154 | 155 | @classmethod 156 | def api_base(cls): 157 | return None 158 | 159 | def request( 160 | self, 161 | method, 162 | url, 163 | params=None, 164 | headers=None, 165 | stream=False, 166 | plain_old_data=False, 167 | request_id: Optional[str] = None, 168 | request_timeout: Optional[Union[float, Tuple[float, float]]] = None, 169 | ): 170 | if params is None: 171 | params = self._retrieve_params 172 | requestor = api_requestor.APIRequestor( 173 | key=self.api_key, 174 | api_base=self.api_base_override or self.api_base(), 175 | api_type=self.api_type, 176 | api_version=self.api_version, 177 | organization=self.organization, 178 | ) 179 | response, stream, api_key = requestor.request( 180 | method, 181 | url, 182 | params=params, 183 | stream=stream, 184 | headers=headers, 185 | request_id=request_id, 186 | request_timeout=request_timeout, 187 | ) 188 | 189 | if stream: 190 | assert not isinstance(response, OpenAIResponse) # must be an iterator 191 | return ( 192 | util.convert_to_openai_object( 193 | line, 194 | api_key, 195 | self.api_version, 196 | self.organization, 197 | plain_old_data=plain_old_data, 198 | ) 199 | for line in response 200 | ) 201 | else: 202 | return util.convert_to_openai_object( 203 | response, 204 | api_key, 205 | self.api_version, 206 | self.organization, 207 | plain_old_data=plain_old_data, 208 | ) 209 | 210 | def __repr__(self): 211 | ident_parts = [type(self).__name__] 212 | 213 | obj = self.get("object") 214 | if isinstance(obj, str): 215 | ident_parts.append(obj) 216 | 217 | if isinstance(self.get("id"), str): 218 | ident_parts.append("id=%s" % (self.get("id"),)) 219 | 220 | unicode_repr = "<%s at %s> JSON: %s" % ( 221 | " ".join(ident_parts), 222 | hex(id(self)), 223 | str(self), 224 | ) 225 | 226 | return unicode_repr 227 | 228 | def __str__(self): 229 | obj = self.to_dict_recursive() 230 | return json.dumps(obj, sort_keys=True, indent=2) 231 | 232 | def to_dict(self): 233 | return dict(self) 234 | 235 | def to_dict_recursive(self): 236 | d = dict(self) 237 | for k, v in d.items(): 238 | if isinstance(v, OpenAIObject): 239 | d[k] = v.to_dict_recursive() 240 | elif isinstance(v, list): 241 | d[k] = [ 242 | e.to_dict_recursive() if isinstance(e, OpenAIObject) else e 243 | for e in v 244 | ] 245 | return d 246 | 247 | @property 248 | def openai_id(self): 249 | return self.id 250 | 251 | @property 252 | def typed_api_type(self): 253 | return ( 254 | ApiType.from_str(self.api_type) 255 | if self.api_type 256 | else ApiType.from_str(openai.api_type) 257 | ) 258 | 259 | # This class overrides __setitem__ to throw exceptions on inputs that it 260 | # doesn't like. This can cause problems when we try to copy an object 261 | # wholesale because some data that's returned from the API may not be valid 262 | # if it was set to be set manually. Here we override the class' copy 263 | # arguments so that we can bypass these possible exceptions on __setitem__. 264 | def __copy__(self): 265 | copied = OpenAIObject( 266 | self.get("id"), 267 | self.api_key, 268 | api_version=self.api_version, 269 | api_type=self.api_type, 270 | organization=self.organization, 271 | ) 272 | 273 | copied._retrieve_params = self._retrieve_params 274 | 275 | for k, v in self.items(): 276 | # Call parent's __setitem__ to avoid checks that we've added in the 277 | # overridden version that can throw exceptions. 278 | super(OpenAIObject, copied).__setitem__(k, v) 279 | 280 | return copied 281 | 282 | # This class overrides __setitem__ to throw exceptions on inputs that it 283 | # doesn't like. This can cause problems when we try to copy an object 284 | # wholesale because some data that's returned from the API may not be valid 285 | # if it was set to be set manually. Here we override the class' copy 286 | # arguments so that we can bypass these possible exceptions on __setitem__. 287 | def __deepcopy__(self, memo): 288 | copied = self.__copy__() 289 | memo[id(self)] = copied 290 | 291 | for k, v in self.items(): 292 | # Call parent's __setitem__ to avoid checks that we've added in the 293 | # overridden version that can throw exceptions. 294 | super(OpenAIObject, copied).__setitem__(k, deepcopy(v, memo)) 295 | 296 | return copied 297 | -------------------------------------------------------------------------------- /openai/wandb_logger.py: -------------------------------------------------------------------------------- 1 | try: 2 | import wandb 3 | 4 | WANDB_AVAILABLE = True 5 | except: 6 | WANDB_AVAILABLE = False 7 | 8 | 9 | if WANDB_AVAILABLE: 10 | import datetime 11 | import io 12 | import json 13 | import re 14 | from pathlib import Path 15 | 16 | import numpy as np 17 | import pandas as pd 18 | 19 | from openai import File, FineTune 20 | 21 | 22 | class WandbLogger: 23 | """ 24 | Log fine-tunes to [Weights & Biases](https://wandb.me/openai-docs) 25 | """ 26 | 27 | if not WANDB_AVAILABLE: 28 | print("Logging requires wandb to be installed. Run `pip install wandb`.") 29 | else: 30 | _wandb_api = None 31 | _logged_in = False 32 | 33 | @classmethod 34 | def sync( 35 | cls, 36 | id=None, 37 | n_fine_tunes=None, 38 | project="GPT-3", 39 | entity=None, 40 | force=False, 41 | **kwargs_wandb_init, 42 | ): 43 | """ 44 | Sync fine-tunes to Weights & Biases. 45 | :param id: The id of the fine-tune (optional) 46 | :param n_fine_tunes: Number of most recent fine-tunes to log when an id is not provided. By default, every fine-tune is synced. 47 | :param project: Name of the project where you're sending runs. By default, it is "GPT-3". 48 | :param entity: Username or team name where you're sending runs. By default, your default entity is used, which is usually your username. 49 | :param force: Forces logging and overwrite existing wandb run of the same fine-tune. 50 | """ 51 | 52 | if not WANDB_AVAILABLE: 53 | return 54 | 55 | if id: 56 | fine_tune = FineTune.retrieve(id=id) 57 | fine_tune.pop("events", None) 58 | fine_tunes = [fine_tune] 59 | 60 | else: 61 | # get list of fine_tune to log 62 | fine_tunes = FineTune.list() 63 | if not fine_tunes or fine_tunes.get("data") is None: 64 | print("No fine-tune has been retrieved") 65 | return 66 | fine_tunes = fine_tunes["data"][ 67 | -n_fine_tunes if n_fine_tunes is not None else None : 68 | ] 69 | 70 | # log starting from oldest fine_tune 71 | show_individual_warnings = ( 72 | False if id is None and n_fine_tunes is None else True 73 | ) 74 | fine_tune_logged = [ 75 | cls._log_fine_tune( 76 | fine_tune, 77 | project, 78 | entity, 79 | force, 80 | show_individual_warnings, 81 | **kwargs_wandb_init, 82 | ) 83 | for fine_tune in fine_tunes 84 | ] 85 | 86 | if not show_individual_warnings and not any(fine_tune_logged): 87 | print("No new successful fine-tunes were found") 88 | 89 | return "🎉 wandb sync completed successfully" 90 | 91 | @classmethod 92 | def _log_fine_tune( 93 | cls, 94 | fine_tune, 95 | project, 96 | entity, 97 | force, 98 | show_individual_warnings, 99 | **kwargs_wandb_init, 100 | ): 101 | fine_tune_id = fine_tune.get("id") 102 | status = fine_tune.get("status") 103 | 104 | # check run completed successfully 105 | if status != "succeeded": 106 | if show_individual_warnings: 107 | print( 108 | f'Fine-tune {fine_tune_id} has the status "{status}" and will not be logged' 109 | ) 110 | return 111 | 112 | # check results are present 113 | try: 114 | results_id = fine_tune["result_files"][0]["id"] 115 | results = File.download(id=results_id).decode("utf-8") 116 | except: 117 | if show_individual_warnings: 118 | print(f"Fine-tune {fine_tune_id} has no results and will not be logged") 119 | return 120 | 121 | # check run has not been logged already 122 | run_path = f"{project}/{fine_tune_id}" 123 | if entity is not None: 124 | run_path = f"{entity}/{run_path}" 125 | wandb_run = cls._get_wandb_run(run_path) 126 | if wandb_run: 127 | wandb_status = wandb_run.summary.get("status") 128 | if show_individual_warnings: 129 | if wandb_status == "succeeded": 130 | print( 131 | f"Fine-tune {fine_tune_id} has already been logged successfully at {wandb_run.url}" 132 | ) 133 | if not force: 134 | print( 135 | 'Use "--force" in the CLI or "force=True" in python if you want to overwrite previous run' 136 | ) 137 | else: 138 | print( 139 | f"A run for fine-tune {fine_tune_id} was previously created but didn't end successfully" 140 | ) 141 | if wandb_status != "succeeded" or force: 142 | print( 143 | f"A new wandb run will be created for fine-tune {fine_tune_id} and previous run will be overwritten" 144 | ) 145 | if wandb_status == "succeeded" and not force: 146 | return 147 | 148 | # start a wandb run 149 | wandb.init( 150 | job_type="fine-tune", 151 | config=cls._get_config(fine_tune), 152 | project=project, 153 | entity=entity, 154 | name=fine_tune_id, 155 | id=fine_tune_id, 156 | **kwargs_wandb_init, 157 | ) 158 | 159 | # log results 160 | df_results = pd.read_csv(io.StringIO(results)) 161 | for _, row in df_results.iterrows(): 162 | metrics = {k: v for k, v in row.items() if not np.isnan(v)} 163 | step = metrics.pop("step") 164 | if step is not None: 165 | step = int(step) 166 | wandb.log(metrics, step=step) 167 | fine_tuned_model = fine_tune.get("fine_tuned_model") 168 | if fine_tuned_model is not None: 169 | wandb.summary["fine_tuned_model"] = fine_tuned_model 170 | 171 | # training/validation files and fine-tune details 172 | cls._log_artifacts(fine_tune, project, entity) 173 | 174 | # mark run as complete 175 | wandb.summary["status"] = "succeeded" 176 | 177 | wandb.finish() 178 | return True 179 | 180 | @classmethod 181 | def _ensure_logged_in(cls): 182 | if not cls._logged_in: 183 | if wandb.login(): 184 | cls._logged_in = True 185 | else: 186 | raise Exception("You need to log in to wandb") 187 | 188 | @classmethod 189 | def _get_wandb_run(cls, run_path): 190 | cls._ensure_logged_in() 191 | try: 192 | if cls._wandb_api is None: 193 | cls._wandb_api = wandb.Api() 194 | return cls._wandb_api.run(run_path) 195 | except Exception: 196 | return None 197 | 198 | @classmethod 199 | def _get_wandb_artifact(cls, artifact_path): 200 | cls._ensure_logged_in() 201 | try: 202 | if cls._wandb_api is None: 203 | cls._wandb_api = wandb.Api() 204 | return cls._wandb_api.artifact(artifact_path) 205 | except Exception: 206 | return None 207 | 208 | @classmethod 209 | def _get_config(cls, fine_tune): 210 | config = dict(fine_tune) 211 | for key in ("training_files", "validation_files", "result_files"): 212 | if config.get(key) and len(config[key]): 213 | config[key] = config[key][0] 214 | if config.get("created_at"): 215 | config["created_at"] = datetime.datetime.fromtimestamp(config["created_at"]) 216 | return config 217 | 218 | @classmethod 219 | def _log_artifacts(cls, fine_tune, project, entity): 220 | # training/validation files 221 | training_file = ( 222 | fine_tune["training_files"][0] 223 | if fine_tune.get("training_files") and len(fine_tune["training_files"]) 224 | else None 225 | ) 226 | validation_file = ( 227 | fine_tune["validation_files"][0] 228 | if fine_tune.get("validation_files") and len(fine_tune["validation_files"]) 229 | else None 230 | ) 231 | for file, prefix, artifact_type in ( 232 | (training_file, "train", "training_files"), 233 | (validation_file, "valid", "validation_files"), 234 | ): 235 | if file is not None: 236 | cls._log_artifact_inputs(file, prefix, artifact_type, project, entity) 237 | 238 | # fine-tune details 239 | fine_tune_id = fine_tune.get("id") 240 | artifact = wandb.Artifact( 241 | "fine_tune_details", 242 | type="fine_tune_details", 243 | metadata=fine_tune, 244 | ) 245 | with artifact.new_file("fine_tune_details.json") as f: 246 | json.dump(fine_tune, f, indent=2) 247 | wandb.run.log_artifact( 248 | artifact, 249 | aliases=["latest", fine_tune_id], 250 | ) 251 | 252 | @classmethod 253 | def _log_artifact_inputs(cls, file, prefix, artifact_type, project, entity): 254 | file_id = file["id"] 255 | filename = Path(file["filename"]).name 256 | stem = Path(file["filename"]).stem 257 | 258 | # get input artifact 259 | artifact_name = f"{prefix}-{filename}" 260 | # sanitize name to valid wandb artifact name 261 | artifact_name = re.sub(r"[^a-zA-Z0-9_\-.]", "_", artifact_name) 262 | artifact_alias = file_id 263 | artifact_path = f"{project}/{artifact_name}:{artifact_alias}" 264 | if entity is not None: 265 | artifact_path = f"{entity}/{artifact_path}" 266 | artifact = cls._get_wandb_artifact(artifact_path) 267 | 268 | # create artifact if file not already logged previously 269 | if artifact is None: 270 | # get file content 271 | try: 272 | file_content = File.download(id=file_id).decode("utf-8") 273 | except: 274 | print( 275 | f"File {file_id} could not be retrieved. Make sure you are allowed to download training/validation files" 276 | ) 277 | return 278 | artifact = wandb.Artifact(artifact_name, type=artifact_type, metadata=file) 279 | with artifact.new_file(filename, mode="w") as f: 280 | f.write(file_content) 281 | 282 | # create a Table 283 | try: 284 | table, n_items = cls._make_table(file_content) 285 | artifact.add(table, stem) 286 | wandb.config.update({f"n_{prefix}": n_items}) 287 | artifact.metadata["items"] = n_items 288 | except: 289 | print(f"File {file_id} could not be read as a valid JSON file") 290 | else: 291 | # log number of items 292 | wandb.config.update({f"n_{prefix}": artifact.metadata.get("items")}) 293 | 294 | wandb.run.use_artifact(artifact, aliases=["latest", artifact_alias]) 295 | 296 | @classmethod 297 | def _make_table(cls, file_content): 298 | df = pd.read_json(io.StringIO(file_content), orient="records", lines=True) 299 | return wandb.Table(dataframe=df), len(df) 300 | -------------------------------------------------------------------------------- /openai/api_requestor.py: -------------------------------------------------------------------------------- 1 | import json 2 | import platform 3 | import threading 4 | import warnings 5 | from json import JSONDecodeError 6 | from typing import Dict, Iterator, Optional, Tuple, Union, overload 7 | from urllib.parse import urlencode, urlsplit, urlunsplit 8 | 9 | import requests 10 | from typing_extensions import Literal 11 | 12 | import openai 13 | from openai import error, util, version 14 | from openai.openai_response import OpenAIResponse 15 | from openai.util import ApiType 16 | 17 | TIMEOUT_SECS = 600 18 | MAX_CONNECTION_RETRIES = 2 19 | 20 | # Has one attribute per thread, 'session'. 21 | _thread_context = threading.local() 22 | 23 | 24 | def _build_api_url(url, query): 25 | scheme, netloc, path, base_query, fragment = urlsplit(url) 26 | 27 | if base_query: 28 | query = "%s&%s" % (base_query, query) 29 | 30 | return urlunsplit((scheme, netloc, path, query, fragment)) 31 | 32 | 33 | def _requests_proxies_arg(proxy) -> Optional[Dict[str, str]]: 34 | """Returns a value suitable for the 'proxies' argument to 'requests.request.""" 35 | if proxy is None: 36 | return None 37 | elif isinstance(proxy, str): 38 | return {"http": proxy, "https": proxy} 39 | elif isinstance(proxy, dict): 40 | return proxy.copy() 41 | else: 42 | raise ValueError( 43 | "'openai.proxy' must be specified as either a string URL or a dict with string URL under the https and/or http keys." 44 | ) 45 | 46 | 47 | def _make_session() -> requests.Session: 48 | if not openai.verify_ssl_certs: 49 | warnings.warn("verify_ssl_certs is ignored; openai always verifies.") 50 | s = requests.Session() 51 | proxies = _requests_proxies_arg(openai.proxy) 52 | if proxies: 53 | s.proxies = proxies 54 | s.mount( 55 | "https://", 56 | requests.adapters.HTTPAdapter(max_retries=MAX_CONNECTION_RETRIES), 57 | ) 58 | return s 59 | 60 | 61 | def parse_stream(rbody): 62 | for line in rbody: 63 | if line: 64 | if line == b"data: [DONE]": 65 | # return here will cause GeneratorExit exception in urllib3 66 | # and it will close http connection with TCP Reset 67 | continue 68 | if hasattr(line, "decode"): 69 | line = line.decode("utf-8") 70 | if line.startswith("data: "): 71 | line = line[len("data: ") :] 72 | yield line 73 | 74 | 75 | class APIRequestor: 76 | def __init__( 77 | self, 78 | key=None, 79 | api_base=None, 80 | api_type=None, 81 | api_version=None, 82 | organization=None, 83 | ): 84 | self.api_base = api_base or openai.api_base 85 | self.api_key = key or util.default_api_key() 86 | self.api_type = ( 87 | ApiType.from_str(api_type) 88 | if api_type 89 | else ApiType.from_str(openai.api_type) 90 | ) 91 | self.api_version = api_version or openai.api_version 92 | self.organization = organization or openai.organization 93 | 94 | @classmethod 95 | def format_app_info(cls, info): 96 | str = info["name"] 97 | if info["version"]: 98 | str += "/%s" % (info["version"],) 99 | if info["url"]: 100 | str += " (%s)" % (info["url"],) 101 | return str 102 | 103 | @overload 104 | def request( 105 | self, 106 | method, 107 | url, 108 | params, 109 | headers, 110 | files, 111 | stream: Literal[True], 112 | request_id: Optional[str] = ..., 113 | request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., 114 | ) -> Tuple[Iterator[OpenAIResponse], bool, str]: 115 | pass 116 | 117 | @overload 118 | def request( 119 | self, 120 | method, 121 | url, 122 | params=..., 123 | headers=..., 124 | files=..., 125 | *, 126 | stream: Literal[True], 127 | request_id: Optional[str] = ..., 128 | request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., 129 | ) -> Tuple[Iterator[OpenAIResponse], bool, str]: 130 | pass 131 | 132 | @overload 133 | def request( 134 | self, 135 | method, 136 | url, 137 | params=..., 138 | headers=..., 139 | files=..., 140 | stream: Literal[False] = ..., 141 | request_id: Optional[str] = ..., 142 | request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., 143 | ) -> Tuple[OpenAIResponse, bool, str]: 144 | pass 145 | 146 | @overload 147 | def request( 148 | self, 149 | method, 150 | url, 151 | params=..., 152 | headers=..., 153 | files=..., 154 | stream: bool = ..., 155 | request_id: Optional[str] = ..., 156 | request_timeout: Optional[Union[float, Tuple[float, float]]] = ..., 157 | ) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool, str]: 158 | pass 159 | 160 | def request( 161 | self, 162 | method, 163 | url, 164 | params=None, 165 | headers=None, 166 | files=None, 167 | stream: bool = False, 168 | request_id: Optional[str] = None, 169 | request_timeout: Optional[Union[float, Tuple[float, float]]] = None, 170 | ) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool, str]: 171 | result = self.request_raw( 172 | method.lower(), 173 | url, 174 | params=params, 175 | supplied_headers=headers, 176 | files=files, 177 | stream=stream, 178 | request_id=request_id, 179 | request_timeout=request_timeout, 180 | ) 181 | resp, got_stream = self._interpret_response(result, stream) 182 | return resp, got_stream, self.api_key 183 | 184 | def handle_error_response(self, rbody, rcode, resp, rheaders, stream_error=False): 185 | try: 186 | error_data = resp["error"] 187 | except (KeyError, TypeError): 188 | raise error.APIError( 189 | "Invalid response object from API: %r (HTTP response code " 190 | "was %d)" % (rbody, rcode), 191 | rbody, 192 | rcode, 193 | resp, 194 | ) 195 | 196 | if "internal_message" in error_data: 197 | error_data["message"] += "\n\n" + error_data["internal_message"] 198 | 199 | util.log_info( 200 | "OpenAI API error received", 201 | error_code=error_data.get("code"), 202 | error_type=error_data.get("type"), 203 | error_message=error_data.get("message"), 204 | error_param=error_data.get("param"), 205 | stream_error=stream_error, 206 | ) 207 | 208 | # Rate limits were previously coded as 400's with code 'rate_limit' 209 | if rcode == 429: 210 | return error.RateLimitError( 211 | error_data.get("message"), rbody, rcode, resp, rheaders 212 | ) 213 | elif rcode in [400, 404, 415]: 214 | return error.InvalidRequestError( 215 | error_data.get("message"), 216 | error_data.get("param"), 217 | error_data.get("code"), 218 | rbody, 219 | rcode, 220 | resp, 221 | rheaders, 222 | ) 223 | elif rcode == 401: 224 | return error.AuthenticationError( 225 | error_data.get("message"), rbody, rcode, resp, rheaders 226 | ) 227 | elif rcode == 403: 228 | return error.PermissionError( 229 | error_data.get("message"), rbody, rcode, resp, rheaders 230 | ) 231 | elif rcode == 409: 232 | return error.TryAgain( 233 | error_data.get("message"), rbody, rcode, resp, rheaders 234 | ) 235 | elif stream_error: 236 | # TODO: we will soon attach status codes to stream errors 237 | parts = [error_data.get("message"), "(Error occurred while streaming.)"] 238 | message = " ".join([p for p in parts if p is not None]) 239 | return error.APIError(message, rbody, rcode, resp, rheaders) 240 | else: 241 | return error.APIError( 242 | f"{error_data.get('message')} {rbody} {rcode} {resp} {rheaders}", 243 | rbody, 244 | rcode, 245 | resp, 246 | rheaders, 247 | ) 248 | 249 | def request_headers( 250 | self, method: str, extra, request_id: Optional[str] 251 | ) -> Dict[str, str]: 252 | user_agent = "OpenAI/v1 PythonBindings/%s" % (version.VERSION,) 253 | if openai.app_info: 254 | user_agent += " " + self.format_app_info(openai.app_info) 255 | 256 | uname_without_node = " ".join( 257 | v for k, v in platform.uname()._asdict().items() if k != "node" 258 | ) 259 | ua = { 260 | "bindings_version": version.VERSION, 261 | "httplib": "requests", 262 | "lang": "python", 263 | "lang_version": platform.python_version(), 264 | "platform": platform.platform(), 265 | "publisher": "openai", 266 | "uname": uname_without_node, 267 | } 268 | if openai.app_info: 269 | ua["application"] = openai.app_info 270 | 271 | headers = { 272 | "X-OpenAI-Client-User-Agent": json.dumps(ua), 273 | "User-Agent": user_agent, 274 | } 275 | 276 | headers.update(util.api_key_to_header(self.api_type, self.api_key)) 277 | 278 | if self.organization: 279 | headers["OpenAI-Organization"] = self.organization 280 | 281 | if self.api_version is not None and self.api_type == ApiType.OPEN_AI: 282 | headers["OpenAI-Version"] = self.api_version 283 | if request_id is not None: 284 | headers["X-Request-Id"] = request_id 285 | if openai.debug: 286 | headers["OpenAI-Debug"] = "true" 287 | headers.update(extra) 288 | 289 | return headers 290 | 291 | def _validate_headers( 292 | self, supplied_headers: Optional[Dict[str, str]] 293 | ) -> Dict[str, str]: 294 | headers: Dict[str, str] = {} 295 | if supplied_headers is None: 296 | return headers 297 | 298 | if not isinstance(supplied_headers, dict): 299 | raise TypeError("Headers must be a dictionary") 300 | 301 | for k, v in supplied_headers.items(): 302 | if not isinstance(k, str): 303 | raise TypeError("Header keys must be strings") 304 | if not isinstance(v, str): 305 | raise TypeError("Header values must be strings") 306 | headers[k] = v 307 | 308 | # NOTE: It is possible to do more validation of the headers, but a request could always 309 | # be made to the API manually with invalid headers, so we need to handle them server side. 310 | 311 | return headers 312 | 313 | def request_raw( 314 | self, 315 | method, 316 | url, 317 | *, 318 | params=None, 319 | supplied_headers: Dict[str, str] = None, 320 | files=None, 321 | stream: bool = False, 322 | request_id: Optional[str] = None, 323 | request_timeout: Optional[Union[float, Tuple[float, float]]] = None, 324 | ) -> requests.Response: 325 | abs_url = "%s%s" % (self.api_base, url) 326 | headers = self._validate_headers(supplied_headers) 327 | 328 | data = None 329 | if method == "get" or method == "delete": 330 | if params: 331 | encoded_params = urlencode( 332 | [(k, v) for k, v in params.items() if v is not None] 333 | ) 334 | abs_url = _build_api_url(abs_url, encoded_params) 335 | elif method in {"post", "put"}: 336 | if params and files: 337 | raise ValueError("At most one of params and files may be specified.") 338 | if params: 339 | data = json.dumps(params).encode() 340 | headers["Content-Type"] = "application/json" 341 | else: 342 | raise error.APIConnectionError( 343 | "Unrecognized HTTP method %r. This may indicate a bug in the " 344 | "OpenAI bindings. Please contact support@openai.com for " 345 | "assistance." % (method,) 346 | ) 347 | 348 | headers = self.request_headers(method, headers, request_id) 349 | 350 | util.log_info("Request to OpenAI API", method=method, path=abs_url) 351 | util.log_debug("Post details", data=data, api_version=self.api_version) 352 | 353 | if not hasattr(_thread_context, "session"): 354 | _thread_context.session = _make_session() 355 | try: 356 | result = _thread_context.session.request( 357 | method, 358 | abs_url, 359 | headers=headers, 360 | data=data, 361 | files=files, 362 | stream=stream, 363 | timeout=request_timeout if request_timeout else TIMEOUT_SECS, 364 | ) 365 | except requests.exceptions.Timeout as e: 366 | raise error.Timeout("Request timed out") from e 367 | except requests.exceptions.RequestException as e: 368 | raise error.APIConnectionError("Error communicating with OpenAI") from e 369 | util.log_info( 370 | "OpenAI API response", 371 | path=abs_url, 372 | response_code=result.status_code, 373 | processing_ms=result.headers.get("OpenAI-Processing-Ms"), 374 | request_id=result.headers.get("X-Request-Id"), 375 | ) 376 | # Don't read the whole stream for debug logging unless necessary. 377 | if openai.log == "debug": 378 | util.log_debug( 379 | "API response body", body=result.content, headers=result.headers 380 | ) 381 | return result 382 | 383 | def _interpret_response( 384 | self, result: requests.Response, stream: bool 385 | ) -> Tuple[Union[OpenAIResponse, Iterator[OpenAIResponse]], bool]: 386 | """Returns the response(s) and a bool indicating whether it is a stream.""" 387 | if stream and "text/event-stream" in result.headers.get("Content-Type", ""): 388 | return ( 389 | self._interpret_response_line( 390 | line, result.status_code, result.headers, stream=True 391 | ) 392 | for line in parse_stream(result.iter_lines()) 393 | ), True 394 | else: 395 | return ( 396 | self._interpret_response_line( 397 | result.content, result.status_code, result.headers, stream=False 398 | ), 399 | False, 400 | ) 401 | 402 | def _interpret_response_line( 403 | self, rbody, rcode, rheaders, stream: bool 404 | ) -> OpenAIResponse: 405 | # HTTP 204 response code does not have any content in the body. 406 | if rcode == 204: 407 | return OpenAIResponse(None, rheaders) 408 | 409 | if rcode == 503: 410 | raise error.ServiceUnavailableError( 411 | "The server is overloaded or not ready yet.", 412 | rbody, 413 | rcode, 414 | headers=rheaders, 415 | ) 416 | try: 417 | if hasattr(rbody, "decode"): 418 | rbody = rbody.decode("utf-8") 419 | data = json.loads(rbody) 420 | except (JSONDecodeError, UnicodeDecodeError): 421 | raise error.APIError( 422 | f"HTTP code {rcode} from API ({rbody})", rbody, rcode, headers=rheaders 423 | ) 424 | resp = OpenAIResponse(data, rheaders) 425 | # In the future, we might add a "status" parameter to errors 426 | # to better handle the "error while streaming" case. 427 | stream_error = stream and "error" in resp.data 428 | if stream_error or not 200 <= rcode < 300: 429 | raise self.handle_error_response( 430 | rbody, rcode, resp.data, rheaders, stream_error=stream_error 431 | ) 432 | return resp 433 | -------------------------------------------------------------------------------- /openai/validators.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import Any, Callable, NamedTuple, Optional 4 | 5 | import pandas as pd 6 | 7 | 8 | class Remediation(NamedTuple): 9 | name: str 10 | immediate_msg: Optional[str] = None 11 | necessary_msg: Optional[str] = None 12 | necessary_fn: Optional[Callable[[Any], Any]] = None 13 | optional_msg: Optional[str] = None 14 | optional_fn: Optional[Callable[[Any], Any]] = None 15 | error_msg: Optional[str] = None 16 | 17 | 18 | def num_examples_validator(df): 19 | """ 20 | This validator will only print out the number of examples and recommend to the user to increase the number of examples if less than 100. 21 | """ 22 | MIN_EXAMPLES = 100 23 | optional_suggestion = ( 24 | "" 25 | if len(df) >= MIN_EXAMPLES 26 | else ". In general, we recommend having at least a few hundred examples. We've found that performance tends to linearly increase for every doubling of the number of examples" 27 | ) 28 | immediate_msg = ( 29 | f"\n- Your file contains {len(df)} prompt-completion pairs{optional_suggestion}" 30 | ) 31 | return Remediation(name="num_examples", immediate_msg=immediate_msg) 32 | 33 | 34 | def necessary_column_validator(df, necessary_column): 35 | """ 36 | This validator will ensure that the necessary column is present in the dataframe. 37 | """ 38 | 39 | def lower_case_column(df, column): 40 | cols = [c for c in df.columns if c.lower() == column] 41 | df.rename(columns={cols[0]: column.lower()}, inplace=True) 42 | return df 43 | 44 | immediate_msg = None 45 | necessary_fn = None 46 | necessary_msg = None 47 | error_msg = None 48 | 49 | if necessary_column not in df.columns: 50 | if necessary_column in [c.lower() for c in df.columns]: 51 | 52 | def lower_case_column_creator(df): 53 | return lower_case_column(df, necessary_column) 54 | 55 | necessary_fn = lower_case_column_creator 56 | immediate_msg = ( 57 | f"\n- The `{necessary_column}` column/key should be lowercase" 58 | ) 59 | necessary_msg = f"Lower case column name to `{necessary_column}`" 60 | else: 61 | error_msg = f"`{necessary_column}` column/key is missing. Please make sure you name your columns/keys appropriately, then retry" 62 | 63 | return Remediation( 64 | name="necessary_column", 65 | immediate_msg=immediate_msg, 66 | necessary_msg=necessary_msg, 67 | necessary_fn=necessary_fn, 68 | error_msg=error_msg, 69 | ) 70 | 71 | 72 | def additional_column_validator(df, fields=["prompt", "completion"]): 73 | """ 74 | This validator will remove additional columns from the dataframe. 75 | """ 76 | additional_columns = [] 77 | necessary_msg = None 78 | immediate_msg = None 79 | necessary_fn = None 80 | if len(df.columns) > 2: 81 | additional_columns = [c for c in df.columns if c not in fields] 82 | warn_message = "" 83 | for ac in additional_columns: 84 | dups = [c for c in additional_columns if ac in c] 85 | if len(dups) > 0: 86 | warn_message += f"\n WARNING: Some of the additional columns/keys contain `{ac}` in their name. These will be ignored, and the column/key `{ac}` will be used instead. This could also result from a duplicate column/key in the provided file." 87 | immediate_msg = f"\n- The input file should contain exactly two columns/keys per row. Additional columns/keys present are: {additional_columns}{warn_message}" 88 | necessary_msg = f"Remove additional columns/keys: {additional_columns}" 89 | 90 | def necessary_fn(x): 91 | return x[fields] 92 | 93 | return Remediation( 94 | name="additional_column", 95 | immediate_msg=immediate_msg, 96 | necessary_msg=necessary_msg, 97 | necessary_fn=necessary_fn, 98 | ) 99 | 100 | 101 | def non_empty_field_validator(df, field="completion"): 102 | """ 103 | This validator will ensure that no completion is empty. 104 | """ 105 | necessary_msg = None 106 | necessary_fn = None 107 | immediate_msg = None 108 | 109 | if df[field].apply(lambda x: x == "").any() or df[field].isnull().any(): 110 | empty_rows = (df[field] == "") | (df[field].isnull()) 111 | empty_indexes = df.reset_index().index[empty_rows].tolist() 112 | immediate_msg = f"\n- `{field}` column/key should not contain empty strings. These are rows: {empty_indexes}" 113 | 114 | def necessary_fn(x): 115 | return x[x[field] != ""].dropna(subset=[field]) 116 | 117 | necessary_msg = f"Remove {len(empty_indexes)} rows with empty {field}s" 118 | return Remediation( 119 | name=f"empty_{field}", 120 | immediate_msg=immediate_msg, 121 | necessary_msg=necessary_msg, 122 | necessary_fn=necessary_fn, 123 | ) 124 | 125 | 126 | def duplicated_rows_validator(df, fields=["prompt", "completion"]): 127 | """ 128 | This validator will suggest to the user to remove duplicate rows if they exist. 129 | """ 130 | duplicated_rows = df.duplicated(subset=fields) 131 | duplicated_indexes = df.reset_index().index[duplicated_rows].tolist() 132 | immediate_msg = None 133 | optional_msg = None 134 | optional_fn = None 135 | 136 | if len(duplicated_indexes) > 0: 137 | immediate_msg = f"\n- There are {len(duplicated_indexes)} duplicated {'-'.join(fields)} sets. These are rows: {duplicated_indexes}" 138 | optional_msg = f"Remove {len(duplicated_indexes)} duplicate rows" 139 | 140 | def optional_fn(x): 141 | return x.drop_duplicates(subset=fields) 142 | 143 | return Remediation( 144 | name="duplicated_rows", 145 | immediate_msg=immediate_msg, 146 | optional_msg=optional_msg, 147 | optional_fn=optional_fn, 148 | ) 149 | 150 | 151 | def long_examples_validator(df): 152 | """ 153 | This validator will suggest to the user to remove examples that are too long. 154 | """ 155 | immediate_msg = None 156 | optional_msg = None 157 | optional_fn = None 158 | 159 | ft_type = infer_task_type(df) 160 | if ft_type != "open-ended generation": 161 | def get_long_indexes(d): 162 | long_examples = d.apply( 163 | lambda x: len(x.prompt) + len(x.completion) > 10000, axis=1 164 | ) 165 | return d.reset_index().index[long_examples].tolist() 166 | 167 | long_indexes = get_long_indexes(df) 168 | 169 | if len(long_indexes) > 0: 170 | immediate_msg = f"\n- There are {len(long_indexes)} examples that are very long. These are rows: {long_indexes}\nFor conditional generation, and for classification the examples shouldn't be longer than 2048 tokens." 171 | optional_msg = f"Remove {len(long_indexes)} long examples" 172 | 173 | def optional_fn(x): 174 | 175 | long_indexes_to_drop = get_long_indexes(x) 176 | if long_indexes != long_indexes_to_drop: 177 | sys.stdout.write(f"The indices of the long examples has changed as a result of a previously applied recommendation.\nThe {len(long_indexes_to_drop)} long examples to be dropped are now at the following indices: {long_indexes_to_drop}\n") 178 | return x.drop(long_indexes_to_drop) 179 | 180 | return Remediation( 181 | name="long_examples", 182 | immediate_msg=immediate_msg, 183 | optional_msg=optional_msg, 184 | optional_fn=optional_fn, 185 | ) 186 | 187 | 188 | def common_prompt_suffix_validator(df): 189 | """ 190 | This validator will suggest to add a common suffix to the prompt if one doesn't already exist in case of classification or conditional generation. 191 | """ 192 | error_msg = None 193 | immediate_msg = None 194 | optional_msg = None 195 | optional_fn = None 196 | 197 | # Find a suffix which is not contained within the prompt otherwise 198 | suggested_suffix = "\n\n### =>\n\n" 199 | suffix_options = [ 200 | " ->", 201 | "\n\n###\n\n", 202 | "\n\n===\n\n", 203 | "\n\n---\n\n", 204 | "\n\n===>\n\n", 205 | "\n\n--->\n\n", 206 | ] 207 | for suffix_option in suffix_options: 208 | if suffix_option == " ->": 209 | if df.prompt.str.contains("\n").any(): 210 | continue 211 | if df.prompt.str.contains(suffix_option, regex=False).any(): 212 | continue 213 | suggested_suffix = suffix_option 214 | break 215 | display_suggested_suffix = suggested_suffix.replace("\n", "\\n") 216 | 217 | ft_type = infer_task_type(df) 218 | if ft_type == "open-ended generation": 219 | return Remediation(name="common_suffix") 220 | 221 | def add_suffix(x, suffix): 222 | x["prompt"] += suffix 223 | return x 224 | 225 | common_suffix = get_common_xfix(df.prompt, xfix="suffix") 226 | if (df.prompt == common_suffix).all(): 227 | error_msg = f"All prompts are identical: `{common_suffix}`\nConsider leaving the prompts blank if you want to do open-ended generation, otherwise ensure prompts are different" 228 | return Remediation(name="common_suffix", error_msg=error_msg) 229 | 230 | if common_suffix != "": 231 | common_suffix_new_line_handled = common_suffix.replace("\n", "\\n") 232 | immediate_msg = ( 233 | f"\n- All prompts end with suffix `{common_suffix_new_line_handled}`" 234 | ) 235 | if len(common_suffix) > 10: 236 | immediate_msg += f". This suffix seems very long. Consider replacing with a shorter suffix, such as `{display_suggested_suffix}`" 237 | if ( 238 | df.prompt.str[: -len(common_suffix)] 239 | .str.contains(common_suffix, regex=False) 240 | .any() 241 | ): 242 | immediate_msg += f"\n WARNING: Some of your prompts contain the suffix `{common_suffix}` more than once. We strongly suggest that you review your prompts and add a unique suffix" 243 | 244 | else: 245 | immediate_msg = "\n- Your data does not contain a common separator at the end of your prompts. Having a separator string appended to the end of the prompt makes it clearer to the fine-tuned model where the completion should begin. See https://beta.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more detail and examples. If you intend to do open-ended generation, then you should leave the prompts empty" 246 | 247 | if common_suffix == "": 248 | optional_msg = ( 249 | f"Add a suffix separator `{display_suggested_suffix}` to all prompts" 250 | ) 251 | 252 | def optional_fn(x): 253 | return add_suffix(x, suggested_suffix) 254 | 255 | return Remediation( 256 | name="common_completion_suffix", 257 | immediate_msg=immediate_msg, 258 | optional_msg=optional_msg, 259 | optional_fn=optional_fn, 260 | error_msg=error_msg, 261 | ) 262 | 263 | 264 | def common_prompt_prefix_validator(df): 265 | """ 266 | This validator will suggest to remove a common prefix from the prompt if a long one exist. 267 | """ 268 | MAX_PREFIX_LEN = 12 269 | 270 | immediate_msg = None 271 | optional_msg = None 272 | optional_fn = None 273 | 274 | common_prefix = get_common_xfix(df.prompt, xfix="prefix") 275 | if common_prefix == "": 276 | return Remediation(name="common_prefix") 277 | 278 | def remove_common_prefix(x, prefix): 279 | x["prompt"] = x["prompt"].str[len(prefix) :] 280 | return x 281 | 282 | if (df.prompt == common_prefix).all(): 283 | # already handled by common_suffix_validator 284 | return Remediation(name="common_prefix") 285 | 286 | if common_prefix != "": 287 | immediate_msg = f"\n- All prompts start with prefix `{common_prefix}`" 288 | if MAX_PREFIX_LEN < len(common_prefix): 289 | immediate_msg += ". Fine-tuning doesn't require the instruction specifying the task, or a few-shot example scenario. Most of the time you should only add the input data into the prompt, and the desired output into the completion" 290 | optional_msg = f"Remove prefix `{common_prefix}` from all prompts" 291 | 292 | def optional_fn(x): 293 | return remove_common_prefix(x, common_prefix) 294 | 295 | return Remediation( 296 | name="common_prompt_prefix", 297 | immediate_msg=immediate_msg, 298 | optional_msg=optional_msg, 299 | optional_fn=optional_fn, 300 | ) 301 | 302 | 303 | def common_completion_prefix_validator(df): 304 | """ 305 | This validator will suggest to remove a common prefix from the completion if a long one exist. 306 | """ 307 | MAX_PREFIX_LEN = 5 308 | 309 | common_prefix = get_common_xfix(df.completion, xfix="prefix") 310 | ws_prefix = len(common_prefix) > 0 and common_prefix[0] == " " 311 | if len(common_prefix) < MAX_PREFIX_LEN: 312 | return Remediation(name="common_prefix") 313 | 314 | def remove_common_prefix(x, prefix, ws_prefix): 315 | x["completion"] = x["completion"].str[len(prefix) :] 316 | if ws_prefix: 317 | # keep the single whitespace as prefix 318 | x["completion"] = " " + x["completion"] 319 | return x 320 | 321 | if (df.completion == common_prefix).all(): 322 | # already handled by common_suffix_validator 323 | return Remediation(name="common_prefix") 324 | 325 | immediate_msg = f"\n- All completions start with prefix `{common_prefix}`. Most of the time you should only add the output data into the completion, without any prefix" 326 | optional_msg = f"Remove prefix `{common_prefix}` from all completions" 327 | 328 | def optional_fn(x): 329 | return remove_common_prefix(x, common_prefix, ws_prefix) 330 | 331 | return Remediation( 332 | name="common_completion_prefix", 333 | immediate_msg=immediate_msg, 334 | optional_msg=optional_msg, 335 | optional_fn=optional_fn, 336 | ) 337 | 338 | 339 | def common_completion_suffix_validator(df): 340 | """ 341 | This validator will suggest to add a common suffix to the completion if one doesn't already exist in case of classification or conditional generation. 342 | """ 343 | error_msg = None 344 | immediate_msg = None 345 | optional_msg = None 346 | optional_fn = None 347 | 348 | ft_type = infer_task_type(df) 349 | if ft_type == "open-ended generation" or ft_type == "classification": 350 | return Remediation(name="common_suffix") 351 | 352 | common_suffix = get_common_xfix(df.completion, xfix="suffix") 353 | if (df.completion == common_suffix).all(): 354 | error_msg = f"All completions are identical: `{common_suffix}`\nEnsure completions are different, otherwise the model will just repeat `{common_suffix}`" 355 | return Remediation(name="common_suffix", error_msg=error_msg) 356 | 357 | # Find a suffix which is not contained within the completion otherwise 358 | suggested_suffix = " [END]" 359 | suffix_options = [ 360 | "\n", 361 | ".", 362 | " END", 363 | "***", 364 | "+++", 365 | "&&&", 366 | "$$$", 367 | "@@@", 368 | "%%%", 369 | ] 370 | for suffix_option in suffix_options: 371 | if df.completion.str.contains(suffix_option, regex=False).any(): 372 | continue 373 | suggested_suffix = suffix_option 374 | break 375 | display_suggested_suffix = suggested_suffix.replace("\n", "\\n") 376 | 377 | def add_suffix(x, suffix): 378 | x["completion"] += suffix 379 | return x 380 | 381 | if common_suffix != "": 382 | common_suffix_new_line_handled = common_suffix.replace("\n", "\\n") 383 | immediate_msg = ( 384 | f"\n- All completions end with suffix `{common_suffix_new_line_handled}`" 385 | ) 386 | if len(common_suffix) > 10: 387 | immediate_msg += f". This suffix seems very long. Consider replacing with a shorter suffix, such as `{display_suggested_suffix}`" 388 | if ( 389 | df.completion.str[: -len(common_suffix)] 390 | .str.contains(common_suffix, regex=False) 391 | .any() 392 | ): 393 | immediate_msg += f"\n WARNING: Some of your completions contain the suffix `{common_suffix}` more than once. We suggest that you review your completions and add a unique ending" 394 | 395 | else: 396 | immediate_msg = "\n- Your data does not contain a common ending at the end of your completions. Having a common ending string appended to the end of the completion makes it clearer to the fine-tuned model where the completion should end. See https://beta.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more detail and examples." 397 | 398 | if common_suffix == "": 399 | optional_msg = ( 400 | f"Add a suffix ending `{display_suggested_suffix}` to all completions" 401 | ) 402 | 403 | def optional_fn(x): 404 | return add_suffix(x, suggested_suffix) 405 | 406 | return Remediation( 407 | name="common_completion_suffix", 408 | immediate_msg=immediate_msg, 409 | optional_msg=optional_msg, 410 | optional_fn=optional_fn, 411 | error_msg=error_msg, 412 | ) 413 | 414 | 415 | def completions_space_start_validator(df): 416 | """ 417 | This validator will suggest to add a space at the start of the completion if it doesn't already exist. This helps with tokenization. 418 | """ 419 | 420 | def add_space_start(x): 421 | x["completion"] = x["completion"].apply( 422 | lambda x: ("" if x[0] == " " else " ") + x 423 | ) 424 | return x 425 | 426 | optional_msg = None 427 | optional_fn = None 428 | immediate_msg = None 429 | 430 | if df.completion.str[:1].nunique() != 1 or df.completion.values[0][0] != " ": 431 | immediate_msg = "\n- The completion should start with a whitespace character (` `). This tends to produce better results due to the tokenization we use. See https://beta.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more details" 432 | optional_msg = "Add a whitespace character to the beginning of the completion" 433 | optional_fn = add_space_start 434 | return Remediation( 435 | name="completion_space_start", 436 | immediate_msg=immediate_msg, 437 | optional_msg=optional_msg, 438 | optional_fn=optional_fn, 439 | ) 440 | 441 | 442 | def lower_case_validator(df, column): 443 | """ 444 | This validator will suggest to lowercase the column values, if more than a third of letters are uppercase. 445 | """ 446 | 447 | def lower_case(x): 448 | x[column] = x[column].str.lower() 449 | return x 450 | 451 | count_upper = ( 452 | df[column] 453 | .apply(lambda x: sum(1 for c in x if c.isalpha() and c.isupper())) 454 | .sum() 455 | ) 456 | count_lower = ( 457 | df[column] 458 | .apply(lambda x: sum(1 for c in x if c.isalpha() and c.islower())) 459 | .sum() 460 | ) 461 | 462 | if count_upper * 2 > count_lower: 463 | return Remediation( 464 | name="lower_case", 465 | immediate_msg=f"\n- More than a third of your `{column}` column/key is uppercase. Uppercase {column}s tends to perform worse than a mixture of case encountered in normal language. We recommend to lower case the data if that makes sense in your domain. See https://beta.openai.com/docs/guides/fine-tuning/preparing-your-dataset for more details", 466 | optional_msg=f"Lowercase all your data in column/key `{column}`", 467 | optional_fn=lower_case, 468 | ) 469 | 470 | 471 | def read_any_format(fname, fields=["prompt", "completion"]): 472 | """ 473 | This function will read a file saved in .csv, .json, .txt, .xlsx or .tsv format using pandas. 474 | - for .xlsx it will read the first sheet 475 | - for .txt it will assume completions and split on newline 476 | """ 477 | remediation = None 478 | necessary_msg = None 479 | immediate_msg = None 480 | error_msg = None 481 | df = None 482 | 483 | if os.path.isfile(fname): 484 | for ending, separator in [(".csv", ","), (".tsv", "\t")]: 485 | if fname.lower().endswith(ending): 486 | immediate_msg = f"\n- Based on your file extension, your file is formatted as a {ending[1:].upper()} file" 487 | necessary_msg = ( 488 | f"Your format `{ending[1:].upper()}` will be converted to `JSONL`" 489 | ) 490 | df = pd.read_csv(fname, sep=separator, dtype=str) 491 | if fname.lower().endswith(".xlsx"): 492 | immediate_msg = "\n- Based on your file extension, your file is formatted as an Excel file" 493 | necessary_msg = "Your format `XLSX` will be converted to `JSONL`" 494 | xls = pd.ExcelFile(fname) 495 | sheets = xls.sheet_names 496 | if len(sheets) > 1: 497 | immediate_msg += "\n- Your Excel file contains more than one sheet. Please either save as csv or ensure all data is present in the first sheet. WARNING: Reading only the first sheet..." 498 | df = pd.read_excel(fname, dtype=str) 499 | if fname.lower().endswith(".txt"): 500 | immediate_msg = "\n- Based on your file extension, you provided a text file" 501 | necessary_msg = "Your format `TXT` will be converted to `JSONL`" 502 | with open(fname, "r") as f: 503 | content = f.read() 504 | df = pd.DataFrame( 505 | [["", line] for line in content.split("\n")], 506 | columns=fields, 507 | dtype=str, 508 | ) 509 | if fname.lower().endswith("jsonl") or fname.lower().endswith("json"): 510 | try: 511 | df = pd.read_json(fname, lines=True, dtype=str) 512 | except (ValueError, TypeError): 513 | df = pd.read_json(fname, dtype=str) 514 | immediate_msg = "\n- Your file appears to be in a .JSON format. Your file will be converted to JSONL format" 515 | necessary_msg = "Your format `JSON` will be converted to `JSONL`" 516 | 517 | if df is None: 518 | error_msg = ( 519 | "Your file is not saved as a .CSV, .TSV, .XLSX, .TXT or .JSONL file." 520 | ) 521 | if "." in fname: 522 | error_msg += ( 523 | f" Your file `{fname}` appears to end with `.{fname.split('.')[1]}`" 524 | ) 525 | else: 526 | error_msg += f" Your file `{fname}` does not appear to have a file ending. Please ensure your filename ends with one of the supported file endings." 527 | else: 528 | df.fillna("", inplace=True) 529 | else: 530 | error_msg = f"File {fname} does not exist." 531 | 532 | remediation = Remediation( 533 | name="read_any_format", 534 | necessary_msg=necessary_msg, 535 | immediate_msg=immediate_msg, 536 | error_msg=error_msg, 537 | ) 538 | return df, remediation 539 | 540 | 541 | def format_inferrer_validator(df): 542 | """ 543 | This validator will infer the likely fine-tuning format of the data, and display it to the user if it is classification. 544 | It will also suggest to use ada and explain train/validation split benefits. 545 | """ 546 | ft_type = infer_task_type(df) 547 | immediate_msg = None 548 | if ft_type == "classification": 549 | immediate_msg = f"\n- Based on your data it seems like you're trying to fine-tune a model for {ft_type}\n- For classification, we recommend you try one of the faster and cheaper models, such as `ada`\n- For classification, you can estimate the expected model performance by keeping a held out dataset, which is not used for training" 550 | return Remediation(name="num_examples", immediate_msg=immediate_msg) 551 | 552 | 553 | def apply_necessary_remediation(df, remediation): 554 | """ 555 | This function will apply a necessary remediation to a dataframe, or print an error message if one exists. 556 | """ 557 | if remediation.error_msg is not None: 558 | sys.stderr.write( 559 | f"\n\nERROR in {remediation.name} validator: {remediation.error_msg}\n\nAborting..." 560 | ) 561 | sys.exit(1) 562 | if remediation.immediate_msg is not None: 563 | sys.stdout.write(remediation.immediate_msg) 564 | if remediation.necessary_fn is not None: 565 | df = remediation.necessary_fn(df) 566 | return df 567 | 568 | 569 | def accept_suggestion(input_text, auto_accept): 570 | sys.stdout.write(input_text) 571 | if auto_accept: 572 | sys.stdout.write("Y\n") 573 | return True 574 | return input().lower() != "n" 575 | 576 | 577 | def apply_optional_remediation(df, remediation, auto_accept): 578 | """ 579 | This function will apply an optional remediation to a dataframe, based on the user input. 580 | """ 581 | optional_applied = False 582 | input_text = f"- [Recommended] {remediation.optional_msg} [Y/n]: " 583 | if remediation.optional_msg is not None: 584 | if accept_suggestion(input_text, auto_accept): 585 | df = remediation.optional_fn(df) 586 | optional_applied = True 587 | if remediation.necessary_msg is not None: 588 | sys.stdout.write(f"- [Necessary] {remediation.necessary_msg}\n") 589 | return df, optional_applied 590 | 591 | 592 | def estimate_fine_tuning_time(df): 593 | """ 594 | Estimate the time it'll take to fine-tune the dataset 595 | """ 596 | ft_format = infer_task_type(df) 597 | expected_time = 1.0 598 | if ft_format == "classification": 599 | num_examples = len(df) 600 | expected_time = num_examples * 1.44 601 | else: 602 | size = df.memory_usage(index=True).sum() 603 | expected_time = size * 0.0515 604 | 605 | def format_time(time): 606 | if time < 60: 607 | return f"{round(time, 2)} seconds" 608 | elif time < 3600: 609 | return f"{round(time / 60, 2)} minutes" 610 | elif time < 86400: 611 | return f"{round(time / 3600, 2)} hours" 612 | else: 613 | return f"{round(time / 86400, 2)} days" 614 | 615 | time_string = format_time(expected_time + 140) 616 | sys.stdout.write( 617 | f"Once your model starts training, it'll approximately take {time_string} to train a `curie` model, and less for `ada` and `babbage`. Queue will approximately take half an hour per job ahead of you.\n" 618 | ) 619 | 620 | 621 | def get_outfnames(fname, split): 622 | suffixes = ["_train", "_valid"] if split else [""] 623 | i = 0 624 | while True: 625 | index_suffix = f" ({i})" if i > 0 else "" 626 | candidate_fnames = [ 627 | os.path.splitext(fname)[0] + "_prepared" + suffix + index_suffix + ".jsonl" 628 | for suffix in suffixes 629 | ] 630 | if not any(os.path.isfile(f) for f in candidate_fnames): 631 | return candidate_fnames 632 | i += 1 633 | 634 | 635 | def get_classification_hyperparams(df): 636 | n_classes = df.completion.nunique() 637 | pos_class = None 638 | if n_classes == 2: 639 | pos_class = df.completion.value_counts().index[0] 640 | return n_classes, pos_class 641 | 642 | 643 | def write_out_file(df, fname, any_remediations, auto_accept): 644 | """ 645 | This function will write out a dataframe to a file, if the user would like to proceed, and also offer a fine-tuning command with the newly created file. 646 | For classification it will optionally ask the user if they would like to split the data into train/valid files, and modify the suggested command to include the valid set. 647 | """ 648 | ft_format = infer_task_type(df) 649 | common_prompt_suffix = get_common_xfix(df.prompt, xfix="suffix") 650 | common_completion_suffix = get_common_xfix(df.completion, xfix="suffix") 651 | 652 | split = False 653 | input_text = "- [Recommended] Would you like to split into training and validation set? [Y/n]: " 654 | if ft_format == "classification": 655 | if accept_suggestion(input_text, auto_accept): 656 | split = True 657 | 658 | additional_params = "" 659 | common_prompt_suffix_new_line_handled = common_prompt_suffix.replace("\n", "\\n") 660 | common_completion_suffix_new_line_handled = common_completion_suffix.replace( 661 | "\n", "\\n" 662 | ) 663 | optional_ending_string = ( 664 | f' Make sure to include `stop=["{common_completion_suffix_new_line_handled}"]` so that the generated texts ends at the expected place.' 665 | if len(common_completion_suffix_new_line_handled) > 0 666 | else "" 667 | ) 668 | 669 | input_text = "\n\nYour data will be written to a new JSONL file. Proceed [Y/n]: " 670 | 671 | if not any_remediations and not split: 672 | sys.stdout.write( 673 | f'\nYou can use your file for fine-tuning:\n> openai api fine_tunes.create -t "{fname}"{additional_params}\n\nAfter you’ve fine-tuned a model, remember that your prompt has to end with the indicator string `{common_prompt_suffix_new_line_handled}` for the model to start generating completions, rather than continuing with the prompt.{optional_ending_string}\n' 674 | ) 675 | estimate_fine_tuning_time(df) 676 | 677 | elif accept_suggestion(input_text, auto_accept): 678 | fnames = get_outfnames(fname, split) 679 | if split: 680 | assert len(fnames) == 2 and "train" in fnames[0] and "valid" in fnames[1] 681 | MAX_VALID_EXAMPLES = 1000 682 | n_train = max(len(df) - MAX_VALID_EXAMPLES, int(len(df) * 0.8)) 683 | df_train = df.sample(n=n_train, random_state=42) 684 | df_valid = df.drop(df_train.index) 685 | df_train[["prompt", "completion"]].to_json( 686 | fnames[0], lines=True, orient="records", force_ascii=False 687 | ) 688 | df_valid[["prompt", "completion"]].to_json( 689 | fnames[1], lines=True, orient="records", force_ascii=False 690 | ) 691 | 692 | n_classes, pos_class = get_classification_hyperparams(df) 693 | additional_params += " --compute_classification_metrics" 694 | if n_classes == 2: 695 | additional_params += f' --classification_positive_class "{pos_class}"' 696 | else: 697 | additional_params += f" --classification_n_classes {n_classes}" 698 | else: 699 | assert len(fnames) == 1 700 | df[["prompt", "completion"]].to_json( 701 | fnames[0], lines=True, orient="records", force_ascii=False 702 | ) 703 | 704 | # Add -v VALID_FILE if we split the file into train / valid 705 | files_string = ("s" if split else "") + " to `" + ("` and `".join(fnames)) 706 | valid_string = f' -v "{fnames[1]}"' if split else "" 707 | separator_reminder = ( 708 | "" 709 | if len(common_prompt_suffix_new_line_handled) == 0 710 | else f"After you’ve fine-tuned a model, remember that your prompt has to end with the indicator string `{common_prompt_suffix_new_line_handled}` for the model to start generating completions, rather than continuing with the prompt." 711 | ) 712 | sys.stdout.write( 713 | f'\nWrote modified file{files_string}`\nFeel free to take a look!\n\nNow use that file when fine-tuning:\n> openai api fine_tunes.create -t "{fnames[0]}"{valid_string}{additional_params}\n\n{separator_reminder}{optional_ending_string}\n' 714 | ) 715 | estimate_fine_tuning_time(df) 716 | else: 717 | sys.stdout.write("Aborting... did not write the file\n") 718 | 719 | 720 | def write_out_search_file(df, fname, any_remediations, auto_accept, fields, purpose): 721 | """ 722 | This function will write out a dataframe to a file, if the user would like to proceed. 723 | """ 724 | input_text = "\n\nYour data will be written to a new JSONL file. Proceed [Y/n]: " 725 | 726 | if not any_remediations: 727 | sys.stdout.write( 728 | f'\nYou can upload your file:\n> openai api files.create -f "{fname}" -p {purpose}' 729 | ) 730 | 731 | elif accept_suggestion(input_text, auto_accept): 732 | fnames = get_outfnames(fname, split=False) 733 | 734 | assert len(fnames) == 1 735 | df[fields].to_json(fnames[0], lines=True, orient="records", force_ascii=False) 736 | 737 | sys.stdout.write( 738 | f'\nWrote modified file to {fnames[0]}`\nFeel free to take a look!\n\nNow upload that file:\n> openai api files.create -f "{fnames[0]}" -p {purpose}' 739 | ) 740 | else: 741 | sys.stdout.write("Aborting... did not write the file\n") 742 | 743 | 744 | def infer_task_type(df): 745 | """ 746 | Infer the likely fine-tuning task type from the data 747 | """ 748 | CLASSIFICATION_THRESHOLD = 3 # min_average instances of each class 749 | if sum(df.prompt.str.len()) == 0: 750 | return "open-ended generation" 751 | 752 | if len(df.completion.unique()) < len(df) / CLASSIFICATION_THRESHOLD: 753 | return "classification" 754 | 755 | return "conditional generation" 756 | 757 | 758 | def get_common_xfix(series, xfix="suffix"): 759 | """ 760 | Finds the longest common suffix or prefix of all the values in a series 761 | """ 762 | common_xfix = "" 763 | while True: 764 | common_xfixes = ( 765 | series.str[-(len(common_xfix) + 1) :] 766 | if xfix == "suffix" 767 | else series.str[: len(common_xfix) + 1] 768 | ) # first few or last few characters 769 | if ( 770 | common_xfixes.nunique() != 1 771 | ): # we found the character at which we don't have a unique xfix anymore 772 | break 773 | elif ( 774 | common_xfix == common_xfixes.values[0] 775 | ): # the entire first row is a prefix of every other row 776 | break 777 | else: # the first or last few characters are still common across all rows - let's try to add one more 778 | common_xfix = common_xfixes.values[0] 779 | return common_xfix 780 | 781 | 782 | def get_validators(): 783 | return [ 784 | num_examples_validator, 785 | lambda x: necessary_column_validator(x, "prompt"), 786 | lambda x: necessary_column_validator(x, "completion"), 787 | additional_column_validator, 788 | non_empty_field_validator, 789 | format_inferrer_validator, 790 | duplicated_rows_validator, 791 | long_examples_validator, 792 | lambda x: lower_case_validator(x, "prompt"), 793 | lambda x: lower_case_validator(x, "completion"), 794 | common_prompt_suffix_validator, 795 | common_prompt_prefix_validator, 796 | common_completion_prefix_validator, 797 | common_completion_suffix_validator, 798 | completions_space_start_validator, 799 | ] 800 | 801 | 802 | def get_search_validators(required_fields, optional_fields): 803 | validators = [ 804 | lambda x: necessary_column_validator(x, field) for field in required_fields 805 | ] 806 | validators += [ 807 | lambda x: non_empty_field_validator(x, field) for field in required_fields 808 | ] 809 | validators += [lambda x: duplicated_rows_validator(x, required_fields)] 810 | validators += [ 811 | lambda x: additional_column_validator( 812 | x, fields=required_fields + optional_fields 813 | ), 814 | ] 815 | 816 | return validators 817 | 818 | 819 | def apply_validators( 820 | df, 821 | fname, 822 | remediation, 823 | validators, 824 | auto_accept, 825 | write_out_file_func, 826 | ): 827 | optional_remediations = [] 828 | if remediation is not None: 829 | optional_remediations.append(remediation) 830 | for validator in validators: 831 | remediation = validator(df) 832 | if remediation is not None: 833 | optional_remediations.append(remediation) 834 | df = apply_necessary_remediation(df, remediation) 835 | 836 | any_optional_or_necessary_remediations = any( 837 | [ 838 | remediation 839 | for remediation in optional_remediations 840 | if remediation.optional_msg is not None 841 | or remediation.necessary_msg is not None 842 | ] 843 | ) 844 | any_necessary_applied = any( 845 | [ 846 | remediation 847 | for remediation in optional_remediations 848 | if remediation.necessary_msg is not None 849 | ] 850 | ) 851 | any_optional_applied = False 852 | 853 | if any_optional_or_necessary_remediations: 854 | sys.stdout.write( 855 | "\n\nBased on the analysis we will perform the following actions:\n" 856 | ) 857 | for remediation in optional_remediations: 858 | df, optional_applied = apply_optional_remediation( 859 | df, remediation, auto_accept 860 | ) 861 | any_optional_applied = any_optional_applied or optional_applied 862 | else: 863 | sys.stdout.write("\n\nNo remediations found.\n") 864 | 865 | any_optional_or_necessary_applied = any_optional_applied or any_necessary_applied 866 | 867 | write_out_file_func(df, fname, any_optional_or_necessary_applied, auto_accept) 868 | --------------------------------------------------------------------------------