├── .devcontainer └── devcontainer.json ├── .dockerignore ├── .github ├── FUNDING.yml ├── pull_request_template.md └── workflows │ └── ci-cd.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── Dockerfile ├── LICENSE ├── README.md ├── docs ├── index.md ├── license.md ├── todos.md ├── usage.md └── warning.md ├── examples └── text_classification_imdb.ipynb ├── mkdocs.yml ├── pyproject.toml ├── src └── opentrain │ ├── __init__.py │ ├── dataset.py │ ├── inference.py │ ├── schemas.py │ ├── train.py │ └── typing.py └── tests ├── conftest.py ├── test_dataset.py ├── test_inference.py ├── test_schemas.py └── test_train.py /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "build": { 3 | "dockerfile": "../Dockerfile" 4 | }, 5 | "extensions": [ 6 | "ms-python.python", 7 | "ms-python.vscode-pylance", 8 | "eamodio.gitlens", 9 | "bungcip.better-toml", 10 | "ms-azuretools.vscode-docker", 11 | "zhuangtongfa.material-theme", 12 | "redhat.vscode-yaml", 13 | "GitHub.copilot", 14 | "znck.grammarly" 15 | ] 16 | } -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .venv 2 | .mypy_cache 3 | .pytest_cache 4 | __pycache__ 5 | .hatch -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: alvarobartt 2 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## ✨ Features 2 | 3 | - List 4 | - implemented 5 | - features 6 | - here 7 | 8 | ## 🐛 Bug Fixes 9 | 10 | - Listed 11 | - fixed 12 | - bugs 13 | - here 14 | 15 | ## 🔗 Linked Issue/s 16 | 17 | Add here the reference to the issue/s referenced in this PR 18 | 19 | ## 🧪 Tests 20 | 21 | - [ ] Did you implement unit tests if required? 22 | 23 | If the above checkbox is checked, describe how you unit-tested it. -------------------------------------------------------------------------------- /.github/workflows/ci-cd.yaml: -------------------------------------------------------------------------------- 1 | name: ci-cd 2 | 3 | on: 4 | workflow_dispatch: 5 | pull_request: 6 | branches: 7 | - main 8 | push: 9 | branches: 10 | - main 11 | paths: 12 | - .github/workflows/ci-cd.yaml 13 | - src/** 14 | - tests/** 15 | release: 16 | types: 17 | - published 18 | 19 | jobs: 20 | check-quality: 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - name: checkout 25 | uses: actions/checkout@v3 26 | 27 | - name: setup-python 28 | uses: actions/setup-python@v4 29 | with: 30 | python-version: 3.8 31 | 32 | - name: install-dependencies 33 | run: | 34 | python -m pip install --upgrade pip 35 | pip install ".[quality]" 36 | 37 | - name: check-quality 38 | run: | 39 | black --check --diff --preview src tests 40 | ruff src tests 41 | 42 | run-tests: 43 | needs: check-quality 44 | 45 | runs-on: ubuntu-latest 46 | 47 | steps: 48 | - name: checkout 49 | uses: actions/checkout@v3 50 | 51 | - name: setup-python 52 | uses: actions/setup-python@v4 53 | with: 54 | python-version: 3.8 55 | 56 | - name: install-dependencies 57 | run: | 58 | python -m pip install --upgrade pip 59 | pip install ".[tests]" 60 | 61 | - name: run-tests 62 | run: pytest tests/ -s --durations 0 63 | env: 64 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} 65 | 66 | deploy-docs: 67 | needs: run-tests 68 | if: github.event_name == 'release' 69 | 70 | runs-on: ubuntu-latest 71 | 72 | steps: 73 | - name: checkout 74 | uses: actions/checkout@v3 75 | 76 | - name: setup-python 77 | uses: actions/setup-python@v4 78 | with: 79 | python-version: 3.8 80 | 81 | - name: install-dependencies 82 | run: | 83 | python -m pip install --upgrade pip 84 | pip install -e ".[docs]" 85 | 86 | - name: deploy-to-gh-pages 87 | run: mkdocs gh-deploy --force 88 | 89 | publish-package: 90 | needs: deploy-docs 91 | if: github.event_name == 'release' 92 | 93 | runs-on: ubuntu-latest 94 | 95 | steps: 96 | - name: checkout 97 | uses: actions/checkout@v3 98 | 99 | - name: setup-python 100 | uses: actions/setup-python@v4 101 | with: 102 | python-version: 3.8 103 | 104 | - name: install-dependencies 105 | run: | 106 | python -m pip install --upgrade pip 107 | pip install hatch 108 | 109 | - name: build-package 110 | run: hatch build 111 | 112 | - name: publish-package 113 | run: hatch publish --user __token__ --auth $PYPI_TOKEN 114 | env: 115 | PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} 116 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # PyCharm default idea folder 107 | .idea/ 108 | 109 | # VSCode Files 110 | .vscode/ 111 | 112 | # Hatch files 113 | .hatch/ 114 | 115 | # Ruff cache 116 | .ruff_cache/ 117 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: "v4.3.0" 4 | hooks: 5 | - id: check-added-large-files 6 | - id: check-toml 7 | - id: check-yaml 8 | 9 | - repo: https://github.com/psf/black 10 | rev: 22.10.0 11 | hooks: 12 | - id: black 13 | args: ["--preview"] 14 | language_version: python3 15 | 16 | - repo: https://github.com/charliermarsh/ruff-pre-commit 17 | rev: "v0.0.263" 18 | hooks: 19 | - id: ruff 20 | args: [--fix] 21 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.9-slim 2 | 3 | RUN apt-get update \ 4 | && apt-get install build-essential git -y --no-install-recommends 5 | 6 | ENV PYTHONUNBUFFERED=1 7 | 8 | RUN python -m pip install pip --upgrade \ 9 | && python -m pip install hatch 10 | 11 | COPY . . 12 | 13 | # https://github.com/gitpod-io/gitpod/issues/1997 14 | ENV PIP_USER false 15 | 16 | RUN python -m hatch env create 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023-present Alvaro Bartolome 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

opentrain

3 |

4 | 🚂 Fine-tune OpenAI models for text classification, question answering, and more 5 |

6 |
7 | 8 | --- 9 | 10 | `opentrain` is a simple Python package to fine-tune OpenAI models for task-specific purposes such as text classification, token classification, or question answering. 11 | 12 | More information about OpenAI Fine-tuning at https://platform.openai.com/docs/guides/fine-tuning. 13 | 14 | ## 💻 Usage 15 | 16 | ### 📦 Data management 17 | 18 | ```python 19 | import openai 20 | from opentrain import Dataset 21 | 22 | openai.api_key = "" 23 | 24 | dataset = Dataset.from_file("data.jsonl") 25 | dataset.info 26 | dataset.download(output_path="downloaded-data.jsonl") 27 | ``` 28 | 29 | ### 🦾 Fine-tune 30 | 31 | ```python 32 | import openai 33 | from opentrain import Train 34 | 35 | openai.api_key = "" 36 | 37 | trainer = Train(model="ada") 38 | trainer.train( 39 | [ 40 | { 41 | "prompt": "I love to play soccer ->", 42 | "completion": " soccer", 43 | }, 44 | { 45 | "prompt": "I love to play basketball ->", 46 | "completion": " basketball", 47 | }, 48 | ], 49 | ) 50 | ``` 51 | 52 | ### 🤖 Predict 53 | 54 | ```python 55 | import openai 56 | from opentrain import Inference 57 | 58 | openai.api_key = "" 59 | 60 | predict = Inference(model="ada:ft-personal-2021-03-01-00-00-01") 61 | predict.predict("I love to play ->") 62 | ``` 63 | 64 | ## ⚠️ Warning 65 | 66 | Fine-tuning OpenAI models via their API may take too long, so please be patient. Also, bear in mind 67 | that in some cases you just won't need to fine-tune an OpenAI model for your task. 68 | 69 | To keep track of all the models you fine-tuned, you should visit https://platform.openai.com/account/usage, 70 | and then in the "Daily usage breakdown (UTC)" you'll need to select the date where you triggered the 71 | fine-tuning and click on "Fine-tune training" to see all the fine-tune training requests that you sent. 72 | 73 | Besides that, in the OpenAI Playground at https://platform.openai.com/playground, you'll see a dropdown 74 | menu for all the available models, both the default ones and the ones you fine-tuned. Usually, in the 75 | following format `:ft-personal-`, e.g. `ada:ft-personal-2021-03-01-00-00-01`. 76 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 |
2 |

opentrain

3 |

4 | 🚂 Fine-tune OpenAI models for text classification, question answering, and more 5 |

6 |
7 | 8 | --- 9 | 10 | `opentrain` is a simple Python package to fine-tune OpenAI models for task-specific purposes such as text classification, token classification, or question answering. 11 | 12 | More information about OpenAI Fine-tuning at https://platform.openai.com/docs/guides/fine-tuning. 13 | -------------------------------------------------------------------------------- /docs/license.md: -------------------------------------------------------------------------------- 1 | # 📄 License 2 | 3 | MIT License 4 | 5 | Copyright (c) 2023-present Alvaro Bartolome 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | -------------------------------------------------------------------------------- /docs/todos.md: -------------------------------------------------------------------------------- 1 | # 🔮 v0.2.0 - TODOs 2 | 3 | - [ ] Add `Typer` CLI e.g. `opentrain train ...` 4 | - [ ] Add `Dataset` validation before actually uploading a `Dataset`/`File` to OpenAI. 5 | - [ ] Add `Dataset.from_datasets`, `Dataset.to_datasets`, and `Dataset.to_records`. 6 | - [ ] Add `fsspec` support for `Dataset.from_file`, and `Dataset.to_file`. 7 | - [ ] Allow different input paths such as `pathlib.Path` or `os.path` in `Dataset.from_file`. 8 | - [ ] Explore https://github.com/openai/openai-python/blob/c556584eff3b36c92278e6af62cfe02ebb68fb65/openai/api_resources/file.py#L218 to avoid uploading duplicated files to OpenAI. 9 | - [ ] Add `Trainer.for_text_classification`, `Trainer.for_question_answering`, `Trainer.for_text_summarization`, and more if applicable. 10 | - [ ] Add `wandb` as an optional dependency for tracking fine-tune runs. 11 | - [ ] Explore automatically uploaded files to OpenAI after fine-tuning with `purpose='fine-tune-results'`. 12 | - [ ] Differentiate between both file-purposes `fine-tune` and `fine-tune-results`. 13 | -------------------------------------------------------------------------------- /docs/usage.md: -------------------------------------------------------------------------------- 1 | # 💻 Usage 2 | 3 | ## 📦 Data management 4 | 5 | ```python 6 | import openai 7 | from opentrain import Dataset 8 | 9 | openai.api_key = "" 10 | 11 | dataset = Dataset.from_file("data.jsonl") 12 | dataset.info 13 | dataset.download(output_path="downloaded-data.jsonl") 14 | ``` 15 | 16 | ## 🦾 Fine-tune 17 | 18 | ```python 19 | import openai 20 | from opentrain import Train 21 | 22 | openai.api_key = "" 23 | 24 | trainer = Train(model="ada") 25 | trainer.train( 26 | [ 27 | { 28 | "prompt": "I love to play soccer ->", 29 | "completion": " soccer", 30 | }, 31 | { 32 | "prompt": "I love to play basketball ->", 33 | "completion": " basketball", 34 | }, 35 | ], 36 | ) 37 | ``` 38 | 39 | ## 🤖 Predict 40 | 41 | ```python 42 | import openai 43 | from opentrain import Inference 44 | 45 | openai.api_key = "" 46 | 47 | predict = Inference(model="ada:ft-personal-2021-03-01-00-00-01") 48 | predict.predict("I love to play ->") 49 | ``` 50 | -------------------------------------------------------------------------------- /docs/warning.md: -------------------------------------------------------------------------------- 1 | # ⚠️ Warning 2 | 3 | Fine-tuning OpenAI models via their API may take too long, so please be patient. Also, bear in mind 4 | that in some cases you just won't need to fine-tune an OpenAI model for your task. 5 | 6 | To keep track of all the models you fine-tuned, you should visit https://platform.openai.com/account/usage, 7 | and then in the "Daily usage breakdown (UTC)" you'll need to select the date where you triggered the 8 | fine-tuning and click on "Fine-tune training" to see all the fine-tune training requests that you sent. 9 | 10 | Besides that, in the OpenAI Playground at https://platform.openai.com/playground, you'll see a dropdown 11 | menu for all the available models, both the default ones and the ones you fine-tuned. Usually, in the 12 | following format `:ft-personal-`, e.g. `ada:ft-personal-2021-03-01-00-00-01`. 13 | -------------------------------------------------------------------------------- /examples/text_classification_imdb.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 13, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "!pip install --upgrade pip -q\n", 10 | "!pip install opentrain datasets -q" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 14, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "from datasets import load_dataset" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 38, 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "name": "stderr", 29 | "output_type": "stream", 30 | "text": [ 31 | "Found cached dataset imdb (/Users/alvarobartt/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)\n", 32 | "Loading cached shuffled indices for dataset at /Users/alvarobartt/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0/cache-9c48ce5d173413c7.arrow\n" 33 | ] 34 | } 35 | ], 36 | "source": [ 37 | "ds = load_dataset(\"imdb\", split=\"train\").shuffle(seed=42)" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 39, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "label2idx = {label: idx for idx, label in enumerate(ds.features[\"label\"].names)}\n", 47 | "idx2label = {idx: label for label, idx in label2idx.items()}" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 40, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "ds = ds.select(range(1000))" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": 41, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "PROMPT_TEMPLATE = \"Categorize the following text from an IMDB review in the following categories based on its sentiment: 'pos', or 'neg'.\\n\\nReview: {text}\\n\\nLabel: \"" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 42, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "def text_and_label_to_prompt_and_completion(example) -> dict:\n", 75 | " return {\n", 76 | " \"prompt\": PROMPT_TEMPLATE.format(text=example[\"text\"]),\n", 77 | " \"completion\": idx2label[example[\"label\"]],\n", 78 | " }" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 43, 84 | "metadata": {}, 85 | "outputs": [ 86 | { 87 | "data": { 88 | "application/vnd.jupyter.widget-view+json": { 89 | "model_id": "629fdc06f6564cd1b3d8b5faaddba758", 90 | "version_major": 2, 91 | "version_minor": 0 92 | }, 93 | "text/plain": [ 94 | "Map: 0%| | 0/1000 [00:00\"" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 54, 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "from opentrain import Train" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 55, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "trainer = Train(model=\"curie\")" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 56, 171 | "metadata": {}, 172 | "outputs": [ 173 | { 174 | "name": "stderr", 175 | "output_type": "stream", 176 | "text": [ 177 | "/var/folders/tq/d6srhsd134l2b6x4sj01hpqr0000gn/T/ipykernel_14681/526391229.py:1: UserWarning: Since the OpenAI API may take from minutes to hours depending on the size of the training data, then from now on, you'll be able to check its progress via the following command: `openai api fine_tunes.follow -i ft-nbCAuynTlICBi7HfJ82XFB9F`. Once the training is completed, then you'll be able to use `Inference` with the either the fine tune id returned, or from the model name generated by OpenAI linked to your account.\n", 178 | " fine_tune_id = trainer.train(\"imdb.jsonl\", n_epochs=5, batch_size=64)\n" 179 | ] 180 | } 181 | ], 182 | "source": [ 183 | "fine_tune_id = trainer.train(\"imdb.jsonl\", n_epochs=5, batch_size=64)" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 57, 189 | "metadata": {}, 190 | "outputs": [ 191 | { 192 | "data": { 193 | "text/plain": [ 194 | "'ft-nbCAuynTlICBi7HfJ82XFB9F'" 195 | ] 196 | }, 197 | "execution_count": 57, 198 | "metadata": {}, 199 | "output_type": "execute_result" 200 | } 201 | ], 202 | "source": [ 203 | "fine_tune_id" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 60, 209 | "metadata": {}, 210 | "outputs": [ 211 | { 212 | "name": "stdout", 213 | "output_type": "stream", 214 | "text": [ 215 | "{\n", 216 | " \"created_at\": 1681907691,\n", 217 | " \"level\": \"info\",\n", 218 | " \"message\": \"Created fine-tune: ft-nbCAuynTlICBi7HfJ82XFB9F\",\n", 219 | " \"object\": \"fine-tune-event\"\n", 220 | "}\n", 221 | "{\n", 222 | " \"created_at\": 1681907801,\n", 223 | " \"level\": \"info\",\n", 224 | " \"message\": \"Fine-tune costs $4.99\",\n", 225 | " \"object\": \"fine-tune-event\"\n", 226 | "}\n", 227 | "{\n", 228 | " \"created_at\": 1681907801,\n", 229 | " \"level\": \"info\",\n", 230 | " \"message\": \"Fine-tune enqueued. Queue number: 0\",\n", 231 | " \"object\": \"fine-tune-event\"\n", 232 | "}\n", 233 | "{\n", 234 | " \"created_at\": 1681907802,\n", 235 | " \"level\": \"info\",\n", 236 | " \"message\": \"Fine-tune started\",\n", 237 | " \"object\": \"fine-tune-event\"\n", 238 | "}\n", 239 | "{\n", 240 | " \"created_at\": 1681907902,\n", 241 | " \"level\": \"info\",\n", 242 | " \"message\": \"Completed epoch 1/5\",\n", 243 | " \"object\": \"fine-tune-event\"\n", 244 | "}\n", 245 | "{\n", 246 | " \"created_at\": 1681907935,\n", 247 | " \"level\": \"info\",\n", 248 | " \"message\": \"Completed epoch 2/5\",\n", 249 | " \"object\": \"fine-tune-event\"\n", 250 | "}\n", 251 | "{\n", 252 | " \"created_at\": 1681907968,\n", 253 | " \"level\": \"info\",\n", 254 | " \"message\": \"Completed epoch 3/5\",\n", 255 | " \"object\": \"fine-tune-event\"\n", 256 | "}\n", 257 | "{\n", 258 | " \"created_at\": 1681908000,\n", 259 | " \"level\": \"info\",\n", 260 | " \"message\": \"Completed epoch 4/5\",\n", 261 | " \"object\": \"fine-tune-event\"\n", 262 | "}\n", 263 | "{\n", 264 | " \"created_at\": 1681908032,\n", 265 | " \"level\": \"info\",\n", 266 | " \"message\": \"Completed epoch 5/5\",\n", 267 | " \"object\": \"fine-tune-event\"\n", 268 | "}\n", 269 | "{\n", 270 | " \"created_at\": 1681908056,\n", 271 | " \"level\": \"info\",\n", 272 | " \"message\": \"Uploaded model: curie:ft-personal-2023-04-19-12-40-56\",\n", 273 | " \"object\": \"fine-tune-event\"\n", 274 | "}\n", 275 | "{\n", 276 | " \"created_at\": 1681908057,\n", 277 | " \"level\": \"info\",\n", 278 | " \"message\": \"Uploaded result file: file-VfbjfQXHYzJ8Qhi9Ls5tPlKC\",\n", 279 | " \"object\": \"fine-tune-event\"\n", 280 | "}\n", 281 | "{\n", 282 | " \"created_at\": 1681908057,\n", 283 | " \"level\": \"info\",\n", 284 | " \"message\": \"Fine-tune succeeded\",\n", 285 | " \"object\": \"fine-tune-event\"\n", 286 | "}\n" 287 | ] 288 | } 289 | ], 290 | "source": [ 291 | "for event in trainer.track():\n", 292 | " print(event)" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": 61, 298 | "metadata": {}, 299 | "outputs": [], 300 | "source": [ 301 | "from opentrain import Inference" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": 62, 307 | "metadata": {}, 308 | "outputs": [], 309 | "source": [ 310 | "predict = Inference.from_fine_tune_id(fine_tune_id)" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": 63, 316 | "metadata": {}, 317 | "outputs": [ 318 | { 319 | "name": "stderr", 320 | "output_type": "stream", 321 | "text": [ 322 | "Found cached dataset imdb (/Users/alvarobartt/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)\n" 323 | ] 324 | } 325 | ], 326 | "source": [ 327 | "ds = load_dataset(\"imdb\", split=\"test\").shuffle(seed=42)" 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "execution_count": 67, 333 | "metadata": {}, 334 | "outputs": [], 335 | "source": [ 336 | "completion = predict(PROMPT_TEMPLATE.format(text=ds[0][\"text\"]))" 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "execution_count": 68, 342 | "metadata": {}, 343 | "outputs": [ 344 | { 345 | "data": { 346 | "text/plain": [ 347 | "True" 348 | ] 349 | }, 350 | "execution_count": 68, 351 | "metadata": {}, 352 | "output_type": "execute_result" 353 | } 354 | ], 355 | "source": [ 356 | "idx2label[ds[0][\"label\"]] == completion" 357 | ] 358 | } 359 | ], 360 | "metadata": { 361 | "kernelspec": { 362 | "display_name": "Python 3", 363 | "language": "python", 364 | "name": "python3" 365 | }, 366 | "language_info": { 367 | "codemirror_mode": { 368 | "name": "ipython", 369 | "version": 3 370 | }, 371 | "file_extension": ".py", 372 | "mimetype": "text/x-python", 373 | "name": "python", 374 | "nbconvert_exporter": "python", 375 | "pygments_lexer": "ipython3", 376 | "version": "3.9.16" 377 | }, 378 | "orig_nbformat": 4 379 | }, 380 | "nbformat": 4, 381 | "nbformat_minor": 2 382 | } 383 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: opentrain 2 | site_url: https://github.com/alvarobartt/opentrain 3 | site_author: Alvaro Bartolome 4 | site_description: 🚂 Fine-tune OpenAI models for text classification, question answering, and more 5 | 6 | repo_name: alvarobartt/opentrain 7 | repo_url: https://github.com/alvarobartt/opentrain 8 | 9 | copyright: Copyright (c) 2023-present Alvaro Bartolome 10 | 11 | theme: 12 | name: material 13 | palette: 14 | - scheme: default 15 | toggle: 16 | icon: material/brightness-7 17 | name: Switch to dark mode 18 | - scheme: slate 19 | toggle: 20 | icon: material/brightness-4 21 | name: Switch to light mode 22 | font: 23 | text: Roboto 24 | code: Roboto Mono 25 | 26 | markdown_extensions: 27 | - pymdownx.highlight: 28 | anchor_linenums: true 29 | - pymdownx.superfences 30 | 31 | plugins: 32 | - search: 33 | - git-revision-date-localized: 34 | type: timeago 35 | enable_creation_date: true 36 | - mkdocstrings: 37 | 38 | extra: 39 | social: 40 | - icon: fontawesome/brands/python 41 | link: https://pypi.org/project/opentrain/ 42 | - icon: fontawesome/brands/github 43 | link: https://github.com/alvarobartt 44 | - icon: fontawesome/brands/twitter 45 | link: https://twitter.com/alvarobartt 46 | - icon: fontawesome/brands/linkedin 47 | link: https://www.linkedin.com/in/alvarobartt/ 48 | 49 | nav: 50 | - Home: index.md 51 | - Usage: usage.md 52 | - Warning: warning.md 53 | - TODOs: todos.md 54 | - License: license.md 55 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | build-backend = "hatchling.build" 3 | requires = ["hatchling"] 4 | 5 | [project] 6 | authors = [{name = "Alvaro Bartolome", email = "alvarobartt@gmail.com"}] 7 | classifiers = [ 8 | "Development Status :: 4 - Beta", 9 | "Programming Language :: Python", 10 | "Programming Language :: Python :: 3.8", 11 | "Programming Language :: Python :: 3.9", 12 | "Programming Language :: Python :: 3.10", 13 | "Programming Language :: Python :: Implementation :: CPython", 14 | "Programming Language :: Python :: Implementation :: PyPy", 15 | ] 16 | dependencies = ["openai~=0.27.4"] 17 | description = "🚂 Fine-tune OpenAI models for text classification, question answering, and more" 18 | dynamic = ["version"] 19 | keywords = [] 20 | license = "MIT" 21 | name = "opentrain" 22 | readme = "README.md" 23 | requires-python = ">=3.8,<3.11" 24 | 25 | [project.urls] 26 | Documentation = "https://alvarobartt.github.io/opentrain" 27 | Issues = "https://github.com/alvarobartt/opentrain/issues" 28 | Source = "https://github.com/alvarobartt/opentrain" 29 | 30 | [tool.hatch.version] 31 | path = "src/opentrain/__init__.py" 32 | 33 | [project.optional-dependencies] 34 | docs = [ 35 | "mkdocs~=1.4.0", 36 | "mkdocs-material~=8.5.4", 37 | "mkdocs-git-revision-date-localized-plugin~=1.1.0", 38 | "mkdocstrings[python]~=0.19.0", 39 | ] 40 | pydantic = ["pydantic>=1.10,<2"] 41 | quality = [ 42 | "black~=22.10.0", 43 | "ruff~=0.0.263", 44 | "pre-commit~=2.20.0", 45 | ] 46 | tests = [ 47 | "pytest~=7.1.2", 48 | ] 49 | 50 | [tool.hatch.envs.quality] 51 | features = [ 52 | "quality", 53 | ] 54 | 55 | [tool.hatch.envs.quality.scripts] 56 | check = [ 57 | "black --check --diff --preview src tests", 58 | "ruff src tests", 59 | ] 60 | style = [ 61 | "black --preview src tests", 62 | "ruff --fix src tests", 63 | "check", 64 | ] 65 | 66 | [tool.ruff] 67 | ignore = [ 68 | "E501", # line too long, handled by black 69 | "B008", # do not perform function calls in argument defaults 70 | "C901", # too complex 71 | ] 72 | select = [ 73 | "E", # pycodestyle errors 74 | "W", # pycodestyle warnings 75 | "F", # pyflakes 76 | "I", # isort 77 | "C", # flake8-comprehensions 78 | "B", # flake8-bugbear 79 | ] 80 | 81 | [tool.ruff.isort] 82 | known-first-party = ["opentrain"] 83 | 84 | [tool.hatch.envs.test] 85 | features = [ 86 | "tests", 87 | ] 88 | 89 | [tool.hatch.envs.test.scripts] 90 | run = "pytest tests/ --durations 0 -s" 91 | 92 | [[tool.hatch.envs.test.matrix]] 93 | python = ["38", "39", "310"] 94 | 95 | [tool.hatch.envs.docs] 96 | features = [ 97 | "docs", 98 | ] 99 | 100 | [tool.hatch.envs.docs.scripts] 101 | build = [ 102 | "mkdocs build", 103 | ] 104 | serve = [ 105 | "mkdocs serve", 106 | ] 107 | 108 | [tool.hatch.build.targets.sdist] 109 | exclude = [ 110 | "/.github", 111 | "/.vscode", 112 | "/docs", 113 | "/.devcontainer", 114 | "/.pre-commit-config.yaml", 115 | "/.gitignore", 116 | "/tests", 117 | "/Dockerfile", 118 | "/.dockerignore", 119 | ] 120 | -------------------------------------------------------------------------------- /src/opentrain/__init__.py: -------------------------------------------------------------------------------- 1 | """`opentrain `: 🚂 Fine-tune OpenAI models for text classification, question answering, and more""" 2 | 3 | __author__ = "Alvaro Bartolome " 4 | __version__ = "0.1.0" 5 | 6 | from opentrain.dataset import Dataset, File, list_datasets, list_files 7 | from opentrain.inference import Inference, list_fine_tunes 8 | from opentrain.train import FineTune, Train 9 | 10 | __all__ = [ 11 | "Dataset", 12 | "File", 13 | "list_datasets", 14 | "list_files", 15 | "Inference", 16 | "list_fine_tunes", 17 | "Train", 18 | "FineTune", 19 | ] 20 | -------------------------------------------------------------------------------- /src/opentrain/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import warnings 3 | from functools import cached_property 4 | from pathlib import Path 5 | from time import sleep 6 | from typing import Any, Dict, List, Union 7 | from uuid import uuid4 8 | 9 | import openai 10 | from openai.error import TryAgain 11 | 12 | FILE_SIZE_WARNING = 500 * 1024 * 1024 13 | 14 | 15 | class Dataset: 16 | """The `Dataset` class is not just a wrapper around OpenAI's File API, but it also 17 | provides some useful methods to work with datasets. 18 | 19 | Args: 20 | file_id: the ID of the file previously uploaded to OpenAI. 21 | organization: the OpenAI organization name. Defaults to None. 22 | 23 | Attributes: 24 | file_id: the ID of the file previously uploaded to OpenAI. 25 | organization: the OpenAI organization name. 26 | info: the information of the file. 27 | 28 | Examples: 29 | >>> from opentrain import Dataset 30 | >>> dataset = Dataset(file_id="file-1234") 31 | >>> dataset.info 32 | >>> content = dataset.download() 33 | >>> dataset.delete() 34 | """ 35 | 36 | def __init__(self, file_id: str, organization: Union[str, None] = None) -> None: 37 | """Initializes the `Dataset` class. 38 | 39 | Args: 40 | file_id: the ID of the file previously uploaded to OpenAI. 41 | organization: the OpenAI organization name. Defaults to None. 42 | """ 43 | self.file_id = file_id 44 | self.organization = organization 45 | 46 | @cached_property 47 | def info(self) -> Dict[str, Any]: 48 | """Returns the information of the file uploaded to OpenAI. 49 | 50 | Returns: 51 | A dictionary with the information of the file. 52 | """ 53 | return openai.File.retrieve(id=self.file_id, organization=self.organization) 54 | 55 | def download(self) -> bytes: 56 | """Downloads the file from OpenAI. 57 | 58 | Returns: 59 | The content of the file as bytes. 60 | """ 61 | warnings.warn( 62 | "Dataset.download() is just available for paid/pro accounts, so bear in" 63 | " mind that this will fail if you're using a free tier.", 64 | stacklevel=2, 65 | ) 66 | return openai.File.download(id=self.file_id, organization=self.organization) 67 | 68 | def to_file(self, output_path: str) -> None: 69 | """Downloads the file from OpenAI and saves it to the specified path. 70 | 71 | Args: 72 | output_path: the path where the file will be saved. 73 | """ 74 | content = self.download() 75 | with open(output_path, "wb") as f: 76 | f.write(content) 77 | del content 78 | 79 | def delete(self) -> None: 80 | """Deletes the file from OpenAI.""" 81 | file_deleted = False 82 | while file_deleted is False: 83 | try: 84 | openai.File.delete( 85 | sid=self.file_id, organization=self.organization, request_timeout=10 86 | ) 87 | file_deleted = True 88 | except TryAgain: 89 | sleep(1) 90 | 91 | @classmethod 92 | def from_file( 93 | cls, 94 | file_path: str, 95 | file_name: Union[str, None] = None, 96 | organization: Union[str, None] = None, 97 | ) -> "Dataset": 98 | """Uploads a file to OpenAI and returns a `Dataset` object. 99 | 100 | Args: 101 | file_path: the path of the file to be uploaded. 102 | file_name: the name of the file to be defined in OpenAI. Defaults to None. 103 | organization: the OpenAI organization name. Defaults to None. 104 | 105 | Returns: 106 | A `Dataset` object. 107 | """ 108 | upload_response = openai.File.create( 109 | file=open(file_path, "rb"), 110 | organization=organization, 111 | purpose="fine-tune", 112 | user_provided_filename=file_name, 113 | ) 114 | return cls(file_id=upload_response.id, organization=organization) 115 | 116 | @classmethod 117 | def from_records( 118 | cls, 119 | records: List[Dict[str, str]], 120 | file_name: Union[str, None] = None, 121 | organization: Union[str, None] = None, 122 | ) -> "Dataset": 123 | """Uploads a list of records to OpenAI and returns a `Dataset` object. Note 124 | that this function saves it first to a local file and then uploads it to 125 | OpenAI. 126 | 127 | Args: 128 | records: a list of dictionaries with the records to be uploaded. 129 | file_name: the name of the file to be defined in OpenAI. Defaults to None. 130 | organization: the OpenAI organization name. Defaults to None. 131 | 132 | Returns: 133 | A `Dataset` object. 134 | """ 135 | local_path = ( 136 | Path.home() / ".cache" / "opentrain" / f"{file_name or uuid4()}.jsonl" 137 | ) 138 | local_path.parent.mkdir(parents=True, exist_ok=True) 139 | 140 | with open(local_path.as_posix(), "w") as f: 141 | for record in records: 142 | json.dump(record, f) 143 | f.write("\n") 144 | 145 | if local_path.stat().st_size > FILE_SIZE_WARNING: 146 | warnings.warn( 147 | f"Your file is larger than {FILE_SIZE_WARNING / 1024 / 1024} MB, and" 148 | " the maximum total upload file size in OpenAI is 1GB, so please be" 149 | " aware that if you already have uploaded files to OpenAI this might" 150 | " fail. If you need to upload larger files or require more space," 151 | " please contact OpenAI as suggested at" 152 | " https://platform.openai.com/docs/api-reference/files/upload.", 153 | stacklevel=2, 154 | ) 155 | 156 | upload_response = openai.File.create( 157 | file=open(local_path.as_posix(), "rb"), 158 | organization=organization, 159 | purpose="fine-tune", 160 | user_provided_filename=file_name, 161 | ) 162 | return cls(file_id=upload_response.id, organization=organization) 163 | 164 | 165 | class File(Dataset): 166 | """This class is just a wrapper around `Dataset` with the same functionality. It's 167 | just here to keep the same naming convention as OpenAI.""" 168 | 169 | pass 170 | 171 | 172 | def list_datasets(organization: Union[str, None] = None) -> List[Dataset]: 173 | """Lists the datasets uploaded to your OpenAI or your organization's account. 174 | 175 | Args: 176 | organization: the OpenAI organization name. Defaults to None. 177 | 178 | Returns: 179 | A list of `Dataset` objects. 180 | """ 181 | return [ 182 | Dataset(file_id=file["id"], organization=organization) 183 | for file in openai.File.list(organization=organization)["data"] 184 | ] 185 | 186 | 187 | def list_files(organization: Union[str, None] = None) -> List[File]: 188 | """This function is just a wrapper around `list_datasets` with the same 189 | functionality. It's just here to keep the same naming convention as OpenAI. 190 | 191 | Args: 192 | organization: the OpenAI organization name. Defaults to None. 193 | 194 | Returns: 195 | A list of `File` objects. 196 | """ 197 | return [ 198 | File(file_id=file["id"], organization=organization) 199 | for file in openai.File.list(organization=organization)["data"] 200 | ] 201 | -------------------------------------------------------------------------------- /src/opentrain/inference.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import List 3 | 4 | import openai 5 | 6 | from opentrain.schemas import FineTune 7 | 8 | warnings.simplefilter("once", category=UserWarning) 9 | 10 | 11 | class Inference: 12 | """The `Inference` class is a wrapper around OpenAI's Completion API, making it easy 13 | to generate completions for any given prompt, using an OpenAI fine-tuned model. 14 | 15 | Args: 16 | model: the name of the OpenAI model to use for the inference. 17 | 18 | Attributes: 19 | model: the name of the OpenAI model to use for the inference. 20 | 21 | Examples: 22 | >>> from opentrain import Inference 23 | >>> inference = Inference(model="curie:ft-personal-") 24 | >>> inference(prompt="This is a sample prompt.") 25 | 'This is a sample completion.' 26 | """ 27 | 28 | def __init__(self, model: str) -> None: 29 | """Initializes the `Inference` class. 30 | 31 | Args: 32 | model: the name of the OpenAI model to use for the inference. 33 | """ 34 | self.model = model 35 | 36 | def __call__(self, prompt: str, **kwargs) -> str: 37 | """Generates the completion for a given prompt. 38 | 39 | Args: 40 | prompt: the prompt to generate the completion for. Should be aligned 41 | with the one used/defined for the fine-tuning, if applicable. 42 | **kwargs: the keyword arguments to pass to the OpenAI API. See 43 | https://platform.openai.com/docs/api-reference/completions/create. 44 | 45 | Returns: 46 | The completion for the given prompt. 47 | """ 48 | kwargs.setdefault("temperature", 0.0) 49 | if kwargs["temperature"] != 0: 50 | warnings.warn( 51 | f"The `temperature` parameter is set to {kwargs['temperature']}," 52 | " instead of 0. That means the completion for a given prompt will be" 53 | " random, so the suggestion on fine-tuned models, unless desired" 54 | " otherwise, is to set the `temperature` to 0 so that the model is" 55 | " almost deterministic.", 56 | UserWarning, 57 | stacklevel=2, 58 | ) 59 | response = openai.Completion.create( 60 | model=self.model, 61 | prompt=prompt, 62 | **kwargs, 63 | ) 64 | return response.choices[0].text 65 | 66 | @classmethod 67 | def from_fine_tune_id(cls, fine_tune_id: str) -> "Inference": 68 | """Returns an `Inference` object from an OpenAI fine-tune ID. 69 | 70 | Args: 71 | fine_tune_id: the ID of the OpenAI fine-tune to use for the inference. 72 | 73 | Returns: 74 | An `Inference` object. 75 | """ 76 | model = openai.FineTune.retrieve(fine_tune_id).fine_tuned_model 77 | if model is None: 78 | raise ValueError( 79 | "The model is not fine-tuned yet! Please wait a few minutes and try" 80 | " again." 81 | ) 82 | return cls(model=model) 83 | 84 | 85 | def list_fine_tunes() -> List[FineTune]: 86 | """List all fine-tuned models in your OpenAI account. 87 | 88 | Returns: 89 | A list of OpenAI fine-tunes, as `FineTune` objects. 90 | """ 91 | return [FineTune(**fine_tune) for fine_tune in openai.FineTune.list()["data"]] 92 | -------------------------------------------------------------------------------- /src/opentrain/schemas.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Union 3 | 4 | try: 5 | from pydantic import BaseModel 6 | 7 | has_pydantic = True 8 | except ImportError: 9 | has_pydantic = False 10 | 11 | if has_pydantic: 12 | 13 | class _HyperParams(BaseModel): 14 | batch_size: int 15 | learning_rate_multiplier: float 16 | n_epochs: int 17 | prompt_loss_weight: float 18 | 19 | class _File(BaseModel): 20 | bytes: int 21 | created_at: int 22 | filename: str 23 | id: str 24 | object: str 25 | purpose: str 26 | status: str 27 | status_details: Union[str, None] 28 | 29 | class FineTune(BaseModel): 30 | created_at: int 31 | fine_tuned_model: Union[str, None] 32 | hyperparams: _HyperParams 33 | id: str 34 | model: str 35 | object: str 36 | organization_id: str 37 | result_files: list 38 | status: str 39 | training_files: List[_File] 40 | updated_at: int 41 | validation_files: List[_File] 42 | 43 | class PromptCompletion(BaseModel): 44 | prompt: str 45 | completion: str 46 | 47 | else: 48 | 49 | @dataclass 50 | class _HyperParams: 51 | batch_size: int 52 | learning_rate_multiplier: float 53 | n_epochs: int 54 | prompt_loss_weight: float 55 | 56 | @dataclass 57 | class _File: 58 | bytes: int 59 | created_at: int 60 | filename: str 61 | id: str 62 | object: str 63 | purpose: str 64 | status: str 65 | status_details: Union[str, None] 66 | 67 | @dataclass 68 | class FineTune: 69 | created_at: int 70 | fine_tuned_model: Union[str, None] 71 | hyperparams: _HyperParams 72 | id: str 73 | model: str 74 | object: str 75 | organization_id: str 76 | result_files: list 77 | status: str 78 | training_files: List[_File] 79 | updated_at: int 80 | validation_files: List[_File] 81 | 82 | @dataclass 83 | class PromptCompletion: 84 | prompt: str 85 | completion: str 86 | -------------------------------------------------------------------------------- /src/opentrain/train.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Any, Dict, Iterator 3 | 4 | import openai 5 | 6 | from opentrain.dataset import Dataset 7 | from opentrain.typing import DatasetType 8 | 9 | warnings.simplefilter("once", category=UserWarning) 10 | 11 | DEFAULT_OPENAI_MODELS = ["ada", "babbage", "curie", "davinci"] 12 | 13 | 14 | class Train: 15 | """The `Train` class is a wrapper around OpenAI's FineTune API, and it also 16 | uses `opentrain.Dataset` to use OpenAI files for training/fine-tuning. 17 | 18 | Args: 19 | model: the OpenAI model name to be used for training/fine-tuning. 20 | 21 | Attributes: 22 | model: the OpenAI model name to be used for training/fine-tuning. 23 | 24 | Examples: 25 | >>> from opentrain import Train, Dataset 26 | >>> trainer = Train(model="curie") 27 | >>> dataset = Dataset(file_id="file-1234") 28 | >>> trainer.train(dataset, n_epochs=5, batch_size=32) 29 | >>> trainer.track() 30 | 31 | >>> from opentrain import Train 32 | >>> trainer = Train(model="curie") 33 | >>> trainer.train( 34 | { 35 | "train": "file-1234", 36 | "eval": "file-5678", 37 | }, 38 | n_epochs=5, 39 | batch_size=32 40 | ) 41 | >>> trainer.track() 42 | """ 43 | 44 | def __init__(self, model: str) -> None: 45 | """Initializes the `Train` class. 46 | 47 | Args: 48 | model: the OpenAI model name to be used for training/fine-tuning. 49 | """ 50 | assert model in DEFAULT_OPENAI_MODELS, ( 51 | "Invalid OpenAI model, it must be one of the following:" 52 | f" {','.join(DEFAULT_OPENAI_MODELS)}." 53 | ) 54 | self.model = model 55 | 56 | def train( 57 | self, 58 | dataset: DatasetType, 59 | **kwargs, 60 | ) -> None: 61 | """Trains/Fine-tunes the OpenAI model with the given dataset/s. 62 | 63 | Args: 64 | dataset: the dataset/s to be used for training/fine-tuning and/or evaluating it. 65 | **kwargs: the keyword arguments to be passed to the OpenAI FineTune API. See 66 | https://platform.openai.com/docs/api-reference/fine-tunes 67 | """ 68 | if isinstance(dataset, str): 69 | train_file_id = dataset 70 | validation_file_id = None 71 | elif isinstance(dataset, Dataset): 72 | train_file_id = dataset.file_id 73 | validation_file_id = None 74 | elif isinstance(dataset, dict): 75 | train_file = dataset.get("train") 76 | if train_file: 77 | train_file_id = ( 78 | train_file.file_id 79 | if isinstance(train_file, Dataset) 80 | else train_file 81 | if isinstance(train_file, str) 82 | else None 83 | ) 84 | else: 85 | train_file_id = None 86 | validation_file = dataset.get("eval") 87 | if validation_file: 88 | validation_file_id = ( 89 | validation_file.file_id 90 | if isinstance(validation_file, Dataset) 91 | else validation_file 92 | if isinstance(validation_file, str) 93 | else None 94 | ) 95 | else: 96 | validation_file_id = None 97 | 98 | if not train_file_id: 99 | raise ValueError( 100 | "You must provide at least a training dataset to be used for training" 101 | " the model." 102 | ) 103 | 104 | fine_tune_args = { 105 | "training_file": train_file_id, 106 | "model": self.model, 107 | **kwargs, 108 | } 109 | 110 | if validation_file_id: 111 | fine_tune_args["validation_file"] = validation_file_id 112 | 113 | fine_tune_response = openai.FineTune.create(**fine_tune_args) 114 | self.fine_tune_id = fine_tune_response.id 115 | 116 | warnings.warn( 117 | "Since the OpenAI API may take from minutes to hours depending on the size" 118 | " of the training data, then from now on, you'll be able to check its" 119 | " progress via the following command: `openai api fine_tunes.follow -i" 120 | f" {self.fine_tune_id}`. Once the training is completed, then you'll be" 121 | " able to use `Inference` with the either the fine tune id returned," 122 | " or from the model name generated by OpenAI linked to your account.", 123 | stacklevel=2, 124 | ) 125 | 126 | def fine_tune(self, dataset: DatasetType, **kwargs) -> None: 127 | """This function is just a wrapper around `train` with the same 128 | functionality. It's just here to keep the same naming convention as OpenAI. 129 | 130 | Args: 131 | dataset: the dataset/s to be used for training/fine-tuning and/or evaluating it. 132 | **kwargs: the keyword arguments to be passed to the OpenAI FineTune API. See 133 | https://platform.openai.com/docs/api-reference/fine-tunes 134 | """ 135 | self.train(dataset, **kwargs) 136 | 137 | def track(self) -> Iterator[Dict[str, Any]]: 138 | """Tracks the progress of the training/fine-tuning process. 139 | 140 | Returns: 141 | A list of events containing the progress of the training/fine-tuning 142 | process. 143 | 144 | Raises: 145 | ValueError: if the model training/fine-tuning hasn't started yet. 146 | """ 147 | if not self.fine_tune_id: 148 | raise ValueError( 149 | "You must call `train` before `track`, since nothing will be tracked as" 150 | " the training/fine-tuning hasn't started yet." 151 | ) 152 | return openai.FineTune.stream_events(self.fine_tune_id) 153 | 154 | 155 | class FineTune(Train): 156 | pass 157 | -------------------------------------------------------------------------------- /src/opentrain/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | 3 | from opentrain.dataset import Dataset, File 4 | 5 | DatasetType = Union[ 6 | str, Dict[str, str], Dataset, Dict[str, Dataset], File, Dict[str, File] 7 | ] 8 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import openai 4 | import pytest 5 | 6 | 7 | def pytest_sessionstart() -> None: 8 | openai.api_key = os.getenv("OPENAI_API_KEY") 9 | 10 | 11 | @pytest.fixture 12 | def training_data() -> list: 13 | """Mock training data to use as few tokens as possible.""" 14 | return [ 15 | { 16 | "prompt": "A", 17 | "completion": "B", 18 | }, 19 | ] 20 | 21 | 22 | @pytest.fixture 23 | def file_id() -> str: 24 | """File ID for an uploaded file in the OpenAI API.""" 25 | return "file-TnIO5MBmKZmzOtpM4ISCkEwx" 26 | 27 | 28 | @pytest.fixture 29 | def fine_tune_id() -> str: 30 | """Fine-tune ID for the IMDB text classification model trained in `examples/text_classification_imdb.ipynb`. 31 | """ 32 | return "ft-nbCAuynTlICBi7HfJ82XFB9F" 33 | 34 | 35 | @pytest.fixture 36 | def fine_tuned_model() -> str: 37 | """Fine-tuned model for the IMDB text classification model trained in `examples/text_classification_imdb.ipynb`. 38 | """ 39 | return "curie:ft-personal-2023-04-19-12-40-56" 40 | 41 | 42 | @pytest.fixture 43 | def prompt() -> str: 44 | """Prompt generated for a random entry in the IMDB test set.""" 45 | return ( 46 | "Categorize the following text from an IMDB review in the following categories" 47 | " based on its sentiment: 'pos', or 'neg'.\n\nReview: I love sci-fi and am" 48 | " willing to put up with a lot. Sci-fi movies/TV are usually underfunded," 49 | " under-appreciated and misunderstood. I tried to like this, I really did, but" 50 | " it is to good TV sci-fi as Babylon 5 is to Star Trek (the original). Silly" 51 | " prosthetics, cheap cardboard sets, stilted dialogues, CG that doesn't match" 52 | " the background, and painfully one-dimensional characters cannot be overcome" 53 | " with a 'sci-fi' setting. (I'm sure there are those of you out there who think" 54 | " Babylon 5 is good sci-fi TV. It's not. It's clichéd and uninspiring.) While" 55 | " US viewers might like emotion and character development, sci-fi is a genre" 56 | " that does not take itself seriously (cf. Star Trek). It may treat important" 57 | " issues, yet not as a serious philosophy. It's really difficult to care about" 58 | " the characters here as they are not simply foolish, just missing a spark of" 59 | " life. Their actions and reactions are wooden and predictable, often painful" 60 | " to watch. The makers of Earth KNOW it's rubbish as they have to always say" 61 | " 'Gene Roddenberry's Earth...' otherwise people would not continue watching." 62 | " Roddenberry's ashes must be turning in their orbit as this dull, cheap," 63 | " poorly edited (watching it without advert breaks really brings this home)" 64 | " trudging Trabant of a show lumbers into space. Spoiler. So, kill off a main" 65 | " character. And then bring him back as another actor. Jeeez! Dallas all over" 66 | " again.\n\nLabel: " 67 | ) 68 | -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import tempfile 3 | 4 | import pytest 5 | 6 | from opentrain.dataset import Dataset, list_datasets 7 | 8 | 9 | class TestDatasetFromRecords: 10 | @pytest.fixture(autouse=True) 11 | @pytest.mark.usefixtures("training_data") 12 | def setup_method(self, training_data: list) -> None: 13 | self.dataset = Dataset.from_records( 14 | records=training_data, file_name="opentrain-test-dataset" 15 | ) 16 | 17 | def teardown_method(self) -> None: 18 | self.dataset.delete() 19 | del self.dataset 20 | 21 | def test_info(self) -> None: 22 | info = self.dataset.info 23 | assert isinstance(info, dict) 24 | assert info["id"] == self.dataset.file_id 25 | assert info["object"] == "file" 26 | assert info["purpose"] == "fine-tune" 27 | 28 | 29 | class TestDatasetFromFile: 30 | @pytest.fixture(autouse=True) 31 | @pytest.mark.usefixtures("training_data") 32 | def setup_method(self, training_data: dict) -> None: 33 | with tempfile.NamedTemporaryFile(suffix=".jsonl") as f: 34 | with open(f.name, "w") as f: 35 | for record in training_data: 36 | json.dump(record, f) 37 | f.write("\n") 38 | self.dataset = Dataset.from_file( 39 | file_path=f.name, file_name="opentrain-test-dataset" 40 | ) 41 | 42 | def teardown_method(self) -> None: 43 | self.dataset.delete() 44 | del self.dataset 45 | 46 | def test_info(self) -> None: 47 | info = self.dataset.info 48 | assert isinstance(info, dict) 49 | assert info["id"] == self.dataset.file_id 50 | assert info["object"] == "file" 51 | assert info["purpose"] == "fine-tune" 52 | 53 | 54 | def test_list_datasets() -> None: 55 | datasets = list_datasets() 56 | assert isinstance(datasets, list) 57 | assert len(datasets) > 0 58 | assert isinstance(datasets[0], Dataset) 59 | -------------------------------------------------------------------------------- /tests/test_inference.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | try: 4 | from pydantic import BaseModel 5 | 6 | has_pydantic = True 7 | except ImportError: 8 | has_pydantic = False 9 | 10 | from opentrain.inference import Inference, list_fine_tunes 11 | 12 | 13 | @pytest.mark.usefixtures("fine_tuned_model", "prompt") 14 | def test_inference(fine_tuned_model: str, prompt: str) -> None: 15 | inference = Inference(fine_tuned_model) 16 | completion = inference(prompt, temperature=0.0, max_tokens=1) 17 | assert isinstance(completion, str) 18 | assert completion in ["pos", "neg"] 19 | 20 | 21 | @pytest.mark.usefixtures("fine_tune_id", "prompt") 22 | def test_inference_from_fine_tune_id(fine_tune_id: str, prompt: str) -> None: 23 | inference = Inference.from_fine_tune_id(fine_tune_id) 24 | completion = inference(prompt, temperature=0.0, max_tokens=1) 25 | assert isinstance(completion, str) 26 | assert completion in ["pos", "neg"] 27 | 28 | 29 | def test_list_fine_tunes() -> None: 30 | fine_tunes = list_fine_tunes() 31 | assert isinstance(fine_tunes, list) 32 | if has_pydantic: 33 | assert isinstance(fine_tunes[0], BaseModel) 34 | else: 35 | assert hasattr(fine_tunes[0], "__dataclass_fields__") 36 | -------------------------------------------------------------------------------- /tests/test_schemas.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | try: 4 | from pydantic import ValidationError 5 | 6 | has_pydantic = True 7 | except ImportError: 8 | has_pydantic = False 9 | 10 | from opentrain.schemas import PromptCompletion 11 | 12 | 13 | def test_prompt_completion_schema() -> None: 14 | valid_schema = {"prompt": "Hello", "completion": "World"} 15 | assert PromptCompletion(**valid_schema) 16 | 17 | invalid_schema = {"not_prompt": "Hello", "not_completion": "World"} 18 | with pytest.raises(TypeError): 19 | PromptCompletion(**invalid_schema) 20 | 21 | if has_pydantic: 22 | invalid_values = {"prompt": 1, "completion": 1} 23 | with pytest.raises(ValidationError): 24 | PromptCompletion(**invalid_values) 25 | -------------------------------------------------------------------------------- /tests/test_train.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from opentrain.train import Train 4 | 5 | 6 | @pytest.mark.usefixtures("file_id") 7 | def test_train(file_id: str) -> None: 8 | trainer = Train(model="ada") 9 | with pytest.warns(UserWarning): 10 | trainer.train(file_id, n_epochs=1, batch_size=1) 11 | assert isinstance(trainer.fine_tune_id, str) 12 | assert trainer.fine_tune_id.startswith("ft-") 13 | --------------------------------------------------------------------------------