├── .github └── workflows │ └── minichain.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── Makefile ├── README.md ├── docs ├── examples └── index.md ├── examples ├── agent.pmpt.tpl ├── agent.py ├── backtrack.ipynb ├── bash.ipynb ├── bash.pmpt.tpl ├── bash.py ├── chat.ipynb ├── chat.pmpt.tpl ├── chat.py ├── chatgpt.pmpt.tpl ├── data.json ├── gatsby.ipynb ├── gatsby.pmpt.tpl ├── gatsby.py ├── gatsby │ ├── data-00000-of-00001.arrow │ ├── dataset_info.json │ └── state.json ├── gradio_example.py ├── math.pmpt.tpl ├── math_demo.ipynb ├── math_demo.py ├── ner.ipynb ├── ner.pmpt.tpl ├── ner.py ├── olympics.data │ ├── data-00000-of-00001.arrow │ ├── dataset_info.json │ └── state.json ├── pal.ipynb ├── pal.pmpt.tpl ├── pal.py ├── parallel.py ├── qa.ipynb ├── qa.pmpt.tpl ├── qa.py ├── selfask.ipynb ├── selfask.pmpt.tpl ├── selfask.py ├── sixers.txt ├── stats.ipynb ├── stats.pmpt.tpl ├── stats.py ├── summary.ipynb ├── summary.pmpt.tpl ├── summary.py ├── table.pmpt.txt ├── table.py └── type_prompt.pmpt.tpl ├── minichain ├── __init__.py ├── backend.py ├── base.py ├── gradio.py └── templates │ ├── prompt.html.tpl │ └── type_prompt.pmpt.tpl ├── mkdocs.yml ├── requirements-docs.txt ├── requirements.txt ├── setup.cfg └── setup.py /.github/workflows/minichain.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | python-version: [3.8] 12 | 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python ${{ matrix.python-version }} 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: ${{ matrix.python-version }} 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 23 | - name: Lint with pre-commit 24 | uses: pre-commit/action@v2.0.3 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # User Defined 2 | .vscode 3 | 4 | # Temporary files 5 | *.swp 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | .archives/ 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | *.py~ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # To use: 2 | # 3 | # pre-commit run -a 4 | # 5 | # Or: 6 | # 7 | # pre-commit install # (runs every time you commit in git) 8 | # 9 | # To update this file: 10 | # 11 | # pre-commit autoupdate 12 | # 13 | # See https://github.com/pre-commit/pre-commit 14 | 15 | repos: 16 | # Standard hooks 17 | - repo: https://github.com/pre-commit/pre-commit-hooks 18 | rev: v4.3.0 19 | hooks: 20 | - id: check-added-large-files 21 | - id: check-case-conflict 22 | - id: check-docstring-first 23 | - id: check-merge-conflict 24 | - id: check-symlinks 25 | - id: check-toml 26 | - id: debug-statements 27 | - id: mixed-line-ending 28 | - id: requirements-txt-fixer 29 | - id: trailing-whitespace 30 | 31 | - repo: https://github.com/timothycrosley/isort 32 | rev: 5.12.0 33 | hooks: 34 | - id: isort 35 | exclude: ^(docs/)|(project/)|(assignments/)|(project/interface/)|(examples/) 36 | 37 | - repo: https://github.com/pre-commit/mirrors-mypy 38 | rev: v0.971 39 | hooks: 40 | - id: mypy 41 | exclude: ^(docs/)|(project/)|(assignments/)|(project/interface/)|(examples/) 42 | 43 | 44 | # Black, the code formatter, natively supports pre-commit 45 | - repo: https://github.com/psf/black 46 | rev: 22.6.0 47 | hooks: 48 | - id: black 49 | exclude: ^(docs/)|(project/)|(assignments/)|(project/interface/)|(examples/) 50 | # Flake8 also supports pre-commit natively (same author) 51 | - repo: https://github.com/PyCQA/flake8 52 | rev: 5.0.4 53 | hooks: 54 | - id: flake8 55 | additional_dependencies: 56 | - pep8-naming 57 | exclude: ^(docs/)|(assignments/)|(project/interface/)|(examples/) 58 | 59 | # Doc linters 60 | - repo: https://github.com/terrencepreilly/darglint 61 | rev: v1.8.1 62 | hooks: 63 | - id: darglint 64 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Sasha Rush 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | INPUTS = $(wildcard examples/*.py) 2 | 3 | OUTPUTS = $(patsubst %.py,%.ipynb,$(INPUTS)) 4 | 5 | examples/%.ipynb : examples/%.py 6 | python examples/process.py < $< > /tmp/out.py 7 | jupytext --to notebook /tmp/out.py -o $@ 8 | 9 | examples/%.md : examples/%.py 10 | jupytext --to markdown $< 11 | 12 | all: $(OUTPUTS) 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | A tiny library for coding with **large** language models. Check out the [MiniChain Zoo](https://srush-minichain.hf.space/) to get a sense of how it works. 4 | 5 | ## Coding 6 | 7 | * Code ([math_demo.py](https://github.com/srush/MiniChain/blob/main/examples/math_demo.py)): Annotate Python functions that call language models. 8 | 9 | ```python 10 | @prompt(OpenAI(), template_file="math.pmpt.tpl") 11 | def math_prompt(model, question): 12 | "Prompt to call GPT with a Jinja template" 13 | return model(dict(question=question)) 14 | 15 | @prompt(Python(), template="import math\n{{code}}") 16 | def python(model, code): 17 | "Prompt to call Python interpreter" 18 | code = "\n".join(code.strip().split("\n")[1:-1]) 19 | return model(dict(code=code)) 20 | 21 | def math_demo(question): 22 | "Chain them together" 23 | return python(math_prompt(question)) 24 | ``` 25 | 26 | * Chains ([Space](https://srush-minichain.hf.space/)): MiniChain builds a graph (think like PyTorch) of all the calls you make for debugging and error handling. 27 | 28 | 29 | 30 | ```python 31 | show(math_demo, 32 | examples=["What is the sum of the powers of 3 (3^i) that are smaller than 100?", 33 | "What is the sum of the 10 first positive integers?"], 34 | subprompts=[math_prompt, python], 35 | out_type="markdown").queue().launch() 36 | ``` 37 | 38 | 39 | * Template ([math.pmpt.tpl](https://github.com/srush/MiniChain/blob/main/examples/math.pmpt.tpl)): Prompts are separated from code. 40 | 41 | ``` 42 | ... 43 | Question: 44 | A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take? 45 | Code: 46 | 2 + 2/2 47 | 48 | Question: 49 | {{question}} 50 | Code: 51 | ``` 52 | 53 | * Installation 54 | 55 | ```bash 56 | pip install minichain 57 | export OPENAI_API_KEY="sk-***" 58 | ``` 59 | 60 | ## Examples 61 | 62 | This library allows us to implement several popular approaches in a few lines of code. 63 | 64 | * [Retrieval-Augmented QA](https://srush.github.io/MiniChain/examples/qa/) 65 | * [Chat with memory](https://srush.github.io/MiniChain/examples/chatgpt/) 66 | * [Information Extraction](https://srush.github.io/MiniChain/examples/ner/) 67 | * [Interleaved Code (PAL)](https://srush.github.io/MiniChain/examples/pal/) - [(Gao et al 2022)](https://arxiv.org/pdf/2211.10435.pdf) 68 | * [Search Augmentation (Self-Ask)](https://srush.github.io/MiniChain/examples/selfask/) - [(Press et al 2022)](https://ofir.io/self-ask.pdf) 69 | * [Chain-of-Thought](https://srush.github.io/MiniChain/examples/bash/) - [(Wei et al 2022)](https://arxiv.org/abs/2201.11903) 70 | 71 | It supports the current backends. 72 | 73 | * OpenAI (Completions / Embeddings) 74 | * Hugging Face 🤗 75 | * Google Search 76 | * Python 77 | * Manifest-ML (AI21, Cohere, Together) 78 | * Bash 79 | 80 | ## Why Mini-Chain? 81 | 82 | There are several very popular libraries for prompt chaining, 83 | notably: [LangChain](https://langchain.readthedocs.io/en/latest/), 84 | [Promptify](https://github.com/promptslab/Promptify), and 85 | [GPTIndex](https://gpt-index.readthedocs.io/en/latest/reference/prompts.html). 86 | These library are useful, but they are extremely large and 87 | complex. MiniChain aims to implement the core prompt chaining 88 | functionality in a tiny digestable library. 89 | 90 | 91 | ## Tutorial 92 | 93 | Mini-chain is based on annotating functions as prompts. 94 | 95 | ![image](https://user-images.githubusercontent.com/35882/221280012-d58c186d-4da2-4cb6-96af-4c4d9069943f.png) 96 | 97 | 98 | ```python 99 | @prompt(OpenAI()) 100 | def color_prompt(model, input): 101 | return model(f"Answer 'Yes' if this is a color, {input}. Answer:") 102 | ``` 103 | 104 | Prompt functions act like python functions, except they are lazy to access the result you need to call `run()`. 105 | 106 | ```python 107 | if color_prompt("blue").run() == "Yes": 108 | print("It's a color") 109 | ``` 110 | Alternatively you can chain prompts together. Prompts are lazy, so if you want to manipulate them you need to add `@transform()` to your function. For example: 111 | 112 | ```python 113 | @transform() 114 | def said_yes(input): 115 | return input == "Yes" 116 | ``` 117 | 118 | ![image](https://user-images.githubusercontent.com/35882/221281771-3770be96-02ce-4866-a6f8-c458c9a11c6f.png) 119 | 120 | ```python 121 | @prompt(OpenAI()) 122 | def adjective_prompt(model, input): 123 | return model(f"Give an adjective to describe {input}. Answer:") 124 | ``` 125 | 126 | 127 | ```python 128 | adjective = adjective_prompt("rainbow") 129 | if said_yes(color_prompt(adjective)).run(): 130 | print("It's a color") 131 | ``` 132 | 133 | 134 | We also include an argument `template_file` which assumes model uses template from the 135 | [Jinja](https://jinja.palletsprojects.com/en/3.1.x/templates/) language. 136 | This allows us to separate prompt text from the python code. 137 | 138 | ```python 139 | @prompt(OpenAI(), template_file="math.pmpt.tpl") 140 | def math_prompt(model, question): 141 | return model(dict(question=question)) 142 | ``` 143 | 144 | ### Visualization 145 | 146 | MiniChain has a built-in prompt visualization system using `Gradio`. 147 | If you construct a function that calls a prompt chain you can visualize it 148 | by calling `show` and `launch`. This can be done directly in a notebook as well. 149 | 150 | ```python 151 | show(math_demo, 152 | examples=["What is the sum of the powers of 3 (3^i) that are smaller than 100?", 153 | "What is the sum of the 10 first positive integers?"], 154 | subprompts=[math_prompt, python], 155 | out_type="markdown").queue().launch() 156 | ``` 157 | 158 | 159 | ### Memory 160 | 161 | MiniChain does not build in an explicit stateful memory class. We recommend implementing it as a queue. 162 | 163 | ![image](https://user-images.githubusercontent.com/35882/221622653-7b13783e-0439-4d59-8f57-b98b82ab83c0.png) 164 | 165 | Here is a class you might find useful to keep track of responses. 166 | 167 | ```python 168 | @dataclass 169 | class State: 170 | memory: List[Tuple[str, str]] 171 | human_input: str = "" 172 | 173 | def push(self, response: str) -> "State": 174 | memory = self.memory if len(self.memory) < MEMORY_LIMIT else self.memory[1:] 175 | return State(memory + [(self.human_input, response)]) 176 | ``` 177 | 178 | See the full Chat example. 179 | It keeps track of the last two responses that it has seen. 180 | 181 | ### Tools and agents. 182 | 183 | MiniChain does not provide `agents` or `tools`. If you want that functionality you can use the `tool_num` argument of model which allows you to select from multiple different possible backends. It's easy to add new backends of your own (see the GradioExample). 184 | 185 | ```python 186 | @prompt([Python(), Bash()]) 187 | def math_prompt(model, input, lang): 188 | return model(input, tool_num= 0 if lang == "python" else 1) 189 | ``` 190 | 191 | ### Documents and Embeddings 192 | 193 | MiniChain does not manage documents and embeddings. We recommend using 194 | the [Hugging Face Datasets](https://huggingface.co/docs/datasets/index) library with 195 | built in FAISS indexing. 196 | 197 | ![image](https://user-images.githubusercontent.com/35882/221387303-e3dd8456-a0f0-4a70-a1bb-657fe2240862.png) 198 | 199 | 200 | Here is the implementation. 201 | 202 | ```python 203 | # Load and index a dataset 204 | olympics = datasets.load_from_disk("olympics.data") 205 | olympics.add_faiss_index("embeddings") 206 | 207 | @prompt(OpenAIEmbed()) 208 | def get_neighbors(model, inp, k): 209 | embedding = model(inp) 210 | res = olympics.get_nearest_examples("embeddings", np.array(embedding), k) 211 | return res.examples["content"] 212 | ``` 213 | 214 | This creates a K-nearest neighbors (KNN) prompt that looks up the 215 | 3 closest documents based on embeddings of the question asked. 216 | See the full [Retrieval-Augemented QA](https://srush.github.io/MiniChain/examples/qa/) 217 | example. 218 | 219 | 220 | We recommend creating these embeddings offline using the batch map functionality of the 221 | datasets library. 222 | 223 | ```python 224 | def embed(x): 225 | emb = openai.Embedding.create(input=x["content"], engine=EMBEDDING_MODEL) 226 | return {"embeddings": [np.array(emb['data'][i]['embedding']) 227 | for i in range(len(emb["data"]))]} 228 | x = dataset.map(embed, batch_size=BATCH_SIZE, batched=True) 229 | x.save_to_disk("olympics.data") 230 | ``` 231 | 232 | There are other ways to do this such as [sqllite](https://github.com/asg017/sqlite-vss) 233 | or [Weaviate](https://weaviate.io/). 234 | 235 | 236 | ### Typed Prompts 237 | 238 | MiniChain can automatically generate a prompt header for you that aims to ensure the 239 | output follows a given typed specification. For example, if you run the following code 240 | MiniChain will produce prompt that returns a list of `Player` objects. 241 | 242 | ```python 243 | class StatType(Enum): 244 | POINTS = 1 245 | REBOUNDS = 2 246 | ASSISTS = 3 247 | 248 | @dataclass 249 | class Stat: 250 | value: int 251 | stat: StatType 252 | 253 | @dataclass 254 | class Player: 255 | player: str 256 | stats: List[Stat] 257 | 258 | 259 | @prompt(OpenAI(), template_file="stats.pmpt.tpl", parser="json") 260 | def stats(model, passage): 261 | out = model(dict(passage=passage, typ=type_to_prompt(Player))) 262 | return [Player(**j) for j in out] 263 | ``` 264 | 265 | Specifically it will provide your template with a string `typ` that you can use. For this example the string will be of the following form: 266 | 267 | 268 | ``` 269 | You are a highly intelligent and accurate information extraction system. You take passage as input and your task is to find parts of the passage to answer questions. 270 | 271 | You need to output a list of JSON encoded values 272 | 273 | You need to classify in to the following types for key: "color": 274 | 275 | RED 276 | GREEN 277 | BLUE 278 | 279 | 280 | Only select from the above list, or "Other".⏎ 281 | 282 | 283 | You need to classify in to the following types for key: "object":⏎ 284 | 285 | String 286 | 287 | 288 | 289 | You need to classify in to the following types for key: "explanation": 290 | 291 | String 292 | 293 | [{ "color" : "color" , "object" : "object" , "explanation" : "explanation"}, ...] 294 | 295 | Make sure every output is exactly seen in the document. Find as many as you can. 296 | ``` 297 | 298 | This will then be converted to an object automatically for you. 299 | 300 | 301 | -------------------------------------------------------------------------------- /docs/examples: -------------------------------------------------------------------------------- 1 | ../examples/ -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | hide: 3 | - navigation 4 | --- 5 | --8<-- "README.md" 6 | -------------------------------------------------------------------------------- /examples/agent.pmpt.tpl: -------------------------------------------------------------------------------- 1 | Assistant is a large language model trained by OpenAI. 2 | 3 | Assistant is designed to be able to assist with a wide range of tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. As a language model, Assistant is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand. 4 | 5 | Assistant is constantly learning and improving, and its capabilities are constantly evolving. It is able to process and understand large amounts of text, and can use this knowledge to provide accurate and informative responses to a wide range of questions. Additionally, Assistant is able to generate its own text based on the input it receives, allowing it to engage in discussions and provide explanations and descriptions on a wide range of topics. 6 | 7 | Overall, Assistant is a powerful tool that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics. Whether you need help with a specific question or just want to have a conversation about a particular topic, Assistant is here to assist. 8 | 9 | TOOLS: 10 | ------ 11 | 12 | Assistant has access to the following tools: 13 | 14 | {% for tool in tools%} 15 | {{tool[0]}}: {{tool[1]}} 16 | {% endfor %} 17 | 18 | To use a tool, you MUST use exactly the following format: 19 | 20 | ``` 21 | Thought: Do I need to use a tool? Yes 22 | Action: the action to take, should be one of [{% for tool in tools%}{{tool[0]}}, {% endfor %}] 23 | Action Input: the input to the action 24 | Observation: the result of the action 25 | ``` 26 | 27 | When you have a response to say to the Human, or if you do not need to use a tool, you MUST use the format: 28 | 29 | ``` 30 | Thought: Do I need to use a tool? No 31 | AI: [your response here] 32 | ``` 33 | 34 | Do NOT output in any other format. Begin! 35 | 36 | New input: {{input}} 37 | {{agent_scratchpad}} -------------------------------------------------------------------------------- /examples/agent.py: -------------------------------------------------------------------------------- 1 | # + tags=["hide_inp"] 2 | 3 | desc = """ 4 | ### Agent 5 | 6 | Chain that executes different tools based on model decisions. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/bash.ipynb) 7 | 8 | (Adapted from LangChain ) 9 | """ 10 | # - 11 | 12 | # $ 13 | 14 | from minichain import Id, prompt, OpenAI, show, transform, Mock, Break 15 | from gradio_tools.tools import StableDiffusionTool, ImageCaptioningTool, ImageToMusicTool 16 | 17 | 18 | # class ImageCaptioningTool: 19 | # def run(self, inp): 20 | # return "This is a picture of a smiling huggingface logo." 21 | 22 | # description = "Image Captioning" 23 | 24 | tools = [StableDiffusionTool(), ImageCaptioningTool(), ImageToMusicTool()] 25 | 26 | 27 | @prompt(OpenAI(stop=["Observation:"]), 28 | template_file="agent.pmpt.tpl") 29 | def agent(model, query, history): 30 | return model(dict(tools=[(str(tool.__class__.__name__), tool.description) 31 | for tool in tools], 32 | input=query, 33 | agent_scratchpad=history 34 | )) 35 | @transform() 36 | def tool_parse(out): 37 | lines = out.split("\n") 38 | if lines[0].split("?")[-1].strip() == "Yes": 39 | tool = lines[1].split(":", 1)[-1].strip() 40 | command = lines[2].split(":", 1)[-1].strip() 41 | return tool, command 42 | else: 43 | return Break() 44 | 45 | @prompt(tools) 46 | def tool_use(model, usage): 47 | selector, command = usage 48 | for i, tool in enumerate(tools): 49 | if selector == tool.__class__.__name__: 50 | return model(command, tool_num=i) 51 | return ("",) 52 | 53 | @transform() 54 | def append(history, new, observation): 55 | return history + "\n" + new + "Observation: " + observation 56 | 57 | def run(query): 58 | history = "" 59 | observations = [] 60 | for i in range(3): 61 | select_input = agent(query, history) 62 | observations.append(tool_use(tool_parse(select_input))) 63 | history = append(history, select_input, observations[i]) 64 | 65 | return observations[-1] 66 | 67 | # $ 68 | 69 | gradio = show(run, 70 | subprompts=[agent, tool_use] * 3, 71 | examples=[ 72 | "I would please like a photo of a dog riding a skateboard. " 73 | "Please caption this image and create a song for it.", 74 | 'Use an image generator tool to draw a cat.', 75 | 'Caption the image https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png from the internet'], 76 | out_type="markdown", 77 | description=desc, 78 | show_advanced=False 79 | ) 80 | if __name__ == "__main__": 81 | gradio.queue().launch() 82 | 83 | -------------------------------------------------------------------------------- /examples/backtrack.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "158928f5", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "!pip install -q git+https://github.com/srush/MiniChain\n", 11 | "!git clone https://github.com/srush/MiniChain; cp -fr MiniChain/examples/* . " 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "b7be4150", 18 | "metadata": { 19 | "tags": [ 20 | "hide_inp" 21 | ] 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "desc = \"\"\"\n", 26 | "### Backtrack on Failure\n", 27 | "\n", 28 | "Chain that backtracks on failure. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/backtrack.ipynb)\n", 29 | "\n", 30 | "\"\"\"" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "id": "d9a87050", 37 | "metadata": { 38 | "lines_to_next_cell": 1 39 | }, 40 | "outputs": [], 41 | "source": [ 42 | "from minichain import prompt, Mock, show, OpenAI\n", 43 | "import minichain" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "id": "c11692a3", 50 | "metadata": { 51 | "lines_to_next_cell": 1 52 | }, 53 | "outputs": [], 54 | "source": [ 55 | "@prompt(Mock([\"dog\", \"blue\", \"cat\"]))\n", 56 | "def prompt_generation(model):\n", 57 | " return model(\"\")" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "id": "2035f083", 64 | "metadata": { 65 | "lines_to_next_cell": 1 66 | }, 67 | "outputs": [], 68 | "source": [ 69 | "@prompt(OpenAI(), template=\"Answer 'yes' is {{query}} is a color. Answer:\")\n", 70 | "def prompt_validation(model, x):\n", 71 | " out = model(dict(query=x))\n", 72 | " if out.strip().lower().startswith(\"yes\"):\n", 73 | " return x\n", 74 | " return model.fail(1)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "id": "2b9caa2b", 81 | "metadata": { 82 | "lines_to_next_cell": 1 83 | }, 84 | "outputs": [], 85 | "source": [ 86 | "def run():\n", 87 | " x = prompt_generation()\n", 88 | " return prompt_validation(x)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "id": "d825edb7", 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "gradio = show(run,\n", 99 | " examples = [],\n", 100 | " subprompts=[prompt_generation, prompt_validation],\n", 101 | " out_type=\"markdown\"\n", 102 | " )" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "id": "1eac94d8", 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "if __name__ == \"__main__\":\n", 113 | " gradio.launch()" 114 | ] 115 | } 116 | ], 117 | "metadata": { 118 | "jupytext": { 119 | "cell_metadata_filter": "tags,-all", 120 | "main_language": "python", 121 | "notebook_metadata_filter": "-all" 122 | } 123 | }, 124 | "nbformat": 4, 125 | "nbformat_minor": 5 126 | } 127 | -------------------------------------------------------------------------------- /examples/bash.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "0360e829", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "!pip install -q git+https://github.com/srush/MiniChain\n", 11 | "!git clone https://github.com/srush/MiniChain; cp -fr MiniChain/examples/* . " 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "60519a97", 18 | "metadata": { 19 | "lines_to_next_cell": 2, 20 | "tags": [ 21 | "hide_inp" 22 | ] 23 | }, 24 | "outputs": [], 25 | "source": [ 26 | "\n", 27 | "desc = \"\"\"\n", 28 | "### Bash Command Suggestion\n", 29 | "\n", 30 | "Chain that ask for a command-line question and then runs the bash command. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/bash.ipynb)\n", 31 | "\n", 32 | "(Adapted from LangChain [BashChain](https://langchain.readthedocs.io/en/latest/modules/chains/examples/llm_bash.html))\n", 33 | "\"\"\"" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "id": "852ced62", 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "from minichain import show, prompt, OpenAI, Bash" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "id": "b80a8444", 50 | "metadata": { 51 | "lines_to_next_cell": 1 52 | }, 53 | "outputs": [], 54 | "source": [ 55 | "@prompt(OpenAI(), template_file = \"bash.pmpt.tpl\")\n", 56 | "def cli_prompt(model, query):\n", 57 | " x = model(dict(question=query))\n", 58 | " return \"\\n\".join(x.strip().split(\"\\n\")[1:-1])" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "id": "7d6cc929", 65 | "metadata": { 66 | "lines_to_next_cell": 1 67 | }, 68 | "outputs": [], 69 | "source": [ 70 | "@prompt(Bash())\n", 71 | "def bash_run(model, x):\n", 72 | " return model(x)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "id": "3c841ddf", 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "def bash(query):\n", 83 | " return bash_run(cli_prompt(query))" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "id": "cd905126", 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "id": "8b3470a3", 98 | "metadata": { 99 | "lines_to_next_cell": 2 100 | }, 101 | "outputs": [], 102 | "source": [ 103 | "gradio = show(bash,\n", 104 | " subprompts=[cli_prompt, bash_run],\n", 105 | " examples=['Go up one directory, and then into the minichain directory,'\n", 106 | " 'and list the files in the directory',\n", 107 | " \"Please write a bash script that prints 'Hello World' to the console.\"],\n", 108 | " out_type=\"markdown\",\n", 109 | " description=desc,\n", 110 | " )\n", 111 | "if __name__ == \"__main__\":\n", 112 | " gradio.launch()" 113 | ] 114 | } 115 | ], 116 | "metadata": { 117 | "jupytext": { 118 | "cell_metadata_filter": "tags,-all", 119 | "main_language": "python", 120 | "notebook_metadata_filter": "-all" 121 | } 122 | }, 123 | "nbformat": 4, 124 | "nbformat_minor": 5 125 | } 126 | -------------------------------------------------------------------------------- /examples/bash.pmpt.tpl: -------------------------------------------------------------------------------- 1 | If someone asks you to perform a task, your job is to come up with a series of bash commands that will perform the task. There is no need to put "#!/bin/bash" in your answer. Make sure to reason step by step, using this format: 2 | 3 | Question: "copy the files in the directory named 'target' into a new directory at the same level as target called 'myNewDirectory'" 4 | 5 | I need to take the following actions: 6 | - List all files in the directory 7 | - Create a new directory 8 | - Copy the files from the first directory into the second directory 9 | ```bash 10 | ls 11 | mkdir myNewDirectory 12 | cp -r target/* myNewDirectory 13 | ``` 14 | 15 | That is the format. Begin! 16 | 17 | Question: "{{question}}" -------------------------------------------------------------------------------- /examples/bash.py: -------------------------------------------------------------------------------- 1 | # + tags=["hide_inp"] 2 | 3 | desc = """ 4 | ### Bash Command Suggestion 5 | 6 | Chain that ask for a command-line question and then runs the bash command. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/bash.ipynb) 7 | 8 | (Adapted from LangChain [BashChain](https://langchain.readthedocs.io/en/latest/modules/chains/examples/llm_bash.html)) 9 | """ 10 | # - 11 | 12 | # $ 13 | 14 | from minichain import show, prompt, OpenAI, Bash 15 | 16 | 17 | @prompt(OpenAI(), template_file = "bash.pmpt.tpl") 18 | def cli_prompt(model, query): 19 | return model(dict(question=query)) 20 | 21 | @prompt(Bash()) 22 | def bash_run(model, x): 23 | x = "\n".join(x.strip().split("\n")[1:-1]) 24 | return model(x) 25 | 26 | def bash(query): 27 | return bash_run(cli_prompt(query)) 28 | 29 | 30 | # $ 31 | 32 | gradio = show(bash, 33 | subprompts=[cli_prompt, bash_run], 34 | examples=['Go up one directory, and then into the minichain directory,' 35 | 'and list the files in the directory', 36 | "Please write a bash script that prints 'Hello World' to the console."], 37 | out_type="markdown", 38 | description=desc, 39 | code=open("bash.py", "r").read().split("$")[1].strip().strip("#").strip(), 40 | ) 41 | if __name__ == "__main__": 42 | gradio.queue().launch() 43 | 44 | -------------------------------------------------------------------------------- /examples/chat.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "1ca7b759", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "!pip install -q git+https://github.com/srush/MiniChain\n", 11 | "!git clone https://github.com/srush/MiniChain; cp -fr MiniChain/examples/* . " 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "cc3f2a69", 18 | "metadata": { 19 | "lines_to_next_cell": 0, 20 | "tags": [ 21 | "hide_inp" 22 | ] 23 | }, 24 | "outputs": [], 25 | "source": [ 26 | "desc = \"\"\"\n", 27 | "### Chat\n", 28 | "\n", 29 | "A chat-like example for multi-turn chat with state. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/chat.ipynb)\n", 30 | "\n", 31 | "(Adapted from [LangChain](https://langchain.readthedocs.io/en/latest/modules/memory/examples/chatgpt_clone.html)'s version of this [blog post](https://www.engraved.blog/building-a-virtual-machine-inside/).)\n", 32 | "\n", 33 | "\"\"\"" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "id": "1ca0a1eb", 40 | "metadata": { 41 | "lines_to_next_cell": 2 42 | }, 43 | "outputs": [], 44 | "source": [] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "id": "f4ac1417", 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "from dataclasses import dataclass, replace\n", 54 | "from typing import List, Tuple\n", 55 | "from minichain import OpenAI, prompt, show" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "id": "56c8ab3d", 61 | "metadata": {}, 62 | "source": [ 63 | "Generic stateful Memory" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "id": "7c3ffeaa", 70 | "metadata": { 71 | "lines_to_next_cell": 1 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "MEMORY = 2" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "id": "c309a94f", 82 | "metadata": { 83 | "lines_to_next_cell": 1 84 | }, 85 | "outputs": [], 86 | "source": [ 87 | "@dataclass\n", 88 | "class State:\n", 89 | " memory: List[Tuple[str, str]]\n", 90 | " human_input: str = \"\"\n", 91 | "\n", 92 | " def push(self, response: str) -> \"State\":\n", 93 | " memory = self.memory if len(self.memory) < MEMORY else self.memory[1:]\n", 94 | " return State(memory + [(self.human_input, response)])" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "id": "4c9a82f1", 100 | "metadata": {}, 101 | "source": [ 102 | "Chat prompt with memory" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "id": "279179dd", 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "@prompt(OpenAI(), template_file=\"chat.pmpt.tpl\")\n", 113 | "def chat_prompt(model, state: State) -> State:\n", 114 | " out = model(state)\n", 115 | " result = out.split(\"Assistant:\")[-1]\n", 116 | " return state.push(result)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "id": "2d94b22d", 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "examples = [\n", 127 | " \"ls ~\",\n", 128 | " \"cd ~\",\n", 129 | " \"{Please make a file jokes.txt inside and put some jokes inside}\",\n", 130 | " \"\"\"echo -e \"x=lambda y:y*5+3;print('Result:' + str(x(6)))\" > run.py && python3 run.py\"\"\",\n", 131 | " \"\"\"echo -e \"print(list(filter(lambda x: all(x%d for d in range(2,x)),range(2,3**10)))[:10])\" > run.py && python3 run.py\"\"\",\n", 132 | " \"\"\"echo -e \"echo 'Hello from Docker\" > entrypoint.sh && echo -e \"FROM ubuntu:20.04\\nCOPY entrypoint.sh entrypoint.sh\\nENTRYPOINT [\\\"/bin/sh\\\",\\\"entrypoint.sh\\\"]\">Dockerfile && docker build . -t my_docker_image && docker run -t my_docker_image\"\"\",\n", 133 | " \"nvidia-smi\"\n", 134 | "]" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "id": "d77406cf", 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "gradio = show(lambda command, state: chat_prompt(replace(state, human_input=command)),\n", 145 | " initial_state=State([]),\n", 146 | " subprompts=[chat_prompt],\n", 147 | " examples=examples,\n", 148 | " out_type=\"json\",\n", 149 | " description=desc,\n", 150 | ")\n", 151 | "if __name__ == \"__main__\":\n", 152 | " gradio.launch()" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "id": "bd255c7b", 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [] 162 | } 163 | ], 164 | "metadata": { 165 | "jupytext": { 166 | "cell_metadata_filter": "tags,-all", 167 | "main_language": "python", 168 | "notebook_metadata_filter": "-all" 169 | } 170 | }, 171 | "nbformat": 4, 172 | "nbformat_minor": 5 173 | } 174 | -------------------------------------------------------------------------------- /examples/chat.pmpt.tpl: -------------------------------------------------------------------------------- 1 | Assistant is a large language model trained by OpenAI. 2 | 3 | Assistant is designed to be able to assist with a wide range of tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. As a language model, Assistant is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand. 4 | 5 | Assistant is constantly learning and improving, and its capabilities are constantly evolving. It is able to process and understand large amounts of text, and can use this knowledge to provide accurate and informative responses to a wide range of questions. Additionally, Assistant is able to generate its own text based on the input it receives, allowing it to engage in discussions and provide explanations and descriptions on a wide range of topics. 6 | 7 | Overall, Assistant is a powerful tool that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics. Whether you need help with a specific question or just want to have a conversation about a particular topic, Assistant is here to assist. 8 | 9 | I want you to act as a Linux terminal. I will type commands and you will reply with what the terminal should show. I want you to only reply with the terminal output inside one unique code block, and nothing else. Do not write explanations. Do not type commands unless I instruct you to do so. When I need to tell you something in English I will do so by putting text inside curly brackets {like this}. 10 | 11 | {% for d in memory %} 12 | Human: {{d[0]}} 13 | AI: {{d[1]}} 14 | {% endfor %} 15 | 16 | Human: {{human_input}} 17 | Assistant: -------------------------------------------------------------------------------- /examples/chat.py: -------------------------------------------------------------------------------- 1 | # + tags=["hide_inp"] 2 | desc = """ 3 | ### Chat 4 | 5 | A chat-like example for multi-turn chat with state. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/chat.ipynb) 6 | 7 | (Adapted from [LangChain](https://langchain.readthedocs.io/en/latest/modules/memory/examples/chatgpt_clone.html)'s version of this [blog post](https://www.engraved.blog/building-a-virtual-machine-inside/).) 8 | 9 | """ 10 | # - 11 | 12 | 13 | # $ 14 | 15 | from dataclasses import dataclass, replace 16 | from typing import List, Tuple 17 | from minichain import OpenAI, prompt, show, transform, Mock 18 | 19 | # Generic stateful Memory 20 | 21 | MEMORY = 2 22 | 23 | @dataclass 24 | class State: 25 | memory: List[Tuple[str, str]] 26 | human_input: str = "" 27 | 28 | def push(self, response: str) -> "State": 29 | memory = self.memory if len(self.memory) < MEMORY else self.memory[1:] 30 | return State(memory + [(self.human_input, response)]) 31 | 32 | def __str__(self): 33 | return self.memory[-1][-1] 34 | 35 | # Chat prompt with memory 36 | 37 | @prompt(OpenAI(), template_file="chat.pmpt.tpl") 38 | def chat_response(model, state: State) -> State: 39 | return model.stream(state) 40 | 41 | @transform() 42 | def update(state, chat_output): 43 | result = chat_output.split("Assistant:")[-1] 44 | return state.push(result) 45 | 46 | 47 | def chat(command, state): 48 | state = replace(state, human_input=command) 49 | return update(state, chat_response(state)) 50 | 51 | # $ 52 | 53 | examples = [ 54 | "ls ~", 55 | "cd ~", 56 | "{Please make a file jokes.txt inside and put some jokes inside}", 57 | """echo -e "x=lambda y:y*5+3;print('Result:' + str(x(6)))" > run.py && python3 run.py""", 58 | """echo -e "print(list(filter(lambda x: all(x%d for d in range(2,x)),range(2,3**10)))[:10])" > run.py && python3 run.py""", 59 | """echo -e "echo 'Hello from Docker" > entrypoint.sh && echo -e "FROM ubuntu:20.04\nCOPY entrypoint.sh entrypoint.sh\nENTRYPOINT [\"/bin/sh\",\"entrypoint.sh\"]">Dockerfile && docker build . -t my_docker_image && docker run -t my_docker_image""", 60 | "nvidia-smi" 61 | ] 62 | 63 | print(chat("ls", State([])).run()) 64 | 65 | gradio = show(chat, 66 | initial_state=State([]), 67 | subprompts=[chat_response], 68 | examples=examples, 69 | out_type="json", 70 | description=desc, 71 | code=open("chat.py", "r").read().split("$")[1].strip().strip("#").strip(), 72 | ) 73 | if __name__ == "__main__": 74 | gradio.queue().launch() 75 | 76 | 77 | -------------------------------------------------------------------------------- /examples/chatgpt.pmpt.tpl: -------------------------------------------------------------------------------- 1 | Assistant is a large language model trained by OpenAI. 2 | 3 | Assistant is designed to be able to assist with a wide range of tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. As a language model, Assistant is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand. 4 | 5 | Assistant is constantly learning and improving, and its capabilities are constantly evolving. It is able to process and understand large amounts of text, and can use this knowledge to provide accurate and informative responses to a wide range of questions. Additionally, Assistant is able to generate its own text based on the input it receives, allowing it to engage in discussions and provide explanations and descriptions on a wide range of topics. 6 | 7 | Overall, Assistant is a powerful tool that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics. Whether you need help with a specific question or just want to have a conversation about a particular topic, Assistant is here to assist. 8 | 9 | {% for d in memory %} 10 | Human: {{d[0]}} 11 | AI: {{d[1]}} 12 | {% endfor %} 13 | 14 | Human: {{human_input}} 15 | Assistant: -------------------------------------------------------------------------------- /examples/gatsby.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "b57f44dc", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "!pip install -q git+https://github.com/srush/MiniChain\n", 11 | "!git clone https://github.com/srush/MiniChain; cp -fr MiniChain/examples/* . " 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "f675b525", 18 | "metadata": { 19 | "lines_to_next_cell": 2, 20 | "tags": [ 21 | "hide_inp" 22 | ] 23 | }, 24 | "outputs": [], 25 | "source": [ 26 | "desc = \"\"\"\n", 27 | "### Book QA\n", 28 | "\n", 29 | "Chain that does question answering with Hugging Face embeddings. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/gatsby.ipynb)\n", 30 | "\n", 31 | "(Adapted from the [LlamaIndex example](https://github.com/jerryjliu/gpt_index/blob/main/examples/gatsby/TestGatsby.ipynb).)\n", 32 | "\"\"\"" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "id": "eab06e25", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "import datasets\n", 43 | "import numpy as np\n", 44 | "from minichain import prompt, show, HuggingFaceEmbed, OpenAI" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "id": "ad893c7e", 50 | "metadata": {}, 51 | "source": [ 52 | "Load data with embeddings (computed beforehand)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "id": "c64cbf9f", 59 | "metadata": { 60 | "lines_to_next_cell": 1 61 | }, 62 | "outputs": [], 63 | "source": [ 64 | "gatsby = datasets.load_from_disk(\"gatsby\")\n", 65 | "gatsby.add_faiss_index(\"embeddings\")" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "id": "20b8d756", 71 | "metadata": {}, 72 | "source": [ 73 | "Fast KNN retieval prompt" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "id": "a786c893", 80 | "metadata": { 81 | "lines_to_next_cell": 1 82 | }, 83 | "outputs": [], 84 | "source": [ 85 | "@prompt(HuggingFaceEmbed(\"sentence-transformers/all-mpnet-base-v2\"))\n", 86 | "def get_neighbors(model, inp, k=1):\n", 87 | " embedding = model(inp)\n", 88 | " res = olympics.get_nearest_examples(\"embeddings\", np.array(embedding), k)\n", 89 | " return res.examples[\"passages\"]" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "id": "fac56e96", 96 | "metadata": { 97 | "lines_to_next_cell": 1 98 | }, 99 | "outputs": [], 100 | "source": [ 101 | "@prompt(OpenAI(),\n", 102 | " template_file=\"gatsby.pmpt.tpl\")\n", 103 | "def ask(model, query, neighbors):\n", 104 | " return model(dict(question=query, docs=neighbors))" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "id": "4e46761b", 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "def gatsby(query):\n", 115 | " n = get_neighbors(query)\n", 116 | " return ask(query, n)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "id": "e9bafe2a", 123 | "metadata": { 124 | "lines_to_next_cell": 2 125 | }, 126 | "outputs": [], 127 | "source": [] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "id": "913793a8", 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "gradio = show(gatsby,\n", 137 | " subprompts=[get_neighbors, ask],\n", 138 | " examples=[\"What did Gatsby do before he met Daisy?\",\n", 139 | " \"What did the narrator do after getting back to Chicago?\"],\n", 140 | " keys={\"HF_KEY\"},\n", 141 | " description=desc,\n", 142 | " )\n", 143 | "if __name__ == \"__main__\":\n", 144 | " gradio.launch()" 145 | ] 146 | } 147 | ], 148 | "metadata": { 149 | "jupytext": { 150 | "cell_metadata_filter": "tags,-all", 151 | "main_language": "python", 152 | "notebook_metadata_filter": "-all" 153 | } 154 | }, 155 | "nbformat": 4, 156 | "nbformat_minor": 5 157 | } 158 | -------------------------------------------------------------------------------- /examples/gatsby.pmpt.tpl: -------------------------------------------------------------------------------- 1 | Context information is below. 2 | 3 | 4 | --------------------- 5 | 6 | 7 | {% for doc in docs %} 8 | * {{doc}} 9 | {% endfor %} 10 | 11 | --------------------- 12 | 13 | Given the context information and not prior knowledge, answer the question: {{question}} -------------------------------------------------------------------------------- /examples/gatsby.py: -------------------------------------------------------------------------------- 1 | # + tags=["hide_inp"] 2 | desc = """ 3 | ### Book QA 4 | 5 | Chain that does question answering with Hugging Face embeddings. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/gatsby.ipynb) 6 | 7 | (Adapted from the [LlamaIndex example](https://github.com/jerryjliu/gpt_index/blob/main/examples/gatsby/TestGatsby.ipynb).) 8 | """ 9 | # - 10 | 11 | # $ 12 | 13 | import datasets 14 | import numpy as np 15 | from minichain import prompt, show, HuggingFaceEmbed, OpenAI, transform 16 | 17 | # Load data with embeddings (computed beforehand) 18 | 19 | gatsby = datasets.load_from_disk("gatsby") 20 | gatsby.add_faiss_index("embeddings") 21 | 22 | # Fast KNN retrieval prompt 23 | 24 | @prompt(HuggingFaceEmbed("sentence-transformers/all-mpnet-base-v2")) 25 | def embed(model, inp): 26 | return model(inp) 27 | 28 | @transform() 29 | def get_neighbors(embedding, k=1): 30 | res = gatsby.get_nearest_examples("embeddings", np.array(embedding), k) 31 | return res.examples["passages"] 32 | 33 | @prompt(OpenAI(), template_file="gatsby.pmpt.tpl") 34 | def ask(model, query, neighbors): 35 | return model(dict(question=query, docs=neighbors)) 36 | 37 | def gatsby_q(query): 38 | n = get_neighbors(embed(query)) 39 | return ask(query, n) 40 | 41 | 42 | # $ 43 | 44 | 45 | gradio = show(gatsby_q, 46 | subprompts=[ask], 47 | examples=["What did Gatsby do before he met Daisy?", 48 | "What did the narrator do after getting back to Chicago?"], 49 | keys={"HF_KEY"}, 50 | description=desc, 51 | code=open("gatsby.py", "r").read().split("$")[1].strip().strip("#").strip() 52 | ) 53 | if __name__ == "__main__": 54 | gradio.queue().launch() 55 | -------------------------------------------------------------------------------- /examples/gatsby/data-00000-of-00001.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srush/MiniChain/637d310ccd77dd7cb3197c826d0a304cafce65b2/examples/gatsby/data-00000-of-00001.arrow -------------------------------------------------------------------------------- /examples/gatsby/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "citation": "", 3 | "description": "", 4 | "features": { 5 | "passages": { 6 | "dtype": "string", 7 | "_type": "Value" 8 | }, 9 | "embeddings": { 10 | "feature": { 11 | "dtype": "float64", 12 | "_type": "Value" 13 | }, 14 | "_type": "Sequence" 15 | } 16 | }, 17 | "homepage": "", 18 | "license": "" 19 | } -------------------------------------------------------------------------------- /examples/gatsby/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "data-00000-of-00001.arrow" 5 | } 6 | ], 7 | "_fingerprint": "58e539e18c1f1ec8", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_output_all_columns": false, 12 | "_split": null 13 | } -------------------------------------------------------------------------------- /examples/gradio_example.py: -------------------------------------------------------------------------------- 1 | # + tags=["hide_inp"] 2 | 3 | desc = """ 4 | ### Gradio Tool 5 | 6 | Examples using the gradio tool [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/gradio_example.ipynb) 7 | 8 | """ 9 | # - 10 | 11 | # $ 12 | 13 | from minichain import show, prompt, OpenAI, GradioConf 14 | import gradio as gr 15 | from gradio_tools.tools import StableDiffusionTool, ImageCaptioningTool 16 | 17 | @prompt(OpenAI()) 18 | def picture(model, query): 19 | return model(query) 20 | 21 | @prompt(StableDiffusionTool(), 22 | gradio_conf=GradioConf( 23 | block_output= lambda: gr.Image(), 24 | block_input= lambda: gr.Textbox(show_label=False))) 25 | def gen(model, query): 26 | return model(query) 27 | 28 | @prompt(ImageCaptioningTool(), 29 | gradio_conf=GradioConf( 30 | block_input= lambda: gr.Image(), 31 | block_output=lambda: gr.Textbox(show_label=False))) 32 | def caption(model, img_src): 33 | return model(img_src) 34 | 35 | def gradio_example(query): 36 | return caption(gen(picture(query))) 37 | 38 | 39 | # $ 40 | 41 | gradio = show(gradio_example, 42 | subprompts=[picture, gen, caption], 43 | examples=['Describe a one-sentence fantasy scene.', 44 | 'Describe a one-sentence scene happening on the moon.'], 45 | out_type="markdown", 46 | description=desc, 47 | show_advanced=False 48 | ) 49 | if __name__ == "__main__": 50 | gradio.queue().launch() 51 | 52 | -------------------------------------------------------------------------------- /examples/math.pmpt.tpl: -------------------------------------------------------------------------------- 1 | #### Question: 2 | 3 | * What is 37593 * 67? 4 | 5 | #### Code: 6 | 7 | ```python 8 | print(37593 * 67) 9 | ``` 10 | 11 | #### Question: 12 | 13 | * Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market? 14 | 15 | #### Code: 16 | 17 | ```python 18 | print((16-3-4)*2) 19 | ``` 20 | 21 | #### Question: 22 | 23 | * How many of the integers between 0 and 99 inclusive are divisible by 8? 24 | 25 | #### Code: 26 | 27 | ```python 28 | count = 0 29 | for i in range(0, 99+1): 30 | if i % 8 == 0: count += 1 31 | print(count) 32 | ``` 33 | 34 | #### Question: 35 | 36 | * A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take? 37 | 38 | #### Code: 39 | 40 | ```python 41 | print(2 + 2/2) 42 | ``` 43 | 44 | #### Question: 45 | 46 | * {{question}} 47 | 48 | #### Code: -------------------------------------------------------------------------------- /examples/math_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "a914a238", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "!pip install -q git+https://github.com/srush/MiniChain\n", 11 | "!git clone https://github.com/srush/MiniChain; cp -fr MiniChain/examples/* . " 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "aa7b8a98", 18 | "metadata": { 19 | "lines_to_next_cell": 2, 20 | "tags": [ 21 | "hide_inp" 22 | ] 23 | }, 24 | "outputs": [], 25 | "source": [ 26 | "desc = \"\"\"\n", 27 | "### Word Problem Solver\n", 28 | "\n", 29 | "Chain that solves a math word problem by first generating and then running Python code. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/math_demo.ipynb)\n", 30 | "\n", 31 | "(Adapted from Dust [maths-generate-code](https://dust.tt/spolu/a/d12ac33169))\n", 32 | "\"\"\"" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "id": "1a34507a", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "from minichain import show, prompt, OpenAI, Python" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "id": "43cdbfc0", 49 | "metadata": { 50 | "lines_to_next_cell": 1 51 | }, 52 | "outputs": [], 53 | "source": [ 54 | "@prompt(OpenAI(), template_file=\"math.pmpt.tpl\")\n", 55 | "def math_prompt(model, question):\n", 56 | " \"Prompt to call GPT with a Jinja template\"\n", 57 | " return model(dict(question=question))" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "id": "14d6eb95", 64 | "metadata": { 65 | "lines_to_next_cell": 1 66 | }, 67 | "outputs": [], 68 | "source": [ 69 | "@prompt(Python(), template=\"import math\\n{{code}}\")\n", 70 | "def python(model, code):\n", 71 | " \"Prompt to call Python interpreter\"\n", 72 | " code = \"\\n\".join(code.strip().split(\"\\n\")[1:-1])\n", 73 | " return int(model(dict(code=code)))" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "id": "dac2ba43", 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "def math_demo(question):\n", 84 | " \"Chain them together\"\n", 85 | " return python(math_prompt(question))" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "id": "4a4f5bbf", 92 | "metadata": { 93 | "lines_to_next_cell": 0, 94 | "tags": [ 95 | "hide_inp" 96 | ] 97 | }, 98 | "outputs": [], 99 | "source": [ 100 | "gradio = show(math_demo,\n", 101 | " examples=[\"What is the sum of the powers of 3 (3^i) that are smaller than 100?\",\n", 102 | " \"What is the sum of the 10 first positive integers?\",],\n", 103 | " # \"Carla is downloading a 200 GB file. She can download 2 GB/minute, but 40% of the way through the download, the download fails. Then Carla has to restart the download from the beginning. How load did it take her to download the file in minutes?\"],\n", 104 | " subprompts=[math_prompt, python],\n", 105 | " out_type=\"json\",\n", 106 | " description=desc,\n", 107 | " )\n", 108 | "if __name__ == \"__main__\":\n", 109 | " gradio.launch()" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "id": "a478587c", 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [] 119 | } 120 | ], 121 | "metadata": { 122 | "jupytext": { 123 | "cell_metadata_filter": "tags,-all", 124 | "main_language": "python", 125 | "notebook_metadata_filter": "-all" 126 | } 127 | }, 128 | "nbformat": 4, 129 | "nbformat_minor": 5 130 | } 131 | -------------------------------------------------------------------------------- /examples/math_demo.py: -------------------------------------------------------------------------------- 1 | # + tags=["hide_inp"] 2 | desc = """ 3 | ### Word Problem Solver 4 | 5 | Chain that solves a math word problem by first generating and then running Python code. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/math_demo.ipynb) 6 | 7 | (Adapted from Dust [maths-generate-code](https://dust.tt/spolu/a/d12ac33169)) 8 | """ 9 | # - 10 | 11 | # $ 12 | 13 | from minichain import show, prompt, OpenAI, Python, GradioConf 14 | import gradio as gr 15 | 16 | 17 | @prompt(OpenAI(), template_file="math.pmpt.tpl", 18 | gradio_conf=GradioConf(block_input=gr.Markdown)) 19 | def math_prompt(model, question 20 | ): 21 | "Prompt to call GPT with a Jinja template" 22 | return model(dict(question=question)) 23 | 24 | @prompt(Python(), template="import math\n{{code}}") 25 | def python(model, code): 26 | "Prompt to call Python interpreter" 27 | code = "\n".join(code.strip().split("\n")[1:-1]) 28 | return model(dict(code=code)) 29 | 30 | def math_demo(question): 31 | "Chain them together" 32 | return python(math_prompt(question)) 33 | 34 | # $ 35 | 36 | # + tags=["hide_inp"] 37 | gradio = show(math_demo, 38 | examples=["What is the sum of the powers of 3 (3^i) that are smaller than 100?", 39 | "What is the sum of the 10 first positive integers?",], 40 | # "Carla is downloading a 200 GB file. She can download 2 GB/minute, but 40% of the way through the download, the download fails. Then Carla has to restart the download from the beginning. How load did it take her to download the file in minutes?"], 41 | subprompts=[math_prompt, python], 42 | description=desc, 43 | code=open("math_demo.py", "r").read().split("$")[1].strip().strip("#").strip(), 44 | ) 45 | if __name__ == "__main__": 46 | gradio.queue().launch() 47 | # - 48 | 49 | -------------------------------------------------------------------------------- /examples/ner.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "8eccacc7", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "!pip install -q git+https://github.com/srush/MiniChain\n", 11 | "!git clone https://github.com/srush/MiniChain; cp -fr MiniChain/examples/* . " 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "ebfe63f6", 18 | "metadata": { 19 | "lines_to_next_cell": 2, 20 | "tags": [ 21 | "hide_inp" 22 | ] 23 | }, 24 | "outputs": [], 25 | "source": [ 26 | "\n", 27 | "desc = \"\"\"\n", 28 | "### Named Entity Recognition\n", 29 | "\n", 30 | "Chain that does named entity recognition with arbitrary labels. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/ner.ipynb)\n", 31 | "\n", 32 | "(Adapted from [promptify](https://github.com/promptslab/Promptify/blob/main/promptify/prompts/nlp/templates/ner.jinja)).\n", 33 | "\"\"\"" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "id": "45dd8a11", 40 | "metadata": { 41 | "lines_to_next_cell": 1 42 | }, 43 | "outputs": [], 44 | "source": [ 45 | "from minichain import prompt, show, OpenAI" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "id": "9ada6ebb", 52 | "metadata": { 53 | "lines_to_next_cell": 1 54 | }, 55 | "outputs": [], 56 | "source": [ 57 | "@prompt(OpenAI(), template_file = \"ner.pmpt.tpl\", parser=\"json\")\n", 58 | "def ner_extract(model, kwargs):\n", 59 | " return model(kwargs)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "id": "f6873c42", 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "@prompt(OpenAI())\n", 70 | "def team_describe(model, inp):\n", 71 | " query = \"Can you describe these basketball teams? \" + \\\n", 72 | " \" \".join([i[\"E\"] for i in inp if i[\"T\"] ==\"Team\"])\n", 73 | " return model(query)" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "id": "a89fa41d", 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "def ner(text_input, labels, domain):\n", 84 | " extract = ner_extract(dict(text_input=text_input, labels=labels, domain=domain))\n", 85 | " return team_describe(extract)" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "id": "3e8a0502", 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "id": "634fb50b", 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "gradio = show(ner,\n", 104 | " examples=[[\"An NBA playoff pairing a year ago, the 76ers (39-20) meet the Miami Heat (32-29) for the first time this season on Monday night at home.\", \"Team, Date\", \"Sports\"]],\n", 105 | " description=desc,\n", 106 | " subprompts=[ner_extract, team_describe],\n", 107 | " )" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "id": "fa353224", 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "if __name__ == \"__main__\":\n", 118 | " gradio.launch()" 119 | ] 120 | } 121 | ], 122 | "metadata": { 123 | "jupytext": { 124 | "cell_metadata_filter": "tags,-all", 125 | "main_language": "python", 126 | "notebook_metadata_filter": "-all" 127 | } 128 | }, 129 | "nbformat": 4, 130 | "nbformat_minor": 5 131 | } 132 | -------------------------------------------------------------------------------- /examples/ner.pmpt.tpl: -------------------------------------------------------------------------------- 1 | You are a highly intelligent and accurate {{ domain }} domain Named-entity recognition(NER) system. You take Passage as input and your task is to recognize and extract specific types of {{ domain }} domain named entities in that given passage and classify into a set of following predefined entity types: 2 | 3 | {{labels}} 4 | 5 | Your output format is only {{ output_format|default('[{"T": type of entity from predefined entity types, "E": entity in the input text}]') }} form, no other form. 6 | 7 | Input: {{ text_input }} 8 | Output: -------------------------------------------------------------------------------- /examples/ner.py: -------------------------------------------------------------------------------- 1 | # + tags=["hide_inp"] 2 | 3 | desc = """ 4 | ### Named Entity Recognition 5 | 6 | Chain that does named entity recognition with arbitrary labels. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/ner.ipynb) 7 | 8 | (Adapted from [promptify](https://github.com/promptslab/Promptify/blob/main/promptify/prompts/nlp/templates/ner.jinja)). 9 | """ 10 | # - 11 | 12 | # $ 13 | 14 | from minichain import prompt, transform, show, OpenAI 15 | import json 16 | 17 | @prompt(OpenAI(), template_file = "ner.pmpt.tpl") 18 | def ner_extract(model, kwargs): 19 | return model(kwargs) 20 | 21 | @transform() 22 | def to_json(chat_output): 23 | return json.loads(chat_output) 24 | 25 | @prompt(OpenAI()) 26 | def team_describe(model, inp): 27 | query = "Can you describe these basketball teams? " + \ 28 | " ".join([i["E"] for i in inp if i["T"] =="Team"]) 29 | return model(query) 30 | 31 | 32 | def ner(text_input, labels, domain): 33 | extract = to_json(ner_extract(dict(text_input=text_input, labels=labels, domain=domain))) 34 | return team_describe(extract) 35 | 36 | 37 | # $ 38 | 39 | gradio = show(ner, 40 | examples=[["An NBA playoff pairing a year ago, the 76ers (39-20) meet the Miami Heat (32-29) for the first time this season on Monday night at home.", "Team, Date", "Sports"]], 41 | description=desc, 42 | subprompts=[ner_extract, team_describe], 43 | code=open("ner.py", "r").read().split("$")[1].strip().strip("#").strip(), 44 | ) 45 | 46 | if __name__ == "__main__": 47 | gradio.queue().launch() 48 | -------------------------------------------------------------------------------- /examples/olympics.data/data-00000-of-00001.arrow: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srush/MiniChain/637d310ccd77dd7cb3197c826d0a304cafce65b2/examples/olympics.data/data-00000-of-00001.arrow -------------------------------------------------------------------------------- /examples/olympics.data/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "builder_name": "csv", 3 | "citation": "", 4 | "config_name": "default", 5 | "dataset_size": 2548363, 6 | "description": "", 7 | "download_checksums": { 8 | "https://cdn.openai.com/API/examples/data/olympics_sections_text.csv": { 9 | "num_bytes": 2503410, 10 | "checksum": null 11 | } 12 | }, 13 | "download_size": 2503410, 14 | "features": { 15 | "title": { 16 | "dtype": "string", 17 | "_type": "Value" 18 | }, 19 | "heading": { 20 | "dtype": "string", 21 | "_type": "Value" 22 | }, 23 | "content": { 24 | "dtype": "string", 25 | "_type": "Value" 26 | }, 27 | "tokens": { 28 | "dtype": "int64", 29 | "_type": "Value" 30 | }, 31 | "embeddings": { 32 | "feature": { 33 | "dtype": "float64", 34 | "_type": "Value" 35 | }, 36 | "_type": "Sequence" 37 | } 38 | }, 39 | "homepage": "", 40 | "license": "", 41 | "size_in_bytes": 5051773, 42 | "splits": { 43 | "train": { 44 | "name": "train", 45 | "num_bytes": 2548363, 46 | "num_examples": 3964, 47 | "dataset_name": "csv" 48 | } 49 | }, 50 | "version": { 51 | "version_str": "0.0.0", 52 | "major": 0, 53 | "minor": 0, 54 | "patch": 0 55 | } 56 | } -------------------------------------------------------------------------------- /examples/olympics.data/state.json: -------------------------------------------------------------------------------- 1 | { 2 | "_data_files": [ 3 | { 4 | "filename": "data-00000-of-00001.arrow" 5 | } 6 | ], 7 | "_fingerprint": "78ad0f5ec2d98f88", 8 | "_format_columns": null, 9 | "_format_kwargs": {}, 10 | "_format_type": null, 11 | "_output_all_columns": false, 12 | "_split": "train" 13 | } -------------------------------------------------------------------------------- /examples/pal.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "c0075889", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "!pip install -q git+https://github.com/srush/MiniChain\n", 11 | "!git clone https://github.com/srush/MiniChain; cp -fr MiniChain/examples/* . " 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "033f7bd9", 18 | "metadata": { 19 | "lines_to_next_cell": 2 20 | }, 21 | "outputs": [], 22 | "source": [ 23 | "desc = \"\"\"\n", 24 | "### Prompt-aided Language Models\n", 25 | "\n", 26 | "Chain for answering complex problems by code generation and execution. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/pal.ipynb)\n", 27 | "\n", 28 | "(Adapted from Prompt-aided Language Models [PAL](https://arxiv.org/pdf/2211.10435.pdf)).\n", 29 | "\"\"\"" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "id": "ada25bcd", 36 | "metadata": { 37 | "lines_to_next_cell": 1 38 | }, 39 | "outputs": [], 40 | "source": [ 41 | "from minichain import prompt, show, OpenAI, Python" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "id": "3249f5ac", 48 | "metadata": { 49 | "lines_to_next_cell": 1 50 | }, 51 | "outputs": [], 52 | "source": [ 53 | "@prompt(OpenAI(), template_file=\"pal.pmpt.tpl\")\n", 54 | "def pal_prompt(model, question):\n", 55 | " return model(dict(question=question))" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "id": "37f265c9", 62 | "metadata": { 63 | "lines_to_next_cell": 1 64 | }, 65 | "outputs": [], 66 | "source": [ 67 | "@prompt(Python())\n", 68 | "def python(model, inp):\n", 69 | " return float(model(inp + \"\\nprint(solution())\"))" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "id": "e767c1eb", 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "def pal(question):\n", 80 | " return python(pal_prompt(question))" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "id": "f2a4f241", 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "question = \"Melanie is a door-to-door saleswoman. She sold a third of her \" \\\n", 91 | " \"vacuum cleaners at the green house, 2 more to the red house, and half of \" \\\n", 92 | " \"what was left at the orange house. If Melanie has 5 vacuum cleaners left, \" \\\n", 93 | " \"how many did she start with?\"" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "id": "c22e7837", 100 | "metadata": {}, 101 | "outputs": [], 102 | "source": [ 103 | "gradio = show(pal,\n", 104 | " examples=[question],\n", 105 | " subprompts=[pal_prompt, python],\n", 106 | " description=desc,\n", 107 | " out_type=\"json\",\n", 108 | " )" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "id": "8790a65e", 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "if __name__ == \"__main__\":\n", 119 | " gradio.launch()" 120 | ] 121 | } 122 | ], 123 | "metadata": { 124 | "jupytext": { 125 | "cell_metadata_filter": "-all", 126 | "main_language": "python", 127 | "notebook_metadata_filter": "-all" 128 | } 129 | }, 130 | "nbformat": 4, 131 | "nbformat_minor": 5 132 | } 133 | -------------------------------------------------------------------------------- /examples/pal.pmpt.tpl: -------------------------------------------------------------------------------- 1 | Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left? 2 | 3 | # solution in Python: 4 | 5 | 6 | def solution(): 7 | """Olivia has $23. She bought five bagels for $3 each. How much money does she have left?""" 8 | money_initial = 23 9 | bagels = 5 10 | bagel_cost = 3 11 | money_spent = bagels * bagel_cost 12 | money_left = money_initial - money_spent 13 | result = money_left 14 | return result 15 | 16 | 17 | 18 | 19 | 20 | Q: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday? 21 | 22 | # solution in Python: 23 | 24 | 25 | def solution(): 26 | """Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?""" 27 | golf_balls_initial = 58 28 | golf_balls_lost_tuesday = 23 29 | golf_balls_lost_wednesday = 2 30 | golf_balls_left = golf_balls_initial - golf_balls_lost_tuesday - golf_balls_lost_wednesday 31 | result = golf_balls_left 32 | return result 33 | 34 | 35 | 36 | 37 | 38 | Q: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room? 39 | 40 | # solution in Python: 41 | 42 | 43 | def solution(): 44 | """There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?""" 45 | computers_initial = 9 46 | computers_per_day = 5 47 | num_days = 4 # 4 days between monday and thursday 48 | computers_added = computers_per_day * num_days 49 | computers_total = computers_initial + computers_added 50 | result = computers_total 51 | return result 52 | 53 | 54 | 55 | 56 | 57 | Q: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now? 58 | 59 | # solution in Python: 60 | 61 | 62 | def solution(): 63 | """Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?""" 64 | toys_initial = 5 65 | mom_toys = 2 66 | dad_toys = 2 67 | total_received = mom_toys + dad_toys 68 | total_toys = toys_initial + total_received 69 | result = total_toys 70 | return result 71 | 72 | 73 | 74 | 75 | 76 | Q: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny? 77 | 78 | # solution in Python: 79 | 80 | 81 | def solution(): 82 | """Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?""" 83 | jason_lollipops_initial = 20 84 | jason_lollipops_after = 12 85 | denny_lollipops = jason_lollipops_initial - jason_lollipops_after 86 | result = denny_lollipops 87 | return result 88 | 89 | 90 | 91 | 92 | 93 | Q: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total? 94 | 95 | # solution in Python: 96 | 97 | 98 | def solution(): 99 | """Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?""" 100 | leah_chocolates = 32 101 | sister_chocolates = 42 102 | total_chocolates = leah_chocolates + sister_chocolates 103 | chocolates_eaten = 35 104 | chocolates_left = total_chocolates - chocolates_eaten 105 | result = chocolates_left 106 | return result 107 | 108 | 109 | 110 | 111 | 112 | Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot? 113 | 114 | # solution in Python: 115 | 116 | 117 | def solution(): 118 | """If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?""" 119 | cars_initial = 3 120 | cars_arrived = 2 121 | total_cars = cars_initial + cars_arrived 122 | result = total_cars 123 | return result 124 | 125 | 126 | 127 | 128 | 129 | Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today? 130 | 131 | # solution in Python: 132 | 133 | 134 | def solution(): 135 | """There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?""" 136 | trees_initial = 15 137 | trees_after = 21 138 | trees_added = trees_after - trees_initial 139 | result = trees_added 140 | return result 141 | 142 | 143 | 144 | 145 | 146 | Q: {{question}} 147 | 148 | # solution in Python: -------------------------------------------------------------------------------- /examples/pal.py: -------------------------------------------------------------------------------- 1 | desc = """ 2 | ### Prompt-aided Language Models 3 | 4 | Chain for answering complex problems by code generation and execution. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/pal.ipynb) 5 | 6 | (Adapted from Prompt-aided Language Models [PAL](https://arxiv.org/pdf/2211.10435.pdf)). 7 | """ 8 | 9 | # $ 10 | 11 | from minichain import prompt, show, GradioConf, OpenAI, Python 12 | import gradio as gr 13 | 14 | @prompt(OpenAI(), template_file="pal.pmpt.tpl") 15 | def pal_prompt(model, question): 16 | return model(dict(question=question)) 17 | 18 | @prompt(Python(), 19 | gradio_conf=GradioConf(block_input = lambda: gr.Code(language="python"))) 20 | def python(model, inp): 21 | return model(inp + "\nprint(solution())") 22 | 23 | def pal(question): 24 | return python(pal_prompt(question)) 25 | 26 | # $ 27 | 28 | question = "Melanie is a door-to-door saleswoman. She sold a third of her " \ 29 | "vacuum cleaners at the green house, 2 more to the red house, and half of " \ 30 | "what was left at the orange house. If Melanie has 5 vacuum cleaners left, " \ 31 | "how many did she start with?" 32 | 33 | gradio = show(pal, 34 | examples=[question], 35 | subprompts=[pal_prompt, python], 36 | description=desc, 37 | out_type="json", 38 | code=open("pal.py", "r").read().split("$")[1].strip().strip("#").strip(), 39 | ) 40 | 41 | if __name__ == "__main__": 42 | gradio.queue().launch() 43 | -------------------------------------------------------------------------------- /examples/parallel.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srush/MiniChain/637d310ccd77dd7cb3197c826d0a304cafce65b2/examples/parallel.py -------------------------------------------------------------------------------- /examples/qa.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "1f7e3a8e", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "!pip install -q git+https://github.com/srush/MiniChain\n", 11 | "!git clone https://github.com/srush/MiniChain; cp -fr MiniChain/examples/* . " 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "49443595", 18 | "metadata": { 19 | "lines_to_next_cell": 2, 20 | "tags": [ 21 | "hide_inp" 22 | ] 23 | }, 24 | "outputs": [], 25 | "source": [ 26 | "desc = \"\"\"\n", 27 | "### Question Answering with Retrieval\n", 28 | "\n", 29 | "Chain that answers questions with embeedding based retrieval. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/qa.ipynb)\n", 30 | "\n", 31 | "(Adapted from [OpenAI Notebook](https://github.com/openai/openai-cookbook/blob/main/examples/Question_answering_using_embeddings.ipynb).)\n", 32 | "\"\"\"" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "id": "f5183ea7", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "import datasets\n", 43 | "import numpy as np\n", 44 | "from minichain import prompt, show, OpenAIEmbed, OpenAI\n", 45 | "from manifest import Manifest" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "id": "2bf59f0d", 51 | "metadata": {}, 52 | "source": [ 53 | "We use Hugging Face Datasets as the database by assigning\n", 54 | "a FAISS index." 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "id": "f371a85e", 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "olympics = datasets.load_from_disk(\"olympics.data\")\n", 65 | "olympics.add_faiss_index(\"embeddings\")" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "id": "a1099002", 71 | "metadata": {}, 72 | "source": [ 73 | "Fast KNN retieval prompt" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "id": "6881ae0e", 80 | "metadata": { 81 | "lines_to_next_cell": 1 82 | }, 83 | "outputs": [], 84 | "source": [ 85 | "@prompt(OpenAIEmbed())\n", 86 | "def get_neighbors(model, inp, k):\n", 87 | " embedding = model(inp)\n", 88 | " res = olympics.get_nearest_examples(\"embeddings\", np.array(embedding), k)\n", 89 | " return res.examples[\"content\"]" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "id": "59cc1355", 96 | "metadata": { 97 | "lines_to_next_cell": 1 98 | }, 99 | "outputs": [], 100 | "source": [ 101 | "@prompt(OpenAI(),\n", 102 | " template_file=\"qa.pmpt.tpl\")\n", 103 | "def get_result(model, query, neighbors):\n", 104 | " return model(dict(question=query, docs=neighbors))" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "id": "cb2f1101", 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "def qa(query):\n", 115 | " n = get_neighbors(query, 3)\n", 116 | " return get_result(query, n)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "id": "5f70bac7", 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "id": "abdfcd87", 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "questions = [\"Who won the 2020 Summer Olympics men's high jump?\",\n", 135 | " \"Why was the 2020 Summer Olympics originally postponed?\",\n", 136 | " \"In the 2020 Summer Olympics, how many gold medals did the country which won the most medals win?\",\n", 137 | " \"What is the total number of medals won by France?\",\n", 138 | " \"What is the tallest mountain in the world?\"]" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "id": "ddce3ec3", 145 | "metadata": { 146 | "lines_to_next_cell": 2 147 | }, 148 | "outputs": [], 149 | "source": [ 150 | "gradio = show(qa,\n", 151 | " examples=questions,\n", 152 | " subprompts=[get_neighbors, get_result],\n", 153 | " description=desc,\n", 154 | " )\n", 155 | "if __name__ == \"__main__\":\n", 156 | " gradio.launch()" 157 | ] 158 | } 159 | ], 160 | "metadata": { 161 | "jupytext": { 162 | "cell_metadata_filter": "tags,-all", 163 | "main_language": "python", 164 | "notebook_metadata_filter": "-all" 165 | } 166 | }, 167 | "nbformat": 4, 168 | "nbformat_minor": 5 169 | } 170 | -------------------------------------------------------------------------------- /examples/qa.pmpt.tpl: -------------------------------------------------------------------------------- 1 | Answer the question as truthfully as possible using the provided context, and if the answer is not contained within the text below, say "I don't know." 2 | 3 | Context: 4 | 5 | {% for doc in docs %} 6 | * {{doc}} 7 | {% endfor %} 8 | 9 | Q: {{question}} 10 | 11 | A: 12 | 13 | -------------------------------------------------------------------------------- /examples/qa.py: -------------------------------------------------------------------------------- 1 | # + tags=["hide_inp"] 2 | desc = """ 3 | ### Question Answering with Retrieval 4 | 5 | Chain that answers questions with embeedding based retrieval. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/qa.ipynb) 6 | 7 | (Adapted from [OpenAI Notebook](https://github.com/openai/openai-cookbook/blob/main/examples/Question_answering_using_embeddings.ipynb).) 8 | """ 9 | # - 10 | 11 | # $ 12 | 13 | import datasets 14 | import numpy as np 15 | from minichain import prompt, transform, show, OpenAIEmbed, OpenAI 16 | from manifest import Manifest 17 | 18 | # We use Hugging Face Datasets as the database by assigning 19 | # a FAISS index. 20 | 21 | olympics = datasets.load_from_disk("olympics.data") 22 | olympics.add_faiss_index("embeddings") 23 | 24 | 25 | # Fast KNN retieval prompt 26 | 27 | @prompt(OpenAIEmbed()) 28 | def embed(model, inp): 29 | return model(inp) 30 | 31 | @transform() 32 | def get_neighbors(inp, k): 33 | res = olympics.get_nearest_examples("embeddings", np.array(inp), k) 34 | return res.examples["content"] 35 | 36 | @prompt(OpenAI(), template_file="qa.pmpt.tpl") 37 | def get_result(model, query, neighbors): 38 | return model(dict(question=query, docs=neighbors)) 39 | 40 | def qa(query): 41 | n = get_neighbors(embed(query), 3) 42 | return get_result(query, n) 43 | 44 | # $ 45 | 46 | 47 | questions = ["Who won the 2020 Summer Olympics men's high jump?", 48 | "Why was the 2020 Summer Olympics originally postponed?", 49 | "In the 2020 Summer Olympics, how many gold medals did the country which won the most medals win?", 50 | "What is the total number of medals won by France?", 51 | "What is the tallest mountain in the world?"] 52 | 53 | gradio = show(qa, 54 | examples=questions, 55 | subprompts=[embed, get_result], 56 | description=desc, 57 | code=open("qa.py", "r").read().split("$")[1].strip().strip("#").strip(), 58 | ) 59 | if __name__ == "__main__": 60 | gradio.queue().launch() 61 | 62 | -------------------------------------------------------------------------------- /examples/selfask.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "e9c5f0df", 7 | "metadata": { 8 | "lines_to_next_cell": 2 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "!pip install -q git+https://github.com/srush/MiniChain\n", 13 | "!git clone https://github.com/srush/MiniChain; cp -fr MiniChain/examples/* . " 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "id": "56932926", 20 | "metadata": { 21 | "lines_to_next_cell": 2 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "desc = \"\"\"\n", 26 | "### Self-Ask\n", 27 | "\n", 28 | " Notebook implementation of the self-ask + Google tool use prompt. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/selfask.ipynb)\n", 29 | "\n", 30 | " (Adapted from [Self-Ask repo](https://github.com/ofirpress/self-ask))\n", 31 | "\"\"\"" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "id": "43d24799", 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "from dataclasses import dataclass, replace\n", 42 | "from typing import Optional\n", 43 | "from minichain import prompt, show, OpenAI, Google" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "id": "506b8e41", 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "@dataclass\n", 54 | "class State:\n", 55 | " question: str\n", 56 | " history: str = \"\"\n", 57 | " next_query: Optional[str] = None\n", 58 | " final_answer: Optional[str] = None" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "id": "fb5fdf8e", 65 | "metadata": { 66 | "lines_to_next_cell": 1 67 | }, 68 | "outputs": [], 69 | "source": [ 70 | "@prompt(OpenAI(),\n", 71 | " template_file = \"selfask.pmpt.tpl\",\n", 72 | " stop_template = \"\\nIntermediate answer:\")\n", 73 | "def self_ask(model, state):\n", 74 | " out = model(state)\n", 75 | " res = out.split(\":\", 1)[1]\n", 76 | " if out.startswith(\"Follow up:\"):\n", 77 | " return replace(state, next_query=res)\n", 78 | " elif out.startswith(\"So the final answer is:\"):\n", 79 | " return replace(state, final_answer=res)" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "id": "f2d66476", 86 | "metadata": { 87 | "lines_to_next_cell": 1 88 | }, 89 | "outputs": [], 90 | "source": [ 91 | "@prompt(Google())\n", 92 | "def google(model, state):\n", 93 | " if state.next_query is None:\n", 94 | " return state\n", 95 | "\n", 96 | " result = model(state.next_query)\n", 97 | " return State(state.question,\n", 98 | " state.history + \"\\nIntermediate answer: \" + result + \"\\n\")" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "id": "e41d7a75", 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "def selfask(question):\n", 109 | " state = State(question)\n", 110 | " for i in range(3):\n", 111 | " state = self_ask(state)\n", 112 | " state = google(state)\n", 113 | " return state" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "id": "e6e4b06f", 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "gradio = show(selfask,\n", 124 | " examples=[\"What is the zip code of the city where George Washington was born?\"],\n", 125 | " subprompts=[self_ask, google] * 3,\n", 126 | " description=desc,\n", 127 | " out_type=\"json\"\n", 128 | " )\n", 129 | "if __name__ == \"__main__\":\n", 130 | " gradio.launch()" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "id": "6afd60de", 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [] 140 | } 141 | ], 142 | "metadata": { 143 | "jupytext": { 144 | "cell_metadata_filter": "-all", 145 | "main_language": "python", 146 | "notebook_metadata_filter": "-all" 147 | } 148 | }, 149 | "nbformat": 4, 150 | "nbformat_minor": 5 151 | } 152 | -------------------------------------------------------------------------------- /examples/selfask.pmpt.tpl: -------------------------------------------------------------------------------- 1 | Question: Who lived longer, Muhammad Ali or Alan Turing? 2 | Are follow up questions needed here: Yes. 3 | Follow up: How old was Muhammad Ali when he died? 4 | Intermediate answer: Muhammad Ali was 74 years old when he died. 5 | Follow up: How old was Alan Turing when he died? 6 | Intermediate answer: Alan Turing was 41 years old when he died. 7 | So the final answer is: Muhammad Ali 8 | 9 | Question: When was the founder of craigslist born? 10 | Are follow up questions needed here: Yes. 11 | Follow up: Who was the founder of craigslist? 12 | Intermediate answer: Craigslist was founded by Craig Newmark. 13 | Follow up: When was Craig Newmark born? 14 | Intermediate answer: Craig Newmark was born on December 6, 1952. 15 | So the final answer is: December 6, 1952 16 | 17 | Question: Who was the maternal grandfather of George Washington? 18 | Are follow up questions needed here: Yes. 19 | Follow up: Who was the mother of George Washington? 20 | Intermediate answer: The mother of George Washington was Mary Ball Washington. 21 | Follow up: Who was the father of Mary Ball Washington? 22 | Intermediate answer: The father of Mary Ball Washington was Joseph Ball. 23 | So the final answer is: Joseph Ball 24 | 25 | Question: Are both the directors of Jaws and Casino Royale from the same country? 26 | Are follow up questions needed here: Yes. 27 | Follow up: Who is the director of Jaws? 28 | Intermediate answer: The director of Jaws is Steven Spielberg. 29 | Follow up: Where is Steven Spielberg from? 30 | Intermediate answer: The United States. 31 | Follow up: Who is the director of Casino Royale? 32 | Intermediate answer: The director of Casino Royale is Martin Campbell. 33 | Follow up: Where is Martin Campbell from? 34 | Intermediate answer: New Zealand. 35 | So the final answer is: No 36 | 37 | Question: {{question}} 38 | Are followup questions needed here: Yes 39 | {{history}} 40 | -------------------------------------------------------------------------------- /examples/selfask.py: -------------------------------------------------------------------------------- 1 | 2 | desc = """ 3 | ### Self-Ask 4 | 5 | Notebook implementation of the self-ask + Google tool use prompt. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/selfask.ipynb) 6 | 7 | (Adapted from [Self-Ask repo](https://github.com/ofirpress/self-ask)) 8 | """ 9 | 10 | # $ 11 | 12 | from dataclasses import dataclass, replace 13 | from typing import Optional 14 | from minichain import prompt, show, OpenAI, Google, transform 15 | 16 | 17 | @dataclass 18 | class State: 19 | question: str 20 | history: str = "" 21 | next_query: Optional[str] = None 22 | final_answer: Optional[str] = None 23 | 24 | 25 | @prompt(OpenAI(stop="\nIntermediate answer:"), 26 | template_file = "selfask.pmpt.tpl") 27 | def self_ask(model, state): 28 | return model(state) 29 | 30 | @transform() 31 | def next_step(ask): 32 | res = ask.split(":", 1)[1] 33 | if out.startswith("Follow up:"): 34 | return replace(state, next_query=res) 35 | elif out.startswith("So the final answer is:"): 36 | return replace(state, final_answer=res) 37 | 38 | @prompt(Google()) 39 | def google(model, state): 40 | if state.next_query is None: 41 | return "" 42 | 43 | return model(state.next_query) 44 | 45 | @transform() 46 | def update(state, result): 47 | if not result: 48 | return state 49 | return State(state.question, 50 | state.history + "\nIntermediate answer: " + result + "\n") 51 | 52 | def selfask(question): 53 | state = State(question) 54 | for i in range(3): 55 | state = next_step(self_ask(state)) 56 | state = update(google(state)) 57 | return state 58 | 59 | # $ 60 | 61 | gradio = show(selfask, 62 | examples=["What is the zip code of the city where George Washington was born?"], 63 | subprompts=[self_ask, google] * 3, 64 | description=desc, 65 | code=open("selfask.py", "r").read().split("$")[1].strip().strip("#").strip(), 66 | out_type="json" 67 | ) 68 | if __name__ == "__main__": 69 | gradio.queue().launch() 70 | 71 | 72 | -------------------------------------------------------------------------------- /examples/sixers.txt: -------------------------------------------------------------------------------- 1 | The Philadelphia 76ers entered their Christmas Day matinee at Madison Square Garden against the New York Knicks having won seven straight games, all at home. They can now add an eighth victory to that list. Led by a dominant fourth quarter, the 76ers (20-12) defeated the Knicks (18-16), 119-112. 2 | 3 | Joel Embiid led the way with 35 points and eight rebounds, while James Harden added 29 points and 13 assists in another dazzling performance. They got a boost off the bench from Georges Niang, who nailed four 3s in the fourth quarter to finish with 16 points. 4 | 5 | The Knicks have now lost three straight games following an eight-game unbeaten run. Julius Randle matched Embiid with 35 points. 6 | 7 | Follow here for updates and analysis from The Athletic's staff. 8 | -------------------------------------------------------------------------------- /examples/stats.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "247c85dd", 7 | "metadata": { 8 | "lines_to_next_cell": 2 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "!pip install -q git+https://github.com/srush/MiniChain\n", 13 | "!git clone https://github.com/srush/MiniChain; cp -fr MiniChain/examples/* . " 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "id": "d36498d1", 20 | "metadata": { 21 | "lines_to_next_cell": 2 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "desc = \"\"\"\n", 26 | "### Typed Extraction\n", 27 | "\n", 28 | "Information extraction that is automatically generated from a typed specification. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/pal.ipynb)\n", 29 | "\n", 30 | "(Novel to MiniChain)\n", 31 | "\"\"\"" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "id": "e894e4fb", 38 | "metadata": { 39 | "lines_to_next_cell": 1 40 | }, 41 | "outputs": [], 42 | "source": [ 43 | "from minichain import prompt, show, type_to_prompt, OpenAI\n", 44 | "from dataclasses import dataclass\n", 45 | "from typing import List\n", 46 | "from enum import Enum" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "id": "6ac1ff70", 52 | "metadata": {}, 53 | "source": [ 54 | "Data specification" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "id": "64a00e69", 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "class StatType(Enum):\n", 65 | " POINTS = 1\n", 66 | " REBOUNDS = 2\n", 67 | " ASSISTS = 3\n", 68 | "\n", 69 | "@dataclass\n", 70 | "class Stat:\n", 71 | " value: int\n", 72 | " stat: StatType\n", 73 | "\n", 74 | "@dataclass\n", 75 | "class Player:\n", 76 | " player: str\n", 77 | " stats: List[Stat]" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "id": "35e0ed85", 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "@prompt(OpenAI(), template_file=\"stats.pmpt.tpl\", parser=\"json\")\n", 88 | "def stats(model, passage):\n", 89 | " out = model(dict(passage=passage, typ=type_to_prompt(Player)))\n", 90 | " return [Player(**j) for j in out]" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "id": "7241bdc6", 97 | "metadata": { 98 | "lines_to_next_cell": 2 99 | }, 100 | "outputs": [], 101 | "source": [ 102 | "article = open(\"sixers.txt\").read()\n", 103 | "gradio = show(lambda passage: stats(passage),\n", 104 | " examples=[article],\n", 105 | " subprompts=[stats],\n", 106 | " out_type=\"json\",\n", 107 | " description=desc,\n", 108 | ")\n", 109 | "if __name__ == \"__main__\":\n", 110 | " gradio.launch()" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "id": "18ed1b96", 116 | "metadata": {}, 117 | "source": [ 118 | "ExtractionPrompt().show({\"passage\": \"Harden had 10 rebounds.\"},\n", 119 | " '[{\"player\": \"Harden\", \"stats\": {\"value\": 10, \"stat\": 2}}]')" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "id": "02c96e7d", 125 | "metadata": {}, 126 | "source": [ 127 | "# View the run log." 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "id": "0ea427dd", 133 | "metadata": {}, 134 | "source": [ 135 | "minichain.show_log(\"bash.log\")" 136 | ] 137 | } 138 | ], 139 | "metadata": { 140 | "jupytext": { 141 | "cell_metadata_filter": "-all", 142 | "main_language": "python", 143 | "notebook_metadata_filter": "-all" 144 | } 145 | }, 146 | "nbformat": 4, 147 | "nbformat_minor": 5 148 | } 149 | -------------------------------------------------------------------------------- /examples/stats.pmpt.tpl: -------------------------------------------------------------------------------- 1 | {{typ | safe}} 2 | 3 | {{passage}} 4 | 5 | 6 | JSON Output: -------------------------------------------------------------------------------- /examples/stats.py: -------------------------------------------------------------------------------- 1 | 2 | desc = """ 3 | ### Typed Extraction 4 | 5 | Information extraction that is automatically generated from a typed specification. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/pal.ipynb) 6 | 7 | (Novel to MiniChain) 8 | """ 9 | 10 | # $ 11 | 12 | from minichain import prompt, show, OpenAI, transform 13 | from dataclasses import dataclass, is_dataclass, fields 14 | from typing import List, Type, Dict, Any, get_origin, get_args 15 | from enum import Enum 16 | from jinja2 import select_autoescape, FileSystemLoader, Environment 17 | import json 18 | 19 | def enum(x: Type[Enum]) -> Dict[str, int]: 20 | d = {e.name: e.value for e in x} 21 | return d 22 | 23 | 24 | def walk(x: Any) -> Any: 25 | if issubclass(x if get_origin(x) is None else get_origin(x), List): 26 | return {"_t_": "list", "t": walk(get_args(x)[0])} 27 | if issubclass(x, Enum): 28 | return enum(x) 29 | 30 | if is_dataclass(x): 31 | return {y.name: walk(y.type) for y in fields(x)} 32 | return x.__name__ 33 | 34 | 35 | def type_to_prompt(out: type) -> str: 36 | tmp = env.get_template("type_prompt.pmpt.tpl") 37 | d = walk(out) 38 | return tmp.render({"typ": d}) 39 | 40 | env = Environment( 41 | loader=FileSystemLoader("."), 42 | autoescape=select_autoescape(), 43 | extensions=["jinja2_highlight.HighlightExtension"], 44 | ) 45 | 46 | 47 | 48 | # Data specification 49 | 50 | # + 51 | class StatType(Enum): 52 | POINTS = 1 53 | REBOUNDS = 2 54 | ASSISTS = 3 55 | 56 | @dataclass 57 | class Stat: 58 | value: int 59 | stat: StatType 60 | 61 | @dataclass 62 | class Player: 63 | player: str 64 | stats: List[Stat] 65 | # - 66 | 67 | 68 | @prompt(OpenAI(), template_file="stats.pmpt.tpl") 69 | def stats(model, passage): 70 | return model.stream(dict(passage=passage, typ=type_to_prompt(Player))) 71 | 72 | @transform() 73 | def to_data(s:str): 74 | return [Player(**j) for j in json.loads(s)] 75 | 76 | # $ 77 | 78 | article = open("sixers.txt").read() 79 | gradio = show(lambda passage: to_data(stats(passage)), 80 | examples=[article], 81 | subprompts=[stats], 82 | out_type="json", 83 | description=desc, 84 | code=open("stats.py", "r").read().split("$")[1].strip().strip("#").strip(), 85 | ) 86 | if __name__ == "__main__": 87 | gradio.queue().launch() 88 | 89 | 90 | # ExtractionPrompt().show({"passage": "Harden had 10 rebounds."}, 91 | # '[{"player": "Harden", "stats": {"value": 10, "stat": 2}}]') 92 | 93 | # # View the run log. 94 | 95 | # minichain.show_log("bash.log") 96 | -------------------------------------------------------------------------------- /examples/summary.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "49192d35", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "!pip install -q git+https://github.com/srush/MiniChain\n", 11 | "!git clone https://github.com/srush/MiniChain; cp -fr MiniChain/examples/* . " 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "id": "bf1da24c", 17 | "metadata": {}, 18 | "source": [ 19 | "Summarize a long document by chunking and summarizing parts. Uses\n", 20 | "aynchronous calls to the API. Adapted from LangChain [Map-Reduce\n", 21 | "summary](https://langchain.readthedocs.io/en/stable/_modules/langchain/chains/mapreduce.html)." 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "id": "cce74ed6", 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "import trio" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "id": "f25908e4", 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "from minichain import TemplatePrompt, show_log, start_chain" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "id": "174e7a29", 47 | "metadata": { 48 | "lines_to_next_cell": 2 49 | }, 50 | "source": [ 51 | "Prompt that asks LLM to produce a bash command." 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "id": "12b26a26", 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "class SummaryPrompt(TemplatePrompt):\n", 62 | " template_file = \"summary.pmpt.tpl\"" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "id": "98747659", 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "def chunk(f, width=4000, overlap=800):\n", 73 | " \"Split a documents into 4800 character overlapping chunks\"\n", 74 | " text = open(f).read().replace(\"\\n\\n\", \"\\n\")\n", 75 | " chunks = []\n", 76 | " for i in range(4):\n", 77 | " if i * width > len(text):\n", 78 | " break\n", 79 | " chunks.append({\"text\": text[i * width : (i + 1) * width + overlap]})\n", 80 | " return chunks" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "id": "e0ccfddc", 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "with start_chain(\"summary\") as backend:\n", 91 | " prompt = SummaryPrompt(backend.OpenAI())\n", 92 | " list_prompt = prompt.map()\n", 93 | "\n", 94 | " # Map - Summarize each chunk in parallel\n", 95 | " out = trio.run(list_prompt.arun, chunk(\"../state_of_the_union.txt\"))\n", 96 | "\n", 97 | " # Reduce - Summarize the summarized chunks\n", 98 | " print(prompt({\"text\": \"\\n\".join(out)}))" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "id": "e3ffd907", 105 | "metadata": { 106 | "tags": [ 107 | "hide_inp" 108 | ] 109 | }, 110 | "outputs": [], 111 | "source": [ 112 | "SummaryPrompt().show(\n", 113 | " {\"text\": \"One way to fight is to drive down wages and make Americans poorer.\"},\n", 114 | " \"Make Americans poorer\",\n", 115 | ")" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "id": "52be8068", 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "show_log(\"summary.log\")" 126 | ] 127 | } 128 | ], 129 | "metadata": { 130 | "jupytext": { 131 | "cell_metadata_filter": "tags,-all", 132 | "main_language": "python", 133 | "notebook_metadata_filter": "-all" 134 | } 135 | }, 136 | "nbformat": 4, 137 | "nbformat_minor": 5 138 | } 139 | -------------------------------------------------------------------------------- /examples/summary.pmpt.tpl: -------------------------------------------------------------------------------- 1 | Write a concise summary of the following: 2 | 3 | 4 | "{{text}}" 5 | 6 | 7 | CONCISE SUMMARY: -------------------------------------------------------------------------------- /examples/summary.py: -------------------------------------------------------------------------------- 1 | # Summarize a long document by chunking and summarizing parts. Uses 2 | # aynchronous calls to the API. Adapted from LangChain [Map-Reduce 3 | # summary](https://langchain.readthedocs.io/en/stable/_modules/langchain/chains/mapreduce.html). 4 | 5 | import trio 6 | 7 | from minichain import TemplatePrompt, show_log, start_chain 8 | 9 | # Prompt that asks LLM to produce a bash command. 10 | 11 | 12 | class SummaryPrompt(TemplatePrompt): 13 | template_file = "summary.pmpt.tpl" 14 | 15 | 16 | def chunk(f, width=4000, overlap=800): 17 | "Split a documents into 4800 character overlapping chunks" 18 | text = open(f).read().replace("\n\n", "\n") 19 | chunks = [] 20 | for i in range(4): 21 | if i * width > len(text): 22 | break 23 | chunks.append({"text": text[i * width : (i + 1) * width + overlap]}) 24 | return chunks 25 | 26 | 27 | with start_chain("summary") as backend: 28 | prompt = SummaryPrompt(backend.OpenAI()) 29 | list_prompt = prompt.map() 30 | 31 | # Map - Summarize each chunk in parallel 32 | out = trio.run(list_prompt.arun, chunk("../state_of_the_union.txt")) 33 | 34 | # Reduce - Summarize the summarized chunks 35 | print(prompt({"text": "\n".join(out)})) 36 | 37 | # + tags=["hide_inp"] 38 | SummaryPrompt().show( 39 | {"text": "One way to fight is to drive down wages and make Americans poorer."}, 40 | "Make Americans poorer", 41 | ) 42 | # - 43 | 44 | show_log("summary.log") 45 | -------------------------------------------------------------------------------- /examples/table.pmpt.txt: -------------------------------------------------------------------------------- 1 | You are a utility built to extract structured information from documents. You are returning a TSV table. Here are the headers . 2 | 3 | ---- 4 | {{type}} {% for k in player_keys %}{{k[0]}}{{"\t" if not loop.last}}{% endfor %} 5 | ---- 6 | 7 | Return the rest of the table in TSV format. Here are some examples 8 | 9 | {% for example in examples %} 10 | Example 11 | --- 12 | {{example.input}} 13 | --- 14 | 15 | Output 16 | --- 17 | {{example.output}} 18 | --- 19 | {% endfor %} 20 | 21 | Article: 22 | ---- 23 | {{passage}} 24 | ---- 25 | 26 | All other values should be numbers or _. 27 | Only include numbers that appear explicitly in the passage below. 28 | If you cannot find the value in the table, output _. Most cells will be _. 29 | 30 | Ok, here is the correctly valid TSV with headers and nothing else. Remember only include values that are directly written in the article. Do not guess or combine rows. 31 | 32 | -------------------------------------------------------------------------------- /examples/table.py: -------------------------------------------------------------------------------- 1 | # + tags=["hide_inp"] 2 | desc = """ 3 | ### Table 4 | 5 | Example of extracting tables from a textual document. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/table.ipynb) 6 | 7 | """ 8 | # - 9 | 10 | # $ 11 | import pandas as pd 12 | from minichain import prompt, Mock, show, OpenAI, GradioConf 13 | import minichain 14 | import json 15 | import gradio as gr 16 | import requests 17 | 18 | rotowire = requests.get("https://raw.githubusercontent.com/srush/text2table/main/data.json").json() 19 | names = { 20 | '3-pointer percentage': 'FG3_PCT', 21 | '3-pointers attempted': 'FG3A', 22 | '3-pointers made': 'FG3M', 23 | 'Assists': 'AST', 24 | 'Blocks': 'BLK', 25 | 'Field goal percentage': 'FG_PCT', 26 | 'Field goals attempted': 'FGA', 27 | 'Field goals made': 'FGM', 28 | 'Free throw percentage': 'FT_PCT', 29 | 'Free throws attempted': 'FTA', 30 | 'Free throws made': 'FTM', 31 | 'Minutes played': 'MIN', 32 | 'Personal fouls': 'PF', 33 | 'Points': 'PTS', 34 | 'Rebounds': 'REB', 35 | 'Rebounds (Defensive)': 'DREB', 36 | 'Rebounds (Offensive)': 'OREB', 37 | 'Steals': 'STL', 38 | 'Turnovers': 'TO' 39 | } 40 | # Convert an example to dataframe 41 | def to_df(d): 42 | players = {player for v in d.values() if v is not None for player, _ in v.items()} 43 | lookup = {k: {a: b for a, b in v.items()} for k,v in d.items()} 44 | rows = [dict(**{"player": p}, **{k: "_" if p not in lookup.get(k, []) else lookup[k][p] for k in names.keys()}) 45 | for p in players] 46 | return pd.DataFrame.from_dict(rows).astype("str").sort_values(axis=0, by="player", ignore_index=True).transpose() 47 | 48 | 49 | # Make few shot examples 50 | few_shot_examples = 2 51 | examples = [] 52 | for i in range(few_shot_examples): 53 | examples.append({"input": rotowire[i][1], 54 | "output": to_df(rotowire[i][0][1]).transpose().set_index("player").to_csv(sep="\t")}) 55 | 56 | def make_html(out): 57 | return "
" + out.replace("\t", "").replace("\n", "
") + "
" 58 | 59 | @prompt(OpenAI("gpt-4"), template_file="table.pmpt.txt", 60 | gradio_conf=GradioConf(block_output=gr.HTML, 61 | postprocess_output = make_html) 62 | ) 63 | def extract(model, passage, typ): 64 | return model(dict(player_keys=names.items(), examples=examples, passage=passage, type=typ)) 65 | 66 | def run(query): 67 | return extract(query, "Player") 68 | 69 | # $ 70 | 71 | import os 72 | gradio = show(run, 73 | examples = [rotowire[i][1] for i in range(50, 55)], 74 | subprompts=[extract], 75 | code=open("table.py" if os.path.exists("table.py") else "app.py", "r").read().split("$")[1].strip().strip("#").strip(), 76 | out_type="markdown" 77 | ) 78 | 79 | if __name__ == "__main__": 80 | gradio.queue().launch() 81 | -------------------------------------------------------------------------------- /examples/type_prompt.pmpt.tpl: -------------------------------------------------------------------------------- 1 | You are a highly intelligent and accurate information extraction system. You take passage as input and your task is to find parts of the passage to answer questions. 2 | 3 | {% macro describe(typ) -%} 4 | {% for key, val in typ.items() %} 5 | You need to classify in to the following types for key: "{{key}}": 6 | {% if val == "str" %}String 7 | {% elif val == "int" %}Int {% else %} 8 | {% if val.get("_t_") == "list" %}List{{describe(val["t"])}}{% else %} 9 | 10 | {% for k, v in val.items() %}{{k}} 11 | {% endfor %} 12 | 13 | Only select from the above list. 14 | {% endif %} 15 | {%endif%} 16 | {% endfor %} 17 | {% endmacro -%} 18 | {{describe(typ)}} 19 | {% macro json(typ) -%}{% for key, val in typ.items() %}{% if val in ["str", "int"] or val.get("_t_") != "list" %}"{{key}}" : "{{key}}" {% else %} "{{key}}" : [{ {{json(val["t"])}} }] {% endif %}{{"" if loop.last else ", "}} {% endfor %}{% endmacro -%} 20 | 21 | [{ {{json(typ)}} }, ...] 22 | 23 | 24 | 25 | Make sure every output is exactly seen in the document. Find as many as you can. 26 | You need to output only JSON. -------------------------------------------------------------------------------- /minichain/__init__.py: -------------------------------------------------------------------------------- 1 | from .backend import ( 2 | Backend, 3 | Bash, 4 | Google, 5 | HuggingFaceEmbed, 6 | Id, 7 | Manifest, 8 | Mock, 9 | OpenAI, 10 | OpenAIEmbed, 11 | Python, 12 | set_minichain_log, 13 | start_chain, 14 | ) 15 | from .base import Break, prompt, transform 16 | from .gradio import GradioConf, show 17 | -------------------------------------------------------------------------------- /minichain/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import time 4 | from dataclasses import dataclass 5 | from types import TracebackType 6 | from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple 7 | 8 | from eliot import start_action, to_file 9 | 10 | if TYPE_CHECKING: 11 | import manifest 12 | 13 | 14 | class Backend: 15 | @property 16 | def description(self) -> str: 17 | return "" 18 | 19 | def run(self, request: str) -> str: 20 | raise NotImplementedError 21 | 22 | def run_stream(self, request: str) -> Iterator[str]: 23 | yield self.run(request) 24 | 25 | async def arun(self, request: str) -> str: 26 | return self.run(request) 27 | 28 | def _block_input(self, gr): # type: ignore 29 | return gr.Textbox(show_label=False) 30 | 31 | def _block_output(self, gr): # type: ignore 32 | return gr.Textbox(show_label=False) 33 | 34 | 35 | class Id(Backend): 36 | def run(self, request: str) -> str: 37 | return request 38 | 39 | 40 | class Mock(Backend): 41 | def __init__(self, answers: List[str] = []): 42 | self.i = -1 43 | self.answers = answers 44 | 45 | def run(self, request: str) -> str: 46 | self.i += 1 47 | return self.answers[self.i % len(self.answers)] 48 | 49 | def run_stream(self, request: str) -> Iterator[str]: 50 | self.i += 1 51 | result = self.answers[self.i % len(self.answers)] 52 | for c in result: 53 | yield c 54 | time.sleep(0.1) 55 | 56 | def __repr__(self) -> str: 57 | return f"Mocked Backend {self.answers}" 58 | 59 | 60 | class Google(Backend): 61 | def __init__(self) -> None: 62 | pass 63 | 64 | def run(self, request: str) -> str: 65 | from serpapi import GoogleSearch 66 | 67 | serpapi_key = os.environ.get("SERP_KEY") 68 | assert ( 69 | serpapi_key 70 | ), "Need a SERP_KEY. Get one here https://serpapi.com/users/welcome" 71 | self.serpapi_key = serpapi_key 72 | 73 | params = { 74 | "api_key": self.serpapi_key, 75 | "engine": "google", 76 | "q": request, 77 | "google_domain": "google.com", 78 | "gl": "us", 79 | "hl": "en", 80 | } 81 | 82 | search = GoogleSearch(params) 83 | res = search.get_dict() 84 | 85 | if "answer_box" in res.keys() and "answer" in res["answer_box"].keys(): 86 | toret = res["answer_box"]["answer"] 87 | elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys(): 88 | toret = res["answer_box"]["snippet"] 89 | elif ( 90 | "answer_box" in res.keys() 91 | and "snippet_highlighted_words" in res["answer_box"].keys() 92 | ): 93 | toret = res["answer_box"]["snippet_highlighted_words"][0] 94 | elif "snippet" in res["organic_results"][0].keys(): 95 | toret = res["organic_results"][0]["snippet"] 96 | else: 97 | toret = "" 98 | return str(toret) 99 | 100 | def __repr__(self) -> str: 101 | return "Google Search Backend" 102 | 103 | 104 | class Python(Backend): 105 | """Executes Python commands and returns the output.""" 106 | 107 | def _block_input(self, gr): # type: ignore 108 | return gr.Code() 109 | 110 | def _block_output(self, gr): # type: ignore 111 | return gr.Code() 112 | 113 | def run(self, request: str) -> str: 114 | """Run commands and return final output.""" 115 | from contextlib import redirect_stdout 116 | from io import StringIO 117 | 118 | p = request.strip() 119 | if p.startswith("```"): 120 | p = "\n".join(p.strip().split("\n")[1:-1]) 121 | 122 | f = StringIO() 123 | with redirect_stdout(f): 124 | exec(p) 125 | s = f.getvalue() 126 | return s 127 | 128 | def __repr__(self) -> str: 129 | return "Python-Backend" 130 | 131 | 132 | class Bash(Backend): 133 | """Executes bash commands and returns the output.""" 134 | 135 | def _block_input(self, gr): # type: ignore 136 | return gr.Code() 137 | 138 | def _block_output(self, gr): # type: ignore 139 | return gr.Code() 140 | 141 | def __init__(self, strip_newlines: bool = False, return_err_output: bool = False): 142 | """Initialize with stripping newlines.""" 143 | self.strip_newlines = strip_newlines 144 | self.return_err_output = return_err_output 145 | 146 | def run(self, request: str) -> str: 147 | """Run commands and return final output.""" 148 | try: 149 | output = subprocess.run( 150 | request, 151 | shell=True, 152 | check=True, 153 | stdout=subprocess.PIPE, 154 | stderr=subprocess.STDOUT, 155 | ).stdout.decode() 156 | except subprocess.CalledProcessError as error: 157 | if self.return_err_output: 158 | return str(error.stdout.decode()) 159 | return str(error) 160 | if self.strip_newlines: 161 | output = output.strip() 162 | return output 163 | 164 | def __repr__(self) -> str: 165 | return "Bash-Backend" 166 | 167 | 168 | class OpenAIBase(Backend): 169 | def __init__( 170 | self, 171 | model: str = "gpt-3.5-turbo", 172 | max_tokens: int = 256, 173 | temperature: float = 0.0, 174 | stop: Optional[List[str]] = None, 175 | ) -> None: 176 | self.model = model 177 | self.stop = stop 178 | self.options = dict( 179 | model=model, 180 | max_tokens=max_tokens, 181 | temperature=temperature, 182 | ) 183 | 184 | def __repr__(self) -> str: 185 | return f"OpenAI Backend {self.options}" 186 | 187 | 188 | class OpenAI(OpenAIBase): 189 | def run(self, request: str) -> str: 190 | import manifest 191 | 192 | chat = {"gpt-4", "gpt-3.5-turbo"} 193 | manifest = manifest.Manifest( 194 | client_name="openaichat" if self.model in chat else "openai", 195 | max_tokens=self.options["max_tokens"], 196 | cache_name="sqlite", 197 | cache_connection=f"{MinichainContext.name}", 198 | ) 199 | 200 | ans = manifest.run( 201 | request, 202 | stop_sequences=self.stop, 203 | ) 204 | return str(ans) 205 | 206 | def run_stream(self, prompt: str) -> Iterator[str]: 207 | import openai 208 | 209 | self.api_key = os.environ.get("OPENAI_API_KEY") 210 | assert ( 211 | self.api_key 212 | ), "Need an OPENAI_API_KEY. Get one here https://openai.com/api/" 213 | openai.api_key = self.api_key 214 | 215 | for chunk in openai.ChatCompletion.create( 216 | model=self.model, 217 | messages=[{"role": "user", "content": prompt}], 218 | stream=True, 219 | stop=self.stop, 220 | ): 221 | content = chunk["choices"][0].get("delta", {}).get("content") 222 | if content is not None: 223 | yield content 224 | 225 | 226 | class OpenAIEmbed(OpenAIBase): 227 | def _block_output(self, gr): # type: ignore 228 | return gr.Textbox(label="Embedding") 229 | 230 | def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None: 231 | super().__init__(model, **kwargs) 232 | 233 | def run(self, request: str) -> str: 234 | import openai 235 | 236 | self.api_key = os.environ.get("OPENAI_API_KEY") 237 | assert ( 238 | self.api_key 239 | ), "Need an OPENAI_API_KEY. Get one here https://openai.com/api/" 240 | openai.api_key = self.api_key 241 | 242 | ans = openai.Embedding.create( 243 | engine=self.model, 244 | input=request, 245 | ) 246 | return ans["data"][0]["embedding"] # type: ignore 247 | 248 | 249 | class HuggingFaceBase(Backend): 250 | def __init__(self, model: str = "gpt2") -> None: 251 | self.model = model 252 | 253 | 254 | class HuggingFace(HuggingFaceBase): 255 | def run(self, request: str) -> str: 256 | 257 | from huggingface_hub.inference_api import InferenceApi 258 | 259 | self.api_key = os.environ.get("HF_KEY") 260 | assert self.api_key, "Need an HF_KEY. Get one here https://huggingface.co/" 261 | 262 | self.client = InferenceApi( 263 | token=self.api_key, repo_id=self.model, task="text-generation" 264 | ) 265 | response = self.client(inputs=request) 266 | return response # type: ignore 267 | 268 | 269 | class HuggingFaceEmbed(HuggingFaceBase): 270 | def run(self, request: str) -> str: 271 | 272 | from huggingface_hub.inference_api import InferenceApi 273 | 274 | self.api_key = os.environ.get("HF_KEY") 275 | assert self.api_key, "Need an HF_KEY. Get one here https://huggingface.co/" 276 | 277 | self.client = InferenceApi( 278 | token=self.api_key, repo_id=self.model, task="feature-extraction" 279 | ) 280 | response = self.client(inputs=request) 281 | return response # type: ignore 282 | 283 | 284 | class Manifest(Backend): 285 | def __init__(self, client: "manifest.Manifest") -> None: 286 | "Client from [Manifest-ML](https://github.com/HazyResearch/manifest)." 287 | self.client = client 288 | 289 | def run(self, request: str) -> str: 290 | try: 291 | import manifest 292 | except ImportError: 293 | raise ImportError("`pip install manifest-ml` to use the Manifest Backend.") 294 | assert isinstance( 295 | self.client, manifest.Manifest 296 | ), "Client must be a `manifest.Manifest` instance." 297 | 298 | return self.client.run(request) # type: ignore 299 | 300 | 301 | @dataclass 302 | class RunLog: 303 | request: str = "" 304 | response: Optional[str] = "" 305 | output: str = "" 306 | dynamic: int = 0 307 | 308 | 309 | @dataclass 310 | class PromptSnap: 311 | input_: Any = "" 312 | run_log: RunLog = RunLog() 313 | output: Any = "" 314 | 315 | 316 | class MinichainContext: 317 | id_: int = 0 318 | prompt_store: Dict[Tuple[int, int], List[PromptSnap]] = {} 319 | prompt_count: Dict[int, int] = {} 320 | name: str = "" 321 | 322 | 323 | def set_minichain_log(name: str) -> None: 324 | to_file(open(f"{name}.log", "w")) 325 | 326 | 327 | class MiniChain: 328 | """ 329 | MiniChain session object with backends. Make backend by calling 330 | `minichain.OpenAI()` with args for `OpenAI` class. 331 | """ 332 | 333 | def __init__(self, name: str): 334 | to_file(open(f"{name}.log", "w")) 335 | self.name = name 336 | 337 | def __enter__(self) -> "MiniChain": 338 | MinichainContext.prompt_store = {} 339 | MinichainContext.prompt_count = {} 340 | MinichainContext.name = self.name 341 | self.action = start_action(action_type=self.name) 342 | return self 343 | 344 | def __exit__( 345 | self, 346 | type: type, 347 | exception: Optional[BaseException], 348 | traceback: Optional[TracebackType], 349 | ) -> None: 350 | self.action.finish() 351 | self.prompt_store = dict(MinichainContext.prompt_store) 352 | MinichainContext.prompt_store = {} 353 | MinichainContext.prompt_count = {} 354 | MinichainContext.name = "" 355 | 356 | 357 | def start_chain(name: str) -> MiniChain: 358 | """ 359 | Initialize a chain. Logs to {name}.log. Returns a `MiniChain` that 360 | holds LLM backends.. 361 | """ 362 | return MiniChain(name) 363 | 364 | 365 | # def show_log(filename: str, o: Callable[[str], Any] = sys.stderr.write) -> None: 366 | # """ 367 | # Write out the full asynchronous log from file `filename`. 368 | # """ 369 | # render_tasks( 370 | # o, 371 | # tasks_from_iterable([json.loads(line) for line in open(filename)]), 372 | # colorize=True, 373 | # human_readable=True, 374 | # ) 375 | -------------------------------------------------------------------------------- /minichain/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict, dataclass 2 | from itertools import count 3 | from typing import ( 4 | Any, 5 | Callable, 6 | Generic, 7 | Iterable, 8 | Iterator, 9 | List, 10 | Optional, 11 | TypeVar, 12 | Union, 13 | ) 14 | 15 | from jinja2 import Environment, FileSystemLoader, Template 16 | 17 | from .backend import Backend, MinichainContext, PromptSnap, RunLog 18 | 19 | Input = TypeVar("Input") 20 | Output = TypeVar("Output") 21 | FnOutput = TypeVar("FnOutput") 22 | 23 | 24 | @dataclass 25 | class History: 26 | expand: Callable[[List[Any]], Iterator[Any]] 27 | inputs: List[Any] 28 | 29 | 30 | @dataclass 31 | class Break: 32 | pass 33 | 34 | 35 | @dataclass 36 | class Chain: 37 | history: History 38 | name: str 39 | cache: Any = None 40 | 41 | def run_gen(self) -> Any: 42 | # Lazily instantiate all the inputs 43 | args = [] 44 | for i, base_input in enumerate(self.history.inputs): 45 | function_input = base_input 46 | if isinstance(base_input, Chain): 47 | if base_input.cache is not None: 48 | function_input = base_input.cache 49 | if isinstance(function_input, Break): 50 | yield Break() 51 | return 52 | else: 53 | for function_input in base_input.run_gen(): 54 | if isinstance(function_input, Break): 55 | base_input.cache = Break() 56 | yield Break() 57 | return 58 | yield None 59 | 60 | base_input.cache = function_input 61 | args.append(function_input) 62 | # Run the current prompt 63 | for out in self.history.expand(*args): 64 | if isinstance(out, Break): 65 | yield Break() 66 | return 67 | 68 | yield None 69 | yield out 70 | 71 | def run(self) -> Any: 72 | for x in self.run_gen(): 73 | pass 74 | return x 75 | 76 | 77 | class Prompt(Generic[Input, Output, FnOutput]): 78 | counter = count() 79 | 80 | def __init__( 81 | self, 82 | fn: Callable[[Callable[[Input], Output]], Iterable[FnOutput]], 83 | backend: Union[List[Backend], Backend], 84 | template_file: Optional[str], 85 | template: Optional[str], 86 | gradio_conf: Any = None, 87 | ): 88 | self.fn = fn 89 | if not isinstance(backend, List): 90 | self.backend = [backend] 91 | else: 92 | self.backend = backend 93 | 94 | self.template_file: Optional[str] = template_file 95 | self.template: Optional[str] = template 96 | self.gradio_conf = gradio_conf 97 | 98 | self._fn: str = fn.__name__ 99 | self._id: int = Prompt.counter.__next__() 100 | 101 | def run(self, request: str, tool_num: int = 0) -> Iterator[RunLog]: 102 | if not hasattr(self.backend[tool_num], "run_stream"): 103 | yield RunLog(request, None) 104 | response: Union[str, Any] = self.backend[tool_num].run(request) 105 | yield RunLog(request, response) 106 | else: 107 | yield RunLog(request, None) 108 | for r in self.backend[tool_num].run_stream(request): 109 | yield RunLog(request, r) 110 | 111 | def template_fill(self, inp: Any) -> str: 112 | kwargs = inp 113 | if self.template_file: 114 | tmp = Environment(loader=FileSystemLoader(".")).get_template( 115 | name=self.template_file 116 | ) 117 | elif self.template: 118 | tmp = Template(self.template) 119 | 120 | return str(tmp.render(**kwargs)) 121 | 122 | def __call__(self, *args: Any) -> Chain: 123 | return Chain(History(self.expand, list(args)), self.fn.__name__) 124 | 125 | class Model: 126 | def __init__(self, prompt: "Prompt[Input, Output, FnOutput]", data: Any): 127 | self.prompt = prompt 128 | self.data = data 129 | self.run_log = RunLog() 130 | 131 | def __call__(self, model_input: Any, tool_num: int = 0) -> Any: 132 | for r in self.stream(model_input, tool_num): 133 | yield r 134 | 135 | # print("hello tool") 136 | # for out in self.prompt.dynamic[tool_num].expand(*model_input): 137 | # self.run_log = self.prompt.dynamic[tool_num].model.run_log 138 | # self.run_log.dynamic = tool_num 139 | # yield out 140 | 141 | def stream( 142 | self, model_input: Any, tool_num: int = 0 143 | ) -> Iterator[Optional[str]]: 144 | if ( 145 | self.prompt.template is not None 146 | or self.prompt.template_file is not None 147 | ): 148 | if not isinstance(model_input, dict): 149 | model_input = asdict(model_input) 150 | result = self.prompt.template_fill(model_input) 151 | else: 152 | result = model_input 153 | 154 | for run_log in self.prompt.run(result, tool_num): 155 | r = self.run_log.response 156 | if run_log.response is None: 157 | out = r 158 | elif not r: 159 | out = run_log.response 160 | else: 161 | out = r + run_log.response 162 | self.run_log = RunLog(run_log.request, out, dynamic=tool_num) 163 | yield self.run_log.response 164 | 165 | def expand( 166 | self, *args: List[Any], data: Any = None 167 | ) -> Iterator[Optional[FnOutput]]: 168 | # Times prompt has been used. 169 | MinichainContext.prompt_count.setdefault(self._id, -1) 170 | MinichainContext.prompt_count[self._id] += 1 171 | count = MinichainContext.prompt_count[self._id] 172 | 173 | # Snap of the prompt 174 | MinichainContext.prompt_store.setdefault((self._id, count), []) 175 | MinichainContext.prompt_store[self._id, count].append(PromptSnap()) 176 | 177 | # Model to be passed to function 178 | model = self.Model(self, data) 179 | for output in self.fn(model, *args): 180 | t = model.run_log 181 | assert model.run_log, str(model) 182 | snap = PromptSnap(args, t, output) 183 | count = MinichainContext.prompt_count[self._id] 184 | MinichainContext.prompt_store.setdefault((self._id, count), []) 185 | MinichainContext.prompt_store[self._id, count][-1] = snap 186 | yield None 187 | 188 | assert model.run_log, str(model) 189 | t = model.run_log 190 | snap = PromptSnap(args, t, output) 191 | MinichainContext.prompt_store[self._id, count][-1] = snap 192 | yield output 193 | 194 | 195 | def prompt( 196 | backend: List[Backend] = [], 197 | template_file: Optional[str] = None, 198 | template: Optional[str] = None, 199 | gradio_conf: Optional[Any] = None, 200 | ) -> Callable[[Any], Prompt[Input, Output, FnOutput]]: 201 | return lambda fn: Prompt(fn, backend, template_file, template, gradio_conf) 202 | 203 | 204 | def transform(): # type: ignore 205 | return lambda fn: lambda *args: Chain( 206 | History(lambda *x: iter((fn(*x),)), list(args)), fn.__name__ 207 | ) 208 | -------------------------------------------------------------------------------- /minichain/gradio.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import os 3 | from dataclasses import dataclass, field 4 | from typing import Any, Callable, Dict, List, Set, Tuple, Union 5 | 6 | import gradio as gr 7 | from gradio.blocks import Block 8 | 9 | from minichain import start_chain 10 | 11 | from .backend import MinichainContext 12 | from .base import Prompt 13 | 14 | CSS = """ 15 | #clean div.form {border: 0px} 16 | #response {border: 0px; background: #ffeec6} 17 | #prompt {border: 0px;background: aliceblue} 18 | #json {border: 0px} 19 | span.head {font-size: 60pt; font-family: cursive;} 20 | div.gradio-container {color: black} 21 | div.form {background: inherit} 22 | div.form div.block {padding: 0px; background: #fcfcfc} 23 | """ 24 | 25 | 26 | @dataclass 27 | class GradioConf: 28 | block_input: Callable[[], gr.Blocks] = lambda: gr.Textbox(show_label=False) 29 | block_output: Callable[[], gr.Blocks] = lambda: gr.Textbox(show_label=False) 30 | postprocess_output: Callable[[Any], Any] = lambda x: x 31 | preprocess_input: Callable[[Any], Any] = lambda x: x 32 | 33 | 34 | @dataclass 35 | class HTML: 36 | html: str 37 | 38 | def _repr_html_(self) -> str: 39 | return self.html 40 | 41 | 42 | @dataclass 43 | class DisplayOptions: 44 | markdown: bool = True 45 | 46 | 47 | all_data = gr.State({}) 48 | final_output = gr.State({}) 49 | 50 | 51 | @dataclass 52 | class Constructor: 53 | fns: List[Callable[[Dict[Block, Any]], Dict[Block, Any]]] = field( 54 | default_factory=list 55 | ) 56 | inputs: Set[Block] = field(default_factory=set) 57 | outputs: Set[Block] = field(default_factory=set) 58 | 59 | def merge(self, other: "Constructor") -> "Constructor": 60 | return Constructor( 61 | self.fns + other.fns, 62 | self.inputs | other.inputs, 63 | self.outputs | other.outputs, 64 | ) 65 | 66 | def add_inputs(self, inputs: List[Block]) -> "Constructor": 67 | return Constructor(self.fns, self.inputs | set(inputs), self.outputs) 68 | 69 | def fn(self, data: Dict[Block, Any]) -> Dict[Block, Any]: 70 | out: Dict[Block, Any] = {} 71 | for fn in self.fns: 72 | out = {**out, **fn(data)} 73 | return out 74 | 75 | 76 | def to_gradio_block( 77 | base_prompt: Prompt[Any, Any, Any], 78 | i: int, 79 | display_options: DisplayOptions = DisplayOptions(), 80 | show_advanced: bool = True, 81 | ) -> Constructor: 82 | prompts = [] 83 | results = [] 84 | bp = base_prompt 85 | with gr.Accordion( 86 | label=f"👩 Prompt: {str(base_prompt._fn)}", elem_id="prompt", visible=False 87 | ) as accordion_in: 88 | for backend in base_prompt.backend: 89 | if bp.gradio_conf is not None: 90 | prompt = bp.gradio_conf.block_input() 91 | elif hasattr(backend, "_block_input"): 92 | prompt: gr.Blocks = backend._block_input(gr) # type: ignore 93 | else: 94 | prompt = GradioConf().block_input() 95 | prompts.append(prompt) 96 | 97 | with gr.Accordion(label="💻", elem_id="response", visible=False) as accordion_out: 98 | for backend in base_prompt.backend: 99 | if bp.gradio_conf is not None: 100 | result = bp.gradio_conf.block_output() 101 | elif hasattr(backend, "_block_output"): 102 | result: gr.Blocks = backend._block_output(gr) # type: ignore 103 | else: 104 | result = GradioConf().block_output() 105 | results.append(result) 106 | 107 | with gr.Accordion(label="...", open=False, visible=show_advanced): 108 | gr.Markdown(f"Backend: {base_prompt.backend}", elem_id="json") 109 | input = gr.JSON(elem_id="json", label="Input") 110 | json = gr.JSON(elem_id="json", label="Output") 111 | 112 | if base_prompt.template_file: 113 | gr.Code( 114 | label=f"Template: {base_prompt.template_file}", 115 | value=open(base_prompt.template_file).read(), 116 | elem_id="inner", 117 | ) 118 | # btn = gr.Button("Modify Template") 119 | # if base_prompt.template_file is not None: 120 | 121 | # def update_template(template: str) -> None: 122 | # if base_prompt.template_file is not None: 123 | # with open(base_prompt.template_file, "w") as doc: 124 | # doc.write(template) 125 | 126 | # btn.click(update_template, inputs=c) 127 | 128 | def update(data: Dict[Block, Any]) -> Dict[Block, Any]: 129 | "Update the prompt block" 130 | prev_request_ = "" 131 | if (base_prompt._id, i) not in data[all_data]: 132 | ret = {} 133 | for p, r in zip(prompts, results): 134 | ret[p] = gr.update(visible=False) 135 | ret[r] = gr.update(visible=False) 136 | return ret 137 | 138 | if (base_prompt._id, i - 1) in data[all_data]: 139 | prev_request_ = data[all_data][base_prompt._id, i - 1][-1].run_log.request 140 | 141 | snap = data[all_data][base_prompt._id, i][-1] 142 | input_, request_, response_, output_ = ( 143 | snap.input_, 144 | snap.run_log.request, 145 | snap.run_log.response, 146 | snap.output, 147 | ) 148 | 149 | def format(s: Any) -> Any: 150 | if isinstance(s, str): 151 | return {"string": s} 152 | return s 153 | 154 | def mark(s: Any) -> Any: 155 | return str(s) # f"```text\n{s}\n```" 156 | 157 | j = 0 158 | for (a, b) in zip(request_, prev_request_): 159 | if a != b: 160 | break 161 | j += 1 162 | 163 | if base_prompt.gradio_conf is not None: 164 | request_ = base_prompt.gradio_conf.preprocess_input(request_) 165 | output_ = base_prompt.gradio_conf.postprocess_output(output_) 166 | # if j > 30: 167 | # new_prompt = "...\n" + request_[j:] 168 | # else: 169 | new_prompt = request_ 170 | 171 | ret = { 172 | input: format(input_), 173 | json: format(response_), 174 | accordion_in: gr.update(visible=True), 175 | accordion_out: gr.update(visible=bool(output_)), 176 | } 177 | for j, (prompt, result) in enumerate(zip(prompts, results)): 178 | if j == snap.run_log.dynamic: 179 | ret[prompt] = gr.update(value=new_prompt, visible=True) 180 | if output_: 181 | ret[result] = gr.update(value=output_, visible=True) 182 | else: 183 | ret[result] = gr.update(visible=True) 184 | else: 185 | ret[prompt] = gr.update(visible=False) 186 | ret[result] = gr.update(visible=False) 187 | 188 | return ret 189 | 190 | return Constructor( 191 | [update], 192 | set(), 193 | {accordion_in, accordion_out, input, json} | set(prompts) | set(results), 194 | ) 195 | 196 | 197 | def chain_blocks( 198 | prompts: List[Prompt[Any, Any, Any]], show_advanced: bool = True 199 | ) -> Constructor: 200 | cons = Constructor() 201 | count: Dict[int, int] = {} 202 | for p in prompts: 203 | count.setdefault(p._id, 0) 204 | i = count[p._id] 205 | cons = cons.merge(to_gradio_block(p, i, show_advanced=show_advanced)) 206 | count[p._id] += 1 207 | return cons 208 | 209 | 210 | def api_keys(keys: Set[str] = {"OPENAI_API_KEY"}) -> None: 211 | if all([k in os.environ for k in keys]): 212 | return 213 | key_names = {} 214 | 215 | with gr.Accordion(label="API Keys", elem_id="json", open=False): 216 | if "OPENAI_API_KEY" in keys and "OPENAI_API_KEY" not in os.environ: 217 | key_names["OPENAI_API_KEY"] = gr.Textbox( 218 | os.environ.get("OPENAI_API_KEY"), 219 | label="OpenAI Key", 220 | elem_id="json", 221 | type="password", 222 | ) 223 | gr.Markdown( 224 | """ 225 | * [OpenAI Key](https://platform.openai.com/account/api-keys) 226 | """ 227 | ) 228 | 229 | if "HF_KEY" in keys: 230 | gr.Markdown( 231 | """ 232 | * [Hugging Face Key](https://huggingface.co/settings/tokens) 233 | """ 234 | ) 235 | 236 | key_names["HF_KEY"] = gr.Textbox( 237 | os.environ.get("HF_KEY"), 238 | label="Hugging Face Key", 239 | elem_id="inner", 240 | type="password", 241 | ) 242 | 243 | if "SERP_KEY" in keys: 244 | gr.Markdown( 245 | """ 246 | * [Search Key](https://serpapi.com/users/sign_in) 247 | """ 248 | ) 249 | key_names["SERP_KEY"] = gr.Textbox( 250 | os.environ.get("SERP_KEY"), 251 | label="Search Key", 252 | elem_id="inner", 253 | type="password", 254 | ) 255 | 256 | api_btn = gr.Button("Save") 257 | 258 | def api_run(data): # type: ignore 259 | for k, v in key_names.items(): 260 | if data[v] is not None and data[v] != "": 261 | os.environ[k] = data[v] 262 | return {} 263 | 264 | api_btn.click(api_run, inputs=set(key_names.values())) 265 | 266 | 267 | def show( 268 | prompt: Prompt[Any, Any, Any], 269 | examples: Union[List[str], List[Tuple[str]]] = [""], 270 | subprompts: List[Prompt[Any, Any, Any]] = [], 271 | fields: List[str] = [], 272 | initial_state: Any = None, 273 | out_type: str = "markdown", 274 | keys: Set[str] = {"OPENAI_API_KEY"}, 275 | description: str = "", 276 | code: str = "", 277 | css: str = "", 278 | show_advanced: bool = True, 279 | ) -> gr.Blocks: 280 | """ 281 | Constructs a gradio component to show a prompt chain. 282 | 283 | Args: 284 | prompt: A prompt or prompt chain to display. 285 | examples: A list of example inputs, either string or tuples of fields 286 | subprompts: The `Prompt` objects to display. 287 | fields: The names of the field input to the prompt. 288 | initial_state: For stateful prompts, the initial value. 289 | out_type: type of final output 290 | keys: user keys required 291 | description: description of the model 292 | code: code to display 293 | css : additional css 294 | show_advanced : show the "..." advanced elements 295 | 296 | Returns: 297 | Gradio block 298 | """ 299 | fields = [arg for arg in inspect.getfullargspec(prompt).args if arg != "state"] 300 | with gr.Blocks(css=CSS + "\n" + css, theme=gr.themes.Monochrome()) as demo: 301 | # API Keys 302 | api_keys() 303 | 304 | constructor = Constructor() 305 | 306 | # Collect all the inputs 307 | state = gr.State(initial_state) 308 | constructor = constructor.merge(Constructor(inputs={state}, outputs={state})) 309 | 310 | # Show the description 311 | gr.Markdown(description) 312 | 313 | # Build the top query box with one input for each field. 314 | inputs = list([gr.Textbox(label=f) for f in fields]) 315 | examples = gr.Examples(examples=examples, inputs=inputs) 316 | query_btn = gr.Button(value="Run") 317 | constructor = constructor.add_inputs(inputs) 318 | 319 | with gr.Group(): 320 | # Intermediate prompt displays 321 | constructor = constructor.merge( 322 | chain_blocks(subprompts, show_advanced=show_advanced) 323 | ) 324 | 325 | # Final Output result 326 | # with gr.Accordion(label="✔️", elem_id="result"): 327 | # typ = gr.JSON if out_type == "json" else gr.Markdown 328 | # output = typ(elem_id="inner") 329 | 330 | def output_fn(data: Dict[Block, Any]) -> Dict[Block, Any]: 331 | final = data[final_output] 332 | return {state: final} # output: final} 333 | 334 | constructor = constructor.merge(Constructor([output_fn], set(), set())) 335 | 336 | def run(data): # type: ignore 337 | prompt_inputs = {k: data[v] for k, v in zip(fields, inputs)} 338 | if initial_state is not None: 339 | prompt_inputs["state"] = data[state] 340 | 341 | with start_chain("temp"): 342 | 343 | for output in prompt(**prompt_inputs).run_gen(): 344 | data[all_data] = dict(MinichainContext.prompt_store) 345 | data[final_output] = output 346 | yield constructor.fn(data) 347 | if output is not None: 348 | break 349 | yield constructor.fn(data) 350 | 351 | query_btn.click(run, inputs=constructor.inputs, outputs=constructor.outputs) 352 | 353 | if code: 354 | gr.Code(code, language="python", elem_id="inner") 355 | 356 | return demo 357 | -------------------------------------------------------------------------------- /minichain/templates/prompt.html.tpl: -------------------------------------------------------------------------------- 1 | 2 | 3 |
4 | 5 |

{{name}}

6 | 7 |
8 |
Input:
9 |
10 | {% highlight 'python' %} 11 | {{input}} 12 | {% endhighlight %} 13 | 14 |
15 | 16 |
Full Prompt:
17 |
18 |
19 | Prompt 20 |

{{prompt | safe}}

21 |
22 |
23 | 24 |
Response:
25 |
26 | {{response | replace("\n", "
") | safe}} 27 |
28 | 29 |
Value:
30 |
31 | {% highlight 'python' %} 32 | {{output}} 33 | {% endhighlight %} 34 |
35 |
36 | 37 | -------------------------------------------------------------------------------- /minichain/templates/type_prompt.pmpt.tpl: -------------------------------------------------------------------------------- 1 | You are a highly intelligent and accurate information extraction system. You take passage as input and your task is to find parts of the passage to answer questions. 2 | 3 | {% macro describe(typ) -%} 4 | {% for key, val in typ.items() %} 5 | You need to classify in to the following types for key: "{{key}}": 6 | {% if val == "str" %}String 7 | {% elif val == "int" %}Int {% else %} 8 | {% if val.get("_t_") == "list" %}List{{describe(val["t"])}}{% else %} 9 | 10 | {% for k, v in val.items() %}{{k}} 11 | {% endfor %} 12 | 13 | Only select from the above list. 14 | {% endif %} 15 | {%endif%} 16 | {% endfor %} 17 | {% endmacro -%} 18 | {{describe(typ)}} 19 | {% macro json(typ) -%}{% for key, val in typ.items() %}{% if val in ["str", "int"] or val.get("_t_") != "list" %}"{{key}}" : "{{key}}" {% else %} "{{key}}" : [{ {{json(val["t"])}} }] {% endif %}{{"" if loop.last else ", "}} {% endfor %}{% endmacro -%} 20 | 21 | [{ {{json(typ)}} }, ...] 22 | 23 | 24 | 25 | Make sure every output is exactly seen in the document. Find as many as you can. 26 | You need to output only JSON. -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Mini-Chain 2 | site_url: https://srush.github.io/minichain 3 | site_description: Mini-Chain 4 | site_author: Sasha Rush 5 | 6 | ### Repository 7 | repo_url: https://github.com/srush/minichain 8 | edit_uri: '' # comment this out to disable allowing editing of the docs from the website. 9 | remote_branch: gh-pages 10 | remote_name: origin 11 | 12 | ### Copyright 13 | copyright: | 14 | Maintained by Sasha Rush. 15 | 16 | ### Preview Controls 17 | use_directory_urls: true 18 | strict: false 19 | dev_addr: localhost:8000 20 | 21 | ### Configuration 22 | docs_dir: docs 23 | # watch a list of directories for changes 24 | # and automatically regenerate the docs 25 | watch: 26 | - minichain 27 | 28 | ### Theme 29 | theme: 30 | name: material 31 | include_sidebar: true 32 | custom_dir: docs/overrides 33 | #custom_dir: overrides 34 | palette: 35 | - media: "(prefers-color-scheme: light)" 36 | scheme: default 37 | primary: white 38 | accent: amber 39 | # toggle: 40 | # # icon: material/lightbulb-outline 41 | # icon: material/toggle-switch-off-outline 42 | # name: Switch to dark mode 43 | # - media: "(prefers-color-scheme: dark)" 44 | # scheme: slate 45 | # primary: white 46 | # accent: amber 47 | # toggle: 48 | # # icon: material/lightbulb 49 | # icon: material/toggle-switch 50 | # name: Switch to light mode 51 | features: 52 | - content.code.annotate 53 | - content.tabs.link 54 | # - header.autohide 55 | # - navigation.expand 56 | - navigation.indexes # @regular 57 | - navigation.instant # @regular | enables "instant-loading"; good for a very large docs repo. 58 | - navigation.sections # @regular | extending top level sections. 59 | - navigation.tabs # @regular | enables showing toplevel sections as tabs (horizontal). 60 | - navigation.tabs.sticky # @regular | keeps the tabs visible even when you have scrolled down. 61 | - navigation.top # @regular | adds a "back-to-top" is shown after the user scrolls down and then starts to come back up again. 62 | - navigation.tracking # @insiders 63 | - search.highlight 64 | - search.share 65 | - search.suggest 66 | - toc.integrate: false # @regular | integrates the nav (on-left) with toc (on-right) and places the integrated nav+toc on-left. 67 | icon: 68 | # repo: fontawesome/brands/git-square 69 | repo: fontawesome/brands/git-alt 70 | # repo: fontawesome/brands/github 71 | # repo: fontawesome/brands/github-alt 72 | # repo: fontawesome/brands/github-square 73 | logo: https://user-images.githubusercontent.com/35882/218286642-67985b6f-d483-49be-825b-f62b72c469cd.png # img/icon-white.svg 74 | # favicon: logo.png # img/favicon.png 75 | font: 76 | text: Roboto 77 | code: Roboto Mono # Source Code Pro, JetBrains Mono, Roboto Mono 78 | language: en 79 | 80 | ### Plugins 81 | plugins: 82 | - exclude: 83 | glob: 84 | - 'examples/*.py' 85 | - search: 86 | indexing: 'full' # 'full' (default), 'sections', 'titles' 87 | - autorefs 88 | # - git-revision-date 89 | # macros must be placed after plugin: git-revision-date 90 | # - social # @insiders 91 | - mkdocs-jupyter: 92 | include_source: true 93 | ignore_h1_titles: true 94 | execute: false 95 | - mkdocstrings: 96 | handlers: 97 | python: 98 | options: 99 | heading_level: 2 100 | show_root_full_path: false 101 | show_root_heading: true 102 | show_source: false 103 | show_signature: true 104 | show_signature_annotations: true 105 | 106 | ### Extensions 107 | markdown_extensions: 108 | # - abbr 109 | # - admonition 110 | # - attr_list 111 | # - codehilite 112 | # - def_list 113 | # - extra 114 | # - footnotes 115 | # - meta 116 | # - md_in_html 117 | # - smarty 118 | # - tables 119 | # - toc 120 | ##! Controls: markdown.extensions 121 | - markdown.extensions.abbr # same as: - abbr 122 | - markdown.extensions.admonition # same as: - admonition 123 | - markdown.extensions.attr_list # same as: - attr_list 124 | - markdown.extensions.codehilite: # same as: - codehilite 125 | guess_lang: false 126 | - markdown.extensions.def_list # same as: - def_list 127 | - markdown.extensions.extra # same as: - extra 128 | - markdown.extensions.footnotes # same as: - footnotes 129 | - markdown.extensions.meta: # same as: - meta 130 | - markdown.extensions.md_in_html # same as: - md_in_html 131 | - markdown.extensions.smarty: # same as: - smarty 132 | smart_quotes: false 133 | - markdown.extensions.tables # same as: - tables 134 | - markdown.extensions.toc: # same as: - toc 135 | slugify: !!python/name:pymdownx.slugs.uslugify 136 | permalink: true 137 | toc_depth: 4 # default: 6 138 | #separator: "-" 139 | 140 | - markdown_include.include: 141 | base_path: docs 142 | 143 | ##! Controls: mdx 144 | - mdx_include: 145 | base_path: docs 146 | - mdx_truly_sane_lists: 147 | nested_indent: 2 148 | truly_sane: true 149 | 150 | ##! Controls: pymdownx 151 | - pymdownx.arithmatex: 152 | generic: true 153 | # - pymdownx.b64: 154 | # base_path: '.' 155 | - pymdownx.betterem: 156 | smart_enable: all # default: 'underscore' ; options: 'underscore', 'all', 'asterisk', or 'none' 157 | - pymdownx.caret: # "super^script^" will render as superscript text: superscript. 158 | smart_insert: true # default: true 159 | insert: true # default: true 160 | superscript: true # default: true 161 | - pymdownx.critic 162 | - pymdownx.details 163 | - pymdownx.emoji: 164 | emoji_index: !!python/name:materialx.emoji.twemoji 165 | emoji_generator: !!python/name:materialx.emoji.to_svg 166 | - pymdownx.escapeall: 167 | hardbreak: false 168 | nbsp: false 169 | # Uncomment these 2 lines during development to more easily add highlights 170 | - pymdownx.highlight: 171 | use_pygments: true # this uses pygments 172 | linenums: false # Set "linenums" to true for enabling 173 | # code-block line-numbering 174 | # globally. 175 | # None: only enable line-numbering on a per code-block basis. 176 | # False: disable line-numbering globally. 177 | auto_title: false 178 | auto_title_map: { 179 | "Python Console Session": "Python", # lang: pycon 180 | } 181 | linenums_style: pymdownx-inline # table or pymdownx-inline 182 | - pymdownx.keys: 183 | separator: "\uff0b" 184 | - pymdownx.magiclink: 185 | repo_url_shortener: true 186 | repo_url_shorthand: true # 187 | social_url_shorthand: true 188 | social_url_shortener: true 189 | user: !ENV REPO_OWNER # sugatoray, danoneata (github userid) 190 | repo: chalk # 191 | normalize_issue_symbols: true 192 | - pymdownx.mark: 193 | smart_mark: true 194 | - pymdownx.pathconverter: 195 | base_path: 'chalk' # default: '' 196 | relative_path: '' # default '' 197 | absolute: true # default: false 198 | tags: 'a script img link object embed' 199 | - pymdownx.progressbar: 200 | level_class: true 201 | add_classes: '' 202 | #'progress-0plus progress-10plus progress-20plus progress-30plus progress-40plus progress-50plus progress-60plus progress-70plus progress-80plus progress-90plus progress-100plus' 203 | progress_increment: 10 204 | - pymdownx.saneheaders 205 | - pymdownx.superfences: 206 | # highlight_code: true # This was removed from pymdownx v9.0 207 | preserve_tabs: false 208 | disable_indented_code_blocks: false # default: false | set this to "true" 209 | # if you only use fenced code-blocks. 210 | custom_fences: 211 | - name: mermaid 212 | class: mermaid 213 | format: !!python/name:pymdownx.superfences.fence_code_format '' 214 | # - name: md-render 215 | # class: md-render 216 | # format: !!python/name:tools.pymdownx_md_render.md_sub_render 217 | - pymdownx.smartsymbols 218 | - pymdownx.snippets: 219 | base_path: 220 | - '.' 221 | - './docs_src' 222 | - './LICENSE' 223 | - './README.md' 224 | - './doc' # [TODO: move the contents of this folder to the docs folder] 225 | encoding: 'utf-8' # Encoding to use when reading in the snippets. 226 | check_paths: true # Make the build fail if a snippet can't be found. 227 | - pymdownx.striphtml 228 | - pymdownx.tabbed 229 | - pymdownx.tasklist: 230 | custom_checkbox: true 231 | - pymdownx.tasklist: 232 | custom_checkbox: true 233 | - pymdownx.tilde # ~~text~~ will render as strikethrough text. "sub~script" will render as subscript text: subscript. 234 | 235 | 236 | 237 | 238 | 239 | ### Extra CSS 240 | extra_css: 241 | ## for: termynal (terminal animation) 242 | - assets/css-js/termynal/css/termynal.css 243 | - assets/css-js/termynal/css/custom.css 244 | ## for: pymdownx.progressbar 245 | - assets/css-js/general/css/progressbar.css 246 | # - assets/css-js/pymdownx-extras/css/extra.css # (for striped progress bar) 247 | ## for: mkdocs-tooltips 248 | - assets/css-js/mkdocs-tooltips/css/hint.min.css 249 | - assets/css-js/mkdocs-tooltips/css/custom.css 250 | ## for: mkdocs-material using highlight.js 251 | - https://cdnjs.cloudflare.com/ajax/libs/highlight.js/10.7.2/styles/default.min.css 252 | ## for: fastapi like side-theme 253 | - assets/css-js/fastapi/custom.css 254 | 255 | 256 | ### Extra JS 257 | extra_javascript: 258 | ## for: pymdownx.arithmatex 259 | ## - https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-MML-AM_CHTML 260 | ## for: markdown.extensions.tables 261 | - https://cdnjs.cloudflare.com/ajax/libs/tablesort/5.2.1/tablesort.min.js 262 | - assets/css-js/general/js/tables.js 263 | ## for: termynal (terminal animation) 264 | - assets/css-js/termynal/js/termynal.js 265 | - assets/css-js/termynal/js/custom.js 266 | # Set the environment variable "FONTAWESOME_KIT" with the value of the kit. 267 | - !ENV FONTAWESOME_KIT 268 | ## for: lottiefiles 269 | - https://unpkg.com/@lottiefiles/lottie-player@latest/dist/lottie-player.js 270 | ## for: mkdocs-material using highlight.js 271 | - https://cdnjs.cloudflare.com/ajax/libs/highlight.js/10.7.2/highlight.min.js 272 | - assets/css-js/general/js/highlight-config.js 273 | ## for: mkdocs-markmap 274 | - https://unpkg.com/d3@6.7.0/dist/d3.min.js 275 | - https://unpkg.com/markmap-lib@0.11.5/dist/browser/index.min.js 276 | - https://unpkg.com/markmap-view@0.2.6/dist/index.min.j 277 | ## Others 278 | - https://polyfill.io/v3/polyfill.min.js?features=es6 279 | # - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js 280 | ## for: fastapi like side-theme 281 | - assets/css-js/fastapi/custom.js 282 | - assets/css-js/fastapi/chat.js 283 | 284 | ### Pages: Navigation 285 | 286 | ## @@ Begin NAVIGATION 287 | nav: 288 | - Home: index.md 289 | 290 | 291 | ## @@ End NAVIGATION 292 | -------------------------------------------------------------------------------- /requirements-docs.txt: -------------------------------------------------------------------------------- 1 | jupyter 2 | markdown-include>=0.6.0 3 | mdx-include>=1.4.1 4 | mdx_truly_sane_lists>=1.2 5 | 6 | 7 | ## API documentation building 8 | mkapi 9 | mkautodoc 10 | mkdocs 11 | mkdocs-autorefs>=0.3.1 12 | mkdocs-awesome-pages-plugin>=2.5.0 13 | 14 | 15 | ## Citations & bibliography 16 | # source: https://github.com/mkdocs/mkdocs/wiki/MkDocs-Plugins#citations--bibliography 17 | mkdocs-bibtex>=1.0.0 18 | mkdocs-coverage>=0.2.4 19 | mkdocs-drawio-exporter>=0.8.0 20 | mkdocs-enumerate-headings-plugin>=0.4.5 21 | 22 | ## Navigation & page building 23 | # source: https://github.com/mkdocs/mkdocs/wiki/MkDocs-Plugins#navigation--page-building 24 | mkdocs-exclude>=1.0.2 25 | mkdocs-gen-files>=0.3.3 26 | mkdocs-git-revision-date-plugin>=0.3.1 27 | 28 | 29 | ## Reusing content, snippets & includes 30 | # source: https://github.com/mkdocs/mkdocs/wiki/MkDocs-Plugins#reusing-content-snippets--includes 31 | mkdocs-include-markdown-plugin>=3.2.3 32 | mkdocs-jupyter>=0.18.2 33 | 34 | 35 | ## Images, Tables, Charts & Graphs 36 | # source: https://github.com/mkdocs/mkdocs/wiki/MkDocs-Plugins#images-tables-charts--graphs 37 | mkdocs-kroki-plugin>=0.2.0 38 | 39 | 40 | ## Git repos and info 41 | # source: https://github.com/mkdocs/mkdocs/wiki/MkDocs-Plugins#git-repos--info 42 | mkdocs-macros-plugin>=0.6.0 43 | mkdocs-markdownextradata-plugin>=0.2.4 44 | mkdocs-markdownextradata-plugin>=0.2.4 45 | mkdocs-markmap>=2.1.2 46 | mkdocs-material==8.1.3 47 | mkdocs-material-extensions>=1.0.3 48 | 49 | 50 | ## HTML processing & CSS styling 51 | # source: https://github.com/mkdocs/mkdocs/wiki/MkDocs-Plugins#html-processing--css-styling 52 | mkdocs-minify-plugin>=0.4.1 53 | 54 | 55 | ## PDF & site conversion 56 | # source: https://github.com/mkdocs/mkdocs/wiki/MkDocs-Plugins#pdf--site-conversion 57 | # https://doc.courtbouillon.org/weasyprint/latest/first_steps.html#linux 58 | mkdocs-pdf-export-plugin>=0.5.9 59 | 60 | 61 | ## Links & references 62 | # source: https://github.com/mkdocs/mkdocs/wiki/MkDocs-Plugins#links--references 63 | mkdocs-redirects>=1.0.3 64 | mkdocs-table-reader-plugin>=0.6 65 | 66 | 67 | ## Other 68 | # source: https://github.com/mkdocs/mkdocs/wiki/MkDocs-Plugins#other 69 | mkdocs-tooltips>=0.1.0 70 | mkdocs-tooltipster-links-plugin>=0.1.0 71 | mkdocstrings[python]>=0.16.2 72 | 73 | 74 | ## Code execution, variables & templating 75 | # source: https://github.com/mkdocs/mkdocs/wiki/MkDocs-Plugins#code-execution-variables--templating 76 | pydoc-markdown==4.6.3 77 | pygments 78 | pymdown-extensions>=9.0 79 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | async-openai 2 | datasets 3 | eliot 4 | eliot-tree 5 | faiss-cpu 6 | google-search-results 7 | gradio==4.7 8 | huggingface-hub 9 | jinja2 10 | jinja2-highlight 11 | openai==0.28 12 | parsita==1.7.1 13 | trio 14 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | # what is this for? 3 | # - https://packaging.python.org/en/latest/guides/distributing-packages-using-setuptools/#wheels 4 | universal=0 5 | 6 | [metadata] 7 | license_file = LICENSE 8 | 9 | [black] 10 | line-length = 79 11 | # exclude = ''' 12 | # /( 13 | # \.archive 14 | # | \.git 15 | # | examples/* 16 | # | \.hg 17 | # | \.mypy_cache 18 | # | \.tox 19 | # | \.venv 20 | # | \.vscode 21 | # | _build 22 | # | buck-out 23 | # | build 24 | # | dist 25 | # | migrations 26 | # | site 27 | # )/ 28 | # ''' 29 | 30 | [isort] 31 | # make it compatible with black 32 | profile = black 33 | # # Make sure this matches `*.py` in .editorconfig 34 | # ensure_newline_before_comments = true 35 | # force_single_line = true 36 | # lines_after_imports = 3 37 | # include_trailing_comma = true 38 | # use_parentheses = true 39 | 40 | [flake8] 41 | per-file-ignores=minichain/__init__.py: F401 42 | max-line-length = 88 43 | extend-ignore = E203 44 | 45 | [darglint] 46 | ##? Source: https://github.com/terrencepreilly/darglint 47 | ## Ignore properties 48 | ignore_properties = 1 49 | ## Ignore private methods 50 | ignore_regex = ^_(.*) 51 | ## Use message template 52 | # message_template = {msg_id}@{path}:{line} 53 | ## Docstring style to use: 54 | # - google (default) 55 | # - sphinx 56 | # - numpy 57 | docstring_style = google 58 | ## How strict? 59 | # short: One-line descriptions are acceptable; anything 60 | # more and the docstring will be fully checked. 61 | # 62 | # long: One-line descriptions and descriptions without 63 | # arguments/returns/yields/etc. sections will be 64 | # allowed. Anything more, and the docstring will 65 | # be fully checked. 66 | # 67 | # full: (Default) Docstrings will be fully checked. 68 | strictness = long 69 | ## Ignore common exceptions 70 | # ignore_raise = ValueError,MyCustomError 71 | ## Ignore Specific Error Codes 72 | # Example: ignore = DAR402,DAR103 73 | #------------------------------------------------------------------------ 74 | # DAR001 # The docstring was not parsed correctly due to a syntax error. 75 | # DAR002 # An argument/exception lacks a description 76 | # DAR003 # A line is under-indented or over-indented. 77 | # DAR004 # The docstring contains an extra newline where it shouldn't. 78 | # DAR005 # The item contains a type section (parentheses), but no type. 79 | # DAR101 # The docstring is missing a parameter in the definition. 80 | # DAR102 # The docstring contains a parameter not in function. 81 | # DAR103 # The docstring parameter type doesn't match function. 82 | # DAR104 # (disabled) The docstring parameter has no type specified 83 | # DAR105 # The docstring parameter type is malformed. 84 | # DAR201 # The docstring is missing a return from definition. 85 | # DAR202 # The docstring has a return not in definition. 86 | # DAR203 # The docstring parameter type doesn't match function. 87 | # DAR301 # The docstring is missing a yield present in definition. 88 | # DAR302 # The docstring has a yield not in definition. 89 | # DAR401 # The docstring is missing an exception raised. 90 | # DAR402 # The docstring describes an exception not explicitly raised. 91 | # DAR501 # The docstring describes a variable which is not defined. 92 | #------------------------------------------------------------------------ 93 | ignore = DAR103 94 | 95 | [mypy] 96 | strict = true 97 | warn_unreachable = true 98 | pretty = true 99 | show_column_numbers = true 100 | show_error_codes = true 101 | show_error_context = true 102 | 103 | [mypy-minichain] 104 | implicit_reexport = true 105 | [mypy-minichain.shapes] 106 | implicit_reexport = true 107 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | from setuptools import find_packages, setup 4 | 5 | LICENSE: str = "MIT" 6 | README: str = pathlib.Path("README.md").read_text(encoding="utf-8") 7 | 8 | LIBNAME: str = "minichain" 9 | 10 | setup( 11 | name=LIBNAME, 12 | version="0.3.1", 13 | packages=find_packages( 14 | include=["minichain", "minichain*"], 15 | exclude=["examples", "docs", "test*"], 16 | ), 17 | description="A tiny library for large language models", 18 | extras_require={}, 19 | long_description=README, 20 | long_description_content_type="text/markdown", 21 | include_package_data=True, 22 | package_data={"minichain": ["templates/*.tpl"]}, 23 | author="Sasha Rush", 24 | author_email="srush.research@gmail.com", 25 | url="https://github.com/srush/minichain", 26 | project_urls={ 27 | "Documentation": "https://srush.github.io/minichain", 28 | "Source Code": "https://github.com/srush/minichain", 29 | "Issue Tracker": "https://github.com/srush/minichain/issues", 30 | }, 31 | license=LICENSE, 32 | license_files=("LICENSE",), 33 | classifiers=[ 34 | "Intended Audience :: Science/Research", 35 | "Operating System :: OS Independent", 36 | "Programming Language :: Python :: 3", 37 | "Programming Language :: Python :: 3.7", 38 | "Programming Language :: Python :: 3.8", 39 | "Programming Language :: Python :: 3.9", 40 | f"License :: OSI Approved :: {LICENSE} License", 41 | "Topic :: Scientific/Engineering", 42 | ], 43 | install_requires=[ 44 | "manifest-ml", 45 | "datasets", 46 | "gradio", 47 | "faiss-cpu", 48 | "eliot", 49 | "eliot-tree", 50 | "google-search-results", 51 | "jinja2", 52 | "jinja2-highlight", 53 | "openai==0.28", 54 | "trio", 55 | ], 56 | ) 57 | --------------------------------------------------------------------------------