├── .gitignore ├── public └── img.png ├── readme.md ├── requirements.txt ├── src ├── base.py ├── config.py ├── config.sample.ini ├── handlers.py ├── llm.py ├── main.py ├── moudles │ ├── __init__.py │ ├── browser.py │ └── request.py ├── sessions.py ├── tools │ ├── __init__.py │ ├── bai_ke.py │ ├── news.py │ ├── program.py │ └── search_engine.py └── utils │ ├── enums.py │ ├── string_process.py │ └── utils.py └── tests ├── test_moudle.py ├── test_tool_baidu.py ├── test_tool_news.py ├── test_tool_program.py └── test_tool_search.py /.gitignore: -------------------------------------------------------------------------------- 1 | git commit -m "first commit"# Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | 4 | # IDEs and editors 5 | .idea/ 6 | *.swp 7 | *~.nfs* 8 | *.bak 9 | *.cache 10 | *.dat 11 | *.db 12 | *.log 13 | *.patch 14 | *.orig.* 15 | *.rej.* 16 | *.tmp.* 17 | 18 | *.mp3 19 | *.wav 20 | 21 | env/ 22 | src/config.ini -------------------------------------------------------------------------------- /public/img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiran214/gpt-func-calling/403b3a264b7e302478cc43b9c97b1ff83c5ff0a8/public/img.png -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # gpt-func-calling 2 | 3 | ![](https://img.shields.io/badge/license-GPL-blue) 4 | 5 | ## 简介 6 | 利用ChatGPT最新的function-calling,实现类似LangChain Agent代理功能,通过tool补充上下文 7 | 8 | ## 安装 9 | ### 环境 10 | - win 10 11 | - python 3.8 12 | - vpn全局代理 13 | ### pip安装依赖 14 | ```shell 15 | git clone https://github.com/jiran214/gpt-func-calling 16 | cd src 17 | # 建议使用命令行或者pycharm创建虚拟环境,参考链接 https://blog.csdn.net/xp178171640/article/details/115950985 18 | python -m pip install --upgrade pip pip 19 | pip install -r .\requirements.txt 20 | ``` 21 | ### 新建config.ini 22 | - src目录下重命名config.sample.ini为config.ini 23 | - 更改api_key和proxy 24 | ## 快速开始 25 | - 运行 >> `cd src` 26 | - 修改 main.py 内容 27 | ```python 28 | 29 | ... 30 | # 新建会话窗口 31 | session = InteractiveSession() 32 | # 添加输出终端 33 | session.add_handler(Shell()) 34 | # 初始化工具 35 | tools = [ 36 | WangYiNews, # 网易新闻 37 | BaiduBaike, # 百度百科 38 | CSDN, # csdn 39 | JueJin, # 掘金 40 | GoogleSearch, # 谷歌 需要在config.ini -> google 配置api,详见 https://zhuanlan.zhihu.com/p/174666017 41 | BingSearch, # Bing 42 | # 持续开发中... 43 | ] 44 | # 创建代理 45 | agent = GPTAgent.from_tools( 46 | tools=tools, 47 | session=session 48 | ) 49 | # 添加到会话窗口 50 | session.get_input() 51 | # 启动代理 52 | agent.forever_run() 53 | ``` 54 | - 运行 >> `python main.py` 55 | - 在终端输入第一个问题 56 | 57 | ### 工具列表 58 | - 百度百科 59 |
60 |
61 |
62 | - 网易新闻 63 | - 百度百科 64 | - csdn 65 | - 掘金 66 | - 谷歌 67 | - Bing 68 | - 正在开发中... 69 | 70 | ## 更新日志 71 | ## to do list 72 | - [x] playwright 引入,简化爬虫 73 | - [ ] 工具输出的内容长度自由选择 74 | - [ ] 对工具输出做summary 75 | - [ ] 交互式对话,保留一定窗口大小 76 | - [ ] 前端界面 77 | ## Contact Me 78 | - 请先star本项目~~ 79 | - **如果你遇到各种问题,请提issues,一般的问题不要加我,感谢理解!** 80 | - 如果你有好的建议和想法,欢迎加我WX:yuchen59384 交流! 81 | 82 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiran214/gpt-func-calling/403b3a264b7e302478cc43b9c97b1ff83c5ff0a8/requirements.txt -------------------------------------------------------------------------------- /src/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from enum import Enum 3 | 4 | from pydantic import BaseModel 5 | 6 | python_type_2_json = { 7 | "str": "string", 8 | "float": "number", 9 | "int": "integer", 10 | "bool": "boolean", 11 | "list": "array", 12 | "dict": "object", 13 | } 14 | 15 | 16 | class ToolModel(BaseModel, ABC): 17 | 18 | def use(self, *args, **kwargs): 19 | raise NotImplementedError 20 | 21 | @classmethod 22 | def gpt_schema(cls, *args, **kwargs): 23 | cached = cls.__schema_cache__.get('function_calling_schema') 24 | if cached is not None: 25 | return cached 26 | model_class = cls 27 | properties = {} 28 | required = [] 29 | for field_name, field in model_class.__fields__.items(): 30 | properties[field_name] = { 31 | 'type': field.type_.__name__, 32 | 'description': field.field_info.description 33 | } 34 | if field.required: 35 | required.append(field_name) 36 | if issubclass(field.type_, Enum): 37 | for type_ in field.type_.mro(): 38 | if type_.__name__ not in {'Enum', 'object'}: 39 | properties[field_name]['type'] = python_type_2_json.get(type_.__name__, type_.__name__) 40 | properties[field_name]['enum'] = list(field.type_.__members__) 41 | else: 42 | properties[field_name]['type'] = python_type_2_json.get(field.type_.__name__, field.type_.__name__) 43 | 44 | schema = { 45 | 'name': model_class.Meta.name, 46 | 'description': model_class.Meta.description, 47 | 'parameters': { 48 | 'type': 'object', 49 | 'properties': properties, 50 | 'required': required, 51 | } 52 | } 53 | cls.__schema_cache__['function_calling_schema'] = schema 54 | return schema 55 | 56 | class Meta: 57 | name = "" # 工具名称 58 | description = "" # 工具描述 59 | 60 | 61 | class Observable(ABC): 62 | 63 | def __init__(self): 64 | self.observer_list = [] 65 | 66 | def add_handler(self, observer): 67 | self.observer_list.append(observer) 68 | 69 | def notify(self, *args, **kwargs): 70 | for observer in self.observer_list: 71 | observer.handle(*args, **kwargs) 72 | 73 | 74 | class Observer(ABC): 75 | def handle(self, *args, **kwargs): 76 | raise NotImplementedError 77 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import os.path 3 | 4 | root_path = os.path.abspath(os.path.dirname(__file__)) 5 | _file_path = os.path.join(root_path, 'config.ini') 6 | 7 | _config = configparser.RawConfigParser() 8 | _config.read(_file_path) 9 | 10 | api_key = _config.get('openai', 'api_key') 11 | proxy = _config.get('openai', 'proxy') 12 | proxies = { 13 | 'http': f'http://{proxy}/', 14 | 'https': f'http://{proxy}/' 15 | } 16 | google_settings = dict(_config.items('google')) 17 | window_size = _config.getint('session', 'window_size') 18 | tool_output_limit = _config.getint('tool', 'output_limit') -------------------------------------------------------------------------------- /src/config.sample.ini: -------------------------------------------------------------------------------- 1 | [openai] 2 | # https://platform.openai.com/account/api-keys 3 | api_key = xxxxxxxxxxxxxxxxxxxxxxxxxx 4 | proxy = 127.0.0.1 5 | 6 | [google] 7 | # 参考https://zhuanlan.zhihu.com/p/174666017 8 | key = xxxxxxxxxxxxxxxxxxxxx 9 | cx = 11111111111111a -------------------------------------------------------------------------------- /src/handlers.py: -------------------------------------------------------------------------------- 1 | import config 2 | from utils.enums import Role 3 | from base import Observer 4 | from termcolor import colored 5 | 6 | from utils.utils import num_tokens_from_string, num_tokens_from_messages 7 | 8 | 9 | class Shell(Observer): 10 | 11 | role_to_color = { 12 | "system": "red", 13 | "user": "green", 14 | "assistant": "blue", 15 | "function": "magenta", 16 | } 17 | 18 | def handle(self, session, message: dict, role: Role): 19 | if role is Role.SYSTEM: 20 | formatted_message = f"[system]: {message['content']}" 21 | elif role is Role.USER: 22 | formatted_message = f"[user]: {message['content']}" 23 | elif role is Role.ASSISTANT and message.get("function_call"): 24 | # "判断用工具可用 25 | func_call = message['function_call'] 26 | formatted_message = f"[assistant] ({func_call['name']}): {str(func_call['arguments'])}" 27 | elif role is Role.ASSISTANT and not message.get("function_call"): 28 | formatted_message = f"[assistant]: {message['content']}" 29 | elif role is Role.FUNCTION: 30 | content = message['content'][:200].replace(' ', '') + '...' 31 | formatted_message = f"[function] ({message['name']}): {content}" 32 | elif role is None: 33 | formatted_message = f'异常message: {message}' 34 | else: 35 | formatted_message = f'异常message: {message}' 36 | print( 37 | colored( 38 | formatted_message, 39 | self.role_to_color[role], 40 | force_color=True 41 | ) 42 | ) 43 | 44 | 45 | class SlidingWindowHandler: 46 | """滑动窗口,控制tk数量""" 47 | max_window_size = config.window_size 48 | 49 | def __init__(self): 50 | self.num_tk_list = [] 51 | self.current_window_size = 0 52 | 53 | def handle(self, session, message: dict, role: Role): 54 | num_tokens = 0 55 | for key, value in message.items(): 56 | if value: 57 | num_tokens += num_tokens_from_string(str(value)) 58 | self.num_tk_list.append(num_tokens) 59 | self.current_window_size += num_tokens 60 | while 1: 61 | if self.current_window_size > self.max_window_size: 62 | # 从左开始第一个非system角色的message删掉,直到低于max_window_size 63 | for index, message in enumerate(session.message_list): 64 | role = Role(message['role']) 65 | if role is not Role.SYSTEM: 66 | num_token = self.num_tk_list.pop(index) 67 | session.message_list.pop(index) 68 | self.current_window_size -= num_token 69 | else: 70 | break 71 | -------------------------------------------------------------------------------- /src/llm.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import signal 4 | import sys 5 | 6 | from typing import Iterable, Dict, List, Type 7 | 8 | import openai 9 | from termcolor import colored 10 | 11 | import config 12 | from tenacity import retry, wait_random_exponential, stop_after_attempt 13 | 14 | from utils.enums import Role 15 | from sessions import Session, InteractiveMixin 16 | from base import ToolModel 17 | 18 | openai.api_key = config.api_key 19 | openai.proxy = config.proxies 20 | 21 | # os.environ["http_proxy"] = f'http://{config.proxy}/' 22 | # os.environ["https_proxy"] = f'http://{config.proxy}/' 23 | 24 | 25 | class GPTAgent: 26 | 27 | def __init__(self, tool_map: Dict[str, Type[ToolModel]], functions: List[dict], session: Session): 28 | self.tool_map = tool_map 29 | self.functions = functions 30 | self.function_call = "auto" 31 | self.model = "gpt-3.5-turbo-0613" 32 | self.session = session 33 | 34 | @classmethod 35 | def from_tools(cls, tools: Iterable[Type[ToolModel]], session: Session): 36 | tool_map = {} 37 | gpt_schema_list = [] 38 | for tool in tools: 39 | tool_map[tool.Meta.name] = tool 40 | gpt_schema_list.append(tool.gpt_schema()) 41 | return cls( 42 | tool_map=tool_map, 43 | functions=gpt_schema_list, 44 | session=session 45 | ) 46 | 47 | @retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3)) 48 | def _generate(self): 49 | try: 50 | response = openai.ChatCompletion.create( 51 | model=self.model, 52 | messages=self.session.message_list, 53 | functions=self.functions, 54 | function_call=self.function_call, 55 | ) 56 | return response 57 | except Exception as e: 58 | raise e 59 | 60 | def run(self): 61 | # Step 1, send model the user query and what functions it has access to 62 | response = self._generate() 63 | message = response["choices"][0]["message"] 64 | self.session.add_message(message) 65 | 66 | # Step 2, check if the model wants to call a function 67 | if message.get("function_call"): 68 | function_name = message["function_call"]["name"] 69 | function_args = json.loads(message["function_call"]["arguments"]) 70 | 71 | if function_name not in self.tool_map: 72 | raise 'tool_map key 不等于 Meta name' 73 | 74 | # Step 3, call the function 75 | tool = self.tool_map[function_name](**function_args) 76 | function_response = tool.use() 77 | # Step 4, send model the info on the function call and function response 78 | message = { 79 | "role": Role.FUNCTION.value, 80 | "name": function_name, 81 | "content": function_response, 82 | } 83 | self.session.add_message(message) 84 | second_response = self._generate() 85 | self.session.add_message(second_response["choices"][0]["message"]) 86 | 87 | self.session.clear() 88 | 89 | def forever_run(self): 90 | def sig_handler(signum, frame): 91 | is_exit = True 92 | sys.exit(0) 93 | 94 | signal.signal(signal.SIGINT, sig_handler) 95 | signal.signal(signal.SIGTERM, sig_handler) 96 | try: 97 | while 1: 98 | self.run() 99 | if isinstance(self.session, InteractiveMixin): 100 | self.session.get_input() 101 | except KeyboardInterrupt as e: 102 | print(colored('goodbye', 'yellow', force_color=True)) 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import config 2 | from utils.enums import Role 3 | from handlers import Shell, SlidingWindowHandler 4 | from llm import GPTAgent 5 | from sessions import Session, InteractiveSession 6 | from tools import ( 7 | WangYiNews, 8 | BaiduBaike, 9 | CSDN, 10 | JueJin, 11 | GoogleSearch, 12 | BingSearch, 13 | ) 14 | 15 | # 新建会话窗口 16 | session = InteractiveSession() 17 | # 添加输出终端 18 | session.add_handler(Shell()) 19 | if config.window_size: 20 | session.add_handler(SlidingWindowHandler()) 21 | # 初始化工具 22 | tools = [ 23 | WangYiNews, # 网易新闻 24 | BaiduBaike, # 百度百科 25 | CSDN, # csdn 26 | JueJin, # 掘金 27 | GoogleSearch, # 谷歌 需要在config.ini -> google 配置api,详见 https://zhuanlan.zhihu.com/p/174666017 28 | BingSearch, # Bing 29 | # 持续开发中... 30 | ] 31 | # 创建代理 32 | agent = GPTAgent.from_tools( 33 | tools=tools, 34 | session=session 35 | ) 36 | # 添加到会话窗口 37 | session.get_input() 38 | # 启动代理 39 | agent.forever_run() 40 | -------------------------------------------------------------------------------- /src/moudles/__init__.py: -------------------------------------------------------------------------------- 1 | from moudles.request import session 2 | from moudles.browser import PlaywrightOperate 3 | 4 | __all__ = [ 5 | 'session', 6 | 'PlaywrightOperate' 7 | ] -------------------------------------------------------------------------------- /src/moudles/browser.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import queue 3 | 4 | import threading 5 | import time 6 | from contextlib import asynccontextmanager 7 | from typing import Callable 8 | from urllib.parse import urlparse 9 | from playwright.async_api import async_playwright 10 | 11 | 12 | class PlaywrightBrowser: 13 | _browser = None 14 | _context = None 15 | 16 | @asynccontextmanager 17 | async def init_browser_context(self): 18 | """ 19 | 初始化全局浏览器上下文对象 20 | """ 21 | if self._context: 22 | yield self._context 23 | async with async_playwright() as playwright: 24 | _browser = await playwright.chromium.launch(headless=False, slow_mo=50) 25 | self._context = await _browser.new_context() 26 | yield self._context 27 | 28 | async def goto(self, url: str) -> str: 29 | 30 | ctx = self.init_browser_context() 31 | context = await ctx.__aenter__() 32 | 33 | page = await context.new_page() 34 | await page.goto(url) 35 | source = await page.content() 36 | # await page.close() 37 | 38 | return source 39 | 40 | async def close(self): 41 | await self._context.close() 42 | 43 | 44 | _browser_queue = queue.Queue() 45 | 46 | 47 | async def browser_coroutine(): 48 | browser = PlaywrightBrowser() 49 | while 1: 50 | if _browser_queue.empty(): 51 | time.sleep(2) 52 | # _browser_queue.task_done() 53 | continue 54 | else: 55 | item = _browser_queue.get() 56 | 57 | if item == 'over': 58 | await browser.close() 59 | else: 60 | url, store = item 61 | content = await browser.goto(url) 62 | setattr(store, str(hash(url)), content) 63 | _browser_queue.task_done() 64 | 65 | 66 | def run_coroutine_in_thread(): 67 | asyncio.set_event_loop(asyncio.new_event_loop()) 68 | loop = asyncio.get_event_loop() 69 | loop.run_until_complete(browser_coroutine()) 70 | 71 | 72 | # todo 暂时不启用,等tool manager完成,管理不同工具的依赖 73 | if False: 74 | _browser_thread = threading.Thread(target=run_coroutine_in_thread) 75 | _browser_thread.start() 76 | 77 | 78 | class Store: 79 | ... 80 | 81 | 82 | class PlaywrightOperate: 83 | 84 | @staticmethod 85 | def put_task(url: str): 86 | _browser_queue.put((url, Store)) 87 | 88 | @staticmethod 89 | def end(): 90 | _browser_queue.put('over') 91 | 92 | @staticmethod 93 | def wait(url): 94 | store_key = str(hash(url)) 95 | _browser_queue.join() 96 | content = getattr(Store, store_key) 97 | delattr(Store, store_key) 98 | return content 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /src/moudles/request.py: -------------------------------------------------------------------------------- 1 | from requests import session 2 | 3 | session = session() 4 | session.trust_env = False 5 | 6 | 7 | -------------------------------------------------------------------------------- /src/sessions.py: -------------------------------------------------------------------------------- 1 | 2 | from termcolor import colored 3 | 4 | from base import Observable 5 | from utils.enums import Role 6 | 7 | 8 | class Session(Observable): 9 | 10 | def __init__(self): 11 | super().__init__() 12 | self.message_list = [] 13 | 14 | def clear(self): 15 | self.message_list = [] 16 | 17 | def add_message(self, message: dict, *arg, **kwargs): 18 | self.message_list.append(message) 19 | role = Role(message["role"]) 20 | self.notify(self, message, role, *arg, **kwargs) 21 | 22 | def add_extra_message(self, message: dict, *arg, **kwargs): 23 | self.notify(self, message, None, *arg, **kwargs) 24 | 25 | 26 | class InteractiveMixin: 27 | """交互式""" 28 | 29 | def get_input(self): 30 | message = { 31 | 'role': Role.USER.value, 32 | 'content': input( 33 | colored('\n请输入prompt: ', 'green', force_color=True) 34 | ) # 在这输入第一个问题 35 | } 36 | getattr(self, 'add_message')(message) 37 | 38 | 39 | class InteractiveSession(Session, InteractiveMixin): 40 | """交互式Session""" 41 | 42 | pass 43 | 44 | 45 | class MemoryMixin: 46 | """记忆混入""" 47 | 48 | @classmethod 49 | def from_memory(cls): 50 | ... 51 | 52 | def save_memory(self): 53 | ... 54 | 55 | -------------------------------------------------------------------------------- /src/tools/__init__.py: -------------------------------------------------------------------------------- 1 | from tools.bai_ke import BaiduBaike 2 | from tools.news import WangYiNews 3 | from tools.program import JueJin, CSDN 4 | from tools.search_engine import GoogleSearch, BingSearch 5 | 6 | __all__ = [ 7 | 'WangYiNews', 8 | 'BaiduBaike', 9 | 'CSDN', 10 | 'JueJin', 11 | 'GoogleSearch', 12 | 'BingSearch', 13 | ] -------------------------------------------------------------------------------- /src/tools/bai_ke.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | import requests 3 | from parsel import Selector 4 | from pydantic import Field 5 | 6 | from moudles import session 7 | from utils.string_process import filter_html 8 | from base import ToolModel 9 | 10 | 11 | class TemperatureUnit(str, Enum): 12 | celsius = 'celsius' 13 | fahrenheit = 'fahrenheit' 14 | 15 | 16 | headers = { 17 | 'User-Agent': """Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36 Edg/114.0.1823.43""", 18 | 'cookie': """zhishiTopicRequestTime=1686930411657; BAIKE_SHITONG=%7B%22data%22%3A%220f0cbad1e8052f10878993106a4bd1cc2b15888dc99fba71b59334d6494023de22a35fc6aa9a2804e9d63dec196399f86245fb7018b1614827aac90d56d32a5eacab26d5043e8f718a3be37e087b7d5d307ead060ab590ab2847e9615f7211d8%22%2C%22key_id%22%3A%2210%22%2C%22sign%22%3A%2260fcc606%22%7D; BAIDUID=713CB59BDB474EE9AFCC4E0CCEF4EFDF:FG=1; BIDUPSID=713CB59BDB474EE9AFCC4E0CCEF4EFDF; PSTM=1684554678; newlogin=1; BAIDUID_BFESS=713CB59BDB474EE9AFCC4E0CCEF4EFDF:FG=1; ZFY=fNTO5QaO8MX8m1r49ef3AzbGZd4uFmZHDMCXcxBHMbE:C; __bid_n=18837241abd266edd64207; BAIDU_WISE_UID=wapp_1684662096824_212; BDUSS=gtS2p-d2hlNm1vbVFSfm40MHVlanhmUDRKbVdrRWo2SkZHMXJRRi10ZUZ4cDVrRVFBQUFBJCQAAAAAAAAAAAEAAADKfHhasKGwobbuztLIpcilAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIU5d2SFOXdkST; BDUSS_BFESS=gtS2p-d2hlNm1vbVFSfm40MHVlanhmUDRKbVdrRWo2SkZHMXJRRi10ZUZ4cDVrRVFBQUFBJCQAAAAAAAAAAAEAAADKfHhasKGwobbuztLIpcilAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAIU5d2SFOXdkST; FPTOKEN=H/bnHT3J9MbygOPx3ofikribhOYmoHVzKqCzq/IGsybItDzqzSjWhez2lu7sHXD27iCsK06eQRxxZv9a/PQ5in7ln6QQpHho7Su8MQcA+lbRQuWaNCvkSIyQhEhmLidlBgMrXcNduftoBvAn09+H2Wp/ktl33j5CUo89C1DCA/dqD6ekrVEAnyBjxTR7n9WKC5aujx9ikgYgwB9z2Byjg3FzSHpTAKb/MPBuZaR3/0Igu9i3KPXA1lywFWKNKohkhhbYoB4IphZluBa85n/pK972HLyhzV3za7XnptqfR1tuW2J65gcYsFtJmsRVH2f1VsD1pibb/LCzKjM99ZzMruRYYlbLVPLyd/RHa4rakixaY9q7gwSPPzEUqeL0jAVAYrF90BJ6Myvg8F15ShzAJg==|9zXV1+87HTmz/CABUmTR38iSsf/eXf1+QqrNMbhB7dQ=|10|cb4a17febd4adba622cbbfd12afe1553; Hm_lvt_55b574651fcae74b0a9f1cf9c8d7c93a=1686845869,1686927810; zhishiTopicRequestTime=1686927813308; BCLID=12201278950458286643; BDSFRCVID=XL-OJeC626QNXsrfZhMnhwS0QqmGvpRTH6_vGfi489_Mbk2FNcQBEG0PDU8g0KA-8pxBogKKKgOTHICF_2uxOjjg8UtVJeC6EG0Ptf8g0M5; H_BDCLCKID_SF=fRAfoC-Mf-JEJb51q4o-bJD8KpOJK4J3HDo-LIv9BT6cOR5Jj6K-0fCRKp5hXtvuaDbfbKJl-R6nf4J-3MA-BnK1bxuJqTcdBCrWoqQkKMjIsq0x0MOle-bQyPLLqnOO0DOMahvc5h7xOhTJQlPK5JkgMx6MqpQJQeQ-5KQN3KJmfbL9bT3YjjISKx-_J6kJfRRP; BCLID_BFESS=12201278950458286643; BDSFRCVID_BFESS=XL-OJeC626QNXsrfZhMnhwS0QqmGvpRTH6_vGfi489_Mbk2FNcQBEG0PDU8g0KA-8pxBogKKKgOTHICF_2uxOjjg8UtVJeC6EG0Ptf8g0M5; H_BDCLCKID_SF_BFESS=fRAfoC-Mf-JEJb51q4o-bJD8KpOJK4J3HDo-LIv9BT6cOR5Jj6K-0fCRKp5hXtvuaDbfbKJl-R6nf4J-3MA-BnK1bxuJqTcdBCrWoqQkKMjIsq0x0MOle-bQyPLLqnOO0DOMahvc5h7xOhTJQlPK5JkgMx6MqpQJQeQ-5KQN3KJmfbL9bT3YjjISKx-_J6kJfRRP; BK_SEARCHLOG=%7B%22key%22%3A%5B%22win10%20%E5%85%B3%E9%97%AD%E8%80%81%E6%9D%BF%E9%94%AE%22%2C%22%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%22%2C%22%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E7%AE%97%E6%B3%95sww%22%2C%22%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E7%AE%97%E6%B3%95%22%2C%22%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E7%AE%97%E6%B3%95a%22%2C%22%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E9%98%BF%E8%90%A8%22%5D%7D; X_ST_FLOW=0; baikeVisitId=ecd738fe-c4b5-41b8-9959-088057c1ead6; Hm_lpvt_55b574651fcae74b0a9f1cf9c8d7c93a=1686930698; ab_sr=1.0.1_OWNlNDVhZTIxY2IzMzgyMjExNzg5MjYwOWZlYzFjZWU5NDA0MGI1ODE0MWFlMDNjMjhmZTQxZWUwYTNjMWUzMDAzN2FjODZhNGRiYTA3YjdlOTIwODkyM2Y5OGIzYzhhMDVmNDA5OGI0MGE3OTEwOTdjMDc3ZmIzN2RhMThiOWNjNDZiYzM3Yjk5YTIwZTA1NTg0OTkzOTIxMzY2ZWRjN2RlNzgzZWJkYjVkZDE0OWNmODJkYzczMjM5OGM0MGE5; RT="z=1&dm=baidu.com&si=125fd020-cb08-46c9-ad46-dae9eab25096&ss=liyp9d4b&sl=11&tt=1611&bcn=https%3A%2F%2Ffclog.baidu.com%2Flog%2Fweirwood%3Ftype%3Dperf&ld=1q1v4&ul=1q7o4\"""" 19 | } 20 | 21 | 22 | class BaiduBaike(ToolModel): 23 | search_input: str = Field(description='百度百科搜索输入') 24 | 25 | def use(self): 26 | url = f"https://baike.baidu.com/api/searchui/suggest?wd={self.search_input}&enc=utf8" 27 | r = session.get(url) 28 | data_list = r.json()['list'] 29 | if not data_list: 30 | return '未找到相关结果' 31 | lemma_title = data_list[0]['lemmaTitle'] 32 | lemma_id = data_list[0]['lemmaId'] 33 | new_url = f"https://baike.baidu.com/item/{lemma_title}/{lemma_id}" 34 | r = session.get(new_url, headers=headers) 35 | sl = Selector(text=r.text) 36 | text_list = sl.xpath("""//div[@class='main-content J-content']//text()""").getall() 37 | if not data_list: 38 | return '遇到反爬' 39 | return filter_html(text_list[59:]) 40 | 41 | class Meta: 42 | name = "get_baidu_baike_info" 43 | description = "百度百科是一部内容开放、自由的网络百科全书,旨在创造一个涵盖所有领域知识,服务所有互联网用户的中文知识性百科全书。" 44 | 45 | 46 | -------------------------------------------------------------------------------- /src/tools/news.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from parsel import Selector 3 | from pydantic import Field 4 | 5 | from base import ToolModel 6 | from moudles import session 7 | from utils.string_process import filter_html 8 | 9 | 10 | headers = { 11 | 'User-Agent': """Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36 Edg/114.0.1823.43""", 12 | } 13 | 14 | 15 | class WangYiNews(ToolModel): 16 | search_input: str = Field(description='网易新闻内容搜索输入') 17 | 18 | def use(self): 19 | url = f"""https://www.163.com/search?keyword={self.search_input}""" 20 | r = session.get(url, headers=headers) 21 | sl = Selector(text=r.text) 22 | next_url = sl.xpath("""//div[@class="keyword_list "]/div[1]//div[@class="keyword_img"]/a/@href""").get() 23 | if not next_url: 24 | return '未找到相关结果' 25 | r = session.get(next_url, headers=headers) 26 | sl = Selector(text=r.text) 27 | test_list = sl.xpath("""//div[@class="post_body"]//text()""").getall() 28 | return filter_html(test_list) 29 | 30 | class Meta: 31 | name = "get_163_news_results" 32 | description = ( 33 | "网易是中国领先的互联网技术公司,为用户提供免费邮箱、游戏、搜索引擎服务,开设新闻、娱乐、体育等30多个内容频道,及博客、视频、论坛等互动交流,网聚人的力量。" 34 | ) 35 | 36 | -------------------------------------------------------------------------------- /src/tools/program.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from pprint import pprint 3 | 4 | import requests 5 | from parsel import Selector 6 | from pydantic import Field 7 | 8 | from utils.string_process import filter_html 9 | from base import ToolModel 10 | 11 | headers = { 12 | 'User-Agent': """Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36 Edg/114.0.1823.43""", 13 | # 'referer': """https://blog.csdn.net/""", 14 | 'cookie': """Hm_lvt_e5ef47b9f471504959267fd614d579cd=1685848574; https_waf_cookie=af5fc8c6-fc81-4038bab001f5bb3900573931669bb9158ef2; uuid_tt_dd=10_19421407380-1686948782541-741135; dc_session_id=10_1686948782541.324157; dc_sid=b2a67b61c1e3cd4d4c7edd1f0de87920; https_ydclearance=fadabedfa3cac060ee94487d-ba96-4aa5-8192-3d71b72e5bfc-1686956026""" 15 | } 16 | 17 | 18 | class CSDN(ToolModel): 19 | search_input: str = Field(description='CSDN论坛搜索输入') 20 | 21 | def use(self): 22 | url = f"https://so.csdn.net/api/v3/search?q={self.search_input}&t=blog&p=1&s=0&tm=0&lv=-1&ft=0&l=&u=&ct=-1" \ 23 | f"&pnt=-1&ry=-1&ss=-1&dct=-1&vco=-1&cc=-1&sc=-1&akt=-1&art=-1&ca=-1&prs=&pre=&ecc=-1&ebc=-1&ia=1&dId" \ 24 | f"=&cl=-1&scl=-1&tcl=-1&platform=pc&ab_test_code_overlap=&ab_test_random_code=" 25 | r = requests.get(url, headers=headers) 26 | next_url = r.json()['result_vos'][0]['url'] 27 | if not next_url: 28 | return '未找到相关结果' 29 | r = requests.get(next_url, headers=headers) 30 | sl = Selector(text=r.text) 31 | text_list = sl.xpath("""//div[@id='article_content']//text()""").getall() 32 | if not text_list: 33 | raise '遇到反爬' 34 | return filter_html(text_list) 35 | 36 | class Meta: 37 | name = "get_csdn_blog_info" 38 | description = "CSDN是全球知名中文IT技术交流平台,包含原创博客、精品问答、技术论坛等产品服务,提供原创、优质、完整内容的专业IT技术开发社区." 39 | 40 | 41 | class JueJin(ToolModel): 42 | search_input: str = Field(description='掘金论坛搜索输入') 43 | 44 | def use(self): 45 | url = f"""https://api.juejin.cn/search_api/v1/search?spider=0&query={self.search_input}&id_type=0&cursor=0&limit=20&search_type=0&sort_type=0&version=1""" 46 | r = requests.get(url) 47 | article_id = r.json()['data'][0]['result_model']['article_id'] 48 | next_url = f'https://juejin.cn/post/{article_id}' 49 | if not next_url: 50 | return '未找到相关结果' 51 | r = requests.get(next_url) 52 | sl = Selector(text=r.text) 53 | text_list = sl.xpath("""//div[@itemprop="articleBody"]/div/*[position()>2]//text()""").getall() 54 | if not text_list: 55 | raise '遇到反爬' 56 | return filter_html(text_list) 57 | 58 | class Meta: 59 | name = "get_juejin_blog_info" 60 | description = "掘金是面向全球中文开发者的技术内容分享与交流平台。我们通过技术文章、沸点、课程、直播等产品和服务,打造综合类技术社区。" 61 | -------------------------------------------------------------------------------- /src/tools/search_engine.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from parsel import Selector 3 | from pydantic import Field 4 | 5 | import config 6 | from base import ToolModel 7 | from config import google_settings 8 | from moudles import session 9 | 10 | 11 | class GoogleSearch(ToolModel): 12 | search_input: str = Field(description='谷歌搜索输入') 13 | 14 | def use(self): 15 | url = f"""https://www.googleapis.com/customsearch/v1?key={google_settings['key']}&cx={google_settings['cx']}&q={self.search_input}""" 16 | r = session.get(url, proxies=config.proxies) 17 | if not r: 18 | return '未找到相关结果' 19 | json_list = r.json()['items'][:4] 20 | results = [] 21 | for json_data in json_list: 22 | data_dict = { 23 | 'title': json_data['title'], 24 | 'link': json_data['link'], 25 | 'snippet': json_data['snippet'], 26 | 'html_snippet': json_data['htmlSnippet'] 27 | } 28 | results.append(data_dict) 29 | return str(results) 30 | 31 | class Meta: 32 | name = "get_google_search_results" 33 | description = ( 34 | "A wrapper around Google Search. " 35 | "Useful for when you need to answer questions about current events. " 36 | ) 37 | 38 | 39 | class BingSearch(ToolModel): 40 | 41 | search_input: str = Field(description='bing搜索输入') 42 | 43 | def use(self): 44 | url = f"""https://cn.bing.com/search?q={self.search_input}&aqs=edge.2.69i64i450l8.175106209j0j1&FORM=ANAB01&PC=HCTS""" 45 | r = session.get(url) 46 | if not r: 47 | return '未找到相关结果' 48 | html = r.text 49 | sl = Selector(text=html) 50 | items = sl.xpath("""//li[@class='b_algo']""") 51 | results = [] 52 | for item in items: 53 | data_dict = { 54 | 'title': ' '.join(item.xpath("""./div[1]//h2//text()""").getall()), 55 | 'link': item.xpath("""./div[1]//a/@href""").get(), 56 | 'snippet': ' '.join(item.xpath("""./div[2]//p//text()""").getall()), 57 | # 'html_snippet': json_data['htmlSnippet'] 58 | } 59 | results.append(data_dict) 60 | 61 | if len(results) == 7: 62 | break 63 | return str(results) 64 | 65 | class Meta: 66 | name = "get_bing_search_results" 67 | description = "通过必应的智能搜索,可以更轻松地快速查找所需内容并获得奖励。" 68 | -------------------------------------------------------------------------------- /src/utils/enums.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class Role(str, Enum): 5 | SYSTEM = 'system' 6 | USER = 'user' 7 | ASSISTANT = 'assistant' 8 | FUNCTION = 'function' 9 | 10 | -------------------------------------------------------------------------------- /src/utils/string_process.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import config 4 | from utils.utils import num_tokens_from_string 5 | 6 | Limit = config.tool_output_limit 7 | 8 | 9 | def filter_html(text_list: List[str], length_func=num_tokens_from_string): 10 | """去除多余换行,限制总长度在""" 11 | new_text_list = [] 12 | current_len = 0 13 | is_newline_character_last = False 14 | for text in text_list: 15 | if text in {'\n', '▪', '\xa0'} or text.startswith('\n['): 16 | if is_newline_character_last is False: 17 | new_text_list.append(' ') 18 | current_len += length_func(text) 19 | is_newline_character_last = True 20 | else: 21 | is_newline_character_last = False 22 | new_text_list.append(text) 23 | current_len += length_func(text) 24 | 25 | if current_len > Limit: 26 | new_text_list.pop() 27 | break 28 | return ''.join(new_text_list) 29 | 30 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | def num_tokens_from_string(string: str, encoding_name: str = 'cl100k_base') -> int: 2 | """cl100k_base gpt-4, gpt-3.5-turbo, text-embedding-ada-002""" 3 | import tiktoken 4 | """Returns the number of tokens in a text string.""" 5 | encoding = tiktoken.get_encoding(encoding_name) 6 | # encoding = tiktoken.get_encoding(encoding_name) 7 | num_tokens = len(encoding.encode(string)) 8 | return num_tokens 9 | 10 | 11 | def num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613") -> int: 12 | """cl100k_base gpt-4, gpt-3.5-turbo, text-embedding-ada-002""" 13 | """Returns the number of tokens used by a list of messages.""" 14 | import tiktoken 15 | try: 16 | encoding = tiktoken.encoding_for_model(model) 17 | except KeyError: 18 | print("Warning: model not found. Using cl100k_base encoding.") 19 | encoding = tiktoken.get_encoding("cl100k_base") 20 | # if model == "gpt-3.5-turbo": 21 | # print("Warning: gpt-3.5-turbo may change over time. Returning num tokens assuming gpt-3.5-turbo-0301.") 22 | # return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301") 23 | # elif model == "gpt-4": 24 | # print("Warning: gpt-4 may change over time. Returning num tokens assuming gpt-4-0314.") 25 | # return num_tokens_from_messages(messages, model="gpt-4-0314") 26 | # elif model == "gpt-3.5-turbo-0301": 27 | # tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n 28 | # tokens_per_name = -1 # if there's a name, the role is omitted 29 | # elif model == "gpt-4-0314": 30 | # tokens_per_message = 3 31 | # tokens_per_name = 1 32 | # else: 33 | # raise NotImplementedError( 34 | # f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""") 35 | num_tokens = 0 36 | for message in messages: 37 | # num_tokens += tokens_per_message 38 | for key, value in message.items(): 39 | num_tokens += len(encoding.encode(value)) 40 | # if key == "name": 41 | # num_tokens += tokens_per_name 42 | num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> 43 | return num_tokens -------------------------------------------------------------------------------- /tests/test_moudle.py: -------------------------------------------------------------------------------- 1 | from moudles import browser 2 | 3 | 4 | def test_browser(): 5 | print(1) 6 | res = browser.goto('https://www.baidu.com') 7 | print(res) 8 | # print(res) -------------------------------------------------------------------------------- /tests/test_tool_baidu.py: -------------------------------------------------------------------------------- 1 | from src.tools.bai_ke import BaiduBaike 2 | 3 | 4 | def test_baidubaike(): 5 | res = BaiduBaike(search_input='蔡徐坤').use() 6 | print(res) 7 | assert len(res) > 100 -------------------------------------------------------------------------------- /tests/test_tool_news.py: -------------------------------------------------------------------------------- 1 | from tools import WangYiNews 2 | 3 | 4 | def test_163(): 5 | res = WangYiNews(search_input='明星').use() 6 | print(res) 7 | assert len(res) > 100 -------------------------------------------------------------------------------- /tests/test_tool_program.py: -------------------------------------------------------------------------------- 1 | from tools.program import CSDN, JueJin 2 | 3 | 4 | def test_csdn(): 5 | res = CSDN(search_input='python').use() 6 | print(res) 7 | assert len(res) > 100 8 | 9 | 10 | def test_juejin(): 11 | res = JueJin(search_input='python').use() 12 | print(res) 13 | assert len(res) > 100 -------------------------------------------------------------------------------- /tests/test_tool_search.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import requests 4 | from requests import session 5 | 6 | import config 7 | from tools.search_engine import GoogleSearch, BaiduSearch, BingSearch 8 | 9 | os.environ["http_proxy"] = f'http://{config.proxy}/' 10 | os.environ["https_proxy"] = f'http://{config.proxy}/' 11 | 12 | 13 | def test_google(): 14 | res = GoogleSearch(search_input='蔡徐坤').use() 15 | print(res) 16 | assert len(res) > 200 17 | 18 | 19 | def test_baidu(): 20 | res = BaiduSearch(search_input='蔡徐坤').use() 21 | print(res) 22 | assert len(res) > 200 23 | 24 | 25 | def test_bing(): 26 | res = BingSearch(search_input='蔡徐坤').use() 27 | print(res) 28 | assert len(res) > 200 --------------------------------------------------------------------------------