├── .gitignore ├── .gitmodules ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── recipe-template └── README.md └── recipes ├── confluent └── README.md ├── crewai ├── .streamlit │ └── secrets.toml ├── README.md ├── requirements.txt ├── streamlit.app.py ├── tools │ ├── __init__.py │ ├── browser_tools.py │ ├── calculator_tools.py │ └── search_tools.py ├── trip_agents.py └── trip_tasks.py ├── llamaindex ├── .gitignore ├── README.md ├── arc_finetuning_st │ ├── __init__.py │ ├── cli │ │ ├── __init__.py │ │ ├── command_line.py │ │ ├── evaluation.py │ │ └── finetune.py │ ├── finetuning │ │ ├── __init__.py │ │ ├── finetuning_example.py │ │ └── templates.py │ ├── streamlit │ │ ├── __init__.py │ │ ├── app.py │ │ └── controller.py │ └── workflows │ │ ├── __init__.py │ │ ├── arc_task_solver.py │ │ ├── events.py │ │ ├── models.py │ │ └── prompts.py ├── poetry.lock └── pyproject.toml ├── ollama ├── requirements.txt └── streamlit_app.py ├── openai └── README.md ├── replicate ├── .streamlit │ └── secrets_template.toml ├── README.md ├── requirements.txt └── streamlit_app.py ├── replit └── README.md ├── trulens ├── .env.template ├── README.md ├── app.py ├── base.py ├── feedback.py ├── requirements.txt └── vector_store.py └── weaviate ├── .streamlit └── secrets_template.toml ├── README.md ├── demo_app.py ├── helpers ├── add_data.py ├── data │ └── 1950_2024_movies_info.json ├── demo_movie_query.graphql └── verify_data.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Editors 2 | .vscode/ 3 | 4 | # Mac/OSX 5 | .DS_Store 6 | 7 | # Source for the following rules: https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # Environments 14 | replicatevenv/ 15 | weaviatevenv/ 16 | 17 | # Secrets 18 | *secrets.toml 19 | 20 | # Recipe specific 21 | recipes/weaviate/helpers/data/posters 22 | 23 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "recipes/replit/ab-testing"] 2 | path = recipes/replit/ab-testing 3 | url = https://github.com/mattppal/ab-testing/ 4 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | - Using welcoming and inclusive language 18 | - Being respectful of differing viewpoints and experiences 19 | - Gracefully accepting constructive criticism 20 | - Focusing on what is best for the community 21 | - Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | - The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | - Trolling, insulting/derogatory comments, and personal or political attacks 28 | - Public or private harassment 29 | - Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | - Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at hello@streamlit.io. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 📖 Streamlit Cookbook 2 | 3 | Streamlit Cookbook is a compilation of Streamlit app templates that you can use as boilerplate code to jumpstart your own app creation endeavors. 4 | 5 | In a nutshell, each recipe allows you to quickly spin up a working Streamlit app that implements the library/tool of interest. 6 | 7 | ## 🧑‍🍳 How to use this cookbook? 8 | For your convenience, we've categorized tool integrations by library name. For instance, Streamlit app built using OpenAI is a sub-folder in the `replicate` folder, thus `recipes/replicate`. 9 | 10 | ## 🍪 List of Recipes 11 | 12 | ### AI Recipes 13 | | Tool | Description | Resources | 14 | | -- | -- | -- | 15 | | [CrewAI](https://github.com/streamlit/cookbook/tree/main/recipes/crewai) | Build the VacAIgent app where agents collaboratively decide on cities and craft a complete itinerary for your trip based on specified preferences. | | 16 | | [Ollama](https://github.com/streamlit/cookbook/tree/main/recipes/ollama) | Build a Streamlit Ollama chatbot using a local LLM model. | | 17 | | [Replicate](https://github.com/streamlit/cookbook/tree/main/recipes/replicate) | Build a Streamlit Replicate chatbot that allow users to select an LLM model of their choice for response generation. | [Video](https://youtu.be/zsQ7EN10zj8), [Blog](https://blog.streamlit.io/how-to-recommendation-app-vector-database-weaviate/) | 18 | | [Trulens](https://github.com/streamlit/cookbook/tree/main/recipes/trulens) | Ask questions about the Pacific Northwest and get TruLens evaluations on the app response. | | 19 | | [Weaviate](https://github.com/streamlit/cookbook/tree/main/recipes/weaviate) | Build a movie explorer app that leverages the Weaviate vector database. | [Video](https://youtu.be/SQD-aWlhqvM), [Blog](https://blog.streamlit.io/how-to-recommendation-app-vector-database-weaviate/) | 20 | 21 | ### Data Recipes 22 | | Tool | Description | Resources | 23 | | -- | -- | -- | 24 | | [Replit](https://github.com/streamlit/cookbook/tree/main/recipes/replit) | Build a statistical app that allow users select a suitable sample size for AB testing | [Video](https://youtu.be/CJ9E0Sm_hy4) | 25 | 26 | 27 | ## 🏃 How to run the app? 28 | To run the app, start by installing Streamlit on the command line using the `pip install streamlit` command. Next, change into the directory and launch the app via `streamlit run streamlit_app.py`. 29 | 30 | ## 📥 How to contribute to this repo? 31 | We welcome contributions to this repository Contributions can come in many forms whether it be suggesting a recipe idea (please suggest it on the [issues page](https://github.com/streamlit/streamlit-cookbook/issues)), writing a new recipe, improving a pre-existing recipe, or fixing a typo/error. 32 | 33 | Simply submit a pull request [here](https://github.com/streamlit/streamlit-cookbook/pulls) to get started. 34 | -------------------------------------------------------------------------------- /recipe-template/README.md: -------------------------------------------------------------------------------- 1 | # How to run the demo app (Optional: Add what kind of app it is/the tech the app is built with) 2 | ## 🔍 Overview 3 | This is a recipe for a [TODO: Add the kind of app it is]. TODO: Add one sentence describing what the app does. 4 | 5 | Other ways to explore this recipe: 6 | * [Deployed app](TODO: URL of deployed app) 7 | * [Blog post](TODO: URL of blog post) 8 | * [Video](TODO: URL of video) 9 | 10 | ## 📝 Prerequisites 11 | * Python >=3.8, !=3.9.7 12 | * TODO: List additional prerequisites 13 | 14 | ## 🌎 Environment setup 15 | ### Local setup 16 | 1. Clone the Cookbook repo: `git clone https://github.com/streamlit/cookbook.git` 17 | 2. From the Cookbook root directory, change directory into the recipe: `cd recipes/TODO: Add recipe directory` 18 | 3. Add secrets to the `.streamlit/secrets_template.toml` file 19 | 4. Update the filename from `secrets_template.toml` to `secrets.toml`: `mv .streamlit/secrets_template.toml .streamlit/secrets.toml` 20 | 21 | (To learn more about secrets handling in Streamlit, refer to the documentation [here](https://docs.streamlit.io/develop/concepts/connections/secrets-management).) 22 | 5. Create a virtual environment: `python -m venv TODO: Add name of virtual environment` 23 | 6. Activate the virtual environment: `source TODO: Add name of virtual environment/bin/activate` 24 | 7. Install the dependencies: `pip install -r requirements.txt` 25 | 26 | ### GitHub Codespaces setup 27 | 1. Create a new codespace by selecting the `Codespaces` option from the `Code` button 28 | 2. Once the codespace has been generated, add your secrets to the `recipes/TODO: Add recipe directory/.streamlit/secrets_template.toml` file 29 | 3. Update the filename from `secrets_template.toml` to `secrets.toml` 30 | 31 | (To learn more about secrets handling in Streamlit, refer to the documentation [here](https://docs.streamlit.io/develop/concepts/connections/secrets-management).) 32 | 4. From the Cookbook root directory, change directory into the recipe: `cd recipes/TODO: Add recipe directory` 33 | 5. Install the dependencies: `pip install -r requirements.txt` 34 | -------------------------------------------------------------------------------- /recipes/confluent/README.md: -------------------------------------------------------------------------------- 1 | # Confluent recipe 2 | 3 | - 🕹️ Demo app https://flink-st-kafka.streamlit.app/ 4 | - 🐙 GitHub repo https://github.com/confluentinc/demo-scene/tree/master/flink-streamlit 5 | -------------------------------------------------------------------------------- /recipes/crewai/.streamlit/secrets.toml: -------------------------------------------------------------------------------- 1 | SERPER_API_KEY="API_KEY_HERE" # https://serper.dev/ (free tier) 2 | BROWSERLESS_API_KEY="API_KEY_HERE" # https://www.browserless.io/ (free tier) 3 | OPENAI_API_KEY="API_KEY_HERE" -------------------------------------------------------------------------------- /recipes/crewai/README.md: -------------------------------------------------------------------------------- 1 | ## CrewAI Framework 2 | 3 | CrewAI simplifies the orchestration of role-playing AI agents. In VacAIgent, these agents collaboratively decide on cities and craft a complete itinerary for your trip based on specified preferences, all accessible via a streamlined Streamlit user interface. 4 | 5 | ## Running the Application 6 | 7 | - **Configure Environment**: Set up the environment variables for [Browseless](https://www.browserless.io/), [Serper](https://serper.dev/), and [OpenAI](https://openai.com/). Add your keys to `.streamlit/secrets.toml` file. 8 | 9 | - **Install Dependencies**: Execute `pip install -r requirements.txt` in your terminal. 10 | - **Launch the App**: Run `streamlit run streamlit_app.py` to start the Streamlit interface. 11 | 12 | ★ **Disclaimer**: The application uses GPT-4 by default. Ensure you have access to OpenAI's API and be aware of the associated costs. 13 | 14 | ## Details & Explanation 15 | 16 | - **Components**: 17 | - `./trip_tasks.py`: Contains task prompts for the agents. 18 | - `./trip_agents.py`: Manages the creation of agents. 19 | - `./tools directory`: Houses tool classes used by agents. 20 | - `./streamlit_app.py`: The heart of the Streamlit app. 21 | 22 | ## Using GPT-4o mini 23 | 24 | To switch from GPT-4 to GPT-4o-mini, pass the llm argument in the agent constructor: 25 | 26 | ```python 27 | from langchain.chat_models import ChatOpenAI 28 | 29 | llm = ChatOpenAI(model='gpt-4o-mini') # See more OpenAI models at https://platform.openai.com/docs/models/ 30 | 31 | class TripAgents: 32 | # ... existing methods 33 | 34 | def local_expert(self): 35 | return Agent( 36 | role='Local Expert', 37 | goal='Provide insights about the selected city', 38 | tools=[SearchTools.search_internet, BrowserTools.scrape_and_summarize_website], 39 | llm=llm, 40 | verbose=True 41 | ) 42 | 43 | ``` 44 | 45 | ## Using Local Models with Ollama 46 | 47 | For enhanced privacy and customization, you can integrate local models like Ollama: 48 | 49 | ### Setting Up Ollama 50 | 51 | - **Installation**: Follow Ollama's guide for installation. 52 | - **Configuration**: Customize the model as per your requirements. 53 | 54 | ### Integrating Ollama with CrewAI 55 | 56 | Pass the Ollama model to agents in the CrewAI framework: 57 | 58 | ```python 59 | from langchain.llms import Ollama 60 | 61 | ollama_model = Ollama(model="agent") 62 | 63 | class TripAgents: 64 | # ... existing methods 65 | 66 | def local_expert(self): 67 | return Agent( 68 | role='Local Expert', 69 | tools=[SearchTools.search_internet, BrowserTools.scrape_and_summarize_website], 70 | llm=ollama_model, 71 | verbose=True 72 | ) 73 | 74 | ``` 75 | 76 | -------------------------------------------------------------------------------- /recipes/crewai/requirements.txt: -------------------------------------------------------------------------------- 1 | crewai 2 | streamlit 3 | openai 4 | unstructured 5 | langchain 6 | pyowm 7 | tools -------------------------------------------------------------------------------- /recipes/crewai/streamlit.app.py: -------------------------------------------------------------------------------- 1 | from crewai import Crew 2 | from trip_agents import TripAgents, StreamToExpander 3 | from trip_tasks import TripTasks 4 | import streamlit as st 5 | import datetime 6 | import sys 7 | 8 | st.set_page_config(page_icon="✈️", layout="wide") 9 | 10 | 11 | class TripCrew: 12 | def __init__(self, origin, cities, date_range, interests): 13 | self.cities = cities 14 | self.origin = origin 15 | self.interests = interests 16 | self.date_range = date_range 17 | self.output_placeholder = st.empty() 18 | 19 | def run(self): 20 | agents = TripAgents() 21 | tasks = TripTasks() 22 | 23 | city_selector_agent = agents.city_selection_agent() 24 | local_expert_agent = agents.local_expert() 25 | travel_concierge_agent = agents.travel_concierge() 26 | 27 | identify_task = tasks.identify_task( 28 | city_selector_agent, 29 | self.origin, 30 | self.cities, 31 | self.interests, 32 | self.date_range 33 | ) 34 | 35 | gather_task = tasks.gather_task( 36 | local_expert_agent, 37 | self.origin, 38 | self.interests, 39 | self.date_range 40 | ) 41 | 42 | plan_task = tasks.plan_task( 43 | travel_concierge_agent, 44 | self.origin, 45 | self.interests, 46 | self.date_range 47 | ) 48 | 49 | crew = Crew( 50 | agents=[ 51 | city_selector_agent, local_expert_agent, travel_concierge_agent 52 | ], 53 | tasks=[identify_task, gather_task, plan_task], 54 | verbose=True 55 | ) 56 | 57 | result = crew.kickoff() 58 | self.output_placeholder.markdown(result) 59 | 60 | return result 61 | 62 | 63 | if __name__ == "__main__": 64 | st.title("AI Travel Agent") 65 | st.subheader("Let AI agents plan your next vacation!", 66 | divider="rainbow", anchor=False) 67 | 68 | import datetime 69 | 70 | today = datetime.datetime.now().date() 71 | next_year = today.year + 1 72 | jan_16_next_year = datetime.date(next_year, 1, 10) 73 | 74 | with st.sidebar: 75 | st.header("👇 Enter your trip details") 76 | with st.form("my_form"): 77 | location = st.text_input( 78 | "Current location", placeholder="San Mateo, CA") 79 | cities = st.text_input( 80 | "Destination", placeholder="Bali, Indonesia") 81 | date_range = st.date_input( 82 | "Travel date range", 83 | min_value=today, 84 | value=(today, jan_16_next_year + datetime.timedelta(days=6)), 85 | format="MM/DD/YYYY", 86 | ) 87 | interests = st.text_area("High level interests and hobbies or extra details about your trip?", 88 | placeholder="2 adults who love swimming, dancing, hiking, and eating") 89 | 90 | submitted = st.form_submit_button("Submit") 91 | 92 | if submitted: 93 | with st.status("🤖 **Agents at work...**", state="running", expanded=True) as status: 94 | with st.container(height=500, border=False): 95 | sys.stdout = StreamToExpander(st) 96 | trip_crew = TripCrew(location, cities, date_range, interests) 97 | result = trip_crew.run() 98 | status.update(label="✅ Trip Plan Ready!", 99 | state="complete", expanded=False) 100 | 101 | st.subheader("Here is your Trip Plan", anchor=False, divider="rainbow") 102 | st.markdown(result) -------------------------------------------------------------------------------- /recipes/crewai/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/streamlit/cookbook/cae857225ae429b62351a281c49abfbd5346d08f/recipes/crewai/tools/__init__.py -------------------------------------------------------------------------------- /recipes/crewai/tools/browser_tools.py: -------------------------------------------------------------------------------- 1 | import json 2 | import requests 3 | import streamlit as st 4 | from crewai import Agent, Task 5 | from langchain.tools import tool 6 | from unstructured.partition.html import partition_html 7 | 8 | 9 | class BrowserTools(): 10 | 11 | @tool("Scrape website content") 12 | def scrape_and_summarize_website(website): 13 | """Useful to scrape and summarize a website content""" 14 | url = f"https://chrome.browserless.io/content?token={st.secrets['BROWSERLESS_API_KEY']}" 15 | payload = json.dumps({"url": website}) 16 | headers = {'cache-control': 'no-cache', 'content-type': 'application/json'} 17 | response = requests.request("POST", url, headers=headers, data=payload) 18 | elements = partition_html(text=response.text) 19 | content = "\n\n".join([str(el) for el in elements]) 20 | content = [content[i:i + 8000] for i in range(0, len(content), 8000)] 21 | summaries = [] 22 | for chunk in content: 23 | agent = Agent( 24 | role='Principal Researcher', 25 | goal= 26 | 'Do amazing researches and summaries based on the content you are working with', 27 | backstory= 28 | "You're a Principal Researcher at a big company and you need to do a research about a given topic.", 29 | allow_delegation=False) 30 | task = Task( 31 | agent=agent, 32 | description= 33 | f'Analyze and summarize the content bellow, make sure to include the most relevant information in the summary, return only the summary nothing else.\n\nCONTENT\n----------\n{chunk}' 34 | ) 35 | summary = task.execute() 36 | summaries.append(summary) 37 | return "\n\n".join(summaries) -------------------------------------------------------------------------------- /recipes/crewai/tools/calculator_tools.py: -------------------------------------------------------------------------------- 1 | from langchain.tools import tool 2 | 3 | 4 | class CalculatorTools(): 5 | 6 | @tool("Make a calcualtion") 7 | def calculate(operation): 8 | """Useful to perform any mathematical calculations, 9 | like sum, minus, multiplication, division, etc. 10 | The input to this tool should be a mathematical 11 | expression, a couple examples are `200*7` or `5000/2*10` 12 | """ 13 | return eval(operation) -------------------------------------------------------------------------------- /recipes/crewai/tools/search_tools.py: -------------------------------------------------------------------------------- 1 | import json 2 | import requests 3 | import streamlit as st 4 | from langchain.tools import tool 5 | 6 | 7 | class SearchTools(): 8 | 9 | @tool("Search the internet") 10 | def search_internet(query): 11 | """Useful to search the internet 12 | about a a given topic and return relevant results""" 13 | top_result_to_return = 4 14 | url = "https://google.serper.dev/search" 15 | payload = json.dumps({"q": query}) 16 | headers = { 17 | 'X-API-KEY': st.secrets['SERPER_API_KEY'], 18 | 'content-type': 'application/json' 19 | } 20 | response = requests.request("POST", url, headers=headers, data=payload) 21 | # check if there is an organic key 22 | if 'organic' not in response.json(): 23 | return "Sorry, I couldn't find anything about that, there could be an error with you serper api key." 24 | else: 25 | results = response.json()['organic'] 26 | string = [] 27 | for result in results[:top_result_to_return]: 28 | try: 29 | string.append('\n'.join([ 30 | f"Title: {result['title']}", f"Link: {result['link']}", 31 | f"Snippet: {result['snippet']}", "\n-----------------" 32 | ])) 33 | except KeyError: 34 | next 35 | 36 | return '\n'.join(string) -------------------------------------------------------------------------------- /recipes/crewai/trip_agents.py: -------------------------------------------------------------------------------- 1 | from crewai import Agent 2 | import re 3 | import streamlit as st 4 | from langchain_community.llms import OpenAI 5 | from tools.browser_tools import BrowserTools 6 | from tools.calculator_tools import CalculatorTools 7 | from tools.search_tools import SearchTools 8 | 9 | 10 | class TripAgents(): 11 | def city_selection_agent(self): 12 | return Agent( 13 | role='City Selection Expert', 14 | goal='Select the best city based on weather, season, and prices', 15 | backstory='An expert in analyzing travel data to pick ideal destinations', 16 | tools=[ 17 | SearchTools.search_internet, 18 | BrowserTools.scrape_and_summarize_website, 19 | ], 20 | verbose=True, 21 | ) 22 | 23 | def local_expert(self): 24 | return Agent( 25 | role='Local Expert at this city', 26 | goal='Provide the BEST insights about the selected city', 27 | backstory="""A knowledgeable local guide with extensive information 28 | about the city, it's attractions and customs""", 29 | tools=[ 30 | SearchTools.search_internet, 31 | BrowserTools.scrape_and_summarize_website, 32 | ], 33 | verbose=True, 34 | ) 35 | 36 | def travel_concierge(self): 37 | return Agent( 38 | role='Amazing Travel Concierge', 39 | goal="""Create the most amazing travel itineraries with budget and 40 | packing suggestions for the city""", 41 | backstory="""Specialist in travel planning and logistics with 42 | decades of experience""", 43 | tools=[ 44 | SearchTools.search_internet, 45 | BrowserTools.scrape_and_summarize_website, 46 | CalculatorTools.calculate, 47 | ], 48 | verbose=True, 49 | ) 50 | 51 | class StreamToExpander: 52 | def __init__(self, expander): 53 | self.expander = expander 54 | self.buffer = [] 55 | self.colors = ['red', 'green', 'blue', 'orange'] # Define a list of colors 56 | self.color_index = 0 # Initialize color index 57 | 58 | def write(self, data): 59 | # Filter out ANSI escape codes using a regular expression 60 | cleaned_data = re.sub(r'\x1B\[[0-9;]*[mK]', '', data) 61 | 62 | # Check if the data contains 'task' information 63 | task_match_object = re.search(r'\"task\"\s*:\s*\"(.*?)\"', cleaned_data, re.IGNORECASE) 64 | task_match_input = re.search(r'task\s*:\s*([^\n]*)', cleaned_data, re.IGNORECASE) 65 | task_value = None 66 | if task_match_object: 67 | task_value = task_match_object.group(1) 68 | elif task_match_input: 69 | task_value = task_match_input.group(1).strip() 70 | 71 | if task_value: 72 | st.toast(":robot_face: " + task_value) 73 | 74 | # Check if the text contains the specified phrase and apply color 75 | if "Entering new CrewAgentExecutor chain" in cleaned_data: 76 | # Apply different color and switch color index 77 | self.color_index = (self.color_index + 1) % len(self.colors) # Increment color index and wrap around if necessary 78 | 79 | cleaned_data = cleaned_data.replace("Entering new CrewAgentExecutor chain", f":{self.colors[self.color_index]}[Entering new CrewAgentExecutor chain]") 80 | 81 | if "City Selection Expert" in cleaned_data: 82 | # Apply different color 83 | cleaned_data = cleaned_data.replace("City Selection Expert", f":{self.colors[self.color_index]}[City Selection Expert]") 84 | if "Local Expert at this city" in cleaned_data: 85 | cleaned_data = cleaned_data.replace("Local Expert at this city", f":{self.colors[self.color_index]}[Local Expert at this city]") 86 | if "Amazing Travel Concierge" in cleaned_data: 87 | cleaned_data = cleaned_data.replace("Amazing Travel Concierge", f":{self.colors[self.color_index]}[Amazing Travel Concierge]") 88 | if "Finished chain." in cleaned_data: 89 | cleaned_data = cleaned_data.replace("Finished chain.", f":{self.colors[self.color_index]}[Finished chain.]") 90 | 91 | self.buffer.append(cleaned_data) 92 | if "\n" in data: 93 | self.expander.markdown(''.join(self.buffer), unsafe_allow_html=True) 94 | self.buffer = [] -------------------------------------------------------------------------------- /recipes/crewai/trip_tasks.py: -------------------------------------------------------------------------------- 1 | from crewai import Task 2 | from textwrap import dedent 3 | 4 | 5 | class TripTasks(): 6 | 7 | def identify_task(self, agent, origin, cities, interests, range): 8 | return Task(description=dedent(f""" 9 | Analyze and select the best city for the trip based 10 | on specific criteria such as weather patterns, seasonal 11 | events, and travel costs. This task involves comparing 12 | multiple cities, considering factors like current weather 13 | conditions, upcoming cultural or seasonal events, and 14 | overall travel expenses. 15 | 16 | Your final answer must be a detailed 17 | report on the chosen city, and everything you found out 18 | about it, including the actual flight costs, weather 19 | forecast and attractions. 20 | {self.__tip_section()} 21 | 22 | Traveling from: {origin} 23 | City Options: {cities} 24 | Trip Date: {range} 25 | Traveler Interests: {interests} 26 | """), 27 | expected_output="A detailed report on the chosen city with flight costs, weather forecast, and attractions.", 28 | agent=agent) 29 | 30 | def gather_task(self, agent, origin, interests, range): 31 | return Task(description=dedent(f""" 32 | As a local expert on this city you must compile an 33 | in-depth guide for someone traveling there and wanting 34 | to have THE BEST trip ever! 35 | Gather information about key attractions, local customs, 36 | special events, and daily activity recommendations. 37 | Find the best spots to go to, the kind of place only a 38 | local would know. 39 | This guide should provide a thorough overview of what 40 | the city has to offer, including hidden gems, cultural 41 | hotspots, must-visit landmarks, weather forecasts, and 42 | high level costs. 43 | 44 | The final answer must be a comprehensive city guide, 45 | rich in cultural insights and practical tips, 46 | tailored to enhance the travel experience. 47 | {self.__tip_section()} 48 | 49 | Trip Date: {range} 50 | Traveling from: {origin} 51 | Traveler Interests: {interests} 52 | """), 53 | expected_output="A comprehensive city guide with cultural insights and practical tips.", 54 | agent=agent) 55 | 56 | def plan_task(self, agent, origin, interests, range): 57 | return Task(description=dedent(f""" 58 | Expand this guide into a full travel 59 | itinerary for this time {range} with detailed per-day plans, including 60 | weather forecasts, places to eat, packing suggestions, 61 | and a budget breakdown. 62 | 63 | You MUST suggest actual places to visit, actual hotels 64 | to stay and actual restaurants to go to. 65 | 66 | This itinerary should cover all aspects of the trip, 67 | from arrival to departure, integrating the city guide 68 | information with practical travel logistics. 69 | 70 | Your final answer MUST be a complete expanded travel plan, 71 | formatted as markdown, encompassing a daily schedule, 72 | anticipated weather conditions, recommended clothing and 73 | items to pack, and a detailed budget, ensuring THE BEST 74 | TRIP EVER, Be specific and give it a reason why you picked 75 | # up each place, what make them special! {self.__tip_section()} 76 | 77 | Trip Date: {range} 78 | Traveling from: {origin} 79 | Traveler Interests: {interests} 80 | """), 81 | expected_output="A complete 7-day travel plan, formatted as markdown, with a daily schedule and budget.", 82 | agent=agent) 83 | 84 | def __tip_section(self): 85 | return "If you do your BEST WORK, I'll tip you $100 and grant you any wish you want!" -------------------------------------------------------------------------------- /recipes/llamaindex/.gitignore: -------------------------------------------------------------------------------- 1 | .env.* 2 | .env 3 | index.html.* 4 | index.html 5 | task_results 6 | .ipynb_checkpoints/ 7 | secrets.yaml 8 | Dockerfile.local 9 | docker-compose.local.yml 10 | pyproject.local.toml 11 | __pycache__ 12 | data 13 | notebooks 14 | .DS_Store 15 | finetuning_examples 16 | finetuning_assets 17 | -------------------------------------------------------------------------------- /recipes/llamaindex/README.md: -------------------------------------------------------------------------------- 1 | # ARC Task (LLM) Solver With Human Input 2 | 3 | The Abstraction and Reasoning Corpus ([ARC](https://github.com/fchollet/ARC-AGI)) for Artificial General Intelligence 4 | benchmark aims to measure an AI system's ability to efficiently learn new skills. 5 | Each task within the ARC benchmark contains a unique puzzle for which the systems 6 | attempt to solve. Currently, the best AI systems achieve 34% solve rates, whereas 7 | humans are able to achieve 85% ([source](https://www.kaggle.com/competitions/arc-prize-2024/overview/prizes)). 8 | 9 |

10 | cover 11 |

12 | 13 | Motivated by this large disparity, we built this app with the goal of injecting 14 | human-level reasoning on this benchmark to LLMs. Specifically, this app enables 15 | the collaboration of LLMs and humans to solve an ARC task; and these collaborations 16 | can then be used for fine-tuning the LLM. 17 | 18 | The Solver itself is a LlamaIndex `Workflow` that relies on successive runs for 19 | which `Context` is maintained from previous runs. Doing so allows for an 20 | effective implementation of the Human In the Loop Pattern. 21 | 22 |

23 | cover 24 |

25 | 26 | ## Running The App 27 | 28 | Before running the streamlit app, we first must download the ARC dataset. The 29 | below command will download the dataset and store it in a directory named `data/`: 30 | 31 | ```sh 32 | wget https://github.com/fchollet/ARC-AGI/archive/refs/heads/master.zip -O ./master.zip 33 | unzip ./master.zip -d ./ 34 | mv ARC-AGI-master/data ./ 35 | rm -rf ARC-AGI-master 36 | rm master.zip 37 | ``` 38 | 39 | Next, we must install the app's dependencies. To do so, we can use `poetry`: 40 | 41 | ```sh 42 | poetry shell 43 | poetry install 44 | ``` 45 | 46 | Finally, to run the streamlit app: 47 | 48 | ```sh 49 | export OPENAI_API_KEY= && streamlit run arc_finetuning_st/streamlit/app.py 50 | ``` 51 | 52 | ## How To Use The App 53 | 54 | In the next two sections, we discuss how to use the app in order to solve a given 55 | ARC task. 56 | 57 |

58 | cover 59 |

60 | 61 | ### Solving an ARC Task 62 | 63 | Each ARC task consists of training examples, each of which consist of input and 64 | output pairs. There exists a common pattern between these input and output pairs, 65 | and the problem is solved by uncovering this pattern, which can be verified by 66 | the included test examples. 67 | 68 | To solve the task, we cycle through the following three steps: 69 | 70 | 1. Prediction (of test output grid) 71 | 2. Evaluation 72 | 3. Critique (human in the loop) 73 | 74 | (Under the hood a LlamaIndex `Workflow` implements these three `steps`.) 75 | 76 | Step 1. makes use of an LLM to produce the Prediction whereas Step 2. is 77 | deterministic and is a mere comparison between the ground truth test output and 78 | the Prediction. If the Prediction doesn't match the ground truth grid, then Step 3. 79 | is performed. Similar to step 1. an LLM is prompted to generate a Critique on the 80 | Prediction as to why it may not match the pattern underlying the train input and 81 | output pairs. However, we also allow for a human in the loop to override this 82 | LLM generated Critique. 83 | 84 | The Critique is carried on from a previous cycle onto the next in order to 85 | generate an improved and hopefully correct next Prediction. 86 | 87 | To begin, click the `Start` button found in the top-right corner. If the 88 | prediction is incorrect, you can view the Critique produced by the LLM in the 89 | designated text area. You can choose to use this Critique or supply your own by 90 | overwriting the text and applying the change. Once ready to produce the next 91 | prediction, hit the `Continue` button. 92 | 93 | ### Saving solutions for fine-tuning 94 | 95 | Any collaboration session involving the LLM and human can be saved and used to 96 | finetune an LLM. In this app, we use OpenAI LLMs, and so the finetuning examples 97 | adhere to the [OpenAI fine-tuning API](https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset). 98 | Click the `fine-tuning example` button during a session to see the current 99 | example that can be used for fine-tuning. 100 | 101 |

102 | cover 103 |

104 | 105 | ## Fine-tuning (with `arc-finetuning-cli`) 106 | 107 | After you've created your finetuning examples (you'll need at least 10 of them), 108 | you can submit a job to OpenAI to finetune an LLM on them. To do so, we have a 109 | convenient command line tool, that is powered by LlamaIndex plugins such as 110 | `llama-index-finetuning`. 111 | 112 | ```sh 113 | arc finetuning cli tool. 114 | 115 | options: 116 | -h, --help show this help message and exit 117 | 118 | commands: 119 | {evaluate,finetune,job-status} 120 | evaluate Evaluation of ARC Task predictions with LLM and ARCTaskSolverWorkflow. 121 | finetune Finetune OpenAI LLM on ARC Task Solver examples. 122 | job-status Check the status of finetuning job. 123 | ``` 124 | 125 | ### Submitting a fine-tuning job 126 | 127 | To submit a fine-tuning job, use any of the following three `finetune` command: 128 | 129 | ```sh 130 | # submit a new finetune job using the specified llm 131 | arc-finetuning-cli finetune --llm gpt-4o-2024-08-06 132 | 133 | # submit a new finetune job that continues from previously finetuned model 134 | arc-finetuning-cli finetune --llm gpt-4o-2024-08-06 --start-job-id ftjob-TqJd5Nfe3GIiScyTTJH56l61 135 | 136 | # submit a new finetune job that continues from the most recent finetuned model 137 | arc-finetuning-cli finetune --continue-latest 138 | ``` 139 | 140 | The commands above will take care of compiling all of the single finetuning json 141 | examples (i.e. stored in `finetuning_examples/`) into a single `jsonl` file that 142 | is then passed to OpenAI finetuning API. 143 | 144 | ### Checking the status of a fine-tuning job 145 | 146 | After submitting a job, you can check its status using the below cli commands: 147 | 148 | ```sh 149 | arc-finetuning-cli job-status -j ftjob-WYySY3iGYpfiTbSDeKDZO0YL -m gpt-4o-2024-08-06 150 | 151 | # or check status of the latest job submission 152 | arc-finetuning-cli job-status --latest 153 | ``` 154 | 155 | ## Evaluation 156 | 157 | You can evaluate the `ARCTaskSolverWorkflow` and a specified LLM on the ARC test 158 | dataset. You can even supply a fine-tuned LLM here. 159 | 160 | ```sh 161 | # evaluate ARCTaskSolverWorkflow single attempt with gpt-4o 162 | arc-finetuning-cli evaluate --llm gpt-4o-2024-08-06 163 | 164 | # evaluate ARCTaskSolverWorkflow single attempt with a previously fine-tuned gpt-4o 165 | arc-finetuning-cli evaluate --llm gpt-4o-2024-08-06 --start-job-id ftjob-TqJd5Nfe3GIiScyTTJH56l61 166 | ``` 167 | 168 | You can also specify certain parameters to control the speed of the execution so 169 | as to not run into `RateLimitError`'s from OpenAI. 170 | 171 | ```sh 172 | arc-finetuning-cli evaluate --llm gpt-4o-2024-08-06 --batch-size 5 --num-workers 3 --sleep 10 173 | ``` 174 | 175 | In the above command, `batch-size` refers to the number of test cases handled in 176 | single batch. In total, there are 400 test cases. Moreover, `num-workers` is the 177 | maximum number of async calls allowed to be made to OpenAI API at any given moment. 178 | Finally, `sleep` is the amount of time in seconds the execution halts before moving 179 | onto the next batch of test cases. 180 | -------------------------------------------------------------------------------- /recipes/llamaindex/arc_finetuning_st/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/streamlit/cookbook/cae857225ae429b62351a281c49abfbd5346d08f/recipes/llamaindex/arc_finetuning_st/__init__.py -------------------------------------------------------------------------------- /recipes/llamaindex/arc_finetuning_st/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/streamlit/cookbook/cae857225ae429b62351a281c49abfbd5346d08f/recipes/llamaindex/arc_finetuning_st/cli/__init__.py -------------------------------------------------------------------------------- /recipes/llamaindex/arc_finetuning_st/cli/command_line.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import json 4 | from os import listdir 5 | from pathlib import Path 6 | from typing import Any, List, Optional, cast 7 | 8 | from llama_index.llms.openai import OpenAI 9 | 10 | from arc_finetuning_st.cli.evaluation import batch_runner 11 | from arc_finetuning_st.cli.finetune import ( 12 | FINETUNE_JOBS_FILENAME, 13 | check_job_status, 14 | prepare_finetuning_jsonl_file, 15 | submit_finetune_job, 16 | ) 17 | from arc_finetuning_st.workflows.arc_task_solver import ( 18 | ARCTaskSolverWorkflow, 19 | WorkflowOutput, 20 | ) 21 | 22 | SINGLE_EXAMPLE_JSON_PATH = Path( 23 | Path(__file__).parents[2].absolute(), "finetuning_examples" 24 | ) 25 | FINETUNING_ASSETS_PATH = Path( 26 | Path(__file__).parents[2].absolute(), "finetuning_assets" 27 | ) 28 | 29 | 30 | def handle_evaluate( 31 | llm: str, 32 | batch_size: int, 33 | num_workers: int, 34 | verbose: bool, 35 | sleep: int, 36 | **kwargs: Any, 37 | ) -> None: 38 | data_path = Path( 39 | Path(__file__).parents[2].absolute(), "data", "evaluation" 40 | ) 41 | task_paths = [data_path / t for t in listdir(data_path)] 42 | llm = OpenAI(llm) 43 | w = ARCTaskSolverWorkflow(llm=llm, timeout=None) 44 | results = asyncio.run( 45 | batch_runner( 46 | w, 47 | task_paths[:10], 48 | verbose=verbose, 49 | batch_size=batch_size, 50 | num_workers=num_workers, 51 | sleep=sleep, 52 | ) 53 | ) 54 | results = cast(List[WorkflowOutput], results) 55 | num_solved = sum(el.passing for el in results) 56 | print( 57 | f"Solved: {num_solved}\nTotal Tasks:{len(results)}\nAverage Solve Rate: {float(num_solved) / len(results)}" 58 | ) 59 | 60 | 61 | def handle_finetune_job_submit( 62 | llm: str, 63 | start_job_id: Optional[str], 64 | continue_latest: bool = False, 65 | **kwargs: Any, 66 | ) -> None: 67 | prepare_finetuning_jsonl_file( 68 | json_path=SINGLE_EXAMPLE_JSON_PATH, assets_path=FINETUNING_ASSETS_PATH 69 | ) 70 | if continue_latest: 71 | try: 72 | with open(FINETUNING_ASSETS_PATH / FINETUNE_JOBS_FILENAME) as f: 73 | lines = f.read().splitlines() 74 | metadata_str = lines[-1] 75 | metadata = json.loads(metadata_str) 76 | start_job_id = metadata["start_job_id"] 77 | llm = metadata["model"] 78 | except FileNotFoundError: 79 | # no previous finetune model 80 | raise ValueError( 81 | "Missing `finetuning_jobs.jsonl` file. Have you submitted a prior job?" 82 | ) 83 | 84 | submit_finetune_job( 85 | llm=llm, 86 | start_job_id=start_job_id, 87 | assets_path=FINETUNING_ASSETS_PATH, 88 | ) 89 | 90 | 91 | def handle_check_finetune_job( 92 | start_job_id: Optional[str], 93 | llm: Optional[str], 94 | latest: bool, 95 | **kwargs: Any, 96 | ) -> None: 97 | if latest: 98 | try: 99 | with open(FINETUNING_ASSETS_PATH / FINETUNE_JOBS_FILENAME) as f: 100 | lines = f.read().splitlines() 101 | metadata_str = lines[-1] 102 | metadata = json.loads(metadata_str) 103 | start_job_id = metadata["start_job_id"] 104 | llm = metadata["model"] 105 | except FileNotFoundError: 106 | raise ValueError( 107 | "No finetuning_jobs.json file exists. You likely haven't submitted a job yet." 108 | ) 109 | if not latest and (start_job_id is None or llm is None): 110 | raise ValueError( 111 | "If not `use_latest` then must provide `start_job_id` and `llm`." 112 | ) 113 | 114 | # make type checking happy 115 | if start_job_id and llm: 116 | check_job_status( 117 | start_job_id=start_job_id, 118 | llm=llm, 119 | assets_path=FINETUNING_ASSETS_PATH, 120 | ) 121 | 122 | 123 | def main() -> None: 124 | parser = argparse.ArgumentParser(description="arc-finetuning cli tool.") 125 | 126 | # Subparsers 127 | subparsers = parser.add_subparsers( 128 | title="commands", dest="command", required=True 129 | ) 130 | 131 | # evaluate command 132 | evaluate_parser = subparsers.add_parser( 133 | "evaluate", 134 | help="Evaluation of ARC Task predictions with LLM and ARCTaskSolverWorkflow.", 135 | ) 136 | evaluate_parser.add_argument( 137 | "-m", 138 | "--llm", 139 | type=str, 140 | default="gpt-4o", 141 | help="The OpenAI LLM model to use with the Workflow.", 142 | ) 143 | evaluate_parser.add_argument("-b", "--batch-size", type=int, default=5) 144 | evaluate_parser.add_argument("-w", "--num-workers", type=int, default=3) 145 | evaluate_parser.add_argument( 146 | "-v", "--verbose", action=argparse.BooleanOptionalAction 147 | ) 148 | evaluate_parser.add_argument("-s", "--sleep", type=int, default=10) 149 | evaluate_parser.set_defaults( 150 | func=lambda args: handle_evaluate(**vars(args)) 151 | ) 152 | 153 | # finetune command 154 | finetune_parser = subparsers.add_parser( 155 | "finetune", help="Finetune OpenAI LLM on ARC Task Solver examples." 156 | ) 157 | finetune_parser.add_argument( 158 | "-m", 159 | "--llm", 160 | type=str, 161 | default="gpt-4o-2024-08-06", 162 | help="The OpenAI LLM model to finetune.", 163 | ) 164 | finetune_parser.add_argument( 165 | "-j", 166 | "--start-job-id", 167 | type=str, 168 | default=None, 169 | help="Previously started job id, to continue finetuning.", 170 | ) 171 | finetune_parser.add_argument( 172 | "--continue-latest", action=argparse.BooleanOptionalAction 173 | ) 174 | finetune_parser.set_defaults( 175 | func=lambda args: handle_finetune_job_submit(**vars(args)) 176 | ) 177 | 178 | # job status command 179 | job_status_parser = subparsers.add_parser( 180 | "job-status", help="Check the status of finetuning job." 181 | ) 182 | job_status_parser.add_argument( 183 | "-j", 184 | "--start-job-id", 185 | type=str, 186 | default=None, 187 | help="Previously started job id, to continue finetuning.", 188 | ) 189 | job_status_parser.add_argument( 190 | "-m", 191 | "--llm", 192 | type=str, 193 | default="gpt-4o-2024-08-06", 194 | help="The OpenAI LLM model to finetune.", 195 | ) 196 | job_status_parser.add_argument( 197 | "--latest", 198 | action=argparse.BooleanOptionalAction, 199 | help="If set, checks the status of the last submitted job.", 200 | ) 201 | job_status_parser.set_defaults( 202 | func=lambda args: handle_check_finetune_job(**vars(args)) 203 | ) 204 | 205 | # Parse the command-line arguments 206 | args = parser.parse_args() 207 | 208 | # Call the appropriate function based on the command 209 | args.func(args) 210 | 211 | 212 | if __name__ == "__main__": 213 | main() 214 | -------------------------------------------------------------------------------- /recipes/llamaindex/arc_finetuning_st/cli/evaluation.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from os import listdir 3 | from pathlib import Path 4 | from typing import Any, List, cast 5 | 6 | from llama_index.core.async_utils import chunks 7 | from llama_index.llms.openai import OpenAI 8 | 9 | from arc_finetuning_st.workflows.arc_task_solver import ( 10 | ARCTaskSolverWorkflow, 11 | WorkflowOutput, 12 | ) 13 | 14 | DATA_PATH = Path(Path(__file__).parents[1].absolute(), "data", "evaluation") 15 | 16 | 17 | async def batch_runner( 18 | workflow: ARCTaskSolverWorkflow, 19 | task_paths: List[Path], 20 | batch_size: int = 5, 21 | verbose: bool = False, 22 | sleep: int = 10, 23 | num_workers: int = 3, 24 | ) -> List[Any]: 25 | output: List[Any] = [] 26 | sem = asyncio.Semaphore(num_workers) 27 | for task_chunk in chunks(task_paths, batch_size): 28 | task_chunk = ( 29 | workflow.load_and_run_task(task_path=task_path, sem=sem) 30 | for task_path in task_chunk 31 | if task_path is not None 32 | ) 33 | output_chunk = await asyncio.gather(*task_chunk) 34 | output.extend(output_chunk) 35 | if verbose: 36 | print( 37 | f"Completed {len(output)} out of {len(task_paths)} tasks", 38 | flush=True, 39 | ) 40 | await asyncio.sleep(sleep) 41 | return output 42 | 43 | 44 | async def main() -> None: 45 | task_paths = [DATA_PATH / t for t in listdir(DATA_PATH)] 46 | w = ARCTaskSolverWorkflow( 47 | timeout=None, verbose=False, llm=OpenAI("gpt-4o") 48 | ) 49 | results = await batch_runner(w, task_paths[:10], verbose=True) 50 | results = cast(List[WorkflowOutput], results) 51 | num_solved = sum(el.passing for el in results) 52 | print( 53 | f"Solved: {num_solved}\nTotal Tasks:{len(results)}\nAverage Solve Rate: {float(num_solved) / len(results)}" 54 | ) 55 | 56 | 57 | if __name__ == "__main__": 58 | asyncio.run(main()) 59 | -------------------------------------------------------------------------------- /recipes/llamaindex/arc_finetuning_st/cli/finetune.py: -------------------------------------------------------------------------------- 1 | import json 2 | from os import listdir 3 | from pathlib import Path 4 | from typing import Optional 5 | 6 | from llama_index.finetuning import OpenAIFinetuneEngine 7 | 8 | SINGLE_EXAMPLE_JSON_PATH = Path( 9 | Path(__file__).parents[1].absolute(), "finetuning_examples" 10 | ) 11 | 12 | FINETUNING_ASSETS_PATH = Path( 13 | Path(__file__).parents[1].absolute(), "finetuning_assets" 14 | ) 15 | 16 | FINETUNE_JSONL_FILENAME = "finetuning.jsonl" 17 | FINETUNE_JOBS_FILENAME = "finetuning_jobs.jsonl" 18 | 19 | 20 | def prepare_finetuning_jsonl_file( 21 | json_path: Path = SINGLE_EXAMPLE_JSON_PATH, 22 | assets_path: Path = FINETUNING_ASSETS_PATH, 23 | ) -> None: 24 | """Read all json files from data path and write a jsonl file.""" 25 | with open(assets_path / FINETUNE_JSONL_FILENAME, "w") as jsonl_out: 26 | for json_name in listdir(json_path): 27 | with open(json_path / json_name) as f: 28 | for line in f: 29 | jsonl_out.write(line) 30 | jsonl_out.write("\n") 31 | 32 | 33 | def submit_finetune_job( 34 | llm: str = "gpt-4o-2024-08-06", 35 | start_job_id: Optional[str] = None, 36 | assets_path: Path = FINETUNING_ASSETS_PATH, 37 | ) -> None: 38 | """Submit finetuning job.""" 39 | finetune_engine = OpenAIFinetuneEngine( 40 | llm, 41 | (assets_path / FINETUNE_JSONL_FILENAME).as_posix(), 42 | start_job_id=start_job_id, 43 | validate_json=False, 44 | ) 45 | finetune_engine.finetune() 46 | 47 | with open(assets_path / FINETUNE_JOBS_FILENAME, "a+") as f: 48 | metadata = { 49 | "model": llm, 50 | "start_job_id": finetune_engine._start_job.id, 51 | } 52 | json.dump(metadata, f) 53 | f.write("\n") 54 | 55 | print(finetune_engine.get_current_job()) 56 | 57 | 58 | def check_job_status( 59 | start_job_id: str, 60 | llm: str = "gpt-4o-2024-08-06", 61 | assets_path: Path = FINETUNING_ASSETS_PATH, 62 | ) -> None: 63 | """Check on status of most recent submitted finetuning job.""" 64 | 65 | finetune_engine = OpenAIFinetuneEngine( 66 | llm, 67 | (assets_path / FINETUNE_JSONL_FILENAME).as_posix(), 68 | start_job_id=start_job_id, 69 | validate_json=False, 70 | ) 71 | 72 | print(finetune_engine.get_current_job()) 73 | -------------------------------------------------------------------------------- /recipes/llamaindex/arc_finetuning_st/finetuning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/streamlit/cookbook/cae857225ae429b62351a281c49abfbd5346d08f/recipes/llamaindex/arc_finetuning_st/finetuning/__init__.py -------------------------------------------------------------------------------- /recipes/llamaindex/arc_finetuning_st/finetuning/finetuning_example.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import Annotated, Any, Callable, List 4 | 5 | from llama_index.core.base.llms.types import ChatMessage, MessageRole 6 | from llama_index.core.bridge.pydantic import BaseModel, Field, WrapSerializer 7 | 8 | from arc_finetuning_st.finetuning.templates import ( 9 | ASSISTANT_TEMPLATE, 10 | SYSTEM_MESSAGE, 11 | USER_CRITIQUE_TEMPLATE, 12 | USER_TASK_TEMPLATE, 13 | ) 14 | from arc_finetuning_st.workflows.models import Attempt 15 | 16 | 17 | def remove_additional_kwargs(value: Any, handler: Callable, info: Any) -> Any: 18 | partial_result = handler(value, info) 19 | del partial_result["additional_kwargs"] 20 | return partial_result 21 | 22 | 23 | class FineTuningExample(BaseModel): 24 | messages: List[ 25 | Annotated[ChatMessage, WrapSerializer(remove_additional_kwargs)] 26 | ] 27 | task_name: str = Field(exclude=True) 28 | 29 | @classmethod 30 | def from_attempts( 31 | cls, 32 | task_name: str, 33 | examples: str, 34 | test_input: str, 35 | attempts: List[Attempt], 36 | system_message: str = SYSTEM_MESSAGE, 37 | user_task_template: str = USER_TASK_TEMPLATE, 38 | user_critique_template: str = USER_CRITIQUE_TEMPLATE, 39 | assistant_template: str = ASSISTANT_TEMPLATE, 40 | ) -> "FineTuningExample": 41 | messages = [ 42 | ChatMessage(role=MessageRole.SYSTEM, content=system_message), 43 | ChatMessage( 44 | role=MessageRole.USER, 45 | content=user_task_template.format( 46 | examples=examples, test_input=test_input 47 | ), 48 | ), 49 | ] 50 | for a in attempts: 51 | messages.extend( 52 | [ 53 | ChatMessage( 54 | role=MessageRole.ASSISTANT, 55 | content=assistant_template.format( 56 | predicted_output=str(a.prediction), 57 | rationale=a.prediction.rationale, 58 | ), 59 | ), 60 | ChatMessage( 61 | role=MessageRole.USER, 62 | content=user_critique_template.format( 63 | critique=str(a.critique) 64 | ), 65 | ), 66 | ] 67 | ) 68 | 69 | # always end with an asst message or else openai finetuning job will failt 70 | if a.critique == "This predicted output is correct.": 71 | final_asst_message = ChatMessage( 72 | role=MessageRole.ASSISTANT, 73 | content="Glad, we were able to solve the puzzle!", 74 | ) 75 | else: 76 | final_asst_message = ChatMessage( 77 | role=MessageRole.ASSISTANT, 78 | content="Thanks for the feedback. I'll incorporate this into my next prediction.", 79 | ) 80 | 81 | messages.append(final_asst_message) 82 | return cls(messages=messages, task_name=task_name) 83 | 84 | def to_json(self) -> str: 85 | data = self.model_dump() 86 | return json.dumps(data, indent=4) 87 | 88 | def write_json(self, dirpath: Path) -> None: 89 | data = self.model_dump() 90 | with open(Path(dirpath, self.task_name), "w") as f: 91 | json.dump(data, f) 92 | -------------------------------------------------------------------------------- /recipes/llamaindex/arc_finetuning_st/finetuning/templates.py: -------------------------------------------------------------------------------- 1 | SYSTEM_MESSAGE = """You are a bot that is very good at solving puzzles. You will work with the user who will present you a new puzzle to solve it. 2 | The puzzle consists of a list of EXAMPLES, each containing an INPUT/OUTPUT pair describing a pattern which is shared amongst all examples. 3 | The user will also provide a TEST INPUT for which you are to produce a predicted OUTPUT that follows the common pattern of all the examples. 4 | 5 | Your task is collaborate with the user in order to solve the problem. 6 | """ 7 | 8 | USER_TASK_TEMPLATE = """Here is a new task to solve: 9 | EXAMPLES: 10 | {examples} 11 | 12 | TEST INPUT: 13 | {test_input} 14 | """ 15 | 16 | # attempt.prediction 17 | ASSISTANT_TEMPLATE = """ 18 | PREDICTED OUTPUT: 19 | {predicted_output} 20 | 21 | RATIONALE: 22 | {rationale} 23 | """ 24 | 25 | # attempt.critique 26 | USER_CRITIQUE_TEMPLATE = """ 27 | {critique} 28 | """ 29 | -------------------------------------------------------------------------------- /recipes/llamaindex/arc_finetuning_st/streamlit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/streamlit/cookbook/cae857225ae429b62351a281c49abfbd5346d08f/recipes/llamaindex/arc_finetuning_st/streamlit/__init__.py -------------------------------------------------------------------------------- /recipes/llamaindex/arc_finetuning_st/streamlit/app.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import streamlit as st 4 | from llama_index.core.tools.function_tool import async_to_sync 5 | 6 | from arc_finetuning_st.streamlit.controller import Controller 7 | 8 | # startup 9 | st.set_page_config(layout="wide") 10 | 11 | 12 | @st.cache_resource 13 | def startup() -> Tuple[Controller,]: 14 | controller = Controller() 15 | return (controller,) 16 | 17 | 18 | (controller,) = startup() 19 | 20 | # states 21 | if "show_finetuning_preview_dialog" not in st.session_state: 22 | st.session_state["show_finetuning_preview_dialog"] = True 23 | if "disable_continue_button" not in st.session_state: 24 | st.session_state["disable_continue_button"] = True 25 | if "disable_start_button" not in st.session_state: 26 | st.session_state["disable_start_button"] = False 27 | if "disable_abort_button" not in st.session_state: 28 | st.session_state["disable_abort_button"] = True 29 | if "disable_preview_button" not in st.session_state: 30 | st.session_state["disable_preview_button"] = True 31 | if "metric_value" not in st.session_state: 32 | st.session_state["metric_value"] = "N/A" 33 | 34 | logo = '[](https://github.com/run-llama/llama-agents "Check out the llama-agents Github repo!")' 35 | st.title("ARC Task Solver with Human Input") 36 | st.markdown( 37 | f"{logo}   _Powered with LlamaIndex Worfklows_", 38 | unsafe_allow_html=True, 39 | ) 40 | 41 | # sidebar 42 | with st.sidebar: 43 | task_selection = st.radio( 44 | label="Tasks", 45 | options=controller.task_file_names, 46 | index=0, 47 | on_change=controller.selectbox_selection_change_handler, 48 | key="selected_task", 49 | format_func=controller.radio_format_task_name, 50 | ) 51 | 52 | train_col, test_col = st.columns( 53 | [1, 1], vertical_alignment="top", gap="medium" 54 | ) 55 | 56 | with train_col: 57 | st.subheader("Train Examples") 58 | with st.container(): 59 | selected_task = st.session_state.selected_task 60 | if selected_task: 61 | task = controller.load_task(selected_task) 62 | num_examples = len(task["train"]) 63 | tabs = st.tabs( 64 | [f"Example {ix}" for ix in range(1, num_examples + 1)] 65 | ) 66 | for ix, tab in enumerate(tabs): 67 | with tab: 68 | left, right = st.columns( 69 | [1, 1], vertical_alignment="top", gap="medium" 70 | ) 71 | with left: 72 | ex = task["train"][ix] 73 | grid = ex["input"] 74 | fig = Controller.plot_grid(grid, kind="input") 75 | st.plotly_chart(fig, use_container_width=True) 76 | 77 | with right: 78 | ex = task["train"][ix] 79 | grid = ex["output"] 80 | fig = Controller.plot_grid(grid, kind="output") 81 | st.plotly_chart(fig, use_container_width=True) 82 | 83 | 84 | with test_col: 85 | header_col, start_col, preview_col = st.columns( 86 | [4, 1, 2], vertical_alignment="bottom", gap="small" 87 | ) 88 | with header_col: 89 | st.subheader("Test") 90 | with start_col: 91 | st.button( 92 | "start", 93 | on_click=async_to_sync(controller.handle_prediction_click), 94 | use_container_width=True, 95 | type="primary", 96 | disabled=st.session_state.get("disable_start_button"), 97 | ) 98 | with preview_col: 99 | st.button( 100 | "fine-tuning example", 101 | on_click=async_to_sync(controller.handle_finetuning_preview_click), 102 | use_container_width=True, 103 | disabled=st.session_state.get("disable_preview_button"), 104 | key="preview_button", 105 | ) 106 | 107 | with st.container(): 108 | selected_task = st.session_state.selected_task 109 | if selected_task: 110 | task = controller.load_task(selected_task) 111 | num_cases = len(task["test"]) 112 | tabs = st.tabs( 113 | [f"Test Case {ix}" for ix in range(1, num_cases + 1)] 114 | ) 115 | for ix, tab in enumerate(tabs): 116 | with tab: 117 | left, right = st.columns( 118 | [1, 1], vertical_alignment="top", gap="medium" 119 | ) 120 | with left: 121 | ex = task["test"][ix] 122 | grid = ex["input"] 123 | fig = Controller.plot_grid(grid, kind="input") 124 | st.plotly_chart(fig, use_container_width=True) 125 | 126 | with right: 127 | prediction_fig = st.session_state.get( 128 | "prediction", None 129 | ) 130 | if prediction_fig: 131 | st.plotly_chart( 132 | prediction_fig, 133 | use_container_width=True, 134 | key="prediction", 135 | ) 136 | 137 | # metrics and past attempts 138 | with st.container(): 139 | metric_col, critique_col = st.columns( 140 | [1, 7], vertical_alignment="top" 141 | ) 142 | with metric_col: 143 | metric_value = st.session_state.get("metric_value") 144 | st.markdown(body="Passing") 145 | st.markdown(body=f"# {metric_value}") 146 | 147 | with critique_col: 148 | st.markdown(body="Critique of Attempt") 149 | st.text_area( 150 | label="This critique is passed to the LLM to generate a new prediction.", 151 | key="critique", 152 | help=( 153 | "An LLM was prompted to critique the prediction on why it might not fit the pattern. " 154 | "This critique is passed in the PROMPT in the next prediction attempt. " 155 | "Feel free to make edits to the critique or use your own." 156 | ), 157 | ) 158 | 159 | with st.expander("Past Attempts"): 160 | st.dataframe( 161 | controller.attempts_history_df, 162 | hide_index=True, 163 | selection_mode="single-row", 164 | on_select=controller.handle_workflow_run_selection, 165 | column_order=( 166 | "attempt #", 167 | "passing", 168 | "critique", 169 | "rationale", 170 | ), 171 | key="attempts_history_df", 172 | use_container_width=True, 173 | ) 174 | 175 | with st.container(): 176 | continue_col, abort_col = st.columns([3, 1]) 177 | with continue_col: 178 | st.button( 179 | "continue", 180 | on_click=async_to_sync(controller.handle_prediction_click), 181 | use_container_width=True, 182 | disabled=st.session_state.get("disable_continue_button"), 183 | key="continue_button", 184 | type="primary", 185 | ) 186 | 187 | with abort_col: 188 | 189 | @st.dialog("Are you sure you want to abort the session?") 190 | def abort_solving() -> None: 191 | st.write( 192 | "Confirm that you want to abort the session by clicking 'confirm' button below." 193 | ) 194 | if st.button("Confirm"): 195 | controller.reset() 196 | st.rerun() 197 | 198 | st.button( 199 | "abort", 200 | on_click=abort_solving, 201 | use_container_width=True, 202 | disabled=st.session_state.get("disable_abort_button"), 203 | ) 204 | -------------------------------------------------------------------------------- /recipes/llamaindex/arc_finetuning_st/streamlit/controller.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import logging 4 | from os import listdir 5 | from pathlib import Path 6 | from typing import Any, List, Literal, Optional, cast 7 | 8 | import pandas as pd 9 | import plotly.express as px 10 | import streamlit as st 11 | from llama_index.core.workflow.handler import WorkflowHandler 12 | from llama_index.llms.openai import OpenAI 13 | 14 | from arc_finetuning_st.finetuning.finetuning_example import FineTuningExample 15 | from arc_finetuning_st.workflows.arc_task_solver import ( 16 | ARCTaskSolverWorkflow, 17 | WorkflowOutput, 18 | ) 19 | from arc_finetuning_st.workflows.models import Attempt, Prediction 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | class Controller: 25 | def __init__(self) -> None: 26 | self._handler: Optional[WorkflowHandler] = None 27 | self._attempts: List[Attempt] = [] 28 | self._passing_results: List[bool] = [] 29 | self._data_path = Path( 30 | Path(__file__).parents[2].absolute(), "data", "training" 31 | ) 32 | self._finetuning_examples_path = Path( 33 | Path(__file__).parents[2].absolute(), "finetuning_examples" 34 | ) 35 | self._finetuning_examples_path.mkdir(exist_ok=True, parents=True) 36 | 37 | def reset(self) -> None: 38 | # clear prediction 39 | st.session_state.prediction = None 40 | st.session_state.disable_continue_button = True 41 | st.session_state.disable_abort_button = True 42 | st.session_state.disable_preview_button = True 43 | st.session_state.disable_start_button = False 44 | st.session_state.critique = None 45 | st.session_state.metric_value = "N/A" 46 | 47 | self._handler = None 48 | self._attempts = [] 49 | self._passing_results = [] 50 | 51 | def selectbox_selection_change_handler(self) -> None: 52 | # only reset states 53 | # loading of task is delegated to relevant calls made with each 54 | # streamlit element 55 | self.reset() 56 | 57 | @staticmethod 58 | def plot_grid( 59 | grid: List[List[int]], 60 | kind: Literal["input", "output", "prediction", "latest prediction"], 61 | ) -> Any: 62 | m = len(grid) 63 | n = len(grid[0]) 64 | fig = px.imshow( 65 | grid, 66 | text_auto=True, 67 | labels={"x": f"{kind.title()}
{m}x{n}"}, 68 | ) 69 | fig.update_coloraxes(showscale=False) 70 | fig.update_layout( 71 | yaxis={"visible": False}, 72 | xaxis={"visible": True, "showticklabels": False}, 73 | margin=dict( 74 | l=20, 75 | r=20, 76 | b=20, 77 | t=20, 78 | ), 79 | ) 80 | return fig 81 | 82 | async def show_progress_bar(self, handler: WorkflowHandler) -> None: 83 | progress_text_template = "{event} completed. Next step in progress." 84 | my_bar = st.progress(0, text="Workflow run in progress. Please wait.") 85 | num_steps = 5.0 86 | current_step = 1 87 | async for ev in handler.stream_events(): 88 | my_bar.progress( 89 | current_step / num_steps, 90 | text=progress_text_template.format(event=type(ev).__name__), 91 | ) 92 | current_step += 1 93 | my_bar.empty() 94 | 95 | def handle_abort_click(self) -> None: 96 | self.reset() 97 | 98 | async def handle_prediction_click(self) -> None: 99 | """Run workflow to generate prediction.""" 100 | selected_task = st.session_state.selected_task 101 | if selected_task: 102 | task = self.load_task(selected_task) 103 | w = ARCTaskSolverWorkflow( 104 | timeout=None, verbose=False, llm=OpenAI("gpt-4o") 105 | ) 106 | 107 | if not self._handler: # start a new solver 108 | handler = w.run(task=task) 109 | 110 | else: # continuing from past Workflow execution 111 | # need to reset this queue otherwise will use nested event loops 112 | self._handler.ctx._streaming_queue = asyncio.Queue() 113 | 114 | # use the critique and prediction str from streamlit 115 | critique = st.session_state.get("critique") 116 | self._attempts[-1].critique = critique 117 | await self._handler.ctx.set("attempts", self._attempts) 118 | 119 | # run Workflow 120 | handler = w.run(ctx=self._handler.ctx, task=task) 121 | 122 | # progress bar 123 | _ = asyncio.create_task(self.show_progress_bar(handler)) 124 | 125 | res: WorkflowOutput = await handler 126 | 127 | handler = cast(WorkflowHandler, handler) 128 | self._handler = handler 129 | self._passing_results.append(res.passing) 130 | self._attempts = res.attempts 131 | 132 | # update streamlit states 133 | grid = Prediction.prediction_str_to_int_array( 134 | prediction=str(res.attempts[-1].prediction) 135 | ) 136 | prediction_fig = Controller.plot_grid( 137 | grid, kind="latest prediction" 138 | ) 139 | st.session_state.prediction = prediction_fig 140 | st.session_state.critique = str(res.attempts[-1].critique) 141 | st.session_state.disable_continue_button = False 142 | st.session_state.disable_abort_button = False 143 | st.session_state.disable_preview_button = False 144 | st.session_state.disable_start_button = True 145 | metric_value = "✅" if res.passing else "❌" 146 | st.session_state.metric_value = metric_value 147 | 148 | @property 149 | def saved_finetuning_examples(self) -> List[str]: 150 | return listdir(self._finetuning_examples_path) 151 | 152 | @property 153 | def task_file_names(self) -> List[str]: 154 | return listdir(self._data_path) 155 | 156 | def radio_format_task_name(self, selected_task: str) -> str: 157 | if selected_task in self.saved_finetuning_examples: 158 | return f"{selected_task} ✅" 159 | return selected_task 160 | 161 | def load_task(self, selected_task: str) -> Any: 162 | task_path = Path(self._data_path, selected_task) 163 | 164 | with open(task_path) as f: 165 | task = json.load(f) 166 | return task 167 | 168 | @property 169 | def passing(self) -> Optional[bool]: 170 | if self._passing_results: 171 | return self._passing_results[-1] 172 | return None 173 | 174 | @property 175 | def attempts_history_df( 176 | self, 177 | ) -> pd.DataFrame: 178 | if self._attempts: 179 | attempt_number_list: List[int] = [] 180 | passings: List[str] = [] 181 | rationales: List[str] = [] 182 | critiques: List[str] = [] 183 | predictions: List[str] = [] 184 | for ix, a in enumerate(self._attempts): 185 | passings = ["✅" if a.passing else "❌"] + passings 186 | rationales = [a.prediction.rationale] + rationales 187 | predictions = [str(a.prediction)] + predictions 188 | critiques = [str(a.critique)] + critiques 189 | attempt_number_list = [ix + 1] + attempt_number_list 190 | return pd.DataFrame( 191 | { 192 | "attempt #": attempt_number_list, 193 | "passing": passings, 194 | "rationale": rationales, 195 | "critique": critiques, 196 | # hidden from UI 197 | "prediction": predictions, 198 | } 199 | ) 200 | return pd.DataFrame( 201 | { 202 | "attempt #": [], 203 | "passing": [], 204 | "rationale": [], 205 | "critique": [], 206 | # hidden from UI 207 | "prediction": [], 208 | } 209 | ) 210 | 211 | def handle_workflow_run_selection(self) -> None: 212 | @st.dialog("Past Attempt") 213 | def _display_attempt( 214 | fig: Any, rationale: str, critique: str, passing: bool 215 | ) -> None: 216 | st.plotly_chart( 217 | fig, 218 | use_container_width=True, 219 | key="prediction", 220 | ) 221 | st.markdown(body=f"### Passing\n{passing}") 222 | st.markdown(body=f"### Rationale\n{rationale}") 223 | st.markdown(body=f"### Critique\n{critique}") 224 | 225 | selected_rows = ( 226 | st.session_state.get("attempts_history_df") 227 | .get("selection") 228 | .get("rows") 229 | ) 230 | 231 | if selected_rows: 232 | row_ix = selected_rows[0] 233 | df_row = self.attempts_history_df.iloc[row_ix] 234 | 235 | grid = Prediction.prediction_str_to_int_array( 236 | prediction=df_row["prediction"] 237 | ) 238 | prediction_fig = Controller.plot_grid(grid, kind="prediction") 239 | 240 | _display_attempt( 241 | fig=prediction_fig, 242 | rationale=df_row["rationale"], 243 | critique=df_row["critique"], 244 | passing=df_row["passing"], 245 | ) 246 | 247 | async def handle_finetuning_preview_click(self) -> None: 248 | if self._handler: 249 | st.session_state.show_finetuning_preview_dialog = True 250 | prompt_vars = await self._handler.ctx.get("prompt_vars") 251 | 252 | @st.dialog("Finetuning Example", width="large") 253 | def _display_finetuning_example() -> None: 254 | nonlocal prompt_vars 255 | 256 | finetuning_example = FineTuningExample.from_attempts( 257 | task_name=st.session_state.selected_task, 258 | attempts=self._attempts, 259 | examples=prompt_vars["examples"], 260 | test_input=prompt_vars["test_input"], 261 | ) 262 | 263 | with st.container(height=500, border=False): 264 | save_col, close_col = st.columns([1, 1]) 265 | with save_col: 266 | if st.button("Save", use_container_width=True): 267 | finetuning_example.write_json( 268 | self._finetuning_examples_path, 269 | ) 270 | st.session_state.show_finetuning_preview_dialog = ( 271 | False 272 | ) 273 | st.rerun() 274 | with close_col: 275 | if st.button("Close", use_container_width=True): 276 | st.session_state.show_finetuning_preview_dialog = ( 277 | False 278 | ) 279 | st.rerun() 280 | 281 | st.code( 282 | finetuning_example.to_json(), 283 | language="json", 284 | wrap_lines=True, 285 | ) 286 | 287 | if st.session_state.show_finetuning_preview_dialog: 288 | _display_finetuning_example() 289 | -------------------------------------------------------------------------------- /recipes/llamaindex/arc_finetuning_st/workflows/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/streamlit/cookbook/cae857225ae429b62351a281c49abfbd5346d08f/recipes/llamaindex/arc_finetuning_st/workflows/__init__.py -------------------------------------------------------------------------------- /recipes/llamaindex/arc_finetuning_st/workflows/arc_task_solver.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | from pathlib import Path 4 | from typing import Any, Dict, List, Optional, cast 5 | 6 | from llama_index.core.bridge.pydantic import BaseModel 7 | from llama_index.core.llms import LLM 8 | from llama_index.core.workflow import ( 9 | Context, 10 | StartEvent, 11 | StopEvent, 12 | Workflow, 13 | step, 14 | ) 15 | 16 | from arc_finetuning_st.workflows.events import ( 17 | EvaluationEvent, 18 | FormatTaskEvent, 19 | PredictionEvent, 20 | ) 21 | from arc_finetuning_st.workflows.models import Attempt, Critique, Prediction 22 | from arc_finetuning_st.workflows.prompts import ( 23 | CORRECTION_PROMPT_TEMPLATE, 24 | PREDICTION_PROMPT_TEMPLATE, 25 | REFLECTION_PROMPT_TEMPLATE, 26 | ) 27 | 28 | EXAMPLE_TEMPLATE = """=== 29 | EXAMPLE 30 | 31 | INPUT: 32 | {input} 33 | 34 | OUTPUT: 35 | {output} 36 | """ 37 | 38 | PAST_ATTEMPT_TEMPLATE = """◦◦◦ 39 | PAST ATTEMPT {past_attempt_number} 40 | 41 | PREDICTED_OUTPUT: 42 | {past_predicted_output} 43 | 44 | CRITIQUE: 45 | {past_critique} 46 | """ 47 | 48 | 49 | class WorkflowOutput(BaseModel): 50 | passing: bool 51 | attempts: List[Attempt] 52 | 53 | 54 | class ARCTaskSolverWorkflow(Workflow): 55 | def __init__(self, llm: LLM, max_attempts: int = 3, **kwargs: Any) -> None: 56 | super().__init__(**kwargs) 57 | self.llm = llm 58 | self._max_attempts = max_attempts 59 | 60 | def _format_past_attempt(self, attempt: Attempt, attempt_num: int) -> str: 61 | return PAST_ATTEMPT_TEMPLATE.format( 62 | past_attempt_number=attempt_num, 63 | past_predicted_output=str(attempt.prediction), 64 | past_critique=str(attempt.critique) if attempt.critique else "", 65 | ) 66 | 67 | @step 68 | async def format_task( 69 | self, ctx: Context, ev: StartEvent 70 | ) -> FormatTaskEvent: 71 | ctx.write_event_to_stream(ev) 72 | 73 | def _format_row(row: List[int]) -> str: 74 | return ",".join(str(el) for el in row) 75 | 76 | def pretty_print_grid(grid: List[List[int]]) -> str: 77 | formatted_rows = [_format_row(row) for row in grid] 78 | return "\n".join(formatted_rows) 79 | 80 | def format_train_example(train_pair: Dict) -> str: 81 | return EXAMPLE_TEMPLATE.format( 82 | input=pretty_print_grid(train_pair["input"]), 83 | output=pretty_print_grid(train_pair["output"]), 84 | ) 85 | 86 | task = ev.get("task", {}) 87 | await ctx.set("task", task) 88 | 89 | # prepare prompt_vars 90 | attempts = await ctx.get("attempts", []) 91 | if attempts: 92 | # update past predictions 93 | prompt_vars = await ctx.get("prompt_vars") 94 | formatted_past_attempts = [ 95 | self._format_past_attempt(a, ix + 1) 96 | for ix, a in enumerate(attempts) 97 | ] 98 | prompt_vars.update( 99 | past_attempts="\n".join(formatted_past_attempts) 100 | ) 101 | else: 102 | examples = [format_train_example(t) for t in task["train"]] 103 | prompt_vars = { 104 | "test_input": pretty_print_grid(task["test"][0]["input"]), 105 | "examples": "\n".join(examples), 106 | } 107 | await ctx.set("prompt_vars", prompt_vars) 108 | 109 | return FormatTaskEvent() 110 | 111 | @step 112 | async def prediction( 113 | self, ctx: Context, ev: FormatTaskEvent 114 | ) -> PredictionEvent | StopEvent: 115 | ctx.write_event_to_stream(ev) 116 | attempts = await ctx.get("attempts", []) 117 | attempts = cast(List[Attempt], attempts) 118 | prompt_vars = await ctx.get("prompt_vars") 119 | 120 | if attempts: 121 | # generating a correction from last Workflow run 122 | correction: Prediction = await self.llm.astructured_predict( 123 | Prediction, CORRECTION_PROMPT_TEMPLATE, **prompt_vars 124 | ) 125 | attempts.append(Attempt(prediction=correction)) 126 | else: 127 | # starting a new correction with no previous Workflow runs 128 | pred: Prediction = await self.llm.astructured_predict( 129 | Prediction, PREDICTION_PROMPT_TEMPLATE, **prompt_vars 130 | ) 131 | attempts = [Attempt(prediction=pred)] 132 | 133 | await ctx.set("attempts", attempts) 134 | return PredictionEvent() 135 | 136 | @step 137 | async def evaluation( 138 | self, ctx: Context, ev: PredictionEvent 139 | ) -> EvaluationEvent: 140 | ctx.write_event_to_stream(ev) 141 | task = await ctx.get("task") 142 | attempts: List[Attempt] = await ctx.get("attempts") 143 | latest_prediction = attempts[-1].prediction 144 | latest_prediction_as_array = Prediction.prediction_str_to_int_array( 145 | str(latest_prediction) 146 | ) 147 | ground_truth = task["test"][0]["output"] 148 | 149 | return EvaluationEvent( 150 | passing=(latest_prediction_as_array == ground_truth) 151 | ) 152 | 153 | @step 154 | async def reflection(self, ctx: Context, ev: EvaluationEvent) -> StopEvent: 155 | ctx.write_event_to_stream(ev) 156 | attempts = await ctx.get("attempts") 157 | attempts = cast(List[Attempt], attempts) 158 | latest_attempt = attempts[-1] 159 | 160 | # check if passing 161 | if not ev.passing: 162 | prompt_vars = await ctx.get("prompt_vars") 163 | formatted_past_attempts = [ 164 | self._format_past_attempt(a, ix + 1) 165 | for ix, a in enumerate(attempts) 166 | ] 167 | prompt_vars.update( 168 | past_attempts="\n".join(formatted_past_attempts) 169 | ) 170 | 171 | # generate critique 172 | critique: Critique = await self.llm.astructured_predict( 173 | Critique, REFLECTION_PROMPT_TEMPLATE, **prompt_vars 174 | ) 175 | 176 | # update states 177 | latest_attempt.critique = critique 178 | else: 179 | latest_attempt.critique = "This predicted output is correct." 180 | 181 | latest_attempt.passing = ev.passing 182 | attempts[-1] = latest_attempt 183 | await ctx.set("attempts", attempts) 184 | 185 | result = WorkflowOutput(passing=ev.passing, attempts=attempts) 186 | return StopEvent(result=result) 187 | 188 | async def load_and_run_task( 189 | self, 190 | task_path: Path, 191 | ctx: Optional[Context] = None, 192 | sem: Optional[asyncio.Semaphore] = None, 193 | ) -> Any: 194 | """Convenience function for loading a task json and running it.""" 195 | with open(task_path) as f: 196 | task = json.load(f) 197 | 198 | async def _run_workflow() -> Any: 199 | return await self.run(ctx=ctx, task=task) 200 | 201 | if sem: # in case running in batch with other workflow runs 202 | await sem.acquire() 203 | try: 204 | res = await _run_workflow() 205 | finally: 206 | sem.release() 207 | else: 208 | res = await _run_workflow() 209 | 210 | return res 211 | 212 | 213 | async def _test_workflow() -> None: 214 | import json 215 | from pathlib import Path 216 | 217 | from llama_index.llms.openai import OpenAI 218 | 219 | task_path = Path( 220 | Path(__file__).parents[2].absolute(), "data/training/0a938d79.json" 221 | ) 222 | with open(task_path) as f: 223 | task = json.load(f) 224 | 225 | w = ARCTaskSolverWorkflow( 226 | timeout=None, verbose=False, llm=OpenAI("gpt-4o") 227 | ) 228 | attempts = await w.run(task=task) 229 | 230 | print(attempts) 231 | 232 | 233 | if __name__ == "__main__": 234 | import asyncio 235 | 236 | asyncio.run(_test_workflow()) 237 | -------------------------------------------------------------------------------- /recipes/llamaindex/arc_finetuning_st/workflows/events.py: -------------------------------------------------------------------------------- 1 | from llama_index.core.workflow import Event 2 | 3 | 4 | class FormatTaskEvent(Event): 5 | ... 6 | 7 | 8 | class PredictionEvent(Event): 9 | ... 10 | 11 | 12 | class EvaluationEvent(Event): 13 | passing: bool 14 | 15 | 16 | class CorrectionEvent(Event): 17 | ... 18 | -------------------------------------------------------------------------------- /recipes/llamaindex/arc_finetuning_st/workflows/models.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from typing import List, Optional 3 | 4 | from llama_index.core.bridge.pydantic import BaseModel, Field 5 | 6 | 7 | class Prediction(BaseModel): 8 | """Prediction data class for LLM structured predict.""" 9 | 10 | rationale: str = Field( 11 | description="Brief description of pattern and why prediction was made. Limit to 250 words." 12 | ) 13 | prediction: str = Field( 14 | description="Predicted grid as a single string. e.g. '0,0,1\n1,1,1\n0,0,0'" 15 | ) 16 | 17 | def __str__(self) -> str: 18 | return self.prediction 19 | 20 | @staticmethod 21 | def prediction_str_to_int_array(prediction: str) -> List[List[int]]: 22 | return [ 23 | [int(a) for a in el.split(",")] for el in prediction.split("\n") 24 | ] 25 | 26 | 27 | class Critique(BaseModel): 28 | """Critique data class for LLM structured predict.""" 29 | 30 | critique: str = Field( 31 | description="Brief critique of the previous prediction and rationale. Limit to 250 words." 32 | ) 33 | 34 | def __str__(self) -> str: 35 | return self.critique 36 | 37 | 38 | class Attempt(BaseModel): 39 | """Container class of a single solution attempt.""" 40 | 41 | id_: str = Field(default_factory=lambda: str(uuid.uuid4())) 42 | prediction: Prediction 43 | critique: Optional[Critique] = Field(default=None) 44 | passing: bool = Field(default=False) 45 | -------------------------------------------------------------------------------- /recipes/llamaindex/arc_finetuning_st/workflows/prompts.py: -------------------------------------------------------------------------------- 1 | from llama_index.core.prompts import PromptTemplate 2 | 3 | PREDICTION_PROMPT_TEMPLATE = PromptTemplate( 4 | """You are a bot that is very good at solving puzzles. Below is a list of input and output pairs with a pattern. 5 | Identify the pattern in the training examples and predict the output for the provided TEST INPUT. 6 | 7 | EXAMPLES: 8 | {examples} 9 | 10 | TEST INPUT: 11 | {test_input} 12 | 13 | OUTPUT FORMAT: 14 | {{ 15 | "output": 16 | }} 17 | 18 | Return your response in JSON format given above. DO NOT RETURN markdown code. 19 | """ 20 | ) 21 | 22 | 23 | REFLECTION_PROMPT_TEMPLATE = PromptTemplate( 24 | """You are a bot that is very good at solving puzzles. Below is a list of input and output pairs that share a 25 | common pattern. The TEST INPUT also shares this common pattern, and you've previously predicted the output for it. 26 | Your task now is critique the latest prediction on why it might not fit the pattern inherent in the example input/output pairs. 27 | 28 | EXAMPLES: 29 | {examples} 30 | 31 | TEST INPUT: 32 | {test_input} 33 | 34 | PAST ATTEMPTS: 35 | {past_attempts} 36 | 37 | OUTPUT FORMAT: 38 | {{ 39 | "critique": ... 40 | }} 41 | 42 | Return your response in JSON format given above. DO NOT RETURN markdown code.""" 43 | ) 44 | 45 | CORRECTION_PROMPT_TEMPLATE = PromptTemplate( 46 | """You are a bot that is very good at solving puzzles. Below is a list of input and output pairs that share a 47 | common pattern. The TEST INPUT also shares this common pattern, and you've previously predicted the output for it. 48 | The predicted output was found to be incorrect and a critique has been articulated offering a potential 49 | reason as to why it may have been a flawed prediction. 50 | 51 | Your task now to create a new prediction that corrects from the previous attempts. Use 52 | the last attempt and critique. 53 | 54 | EXAMPLES: 55 | {examples} 56 | 57 | TEST INPUT: 58 | {test_input} 59 | 60 | PAST ATTEMPTS: 61 | {past_attempts} 62 | 63 | OUTPUT FORMAT: 64 | {{ 65 | "prediction": ... 66 | }} 67 | 68 | Return your response in JSON format given above. DO NOT RETURN markdown code.""" 69 | ) 70 | -------------------------------------------------------------------------------- /recipes/llamaindex/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["poetry-core"] 3 | build-backend = "poetry.core.masonry.api" 4 | 5 | [tool.poetry] 6 | name = "arc-finetuning-st" 7 | version = "0.1.0" 8 | description = "" 9 | authors = ["Andrei Fajardo "] 10 | readme = "README.md" 11 | 12 | [tool.poetry.dependencies] 13 | python = "^3.10" 14 | llama-index-core = "^0.11.11" 15 | llama-index-finetuning = "^0.2.1" 16 | llama-index-llms-openai = "^0.2.7" 17 | llama-index-program-openai = "^0.2.0" 18 | streamlit = "^1.38.0" 19 | plotly = "^5.24.1" 20 | pandas = "^2.2.2" 21 | 22 | [tool.poetry.group.dev.dependencies] 23 | jupyterlab = "^4.2.5" 24 | 25 | [tool.poetry.scripts] 26 | arc-finetuning-cli = 'arc_finetuning_st.cli.command_line:main' 27 | -------------------------------------------------------------------------------- /recipes/ollama/requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit==1.33.0 2 | ollama 3 | openai 4 | -------------------------------------------------------------------------------- /recipes/ollama/streamlit_app.py: -------------------------------------------------------------------------------- 1 | import ollama 2 | import streamlit as st 3 | from openai import OpenAI 4 | 5 | st.set_page_config( 6 | page_title="Streamlit Ollama Chatbot", 7 | page_icon="💬", 8 | layout="wide", 9 | initial_sidebar_state="expanded", 10 | ) 11 | 12 | 13 | def extract_model_names(models_info: list) -> tuple: 14 | """ 15 | Extracts the model names from the models information. 16 | 17 | :param models_info: A dictionary containing the models' information. 18 | 19 | Return: 20 | A tuple containing the model names. 21 | """ 22 | 23 | return tuple(model["name"] for model in models_info["models"]) 24 | 25 | 26 | def main(): 27 | """ 28 | The main function that runs the application. 29 | """ 30 | st.subheader("Streamlit Ollama Chatbot", divider="red", anchor=False) 31 | 32 | client = OpenAI( 33 | base_url="http://localhost:11434/v1", 34 | api_key="ollama", # required, but unused 35 | ) 36 | 37 | models_info = ollama.list() 38 | available_models = extract_model_names(models_info) 39 | 40 | if available_models: 41 | selected_model = st.selectbox( 42 | "Pick a model available locally on your system ↓", available_models 43 | ) 44 | 45 | else: 46 | st.warning("You have not pulled any model from Ollama yet!", icon="⚠️") 47 | st.page_link("https://ollama.com/library", label="Pull model(s) from Ollama", icon="🦙") 48 | 49 | message_container = st.container(height=500, border=True) 50 | 51 | if "messages" not in st.session_state: 52 | st.session_state.messages = [] 53 | 54 | for message in st.session_state.messages: 55 | avatar = "🤖" if message["role"] == "assistant" else "😎" 56 | with message_container.chat_message(message["role"], avatar=avatar): 57 | st.markdown(message["content"]) 58 | 59 | if prompt := st.chat_input("Enter a prompt here..."): 60 | try: 61 | st.session_state.messages.append( 62 | {"role": "user", "content": prompt}) 63 | 64 | message_container.chat_message("user", avatar="😎").markdown(prompt) 65 | 66 | with message_container.chat_message("assistant", avatar="🤖"): 67 | with st.spinner("Give i a moment..."): 68 | stream = client.chat.completions.create( 69 | model=selected_model, 70 | messages=[ 71 | {"role": m["role"], "content": m["content"]} 72 | for m in st.session_state.messages 73 | ], 74 | stream=True, 75 | ) 76 | # stream response 77 | response = st.write_stream(stream) 78 | st.session_state.messages.append( 79 | {"role": "assistant", "content": response}) 80 | 81 | except Exception as e: 82 | st.error(e, icon="⛔️") 83 | 84 | 85 | if __name__ == "__main__": 86 | main() -------------------------------------------------------------------------------- /recipes/openai/README.md: -------------------------------------------------------------------------------- 1 | # OpenAI 2 | -------------------------------------------------------------------------------- /recipes/replicate/.streamlit/secrets_template.toml: -------------------------------------------------------------------------------- 1 | #API tokens 2 | 3 | REPLICATE_API_TOKEN = "replace-this-with-your-own-token" -------------------------------------------------------------------------------- /recipes/replicate/README.md: -------------------------------------------------------------------------------- 1 | # How to run the demo Replicate Streamlit chatbot app 2 | This is a recipe for a Replicate Streamlit chatbot app. The app uses a single API to access 3 different LLMs and adjust parameters such as temperature and top-p. 3 | 4 | Other ways to explore this recipe: 5 | * [Deployed app](https://replicate-recipe.streamlit.app/) 6 | 7 | (Requires [Replicate API token](https://replicate.com/signin?next=/account/api-tokens)) 8 | * [Blog post](https://blog.streamlit.io/how-to-create-an-ai-chatbot-llm-api-replicate-streamlit/) 9 | * [Video](https://youtu.be/zsQ7EN10zj8?si=fGxg4zH7mJyrasaT) 10 | 11 | ## Prerequisites 12 | * Python >=3.8, !=3.9.7 13 | * A [Replicate API token](https://replicate.com/signin?next=/account/api-tokens) 14 | (Please note that a payment method is required to access features beyond the free trial limits.) 15 | 16 | ## Environment setup 17 | ### Local setup 18 | 1. Clone the Cookbook repo: `git clone https://github.com/streamlit/cookbook.git` 19 | 2. From the Cookbook root directory, change directory into the Replicate recipe: `cd recipes/replicate` 20 | 3. Add your Replicate API token to the `.streamlit/secrets_template.toml` file 21 | 4. Update the filename from `secrets_template.toml` to `secrets.toml`: `mv .streamlit/secrets_template.toml .streamlit/secrets.toml` 22 | 23 | (To learn more about secrets handling in Streamlit, refer to the documentation [here](https://docs.streamlit.io/develop/concepts/connections/secrets-management).) 24 | 5. Create a virtual environment: `python -m venv replicatevenv` 25 | 6. Activate the virtual environment: `source replicatevenv/bin/activate` 26 | 7. Install the dependencies: `pip install -r requirements.txt` 27 | 28 | ### GitHub Codespaces setup 29 | 1. Create a new codespace by selecting the `Codespaces` option from the `Code` button 30 | 2. Once the codespace has been generated, add your Replicate API token to the `recipes/replicate/.streamlit/secrets_template.toml` file 31 | 3. Update the filename from `secrets_template.toml` to `secrets.toml` 32 | 33 | (To learn more about secrets handling in Streamlit, refer to the documentation [here](https://docs.streamlit.io/develop/concepts/connections/secrets-management).) 34 | 4. From the Cookbook root directory, change directory into the Replicate recipe: `cd recipes/replicate` 35 | 5. Install the dependencies: `pip install -r requirements.txt` 36 | -------------------------------------------------------------------------------- /recipes/replicate/requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit 2 | replicate 3 | transformers 4 | -------------------------------------------------------------------------------- /recipes/replicate/streamlit_app.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import replicate 3 | import os 4 | from transformers import AutoTokenizer 5 | 6 | # App title 7 | st.set_page_config(page_title="Streamlit Replicate Chatbot", page_icon="💬") 8 | 9 | # Replicate Credentials 10 | with st.sidebar: 11 | st.title('💬 Streamlit Replicate Chatbot') 12 | st.write('Create chatbots using various LLM models hosted at [Replicate](https://replicate.com/).') 13 | if 'REPLICATE_API_TOKEN' in st.secrets: 14 | replicate_api = st.secrets['REPLICATE_API_TOKEN'] 15 | else: 16 | replicate_api = st.text_input('Enter Replicate API token:', type='password') 17 | if not (replicate_api.startswith('r8_') and len(replicate_api)==40): 18 | st.warning('Please enter your Replicate API token.', icon='⚠️') 19 | st.markdown("**Don't have an API token?** Head over to [Replicate](https://replicate.com) to sign up for one.") 20 | os.environ['REPLICATE_API_TOKEN'] = replicate_api 21 | 22 | st.subheader("Models and parameters") 23 | model = st.selectbox("Select a model",("meta/meta-llama-3-70b-instruct", "mistralai/mistral-7b-instruct-v0.2", "google-deepmind/gemma-2b-it"), key="model") 24 | if model == "google-deepmind/gemma-2b-it": 25 | model = "google-deepmind/gemma-2b-it:dff94eaf770e1fc211e425a50b51baa8e4cac6c39ef074681f9e39d778773626" 26 | 27 | temperature = st.sidebar.slider('temperature', min_value=0.01, max_value=5.0, value=0.7, step=0.01, help="Randomness of generated output") 28 | if temperature >= 1: 29 | st.warning('Values exceeding 1 produces more creative and random output as well as increased likelihood of hallucination.') 30 | if temperature < 0.1: 31 | st.warning('Values approaching 0 produces deterministic output. Recommended starting value is 0.7') 32 | 33 | top_p = st.sidebar.slider('top_p', min_value=0.01, max_value=1.0, value=0.9, step=0.01, help="Top p percentage of most likely tokens for output generation") 34 | 35 | # Store LLM-generated responses 36 | if "messages" not in st.session_state.keys(): 37 | st.session_state.messages = [{"role": "assistant", "content": "Ask me anything."}] 38 | 39 | # Display or clear chat messages 40 | for message in st.session_state.messages: 41 | with st.chat_message(message["role"]): 42 | st.write(message["content"]) 43 | 44 | def clear_chat_history(): 45 | st.session_state.messages = [{"role": "assistant", "content": "Ask me anything."}] 46 | 47 | st.sidebar.button('Clear chat history', on_click=clear_chat_history) 48 | 49 | @st.cache_resource(show_spinner=False) 50 | def get_tokenizer(): 51 | """Get a tokenizer to make sure we're not sending too much text 52 | text to the Model. Eventually we will replace this with ArcticTokenizer 53 | """ 54 | return AutoTokenizer.from_pretrained("huggyllama/llama-7b") 55 | 56 | def get_num_tokens(prompt): 57 | """Get the number of tokens in a given prompt""" 58 | tokenizer = get_tokenizer() 59 | tokens = tokenizer.tokenize(prompt) 60 | return len(tokens) 61 | 62 | # Function for generating model response 63 | def generate_response(): 64 | prompt = [] 65 | for dict_message in st.session_state.messages: 66 | if dict_message["role"] == "user": 67 | prompt.append("<|im_start|>user\n" + dict_message["content"] + "<|im_end|>") 68 | else: 69 | prompt.append("<|im_start|>assistant\n" + dict_message["content"] + "<|im_end|>") 70 | 71 | prompt.append("<|im_start|>assistant") 72 | prompt.append("") 73 | prompt_str = "\n".join(prompt) 74 | 75 | if get_num_tokens(prompt_str) >= 3072: 76 | st.error("Conversation length too long. Please keep it under 3072 tokens.") 77 | st.button('Clear chat history', on_click=clear_chat_history, key="clear_chat_history") 78 | st.stop() 79 | 80 | for event in replicate.stream(model, 81 | input={"prompt": prompt_str, 82 | "prompt_template": r"{prompt}", 83 | "temperature": temperature, 84 | "top_p": top_p, 85 | }): 86 | yield str(event) 87 | 88 | # User-provided prompt 89 | if prompt := st.chat_input(disabled=not replicate_api): 90 | st.session_state.messages.append({"role": "user", "content": prompt}) 91 | with st.chat_message("user"): 92 | st.write(prompt) 93 | 94 | # Generate a new response if last message is not from assistant 95 | if st.session_state.messages[-1]["role"] != "assistant": 96 | with st.chat_message("assistant"): 97 | response = generate_response() 98 | full_response = st.write_stream(response) 99 | message = {"role": "assistant", "content": full_response} 100 | st.session_state.messages.append(message) 101 | -------------------------------------------------------------------------------- /recipes/replit/README.md: -------------------------------------------------------------------------------- 1 | # Building and deploying on Replit 2 | 3 |
4 | 5 |
6 | 7 | 8 | Replit is a cloud-based development and deployment platform that allows you to build and deploy in a single, integrated environment. 9 | It comes with handy tools, like [secrets management](https://docs.replit.com/programming-ide/storing-sensitive-information-environment-variables) and [AI assisted coding](https://docs.replit.com/programming-ide/ai-code-completion), and supports _any_ language you can think of. Replit also has cloud services like [object storage](https://docs.replit.com/storage/replit-database), [embedded Postgres](https://docs.replit.com/hosting/databases/postgresql-database), and [key-value databases](https://docs.replit.com/hosting/databases/replit-database) that make it easy to build and ship apps. For a 1-minute overview of Replit, see [this video](https://www.youtube.com/watch?v=TiHq41h3nDo). 10 | 11 |
12 | 13 |
14 | 15 | ## Setup 16 | 17 | To get started on Replit, follow these easy steps: 18 | 19 | 1. Create a Replit [account](https://replit.com/), 20 | 2. Fork this Replit [template](https://replit.com/@matt/Streamlit?v=1#README.md), which contains the same code in this folder. 21 | 3. To run the app click `Run` in the nav bar. 22 | 23 | A web view pane will open up with your application. Simply edit the code in the editor to see changes in real time. 24 | 25 | **An important note:** because Replit is a cloud-based editor, the [development URL](https://docs.replit.com/additional-resources/add-a-made-with-replit-badge-to-your-webview#what-is-the-webview) is accessible from any device (while your Repl is running). That means you can test your app out on desktop and mobile simultaneously. 26 | 27 |
28 | 29 |
30 | 31 | ## Deployment 32 | 33 | Deployment on Replit is straightforward and can be done in a few steps: 34 | 35 | 1. Click `Deploy` in the top navigation bar of your Repl. 36 | 2. Choose a deployment type: 37 | - For frontend apps, select "Autoscale" deployments. These are cost-effective and automatically scale based on traffic. 38 | - For services requiring continuous execution (e.g., backends, APIs, bots), choose "Reserved VM" deployments. 39 | 3. Select a custom subdomain for your app or use the auto-generated one. 40 | 4. Configure your deployment settings, including environment variables if needed. 41 | 5. Review and confirm your deployment configuration. 42 | 43 | Important notes: 44 | - You'll need to add a payment method to deploy your app. 45 | - Additional options can be configured in the `.replit` file present in your project, note that you may have to reveal hidden files if you can't see it. For more information on configuring the `.replit` file, refer to the [docs](https://docs.replit.com/programming-ide/configuring-repl). 46 | - For detailed instructions and advanced configuration options, refer to the [official Replit deployment documentation](https://docs.replit.com/hosting/deployments/about-deployments). 47 | 48 | ### Custom Domains 49 | 50 | Replit automatically assigns your website a unique domain name such as `your-app-name-your-username.replit.app`. However, if you want to use your own domain that you've purchased through services like [CloudFlare](https://www.cloudflare.com/) or [Name-cheap](https://www.namecheap.com/), you can configure it in your Replit deployment settings. Here's how: 51 | 52 | 1. Go to your Repl's "Deployment" tab. 53 | 2. Click on "Configure custom domain". 54 | 3. Enter your custom domain name and follow the instructions to set up DNS records with your domain provider. 55 | 56 | Make sure to notice the difference between setting up a root domain and a subdomain. Root domains don't have any prefixes before the main site name (e.g., `example.com`), while subdomains do (e.g., `subdomain.example.com`). The setup process might differ slightly for each. Always refer to your domain registrar's documentation for specific instructions on how to set up DNS records for your custom domain. 57 | 58 | For more detailed information on setting up custom domains in Replit, you can refer to their official documentation [here](https://docs.replit.com/hosting/deployments/custom-domains). 59 | -------------------------------------------------------------------------------- /recipes/trulens/.env.template: -------------------------------------------------------------------------------- 1 | OPENAI_API_KEY= 2 | -------------------------------------------------------------------------------- /recipes/trulens/README.md: -------------------------------------------------------------------------------- 1 | # Streamlit × Trulens demo 2 | 3 | Ask questions about the Pacific Northwest, and get `TruLens` evaluations on the app response. 4 | 5 | To run the demo: 6 | 7 | 1. Clone the cookbook repository, and navigate to the `trulens` recipe. 8 | 2. Copy the contents of `.env.template` to `.env`, and fill in your credentials. 9 | 3. Launch the app with `streamlit run app.py` 10 | -------------------------------------------------------------------------------- /recipes/trulens/app.py: -------------------------------------------------------------------------------- 1 | __import__('pysqlite3') 2 | import sys 3 | sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') 4 | 5 | import streamlit as st 6 | import trulens.dashboard.streamlit as trulens_st 7 | from trulens.core import TruSession 8 | 9 | from base import rag, filtered_rag, tru_rag, filtered_tru_rag 10 | 11 | st.set_page_config( 12 | page_title="Use TruLens in Streamlit", 13 | page_icon="🦑", 14 | ) 15 | 16 | st.title("TruLens ❤️ Streamlit") 17 | 18 | st.write("Learn about the Pacific Northwest, and view tracing & evaluation metrics powered by TruLens 🦑.") 19 | 20 | tru = TruSession() 21 | 22 | with_filters = st.toggle("Use Context Filter Guardrails", value=False) 23 | 24 | def generate_response(input_text): 25 | if with_filters: 26 | app = filtered_tru_rag 27 | with filtered_tru_rag as recording: 28 | response = filtered_rag.query(input_text) 29 | else: 30 | app = tru_rag 31 | with tru_rag as recording: 32 | response = rag.query(input_text) 33 | 34 | record = recording.get() 35 | 36 | return record, response 37 | 38 | with st.form("my_form"): 39 | text = st.text_area( 40 | "Enter text:", "When was the University of Washington founded?" 41 | ) 42 | submitted = st.form_submit_button("Submit") 43 | if submitted: 44 | record, response = generate_response(text) 45 | st.info(response) 46 | 47 | if submitted: 48 | with st.expander("See the trace of this record 👀"): 49 | trulens_st.trulens_trace(record=record) 50 | 51 | trulens_st.trulens_feedback(record=record) 52 | 53 | -------------------------------------------------------------------------------- /recipes/trulens/base.py: -------------------------------------------------------------------------------- 1 | __import__('pysqlite3') 2 | import sys 3 | sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') 4 | 5 | import streamlit as st 6 | from openai import OpenAI 7 | import numpy as np 8 | 9 | from trulens.core import TruSession 10 | from trulens.core.guardrails.base import context_filter 11 | from trulens.apps.custom import instrument 12 | from trulens.apps.app import TruApp 13 | from trulens.providers.openai import OpenAI as OpenAIProvider 14 | from trulens.core import Feedback 15 | from trulens.core import Select 16 | from trulens.core.guardrails.base import context_filter 17 | 18 | from feedback import feedbacks, f_guardrail 19 | from vector_store import vector_store 20 | 21 | from dotenv import load_dotenv 22 | 23 | load_dotenv() 24 | 25 | oai_client = OpenAI() 26 | 27 | tru = TruSession() 28 | 29 | class RAG_from_scratch: 30 | @instrument 31 | def retrieve(self, query: str) -> list: 32 | """ 33 | Retrieve relevant text from vector store. 34 | """ 35 | results = vector_store.query(query_texts=query, n_results=4) 36 | # Flatten the list of lists into a single list 37 | return [doc for sublist in results["documents"] for doc in sublist] 38 | 39 | @instrument 40 | def generate_completion(self, query: str, context_str: list) -> str: 41 | """ 42 | Generate answer from context. 43 | """ 44 | completion = ( 45 | oai_client.chat.completions.create( 46 | model="gpt-3.5-turbo", 47 | temperature=0, 48 | messages=[ 49 | { 50 | "role": "user", 51 | "content": f"We have provided context information below. \n" 52 | f"---------------------\n" 53 | f"{context_str}" 54 | f"\n---------------------\n" 55 | f"First, say hello and that you're happy to help. \n" 56 | f"\n---------------------\n" 57 | f"Then, given this information, please answer the question: {query}", 58 | } 59 | ], 60 | ) 61 | .choices[0] 62 | .message.content 63 | ) 64 | return completion 65 | 66 | @instrument 67 | def query(self, query: str) -> str: 68 | context_str = self.retrieve(query) 69 | completion = self.generate_completion(query, context_str) 70 | return completion 71 | 72 | class filtered_RAG_from_scratch: 73 | @instrument 74 | @context_filter(f_guardrail, 0.75, keyword_for_prompt="query") 75 | def retrieve(self, query: str) -> list: 76 | """ 77 | Retrieve relevant text from vector store. 78 | """ 79 | results = vector_store.query(query_texts=query, n_results=4) 80 | return [doc for sublist in results["documents"] for doc in sublist] 81 | 82 | @instrument 83 | def generate_completion(self, query: str, context_str: list) -> str: 84 | """ 85 | Generate answer from context. 86 | """ 87 | completion = ( 88 | oai_client.chat.completions.create( 89 | model="gpt-3.5-turbo", 90 | temperature=0, 91 | messages=[ 92 | { 93 | "role": "user", 94 | "content": f"We have provided context information below. \n" 95 | f"---------------------\n" 96 | f"{context_str}" 97 | f"\n---------------------\n" 98 | f"Given this information, please answer the question: {query}", 99 | } 100 | ], 101 | ) 102 | .choices[0] 103 | .message.content 104 | ) 105 | return completion 106 | 107 | @instrument 108 | def query(self, query: str) -> str: 109 | context_str = self.retrieve(query=query) 110 | completion = self.generate_completion( 111 | query=query, context_str=context_str 112 | ) 113 | return completion 114 | 115 | 116 | filtered_rag = filtered_RAG_from_scratch() 117 | 118 | rag = RAG_from_scratch() 119 | 120 | tru_rag = TruApp( 121 | rag, 122 | app_name="RAG", 123 | app_version="v1", 124 | feedbacks=feedbacks, 125 | ) 126 | 127 | filtered_tru_rag = TruApp( 128 | filtered_rag, 129 | app_name="RAG", 130 | app_version="v2", 131 | feedbacks=feedbacks, 132 | ) 133 | -------------------------------------------------------------------------------- /recipes/trulens/feedback.py: -------------------------------------------------------------------------------- 1 | __import__('pysqlite3') 2 | import sys 3 | sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') 4 | 5 | import numpy as np 6 | from trulens.core import Feedback 7 | from trulens.core import Select 8 | from trulens.providers.openai import OpenAI as OpenAIProvider 9 | 10 | from dotenv import load_dotenv 11 | 12 | load_dotenv() 13 | 14 | provider = OpenAIProvider(model_engine="gpt-4o-mini") 15 | 16 | # Define a groundedness feedback function 17 | f_groundedness = ( 18 | Feedback( 19 | provider.groundedness_measure_with_cot_reasons, name="Groundedness" 20 | ) 21 | .on(Select.RecordCalls.retrieve.rets.collect()) 22 | .on_output() 23 | ) 24 | # Question/answer relevance between overall question and answer. 25 | f_answer_relevance = ( 26 | Feedback(provider.relevance_with_cot_reasons, name="Answer Relevance") 27 | .on_input() 28 | .on_output() 29 | ) 30 | 31 | # Context relevance between question and each context chunk. 32 | f_context_relevance = ( 33 | Feedback( 34 | provider.context_relevance_with_cot_reasons, name="Context Relevance" 35 | ) 36 | .on_input() 37 | .on(Select.RecordCalls.retrieve.rets[:]) 38 | .aggregate(np.mean) # choose a different aggregation method if you wish 39 | ) 40 | 41 | feedbacks = [f_groundedness, f_answer_relevance, f_context_relevance] 42 | 43 | # note: feedback function used for guardrail must only return a score, not also reasons 44 | f_guardrail = Feedback( 45 | provider.context_relevance, name="Context Relevance" 46 | ) 47 | -------------------------------------------------------------------------------- /recipes/trulens/requirements.txt: -------------------------------------------------------------------------------- 1 | pip==24.2 2 | openai 3 | pysqlite3-binary 4 | chromadb 5 | trulens-core @ git+https://github.com/truera/trulens#egg=trulens-core&subdirectory=src/core/ 6 | trulens-feedback @ git+https://github.com/truera/trulens#egg=trulens-feedback&subdirectory=src/feedback/ 7 | trulens-providers-openai @ git+https://github.com/truera/trulens#egg=trulens-providers-openai&subdirectory=src/providers/openai/ 8 | trulens-dashboard @ git+https://github.com/truera/trulens#egg=trulens-dashboard&subdirectory=src/dashboard/ -------------------------------------------------------------------------------- /recipes/trulens/vector_store.py: -------------------------------------------------------------------------------- 1 | __import__('pysqlite3') 2 | import sys 3 | sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') 4 | 5 | import openai 6 | import os 7 | import chromadb 8 | from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction 9 | 10 | from dotenv import load_dotenv 11 | 12 | load_dotenv() 13 | 14 | uw_info = """ 15 | The University of Washington, founded in 1861 in Seattle, is a public research university 16 | with over 45,000 students across three campuses in Seattle, Tacoma, and Bothell. 17 | As the flagship institution of the six public universities in Washington state, 18 | UW encompasses over 500 buildings and 20 million square feet of space, 19 | including one of the largest library systems in the world. 20 | """ 21 | 22 | wsu_info = """ 23 | Washington State University, commonly known as WSU, founded in 1890, is a public research university in Pullman, Washington. 24 | With multiple campuses across the state, it is the state's second largest institution of higher education. 25 | WSU is known for its programs in veterinary medicine, agriculture, engineering, architecture, and pharmacy. 26 | """ 27 | 28 | seattle_info = """ 29 | Seattle, a city on Puget Sound in the Pacific Northwest, is surrounded by water, mountains and evergreen forests, and contains thousands of acres of parkland. 30 | It's home to a large tech industry, with Microsoft and Amazon headquartered in its metropolitan area. 31 | The futuristic Space Needle, a legacy of the 1962 World's Fair, is its most iconic landmark. 32 | """ 33 | 34 | starbucks_info = """ 35 | Starbucks Corporation is an American multinational chain of coffeehouses and roastery reserves headquartered in Seattle, Washington. 36 | As the world's largest coffeehouse chain, Starbucks is seen to be the main representation of the United States' second wave of coffee culture. 37 | """ 38 | 39 | newzealand_info = """ 40 | New Zealand is an island country located in the southwestern Pacific Ocean. It comprises two main landmasses—the North Island and the South Island—and over 700 smaller islands. 41 | The country is known for its stunning landscapes, ranging from lush forests and mountains to beaches and lakes. New Zealand has a rich cultural heritage, with influences from 42 | both the indigenous Māori people and European settlers. The capital city is Wellington, while the largest city is Auckland. New Zealand is also famous for its adventure tourism, 43 | including activities like bungee jumping, skiing, and hiking. 44 | """ 45 | 46 | embedding_function = OpenAIEmbeddingFunction( 47 | api_key=os.environ.get("OPENAI_API_KEY"), 48 | model_name="text-embedding-ada-002", 49 | ) 50 | 51 | 52 | chroma_client = chromadb.Client() 53 | vector_store = chroma_client.get_or_create_collection( 54 | name="Washington", embedding_function=embedding_function 55 | ) 56 | 57 | vector_store.add("uw_info", documents=uw_info) 58 | vector_store.add("wsu_info", documents=wsu_info) 59 | vector_store.add("seattle_info", documents=seattle_info) 60 | vector_store.add("starbucks_info", documents=starbucks_info) 61 | vector_store.add("newzealand_info", documents=newzealand_info) -------------------------------------------------------------------------------- /recipes/weaviate/.streamlit/secrets_template.toml: -------------------------------------------------------------------------------- 1 | WEAVIATE_API_KEY = "your weaviate key goes here" 2 | WEAVIATE_URL = "your weaviate url goes here" 3 | COHERE_API_KEY = "your cohere api key goes here" -------------------------------------------------------------------------------- /recipes/weaviate/README.md: -------------------------------------------------------------------------------- 1 | # How to run the demo app 2 | This is a recipe for a movie recommendation app. The app uses [Weaviate](https://weaviate.io/) to create a vector database of movie titles and [Streamlit](https://streamlit.io/) to create a recommendation chatbot. 3 | 4 | Other ways to explore this recipe: 5 | * [Deployed app](https://weaviate-movie-magic.streamlit.app/) 6 | * [Video](https://youtu.be/SQD-aWlhqvM?si=t54W53G1gWnTAiwx) 7 | * [Blog post](https://blog.streamlit.io/how-to-recommendation-app-vector-database-weaviate/) 8 | 9 | ## Prerequisites 10 | * Python >=3.8, !=3.9.7 11 | * [A Weaviate API key and URL](https://auth.wcs.api.weaviate.io/auth/realms/SeMI/login-actions/registration?client_id=wcs-frontend&tab_id=5bw6GQTdWU0) 12 | * [A Cohere API key](https://dashboard.cohere.com/welcome/register) 13 | 14 | ## Environment setup 15 | ### Local setup 16 | 17 | #### Create a virtual environment 18 | 1. Clone the Cookbook repo: `git clone https://github.com/streamlit/cookbook.git` 19 | 2. From the Cookbook root directory, change directory into the recipe: `cd recipes/weaviate` 20 | 3. Add secrets to the `.streamlit/secrets_template.toml` file 21 | 4. Update the filename from `secrets_template.toml` to `secrets.toml`: `mv .streamlit/secrets_template.toml .streamlit/secrets.toml` 22 | 23 | (To learn more about secrets handling in Streamlit, refer to the documentation [here](https://docs.streamlit.io/develop/concepts/connections/secrets-management).) 24 | 5. Create a virtual environment: `python3 -m venv weaviatevenv` 25 | 6. Activate the virtual environment: `source weaviatevenv/bin/activate` 26 | 7. Install the dependencies: `pip install -r requirements.txt` 27 | 28 | #### Add data to your Weaviate Cloud 29 | 1. Create a Weaviate Cloud [Collection](https://weaviate.io/developers/weaviate/config-refs/schema#introduction) and add data to it: `python3 helpers/add_data.py` 30 | 2. (Optional) Verify the data: `python3 helpers/verify_data.py` 31 | 3. (Optional) Use the Weaviate Cloud UI to [query the Collection](https://weaviate.io/developers/weaviate/connections/connect-query#example-query): 32 | ``` 33 | { Get {MovieDemo (limit: 3 34 | where: { path: ["release_year"], 35 | operator: Equal, 36 | valueInt: 1985}) { 37 | budget 38 | movie_id 39 | overview 40 | release_year 41 | revenue 42 | tagline 43 | title 44 | vote_average 45 | vote_count 46 | }}} 47 | ``` 48 | 49 | #### Run the app 50 | 1. Run the app with: `streamlit run demo_app.py` 51 | 2. The app should spin up in a new browser tab 52 | 53 | (Please note that this version of the demo app does not feature the poster images so it will look different from the [deployed app](https://weaviate-movie-magic.streamlit.app/).) 54 | -------------------------------------------------------------------------------- /recipes/weaviate/demo_app.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import time 3 | import sys 4 | import os 5 | import base64 6 | from st_weaviate_connection import WeaviateConnection, WeaviateFilter 7 | from weaviate.classes.query import Filter 8 | 9 | # Constants 10 | ENV_VARS = ["WEAVIATE_URL", "WEAVIATE_API_KEY", "COHERE_API_KEY"] 11 | NUM_RECOMMENDATIONS_PER_ROW = 5 12 | SEARCH_LIMIT = 10 13 | 14 | # Search Mode descriptions 15 | SEARCH_MODES = { 16 | "Keyword": ("Keyword search (BM25) ranks documents based on the relative frequencies of search terms.", 0), 17 | "Semantic": ("Semantic (vector) search ranks results based on their similarity to your search query.", 1), 18 | "Hybrid": ("Hybrid search combines vector and BM25 searches to offer best-of-both-worlds search results.", 0.7), 19 | } 20 | 21 | # Functions 22 | def get_env_vars(env_vars): 23 | """Retrieve environment variables""" 24 | env_vars = {var: os.environ.get(var, "") for var in env_vars} 25 | for var, value in env_vars.items(): 26 | if not value: 27 | st.error(f"{var} not set", icon="🚨") 28 | sys.exit(f"{var} not set") 29 | return env_vars 30 | 31 | def display_chat_messages(): 32 | """Print message history""" 33 | for message in st.session_state.messages: 34 | with st.chat_message(message["role"]): 35 | st.markdown(message["content"]) 36 | if "images" in message: 37 | for i in range(0, len(message["images"]), NUM_RECOMMENDATIONS_PER_ROW): 38 | cols = st.columns(NUM_RECOMMENDATIONS_PER_ROW) 39 | for j, col in enumerate(cols): 40 | if i + j < len(message["images"]): 41 | col.image(message["images"][i + j], width=200) 42 | if "titles" in message: 43 | for i in range(0, len(message["titles"]), NUM_RECOMMENDATIONS_PER_ROW): 44 | cols = st.columns(NUM_RECOMMENDATIONS_PER_ROW) 45 | for j, col in enumerate(cols): 46 | if i + j < len(message["titles"]): 47 | col.write(message["titles"][i + j]) 48 | 49 | 50 | def base64_to_image(base64_str): 51 | """Convert base64 string to image""" 52 | return f"data:image/png;base64,{base64_str}" 53 | 54 | def clean_input(input_text): 55 | """Clean user input""" 56 | return input_text.replace('"', "").replace("'", "") 57 | 58 | def setup_sidebar(): 59 | """Setup sidebar elements""" 60 | with st.sidebar: 61 | st.title("🎥🍿 Movie Magic") 62 | st.subheader("The RAG Recommender") 63 | st.markdown("Your Weaviate & AI powered movie recommender. Find the perfect film for any occasion. Just tell us what you're looking for!") 64 | st.header("Settings") 65 | 66 | mode = st.radio("Search Mode", options=list(SEARCH_MODES.keys()), index=2) 67 | year_range = st.slider("Year range", min_value=1950, max_value=2024, value=(1990, 2024)) 68 | st.info(SEARCH_MODES[mode][0]) 69 | st.success("Connected to Weaviate", icon="💚") 70 | 71 | return mode, year_range 72 | 73 | def setup_weaviate_connection(env_vars): 74 | """Setup Weaviate connection""" 75 | return st.connection( 76 | "weaviate", 77 | type=WeaviateConnection, 78 | url=env_vars["WEAVIATE_URL"], 79 | api_key=env_vars["WEAVIATE_API_KEY"], 80 | additional_headers={"X-Cohere-Api-Key": env_vars["COHERE_API_KEY"]}, 81 | ) 82 | 83 | def display_example_prompts(): 84 | """Display example prompt buttons""" 85 | example_prompts = [ 86 | ("sci-fi adventure", "movie night with friends"), 87 | ("romantic comedy", "date night"), 88 | ("animated family film", "family viewing"), 89 | ("classic thriller", "solo watching"), 90 | ("historical drama", "educational evening"), 91 | ("indie comedy-drama", "film club discussion"), 92 | ] 93 | 94 | example_prompts_help = [ 95 | "Search for sci-fi adventure movies suitable for a group viewing", 96 | "Find romantic comedies perfect for a date night", 97 | "Look for animated movies great for family entertainment", 98 | "Discover classic thrillers for a solo movie night", 99 | "Explore historical dramas for an educational movie experience", 100 | "Find indie comedy-dramas ideal for film club discussions", 101 | ] 102 | 103 | st.markdown("---") 104 | st.write("Select an example prompt or enter your own, then **click `Search`** to get recommendations.") 105 | 106 | button_cols = st.columns(3) 107 | button_cols_2 = st.columns(3) 108 | 109 | for i, ((movie_type, occasion), help_text) in enumerate(zip(example_prompts, example_prompts_help)): 110 | col = button_cols[i] if i < 3 else button_cols_2[i-3] 111 | if col.button(f"{movie_type} for a {occasion}", help=help_text): 112 | st.session_state.example_movie_type = movie_type 113 | st.session_state.example_occasion = occasion 114 | return True 115 | return False 116 | 117 | def perform_search(conn, movie_type, rag_prompt, year_range, mode): 118 | """Perform search and display results""" 119 | df = conn.query( 120 | "MovieDemo", 121 | query=movie_type, 122 | # Uncomment the line below if you want to use this with poster images 123 | # return_properties=["title", "tagline", "poster"], 124 | # Comment out the line below if you want to use this with poster images 125 | return_properties=["title", "tagline"], 126 | filters=( 127 | WeaviateFilter.by_property("release_year").greater_or_equal(year_range[0]) & 128 | WeaviateFilter.by_property("release_year").less_or_equal(year_range[1]) 129 | ), 130 | limit=SEARCH_LIMIT, 131 | alpha=SEARCH_MODES[mode][1], 132 | ) 133 | 134 | images = [] 135 | titles = [] 136 | 137 | if df is None or df.empty: 138 | with st.chat_message("assistant"): 139 | st.write(f"No movies found matching {movie_type} and using {mode}. Please try again.") 140 | st.session_state.messages.append({"role": "assistant", "content": "No movies found. Please try again."}) 141 | return 142 | else: 143 | with st.chat_message("assistant"): 144 | st.write("Raw search results.") 145 | cols = st.columns(NUM_RECOMMENDATIONS_PER_ROW) 146 | for index, row in df.iterrows(): 147 | col = cols[index % NUM_RECOMMENDATIONS_PER_ROW] 148 | if "poster" in row and row["poster"]: 149 | col.image(base64_to_image(row["poster"]), width=200) 150 | images.append(base64_to_image(row["poster"])) 151 | else: 152 | col.write(f"{row['title']}") 153 | titles.append(row["title"]) 154 | 155 | st.write("Now generating recommendation from these: ...") 156 | 157 | st.session_state.messages.append( 158 | {"role": "assistant", "content": "Raw search results. Generating recommendation from these: ...", "images": images, "titles": titles}) 159 | 160 | with conn.client() as client: 161 | collection = client.collections.get("MovieDemo") 162 | response = collection.generate.hybrid( 163 | query=movie_type, 164 | filters=( 165 | Filter.by_property("release_year").greater_or_equal(year_range[0]) & 166 | Filter.by_property("release_year").less_or_equal(year_range[1]) 167 | ), 168 | limit=SEARCH_LIMIT, 169 | alpha=SEARCH_MODES[mode][1], 170 | grouped_task=rag_prompt, 171 | grouped_properties=["title", "tagline"], 172 | ) 173 | 174 | rag_response = response.generated 175 | 176 | with st.chat_message("assistant"): 177 | message_placeholder = st.empty() 178 | full_response = "" 179 | for chunk in rag_response.split(): 180 | full_response += chunk + " " 181 | time.sleep(0.02) 182 | message_placeholder.markdown(full_response + "▌") 183 | message_placeholder.markdown(full_response) 184 | 185 | st.session_state.messages.append( 186 | {"role": "assistant", "content": "Recommendation from these search results: " + full_response} 187 | ) 188 | 189 | def main(): 190 | st.title("🎥🍿 Movie Magic") 191 | 192 | env_vars = get_env_vars(ENV_VARS) 193 | conn = setup_weaviate_connection(env_vars) 194 | mode, year_range = setup_sidebar() 195 | 196 | if "messages" not in st.session_state: 197 | st.session_state.messages = [] 198 | st.session_state.greetings = False 199 | 200 | display_chat_messages() 201 | 202 | if not st.session_state.greetings: 203 | with st.chat_message("assistant"): 204 | intro = "👋 Welcome to Movie Magic! I'm your AI movie recommender. Tell me what kind of film you're in the mood for and the occasion, and I'll suggest some great options." 205 | st.markdown(intro) 206 | st.session_state.messages.append({"role": "assistant", "content": intro}) 207 | st.session_state.greetings = True 208 | 209 | if "example_movie_type" not in st.session_state: 210 | st.session_state.example_movie_type = "" 211 | if "example_occasion" not in st.session_state: 212 | st.session_state.example_occasion = "" 213 | 214 | example_selected = display_example_prompts() 215 | 216 | movie_type = clean_input(st.text_input( 217 | "What movies are you looking for?", 218 | value=st.session_state.example_movie_type, 219 | placeholder="E.g., sci-fi adventure, romantic comedy" 220 | )) 221 | 222 | viewing_occasion = clean_input(st.text_input( 223 | "What occasion is the movie for?", 224 | value=st.session_state.example_occasion, 225 | placeholder="E.g., movie night with friends, date night" 226 | )) 227 | 228 | if st.button("Search") and movie_type and viewing_occasion: 229 | rag_prompt = f"Suggest one to two movies out of the following list, for a {viewing_occasion}. Give a concise yet fun and positive recommendation." 230 | prompt = f"Searching for: {movie_type} for {viewing_occasion}" 231 | with st.chat_message("user"): 232 | st.markdown(prompt) 233 | st.session_state.messages.append({"role": "user", "content": prompt}) 234 | 235 | perform_search(conn, movie_type, rag_prompt, year_range, mode) 236 | st.rerun() 237 | 238 | if example_selected: 239 | st.rerun() 240 | 241 | if __name__ == "__main__": 242 | main() 243 | -------------------------------------------------------------------------------- /recipes/weaviate/helpers/add_data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from pathlib import Path 3 | import weaviate 4 | from weaviate.classes.config import Configure, DataType, Property 5 | from weaviate.classes.query import Filter 6 | from weaviate.util import generate_uuid5 7 | from datetime import datetime, timezone 8 | import toml 9 | import os 10 | from weaviate.classes.init import Auth 11 | from tqdm import tqdm 12 | import base64 13 | 14 | # Construct the path to the toml file 15 | current_dir = os.path.dirname(__file__) 16 | parent_dir = os.path.dirname(current_dir) 17 | toml_file_path = os.path.join(parent_dir, ".streamlit/secrets.toml") 18 | 19 | config = toml.load(toml_file_path) 20 | 21 | # Access values from the toml file 22 | weaviate_api_key = config["WEAVIATE_API_KEY"] 23 | weaviate_url = config["WEAVIATE_URL"] 24 | cohere_api_key = config["COHERE_API_KEY"] 25 | 26 | # Client for Weaviate Cloud 27 | client = weaviate.connect_to_weaviate_cloud( 28 | cluster_url=weaviate_url, 29 | auth_credentials=Auth.api_key(weaviate_api_key), 30 | headers={ 31 | "X-Cohere-Api-Key": cohere_api_key 32 | } 33 | ) 34 | 35 | # If you are using a local instance of Weaviate, you can use the following code 36 | # client = weaviate.connect_to_local(headers={"X-Cohere-Api-Key": cohere_apikey}) 37 | 38 | # Delete any existing MovieDemo Collection to prevent errors 39 | client.collections.delete(["MovieDemo"]) 40 | 41 | # Create the MovieDemo Collection 42 | movies = client.collections.create( 43 | name="MovieDemo", 44 | properties=[ 45 | Property( 46 | name="title", 47 | data_type=DataType.TEXT, 48 | ), 49 | Property( 50 | name="overview", 51 | data_type=DataType.TEXT, 52 | ), 53 | Property( 54 | name="tagline", 55 | data_type=DataType.TEXT, 56 | ), 57 | Property( 58 | name="movie_id", 59 | data_type=DataType.INT, 60 | skip_vectorization=True, 61 | ), 62 | Property( 63 | name="release_year", 64 | data_type=DataType.INT, 65 | ), 66 | Property( 67 | name="genres", 68 | data_type=DataType.TEXT_ARRAY, 69 | ), 70 | Property( 71 | name="vote_average", 72 | data_type=DataType.NUMBER, 73 | ), 74 | Property( 75 | name="vote_count", 76 | data_type=DataType.INT, 77 | ), 78 | Property( 79 | name="revenue", 80 | data_type=DataType.INT, 81 | ), 82 | Property( 83 | name="budget", 84 | data_type=DataType.INT, 85 | ), 86 | # Uncomment the lines below if you want to use this with poster images 87 | # Property( 88 | # name="poster", 89 | # data_type=DataType.BLOB 90 | # ), 91 | ], 92 | vectorizer_config=Configure.Vectorizer.text2vec_cohere(), 93 | vector_index_config=Configure.VectorIndex.hnsw( 94 | quantizer=Configure.VectorIndex.Quantizer.bq() 95 | ), 96 | generative_config=Configure.Generative.cohere(model="command-r-plus"), 97 | ) 98 | 99 | # Add objects to the MovieDemo collection from the JSON file and directory of poster images 100 | json_file_path = os.path.join(os.getcwd(), "helpers/data/1950_2024_movies_info.json") 101 | movies_df = pd.read_json(json_file_path) 102 | 103 | img_dir = Path(os.path.join(os.getcwd(), "helpers/data/posters")) 104 | 105 | dataobj_list = list() 106 | 107 | with movies.batch.fixed_size(100) as batch: 108 | for i, movie_row in tqdm(movies_df.iterrows()): 109 | try: 110 | date_object = datetime.strptime(movie_row["release_date"], "%Y-%m-%d").replace( 111 | tzinfo=timezone.utc 112 | ) 113 | # Uncomment the lines below if you want to use this with poster images 114 | # img_path = (img_dir / f"{movie_row['id']}_poster.jpg") 115 | # with open(img_path, "rb") as file: 116 | # poster_b64 = base64.b64encode(file.read()).decode("utf-8") 117 | 118 | props = { 119 | k: movie_row[k] 120 | for k in [ 121 | "title", 122 | "overview", 123 | "tagline", 124 | "vote_count", 125 | "vote_average", 126 | "revenue", 127 | "budget", 128 | ] 129 | } 130 | props["movie_id"] = movie_row["id"] 131 | props["release_year"] = date_object.year 132 | props["genres"] = [genre["name"] for genre in movie_row["genres"]] 133 | # Uncomment the line below if you want to use this with poster images 134 | # props["poster"] = poster_b64 135 | 136 | batch.add_object(properties=props, uuid=generate_uuid5(movie_row["id"])) 137 | except Exception as e: 138 | print(f"Error: {e}") 139 | movies_df = movies_df.drop(i) 140 | movies_df.to_json(json_file_path, orient="records") 141 | continue 142 | 143 | # Close the connection to Weaviate Cloud 144 | client.close() -------------------------------------------------------------------------------- /recipes/weaviate/helpers/demo_movie_query.graphql: -------------------------------------------------------------------------------- 1 | { Get {MovieDemo (limit: 3 2 | where: { path: ["release_year"], 3 | operator: Equal, 4 | valueInt: 1985}) { 5 | budget 6 | movie_id 7 | overview 8 | release_year 9 | revenue 10 | tagline 11 | title 12 | vote_average 13 | vote_count 14 | }}} -------------------------------------------------------------------------------- /recipes/weaviate/helpers/verify_data.py: -------------------------------------------------------------------------------- 1 | import weaviate 2 | from weaviate.classes.query import Filter 3 | from weaviate.classes.init import Auth 4 | import toml 5 | import os 6 | 7 | 8 | # Construct the path to the toml file 9 | current_dir = os.path.dirname(__file__) 10 | parent_dir = os.path.dirname(current_dir) 11 | toml_file_path = os.path.join(parent_dir, ".streamlit/secrets.toml") 12 | 13 | config = toml.load(toml_file_path) 14 | 15 | # Access values from the TOML file 16 | weaviate_api_key = config["WEAVIATE_API_KEY"] 17 | weaviate_url = config["WEAVIATE_URL"] 18 | cohere_api_key = config["COHERE_API_KEY"] 19 | 20 | client = weaviate.connect_to_weaviate_cloud( 21 | cluster_url=weaviate_url, 22 | auth_credentials=Auth.api_key(weaviate_api_key), 23 | headers={ 24 | "X-Cohere-Api-Key": cohere_api_key 25 | } 26 | ) 27 | 28 | # # If you are using a local instance of Weaviate, you can use the following code 29 | # client = weaviate.connect_to_local( 30 | # headers={ 31 | # "X-Cohere-Api-Key": cohere_api_key 32 | # } 33 | # ) 34 | 35 | movies = client.collections.get("MovieDemo") 36 | 37 | print(movies.aggregate.over_all(total_count=True)) 38 | 39 | r = movies.query.fetch_objects(limit=1, return_properties=["title"]) 40 | print(r.objects[0].properties["title"]) 41 | 42 | client.close() -------------------------------------------------------------------------------- /recipes/weaviate/requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit 2 | st-weaviate-connection 3 | tqdm --------------------------------------------------------------------------------