├── .gitignore
├── .gitmodules
├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.md
├── recipe-template
└── README.md
└── recipes
├── confluent
└── README.md
├── crewai
├── .streamlit
│ └── secrets.toml
├── README.md
├── requirements.txt
├── streamlit.app.py
├── tools
│ ├── __init__.py
│ ├── browser_tools.py
│ ├── calculator_tools.py
│ └── search_tools.py
├── trip_agents.py
└── trip_tasks.py
├── llamaindex
├── .gitignore
├── README.md
├── arc_finetuning_st
│ ├── __init__.py
│ ├── cli
│ │ ├── __init__.py
│ │ ├── command_line.py
│ │ ├── evaluation.py
│ │ └── finetune.py
│ ├── finetuning
│ │ ├── __init__.py
│ │ ├── finetuning_example.py
│ │ └── templates.py
│ ├── streamlit
│ │ ├── __init__.py
│ │ ├── app.py
│ │ └── controller.py
│ └── workflows
│ │ ├── __init__.py
│ │ ├── arc_task_solver.py
│ │ ├── events.py
│ │ ├── models.py
│ │ └── prompts.py
├── poetry.lock
└── pyproject.toml
├── ollama
├── requirements.txt
└── streamlit_app.py
├── openai
└── README.md
├── replicate
├── .streamlit
│ └── secrets_template.toml
├── README.md
├── requirements.txt
└── streamlit_app.py
├── replit
└── README.md
├── trulens
├── .env.template
├── README.md
├── app.py
├── base.py
├── feedback.py
├── requirements.txt
└── vector_store.py
└── weaviate
├── .streamlit
└── secrets_template.toml
├── README.md
├── demo_app.py
├── helpers
├── add_data.py
├── data
│ └── 1950_2024_movies_info.json
├── demo_movie_query.graphql
└── verify_data.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Editors
2 | .vscode/
3 |
4 | # Mac/OSX
5 | .DS_Store
6 |
7 | # Source for the following rules: https://raw.githubusercontent.com/github/gitignore/master/Python.gitignore
8 | # Byte-compiled / optimized / DLL files
9 | __pycache__/
10 | *.py[cod]
11 | *$py.class
12 |
13 | # Environments
14 | replicatevenv/
15 | weaviatevenv/
16 |
17 | # Secrets
18 | *secrets.toml
19 |
20 | # Recipe specific
21 | recipes/weaviate/helpers/data/posters
22 |
23 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "recipes/replit/ab-testing"]
2 | path = recipes/replit/ab-testing
3 | url = https://github.com/mattppal/ab-testing/
4 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to making participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | - Using welcoming and inclusive language
18 | - Being respectful of differing viewpoints and experiences
19 | - Gracefully accepting constructive criticism
20 | - Focusing on what is best for the community
21 | - Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | - The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | - Trolling, insulting/derogatory comments, and personal or political attacks
28 | - Public or private harassment
29 | - Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | - Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies both within project spaces and in public spaces
49 | when an individual is representing the project or its community. Examples of
50 | representing a project or community include using an official project e-mail
51 | address, posting via an official social media account, or acting as an appointed
52 | representative at an online or offline event. Representation of a project may be
53 | further defined and clarified by project maintainers.
54 |
55 | ## Enforcement
56 |
57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
58 | reported by contacting the project team at hello@streamlit.io. All
59 | complaints will be reviewed and investigated and will result in a response that
60 | is deemed necessary and appropriate to the circumstances. The project team is
61 | obligated to maintain confidentiality with regard to the reporter of an incident.
62 | Further details of specific enforcement policies may be posted separately.
63 |
64 | Project maintainers who do not follow or enforce the Code of Conduct in good
65 | faith may face temporary or permanent repercussions as determined by other
66 | members of the project's leadership.
67 |
68 | ## Attribution
69 |
70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
72 |
73 | [homepage]: https://www.contributor-covenant.org
74 |
75 | For answers to common questions about this code of conduct, see
76 | https://www.contributor-covenant.org/faq
77 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 📖 Streamlit Cookbook
2 |
3 | Streamlit Cookbook is a compilation of Streamlit app templates that you can use as boilerplate code to jumpstart your own app creation endeavors.
4 |
5 | In a nutshell, each recipe allows you to quickly spin up a working Streamlit app that implements the library/tool of interest.
6 |
7 | ## 🧑🍳 How to use this cookbook?
8 | For your convenience, we've categorized tool integrations by library name. For instance, Streamlit app built using OpenAI is a sub-folder in the `replicate` folder, thus `recipes/replicate`.
9 |
10 | ## 🍪 List of Recipes
11 |
12 | ### AI Recipes
13 | | Tool | Description | Resources |
14 | | -- | -- | -- |
15 | | [CrewAI](https://github.com/streamlit/cookbook/tree/main/recipes/crewai) | Build the VacAIgent app where agents collaboratively decide on cities and craft a complete itinerary for your trip based on specified preferences. | |
16 | | [Ollama](https://github.com/streamlit/cookbook/tree/main/recipes/ollama) | Build a Streamlit Ollama chatbot using a local LLM model. | |
17 | | [Replicate](https://github.com/streamlit/cookbook/tree/main/recipes/replicate) | Build a Streamlit Replicate chatbot that allow users to select an LLM model of their choice for response generation. | [Video](https://youtu.be/zsQ7EN10zj8), [Blog](https://blog.streamlit.io/how-to-recommendation-app-vector-database-weaviate/) |
18 | | [Trulens](https://github.com/streamlit/cookbook/tree/main/recipes/trulens) | Ask questions about the Pacific Northwest and get TruLens evaluations on the app response. | |
19 | | [Weaviate](https://github.com/streamlit/cookbook/tree/main/recipes/weaviate) | Build a movie explorer app that leverages the Weaviate vector database. | [Video](https://youtu.be/SQD-aWlhqvM), [Blog](https://blog.streamlit.io/how-to-recommendation-app-vector-database-weaviate/) |
20 |
21 | ### Data Recipes
22 | | Tool | Description | Resources |
23 | | -- | -- | -- |
24 | | [Replit](https://github.com/streamlit/cookbook/tree/main/recipes/replit) | Build a statistical app that allow users select a suitable sample size for AB testing | [Video](https://youtu.be/CJ9E0Sm_hy4) |
25 |
26 |
27 | ## 🏃 How to run the app?
28 | To run the app, start by installing Streamlit on the command line using the `pip install streamlit` command. Next, change into the directory and launch the app via `streamlit run streamlit_app.py`.
29 |
30 | ## 📥 How to contribute to this repo?
31 | We welcome contributions to this repository Contributions can come in many forms whether it be suggesting a recipe idea (please suggest it on the [issues page](https://github.com/streamlit/streamlit-cookbook/issues)), writing a new recipe, improving a pre-existing recipe, or fixing a typo/error.
32 |
33 | Simply submit a pull request [here](https://github.com/streamlit/streamlit-cookbook/pulls) to get started.
34 |
--------------------------------------------------------------------------------
/recipe-template/README.md:
--------------------------------------------------------------------------------
1 | # How to run the demo app (Optional: Add what kind of app it is/the tech the app is built with)
2 | ## 🔍 Overview
3 | This is a recipe for a [TODO: Add the kind of app it is]. TODO: Add one sentence describing what the app does.
4 |
5 | Other ways to explore this recipe:
6 | * [Deployed app](TODO: URL of deployed app)
7 | * [Blog post](TODO: URL of blog post)
8 | * [Video](TODO: URL of video)
9 |
10 | ## 📝 Prerequisites
11 | * Python >=3.8, !=3.9.7
12 | * TODO: List additional prerequisites
13 |
14 | ## 🌎 Environment setup
15 | ### Local setup
16 | 1. Clone the Cookbook repo: `git clone https://github.com/streamlit/cookbook.git`
17 | 2. From the Cookbook root directory, change directory into the recipe: `cd recipes/TODO: Add recipe directory`
18 | 3. Add secrets to the `.streamlit/secrets_template.toml` file
19 | 4. Update the filename from `secrets_template.toml` to `secrets.toml`: `mv .streamlit/secrets_template.toml .streamlit/secrets.toml`
20 |
21 | (To learn more about secrets handling in Streamlit, refer to the documentation [here](https://docs.streamlit.io/develop/concepts/connections/secrets-management).)
22 | 5. Create a virtual environment: `python -m venv TODO: Add name of virtual environment`
23 | 6. Activate the virtual environment: `source TODO: Add name of virtual environment/bin/activate`
24 | 7. Install the dependencies: `pip install -r requirements.txt`
25 |
26 | ### GitHub Codespaces setup
27 | 1. Create a new codespace by selecting the `Codespaces` option from the `Code` button
28 | 2. Once the codespace has been generated, add your secrets to the `recipes/TODO: Add recipe directory/.streamlit/secrets_template.toml` file
29 | 3. Update the filename from `secrets_template.toml` to `secrets.toml`
30 |
31 | (To learn more about secrets handling in Streamlit, refer to the documentation [here](https://docs.streamlit.io/develop/concepts/connections/secrets-management).)
32 | 4. From the Cookbook root directory, change directory into the recipe: `cd recipes/TODO: Add recipe directory`
33 | 5. Install the dependencies: `pip install -r requirements.txt`
34 |
--------------------------------------------------------------------------------
/recipes/confluent/README.md:
--------------------------------------------------------------------------------
1 | # Confluent recipe
2 |
3 | - 🕹️ Demo app https://flink-st-kafka.streamlit.app/
4 | - 🐙 GitHub repo https://github.com/confluentinc/demo-scene/tree/master/flink-streamlit
5 |
--------------------------------------------------------------------------------
/recipes/crewai/.streamlit/secrets.toml:
--------------------------------------------------------------------------------
1 | SERPER_API_KEY="API_KEY_HERE" # https://serper.dev/ (free tier)
2 | BROWSERLESS_API_KEY="API_KEY_HERE" # https://www.browserless.io/ (free tier)
3 | OPENAI_API_KEY="API_KEY_HERE"
--------------------------------------------------------------------------------
/recipes/crewai/README.md:
--------------------------------------------------------------------------------
1 | ## CrewAI Framework
2 |
3 | CrewAI simplifies the orchestration of role-playing AI agents. In VacAIgent, these agents collaboratively decide on cities and craft a complete itinerary for your trip based on specified preferences, all accessible via a streamlined Streamlit user interface.
4 |
5 | ## Running the Application
6 |
7 | - **Configure Environment**: Set up the environment variables for [Browseless](https://www.browserless.io/), [Serper](https://serper.dev/), and [OpenAI](https://openai.com/). Add your keys to `.streamlit/secrets.toml` file.
8 |
9 | - **Install Dependencies**: Execute `pip install -r requirements.txt` in your terminal.
10 | - **Launch the App**: Run `streamlit run streamlit_app.py` to start the Streamlit interface.
11 |
12 | ★ **Disclaimer**: The application uses GPT-4 by default. Ensure you have access to OpenAI's API and be aware of the associated costs.
13 |
14 | ## Details & Explanation
15 |
16 | - **Components**:
17 | - `./trip_tasks.py`: Contains task prompts for the agents.
18 | - `./trip_agents.py`: Manages the creation of agents.
19 | - `./tools directory`: Houses tool classes used by agents.
20 | - `./streamlit_app.py`: The heart of the Streamlit app.
21 |
22 | ## Using GPT-4o mini
23 |
24 | To switch from GPT-4 to GPT-4o-mini, pass the llm argument in the agent constructor:
25 |
26 | ```python
27 | from langchain.chat_models import ChatOpenAI
28 |
29 | llm = ChatOpenAI(model='gpt-4o-mini') # See more OpenAI models at https://platform.openai.com/docs/models/
30 |
31 | class TripAgents:
32 | # ... existing methods
33 |
34 | def local_expert(self):
35 | return Agent(
36 | role='Local Expert',
37 | goal='Provide insights about the selected city',
38 | tools=[SearchTools.search_internet, BrowserTools.scrape_and_summarize_website],
39 | llm=llm,
40 | verbose=True
41 | )
42 |
43 | ```
44 |
45 | ## Using Local Models with Ollama
46 |
47 | For enhanced privacy and customization, you can integrate local models like Ollama:
48 |
49 | ### Setting Up Ollama
50 |
51 | - **Installation**: Follow Ollama's guide for installation.
52 | - **Configuration**: Customize the model as per your requirements.
53 |
54 | ### Integrating Ollama with CrewAI
55 |
56 | Pass the Ollama model to agents in the CrewAI framework:
57 |
58 | ```python
59 | from langchain.llms import Ollama
60 |
61 | ollama_model = Ollama(model="agent")
62 |
63 | class TripAgents:
64 | # ... existing methods
65 |
66 | def local_expert(self):
67 | return Agent(
68 | role='Local Expert',
69 | tools=[SearchTools.search_internet, BrowserTools.scrape_and_summarize_website],
70 | llm=ollama_model,
71 | verbose=True
72 | )
73 |
74 | ```
75 |
76 |
--------------------------------------------------------------------------------
/recipes/crewai/requirements.txt:
--------------------------------------------------------------------------------
1 | crewai
2 | streamlit
3 | openai
4 | unstructured
5 | langchain
6 | pyowm
7 | tools
--------------------------------------------------------------------------------
/recipes/crewai/streamlit.app.py:
--------------------------------------------------------------------------------
1 | from crewai import Crew
2 | from trip_agents import TripAgents, StreamToExpander
3 | from trip_tasks import TripTasks
4 | import streamlit as st
5 | import datetime
6 | import sys
7 |
8 | st.set_page_config(page_icon="✈️", layout="wide")
9 |
10 |
11 | class TripCrew:
12 | def __init__(self, origin, cities, date_range, interests):
13 | self.cities = cities
14 | self.origin = origin
15 | self.interests = interests
16 | self.date_range = date_range
17 | self.output_placeholder = st.empty()
18 |
19 | def run(self):
20 | agents = TripAgents()
21 | tasks = TripTasks()
22 |
23 | city_selector_agent = agents.city_selection_agent()
24 | local_expert_agent = agents.local_expert()
25 | travel_concierge_agent = agents.travel_concierge()
26 |
27 | identify_task = tasks.identify_task(
28 | city_selector_agent,
29 | self.origin,
30 | self.cities,
31 | self.interests,
32 | self.date_range
33 | )
34 |
35 | gather_task = tasks.gather_task(
36 | local_expert_agent,
37 | self.origin,
38 | self.interests,
39 | self.date_range
40 | )
41 |
42 | plan_task = tasks.plan_task(
43 | travel_concierge_agent,
44 | self.origin,
45 | self.interests,
46 | self.date_range
47 | )
48 |
49 | crew = Crew(
50 | agents=[
51 | city_selector_agent, local_expert_agent, travel_concierge_agent
52 | ],
53 | tasks=[identify_task, gather_task, plan_task],
54 | verbose=True
55 | )
56 |
57 | result = crew.kickoff()
58 | self.output_placeholder.markdown(result)
59 |
60 | return result
61 |
62 |
63 | if __name__ == "__main__":
64 | st.title("AI Travel Agent")
65 | st.subheader("Let AI agents plan your next vacation!",
66 | divider="rainbow", anchor=False)
67 |
68 | import datetime
69 |
70 | today = datetime.datetime.now().date()
71 | next_year = today.year + 1
72 | jan_16_next_year = datetime.date(next_year, 1, 10)
73 |
74 | with st.sidebar:
75 | st.header("👇 Enter your trip details")
76 | with st.form("my_form"):
77 | location = st.text_input(
78 | "Current location", placeholder="San Mateo, CA")
79 | cities = st.text_input(
80 | "Destination", placeholder="Bali, Indonesia")
81 | date_range = st.date_input(
82 | "Travel date range",
83 | min_value=today,
84 | value=(today, jan_16_next_year + datetime.timedelta(days=6)),
85 | format="MM/DD/YYYY",
86 | )
87 | interests = st.text_area("High level interests and hobbies or extra details about your trip?",
88 | placeholder="2 adults who love swimming, dancing, hiking, and eating")
89 |
90 | submitted = st.form_submit_button("Submit")
91 |
92 | if submitted:
93 | with st.status("🤖 **Agents at work...**", state="running", expanded=True) as status:
94 | with st.container(height=500, border=False):
95 | sys.stdout = StreamToExpander(st)
96 | trip_crew = TripCrew(location, cities, date_range, interests)
97 | result = trip_crew.run()
98 | status.update(label="✅ Trip Plan Ready!",
99 | state="complete", expanded=False)
100 |
101 | st.subheader("Here is your Trip Plan", anchor=False, divider="rainbow")
102 | st.markdown(result)
--------------------------------------------------------------------------------
/recipes/crewai/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/streamlit/cookbook/cae857225ae429b62351a281c49abfbd5346d08f/recipes/crewai/tools/__init__.py
--------------------------------------------------------------------------------
/recipes/crewai/tools/browser_tools.py:
--------------------------------------------------------------------------------
1 | import json
2 | import requests
3 | import streamlit as st
4 | from crewai import Agent, Task
5 | from langchain.tools import tool
6 | from unstructured.partition.html import partition_html
7 |
8 |
9 | class BrowserTools():
10 |
11 | @tool("Scrape website content")
12 | def scrape_and_summarize_website(website):
13 | """Useful to scrape and summarize a website content"""
14 | url = f"https://chrome.browserless.io/content?token={st.secrets['BROWSERLESS_API_KEY']}"
15 | payload = json.dumps({"url": website})
16 | headers = {'cache-control': 'no-cache', 'content-type': 'application/json'}
17 | response = requests.request("POST", url, headers=headers, data=payload)
18 | elements = partition_html(text=response.text)
19 | content = "\n\n".join([str(el) for el in elements])
20 | content = [content[i:i + 8000] for i in range(0, len(content), 8000)]
21 | summaries = []
22 | for chunk in content:
23 | agent = Agent(
24 | role='Principal Researcher',
25 | goal=
26 | 'Do amazing researches and summaries based on the content you are working with',
27 | backstory=
28 | "You're a Principal Researcher at a big company and you need to do a research about a given topic.",
29 | allow_delegation=False)
30 | task = Task(
31 | agent=agent,
32 | description=
33 | f'Analyze and summarize the content bellow, make sure to include the most relevant information in the summary, return only the summary nothing else.\n\nCONTENT\n----------\n{chunk}'
34 | )
35 | summary = task.execute()
36 | summaries.append(summary)
37 | return "\n\n".join(summaries)
--------------------------------------------------------------------------------
/recipes/crewai/tools/calculator_tools.py:
--------------------------------------------------------------------------------
1 | from langchain.tools import tool
2 |
3 |
4 | class CalculatorTools():
5 |
6 | @tool("Make a calcualtion")
7 | def calculate(operation):
8 | """Useful to perform any mathematical calculations,
9 | like sum, minus, multiplication, division, etc.
10 | The input to this tool should be a mathematical
11 | expression, a couple examples are `200*7` or `5000/2*10`
12 | """
13 | return eval(operation)
--------------------------------------------------------------------------------
/recipes/crewai/tools/search_tools.py:
--------------------------------------------------------------------------------
1 | import json
2 | import requests
3 | import streamlit as st
4 | from langchain.tools import tool
5 |
6 |
7 | class SearchTools():
8 |
9 | @tool("Search the internet")
10 | def search_internet(query):
11 | """Useful to search the internet
12 | about a a given topic and return relevant results"""
13 | top_result_to_return = 4
14 | url = "https://google.serper.dev/search"
15 | payload = json.dumps({"q": query})
16 | headers = {
17 | 'X-API-KEY': st.secrets['SERPER_API_KEY'],
18 | 'content-type': 'application/json'
19 | }
20 | response = requests.request("POST", url, headers=headers, data=payload)
21 | # check if there is an organic key
22 | if 'organic' not in response.json():
23 | return "Sorry, I couldn't find anything about that, there could be an error with you serper api key."
24 | else:
25 | results = response.json()['organic']
26 | string = []
27 | for result in results[:top_result_to_return]:
28 | try:
29 | string.append('\n'.join([
30 | f"Title: {result['title']}", f"Link: {result['link']}",
31 | f"Snippet: {result['snippet']}", "\n-----------------"
32 | ]))
33 | except KeyError:
34 | next
35 |
36 | return '\n'.join(string)
--------------------------------------------------------------------------------
/recipes/crewai/trip_agents.py:
--------------------------------------------------------------------------------
1 | from crewai import Agent
2 | import re
3 | import streamlit as st
4 | from langchain_community.llms import OpenAI
5 | from tools.browser_tools import BrowserTools
6 | from tools.calculator_tools import CalculatorTools
7 | from tools.search_tools import SearchTools
8 |
9 |
10 | class TripAgents():
11 | def city_selection_agent(self):
12 | return Agent(
13 | role='City Selection Expert',
14 | goal='Select the best city based on weather, season, and prices',
15 | backstory='An expert in analyzing travel data to pick ideal destinations',
16 | tools=[
17 | SearchTools.search_internet,
18 | BrowserTools.scrape_and_summarize_website,
19 | ],
20 | verbose=True,
21 | )
22 |
23 | def local_expert(self):
24 | return Agent(
25 | role='Local Expert at this city',
26 | goal='Provide the BEST insights about the selected city',
27 | backstory="""A knowledgeable local guide with extensive information
28 | about the city, it's attractions and customs""",
29 | tools=[
30 | SearchTools.search_internet,
31 | BrowserTools.scrape_and_summarize_website,
32 | ],
33 | verbose=True,
34 | )
35 |
36 | def travel_concierge(self):
37 | return Agent(
38 | role='Amazing Travel Concierge',
39 | goal="""Create the most amazing travel itineraries with budget and
40 | packing suggestions for the city""",
41 | backstory="""Specialist in travel planning and logistics with
42 | decades of experience""",
43 | tools=[
44 | SearchTools.search_internet,
45 | BrowserTools.scrape_and_summarize_website,
46 | CalculatorTools.calculate,
47 | ],
48 | verbose=True,
49 | )
50 |
51 | class StreamToExpander:
52 | def __init__(self, expander):
53 | self.expander = expander
54 | self.buffer = []
55 | self.colors = ['red', 'green', 'blue', 'orange'] # Define a list of colors
56 | self.color_index = 0 # Initialize color index
57 |
58 | def write(self, data):
59 | # Filter out ANSI escape codes using a regular expression
60 | cleaned_data = re.sub(r'\x1B\[[0-9;]*[mK]', '', data)
61 |
62 | # Check if the data contains 'task' information
63 | task_match_object = re.search(r'\"task\"\s*:\s*\"(.*?)\"', cleaned_data, re.IGNORECASE)
64 | task_match_input = re.search(r'task\s*:\s*([^\n]*)', cleaned_data, re.IGNORECASE)
65 | task_value = None
66 | if task_match_object:
67 | task_value = task_match_object.group(1)
68 | elif task_match_input:
69 | task_value = task_match_input.group(1).strip()
70 |
71 | if task_value:
72 | st.toast(":robot_face: " + task_value)
73 |
74 | # Check if the text contains the specified phrase and apply color
75 | if "Entering new CrewAgentExecutor chain" in cleaned_data:
76 | # Apply different color and switch color index
77 | self.color_index = (self.color_index + 1) % len(self.colors) # Increment color index and wrap around if necessary
78 |
79 | cleaned_data = cleaned_data.replace("Entering new CrewAgentExecutor chain", f":{self.colors[self.color_index]}[Entering new CrewAgentExecutor chain]")
80 |
81 | if "City Selection Expert" in cleaned_data:
82 | # Apply different color
83 | cleaned_data = cleaned_data.replace("City Selection Expert", f":{self.colors[self.color_index]}[City Selection Expert]")
84 | if "Local Expert at this city" in cleaned_data:
85 | cleaned_data = cleaned_data.replace("Local Expert at this city", f":{self.colors[self.color_index]}[Local Expert at this city]")
86 | if "Amazing Travel Concierge" in cleaned_data:
87 | cleaned_data = cleaned_data.replace("Amazing Travel Concierge", f":{self.colors[self.color_index]}[Amazing Travel Concierge]")
88 | if "Finished chain." in cleaned_data:
89 | cleaned_data = cleaned_data.replace("Finished chain.", f":{self.colors[self.color_index]}[Finished chain.]")
90 |
91 | self.buffer.append(cleaned_data)
92 | if "\n" in data:
93 | self.expander.markdown(''.join(self.buffer), unsafe_allow_html=True)
94 | self.buffer = []
--------------------------------------------------------------------------------
/recipes/crewai/trip_tasks.py:
--------------------------------------------------------------------------------
1 | from crewai import Task
2 | from textwrap import dedent
3 |
4 |
5 | class TripTasks():
6 |
7 | def identify_task(self, agent, origin, cities, interests, range):
8 | return Task(description=dedent(f"""
9 | Analyze and select the best city for the trip based
10 | on specific criteria such as weather patterns, seasonal
11 | events, and travel costs. This task involves comparing
12 | multiple cities, considering factors like current weather
13 | conditions, upcoming cultural or seasonal events, and
14 | overall travel expenses.
15 |
16 | Your final answer must be a detailed
17 | report on the chosen city, and everything you found out
18 | about it, including the actual flight costs, weather
19 | forecast and attractions.
20 | {self.__tip_section()}
21 |
22 | Traveling from: {origin}
23 | City Options: {cities}
24 | Trip Date: {range}
25 | Traveler Interests: {interests}
26 | """),
27 | expected_output="A detailed report on the chosen city with flight costs, weather forecast, and attractions.",
28 | agent=agent)
29 |
30 | def gather_task(self, agent, origin, interests, range):
31 | return Task(description=dedent(f"""
32 | As a local expert on this city you must compile an
33 | in-depth guide for someone traveling there and wanting
34 | to have THE BEST trip ever!
35 | Gather information about key attractions, local customs,
36 | special events, and daily activity recommendations.
37 | Find the best spots to go to, the kind of place only a
38 | local would know.
39 | This guide should provide a thorough overview of what
40 | the city has to offer, including hidden gems, cultural
41 | hotspots, must-visit landmarks, weather forecasts, and
42 | high level costs.
43 |
44 | The final answer must be a comprehensive city guide,
45 | rich in cultural insights and practical tips,
46 | tailored to enhance the travel experience.
47 | {self.__tip_section()}
48 |
49 | Trip Date: {range}
50 | Traveling from: {origin}
51 | Traveler Interests: {interests}
52 | """),
53 | expected_output="A comprehensive city guide with cultural insights and practical tips.",
54 | agent=agent)
55 |
56 | def plan_task(self, agent, origin, interests, range):
57 | return Task(description=dedent(f"""
58 | Expand this guide into a full travel
59 | itinerary for this time {range} with detailed per-day plans, including
60 | weather forecasts, places to eat, packing suggestions,
61 | and a budget breakdown.
62 |
63 | You MUST suggest actual places to visit, actual hotels
64 | to stay and actual restaurants to go to.
65 |
66 | This itinerary should cover all aspects of the trip,
67 | from arrival to departure, integrating the city guide
68 | information with practical travel logistics.
69 |
70 | Your final answer MUST be a complete expanded travel plan,
71 | formatted as markdown, encompassing a daily schedule,
72 | anticipated weather conditions, recommended clothing and
73 | items to pack, and a detailed budget, ensuring THE BEST
74 | TRIP EVER, Be specific and give it a reason why you picked
75 | # up each place, what make them special! {self.__tip_section()}
76 |
77 | Trip Date: {range}
78 | Traveling from: {origin}
79 | Traveler Interests: {interests}
80 | """),
81 | expected_output="A complete 7-day travel plan, formatted as markdown, with a daily schedule and budget.",
82 | agent=agent)
83 |
84 | def __tip_section(self):
85 | return "If you do your BEST WORK, I'll tip you $100 and grant you any wish you want!"
--------------------------------------------------------------------------------
/recipes/llamaindex/.gitignore:
--------------------------------------------------------------------------------
1 | .env.*
2 | .env
3 | index.html.*
4 | index.html
5 | task_results
6 | .ipynb_checkpoints/
7 | secrets.yaml
8 | Dockerfile.local
9 | docker-compose.local.yml
10 | pyproject.local.toml
11 | __pycache__
12 | data
13 | notebooks
14 | .DS_Store
15 | finetuning_examples
16 | finetuning_assets
17 |
--------------------------------------------------------------------------------
/recipes/llamaindex/README.md:
--------------------------------------------------------------------------------
1 | # ARC Task (LLM) Solver With Human Input
2 |
3 | The Abstraction and Reasoning Corpus ([ARC](https://github.com/fchollet/ARC-AGI)) for Artificial General Intelligence
4 | benchmark aims to measure an AI system's ability to efficiently learn new skills.
5 | Each task within the ARC benchmark contains a unique puzzle for which the systems
6 | attempt to solve. Currently, the best AI systems achieve 34% solve rates, whereas
7 | humans are able to achieve 85% ([source](https://www.kaggle.com/competitions/arc-prize-2024/overview/prizes)).
8 |
9 |
10 |
11 |
12 |
13 | Motivated by this large disparity, we built this app with the goal of injecting
14 | human-level reasoning on this benchmark to LLMs. Specifically, this app enables
15 | the collaboration of LLMs and humans to solve an ARC task; and these collaborations
16 | can then be used for fine-tuning the LLM.
17 |
18 | The Solver itself is a LlamaIndex `Workflow` that relies on successive runs for
19 | which `Context` is maintained from previous runs. Doing so allows for an
20 | effective implementation of the Human In the Loop Pattern.
21 |
22 |
23 |
24 |
25 |
26 | ## Running The App
27 |
28 | Before running the streamlit app, we first must download the ARC dataset. The
29 | below command will download the dataset and store it in a directory named `data/`:
30 |
31 | ```sh
32 | wget https://github.com/fchollet/ARC-AGI/archive/refs/heads/master.zip -O ./master.zip
33 | unzip ./master.zip -d ./
34 | mv ARC-AGI-master/data ./
35 | rm -rf ARC-AGI-master
36 | rm master.zip
37 | ```
38 |
39 | Next, we must install the app's dependencies. To do so, we can use `poetry`:
40 |
41 | ```sh
42 | poetry shell
43 | poetry install
44 | ```
45 |
46 | Finally, to run the streamlit app:
47 |
48 | ```sh
49 | export OPENAI_API_KEY= && streamlit run arc_finetuning_st/streamlit/app.py
50 | ```
51 |
52 | ## How To Use The App
53 |
54 | In the next two sections, we discuss how to use the app in order to solve a given
55 | ARC task.
56 |
57 |
58 |
59 |
60 |
61 | ### Solving an ARC Task
62 |
63 | Each ARC task consists of training examples, each of which consist of input and
64 | output pairs. There exists a common pattern between these input and output pairs,
65 | and the problem is solved by uncovering this pattern, which can be verified by
66 | the included test examples.
67 |
68 | To solve the task, we cycle through the following three steps:
69 |
70 | 1. Prediction (of test output grid)
71 | 2. Evaluation
72 | 3. Critique (human in the loop)
73 |
74 | (Under the hood a LlamaIndex `Workflow` implements these three `steps`.)
75 |
76 | Step 1. makes use of an LLM to produce the Prediction whereas Step 2. is
77 | deterministic and is a mere comparison between the ground truth test output and
78 | the Prediction. If the Prediction doesn't match the ground truth grid, then Step 3.
79 | is performed. Similar to step 1. an LLM is prompted to generate a Critique on the
80 | Prediction as to why it may not match the pattern underlying the train input and
81 | output pairs. However, we also allow for a human in the loop to override this
82 | LLM generated Critique.
83 |
84 | The Critique is carried on from a previous cycle onto the next in order to
85 | generate an improved and hopefully correct next Prediction.
86 |
87 | To begin, click the `Start` button found in the top-right corner. If the
88 | prediction is incorrect, you can view the Critique produced by the LLM in the
89 | designated text area. You can choose to use this Critique or supply your own by
90 | overwriting the text and applying the change. Once ready to produce the next
91 | prediction, hit the `Continue` button.
92 |
93 | ### Saving solutions for fine-tuning
94 |
95 | Any collaboration session involving the LLM and human can be saved and used to
96 | finetune an LLM. In this app, we use OpenAI LLMs, and so the finetuning examples
97 | adhere to the [OpenAI fine-tuning API](https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset).
98 | Click the `fine-tuning example` button during a session to see the current
99 | example that can be used for fine-tuning.
100 |
101 |
102 |
103 |
104 |
105 | ## Fine-tuning (with `arc-finetuning-cli`)
106 |
107 | After you've created your finetuning examples (you'll need at least 10 of them),
108 | you can submit a job to OpenAI to finetune an LLM on them. To do so, we have a
109 | convenient command line tool, that is powered by LlamaIndex plugins such as
110 | `llama-index-finetuning`.
111 |
112 | ```sh
113 | arc finetuning cli tool.
114 |
115 | options:
116 | -h, --help show this help message and exit
117 |
118 | commands:
119 | {evaluate,finetune,job-status}
120 | evaluate Evaluation of ARC Task predictions with LLM and ARCTaskSolverWorkflow.
121 | finetune Finetune OpenAI LLM on ARC Task Solver examples.
122 | job-status Check the status of finetuning job.
123 | ```
124 |
125 | ### Submitting a fine-tuning job
126 |
127 | To submit a fine-tuning job, use any of the following three `finetune` command:
128 |
129 | ```sh
130 | # submit a new finetune job using the specified llm
131 | arc-finetuning-cli finetune --llm gpt-4o-2024-08-06
132 |
133 | # submit a new finetune job that continues from previously finetuned model
134 | arc-finetuning-cli finetune --llm gpt-4o-2024-08-06 --start-job-id ftjob-TqJd5Nfe3GIiScyTTJH56l61
135 |
136 | # submit a new finetune job that continues from the most recent finetuned model
137 | arc-finetuning-cli finetune --continue-latest
138 | ```
139 |
140 | The commands above will take care of compiling all of the single finetuning json
141 | examples (i.e. stored in `finetuning_examples/`) into a single `jsonl` file that
142 | is then passed to OpenAI finetuning API.
143 |
144 | ### Checking the status of a fine-tuning job
145 |
146 | After submitting a job, you can check its status using the below cli commands:
147 |
148 | ```sh
149 | arc-finetuning-cli job-status -j ftjob-WYySY3iGYpfiTbSDeKDZO0YL -m gpt-4o-2024-08-06
150 |
151 | # or check status of the latest job submission
152 | arc-finetuning-cli job-status --latest
153 | ```
154 |
155 | ## Evaluation
156 |
157 | You can evaluate the `ARCTaskSolverWorkflow` and a specified LLM on the ARC test
158 | dataset. You can even supply a fine-tuned LLM here.
159 |
160 | ```sh
161 | # evaluate ARCTaskSolverWorkflow single attempt with gpt-4o
162 | arc-finetuning-cli evaluate --llm gpt-4o-2024-08-06
163 |
164 | # evaluate ARCTaskSolverWorkflow single attempt with a previously fine-tuned gpt-4o
165 | arc-finetuning-cli evaluate --llm gpt-4o-2024-08-06 --start-job-id ftjob-TqJd5Nfe3GIiScyTTJH56l61
166 | ```
167 |
168 | You can also specify certain parameters to control the speed of the execution so
169 | as to not run into `RateLimitError`'s from OpenAI.
170 |
171 | ```sh
172 | arc-finetuning-cli evaluate --llm gpt-4o-2024-08-06 --batch-size 5 --num-workers 3 --sleep 10
173 | ```
174 |
175 | In the above command, `batch-size` refers to the number of test cases handled in
176 | single batch. In total, there are 400 test cases. Moreover, `num-workers` is the
177 | maximum number of async calls allowed to be made to OpenAI API at any given moment.
178 | Finally, `sleep` is the amount of time in seconds the execution halts before moving
179 | onto the next batch of test cases.
180 |
--------------------------------------------------------------------------------
/recipes/llamaindex/arc_finetuning_st/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/streamlit/cookbook/cae857225ae429b62351a281c49abfbd5346d08f/recipes/llamaindex/arc_finetuning_st/__init__.py
--------------------------------------------------------------------------------
/recipes/llamaindex/arc_finetuning_st/cli/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/streamlit/cookbook/cae857225ae429b62351a281c49abfbd5346d08f/recipes/llamaindex/arc_finetuning_st/cli/__init__.py
--------------------------------------------------------------------------------
/recipes/llamaindex/arc_finetuning_st/cli/command_line.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import asyncio
3 | import json
4 | from os import listdir
5 | from pathlib import Path
6 | from typing import Any, List, Optional, cast
7 |
8 | from llama_index.llms.openai import OpenAI
9 |
10 | from arc_finetuning_st.cli.evaluation import batch_runner
11 | from arc_finetuning_st.cli.finetune import (
12 | FINETUNE_JOBS_FILENAME,
13 | check_job_status,
14 | prepare_finetuning_jsonl_file,
15 | submit_finetune_job,
16 | )
17 | from arc_finetuning_st.workflows.arc_task_solver import (
18 | ARCTaskSolverWorkflow,
19 | WorkflowOutput,
20 | )
21 |
22 | SINGLE_EXAMPLE_JSON_PATH = Path(
23 | Path(__file__).parents[2].absolute(), "finetuning_examples"
24 | )
25 | FINETUNING_ASSETS_PATH = Path(
26 | Path(__file__).parents[2].absolute(), "finetuning_assets"
27 | )
28 |
29 |
30 | def handle_evaluate(
31 | llm: str,
32 | batch_size: int,
33 | num_workers: int,
34 | verbose: bool,
35 | sleep: int,
36 | **kwargs: Any,
37 | ) -> None:
38 | data_path = Path(
39 | Path(__file__).parents[2].absolute(), "data", "evaluation"
40 | )
41 | task_paths = [data_path / t for t in listdir(data_path)]
42 | llm = OpenAI(llm)
43 | w = ARCTaskSolverWorkflow(llm=llm, timeout=None)
44 | results = asyncio.run(
45 | batch_runner(
46 | w,
47 | task_paths[:10],
48 | verbose=verbose,
49 | batch_size=batch_size,
50 | num_workers=num_workers,
51 | sleep=sleep,
52 | )
53 | )
54 | results = cast(List[WorkflowOutput], results)
55 | num_solved = sum(el.passing for el in results)
56 | print(
57 | f"Solved: {num_solved}\nTotal Tasks:{len(results)}\nAverage Solve Rate: {float(num_solved) / len(results)}"
58 | )
59 |
60 |
61 | def handle_finetune_job_submit(
62 | llm: str,
63 | start_job_id: Optional[str],
64 | continue_latest: bool = False,
65 | **kwargs: Any,
66 | ) -> None:
67 | prepare_finetuning_jsonl_file(
68 | json_path=SINGLE_EXAMPLE_JSON_PATH, assets_path=FINETUNING_ASSETS_PATH
69 | )
70 | if continue_latest:
71 | try:
72 | with open(FINETUNING_ASSETS_PATH / FINETUNE_JOBS_FILENAME) as f:
73 | lines = f.read().splitlines()
74 | metadata_str = lines[-1]
75 | metadata = json.loads(metadata_str)
76 | start_job_id = metadata["start_job_id"]
77 | llm = metadata["model"]
78 | except FileNotFoundError:
79 | # no previous finetune model
80 | raise ValueError(
81 | "Missing `finetuning_jobs.jsonl` file. Have you submitted a prior job?"
82 | )
83 |
84 | submit_finetune_job(
85 | llm=llm,
86 | start_job_id=start_job_id,
87 | assets_path=FINETUNING_ASSETS_PATH,
88 | )
89 |
90 |
91 | def handle_check_finetune_job(
92 | start_job_id: Optional[str],
93 | llm: Optional[str],
94 | latest: bool,
95 | **kwargs: Any,
96 | ) -> None:
97 | if latest:
98 | try:
99 | with open(FINETUNING_ASSETS_PATH / FINETUNE_JOBS_FILENAME) as f:
100 | lines = f.read().splitlines()
101 | metadata_str = lines[-1]
102 | metadata = json.loads(metadata_str)
103 | start_job_id = metadata["start_job_id"]
104 | llm = metadata["model"]
105 | except FileNotFoundError:
106 | raise ValueError(
107 | "No finetuning_jobs.json file exists. You likely haven't submitted a job yet."
108 | )
109 | if not latest and (start_job_id is None or llm is None):
110 | raise ValueError(
111 | "If not `use_latest` then must provide `start_job_id` and `llm`."
112 | )
113 |
114 | # make type checking happy
115 | if start_job_id and llm:
116 | check_job_status(
117 | start_job_id=start_job_id,
118 | llm=llm,
119 | assets_path=FINETUNING_ASSETS_PATH,
120 | )
121 |
122 |
123 | def main() -> None:
124 | parser = argparse.ArgumentParser(description="arc-finetuning cli tool.")
125 |
126 | # Subparsers
127 | subparsers = parser.add_subparsers(
128 | title="commands", dest="command", required=True
129 | )
130 |
131 | # evaluate command
132 | evaluate_parser = subparsers.add_parser(
133 | "evaluate",
134 | help="Evaluation of ARC Task predictions with LLM and ARCTaskSolverWorkflow.",
135 | )
136 | evaluate_parser.add_argument(
137 | "-m",
138 | "--llm",
139 | type=str,
140 | default="gpt-4o",
141 | help="The OpenAI LLM model to use with the Workflow.",
142 | )
143 | evaluate_parser.add_argument("-b", "--batch-size", type=int, default=5)
144 | evaluate_parser.add_argument("-w", "--num-workers", type=int, default=3)
145 | evaluate_parser.add_argument(
146 | "-v", "--verbose", action=argparse.BooleanOptionalAction
147 | )
148 | evaluate_parser.add_argument("-s", "--sleep", type=int, default=10)
149 | evaluate_parser.set_defaults(
150 | func=lambda args: handle_evaluate(**vars(args))
151 | )
152 |
153 | # finetune command
154 | finetune_parser = subparsers.add_parser(
155 | "finetune", help="Finetune OpenAI LLM on ARC Task Solver examples."
156 | )
157 | finetune_parser.add_argument(
158 | "-m",
159 | "--llm",
160 | type=str,
161 | default="gpt-4o-2024-08-06",
162 | help="The OpenAI LLM model to finetune.",
163 | )
164 | finetune_parser.add_argument(
165 | "-j",
166 | "--start-job-id",
167 | type=str,
168 | default=None,
169 | help="Previously started job id, to continue finetuning.",
170 | )
171 | finetune_parser.add_argument(
172 | "--continue-latest", action=argparse.BooleanOptionalAction
173 | )
174 | finetune_parser.set_defaults(
175 | func=lambda args: handle_finetune_job_submit(**vars(args))
176 | )
177 |
178 | # job status command
179 | job_status_parser = subparsers.add_parser(
180 | "job-status", help="Check the status of finetuning job."
181 | )
182 | job_status_parser.add_argument(
183 | "-j",
184 | "--start-job-id",
185 | type=str,
186 | default=None,
187 | help="Previously started job id, to continue finetuning.",
188 | )
189 | job_status_parser.add_argument(
190 | "-m",
191 | "--llm",
192 | type=str,
193 | default="gpt-4o-2024-08-06",
194 | help="The OpenAI LLM model to finetune.",
195 | )
196 | job_status_parser.add_argument(
197 | "--latest",
198 | action=argparse.BooleanOptionalAction,
199 | help="If set, checks the status of the last submitted job.",
200 | )
201 | job_status_parser.set_defaults(
202 | func=lambda args: handle_check_finetune_job(**vars(args))
203 | )
204 |
205 | # Parse the command-line arguments
206 | args = parser.parse_args()
207 |
208 | # Call the appropriate function based on the command
209 | args.func(args)
210 |
211 |
212 | if __name__ == "__main__":
213 | main()
214 |
--------------------------------------------------------------------------------
/recipes/llamaindex/arc_finetuning_st/cli/evaluation.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from os import listdir
3 | from pathlib import Path
4 | from typing import Any, List, cast
5 |
6 | from llama_index.core.async_utils import chunks
7 | from llama_index.llms.openai import OpenAI
8 |
9 | from arc_finetuning_st.workflows.arc_task_solver import (
10 | ARCTaskSolverWorkflow,
11 | WorkflowOutput,
12 | )
13 |
14 | DATA_PATH = Path(Path(__file__).parents[1].absolute(), "data", "evaluation")
15 |
16 |
17 | async def batch_runner(
18 | workflow: ARCTaskSolverWorkflow,
19 | task_paths: List[Path],
20 | batch_size: int = 5,
21 | verbose: bool = False,
22 | sleep: int = 10,
23 | num_workers: int = 3,
24 | ) -> List[Any]:
25 | output: List[Any] = []
26 | sem = asyncio.Semaphore(num_workers)
27 | for task_chunk in chunks(task_paths, batch_size):
28 | task_chunk = (
29 | workflow.load_and_run_task(task_path=task_path, sem=sem)
30 | for task_path in task_chunk
31 | if task_path is not None
32 | )
33 | output_chunk = await asyncio.gather(*task_chunk)
34 | output.extend(output_chunk)
35 | if verbose:
36 | print(
37 | f"Completed {len(output)} out of {len(task_paths)} tasks",
38 | flush=True,
39 | )
40 | await asyncio.sleep(sleep)
41 | return output
42 |
43 |
44 | async def main() -> None:
45 | task_paths = [DATA_PATH / t for t in listdir(DATA_PATH)]
46 | w = ARCTaskSolverWorkflow(
47 | timeout=None, verbose=False, llm=OpenAI("gpt-4o")
48 | )
49 | results = await batch_runner(w, task_paths[:10], verbose=True)
50 | results = cast(List[WorkflowOutput], results)
51 | num_solved = sum(el.passing for el in results)
52 | print(
53 | f"Solved: {num_solved}\nTotal Tasks:{len(results)}\nAverage Solve Rate: {float(num_solved) / len(results)}"
54 | )
55 |
56 |
57 | if __name__ == "__main__":
58 | asyncio.run(main())
59 |
--------------------------------------------------------------------------------
/recipes/llamaindex/arc_finetuning_st/cli/finetune.py:
--------------------------------------------------------------------------------
1 | import json
2 | from os import listdir
3 | from pathlib import Path
4 | from typing import Optional
5 |
6 | from llama_index.finetuning import OpenAIFinetuneEngine
7 |
8 | SINGLE_EXAMPLE_JSON_PATH = Path(
9 | Path(__file__).parents[1].absolute(), "finetuning_examples"
10 | )
11 |
12 | FINETUNING_ASSETS_PATH = Path(
13 | Path(__file__).parents[1].absolute(), "finetuning_assets"
14 | )
15 |
16 | FINETUNE_JSONL_FILENAME = "finetuning.jsonl"
17 | FINETUNE_JOBS_FILENAME = "finetuning_jobs.jsonl"
18 |
19 |
20 | def prepare_finetuning_jsonl_file(
21 | json_path: Path = SINGLE_EXAMPLE_JSON_PATH,
22 | assets_path: Path = FINETUNING_ASSETS_PATH,
23 | ) -> None:
24 | """Read all json files from data path and write a jsonl file."""
25 | with open(assets_path / FINETUNE_JSONL_FILENAME, "w") as jsonl_out:
26 | for json_name in listdir(json_path):
27 | with open(json_path / json_name) as f:
28 | for line in f:
29 | jsonl_out.write(line)
30 | jsonl_out.write("\n")
31 |
32 |
33 | def submit_finetune_job(
34 | llm: str = "gpt-4o-2024-08-06",
35 | start_job_id: Optional[str] = None,
36 | assets_path: Path = FINETUNING_ASSETS_PATH,
37 | ) -> None:
38 | """Submit finetuning job."""
39 | finetune_engine = OpenAIFinetuneEngine(
40 | llm,
41 | (assets_path / FINETUNE_JSONL_FILENAME).as_posix(),
42 | start_job_id=start_job_id,
43 | validate_json=False,
44 | )
45 | finetune_engine.finetune()
46 |
47 | with open(assets_path / FINETUNE_JOBS_FILENAME, "a+") as f:
48 | metadata = {
49 | "model": llm,
50 | "start_job_id": finetune_engine._start_job.id,
51 | }
52 | json.dump(metadata, f)
53 | f.write("\n")
54 |
55 | print(finetune_engine.get_current_job())
56 |
57 |
58 | def check_job_status(
59 | start_job_id: str,
60 | llm: str = "gpt-4o-2024-08-06",
61 | assets_path: Path = FINETUNING_ASSETS_PATH,
62 | ) -> None:
63 | """Check on status of most recent submitted finetuning job."""
64 |
65 | finetune_engine = OpenAIFinetuneEngine(
66 | llm,
67 | (assets_path / FINETUNE_JSONL_FILENAME).as_posix(),
68 | start_job_id=start_job_id,
69 | validate_json=False,
70 | )
71 |
72 | print(finetune_engine.get_current_job())
73 |
--------------------------------------------------------------------------------
/recipes/llamaindex/arc_finetuning_st/finetuning/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/streamlit/cookbook/cae857225ae429b62351a281c49abfbd5346d08f/recipes/llamaindex/arc_finetuning_st/finetuning/__init__.py
--------------------------------------------------------------------------------
/recipes/llamaindex/arc_finetuning_st/finetuning/finetuning_example.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 | from typing import Annotated, Any, Callable, List
4 |
5 | from llama_index.core.base.llms.types import ChatMessage, MessageRole
6 | from llama_index.core.bridge.pydantic import BaseModel, Field, WrapSerializer
7 |
8 | from arc_finetuning_st.finetuning.templates import (
9 | ASSISTANT_TEMPLATE,
10 | SYSTEM_MESSAGE,
11 | USER_CRITIQUE_TEMPLATE,
12 | USER_TASK_TEMPLATE,
13 | )
14 | from arc_finetuning_st.workflows.models import Attempt
15 |
16 |
17 | def remove_additional_kwargs(value: Any, handler: Callable, info: Any) -> Any:
18 | partial_result = handler(value, info)
19 | del partial_result["additional_kwargs"]
20 | return partial_result
21 |
22 |
23 | class FineTuningExample(BaseModel):
24 | messages: List[
25 | Annotated[ChatMessage, WrapSerializer(remove_additional_kwargs)]
26 | ]
27 | task_name: str = Field(exclude=True)
28 |
29 | @classmethod
30 | def from_attempts(
31 | cls,
32 | task_name: str,
33 | examples: str,
34 | test_input: str,
35 | attempts: List[Attempt],
36 | system_message: str = SYSTEM_MESSAGE,
37 | user_task_template: str = USER_TASK_TEMPLATE,
38 | user_critique_template: str = USER_CRITIQUE_TEMPLATE,
39 | assistant_template: str = ASSISTANT_TEMPLATE,
40 | ) -> "FineTuningExample":
41 | messages = [
42 | ChatMessage(role=MessageRole.SYSTEM, content=system_message),
43 | ChatMessage(
44 | role=MessageRole.USER,
45 | content=user_task_template.format(
46 | examples=examples, test_input=test_input
47 | ),
48 | ),
49 | ]
50 | for a in attempts:
51 | messages.extend(
52 | [
53 | ChatMessage(
54 | role=MessageRole.ASSISTANT,
55 | content=assistant_template.format(
56 | predicted_output=str(a.prediction),
57 | rationale=a.prediction.rationale,
58 | ),
59 | ),
60 | ChatMessage(
61 | role=MessageRole.USER,
62 | content=user_critique_template.format(
63 | critique=str(a.critique)
64 | ),
65 | ),
66 | ]
67 | )
68 |
69 | # always end with an asst message or else openai finetuning job will failt
70 | if a.critique == "This predicted output is correct.":
71 | final_asst_message = ChatMessage(
72 | role=MessageRole.ASSISTANT,
73 | content="Glad, we were able to solve the puzzle!",
74 | )
75 | else:
76 | final_asst_message = ChatMessage(
77 | role=MessageRole.ASSISTANT,
78 | content="Thanks for the feedback. I'll incorporate this into my next prediction.",
79 | )
80 |
81 | messages.append(final_asst_message)
82 | return cls(messages=messages, task_name=task_name)
83 |
84 | def to_json(self) -> str:
85 | data = self.model_dump()
86 | return json.dumps(data, indent=4)
87 |
88 | def write_json(self, dirpath: Path) -> None:
89 | data = self.model_dump()
90 | with open(Path(dirpath, self.task_name), "w") as f:
91 | json.dump(data, f)
92 |
--------------------------------------------------------------------------------
/recipes/llamaindex/arc_finetuning_st/finetuning/templates.py:
--------------------------------------------------------------------------------
1 | SYSTEM_MESSAGE = """You are a bot that is very good at solving puzzles. You will work with the user who will present you a new puzzle to solve it.
2 | The puzzle consists of a list of EXAMPLES, each containing an INPUT/OUTPUT pair describing a pattern which is shared amongst all examples.
3 | The user will also provide a TEST INPUT for which you are to produce a predicted OUTPUT that follows the common pattern of all the examples.
4 |
5 | Your task is collaborate with the user in order to solve the problem.
6 | """
7 |
8 | USER_TASK_TEMPLATE = """Here is a new task to solve:
9 | EXAMPLES:
10 | {examples}
11 |
12 | TEST INPUT:
13 | {test_input}
14 | """
15 |
16 | # attempt.prediction
17 | ASSISTANT_TEMPLATE = """
18 | PREDICTED OUTPUT:
19 | {predicted_output}
20 |
21 | RATIONALE:
22 | {rationale}
23 | """
24 |
25 | # attempt.critique
26 | USER_CRITIQUE_TEMPLATE = """
27 | {critique}
28 | """
29 |
--------------------------------------------------------------------------------
/recipes/llamaindex/arc_finetuning_st/streamlit/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/streamlit/cookbook/cae857225ae429b62351a281c49abfbd5346d08f/recipes/llamaindex/arc_finetuning_st/streamlit/__init__.py
--------------------------------------------------------------------------------
/recipes/llamaindex/arc_finetuning_st/streamlit/app.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | import streamlit as st
4 | from llama_index.core.tools.function_tool import async_to_sync
5 |
6 | from arc_finetuning_st.streamlit.controller import Controller
7 |
8 | # startup
9 | st.set_page_config(layout="wide")
10 |
11 |
12 | @st.cache_resource
13 | def startup() -> Tuple[Controller,]:
14 | controller = Controller()
15 | return (controller,)
16 |
17 |
18 | (controller,) = startup()
19 |
20 | # states
21 | if "show_finetuning_preview_dialog" not in st.session_state:
22 | st.session_state["show_finetuning_preview_dialog"] = True
23 | if "disable_continue_button" not in st.session_state:
24 | st.session_state["disable_continue_button"] = True
25 | if "disable_start_button" not in st.session_state:
26 | st.session_state["disable_start_button"] = False
27 | if "disable_abort_button" not in st.session_state:
28 | st.session_state["disable_abort_button"] = True
29 | if "disable_preview_button" not in st.session_state:
30 | st.session_state["disable_preview_button"] = True
31 | if "metric_value" not in st.session_state:
32 | st.session_state["metric_value"] = "N/A"
33 |
34 | logo = '[
](https://github.com/run-llama/llama-agents "Check out the llama-agents Github repo!")'
35 | st.title("ARC Task Solver with Human Input")
36 | st.markdown(
37 | f"{logo} _Powered with LlamaIndex Worfklows_",
38 | unsafe_allow_html=True,
39 | )
40 |
41 | # sidebar
42 | with st.sidebar:
43 | task_selection = st.radio(
44 | label="Tasks",
45 | options=controller.task_file_names,
46 | index=0,
47 | on_change=controller.selectbox_selection_change_handler,
48 | key="selected_task",
49 | format_func=controller.radio_format_task_name,
50 | )
51 |
52 | train_col, test_col = st.columns(
53 | [1, 1], vertical_alignment="top", gap="medium"
54 | )
55 |
56 | with train_col:
57 | st.subheader("Train Examples")
58 | with st.container():
59 | selected_task = st.session_state.selected_task
60 | if selected_task:
61 | task = controller.load_task(selected_task)
62 | num_examples = len(task["train"])
63 | tabs = st.tabs(
64 | [f"Example {ix}" for ix in range(1, num_examples + 1)]
65 | )
66 | for ix, tab in enumerate(tabs):
67 | with tab:
68 | left, right = st.columns(
69 | [1, 1], vertical_alignment="top", gap="medium"
70 | )
71 | with left:
72 | ex = task["train"][ix]
73 | grid = ex["input"]
74 | fig = Controller.plot_grid(grid, kind="input")
75 | st.plotly_chart(fig, use_container_width=True)
76 |
77 | with right:
78 | ex = task["train"][ix]
79 | grid = ex["output"]
80 | fig = Controller.plot_grid(grid, kind="output")
81 | st.plotly_chart(fig, use_container_width=True)
82 |
83 |
84 | with test_col:
85 | header_col, start_col, preview_col = st.columns(
86 | [4, 1, 2], vertical_alignment="bottom", gap="small"
87 | )
88 | with header_col:
89 | st.subheader("Test")
90 | with start_col:
91 | st.button(
92 | "start",
93 | on_click=async_to_sync(controller.handle_prediction_click),
94 | use_container_width=True,
95 | type="primary",
96 | disabled=st.session_state.get("disable_start_button"),
97 | )
98 | with preview_col:
99 | st.button(
100 | "fine-tuning example",
101 | on_click=async_to_sync(controller.handle_finetuning_preview_click),
102 | use_container_width=True,
103 | disabled=st.session_state.get("disable_preview_button"),
104 | key="preview_button",
105 | )
106 |
107 | with st.container():
108 | selected_task = st.session_state.selected_task
109 | if selected_task:
110 | task = controller.load_task(selected_task)
111 | num_cases = len(task["test"])
112 | tabs = st.tabs(
113 | [f"Test Case {ix}" for ix in range(1, num_cases + 1)]
114 | )
115 | for ix, tab in enumerate(tabs):
116 | with tab:
117 | left, right = st.columns(
118 | [1, 1], vertical_alignment="top", gap="medium"
119 | )
120 | with left:
121 | ex = task["test"][ix]
122 | grid = ex["input"]
123 | fig = Controller.plot_grid(grid, kind="input")
124 | st.plotly_chart(fig, use_container_width=True)
125 |
126 | with right:
127 | prediction_fig = st.session_state.get(
128 | "prediction", None
129 | )
130 | if prediction_fig:
131 | st.plotly_chart(
132 | prediction_fig,
133 | use_container_width=True,
134 | key="prediction",
135 | )
136 |
137 | # metrics and past attempts
138 | with st.container():
139 | metric_col, critique_col = st.columns(
140 | [1, 7], vertical_alignment="top"
141 | )
142 | with metric_col:
143 | metric_value = st.session_state.get("metric_value")
144 | st.markdown(body="Passing")
145 | st.markdown(body=f"# {metric_value}")
146 |
147 | with critique_col:
148 | st.markdown(body="Critique of Attempt")
149 | st.text_area(
150 | label="This critique is passed to the LLM to generate a new prediction.",
151 | key="critique",
152 | help=(
153 | "An LLM was prompted to critique the prediction on why it might not fit the pattern. "
154 | "This critique is passed in the PROMPT in the next prediction attempt. "
155 | "Feel free to make edits to the critique or use your own."
156 | ),
157 | )
158 |
159 | with st.expander("Past Attempts"):
160 | st.dataframe(
161 | controller.attempts_history_df,
162 | hide_index=True,
163 | selection_mode="single-row",
164 | on_select=controller.handle_workflow_run_selection,
165 | column_order=(
166 | "attempt #",
167 | "passing",
168 | "critique",
169 | "rationale",
170 | ),
171 | key="attempts_history_df",
172 | use_container_width=True,
173 | )
174 |
175 | with st.container():
176 | continue_col, abort_col = st.columns([3, 1])
177 | with continue_col:
178 | st.button(
179 | "continue",
180 | on_click=async_to_sync(controller.handle_prediction_click),
181 | use_container_width=True,
182 | disabled=st.session_state.get("disable_continue_button"),
183 | key="continue_button",
184 | type="primary",
185 | )
186 |
187 | with abort_col:
188 |
189 | @st.dialog("Are you sure you want to abort the session?")
190 | def abort_solving() -> None:
191 | st.write(
192 | "Confirm that you want to abort the session by clicking 'confirm' button below."
193 | )
194 | if st.button("Confirm"):
195 | controller.reset()
196 | st.rerun()
197 |
198 | st.button(
199 | "abort",
200 | on_click=abort_solving,
201 | use_container_width=True,
202 | disabled=st.session_state.get("disable_abort_button"),
203 | )
204 |
--------------------------------------------------------------------------------
/recipes/llamaindex/arc_finetuning_st/streamlit/controller.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import json
3 | import logging
4 | from os import listdir
5 | from pathlib import Path
6 | from typing import Any, List, Literal, Optional, cast
7 |
8 | import pandas as pd
9 | import plotly.express as px
10 | import streamlit as st
11 | from llama_index.core.workflow.handler import WorkflowHandler
12 | from llama_index.llms.openai import OpenAI
13 |
14 | from arc_finetuning_st.finetuning.finetuning_example import FineTuningExample
15 | from arc_finetuning_st.workflows.arc_task_solver import (
16 | ARCTaskSolverWorkflow,
17 | WorkflowOutput,
18 | )
19 | from arc_finetuning_st.workflows.models import Attempt, Prediction
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 |
24 | class Controller:
25 | def __init__(self) -> None:
26 | self._handler: Optional[WorkflowHandler] = None
27 | self._attempts: List[Attempt] = []
28 | self._passing_results: List[bool] = []
29 | self._data_path = Path(
30 | Path(__file__).parents[2].absolute(), "data", "training"
31 | )
32 | self._finetuning_examples_path = Path(
33 | Path(__file__).parents[2].absolute(), "finetuning_examples"
34 | )
35 | self._finetuning_examples_path.mkdir(exist_ok=True, parents=True)
36 |
37 | def reset(self) -> None:
38 | # clear prediction
39 | st.session_state.prediction = None
40 | st.session_state.disable_continue_button = True
41 | st.session_state.disable_abort_button = True
42 | st.session_state.disable_preview_button = True
43 | st.session_state.disable_start_button = False
44 | st.session_state.critique = None
45 | st.session_state.metric_value = "N/A"
46 |
47 | self._handler = None
48 | self._attempts = []
49 | self._passing_results = []
50 |
51 | def selectbox_selection_change_handler(self) -> None:
52 | # only reset states
53 | # loading of task is delegated to relevant calls made with each
54 | # streamlit element
55 | self.reset()
56 |
57 | @staticmethod
58 | def plot_grid(
59 | grid: List[List[int]],
60 | kind: Literal["input", "output", "prediction", "latest prediction"],
61 | ) -> Any:
62 | m = len(grid)
63 | n = len(grid[0])
64 | fig = px.imshow(
65 | grid,
66 | text_auto=True,
67 | labels={"x": f"{kind.title()}
{m}x{n}"},
68 | )
69 | fig.update_coloraxes(showscale=False)
70 | fig.update_layout(
71 | yaxis={"visible": False},
72 | xaxis={"visible": True, "showticklabels": False},
73 | margin=dict(
74 | l=20,
75 | r=20,
76 | b=20,
77 | t=20,
78 | ),
79 | )
80 | return fig
81 |
82 | async def show_progress_bar(self, handler: WorkflowHandler) -> None:
83 | progress_text_template = "{event} completed. Next step in progress."
84 | my_bar = st.progress(0, text="Workflow run in progress. Please wait.")
85 | num_steps = 5.0
86 | current_step = 1
87 | async for ev in handler.stream_events():
88 | my_bar.progress(
89 | current_step / num_steps,
90 | text=progress_text_template.format(event=type(ev).__name__),
91 | )
92 | current_step += 1
93 | my_bar.empty()
94 |
95 | def handle_abort_click(self) -> None:
96 | self.reset()
97 |
98 | async def handle_prediction_click(self) -> None:
99 | """Run workflow to generate prediction."""
100 | selected_task = st.session_state.selected_task
101 | if selected_task:
102 | task = self.load_task(selected_task)
103 | w = ARCTaskSolverWorkflow(
104 | timeout=None, verbose=False, llm=OpenAI("gpt-4o")
105 | )
106 |
107 | if not self._handler: # start a new solver
108 | handler = w.run(task=task)
109 |
110 | else: # continuing from past Workflow execution
111 | # need to reset this queue otherwise will use nested event loops
112 | self._handler.ctx._streaming_queue = asyncio.Queue()
113 |
114 | # use the critique and prediction str from streamlit
115 | critique = st.session_state.get("critique")
116 | self._attempts[-1].critique = critique
117 | await self._handler.ctx.set("attempts", self._attempts)
118 |
119 | # run Workflow
120 | handler = w.run(ctx=self._handler.ctx, task=task)
121 |
122 | # progress bar
123 | _ = asyncio.create_task(self.show_progress_bar(handler))
124 |
125 | res: WorkflowOutput = await handler
126 |
127 | handler = cast(WorkflowHandler, handler)
128 | self._handler = handler
129 | self._passing_results.append(res.passing)
130 | self._attempts = res.attempts
131 |
132 | # update streamlit states
133 | grid = Prediction.prediction_str_to_int_array(
134 | prediction=str(res.attempts[-1].prediction)
135 | )
136 | prediction_fig = Controller.plot_grid(
137 | grid, kind="latest prediction"
138 | )
139 | st.session_state.prediction = prediction_fig
140 | st.session_state.critique = str(res.attempts[-1].critique)
141 | st.session_state.disable_continue_button = False
142 | st.session_state.disable_abort_button = False
143 | st.session_state.disable_preview_button = False
144 | st.session_state.disable_start_button = True
145 | metric_value = "✅" if res.passing else "❌"
146 | st.session_state.metric_value = metric_value
147 |
148 | @property
149 | def saved_finetuning_examples(self) -> List[str]:
150 | return listdir(self._finetuning_examples_path)
151 |
152 | @property
153 | def task_file_names(self) -> List[str]:
154 | return listdir(self._data_path)
155 |
156 | def radio_format_task_name(self, selected_task: str) -> str:
157 | if selected_task in self.saved_finetuning_examples:
158 | return f"{selected_task} ✅"
159 | return selected_task
160 |
161 | def load_task(self, selected_task: str) -> Any:
162 | task_path = Path(self._data_path, selected_task)
163 |
164 | with open(task_path) as f:
165 | task = json.load(f)
166 | return task
167 |
168 | @property
169 | def passing(self) -> Optional[bool]:
170 | if self._passing_results:
171 | return self._passing_results[-1]
172 | return None
173 |
174 | @property
175 | def attempts_history_df(
176 | self,
177 | ) -> pd.DataFrame:
178 | if self._attempts:
179 | attempt_number_list: List[int] = []
180 | passings: List[str] = []
181 | rationales: List[str] = []
182 | critiques: List[str] = []
183 | predictions: List[str] = []
184 | for ix, a in enumerate(self._attempts):
185 | passings = ["✅" if a.passing else "❌"] + passings
186 | rationales = [a.prediction.rationale] + rationales
187 | predictions = [str(a.prediction)] + predictions
188 | critiques = [str(a.critique)] + critiques
189 | attempt_number_list = [ix + 1] + attempt_number_list
190 | return pd.DataFrame(
191 | {
192 | "attempt #": attempt_number_list,
193 | "passing": passings,
194 | "rationale": rationales,
195 | "critique": critiques,
196 | # hidden from UI
197 | "prediction": predictions,
198 | }
199 | )
200 | return pd.DataFrame(
201 | {
202 | "attempt #": [],
203 | "passing": [],
204 | "rationale": [],
205 | "critique": [],
206 | # hidden from UI
207 | "prediction": [],
208 | }
209 | )
210 |
211 | def handle_workflow_run_selection(self) -> None:
212 | @st.dialog("Past Attempt")
213 | def _display_attempt(
214 | fig: Any, rationale: str, critique: str, passing: bool
215 | ) -> None:
216 | st.plotly_chart(
217 | fig,
218 | use_container_width=True,
219 | key="prediction",
220 | )
221 | st.markdown(body=f"### Passing\n{passing}")
222 | st.markdown(body=f"### Rationale\n{rationale}")
223 | st.markdown(body=f"### Critique\n{critique}")
224 |
225 | selected_rows = (
226 | st.session_state.get("attempts_history_df")
227 | .get("selection")
228 | .get("rows")
229 | )
230 |
231 | if selected_rows:
232 | row_ix = selected_rows[0]
233 | df_row = self.attempts_history_df.iloc[row_ix]
234 |
235 | grid = Prediction.prediction_str_to_int_array(
236 | prediction=df_row["prediction"]
237 | )
238 | prediction_fig = Controller.plot_grid(grid, kind="prediction")
239 |
240 | _display_attempt(
241 | fig=prediction_fig,
242 | rationale=df_row["rationale"],
243 | critique=df_row["critique"],
244 | passing=df_row["passing"],
245 | )
246 |
247 | async def handle_finetuning_preview_click(self) -> None:
248 | if self._handler:
249 | st.session_state.show_finetuning_preview_dialog = True
250 | prompt_vars = await self._handler.ctx.get("prompt_vars")
251 |
252 | @st.dialog("Finetuning Example", width="large")
253 | def _display_finetuning_example() -> None:
254 | nonlocal prompt_vars
255 |
256 | finetuning_example = FineTuningExample.from_attempts(
257 | task_name=st.session_state.selected_task,
258 | attempts=self._attempts,
259 | examples=prompt_vars["examples"],
260 | test_input=prompt_vars["test_input"],
261 | )
262 |
263 | with st.container(height=500, border=False):
264 | save_col, close_col = st.columns([1, 1])
265 | with save_col:
266 | if st.button("Save", use_container_width=True):
267 | finetuning_example.write_json(
268 | self._finetuning_examples_path,
269 | )
270 | st.session_state.show_finetuning_preview_dialog = (
271 | False
272 | )
273 | st.rerun()
274 | with close_col:
275 | if st.button("Close", use_container_width=True):
276 | st.session_state.show_finetuning_preview_dialog = (
277 | False
278 | )
279 | st.rerun()
280 |
281 | st.code(
282 | finetuning_example.to_json(),
283 | language="json",
284 | wrap_lines=True,
285 | )
286 |
287 | if st.session_state.show_finetuning_preview_dialog:
288 | _display_finetuning_example()
289 |
--------------------------------------------------------------------------------
/recipes/llamaindex/arc_finetuning_st/workflows/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/streamlit/cookbook/cae857225ae429b62351a281c49abfbd5346d08f/recipes/llamaindex/arc_finetuning_st/workflows/__init__.py
--------------------------------------------------------------------------------
/recipes/llamaindex/arc_finetuning_st/workflows/arc_task_solver.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import json
3 | from pathlib import Path
4 | from typing import Any, Dict, List, Optional, cast
5 |
6 | from llama_index.core.bridge.pydantic import BaseModel
7 | from llama_index.core.llms import LLM
8 | from llama_index.core.workflow import (
9 | Context,
10 | StartEvent,
11 | StopEvent,
12 | Workflow,
13 | step,
14 | )
15 |
16 | from arc_finetuning_st.workflows.events import (
17 | EvaluationEvent,
18 | FormatTaskEvent,
19 | PredictionEvent,
20 | )
21 | from arc_finetuning_st.workflows.models import Attempt, Critique, Prediction
22 | from arc_finetuning_st.workflows.prompts import (
23 | CORRECTION_PROMPT_TEMPLATE,
24 | PREDICTION_PROMPT_TEMPLATE,
25 | REFLECTION_PROMPT_TEMPLATE,
26 | )
27 |
28 | EXAMPLE_TEMPLATE = """===
29 | EXAMPLE
30 |
31 | INPUT:
32 | {input}
33 |
34 | OUTPUT:
35 | {output}
36 | """
37 |
38 | PAST_ATTEMPT_TEMPLATE = """◦◦◦
39 | PAST ATTEMPT {past_attempt_number}
40 |
41 | PREDICTED_OUTPUT:
42 | {past_predicted_output}
43 |
44 | CRITIQUE:
45 | {past_critique}
46 | """
47 |
48 |
49 | class WorkflowOutput(BaseModel):
50 | passing: bool
51 | attempts: List[Attempt]
52 |
53 |
54 | class ARCTaskSolverWorkflow(Workflow):
55 | def __init__(self, llm: LLM, max_attempts: int = 3, **kwargs: Any) -> None:
56 | super().__init__(**kwargs)
57 | self.llm = llm
58 | self._max_attempts = max_attempts
59 |
60 | def _format_past_attempt(self, attempt: Attempt, attempt_num: int) -> str:
61 | return PAST_ATTEMPT_TEMPLATE.format(
62 | past_attempt_number=attempt_num,
63 | past_predicted_output=str(attempt.prediction),
64 | past_critique=str(attempt.critique) if attempt.critique else "",
65 | )
66 |
67 | @step
68 | async def format_task(
69 | self, ctx: Context, ev: StartEvent
70 | ) -> FormatTaskEvent:
71 | ctx.write_event_to_stream(ev)
72 |
73 | def _format_row(row: List[int]) -> str:
74 | return ",".join(str(el) for el in row)
75 |
76 | def pretty_print_grid(grid: List[List[int]]) -> str:
77 | formatted_rows = [_format_row(row) for row in grid]
78 | return "\n".join(formatted_rows)
79 |
80 | def format_train_example(train_pair: Dict) -> str:
81 | return EXAMPLE_TEMPLATE.format(
82 | input=pretty_print_grid(train_pair["input"]),
83 | output=pretty_print_grid(train_pair["output"]),
84 | )
85 |
86 | task = ev.get("task", {})
87 | await ctx.set("task", task)
88 |
89 | # prepare prompt_vars
90 | attempts = await ctx.get("attempts", [])
91 | if attempts:
92 | # update past predictions
93 | prompt_vars = await ctx.get("prompt_vars")
94 | formatted_past_attempts = [
95 | self._format_past_attempt(a, ix + 1)
96 | for ix, a in enumerate(attempts)
97 | ]
98 | prompt_vars.update(
99 | past_attempts="\n".join(formatted_past_attempts)
100 | )
101 | else:
102 | examples = [format_train_example(t) for t in task["train"]]
103 | prompt_vars = {
104 | "test_input": pretty_print_grid(task["test"][0]["input"]),
105 | "examples": "\n".join(examples),
106 | }
107 | await ctx.set("prompt_vars", prompt_vars)
108 |
109 | return FormatTaskEvent()
110 |
111 | @step
112 | async def prediction(
113 | self, ctx: Context, ev: FormatTaskEvent
114 | ) -> PredictionEvent | StopEvent:
115 | ctx.write_event_to_stream(ev)
116 | attempts = await ctx.get("attempts", [])
117 | attempts = cast(List[Attempt], attempts)
118 | prompt_vars = await ctx.get("prompt_vars")
119 |
120 | if attempts:
121 | # generating a correction from last Workflow run
122 | correction: Prediction = await self.llm.astructured_predict(
123 | Prediction, CORRECTION_PROMPT_TEMPLATE, **prompt_vars
124 | )
125 | attempts.append(Attempt(prediction=correction))
126 | else:
127 | # starting a new correction with no previous Workflow runs
128 | pred: Prediction = await self.llm.astructured_predict(
129 | Prediction, PREDICTION_PROMPT_TEMPLATE, **prompt_vars
130 | )
131 | attempts = [Attempt(prediction=pred)]
132 |
133 | await ctx.set("attempts", attempts)
134 | return PredictionEvent()
135 |
136 | @step
137 | async def evaluation(
138 | self, ctx: Context, ev: PredictionEvent
139 | ) -> EvaluationEvent:
140 | ctx.write_event_to_stream(ev)
141 | task = await ctx.get("task")
142 | attempts: List[Attempt] = await ctx.get("attempts")
143 | latest_prediction = attempts[-1].prediction
144 | latest_prediction_as_array = Prediction.prediction_str_to_int_array(
145 | str(latest_prediction)
146 | )
147 | ground_truth = task["test"][0]["output"]
148 |
149 | return EvaluationEvent(
150 | passing=(latest_prediction_as_array == ground_truth)
151 | )
152 |
153 | @step
154 | async def reflection(self, ctx: Context, ev: EvaluationEvent) -> StopEvent:
155 | ctx.write_event_to_stream(ev)
156 | attempts = await ctx.get("attempts")
157 | attempts = cast(List[Attempt], attempts)
158 | latest_attempt = attempts[-1]
159 |
160 | # check if passing
161 | if not ev.passing:
162 | prompt_vars = await ctx.get("prompt_vars")
163 | formatted_past_attempts = [
164 | self._format_past_attempt(a, ix + 1)
165 | for ix, a in enumerate(attempts)
166 | ]
167 | prompt_vars.update(
168 | past_attempts="\n".join(formatted_past_attempts)
169 | )
170 |
171 | # generate critique
172 | critique: Critique = await self.llm.astructured_predict(
173 | Critique, REFLECTION_PROMPT_TEMPLATE, **prompt_vars
174 | )
175 |
176 | # update states
177 | latest_attempt.critique = critique
178 | else:
179 | latest_attempt.critique = "This predicted output is correct."
180 |
181 | latest_attempt.passing = ev.passing
182 | attempts[-1] = latest_attempt
183 | await ctx.set("attempts", attempts)
184 |
185 | result = WorkflowOutput(passing=ev.passing, attempts=attempts)
186 | return StopEvent(result=result)
187 |
188 | async def load_and_run_task(
189 | self,
190 | task_path: Path,
191 | ctx: Optional[Context] = None,
192 | sem: Optional[asyncio.Semaphore] = None,
193 | ) -> Any:
194 | """Convenience function for loading a task json and running it."""
195 | with open(task_path) as f:
196 | task = json.load(f)
197 |
198 | async def _run_workflow() -> Any:
199 | return await self.run(ctx=ctx, task=task)
200 |
201 | if sem: # in case running in batch with other workflow runs
202 | await sem.acquire()
203 | try:
204 | res = await _run_workflow()
205 | finally:
206 | sem.release()
207 | else:
208 | res = await _run_workflow()
209 |
210 | return res
211 |
212 |
213 | async def _test_workflow() -> None:
214 | import json
215 | from pathlib import Path
216 |
217 | from llama_index.llms.openai import OpenAI
218 |
219 | task_path = Path(
220 | Path(__file__).parents[2].absolute(), "data/training/0a938d79.json"
221 | )
222 | with open(task_path) as f:
223 | task = json.load(f)
224 |
225 | w = ARCTaskSolverWorkflow(
226 | timeout=None, verbose=False, llm=OpenAI("gpt-4o")
227 | )
228 | attempts = await w.run(task=task)
229 |
230 | print(attempts)
231 |
232 |
233 | if __name__ == "__main__":
234 | import asyncio
235 |
236 | asyncio.run(_test_workflow())
237 |
--------------------------------------------------------------------------------
/recipes/llamaindex/arc_finetuning_st/workflows/events.py:
--------------------------------------------------------------------------------
1 | from llama_index.core.workflow import Event
2 |
3 |
4 | class FormatTaskEvent(Event):
5 | ...
6 |
7 |
8 | class PredictionEvent(Event):
9 | ...
10 |
11 |
12 | class EvaluationEvent(Event):
13 | passing: bool
14 |
15 |
16 | class CorrectionEvent(Event):
17 | ...
18 |
--------------------------------------------------------------------------------
/recipes/llamaindex/arc_finetuning_st/workflows/models.py:
--------------------------------------------------------------------------------
1 | import uuid
2 | from typing import List, Optional
3 |
4 | from llama_index.core.bridge.pydantic import BaseModel, Field
5 |
6 |
7 | class Prediction(BaseModel):
8 | """Prediction data class for LLM structured predict."""
9 |
10 | rationale: str = Field(
11 | description="Brief description of pattern and why prediction was made. Limit to 250 words."
12 | )
13 | prediction: str = Field(
14 | description="Predicted grid as a single string. e.g. '0,0,1\n1,1,1\n0,0,0'"
15 | )
16 |
17 | def __str__(self) -> str:
18 | return self.prediction
19 |
20 | @staticmethod
21 | def prediction_str_to_int_array(prediction: str) -> List[List[int]]:
22 | return [
23 | [int(a) for a in el.split(",")] for el in prediction.split("\n")
24 | ]
25 |
26 |
27 | class Critique(BaseModel):
28 | """Critique data class for LLM structured predict."""
29 |
30 | critique: str = Field(
31 | description="Brief critique of the previous prediction and rationale. Limit to 250 words."
32 | )
33 |
34 | def __str__(self) -> str:
35 | return self.critique
36 |
37 |
38 | class Attempt(BaseModel):
39 | """Container class of a single solution attempt."""
40 |
41 | id_: str = Field(default_factory=lambda: str(uuid.uuid4()))
42 | prediction: Prediction
43 | critique: Optional[Critique] = Field(default=None)
44 | passing: bool = Field(default=False)
45 |
--------------------------------------------------------------------------------
/recipes/llamaindex/arc_finetuning_st/workflows/prompts.py:
--------------------------------------------------------------------------------
1 | from llama_index.core.prompts import PromptTemplate
2 |
3 | PREDICTION_PROMPT_TEMPLATE = PromptTemplate(
4 | """You are a bot that is very good at solving puzzles. Below is a list of input and output pairs with a pattern.
5 | Identify the pattern in the training examples and predict the output for the provided TEST INPUT.
6 |
7 | EXAMPLES:
8 | {examples}
9 |
10 | TEST INPUT:
11 | {test_input}
12 |
13 | OUTPUT FORMAT:
14 | {{
15 | "output":
16 | }}
17 |
18 | Return your response in JSON format given above. DO NOT RETURN markdown code.
19 | """
20 | )
21 |
22 |
23 | REFLECTION_PROMPT_TEMPLATE = PromptTemplate(
24 | """You are a bot that is very good at solving puzzles. Below is a list of input and output pairs that share a
25 | common pattern. The TEST INPUT also shares this common pattern, and you've previously predicted the output for it.
26 | Your task now is critique the latest prediction on why it might not fit the pattern inherent in the example input/output pairs.
27 |
28 | EXAMPLES:
29 | {examples}
30 |
31 | TEST INPUT:
32 | {test_input}
33 |
34 | PAST ATTEMPTS:
35 | {past_attempts}
36 |
37 | OUTPUT FORMAT:
38 | {{
39 | "critique": ...
40 | }}
41 |
42 | Return your response in JSON format given above. DO NOT RETURN markdown code."""
43 | )
44 |
45 | CORRECTION_PROMPT_TEMPLATE = PromptTemplate(
46 | """You are a bot that is very good at solving puzzles. Below is a list of input and output pairs that share a
47 | common pattern. The TEST INPUT also shares this common pattern, and you've previously predicted the output for it.
48 | The predicted output was found to be incorrect and a critique has been articulated offering a potential
49 | reason as to why it may have been a flawed prediction.
50 |
51 | Your task now to create a new prediction that corrects from the previous attempts. Use
52 | the last attempt and critique.
53 |
54 | EXAMPLES:
55 | {examples}
56 |
57 | TEST INPUT:
58 | {test_input}
59 |
60 | PAST ATTEMPTS:
61 | {past_attempts}
62 |
63 | OUTPUT FORMAT:
64 | {{
65 | "prediction": ...
66 | }}
67 |
68 | Return your response in JSON format given above. DO NOT RETURN markdown code."""
69 | )
70 |
--------------------------------------------------------------------------------
/recipes/llamaindex/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["poetry-core"]
3 | build-backend = "poetry.core.masonry.api"
4 |
5 | [tool.poetry]
6 | name = "arc-finetuning-st"
7 | version = "0.1.0"
8 | description = ""
9 | authors = ["Andrei Fajardo "]
10 | readme = "README.md"
11 |
12 | [tool.poetry.dependencies]
13 | python = "^3.10"
14 | llama-index-core = "^0.11.11"
15 | llama-index-finetuning = "^0.2.1"
16 | llama-index-llms-openai = "^0.2.7"
17 | llama-index-program-openai = "^0.2.0"
18 | streamlit = "^1.38.0"
19 | plotly = "^5.24.1"
20 | pandas = "^2.2.2"
21 |
22 | [tool.poetry.group.dev.dependencies]
23 | jupyterlab = "^4.2.5"
24 |
25 | [tool.poetry.scripts]
26 | arc-finetuning-cli = 'arc_finetuning_st.cli.command_line:main'
27 |
--------------------------------------------------------------------------------
/recipes/ollama/requirements.txt:
--------------------------------------------------------------------------------
1 | streamlit==1.33.0
2 | ollama
3 | openai
4 |
--------------------------------------------------------------------------------
/recipes/ollama/streamlit_app.py:
--------------------------------------------------------------------------------
1 | import ollama
2 | import streamlit as st
3 | from openai import OpenAI
4 |
5 | st.set_page_config(
6 | page_title="Streamlit Ollama Chatbot",
7 | page_icon="💬",
8 | layout="wide",
9 | initial_sidebar_state="expanded",
10 | )
11 |
12 |
13 | def extract_model_names(models_info: list) -> tuple:
14 | """
15 | Extracts the model names from the models information.
16 |
17 | :param models_info: A dictionary containing the models' information.
18 |
19 | Return:
20 | A tuple containing the model names.
21 | """
22 |
23 | return tuple(model["name"] for model in models_info["models"])
24 |
25 |
26 | def main():
27 | """
28 | The main function that runs the application.
29 | """
30 | st.subheader("Streamlit Ollama Chatbot", divider="red", anchor=False)
31 |
32 | client = OpenAI(
33 | base_url="http://localhost:11434/v1",
34 | api_key="ollama", # required, but unused
35 | )
36 |
37 | models_info = ollama.list()
38 | available_models = extract_model_names(models_info)
39 |
40 | if available_models:
41 | selected_model = st.selectbox(
42 | "Pick a model available locally on your system ↓", available_models
43 | )
44 |
45 | else:
46 | st.warning("You have not pulled any model from Ollama yet!", icon="⚠️")
47 | st.page_link("https://ollama.com/library", label="Pull model(s) from Ollama", icon="🦙")
48 |
49 | message_container = st.container(height=500, border=True)
50 |
51 | if "messages" not in st.session_state:
52 | st.session_state.messages = []
53 |
54 | for message in st.session_state.messages:
55 | avatar = "🤖" if message["role"] == "assistant" else "😎"
56 | with message_container.chat_message(message["role"], avatar=avatar):
57 | st.markdown(message["content"])
58 |
59 | if prompt := st.chat_input("Enter a prompt here..."):
60 | try:
61 | st.session_state.messages.append(
62 | {"role": "user", "content": prompt})
63 |
64 | message_container.chat_message("user", avatar="😎").markdown(prompt)
65 |
66 | with message_container.chat_message("assistant", avatar="🤖"):
67 | with st.spinner("Give i a moment..."):
68 | stream = client.chat.completions.create(
69 | model=selected_model,
70 | messages=[
71 | {"role": m["role"], "content": m["content"]}
72 | for m in st.session_state.messages
73 | ],
74 | stream=True,
75 | )
76 | # stream response
77 | response = st.write_stream(stream)
78 | st.session_state.messages.append(
79 | {"role": "assistant", "content": response})
80 |
81 | except Exception as e:
82 | st.error(e, icon="⛔️")
83 |
84 |
85 | if __name__ == "__main__":
86 | main()
--------------------------------------------------------------------------------
/recipes/openai/README.md:
--------------------------------------------------------------------------------
1 | # OpenAI
2 |
--------------------------------------------------------------------------------
/recipes/replicate/.streamlit/secrets_template.toml:
--------------------------------------------------------------------------------
1 | #API tokens
2 |
3 | REPLICATE_API_TOKEN = "replace-this-with-your-own-token"
--------------------------------------------------------------------------------
/recipes/replicate/README.md:
--------------------------------------------------------------------------------
1 | # How to run the demo Replicate Streamlit chatbot app
2 | This is a recipe for a Replicate Streamlit chatbot app. The app uses a single API to access 3 different LLMs and adjust parameters such as temperature and top-p.
3 |
4 | Other ways to explore this recipe:
5 | * [Deployed app](https://replicate-recipe.streamlit.app/)
6 |
7 | (Requires [Replicate API token](https://replicate.com/signin?next=/account/api-tokens))
8 | * [Blog post](https://blog.streamlit.io/how-to-create-an-ai-chatbot-llm-api-replicate-streamlit/)
9 | * [Video](https://youtu.be/zsQ7EN10zj8?si=fGxg4zH7mJyrasaT)
10 |
11 | ## Prerequisites
12 | * Python >=3.8, !=3.9.7
13 | * A [Replicate API token](https://replicate.com/signin?next=/account/api-tokens)
14 | (Please note that a payment method is required to access features beyond the free trial limits.)
15 |
16 | ## Environment setup
17 | ### Local setup
18 | 1. Clone the Cookbook repo: `git clone https://github.com/streamlit/cookbook.git`
19 | 2. From the Cookbook root directory, change directory into the Replicate recipe: `cd recipes/replicate`
20 | 3. Add your Replicate API token to the `.streamlit/secrets_template.toml` file
21 | 4. Update the filename from `secrets_template.toml` to `secrets.toml`: `mv .streamlit/secrets_template.toml .streamlit/secrets.toml`
22 |
23 | (To learn more about secrets handling in Streamlit, refer to the documentation [here](https://docs.streamlit.io/develop/concepts/connections/secrets-management).)
24 | 5. Create a virtual environment: `python -m venv replicatevenv`
25 | 6. Activate the virtual environment: `source replicatevenv/bin/activate`
26 | 7. Install the dependencies: `pip install -r requirements.txt`
27 |
28 | ### GitHub Codespaces setup
29 | 1. Create a new codespace by selecting the `Codespaces` option from the `Code` button
30 | 2. Once the codespace has been generated, add your Replicate API token to the `recipes/replicate/.streamlit/secrets_template.toml` file
31 | 3. Update the filename from `secrets_template.toml` to `secrets.toml`
32 |
33 | (To learn more about secrets handling in Streamlit, refer to the documentation [here](https://docs.streamlit.io/develop/concepts/connections/secrets-management).)
34 | 4. From the Cookbook root directory, change directory into the Replicate recipe: `cd recipes/replicate`
35 | 5. Install the dependencies: `pip install -r requirements.txt`
36 |
--------------------------------------------------------------------------------
/recipes/replicate/requirements.txt:
--------------------------------------------------------------------------------
1 | streamlit
2 | replicate
3 | transformers
4 |
--------------------------------------------------------------------------------
/recipes/replicate/streamlit_app.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | import replicate
3 | import os
4 | from transformers import AutoTokenizer
5 |
6 | # App title
7 | st.set_page_config(page_title="Streamlit Replicate Chatbot", page_icon="💬")
8 |
9 | # Replicate Credentials
10 | with st.sidebar:
11 | st.title('💬 Streamlit Replicate Chatbot')
12 | st.write('Create chatbots using various LLM models hosted at [Replicate](https://replicate.com/).')
13 | if 'REPLICATE_API_TOKEN' in st.secrets:
14 | replicate_api = st.secrets['REPLICATE_API_TOKEN']
15 | else:
16 | replicate_api = st.text_input('Enter Replicate API token:', type='password')
17 | if not (replicate_api.startswith('r8_') and len(replicate_api)==40):
18 | st.warning('Please enter your Replicate API token.', icon='⚠️')
19 | st.markdown("**Don't have an API token?** Head over to [Replicate](https://replicate.com) to sign up for one.")
20 | os.environ['REPLICATE_API_TOKEN'] = replicate_api
21 |
22 | st.subheader("Models and parameters")
23 | model = st.selectbox("Select a model",("meta/meta-llama-3-70b-instruct", "mistralai/mistral-7b-instruct-v0.2", "google-deepmind/gemma-2b-it"), key="model")
24 | if model == "google-deepmind/gemma-2b-it":
25 | model = "google-deepmind/gemma-2b-it:dff94eaf770e1fc211e425a50b51baa8e4cac6c39ef074681f9e39d778773626"
26 |
27 | temperature = st.sidebar.slider('temperature', min_value=0.01, max_value=5.0, value=0.7, step=0.01, help="Randomness of generated output")
28 | if temperature >= 1:
29 | st.warning('Values exceeding 1 produces more creative and random output as well as increased likelihood of hallucination.')
30 | if temperature < 0.1:
31 | st.warning('Values approaching 0 produces deterministic output. Recommended starting value is 0.7')
32 |
33 | top_p = st.sidebar.slider('top_p', min_value=0.01, max_value=1.0, value=0.9, step=0.01, help="Top p percentage of most likely tokens for output generation")
34 |
35 | # Store LLM-generated responses
36 | if "messages" not in st.session_state.keys():
37 | st.session_state.messages = [{"role": "assistant", "content": "Ask me anything."}]
38 |
39 | # Display or clear chat messages
40 | for message in st.session_state.messages:
41 | with st.chat_message(message["role"]):
42 | st.write(message["content"])
43 |
44 | def clear_chat_history():
45 | st.session_state.messages = [{"role": "assistant", "content": "Ask me anything."}]
46 |
47 | st.sidebar.button('Clear chat history', on_click=clear_chat_history)
48 |
49 | @st.cache_resource(show_spinner=False)
50 | def get_tokenizer():
51 | """Get a tokenizer to make sure we're not sending too much text
52 | text to the Model. Eventually we will replace this with ArcticTokenizer
53 | """
54 | return AutoTokenizer.from_pretrained("huggyllama/llama-7b")
55 |
56 | def get_num_tokens(prompt):
57 | """Get the number of tokens in a given prompt"""
58 | tokenizer = get_tokenizer()
59 | tokens = tokenizer.tokenize(prompt)
60 | return len(tokens)
61 |
62 | # Function for generating model response
63 | def generate_response():
64 | prompt = []
65 | for dict_message in st.session_state.messages:
66 | if dict_message["role"] == "user":
67 | prompt.append("<|im_start|>user\n" + dict_message["content"] + "<|im_end|>")
68 | else:
69 | prompt.append("<|im_start|>assistant\n" + dict_message["content"] + "<|im_end|>")
70 |
71 | prompt.append("<|im_start|>assistant")
72 | prompt.append("")
73 | prompt_str = "\n".join(prompt)
74 |
75 | if get_num_tokens(prompt_str) >= 3072:
76 | st.error("Conversation length too long. Please keep it under 3072 tokens.")
77 | st.button('Clear chat history', on_click=clear_chat_history, key="clear_chat_history")
78 | st.stop()
79 |
80 | for event in replicate.stream(model,
81 | input={"prompt": prompt_str,
82 | "prompt_template": r"{prompt}",
83 | "temperature": temperature,
84 | "top_p": top_p,
85 | }):
86 | yield str(event)
87 |
88 | # User-provided prompt
89 | if prompt := st.chat_input(disabled=not replicate_api):
90 | st.session_state.messages.append({"role": "user", "content": prompt})
91 | with st.chat_message("user"):
92 | st.write(prompt)
93 |
94 | # Generate a new response if last message is not from assistant
95 | if st.session_state.messages[-1]["role"] != "assistant":
96 | with st.chat_message("assistant"):
97 | response = generate_response()
98 | full_response = st.write_stream(response)
99 | message = {"role": "assistant", "content": full_response}
100 | st.session_state.messages.append(message)
101 |
--------------------------------------------------------------------------------
/recipes/replit/README.md:
--------------------------------------------------------------------------------
1 | # Building and deploying on Replit
2 |
3 |
4 |

5 |
6 |
7 |
8 | Replit is a cloud-based development and deployment platform that allows you to build and deploy in a single, integrated environment.
9 | It comes with handy tools, like [secrets management](https://docs.replit.com/programming-ide/storing-sensitive-information-environment-variables) and [AI assisted coding](https://docs.replit.com/programming-ide/ai-code-completion), and supports _any_ language you can think of. Replit also has cloud services like [object storage](https://docs.replit.com/storage/replit-database), [embedded Postgres](https://docs.replit.com/hosting/databases/postgresql-database), and [key-value databases](https://docs.replit.com/hosting/databases/replit-database) that make it easy to build and ship apps. For a 1-minute overview of Replit, see [this video](https://www.youtube.com/watch?v=TiHq41h3nDo).
10 |
11 |
12 |

13 |
14 |
15 | ## Setup
16 |
17 | To get started on Replit, follow these easy steps:
18 |
19 | 1. Create a Replit [account](https://replit.com/),
20 | 2. Fork this Replit [template](https://replit.com/@matt/Streamlit?v=1#README.md), which contains the same code in this folder.
21 | 3. To run the app click `Run` in the nav bar.
22 |
23 | A web view pane will open up with your application. Simply edit the code in the editor to see changes in real time.
24 |
25 | **An important note:** because Replit is a cloud-based editor, the [development URL](https://docs.replit.com/additional-resources/add-a-made-with-replit-badge-to-your-webview#what-is-the-webview) is accessible from any device (while your Repl is running). That means you can test your app out on desktop and mobile simultaneously.
26 |
27 |
28 |

29 |
30 |
31 | ## Deployment
32 |
33 | Deployment on Replit is straightforward and can be done in a few steps:
34 |
35 | 1. Click `Deploy` in the top navigation bar of your Repl.
36 | 2. Choose a deployment type:
37 | - For frontend apps, select "Autoscale" deployments. These are cost-effective and automatically scale based on traffic.
38 | - For services requiring continuous execution (e.g., backends, APIs, bots), choose "Reserved VM" deployments.
39 | 3. Select a custom subdomain for your app or use the auto-generated one.
40 | 4. Configure your deployment settings, including environment variables if needed.
41 | 5. Review and confirm your deployment configuration.
42 |
43 | Important notes:
44 | - You'll need to add a payment method to deploy your app.
45 | - Additional options can be configured in the `.replit` file present in your project, note that you may have to reveal hidden files if you can't see it. For more information on configuring the `.replit` file, refer to the [docs](https://docs.replit.com/programming-ide/configuring-repl).
46 | - For detailed instructions and advanced configuration options, refer to the [official Replit deployment documentation](https://docs.replit.com/hosting/deployments/about-deployments).
47 |
48 | ### Custom Domains
49 |
50 | Replit automatically assigns your website a unique domain name such as `your-app-name-your-username.replit.app`. However, if you want to use your own domain that you've purchased through services like [CloudFlare](https://www.cloudflare.com/) or [Name-cheap](https://www.namecheap.com/), you can configure it in your Replit deployment settings. Here's how:
51 |
52 | 1. Go to your Repl's "Deployment" tab.
53 | 2. Click on "Configure custom domain".
54 | 3. Enter your custom domain name and follow the instructions to set up DNS records with your domain provider.
55 |
56 | Make sure to notice the difference between setting up a root domain and a subdomain. Root domains don't have any prefixes before the main site name (e.g., `example.com`), while subdomains do (e.g., `subdomain.example.com`). The setup process might differ slightly for each. Always refer to your domain registrar's documentation for specific instructions on how to set up DNS records for your custom domain.
57 |
58 | For more detailed information on setting up custom domains in Replit, you can refer to their official documentation [here](https://docs.replit.com/hosting/deployments/custom-domains).
59 |
--------------------------------------------------------------------------------
/recipes/trulens/.env.template:
--------------------------------------------------------------------------------
1 | OPENAI_API_KEY=
2 |
--------------------------------------------------------------------------------
/recipes/trulens/README.md:
--------------------------------------------------------------------------------
1 | # Streamlit × Trulens demo
2 |
3 | Ask questions about the Pacific Northwest, and get `TruLens` evaluations on the app response.
4 |
5 | To run the demo:
6 |
7 | 1. Clone the cookbook repository, and navigate to the `trulens` recipe.
8 | 2. Copy the contents of `.env.template` to `.env`, and fill in your credentials.
9 | 3. Launch the app with `streamlit run app.py`
10 |
--------------------------------------------------------------------------------
/recipes/trulens/app.py:
--------------------------------------------------------------------------------
1 | __import__('pysqlite3')
2 | import sys
3 | sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
4 |
5 | import streamlit as st
6 | import trulens.dashboard.streamlit as trulens_st
7 | from trulens.core import TruSession
8 |
9 | from base import rag, filtered_rag, tru_rag, filtered_tru_rag
10 |
11 | st.set_page_config(
12 | page_title="Use TruLens in Streamlit",
13 | page_icon="🦑",
14 | )
15 |
16 | st.title("TruLens ❤️ Streamlit")
17 |
18 | st.write("Learn about the Pacific Northwest, and view tracing & evaluation metrics powered by TruLens 🦑.")
19 |
20 | tru = TruSession()
21 |
22 | with_filters = st.toggle("Use Context Filter Guardrails", value=False)
23 |
24 | def generate_response(input_text):
25 | if with_filters:
26 | app = filtered_tru_rag
27 | with filtered_tru_rag as recording:
28 | response = filtered_rag.query(input_text)
29 | else:
30 | app = tru_rag
31 | with tru_rag as recording:
32 | response = rag.query(input_text)
33 |
34 | record = recording.get()
35 |
36 | return record, response
37 |
38 | with st.form("my_form"):
39 | text = st.text_area(
40 | "Enter text:", "When was the University of Washington founded?"
41 | )
42 | submitted = st.form_submit_button("Submit")
43 | if submitted:
44 | record, response = generate_response(text)
45 | st.info(response)
46 |
47 | if submitted:
48 | with st.expander("See the trace of this record 👀"):
49 | trulens_st.trulens_trace(record=record)
50 |
51 | trulens_st.trulens_feedback(record=record)
52 |
53 |
--------------------------------------------------------------------------------
/recipes/trulens/base.py:
--------------------------------------------------------------------------------
1 | __import__('pysqlite3')
2 | import sys
3 | sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
4 |
5 | import streamlit as st
6 | from openai import OpenAI
7 | import numpy as np
8 |
9 | from trulens.core import TruSession
10 | from trulens.core.guardrails.base import context_filter
11 | from trulens.apps.custom import instrument
12 | from trulens.apps.app import TruApp
13 | from trulens.providers.openai import OpenAI as OpenAIProvider
14 | from trulens.core import Feedback
15 | from trulens.core import Select
16 | from trulens.core.guardrails.base import context_filter
17 |
18 | from feedback import feedbacks, f_guardrail
19 | from vector_store import vector_store
20 |
21 | from dotenv import load_dotenv
22 |
23 | load_dotenv()
24 |
25 | oai_client = OpenAI()
26 |
27 | tru = TruSession()
28 |
29 | class RAG_from_scratch:
30 | @instrument
31 | def retrieve(self, query: str) -> list:
32 | """
33 | Retrieve relevant text from vector store.
34 | """
35 | results = vector_store.query(query_texts=query, n_results=4)
36 | # Flatten the list of lists into a single list
37 | return [doc for sublist in results["documents"] for doc in sublist]
38 |
39 | @instrument
40 | def generate_completion(self, query: str, context_str: list) -> str:
41 | """
42 | Generate answer from context.
43 | """
44 | completion = (
45 | oai_client.chat.completions.create(
46 | model="gpt-3.5-turbo",
47 | temperature=0,
48 | messages=[
49 | {
50 | "role": "user",
51 | "content": f"We have provided context information below. \n"
52 | f"---------------------\n"
53 | f"{context_str}"
54 | f"\n---------------------\n"
55 | f"First, say hello and that you're happy to help. \n"
56 | f"\n---------------------\n"
57 | f"Then, given this information, please answer the question: {query}",
58 | }
59 | ],
60 | )
61 | .choices[0]
62 | .message.content
63 | )
64 | return completion
65 |
66 | @instrument
67 | def query(self, query: str) -> str:
68 | context_str = self.retrieve(query)
69 | completion = self.generate_completion(query, context_str)
70 | return completion
71 |
72 | class filtered_RAG_from_scratch:
73 | @instrument
74 | @context_filter(f_guardrail, 0.75, keyword_for_prompt="query")
75 | def retrieve(self, query: str) -> list:
76 | """
77 | Retrieve relevant text from vector store.
78 | """
79 | results = vector_store.query(query_texts=query, n_results=4)
80 | return [doc for sublist in results["documents"] for doc in sublist]
81 |
82 | @instrument
83 | def generate_completion(self, query: str, context_str: list) -> str:
84 | """
85 | Generate answer from context.
86 | """
87 | completion = (
88 | oai_client.chat.completions.create(
89 | model="gpt-3.5-turbo",
90 | temperature=0,
91 | messages=[
92 | {
93 | "role": "user",
94 | "content": f"We have provided context information below. \n"
95 | f"---------------------\n"
96 | f"{context_str}"
97 | f"\n---------------------\n"
98 | f"Given this information, please answer the question: {query}",
99 | }
100 | ],
101 | )
102 | .choices[0]
103 | .message.content
104 | )
105 | return completion
106 |
107 | @instrument
108 | def query(self, query: str) -> str:
109 | context_str = self.retrieve(query=query)
110 | completion = self.generate_completion(
111 | query=query, context_str=context_str
112 | )
113 | return completion
114 |
115 |
116 | filtered_rag = filtered_RAG_from_scratch()
117 |
118 | rag = RAG_from_scratch()
119 |
120 | tru_rag = TruApp(
121 | rag,
122 | app_name="RAG",
123 | app_version="v1",
124 | feedbacks=feedbacks,
125 | )
126 |
127 | filtered_tru_rag = TruApp(
128 | filtered_rag,
129 | app_name="RAG",
130 | app_version="v2",
131 | feedbacks=feedbacks,
132 | )
133 |
--------------------------------------------------------------------------------
/recipes/trulens/feedback.py:
--------------------------------------------------------------------------------
1 | __import__('pysqlite3')
2 | import sys
3 | sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
4 |
5 | import numpy as np
6 | from trulens.core import Feedback
7 | from trulens.core import Select
8 | from trulens.providers.openai import OpenAI as OpenAIProvider
9 |
10 | from dotenv import load_dotenv
11 |
12 | load_dotenv()
13 |
14 | provider = OpenAIProvider(model_engine="gpt-4o-mini")
15 |
16 | # Define a groundedness feedback function
17 | f_groundedness = (
18 | Feedback(
19 | provider.groundedness_measure_with_cot_reasons, name="Groundedness"
20 | )
21 | .on(Select.RecordCalls.retrieve.rets.collect())
22 | .on_output()
23 | )
24 | # Question/answer relevance between overall question and answer.
25 | f_answer_relevance = (
26 | Feedback(provider.relevance_with_cot_reasons, name="Answer Relevance")
27 | .on_input()
28 | .on_output()
29 | )
30 |
31 | # Context relevance between question and each context chunk.
32 | f_context_relevance = (
33 | Feedback(
34 | provider.context_relevance_with_cot_reasons, name="Context Relevance"
35 | )
36 | .on_input()
37 | .on(Select.RecordCalls.retrieve.rets[:])
38 | .aggregate(np.mean) # choose a different aggregation method if you wish
39 | )
40 |
41 | feedbacks = [f_groundedness, f_answer_relevance, f_context_relevance]
42 |
43 | # note: feedback function used for guardrail must only return a score, not also reasons
44 | f_guardrail = Feedback(
45 | provider.context_relevance, name="Context Relevance"
46 | )
47 |
--------------------------------------------------------------------------------
/recipes/trulens/requirements.txt:
--------------------------------------------------------------------------------
1 | pip==24.2
2 | openai
3 | pysqlite3-binary
4 | chromadb
5 | trulens-core @ git+https://github.com/truera/trulens#egg=trulens-core&subdirectory=src/core/
6 | trulens-feedback @ git+https://github.com/truera/trulens#egg=trulens-feedback&subdirectory=src/feedback/
7 | trulens-providers-openai @ git+https://github.com/truera/trulens#egg=trulens-providers-openai&subdirectory=src/providers/openai/
8 | trulens-dashboard @ git+https://github.com/truera/trulens#egg=trulens-dashboard&subdirectory=src/dashboard/
--------------------------------------------------------------------------------
/recipes/trulens/vector_store.py:
--------------------------------------------------------------------------------
1 | __import__('pysqlite3')
2 | import sys
3 | sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')
4 |
5 | import openai
6 | import os
7 | import chromadb
8 | from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction
9 |
10 | from dotenv import load_dotenv
11 |
12 | load_dotenv()
13 |
14 | uw_info = """
15 | The University of Washington, founded in 1861 in Seattle, is a public research university
16 | with over 45,000 students across three campuses in Seattle, Tacoma, and Bothell.
17 | As the flagship institution of the six public universities in Washington state,
18 | UW encompasses over 500 buildings and 20 million square feet of space,
19 | including one of the largest library systems in the world.
20 | """
21 |
22 | wsu_info = """
23 | Washington State University, commonly known as WSU, founded in 1890, is a public research university in Pullman, Washington.
24 | With multiple campuses across the state, it is the state's second largest institution of higher education.
25 | WSU is known for its programs in veterinary medicine, agriculture, engineering, architecture, and pharmacy.
26 | """
27 |
28 | seattle_info = """
29 | Seattle, a city on Puget Sound in the Pacific Northwest, is surrounded by water, mountains and evergreen forests, and contains thousands of acres of parkland.
30 | It's home to a large tech industry, with Microsoft and Amazon headquartered in its metropolitan area.
31 | The futuristic Space Needle, a legacy of the 1962 World's Fair, is its most iconic landmark.
32 | """
33 |
34 | starbucks_info = """
35 | Starbucks Corporation is an American multinational chain of coffeehouses and roastery reserves headquartered in Seattle, Washington.
36 | As the world's largest coffeehouse chain, Starbucks is seen to be the main representation of the United States' second wave of coffee culture.
37 | """
38 |
39 | newzealand_info = """
40 | New Zealand is an island country located in the southwestern Pacific Ocean. It comprises two main landmasses—the North Island and the South Island—and over 700 smaller islands.
41 | The country is known for its stunning landscapes, ranging from lush forests and mountains to beaches and lakes. New Zealand has a rich cultural heritage, with influences from
42 | both the indigenous Māori people and European settlers. The capital city is Wellington, while the largest city is Auckland. New Zealand is also famous for its adventure tourism,
43 | including activities like bungee jumping, skiing, and hiking.
44 | """
45 |
46 | embedding_function = OpenAIEmbeddingFunction(
47 | api_key=os.environ.get("OPENAI_API_KEY"),
48 | model_name="text-embedding-ada-002",
49 | )
50 |
51 |
52 | chroma_client = chromadb.Client()
53 | vector_store = chroma_client.get_or_create_collection(
54 | name="Washington", embedding_function=embedding_function
55 | )
56 |
57 | vector_store.add("uw_info", documents=uw_info)
58 | vector_store.add("wsu_info", documents=wsu_info)
59 | vector_store.add("seattle_info", documents=seattle_info)
60 | vector_store.add("starbucks_info", documents=starbucks_info)
61 | vector_store.add("newzealand_info", documents=newzealand_info)
--------------------------------------------------------------------------------
/recipes/weaviate/.streamlit/secrets_template.toml:
--------------------------------------------------------------------------------
1 | WEAVIATE_API_KEY = "your weaviate key goes here"
2 | WEAVIATE_URL = "your weaviate url goes here"
3 | COHERE_API_KEY = "your cohere api key goes here"
--------------------------------------------------------------------------------
/recipes/weaviate/README.md:
--------------------------------------------------------------------------------
1 | # How to run the demo app
2 | This is a recipe for a movie recommendation app. The app uses [Weaviate](https://weaviate.io/) to create a vector database of movie titles and [Streamlit](https://streamlit.io/) to create a recommendation chatbot.
3 |
4 | Other ways to explore this recipe:
5 | * [Deployed app](https://weaviate-movie-magic.streamlit.app/)
6 | * [Video](https://youtu.be/SQD-aWlhqvM?si=t54W53G1gWnTAiwx)
7 | * [Blog post](https://blog.streamlit.io/how-to-recommendation-app-vector-database-weaviate/)
8 |
9 | ## Prerequisites
10 | * Python >=3.8, !=3.9.7
11 | * [A Weaviate API key and URL](https://auth.wcs.api.weaviate.io/auth/realms/SeMI/login-actions/registration?client_id=wcs-frontend&tab_id=5bw6GQTdWU0)
12 | * [A Cohere API key](https://dashboard.cohere.com/welcome/register)
13 |
14 | ## Environment setup
15 | ### Local setup
16 |
17 | #### Create a virtual environment
18 | 1. Clone the Cookbook repo: `git clone https://github.com/streamlit/cookbook.git`
19 | 2. From the Cookbook root directory, change directory into the recipe: `cd recipes/weaviate`
20 | 3. Add secrets to the `.streamlit/secrets_template.toml` file
21 | 4. Update the filename from `secrets_template.toml` to `secrets.toml`: `mv .streamlit/secrets_template.toml .streamlit/secrets.toml`
22 |
23 | (To learn more about secrets handling in Streamlit, refer to the documentation [here](https://docs.streamlit.io/develop/concepts/connections/secrets-management).)
24 | 5. Create a virtual environment: `python3 -m venv weaviatevenv`
25 | 6. Activate the virtual environment: `source weaviatevenv/bin/activate`
26 | 7. Install the dependencies: `pip install -r requirements.txt`
27 |
28 | #### Add data to your Weaviate Cloud
29 | 1. Create a Weaviate Cloud [Collection](https://weaviate.io/developers/weaviate/config-refs/schema#introduction) and add data to it: `python3 helpers/add_data.py`
30 | 2. (Optional) Verify the data: `python3 helpers/verify_data.py`
31 | 3. (Optional) Use the Weaviate Cloud UI to [query the Collection](https://weaviate.io/developers/weaviate/connections/connect-query#example-query):
32 | ```
33 | { Get {MovieDemo (limit: 3
34 | where: { path: ["release_year"],
35 | operator: Equal,
36 | valueInt: 1985}) {
37 | budget
38 | movie_id
39 | overview
40 | release_year
41 | revenue
42 | tagline
43 | title
44 | vote_average
45 | vote_count
46 | }}}
47 | ```
48 |
49 | #### Run the app
50 | 1. Run the app with: `streamlit run demo_app.py`
51 | 2. The app should spin up in a new browser tab
52 |
53 | (Please note that this version of the demo app does not feature the poster images so it will look different from the [deployed app](https://weaviate-movie-magic.streamlit.app/).)
54 |
--------------------------------------------------------------------------------
/recipes/weaviate/demo_app.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | import time
3 | import sys
4 | import os
5 | import base64
6 | from st_weaviate_connection import WeaviateConnection, WeaviateFilter
7 | from weaviate.classes.query import Filter
8 |
9 | # Constants
10 | ENV_VARS = ["WEAVIATE_URL", "WEAVIATE_API_KEY", "COHERE_API_KEY"]
11 | NUM_RECOMMENDATIONS_PER_ROW = 5
12 | SEARCH_LIMIT = 10
13 |
14 | # Search Mode descriptions
15 | SEARCH_MODES = {
16 | "Keyword": ("Keyword search (BM25) ranks documents based on the relative frequencies of search terms.", 0),
17 | "Semantic": ("Semantic (vector) search ranks results based on their similarity to your search query.", 1),
18 | "Hybrid": ("Hybrid search combines vector and BM25 searches to offer best-of-both-worlds search results.", 0.7),
19 | }
20 |
21 | # Functions
22 | def get_env_vars(env_vars):
23 | """Retrieve environment variables"""
24 | env_vars = {var: os.environ.get(var, "") for var in env_vars}
25 | for var, value in env_vars.items():
26 | if not value:
27 | st.error(f"{var} not set", icon="🚨")
28 | sys.exit(f"{var} not set")
29 | return env_vars
30 |
31 | def display_chat_messages():
32 | """Print message history"""
33 | for message in st.session_state.messages:
34 | with st.chat_message(message["role"]):
35 | st.markdown(message["content"])
36 | if "images" in message:
37 | for i in range(0, len(message["images"]), NUM_RECOMMENDATIONS_PER_ROW):
38 | cols = st.columns(NUM_RECOMMENDATIONS_PER_ROW)
39 | for j, col in enumerate(cols):
40 | if i + j < len(message["images"]):
41 | col.image(message["images"][i + j], width=200)
42 | if "titles" in message:
43 | for i in range(0, len(message["titles"]), NUM_RECOMMENDATIONS_PER_ROW):
44 | cols = st.columns(NUM_RECOMMENDATIONS_PER_ROW)
45 | for j, col in enumerate(cols):
46 | if i + j < len(message["titles"]):
47 | col.write(message["titles"][i + j])
48 |
49 |
50 | def base64_to_image(base64_str):
51 | """Convert base64 string to image"""
52 | return f"data:image/png;base64,{base64_str}"
53 |
54 | def clean_input(input_text):
55 | """Clean user input"""
56 | return input_text.replace('"', "").replace("'", "")
57 |
58 | def setup_sidebar():
59 | """Setup sidebar elements"""
60 | with st.sidebar:
61 | st.title("🎥🍿 Movie Magic")
62 | st.subheader("The RAG Recommender")
63 | st.markdown("Your Weaviate & AI powered movie recommender. Find the perfect film for any occasion. Just tell us what you're looking for!")
64 | st.header("Settings")
65 |
66 | mode = st.radio("Search Mode", options=list(SEARCH_MODES.keys()), index=2)
67 | year_range = st.slider("Year range", min_value=1950, max_value=2024, value=(1990, 2024))
68 | st.info(SEARCH_MODES[mode][0])
69 | st.success("Connected to Weaviate", icon="💚")
70 |
71 | return mode, year_range
72 |
73 | def setup_weaviate_connection(env_vars):
74 | """Setup Weaviate connection"""
75 | return st.connection(
76 | "weaviate",
77 | type=WeaviateConnection,
78 | url=env_vars["WEAVIATE_URL"],
79 | api_key=env_vars["WEAVIATE_API_KEY"],
80 | additional_headers={"X-Cohere-Api-Key": env_vars["COHERE_API_KEY"]},
81 | )
82 |
83 | def display_example_prompts():
84 | """Display example prompt buttons"""
85 | example_prompts = [
86 | ("sci-fi adventure", "movie night with friends"),
87 | ("romantic comedy", "date night"),
88 | ("animated family film", "family viewing"),
89 | ("classic thriller", "solo watching"),
90 | ("historical drama", "educational evening"),
91 | ("indie comedy-drama", "film club discussion"),
92 | ]
93 |
94 | example_prompts_help = [
95 | "Search for sci-fi adventure movies suitable for a group viewing",
96 | "Find romantic comedies perfect for a date night",
97 | "Look for animated movies great for family entertainment",
98 | "Discover classic thrillers for a solo movie night",
99 | "Explore historical dramas for an educational movie experience",
100 | "Find indie comedy-dramas ideal for film club discussions",
101 | ]
102 |
103 | st.markdown("---")
104 | st.write("Select an example prompt or enter your own, then **click `Search`** to get recommendations.")
105 |
106 | button_cols = st.columns(3)
107 | button_cols_2 = st.columns(3)
108 |
109 | for i, ((movie_type, occasion), help_text) in enumerate(zip(example_prompts, example_prompts_help)):
110 | col = button_cols[i] if i < 3 else button_cols_2[i-3]
111 | if col.button(f"{movie_type} for a {occasion}", help=help_text):
112 | st.session_state.example_movie_type = movie_type
113 | st.session_state.example_occasion = occasion
114 | return True
115 | return False
116 |
117 | def perform_search(conn, movie_type, rag_prompt, year_range, mode):
118 | """Perform search and display results"""
119 | df = conn.query(
120 | "MovieDemo",
121 | query=movie_type,
122 | # Uncomment the line below if you want to use this with poster images
123 | # return_properties=["title", "tagline", "poster"],
124 | # Comment out the line below if you want to use this with poster images
125 | return_properties=["title", "tagline"],
126 | filters=(
127 | WeaviateFilter.by_property("release_year").greater_or_equal(year_range[0]) &
128 | WeaviateFilter.by_property("release_year").less_or_equal(year_range[1])
129 | ),
130 | limit=SEARCH_LIMIT,
131 | alpha=SEARCH_MODES[mode][1],
132 | )
133 |
134 | images = []
135 | titles = []
136 |
137 | if df is None or df.empty:
138 | with st.chat_message("assistant"):
139 | st.write(f"No movies found matching {movie_type} and using {mode}. Please try again.")
140 | st.session_state.messages.append({"role": "assistant", "content": "No movies found. Please try again."})
141 | return
142 | else:
143 | with st.chat_message("assistant"):
144 | st.write("Raw search results.")
145 | cols = st.columns(NUM_RECOMMENDATIONS_PER_ROW)
146 | for index, row in df.iterrows():
147 | col = cols[index % NUM_RECOMMENDATIONS_PER_ROW]
148 | if "poster" in row and row["poster"]:
149 | col.image(base64_to_image(row["poster"]), width=200)
150 | images.append(base64_to_image(row["poster"]))
151 | else:
152 | col.write(f"{row['title']}")
153 | titles.append(row["title"])
154 |
155 | st.write("Now generating recommendation from these: ...")
156 |
157 | st.session_state.messages.append(
158 | {"role": "assistant", "content": "Raw search results. Generating recommendation from these: ...", "images": images, "titles": titles})
159 |
160 | with conn.client() as client:
161 | collection = client.collections.get("MovieDemo")
162 | response = collection.generate.hybrid(
163 | query=movie_type,
164 | filters=(
165 | Filter.by_property("release_year").greater_or_equal(year_range[0]) &
166 | Filter.by_property("release_year").less_or_equal(year_range[1])
167 | ),
168 | limit=SEARCH_LIMIT,
169 | alpha=SEARCH_MODES[mode][1],
170 | grouped_task=rag_prompt,
171 | grouped_properties=["title", "tagline"],
172 | )
173 |
174 | rag_response = response.generated
175 |
176 | with st.chat_message("assistant"):
177 | message_placeholder = st.empty()
178 | full_response = ""
179 | for chunk in rag_response.split():
180 | full_response += chunk + " "
181 | time.sleep(0.02)
182 | message_placeholder.markdown(full_response + "▌")
183 | message_placeholder.markdown(full_response)
184 |
185 | st.session_state.messages.append(
186 | {"role": "assistant", "content": "Recommendation from these search results: " + full_response}
187 | )
188 |
189 | def main():
190 | st.title("🎥🍿 Movie Magic")
191 |
192 | env_vars = get_env_vars(ENV_VARS)
193 | conn = setup_weaviate_connection(env_vars)
194 | mode, year_range = setup_sidebar()
195 |
196 | if "messages" not in st.session_state:
197 | st.session_state.messages = []
198 | st.session_state.greetings = False
199 |
200 | display_chat_messages()
201 |
202 | if not st.session_state.greetings:
203 | with st.chat_message("assistant"):
204 | intro = "👋 Welcome to Movie Magic! I'm your AI movie recommender. Tell me what kind of film you're in the mood for and the occasion, and I'll suggest some great options."
205 | st.markdown(intro)
206 | st.session_state.messages.append({"role": "assistant", "content": intro})
207 | st.session_state.greetings = True
208 |
209 | if "example_movie_type" not in st.session_state:
210 | st.session_state.example_movie_type = ""
211 | if "example_occasion" not in st.session_state:
212 | st.session_state.example_occasion = ""
213 |
214 | example_selected = display_example_prompts()
215 |
216 | movie_type = clean_input(st.text_input(
217 | "What movies are you looking for?",
218 | value=st.session_state.example_movie_type,
219 | placeholder="E.g., sci-fi adventure, romantic comedy"
220 | ))
221 |
222 | viewing_occasion = clean_input(st.text_input(
223 | "What occasion is the movie for?",
224 | value=st.session_state.example_occasion,
225 | placeholder="E.g., movie night with friends, date night"
226 | ))
227 |
228 | if st.button("Search") and movie_type and viewing_occasion:
229 | rag_prompt = f"Suggest one to two movies out of the following list, for a {viewing_occasion}. Give a concise yet fun and positive recommendation."
230 | prompt = f"Searching for: {movie_type} for {viewing_occasion}"
231 | with st.chat_message("user"):
232 | st.markdown(prompt)
233 | st.session_state.messages.append({"role": "user", "content": prompt})
234 |
235 | perform_search(conn, movie_type, rag_prompt, year_range, mode)
236 | st.rerun()
237 |
238 | if example_selected:
239 | st.rerun()
240 |
241 | if __name__ == "__main__":
242 | main()
243 |
--------------------------------------------------------------------------------
/recipes/weaviate/helpers/add_data.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from pathlib import Path
3 | import weaviate
4 | from weaviate.classes.config import Configure, DataType, Property
5 | from weaviate.classes.query import Filter
6 | from weaviate.util import generate_uuid5
7 | from datetime import datetime, timezone
8 | import toml
9 | import os
10 | from weaviate.classes.init import Auth
11 | from tqdm import tqdm
12 | import base64
13 |
14 | # Construct the path to the toml file
15 | current_dir = os.path.dirname(__file__)
16 | parent_dir = os.path.dirname(current_dir)
17 | toml_file_path = os.path.join(parent_dir, ".streamlit/secrets.toml")
18 |
19 | config = toml.load(toml_file_path)
20 |
21 | # Access values from the toml file
22 | weaviate_api_key = config["WEAVIATE_API_KEY"]
23 | weaviate_url = config["WEAVIATE_URL"]
24 | cohere_api_key = config["COHERE_API_KEY"]
25 |
26 | # Client for Weaviate Cloud
27 | client = weaviate.connect_to_weaviate_cloud(
28 | cluster_url=weaviate_url,
29 | auth_credentials=Auth.api_key(weaviate_api_key),
30 | headers={
31 | "X-Cohere-Api-Key": cohere_api_key
32 | }
33 | )
34 |
35 | # If you are using a local instance of Weaviate, you can use the following code
36 | # client = weaviate.connect_to_local(headers={"X-Cohere-Api-Key": cohere_apikey})
37 |
38 | # Delete any existing MovieDemo Collection to prevent errors
39 | client.collections.delete(["MovieDemo"])
40 |
41 | # Create the MovieDemo Collection
42 | movies = client.collections.create(
43 | name="MovieDemo",
44 | properties=[
45 | Property(
46 | name="title",
47 | data_type=DataType.TEXT,
48 | ),
49 | Property(
50 | name="overview",
51 | data_type=DataType.TEXT,
52 | ),
53 | Property(
54 | name="tagline",
55 | data_type=DataType.TEXT,
56 | ),
57 | Property(
58 | name="movie_id",
59 | data_type=DataType.INT,
60 | skip_vectorization=True,
61 | ),
62 | Property(
63 | name="release_year",
64 | data_type=DataType.INT,
65 | ),
66 | Property(
67 | name="genres",
68 | data_type=DataType.TEXT_ARRAY,
69 | ),
70 | Property(
71 | name="vote_average",
72 | data_type=DataType.NUMBER,
73 | ),
74 | Property(
75 | name="vote_count",
76 | data_type=DataType.INT,
77 | ),
78 | Property(
79 | name="revenue",
80 | data_type=DataType.INT,
81 | ),
82 | Property(
83 | name="budget",
84 | data_type=DataType.INT,
85 | ),
86 | # Uncomment the lines below if you want to use this with poster images
87 | # Property(
88 | # name="poster",
89 | # data_type=DataType.BLOB
90 | # ),
91 | ],
92 | vectorizer_config=Configure.Vectorizer.text2vec_cohere(),
93 | vector_index_config=Configure.VectorIndex.hnsw(
94 | quantizer=Configure.VectorIndex.Quantizer.bq()
95 | ),
96 | generative_config=Configure.Generative.cohere(model="command-r-plus"),
97 | )
98 |
99 | # Add objects to the MovieDemo collection from the JSON file and directory of poster images
100 | json_file_path = os.path.join(os.getcwd(), "helpers/data/1950_2024_movies_info.json")
101 | movies_df = pd.read_json(json_file_path)
102 |
103 | img_dir = Path(os.path.join(os.getcwd(), "helpers/data/posters"))
104 |
105 | dataobj_list = list()
106 |
107 | with movies.batch.fixed_size(100) as batch:
108 | for i, movie_row in tqdm(movies_df.iterrows()):
109 | try:
110 | date_object = datetime.strptime(movie_row["release_date"], "%Y-%m-%d").replace(
111 | tzinfo=timezone.utc
112 | )
113 | # Uncomment the lines below if you want to use this with poster images
114 | # img_path = (img_dir / f"{movie_row['id']}_poster.jpg")
115 | # with open(img_path, "rb") as file:
116 | # poster_b64 = base64.b64encode(file.read()).decode("utf-8")
117 |
118 | props = {
119 | k: movie_row[k]
120 | for k in [
121 | "title",
122 | "overview",
123 | "tagline",
124 | "vote_count",
125 | "vote_average",
126 | "revenue",
127 | "budget",
128 | ]
129 | }
130 | props["movie_id"] = movie_row["id"]
131 | props["release_year"] = date_object.year
132 | props["genres"] = [genre["name"] for genre in movie_row["genres"]]
133 | # Uncomment the line below if you want to use this with poster images
134 | # props["poster"] = poster_b64
135 |
136 | batch.add_object(properties=props, uuid=generate_uuid5(movie_row["id"]))
137 | except Exception as e:
138 | print(f"Error: {e}")
139 | movies_df = movies_df.drop(i)
140 | movies_df.to_json(json_file_path, orient="records")
141 | continue
142 |
143 | # Close the connection to Weaviate Cloud
144 | client.close()
--------------------------------------------------------------------------------
/recipes/weaviate/helpers/demo_movie_query.graphql:
--------------------------------------------------------------------------------
1 | { Get {MovieDemo (limit: 3
2 | where: { path: ["release_year"],
3 | operator: Equal,
4 | valueInt: 1985}) {
5 | budget
6 | movie_id
7 | overview
8 | release_year
9 | revenue
10 | tagline
11 | title
12 | vote_average
13 | vote_count
14 | }}}
--------------------------------------------------------------------------------
/recipes/weaviate/helpers/verify_data.py:
--------------------------------------------------------------------------------
1 | import weaviate
2 | from weaviate.classes.query import Filter
3 | from weaviate.classes.init import Auth
4 | import toml
5 | import os
6 |
7 |
8 | # Construct the path to the toml file
9 | current_dir = os.path.dirname(__file__)
10 | parent_dir = os.path.dirname(current_dir)
11 | toml_file_path = os.path.join(parent_dir, ".streamlit/secrets.toml")
12 |
13 | config = toml.load(toml_file_path)
14 |
15 | # Access values from the TOML file
16 | weaviate_api_key = config["WEAVIATE_API_KEY"]
17 | weaviate_url = config["WEAVIATE_URL"]
18 | cohere_api_key = config["COHERE_API_KEY"]
19 |
20 | client = weaviate.connect_to_weaviate_cloud(
21 | cluster_url=weaviate_url,
22 | auth_credentials=Auth.api_key(weaviate_api_key),
23 | headers={
24 | "X-Cohere-Api-Key": cohere_api_key
25 | }
26 | )
27 |
28 | # # If you are using a local instance of Weaviate, you can use the following code
29 | # client = weaviate.connect_to_local(
30 | # headers={
31 | # "X-Cohere-Api-Key": cohere_api_key
32 | # }
33 | # )
34 |
35 | movies = client.collections.get("MovieDemo")
36 |
37 | print(movies.aggregate.over_all(total_count=True))
38 |
39 | r = movies.query.fetch_objects(limit=1, return_properties=["title"])
40 | print(r.objects[0].properties["title"])
41 |
42 | client.close()
--------------------------------------------------------------------------------
/recipes/weaviate/requirements.txt:
--------------------------------------------------------------------------------
1 | streamlit
2 | st-weaviate-connection
3 | tqdm
--------------------------------------------------------------------------------