├── .gitattributes ├── requirements.txt ├── .pre-commit-config.yaml ├── LICENSE ├── database.py ├── README.md ├── .gitignore ├── dataset.ipynb └── features.ipynb /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb filter=strip-notebook-output 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dspy-ai 2 | matplotlib 3 | arxiv-base 4 | git+https://github.com/dsdanielpark/arxiv2text.git 5 | ipykernel 6 | ipywidgets 7 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | rev: v0.1.14 4 | hooks: 5 | # Run the linter. 6 | - id: ruff 7 | types_or: [ python, pyi, jupyter ] 8 | args: [ --fix ] 9 | # Run the formatter. 10 | - id: ruff-format 11 | types_or: [ python, pyi, jupyter ] 12 | - repo: https://github.com/compilerla/conventional-pre-commit 13 | rev: v3.1.0 14 | hooks: 15 | # Enforce ConventionalCommits.org 16 | - id: conventional-pre-commit 17 | stages: [commit-msg] 18 | args: [] 19 | 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 S1M0N38 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /database.py: -------------------------------------------------------------------------------- 1 | import chromadb 2 | import json 3 | import pathlib 4 | 5 | path_root = pathlib.Path(__file__).parent 6 | path_dataset = path_root / "dataset" 7 | 8 | 9 | # Assuming that chromadb is running on localhost:11435 ... 10 | client = chromadb.HttpClient(host="localhost", port="11435") 11 | # client.reset() # uncomment to reset chromadb. Destructive action! 12 | collection = client.get_or_create_collection(name="papers") 13 | 14 | ids, abstracts, metadatas = [], [], [] 15 | i = 0 16 | 17 | 18 | def clean(text: str) -> str: 19 | return text.strip().replace("\n", " ") 20 | 21 | 22 | with open(path_dataset / "arxiv.json") as file: 23 | for line in file: 24 | paper = json.loads(line.strip()) 25 | ids.append(paper["id"]) 26 | abstracts.append(clean(paper["abstract"])) 27 | metadatas.append({"title": clean(paper["title"]), "doi": paper["doi"] or ""}) 28 | i += 1 29 | 30 | # total num of line in arxiv.json: `wc -l arxiv.json` 31 | print(f"\r[{i:>5}/926193]", end="") 32 | 33 | if i % 1000 == 0: 34 | for id_ in collection.get(ids)["ids"]: 35 | idx = ids.index(id_) 36 | del ids[idx] 37 | del abstracts[idx] 38 | del metadatas[idx] 39 | if ids: 40 | collection.add(ids=ids, documents=abstracts, metadatas=metadatas) 41 | ids, abstracts, metadatas = [], [], [] 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dspy-arxiv 2 | 3 | Explore the use of [DSPy](https://github.com/stanfordnlp/dspy) for extracting features from PDFs. 4 | This repository provides a simple example of how to use this framework to predict the sub-category of a Computer Science paper from arXiv. 5 | 6 | ## Suggested Installation 7 | 8 | 1. Clone this repository. 9 | 2. Create a virtual environment. 10 | 3. Install dependencies from *requirements.txt*. 11 | 4. Install the virtual environment as a Jupyter kernel. 12 | 13 | ## Build Dataset & Database 14 | 15 | The **dataset** is a selection of 150 arXiv papers (metadata + pdf) from the computer science category. 16 | 17 | To build the database: 18 | 19 | 1. Download the JSON file from [Kaggle](https://www.kaggle.com/datasets/Cornell-University/arxiv) into the `dspy-arxiv` directory. 20 | 2. Rename the file to `arxiv.json`. 21 | 3. Run the notebook `data.ipynb` from top to bottom. 22 | 23 | At the end, you should have two directories: 24 | - *dspy-arxiv/database* 25 | - *arxiv.json* - the original JSON file with only the computer science category 26 | - *dspy-arxiv/dataset* 27 | - *trainset* - 50 JSON files with metadata + text used for "training" 28 | - *valset* - 50 JSON files with metadata + text used for "validation" 29 | - *testset* - 50 JSON files with metadata + text used for "testing" 30 | 31 | > If you want to add RAG to the pipeline, it's handy to have the data in a vector database for fast retrieval. 32 | > Check out *database.py* for an example script to set up [chromadb](https://docs.trychroma.com/) and populate it with arXiv metadata. 33 | 34 | ## Features Extraction 35 | 36 | The notebook *features.ipynb* can be seen as a simple tutorial on how to use DSPy to programmatically prompt LLM for feature extraction (in this case, predicting the sub-category of a Computer Science paper from arXiv). 37 | 38 | You can also take a look at the [slides](https://s1m0n38.github.io/dspy-arxiv/#/) generated from this notebook. 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | ### macOS ### 163 | # General 164 | .DS_Store 165 | .AppleDouble 166 | .LSOverride 167 | 168 | # Icon must end with two \r 169 | Icon 170 | 171 | 172 | # Thumbnails 173 | ._* 174 | 175 | # Files that might appear in the root of a volume 176 | .DocumentRevisions-V100 177 | .fseventsd 178 | .Spotlight-V100 179 | .TemporaryItems 180 | .Trashes 181 | .VolumeIcon.icns 182 | .com.apple.timemachine.donotpresent 183 | 184 | # Directories potentially created on remote AFP share 185 | .AppleDB 186 | .AppleDesktop 187 | Network Trash Folder 188 | Temporary Items 189 | .apdisk 190 | 191 | ### macOS Patch ### 192 | # iCloud generated files 193 | *.icloud 194 | 195 | dspy-src 196 | local_cache 197 | database 198 | dataset 199 | -------------------------------------------------------------------------------- /dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "61f2c04d-d309-405c-899f-e666dc8516f1", 7 | "metadata": { 8 | "editable": true, 9 | "slideshow": { 10 | "slide_type": "" 11 | }, 12 | "tags": [] 13 | }, 14 | "outputs": [], 15 | "source": [ 16 | "import json\n", 17 | "import pathlib\n", 18 | "from itertools import chain\n", 19 | "from collections import Counter\n", 20 | "\n", 21 | "import pandas as pd\n", 22 | "import matplotlib.pyplot as plt\n", 23 | "\n", 24 | "from arxiv2text import arxiv_to_text\n", 25 | "from arxiv.taxonomy import definitions\n", 26 | "\n", 27 | "path_root = pathlib.Path(\".\").parent\n", 28 | "path_dataset = path_root / \"dataset\"\n", 29 | "path_database = path_root / \"database\"\n", 30 | "path_dataset.mkdir(exist_ok=True)\n", 31 | "path_database.mkdir(exist_ok=True)\n", 32 | "\n", 33 | "CATEGORIES = {\n", 34 | " cat: meta\n", 35 | " for cat, meta in definitions.CATEGORIES_ACTIVE.items()\n", 36 | " if cat.split(\".\")[0] == \"cs\"\n", 37 | "}\n", 38 | "\n", 39 | "cat2idx = {cat: idx for idx, cat in enumerate(CATEGORIES)}\n", 40 | "idx2cat = {idx: cat for idx, cat in enumerate(CATEGORIES)}" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "id": "f9ca893f-fe2f-4def-80e3-cd8f564966e0", 46 | "metadata": {}, 47 | "source": [ 48 | "## Build Database\n", 49 | "\n", 50 | "First, we need to discard all metadata that is unrelated to Computer Science papers. This step is taken to restrict the size of the dataset, making it more manageable. By doing so, we also limit the number of categories in the classification task, enabling it to comfortably fit into the LLM prompt without muddling the context. If your prompt contains numerous categories, you can refer to [this notebook](https://colab.research.google.com/drive/1CpsOiLiLYKeGrhmq579_FmtGsD5uZ3Qe); dspy-arxiv heavily relies on the techniques discussed in this notebook." 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "id": "faec2f98-a934-4497-b572-47af2f7b6ea0", 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "with open(path_database / \"arxiv.json\", \"w\") as outfile:\n", 61 | " with open(path_root / \"arxiv.json\") as infile:\n", 62 | " for line in infile:\n", 63 | " paper = json.loads(line.strip())\n", 64 | " cats = {cat for cat in paper[\"categories\"].split(\" \")}\n", 65 | " if cats & set(CATEGORIES.keys()):\n", 66 | " outfile.write(line)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "id": "f8937de4-b085-4e14-b212-1ed24f1716e5", 72 | "metadata": {}, 73 | "source": [ 74 | "## Build Dataset\n", 75 | "\n", 76 | "In the following, we will perform a brief data analysis to better understand the data and determine which papers to select. The goal is to choose papers that belong to a single category and also include a few papers that have multiple categories." 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "id": "7f756444-11db-4979-9aa2-06bdca396f65", 83 | "metadata": { 84 | "editable": true, 85 | "slideshow": { 86 | "slide_type": "" 87 | }, 88 | "tags": [] 89 | }, 90 | "outputs": [], 91 | "source": [ 92 | "def reader(path_file):\n", 93 | " with open(path_file, \"r\") as f:\n", 94 | " for line in f:\n", 95 | " paper = json.loads(line)\n", 96 | " yield {\n", 97 | " \"id\": paper[\"id\"],\n", 98 | " \"categories\": paper[\"categories\"].split(\" \"),\n", 99 | " }\n", 100 | "\n", 101 | "\n", 102 | "df = pd.DataFrame(reader(path_database / \"arxiv.json\"))\n", 103 | "df.set_index(\"id\", inplace=True)\n", 104 | "df.head()" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "id": "355c3539-347f-4b74-92b9-74957fb42c80", 111 | "metadata": { 112 | "editable": true, 113 | "slideshow": { 114 | "slide_type": "" 115 | }, 116 | "tags": [] 117 | }, 118 | "outputs": [], 119 | "source": [ 120 | "occurences = Counter(\n", 121 | " s for s in chain.from_iterable(df.categories) if s.startswith(\"cs\")\n", 122 | ")\n", 123 | "occurences.update(CATEGORIES.keys())\n", 124 | "assert set(occurences.keys()) == set(CATEGORIES.keys())\n", 125 | "\n", 126 | "fig, ax = plt.subplots(figsize=(10, 10))\n", 127 | "numbers = list(occurences.keys())\n", 128 | "counts = list(occurences.values())\n", 129 | "ax.barh(numbers, counts, height=0.8)\n", 130 | "ax.set_ylabel(\"Number\")\n", 131 | "ax.set_xlabel(\"Frequency\")\n", 132 | "ax.set_xscale(\"log\")\n", 133 | "ax.set_title(\"Number of papers\")\n", 134 | "plt.show()" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "id": "6dde45e4-1305-4cb3-99d2-5636fa31d5f6", 141 | "metadata": { 142 | "editable": true, 143 | "slideshow": { 144 | "slide_type": "" 145 | }, 146 | "tags": [] 147 | }, 148 | "outputs": [], 149 | "source": [ 150 | "df_single_cat = df[df.categories.apply(len) == 1].categories.apply(lambda x: x[0])\n", 151 | "occurences = Counter(s for s in df_single_cat if s.startswith(\"cs\"))\n", 152 | "occurences.update(CATEGORIES.keys())\n", 153 | "assert set(occurences.keys()) == set(CATEGORIES.keys())\n", 154 | "# The category cs.IT always come with math.IT\n", 155 | "\n", 156 | "fig, ax = plt.subplots(figsize=(10, 10))\n", 157 | "numbers = list(occurences.keys())\n", 158 | "counts = list(occurences.values())\n", 159 | "ax.barh(numbers, counts, height=0.8)\n", 160 | "ax.set_ylabel(\"Number\")\n", 161 | "ax.set_xlabel(\"Frequency\")\n", 162 | "ax.set_xscale(\"log\")\n", 163 | "ax.set_title(\"Number of paper with single category\")\n", 164 | "plt.show()" 165 | ] 166 | }, 167 | { 168 | "cell_type": "markdown", 169 | "id": "cb351d6b-3d11-4006-8769-f3af7ec5eb1a", 170 | "metadata": {}, 171 | "source": [ 172 | "We create three splits:\n", 173 | "\n", 174 | "- `trainset` used for \"training\" the pipeline\n", 175 | "- `valset` used to evaluate the performace during training\n", 176 | "- `testset` used to evaluate the preformace after training " 177 | ] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "execution_count": null, 182 | "id": "30584475-5761-4430-bbcc-eaca1aabf953", 183 | "metadata": { 184 | "editable": true, 185 | "slideshow": { 186 | "slide_type": "" 187 | }, 188 | "tags": [] 189 | }, 190 | "outputs": [], 191 | "source": [ 192 | "# splits\n", 193 | "trainset, valset, testset = set(), set(), set()\n", 194 | "\n", 195 | "# Single category\n", 196 | "for cat in CATEGORIES:\n", 197 | " if cat == \"cs.IT\": # The category cs.IT always come with math.IT\n", 198 | " df_cat = df[df.categories.apply(lambda x: x == [\"cs.IT\", \"math.IT\"])]\n", 199 | " else:\n", 200 | " df_cat = df_single_cat[df_single_cat == cat]\n", 201 | " sample = df_cat.sample(n=3, random_state=1).index\n", 202 | " trainset.add(sample[0])\n", 203 | " valset.add(sample[1])\n", 204 | " testset.add(sample[2])\n", 205 | "\n", 206 | "# Multiple categories: add multi-categories paper to reach 50 papers in each split\n", 207 | "# random_state 1 sample pdf that are not withdrawn\n", 208 | "num = 50 - len(CATEGORIES)\n", 209 | "sample = df[df.categories.apply(len) > 2].sample(n=3 * num, random_state=1).index\n", 210 | "trainset |= set(sample[:num])\n", 211 | "valset |= set(sample[num : num * 2])\n", 212 | "testset |= set(sample[num * 2 :])\n", 213 | "\n", 214 | "dataset = trainset | valset | testset\n", 215 | "\n", 216 | "# ensure no overlapping\n", 217 | "assert not (trainset & valset)\n", 218 | "assert not (trainset & testset)\n", 219 | "assert not (valset & testset)\n", 220 | "assert len(trainset) == len(valset) == len(testset)" 221 | ] 222 | }, 223 | { 224 | "cell_type": "markdown", 225 | "id": "3ba397a1-b64a-4903-a022-0f4c917ff498", 226 | "metadata": {}, 227 | "source": [ 228 | "Now, we want to download the PDFs of the selected papers (50 per split) and extract the full text body. \n", 229 | "For this task, we utilize [arxiv2text](https://github.com/dsdanielpark/arxiv2text).\n", 230 | "\n", 231 | "> In the simple example of the pipeline proposed in *features.ipynb*, we only utilize the title and abstract of the paper, without using the full text body. However, we have decided to include the full PDF text in the dataset to facilitate future experimentation with more complex pipelines that can process documents in chunks." 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "id": "688f5c98-3702-47e5-9a3d-d25a1f957353", 238 | "metadata": { 239 | "editable": true, 240 | "slideshow": { 241 | "slide_type": "" 242 | }, 243 | "tags": [] 244 | }, 245 | "outputs": [], 246 | "source": [ 247 | "(path_dataset / \"trainset\").mkdir(exist_ok=True)\n", 248 | "(path_dataset / \"valset\").mkdir(exist_ok=True)\n", 249 | "(path_dataset / \"testset\").mkdir(exist_ok=True)\n", 250 | "\n", 251 | "i = 0\n", 252 | "with open(path_database / \"arxiv.json\") as file:\n", 253 | " for line in file:\n", 254 | " paper = json.loads(line.strip())\n", 255 | " if paper[\"id\"] in dataset:\n", 256 | " if paper[\"id\"] in trainset:\n", 257 | " path_paper = path_dataset / \"trainset\"\n", 258 | " if paper[\"id\"] in valset:\n", 259 | " path_paper = path_dataset / \"valset\"\n", 260 | " if paper[\"id\"] in testset:\n", 261 | " path_paper = path_dataset / \"testset\"\n", 262 | "\n", 263 | " i += 1\n", 264 | " path_paper = path_paper / f'{paper[\"id\"].replace(\"/\", \"-\")}.json'\n", 265 | " url = f\"https://arxiv.org/pdf/{paper['id']}.pdf\"\n", 266 | " print(f\"[{i:>3}/{len(dataset)}] Processing {path_paper.stem}\")\n", 267 | " if not path_paper.exists():\n", 268 | " paper[\"text\"] = arxiv_to_text(url)\n", 269 | " with open(path_paper, \"w\") as outfile:\n", 270 | " json.dump(paper, outfile, indent=4)" 271 | ] 272 | } 273 | ], 274 | "metadata": { 275 | "kernelspec": { 276 | "display_name": "dspy-arxiv", 277 | "language": "python", 278 | "name": "dspy-arxiv" 279 | }, 280 | "language_info": { 281 | "codemirror_mode": { 282 | "name": "ipython", 283 | "version": 3 284 | }, 285 | "file_extension": ".py", 286 | "mimetype": "text/x-python", 287 | "name": "python", 288 | "nbconvert_exporter": "python", 289 | "pygments_lexer": "ipython3", 290 | "version": "3.10.13" 291 | } 292 | }, 293 | "nbformat": 4, 294 | "nbformat_minor": 5 295 | } 296 | -------------------------------------------------------------------------------- /features.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "editable": true, 7 | "id": "3g7UMaMuGhZt", 8 | "slideshow": { 9 | "slide_type": "slide" 10 | }, 11 | "tags": [] 12 | }, 13 | "source": [ 14 | "# DSPy-arXiv\n", 15 | "\n", 16 | "Given an arXiv paper from the Computer Science (cs) section,\\\n", 17 | "extract its subcategories (e.g., cs.AI, cs.IR, ...)." 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "metadata": { 24 | "editable": true, 25 | "slideshow": { 26 | "slide_type": "skip" 27 | }, 28 | "tags": [] 29 | }, 30 | "outputs": [], 31 | "source": [ 32 | "import json\n", 33 | "import re\n", 34 | "import pathlib\n", 35 | "\n", 36 | "# dspy framework\n", 37 | "import dspy\n", 38 | "from dspy.teleprompt import BootstrapFewShotWithRandomSearch\n", 39 | "\n", 40 | "# arXiv utilites\n", 41 | "from arxiv.taxonomy import definitions\n", 42 | "\n", 43 | "# various paths of the project\n", 44 | "PATH_ROOT = pathlib.Path(\".\").parent\n", 45 | "PATH_DATASET = PATH_ROOT / \"dataset\"\n", 46 | "PATH_DATABASE = PATH_ROOT / \"database\"\n", 47 | "\n", 48 | "# ports where services are exposed\n", 49 | "PORT_LM = 11433\n", 50 | "\n", 51 | "# selected categories, i.e. just the ones from Computer Science (cs)\n", 52 | "CATEGORIES = {\n", 53 | " cat: meta\n", 54 | " for cat, meta in definitions.CATEGORIES_ACTIVE.items()\n", 55 | " if cat.split(\".\")[0] == \"cs\"\n", 56 | "}" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": { 62 | "editable": true, 63 | "id": "5wc3E6-aHVi5", 64 | "slideshow": { 65 | "slide_type": "slide" 66 | }, 67 | "tags": [] 68 | }, 69 | "source": [ 70 | "## Dataset\n", 71 | "\n", 72 | "We use the term **dataset** to refer to a small selection of papers that will be used to 'train' the pipeline.\n", 73 | "\n", 74 | "- 50 papers in `trainset` - used for pipeline training\n", 75 | "- 50 papers in `valset` - used for pipeline evaluation during training\n", 76 | "- 50 papers in `testset` - used for pipeline evaluation after training" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": { 82 | "editable": true, 83 | "slideshow": { 84 | "slide_type": "" 85 | }, 86 | "tags": [] 87 | }, 88 | "source": [ 89 | "To construct the dataset, please refer to the *database.ipynb* notebook." 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": { 96 | "colab": { 97 | "base_uri": "https://localhost:8080/" 98 | }, 99 | "editable": true, 100 | "id": "4xa9DkssVrJm", 101 | "outputId": "266e4101-170d-46b4-9d49-d81c847024d8", 102 | "slideshow": { 103 | "slide_type": "skip" 104 | }, 105 | "tags": [] 106 | }, 107 | "outputs": [], 108 | "source": [ 109 | "def preprocess_example(example: dict) -> dspy.Example:\n", 110 | " \"\"\"\n", 111 | " Turn a paper (example) into dspy.Example.\n", 112 | " \"\"\"\n", 113 | " categories = set(example[\"categories\"].split(\" \")) & set(CATEGORIES)\n", 114 | " return {\n", 115 | " \"title\": example[\"title\"],\n", 116 | " \"abstract\": example[\"abstract\"],\n", 117 | " \"text\": example[\"text\"],\n", 118 | " \"categories\": categories,\n", 119 | " \"labels\": dspy.Example(categories=categories),\n", 120 | " }\n", 121 | "\n", 122 | "\n", 123 | "trainset, valset, testset = [], [], []\n", 124 | "\n", 125 | "for path in (PATH_DATASET / \"trainset\").glob(\"*.json\"):\n", 126 | " with open(path) as f:\n", 127 | " example = dspy.Example(preprocess_example(json.load(f)))\n", 128 | " example = example.with_inputs(\"title\", \"abstract\", \"text\", \"labels\")\n", 129 | " trainset.append(example)\n", 130 | "\n", 131 | "for path in (PATH_DATASET / \"valset\").glob(\"*.json\"):\n", 132 | " with open(path) as f:\n", 133 | " example = dspy.Example(preprocess_example(json.load(f)))\n", 134 | " example = example.with_inputs(\"title\", \"abstract\", \"text\")\n", 135 | " valset.append(example)\n", 136 | "\n", 137 | "for path in (PATH_DATASET / \"testset\").glob(\"*.json\"):\n", 138 | " with open(path) as f:\n", 139 | " example = dspy.Example(preprocess_example(json.load(f)))\n", 140 | " example = example.with_inputs(\"title\", \"abstract\", \"text\")\n", 141 | " testset.append(example)" 142 | ] 143 | }, 144 | { 145 | "cell_type": "markdown", 146 | "metadata": { 147 | "editable": true, 148 | "id": "13JR0wybbVgn", 149 | "slideshow": { 150 | "slide_type": "subslide" 151 | }, 152 | "tags": [] 153 | }, 154 | "source": [ 155 | "Each datapoint (paper + paper metadata) is a `dspy.Example`,\\\n", 156 | "a dict-like structure with *inputs* ($x$) and *labels* ($y$)." 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": { 162 | "editable": true, 163 | "id": "13JR0wybbVgn", 164 | "slideshow": { 165 | "slide_type": "" 166 | }, 167 | "tags": [] 168 | }, 169 | "source": [ 170 | "- Inputs:\n", 171 | " - `title`: Title of the paper.\n", 172 | " - `abstract`: Abstract of the paper.\n", 173 | " - `text`: Text body of the paper parsed from PDF with [arxiv2text](https://github.com/dsdanielpark/arxiv2text). (This is future work.)" 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": { 179 | "editable": true, 180 | "id": "13JR0wybbVgn", 181 | "slideshow": { 182 | "slide_type": "" 183 | }, 184 | "tags": [] 185 | }, 186 | "source": [ 187 | "- Labels:\n", 188 | " - `categories`: Set of associated categories." 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "metadata": { 195 | "colab": { 196 | "base_uri": "https://localhost:8080/" 197 | }, 198 | "editable": true, 199 | "hide_input": true, 200 | "id": "gTL1bTAKXOH6", 201 | "outputId": "acf93d87-8efd-4436-9158-6910d51efa7a", 202 | "slideshow": { 203 | "slide_type": "subslide" 204 | }, 205 | "tags": [] 206 | }, 207 | "outputs": [], 208 | "source": [ 209 | "print(re.sub(r\"[\\s\\n]+\", \" \", valset[0].title), \"\\n\")\n", 210 | "print(re.sub(r\"[\\s\\n]+\", \" \", valset[0].abstract)[:400], \"...\\n\")\n", 211 | "print(valset[0].labels().categories, \"\\n\")" 212 | ] 213 | }, 214 | { 215 | "cell_type": "markdown", 216 | "metadata": { 217 | "editable": true, 218 | "slideshow": { 219 | "slide_type": "slide" 220 | }, 221 | "tags": [] 222 | }, 223 | "source": [ 224 | "## Metrics\n", 225 | "\n", 226 | "**Metrics** are scalar values that quantify the performance of a pipeline with respect to a given task." 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": null, 232 | "metadata": { 233 | "editable": true, 234 | "slideshow": { 235 | "slide_type": "" 236 | }, 237 | "tags": [] 238 | }, 239 | "outputs": [], 240 | "source": [ 241 | "def metric_fn(labels, preds, trace=None):\n", 242 | " preds: list[str] | str = preds.categories\n", 243 | " labels: list[str] = labels.categories\n", 244 | "\n", 245 | " # We assume that predicted categories are sorted by relevance\n", 246 | " # We selected top k predicted categories\n", 247 | " k = min(len(labels), len(preds))\n", 248 | " top_k_preds = preds[:k] if isinstance(preds, list) else [preds]\n", 249 | "\n", 250 | " # ground-truth labels are alphabetically sorted\n", 251 | " # so it make sense to look at the intesection with top_k_preds\n", 252 | " top_k_pred_set: set[str] = set(top_k_preds)\n", 253 | " lables_set: set[str] = set(labels)\n", 254 | "\n", 255 | " score: float = len(top_k_pred_set & lables_set) / len(labels)\n", 256 | " return score" 257 | ] 258 | }, 259 | { 260 | "cell_type": "markdown", 261 | "metadata": { 262 | "editable": true, 263 | "slideshow": { 264 | "slide_type": "slide" 265 | }, 266 | "tags": [] 267 | }, 268 | "source": [ 269 | "## DSPy\n", 270 | "\n", 271 | "The DSPy framework resembles PyTorch.\n", 272 | "\n", 273 | "- I/O interface\n", 274 | "- Modular structure\n", 275 | "- Optimization" 276 | ] 277 | }, 278 | { 279 | "cell_type": "markdown", 280 | "metadata": { 281 | "editable": true, 282 | "slideshow": { 283 | "slide_type": "subslide" 284 | }, 285 | "tags": [] 286 | }, 287 | "source": [ 288 | "### I/O interface:\n", 289 | "- $(x, y)$ → pipeline → generated outputs\n", 290 | "- (`title & abstract`, `categories`) → pipeline → `preds`" 291 | ] 292 | }, 293 | { 294 | "cell_type": "markdown", 295 | "metadata": { 296 | "editable": true, 297 | "slideshow": { 298 | "slide_type": "subslide" 299 | }, 300 | "tags": [] 301 | }, 302 | "source": [ 303 | "### Modular structure:\n", 304 | "\n", 305 | "In PyTorch:\n", 306 | "- Tensor/s → Module → Tensor/s\n", 307 | "- Tensor/s → Module → Module → ... → Module → Tensor/s\n", 308 | "- e.g. `Linear`, `Dropout`, `ReLU`...\n", 309 | " \n", 310 | "In DSPy:\n", 311 | "- InputField/s → Module → OutputField/s\n", 312 | "- InputField/s → Module → Module → ... → Module → OutputField/s\n", 313 | "- e.g. `Predict`, `ChainOfThought`, `React`, custom, ..." 314 | ] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "metadata": { 319 | "editable": true, 320 | "slideshow": { 321 | "slide_type": "subslide" 322 | }, 323 | "tags": [] 324 | }, 325 | "source": [ 326 | "### \"Optimization\"\n", 327 | "\n", 328 | "In PyTorch:\n", 329 | "- Define `loss`. e.g., `nn.MSELoss`, `nn.CrossEntropyLoss`, ...\n", 330 | "- Define `optimizer`. e.g., `optim.SGD`, `optim.Adam`, ...\n", 331 | "- Minimize `loss` over `trainset` using `optimizer` by adjusting model parameters.\n", 332 | "\n", 333 | "In DSPy:\n", 334 | "- Define `metric`. e.g., `metric_fn`\n", 335 | "- Define `optimizer`. e.g., `BootstrapFewShot`, `SignatureOptimizer`, ...\n", 336 | "- Maximize `metric` over `trainset` using `optimizer` by adjusting text generation within modules.\n", 337 | "\n", 338 | "**DSPy heuristically searches for the most effective strategy to prompt an LLM to achieve the task according to the pipeline.**" 339 | ] 340 | }, 341 | { 342 | "cell_type": "markdown", 343 | "metadata": { 344 | "editable": true, 345 | "slideshow": { 346 | "slide_type": "slide" 347 | }, 348 | "tags": [] 349 | }, 350 | "source": [ 351 | "## Pipeline 101\n", 352 | "\n", 353 | "(`title & abstract`, `categories`) → pipeline101 → `preds`\n", 354 | "\n", 355 | "- Just title & abstract, no text body of the paper.\n", 356 | "- No custom modules or creative modules usage.\n", 357 | "- No RAG.\n", 358 | "\n", 359 | "(But all the above can be easily added later.)" 360 | ] 361 | }, 362 | { 363 | "cell_type": "markdown", 364 | "metadata": { 365 | "editable": true, 366 | "slideshow": { 367 | "slide_type": "subslide" 368 | }, 369 | "tags": [] 370 | }, 371 | "source": [ 372 | "### Signature\n", 373 | "\n", 374 | "**Signatures** are like types in a programming language.\n", 375 | "\n", 376 | "- They define the module's input/output.\n", 377 | "- Their `__doc__` will be included in the LLM prompt, so they can specify the goal of a module." 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": null, 383 | "metadata": { 384 | "editable": true, 385 | "id": "TWMKdU8eQJu6", 386 | "slideshow": { 387 | "slide_type": "" 388 | }, 389 | "tags": [] 390 | }, 391 | "outputs": [], 392 | "source": [ 393 | "class PredictCategories(dspy.Signature):\n", 394 | " __doc__ = (\n", 395 | " f\"Given the abstract of a scientific paper, \"\n", 396 | " f\"identify most relevant categories. \"\n", 397 | " f\"Valid categories are {CATEGORIES.keys()}\"\n", 398 | " )\n", 399 | " title = dspy.InputField()\n", 400 | " abstract = dspy.InputField()\n", 401 | " categories = dspy.OutputField(\n", 402 | " desc=\"list of comma-separated categories\",\n", 403 | " format=lambda x: \", \".join(x) if isinstance(x, list) else x,\n", 404 | " )" 405 | ] 406 | }, 407 | { 408 | "cell_type": "markdown", 409 | "metadata": { 410 | "editable": true, 411 | "slideshow": { 412 | "slide_type": "subslide" 413 | }, 414 | "tags": [] 415 | }, 416 | "source": [ 417 | "Newer releases of DSPy make use of type hints to enforce structure in LM responses.\\\n", 418 | "So, the previous signature could be written as:\n", 419 | "```python\n", 420 | "from typing import Literal\n", 421 | "\n", 422 | "class PredictCategories(dspy.Signature):\n", 423 | " \"\"\"Given the abstract of a scientific paper, identify most relevant categories.\"\"\"\n", 424 | " title = dspy.InputField()\n", 425 | " abstract = dspy.InputField()\n", 426 | " categories: list[Literal[*CATEGORIES.keys()]] = dspy.OutputField()\n", 427 | "predict = dspy.functional.TypedChainOfThought(PredictCategories)\n", 428 | "```\n", 429 | "\n", 430 | "*Thanks to [thomasahle](https://github.com/thomasahle) for the [suggestion](https://github.com/S1M0N38/dspy-arxiv/issues/1).*" 431 | ] 432 | }, 433 | { 434 | "cell_type": "markdown", 435 | "metadata": { 436 | "editable": true, 437 | "slideshow": { 438 | "slide_type": "subslide" 439 | }, 440 | "tags": [] 441 | }, 442 | "source": [ 443 | "### Pipeline / Module\n", 444 | "\n", 445 | "The pipeline is a Module as well.\n", 446 | "\n", 447 | "Similar to PyTorch, it makes use of:\n", 448 | "- `__init__`: Here, the modules are instantiated.\n", 449 | "- `forward`: Here, it is defined how modules interact." 450 | ] 451 | }, 452 | { 453 | "cell_type": "code", 454 | "execution_count": null, 455 | "metadata": { 456 | "editable": true, 457 | "id": "hVrLgbZvbJ97", 458 | "slideshow": { 459 | "slide_type": "" 460 | }, 461 | "tags": [] 462 | }, 463 | "outputs": [], 464 | "source": [ 465 | "class Pipeline101(dspy.Module):\n", 466 | " def __init__(self):\n", 467 | " super().__init__()\n", 468 | " self.predict = dspy.ChainOfThought(PredictCategories)\n", 469 | "\n", 470 | " def forward(self, title, abstract, text=None, labels=None):\n", 471 | " categories = self.predict(title=title, abstract=abstract).completions.categories\n", 472 | " categories = [cat.strip() for cat in categories[0].split(\",\")]\n", 473 | " return dspy.Prediction(categories=categories)" 474 | ] 475 | }, 476 | { 477 | "cell_type": "markdown", 478 | "metadata": { 479 | "editable": true, 480 | "slideshow": { 481 | "slide_type": "slide" 482 | }, 483 | "tags": [] 484 | }, 485 | "source": [ 486 | "## Language Model\n", 487 | "\n", 488 | "The **Language Model (LM)** is at the core of the pipeline.\n", 489 | "\n", 490 | "- It is used for processing and generating text in the pipeline.\n", 491 | "- It is used by the optimizers to improve the pipeline itself.\n", 492 | "\n", 493 | "For simple tasks, it can be *fast* and *cheap* (many calls in the optimization).\n", 494 | "\n", 495 | "**DSPy caches all the calls to LM.**" 496 | ] 497 | }, 498 | { 499 | "cell_type": "code", 500 | "execution_count": null, 501 | "metadata": { 502 | "editable": true, 503 | "id": "eBYBiXWZyPjy", 504 | "slideshow": { 505 | "slide_type": "" 506 | }, 507 | "tags": [] 508 | }, 509 | "outputs": [], 510 | "source": [ 511 | "# You can host local model with ollama.\n", 512 | "# Just change `model` and `api_base` accordingly.\n", 513 | "# For example: `model=\"gemma\"` & `api_base=\"http://localhost:11434/v1/\"`\n", 514 | "lm = dspy.OpenAI(\n", 515 | " model=\"gpt3.5-turbo\",\n", 516 | " api_base=f\"http://localhost:{PORT_LM}/v1/\",\n", 517 | " api_key=\"you-api-key\",\n", 518 | " model_type=\"chat\",\n", 519 | ")\n", 520 | "\n", 521 | "# configure dspy to use `lm` as Language Model\n", 522 | "dspy.settings.configure(lm=lm)\n", 523 | "\n", 524 | "# Just testing that LM works\n", 525 | "lm(\"What's red + yellow?\")" 526 | ] 527 | }, 528 | { 529 | "cell_type": "markdown", 530 | "metadata": { 531 | "editable": true, 532 | "slideshow": { 533 | "slide_type": "slide" 534 | }, 535 | "tags": [] 536 | }, 537 | "source": [ 538 | "## Optimization\n", 539 | "\n", 540 | "As suggest by the [docs](https://dspy-docs.vercel.app/docs/building-blocks/optimizers#which-optimizer-should-i-use), with 50 examples, we choose `BootstrapFewShotWithRandomSearch`." 541 | ] 542 | }, 543 | { 544 | "cell_type": "code", 545 | "execution_count": null, 546 | "metadata": { 547 | "editable": true, 548 | "slideshow": { 549 | "slide_type": "" 550 | }, 551 | "tags": [] 552 | }, 553 | "outputs": [], 554 | "source": [ 555 | "# This is not optimized\n", 556 | "pipeline101 = Pipeline101()\n", 557 | "\n", 558 | "optimizer = BootstrapFewShotWithRandomSearch(\n", 559 | " metric=metric_fn,\n", 560 | " max_bootstrapped_demos=2,\n", 561 | " max_labeled_demos=0,\n", 562 | " max_rounds=1,\n", 563 | " num_candidate_programs=20,\n", 564 | " num_threads=8,\n", 565 | " teacher_settings=dict(lm=lm),\n", 566 | ")\n", 567 | "\n", 568 | "pipeline101_optimized = optimizer.compile(\n", 569 | " pipeline101,\n", 570 | " teacher=pipeline101,\n", 571 | " trainset=trainset,\n", 572 | " valset=valset,\n", 573 | ")" 574 | ] 575 | }, 576 | { 577 | "cell_type": "markdown", 578 | "metadata": { 579 | "editable": true, 580 | "slideshow": { 581 | "slide_type": "subslide" 582 | }, 583 | "tags": [] 584 | }, 585 | "source": [ 586 | "## Results\n", 587 | "\n", 588 | "We simply compare the `metric_fn` on the `testset`:\n", 589 | "\n", 590 | "`pipeline101` *vs.* `pipeline101_optimized`" 591 | ] 592 | }, 593 | { 594 | "cell_type": "code", 595 | "execution_count": null, 596 | "metadata": { 597 | "editable": true, 598 | "slideshow": { 599 | "slide_type": "subslide" 600 | }, 601 | "tags": [] 602 | }, 603 | "outputs": [], 604 | "source": [ 605 | "scores_pipeline101 = []\n", 606 | "for example in testset:\n", 607 | " example_x = example.inputs()\n", 608 | " example_y = example.labels()\n", 609 | " prediction = pipeline101(**example_x)\n", 610 | " score = metric_fn(example_y, prediction)\n", 611 | " scores_pipeline101.append(score)\n", 612 | "\n", 613 | "# Inspcet the last prompt given to LLM\n", 614 | "lm.inspect_history()\n", 615 | "print(\"Ground-truth categories:\", example.labels().categories)\n", 616 | "print(\"Score:\", score)\n", 617 | "\n", 618 | "print(\"\\n\" * 5)" 619 | ] 620 | }, 621 | { 622 | "cell_type": "code", 623 | "execution_count": null, 624 | "metadata": { 625 | "editable": true, 626 | "slideshow": { 627 | "slide_type": "subslide" 628 | }, 629 | "tags": [] 630 | }, 631 | "outputs": [], 632 | "source": [ 633 | "scores_pipeline101_optimized = []\n", 634 | "for example in testset:\n", 635 | " example_x = example.inputs()\n", 636 | " example_y = example.labels()\n", 637 | " prediction = pipeline101_optimized(**example_x)\n", 638 | " score = metric_fn(example_y, prediction)\n", 639 | " scores_pipeline101_optimized.append(score)\n", 640 | "\n", 641 | "lm.inspect_history()\n", 642 | "print(\"Ground-truth categories:\", example.labels().categories)\n", 643 | "print(\"Score:\", score)\n", 644 | "print(\"\\n\" * 5)" 645 | ] 646 | }, 647 | { 648 | "cell_type": "code", 649 | "execution_count": null, 650 | "metadata": { 651 | "editable": true, 652 | "slideshow": { 653 | "slide_type": "skip" 654 | }, 655 | "tags": [] 656 | }, 657 | "outputs": [], 658 | "source": [ 659 | "print(\n", 660 | " \"pipeline101:\",\n", 661 | " sum(scores_pipeline101) / len(scores_pipeline101),\n", 662 | ")\n", 663 | "print(\n", 664 | " \"pipeline101_optimized:\",\n", 665 | " sum(scores_pipeline101_optimized) / len(scores_pipeline101_optimized),\n", 666 | ")" 667 | ] 668 | }, 669 | { 670 | "cell_type": "markdown", 671 | "metadata": { 672 | "editable": true, 673 | "slideshow": { 674 | "slide_type": "subslide" 675 | }, 676 | "tags": [] 677 | }, 678 | "source": [ 679 | "While developing this notebook, we:\n", 680 | "- Processed 7,537,982 input tokens (0.0005/1K)\n", 681 | "- Generated 315,868 output tokens (0.0015/1K)\n", 682 | "- With an estimated cost of < $5\n", 683 | "\n", 684 | "---\n", 685 | "\n", 686 | "| Pipeline | Avg. metric_fn |\n", 687 | "|----------------------------|---------------:|\n", 688 | "| pipeline101 | 56% |\n", 689 | "| pipeline101_optimized | 65% |" 690 | ] 691 | }, 692 | { 693 | "cell_type": "markdown", 694 | "metadata": { 695 | "editable": true, 696 | "slideshow": { 697 | "slide_type": "slide" 698 | }, 699 | "tags": [] 700 | }, 701 | "source": [ 702 | "## Conclusions" 703 | ] 704 | }, 705 | { 706 | "cell_type": "markdown", 707 | "metadata": { 708 | "editable": true, 709 | "slideshow": { 710 | "slide_type": "subslide" 711 | }, 712 | "tags": [] 713 | }, 714 | "source": [ 715 | "### Future Work\n", 716 | "\n", 717 | "- Add RAG.\n", 718 | "\n", 719 | "- Utilize the category descriptions.\n", 720 | "\n", 721 | "- Use the full body of the paper.\n", 722 | " - Generate summaries.\n", 723 | " - Use a sliding window, process chunks, and aggregate.\n", 724 | " - Use a more capable language model with greater context length.\n", 725 | "\n", 726 | "- Validate data with the `Assert` module.\n", 727 | "\n", 728 | "- Use a smarter teacher (e.g., GPT-4).\n", 729 | "\n", 730 | "- Experiment with more creative pipelines." 731 | ] 732 | }, 733 | { 734 | "cell_type": "markdown", 735 | "metadata": { 736 | "editable": true, 737 | "slideshow": { 738 | "slide_type": "subslide" 739 | }, 740 | "tags": [] 741 | }, 742 | "source": [ 743 | "### Why DSPy?\n", 744 | "\n", 745 | "- It has promising core concepts.\n", 746 | "- It is actively being developed.\n", 747 | "- It is versatile." 748 | ] 749 | }, 750 | { 751 | "cell_type": "markdown", 752 | "metadata": { 753 | "editable": true, 754 | "slideshow": { 755 | "slide_type": "" 756 | }, 757 | "tags": [] 758 | }, 759 | "source": [ 760 | "### Why Not DSPy?\n", 761 | "\n", 762 | "- It is not production-ready.\n", 763 | "- As of 23rd February 2024, it is not well-documented (see [#390](https://github.com/stanfordnlp/dspy/issues/390)).\n", 764 | "- Other alternatives exist for similar use cases." 765 | ] 766 | }, 767 | { 768 | "cell_type": "markdown", 769 | "metadata": { 770 | "editable": true, 771 | "slideshow": { 772 | "slide_type": "slide" 773 | }, 774 | "tags": [] 775 | }, 776 | "source": [ 777 | "## Alternatives\n", 778 | "\n", 779 | "Many frameworks exist that programmatically generate prompts and parse responses.\n", 780 | "\n", 781 | "- [Instructor](https://github.com/jxnl/instructor): Provides structured outputs for Large Language Models (LLMs).\n", 782 | "- [Guidance](https://github.com/guidance-ai/guidance?tab=readme-ov-file#constrained-generation): A guidance language for controlling large language models.\n", 783 | "- [LMQL](https://github.com/eth-sri/lmql): A language for constraint-guided and efficient LLM programming.\n", 784 | "- [Outlines](https://github.com/outlines-dev/outlines): Supports structured text generation.\n", 785 | "- ..." 786 | ] 787 | }, 788 | { 789 | "cell_type": "markdown", 790 | "metadata": { 791 | "editable": true, 792 | "slideshow": { 793 | "slide_type": "subslide" 794 | }, 795 | "tags": [] 796 | }, 797 | "source": [ 798 | "### Guidance\n", 799 | "\n", 800 | " \"...constrain generation (e.g. with regex and CFGs) as well as to interleave control (conditional, loops) and generation seamlessly.\"\n", 801 | "\n", 802 | " ```python\n", 803 | " from guidance import models, select\n", 804 | "\n", 805 | " # load a model\n", 806 | " llama2 = models.LlamaCpp(path)\n", 807 | "\n", 808 | " # a simple select between two options\n", 809 | " llama2 + f'Do you want a joke or a poem? A ' + select(['joke', 'poem'])\n", 810 | " ```\n", 811 | "\n", 812 | " > Do you want a joke or a poem? A **poem**" 813 | ] 814 | }, 815 | { 816 | "cell_type": "markdown", 817 | "metadata": { 818 | "editable": true, 819 | "slideshow": { 820 | "slide_type": "subslide" 821 | }, 822 | "tags": [] 823 | }, 824 | "source": [ 825 | "### Instructor\n", 826 | "\n", 827 | "Validate LLMs outputs to streamline data extraction.\n", 828 | "\n", 829 | "```python\n", 830 | "import instructor\n", 831 | "from openai import OpenAI\n", 832 | "from pydantic import BaseModel\n", 833 | "\n", 834 | "# Enables `response_model`\n", 835 | "client = instructor.patch(OpenAI())\n", 836 | "\n", 837 | "\n", 838 | "class UserDetail(BaseModel):\n", 839 | " name: str\n", 840 | " age: int\n", 841 | "\n", 842 | "\n", 843 | "user = client.chat.completions.create(\n", 844 | " model=\"gpt-3.5-turbo\",\n", 845 | " response_model=UserDetail,\n", 846 | " messages=[\n", 847 | " {\"role\": \"user\", \"content\": \"Extract Jason is 25 years old\"},\n", 848 | " ],\n", 849 | ")\n", 850 | "\n", 851 | "assert isinstance(user, UserDetail)\n", 852 | "assert user.name == \"Jason\"\n", 853 | "assert user.age == 25\n", 854 | "```" 855 | ] 856 | } 857 | ], 858 | "metadata": { 859 | "colab": { 860 | "provenance": [] 861 | }, 862 | "kernelspec": { 863 | "display_name": "dspy-arxiv", 864 | "language": "python", 865 | "name": "dspy-arxiv" 866 | }, 867 | "language_info": { 868 | "codemirror_mode": { 869 | "name": "ipython", 870 | "version": 3 871 | }, 872 | "file_extension": ".py", 873 | "mimetype": "text/x-python", 874 | "name": "python", 875 | "nbconvert_exporter": "python", 876 | "pygments_lexer": "ipython3", 877 | "version": "3.10.13" 878 | } 879 | }, 880 | "nbformat": 4, 881 | "nbformat_minor": 4 882 | } 883 | --------------------------------------------------------------------------------