├── .gitignore
├── README.md
├── babyagi.py
├── img
└── robot.png
├── pyproject.toml
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.py[cod]
3 | *$py.class
4 |
5 | .env
6 | .env.*
7 | env/
8 | .venv
9 | venv/
10 |
11 | .vscode/
12 | .idea/
13 |
14 | *.lock
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | babyagi-streamlit
3 |
4 |
5 |
6 | # Demo
7 | https://user-images.githubusercontent.com/5876695/230781482-cb03a6b2-b2c9-49c8-ab6d-0066c0d82fd6.mov
8 |
9 |
10 | # Installation
11 |
12 | Install the required packages:
13 | ````
14 | poetry install
15 | ````
16 |
17 | # Usage
18 |
19 | Run streamlit.
20 | ````
21 | poetry run streamlit run babyagi.py
22 | ````
23 |
24 | You can now view your Streamlit app in your browser.
25 |
26 | Local URL: http://localhost:8501
27 |
28 | To stop the Streamlit server, press ctrl-C.
29 |
30 | # Acknowledgments
31 |
32 | I would like to express my gratitude to the developers whose code I referenced in creating this repo.
33 |
34 | Special thanks go to
35 |
36 | @yoheinakajima (https://github.com/yoheinakajima/babyagi)
37 |
38 | @hinthornw (https://github.com/hwchase17/langchain/pull/2559)
39 |
40 | ---
41 | Roboto Logo Icon by Icons Mind(https://iconscout.com/contributors/icons-mind) on IconScout(https://iconscout.com)
42 |
--------------------------------------------------------------------------------
/babyagi.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 | from typing import Dict, List, Optional
3 | from langchain import LLMChain, OpenAI, PromptTemplate
4 | from langchain.embeddings import HuggingFaceEmbeddings
5 | from langchain.llms import BaseLLM
6 | from langchain.vectorstores import FAISS
7 | from langchain.vectorstores.base import VectorStore
8 | from pydantic import BaseModel, Field
9 | import streamlit as st
10 |
11 | class TaskCreationChain(LLMChain):
12 | @classmethod
13 | def from_llm(cls, llm: BaseLLM, objective: str, verbose: bool = True) -> LLMChain:
14 | """Get the response parser."""
15 | task_creation_template = (
16 | "You are an task creation AI that uses the result of an execution agent"
17 | " to create new tasks with the following objective: {objective},"
18 | " The last completed task has the result: {result}."
19 | " This result was based on this task description: {task_description}."
20 | " These are incomplete tasks: {incomplete_tasks}."
21 | " Based on the result, create new tasks to be completed"
22 | " by the AI system that do not overlap with incomplete tasks."
23 | " Return the tasks as an array."
24 | )
25 | prompt = PromptTemplate(
26 | template=task_creation_template,
27 | partial_variables={"objective": objective},
28 | input_variables=["result", "task_description", "incomplete_tasks"],
29 | )
30 | return cls(prompt=prompt, llm=llm, verbose=verbose)
31 |
32 | def get_next_task(self, result: Dict, task_description: str, task_list: List[str]) -> List[Dict]:
33 | """Get the next task."""
34 | incomplete_tasks = ", ".join(task_list)
35 | response = self.run(result=result, task_description=task_description, incomplete_tasks=incomplete_tasks)
36 | new_tasks = response.split('\n')
37 | return [{"task_name": task_name} for task_name in new_tasks if task_name.strip()]
38 |
39 |
40 | class TaskPrioritizationChain(LLMChain):
41 | """Chain to prioritize tasks."""
42 |
43 | @classmethod
44 | def from_llm(cls, llm: BaseLLM, objective: str, verbose: bool = True) -> LLMChain:
45 | """Get the response parser."""
46 | task_prioritization_template = (
47 | "You are an task prioritization AI tasked with cleaning the formatting of and reprioritizing"
48 | " the following tasks: {task_names}."
49 | " Consider the ultimate objective of your team: {objective}."
50 | " Do not remove any tasks. Return the result as a numbered list, like:"
51 | " #. First task"
52 | " #. Second task"
53 | " Start the task list with number {next_task_id}."
54 | )
55 | prompt = PromptTemplate(
56 | template=task_prioritization_template,
57 | partial_variables={"objective": objective},
58 | input_variables=["task_names", "next_task_id"],
59 | )
60 | return cls(prompt=prompt, llm=llm, verbose=verbose)
61 |
62 | def prioritize_tasks(self, this_task_id: int, task_list: List[Dict]) -> List[Dict]:
63 | """Prioritize tasks."""
64 | task_names = [t["task_name"] for t in task_list]
65 | next_task_id = int(this_task_id) + 1
66 | response = self.run(task_names=task_names, next_task_id=next_task_id)
67 | new_tasks = response.split('\n')
68 | prioritized_task_list = []
69 | for task_string in new_tasks:
70 | if not task_string.strip():
71 | continue
72 | task_parts = task_string.strip().split(".", 1)
73 | if len(task_parts) == 2:
74 | task_id = task_parts[0].strip()
75 | task_name = task_parts[1].strip()
76 | prioritized_task_list.append({"task_id": task_id, "task_name": task_name})
77 | return prioritized_task_list
78 |
79 |
80 | class ExecutionChain(LLMChain):
81 | """Chain to execute tasks."""
82 |
83 | vectorstore: VectorStore = Field(init=False)
84 |
85 | @classmethod
86 | def from_llm(cls, llm: BaseLLM, vectorstore: VectorStore, verbose: bool = True) -> LLMChain:
87 | """Get the response parser."""
88 | execution_template = (
89 | "You are an AI who performs one task based on the following objective: {objective}."
90 | " Take into account these previously completed tasks: {context}."
91 | " Your task: {task}."
92 | " Response:"
93 | )
94 | prompt = PromptTemplate(
95 | template=execution_template,
96 | input_variables=["objective", "context", "task"],
97 | )
98 | return cls(prompt=prompt, llm=llm, verbose=verbose, vectorstore=vectorstore)
99 |
100 | def _get_top_tasks(self, query: str, k: int) -> List[str]:
101 | """Get the top k tasks based on the query."""
102 | results = self.vectorstore.similarity_search_with_score(query, k=k)
103 | if not results:
104 | return []
105 | sorted_results, _ = zip(*sorted(results, key=lambda x: x[1], reverse=True))
106 | return [str(item.metadata['task']) for item in sorted_results]
107 |
108 | def execute_task(self, objective: str, task: str, k: int = 5) -> str:
109 | """Execute a task."""
110 | context = self._get_top_tasks(query=objective, k=k)
111 | return self.run(objective=objective, context=context, task=task)
112 |
113 |
114 | class Message:
115 | exp: st.expander
116 | ai_icon = "./img/robot.png"
117 |
118 | def __init__(self, label: str):
119 | message_area, icon_area = st.columns([10, 1])
120 | icon_area.image(self.ai_icon, caption="BabyAGI")
121 |
122 | # Expander
123 | self.exp = message_area.expander(label=label, expanded=True)
124 |
125 | def __enter__(self):
126 | return self
127 |
128 | def __exit__(self, ex_type, ex_value, trace):
129 | pass
130 |
131 | def write(self, content):
132 | self.exp.markdown(content)
133 |
134 |
135 | class BabyAGI(BaseModel):
136 | """Controller model for the BabyAGI agent."""
137 |
138 | objective: str = Field(alias="objective")
139 | task_list: deque = Field(default_factory=deque)
140 | task_creation_chain: TaskCreationChain = Field(...)
141 | task_prioritization_chain: TaskPrioritizationChain = Field(...)
142 | execution_chain: ExecutionChain = Field(...)
143 | task_id_counter: int = Field(1)
144 |
145 | def add_task(self, task: Dict):
146 | self.task_list.append(task)
147 |
148 | def print_task_list(self):
149 | with Message(label="Task List") as m:
150 | m.write("### Task List")
151 | for t in self.task_list:
152 | m.write("- " + str(t["task_id"]) + ": " + t["task_name"])
153 | m.write("")
154 |
155 | def print_next_task(self, task: Dict):
156 | with Message(label="Next Task") as m:
157 | m.write("### Next Task")
158 | m.write("- " + str(task["task_id"]) + ": " + task["task_name"])
159 | m.write("")
160 |
161 | def print_task_result(self, result: str):
162 | with Message(label="Task Result") as m:
163 | m.write("### Task Result")
164 | m.write(result)
165 | m.write("")
166 |
167 | def print_task_ending(self):
168 | with Message(label="Task Ending") as m:
169 | m.write("### Task Ending")
170 | m.write("")
171 |
172 |
173 | def run(self, max_iterations: Optional[int] = None):
174 | """Run the agent."""
175 | num_iters = 0
176 | while True:
177 | if self.task_list:
178 | self.print_task_list()
179 |
180 | # Step 1: Pull the first task
181 | task = self.task_list.popleft()
182 | self.print_next_task(task)
183 |
184 | # Step 2: Execute the task
185 | result = self.execution_chain.execute_task(
186 | self.objective, task["task_name"]
187 | )
188 | this_task_id = int(task["task_id"])
189 | self.print_task_result(result)
190 |
191 | # Step 3: Store the result in Pinecone
192 | result_id = f"result_{task['task_id']}"
193 | self.execution_chain.vectorstore.add_texts(
194 | texts=[result],
195 | metadatas=[{"task": task["task_name"]}],
196 | ids=[result_id],
197 | )
198 |
199 | # Step 4: Create new tasks and reprioritize task list
200 | new_tasks = self.task_creation_chain.get_next_task(
201 | result, task["task_name"], [t["task_name"] for t in self.task_list]
202 | )
203 | for new_task in new_tasks:
204 | self.task_id_counter += 1
205 | new_task.update({"task_id": self.task_id_counter})
206 | self.add_task(new_task)
207 | self.task_list = deque(
208 | self.task_prioritization_chain.prioritize_tasks(
209 | this_task_id, list(self.task_list)
210 | )
211 | )
212 | num_iters += 1
213 | if max_iterations is not None and num_iters == max_iterations:
214 | self.print_task_ending()
215 | break
216 |
217 | @classmethod
218 | def from_llm_and_objectives(
219 | cls,
220 | llm: BaseLLM,
221 | vectorstore: VectorStore,
222 | objective: str,
223 | first_task: str,
224 | verbose: bool = False,
225 | ) -> "BabyAGI":
226 | """Initialize the BabyAGI Controller."""
227 | task_creation_chain = TaskCreationChain.from_llm(
228 | llm, objective, verbose=verbose
229 | )
230 | task_prioritization_chain = TaskPrioritizationChain.from_llm(
231 | llm, objective, verbose=verbose
232 | )
233 | execution_chain = ExecutionChain.from_llm(llm, vectorstore, verbose=verbose)
234 | controller = cls(
235 | objective=objective,
236 | task_creation_chain=task_creation_chain,
237 | task_prioritization_chain=task_prioritization_chain,
238 | execution_chain=execution_chain,
239 | )
240 | controller.add_task({"task_id": 1, "task_name": first_task})
241 | return controller
242 |
243 |
244 | def main():
245 | st.set_page_config(
246 | initial_sidebar_state="expanded",
247 | page_title="BabyAGI Streamlit",
248 | layout="centered",
249 | )
250 |
251 | with st.sidebar:
252 | openai_api_key = st.text_input('Your OpenAI API KEY', type="password")
253 |
254 | st.title("BabyAGI Streamlit")
255 | objective = st.text_input("Input Ultimate goal", "Solve world hunger")
256 | first_task = st.text_input("Input Where to start", "Develop a task list")
257 | max_iterations = st.number_input("Max iterations", value=3, min_value=1, step=1)
258 | button = st.button("Run")
259 |
260 | embedding_model = HuggingFaceEmbeddings()
261 | vectorstore = FAISS.from_texts(["_"], embedding_model, metadatas=[{"task":first_task}])
262 |
263 | if button:
264 | try:
265 | baby_agi = BabyAGI.from_llm_and_objectives(
266 | llm=OpenAI(openai_api_key=openai_api_key),
267 | vectorstore=vectorstore,
268 | objective=objective,
269 | first_task=first_task,
270 | verbose=False
271 | )
272 | baby_agi.run(max_iterations=max_iterations)
273 | except Exception as e:
274 | st.error(e)
275 |
276 |
277 | if __name__ == "__main__":
278 | main()
279 |
--------------------------------------------------------------------------------
/img/robot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dory111111/babyagi-streamlit/0f05e59de24b0eb0891649adb40df4a1e61ebc41/img/robot.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "babyagi-streamlit"
3 | version = "1.0.0"
4 | description = ""
5 | authors = ["Dory "]
6 |
7 | [tool.poetry.dependencies]
8 | python = ">=3.10.10,<3.12"
9 | openai = "^0.27.0"
10 | langchain = ">=0.0.131"
11 | python-dotenv = "^1.0.0"
12 | faiss-cpu = "^1.7.3"
13 | sentence-transformers = "^2.2.2"
14 | streamlit = "^1.21.0"
15 |
16 | [tool.poetry.dev-dependencies]
17 |
18 | [tool.poetry.group.dev.dependencies]
19 | flake8 = "^6.0.0"
20 | black = "^23.1.0"
21 | isort = "^5.12.0"
22 |
23 | [build-system]
24 | requires = ["poetry-core>=1.0.0"]
25 | build-backend = "poetry.core.masonry.api"
26 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | aiohttp
2 | aiosignal
3 | altair
4 | async-timeout
5 | attrs
6 | blinker
7 | cachetools
8 | certifi
9 | charset-normalizer
10 | click
11 | cmake
12 | colorama
13 | dataclasses-json
14 | decorator
15 | entrypoints
16 | faiss-cpu
17 | filelock
18 | frozenlist
19 | gitdb
20 | gitpython
21 | greenlet
22 | huggingface-hub
23 | idna
24 | importlib-metadata
25 | jinja2
26 | joblib
27 | jsonschema
28 | langchain
29 | lit
30 | markdown-it-py
31 | markupsafe
32 | marshmallow-enum
33 | marshmallow
34 | mdurl
35 | mpmath
36 | multidict
37 | mypy-extensions
38 | networkx
39 | nltk
40 | numpy
41 | nvidia-cublas-cu11
42 | nvidia-cuda-cupti-cu11
43 | nvidia-cuda-nvrtc-cu11
44 | nvidia-cuda-runtime-cu11
45 | nvidia-cudnn-cu11
46 | nvidia-cufft-cu11
47 | nvidia-curand-cu11
48 | nvidia-cusolver-cu11
49 | nvidia-cusparse-cu11
50 | nvidia-nccl-cu11
51 | nvidia-nvtx-cu11
52 | openai
53 | openapi-schema-pydantic
54 | packaging
55 | pandas
56 | pillow
57 | protobuf
58 | pyarrow
59 | pydantic
60 | pydeck
61 | pygments
62 | pympler
63 | pyrsistent
64 | python-dateutil
65 | python-dotenv
66 | pytz-deprecation-shim
67 | pytz
68 | pyyaml
69 | regex
70 | requests
71 | rich
72 | scikit-learn
73 | scipy
74 | sentence-transformers
75 | sentencepiece
76 | setuptools
77 | six
78 | smmap
79 | sqlalchemy
80 | streamlit
81 | sympy
82 | tenacity
83 | threadpoolctl
84 | tokenizers
85 | toml
86 | toolz
87 | torch
88 | torchvision
89 | tornado
90 | tqdm
91 | transformers
92 | triton
93 | typing-extensions
94 | typing-inspect
95 | tzdata
96 | tzlocal
97 | urllib3
98 | validators
99 | watchdog
100 | wheel
101 | yarl
102 | zipp
--------------------------------------------------------------------------------