├── README.md ├── LICENSE ├── .gitignore ├── context-free-grammar.ipynb └── constrained-decoding.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # A Guide to Structured Generation Using Constrained Decoding 2 | 3 | Article: [A Guide to Structured Generation Using Constrained Decoding](www.aidancooper.co.uk/constrained-decoding/) 4 | 5 | [Code samples notebook](constrained-decoding.ipynb) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Aidan Cooper 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | /code_html 163 | /images -------------------------------------------------------------------------------- /context-free-grammar.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from llama_cpp.llama import LlamaGrammar\n", 10 | "\n", 11 | "HARRY_POTTER_GBNF = r\"\"\"\n", 12 | "ws ::= ([ \\t\\n] ws)?\n", 13 | "\n", 14 | "string ::=\n", 15 | " \"\\\"\" (\n", 16 | " [^\"\\\\\\x7F\\x00-\\x1F] |\n", 17 | " \"\\\\\" ([\"\\\\/bfnrt] | \"u\" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])\n", 18 | " )* \"\\\"\"\n", 19 | "\n", 20 | "\n", 21 | "digit ::= \"0\" | \"1\" | \"2\" | \"3\" | \"4\" | \"5\" | \"6\" | \"7\" | \"8\" | \"9\"\n", 22 | "\n", 23 | "one-or-two-digits ::= digit | digit digit\n", 24 | "\n", 25 | "zero-to-two-digits ::= \"\" | digit | digit digit\n", 26 | "\n", 27 | "house ::= \"\\\"Gryffindor\\\"\" | \"\\\"Hufflepuff\\\"\" | \"\\\"Ravenclaw\\\"\" | \"\\\"Slytherin\\\"\"\n", 28 | "\n", 29 | "blood ::= \"\\\"Muggle-born\\\"\" | \"\\\"Half-blood\\\"\" | \"\\\"Pure-blood\\\"\"\n", 30 | "\n", 31 | "wand ::= (\n", 32 | " \"{\\n\" ws\n", 33 | " \"\\\"wood\\\": \" string \",\" ws\n", 34 | " \"\\\"core\\\": \" string \",\" ws\n", 35 | " \"\\\"length\\\": \" one-or-two-digits \".\" zero-to-two-digits ws\n", 36 | " \"}\"\n", 37 | ")\n", 38 | "\n", 39 | "character ::= (\n", 40 | " \"{\\n\" ws\n", 41 | " \"\\\"name\\\": \" string \",\" ws\n", 42 | " \"\\\"house\\\": \" house \",\" ws\n", 43 | " \"\\\"blood status\\\": \" blood \",\" ws\n", 44 | " \"\\\"wand\\\": \" wand ws\n", 45 | " \"}\"\n", 46 | ")\n", 47 | "\n", 48 | "root ::= \"[\\n \" character (\",\\n\" ws character)* ws \"]\"\n", 49 | "\"\"\"\n", 50 | "\n", 51 | "grammar = LlamaGrammar.from_string(HARRY_POTTER_GBNF, verbose=False)\n" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 2, 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "name": "stdout", 61 | "output_type": "stream", 62 | "text": [ 63 | "[\n", 64 | " {\n", 65 | " \"name\": \"Ronald Bilius Weasley\",\n", 66 | " \"house\": \"Gryffindor\",\n", 67 | " \"blood status\": \"Pure-blood\",\n", 68 | " \"wand\": {\n", 69 | " \"wood\": \"Holly\",\n", 70 | " \"core\": \"Veela hair\",\n", 71 | " \"length\": 10.5\n", 72 | " }\n", 73 | " },\n", 74 | " {\n", 75 | " \"name\": \"Severus Tobias Snape\",\n", 76 | " \"house\": \"Slytherin\",\n", 77 | " \"blood status\": \"Half-blood\",\n", 78 | " \"wand\": {\n", 79 | " \"wood\": \"Blackthorn\",\n", 80 | " \"core\": \"Unicorn hair\",\n", 81 | " \"length\": 11.75\n", 82 | " }\n", 83 | " }\n", 84 | "]\n" 85 | ] 86 | } 87 | ], 88 | "source": [ 89 | "from llama_cpp.llama import Llama\n", 90 | "\n", 91 | "llm = Llama(\n", 92 | " model_path=\"/Users/aidan/models/Meta-Llama-3-8B-Instruct.Q4_0.gguf\",\n", 93 | " n_gpu_layers=-1,\n", 94 | " verbose=False,\n", 95 | ")\n", 96 | "\n", 97 | "output = llm.create_chat_completion(\n", 98 | " messages=[\n", 99 | " {\n", 100 | " \"role\": \"user\",\n", 101 | " \"content\": \"Using JSON, describe these Harry Potter characters: \"\n", 102 | " + \"Ron Weasley, Snape.\",\n", 103 | " },\n", 104 | " ],\n", 105 | " grammar=grammar,\n", 106 | ")\n", 107 | "print(output[\"choices\"][0][\"message\"][\"content\"])\n", 108 | "# >>> [\n", 109 | "# >>> {\n", 110 | "# >>> \"name\": \"Ronald Bilius Weasley\",\n", 111 | "# >>> \"house\": \"Gryffindor\",\n", 112 | "# >>> \"blood status\": \"Pure-blood\",\n", 113 | "# >>> \"wand\": {\n", 114 | "# >>> \"wood\": \"Holly\",\n", 115 | "# >>> \"core\": \"Veela hair\",\n", 116 | "# >>> \"length\": 10.5\n", 117 | "# >>> }\n", 118 | "# >>> },\n", 119 | "# >>> {\n", 120 | "# >>> \"name\": \"Severus Tobias Snape\",\n", 121 | "# >>> \"house\": \"Slytherin\",\n", 122 | "# >>> \"blood status\": \"Half-blood\",\n", 123 | "# >>> \"wand\": {\n", 124 | "# >>> \"wood\": \"Blackthorn\",\n", 125 | "# >>> \"core\": \"Unicorn hair\",\n", 126 | "# >>> \"length\": 11.75\n", 127 | "# >>> }\n", 128 | "# >>> }\n", 129 | "# >>> ]" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [] 138 | } 139 | ], 140 | "metadata": { 141 | "kernelspec": { 142 | "display_name": "venv", 143 | "language": "python", 144 | "name": "python3" 145 | }, 146 | "language_info": { 147 | "codemirror_mode": { 148 | "name": "ipython", 149 | "version": 3 150 | }, 151 | "file_extension": ".py", 152 | "mimetype": "text/x-python", 153 | "name": "python", 154 | "nbconvert_exporter": "python", 155 | "pygments_lexer": "ipython3", 156 | "version": "3.11.4" 157 | } 158 | }, 159 | "nbformat": 4, 160 | "nbformat_minor": 2 161 | } 162 | -------------------------------------------------------------------------------- /constrained-decoding.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 6, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | " Sure. Here's the character description based on JSON:\n", 13 | "\n", 14 | "```json\n", 15 | "{\n", 16 | " \"name\": \"Hermione Granger\",\n", 17 | " \"age\": 15,\n", 18 | " \"species\": \"Human\",\n", 19 | " \"fictional_status\": true,\n", 20 | " \"nationality\": \"British\",\n", 21 | " \"gender\": \"Female\",\n", 22 | " \"house\": \"Hufflepuff\",\n", 23 | " \"looking_after\": null,\n", 24 | " \"mbti\": \"INFP\",\n", 25 | " \"personality\": \"Intelligent, bookish, independent, curious, and compassionate\",\n", 26 | " \"status\": \"Alive\",\n", 27 | " \"birth_date\": \"1991-01-03\",\n", 28 | " \"death_date\": null\n", 29 | "}\n", 30 | "```\n", 31 | "\n", 32 | "Please note that this character description is based on the character's portrayal in the Harry Potter series of books and movies.\n" 33 | ] 34 | } 35 | ], 36 | "source": [ 37 | "import sglang as sgl\n", 38 | "\n", 39 | "@sgl.function\n", 40 | "def harry_potter_gen(s, name):\n", 41 | " s += sgl.user(f\"Using JSON, describe the character {name} from Harry Potter.\")\n", 42 | " s += sgl.assistant(sgl.gen(\"json\", max_tokens=256))\n", 43 | "\n", 44 | "sgl.set_default_backend(sgl.RuntimeEndpoint(\"http://localhost:30000\"))\n", 45 | "state = harry_potter_gen.run(\"Hermione Granger\")\n", 46 | "print(state[\"json\"])\n", 47 | "# >>> Sure. Here's the character description based on JSON:\n", 48 | "# >>>\n", 49 | "# >>> ```json\n", 50 | "# >>> {\n", 51 | "# >>> \"name\": \"Hermione Granger\",\n", 52 | "# >>> \"age\": 15,\n", 53 | "# >>> \"species\": \"Human\",\n", 54 | "# >>> \"nationality\": \"British\",\n", 55 | "# >>> \"gender\": \"Female\",\n", 56 | "# >>> \"house\": \"Hufflepuff\",\n", 57 | "# >>> \"personality\": \"Intelligent, bookish, and compassionate\",\n", 58 | "# >>> \"status\": \"Alive\",\n", 59 | "# >>> \"birth_date\": \"1991-01-03\",\n", 60 | "# >>> \"death_date\": null\n", 61 | "# >>> }\n", 62 | "# >>> ```" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 1, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "import sglang as sgl\n", 72 | "\n", 73 | "character_regex = (\n", 74 | " r\"\"\"\\{\n", 75 | " \"name\": \"[\\w\\d\\s]{1,16}\",\n", 76 | " \"house\": \"(Gryffindor|Slytherin|Ravenclaw|Hufflepuff)\",\n", 77 | " \"blood status\": \"(Pure-blood|Half-blood|Muggle-born)\",\n", 78 | " \"wand\": \\{\n", 79 | " \"wood\": \"[\\w\\d\\s]{1,16}\",\n", 80 | " \"core\": \"[\\w\\d\\s]{1,16}\",\n", 81 | " \"length\": [0-9]{1,2}\\.[0-9]{0,2}\n", 82 | " \\}\n", 83 | "\\}\"\"\"\n", 84 | ")\n", 85 | "\n", 86 | "@sgl.function\n", 87 | "def harry_potter_gen(s, name):\n", 88 | " s += sgl.user(f\"Please describe the character {name} from Harry Potter.\")\n", 89 | " s += sgl.assistant(sgl.gen(\"json\", max_tokens=256, regex=character_regex))\n", 90 | "\n", 91 | "sgl.set_default_backend(sgl.RuntimeEndpoint(\"http://localhost:30000\"))\n", 92 | "state = harry_potter_gen.run(\"Hermione Granger\")\n", 93 | "character_json = state[\"json\"]" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 2, 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "name": "stdout", 103 | "output_type": "stream", 104 | "text": [ 105 | "{\n", 106 | " \"name\": \"Hermione Granger\",\n", 107 | " \"house\": \"Hufflepuff\",\n", 108 | " \"blood status\": \"Pure-blood\",\n", 109 | " \"wand\": {\n", 110 | " \"wood\": \"Bludger\",\n", 111 | " \"core\": \"Phoenix\",\n", 112 | " \"length\": 13.5\n", 113 | " }\n", 114 | "}\n" 115 | ] 116 | } 117 | ], 118 | "source": [ 119 | "print(character_json)" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 37, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "import sglang as sgl\n", 129 | "from sglang.srt.constrained import build_regex_from_object\n", 130 | "from pydantic import BaseModel\n", 131 | "\n", 132 | "class Wand(BaseModel):\n", 133 | " wood: str\n", 134 | " core: str\n", 135 | " length: float\n", 136 | "\n", 137 | "class Character(BaseModel):\n", 138 | " name: str\n", 139 | " house: str\n", 140 | " blood_status: str\n", 141 | " wand: Wand\n", 142 | "\n", 143 | "@sgl.function\n", 144 | "def harry_potter_gen(s, name):\n", 145 | " s += sgl.user(f\"Please describe the character {name} from Harry Potter.\")\n", 146 | " s += sgl.assistant(\n", 147 | " sgl.gen(\"json\", max_tokens=256, regex=build_regex_from_object(Character))\n", 148 | " )\n", 149 | "\n", 150 | "sgl.set_default_backend(sgl.RuntimeEndpoint(\"http://localhost:30000\"))\n", 151 | "state = harry_potter_gen.run(\"Hermione Granger\")\n", 152 | "character_json = state[\"json\"]" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 38, 158 | "metadata": {}, 159 | "outputs": [ 160 | { 161 | "name": "stdout", 162 | "output_type": "stream", 163 | "text": [ 164 | "{\"name\" : \"Hermione Granger\",se\"use\" : \"blood_status\"_statusand\" \n", 165 | " ,d\"and\" : {\"ore\"\":\"-\",\"h\"re\":\"2nd\" , \"length\":4} }\n" 166 | ] 167 | } 168 | ], 169 | "source": [ 170 | "print(character_json) # does not work" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 5, 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [ 179 | "import sglang as sgl\n", 180 | "\n", 181 | "@sgl.function\n", 182 | "def harry_potter_gen(s, name):\n", 183 | " s += sgl.user(f\"Please describe the character {name} from Harry Potter.\")\n", 184 | " s += sgl.assistant('''{\n", 185 | " \"name\": \"''' + sgl.gen(\"name\", max_tokens=32, stop='\"') + '''\",\n", 186 | " \"house\": \"''' + sgl.gen(\n", 187 | " \"house\", choices=[\"Gryffindor\", \"Slytherin\", \"Ravenclaw\", \"Hufflepuff\"]\n", 188 | " ) + '''\",\n", 189 | " \"blood status\": \"''' + sgl.gen(\n", 190 | " \"blood status\", choices=[\"Pure-blood\", \"Half-blood\", \"Muggle-born\"]\n", 191 | " ) + '''\",\n", 192 | " \"wand\": {\n", 193 | " \"wood\": \"''' + sgl.gen(\"wood\", regex=r\"[\\w\\d\\s]{1,16}\") + '''\",\n", 194 | " \"core\": \"''' + sgl.gen(\"core\", regex=r\"[\\w\\d\\s]{1,16}\") + '''\",\n", 195 | " \"length\": ''' + sgl.gen(\"length\", regex=r\"[0-9]{1,2}\\.[0-9]{0,2}\") + '''\n", 196 | " }\n", 197 | "}''')\n", 198 | "\n", 199 | "sgl.set_default_backend(sgl.RuntimeEndpoint(\"http://localhost:30000\"))\n", 200 | "state = harry_potter_gen.run(\"Hermione Granger\")\n", 201 | "character_json = state.messages()[1][\"content\"]" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 6, 207 | "metadata": {}, 208 | "outputs": [ 209 | { 210 | "name": "stdout", 211 | "output_type": "stream", 212 | "text": [ 213 | "{\n", 214 | " \"name\": \"Hermione Granger\",\n", 215 | " \"house\": \"Hufflepuff\",\n", 216 | " \"blood status\": \"Pure-blood\",\n", 217 | " \"wand\": {\n", 218 | " \"wood\": \"Oak treewood with an emerald core and dragon heartwood trim and grip with dragon\",\n", 219 | " \"core\": \"Eldritch Wood from a Goxdrich willow tree in the Forbidden Forest of\",\n", 220 | " \"length\": 11.5136738863437\n", 221 | " }\n", 222 | "}\n" 223 | ] 224 | } 225 | ], 226 | "source": [ 227 | "print(character_json)" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 7, 233 | "metadata": {}, 234 | "outputs": [ 235 | { 236 | "name": "stdout", 237 | "output_type": "stream", 238 | "text": [ 239 | "France\n", 240 | "spain\n" 241 | ] 242 | } 243 | ], 244 | "source": [ 245 | "import sglang as sgl\n", 246 | "\n", 247 | "@sgl.function\n", 248 | "def which_country_upper(s, city):\n", 249 | " s += sgl.user(f\"In which country is {city} located?\")\n", 250 | " s += sgl.assistant(sgl.gen(\"country\", choices=[\"France\", \"Spain\", \"Italy\"]))\n", 251 | "\n", 252 | "@sgl.function\n", 253 | "def which_country_lower(s, city):\n", 254 | " s += sgl.user(f\"In which country is {city} located?\")\n", 255 | " s += sgl.assistant(sgl.gen(\"country\", choices=[\"france\", \"spain\", \"italy\"]))\n", 256 | "\n", 257 | "sgl.set_default_backend(sgl.RuntimeEndpoint(\"http://localhost:30000\"))\n", 258 | "state_upper = which_country_upper.run(\"Paris\")\n", 259 | "state_lower = which_country_lower.run(\"Paris\")\n", 260 | "print(state_upper[\"country\"]) # >>> France\n", 261 | "print(state_lower[\"country\"]) # >>> spain" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": 8, 267 | "metadata": {}, 268 | "outputs": [ 269 | { 270 | "name": "stdout", 271 | "output_type": "stream", 272 | "text": [ 273 | "Millard Fillmore\n", 274 | "Donald Duck\n" 275 | ] 276 | } 277 | ], 278 | "source": [ 279 | "import sglang as sgl\n", 280 | "\n", 281 | "@sgl.function\n", 282 | "def us_president_choices(s):\n", 283 | " s += sgl.user(\"Name a US president.\")\n", 284 | " s += sgl.assistant(\n", 285 | " \"An example of a US president is \" +\n", 286 | " sgl.gen(\"president\", choices=[\"Donald Duck\", \"Millard Fillmore\"])\n", 287 | " )\n", 288 | "\n", 289 | "@sgl.function\n", 290 | "def us_president_regex(s):\n", 291 | " s += sgl.user(\"Name a US president.\")\n", 292 | " s += sgl.assistant(\n", 293 | " \"An example of a US president is \" +\n", 294 | " sgl.gen(\"president\", regex=r\"(Donald Duck|Millard Fillmore)\")\n", 295 | " )\n", 296 | "\n", 297 | "sgl.set_default_backend(sgl.RuntimeEndpoint(\"http://localhost:30000\"))\n", 298 | "state_choices = us_president_choices.run()\n", 299 | "state_regex = us_president_regex.run()\n", 300 | "print(state_choices[\"president\"]) # >>> Millard Fillmore\n", 301 | "print(state_regex[\"president\"]) # >>> Donald Duck" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": 35, 307 | "metadata": {}, 308 | "outputs": [], 309 | "source": [ 310 | "import sglang as sgl\n", 311 | "\n", 312 | "@sgl.function\n", 313 | "def harry_potter_gen(s, names, max_characters=5):\n", 314 | " s += sgl.user(f\"Describe these Harry Potter characters: {', '.join(names)}\")\n", 315 | " n = 1\n", 316 | " while n <= max_characters:\n", 317 | " s += sgl.assistant('''{\n", 318 | " \"name\": \"''' + sgl.gen(\"name\", max_tokens=32, stop='\"') + '''\",\n", 319 | " \"house\": \"''' + sgl.gen(\n", 320 | " \"house\", choices=[\"Gryffindor\", \"Slytherin\", \"Ravenclaw\", \"Hufflepuff\"]\n", 321 | " ) + '''\",\n", 322 | " \"blood status\": \"''' + sgl.gen(\n", 323 | " \"blood status\", choices=[\"Pure-blood\", \"Half-blood\", \"Muggle-born\"]\n", 324 | " ) + '''\",\n", 325 | " \"wand\": {\n", 326 | " \"wood\": \"''' + sgl.gen(\"wood\", max_tokens=32, stop='\"') + '''\",\n", 327 | " \"core\": \"''' + sgl.gen(\"core\", max_tokens=32, stop='\"') + '''\",\n", 328 | " \"length\": ''' + sgl.gen(\"length\", regex=r\"[0-9]{1,2}\\.[0-9]{0,2}\") + '''\n", 329 | " }\n", 330 | "}''')\n", 331 | " s += sgl.user(\"Are there any more characters to describe? (Y/N)\")\n", 332 | " s += sgl.assistant(sgl.gen(f\"continue_{n}\", choices=[\"Y\", \"N\"]))\n", 333 | " if s[f\"continue_{n}\"] == \"N\":\n", 334 | " break\n", 335 | " n += 1\n", 336 | " s += sgl.user(f\"OK, describe the next character.\")\n", 337 | " s[\"n\"] = min(n, max_characters)\n", 338 | "\n", 339 | "sgl.set_default_backend(sgl.RuntimeEndpoint(\"http://localhost:30000\"))\n", 340 | "state = harry_potter_gen.run([\"Ron Weasley\", \"Snape\"])\n", 341 | "characters_json = [state.messages()[1+i*4][\"content\"] for i in range(state[\"n\"])]" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": 36, 347 | "metadata": {}, 348 | "outputs": [ 349 | { 350 | "name": "stdout", 351 | "output_type": "stream", 352 | "text": [ 353 | "{\n", 354 | " \"name\": \"Ron Weasley\",\n", 355 | " \"house\": \"Hufflepuff\",\n", 356 | " \"blood status\": \"Pure-blood\",\n", 357 | " \"wand\": {\n", 358 | " \"wood\": \"Flint\",\n", 359 | " \"core\": \"Dragonheart\",\n", 360 | " \"length\": 11.50\n", 361 | " }\n", 362 | "}\n", 363 | "\n", 364 | "{\n", 365 | " \"name\": \"Snape\",\n", 366 | " \"house\": \"Slytherin\",\n", 367 | " \"blood status\": \"Pure-blood\",\n", 368 | " \"wand\": {\n", 369 | " \"wood\": \"Ash\",\n", 370 | " \"core\": \"Phoenix tail\",\n", 371 | " \"length\": 13.50\n", 372 | " }\n", 373 | "}\n", 374 | "\n" 375 | ] 376 | } 377 | ], 378 | "source": [ 379 | "for character in characters_json:\n", 380 | " print(character)\n", 381 | " print()" 382 | ] 383 | }, 384 | { 385 | "cell_type": "code", 386 | "execution_count": 12, 387 | "metadata": {}, 388 | "outputs": [ 389 | { 390 | "name": "stdout", 391 | "output_type": "stream", 392 | "text": [ 393 | "Grey\n", 394 | "Forest green\n" 395 | ] 396 | } 397 | ], 398 | "source": [ 399 | "import sglang as sgl\n", 400 | "\n", 401 | "@sgl.function\n", 402 | "def grass(s):\n", 403 | " s += sgl.user(f\"What colour is grass?\")\n", 404 | " s += sgl.assistant(sgl.gen(\"colour\", regex=f\"(Grey|Forest green)\"))\n", 405 | "\n", 406 | "colours = {\"Grey\": \"Grey\", \"Forest green\": \"Forest green\", \"Green\": \"Forest green\"}\n", 407 | "\n", 408 | "@sgl.function\n", 409 | "def grass_aliases(s):\n", 410 | " s += sgl.user(f\"What colour is grass?\")\n", 411 | " s += sgl.assistant(sgl.gen(\"colour\", regex=f\"(Grey|Forest green|Green)\"))\n", 412 | "\n", 413 | "sgl.set_default_backend(sgl.RuntimeEndpoint(\"http://localhost:30000\"))\n", 414 | "state = grass.run()\n", 415 | "state_aliases = grass_aliases.run()\n", 416 | "print(state[\"colour\"]) # >>> Grey\n", 417 | "print(colours[state_aliases[\"colour\"]]) # >>> Forest green" 418 | ] 419 | }, 420 | { 421 | "cell_type": "code", 422 | "execution_count": null, 423 | "metadata": {}, 424 | "outputs": [], 425 | "source": [] 426 | } 427 | ], 428 | "metadata": { 429 | "kernelspec": { 430 | "display_name": "venv", 431 | "language": "python", 432 | "name": "python3" 433 | }, 434 | "language_info": { 435 | "codemirror_mode": { 436 | "name": "ipython", 437 | "version": 3 438 | }, 439 | "file_extension": ".py", 440 | "mimetype": "text/x-python", 441 | "name": "python", 442 | "nbconvert_exporter": "python", 443 | "pygments_lexer": "ipython3", 444 | "version": "3.10.10" 445 | } 446 | }, 447 | "nbformat": 4, 448 | "nbformat_minor": 2 449 | } 450 | --------------------------------------------------------------------------------