├── LICENSE ├── README.md ├── __pycache__ ├── prompts.cpython-310.pyc ├── prompts.cpython-39.pyc ├── search.cpython-310.pyc └── search.cpython-39.pyc ├── agent.py ├── config.yaml ├── memory.json ├── prompts.py ├── requirements.txt ├── schema └── Agent Schema.png └── search.py /LICENSE: -------------------------------------------------------------------------------- 1 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 2 | 3 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 4 | 5 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Custom Agent 3 | 4 | A custom websearch agent useable with Ollama, OpenAI and vLLM. 5 | 6 | ### Agent Schema: 7 | ![Agent Schema](schema/Agent%20Schema.png) 8 | 9 | 10 | ### Prerequisites 11 | 12 | #### Environment Setup 13 | 1. **Install Anaconda:** 14 | Download Anaconda from [https://www.anaconda.com/](https://www.anaconda.com/). 15 | 16 | 2. **Create a Virtual Environment:** 17 | ```bash 18 | conda create -n agent_env python=3.10 pip 19 | ``` 20 | 21 | 3. **Activate the Virtual Environment:** 22 | ```bash 23 | conda activate agent_env 24 | ``` 25 | #### Setup Ollama Server 26 | 1. **Download Ollama:** 27 | Download [https://ollama.com/download](Ollama) 28 | 29 | 2. **Download an Ollama Model:** 30 | ```bash 31 | curl http://localhost:11434/api/pull -d "{\"name\": \"llama3\"}" 32 | ``` 33 | Ollama[https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models](API documentionation) 34 | 35 | ### Clone and Navigate to the Repository 36 | 1. **Clone the Repo:** 37 | ```bash 38 | git clone https://github.com/john-adeojo/custom_agent_tutorial.git 39 | ``` 40 | 41 | 2. **Navigate to the Repo:** 42 | ```bash 43 | cd /path/to/your-repo/custom_agent_tutorial 44 | ``` 45 | 46 | 3. **Install Requirements:** 47 | ```bash 48 | pip install -r requirements.txt 49 | ``` 50 | 51 | ### Configure API Keys 52 | 1. **Open the `config.yaml`:** 53 | ```bash 54 | nano config.yaml 55 | ``` 56 | 57 | 2. **Enter API Keys:** 58 | - **Serper API Key:** Get it from [https://serper.dev/](https://serper.dev/) 59 | - **OpenAI API Key:** Get it from [https://openai.com/](https://openai.com/) 60 | 61 | ### Run Your Query 62 | ```bash 63 | python agent.py run 64 | ``` 65 | Then enter your query. 66 | -------------------------------------------------------------------------------- /__pycache__/prompts.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/john-adeojo/custom_websearch_agent/0857ee53a4eb351c4595fd22bcfa25b85bce0857/__pycache__/prompts.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/prompts.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/john-adeojo/custom_websearch_agent/0857ee53a4eb351c4595fd22bcfa25b85bce0857/__pycache__/prompts.cpython-39.pyc -------------------------------------------------------------------------------- /__pycache__/search.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/john-adeojo/custom_websearch_agent/0857ee53a4eb351c4595fd22bcfa25b85bce0857/__pycache__/search.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/search.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/john-adeojo/custom_websearch_agent/0857ee53a4eb351c4595fd22bcfa25b85bce0857/__pycache__/search.cpython-39.pyc -------------------------------------------------------------------------------- /agent.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import json 4 | import requests 5 | from datetime import datetime, timezone 6 | from termcolor import colored 7 | from prompts import planning_agent_prompt, integration_agent_prompt, check_response_prompt, check_response_json 8 | from search import WebSearcher 9 | import ast 10 | 11 | 12 | def load_config(file_path): 13 | with open(file_path, 'r') as file: 14 | config = yaml.safe_load(file) 15 | for key, value in config.items(): 16 | os.environ[key] = value 17 | 18 | def get_current_utc_datetime(): 19 | now_utc = datetime.now(timezone.utc) 20 | current_time_utc = now_utc.strftime("%Y-%m-%d %H:%M:%S %Z") 21 | return current_time_utc 22 | 23 | def save_feedback(response, json_filename="memory.json"): 24 | # Create a dictionary with the response 25 | feedback_entry = {"feedback": response} 26 | 27 | # Load existing data from the JSON file if it exists 28 | if os.path.exists(json_filename): 29 | with open(json_filename, "r") as json_file: 30 | data = json.load(json_file) 31 | else: 32 | data = [] 33 | 34 | # Append the new feedback entry to the data 35 | data.append(feedback_entry) 36 | 37 | # Write the updated data back to the JSON file 38 | with open(json_filename, "w") as json_file: 39 | json.dump(data, json_file, indent=4) 40 | 41 | def read_feedback(json_filename="memory.json"): 42 | if os.path.exists(json_filename): 43 | with open(json_filename, "r") as json_file: 44 | data = json.load(json_file) 45 | # Convert the JSON data to a pretty-printed string 46 | json_string = json.dumps(data, indent=4) 47 | # json_string = str(data) 48 | return json_string 49 | else: 50 | return "" 51 | 52 | def clear_json_file(json_filename="memory.json"): 53 | # Open the file in write mode to clear its contents 54 | with open(json_filename, "w") as json_file: 55 | json.dump([], json_file) 56 | 57 | def initialize_json_file(json_filename="memory.json"): 58 | if not os.path.exists(json_filename) or os.path.getsize(json_filename) == 0: 59 | with open(json_filename, "w") as json_file: 60 | json.dump([], json_file) 61 | 62 | # Call this function at the beginning of your script 63 | initialize_json_file() 64 | 65 | 66 | class Agent: 67 | def __init__(self, model, model_tool, model_qa, tool, temperature=0, max_tokens=1000, planning_agent_prompt=None, integration_agent_prompt=None, check_response_prompt=None, verbose=False, iterations=5, model_endpoint=None, server=None, stop=None): 68 | self.server = server 69 | self.model_endpoint = model_endpoint 70 | 71 | if server == 'openai': 72 | load_config('config.yaml') 73 | self.api_key = os.getenv('OPENAI_API_KEY') 74 | self.headers = { 75 | 'Content-Type': 'application/json', 76 | 'Authorization': f'Bearer {self.api_key}' 77 | } 78 | else: 79 | self.headers = {"Content-Type": "application/json"} 80 | 81 | self.temperature = temperature 82 | self.max_tokens = max_tokens 83 | self.tool_specs = tool.__doc__ 84 | self.planning_agent_prompt = planning_agent_prompt 85 | self.integration_agent_prompt = integration_agent_prompt 86 | self.model = model 87 | self.tool = tool(model=model_tool, verbose=verbose, model_endpoint=model_endpoint, server=server, stop=stop) 88 | self.iterations = iterations 89 | self.model_qa = model_qa 90 | self.stop = stop 91 | 92 | def run_planning_agent(self, query, plan=None, feedback=None): 93 | 94 | system_prompt = self.planning_agent_prompt.format( 95 | plan=plan, 96 | feedback=feedback, 97 | tool_specs=self.tool_specs, 98 | datetime=get_current_utc_datetime() 99 | ) 100 | 101 | if self.server == 'ollama': 102 | payload = { 103 | "model": self.model, 104 | "prompt": query, 105 | "system": system_prompt, 106 | "stream": False, 107 | "temperature": 0, 108 | } 109 | 110 | if self.server == 'runpod' or self.server == 'openai': 111 | 112 | prefix = self.model.split('/')[0] 113 | exception_models = ['microsoft/Phi-3-medium-128k-instruct', 114 | 'microsoft/Phi-3-mini-128k-instruct', 115 | 'microsoft/Phi-3-medium-4k-instruct', 116 | 'microsoft/Phi-3-mini-4k-instruct', 117 | ] 118 | 119 | if prefix == 'mistralai' or self.model in exception_models: 120 | payload = { 121 | "model": self.model, 122 | "messages": [ 123 | { 124 | "role": "user", 125 | "content": f"system_prompt:{system_prompt}\n\n query: {query}" 126 | } 127 | ], 128 | "temperature": 0, 129 | "stop": None 130 | } 131 | 132 | else: 133 | payload = { 134 | "model": self.model, 135 | "messages": [ 136 | { 137 | "role": "system", 138 | "content": system_prompt 139 | }, 140 | { 141 | "role": "user", 142 | "content": query 143 | } 144 | ], 145 | "stream": False, 146 | "temperature": 0, 147 | "stop": self.stop 148 | } 149 | 150 | if self.server == 'openai': 151 | del payload["stop"] 152 | 153 | try: 154 | response = requests.post(self.model_endpoint, headers=self.headers, data=json.dumps(payload)) 155 | print("Response_DEBUG:", response) 156 | 157 | try: 158 | response_dict = response.json() 159 | except json.JSONDecodeError as e: 160 | response_dict = ast.literal_eval(response) 161 | 162 | if self.server == 'ollama': 163 | response = response_dict['response'] 164 | 165 | if self.server == 'runpod' or self.server == 'openai': 166 | response = response_dict['choices'][0]['message']['content'] 167 | 168 | print(colored(f"Planning Agent: {response}", 'green')) 169 | return response 170 | 171 | except Exception as e: 172 | print("Error in response:", response_dict) 173 | return "Error generating plan {e}" 174 | 175 | def run_integration_agent(self, query, plan, outputs, reason, previous_response): 176 | 177 | system_prompt = self.integration_agent_prompt.format( 178 | outputs=outputs, 179 | plan=plan, 180 | reason=reason, 181 | sources=outputs.get('sources', ''), 182 | previous_response=previous_response, 183 | datetime=get_current_utc_datetime(), 184 | query=query 185 | ) 186 | 187 | if self.server == 'ollama': 188 | payload = { 189 | "model": self.model, 190 | "prompt": query, 191 | "system": system_prompt, 192 | "stream": False, 193 | "temperature": 0, 194 | } 195 | 196 | if self.server == 'runpod' or self.server == 'openai': 197 | 198 | prefix = self.model.split('/')[0] 199 | exception_models = ['microsoft/Phi-3-medium-128k-instruct', 200 | 'microsoft/Phi-3-mini-128k-instruct', 201 | 'microsoft/Phi-3-medium-4k-instruct', 202 | 'microsoft/Phi-3-mini-4k-instruct', 203 | ] 204 | 205 | if prefix == 'mistralai' or self.model in exception_models: 206 | payload = { 207 | "model": self.model, 208 | "messages": [ 209 | { 210 | "role": "user", 211 | "content": f"system_prompt:{system_prompt}\n\n query: {query}" 212 | } 213 | ], 214 | "temperature": 0, 215 | "stop": None 216 | } 217 | 218 | else: 219 | payload = { 220 | "model": self.model, 221 | "messages": [ 222 | { 223 | "role": "system", 224 | "content": system_prompt 225 | }, 226 | { 227 | "role": "user", 228 | "content": query 229 | } 230 | ], 231 | "stream": False, 232 | "temperature": 0, 233 | "stop": self.stop 234 | } 235 | 236 | if self.server == 'openai': 237 | del payload["stop"] 238 | 239 | try: 240 | response = requests.post(self.model_endpoint, headers=self.headers, data=json.dumps(payload)) 241 | try: 242 | response_dict = response.json() 243 | except json.JSONDecodeError as e: 244 | response_dict = ast.literal_eval(response) 245 | 246 | if self.server == 'ollama': 247 | response = response_dict['response'] 248 | 249 | if self.server == 'runpod' or self.server == 'openai': 250 | response = response_dict['choices'][0]['message']['content'] 251 | 252 | print(colored(f"Integration Agent: {response}", 'cyan')) 253 | 254 | return response 255 | 256 | except Exception as e: 257 | print("Error in response:", response_dict) 258 | return "Error generating plan {e}" 259 | 260 | def check_response(self, response, query, previous_response, datetime=get_current_utc_datetime()): 261 | 262 | if self.server == 'ollama': 263 | payload = { 264 | "model": self.model_qa, 265 | "prompt": f"query: {query}\n\nresponse: {response}", 266 | "format": "json", 267 | "system": check_response_prompt, 268 | "stream": False, 269 | "temperature": 0, 270 | "stop": self.stop 271 | } 272 | 273 | if self.server == 'runpod' or self.server == 'openai': 274 | 275 | prefix = self.model.split('/')[0] 276 | exception_models = ['microsoft/Phi-3-medium-128k-instruct', 277 | 'microsoft/Phi-3-mini-128k-instruct', 278 | 'microsoft/Phi-3-medium-4k-instruct', 279 | 'microsoft/Phi-3-mini-4k-instruct', 280 | ] 281 | 282 | if prefix == 'mistralai' or self.model in exception_models: 283 | payload = { 284 | "model": self.model_qa, 285 | "messages": [ 286 | { 287 | "role": "user", 288 | "content": f"system:{check_response_prompt}\n\n query: {query}\n\nresponse: {response} \n\n previous response: {previous_response} \n\n current datetime: {datetime}" 289 | } 290 | ], 291 | "temperature": 0, 292 | "stop": None, 293 | "guided_json": check_response_json 294 | } 295 | 296 | else: 297 | payload = { 298 | "model": self.model, 299 | "response_format": {"type": "json_object"}, 300 | "messages": [ 301 | { 302 | "role": "system", 303 | "content": check_response_prompt 304 | }, 305 | { 306 | "role": "user", 307 | "content": f"query: {query} \n\nresponse: {response} \n\nprevious response: {previous_response} \n\n current datetime: {datetime}" 308 | } 309 | ], 310 | "temperature": 0, 311 | "stop": self.stop, 312 | "guided_json": check_response_json 313 | } 314 | 315 | if self.server == 'openai': 316 | del payload["stop"] 317 | del payload["guided_json"] 318 | 319 | try: 320 | response = requests.post(self.model_endpoint, headers=self.headers, data=json.dumps(payload)) 321 | try: 322 | response_dict = response.json() 323 | except json.JSONDecodeError as e: 324 | response_dict = ast.literal_eval(response) 325 | 326 | print(f"check_response response_dict type: {type(response_dict)}") 327 | 328 | if self.server == 'ollama': 329 | decision_dict = json.loads(response_dict['response']) 330 | 331 | if self.server == 'runpod' or self.server == 'openai': 332 | response_content = response_dict['choices'][0]['message']['content'] 333 | 334 | try: 335 | decision_dict = json.loads(response_content) 336 | except json.JSONDecodeError as e: 337 | decision_dict = ast.literal_eval(response_content) 338 | 339 | print("Response Quality Assessment:", decision_dict) 340 | return decision_dict 341 | 342 | except Exception as e: 343 | print("Error in assessing response quality:", response_dict) 344 | return "Error in assessing response quality" 345 | 346 | def execute(self): 347 | query = input("Enter your query: ") 348 | meets_requirements = False 349 | plan = None 350 | outputs = None 351 | integration_agent_response = None 352 | reason = None 353 | iterations = 0 354 | visited_sites = [] 355 | failed_sites = [] 356 | 357 | while not meets_requirements and iterations < self.iterations: 358 | iterations += 1 359 | feedback = read_feedback(json_filename="memory.json") 360 | plan = self.run_planning_agent(query, plan=plan, feedback=feedback) 361 | outputs = self.tool.use_tool(plan=plan, query=query, visited_sites=visited_sites, failed_sites=failed_sites) 362 | visited_sites.append(outputs.get('source', '')) 363 | print("VISITED_SITES",visited_sites) 364 | 365 | integration_agent_response = self.run_integration_agent(query=query, plan=plan, outputs=outputs, reason=reason, previous_response=feedback) 366 | save_feedback(integration_agent_response, json_filename="memory.json") 367 | response_dict = self.check_response(response=integration_agent_response, query=query, previous_response=feedback) 368 | meets_requirements = response_dict.get('pass', '') 369 | print(f"Response meets requirements: {meets_requirements}") 370 | if meets_requirements == 'True': 371 | meets_requirements = True 372 | else: 373 | meets_requirements = False 374 | reason = response_dict.get('reason', '') 375 | 376 | clear_json_file() 377 | print(colored(f"Final Response: {integration_agent_response}", 'cyan')) 378 | 379 | 380 | if __name__ == '__main__': 381 | 382 | # Params for Ollama 383 | # model = "llama3:instruct" 384 | # model_tool = "llama3:instruct" 385 | # model_qa = "llama3:instruct" 386 | # model_endpoint = 'http://localhost:11434/api/generate' 387 | # stop = None 388 | # server = 'ollama' 389 | 390 | # Params for RunPod 391 | # model = "mistralai/Codestral-22B-v0.1" 392 | # model_tool = "mistralai/Codestral-22B-v0.1" 393 | # model_qa = "mistralai/Codestral-22B-v0.1" 394 | # runpod_endpoint = 'https://dtalj4mdnqx7vh-8000.proxy.runpod.net/' # Add your RunPod endpoint here 395 | # completions_endpoint = 'v1/chat/completions' 396 | # model_endpoint = runpod_endpoint + completions_endpoint 397 | # stop = "" 398 | # server = 'runpod' 399 | 400 | # Params for OpenAI 401 | model = 'gpt-4o' 402 | model_tool = 'gpt-4o' 403 | model_qa = 'gpt-4o' 404 | model_endpoint = 'https://api.openai.com/v1/chat/completions' 405 | stop = None 406 | server = 'openai' 407 | 408 | agent = Agent(model=model, 409 | model_tool=model_tool, 410 | model_qa=model_qa, 411 | tool=WebSearcher, 412 | planning_agent_prompt=planning_agent_prompt, 413 | integration_agent_prompt=integration_agent_prompt, 414 | verbose=False, 415 | iterations=6, 416 | model_endpoint=model_endpoint, 417 | server=server 418 | ) 419 | agent.execute() 420 | 421 | 422 | 423 | 424 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | OPENAI_API_KEY: "YOUR_API_KEY" 2 | SERPER_DEV_API_KEY: "YOUR_API_KEY" -------------------------------------------------------------------------------- /memory.json: -------------------------------------------------------------------------------- 1 | [] -------------------------------------------------------------------------------- /prompts.py: -------------------------------------------------------------------------------- 1 | planning_agent_prompt = """ 2 | You are an AI planning agent working with an integration agent. 3 | 4 | Your job is to come up with the searches you can use in a search engine to answer the query. 5 | 6 | You must not answer the query, only generate the questions. 7 | 8 | If there are multiple searches, highlight the single most important search. 9 | 10 | Ensure your response takes into account any feedback (if available). 11 | 12 | Here is your previous plan: `{plan}` 13 | 14 | Here is the feedback: `{feedback}` 15 | 16 | You MUST carefully consider the feedback and adjust or change your plan based on the feedback provided. 17 | 18 | For example, if the feedback is that the plan is missing a key element, you should adjust the plan to include that element. 19 | 20 | You should be aware of today's date to help you answer questions that require current information. 21 | Here is today's date and time (Timezone: UTC): `{datetime}` 22 | """ 23 | 24 | integration_agent_prompt = """ 25 | You are an AI Integration Agent working with a planning agent. 26 | 27 | Your job is to compile a response to the original query based entirely on the research provided to you. 28 | 29 | If the research is insufficient, provide explicit feedback to the planning agent to refine the plan. 30 | 31 | This feedback should include the specific information that is missing from the research. 32 | 33 | Your feedback should state which questions have already been answered by the research and which questions are still unanswered. 34 | 35 | If the research is sufficient, provide a comprehensive response to the query with citations. 36 | 37 | In your comprehensive response, you MUST do the following: 38 | 1. Only use the research provided to you to generate the response. 39 | 2. Directly provide the source of the information in the response. 40 | The research is a dictionary that provides research content alongside its source. 41 | 42 | research: `{outputs}` 43 | 44 | Here is the plan from the planning agent: `{plan}` 45 | 46 | You must fully cite the sources provided in the research. 47 | 48 | Sources from research: `{sources}` 49 | 50 | Do not use sources that have not been provided in the research. 51 | 52 | Example Response: 53 | Based on the information gathered, here is the comprehensive response to the query: 54 | 55 | The sky appears blue because of a phenomenon called Rayleigh scattering, which causes shorter wavelengths of light (blue) to scatter more than longer wavelengths (red). This scattering causes the sky to look blue most of the time . 56 | 57 | Additionally, during sunrise and sunset, the sky can appear red or orange because the light has to pass through more atmosphere, scattering the shorter blue wavelengths out of the line of sight and allowing the longer red wavelengths to dominate . 58 | 59 | Sources: 60 | : https://example.com/science/why-is-the-sky-blue 61 | : https://example.com/science/sunrise-sunset-colors 62 | 63 | There is a quality assurance process to check your response meets the requirements. 64 | 65 | Here are the results of the last quality assurance check: `{reason}` 66 | 67 | Take these into account when generating your response. 68 | 69 | Here are all your previous responses: `{previous_response}` 70 | 71 | Your previous responses may partially answer the original user query, you should consider this when generating your response. 72 | 73 | Here is today's date and time (Timezone: UTC): `{datetime}` 74 | 75 | Here's a reminder of the original user query: `{query}` 76 | """ 77 | 78 | 79 | check_response_prompt = """ 80 | Check if the response meets all of the requirements of the query based on the following: 81 | 1. The response must be relevant to the query. 82 | if the response is not relevant, return pass as 'False' and state the 'relevant' as 'Not relevant'. 83 | 2. The response must be coherent and well-structured. 84 | if the response is not coherent and well-structured, return pass as 'False' and state the 'coherent' as 'Incoherent'. 85 | 3. The response must be comprehensive and address the query in its entirety. 86 | if the response is not comprehensive and doesn't address the query in its entirety, return pass as 'False' and state the 'comprehensive' as 'Incomprehensive'. 87 | 4. The response must have Citations and links to sources. 88 | if the response does not have citations and links to sources, return pass as 'False' and state the 'citations' as 'No citations'. 89 | 5. Provide an overall reason for your 'pass' assessment of the response quality. 90 | 91 | The previous responses may partially answer the original user query. The response may likely contain some information from previous responses. 92 | You should consider this when checking the response quality. 93 | 94 | The json object should have the following format: 95 | { 96 | 'pass': 'True' or 'False' 97 | 'relevant': 'Relevant' or 'Not relevant' 98 | 'coherent': 'Coherent' or 'Incoherent' 99 | 'comprehensive': 'Comprehensive' or 'Incomprehensive' 100 | 'citations': 'Citations' or 'No citations' 101 | 'reason': 'Provide a reason for the response quality.' 102 | } 103 | """ 104 | 105 | 106 | generate_searches_prompt = """ 107 | Return a json object that gives the input to a google search engine that could be used to find an answer to the Query based on the Plan. 108 | You may be given a multiple questions to answer, but you should only generate the search engine query for the single most important question according to the Plan and query. 109 | The json object should have the following format: 110 | { 111 | 'response': 'search engine query' 112 | } 113 | """ 114 | 115 | 116 | get_search_page_prompt = """ 117 | Return a json object that gives the URL of the best website source to answer the Query, 118 | Plan and Search Results. The URL MUST be selected 119 | from the Search Results provided. 120 | YOU MUST NOT SELECT A URL FROM THE FAILED SITES! 121 | YOU MUST NOT SELECT A URL FROM THE VISITED SITES! 122 | Do not select any of these sites: 123 | The json object should have the following format: 124 | { 125 | 'response': 'Best website source URL' 126 | } 127 | """ 128 | 129 | 130 | check_response_json = { 131 | "type": "object", 132 | "properties": { 133 | "pass": { 134 | "type": "string" 135 | }, 136 | "relevant": { 137 | "type": "string" 138 | }, 139 | "coherent": { 140 | "type": "string" 141 | }, 142 | "comprehensive": { 143 | "type": "string" 144 | }, 145 | "citations": { 146 | "type": "string" 147 | }, 148 | "reason": { 149 | "type": "string" 150 | } 151 | }, 152 | "required": ["pass", "relevant", "coherent", "comprehensive", "citations", "reason"] 153 | } 154 | 155 | generate_searches_json = { 156 | "type": "object", 157 | "properties": { 158 | "response": { 159 | "type": "string" 160 | } 161 | }, 162 | "required": ["response"] 163 | } 164 | 165 | get_search_page_json = { 166 | "type": "object", 167 | "properties": { 168 | "response": { 169 | "type": "string" 170 | } 171 | }, 172 | "required": ["response"] 173 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | termcolor==2.4.0 2 | PyYAML==6.0.1 3 | requests==2.31.0 4 | beautifulsoup4==4.12.3 5 | chardet==5.2.0 6 | -------------------------------------------------------------------------------- /schema/Agent Schema.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/john-adeojo/custom_websearch_agent/0857ee53a4eb351c4595fd22bcfa25b85bce0857/schema/Agent Schema.png -------------------------------------------------------------------------------- /search.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from bs4 import BeautifulSoup 3 | import json 4 | import yaml 5 | from termcolor import colored 6 | import os 7 | import chardet 8 | import string 9 | import ast 10 | from prompts import generate_searches_prompt, get_search_page_prompt, generate_searches_json, get_search_page_json 11 | 12 | 13 | def load_config(file_path): 14 | with open(file_path, 'r') as file: 15 | config = yaml.safe_load(file) 16 | for key, value in config.items(): 17 | os.environ[key] = value 18 | 19 | class WebSearcher: 20 | """ 21 | Input: 22 | Search Engine Query: The primary input to the tool is a search engine query intended for Google Search. This query is generated based on a specified plan and user query. 23 | Output: 24 | Dictionary of Website Content: The output of the tool is a dictionary where: 25 | The key is the URL of the website that is deemed most relevant based on the search results. 26 | The value is the content scraped from that website, presented as plain text. 27 | The source is useful for citation purposes in the final response to the user query. 28 | The content is used to generate a comprehensive response to the user query. 29 | """ 30 | def __init__(self, model, verbose=False, model_endpoint=None, server=None, stop=None): 31 | self.server = server 32 | self.model_endpoint = model_endpoint 33 | load_config('config.yaml') 34 | if server == 'openai': 35 | self.api_key = os.getenv("OPENAI_API_KEY") 36 | self.headers = { 37 | 'Content-Type': 'application/json', 38 | 'Authorization': f'Bearer {self.api_key}' 39 | } 40 | else: 41 | self.headers = {"Content-Type": "application/json"} 42 | self.model = model 43 | self.verbose = verbose 44 | 45 | # self.failed_sites = [] 46 | self.stop = stop 47 | 48 | def generate_searches(self, plan, query): 49 | 50 | if self.server == 'ollama': 51 | 52 | payload = { 53 | "model": self.model, 54 | "prompt": f"Query: {query}\n\nPlan: {plan}", 55 | "format": "json", 56 | "system": generate_searches_prompt, 57 | "stream": False, 58 | "temperature": 0, 59 | } 60 | 61 | if self.server == 'runpod' or self.server == 'openai': 62 | 63 | prefix = self.model.split('/')[0] 64 | exception_models = ['microsoft/Phi-3-medium-128k-instruct', 65 | 'microsoft/Phi-3-mini-128k-instruct', 66 | 'microsoft/Phi-3-medium-4k-instruct', 67 | 'microsoft/Phi-3-mini-4k-instruct', 68 | ] 69 | 70 | if prefix == 'mistralai' or self.model in exception_models: 71 | payload = { 72 | "model": self.model, 73 | "response_format": {"type": "json_object"}, 74 | "messages": [ 75 | { 76 | "role": "user", 77 | "content": f"System: {generate_searches_prompt} \n\n\ Query: {query}\n\nPlan: {plan}" 78 | } 79 | ], 80 | "temperature": 0, 81 | "stop": None, 82 | "guided_json": generate_searches_json 83 | } 84 | 85 | else: 86 | payload = { 87 | "model": self.model, 88 | "response_format": {"type": "json_object"}, 89 | "messages": [ 90 | { 91 | "role": "system", 92 | "content": generate_searches_prompt 93 | }, 94 | { 95 | "role": "user", 96 | "content": f"Query: {query}\n\nPlan: {plan}" 97 | } 98 | ], 99 | "temperature": 0, 100 | "stop": self.stop, 101 | "guided_json": generate_searches_json 102 | 103 | } 104 | 105 | if self.server == 'openai': 106 | del payload["stop"] 107 | del payload["guided_json"] 108 | 109 | try: 110 | response = requests.post(self.model_endpoint, headers=self.headers, data=json.dumps(payload)) 111 | print(f"Response_DEBUG: {response}") 112 | try: 113 | response_dict = response.json() 114 | except json.JSONDecodeError: 115 | response_dict = ast.literal_eval(response.content) 116 | 117 | if self.server == 'ollama': 118 | response_json = json.loads(response_dict['response']) 119 | print(f"Response JSON: {response_json}") 120 | search_query = response_json.get('response', '') 121 | 122 | if self.server == 'runpod' or self.server == 'openai': 123 | response_content = response_dict['choices'][0]['message']['content'] 124 | 125 | try: 126 | response_json = json.loads(response_content) 127 | except json.JSONDecodeError: 128 | response_json = ast.literal_eval(response_content) 129 | 130 | search_query = response_json.get('response', '') 131 | 132 | print(f"Search Query: {search_query}") 133 | 134 | return search_query 135 | 136 | except Exception as e: 137 | print("Error in response:", response_dict) 138 | return "Error generating search query" 139 | 140 | def get_search_page(self, plan, query, search_results, failed_sites=[], visited_sites=[]): 141 | 142 | if self.server == 'ollama': 143 | payload = { 144 | "model": self.model, 145 | "prompt": f"Query: {query}\n\nPlan: {plan} \n\nSearch Results: {search_results}\n\nFailed Sites: {failed_sites}\n\nVisited Sites: {visited_sites}", 146 | "format": "json", 147 | "system": get_search_page_prompt, 148 | "stream": False, 149 | "temperature": 0, 150 | } 151 | 152 | if self.server == 'runpod' or self.server == 'openai': 153 | 154 | prefix = self.model.split('/')[0] 155 | exception_models = ['microsoft/Phi-3-medium-128k-instruct', 156 | 'microsoft/Phi-3-mini-128k-instruct', 157 | 'microsoft/Phi-3-medium-4k-instruct', 158 | 'microsoft/Phi-3-mini-4k-instruct', 159 | ] 160 | 161 | if prefix == 'mistralai' or self.model in exception_models: 162 | payload = { 163 | "model": self.model, 164 | "response_format": {"type": "json_object"}, 165 | "messages": [ 166 | { 167 | "role": "user", 168 | "content": f"System: {get_search_page_prompt} \n\n\ Query: {query}\n\nPlan: {plan}\n\nSearch Results: {search_results} \n\nFailed Sites: {failed_sites}\n\nVisited Sites: {visited_sites}" 169 | } 170 | ], 171 | "temperature": 0, 172 | "stop": None, 173 | "guided_json": get_search_page_json 174 | } 175 | 176 | else: 177 | payload = { 178 | "model": self.model, 179 | "response_format": {"type": "json_object"}, 180 | "messages": [ 181 | { 182 | "role": "system", 183 | "content": get_search_page_prompt 184 | }, 185 | { 186 | "role": "user", 187 | "content": f"Query: {query}\n\nPlan: {plan}\n\nSearch Results: {search_results} \n\nFailed Sites: {failed_sites}\n\nVisited Sites: {visited_sites}" 188 | } 189 | ], 190 | "temperature": 0, 191 | "stop": self.stop, 192 | "guided_json": get_search_page_json 193 | } 194 | 195 | if self.server == 'openai': 196 | del payload["stop"] 197 | del payload["guided_json"] 198 | 199 | try: 200 | response = requests.post(self.model_endpoint, headers=self.headers, data=json.dumps(payload)) 201 | try: 202 | response_dict = response.json() 203 | except json.JSONDecodeError: 204 | response_dict = ast.literal_eval(response.content) 205 | 206 | if self.server == 'ollama': 207 | response_json = json.loads(response_dict['response']) 208 | search_query = response_json.get('response', '') 209 | 210 | if self.server == 'runpod' or self.server == 'openai': 211 | response_content = response_dict['choices'][0]['message']['content'] 212 | 213 | try: 214 | response_json = json.loads(response_content) 215 | except json.JSONDecodeError: 216 | response_json = ast.literal_eval(response_content) 217 | 218 | search_query = response_json.get('response', '') 219 | 220 | return search_query 221 | 222 | except Exception as e: 223 | print("Error in response:", response_dict) 224 | return "Error getting search page URL" 225 | 226 | def format_results(self, organic_results): 227 | 228 | result_strings = [] 229 | for result in organic_results: 230 | title = result.get('title', 'No Title') 231 | link = result.get('link', '#') 232 | snippet = result.get('snippet', 'No snippet available.') 233 | result_strings.append(f"Title: {title}\nLink: {link}\nSnippet: {snippet}\n---") 234 | 235 | return '\n'.join(result_strings) 236 | 237 | def fetch_search_results(self, search_queries): 238 | 239 | search_url = "https://google.serper.dev/search" 240 | headers = { 241 | 'Content-Type': 'application/json', 242 | 'X-API-KEY': os.environ['SERPER_DEV_API_KEY'] # Ensure this environment variable is set with your API key 243 | } 244 | payload = json.dumps({"q": search_queries}) 245 | 246 | # Attempt to make the HTTP POST request 247 | try: 248 | response = requests.post(search_url, headers=headers, data=payload) 249 | response.raise_for_status() # Raise an HTTPError for bad responses (4XX, 5XX) 250 | results = response.json() 251 | 252 | # Check if 'organic' results are in the response 253 | if 'organic' in results: 254 | return self.format_results(results['organic']) 255 | else: 256 | return "No organic results found." 257 | 258 | except requests.exceptions.HTTPError as http_err: 259 | return f"HTTP error occurred: {http_err}" 260 | except requests.exceptions.RequestException as req_err: 261 | return f"Request exception occurred: {req_err}" 262 | except KeyError as key_err: 263 | return f"Key error in handling response: {key_err}" 264 | 265 | def scrape_website_content(self, website_url, failed_sites=[]): 266 | headers = { 267 | 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36', 268 | 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9', 269 | 'Accept-Language': 'en-US,en;q=0.9', 270 | 'Referer': 'https://www.google.com/', 271 | 'Connection': 'keep-alive', 272 | 'Upgrade-Insecure-Requests': '1', 273 | 'Accept-Encoding': 'gzip, deflate, br' 274 | } 275 | 276 | def is_garbled(text): 277 | # Count non-ASCII characters 278 | non_ascii_chars = sum(1 for char in text if char not in string.printable) 279 | try: 280 | # Calculate the proportion of non-ASCII characters 281 | return non_ascii_chars / len(text) > 0.2 282 | except ZeroDivisionError: 283 | # If the text is empty, it cannot be garbled 284 | return False 285 | 286 | 287 | try: 288 | # Making a GET request to the website 289 | response = requests.get(website_url, headers=headers, timeout=15) 290 | response.raise_for_status() # This will raise an exception for HTTP errors 291 | 292 | # Detecting encoding using chardet 293 | detected_encoding = chardet.detect(response.content) 294 | response.encoding = detected_encoding['encoding'] if detected_encoding['confidence'] > 0.5 else 'utf-8' 295 | 296 | # Handling possible issues with encoding detection 297 | try: 298 | content = response.text 299 | except UnicodeDecodeError: 300 | content = response.content.decode('utf-8', errors='replace') 301 | 302 | # Parsing the page content using BeautifulSoup 303 | soup = BeautifulSoup(content, 'html.parser') 304 | text = soup.get_text(separator='\n') 305 | # Cleaning up the text: removing excess whitespace 306 | clean_text = '\n'.join([line.strip() for line in text.splitlines() if line.strip()]) 307 | split_text = clean_text.split() 308 | first_5k_words = split_text[:4000] 309 | clean_text_5k = ' '.join(first_5k_words) 310 | 311 | if is_garbled(clean_text): 312 | print(f"Failed to retrieve content from {website_url} due to garbled text.") 313 | failed = {"source": website_url, "content": "Failed to retrieve content due to garbled text"} 314 | failed_sites.append(website_url) 315 | return failed, failed_sites, False 316 | 317 | 318 | return {"source": website_url, "content": clean_text_5k}, "N/A", True 319 | 320 | except requests.exceptions.RequestException as e: 321 | print(f"Error retrieving content from {website_url}: {e}") 322 | failed = {"source": website_url, "content": f"Failed to retrieve content due to an error: {e}"} 323 | failed_sites.append(website_url) 324 | return failed, failed_sites, False 325 | 326 | def use_tool(self, plan=None, query=None, visited_sites=[], failed_sites=[]): 327 | 328 | search_queries = self.generate_searches(plan, query) 329 | search_results = self.fetch_search_results(search_queries) 330 | best_page = self.get_search_page(plan, query, search_results, visited_sites=visited_sites) 331 | results_dict, failed_sites, response = self.scrape_website_content(best_page, failed_sites=failed_sites) 332 | 333 | attempts = 0 334 | 335 | while not response and attempts < 5: 336 | print(f"Failed to retrieve content from {best_page}...Trying a different page") 337 | print(f"Failed Sites: {failed_sites}") 338 | best_page = self.get_search_page(plan, query, search_results, failed_sites=failed_sites) 339 | results_dict, failed_sites, response = self.scrape_website_content(best_page) 340 | 341 | attempts += 1 342 | 343 | 344 | if self.verbose: 345 | print(f"Search Engine Query: {search_queries}") 346 | print(colored(f"SEARCH RESULTS {search_results}", 'yellow')) 347 | print(f"BEST PAGE {best_page}") 348 | print(f"Scraping URL: {best_page}") 349 | print(colored(f"RESULTS DICT {results_dict}", 'yellow')) 350 | 351 | return results_dict 352 | 353 | 354 | if __name__ == '__main__': 355 | 356 | search = WebSearcher() 357 | search.use_tool() 358 | --------------------------------------------------------------------------------