├── .gitattributes ├── .github ├── README.md └── workflows │ └── sync-to-hf.yaml ├── .gitignore ├── LICENSE ├── README-main.md ├── README.md ├── autoagents ├── __init__.py ├── agents │ ├── __init__.py │ ├── agents │ │ ├── __init__.py │ │ ├── search.py │ │ ├── search_v3.py │ │ └── wiki_agent.py │ ├── models │ │ └── custom.py │ ├── spaces │ │ └── app.py │ ├── tools │ │ ├── __init__.py │ │ └── tools.py │ └── utils │ │ ├── __init__.py │ │ ├── constants.py │ │ └── logger.py ├── data │ ├── action_name_transformation.py │ ├── create_sft_dataset.py │ ├── dataset.py │ ├── generate_action_data.py │ └── generate_action_tasks │ │ ├── REACT.ipynb │ │ ├── README.md │ │ ├── generate_data.py │ │ ├── generate_data_chat_api.py │ │ ├── goals_test.json │ │ ├── goals_train.json │ │ ├── goals_valid.json │ │ ├── previous_goals.json │ │ ├── prompt.txt │ │ ├── requirements.txt │ │ ├── run_genenerate_data.sh │ │ ├── seed_tasks_test.jsonl │ │ ├── seed_tasks_train.jsonl │ │ ├── seed_tasks_valid.jsonl │ │ └── utils.py ├── eval │ ├── README.md │ ├── __init__.py │ ├── bamboogle.py │ ├── hotpotqa │ │ ├── README.md │ │ ├── __init__.py │ │ ├── constants.py │ │ ├── eval_async.py │ │ ├── hotpotqa_eval.py │ │ └── run_eval.py │ ├── metrics.py │ ├── reward │ │ ├── eval.py │ │ └── get_scores.py │ └── test.py ├── serve │ ├── README.md │ ├── action_api_server.py │ ├── action_model_worker.py │ ├── controller.sh │ ├── model_worker.sh │ ├── openai_api.sh │ └── serve_rescale.py └── train │ ├── README.md │ ├── scripts │ ├── action_finetuning.sh │ ├── action_finetuning_v3.sh │ ├── conv_finetuning.sh │ ├── longchat_action_finetuning.sh │ └── longchat_conv_finetuning.sh │ ├── test_v3_preprocess.py │ ├── train.py │ └── train_v3.py ├── requirements.txt └── setup.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.7z filter=lfs diff=lfs merge=lfs -text 2 | *.arrow filter=lfs diff=lfs merge=lfs -text 3 | *.bin filter=lfs diff=lfs merge=lfs -text 4 | *.bz2 filter=lfs diff=lfs merge=lfs -text 5 | *.ckpt filter=lfs diff=lfs merge=lfs -text 6 | *.ftz filter=lfs diff=lfs merge=lfs -text 7 | *.gz filter=lfs diff=lfs merge=lfs -text 8 | *.h5 filter=lfs diff=lfs merge=lfs -text 9 | *.joblib filter=lfs diff=lfs merge=lfs -text 10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text 11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text 12 | *.model filter=lfs diff=lfs merge=lfs -text 13 | *.msgpack filter=lfs diff=lfs merge=lfs -text 14 | *.npy filter=lfs diff=lfs merge=lfs -text 15 | *.npz filter=lfs diff=lfs merge=lfs -text 16 | *.onnx filter=lfs diff=lfs merge=lfs -text 17 | *.ot filter=lfs diff=lfs merge=lfs -text 18 | *.parquet filter=lfs diff=lfs merge=lfs -text 19 | *.pb filter=lfs diff=lfs merge=lfs -text 20 | *.pickle filter=lfs diff=lfs merge=lfs -text 21 | *.pkl filter=lfs diff=lfs merge=lfs -text 22 | *.pt filter=lfs diff=lfs merge=lfs -text 23 | *.pth filter=lfs diff=lfs merge=lfs -text 24 | *.rar filter=lfs diff=lfs merge=lfs -text 25 | *.safetensors filter=lfs diff=lfs merge=lfs -text 26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text 27 | *.tar.* filter=lfs diff=lfs merge=lfs -text 28 | *.tflite filter=lfs diff=lfs merge=lfs -text 29 | *.tgz filter=lfs diff=lfs merge=lfs -text 30 | *.wasm filter=lfs diff=lfs merge=lfs -text 31 | *.xz filter=lfs diff=lfs merge=lfs -text 32 | *.zip filter=lfs diff=lfs merge=lfs -text 33 | *.zst filter=lfs diff=lfs merge=lfs -text 34 | *tfevents* filter=lfs diff=lfs merge=lfs -text 35 | -------------------------------------------------------------------------------- /.github/README.md: -------------------------------------------------------------------------------- 1 | ../README-main.md -------------------------------------------------------------------------------- /.github/workflows/sync-to-hf.yaml: -------------------------------------------------------------------------------- 1 | name: Sync to Hugging Face hub 2 | on: 3 | push: 4 | branches: [main] 5 | 6 | # to run this workflow manually from the Actions tab 7 | workflow_dispatch: 8 | 9 | jobs: 10 | sync-to-hub: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v3 14 | with: 15 | fetch-depth: 0 16 | lfs: true 17 | ref: hf-active 18 | - name: Push to hub 19 | env: 20 | HF_TOKEN: ${{ secrets.HF_TOKEN }} 21 | run: git push https://omkarenator:$HF_TOKEN@huggingface.co/spaces/AutoLLM/AutoAgents hf-active:main 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | .DS_Store 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 AutoLLM 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README-main.md: -------------------------------------------------------------------------------- 1 | # AutoAgents 2 | 3 |

4 | 5 | Unlock complex question answering in LLMs with enhanced chain-of-thought reasoning and information-seeking capabilities. 6 | 7 | ## 👉 Overview 8 | 9 | The purpose of this project is to extend LLMs ability to answer more complex questions through chain-of-thought reasoning and information-seeking actions. 10 | 11 | We are excited to release the initial version of AutoAgents, a proof-of-concept on what can be achieved with only well-written prompts. This is the initial step towards our first big milestone, releasing and open-sourcing the AutoAgents 7B model! 12 | 13 | Come try out our [Huggingface Space](https://huggingface.co/spaces/AutoLLM/AutoAgents)! 14 | 15 | 16 | 17 | ## 🤖 The AutoAgents Project 18 | 19 | This project demonstrates LLMs capability to execute a complex user goal: understand a user's goal, generate a plan, use proper tools, and deliver a final result. 20 | 21 | For simplicity, our first attempt starts with a Web Search Agent. 22 | 23 | 24 | 25 | ## 💫 How it works: 26 | 27 |

28 | 29 | 30 | 31 | ## 📔 Examples 32 | 33 | Ask your AutoAgent to do what a real person would do using the internet: 34 | 35 | For example: 36 | 37 | *1. Recommend a kid friendly movie that is playing at a theater near Sunnyvale. Give me the showtimes and a link to purchase the tickets* 38 | 39 | *2. What is the average age of the past three president when they took office* 40 | 41 | *3. What is the mortgage rate right now and how does that compare to the past two years* 42 | 43 | 44 | 45 | ## 💁 Roadmap 46 | 47 | * ~~HuggingFace Space demo using OpenAI models~~ [LINK](https://huggingface.co/spaces/AutoLLM/AutoAgents) 48 | * AutoAgents [7B] Model 49 | * Initial Release: 50 | * Finetune and release a 7B parameter fine-tuned search model 51 | * AutoAgents Dataset 52 | * A high-quality dataset for a diverse set of search scenarios (why quality and diversity?[1](https://arxiv.org/abs/2305.11206)) 53 | * Reduce Model Inference Overhead 54 | * Affordance Modeling [2](https://en.wikipedia.org/wiki/Affordance) 55 | * Extend Support to Additional Tools 56 | * Customizable Document Search set (e.g. personal documents) 57 | * Support Multi-turn Dialogue 58 | * Advanced Flow Control in Plan Execution 59 | 60 | We are actively developing a few interesting things, check back here or follow us on [Twitter](https://twitter.com/AutoLLM) for any new development. 61 | 62 | If you are interested in any other problems, feel free to shoot us an issue. 63 | 64 | 65 | 66 | ## 🧭 How to use this repo? 67 | 68 | This repo contains the entire code to run the search agent from your local browser. All you need is an OpenAI API key to begin. 69 | 70 | To run the search agent locally: 71 | 72 | 1. Clone the repo and change the directory 73 | 74 | ```bash 75 | git clone https://github.com/AutoLLM/AutoAgents.git 76 | cd AutoAgents 77 | ``` 78 | 79 | 2. Install the dependencies 80 | 81 | ```bash 82 | pip install -r requirements.txt 83 | ``` 84 | 85 | 3. Install the `autoagents` package 86 | 87 | ```bash 88 | pip install -e . 89 | ``` 90 | 91 | 4. Make sure you have your OpenAI API key set as an environment variable. Alternatively, you can also feed it through the input text-box on the sidebar. 92 | 93 | ```bash 94 | export OPENAI_API_KEY=sk-xxxxxx 95 | ``` 96 | 97 | 5. Run the Streamlit app 98 | 99 | ```bash 100 | streamlit run autoagents/agents/spaces/app.py 101 | ``` 102 | 103 | This should open a browser window where you can type your search query. 104 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: AutoAgents 3 | emoji: 🐢 4 | colorFrom: green 5 | colorTo: purple 6 | sdk: streamlit 7 | sdk_version: 1.21.0 8 | python_version: 3.10.11 9 | app_file: autoagents/spaces/app.py 10 | pinned: true 11 | --- 12 | 13 | Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference 14 | -------------------------------------------------------------------------------- /autoagents/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutoLLM/AutoAgents/e1210c7884f951f1b90254e40c162e81fc1442f3/autoagents/__init__.py -------------------------------------------------------------------------------- /autoagents/agents/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutoLLM/AutoAgents/e1210c7884f951f1b90254e40c162e81fc1442f3/autoagents/agents/__init__.py -------------------------------------------------------------------------------- /autoagents/agents/agents/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutoLLM/AutoAgents/e1210c7884f951f1b90254e40c162e81fc1442f3/autoagents/agents/agents/__init__.py -------------------------------------------------------------------------------- /autoagents/agents/agents/search.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import uuid 4 | import re 5 | from datetime import date 6 | import asyncio 7 | from collections import defaultdict 8 | from pprint import pprint 9 | from typing import List, Union, Any, Optional, Dict 10 | 11 | from langchain.agents import Tool, AgentExecutor, LLMSingleActionAgent, AgentOutputParser 12 | from langchain.prompts import StringPromptTemplate 13 | from langchain import LLMChain 14 | from langchain.chat_models import ChatOpenAI 15 | from langchain.schema import AgentAction, AgentFinish 16 | from langchain.callbacks import get_openai_callback 17 | from langchain.callbacks.base import AsyncCallbackHandler 18 | from langchain.callbacks.manager import AsyncCallbackManager 19 | from langchain.base_language import BaseLanguageModel 20 | 21 | from autoagents.agents.tools.tools import search_tool, note_tool, rewrite_search_query, finish_tool 22 | from autoagents.agents.utils.logger import InteractionsLogger 23 | from autoagents.agents.utils.constants import LOG_SAVE_DIR 24 | 25 | from pydantic import BaseModel, ValidationError, Extra # pydantic==1.10.11 26 | 27 | 28 | class InterOutputSchema(BaseModel): 29 | thought: str 30 | reasoning: str 31 | plan: List[str] 32 | action: str 33 | action_input: str 34 | class Config: 35 | extra = Extra.forbid 36 | 37 | 38 | class FinalOutputSchema(BaseModel): 39 | thought: str 40 | reasoning: str 41 | plan: List[str] 42 | action: str 43 | action_input: str 44 | citations: List[str] 45 | class Config: 46 | extra = Extra.forbid 47 | 48 | 49 | def check_valid(o): 50 | try: 51 | if o.get("action") == "Tool_Finish": 52 | FinalOutputSchema(**o) 53 | else: 54 | InterOutputSchema(**o) 55 | except ValidationError: 56 | return False 57 | return True 58 | 59 | 60 | # Set up the base template 61 | template = """We are working together to satisfy the user's original goal 62 | step-by-step. Play to your strengths as an LLM. Make sure the plan is 63 | achievable using the available tools. The final answer should be descriptive, 64 | and should include all relevant details. 65 | 66 | Today is {today}. 67 | 68 | ## Goal: 69 | {input} 70 | 71 | If you require assistance or additional information, you should use *only* one 72 | of the following tools: {tools}. 73 | 74 | ## History 75 | {agent_scratchpad} 76 | 77 | Do not repeat any past actions in History, because you will not get additional 78 | information. If the last action is Tool_Search, then you should use Tool_Notepad to keep 79 | critical information. If you have gathered all information in your plannings 80 | to satisfy the user's original goal, then respond immediately with the Finish 81 | Action. 82 | 83 | ## Output format 84 | You MUST produce JSON output with below keys: 85 | "thought": "current train of thought", 86 | "reasoning": "reasoning", 87 | "plan": [ 88 | "short bulleted", 89 | "list that conveys", 90 | "next-step plan", 91 | ], 92 | "action": "the action to take", 93 | "action_input": "the input to the Action", 94 | """ 95 | 96 | 97 | # Set up a prompt template 98 | class CustomPromptTemplate(StringPromptTemplate): 99 | # The template to use 100 | template: str 101 | # The list of tools available 102 | tools: List[Tool] 103 | ialogger: InteractionsLogger 104 | 105 | def format(self, **kwargs) -> str: 106 | # Get the intermediate steps [(AgentAction, Observation)] 107 | # Format them in a particular way 108 | intermediate_steps = kwargs.pop("intermediate_steps") 109 | history = [] 110 | # Set the agent_scratchpad variable to that value 111 | for i, (action, observation) in enumerate(intermediate_steps): 112 | if action.tool not in [tool.name for tool in self.tools]: 113 | raise Exception("Invalid tool requested by the model.") 114 | parsed = json.loads(action.log) 115 | if i == len(intermediate_steps) - 1: 116 | # Add observation only for the last action 117 | parsed["observation"] = observation 118 | history.append(parsed) 119 | self.ialogger.add_history(history) 120 | kwargs["agent_scratchpad"] = json.dumps(history) 121 | # Create a tools variable from the list of tools provided 122 | kwargs["tools"] = "\n".join([f"{tool.name}: {tool.description}" for tool in self.tools]) 123 | # Create a list of tool names for the tools provided 124 | kwargs["tool_names"] = ", ".join([tool.name for tool in self.tools]) 125 | kwargs["today"] = date.today() 126 | final_prompt = self.template.format(**kwargs) 127 | self.ialogger.add_system(final_prompt) 128 | return final_prompt 129 | 130 | 131 | class CustomOutputParser(AgentOutputParser): 132 | class Config: 133 | arbitrary_types_allowed = True 134 | ialogger: InteractionsLogger 135 | llm: BaseLanguageModel 136 | new_action_input: Optional[str] 137 | action_history = defaultdict(set) 138 | 139 | def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]: 140 | self.ialogger.add_ai(llm_output) 141 | parsed = json.loads(llm_output) 142 | if not check_valid(parsed): 143 | raise ValueError(f"Could not parse LLM output: `{llm_output}`") 144 | 145 | # Parse out the action and action input 146 | action = parsed["action"] 147 | action_input = parsed["action_input"] 148 | 149 | if action == "Tool_Finish": 150 | return AgentFinish(return_values={"output": action_input}, log=llm_output) 151 | 152 | if action_input in self.action_history[action]: 153 | new_action_input = rewrite_search_query(action_input, 154 | self.action_history[action], 155 | self.llm) 156 | self.ialogger.add_message({"query_rewrite": True}) 157 | self.new_action_input = new_action_input 158 | self.action_history[action].add(new_action_input) 159 | return AgentAction(tool=action, tool_input=new_action_input, log=llm_output) 160 | else: 161 | # Return the action and action input 162 | self.action_history[action].add(action_input) 163 | return AgentAction(tool=action, tool_input=action_input, log=llm_output) 164 | 165 | 166 | class ActionRunner: 167 | def __init__(self, 168 | outputq, 169 | llm: BaseLanguageModel, 170 | persist_logs: bool = False, 171 | prompt_template: str = template, 172 | tools: List[Tool] = [search_tool, note_tool, finish_tool]): 173 | self.ialogger = InteractionsLogger(name=f"{uuid.uuid4().hex[:6]}", persist=persist_logs) 174 | prompt = CustomPromptTemplate(template=prompt_template, 175 | tools=tools, 176 | input_variables=["input", "intermediate_steps"], 177 | ialogger=self.ialogger) 178 | 179 | output_parser = CustomOutputParser(ialogger=self.ialogger, llm=llm) 180 | self.model_name = llm.model_name 181 | 182 | class MyCustomHandler(AsyncCallbackHandler): 183 | def __init__(self): 184 | pass 185 | 186 | async def on_chain_end(self, outputs, **kwargs) -> None: 187 | if "text" in outputs: 188 | await outputq.put(outputs["text"]) 189 | 190 | async def on_agent_action( 191 | self, 192 | action: AgentAction, 193 | *, 194 | run_id: uuid.UUID, 195 | parent_run_id: Optional[uuid.UUID] = None, 196 | **kwargs: Any, 197 | ) -> None: 198 | if (new_action_input := output_parser.new_action_input): 199 | await outputq.put(RuntimeWarning(f"Action Input Rewritten: {new_action_input}")) 200 | # Notify users 201 | output_parser.new_action_input = None 202 | 203 | async def on_tool_start( 204 | self, 205 | serialized: Dict[str, Any], 206 | input_str: str, 207 | *, 208 | run_id: uuid.UUID, 209 | parent_run_id: Optional[uuid.UUID] = None, 210 | **kwargs: Any, 211 | ) -> None: 212 | pass 213 | 214 | async def on_tool_end( 215 | self, 216 | output: str, 217 | *, 218 | run_id: uuid.UUID, 219 | parent_run_id: Optional[uuid.UUID] = None, 220 | **kwargs: Any, 221 | ) -> None: 222 | await outputq.put(output) 223 | 224 | handler = MyCustomHandler() 225 | 226 | llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=[handler]) 227 | tool_names = [tool.name for tool in tools] 228 | for tool in tools: 229 | tool.callbacks = [handler] 230 | 231 | agent = LLMSingleActionAgent( 232 | llm_chain=llm_chain, 233 | output_parser=output_parser, 234 | stop=["0xdeadbeef"], # required 235 | allowed_tools=tool_names 236 | ) 237 | callback_manager = AsyncCallbackManager([handler]) 238 | 239 | # Finally create the Executor 240 | self.agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, 241 | tools=tools, 242 | verbose=False, 243 | callback_manager=callback_manager) 244 | 245 | async def run(self, goal: str, outputq, save_dir=LOG_SAVE_DIR): 246 | self.ialogger.set_goal(goal) 247 | try: 248 | with get_openai_callback() as cb: 249 | output = await self.agent_executor.arun(goal) 250 | self.ialogger.add_cost({"total_tokens": cb.total_tokens, 251 | "prompt_tokens": cb.prompt_tokens, 252 | "completion_tokens": cb.completion_tokens, 253 | "total_cost": cb.total_cost, 254 | "successful_requests": cb.successful_requests}) 255 | self.ialogger.save(save_dir) 256 | except Exception as e: 257 | self.ialogger.add_message({"error": str(e)}) 258 | self.ialogger.save(save_dir) 259 | await outputq.put(e) 260 | return 261 | return output 262 | -------------------------------------------------------------------------------- /autoagents/agents/agents/search_v3.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | import json 3 | import uuid 4 | from collections import defaultdict 5 | from typing import List, Union, Any, Optional, Dict 6 | 7 | from langchain.agents import Tool, AgentExecutor, LLMSingleActionAgent, AgentOutputParser 8 | from langchain.prompts import StringPromptTemplate 9 | from langchain import LLMChain 10 | from langchain.schema import AgentAction, AgentFinish 11 | from langchain.callbacks import get_openai_callback 12 | from langchain.callbacks.base import AsyncCallbackHandler 13 | from langchain.callbacks.manager import AsyncCallbackManager 14 | from langchain.base_language import BaseLanguageModel 15 | 16 | from autoagents.agents.tools.tools import search_tool_v3, note_tool_v3, finish_tool_v3 17 | from autoagents.agents.utils.logger import InteractionsLogger 18 | from autoagents.agents.utils.constants import LOG_SAVE_DIR 19 | from autoagents.agents.agents.search import check_valid 20 | 21 | 22 | # Set up a prompt template 23 | class CustomPromptTemplate(StringPromptTemplate): 24 | # The list of tools available 25 | tools: List[Tool] 26 | ialogger: InteractionsLogger 27 | 28 | def format(self, **kwargs) -> str: 29 | # Get the intermediate steps [(AgentAction, Observation)] 30 | # Format them in a particular way 31 | intermediate_steps = kwargs.pop("intermediate_steps") 32 | history = [] 33 | # Set the agent_scratchpad variable to that value 34 | for i, (action, observation) in enumerate(intermediate_steps): 35 | if action.tool not in [tool.name for tool in self.tools]: 36 | raise Exception("Invalid tool requested by the model.") 37 | parsed = json.loads(action.log) 38 | if i == len(intermediate_steps) - 1: 39 | # Add observation only for the last action 40 | parsed["observation"] = observation 41 | history.append(parsed) 42 | self.ialogger.add_history(history) 43 | goal = kwargs["input"] 44 | goal = f"Today is {date.today()}. {goal}" 45 | list_prompt =[] 46 | list_prompt.append({"role": "goal", "content": goal}) 47 | list_prompt.append({"role": "tools", "content": [{tool.name: tool.description} for tool in self.tools]}) 48 | list_prompt.append({"role": "history", "content": history}) 49 | return json.dumps(list_prompt) 50 | 51 | 52 | class CustomOutputParser(AgentOutputParser): 53 | class Config: 54 | arbitrary_types_allowed = True 55 | ialogger: InteractionsLogger 56 | llm: BaseLanguageModel 57 | new_action_input: Optional[str] 58 | action_history = defaultdict(set) 59 | 60 | def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]: 61 | try: 62 | parsed = json.loads(llm_output) 63 | except json.decoder.JSONDecodeError: 64 | raise ValueError(f"Could not parse LLM output: `{llm_output}`") 65 | if not check_valid(parsed): 66 | raise ValueError(f"Could not parse LLM output: `{llm_output}`") 67 | self.ialogger.add_ai(llm_output) 68 | # Parse out the action and action input 69 | action = parsed["action"] 70 | action_input = parsed["action_input"] 71 | 72 | if action == "Tool_Finish": 73 | return AgentFinish(return_values={"output": action_input}, log=llm_output) 74 | return AgentAction(tool=action, tool_input=action_input, log=llm_output) 75 | 76 | 77 | class ActionRunnerV3: 78 | def __init__(self, 79 | outputq, 80 | llm: BaseLanguageModel, 81 | persist_logs: bool = False, 82 | tools = [search_tool_v3, note_tool_v3, finish_tool_v3]): 83 | self.ialogger = InteractionsLogger(name=f"{uuid.uuid4().hex[:6]}", persist=persist_logs) 84 | self.ialogger.set_tools([{tool.name: tool.description} for tool in tools]) 85 | prompt = CustomPromptTemplate(tools=tools, 86 | input_variables=["input", "intermediate_steps"], 87 | ialogger=self.ialogger) 88 | output_parser = CustomOutputParser(ialogger=self.ialogger, llm=llm) 89 | 90 | class MyCustomHandler(AsyncCallbackHandler): 91 | def __init__(self): 92 | pass 93 | 94 | async def on_chain_end(self, outputs, **kwargs) -> None: 95 | if "text" in outputs: 96 | await outputq.put(outputs["text"]) 97 | 98 | async def on_agent_action( 99 | self, 100 | action: AgentAction, 101 | *, 102 | run_id: uuid.UUID, 103 | parent_run_id: Optional[uuid.UUID] = None, 104 | **kwargs: Any, 105 | ) -> None: 106 | pass 107 | 108 | async def on_tool_start( 109 | self, 110 | serialized: Dict[str, Any], 111 | input_str: str, 112 | *, 113 | run_id: uuid.UUID, 114 | parent_run_id: Optional[uuid.UUID] = None, 115 | **kwargs: Any, 116 | ) -> None: 117 | pass 118 | 119 | async def on_tool_end( 120 | self, 121 | output: str, 122 | *, 123 | run_id: uuid.UUID, 124 | parent_run_id: Optional[uuid.UUID] = None, 125 | **kwargs: Any, 126 | ) -> None: 127 | await outputq.put(output) 128 | 129 | handler = MyCustomHandler() 130 | llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=[handler]) 131 | tool_names = [tool.name for tool in tools] 132 | for tool in tools: 133 | tool.callbacks = [handler] 134 | 135 | agent = LLMSingleActionAgent( 136 | llm_chain=llm_chain, 137 | output_parser=output_parser, 138 | stop=["0xdeadbeef"], # required 139 | allowed_tools=tool_names 140 | ) 141 | callback_manager = AsyncCallbackManager([handler]) 142 | # Finally create the Executor 143 | self.agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, 144 | tools=tools, 145 | verbose=False, 146 | callback_manager=callback_manager) 147 | 148 | async def run(self, goal: str, outputq, save_dir=LOG_SAVE_DIR): 149 | goal = f"Today is {date.today()}. {goal}" 150 | self.ialogger.set_goal(goal) 151 | try: 152 | with get_openai_callback() as cb: 153 | output = await self.agent_executor.arun(goal) 154 | self.ialogger.add_cost({"total_tokens": cb.total_tokens, 155 | "prompt_tokens": cb.prompt_tokens, 156 | "completion_tokens": cb.completion_tokens, 157 | "total_cost": cb.total_cost, 158 | "successful_requests": cb.successful_requests}) 159 | self.ialogger.save(save_dir) 160 | except Exception as e: 161 | self.ialogger.add_message({"error": str(e)}) 162 | self.ialogger.save(save_dir) 163 | await outputq.put(e) 164 | return 165 | return output 166 | -------------------------------------------------------------------------------- /autoagents/agents/agents/wiki_agent.py: -------------------------------------------------------------------------------- 1 | from langchain.base_language import BaseLanguageModel 2 | from autoagents.agents.agents.search import ActionRunner 3 | from autoagents.agents.agents.search_v3 import ActionRunnerV3 4 | from autoagents.agents.tools.tools import wiki_dump_search_tool, wiki_note_tool, finish_tool 5 | 6 | 7 | # Set up the base template 8 | template = """We are working together to satisfy the user's original goal 9 | step-by-step. Play to your strengths as an LLM. Make sure the plan is 10 | achievable using the available tools. The final answer should be descriptive, 11 | and should include all relevant details. 12 | 13 | Today is {today}. 14 | 15 | ## Goal: 16 | {input} 17 | 18 | If you require assistance or additional information, you should use *only* one 19 | of the following tools: {tools}. 20 | 21 | ## History 22 | {agent_scratchpad} 23 | 24 | Do not repeat any past actions in History, because you will not get additional 25 | information. If the last action is Tool_Wikipedia, then you should use Tool_Notepad to keep 26 | critical information. If you have gathered all information in your plannings 27 | to satisfy the user's original goal, then respond immediately with the Finish 28 | Action. 29 | 30 | ## Output format 31 | You MUST produce JSON output with below keys: 32 | "thought": "current train of thought", 33 | "reasoning": "reasoning", 34 | "plan": [ 35 | "short bulleted", 36 | "list that conveys", 37 | "next-step plan", 38 | ], 39 | "action": "the action to take", 40 | "action_input": "the input to the Action", 41 | """ 42 | 43 | 44 | class WikiActionRunner(ActionRunner): 45 | 46 | def __init__(self, outputq, llm: BaseLanguageModel, persist_logs: bool = False): 47 | 48 | super().__init__( 49 | outputq, llm, persist_logs, 50 | prompt_template=template, 51 | tools=[wiki_dump_search_tool, wiki_note_tool, finish_tool] 52 | ) 53 | 54 | class WikiActionRunnerV3(ActionRunnerV3): 55 | def __init__(self, outputq, llm: BaseLanguageModel, persist_logs: bool = False): 56 | 57 | super().__init__( 58 | outputq, llm, persist_logs, 59 | tools=[wiki_dump_search_tool, wiki_note_tool, finish_tool] 60 | ) 61 | -------------------------------------------------------------------------------- /autoagents/agents/models/custom.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | 4 | from langchain.llms.base import LLM 5 | 6 | 7 | class CustomLLM(LLM): 8 | model_name: str 9 | completions_url: str = "http://localhost:8000/v1/chat/completions" 10 | temperature: float = 0. 11 | max_tokens: int = 1024 12 | 13 | @property 14 | def _llm_type(self) -> str: 15 | return "custom" 16 | 17 | def _call(self, prompt: str, stop=None) -> str: 18 | r = requests.post( 19 | self.completions_url, 20 | json={ 21 | "model": self.model_name, 22 | "messages": [{"role": "user", "content": prompt}], 23 | "stop": stop, 24 | "temperature": self.temperature, 25 | "max_tokens": self.max_tokens 26 | }, 27 | ) 28 | result = r.json() 29 | try: 30 | return result["choices"][0]["message"]["content"] 31 | except: 32 | raise RuntimeError(result) 33 | 34 | async def _acall(self, prompt: str, stop=None) -> str: 35 | return self._call(prompt, stop) 36 | 37 | 38 | class CustomLLMV3(LLM): 39 | model_name: str 40 | completions_url: str = "http://localhost:8004/v1/completions" 41 | temperature: float = 0. 42 | max_tokens: int = 1024 43 | 44 | @property 45 | def _llm_type(self) -> str: 46 | return "custom" 47 | 48 | def _call(self, prompt: str, stop=None) -> str: 49 | r = requests.post( 50 | self.completions_url, 51 | json={ 52 | "model": self.model_name, 53 | "prompt": json.loads(prompt), 54 | "stop": "\n\n", 55 | "temperature": self.temperature, 56 | "max_tokens": self.max_tokens 57 | }, 58 | ) 59 | result = r.json() 60 | if result.get("object") == "error": 61 | raise RuntimeError(result.get("message")) 62 | else: 63 | return result["choices"][0]["text"] 64 | 65 | async def _acall(self, prompt: str, stop=None) -> str: 66 | return self._call(prompt, stop) 67 | -------------------------------------------------------------------------------- /autoagents/agents/spaces/app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import asyncio 3 | import json 4 | import random 5 | from datetime import date, datetime, timezone, timedelta 6 | from ast import literal_eval 7 | 8 | import streamlit as st 9 | import openai 10 | 11 | from autoagents.agents.utils.constants import MAIN_HEADER, MAIN_CAPTION, SAMPLE_QUESTIONS 12 | from autoagents.agents.agents.search import ActionRunner 13 | 14 | from langchain.chat_models import ChatOpenAI 15 | 16 | 17 | async def run(): 18 | output_acc = "" 19 | st.session_state["random"] = random.randint(0, 99) 20 | if "task" not in st.session_state: 21 | st.session_state.task = None 22 | if "model_name" not in st.session_state: 23 | st.session_state.model_name = "gpt-3.5-turbo" 24 | 25 | st.set_page_config( 26 | page_title="Search Agent", 27 | page_icon="🤖", 28 | layout="wide", 29 | initial_sidebar_state="expanded", 30 | ) 31 | 32 | st.title(MAIN_HEADER) 33 | st.caption(MAIN_CAPTION) 34 | 35 | with st.form("my_form", clear_on_submit=False): 36 | st.markdown("", unsafe_allow_html=True) 37 | user_input = st.text_input( 38 | "You: ", 39 | key="input", 40 | placeholder="Ask me anything ...", 41 | label_visibility="hidden", 42 | ) 43 | 44 | submitted = st.form_submit_button( 45 | "Search", help="Hit to submit the search query." 46 | ) 47 | 48 | # Ask the user to enter their OpenAI API key 49 | if (api_key := st.sidebar.text_input("OpenAI api-key", type="password")): 50 | api_org = None 51 | else: 52 | api_key, api_org = os.getenv("OPENAI_API_KEY"), os.getenv("OPENAI_API_ORG") 53 | with st.sidebar: 54 | model_dict = { 55 | "gpt-3.5-turbo": "GPT-3.5-turbo", 56 | "gpt-4": "GPT-4 (Better but slower)", 57 | } 58 | st.radio( 59 | "OpenAI model", 60 | model_dict.keys(), 61 | key="model_name", 62 | format_func=lambda x: model_dict[x], 63 | ) 64 | 65 | time_zone = str(datetime.now(timezone(timedelta(0))).astimezone().tzinfo) 66 | st.markdown(f"**The system time zone is {time_zone} and the date is {date.today()}**") 67 | 68 | st.markdown("**Example Queries:**") 69 | for q in SAMPLE_QUESTIONS: 70 | st.markdown(f"*{q}*") 71 | 72 | if not api_key: 73 | st.warning( 74 | "API key required to try this app. The API key is not stored in any form. [This](https://help.openai.com/en/articles/4936850-where-do-i-find-my-secret-api-key) might help." 75 | ) 76 | elif api_org and st.session_state.model_name == "gpt-4": 77 | st.warning( 78 | "The free API key does not support GPT-4. Please switch to GPT-3.5-turbo or input your own API key." 79 | ) 80 | else: 81 | outputq = asyncio.Queue() 82 | runner = ActionRunner(outputq, 83 | ChatOpenAI(openai_api_key=api_key, 84 | openai_organization=api_org, 85 | temperature=0, 86 | model_name=st.session_state.model_name), 87 | persist_logs=True) # log to HF-dataset 88 | 89 | async def cleanup(e): 90 | st.error(e) 91 | await st.session_state.task 92 | st.session_state.task = None 93 | st.stop() 94 | 95 | placeholder = st.empty() 96 | 97 | if user_input and submitted: 98 | if st.session_state.task is not None: 99 | with placeholder.container(): 100 | st.session_state.task.cancel() 101 | st.warning("Previous search aborted", icon="⚠️") 102 | 103 | st.session_state.task = asyncio.create_task( 104 | runner.run(user_input, outputq) 105 | ) 106 | iterations = 0 107 | with st.expander("Search Results", expanded=True): 108 | while True: 109 | with st.spinner("Wait for it..."): 110 | output = await outputq.get() 111 | placeholder.empty() 112 | if isinstance(output, Exception): 113 | if isinstance(output, openai.error.AuthenticationError): 114 | await cleanup(f"AuthenticationError: Invalid OpenAI API key.") 115 | elif isinstance(output, openai.error.InvalidRequestError) \ 116 | and output._message == "The model: `gpt-4` does not exist": 117 | await cleanup(f"The free API key does not support GPT-4. Please switch to GPT-3.5-turbo or input your own API key.") 118 | elif isinstance(output, openai.error.OpenAIError): 119 | await cleanup(output) 120 | elif isinstance(output, RuntimeWarning): 121 | st.warning(output) 122 | continue 123 | else: 124 | await cleanup("Something went wrong. Please try searching again.") 125 | return 126 | try: 127 | parsed = json.loads(output) 128 | st.json(output, expanded=True) 129 | st.write("---") 130 | iterations += 1 131 | if parsed.get("action") == "Finish": 132 | break 133 | except: 134 | output_fmt = literal_eval(output) 135 | st.json(output_fmt, expanded=False) 136 | if iterations >= runner.agent_executor.max_iterations: 137 | await cleanup( 138 | f"Maximum iterations ({iterations}) exceeded. You can try running the search again or try a variation of the query." 139 | ) 140 | return 141 | # Found the answer 142 | final_answer = await st.session_state.task 143 | final_answer = final_answer.replace("$", "\$") 144 | # st.success accepts md 145 | st.success(final_answer, icon="✅") 146 | st.balloons() 147 | st.session_state.task = None 148 | st.stop() 149 | 150 | if __name__ == "__main__": 151 | loop = asyncio.new_event_loop() 152 | loop.set_debug(enabled=False) 153 | loop.run_until_complete(run()) 154 | -------------------------------------------------------------------------------- /autoagents/agents/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutoLLM/AutoAgents/e1210c7884f951f1b90254e40c162e81fc1442f3/autoagents/agents/tools/__init__.py -------------------------------------------------------------------------------- /autoagents/agents/tools/tools.py: -------------------------------------------------------------------------------- 1 | import wikipedia 2 | import requests 3 | from elasticsearch import Elasticsearch 4 | 5 | from duckpy import Client 6 | from langchain import PromptTemplate, LLMChain, Wikipedia 7 | from langchain.agents import Tool 8 | from langchain.agents.react.base import DocstoreExplorer 9 | 10 | from langchain.base_language import BaseLanguageModel 11 | 12 | 13 | MAX_SEARCH_RESULTS = 5 # Number of search results to observe at a time 14 | 15 | INDEX_NAME = "wiki-dump-2017" 16 | 17 | # Create the client instance 18 | es_client = None 19 | 20 | search_description = """ Useful for when you need to ask with search. Use direct language and be 21 | EXPLICIT in what you want to search. Do NOT use filler words. 22 | 23 | ## Examples of incorrect use 24 | { 25 | "action": "Tool_Search", 26 | "action_input": "[name of bagel shop] menu" 27 | } 28 | 29 | The action_input cannot be None or empty. 30 | """ 31 | 32 | notepad_description = """ Useful for when you need to note-down specific 33 | information for later reference. Please provide the website and full 34 | information you want to note-down in the action_input and all future prompts 35 | will remember it. This is the mandatory tool after using the Tool_Search. 36 | Using Tool_Notepad does not always lead to a final answer. 37 | 38 | ## Examples of using Notepad tool 39 | { 40 | "action": "Tool_Notepad", 41 | "action_input": "(www.website.com) the information you want to note-down" 42 | } 43 | """ 44 | 45 | wiki_notepad_description = """ Useful for when you need to note-down specific 46 | information for later reference. Please provide the website and full 47 | information you want to note-down in the action_input and all future prompts 48 | will remember it. This is the mandatory tool after using the Tool_Wikipedia. 49 | Using Tool_Notepad does not always lead to a final answer. 50 | 51 | ## Examples of using Notepad tool 52 | { 53 | "action": "Tool_Notepad", 54 | "action_input": "(www.website.com) the information you want to note-down" 55 | } 56 | """ 57 | 58 | wiki_search_description = """ Useful for when you need to get some information about a certain entity. Use direct language and be 59 | concise about what you want to retrieve. Note: the action input MUST be a wikipedia entity instead of a long sentence. 60 | 61 | ## Examples of correct use 62 | 1. Action: Tool_Wikipedia 63 | Action Input: Colorado orogeny 64 | 65 | The Action Input cannot be None or empty. 66 | """ 67 | 68 | wiki_lookup_description = """ This tool is helpful when you want to retrieve sentences containing a specific text snippet after checking a Wikipedia entity. 69 | It should be utilized when a successful Wikipedia search does not provide sufficient information. 70 | Keep your lookup concise, using no more than three words. 71 | 72 | ## Examples of correct use 73 | 1. Action: Tool_Lookup 74 | Action Input: eastern sector 75 | 76 | The Action Input cannot be None or empty. 77 | """ 78 | 79 | 80 | async def ddg(query: str): 81 | if query is None or query.lower().strip().strip('"') == "none" or query.lower().strip().strip('"') == "null": 82 | x = "The action_input field is empty. Please provide a search query." 83 | return [x] 84 | else: 85 | client = Client() 86 | return client.search(query)[:MAX_SEARCH_RESULTS] 87 | 88 | docstore=DocstoreExplorer(Wikipedia()) 89 | 90 | async def notepad(x: str) -> str: 91 | return f"{[x]}" 92 | 93 | async def wikisearch(x: str) -> str: 94 | title_list = wikipedia.search(x) 95 | if not title_list: 96 | return docstore.search(x) 97 | title = title_list[0] 98 | return f"Wikipedia Page Title: {title}\nWikipedia Page Content: {docstore.search(title)}" 99 | 100 | async def wikilookup(x: str) -> str: 101 | return docstore.lookup(x) 102 | 103 | async def wikidumpsearch_es(x: str) -> str: 104 | global es_client 105 | if es_client is None: 106 | es_client = Elasticsearch("http://localhost:9200") 107 | resp = es_client.search( 108 | index=INDEX_NAME, query={"match": {"text": x}}, size=MAX_SEARCH_RESULTS 109 | ) 110 | res = [] 111 | for hit in resp['hits']['hits']: 112 | doc = hit["_source"] 113 | res.append({ 114 | "title": doc["title"], 115 | "text": ''.join(sent for sent in doc["text"][1]), 116 | "url": doc["url"] 117 | }) 118 | if doc["title"] == x: 119 | return [{ 120 | "title": doc["title"], 121 | "text": '\n'.join(''.join(paras) for paras in doc["text"][1:3]) 122 | if len(doc["text"]) > 2 123 | else '\n'.join(''.join(paras) for paras in doc["text"]), 124 | "url": doc["url"] 125 | }] 126 | return res 127 | 128 | async def wikidumpsearch_embed(x: str) -> str: 129 | res = [] 130 | for obj in vector_search(x): 131 | paras = obj["text"].split('\n') 132 | cur = { 133 | "title": obj["sources"][0]["title"], 134 | "text": paras[min(1, len(paras) - 1)], 135 | "url": obj["sources"][0]["url"] 136 | } 137 | res.append(cur) 138 | if cur["title"] == x: 139 | return [{ 140 | "title": obj["sources"][0]["title"], 141 | "text": '\n'.join(paras[1:3] if len(paras) > 2 else paras), 142 | "url": obj["sources"][0]["url"] 143 | }] 144 | return res 145 | 146 | def vector_search( 147 | query: str, 148 | url: str = "http://0.0.0.0:8080/query", 149 | max_candidates: int = MAX_SEARCH_RESULTS 150 | ): 151 | response = requests.post( 152 | url=url, json={"query_list": [query]} 153 | ).json()["result"][0]["top_answers"] 154 | return response[:min(max_candidates, len(response))] 155 | 156 | 157 | search_tool = Tool(name="Tool_Search", 158 | func=lambda x: x, 159 | coroutine=ddg, 160 | description=search_description) 161 | 162 | note_tool = Tool(name="Tool_Notepad", 163 | func=lambda x: x, 164 | coroutine=notepad, 165 | description=notepad_description) 166 | 167 | wiki_note_tool = Tool(name="Tool_Notepad", 168 | func=lambda x: x, 169 | coroutine=notepad, 170 | description=wiki_notepad_description) 171 | 172 | wiki_search_tool = Tool( 173 | name="Tool_Wikipedia", 174 | func=lambda x: x, 175 | coroutine=wikisearch, 176 | description=wiki_search_description 177 | ) 178 | 179 | wiki_lookup_tool = Tool( 180 | name="Tool_Lookup", 181 | func=lambda x: x, 182 | coroutine=wikilookup, 183 | description=wiki_lookup_description 184 | ) 185 | 186 | wiki_dump_search_tool = Tool( 187 | name="Tool_Wikipedia", 188 | func=lambda x: x, 189 | coroutine=wikidumpsearch_embed, 190 | description=wiki_search_description 191 | ) 192 | 193 | async def final(x: str): 194 | pass 195 | 196 | finish_description = """ Useful when you have enough information to produce a 197 | final answer that achieves the original Goal. 198 | 199 | You must also include this key in the output for the Tool_Finish action 200 | "citations": ["www.example.com/a/list/of/websites: what facts you got from the website", 201 | "www.example.com/used/to/produce/the/action/and/action/input: "what facts you got from the website", 202 | "www.webiste.com/include/the/citations/from/the/previous/steps/as/well: "what facts you got from the website", 203 | "www.website.com": "this section is only needed for the final answer"] 204 | 205 | ## Examples of using Finish tool 206 | { 207 | "action": "Tool_Finish", 208 | "action_input": "final answer", 209 | "citations": ["www.example.com: what facts you got from the website"] 210 | } 211 | """ 212 | 213 | finish_tool = Tool(name="Tool_Finish", 214 | func=lambda x: x, 215 | coroutine=final, 216 | description=finish_description) 217 | 218 | def rewrite_search_query(q: str, search_history, llm: BaseLanguageModel) -> str: 219 | history_string = '\n'.join(search_history) 220 | template ="""We are using the Search tool. 221 | # Previous queries: 222 | {history_string}. \n\n Rewrite query {action_input} to be 223 | different from the previous queries.""" 224 | prompt = PromptTemplate(template=template, 225 | input_variables=["action_input", "history_string"]) 226 | llm_chain = LLMChain(prompt=prompt, llm=llm) 227 | result = llm_chain.predict(action_input=q, history_string=history_string) 228 | return result 229 | 230 | 231 | ### Prompt V3 tools 232 | 233 | search_description_v3 = """Useful for when you need to ask with search.""" 234 | notepad_description_v3 = """ Useful for when you need to note-down specific information for later reference.""" 235 | finish_description_v3 = """Useful when you have enough information to produce a final answer that achieves the original Goal.""" 236 | 237 | search_tool_v3 = Tool(name="Tool_Search", 238 | func=lambda x: x, 239 | coroutine=ddg, 240 | description=search_description_v3) 241 | 242 | note_tool_v3 = Tool(name="Tool_Notepad", 243 | func=lambda x: x, 244 | coroutine=notepad, 245 | description=notepad_description_v3) 246 | 247 | finish_tool_v3 = Tool(name="Tool_Finish", 248 | func=lambda x: x, 249 | coroutine=final, 250 | description=finish_description_v3) 251 | -------------------------------------------------------------------------------- /autoagents/agents/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutoLLM/AutoAgents/e1210c7884f951f1b90254e40c162e81fc1442f3/autoagents/agents/utils/__init__.py -------------------------------------------------------------------------------- /autoagents/agents/utils/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | MAIN_HEADER = "Web Search Agent" 5 | 6 | MAIN_CAPTION = """This is a proof-of-concept search agent that reasons, plans, 7 | and executes web searches to collect information on your behalf. It aims to 8 | resolve your question by breaking it down into step-by-step subtasks. All the 9 | intermediate results will be presented. 10 | 11 | *DISCLAIMER*: We are collecting search queries, so please refrain from 12 | providing any personal information. If you wish to avoid this, you can run the 13 | app locally by following the instructions on our 14 | [Github](https://github.com/AutoLLM/AutoAgents).""" 15 | 16 | SAMPLE_QUESTIONS = [ 17 | "Recommend me a movie in theater now to watch with kids.", 18 | "Who is the most recent NBA MVP? Which team does he play for? What are his career stats?", 19 | "Who is the head coach of AC Milan now? How long has he been coaching the team?", 20 | "What is the mortgage rate right now and how does that compare to the past two years?", 21 | "What is the weather like in San Francisco today? What about tomorrow?", 22 | "When and where is the upcoming concert for Taylor Swift? Share a link to purchase tickets.", 23 | "Find me recent studies focusing on hallucination in large language models. Provide the link to each study found.", 24 | ] 25 | 26 | LOG_SAVE_DIR: str = os.path.join(os.getcwd(), "data") 27 | -------------------------------------------------------------------------------- /autoagents/agents/utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from typing import Dict, Any, List 4 | import uuid 5 | from datetime import datetime 6 | import pytz 7 | 8 | import huggingface_hub 9 | from huggingface_hub import Repository 10 | 11 | from autoagents.agents.utils.constants import LOG_SAVE_DIR 12 | 13 | 14 | class InteractionsLogger: 15 | def __init__(self, name: str, persist=False): 16 | self.persist = persist 17 | self.messages = [] 18 | self.counter = 0 19 | self.name = name # unique id 20 | HF_TOKEN = os.environ.get("HF_TOKEN") 21 | HF_DATASET_REPO_URL = os.environ.get("HF_DATASET_REPO_URL") 22 | if (HF_TOKEN is not None) and (HF_DATASET_REPO_URL is not None): 23 | self.repo = Repository( 24 | local_dir="data", clone_from=HF_DATASET_REPO_URL, use_auth_token=HF_TOKEN 25 | ) 26 | else: 27 | self.repo = None 28 | 29 | def set_goal(self, goal: str): 30 | self.messages.append({"goal": goal}) 31 | 32 | def set_tools(self, tools: List): 33 | self.messages.append({"tools": tools}) 34 | 35 | def add_history(self, hist: Dict): 36 | self.convos = [{"from": "history", "value": hist}] 37 | 38 | def add_ai(self, msg: Dict): 39 | self.convos.append({"from": "ai", "value": msg}) 40 | self.messages.append({"id": f"{self.name}_{self.counter}", "conversations": self.convos}) 41 | self.counter += 1 42 | 43 | def add_system(self, more: Dict): 44 | self.convos.append({"from": "system", "value": more}) 45 | 46 | def add_message(self, data: Dict[str, Any]): 47 | self.messages.append(data) 48 | 49 | def save(self, save_dir=LOG_SAVE_DIR): 50 | self.add_message({"datetime": datetime.now(pytz.utc).strftime("%m/%d/%Y %H:%M:%S %Z%z")}) 51 | if self.persist: 52 | if not os.path.isdir(save_dir): 53 | os.mkdir(save_dir) 54 | # TODO: want to add retry in a loop? 55 | if self.repo is not None: 56 | self.repo.git_pull() 57 | fname = uuid.uuid4().hex[:16] 58 | with open(os.path.join(save_dir, f"{fname}.json"), "w") as f: 59 | json.dump(self.messages, f, indent=2) 60 | if self.repo is not None: 61 | commit_url = self.repo.push_to_hub() 62 | 63 | def add_cost(self, cost): 64 | self.messages.append({"metrics": cost}) 65 | -------------------------------------------------------------------------------- /autoagents/data/action_name_transformation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import uuid 3 | import argparse 4 | 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('input') 8 | parser.add_argument('output') 9 | 10 | args = parser.parse_args() 11 | 12 | 13 | def transform(conversation, search_word, notepad_word): 14 | for i, message in enumerate(conversation): 15 | conversation[i]["value"] = \ 16 | message["value"].replace( 17 | "Tool_Search", search_word).replace( 18 | "Tool_Notepad", notepad_word) 19 | return conversation 20 | 21 | 22 | input_file = args.input 23 | output_file = args.output 24 | 25 | with open(input_file, "r") as f: 26 | body = json.load(f) 27 | 28 | result = [] 29 | for elem in body: 30 | search_word = str(uuid.uuid4())[:6] 31 | notepad_word = str(uuid.uuid4())[:6] 32 | elem = { 33 | "id": elem["id"], 34 | "conversations": transform(elem["conversations"], search_word, notepad_word)} 35 | result.append(elem) 36 | with open(output_file, "w") as f: 37 | json.dump(result, f, indent=2) 38 | -------------------------------------------------------------------------------- /autoagents/data/create_sft_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import glob 3 | from collections import Counter 4 | import argparse 5 | 6 | counts = Counter() 7 | 8 | Goals = set() 9 | 10 | def process_file(data, name): 11 | if name.startswith("train_data") or name.startswith("final") or name.startswith("ab_"): 12 | return 13 | for d in data: 14 | if "error" in d: 15 | counts["error"] += 1 16 | return 17 | elif "query_rewrite" in d: 18 | counts["query_rewrite"] += 1 19 | return 20 | output = json.loads(data[-3]["conversations"][2]["value"]) 21 | if output["action"] != "Tool_Finish": 22 | counts["no_final_answer"] += 1 23 | return 24 | # remove dups in case 25 | goal = data[0]["goal"] 26 | if goal not in Goals: 27 | Goals.add(goal) 28 | else: 29 | return 30 | costs = data[-2] 31 | data = data[1:-2] 32 | 33 | counts["conv_len"] += len(data) 34 | counts["total"] += 1 35 | 36 | counts["totals_cost"] += costs["metrics"]["total_cost"] 37 | data_new = [] 38 | for d in data: 39 | convs = [] 40 | for conv in d["conversations"]: 41 | k, v = conv.values() 42 | if k == "system": 43 | convs.append({"from": "human", "value": v}) 44 | elif k == "ai": 45 | convs.append({"from": "gpt", "value": v}) 46 | assert len(convs) == 2 47 | data_new.append({"id": d["id"], "conversations": convs}) 48 | return data_new 49 | 50 | def main(dir_path, save=False): 51 | assert dir_path is not None 52 | train_data = [] 53 | for name in glob.glob(f"{dir_path}/*.json"): 54 | dname = name.split("/")[-1].split(".")[0] 55 | with open(f"{dir_path}/{dname}.json", "r") as file: 56 | data = json.load(file) 57 | if (filtered_data := process_file(data, dname)): 58 | train_data += filtered_data 59 | if save: 60 | with open(f"{dir_path}/sft_data.json", "w") as f: 61 | json.dump(train_data, f, indent=2) 62 | print(counts) 63 | 64 | if __name__ == "__main__": 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument('data_dir_path', type=str) 67 | parser.add_argument("--save", action="store_true") 68 | args = parser.parse_args() 69 | main(args.data_dir_path, args.save) 70 | 71 | -------------------------------------------------------------------------------- /autoagents/data/dataset.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | BAMBOOGLE = { 4 | "questions": ['Who was president of the United States in the year that Citibank was founded?', 'What rocket was the first spacecraft that ever approached Uranus launched on?', 'In what year was the company that was founded as Sound of Music added to the S&P 500?', 'Who was the first African American mayor of the most populous city in the United States?', "When did the last king from Britain's House of Hanover die?", 'When did the president who set the precedent of a two term limit enter office?', 'When did the president who set the precedent of a two term limit leave office?', 'How many people died in the second most powerful earthquake ever recorded?', 'Can people who have celiac eat camel meat?', 'What was the final book written by the author of On the Origin of Species?', 'When was the company that built the first steam locomotive to carry passengers on a public rail line founded?', 'Which Theranos whistleblower is not related to a senior American government official?', 'What is the fastest air-breathing manned aircraft mostly made out of?', 'Who built the fastest air-breathing manned aircraft?', 'When was the author of The Population Bomb born?', 'When did the author of Annabel Lee enlist in the army?', 'What was the religion of the inventor of the Polio vaccine?', 'Who was the second wife of the founder of CNN?', 'When did the first prime minister of the Russian Empire come into office?', 'What is the primary male hormone derived from?', 'The Filipino statesman who established the government-in-exile during the outbreak of World War II was also the mayor of what city?', 'Where was the person who shared the Nobel Prize in Physics in 1954 with Max Born born?', 'When was the person who shared the Nobel Prize in Physics in 1954 with Max Born born?', 'What was the founding date of the university in which Plotonium was discovered?', 'The material out of which the Great Sphinx of Giza is made of is mainly composed of what mineral?', 'The husband of Lady Godiva was Earl of which Anglic kingdom?', 'The machine used to extract honey from honeycombs uses which physical force?', 'What is the third letter of the top level domain of the military?', 'In what year was the government department where the internet originated at founded?', 'The main actor of Indiana Jones is a licensed what?', 'When was the person after which the Hubble Space Telescope is named after born?', 'When did the person who gave the Checkers speech die?', 'When was the philosopher that formulated the hard problem of consciousness born?', 'What is the capital of the second largest state in the US by area?', 'What is the maximum airspeed (in km/h) of the third fastest bird?', 'Who founded the city where the founder of geometry lived?', 'What is the capital of the country where yoga originated?', 'The fourth largest city in Germany was originally called what?', "When did Nirvana's second most selling studio album come out?", 'What was the job of the father of the founder of psychoanalysis?', 'How much protein in four boiled egg yolks?', 'What is the political party of the American president who entered into the Paris agreement?', 'The most populous city in Punjab is how large (area wise)?', 'What was the death toll of the second largest volcanic eruption in the 20th century?', 'What was the death toll of the most intense Atlantic hurricane?', 'Who was the head of NASA during Apollo 11?', 'Who is the father of the father of George Washington?', 'Who is the mother of the father of George Washington?', 'Who is the father of the father of Barack Obama?', 'Who is the mother of the father of Barack Obama?', 'Who was mayor of New York City when Fiorello H. La Guardia was born?', 'Who was president of the U.S. when superconductivity was discovered?', 'When was the person Russ Hanneman is based on born?', "When was the first location of the world's largest coffeehouse chain opened?", 'Who directed the highest grossing film?', 'When was the longest bridge in the world opened?', 'Which company was responsible for the largest pharmaceutical settlement?', 'In what year was the tallest self-supporting tower completed?', 'In what year was the tallest fixed steel structure completed?', 'In what year was the tallest lattice tower completed?', 'In what year was the current tallest wooden lattice tower completed?', 'In what country is the second tallest statue in the world?', 'When was the tallest ferris wheel in the world completed?', 'In what year was the tallest lighthouse completed?', 'In what country is the world largest desalination plant?', 'The most populous national capital city was established in what year?', 'The third largest river (by discharge) in the world is in what countries?', 'What is the highest elevation (in meters) of the second largest island in the world?', 'What is the length of the second deepest river in the world?', 'In what country is the third largest stadium in the world?', 'Who is the largest aircraft carrier in the world is named after?', 'In what year did the oldest cat ever recorded with the Cat of the Year award?', 'In what year was the country that is the third largest exporter of coffee founded?', 'Who was the commander for the space mission that had the first spacewalk?', 'Who is the predecessor of the longest-reigning British monarch?', 'In 2016, who was the host of the longest running talk show?', 'In 2016, who was the host of the longest running American game show?', 'Who wrote the novel on which the longest running show in Broadway history is based on?', 'In what country was the only cruise line that flies the American flag incorporated in?', 'In what year did work begin on the second longest road tunnel in the world?', 'What is the official color of the third oldest surviving university?', 'Who succeeded the longest reigning Roman emperor?', 'Who preceded the Roman emperor that declared war on the sea?', 'Who produced the longest running video game franchise?', 'Who was the father of the father of psychoanalysis?', 'Who was the father of the father of empiricism?', 'Who is the father of the father of observational astronomy?', 'Who is the father of the father of modern Hebrew?', 'Who is the father of the father of modern experimental psychology?', 'Who is the father of the originator of cybernetics?', 'Who is the father of the father of the hydrogen bomb?', 'Who was the father of the father of computer science?', 'Who was the father of the father of behaviorism?', 'Who was the father of the founder of modern human anatomy?', 'What was the father of the last surviving Canadian father of Confederation?', 'When was the person who said “Now, I am become Death, the destroyer of worlds.” born?', 'Who was the father of the father of information theory?', 'When was the person who delivered the "Quit India" speech born?', 'When did the president who warned about the military industrial complex die?', 'When did the president who said Tear Down This Wall die?', 'What is the lowest elevation of the longest railway tunnel?', 'When did the person who said "Cogito, ergo sum." die?', 'When did the person who delivered the Gettysburg Address die?', 'Who was governor of Florida during Hurricane Irma?', "For which club did the winner of the 2007 Ballon d'Or play for in 2012?", "What's the capital city of the country that was the champion of the 2010 World Cup?", 'When was the anime studio that made Sword Art Online founded?', 'Who was the first king of the longest Chinese dynasty?', 'Who was the last emperor of the dynasty that succeeded the Song dynasty?', "What's the motto of the oldest California State university?", "What's the capital of the state that the College of William & Mary is in?", "What's the capital of the state that Washington University in St. Louis is in?", "What's the capital of the state that Harvard University is in?", "What's the capital of the state that the Space Needle is at?", "Which team won in women's volleyball in the most recent Summer Olympics that was held in London?", 'What is the nickname of the easternmost U.S. state?', 'What is the nickname for the state that is the home to the “Avocado Capital of the World"?', 'What rocket was used for the mission that landed the first humans on the moon?', 'When did the war that Neil Armstrong served in end?', 'What is the nickname for the state that Mount Rainier is located in?', 'When was the composer of Carol of the Bells born?', 'Who is the father of the scientist at MIT that won the Queen Elizabeth Prize for Engineering in 2013?', 'Who was the mother of the emperor of Japan during World War I?', 'Which element has an atomic number that is double that of hydrogen?', 'What was the motto of the Olympics that had Fuwa as the mascots?'], 5 | "answers": ['james madison', 'Titan IIIE', 1999, 'David Dinkins', '20 June 1837', 'April 30, 1789', 'March 4, 1797', 131, 'Yes', 'The Formation of Vegetable Mould Through the Action of Worms', 1823, 'Erika Cheung', 'Titanium', 'Lockheed Corporation', datetime.datetime(1932, 5, 29, 0, 0), 1827, 'Jewish', 'Jane Shirley Smith', datetime.datetime(1905, 11, 6, 0, 0), 'cholesterol', 'Quezon City', 'Oranienburg, Germany', 'January 8, 1891', 'March 23, 1868', 'calcite', 'Mercia', 'Centrifugal Force', 'l', 1947, 'pilot', 'November 20, 1889', datetime.datetime(1994, 4, 22, 0, 0), datetime.datetime(1966, 4, 20, 0, 0), 'Austin', '320 km/h', 'Alexander the Great', 'New Delhi', 'Colonia Claudia Ara Agrippinensium', datetime.datetime(1993, 9, 13, 0, 0), 'wool merchant', 10.8, 'Democratic Party', '310 square kilometers', 847, 52, 'Thomas O. Paine', 'Lawrence Washington', 'Mildred Warner', 'Hussein Onyango Obama', 'Habiba Akumu Nyanjango', 'William Russell Grace', 'William Howard Taft', datetime.datetime(1958, 7, 31, 0, 0), datetime.datetime(1971, 3, 30, 0, 0), 'James Cameroon', datetime.datetime(2011, 6, 30, 0, 0), 'GlaxoSmithKline', 2012, 1988, 2012, 1935, 'China', 2021, 1902, 'Saudi Arabia', '1045 BC', 'India and Bangladesh', '4,884 m', '6,300 km', 'United States', 'Gerald R. Ford', 1999, 1810, 'Pavel Belyayev', 'George VI\n', 'Jimmy Fallon', 'Drew Carey', 'Gaston Leroux', 'Bermuda', 1992, 'Cambridge Blue', 'Tiberius', 'Tiberius', 'MECC', 'Jacob Freud', 'Sir Nicholas Bacon', 'Vincenzo Galilei', 'Yehuda Leib', 'Maximilian Wundt', 'Leo Wiener', 'Max Teller', 'Julius Mathison Turing', 'Pickens Butler Watson', 'Anders van Wesel', 'Charles Tupper Sr.', datetime.datetime(1904, 4, 22, 0, 0), 'Claude Sr.', 'October 2, 1869', datetime.datetime(1969, 3, 28, 0, 0), datetime.datetime(2004, 6, 5, 0, 0), '312 m', 'February 11, 1650', 'April 15, 1865', 'Rick Scott', 'Real Madrid', 'Madrid', datetime.datetime(2005, 5, 9, 0, 0), 'King Wu of Zhou', 'Toghon Temür', 'Powering Silicon Valley', 'Richmond', 'Jefferson City', 'Boston', 'Olympia', 'Brazil', 'Pine Tree State', 'Golden State', 'Saturn V', datetime.datetime(1953, 7, 27, 0, 0), 'Evergreen State', 'December 13, 1877', 'Conway Berners-Lee', 'Yanagiwara Naruko', 'Helium', 'One World, One Dream'] 6 | } 7 | 8 | DEFAULT_Q = [ 9 | (0, "list 3 cities and their current populations where Paramore is playing this year."), 10 | (1, "Who is Leo DiCaprio's girlfriend? What is her current age raised to the 0.43 power?"), 11 | (2, "How many watermelons can fit in a Tesla Model S?"), 12 | (3, "Recommend me some laptops suitable for UI designers under $2000. Please include brand and price."), 13 | (4, "Build me a vacation plan for Rome and Milan this summer for seven days. Include place to visit and hotels to stay. "), 14 | (5, "What is the sum of ages of the wives of Barack Obama and Donald Trump?"), 15 | (6, "Who is the most recent NBA MVP? Which team does he play for? What is his season stats?"), 16 | (7, "What were the scores for the last three games for the Los Angeles Lakers? Provide the dates and opposing teams."), 17 | (8, "Which team won in women's volleyball in the Summer Olympics that was held in London?"), 18 | (9, "Provide a summary of the latest COVID-19 research paper published. Include the title, authors and abstract."), 19 | (10, "What is the top grossing movie in theatres this week? Provide the movie title, director, and a brief synopsis of the movie's plot. Attach a review for this movie."), 20 | (11, "Recommend a bagel shop near the Strip district in Pittsburgh that offer vegan food"), 21 | (12, "Who are some top researchers in the field of machine learning systems nowadays?"), 22 | ] 23 | 24 | FT = [ 25 | (0, "Briefly explain the current global climate change adaptation strategy and its effectiveness."), 26 | (1, "What steps should be taken to prepare a backyard garden for spring planting?"), 27 | (2, "Report the critical reception of the latest superhero movie."), 28 | (3, "When is the next NBA or NFL finals game scheduled?"), 29 | (4, "Which national parks or nature reserves are currently open for visitors near Denver, Colorado?"), 30 | (5, "Who are the most recent Nobel Prize winners in physics, chemistry, and medicine, and what are their respective contributions?"), 31 | ] 32 | 33 | HF = [ 34 | (0, "Recommend me a movie in theater now to watch with kids."), 35 | (1, "Who is the most recent NBA MVP? Which team does he play for? What are his career stats?"), 36 | (2, "Who is the head coach of AC Milan now? How long has he been coaching the team?"), 37 | (3, "What is the mortgage rate right now and how does that compare to the past two years?"), 38 | (4, "What is the weather like in San Francisco today? What about tomorrow?"), 39 | (5, "When and where is the upcoming concert for Taylor Swift? Share a link to purchase tickets."), 40 | (6, "Find me recent studies focusing on hallucination in large language models. Provide the link to each study found."), 41 | ] 42 | -------------------------------------------------------------------------------- /autoagents/data/generate_action_data.py: -------------------------------------------------------------------------------- 1 | # Script generates action data from goals calling GPT-4 2 | import os 3 | import asyncio 4 | import argparse 5 | from tqdm import tqdm 6 | 7 | from multiprocessing import Pool 8 | 9 | from autoagents.agents.agents.search import ActionRunner 10 | from autoagents.eval.test import AWAIT_TIMEOUT 11 | from langchain.chat_models import ChatOpenAI 12 | import json 13 | 14 | 15 | async def work(user_input): 16 | outputq = asyncio.Queue() 17 | llm = ChatOpenAI(openai_api_key=os.getenv("OPENAI_API_KEY"), 18 | openai_organization=os.getenv("OPENAI_API_ORG"), 19 | temperature=0., 20 | model_name="gpt-4") 21 | runner = ActionRunner(outputq, llm=llm, persist_logs=True) 22 | task = asyncio.create_task(runner.run(user_input, outputq)) 23 | 24 | while True: 25 | try: 26 | output = await asyncio.wait_for(outputq.get(), AWAIT_TIMEOUT) 27 | except asyncio.TimeoutError: 28 | return 29 | if isinstance(output, RuntimeWarning): 30 | print(output) 31 | continue 32 | elif isinstance(output, Exception): 33 | print(output) 34 | return 35 | try: 36 | parsed = json.loads(output) 37 | if parsed["action"] in ("Tool_Finish", "Tool_Abort"): 38 | break 39 | except: 40 | pass 41 | await task 42 | 43 | def main(q): 44 | asyncio.run(work(q)) 45 | 46 | if __name__ == "__main__": 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument('--goals', type=str, help="file containing JSON array of goals", required=True) 49 | parser.add_argument("--num_data", type=int, default=-1, help="number of goals for generation") 50 | args = parser.parse_args() 51 | with open(args.goals, "r") as file: 52 | data = json.load(file) 53 | if args.num_data > -1 and len(data) > args.num_data: 54 | data = data[:args.num_data] 55 | with Pool(processes=4) as pool: 56 | for _ in tqdm(pool.imap_unordered(main, data), total=len(data)): 57 | pass 58 | -------------------------------------------------------------------------------- /autoagents/data/generate_action_tasks/README.md: -------------------------------------------------------------------------------- 1 | # Generate_action_tasks 2 | Generate action tasks for AutoGPT following self-instruct. 3 | 4 | 5 | ## What does this repo do 6 | 7 | This repo only generates the tasks that needs the agent to complete, not the full action data that includes reasoning, planning and execution of tools. That part of codes is implemented together with the AutoAgent repo. 8 | 9 | 10 | ## Repo Structure 11 | 12 | * `REACT.ipynb` is the exploration of creating action data. 13 | * `generate_data_chat_api.py` is the main file that adopts openAI's API to generate more tasks based on in-context learning. 14 | * `prompt.txt` is the prompt used by `generate_data_chat_api.py` 15 | * `seed_tasks.jsonl`: the manually labeled tasks for in-context learning 16 | * `generate_data_chat_api.py`: uses the completion API, not the chat API. 17 | 18 | 19 | ## Usage 20 | Simply run 21 | `sh run_genenerate_data.sh`. Parameters could be modified. 22 | 23 | -------------------------------------------------------------------------------- /autoagents/data/generate_action_tasks/generate_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | run: 3 | python -m generate_data generate_agents_data \ 4 | --output_dir ./ \ 5 | --num_agents_to_generate 10 \ 6 | --model_name="text-davinci-003" \ 7 | """ 8 | import time 9 | import json 10 | import os 11 | import random 12 | import re 13 | import string 14 | from functools import partial 15 | from multiprocessing import Pool 16 | 17 | import numpy as np 18 | import tqdm 19 | from rouge_score import rouge_scorer 20 | import utils 21 | 22 | import fire 23 | import pdb 24 | 25 | ''' 26 | The tools used by Auto-GPT are available at: 27 | https://github.com/Significant-Gravitas/Auto-GPT/tree/master/autogpt/commands 28 | ''' 29 | 30 | def encode_prompt(prompt_agents): 31 | """Encode multiple prompt instructions into a single string.""" 32 | prompt = open("./prompt.txt").read() + "\n" 33 | 34 | for idx, task_dict in enumerate(prompt_agents): 35 | (name, goal) = task_dict["name"], task_dict["goal"] 36 | if not goal: 37 | raise 38 | prompt += f"###\n" 39 | prompt += f"{idx + 1}. Name: {name}\n" 40 | prompt += f"{idx + 1}. Goal:\n{goal}\n" 41 | prompt += f"###\n" 42 | prompt += f"{idx + 2}. Name:" 43 | return prompt 44 | 45 | 46 | def post_process_gpt_response(num_prompt_agents, response): 47 | print ("post_process_gpt_response") 48 | if response is None: 49 | return [] 50 | raw_instructions = f"{num_prompt_agents+1}. Name:" + response["text"] 51 | raw_instructions = re.split("###", raw_instructions) 52 | agents = [] 53 | for idx, inst in enumerate(raw_instructions): 54 | # if the decoding stops due to length, the last example is likely truncated so we discard it 55 | if idx == len(raw_instructions) - 1 and response["finish_reason"] == "length": 56 | continue 57 | idx += num_prompt_agents + 1 58 | splitted_data = re.split(f"{idx}\.\s+(Name|Goal):", inst) 59 | if len(splitted_data) != 5: 60 | continue 61 | else: 62 | name = splitted_data[2].strip() 63 | role = splitted_data[4].strip() 64 | # goals = splitted_data[6].strip() 65 | # goals = "" if goals.lower() == "" else goals 66 | # filter out too short or too long role 67 | if len(role.split()) <= 3 or len(role.split()) > 150: 68 | continue 69 | # filter based on keywords that are not suitable for language models. 70 | blacklist = [ 71 | "kill", 72 | "harm", 73 | "discriminate", 74 | ] 75 | blacklist += [] 76 | if any(find_word_in_string(word, role) for word in blacklist): 77 | continue 78 | # filter those starting with punctuation 79 | if role[0] in string.punctuation: 80 | continue 81 | # filter those starting with non-english character 82 | if not role[0].isascii(): 83 | continue 84 | agents.append({"name": name, "goal": role}) 85 | return agents 86 | 87 | 88 | def find_word_in_string(w, s): 89 | return re.compile(r"\b({0})\b".format(w), flags=re.IGNORECASE).search(s) 90 | 91 | 92 | def generate_agents_data( 93 | output_dir="./", 94 | seed_tasks_path="./seed_tasks.jsonl", 95 | num_agents_to_generate=20, 96 | model_name="text-davinci-003", 97 | num_prompt_agents=5, 98 | request_batch_size=1, 99 | temperature=1.0, 100 | top_p=1.0, 101 | num_cpus=16, 102 | ): 103 | print ("generate_instruction_following_data") 104 | seed_tasks = [json.loads(l) for l in open(seed_tasks_path, "r")] 105 | seed_agent_data = [ 106 | {"name": t["name"], "goal": t["goal"]} 107 | for t in seed_tasks 108 | ] 109 | print(f"Loaded {len(seed_agent_data)} human-written seed agents") 110 | 111 | os.makedirs(output_dir, exist_ok=True) 112 | request_idx = 0 113 | # load the LM-generated instructions 114 | machine_agent_data = [] 115 | if os.path.exists(os.path.join(output_dir, "regen.json")): 116 | machine_agent_data = utils.jload(os.path.join(output_dir, "regen.json")) 117 | print(f"Loaded {len(machine_agent_data)} machine-generated agents") 118 | 119 | # similarities = {} 120 | scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=False) 121 | 122 | # now let's generate new instructions! 123 | progress_bar = tqdm.tqdm(total=num_agents_to_generate) 124 | if machine_agent_data: 125 | progress_bar.update(len(machine_agent_data)) 126 | 127 | # first we tokenize all the seed instructions and generated machine instructions 128 | all_roles = [d["goal"] for d in seed_agent_data] + [ 129 | d["goal"] for d in machine_agent_data 130 | ] 131 | all_instruction_tokens = [scorer._tokenizer.tokenize(role) for role in all_roles] 132 | 133 | while len(machine_agent_data) < num_agents_to_generate: 134 | request_idx += 1 135 | 136 | batch_inputs = [] 137 | for _ in range(request_batch_size): 138 | # only sampling from the seed tasks 139 | prompt_agents = random.sample(seed_agent_data, num_prompt_agents) 140 | prompt = encode_prompt(prompt_agents) 141 | batch_inputs.append(prompt) 142 | decoding_args = utils.OpenAIDecodingArguments( 143 | temperature=temperature, 144 | n=1, 145 | max_tokens=3072, # hard-code to maximize the length. the requests will be automatically adjusted 146 | top_p=top_p, 147 | stop=["\n20", "20.", "20."], 148 | ) 149 | request_start = time.time() 150 | results = utils.openai_completion( 151 | prompts=batch_inputs, 152 | model_name=model_name, 153 | batch_size=request_batch_size, 154 | decoding_args=decoding_args, 155 | logit_bias={"50256": -100}, # prevent the <|endoftext|> token from being generated 156 | ) 157 | request_duration = time.time() - request_start 158 | 159 | process_start = time.time() 160 | agent_data = [] 161 | for result in results: 162 | new_agents = post_process_gpt_response(num_prompt_agents, result) 163 | agent_data += new_agents 164 | 165 | total = len(agent_data) 166 | keep = 0 167 | for agent_data_entry in agent_data: 168 | # computing similarity with the pre-tokenzied instructions 169 | new_agent_tokens = scorer._tokenizer.tokenize(agent_data_entry["goal"]) 170 | with Pool(num_cpus) as p: 171 | rouge_scores = p.map( 172 | partial(rouge_scorer._score_lcs, new_agent_tokens), 173 | all_instruction_tokens, 174 | ) 175 | rouge_scores = [score.fmeasure for score in rouge_scores] 176 | most_similar_instructions = { 177 | all_roles[i]: rouge_scores[i] for i in np.argsort(rouge_scores)[-10:][::-1] 178 | } 179 | if max(rouge_scores) > 0.7: 180 | continue 181 | else: 182 | keep += 1 183 | agent_data_entry["most_similar_instructions"] = most_similar_instructions 184 | agent_data_entry["avg_similarity_score"] = float(np.mean(rouge_scores)) 185 | machine_agent_data.append(agent_data_entry) 186 | all_roles.append(agent_data_entry["goal"]) 187 | all_instruction_tokens.append(new_agent_tokens) 188 | progress_bar.update(1) 189 | process_duration = time.time() - process_start 190 | print(f"Request {request_idx} took {request_duration:.2f}s, processing took {process_duration:.2f}s") 191 | print(f"Generated {total} agents, kept {keep} agents") 192 | utils.jdump(machine_agent_data, os.path.join(output_dir, "self-gen.json")) 193 | 194 | 195 | def main(task, **kwargs): 196 | globals()[task](**kwargs) 197 | 198 | 199 | if __name__ == "__main__": 200 | fire.Fire(main) -------------------------------------------------------------------------------- /autoagents/data/generate_action_tasks/generate_data_chat_api.py: -------------------------------------------------------------------------------- 1 | """ 2 | run: 3 | python -m generate_data_chat_api generate_agents_data \ 4 | --output_dir ./new_data \ 5 | --seed_tasks_path ./seed_tasks.jsonl \ 6 | --num_agents_to_generate 1000 \ 7 | --model_name="gpt-4" \ 8 | """ 9 | import time 10 | import json 11 | import os 12 | import random 13 | import re 14 | import string 15 | from functools import partial 16 | from multiprocessing import Pool 17 | 18 | import numpy as np 19 | import tqdm 20 | from rouge_score import rouge_scorer 21 | import utils 22 | 23 | import fire 24 | import pdb 25 | 26 | ''' 27 | The tools used by Auto-GPT are available at: 28 | https://github.com/Significant-Gravitas/Auto-GPT/tree/master/autogpt/commands 29 | ''' 30 | 31 | def encode_prompt(prompt_agents): 32 | """Encode multiple prompt instructions into a single string.""" 33 | prompt = open("./prompt.txt").read() + "\n" 34 | 35 | for idx, task_dict in enumerate(prompt_agents): 36 | (name, goal) = task_dict["name"], task_dict["goal"] 37 | if not goal: 38 | raise 39 | prompt += f"###\n" 40 | prompt += f"{idx + 1}. Name: {name}\n" 41 | prompt += f"{idx + 1}. Goal:\n{goal}\n" 42 | prompt += f"###\n" 43 | prompt += f"{idx + 2}. Name:" 44 | return prompt 45 | 46 | 47 | def post_process_chat_gpt_response(num_prompt_agents, response): 48 | print ("post_process_gpt_response") 49 | if response is None: 50 | return [] 51 | raw_instructions = f"{num_prompt_agents+1}. Name:" + response['message']['content'] 52 | raw_instructions = re.split("###", raw_instructions) 53 | agents = [] 54 | for idx, inst in enumerate(raw_instructions): 55 | # if the decoding stops due to length, the last example is likely truncated so we discard it 56 | if idx == len(raw_instructions) - 1 and response["finish_reason"] == "length": 57 | continue 58 | idx += num_prompt_agents + 1 59 | splitted_data = re.split(f"{idx}\.\s+(Name|Goal):", inst) 60 | if len(splitted_data) != 5: 61 | continue 62 | else: 63 | name = splitted_data[2].strip() 64 | goal = splitted_data[4].strip() 65 | # filter out too short or too long role 66 | if len(goal.split()) <= 3 or len(goal.split()) > 150: 67 | continue 68 | # filter based on keywords that are not suitable for language models. 69 | blacklist = [ 70 | "kill", 71 | "harm", 72 | "discriminate", 73 | "racist", 74 | "figure", 75 | "plot", 76 | "chart", 77 | "image", 78 | "images", 79 | "graph", 80 | "graphs", 81 | "picture", 82 | "pictures", 83 | "file", 84 | "files", 85 | "draw", 86 | "plot", 87 | "go to", 88 | "video", 89 | "audio", 90 | "flowchart", 91 | "diagram", 92 | ] 93 | blacklist += [] 94 | if any(find_word_in_string(word, goal) for word in blacklist): 95 | continue 96 | # filter those starting with punctuation 97 | if goal[0] in string.punctuation: 98 | continue 99 | # filter those starting with non-english character 100 | if not goal[0].isascii(): 101 | continue 102 | agents.append({"name": name, "goal": goal}) 103 | return agents 104 | 105 | 106 | def find_word_in_string(w, s): 107 | return re.compile(r"\b({0})\b".format(w), flags=re.IGNORECASE).search(s) 108 | 109 | 110 | def generate_agents_data( 111 | output_dir="./", 112 | seed_tasks_path="./new_seed_tasks.jsonl", 113 | num_agents_to_generate=50, 114 | model_name="gpt-3.5-turbo", 115 | num_prompt_agents=8, 116 | temperature=1.0, 117 | top_p=1.0, 118 | num_cpus=8, 119 | ): 120 | print ("generate_instruction_following_data") 121 | seed_tasks = [json.loads(l) for l in open(seed_tasks_path, "r")] 122 | seed_agent_data = [ 123 | {"name": t["name"], "goal": t["goal"], "task_id": t["task_id"]} 124 | for t in seed_tasks 125 | ] 126 | print(f"Loaded {len(seed_agent_data)} human-written seed agents") 127 | 128 | os.makedirs(output_dir, exist_ok=True) 129 | request_idx = 0 130 | # load the LM-generated instructions 131 | machine_agent_data = [] 132 | machine_data_path = os.path.join(output_dir, "self-gen-batch1.json") 133 | if os.path.exists(machine_data_path): 134 | # machine_agent_data = utils.jload(machine_data_path) 135 | machine_agent_data = [json.loads(l) for l in open(machine_data_path, "r")] 136 | print(f"Loaded {len(machine_agent_data)} machine-generated agents") 137 | 138 | # similarities = {} 139 | scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=False) 140 | 141 | # now let's generate new instructions! 142 | progress_bar = tqdm.tqdm(total=num_agents_to_generate) 143 | if machine_agent_data: 144 | progress_bar.update(len(machine_agent_data)) 145 | 146 | previous_goals = [] 147 | if os.path.isfile("previous_goals.json"): 148 | with open("previous_goals.json", 'r') as f: 149 | previous_goals = json.load(f) 150 | 151 | # first we tokenize all the seed instructions and generated machine instructions 152 | all_goals = [d["goal"] for d in seed_agent_data] + [ 153 | d["goal"] for d in machine_agent_data 154 | ] + previous_goals 155 | all_goals = list(set(all_goals)) 156 | all_instruction_tokens = [scorer._tokenizer.tokenize(role) for role in all_goals] 157 | 158 | while len(machine_agent_data) < num_agents_to_generate: 159 | request_idx += 1 160 | 161 | # only sampling from the seed tasks 162 | prompt_agents = random.sample(seed_agent_data, num_prompt_agents) 163 | prompt = encode_prompt(prompt_agents) 164 | 165 | decoding_args = utils.OpenAIDecodingArguments( 166 | temperature=temperature, 167 | n=1, 168 | max_tokens=3072, # hard-code to maximize the length. the requests will be automatically adjusted 169 | top_p=top_p, 170 | stop=["\n60", "60."], 171 | ) 172 | request_start = time.time() 173 | results = utils.openai_completion( 174 | prompts=prompt, 175 | model_name=model_name, 176 | batch_size=1, 177 | decoding_args=decoding_args, 178 | logit_bias={"100257": -100}, # prevent the <|endoftext|> from being generated 179 | # "100265":-100, "100276":-100 for <|im_end|> and token 180 | ) 181 | request_duration = time.time() - request_start 182 | 183 | process_start = time.time() 184 | agent_data = post_process_chat_gpt_response(num_prompt_agents, results) 185 | 186 | total = len(agent_data) 187 | keep = 0 188 | for agent_data_entry in agent_data: 189 | # computing similarity with the pre-tokenzied instructions 190 | new_agent_tokens = scorer._tokenizer.tokenize(agent_data_entry["goal"]) 191 | with Pool(num_cpus) as p: 192 | rouge_scores = p.map( 193 | partial(rouge_scorer._score_lcs, new_agent_tokens), 194 | all_instruction_tokens, 195 | ) 196 | rouge_scores = [score.fmeasure for score in rouge_scores] 197 | # most_similar_instructions = { 198 | # all_goals[i]: rouge_scores[i] for i in np.argsort(rouge_scores)[-10:][::-1] 199 | # } 200 | max_score = max(rouge_scores) 201 | if max_score > 0.40: 202 | continue 203 | else: 204 | keep += 1 205 | # agent_data_entry["most_similar_instructions"] = most_similar_instructions 206 | # agent_data_entry["avg_similarity_score"] = float(np.mean(rouge_scores)) 207 | agent_data_entry["max_similarity_score"] = max_score 208 | agent_data_entry["seed_tasks"] = [task["task_id"] for task in prompt_agents] 209 | machine_agent_data.append(agent_data_entry) 210 | all_goals.append(agent_data_entry["goal"]) 211 | all_instruction_tokens.append(new_agent_tokens) 212 | progress_bar.update(1) 213 | process_duration = time.time() - process_start 214 | print(f"Request {request_idx} took {request_duration:.2f}s, processing took {process_duration:.2f}s") 215 | print(f"Generated {total} agents, kept {keep} agents") 216 | utils.jdump(machine_agent_data, os.path.join(output_dir, "self-gen.json")) 217 | 218 | 219 | def main(task, **kwargs): 220 | globals()[task](**kwargs) 221 | 222 | 223 | if __name__ == "__main__": 224 | fire.Fire(main) -------------------------------------------------------------------------------- /autoagents/data/generate_action_tasks/prompt.txt: -------------------------------------------------------------------------------- 1 | You have been asked to generate a set of 60 diverse tasks, each with defined goals and descriptions. 2 | These tasks will be given to a GPT model to carry out multi-step reasoning, planning, and execution of tools, without requiring user assistance. 3 | Today is Apr 30, 2023. 4 | 5 | Here are the requirements: 6 | 1. To maximize diversity, try not to repeat verbs and use diverse language for the goals. 7 | 2. The goals should be in English. 8 | 3. The goals should be 1 to 3 sentences long. Either an imperative sentence or a question is permitted. 9 | 4. Refrain from requesting the accomplishment of highly abstract or open-ended tasks, such as generating wealth or solving global issues. 10 | 5. The goal should be entirely autonomous, utilizing the given tools and eliminating the need for real-time user interaction. 11 | 6. Bear in mind that completing each goal may involve multiple steps instead of just a single step executed by the language model and tools. 12 | 7. The goal cannot be directly answerable and solvable using the language model. For example, "What is the capital of France?" and "Translate a paragraph of text from English to Spanish" are not acceptable. 13 | 8. The goal should include time-sensitive elements that may require up-to-date knowledge beyond what the current language model possesses, such as identifying upcoming event or reporting the latest news. 14 | 15 | The following tools will be provided and system will return observations: 16 | 1. {search: useful for when you need to answer questions about current events. You should ask targeted questions, args: "query"} 17 | 2. {finish: use this to signal that you have finished all your objectives, args: "response": "Final Answer: get the final answer to achieve the original goal. "} 18 | 19 | List of 60 goals: 20 | -------------------------------------------------------------------------------- /autoagents/data/generate_action_tasks/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | rouge_score 3 | fire 4 | openai 5 | transformers>=4.26.1 6 | torch 7 | sentencepiece 8 | tokenizers==0.12.1 9 | wandb 10 | -------------------------------------------------------------------------------- /autoagents/data/generate_action_tasks/run_genenerate_data.sh: -------------------------------------------------------------------------------- 1 | # Use this for chat API: 2 | python -m generate_data_chat_api generate_agents_data \ 3 | --output_dir ./new_data \ 4 | --seed_tasks_path ./seed_tasks.jsonl \ 5 | --num_agents_to_generate 1000 \ 6 | --model_name="gpt-4" -------------------------------------------------------------------------------- /autoagents/data/generate_action_tasks/seed_tasks_test.jsonl: -------------------------------------------------------------------------------- 1 | {"name": "urgent_car", "goal": "Find me the top 5 urgent care centers near the 750 7th Avenue. List the distances for each of them.", "task_id": "39f5f08e76fe4a66"} 2 | {"name": "pitchcer_filter", "goal": "I just got a Denali Water Pitcher (standard filter). When should I change the filter?", "task_id": "6b34a2619ead4216"} 3 | {"name": "a16z_investments", "goal": "List 3 most significant investments of a16z and their current combined valuation.", "task_id": "b475e11d5fd94049"} 4 | {"name": "musci_band", "goal": "What are the names of the current members of American heavy metal band who wrote the music for Hurt Locker The Musical?", "task_id": "faeadb6d51bc4c1a"} 5 | {"name": "international_traveler", "goal": "Recommend the top 5 national parks in the United States for international travelers. Provide a list of estimated cost for each destination.", "task_id": "6441ee70888e48f0"} 6 | {"name": "mars_explorer", "goal": "Come up with a reasonable, actionable plan to get me to Mars.", "task_id": "63d60d723b0e468b"} 7 | {"name": "calculus_mathematicians", "goal": "Name a few mathematicians who contributed to the development of Calculus and also provide their contributions.", "task_id": "aa80dbf861bc437b"} 8 | {"name": "car_purchaser", "goal": "I am deciding between purchasing a 2023 Ioniq 5 and a 2023 Solterra. I repeatedly have to drive from SF to LA so I need a car with good battery life and charges quickly. Which option should I choose?", "task_id": "174497295ebd4e32"} -------------------------------------------------------------------------------- /autoagents/data/generate_action_tasks/seed_tasks_train.jsonl: -------------------------------------------------------------------------------- 1 | {"name": "holiday_chef", "goal": "Find the upcoming holiday and generate a home recipe to celebrate.", "task_id": "5c28888898184a75"} 2 | {"name": "pitt_weather_reporter", "goal": "Get weather forcast for San Francisco tomorrow.", "task_id": "479ce326b693458e"} 3 | {"name": "current_weather", "goal": "What is the current temperature and weather condition in Paris?", "task_id": "10825a7b70f24cdc"} 4 | {"name": "ucl_reporter", "goal": "Who is the most recent NBA MVP? Which team does he play for? What are his career stats?", "task_id": "67e21d88864040b6"} 5 | {"name": "nba_reporter", "goal": "Who is the leading goalscorer in the current season of the UEFA Champions League?", "task_id": "201cbef897f04aa8"} 6 | {"name": "sports_score", "goal": "What was the final score of the latest NBA game between the Los Angeles Lakers and Memphis Grizzlies?", "task_id": "27d3ff8da66b4638"} 7 | {"name": "sports_odds_maker", "goal": "What are the betting odds for the next UFC fight between Conor McGregor?", "task_id": "a4857bf2bd504c68"} 8 | {"name": "usopen_champion_bio", "goal": "What is the hometown of the reigning champion for U.S. Open, men's single?", "task_id": "19d05341e4cb4acd"} 9 | {"name": "stock_invester", "goal": "Can you find the stock that paid out highest dividend recently?. Should have source to support the reasons. Recommend at most three stocks", "task_id": "31e07ce5dc6f4413"} 10 | {"name": "active_stocks", "goal": "Recommend me three most actively traded stocks recently.", "task_id": "21651d01eb4a4a3e"} 11 | {"name": "stock_prices", "goal": "Get the current stock price for Apple Inc. (AAPL). How large is its volume today?", "task_id": "6ce1eb431d434730"} 12 | {"name": "best_movie", "goal": "What is the highest grossing movie of all time. Find one movie, including the movie name, director and description.", "task_id": "eb4d5283fd4b4250"} 13 | {"name": "family_movie", "goal": "Recommend a kid-friendly movie that is playing at a theater near Sunnyvale. Give me the showtimes and a link to purchase the tickets", "task_id": "1d8aae1cafb648b8"} 14 | {"name": "book_recommendation", "goal": "Recommend the 10 best-selling books on Amazon. Include books name and author.", "task_id": "9b33cf57d4194577"} 15 | {"name": "crime_reporter", "goal": "Give me the most recent crime alert in New York City. Include the crime time, location and description", "task_id": "405d201e150c4ea2"} 16 | {"name": "concert_location", "goal": "Give me information about the next concert performed in Las Vegas.", "task_id": "98e0661782424bdc"} 17 | {"name": "check_flights_nonstop", "goal": "What is the earlist non-stop flight from Seattle to San Francisco?", "task_id": "da2b7a5e52084210"} 18 | {"name": "flight_status", "goal": "What is the status of flight AA101 from New York to London scheduled today? Is there a delay or cancellation?", "task_id": "fb719d0ca75647c6"} 19 | {"name": "laptop_recommend", "goal": "Recommend me some laptops suitable for UI designers under $2000. Please include brand and price.", "task_id": "7cfbd8e07d764745"} 20 | {"name": "podcast_planner", "goal": "Write a podcast outline on a recent political event", "task_id": "1aaf9a6a8d0c494e"} 21 | {"name": "langchain", "goal": "What is the LangChain framework? Give me an example in Python.", "task_id": "f9f9f399a96f4b9b"} 22 | {"name": "Ilya_paper", "goal": "Can you give me the title of the paper the Ilya Sutskever wrote when he was about 30?", "task_id": "dee6e17fc9054a63"} 23 | {"name": "tom_cruise_pay", "goal": "How much did Tom Cruise make on his latest movie? Can you include both the movie names and the amount of money?", "task_id": "3749fd7592e846de"} 24 | {"name": "matt_damon_director", "goal": "Who is the director of the next movie with Matt Damon?", "task_id": "ace796ed525d4ec2"} 25 | {"name": "holiday_planner", "goal": "What are the top three places to visit in Europe during the summer of 2022? Provide a brief description of each location, including their popular attractions and must-see landmarks.", "task_id": "8d58b4facee24ed2"} 26 | {"name": "holiday_destination", "goal": "What is the best holiday destination for budget travelers in United States during summer of 2023? Provide an estimated cost for a week's stay.", "task_id": "bf1663cc8c844843"} 27 | {"name": "find_restaurant", "goal": "Suggest a fancy Italian restaurant in New York City suitable for a romantic dinner. Provide the address, phone number, and website for the restaurant.", "task_id": "4ae39cbeff834050"} 28 | {"name": "book_reader", "goal": "What is the most popular science-fiction book on the New York Times Best Sellers list? Provide the author's name, book title, and a brief description of the plot.", "task_id": "8f1f5ada0058409b"} 29 | {"name": "text_summarizer", "goal": "Can you find some latest articles on the COVID-19 vaccine?", "task_id": "97563fc2f3e04351"} 30 | {"name": "job_searcher", "goal": "Find a job opening for a software engineer in San Francisco with competitive salary, at a company with at least 100 employees.", "task_id": "ee8539e503ae4854"} 31 | {"name": "music_listener", "goal": "What is currently the #1 song on the Billboard Hot 100 chart? Provide the song title, artist, and genre.", "task_id": "62461e3a44dc43b2"} 32 | {"name": "university_finder", "goal": "Rank the top three universities in the United States for computer science programs in year 2023, including their location, tuition fees, and notable alumni. Provide a brief description of each university's computer science program.", "task_id": "d1bb40f375e645a6"} 33 | {"name": "video_game_reviews", "goal": "What are the top three video games released in the last six months, according to their user ratings on Metacritic? Provide a brief summary of each game's strengths and weaknesses.", "task_id": "bb44d84662824824"} 34 | {"name": "news_headline", "goal": "What is the headline of the most viewed news articles today?", "task_id": "8b2f4c86fb45452a"} 35 | {"name": "upcoming_concerts", "goal": "When and where is the upcoming concert for Taylor Swift? Share a link to purchase tickets.", "task_id": "07f3c15f494340a3"} 36 | {"name": "stock_analyst", "goal": "What's the stock price for Tesla as of yesterday's closing bell? Provide the stock symbol, closing price, and a brief analysis of the stock's performance.", "task_id": "3ee477d14f3944ea"} 37 | {"name": "book_store_locator", "goal": "Find the nearest Barnes & Noble bookstore to zip code 15213.", "task_id": "4b5f52a47d8c487e"} 38 | {"name": "covid_statistics", "goal": "Provide the current COVID-19 statistics for the United States.", "task_id": "7b86aba330614722"} 39 | {"name": "pet_adoption_finder", "goal": "What is the best animal shelter that is currently allowing adoptions of cats or dogs in Seattle?", "task_id": "ab57b69aeee84e05"} 40 | {"name": "car_buyer", "goal": "Find a high-rated sedan car model that has been released within the last year. Provide information about the car's fuel efficiency and safety features.", "task_id": "f1dfe422c7924439"} 41 | {"name": "latest_tech_gadgets", "goal": "What are the newest tech gadgets on the market?", "task_id": "b92921736a8b4b23"} 42 | {"name": "popular_streamer", "goal": "Who is the streamer that has the most subscribers currently one of the most popular and what games do they stream? How many?", "task_id": "8bde59a428a943bc"} 43 | {"name": "TV_show_recommendation", "goal": "Recommend a drama TV show that is currently streaming on Netflix. Provide a brief plot summary and the rating.", "task_id": "7f012ca4be67485c"} 44 | {"name": "age_calculator", "goal": "What is the sum of ages of the wives of Barack Obama and Donald Trump?", "task_id": "e30679fc81a641a2"} 45 | {"name": "volleyball_winnder", "goal": "Which team won in women's volleyball in the Summer Olympics that was held in London?", "task_id": "7a6dd09052194d1c"} 46 | {"name": "olympic_medal", "goal": "Which country won the most gold medals in the last Winter Olympics?", "task_id": "50ae10f10dad46b8"} 47 | {"name": "vegan_bagel", "goal": "Recommend a bagel shop near the Strip district in Pittsburgh that offer vegan food", "task_id": "d25eccefd8af462b"} -------------------------------------------------------------------------------- /autoagents/data/generate_action_tasks/seed_tasks_valid.jsonl: -------------------------------------------------------------------------------- 1 | {"name": "autogpt_adoption", "goal": "What are the main reasons AutoGPT hasn't been adopted by the general public?", "task_id": "7bc03fce036e4302"} 2 | {"name": "tax_rate", "goal": "Find the latest local income tax rate for an individual working in the City of Pittsburgh. What is the deadline for tax return filing this year?", "task_id": "61b6dd397931404a"} 3 | {"name": "female_gymnast", "goal": "I love the Olympics and Women's Gymnastics is my favorite event. Which female gymnast alive today has the most medals?", "task_id": "ff3a57e32a324060"} 4 | {"name": "columbus_physicians", "goal": "Who are the top 5 primary care physicians in Columbus, OH? Include a detailed list for each physician and their expertise.", "task_id": "a42c239a3b8a44df"} 5 | {"name": "sangiovese_wine", "goal": "Suggest a couple of good, budget-friendly wines for someone who likes Sangiovese.", "task_id": "9fc48092050e4fef"} 6 | {"name": "khaleesi_dragons", "goal": "Name Khaleesi's three dragons and their character development throughout the seasons of Game of Thrones.", "task_id": "38a57d996ffc4d01"} 7 | {"name": "remote_worker", "goal": "If I am making \u00a3100000 working remotely in London, how much more spare money would I have if I had the same remote job in Bath?", "task_id": "ba5a3e97724f44fc"} 8 | {"name": "llama2_service", "goal": "I would like to deploy an inference service of LLAMA 2 for a week. Which cloud platforms should I choose and what are the price plans for different options?", "task_id": "e9db9ae025f84c8b"} 9 | {"name": "netflix_movie", "goal": "Provide a list of top-rated movies released this year from Netflix. Summarize their plots and reviews.", "task_id": "71cd3750f6854f61"} 10 | {"name": "civic_driver", "goal": "Suppose that I get my 2023 Honda Civic serviced, and then drive from Maine to San Diego. How many more miles should I drive before getting my car serviced again?", "task_id": "e41ab586a6294cc0"} 11 | {"name": "tbill_yield", "goal": "What is the current estimated yield for the treasury bills issued recently? Compare it with the yields of the treasury bills issued last month.", "task_id": "66898415ed24453a"} 12 | {"name": "book_reader", "goal": "Find me a book that explores similar themes to Spain's highest grossing movie in 2022. Why should I read it?", "task_id": "c1a7411b35544256"} 13 | {"name": "car_renter", "goal": "I plan to rent a car to drive from San Jose, CA to Los Angeles, CA with two friends. Recommend three economical options from different car rental companies and provide the estimated budget for each of the options.", "task_id": "2404514031784c7c"} 14 | {"name": "pittsburgh_events", "goal": "List 3 popular events happening in the city of Pittsburgh over the labor day weekend. Also provide entry fee, accessibility, bag policy, parking information etc about each of the events. Please try to include events involving food.", "task_id": "ff9f418eb3424e7a"} 15 | {"name": "nfl_team", "goal": "Which NFL team has the youngest opening day roster in history? What was that teams record at the end of the year?", "task_id": "bdd0b5aafac6442c"} -------------------------------------------------------------------------------- /autoagents/data/generate_action_tasks/utils.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import logging 3 | import math 4 | import os 5 | import io 6 | import sys 7 | import time 8 | import json 9 | from typing import Optional, Sequence, Union 10 | 11 | import openai 12 | import tqdm 13 | from openai import openai_object 14 | import copy 15 | import pdb 16 | 17 | StrOrOpenAIObject = Union[str, openai_object.OpenAIObject] 18 | 19 | openai_org = os.getenv("OPENAI_ORG") 20 | if openai_org is not None: 21 | openai.organization = openai_org 22 | logging.warning(f"Switching to organization: {openai_org} for OAI API key.") 23 | 24 | 25 | @dataclasses.dataclass 26 | class OpenAIDecodingArguments(object): 27 | max_tokens: int = 1800 28 | temperature: float = 0.2 29 | top_p: float = 1.0 30 | n: int = 1 31 | stream: bool = False 32 | stop: Optional[Sequence[str]] = None 33 | presence_penalty: float = 0.0 34 | frequency_penalty: float = 0.0 35 | # logprobs: Optional[int] = None 36 | 37 | 38 | def openai_completion( 39 | prompts: Union[str, Sequence[str], Sequence[dict], dict], 40 | decoding_args: OpenAIDecodingArguments, 41 | model_name="text-davinci-003", 42 | sleep_time=2, 43 | batch_size=1, 44 | max_instances=sys.maxsize, 45 | max_batches=sys.maxsize, 46 | return_text=False, 47 | **decoding_kwargs, 48 | ): 49 | """Decode with OpenAI API. 50 | 51 | Args: 52 | prompts: A string or a list of strings to complete. If it is a chat model the strings should be formatted 53 | as explained here: https://github.com/openai/openai-python/blob/main/chatml.md. If it is a chat model 54 | it can also be a dictionary (or list thereof) as explained here: 55 | https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb 56 | decoding_args: Decoding arguments. 57 | model_name: Model name. Can be either in the format of "org/model" or just "model". 58 | sleep_time: Time to sleep once the rate-limit is hit. 59 | batch_size: Number of prompts to send in a single request. Only for non chat model. 60 | max_instances: Maximum number of prompts to decode. 61 | max_batches: Maximum number of batches to decode. This argument will be deprecated in the future. 62 | return_text: If True, return text instead of full completion object (which contains things like logprob). 63 | decoding_kwargs: Additional decoding arguments. Pass in `best_of` and `logit_bias` if you need them. 64 | 65 | Returns: 66 | A completion or a list of completions. 67 | Depending on return_text, return_openai_object, and decoding_args.n, the completion type can be one of 68 | - a string (if return_text is True) 69 | - an openai_object.OpenAIObject object (if return_text is False) 70 | - a list of objects of the above types (if decoding_args.n > 1) 71 | """ 72 | is_chat_model = "gpt-3.5" in model_name or "gpt-4" in model_name 73 | is_single_prompt = isinstance(prompts, (str, dict)) 74 | if is_single_prompt: 75 | prompts = [prompts] 76 | 77 | if max_batches < sys.maxsize: 78 | logging.warning( 79 | "`max_batches` will be deprecated in the future, please use `max_instances` instead." 80 | "Setting `max_instances` to `max_batches * batch_size` for now." 81 | ) 82 | max_instances = max_batches * batch_size 83 | 84 | prompts = prompts[:max_instances] 85 | num_prompts = len(prompts) 86 | prompt_batches = [ 87 | prompts[batch_id * batch_size : (batch_id + 1) * batch_size] 88 | for batch_id in range(int(math.ceil(num_prompts / batch_size))) 89 | ] 90 | 91 | completions = [] 92 | for batch_id, prompt_batch in tqdm.tqdm( 93 | enumerate(prompt_batches), 94 | desc="prompt_batches", 95 | total=len(prompt_batches), 96 | ): 97 | batch_decoding_args = copy.deepcopy(decoding_args) # cloning the decoding_args 98 | 99 | while True: 100 | try: 101 | shared_kwargs = dict( 102 | model=model_name, 103 | **batch_decoding_args.__dict__, 104 | **decoding_kwargs, 105 | ) 106 | if is_chat_model: 107 | completion_batch = openai.ChatCompletion.create( 108 | messages=[ 109 | {"role": "system", "content": "You are a helpful assistant."}, 110 | {"role": "user", "content": prompt_batch[0]} 111 | ], 112 | **shared_kwargs 113 | ) 114 | else: 115 | completion_batch = openai.Completion.create(prompt=prompt_batch, **shared_kwargs) 116 | 117 | choices = completion_batch.choices 118 | 119 | for choice in choices: 120 | choice["total_tokens"] = completion_batch.usage.total_tokens 121 | completions.extend(choices) 122 | break 123 | except openai.error.OpenAIError as e: 124 | logging.warning(f"OpenAIError: {e}.") 125 | if "Please reduce your prompt" in str(e): 126 | batch_decoding_args.max_tokens = int(batch_decoding_args.max_tokens * 0.8) 127 | logging.warning(f"Reducing target length to {batch_decoding_args.max_tokens}, Retrying...") 128 | else: 129 | logging.warning("Hit request rate limit; retrying...") 130 | time.sleep(sleep_time) # Annoying rate limit on requests. 131 | 132 | if return_text: 133 | completions = [completion.text for completion in completions] 134 | if decoding_args.n > 1: 135 | # make completions a nested list, where each entry is a consecutive decoding_args.n of original entries. 136 | completions = [completions[i : i + decoding_args.n] for i in range(0, len(completions), decoding_args.n)] 137 | if is_single_prompt: 138 | # Return non-tuple if only 1 input and 1 generation. 139 | (completions,) = completions 140 | return completions 141 | 142 | 143 | def _make_w_io_base(f, mode: str): 144 | if not isinstance(f, io.IOBase): 145 | f_dirname = os.path.dirname(f) 146 | if f_dirname != "": 147 | os.makedirs(f_dirname, exist_ok=True) 148 | f = open(f, mode=mode) 149 | return f 150 | 151 | 152 | def _make_r_io_base(f, mode: str): 153 | if not isinstance(f, io.IOBase): 154 | f = open(f, mode=mode) 155 | return f 156 | 157 | 158 | def jdump(obj, f, mode="w", indent=4, default=str): 159 | """Dump a str or dictionary to a file in json format. 160 | 161 | Args: 162 | obj: An object to be written. 163 | f: A string path to the location on disk. 164 | mode: Mode for opening the file. 165 | indent: Indent for storing json dictionaries. 166 | default: A function to handle non-serializable entries; defaults to `str`. 167 | """ 168 | f = _make_w_io_base(f, mode) 169 | if isinstance(obj, (dict, list)): 170 | json.dump(obj, f, indent=indent, default=default) 171 | elif isinstance(obj, str): 172 | f.write(obj) 173 | else: 174 | raise ValueError(f"Unexpected type: {type(obj)}") 175 | f.close() 176 | 177 | 178 | def jload(f, mode="r"): 179 | """Load a .json file into a dictionary.""" 180 | f = _make_r_io_base(f, mode) 181 | jdict = json.load(f) 182 | f.close() 183 | return jdict 184 | -------------------------------------------------------------------------------- /autoagents/eval/README.md: -------------------------------------------------------------------------------- 1 | Use test.py to evaluate models: 2 | 3 | ``` 4 | PYTHONPATH=`pwd` python autoagents/eval/test.py --help 5 | ``` 6 | ``` 7 | usage: test.py [-h] [--model MODEL] [--temperature TEMPERATURE] [--agent [{ddg,wiki}]] [--persist-logs] 8 | [--dataset [{default,hotpotqa,ft,hf,bamboogle}]] [--eval] [--prompt-version [{v2,v3}]] [--slice SLICE] 9 | 10 | optional arguments: 11 | -h, --help show this help message and exit 12 | --model MODEL model to be tested 13 | --temperature TEMPERATURE 14 | model temperature 15 | --agent [{ddg,wiki}] which action agent we want to interact with(default: ddg) 16 | --persist-logs persist logs on disk, enable this feature for later eval purpose 17 | --log-save-dir LOG_SAVE_DIR 18 | dir to save logs 19 | --dataset [{default,hotpotqa,ft,hf,bamboogle}] 20 | which dataset we want to interact with(default: default) 21 | --eval enable automatic eval 22 | --prompt-version [{v2,v3}] 23 | which version of prompt to use(default: v2) 24 | --slice SLICE slice the dataset from left, question list will start from index 0 to slice - 1 25 | ``` 26 | Sample command to eval on Hotpotqa dataset: 27 | ``` 28 | PYTHONPATH=`pwd` python autoagents/eval/test.py --model gpt-4 --temperature 0 --agent wiki --persist-logs --dataset hotpotqa --prompt-version v2 --eval 29 | ``` 30 | 31 | Sample command to eval on Bamboogle dataset: 32 | ``` 33 | PYTHONPATH=`pwd` python autoagents/eval/test.py --model gpt-4 --temperature 0 --agent ddg --persist-logs --dataset bamboogle --prompt-version v2 --eval 34 | ``` 35 | These commands will generate model logs under `data` folders automatically and run evaluation scripts on those logs. 36 | 37 | 38 | ## Common Metrics 39 | 40 | ### general counts 41 | 42 | - total_logs 43 | 44 | Total number of log json files evaluated 45 | 46 | - total_steps 47 | 48 | Total number of steps in all log files evaluated 49 | 50 | - total_rewrites 51 | 52 | Total number of rewrites triggered 53 | 54 | - total_valid 55 | 56 | Number of valid log files. In most cases, a log is valid when it does not contain any errors. 57 | 58 | - valid_steps 59 | 60 | Aggregated number of steps in valid log files. 61 | 62 | - search_invoked 63 | 64 | Number of times `Tool_Search` or `Tool_Wikipedia` is invoked. 65 | 66 | - notepad_invoked 67 | 68 | Number of times `Tool_Notepad` is invoked. 69 | 70 | - Endwith_{action/tool} 71 | 72 | Number of times a conversation ends with a specific tool. 73 | 74 | - visit_in_plan 75 | 76 | Number of plans that start with `Visit` 77 | 78 | - len_hist 79 | 80 | Aggregated length of history trace 81 | 82 | - duplicate_actions 83 | 84 | Number of duplicate {action}+{action_inputs} pairs 85 | 86 | - Finish_with_dups 87 | 88 | Number of duplicate `Tool_Finish`+{action_inputs} pairs 89 | 90 | - average_answer_missing 91 | 92 | The ratio of times when the agent fails to produce a final answer for an input sample. 93 | 94 | - average_steps 95 | 96 | Average number of steps in a conversation to reach the final answer. 97 | 98 | - total_samples 99 | 100 | The total number of samples/goals evaluated. 101 | 102 | - finished_samples 103 | 104 | The number of samples where the agent is able to call `Tool_Finish` for it. 105 | 106 | ### error counts 107 | 108 | Count the number of times a specific pattern of error occurs in the error log. 109 | 110 | - invalid_tools_error 111 | 112 | Check whether error log contains "Invalid tool requested by the model.". 113 | 114 | - context_len_error 115 | 116 | Check whether error log contains "This model's maximum context length". 117 | 118 | - dns_error 119 | 120 | Check whether error log contains "[Errno -3] Temporary failure in name resolution". 121 | 122 | - parse_error 123 | 124 | Check whether error log contains "Could not parse LLM output:". 125 | 126 | - rate_limit_error 127 | 128 | Check whether error log contains "Rate limit reached for ". 129 | 130 | - connection_error 131 | 132 | Check whether error log contains "[Errno 111] Connection refused". 133 | 134 | - other_error 135 | 136 | Any other kinds of uncaught exceptions will be marked as other_error. 137 | 138 | ### plan patterns 139 | 140 | Occurrence of action sequences 141 | 142 | - Tool_Search->Tool_Notepad 143 | 144 | - Tool_Search->Tool_Search->Tool_Notepad 145 | 146 | - Tool_Search->Tool_Search->Tool_Search->Tool_Notepad 147 | 148 | - … 149 | 150 | ### histograms 151 | 152 | - len_history_trace 153 | 154 | histogram of the lengths of history trace 155 | 156 | - len_initial_plan 157 | 158 | histogram of the lengths of initial plans 159 | -------------------------------------------------------------------------------- /autoagents/eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutoLLM/AutoAgents/e1210c7884f951f1b90254e40c162e81fc1442f3/autoagents/eval/__init__.py -------------------------------------------------------------------------------- /autoagents/eval/bamboogle.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | import pprint 5 | import shutil 6 | 7 | from autoagents.data.dataset import BAMBOOGLE 8 | from autoagents.eval.metrics import get_common_stats 9 | from autoagents.eval.hotpotqa.eval_async import check_answer_equivalency 10 | from autoagents.agents.utils.constants import LOG_SAVE_DIR 11 | from tqdm import tqdm 12 | from langchain.chat_models import ChatOpenAI 13 | 14 | 15 | async def eval(eval_results_path: str=LOG_SAVE_DIR): 16 | files = glob.glob(f"{eval_results_path}/*.json") 17 | evalllm = ChatOpenAI( 18 | openai_api_key=os.getenv("OPENAI_API_KEY"), 19 | openai_organization=os.getenv("OPENAI_API_ORG"), 20 | temperature=0, 21 | model="gpt-3.5-turbo", 22 | request_timeout=120 23 | ) 24 | print(f"Found {len(files)} log files! Starts to analyze......") 25 | common_stats = get_common_stats(files) 26 | print(common_stats) 27 | accuracy = 0 28 | correct_res_dir, wrong_res_dir, err_res_dir = f"{eval_results_path}-eval/correct", f"{eval_results_path}-eval/wrong", f"{eval_results_path}-eval/error" 29 | os.makedirs(correct_res_dir, exist_ok=True) 30 | os.makedirs(wrong_res_dir, exist_ok=True) 31 | os.makedirs(err_res_dir, exist_ok=True) 32 | for file in tqdm(files): 33 | finish = False 34 | with open(file, "r") as f: 35 | log_data = json.load(f) 36 | has_error = any([True if "error" in entry else False for entry in log_data]) 37 | for entry in log_data: 38 | if not has_error: 39 | if "goal" in entry: 40 | question = entry["goal"] 41 | if "conversations" in entry: 42 | output = json.loads(entry["conversations"][-1]["value"]) 43 | if output["action"] == "Tool_Finish": 44 | finish = True 45 | action_input = output["action_input"] 46 | for i in range(len(BAMBOOGLE["questions"])): 47 | if question == BAMBOOGLE["questions"][i]: 48 | answer = BAMBOOGLE["answers"][i] 49 | resp_obj = await check_answer_equivalency(question, answer, action_input, evalllm) 50 | is_correct = int(resp_obj.get("is_inferable", 0)) 51 | if is_correct: 52 | shutil.copy2(file, correct_res_dir) 53 | else: 54 | shutil.copy2(file, wrong_res_dir) 55 | accuracy += is_correct 56 | else: 57 | shutil.copy2(file, err_res_dir) 58 | if not finish: 59 | shutil.copy2(file, wrong_res_dir) 60 | counts = common_stats["counts"] 61 | total_samples = counts["total_samples"] 62 | finished_samples = counts["finished_samples"] 63 | print(f'accuracy overall is {accuracy}/{total_samples}={accuracy/total_samples}') 64 | print(f'accuracy on finished samples is {accuracy}/{finished_samples}={accuracy/finished_samples}') 65 | counts["accuracy on finished samples"] = accuracy/finished_samples 66 | counts["accuracy"] = accuracy/total_samples 67 | counts["average_answer_missing"] = (total_samples - finished_samples) / total_samples 68 | pprint.pprint(common_stats) 69 | with open(f"{eval_results_path}-eval/stats.json", "w") as f: 70 | json.dump(common_stats, f) 71 | -------------------------------------------------------------------------------- /autoagents/eval/hotpotqa/README.md: -------------------------------------------------------------------------------- 1 | # Search LLM Evaluation 2 | 3 | ## Overview 4 | 5 | ### Dataset 6 | - [HotpotQA](https://hotpotqa.github.io/) 7 | - [Fullwiki Dev Set](http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_fullwiki_v1.json) 8 | 9 | ### Metrics 10 | - Recall/Precision/F1 11 | - Mean Reciprocal Rank (MRR) 12 | - Accuracy 13 | - Missing Rate 14 | 15 | ## Results 16 | 17 | - Overall 18 | | Metric | LLM | 30 Samples | 100 Samples | 500 Samples | 19 | | --- | --- | --- | --- | --- | 20 | | LLM Accuracy | GPT-3.5-turbo | 0.4 | 0.35 | 0.328| 21 | | LLM Accuracy | GPT-4 | **0.7** | **0.55** | **0.414** | 22 | | Supporting facts recall | GPT-3.5-turbo | 0.3833 | 0.355 | 0.313 | 23 | | Supporting facts recall | GPT-4 | **0.5333** | **0.44** | **0.353** | 24 | | Max MRR | GPT-3.5-turbo | 0.2606 | 0.2862 | 0.2412 | 25 | | Max MRR | GPT-4 | **0.3531** | **0.3229** | **0.2740** | 26 | | First MRR | GPT-3.5-turbo | 0.2522 | 0.2799 | 0.2355 | 27 | | First MRR | GPT-4 | **0.3531** | **0.3204** | **0.2704** | 28 | | Last MRR | GPT-3.5-turbo | 0.2592 | 0.2743 | 0.2288 | 29 | | Last MRR | GPT-4 | **0.3481** | **0.3127** | **0.2663** | 30 | 31 | - Only on output with final answers 32 | | Metric | LLM | 30 Samples | 100 Samples | 500 Samples | 33 | | --- | --- | --- | --- | --- | 34 | | LLM Accuracy | GPT-3.5-turbo | 0.48 | 0.4667 | 0.5031 | 35 | | LLM Accuracy | GPT-4 | **0.8077** | **0.7534** | **0.6635** | 36 | | Supporting facts recall | GPT-3.5-turbo | 0.46 | 0.4733 | 0.4801 | 37 | | Supporting facts recall | GPT-4 | **0.6154** | **0.6027** | **0.5657** | 38 | | Max MRR | GPT-3.5-turbo | 0.3127 | 0.3816 | 0.3700 | 39 | | Max MRR | GPT-4 | **0.4074** | **0.4424** | **0.4390** | 40 | | First MRR | GPT-3.5-turbo | 0.3027 | 0.3732 | 0.3611 | 41 | | First MRR | GPT-4 | **0.4074** | **0.4389** | **0.4333** | 42 | | Last MRR | GPT-3.5-turbo | 0.311 | 0.3657 | 0.3510 | 43 | | Last MRR | GPT-4 | **0.4016** | **0.4283** | **0.4267** | 44 | 45 | - Error rate 46 | | Metric | LLM | 30 Samples | 100 Samples | 500 Samples | 47 | | --- | --- | --- | --- | --- | 48 | | Parsing error rate | GPT-3.5-turbo | 0.0667 | 0.01 | 0.06 | 49 | | Parsing error rate | GPT-4 | 0.0667 | **0** | **0.008** | 50 | | Missiong rate | GPT-3.5-turbo | 0.1667 | **0.25** | **0.348** | 51 | | Missiong rate | GPT-4 | **0.1333** | 0.27 | 0.376 | 52 | 53 | -------------------------------------------------------------------------------- /autoagents/eval/hotpotqa/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutoLLM/AutoAgents/e1210c7884f951f1b90254e40c162e81fc1442f3/autoagents/eval/hotpotqa/__init__.py -------------------------------------------------------------------------------- /autoagents/eval/hotpotqa/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | MODEL_NAME: str = "gpt-3.5-turbo" 5 | 6 | PERSIST_LOGS: bool = True 7 | 8 | EVAL_MODEL_NAME: str = "gpt-3.5-turbo" 9 | 10 | TEMPERATURE: float = 0 11 | 12 | NUM_SAMPLES_TOTAL: int = 200 13 | 14 | AWAIT_TIMEOUT: int = 360 15 | 16 | ROUND_WAITTIME: int = 10 17 | 18 | MAX_RETRY_ROUND: int = 1 19 | 20 | MAX_ROUND_STEPS: int = 30 21 | 22 | OPENAI_MODEL_NAMES = {"gpt-3.5-turbo", "gpt-4"} 23 | 24 | PARENT_DIRECTORY: str = os.path.dirname(os.path.abspath(__file__)) 25 | 26 | GT_FILE: str = os.path.join(PARENT_DIRECTORY, "hotpot_dev_fullwiki_v1.json") 27 | 28 | GT_URL: str = "http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_fullwiki_v1.json" 29 | 30 | RESULTS_DIR: str = os.path.join(PARENT_DIRECTORY, f"results_{MODEL_NAME}") 31 | 32 | OUTPUT_FILE: str = os.path.join(RESULTS_DIR, f"prediction.json") 33 | 34 | RUN_EVAL_LOG_FILE: str = os.path.join(RESULTS_DIR, "run_eval.log") 35 | 36 | WRONG_ANS_OUTPUT_FILE: str = os.path.join(RESULTS_DIR, f"wrong_answers.json") 37 | 38 | NEW_LOG_DIR: str = os.path.join(RESULTS_DIR, "data") 39 | -------------------------------------------------------------------------------- /autoagents/eval/hotpotqa/eval_async.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import asyncio 4 | import requests 5 | from argparse import ArgumentParser 6 | from typing import Optional, Union 7 | from tqdm.asyncio import tqdm_asyncio 8 | from langchain.schema import HumanMessage 9 | from langchain.chat_models import ChatOpenAI 10 | 11 | from autoagents.eval.hotpotqa.constants import * 12 | from autoagents.agents.utils.constants import LOG_SAVE_DIR 13 | from autoagents.eval.metrics import get_summary_from_log_data 14 | from autoagents.eval.hotpotqa.hotpotqa_eval import eval 15 | 16 | 17 | class HotpotqaAsyncEval: 18 | 19 | def __init__( 20 | self, 21 | model: str, 22 | ckpt_dir: Optional[str] = None, 23 | pred_file: Optional[str] = None 24 | ): 25 | 26 | if ckpt_dir is None: 27 | ckpt_dir = os.path.join(PARENT_DIRECTORY, f"results_{model}") 28 | 29 | if not os.path.isdir(ckpt_dir): 30 | os.mkdir(ckpt_dir) 31 | 32 | self.pred_file = pred_file or os.path.join(ckpt_dir, "prediction.json") 33 | self.new_log_dir = os.path.join(ckpt_dir, "data") 34 | self.wrong_ans_file = os.path.join(ckpt_dir, "wrong_ans.json") 35 | 36 | def get_questions(self, total: Optional[int] = None): 37 | dataset = prepare_dataset(total=total, pred_ckpt=self.pred_file) 38 | return [data["question"] for data in dataset] 39 | 40 | def run(self, log_dir: Optional[str] = None): 41 | 42 | if log_dir is None: 43 | if not os.path.isdir(self.new_log_dir): 44 | os.mkdir(self.new_log_dir) 45 | if os.path.isdir(LOG_SAVE_DIR): 46 | for log_file in os.listdir(LOG_SAVE_DIR): 47 | os.rename( 48 | src=os.path.join(LOG_SAVE_DIR, log_file), 49 | dst=os.path.join(self.new_log_dir, log_file) 50 | ) 51 | os.rmdir(LOG_SAVE_DIR) 52 | log_dir = self.new_log_dir 53 | 54 | pred_dict = predict_log_dir(log_dir=log_dir, pred_ckpt=self.pred_file) 55 | 56 | self.save_output(pred_dict=pred_dict) 57 | 58 | eval(self.pred_file, GT_FILE) 59 | 60 | def save_output(self, pred_dict: dict): 61 | 62 | with open(self.pred_file, 'w') as f: 63 | json.dump(pred_dict, f, indent=2) 64 | 65 | wrong_ans = [] 66 | for qid, stat in pred_dict["statistics"].items(): 67 | if stat["summary"]["counts"].get("equivalency", 0) == 0: 68 | wrong_ans.append({ 69 | "question": stat["question"], 70 | "gt_answer": stat["gt_answer"], 71 | "prediction": pred_dict["answer"].get(qid, ''), 72 | "reasoning": stat["reasoning"] 73 | }) 74 | with open(self.wrong_ans_file, 'w') as f: 75 | json.dump(wrong_ans, f, indent=2) 76 | 77 | 78 | def get_pred_dict(pred_ckpt: Optional[str] = None): 79 | 80 | if pred_ckpt is not None and os.path.isfile(pred_ckpt): 81 | with open(pred_ckpt, 'r') as f: 82 | return json.load(f) 83 | 84 | return {"answer": {}, "statistics": {}, "sp": {}, "error": {}} 85 | 86 | 87 | def prepare_dataset( 88 | total: Optional[int] = None, 89 | pred_ckpt: Optional[Union[str, dict]] = None, 90 | log_dir: Optional[str] = None 91 | ): 92 | 93 | full_dataset = get_hotpotqa_fullwiki_devset() 94 | filtered_dataset = [] 95 | 96 | if total is None: 97 | total = len(full_dataset) 98 | 99 | if log_dir is not None and os.path.isdir(log_dir): 100 | goal_set: set = set() 101 | for log_file in os.listdir(log_dir): 102 | with open(os.path.join(log_dir, log_file), 'r') as f: 103 | try: 104 | log_data = json.load(f) 105 | except json.decoder.JSONDecodeError: 106 | continue 107 | if log_data and isinstance(log_data, list): 108 | goal = None 109 | for entry in log_data: 110 | if "goal" in entry: 111 | goal = entry["goal"] 112 | break 113 | if goal: 114 | goal_set.add(goal) 115 | for data in full_dataset: 116 | for goal in goal_set: 117 | if data["question"] in goal: 118 | filtered_dataset.append(data) 119 | return filtered_dataset 120 | 121 | if isinstance(pred_ckpt, dict): 122 | pred_dict = pred_ckpt 123 | else: 124 | pred_dict = get_pred_dict(pred_ckpt=pred_ckpt) 125 | 126 | dataset = [] 127 | num_new_ids = 0 128 | for data in full_dataset: 129 | if data["_id"] not in pred_dict["statistics"]: 130 | if len(pred_dict["statistics"]) + num_new_ids >= total: 131 | break 132 | dataset.append(data) 133 | num_new_ids += 1 134 | elif data["_id"] in pred_dict.get("error", []): 135 | dataset.append(data) 136 | 137 | return dataset 138 | 139 | 140 | def get_hotpotqa_fullwiki_devset(file: str = GT_FILE, url: str = GT_URL): 141 | 142 | if not os.path.isfile(file): 143 | response = requests.get(url) 144 | with open(file, 'wb') as f: 145 | f.write(response.content) 146 | 147 | with open(file, 'r') as f: 148 | return json.load(f) 149 | 150 | 151 | def evaluate_log_dir( 152 | log_dir: str = LOG_SAVE_DIR, 153 | pred_ckpt: Optional[str] = None 154 | ): 155 | pred_ckpt = pred_ckpt or os.path.join(PARENT_DIRECTORY, "prediction.json") 156 | pred_dict = predict_log_dir(log_dir=log_dir, pred_ckpt=pred_ckpt) 157 | with open(pred_ckpt, 'w') as f: 158 | json.dump(pred_dict, f, indent=2) 159 | eval(pred_ckpt, GT_FILE) 160 | 161 | 162 | def predict_log_dir( 163 | log_dir: str = LOG_SAVE_DIR, 164 | pred_ckpt: Optional[str] = None 165 | ): 166 | dataset = { 167 | data["question"]: data for data in prepare_dataset(log_dir=log_dir) 168 | } 169 | 170 | pred_dict = get_pred_dict(pred_ckpt=pred_ckpt) 171 | 172 | asyncio.run(collect_metrics( 173 | pred_dict=pred_dict, dataset=dataset, log_files=[ 174 | os.path.join(log_dir, file) for file in os.listdir(log_dir) 175 | ] 176 | )) 177 | 178 | return pred_dict 179 | 180 | 181 | async def collect_metrics(pred_dict: dict, dataset: dict, log_files: list): 182 | 183 | semaphore = asyncio.Semaphore(10) 184 | 185 | async def process_log_file(log_file: str): 186 | async with semaphore: 187 | with open(log_file, "r") as f: 188 | try: 189 | log_data = json.load(f) 190 | except json.decoder.JSONDecodeError: 191 | return 192 | await evaluate_log_data(log_data, pred_dict, dataset) 193 | 194 | await tqdm_asyncio.gather(*[ 195 | process_log_file(log_file) for log_file in log_files 196 | ]) 197 | 198 | 199 | async def evaluate_log_data( 200 | log_data: dict, pred_dict: dict, dataset: dict 201 | ): 202 | 203 | if not log_data or not isinstance(log_data, list): 204 | return 205 | summary = get_summary_from_log_data(log_data=log_data) 206 | question = summary["question"] 207 | if question is None: 208 | return 209 | gt = None 210 | for q in dataset: 211 | if q in question: 212 | gt = dataset[q] 213 | break 214 | if gt is None: 215 | return 216 | qid = gt["_id"] 217 | if qid in pred_dict["answer"]: 218 | return 219 | for key in list(pred_dict.keys()): 220 | if qid in pred_dict[key]: 221 | del pred_dict[key][qid] 222 | 223 | titles = [] 224 | statistics = { 225 | "reasoning": '', 226 | "question": question, 227 | "gt_answer": gt["answer"], 228 | "gt_citations": [fact[0] for fact in gt["supporting_facts"]], 229 | "raw_citation_urls": [], 230 | "citations": {}, 231 | "summary": summary 232 | } 233 | 234 | if summary["answer"] is not None: 235 | pred_dict["answer"][gt["_id"]] = summary["answer"] 236 | if gt["_id"] in pred_dict["error"]: 237 | del pred_dict["error"][gt["_id"]] 238 | await evaluate_final_answer(summary["answer"], gt, pred_dict, statistics) 239 | 240 | for entry in log_data: 241 | 242 | if "error" in entry: 243 | pred_dict["error"][qid] = entry["error"] 244 | 245 | if "conversations" in entry: 246 | await process_conversation_log( 247 | entry["conversations"], statistics, titles 248 | ) 249 | 250 | if titles: 251 | pred_dict["sp"][qid] = titles 252 | if isinstance(statistics["citations"], dict): 253 | statistics["citations"] = [] 254 | pred_dict["statistics"][qid] = statistics 255 | if qid not in pred_dict["answer"] and qid not in pred_dict["error"]: 256 | pred_dict["error"][qid] = json.dumps(statistics, indent=2) 257 | 258 | 259 | async def process_conversation_log( 260 | conversations: list, statistics: dict, titles: list 261 | ): 262 | try: 263 | observation = conversations[0]["value"][-1]["observation"] 264 | titles.append([doc["title"] for doc in observation]) 265 | for doc in observation: 266 | statistics["citations"][doc["url"]] = doc["title"] 267 | except: 268 | pass 269 | 270 | try: 271 | prediction = json.loads(conversations[-1]["value"]) 272 | except json.decoder.JSONDecodeError: 273 | statistics["summary"]["error_counts"]["parse_error"] += 1 274 | return 275 | if prediction["action"] == "Tool_Finish": 276 | # Get list of citations 277 | citations = [] 278 | for citation in prediction.get("citations", []): 279 | if ": " not in citation: 280 | continue 281 | url = citation.split(": ")[0] 282 | statistics["raw_citation_urls"].append(url) 283 | if url in statistics["citations"]: 284 | citations.append(statistics["citations"].get(url)) 285 | statistics["citations"] = citations 286 | 287 | 288 | async def evaluate_final_answer( 289 | final_answer: str, data: dict, pred_dict, statistics, llm=None 290 | ): 291 | 292 | question: str = data["question"] 293 | gt_answer: str = data["answer"] 294 | 295 | try: 296 | # Use GPT to determine if the final output is equivalent with the ground truth 297 | resp_obj = await check_answer_equivalency(question, gt_answer, final_answer, llm) 298 | statistics["summary"]["counts"]["equivalency"] = int(resp_obj.get("is_inferable", 0)) 299 | statistics["reasoning"] = resp_obj.get("reasoning", '') 300 | 301 | except Exception as e: 302 | pred_dict["error"][data["_id"]] = f"Error during evalutaion: {e}" 303 | 304 | 305 | async def check_answer_equivalency(question: str, answer1: str, answer2: str, llm=None): 306 | 307 | if llm is None: 308 | llm = ChatOpenAI( 309 | openai_api_key=os.getenv("OPENAI_API_KEY"), 310 | openai_organization=os.getenv("OPENAI_API_ORG"), 311 | temperature=0, 312 | model=EVAL_MODEL_NAME, 313 | request_timeout=AWAIT_TIMEOUT 314 | ) 315 | 316 | # Use GPT to determine if the answer1 is equivalent with answer2 317 | resp = await llm.agenerate([[HumanMessage( 318 | content=f"Given a question and a pair of answers. Determine if Answer1 can be strictly infered from Answer2. Return False if given the information in Answer2, we cannot determine whether Answer1 is right. Add detailed explaination and reasioning. Format your answer in JSON with a boolean field called 'is_inferable' and a string field 'reasoning' that can be loaded in python.\n\nQuestion: '{question}'\n\nAnswer1: '{answer1}'\n\nAnswer2: '{answer2}'" 319 | )]]) 320 | return json.loads(resp.generations[0][0].text.strip()) 321 | 322 | 323 | def main(): 324 | 325 | parser = ArgumentParser() 326 | parser.add_argument("log_dir", type=str, help="path of the log directory") 327 | parser.add_argument("--pred_ckpt", type=str, help="path of the log directory") 328 | args = parser.parse_args() 329 | evaluate_log_dir(args.log_dir, args.pred_ckpt) 330 | 331 | if __name__ == "__main__": 332 | main() 333 | -------------------------------------------------------------------------------- /autoagents/eval/hotpotqa/hotpotqa_eval.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import ujson as json 4 | import re 5 | import string 6 | from collections import Counter 7 | from pprint import pprint 8 | 9 | 10 | def normalize_answer(s): 11 | 12 | def remove_articles(text): 13 | return re.sub(r'\b(a|an|the)\b', ' ', text) 14 | 15 | def white_space_fix(text): 16 | return ' '.join(text.split()) 17 | 18 | def remove_punc(text): 19 | exclude = set(string.punctuation) 20 | return ''.join(ch for ch in text if ch not in exclude) 21 | 22 | def lower(text): 23 | return text.lower() 24 | 25 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 26 | 27 | 28 | def f1_score(prediction, ground_truth): 29 | normalized_prediction = normalize_answer(prediction) 30 | normalized_ground_truth = normalize_answer(ground_truth) 31 | 32 | ZERO_METRIC = (0, 0, 0) 33 | 34 | # if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: 35 | # return ZERO_METRIC 36 | # if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth: 37 | # return ZERO_METRIC 38 | 39 | prediction_tokens = normalized_prediction.split() 40 | ground_truth_tokens = normalized_ground_truth.split() 41 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 42 | num_same = sum(common.values()) 43 | if num_same == 0: 44 | return ZERO_METRIC 45 | precision = 1.0 * num_same / len(prediction_tokens) 46 | recall = 1.0 * num_same / len(ground_truth_tokens) 47 | f1 = (2 * precision * recall) / (precision + recall) 48 | return f1, precision, recall 49 | 50 | 51 | def exact_match_score(prediction, ground_truth): 52 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 53 | 54 | def update_answer(metrics, prediction, gold): 55 | em = exact_match_score(prediction, gold) 56 | f1, prec, recall = f1_score(prediction, gold) 57 | metrics['em'] += float(em) 58 | metrics['f1'] += f1 59 | metrics['prec'] += prec 60 | metrics['recall'] += recall 61 | return em, prec, recall 62 | 63 | def update_sp(metrics, prediction, gold, statistics): 64 | 65 | # Only match titles 66 | cur_sp_pred = set(title for rank in prediction for title in rank) 67 | gold_sp_pred = set(x[0] for x in gold) 68 | 69 | tp, fp, fn = 0, 0, 0 70 | for e in cur_sp_pred: 71 | if e in gold_sp_pred: 72 | tp += 1 73 | else: 74 | fp += 1 75 | for e in gold_sp_pred: 76 | if e not in cur_sp_pred: 77 | fn += 1 78 | prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0 79 | recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0 80 | f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0 81 | em = 1.0 if fp + fn == 0 else 0.0 82 | metrics['sp_em'] += em 83 | metrics['sp_f1'] += f1 84 | metrics['sp_prec'] += prec 85 | metrics['sp_recall'] += recall 86 | 87 | if all(e in cur_sp_pred for e in gold_sp_pred) and \ 88 | statistics["summary"]["counts"].get("equivalency", 0) == 0: 89 | metrics["wrong_infer"] += 1 90 | 91 | title_to_ranks = {} 92 | for title_list in prediction: 93 | for i, title in enumerate(title_list): 94 | if title not in gold_sp_pred: 95 | continue 96 | if title not in title_to_ranks: 97 | title_to_ranks[title] = [i + 1] * 3 98 | title_to_ranks[title][0] = min(i + 1, title_to_ranks[title][0]) 99 | title_to_ranks[title][2] = i + 1 100 | n_gt_titles = len(gold_sp_pred) 101 | cur_ranks = title_to_ranks.values() 102 | metrics["max_mrr"] += sum(1 / ranks[0] for ranks in cur_ranks) / n_gt_titles 103 | metrics["first_mrr"] += sum(1 / ranks[1] for ranks in cur_ranks) / n_gt_titles 104 | metrics["last_mrr"] += sum(1 / ranks[2] for ranks in cur_ranks) / n_gt_titles 105 | 106 | return em, prec, recall 107 | 108 | def update_last_sp(metrics, statistics, gold): 109 | 110 | # Only match titles 111 | cur_sp_pred = set(title for title in statistics.get("citations", [])) 112 | gold_sp_pred = set(x[0] for x in gold) 113 | 114 | tp, fp, fn = 0, 0, 0 115 | for e in cur_sp_pred: 116 | if e in gold_sp_pred: 117 | tp += 1 118 | else: 119 | fp += 1 120 | for e in gold_sp_pred: 121 | if e not in cur_sp_pred: 122 | fn += 1 123 | prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0 124 | recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0 125 | f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0 126 | em = 1.0 if fp + fn == 0 else 0.0 127 | metrics['last_sp_em'] += em 128 | metrics['last_sp_f1'] += f1 129 | metrics['last_sp_prec'] += prec 130 | metrics['last_sp_recall'] += recall 131 | 132 | return em, prec, recall 133 | 134 | def eval(prediction_file, gold_file): 135 | with open(prediction_file) as f: 136 | prediction = json.load(f) 137 | print(f"len answer = {len(prediction['answer'])}, len error = {len(prediction['error'])}, len sp = {len(prediction['sp'])}, len statistics = {len(prediction['statistics'])}") 138 | 139 | with open(gold_file) as f: 140 | gold = [] 141 | for data in json.load(f): 142 | if data["_id"] in prediction["statistics"]: 143 | gold.append(data) 144 | 145 | metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0, 146 | 'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0, 147 | 'last_sp_em': 0, 'last_sp_f1': 0, 'last_sp_prec': 0, 'last_sp_recall': 0, 148 | 'joint_em': 0, 'joint_f1': 0, 'joint_prec': 0, 'joint_recall': 0, 149 | "max_mrr": 0, "first_mrr": 0, "last_mrr": 0, "wrong_infer": 0} 150 | stats = { 151 | "counts": Counter(), # general counters 152 | "error_counts": Counter(), # error counters 153 | "plan_counts": Counter(), # plan patterns 154 | "len_history_trace": [], 155 | "len_initial_plan": [] 156 | } 157 | for dp in gold: 158 | cur_id = dp['_id'] 159 | can_eval_joint = True 160 | if cur_id not in prediction['answer']: 161 | print('missing answer {}'.format(cur_id)) 162 | can_eval_joint = False 163 | else: 164 | em, prec, recall = update_answer( 165 | metrics, prediction['answer'][cur_id], dp['answer']) 166 | summary = prediction['statistics'][cur_id]["summary"] 167 | stats["counts"] += summary["counts"] 168 | stats["error_counts"] += summary["error_counts"] 169 | stats["plan_counts"] += summary["plan_counts"] 170 | stats["len_history_trace"].extend(summary["len_history_trace"]) 171 | stats["len_initial_plan"].extend(summary["len_initial_plan"]) 172 | 173 | if cur_id not in prediction['sp']: 174 | print('missing sp fact {}'.format(cur_id)) 175 | can_eval_joint = False 176 | else: 177 | sp_em, sp_prec, sp_recall = update_sp( 178 | metrics, prediction['sp'][cur_id], dp['supporting_facts'], prediction['statistics'][cur_id]) 179 | update_last_sp( 180 | metrics, prediction['statistics'].get(cur_id, {}), dp['supporting_facts'] 181 | ) 182 | 183 | if can_eval_joint: 184 | joint_prec = prec * sp_prec 185 | joint_recall = recall * sp_recall 186 | if joint_prec + joint_recall > 0: 187 | joint_f1 = 2 * joint_prec * joint_recall / (joint_prec + joint_recall) 188 | else: 189 | joint_f1 = 0. 190 | joint_em = em * sp_em 191 | 192 | metrics['joint_em'] += joint_em 193 | metrics['joint_f1'] += joint_f1 194 | metrics['joint_prec'] += joint_prec 195 | metrics['joint_recall'] += joint_recall 196 | 197 | N = len(gold) 198 | for k in metrics.keys(): 199 | metrics[k] /= N 200 | 201 | hist, rng = np.histogram(stats["len_history_trace"], bins=range(0, 16)) 202 | stats["len_history_trace"] = hist.tolist() 203 | hist, rng = np.histogram(stats["len_initial_plan"], bins=range(0, 16)) 204 | stats["len_initial_plan"] = hist.tolist() 205 | 206 | stats["error_rate"] = { 207 | error: cnt / N 208 | for error, cnt in stats["error_counts"].items() 209 | } 210 | stats["avg_metrics"] = { 211 | metric: cnt / N 212 | for metric, cnt in stats["counts"].items() 213 | } 214 | metrics.update(stats) 215 | 216 | metrics["ans_missing_rate"] = 1 - len(prediction["answer"]) / N 217 | metrics["sp_missing_rate"] = 1 - len(prediction["sp"]) / N 218 | metrics["num_evaluated"] = N 219 | metrics["llm_accuracy_on_finished_samples"] = stats["avg_metrics"]["equivalency"] / (1 - metrics["ans_missing_rate"]) 220 | 221 | pprint(metrics) 222 | 223 | if __name__ == '__main__': 224 | eval(sys.argv[1], sys.argv[2]) 225 | -------------------------------------------------------------------------------- /autoagents/eval/hotpotqa/run_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import asyncio 4 | import json 5 | import logging 6 | from langchain.chat_models import ChatOpenAI 7 | from pprint import pformat 8 | from ast import literal_eval 9 | from multiprocessing import Pool, Manager 10 | from functools import partial 11 | from tqdm import tqdm 12 | 13 | from autoagents.agents.agents.wiki_agent import WikiActionRunner 14 | from autoagents.agents.models.custom import CustomLLM 15 | from autoagents.eval.hotpotqa.eval_async import ( 16 | evaluate_final_answer, prepare_dataset 17 | ) 18 | from autoagents.eval.hotpotqa.hotpotqa_eval import eval 19 | from autoagents.eval.hotpotqa.constants import * 20 | from autoagents.agents.utils.constants import LOG_SAVE_DIR 21 | 22 | 23 | if not os.path.isdir(RESULTS_DIR): 24 | os.mkdir(RESULTS_DIR) 25 | 26 | logger = logging.getLogger(__name__) 27 | logger.setLevel(level=logging.DEBUG) 28 | log_filehandler = logging.FileHandler(RUN_EVAL_LOG_FILE) 29 | log_filehandler.setLevel(logging.DEBUG) 30 | log_filehandler.setFormatter( 31 | logging.Formatter('%(asctime)s - %(name)s - %(levelname)s \n%(message)s') 32 | ) 33 | logger.addHandler(log_filehandler) 34 | 35 | 36 | def get_llms(): 37 | if MODEL_NAME not in OPENAI_MODEL_NAMES: 38 | llm = CustomLLM( 39 | model_name=MODEL_NAME, 40 | temperature=TEMPERATURE, 41 | request_timeout=AWAIT_TIMEOUT 42 | ) 43 | else: 44 | llm = ChatOpenAI( 45 | openai_api_key=os.getenv("OPENAI_API_KEY"), 46 | openai_organization=os.getenv("OPENAI_API_ORG"), 47 | temperature=TEMPERATURE, 48 | model_name=MODEL_NAME, 49 | request_timeout=AWAIT_TIMEOUT 50 | ) 51 | 52 | evalllm = ChatOpenAI( 53 | openai_api_key=os.getenv("OPENAI_API_KEY"), 54 | openai_organization=os.getenv("OPENAI_API_ORG"), 55 | temperature=0, 56 | model=EVAL_MODEL_NAME, 57 | request_timeout=AWAIT_TIMEOUT 58 | ) 59 | return llm, evalllm 60 | 61 | 62 | async def work(data, pred_dict): 63 | outputq = asyncio.Queue() 64 | user_input = data["question"] 65 | 66 | llm, evalllm = get_llms() 67 | runner = WikiActionRunner(outputq, llm=llm, persist_logs=PERSIST_LOGS) 68 | task = asyncio.create_task(runner.run(user_input, outputq)) 69 | 70 | titles = [] 71 | statistics = { 72 | "steps": 0, "equivalency": 0, "reasoning": '', "question": user_input, "gt_answer": data["answer"], "raw_citation_urls": [], "citations": {}, "rewritten": 0, "search_invoked": 0, "notepad_invoked": 0, "multi_tools": 0, "parse_error": 0, "invalid_tool": 0, "context_len_err": 0 73 | } 74 | for _ in range(runner.agent_executor.max_iterations or MAX_ROUND_STEPS): 75 | 76 | try: 77 | output = await asyncio.wait_for(outputq.get(), AWAIT_TIMEOUT) 78 | except asyncio.TimeoutError: 79 | logger.error(f"Question: {user_input}\nError: Timed out waiting for output from queue\n") 80 | pred_dict["error"][data["_id"]] = "Timed out waiting for output from queue." 81 | break 82 | statistics["steps"] += 1 83 | 84 | if isinstance(output, Exception): 85 | logger.error(f"Question: {user_input}\nError: {output}\n") 86 | if isinstance(output, RuntimeWarning) and "Action Input Rewritten: " in str(output): 87 | statistics["rewritten"] += 1 88 | continue 89 | else: 90 | if "Could not parse LLM output: " in str(output): 91 | statistics["parse_error"] += 1 92 | elif "Invalid tool requested by the model." in str(output): 93 | statistics["invalid_tool"] += 1 94 | elif "This model's maximum context length is" in str(output): 95 | statistics["context_len_err"] += 1 96 | pred_dict["error"][data["_id"]] = str(output) 97 | break 98 | 99 | parsed = get_parsed_output(user_input, output, statistics, titles) 100 | 101 | if isinstance(parsed, dict) and parsed.get("action") == "Tool_Finish": 102 | final_answer: str = parsed["action_input"] 103 | logger.info(f"Question: {user_input}\nFinal Output: {final_answer}\n") 104 | 105 | # Get list of citations 106 | citations = [] 107 | for citation in parsed.get("citations", []): 108 | if ": " not in citation: 109 | continue 110 | url = citation.split(": ")[0] 111 | statistics["raw_citation_urls"].append(url) 112 | if url in statistics["citations"]: 113 | citations.append(statistics["citations"].get(url)) 114 | statistics["citations"] = citations 115 | 116 | await evaluate_final_answer(final_answer, data, pred_dict, statistics, evalllm) 117 | 118 | break 119 | if titles: 120 | pred_dict["sp"][data["_id"]] = json.dumps(titles) 121 | if isinstance(statistics["citations"], dict): 122 | statistics["citations"] = [] 123 | pred_dict["statistics"][data["_id"]] = json.dumps(statistics) 124 | if data["_id"] not in pred_dict["answer"] and data["_id"] not in pred_dict["error"]: 125 | pred_dict["error"][data["_id"]] = json.dumps(statistics, indent=2) 126 | 127 | # await task 128 | try: 129 | return await asyncio.wait_for(task, AWAIT_TIMEOUT) 130 | except asyncio.TimeoutError: 131 | logger.error(f"Question: {user_input}\nError: Timed out waiting for task to complete\n") 132 | pred_dict["error"][data["_id"]] = "Timed out waiting for task to complete." 133 | 134 | 135 | def get_parsed_output(user_input, output, statistics, titles): 136 | parsed = None 137 | try: 138 | parsed = json.loads(output) 139 | logger.debug(f"Question: {user_input}\n{json.dumps(parsed, indent=2)}") 140 | if parsed["action"] == "Tool_Wikipedia": 141 | statistics["search_invoked"] += 1 142 | elif parsed["action"] == "Tool_Notepad": 143 | statistics["notepad_invoked"] += 1 144 | except: 145 | try: 146 | parsed = literal_eval(output) 147 | logger.debug(f"Question: {user_input}\n{json.dumps(parsed, indent=2)}") 148 | if isinstance(parsed, list) and isinstance(parsed[0], dict) and "title" in parsed[0]: 149 | titles.append([doc["title"] for doc in parsed]) 150 | for doc in parsed: 151 | statistics["citations"][doc["url"]] = doc["title"] 152 | except: 153 | logger.debug(f"Question: {user_input}\n{output}") 154 | return parsed 155 | 156 | 157 | def save_output(): 158 | 159 | if PERSIST_LOGS: 160 | for log_file in os.listdir(LOG_SAVE_DIR): 161 | os.rename( 162 | src=os.path.join(LOG_SAVE_DIR, log_file), 163 | dst=os.path.join(NEW_LOG_DIR, log_file) 164 | ) 165 | os.rmdir(LOG_SAVE_DIR) 166 | 167 | output_dict = dict(pred_dict) 168 | for k in list(output_dict.keys()): 169 | output_dict[k] = dict(output_dict[k]) 170 | if k in ("sp", "statistics"): 171 | for qid in output_dict[k]: 172 | output_dict[k][qid] = json.loads(output_dict[k][qid]) 173 | if isinstance(output_dict[k][qid], str): 174 | output_dict[k][qid] = json.loads(output_dict[k][qid]) 175 | 176 | logger.info(pformat(output_dict, indent=2)) 177 | with open(OUTPUT_FILE, 'w') as f: 178 | json.dump(output_dict, f, indent=2) 179 | 180 | wrong_ans = [] 181 | for qid, stat in output_dict["statistics"].items(): 182 | if stat["equivalency"] == 0: 183 | wrong_ans.append({ 184 | "question": stat["question"], 185 | "gt_answer": stat["gt_answer"], 186 | "prediction": output_dict["answer"].get(qid, ''), 187 | "reasoning": stat["reasoning"] 188 | }) 189 | with open(WRONG_ANS_OUTPUT_FILE, 'w') as f: 190 | json.dump(wrong_ans, f, indent=2) 191 | 192 | 193 | def initialize_pred_dict(): 194 | 195 | pred_dict["answer"] = manager.dict() 196 | pred_dict["statistics"] = manager.dict() 197 | pred_dict["sp"] = manager.dict() 198 | pred_dict["error"] = manager.dict() 199 | 200 | cur_dict = {} 201 | if os.path.isfile(OUTPUT_FILE): 202 | with open(OUTPUT_FILE, 'r') as f: 203 | cur_dict = json.load(f) 204 | pred_dict["answer"].update(cur_dict["answer"]) 205 | for _id, sp in cur_dict["sp"].items(): 206 | pred_dict["sp"][_id] = json.dumps(sp) 207 | for _id, stat in cur_dict["statistics"].items(): 208 | pred_dict["statistics"][_id] = json.dumps(stat) 209 | 210 | 211 | def retry(dataset): 212 | 213 | # Retry until we get all the final answers 214 | round = 0 215 | while pred_dict["error"] and round < MAX_RETRY_ROUND: 216 | 217 | logger.info( 218 | f"Round {round}. Start retrying failed samples: " 219 | f"{json.dumps(dict(pred_dict['error']), indent=2)}" 220 | ) 221 | 222 | retry_data = [] 223 | for i in range(len(dataset)): 224 | if dataset[i]["_id"] in pred_dict["error"]: 225 | retry_data.append(dataset[i]) 226 | del pred_dict["error"][dataset[i]["_id"]] 227 | 228 | time.sleep(ROUND_WAITTIME) 229 | 230 | with Pool(processes=10) as pool: 231 | for _ in tqdm(pool.imap_unordered( 232 | partial(main, pred_dict=pred_dict), retry_data 233 | ), total=len(retry_data)): 234 | pass 235 | 236 | round += 1 237 | 238 | 239 | def main(data, pred_dict): 240 | asyncio.run(work(data, pred_dict)) 241 | 242 | 243 | if __name__ == "__main__": 244 | 245 | manager = Manager() 246 | 247 | pred_dict = manager.dict() 248 | 249 | initialize_pred_dict() 250 | 251 | dataset = prepare_dataset(total=NUM_SAMPLES_TOTAL, pred_ckpt=pred_dict) 252 | 253 | if PERSIST_LOGS: 254 | if not os.path.isdir(LOG_SAVE_DIR): 255 | os.mkdir(LOG_SAVE_DIR) 256 | if not os.path.isdir(NEW_LOG_DIR): 257 | os.mkdir(NEW_LOG_DIR) 258 | 259 | with Pool(processes=10) as pool: 260 | for _ in tqdm(pool.imap_unordered( 261 | partial(main, pred_dict=pred_dict), dataset 262 | ), total=len(dataset)): 263 | pass 264 | 265 | retry(dataset=dataset) 266 | 267 | save_output() 268 | 269 | eval(OUTPUT_FILE, GT_FILE) 270 | -------------------------------------------------------------------------------- /autoagents/eval/metrics.py: -------------------------------------------------------------------------------- 1 | import json 2 | import glob 3 | from argparse import ArgumentParser 4 | from collections import Counter, defaultdict 5 | import numpy as np 6 | from pprint import pprint 7 | 8 | 9 | def get_common_stats(log_files): 10 | stats = { 11 | "counts": Counter(), # general counters 12 | "error_counts": Counter(), # error counters 13 | "plan_counts": Counter(), # plan patterns 14 | "len_history_trace": [], 15 | "len_initial_plan": [] 16 | } 17 | samples = set() 18 | finished_samples = set() 19 | for file in log_files: 20 | with open(file, "r") as f: 21 | try: 22 | log_data = json.load(f) 23 | except json.decoder.JSONDecodeError: 24 | continue 25 | summary = get_summary_from_log_data(log_data=log_data) 26 | stats["counts"] += summary["counts"] 27 | stats["error_counts"] += summary["error_counts"] 28 | stats["plan_counts"] += summary["plan_counts"] 29 | stats["len_history_trace"].extend(summary["len_history_trace"]) 30 | stats["len_initial_plan"].extend(summary["len_initial_plan"]) 31 | if summary["question"] is not None: 32 | samples.add(summary["question"]) 33 | if summary["answer"] is not None: 34 | finished_samples.add(summary["question"]) 35 | stats["counts"]["total_samples"] = len(samples) 36 | stats["counts"]["finished_samples"] = len(finished_samples) 37 | 38 | hist, rng = np.histogram(stats["len_history_trace"], bins=range(0, 16)) 39 | stats["len_history_trace"] = hist.tolist() 40 | 41 | hist, rng = np.histogram(stats["len_initial_plan"], bins=range(0, 16)) 42 | stats["len_initial_plan"] = hist.tolist() 43 | 44 | return stats 45 | 46 | 47 | def get_summary_from_log_data(log_data: list): 48 | 49 | counts = Counter() # general counters 50 | error_counts = Counter() # error counters 51 | plan_counts = Counter() # plan patterns 52 | len_initial_plan = [] 53 | len_history_trace = [] 54 | 55 | summary = dict( 56 | counts=counts, 57 | error_counts=error_counts, 58 | plan_counts=plan_counts, 59 | len_history_trace=len_history_trace, 60 | len_initial_plan=len_initial_plan, 61 | question=None, 62 | answer=None 63 | ) 64 | 65 | # Handle errors and rewrites 66 | is_valid: bool = True 67 | counts["total_logs"] += 1 68 | for entry in log_data: 69 | if "id" in entry: 70 | counts["total_steps"] += 1 71 | if "goal" in entry: 72 | summary["question"] = entry["goal"] 73 | if "error" in entry: 74 | if "Expecting value" in entry["error"]: 75 | # This is the old rewrite error 76 | pass 77 | elif "Invalid tool requested by the model." in entry["error"]: 78 | error_counts["invalid_tools_error"] += 1 79 | is_valid = False 80 | elif "This model's maximum context length" in entry["error"]: 81 | error_counts["context_len_error"] += 1 82 | if len(log_data) < 4: 83 | is_valid = False 84 | elif "[Errno -3] Temporary failure in name resolution" in entry["error"]: 85 | error_counts["dns_error"] += 1 86 | elif "Could not parse LLM output:" in entry["error"]: 87 | error_counts["parse_error"] += 1 88 | is_valid = False 89 | elif "Rate limit reached for " in entry["error"]: 90 | error_counts["rate_limit_error"] += 1 91 | is_valid = False 92 | elif "[Errno 111] Connection refused" in entry["error"]: 93 | error_counts["connection_error"] += 1 94 | is_valid = False 95 | else: 96 | error_counts["other_error"] += 1 97 | is_valid = False 98 | elif "query_rewrite" in entry: 99 | counts["total_rewrites"] += 1 100 | 101 | if not is_valid: 102 | return summary 103 | counts["total_valid"] += 1 104 | 105 | for entry in log_data: 106 | if "conversations" in entry: 107 | counts["valid_steps"] += 1 108 | prediction = json.loads(entry["conversations"][-1]["value"]) 109 | action = prediction["action"] 110 | if action == "Tool_Search" or action == "Tool_Wikipedia": 111 | counts["search_invoked"] += 1 112 | elif action == "Tool_Notepad": 113 | counts["notepad_invoked"] += 1 114 | elif action == "Tool_Finish": 115 | summary["answer"] = prediction["action_input"] 116 | 117 | # do last-step history analysis, log_data[-3] 118 | try: 119 | last_convo = log_data[-3]["conversations"] 120 | if last_convo[1]["from"] in ("gpt", "ai"): 121 | # we don't have the system key in prompt_v3 122 | output = json.loads(last_convo[1]["value"]) 123 | else: 124 | output = json.loads(last_convo[2]["value"]) 125 | except: 126 | return summary 127 | 128 | counts[f"EndWith_{output['action']}"] += 1 129 | 130 | if last_convo[0]["from"] == "history": 131 | hist = last_convo[0]["value"] 132 | actions = [h["action"] for h in hist] 133 | if len(actions) < 5 and len(actions) > 0: 134 | actions_str = "->".join(actions) 135 | plan_counts[actions_str] += 1 136 | if actions_str == "Tool_Notepad": 137 | pass 138 | if actions_str == "Tool_Search->Tool_Search->Tool_Notepad": 139 | pass 140 | if actions_str == "Tool_Search->Tool_Search->Tool_Search->Tool_Notepad": 141 | pass 142 | plans = [] 143 | for plan in [h["plan"] for h in hist]: 144 | plans.extend(plan) 145 | for plan in plans: 146 | if plan.startswith("Visit"): 147 | counts["visit_in_plan"] += 1 148 | break 149 | 150 | len_hist = len(hist) 151 | if len_hist > 0: 152 | len_plan0 = len(hist[0]["plan"]) 153 | len_initial_plan.append(len_plan0) 154 | if len_plan0 == 1: 155 | pass 156 | counts["len_hist"] += len_hist 157 | len_history_trace.append(len_hist) 158 | 159 | # find out if there are duplicate action+action_inputs 160 | inputs = defaultdict(set) 161 | plans = set() 162 | for h in hist + [output]: 163 | if h["action"] in inputs: 164 | if h["action_input"] in inputs[h["action"]]: 165 | if output["action"] == "Tool_Finish": 166 | counts["Finish_with_dups"] += 1 167 | break 168 | else: # only count duplicates that didn't finish 169 | counts["duplicate_actions"] += 1 170 | break 171 | inputs[h["action"]].add(h["action_input"]) 172 | 173 | return summary 174 | 175 | 176 | def main(): 177 | 178 | parser = ArgumentParser() 179 | parser.add_argument("log_dir", type=str, help="path of the log directory") 180 | args = parser.parse_args() 181 | 182 | stats = get_common_stats(log_files=glob.glob(f"{args.log_dir}/*.json")) 183 | pprint(stats) 184 | 185 | 186 | if __name__ == "__main__": 187 | main() 188 | -------------------------------------------------------------------------------- /autoagents/eval/reward/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import asyncio 4 | import argparse 5 | from tqdm.asyncio import tqdm_asyncio 6 | from langchain.schema import HumanMessage 7 | from langchain.chat_models import ChatOpenAI 8 | 9 | 10 | PARENT_DIR: str = os.path.dirname(os.path.abspath(__file__)) 11 | DATASET_FILE: str = os.path.join(PARENT_DIR, "transformed.json") 12 | NUM_DATA: int = 100 13 | EVAL_MODEL_NAME: str = "gpt-4" 14 | AWAIT_TIMEOUT: int = 360 15 | RETRY_COUNT = 2 16 | 17 | 18 | def get_dataset(dataset_file: str = DATASET_FILE): 19 | with open(dataset_file, 'r') as f: 20 | dataset = json.load(f) 21 | return dataset 22 | 23 | 24 | async def evaluate(data, model, results): 25 | 26 | conversations = data["conversations"] 27 | history = conversations[0]["value"] 28 | model_output = conversations[1]["value"] 29 | 30 | prompt = f"""Given an input with a final goal, a set of detailed instructions and a history of the query/thoughts/plan/action/observation, the search agent need to generate a next step plan to fulfill user's input query. Evalluate a set of plans generated by a search agent. The input and the generated plans are both delimited by triple backticks. 31 | 32 | Input: 33 | ``` 34 | {history} 35 | ``` 36 | 37 | Next step plan of Agent: 38 | ``` 39 | {model_output} 40 | ``` 41 | 42 | Consider evaluate the next step plan based on the following aspects/metrics: 43 | - Clarity: The plan should be clear and understandable. 44 | - Effectiveness: The plan should be correct and move towards the final goal to answer the input query. 45 | 46 | Your answer should include a score on a scale of 0 to 5, where 0 means terrible, 1 means bad, 2 means acceptable, 3 means okay, 4 means good, and 5 means wonderful. Format your output in a json consisting of overall_score, overall_judgement, clarity_score, clarity_reasoning, effectiveness_score, and effectiveness_reasoning. 47 | """ 48 | 49 | llm = ChatOpenAI( 50 | openai_api_key=os.getenv("OPENAI_API_KEY"), 51 | openai_organization=os.getenv("OPENAI_API_ORG"), 52 | temperature=0, 53 | model=model, 54 | request_timeout=AWAIT_TIMEOUT 55 | ) 56 | 57 | retry_cnt = 0 58 | while retry_cnt <= RETRY_COUNT: 59 | try: 60 | resp = await llm.agenerate([[HumanMessage(content=prompt)]]) 61 | resp_obj = json.loads(resp.generations[0][0].text.strip()) 62 | resp_obj["data_id"] = data["id"] 63 | 64 | results.append(resp_obj) 65 | break 66 | 67 | except Exception as e: 68 | print(e) 69 | retry_cnt += 1 70 | 71 | 72 | async def main(dataset, model, num_data): 73 | 74 | results = [] 75 | 76 | semaphore = asyncio.Semaphore(10) 77 | 78 | async def process_data(data): 79 | async with semaphore: 80 | await evaluate(data, model, results) 81 | 82 | await tqdm_asyncio.gather(*[process_data(data) for data in dataset]) 83 | 84 | with open(f"response_eval_{model}_{num_data}.json", 'w') as f: 85 | json.dump(results, f, indent=2) 86 | 87 | 88 | if __name__ == "__main__": 89 | parser = argparse.ArgumentParser() 90 | parser.add_argument( 91 | "--dataset_file", type=str, default=DATASET_FILE, 92 | help="file containing the dataset" 93 | ) 94 | parser.add_argument("--eval_model", type=str, default=EVAL_MODEL_NAME) 95 | parser.add_argument("--num_data", type=int, default=NUM_DATA) 96 | args = parser.parse_args() 97 | 98 | dataset = get_dataset(dataset_file=args.dataset_file)[:args.num_data] 99 | asyncio.run(main(dataset, args.eval_model, args.num_data)) 100 | -------------------------------------------------------------------------------- /autoagents/eval/reward/get_scores.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | 6 | if __name__ == "__main__": 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument( 9 | "input_file", type=str, 10 | help="file containing the evaluation results" 11 | ) 12 | args = parser.parse_args() 13 | 14 | if os.path.isfile(args.input_file): 15 | with open(args.input_file, 'r') as f: 16 | results = json.load(f) 17 | num = len(results) 18 | 19 | stats = {"avg_overall": 0, "avg_clarity": 0, "avg_effectiveness": 0} 20 | for obj in results: 21 | stats["avg_overall"] += obj.get("overall_score", 0) 22 | stats["avg_clarity"] += obj.get("clarity_score", 0) 23 | stats["avg_effectiveness"] += obj.get("effectiveness_score", 0) 24 | 25 | for key in stats: 26 | stats[key] /= num 27 | stats["num"] = num 28 | print(stats) 29 | -------------------------------------------------------------------------------- /autoagents/eval/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import json 4 | import os 5 | from tqdm.asyncio import tqdm_asyncio 6 | 7 | from autoagents.agents.agents.search import ActionRunner 8 | from autoagents.agents.agents.wiki_agent import WikiActionRunner, WikiActionRunnerV3 9 | from autoagents.agents.agents.search_v3 import ActionRunnerV3 10 | from autoagents.agents.models.custom import CustomLLM, CustomLLMV3 11 | from autoagents.agents.utils.constants import LOG_SAVE_DIR 12 | from autoagents.data.dataset import BAMBOOGLE, DEFAULT_Q, FT, HF 13 | from autoagents.eval.bamboogle import eval as eval_bamboogle 14 | from autoagents.eval.hotpotqa.eval_async import HotpotqaAsyncEval, NUM_SAMPLES_TOTAL 15 | from langchain.chat_models import ChatOpenAI 16 | 17 | 18 | OPENAI_MODEL_NAMES = {"gpt-3.5-turbo", "gpt-4"} 19 | AWAIT_TIMEOUT: int = 120 20 | MAX_RETRIES: int = 2 21 | 22 | 23 | async def work(user_input: str, model: str, temperature: int, agent: str, prompt_version: str, persist_logs: bool, log_save_dir: str): 24 | if model not in OPENAI_MODEL_NAMES: 25 | if prompt_version == "v2": 26 | llm = CustomLLM( 27 | model_name=model, 28 | temperature=temperature, 29 | request_timeout=AWAIT_TIMEOUT 30 | ) 31 | elif prompt_version == "v3": 32 | llm = CustomLLMV3( 33 | model_name=model, 34 | temperature=temperature, 35 | request_timeout=AWAIT_TIMEOUT 36 | ) 37 | else: 38 | llm = ChatOpenAI( 39 | openai_api_key=os.getenv("OPENAI_API_KEY"), 40 | openai_organization=os.getenv("OPENAI_API_ORG"), 41 | temperature=temperature, 42 | model_name=model, 43 | request_timeout=AWAIT_TIMEOUT 44 | ) 45 | 46 | retry_count = 0 47 | while retry_count < MAX_RETRIES: 48 | outputq = asyncio.Queue() 49 | if agent == "ddg": 50 | if prompt_version == "v2": 51 | runner = ActionRunner(outputq, llm=llm, persist_logs=persist_logs) 52 | elif prompt_version == "v3": 53 | runner = ActionRunnerV3(outputq, llm=llm, persist_logs=persist_logs) 54 | elif agent == "wiki": 55 | if prompt_version == "v2": 56 | runner = WikiActionRunner(outputq, llm=llm, persist_logs=persist_logs) 57 | elif prompt_version == "v3": 58 | runner = WikiActionRunnerV3(outputq, llm=llm, persist_logs=persist_logs) 59 | task = asyncio.create_task(runner.run(user_input, outputq, log_save_dir)) 60 | while True: 61 | try: 62 | output = await asyncio.wait_for(outputq.get(), AWAIT_TIMEOUT) 63 | except asyncio.TimeoutError: 64 | task.cancel() 65 | retry_count += 1 66 | break 67 | if isinstance(output, RuntimeWarning): 68 | print(f"Question: {user_input}") 69 | print(output) 70 | continue 71 | elif isinstance(output, Exception): 72 | task.cancel() 73 | print(f"Question: {user_input}") 74 | print(output) 75 | retry_count += 1 76 | break 77 | try: 78 | parsed = json.loads(output) 79 | print(json.dumps(parsed, indent=2)) 80 | print("-----------------------------------------------------------") 81 | if parsed["action"] == "Tool_Finish": 82 | return await task 83 | except: 84 | print(f"Question: {user_input}") 85 | print(output) 86 | print("-----------------------------------------------------------") 87 | 88 | 89 | async def main(questions, args): 90 | sem = asyncio.Semaphore(10) 91 | 92 | async def safe_work(user_input: str, model: str, temperature: int, agent: str, prompt_version: str, persist_logs: bool, log_save_dir: str): 93 | async with sem: 94 | return await work(user_input, model, temperature, agent, prompt_version, persist_logs, log_save_dir) 95 | 96 | persist_logs = True if args.persist_logs else False 97 | await tqdm_asyncio.gather(*[safe_work(q, args.model, args.temperature, args.agent, args.prompt_version, persist_logs, args.log_save_dir) for q in questions]) 98 | 99 | 100 | if __name__ == "__main__": 101 | parser = argparse.ArgumentParser() 102 | parser.add_argument("--model", type=str, default="gpt-3.5-turbo", help="model to be tested") 103 | parser.add_argument("--temperature", type=float, default=0, help="model temperature") 104 | parser.add_argument("--agent", 105 | default="ddg", 106 | const="ddg", 107 | nargs="?", 108 | choices=("ddg", "wiki"), 109 | help='which action agent we want to interact with(default: ddg)' 110 | ) 111 | parser.add_argument("--persist-logs", action="store_true", help="persist logs on disk, enable this feature for later eval purpose") 112 | parser.add_argument("--log-save-dir", type=str, default=LOG_SAVE_DIR, help="dir to save logs") 113 | parser.add_argument("--dataset", 114 | default="default", 115 | const="default", 116 | nargs="?", 117 | choices=("default", "hotpotqa", "ft", "hf", "bamboogle"), 118 | help='which dataset we want to interact with(default: default)' 119 | ) 120 | parser.add_argument("--eval", action="store_true", help="enable automatic eval") 121 | parser.add_argument("--prompt-version", 122 | default="v2", 123 | const="v3", 124 | nargs="?", 125 | choices=("v2", "v3"), 126 | help='which version of prompt to use(default: v2)' 127 | ) 128 | parser.add_argument("--slice", type=int, help="slice the dataset from left, question list will start from index 0 to slice - 1") 129 | args = parser.parse_args() 130 | print(args) 131 | if args.prompt_version == "v3" and args.model in OPENAI_MODEL_NAMES: 132 | raise ValueError("Prompt v3 is not compatiable with OPENAI models, please adjust your settings!") 133 | if not args.persist_logs and args.eval: 134 | raise ValueError("Please enable persist_logs feature to allow eval code to run!") 135 | if not args.log_save_dir and args.persist_logs: 136 | raise ValueError("Please endbale persist_logs feature to configure log dir location!") 137 | questions = [] 138 | if args.dataset == "ft": 139 | questions = [q for _, q in FT] 140 | elif args.dataset == "hf": 141 | questions = [q for _, q in HF] 142 | elif args.dataset == "hotpotqa": 143 | hotpotqa_eval = HotpotqaAsyncEval(model=args.model) 144 | questions = hotpotqa_eval.get_questions(args.slice or NUM_SAMPLES_TOTAL) 145 | elif args.dataset == "bamboogle": 146 | questions = BAMBOOGLE["questions"] 147 | else: 148 | questions = [q for _, q in DEFAULT_Q] 149 | if args.slice and args.dataset != "hotpotqa": 150 | questions = questions[:args.slice] 151 | asyncio.run(main(questions, args)) 152 | if args.eval: 153 | if args.dataset == "bamboogle": 154 | if args.log_save_dir: 155 | asyncio.run(eval_bamboogle(args.log_save_dir)) 156 | else: 157 | asyncio.run(eval_bamboogle()) 158 | elif args.dataset == "hotpotqa": 159 | if args.log_save_dir: 160 | hotpotqa_eval.run(args.log_save_dir) 161 | else: 162 | hotpotqa_eval.run() 163 | -------------------------------------------------------------------------------- /autoagents/serve/README.md: -------------------------------------------------------------------------------- 1 | ## SERVING 2 | 3 | To serve the models, run the following simultaniously (in multiple terminals or using background tasks) 4 | 5 | 6 | ``` 7 | bash autoagents/serve/controller.sh 8 | ``` 9 | 10 | ``` 11 | bash autoagents/serve/openai_api.sh 12 | ``` 13 | 14 | ``` 15 | MODEL_PATH=/some/path/to/your/model CONDENSE_RESCALE=1 bash autoagents/serve/model_worker.sh 16 | ``` 17 | 18 | You may have multiple `model_worker.sh` instances. If you are using LongChat, set `CONDENSE_RESCALE` to be whatever scaling you are using (e.g. 4 or 8) 19 | 20 | ### Prompt V3 serving 21 | 22 | 1. Start the model server 23 | ``` 24 | python3 autoagents/serve/action_model_worker.py --model-path /path/to/model/checkpoint --controller http://localhost:21001 --port 31008 --worker http://localhost:31008 25 | ``` 26 | 27 | 2. Start the completion API server, default address http://localhost:8004 28 | ``` 29 | python3 autoagents/serve/action_api_server.py 30 | ``` 31 | -------------------------------------------------------------------------------- /autoagents/serve/action_api_server.py: -------------------------------------------------------------------------------- 1 | """A server that provides OpenAI-compatible RESTful APIs. It supports: 2 | 3 | - Completions 4 | 5 | Usage: 6 | python3 -m autoagents.serve.action_api_server 7 | """ 8 | import asyncio 9 | import argparse 10 | import json 11 | import logging 12 | from typing import Optional, Union, Dict, List, Any 13 | 14 | import fastapi 15 | from fastapi.middleware.cors import CORSMiddleware 16 | from fastapi.responses import StreamingResponse, JSONResponse 17 | import httpx 18 | from pydantic import BaseSettings 19 | import shortuuid 20 | import tiktoken 21 | import uvicorn 22 | 23 | from fastchat.constants import ( 24 | WORKER_API_TIMEOUT, 25 | ErrorCode, 26 | ) 27 | from fastchat.conversation import Conversation, SeparatorStyle 28 | from fastapi.exceptions import RequestValidationError 29 | from fastchat.protocol.openai_api_protocol import ( 30 | CompletionRequest, 31 | CompletionResponse, 32 | CompletionResponseChoice, 33 | CompletionResponseStreamChoice, 34 | CompletionStreamResponse, 35 | ErrorResponse, 36 | ModelCard, 37 | ModelList, 38 | ModelPermission, 39 | UsageInfo, 40 | ) 41 | 42 | logger = logging.getLogger(__name__) 43 | 44 | conv_template_map = {} 45 | 46 | 47 | class AppSettings(BaseSettings): 48 | # The address of the model controller. 49 | controller_address: str = "http://localhost:21001" 50 | api_keys: List[str] = None 51 | 52 | 53 | app_settings = AppSettings() 54 | app = fastapi.FastAPI() 55 | headers = {"User-Agent": "FastChat API Server"} 56 | 57 | 58 | def create_error_response(code: int, message: str) -> JSONResponse: 59 | return JSONResponse( 60 | ErrorResponse(message=message, code=code).dict(), status_code=400 61 | ) 62 | 63 | 64 | @app.exception_handler(RequestValidationError) 65 | async def validation_exception_handler(request, exc): 66 | return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc)) 67 | 68 | 69 | async def check_model(request) -> Optional[JSONResponse]: 70 | controller_address = app_settings.controller_address 71 | ret = None 72 | async with httpx.AsyncClient() as client: 73 | try: 74 | _worker_addr = await get_worker_address(request.model, client) 75 | except: 76 | models_ret = await client.post(controller_address + "/list_models") 77 | models = models_ret.json()["models"] 78 | ret = create_error_response( 79 | ErrorCode.INVALID_MODEL, 80 | f"Only {'&&'.join(models)} allowed now, your model {request.model}", 81 | ) 82 | return ret 83 | 84 | 85 | async def check_length(request, prompt, max_tokens): 86 | async with httpx.AsyncClient() as client: 87 | worker_addr = await get_worker_address(request.model, client) 88 | 89 | response = await client.post( 90 | worker_addr + "/model_details", 91 | headers=headers, 92 | json={"model": request.model}, 93 | timeout=WORKER_API_TIMEOUT, 94 | ) 95 | context_len = response.json()["context_length"] 96 | 97 | response = await client.post( 98 | worker_addr + "/count_token", 99 | headers=headers, 100 | json={"model": request.model, "prompt": json.dumps(prompt)}, 101 | timeout=WORKER_API_TIMEOUT, 102 | ) 103 | token_num = response.json()["count"] 104 | 105 | if token_num + max_tokens > context_len: 106 | return create_error_response( 107 | ErrorCode.CONTEXT_OVERFLOW, 108 | f"This model's maximum context length is {context_len} tokens. " 109 | f"However, you requested {max_tokens + token_num} tokens " 110 | f"({token_num} in the messages, " 111 | f"{max_tokens} in the completion). " 112 | f"Please reduce the length of the messages or completion.", 113 | ) 114 | else: 115 | return None 116 | 117 | 118 | def check_requests(request) -> Optional[JSONResponse]: 119 | # Check all params 120 | if request.max_tokens is not None and request.max_tokens <= 0: 121 | return create_error_response( 122 | ErrorCode.PARAM_OUT_OF_RANGE, 123 | f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'", 124 | ) 125 | if request.n is not None and request.n <= 0: 126 | return create_error_response( 127 | ErrorCode.PARAM_OUT_OF_RANGE, 128 | f"{request.n} is less than the minimum of 1 - 'n'", 129 | ) 130 | if request.temperature is not None and request.temperature < 0: 131 | return create_error_response( 132 | ErrorCode.PARAM_OUT_OF_RANGE, 133 | f"{request.temperature} is less than the minimum of 0 - 'temperature'", 134 | ) 135 | if request.temperature is not None and request.temperature > 2: 136 | return create_error_response( 137 | ErrorCode.PARAM_OUT_OF_RANGE, 138 | f"{request.temperature} is greater than the maximum of 2 - 'temperature'", 139 | ) 140 | if request.top_p is not None and request.top_p < 0: 141 | return create_error_response( 142 | ErrorCode.PARAM_OUT_OF_RANGE, 143 | f"{request.top_p} is less than the minimum of 0 - 'top_p'", 144 | ) 145 | if request.top_p is not None and request.top_p > 1: 146 | return create_error_response( 147 | ErrorCode.PARAM_OUT_OF_RANGE, 148 | f"{request.top_p} is greater than the maximum of 1 - 'temperature'", 149 | ) 150 | if request.stop is not None and ( 151 | not isinstance(request.stop, str) and not isinstance(request.stop, list) 152 | ): 153 | return create_error_response( 154 | ErrorCode.PARAM_OUT_OF_RANGE, 155 | f"{request.stop} is not valid under any of the given schemas - 'stop'", 156 | ) 157 | 158 | return None 159 | 160 | 161 | def process_input(model_name, inp): 162 | if isinstance(inp, str): 163 | inp = [inp] 164 | elif isinstance(inp, list): 165 | if isinstance(inp[0], int): 166 | decoding = tiktoken.model.encoding_for_model(model_name) 167 | inp = [decoding.decode(inp)] 168 | elif isinstance(inp[0], list): 169 | decoding = tiktoken.model.encoding_for_model(model_name) 170 | inp = [decoding.decode(text) for text in inp] 171 | 172 | return inp 173 | 174 | 175 | async def get_gen_params( 176 | model_name: str, 177 | messages: Union[str, List[Dict[str, str]]], 178 | *, 179 | temperature: float, 180 | top_p: float, 181 | max_tokens: Optional[int], 182 | echo: Optional[bool], 183 | stream: Optional[bool], 184 | stop: Optional[Union[str, List[str]]], 185 | ) -> Dict[str, Any]: 186 | conv = await get_conv(model_name) 187 | conv = Conversation( 188 | name=conv["name"], 189 | system_message=conv["system_message"], 190 | roles=conv["roles"], 191 | messages=list(conv["messages"]), # prevent in-place modification 192 | offset=conv["offset"], 193 | sep_style=SeparatorStyle(conv["sep_style"]), 194 | sep=conv["sep"], 195 | sep2=conv["sep2"], 196 | stop_str=conv["stop_str"], 197 | stop_token_ids=conv["stop_token_ids"], 198 | ) 199 | 200 | if isinstance(messages, str): 201 | prompt = messages 202 | else: 203 | for message in messages: 204 | msg_role = message["role"] 205 | if msg_role == "goal": 206 | conv.append_message(conv.roles[0], json.dumps(message["content"])) 207 | elif msg_role == "tools": 208 | conv.append_message(conv.roles[1], json.dumps(message["content"])) 209 | elif msg_role == "history": 210 | conv.append_message(conv.roles[2], json.dumps(message["content"])) 211 | else: 212 | raise ValueError(f"Unknown role: {msg_role}") 213 | 214 | # Add a blank message for the assistant. 215 | conv.append_message(conv.roles[3], None) 216 | prompt = conv.get_prompt() 217 | 218 | if max_tokens is None: 219 | max_tokens = 512 220 | gen_params = { 221 | "model": model_name, 222 | "prompt": prompt, 223 | "temperature": temperature, 224 | "top_p": top_p, 225 | "max_new_tokens": max_tokens, 226 | "echo": echo, 227 | "stream": stream, 228 | } 229 | 230 | if not stop: 231 | gen_params.update( 232 | {"stop": conv.stop_str, "stop_token_ids": conv.stop_token_ids} 233 | ) 234 | else: 235 | gen_params.update({"stop": stop}) 236 | 237 | logger.info(f"==== request ====\n{gen_params}") 238 | return gen_params 239 | 240 | 241 | async def get_worker_address(model_name: str, client: httpx.AsyncClient) -> str: 242 | """ 243 | Get worker address based on the requested model 244 | 245 | :param model_name: The worker's model name 246 | :param client: The httpx client to use 247 | :return: Worker address from the controller 248 | :raises: :class:`ValueError`: No available worker for requested model 249 | """ 250 | controller_address = app_settings.controller_address 251 | 252 | ret = await client.post( 253 | controller_address + "/get_worker_address", json={"model": model_name} 254 | ) 255 | worker_addr = ret.json()["address"] 256 | # No available worker 257 | if worker_addr == "": 258 | raise ValueError(f"No available worker for {model_name}") 259 | 260 | logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") 261 | return worker_addr 262 | 263 | 264 | async def get_conv(model_name: str): 265 | async with httpx.AsyncClient() as client: 266 | worker_addr = await get_worker_address(model_name, client) 267 | conv_template = conv_template_map.get((worker_addr, model_name)) 268 | if conv_template is None: 269 | response = await client.post( 270 | worker_addr + "/worker_get_conv_template", 271 | headers=headers, 272 | json={"model": model_name}, 273 | timeout=WORKER_API_TIMEOUT, 274 | ) 275 | conv_template = response.json()["conv"] 276 | conv_template_map[(worker_addr, model_name)] = conv_template 277 | return conv_template 278 | 279 | 280 | @app.get("/v1/models") 281 | async def show_available_models(): 282 | controller_address = app_settings.controller_address 283 | async with httpx.AsyncClient() as client: 284 | ret = await client.post(controller_address + "/refresh_all_workers") 285 | ret = await client.post(controller_address + "/list_models") 286 | models = ret.json()["models"] 287 | models.sort() 288 | # TODO: return real model permission details 289 | model_cards = [] 290 | for m in models: 291 | model_cards.append(ModelCard(id=m, root=m, permission=[ModelPermission()])) 292 | return ModelList(data=model_cards) 293 | 294 | 295 | 296 | @app.post("/v1/completions") 297 | async def create_completion(request: CompletionRequest): 298 | if (error_check_ret := await check_model(request)): 299 | return error_check_ret 300 | if (error_check_ret := check_requests(request)): 301 | return error_check_ret 302 | 303 | if (error_check_ret := await check_length(request, request.prompt, request.max_tokens)): 304 | return error_check_ret 305 | 306 | if request.stream: 307 | generator = generate_completion_stream_generator(request, request.n) 308 | return StreamingResponse(generator, media_type="text/event-stream") 309 | else: 310 | text_completions = [] 311 | gen_params = await get_gen_params( 312 | request.model, 313 | request.prompt, 314 | temperature=request.temperature, 315 | top_p=request.top_p, 316 | max_tokens=request.max_tokens, 317 | echo=request.echo, 318 | stream=request.stream, 319 | stop=request.stop, 320 | ) 321 | logger.info(f"{gen_params}") 322 | for i in range(request.n): 323 | content = asyncio.create_task(generate_completion(gen_params)) 324 | text_completions.append(content) 325 | 326 | try: 327 | all_tasks = await asyncio.wait_for(asyncio.gather(*text_completions), timeout=500) 328 | except Exception as e: 329 | return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) 330 | 331 | choices = [] 332 | usage = UsageInfo() 333 | for i, content in enumerate(all_tasks): 334 | if content["error_code"] != 0: 335 | return create_error_response(content["error_code"], content["text"]) 336 | choices.append( 337 | CompletionResponseChoice( 338 | index=i, 339 | text=content["text"], 340 | logprobs=content.get("logprobs", None), 341 | finish_reason=content.get("finish_reason", "stop"), 342 | ) 343 | ) 344 | task_usage = UsageInfo.parse_obj(content["usage"]) 345 | for usage_key, usage_value in task_usage.dict().items(): 346 | setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) 347 | 348 | return CompletionResponse( 349 | model=request.model, choices=choices, usage=UsageInfo.parse_obj(usage) 350 | ) 351 | 352 | 353 | async def generate_completion_stream_generator(request: CompletionRequest, n: int): 354 | model_name = request.model 355 | id = f"cmpl-{shortuuid.random()}" 356 | finish_stream_events = [] 357 | for text in request.prompt: 358 | for i in range(n): 359 | previous_text = "" 360 | gen_params = await get_gen_params( 361 | request.model, 362 | text, 363 | temperature=request.temperature, 364 | top_p=request.top_p, 365 | max_tokens=request.max_tokens, 366 | echo=request.echo, 367 | stream=request.stream, 368 | stop=request.stop, 369 | ) 370 | async for content in generate_completion_stream(gen_params): 371 | if content["error_code"] != 0: 372 | yield f"data: {json.dumps(content, ensure_ascii=False)}\n\n" 373 | yield "data: [DONE]\n\n" 374 | return 375 | decoded_unicode = content["text"].replace("\ufffd", "") 376 | delta_text = decoded_unicode[len(previous_text) :] 377 | previous_text = decoded_unicode 378 | # todo: index is not apparent 379 | choice_data = CompletionResponseStreamChoice( 380 | index=i, 381 | text=delta_text, 382 | logprobs=content.get("logprobs", None), 383 | finish_reason=content.get("finish_reason", None), 384 | ) 385 | chunk = CompletionStreamResponse( 386 | id=id, 387 | object="text_completion", 388 | choices=[choice_data], 389 | model=model_name, 390 | ) 391 | if len(delta_text) == 0: 392 | if content.get("finish_reason", None) is not None: 393 | finish_stream_events.append(chunk) 394 | continue 395 | yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" 396 | # There is not "content" field in the last delta message, so exclude_none to exclude field "content". 397 | for finish_chunk in finish_stream_events: 398 | yield f"data: {finish_chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" 399 | yield "data: [DONE]\n\n" 400 | 401 | 402 | async def generate_completion_stream(payload: Dict[str, Any]): 403 | async with httpx.AsyncClient() as client: 404 | worker_addr = await get_worker_address(payload["model"], client) 405 | delimiter = b"\0" 406 | async with client.stream( 407 | "POST", 408 | worker_addr + "/worker_generate_stream", 409 | headers=headers, 410 | json=payload, 411 | timeout=WORKER_API_TIMEOUT, 412 | ) as response: 413 | # content = await response.aread() 414 | async for raw_chunk in response.aiter_raw(): 415 | for chunk in raw_chunk.split(delimiter): 416 | if not chunk: 417 | continue 418 | data = json.loads(chunk.decode()) 419 | yield data 420 | 421 | 422 | async def generate_completion(payload: Dict[str, Any]): 423 | async with httpx.AsyncClient() as client: 424 | worker_addr = await get_worker_address(payload["model"], client) 425 | 426 | response = await client.post( 427 | worker_addr + "/worker_generate", 428 | headers=headers, 429 | json=payload, 430 | timeout=WORKER_API_TIMEOUT, 431 | ) 432 | completion = response.json() 433 | return completion 434 | 435 | 436 | if __name__ == "__main__": 437 | parser = argparse.ArgumentParser( 438 | description="FastChat ChatGPT-Compatible RESTful API server." 439 | ) 440 | parser.add_argument("--host", type=str, default="localhost", help="host name") 441 | parser.add_argument("--port", type=int, default=8004, help="port number") 442 | parser.add_argument( 443 | "--controller-address", type=str, default="http://localhost:21001" 444 | ) 445 | parser.add_argument( 446 | "--allow-credentials", action="store_true", help="allow credentials" 447 | ) 448 | parser.add_argument( 449 | "--allowed-origins", type=json.loads, default=["*"], help="allowed origins" 450 | ) 451 | parser.add_argument( 452 | "--allowed-methods", type=json.loads, default=["*"], help="allowed methods" 453 | ) 454 | parser.add_argument( 455 | "--allowed-headers", type=json.loads, default=["*"], help="allowed headers" 456 | ) 457 | args = parser.parse_args() 458 | 459 | app.add_middleware( 460 | CORSMiddleware, 461 | allow_origins=args.allowed_origins, 462 | allow_credentials=args.allow_credentials, 463 | allow_methods=args.allowed_methods, 464 | allow_headers=args.allowed_headers, 465 | ) 466 | app_settings.controller_address = args.controller_address 467 | 468 | logger.info(f"args: {args}") 469 | 470 | uvicorn.run(app, host=args.host, port=args.port, log_level="info") 471 | -------------------------------------------------------------------------------- /autoagents/serve/action_model_worker.py: -------------------------------------------------------------------------------- 1 | import uvicorn 2 | import fastchat.serve.model_worker as model_worker 3 | from fastchat.conversation import ( 4 | SeparatorStyle, 5 | Conversation, 6 | register_conv_template, 7 | get_conv_template 8 | ) 9 | from fastchat.model.model_adapter import ( 10 | register_model_adapter, 11 | BaseModelAdapter 12 | ) 13 | from transformers import ( 14 | AutoConfig, 15 | AutoModelForCausalLM, 16 | AutoTokenizer 17 | ) 18 | 19 | class ActionAdapter(BaseModelAdapter): 20 | """The model adapter for Action Vicuna""" 21 | 22 | def match(self, model_path: str): 23 | return "action" in model_path 24 | 25 | def load_model(self, model_path: str, from_pretrained_kwargs: dict): 26 | config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) 27 | tokenizer = AutoTokenizer.from_pretrained( 28 | model_path, config=config, trust_remote_code=True 29 | ) 30 | model = AutoModelForCausalLM.from_pretrained( 31 | model_path, 32 | config=config, 33 | trust_remote_code=True, 34 | low_cpu_mem_usage=True, 35 | **from_pretrained_kwargs, 36 | ) 37 | return model, tokenizer 38 | 39 | def get_default_conv_template(self, model_path: str) -> Conversation: 40 | return get_conv_template("action") 41 | 42 | 43 | if __name__ == "__main__": 44 | # Action LLM default template 45 | register_conv_template( 46 | Conversation( 47 | name="action", 48 | system_message="Below is a goal you need to achieve. Given the available tools and history of past actions provide the next action to perform.", 49 | roles=("### Goal", "### Tools", "### History", "### Next action"), 50 | messages=(), 51 | offset=0, 52 | sep_style=SeparatorStyle.ADD_COLON_SINGLE, 53 | sep="\n\n", # separator between roles 54 | ) 55 | ) 56 | register_model_adapter(ActionAdapter) 57 | args, model_worker.worker = model_worker.create_model_worker() 58 | # hardcode the conv template 59 | args.conv_template = "action" 60 | uvicorn.run(model_worker.app, host=args.host, port=args.port, log_level="info") 61 | -------------------------------------------------------------------------------- /autoagents/serve/controller.sh: -------------------------------------------------------------------------------- 1 | python3 -m fastchat.serve.controller 2 | -------------------------------------------------------------------------------- /autoagents/serve/model_worker.sh: -------------------------------------------------------------------------------- 1 | CONDENSE_RESCALE=1 python3 serve/serve_rescale.py --model-path $MODEL_PATH 2 | -------------------------------------------------------------------------------- /autoagents/serve/openai_api.sh: -------------------------------------------------------------------------------- 1 | python3 -m fastchat.serve.openai_api_server --host localhost --port 8000 2 | -------------------------------------------------------------------------------- /autoagents/serve/serve_rescale.py: -------------------------------------------------------------------------------- 1 | import uvicorn 2 | 3 | from fastchat.train.llama_flash_attn_monkey_patch import ( 4 | replace_llama_attn_with_flash_attn, 5 | ) 6 | 7 | import os 8 | 9 | # TODO: change from env var to proper args (passing the rest of the args to create_model_worker) 10 | rescale = int(os.environ.get("CONDENSE_RESCALE", 1)) 11 | if rescale > 1: 12 | from longchat.train.monkey_patch.llama_condense_monkey_patch import replace_llama_with_condense 13 | replace_llama_with_condense(rescale) 14 | 15 | 16 | 17 | from fastchat.serve.model_worker import create_model_worker 18 | 19 | if __name__ == "__main__": 20 | args, worker = create_model_worker() 21 | uvicorn.run(app, host=args.host, port=args.port, log_level="info") 22 | -------------------------------------------------------------------------------- /autoagents/train/README.md: -------------------------------------------------------------------------------- 1 | ## Fine Tuning 2 | 3 | To train an action finetuned model from Llama 2, do the following: 4 | 5 | 1. Ensure that your cluster has enough GPU RAM for your model. Modify the `--nproc_per_node` and `--nnodes` values in each of the `train/scripts/` files to match your topology. If you are using only a single machine, you may hardcode `$RANK` to 0. Otherwise, ensure that `--master_addr` is set to the address of the rank 0 worker. 6 | 2. Modify `--data_path`, `--output_dir`, and `--model_name_or_path` to match the data and models you are using. Note that the output of `conv_finetuning.sh` and `longchat_conv_finetuning.sh` will be used as the model input for `action_finetuning.sh` and `longchat_action_finetuning.sh`. If using LongChat, make sure that `CONDENSE_RESCALE` is set to the right value. 7 | 3. Run either `conv_finetuning.sh` or `longchat_conv_finetuning.sh` with the appropriate environment variables set (see step 1). You will need to run it from the root directory of this repository. 8 | 4. Run either `action_finetuning.sh` or `longchat_action_finetuning.sh` with the appropriate environment variables set (see step 1). Make sure that if you use LongChat for the first finetuning, you use LongChat for the second. You will need to run it from the root directory of this repository. 9 | 10 | ### For the action V3 prompt 11 | 12 | Follow steps 1 and 2 above, and run `action_finetuning_v3.sh`. 13 | -------------------------------------------------------------------------------- /autoagents/train/scripts/action_finetuning.sh: -------------------------------------------------------------------------------- 1 | python3 -m torch.distributed.run --nproc_per_node=8 \ 2 | --master_port=20005 --nnodes=1 --node_rank=$RANK --master_addr=127.0.0.1 \ 3 | autoagents/train/train.py \ 4 | --model_name_or_path input-model \ 5 | --data_path json-data.json \ 6 | --bf16 True \ 7 | --output_dir output-path \ 8 | --num_train_epochs 2 \ 9 | --per_device_train_batch_size 1 \ 10 | --per_device_eval_batch_size 1 \ 11 | --gradient_accumulation_steps 16 \ 12 | --evaluation_strategy "no" \ 13 | --save_strategy "steps" \ 14 | --save_steps 40 \ 15 | --save_total_limit 20 \ 16 | --learning_rate 2e-5 \ 17 | --weight_decay 0. \ 18 | --warmup_ratio 0.03 \ 19 | --lr_scheduler_type "cosine" \ 20 | --logging_steps 1 \ 21 | --fsdp "full_shard auto_wrap" \ 22 | --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \ 23 | --tf32 True \ 24 | --model_max_length 4096 \ 25 | --gradient_checkpointing True \ 26 | --lazy_preprocess True 27 | -------------------------------------------------------------------------------- /autoagents/train/scripts/action_finetuning_v3.sh: -------------------------------------------------------------------------------- 1 | python3 -m torch.distributed.run --nproc_per_node=8 \ 2 | --master_port=20005 --nnodes=1 --node_rank=$RANK --master_addr=127.0.0.1 \ 3 | autoagents/train/train_v3.py \ 4 | --model_name_or_path input-model \ 5 | --data_path json-data.json \ 6 | --bf16 True \ 7 | --output_dir output-path \ 8 | --num_train_epochs 2 \ 9 | --per_device_train_batch_size 1 \ 10 | --per_device_eval_batch_size 1 \ 11 | --gradient_accumulation_steps 16 \ 12 | --evaluation_strategy "no" \ 13 | --save_strategy "steps" \ 14 | --save_steps 40 \ 15 | --save_total_limit 20 \ 16 | --learning_rate 2e-5 \ 17 | --weight_decay 0. \ 18 | --warmup_ratio 0.03 \ 19 | --lr_scheduler_type "cosine" \ 20 | --logging_steps 1 \ 21 | --fsdp "full_shard auto_wrap" \ 22 | --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \ 23 | --tf32 True \ 24 | --model_max_length 4096 \ 25 | --gradient_checkpointing True \ 26 | --lazy_preprocess True 27 | -------------------------------------------------------------------------------- /autoagents/train/scripts/conv_finetuning.sh: -------------------------------------------------------------------------------- 1 | python3 -m torch.distributed.run --nproc_per_node=8 \ 2 | --master_port=20005 --nnodes=1 --node_rank=$RANK --master_addr=127.0.0.1 \ 3 | autoagents/train/train.py \ 4 | --model_name_or_path meta-llama/Llama-2-13b-hf \ 5 | --data_path path-to-sharegpt-data.json \ 6 | --bf16 True \ 7 | --output_dir output-directory-path \ 8 | --num_train_epochs 1 \ 9 | --per_device_train_batch_size 1 \ 10 | --per_device_eval_batch_size 1 \ 11 | --gradient_accumulation_steps 16 \ 12 | --evaluation_strategy "no" \ 13 | --save_strategy "steps" \ 14 | --save_steps 40 \ 15 | --save_total_limit 20 \ 16 | --learning_rate 2e-5 \ 17 | --weight_decay 0. \ 18 | --warmup_ratio 0.03 \ 19 | --lr_scheduler_type "cosine" \ 20 | --logging_steps 1 \ 21 | --tf32 True \ 22 | --fsdp "full_shard auto_wrap" \ 23 | --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \ 24 | --model_max_length 4096 \ 25 | --gradient_checkpointing True \ 26 | --lazy_preprocess True \ 27 | -------------------------------------------------------------------------------- /autoagents/train/scripts/longchat_action_finetuning.sh: -------------------------------------------------------------------------------- 1 | CONDENSE_RESCALE=4 python3 -m torch.distributed.run --nproc_per_node=8 \ 2 | --master_port=20005 --nnodes=1 --node_rank=$RANK --master_addr=127.0.0.1 \ 3 | autoagents/train/train.py \ 4 | --model_name_or_path input-model \ 5 | --data_path json-data.json \ 6 | --bf16 True \ 7 | --output_dir output-path \ 8 | --num_train_epochs 2 \ 9 | --per_device_train_batch_size 1 \ 10 | --per_device_eval_batch_size 1 \ 11 | --gradient_accumulation_steps 16 \ 12 | --evaluation_strategy "no" \ 13 | --save_strategy "steps" \ 14 | --save_steps 40 \ 15 | --save_total_limit 20 \ 16 | --learning_rate 2e-5 \ 17 | --weight_decay 0. \ 18 | --warmup_ratio 0.03 \ 19 | --lr_scheduler_type "cosine" \ 20 | --logging_steps 1 \ 21 | --fsdp "full_shard auto_wrap" \ 22 | --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \ 23 | --tf32 True \ 24 | --model_max_length 4096 \ 25 | --gradient_checkpointing True \ 26 | --lazy_preprocess True \ 27 | -------------------------------------------------------------------------------- /autoagents/train/scripts/longchat_conv_finetuning.sh: -------------------------------------------------------------------------------- 1 | CONDENSE_RESCALE=4 python3 -m torch.distributed.run --nproc_per_node=8 \ 2 | --master_port=20005 --nnodes=1 --node_rank=$RANK --master_addr=127.0.0.1 \ 3 | autoagents/train/train.py \ 4 | --model_name_or_path meta-llama/Llama-2-13b-hf \ 5 | --data_path path-to-sharegpt-data.json \ 6 | --bf16 True \ 7 | --output_dir output-directory-path \ 8 | --num_train_epochs 1 \ 9 | --per_device_train_batch_size 1 \ 10 | --per_device_eval_batch_size 1 \ 11 | --gradient_accumulation_steps 16 \ 12 | --evaluation_strategy "no" \ 13 | --save_strategy "steps" \ 14 | --save_steps 40 \ 15 | --save_total_limit 20 \ 16 | --learning_rate 2e-5 \ 17 | --weight_decay 0. \ 18 | --warmup_ratio 0.03 \ 19 | --lr_scheduler_type "cosine" \ 20 | --logging_steps 1 \ 21 | --tf32 True \ 22 | --fsdp "full_shard auto_wrap" \ 23 | --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \ 24 | --model_max_length 4096 \ 25 | --gradient_checkpointing True \ 26 | --lazy_preprocess True \ 27 | -------------------------------------------------------------------------------- /autoagents/train/test_v3_preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import torch 4 | import click 5 | import tempfile 6 | import transformers 7 | from transformers.trainer_pt_utils import LabelSmoother 8 | from torch.utils.data import Dataset 9 | from typing import Dict 10 | 11 | from fastchat.conversation import ( 12 | SeparatorStyle, 13 | Conversation, 14 | register_conv_template 15 | ) 16 | from fastchat.model.model_adapter import ( 17 | get_conversation_template, 18 | register_model_adapter 19 | ) 20 | 21 | from train_v3 import ActionAdapter 22 | 23 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 24 | 25 | def rank0_print(*x): 26 | print(*x) 27 | 28 | 29 | def preprocess( 30 | sources, 31 | tokenizer: transformers.PreTrainedTokenizer, 32 | ) -> Dict: 33 | conv = get_conversation_template("action") 34 | roles = {"goal": conv.roles[0], 35 | "tools": conv.roles[1], 36 | "history": conv.roles[2], 37 | "next_action": conv.roles[3]} 38 | 39 | # Apply prompt templates 40 | conversations = [] 41 | for i, source in enumerate(sources): 42 | if roles[source[0]["from"]] != conv.roles[0]: 43 | # Skip the first one if it is not from human 44 | source = source[1:] 45 | 46 | conv.messages = [] 47 | for j, sentence in enumerate(source): 48 | role = roles[sentence["from"]] 49 | conv.append_message(role, sentence["value"]) 50 | conversations.append(conv.get_prompt()) 51 | 52 | # Tokenize conversations 53 | input_ids = tokenizer( 54 | conversations, 55 | return_tensors="pt", 56 | padding="max_length", 57 | max_length=tokenizer.model_max_length, 58 | truncation=True, 59 | ).input_ids 60 | targets = input_ids.clone() 61 | 62 | assert conv.sep_style == SeparatorStyle.ADD_COLON_SINGLE 63 | 64 | show_index = 3 65 | # Mask targets 66 | sep = conv.sep + conv.roles[-1] + ": " 67 | for i, (conversation, target) in enumerate(zip(conversations, targets)): 68 | if i == show_index: 69 | z = target.clone() 70 | z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) 71 | rank0_print(tokenizer.decode(z)) 72 | 73 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) # = len(tokenizer(conversation).input_ids) 74 | inputs, outputs = conversation.split(sep) # [.., goal, history | next_action] 75 | cur_len = 0 76 | instruction_len = len(tokenizer(inputs + sep).input_ids) - 1 77 | target[cur_len:cur_len+instruction_len] = IGNORE_TOKEN_ID 78 | cur_len += instruction_len - 1 79 | outputs_len = len(tokenizer(outputs).input_ids) 80 | target[cur_len+outputs_len:] = IGNORE_TOKEN_ID 81 | cur_len += outputs_len 82 | 83 | if i == show_index: 84 | z = target.clone() 85 | z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) 86 | rank0_print(tokenizer.decode(z)) 87 | 88 | if cur_len < tokenizer.model_max_length: 89 | if cur_len != total_len: 90 | target[:] = IGNORE_TOKEN_ID 91 | rank0_print( 92 | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." 93 | f" (ignored)" 94 | ) 95 | 96 | return dict( 97 | input_ids=input_ids, 98 | labels=targets, 99 | attention_mask=input_ids.ne(tokenizer.pad_token_id), 100 | ) 101 | 102 | class SupervisedDataset(Dataset): 103 | """Dataset for supervised fine-tuning.""" 104 | 105 | def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer): 106 | super(SupervisedDataset, self).__init__() 107 | 108 | rank0_print("Formatting inputs...") 109 | sources = [example["conversations"] for example in raw_data] 110 | data_dict = preprocess(sources, tokenizer) 111 | 112 | self.input_ids = data_dict["input_ids"] 113 | self.labels = data_dict["labels"] 114 | self.attention_mask = data_dict["attention_mask"] 115 | 116 | def __len__(self): 117 | return len(self.input_ids) 118 | 119 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 120 | return dict( 121 | input_ids=self.input_ids[i], 122 | labels=self.labels[i], 123 | attention_mask=self.attention_mask[i], 124 | ) 125 | 126 | 127 | def make_supervised_data_module( 128 | tokenizer: transformers.PreTrainedTokenizer, data_path 129 | ) -> Dict: 130 | """Make dataset and collator for supervised fine-tuning.""" 131 | dataset_cls = ( 132 | SupervisedDataset 133 | ) 134 | rank0_print("Loading data...") 135 | raw_data = json.load(open(data_path, "r")) 136 | 137 | # Split train/test 138 | np.random.seed(0) 139 | perm = np.random.permutation(len(raw_data)) 140 | split = int(len(perm) * 0.98) 141 | train_indices = perm[:split] 142 | eval_indices = perm[split:] 143 | train_raw_data = [raw_data[i] for i in train_indices] 144 | eval_raw_data = [raw_data[i] for i in eval_indices] 145 | rank0_print(f"#train {len(train_raw_data)}, #eval {len(eval_raw_data)}") 146 | 147 | train_dataset = dataset_cls(train_raw_data, tokenizer=tokenizer) 148 | eval_dataset = dataset_cls(eval_raw_data, tokenizer=tokenizer) 149 | return dict(train_dataset=train_dataset, eval_dataset=eval_dataset) 150 | 151 | 152 | @click.command() 153 | @click.option('--model_name_or_path', required=True) 154 | @click.option('--data_path', required=True) 155 | @click.option('--model_max_length', default=4096) 156 | def main(model_name_or_path, data_path, model_max_length): 157 | tokenizer = transformers.AutoTokenizer.from_pretrained( 158 | model_name_or_path, 159 | cache_dir=tempfile.TemporaryDirectory().name, 160 | model_max_length=model_max_length, 161 | padding_side="right", 162 | use_fast=False, 163 | legacy=False, 164 | ) 165 | 166 | tokenizer.pad_token = tokenizer.unk_token 167 | make_supervised_data_module(tokenizer=tokenizer, data_path=data_path) 168 | 169 | if __name__ == "__main__": 170 | # Action LLM default template 171 | register_conv_template( 172 | Conversation( 173 | name="action", 174 | system_message="Below is a goal you need to achieve. Given the available tools and history of past actions provide the next action to perform.", 175 | roles=("### Goal", "### Tools", "### History", "### Next action"), 176 | messages=(), 177 | offset=0, 178 | sep_style=SeparatorStyle.ADD_COLON_SINGLE, 179 | sep="\n\n", # separator between roles 180 | ) 181 | ) 182 | register_model_adapter(ActionAdapter) 183 | main() 184 | -------------------------------------------------------------------------------- /autoagents/train/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | rescale = int(os.environ.get("CONDENSE_RESCALE", 1)) 4 | if rescale > 1: 5 | from longchat.train.monkey_patch.llama_condense_monkey_patch import replace_llama_with_condense 6 | replace_llama_with_condense(rescale) 7 | 8 | from longchat.train.monkey_patch.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 9 | 10 | replace_llama_attn_with_flash_attn() 11 | 12 | from fastchat.train.train import train 13 | 14 | if __name__ == "__main__": 15 | train() 16 | -------------------------------------------------------------------------------- /autoagents/train/train_v3.py: -------------------------------------------------------------------------------- 1 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. 2 | # Need to call this before importing transformers. 3 | try: 4 | from fastchat.train.llama_flash_attn_monkey_patch import ( 5 | replace_llama_attn_with_flash_attn, 6 | ) 7 | 8 | replace_llama_attn_with_flash_attn() 9 | except ImportError: 10 | pass # ignore if flash-attn not installed 11 | 12 | from typing import Dict 13 | import torch 14 | from fastchat.conversation import ( 15 | SeparatorStyle, 16 | Conversation, 17 | register_conv_template, 18 | get_conv_template 19 | ) 20 | from fastchat.model.model_adapter import ( 21 | get_conversation_template, 22 | BaseModelAdapter, 23 | register_model_adapter 24 | ) 25 | from transformers import ( 26 | AutoConfig, 27 | AutoModelForCausalLM, 28 | AutoTokenizer, 29 | PreTrainedTokenizer 30 | ) 31 | 32 | import fastchat.train.train as train 33 | from fastchat.train.train import rank0_print 34 | from fastchat.train.train import IGNORE_TOKEN_ID 35 | 36 | class ActionAdapter(BaseModelAdapter): 37 | """The model adapter for Action Vicuna""" 38 | 39 | def match(self, model_path: str): 40 | return "action" in model_path 41 | 42 | def load_model(self, model_path: str, from_pretrained_kwargs: dict): 43 | config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) 44 | tokenizer = AutoTokenizer.from_pretrained( 45 | model_path, config=config, trust_remote_code=True 46 | ) 47 | model = AutoModelForCausalLM.from_pretrained( 48 | model_path, 49 | config=config, 50 | trust_remote_code=True, 51 | low_cpu_mem_usage=True, 52 | **from_pretrained_kwargs, 53 | ) 54 | return model, tokenizer 55 | 56 | def get_default_conv_template(self, model_path: str) -> Conversation: 57 | return get_conv_template("action") 58 | 59 | 60 | def preprocess( 61 | sources, 62 | tokenizer: PreTrainedTokenizer, 63 | ) -> Dict: 64 | conv = get_conversation_template("action") 65 | roles = {"goal": conv.roles[0], 66 | "tools": conv.roles[1], 67 | "history": conv.roles[2], 68 | "next_action": conv.roles[3]} 69 | 70 | # Apply prompt templates 71 | conversations = [] 72 | for i, source in enumerate(sources): 73 | if roles[source[0]["from"]] != conv.roles[0]: 74 | # Skip the first one if it is not from human 75 | source = source[1:] 76 | 77 | conv.messages = [] 78 | for j, sentence in enumerate(source): 79 | role = roles[sentence["from"]] 80 | conv.append_message(role, sentence["value"]) 81 | conversations.append(conv.get_prompt()) 82 | 83 | # Tokenize conversations 84 | input_ids = tokenizer( 85 | conversations, 86 | return_tensors="pt", 87 | padding="max_length", 88 | max_length=tokenizer.model_max_length, 89 | truncation=True, 90 | ).input_ids 91 | targets = input_ids.clone() 92 | 93 | assert conv.sep_style == SeparatorStyle.ADD_COLON_SINGLE 94 | 95 | # Mask targets 96 | sep = conv.sep + conv.roles[-1] + ": " 97 | for conversation, target in zip(conversations, targets): 98 | if False: 99 | z = target.clone() 100 | z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) 101 | rank0_print(tokenizer.decode(z)) 102 | 103 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) 104 | inputs, outputs = conversation.split(sep) # [.., goal, history | next_action] 105 | cur_len = 0 106 | instruction_len = len(tokenizer(inputs + sep).input_ids) - 1 107 | target[cur_len:cur_len+instruction_len] = IGNORE_TOKEN_ID 108 | cur_len += instruction_len - 1 109 | outputs_len = len(tokenizer(outputs).input_ids) 110 | target[cur_len+outputs_len:] = IGNORE_TOKEN_ID 111 | cur_len += outputs_len 112 | 113 | if False: 114 | z = target.clone() 115 | z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) 116 | rank0_print(tokenizer.decode(z)) 117 | 118 | if cur_len < tokenizer.model_max_length: 119 | if cur_len != total_len: 120 | target[:] = IGNORE_TOKEN_ID 121 | rank0_print( 122 | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." 123 | f" (ignored)" 124 | ) 125 | 126 | return dict( 127 | input_ids=input_ids, 128 | labels=targets, 129 | attention_mask=input_ids.ne(tokenizer.pad_token_id), 130 | ) 131 | 132 | 133 | 134 | if __name__ == "__main__": 135 | # Action LLM default template 136 | register_conv_template( 137 | Conversation( 138 | name="action", 139 | system_message="Below is a goal you need to achieve. Given the available tools and history of past actions provide the next action to perform.", 140 | roles=("### Goal", "### Tools", "### History", "### Next action"), 141 | messages=(), 142 | offset=0, 143 | sep_style=SeparatorStyle.ADD_COLON_SINGLE, 144 | sep="\n\n", # separator between roles 145 | ) 146 | ) 147 | register_model_adapter(ActionAdapter) 148 | 149 | # Monkeypatch preprocessing which handles the action roles 150 | train.preprocess = preprocess 151 | train.train() 152 | 153 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | openai>=0.27.7 2 | langchain>=0.0.193 3 | duckpy 4 | huggingface_hub 5 | pytz 6 | click 7 | pydantic==1.10.11 8 | git+https://github.com/lm-sys/FastChat@v0.2.25 9 | git+https://github.com/DachengLi1/LongChat@0a4d022 10 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='autoagents', 5 | version='0.1.0', 6 | packages=find_packages(include=['agents', 'agents.*']) 7 | ) 8 | --------------------------------------------------------------------------------