├── .gitignore ├── CMakeLists.txt ├── Kconfig ├── README.md ├── assets ├── chat_with_llm.mp4 └── works.png ├── example └── test │ ├── CMakeLists.txt │ ├── idf_component.yml │ └── main.c ├── homeagent.py ├── idf_component.yml ├── include ├── HomeRPC.h ├── rpc_data.h ├── rpc_log.h ├── rpc_mdns.h ├── rpc_mesh.h └── rpc_mqtt.h ├── license.txt ├── qwen_agent ├── __init__.py ├── agent.py ├── agents │ ├── __init__.py │ ├── article_agent.py │ ├── assistant.py │ ├── docqa_agent.py │ ├── fncall_agent.py │ ├── group_chat.py │ ├── group_chat_auto_router.py │ ├── group_chat_creator.py │ ├── react_chat.py │ ├── router.py │ ├── user_agent.py │ └── write_from_scratch.py ├── llm │ ├── __init__.py │ ├── base.py │ ├── function_calling.py │ ├── oai.py │ ├── qwen_dashscope.py │ ├── qwenvl_dashscope.py │ ├── schema.py │ └── text_base.py ├── log.py ├── memory │ ├── __init__.py │ └── memory.py ├── prompts │ ├── __init__.py │ ├── continue_writing.py │ ├── doc_qa.py │ ├── expand_writing.py │ ├── gen_keyword.py │ └── outline_writing.py ├── tools │ ├── __init__.py │ ├── amap_weather.py │ ├── base.py │ ├── code_interpreter.py │ ├── doc_parser.py │ ├── image_gen.py │ ├── resource │ │ ├── AlibabaPuHuiTi-3-45-Light.ttf │ │ ├── code_interpreter_init_kernel.py │ │ └── image_service.py │ ├── retrieval.py │ ├── similarity_search.py │ ├── storage.py │ └── web_extractor.py └── utils │ ├── __init__.py │ ├── doc_parser.py │ ├── qwen.tiktoken │ ├── tokenization_qwen.py │ └── utils.py ├── requirements.txt ├── server.py ├── server ├── __init__.py ├── base.py ├── broker.py ├── log.py └── rule.py └── src ├── home_rpc.c ├── rpc_data.c ├── rpc_log.c ├── rpc_mdns.c ├── rpc_mesh.c └── rpc_mqtt.c /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__/ -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | file(GLOB SOURCES "src/*.c") 2 | 3 | idf_component_register(SRCS "${SOURCES}" 4 | INCLUDE_DIRS "include" 5 | REQUIRES mdns esp_wifi nvs_flash esp_netif esp_event iot_bridge mesh_lite mqtt 6 | ) 7 | -------------------------------------------------------------------------------- /Kconfig: -------------------------------------------------------------------------------- 1 | menu "HomeRPC configuration" 2 | 3 | config ROUTER_SSID 4 | string "Router SSID" 5 | default "ROUTER_SSID" 6 | help 7 | Router SSID. 8 | 9 | config ROUTER_PASSWORD 10 | string "Router password" 11 | default "ROUTER_PASSWORD" 12 | help 13 | Router password. 14 | 15 | config BROKER_URL 16 | string "Broker URL" 17 | default "homerpc" 18 | help 19 | Broker URL. 20 | 21 | config PARAMS_MAX 22 | int "Params max" 23 | default 10 24 | help 25 | Params max. 26 | 27 | config TOPIC_LEN_MAX 28 | int "Topic len max" 29 | default 128 30 | help 31 | Topic len max. 32 | 33 | endmenu -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # *HomeRPC* 4 | 5 | [![FreeRTOS](https://img.shields.io/badge/OS-FreeRTOS-brightgreen)](https://www.freertos.org/) 6 | [![ESP-IDF](https://img.shields.io/badge/SDK-ESP--IDF-blue)](https://docs.espressif.com/projects/esp-idf/en/latest/esp32/) 7 | [![MQTT](https://img.shields.io/badge/Protocol-MQTT-orange)](https://mqtt.org/) 8 | [![License](https://img.shields.io/badge/License-MIT-blue.svg)](./LICENSE) 9 | 10 | ### **适用于智能家居场景的嵌入式 RPC 框架** 11 | 12 |
13 | 14 | ## 🚀 项目介绍 15 | 16 | **`HomeRPC`** 是一个面向智能家居场景的嵌入式 RPC 框架,旨在提供一种简单、高效的远程调用解决方案,非常适合与 **大语言模型** 进行结合,实现智能家居场景下的对话式交互。 17 | 18 | ## 🎨 效果展示 19 | 20 | https://github.com/guidons-master/HomeRPC/assets/41904603/89fafaa5-8a3b-4159-9780-6227521c16d5 21 | 22 | ```bash 23 | streamlit run homeagent.py 24 | ``` 25 | 26 | ## ⚙️ 实现原理 27 | 28 | ![](./assets/works.png) 29 | 30 | ## 🛠️ 使用说明 31 | 32 | `HomeRPC` 支持以下基本数据类型作为命令参数: 33 | 34 | | 类型 | 签名 | 示例 | 35 | | ----------------------- | ---- | ----- | 36 | | char(字符) | c | 'a' | 37 | | short、int、long(数字) | i | 123 | 38 | | float(单精度浮点数) | f | 3.14 | 39 | | double(双精度浮点数) | d | 3.141 | 40 | 41 | 其中 `rpc_any_t` 类型的定义如下: 42 | 43 | ``` 44 | typedef union { 45 | char c; 46 | unsigned char uc; 47 | short s; 48 | unsigned short us; 49 | int i; 50 | unsigned int ui; 51 | long l; 52 | float f; 53 | double d; 54 | } __attribute__((aligned(1))) rpc_any_t; 55 | ``` 56 | 57 | ### 📚 示例代码 58 | 59 | `ESP32` 客户端示例代码如下: 60 | ```c 61 | // file: main.c 62 | #include "freertos/FreeRTOS.h" 63 | #include "driver/gpio.h" 64 | #include "HomeRPC.h" 65 | 66 | #define BLINK_GPIO 2 67 | 68 | static uint8_t s_led_state = 0; 69 | 70 | static void configure_led(void) { 71 | gpio_reset_pin(BLINK_GPIO); 72 | gpio_set_direction(BLINK_GPIO, GPIO_MODE_OUTPUT); 73 | } 74 | 75 | // 触发LED 76 | static rpc_any_t trigger_led(rpc_any_t state) { 77 | gpio_set_level(BLINK_GPIO, state.uc); 78 | rpc_any_t ret; 79 | ret.i = 0; 80 | return ret; 81 | } 82 | 83 | // 获取LED状态 84 | static rpc_any_t led_status(void) { 85 | rpc_any_t ret; 86 | ret.uc = s_led_state; 87 | return ret; 88 | } 89 | 90 | void app_main(void) { 91 | configure_led(); 92 | // 启动HomeRPC 93 | HomeRPC.start(); 94 | // 服务列表 95 | Service_t services[] = { 96 | { 97 | .func = trigger_led, 98 | .input_type = "i", 99 | .output_type = 'i', 100 | .name = "trigger", 101 | .desc = "open the light", 102 | }, 103 | { 104 | .func = led_status, 105 | .input_type = "", 106 | .output_type = 'i', 107 | .name = "status", 108 | .desc = "check the light status, return 1 if on, 0 if off", 109 | } 110 | }; 111 | // 设备信息 112 | Device_t led = { 113 | .place = "room", 114 | .type = "light", 115 | .id = 1, 116 | .services = services, 117 | .services_num = sizeof(services) / sizeof(Service_t) 118 | }; 119 | // 注册设备 120 | HomeRPC.addDevice(&led); 121 | // 调用服务 122 | Device_t led2 = { 123 | .place = "room", 124 | .type = "light", 125 | .id = 1 126 | }; 127 | 128 | while (1) { 129 | vTaskDelay(5000 / portTICK_PERIOD_MS); 130 | // 调用服务 131 | // rpc_any_t status = HomeRPC.callService(&led2, "status", NULL, 10); 132 | // printf("led status: %d\n", status.i); 133 | } 134 | } 135 | ``` 136 | 137 | `Broker` 服务端示例代码如下: 138 | ```Python 139 | # file: server.py 140 | from server import HomeRPC 141 | 142 | if __name__ == '__main__': 143 | # 启动HomeRPC 144 | HomeRPC.setup(ip = "192.168.43.9", log = True) 145 | 146 | # 等待ESP32连接 147 | input("Waiting for ESP32 to connect...") 148 | 149 | place = HomeRPC.place("room") 150 | # 调用ESP32客户端服务 151 | place.device("light").id(1).call("trigger", 1, timeout_s = 10) 152 | print("led status: ", place.device("light").id(1).call("status", timeout_s = 10)) 153 | ``` 154 | ## 📦 安装方法 155 | 156 | 1. 将 `HomeRPC` 组件添加到您的 `ESP-IDF` 项目中: 157 | ```bash 158 | cd ~/my_esp_idf_project 159 | mkdir components 160 | cd components 161 | git clone https://github.com/guidons-master/HomeRPC.git 162 | ``` 163 | 2. 在 `menuconfig` 中配置 `HomeRPC` 164 | 165 | ## 🧑‍💻 维护人员 166 | 167 | - [@guidons](https://github.com/guidons-master) 168 | - [@Hexin Lv](https://github.com/Mondaylv) 169 | -------------------------------------------------------------------------------- /assets/chat_with_llm.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidons-master/HomeRPC/07147c3d4fc554b46bec8296e393d727a7b6238c/assets/chat_with_llm.mp4 -------------------------------------------------------------------------------- /assets/works.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidons-master/HomeRPC/07147c3d4fc554b46bec8296e393d727a7b6238c/assets/works.png -------------------------------------------------------------------------------- /example/test/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | idf_component_register(SRCS "main.c" 3 | INCLUDE_DIRS ".") 4 | -------------------------------------------------------------------------------- /example/test/idf_component.yml: -------------------------------------------------------------------------------- 1 | dependencies: 2 | espressif/mdns: "^1.2.5" 3 | idf: 4 | version: '>=4.3' 5 | mesh_lite: 6 | version: '*' 7 | usb_device: 8 | git: https://github.com/espressif/esp-iot-bridge.git 9 | path: components/usb/usb_device 10 | rules: 11 | - if: target in [esp32s2, esp32s3] 12 | -------------------------------------------------------------------------------- /example/test/main.c: -------------------------------------------------------------------------------- 1 | #include "freertos/FreeRTOS.h" 2 | #include "driver/gpio.h" 3 | #include "HomeRPC.h" 4 | 5 | #define BLINK_GPIO 2 6 | 7 | static uint8_t s_led_state = 0; 8 | 9 | static void configure_led(void) { 10 | gpio_reset_pin(BLINK_GPIO); 11 | gpio_set_direction(BLINK_GPIO, GPIO_MODE_OUTPUT); 12 | } 13 | 14 | // 触发LED 15 | static rpc_any_t trigger_led(rpc_any_t state) { 16 | gpio_set_level(BLINK_GPIO, state.uc); 17 | rpc_any_t ret; 18 | ret.i = 0; 19 | return ret; 20 | } 21 | 22 | // 获取LED状态 23 | static rpc_any_t led_status(void) { 24 | rpc_any_t ret; 25 | ret.uc = s_led_state; 26 | return ret; 27 | } 28 | 29 | void app_main(void) { 30 | configure_led(); 31 | // 启动HomeRPC 32 | HomeRPC.start(); 33 | // 服务列表 34 | Service_t services[] = { 35 | { 36 | .func = trigger_led, 37 | .input_type = "i", 38 | .output_type = 'i', 39 | .name = "trigger", 40 | .desc = "open the light", 41 | }, 42 | { 43 | .func = led_status, 44 | .input_type = "", 45 | .output_type = 'i', 46 | .name = "status", 47 | .desc = "check the light status, return 1 if on, 0 if off", 48 | } 49 | }; 50 | // 设备信息 51 | Device_t led = { 52 | .place = "room", 53 | .type = "light", 54 | .id = 1, 55 | .services = services, 56 | .services_num = sizeof(services) / sizeof(Service_t) 57 | }; 58 | // 注册设备 59 | HomeRPC.addDevice(&led); 60 | // 调用服务 61 | Device_t led2 = { 62 | .place = "room", 63 | .type = "light", 64 | .id = 1 65 | }; 66 | 67 | while (1) { 68 | vTaskDelay(5000 / portTICK_PERIOD_MS); 69 | // 调用服务 70 | // rpc_any_t status = HomeRPC.callService(&led2, "status", NULL, 10); 71 | // printf("led status: %d\n", status.i); 72 | } 73 | } -------------------------------------------------------------------------------- /homeagent.py: -------------------------------------------------------------------------------- 1 | import json 2 | from qwen_agent.llm import get_chat_model 3 | import streamlit as st 4 | from server import HomeRPC 5 | 6 | def function_call(func, place = None, device_type = None, device_id = None, input_type = None): 7 | # Call the HomeRPC function 8 | if not (func and place and device_type and device_id): 9 | return json.dumps({ 10 | "status": "error", 11 | "message": "Missing required parameters" 12 | }) 13 | 14 | if input_type: 15 | ret = HomeRPC.place(place).device(device_type).id(device_id).call(func, input_type) 16 | else: 17 | ret = HomeRPC.place(place).device(device_type).id(device_id).call(func) 18 | 19 | if ret is None: 20 | return json.dumps({ 21 | "status": "error", 22 | "message": "Function call failed" 23 | }) 24 | 25 | return json.dumps({ 26 | "status": "success", 27 | "return": ret 28 | }) 29 | 30 | if __name__ == "__main__": 31 | st.set_page_config( 32 | page_title="HomeAgent", 33 | page_icon=":robot:", 34 | layout="wide" 35 | ) 36 | 37 | @st.cache_resource 38 | def get_model(): 39 | llm = get_chat_model({ 40 | 'model': 'qwen-max', 41 | 'model_server': 'dashscope', 42 | 'api_key': 'xxxxx-xxxxx-xxxxx-xxxxx-xxxxx-xxxxx-xxxxx-xxxxx', 43 | 'generate_cfg': { 44 | 'top_p': 0.8, 45 | 'temperature': 0.8 46 | } 47 | }) 48 | 49 | HomeRPC.setup(ip = "192.168.43.9", log = True) 50 | 51 | return llm 52 | 53 | llm = get_model() 54 | 55 | if "history" not in st.session_state: 56 | st.session_state.history = [] 57 | 58 | st.sidebar.header("HomeAgent", divider = False) 59 | 60 | st.sidebar.divider() 61 | 62 | funcName = st.sidebar.markdown("**函数调用:**") 63 | funcArgs = st.sidebar.markdown("**参数:**") 64 | callStatus = st.sidebar.markdown("**调用状态:**") 65 | 66 | st.sidebar.divider() 67 | 68 | buttonClean = st.sidebar.button("清理会话历史", key="clean") 69 | if buttonClean: 70 | st.session_state.history = [] 71 | st.rerun() 72 | 73 | for i, message in enumerate(st.session_state.history): 74 | if message["role"] == "user": 75 | with st.chat_message(name="user", avatar="user"): 76 | st.markdown(message["content"]) 77 | elif message["role"] == "assistant": 78 | if message["content"]: 79 | with st.chat_message(name="assistant", avatar="assistant"): 80 | st.markdown(message["content"]) 81 | 82 | with st.chat_message(name="user", avatar="user"): 83 | input_placeholder = st.empty() 84 | with st.chat_message(name="assistant", avatar="assistant"): 85 | message_placeholder = st.empty() 86 | 87 | prompt_text = st.chat_input("需要什么帮助吗?") 88 | 89 | if prompt_text: 90 | 91 | st.session_state.history.append({ 92 | "role": "user", 93 | "content": prompt_text 94 | }) 95 | 96 | input_placeholder.markdown(prompt_text) 97 | history = st.session_state.history 98 | responses = [] 99 | 100 | for responses in llm.chat(messages=history, 101 | functions=HomeRPC.funcs(), 102 | stream=True): 103 | if responses[0]["content"]: 104 | message_placeholder.markdown(responses[0]["content"]) 105 | 106 | history.extend(responses) 107 | 108 | last_response = history[-1] 109 | 110 | if last_response.get('function_call', None): 111 | 112 | try: 113 | function_name = last_response['function_call']['name'] 114 | funcName.markdown(f"**函数调用:**```{function_name}```") 115 | function_args = json.loads(last_response['function_call']['arguments']) 116 | funcArgs.markdown(f"**参数:**```{function_args}```") 117 | 118 | function_response = function_call( 119 | function_name, 120 | **function_args 121 | ) 122 | 123 | if json.loads(function_response).get('status') == 'success': 124 | callStatus.markdown(f"**调用状态:**```{function_response}```", unsafe_allow_html=True) 125 | else: 126 | callStatus.markdown(f"**调用状态:**```{function_response}```", unsafe_allow_html=True) 127 | 128 | history.append({ 129 | 'role': 'function', 130 | 'name': function_name, 131 | 'content': function_response 132 | }) 133 | 134 | for responses in llm.chat( 135 | messages=history, 136 | functions=HomeRPC.funcs(), 137 | stream=True, 138 | ): 139 | message_placeholder.markdown(responses[0]["content"]) 140 | 141 | history.extend(responses) 142 | 143 | except Exception as e: 144 | 145 | callStatus.markdown(f'**调用状态:**```{e}```', unsafe_allow_html=True) -------------------------------------------------------------------------------- /idf_component.yml: -------------------------------------------------------------------------------- 1 | dependencies: 2 | espressif/mdns: "^1.2.5" 3 | idf: 4 | version: '>=4.3' 5 | mesh_lite: 6 | version: '*' 7 | usb_device: 8 | git: https://github.com/espressif/esp-iot-bridge.git 9 | path: components/usb/usb_device 10 | rules: 11 | - if: target in [esp32s2, esp32s3] -------------------------------------------------------------------------------- /include/HomeRPC.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef __cplusplus 4 | extern "C" { 5 | #endif 6 | 7 | #include "esp_system.h" 8 | #include "esp_log.h" 9 | #include "cJSON.h" 10 | #include "freertos/FreeRTOS.h" 11 | #include "freertos/event_groups.h" 12 | 13 | #define RPC_ERROR_CHECK(TAG, err) if ((err)) { \ 14 | rpc_log.log_error((TAG), "Error occurred: %s", esp_err_to_name((err))); \ 15 | esp_restart(); \ 16 | } 17 | 18 | typedef union { 19 | char c; 20 | unsigned char uc; 21 | short s; 22 | unsigned short us; 23 | int i; 24 | unsigned int ui; 25 | long l; 26 | float f; 27 | double d; 28 | // char* str; // todo 29 | } __attribute__((aligned(1))) rpc_any_t; 30 | 31 | typedef rpc_any_t (*rpc_func_t)(void); 32 | 33 | typedef struct { 34 | /* Public */ 35 | rpc_func_t func; 36 | char *input_type; 37 | char output_type; 38 | const char *name; 39 | const char *desc; 40 | /* Private */ 41 | cJSON *_input; 42 | EventBits_t _wait; 43 | } Service_t; 44 | 45 | typedef struct { 46 | char *place; 47 | char *type; 48 | unsigned int id; 49 | Service_t *services; 50 | unsigned int services_num; 51 | } Device_t; 52 | 53 | struct List_t { 54 | Device_t* device; 55 | struct List_t *next; 56 | }; 57 | typedef struct List_t DeviceList_t; 58 | 59 | typedef struct { 60 | esp_log_level_t log_level; 61 | uint8_t log_enable; 62 | void (*start)(void); 63 | void (*addDevice)(const Device_t *); 64 | rpc_any_t (*_callService)(const Device_t *, const char *, const rpc_any_t *, unsigned int, TickType_t); 65 | } HomeRPC_t; 66 | 67 | #define callService(device, service, params, timeout_s) _callService((device), (service), (params), ((params) && (sizeof((params)) / sizeof(rpc_any_t))), (timeout_s) * 1000) 68 | 69 | extern HomeRPC_t HomeRPC; 70 | 71 | #ifdef __cplusplus 72 | } 73 | #endif -------------------------------------------------------------------------------- /include/rpc_data.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef __cplusplus 4 | extern "C" { 5 | #endif 6 | 7 | #include "HomeRPC.h" 8 | #include "cJSON.h" 9 | 10 | char* serialize_device(const Device_t*); 11 | char* serialize_service(const char*, const rpc_any_t*, unsigned int); 12 | int deserialize_service(cJSON*, char**, rpc_any_t*, unsigned int*); 13 | cJSON* rpc_any_to_json(rpc_any_t); 14 | rpc_any_t json_to_rpc_any(cJSON *); 15 | 16 | #ifdef __cplusplus 17 | } 18 | #endif -------------------------------------------------------------------------------- /include/rpc_log.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef __cplusplus 4 | extern "C" { 5 | #endif 6 | 7 | typedef struct { 8 | void (*log_info)(const char* tag, const char* format, ...); 9 | void (*log_error)(const char* tag, const char* format, ...); 10 | void (*log_warn)(const char* tag, const char* format, ...); 11 | } rpc_log_t; 12 | 13 | extern rpc_log_t rpc_log; 14 | 15 | #ifdef __cplusplus 16 | } 17 | #endif -------------------------------------------------------------------------------- /include/rpc_mdns.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef __cplusplus 4 | extern "C" { 5 | #endif 6 | 7 | #include "esp_err.h" 8 | #include "mdns.h" 9 | #include "esp_netif.h" 10 | 11 | esp_err_t rpc_mdns_init(void); 12 | esp_err_t rpc_mdns_search(const char *, esp_ip4_addr_t *); 13 | void rpc_mdns_stop(void); 14 | 15 | #ifdef __cplusplus 16 | } 17 | #endif -------------------------------------------------------------------------------- /include/rpc_mesh.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef __cplusplus 4 | extern "C" { 5 | #endif 6 | 7 | void rpc_mesh_init(esp_event_handler_t); 8 | 9 | #ifdef __cplusplus 10 | } 11 | #endif -------------------------------------------------------------------------------- /include/rpc_mqtt.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef __cplusplus 4 | extern "C" { 5 | #endif 6 | 7 | #include "HomeRPC.h" 8 | 9 | extern char callback_topic[CONFIG_TOPIC_LEN_MAX]; 10 | void rpc_mqtt_call(const Device_t *, const char *, const rpc_any_t*, const unsigned int); 11 | void rpc_mqtt_task(void *); 12 | 13 | #ifdef __cplusplus 14 | } 15 | #endif -------------------------------------------------------------------------------- /license.txt: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /qwen_agent/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.0.2' 2 | from .agent import Agent 3 | 4 | __all__ = ['Agent'] 5 | -------------------------------------------------------------------------------- /qwen_agent/agent.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import traceback 4 | from abc import ABC, abstractmethod 5 | from typing import Dict, Iterator, List, Optional, Tuple, Union 6 | 7 | from qwen_agent.llm import get_chat_model 8 | from qwen_agent.llm.base import BaseChatModel 9 | from qwen_agent.llm.schema import (CONTENT, DEFAULT_SYSTEM_MESSAGE, ROLE, 10 | SYSTEM, ContentItem, Message) 11 | from qwen_agent.log import logger 12 | from qwen_agent.tools import TOOL_REGISTRY, BaseTool 13 | from qwen_agent.utils.utils import has_chinese_chars 14 | 15 | 16 | class Agent(ABC): 17 | """A base class for Agent. 18 | 19 | An agent can receive messages and provide response by LLM or Tools. 20 | Different agents have distinct workflows for processing messages and generating responses in the `_run` method. 21 | """ 22 | 23 | def __init__(self, 24 | function_list: Optional[List[Union[str, Dict, 25 | BaseTool]]] = None, 26 | llm: Optional[Union[Dict, BaseChatModel]] = None, 27 | system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE, 28 | name: Optional[str] = None, 29 | description: Optional[str] = None, 30 | **kwargs): 31 | """Initialization the agent. 32 | 33 | Args: 34 | function_list: One list of tool name, tool configuration or Tool object, 35 | such as 'code_interpreter', {'name': 'code_interpreter', 'timeout': 10}, or CodeInterpreter(). 36 | llm: The LLM model configuration or LLM model object. 37 | Set the configuration as {'model': '', 'api_key': '', 'model_server': ''}. 38 | system_message: The specified system message for LLM chat. 39 | name: The name of this agent. 40 | description: The description of this agent, which will be used for multi_agent. 41 | """ 42 | if isinstance(llm, dict): 43 | self.llm = get_chat_model(llm) 44 | else: 45 | self.llm = llm 46 | 47 | self.function_map = {} 48 | if function_list: 49 | for tool in function_list: 50 | self._init_tool(tool) 51 | 52 | self.system_message = system_message 53 | self.name = name 54 | self.description = description 55 | 56 | def run(self, messages: List[Union[Dict, Message]], 57 | **kwargs) -> Union[Iterator[List[Message]], Iterator[List[Dict]]]: 58 | """Return one response generator based on the received messages. 59 | 60 | This method performs a uniform type conversion for the inputted messages, 61 | and calls the _run method to generate a reply. 62 | 63 | Args: 64 | messages: A list of messages. 65 | 66 | Yields: 67 | The response generator. 68 | """ 69 | messages = copy.deepcopy(messages) 70 | _return_message_type = 'dict' 71 | new_messages = [] 72 | # Only return dict when all input messages are dict 73 | if not messages: 74 | _return_message_type = 'message' 75 | for msg in messages: 76 | if isinstance(msg, dict): 77 | new_messages.append(Message(**msg)) 78 | else: 79 | new_messages.append(msg) 80 | _return_message_type = 'message' 81 | 82 | if new_messages and 'lang' not in kwargs: 83 | if has_chinese_chars([new_messages[-1][CONTENT], kwargs]): 84 | kwargs['lang'] = 'zh' 85 | else: 86 | kwargs['lang'] = 'en' 87 | 88 | for rsp in self._run(messages=new_messages, **kwargs): 89 | for i in range(len(rsp)): 90 | if not rsp[i].name and self.name: 91 | rsp[i].name = self.name 92 | if _return_message_type == 'message': 93 | yield [Message(**x) if isinstance(x, dict) else x for x in rsp] 94 | else: 95 | yield [ 96 | x.model_dump() if not isinstance(x, dict) else x 97 | for x in rsp 98 | ] 99 | 100 | @abstractmethod 101 | def _run(self, 102 | messages: List[Message], 103 | lang: str = 'en', 104 | **kwargs) -> Iterator[List[Message]]: 105 | """Return one response generator based on the received messages. 106 | 107 | The workflow for an agent to generate a reply. 108 | Each agent subclass needs to implement this method. 109 | 110 | Args: 111 | messages: A list of messages. 112 | lang: Language, which will be used to select the language of the prompt 113 | during the agent's execution process. 114 | 115 | Yields: 116 | The response generator. 117 | """ 118 | raise NotImplementedError 119 | 120 | def _call_llm( 121 | self, 122 | messages: List[Message], 123 | functions: Optional[List[Dict]] = None, 124 | stream: bool = True, 125 | ) -> Iterator[List[Message]]: 126 | """The interface of calling LLM for the agent. 127 | 128 | We prepend the system_message of this agent to the messages, and call LLM. 129 | 130 | Args: 131 | messages: A list of messages. 132 | functions: The list of functions provided to LLM. 133 | stream: LLM streaming output or non-streaming output. 134 | For consistency, we default to using streaming output across all agents. 135 | 136 | Yields: 137 | The response generator of LLM. 138 | """ 139 | messages = copy.deepcopy(messages) 140 | if messages[0][ROLE] != SYSTEM: 141 | messages.insert(0, Message(role=SYSTEM, 142 | content=self.system_message)) 143 | elif isinstance(messages[0][CONTENT], str): 144 | messages[0][CONTENT] = self.system_message + messages[0][CONTENT] 145 | else: 146 | assert isinstance(messages[0][CONTENT], list) 147 | messages[0][CONTENT] = [ContentItem(text=self.system_message) 148 | ] + messages[0][CONTENT] 149 | return self.llm.chat(messages=messages, 150 | functions=functions, 151 | stream=stream) 152 | 153 | def _call_tool(self, 154 | tool_name: str, 155 | tool_args: Union[str, dict] = '{}', 156 | **kwargs) -> str: 157 | """The interface of calling tools for the agent. 158 | 159 | Args: 160 | tool_name: The name of one tool. 161 | tool_args: Model generated or user given tool parameters. 162 | 163 | Returns: 164 | The output of tools. 165 | """ 166 | if tool_name not in self.function_map: 167 | return f'Tool {tool_name} does not exists.' 168 | tool = self.function_map[tool_name] 169 | try: 170 | tool_result = tool.call(tool_args, **kwargs) 171 | except Exception as ex: 172 | exception_type = type(ex).__name__ 173 | exception_message = str(ex) 174 | traceback_info = ''.join(traceback.format_tb(ex.__traceback__)) 175 | error_message = f'An error occurred when calling tool `{tool_name}`:\n' \ 176 | f'{exception_type}: {exception_message}\n' \ 177 | f'Traceback:\n{traceback_info}' 178 | return error_message 179 | 180 | if isinstance(tool_result, str): 181 | return tool_result 182 | else: 183 | return json.dumps(tool_result, ensure_ascii=False, indent=4) 184 | 185 | def _init_tool(self, tool: Union[str, Dict, BaseTool]): 186 | if isinstance(tool, BaseTool): 187 | tool_name = tool.name 188 | if tool_name in self.function_map: 189 | logger.warning( 190 | f'Repeatedly adding tool {tool_name}, will use the newest tool in function list' 191 | ) 192 | self.function_map[tool_name] = tool 193 | else: 194 | if isinstance(tool, dict): 195 | tool_name = tool['name'] 196 | tool_cfg = tool 197 | else: 198 | tool_name = tool 199 | tool_cfg = None 200 | if tool_name not in TOOL_REGISTRY: 201 | raise ValueError(f'Tool {tool_name} is not registered.') 202 | 203 | if tool_name in self.function_map: 204 | logger.warning( 205 | f'Repeatedly adding tool {tool_name}, will use the newest tool in function list' 206 | ) 207 | self.function_map[tool_name] = TOOL_REGISTRY[tool_name](tool_cfg) 208 | 209 | def _detect_tool(self, message: Message) -> Tuple[bool, str, str, str]: 210 | """A built-in tool call detection for func_call format message. 211 | 212 | Args: 213 | message: one message generated by LLM. 214 | 215 | Returns: 216 | Need to call tool or not, tool name, tool args, text replies. 217 | """ 218 | func_name = None 219 | func_args = None 220 | 221 | if message.function_call: 222 | func_call = message.function_call 223 | func_name = func_call.name 224 | func_args = func_call.arguments 225 | text = message.content 226 | if not text: 227 | text = '' 228 | 229 | return (func_name is not None), func_name, func_args, text 230 | -------------------------------------------------------------------------------- /qwen_agent/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from .article_agent import ArticleAgent 2 | from .assistant import Assistant 3 | from .docqa_agent import DocQAAgent 4 | from .fncall_agent import FnCallAgent 5 | from .group_chat import GroupChat 6 | from .group_chat_auto_router import GroupChatAutoRouter 7 | from .group_chat_creator import GroupChatCreator 8 | from .react_chat import ReActChat 9 | from .router import Router 10 | from .user_agent import UserAgent 11 | from .write_from_scratch import WriteFromScratch 12 | 13 | __all__ = [ 14 | 'DocQAAgent', 'Assistant', 'ArticleAgent', 'ReActChat', 'Router', 15 | 'UserAgent', 'GroupChat', 'WriteFromScratch', 'GroupChatCreator', 16 | 'GroupChatAutoRouter', 'FnCallAgent' 17 | ] 18 | -------------------------------------------------------------------------------- /qwen_agent/agents/article_agent.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, List 2 | 3 | from qwen_agent.agents.assistant import Assistant 4 | from qwen_agent.agents.write_from_scratch import WriteFromScratch 5 | from qwen_agent.llm.schema import ASSISTANT, CONTENT, Message 6 | from qwen_agent.prompts import ContinueWriting 7 | 8 | 9 | class ArticleAgent(Assistant): 10 | """This is an agent for writing articles. 11 | 12 | It can write a thematic essay or continue writing an article based on reference materials 13 | """ 14 | 15 | def _run(self, 16 | messages: List[Message], 17 | lang: str = 'en', 18 | max_ref_token: int = 4000, 19 | full_article: bool = False, 20 | **kwargs) -> Iterator[List[Message]]: 21 | 22 | # Need to use Memory agent for data management 23 | *_, last = self.mem.run(messages=messages, 24 | max_ref_token=max_ref_token, 25 | **kwargs) 26 | _ref = last[-1][CONTENT] 27 | 28 | response = [] 29 | if _ref: 30 | response.append( 31 | Message(ASSISTANT, 32 | f'>\n> Search for relevant information: \n{_ref}\n')) 33 | yield response 34 | 35 | if full_article: 36 | writing_agent = WriteFromScratch(llm=self.llm) 37 | else: 38 | writing_agent = ContinueWriting(llm=self.llm) 39 | response.append(Message(ASSISTANT, '>\n> Writing Text: \n')) 40 | yield response 41 | 42 | for trunk in writing_agent.run(messages=messages, 43 | lang=lang, 44 | knowledge=_ref): 45 | if trunk: 46 | yield response + trunk 47 | -------------------------------------------------------------------------------- /qwen_agent/agents/assistant.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Iterator, List 3 | 4 | from qwen_agent.llm.schema import CONTENT, ROLE, SYSTEM, Message 5 | from qwen_agent.log import logger 6 | from qwen_agent.utils.utils import format_knowledge_to_source_and_content 7 | 8 | from .fncall_agent import FnCallAgent 9 | 10 | KNOWLEDGE_SNIPPET_ZH = """## 来自 {source} 的内容: 11 | 12 | ``` 13 | {content} 14 | ```""" 15 | KNOWLEDGE_TEMPLATE_ZH = """ 16 | 17 | # 知识库 18 | 19 | {knowledge}""" 20 | 21 | KNOWLEDGE_SNIPPET_EN = """## The content from {source}: 22 | 23 | ``` 24 | {content} 25 | ```""" 26 | KNOWLEDGE_TEMPLATE_EN = """ 27 | 28 | # Knowledge Base 29 | 30 | {knowledge}""" 31 | 32 | KNOWLEDGE_SNIPPET = {'zh': KNOWLEDGE_SNIPPET_ZH, 'en': KNOWLEDGE_SNIPPET_EN} 33 | KNOWLEDGE_TEMPLATE = {'zh': KNOWLEDGE_TEMPLATE_ZH, 'en': KNOWLEDGE_TEMPLATE_EN} 34 | 35 | 36 | class Assistant(FnCallAgent): 37 | """This is a widely applicable agent integrated with RAG capabilities and function call ability.""" 38 | 39 | def _run(self, 40 | messages: List[Message], 41 | lang: str = 'en', 42 | max_ref_token: int = 4000, 43 | **kwargs) -> Iterator[List[Message]]: 44 | 45 | new_messages = self._prepend_knowledge_prompt(messages, lang, 46 | max_ref_token, **kwargs) 47 | return super()._run(messages=new_messages, 48 | lang=lang, 49 | max_ref_token=max_ref_token, 50 | **kwargs) 51 | 52 | def _prepend_knowledge_prompt(self, 53 | messages: List[Message], 54 | lang: str = 'en', 55 | max_ref_token: int = 4000, 56 | **kwargs) -> List[Message]: 57 | messages = copy.deepcopy(messages) 58 | # Retrieval knowledge from files 59 | *_, last = self.mem.run(messages=messages, max_ref_token=max_ref_token, lang=lang, **kwargs) 60 | knowledge = last[-1][CONTENT] 61 | 62 | logger.debug( 63 | f'Retrieved knowledge of type `{type(knowledge).__name__}`:\n{knowledge}' 64 | ) 65 | if knowledge: 66 | knowledge = format_knowledge_to_source_and_content(knowledge) 67 | logger.debug( 68 | f'Formatted knowledge into type `{type(knowledge).__name__}`:\n{knowledge}' 69 | ) 70 | else: 71 | knowledge = [] 72 | snippets = [] 73 | for k in knowledge: 74 | snippets.append(KNOWLEDGE_SNIPPET[lang].format( 75 | source=k['source'], content=k['content'])) 76 | knowledge_prompt = '' 77 | if snippets: 78 | knowledge_prompt = KNOWLEDGE_TEMPLATE[lang].format( 79 | knowledge='\n\n'.join(snippets)) 80 | 81 | if knowledge_prompt: 82 | if messages[0][ROLE] == SYSTEM: 83 | messages[0][CONTENT] += knowledge_prompt 84 | else: 85 | messages = [Message(role=SYSTEM, content=knowledge_prompt) 86 | ] + messages 87 | return messages 88 | -------------------------------------------------------------------------------- /qwen_agent/agents/docqa_agent.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Iterator, List, Optional, Union 2 | 3 | from qwen_agent.agents.assistant import Assistant 4 | from qwen_agent.llm.base import BaseChatModel 5 | from qwen_agent.llm.schema import CONTENT, DEFAULT_SYSTEM_MESSAGE, Message 6 | from qwen_agent.prompts import DocQA 7 | from qwen_agent.tools import BaseTool 8 | 9 | 10 | class DocQAAgent(Assistant): 11 | """This is an agent for doc QA.""" 12 | 13 | def __init__(self, 14 | function_list: Optional[List[Union[str, Dict, 15 | BaseTool]]] = None, 16 | llm: Optional[Union[Dict, BaseChatModel]] = None, 17 | system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE, 18 | name: Optional[str] = None, 19 | description: Optional[str] = None, 20 | files: Optional[List[str]] = None): 21 | super().__init__(function_list=function_list, 22 | llm=llm, 23 | system_message=system_message, 24 | name=name, 25 | description=description, 26 | files=files) 27 | 28 | self.doc_qa = DocQA(llm=self.llm) 29 | 30 | def _run(self, 31 | messages: List[Message], 32 | lang: str = 'en', 33 | max_ref_token: int = 4000, 34 | **kwargs) -> Iterator[List[Message]]: 35 | 36 | # Need to use Memory agent for data management 37 | *_, last = self.mem.run(messages=messages, 38 | max_ref_token=max_ref_token, 39 | **kwargs) 40 | _ref = last[-1][CONTENT] 41 | 42 | # Use RetrievalQA agent 43 | response = self.doc_qa.run(messages=messages, 44 | lang=lang, 45 | knowledge=_ref) 46 | 47 | return response 48 | -------------------------------------------------------------------------------- /qwen_agent/agents/fncall_agent.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Dict, Iterator, List, Optional, Union 3 | 4 | from qwen_agent import Agent 5 | from qwen_agent.llm import BaseChatModel 6 | from qwen_agent.llm.schema import DEFAULT_SYSTEM_MESSAGE, FUNCTION, Message 7 | from qwen_agent.memory import Memory 8 | from qwen_agent.tools import BaseTool 9 | 10 | MAX_LLM_CALL_PER_RUN = 8 11 | 12 | 13 | class FnCallAgent(Agent): 14 | """This is a widely applicable function call agent integrated with llm and tool use ability.""" 15 | 16 | def __init__(self, 17 | function_list: Optional[List[Union[str, Dict, 18 | BaseTool]]] = None, 19 | llm: Optional[Union[Dict, BaseChatModel]] = None, 20 | system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE, 21 | name: Optional[str] = None, 22 | description: Optional[str] = None, 23 | files: Optional[List[str]] = None): 24 | """Initialization the agent. 25 | 26 | Args: 27 | function_list: One list of tool name, tool configuration or Tool object, 28 | such as 'code_interpreter', {'name': 'code_interpreter', 'timeout': 10}, or CodeInterpreter(). 29 | llm: The LLM model configuration or LLM model object. 30 | Set the configuration as {'model': '', 'api_key': '', 'model_server': ''}. 31 | system_message: The specified system message for LLM chat. 32 | name: The name of this agent. 33 | description: The description of this agent, which will be used for multi_agent. 34 | files: A file url list. The initialized files for the agent. 35 | """ 36 | super().__init__(function_list=function_list, 37 | llm=llm, 38 | system_message=system_message, 39 | name=name, 40 | description=description) 41 | 42 | # Default to use Memory to manage files 43 | self.mem = Memory(llm=self.llm, files=files) 44 | 45 | def _run(self, 46 | messages: List[Message], 47 | lang: str = 'en', 48 | **kwargs) -> Iterator[List[Message]]: 49 | messages = copy.deepcopy(messages) 50 | num_llm_calls_available = MAX_LLM_CALL_PER_RUN 51 | response = [] 52 | while True and num_llm_calls_available > 0: 53 | num_llm_calls_available -= 1 54 | output_stream = self._call_llm( 55 | messages=messages, 56 | functions=[ 57 | func.function for func in self.function_map.values() 58 | ]) 59 | output: List[Message] = [] 60 | for output in output_stream: 61 | if output: 62 | yield response + output 63 | if output: 64 | response.extend(output) 65 | messages.extend(output) 66 | use_tool, action, action_input, _ = self._detect_tool(response[-1]) 67 | if use_tool: 68 | observation = self._call_tool(action, 69 | action_input, 70 | messages=messages) 71 | fn_msg = Message( 72 | role=FUNCTION, 73 | name=action, 74 | content=observation, 75 | ) 76 | messages.append(fn_msg) 77 | response.append(fn_msg) 78 | yield response 79 | else: 80 | break 81 | 82 | def _call_tool(self, 83 | tool_name: str, 84 | tool_args: Union[str, dict] = '{}', 85 | **kwargs) -> str: 86 | # Temporary plan: Check if it is necessary to transfer files to the tool 87 | # Todo: This should be changed to parameter passing, and the file URL should be determined by the model 88 | if self.function_map[tool_name].file_access: 89 | assert 'messages' in kwargs 90 | files = self.mem.get_all_files_of_messages( 91 | kwargs['messages']) + self.mem.system_files 92 | return super()._call_tool(tool_name, 93 | tool_args, 94 | files=files, 95 | **kwargs) 96 | else: 97 | return super()._call_tool(tool_name, tool_args, **kwargs) 98 | -------------------------------------------------------------------------------- /qwen_agent/agents/group_chat_auto_router.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Iterator, List, Optional, Union 2 | 3 | from qwen_agent import Agent 4 | from qwen_agent.llm import BaseChatModel 5 | from qwen_agent.llm.schema import Message 6 | from qwen_agent.tools import BaseTool 7 | from qwen_agent.utils.utils import has_chinese_chars 8 | 9 | PROMPT_TEMPLATE_ZH = '''你扮演角色扮演游戏的上帝,你的任务是选择合适的发言角色。有如下角色: 10 | {agent_descs} 11 | 12 | 角色间的对话历史格式如下,越新的对话越重要: 13 | 角色名: 说话内容 14 | 15 | 请阅读对话历史,并选择下一个合适的发言角色,从 [{agent_names}] 里选,当真实用户最近表明了停止聊天时,或话题应该终止时,请返回“[STOP]”,用户很懒,非必要不要选真实用户。 16 | 仅返回角色名或“[STOP]”,不要返回其余内容。''' 17 | 18 | PROMPT_TEMPLATE_EN = '''You are in a role play game. The following roles are available: 19 | {agent_descs} 20 | 21 | The format of dialogue history between roles is as follows: 22 | Role Name: Speech Content 23 | 24 | Please read the dialogue history and choose the next suitable role to speak. 25 | When the user indicates to stop chatting or when the topic should be terminated, please return '[STOP]'. 26 | Only return the role name from [{agent_names}] or '[STOP]'. Do not reply any other content.''' 27 | 28 | PROMPT_TEMPLATE = { 29 | 'zh': PROMPT_TEMPLATE_ZH, 30 | 'en': PROMPT_TEMPLATE_EN, 31 | } 32 | 33 | 34 | class GroupChatAutoRouter(Agent): 35 | 36 | def __init__(self, 37 | function_list: Optional[List[Union[str, Dict, 38 | BaseTool]]] = None, 39 | llm: Optional[Union[Dict, BaseChatModel]] = None, 40 | agents: List[Agent] = None, 41 | name: Optional[str] = None, 42 | description: Optional[str] = None, 43 | **kwargs): 44 | # This agent need prepend special system message according to inputted agents 45 | agent_descs = '\n'.join([f'{x.name}: {x.description}' for x in agents]) 46 | lang = 'en' 47 | if has_chinese_chars(agent_descs): 48 | lang = 'zh' 49 | system_prompt = PROMPT_TEMPLATE[lang].format( 50 | agent_descs=agent_descs, 51 | agent_names=', '.join([x.name for x in agents])) 52 | 53 | super().__init__(function_list=function_list, 54 | llm=llm, 55 | system_message=system_prompt, 56 | name=name, 57 | description=description, 58 | **kwargs) 59 | 60 | def _run(self, 61 | messages: List[Message], 62 | lang: str = 'en', 63 | **kwargs) -> Iterator[List[Message]]: 64 | 65 | dialogue = [] 66 | for msg in messages: 67 | if msg.role == 'function' or not msg.content: 68 | continue 69 | if isinstance(msg.content, list): 70 | content = '\n'.join( 71 | [x.text if x.text else '' for x in msg.content]).strip() 72 | else: 73 | content = msg.content.strip() 74 | display_name = msg.role 75 | if msg.name: 76 | display_name = msg.name 77 | if dialogue and dialogue[-1].startswith(display_name): 78 | dialogue[-1] += f'\n{content}' 79 | else: 80 | dialogue.append(f'{display_name}: {content}') 81 | 82 | if not dialogue: 83 | dialogue.append('对话刚开始,请任意选择一个发言人,别选真实用户') 84 | new_messages = [Message('user', '\n'.join(dialogue))] 85 | 86 | return self._call_llm(messages=new_messages) 87 | -------------------------------------------------------------------------------- /qwen_agent/agents/group_chat_creator.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | from typing import Dict, Iterator, List, Optional, Tuple, Union 4 | 5 | import json5 6 | 7 | from qwen_agent import Agent 8 | from qwen_agent.llm import BaseChatModel 9 | from qwen_agent.llm.schema import Message 10 | from qwen_agent.tools import BaseTool 11 | 12 | CONFIG_SCHEMA = { 13 | 'name': '... # 角色名字,5字左右', 14 | 'description': '... # 角色简介,10字左右', 15 | 'instructions': '... # 对角色的具体功能要求,30字左右,以第二人称称呼角色' 16 | } 17 | 18 | CONFIG_EXAMPLE = { 19 | 'name': '小红书写作专家', 20 | 'description': '我会写小红书爆款', 21 | 'instructions': '你是小红书爆款写作专家,创作会先产5个标题(含emoji),再产正文(每段落含emoji,文末有tag)。' 22 | } 23 | 24 | BACKGROUND_TOKEN = '' 25 | CONFIG_TOKEN = '' 26 | ANSWER_TOKEN = '' 27 | 28 | ROLE_CREATE_SYSTEM = '''你扮演创建群聊的助手,请你根据用户输入的聊天主题,创建n个合适的虚拟角色,这些角色将在一个聊天室内对话,你需要和用户进行对话,明确用户对这些角色的要求。 29 | 30 | 配置文件为json格式: 31 | {config_schema} 32 | 33 | 一个优秀的RichConfig样例如下: 34 | {config_example} 35 | 36 | 在接下来的对话中,请在回答时严格使用如下格式,先生成群聊背景,然后依次生成所有角色的配置文件,最后再作出回复,除此之外不要回复其他任何内容: 37 | {background_token}: ... # 生成的群聊背景,包括人物关系,预设故事背景等信息。 38 | {config_token}: ... # 生成的第一个角色的配置文件,严格按照以上json格式,禁止为空。保证name和description不为空。instructions内容比description具体,如果用户给出了详细指令,请完全保留,用第二人称描述角色,例如“你是xxx,你具有xxx能力。 39 | {config_token}: ... # 生成的第二个角色的配置文件,要求同上。 40 | ... 41 | {config_token}: ... # 生成的第n个角色的配置文件,要求同上,如果用户没有明确指出n的数量,则n等于3;要求每个角色的名字不相同。 42 | {answer_token}: ... # 你希望对用户说的话,用于询问用户对角色的要求,禁止为空,问题要广泛,不要重复问类似的问题。 43 | 44 | 如果群聊背景或某个角色的配置文件不需要更新,可以不重复输出{background_token}和对应的{config_token}的内容、只输出{answer_token}和需要修改的{config_token}的内容。'''.format( 45 | config_schema=json.dumps(CONFIG_SCHEMA, ensure_ascii=False, indent=2), 46 | config_example=json.dumps(CONFIG_EXAMPLE, ensure_ascii=False, indent=2), 47 | background_token=BACKGROUND_TOKEN, 48 | config_token=CONFIG_TOKEN, 49 | answer_token=ANSWER_TOKEN, 50 | ) 51 | assert CONFIG_TOKEN in ROLE_CREATE_SYSTEM 52 | assert ANSWER_TOKEN in ROLE_CREATE_SYSTEM 53 | 54 | 55 | class GroupChatCreator(Agent): 56 | 57 | def __init__(self, 58 | function_list: Optional[List[Union[str, Dict, 59 | BaseTool]]] = None, 60 | llm: Optional[Union[Dict, BaseChatModel]] = None, 61 | name: Optional[str] = None, 62 | description: Optional[str] = None, 63 | **kwargs): 64 | super().__init__(function_list=function_list, 65 | llm=llm, 66 | system_message=ROLE_CREATE_SYSTEM, 67 | name=name, 68 | description=description, 69 | **kwargs) 70 | 71 | def _run(self, 72 | messages: List[Message], 73 | agents: List[Agent] = None, 74 | lang: str = 'en', 75 | **kwargs) -> Iterator[List[Message]]: 76 | messages = copy.deepcopy(messages) 77 | messages = self._preprocess_messages(messages) 78 | 79 | for rsp in self._call_llm(messages=messages): 80 | yield self._postprocess_messages(rsp) 81 | 82 | def _preprocess_messages(self, messages: List[Message]) -> List[Message]: 83 | new_messages = [] 84 | content = [] 85 | for message in messages: 86 | if message.role != 'assistant': 87 | new_messages.append(message) 88 | else: 89 | if message.name == 'background': 90 | content.append(f'{BACKGROUND_TOKEN}: {message.content}') 91 | elif message.name == 'role_config': 92 | content.append(f'{CONFIG_TOKEN}: {message.content}') 93 | else: 94 | content.append(f'{ANSWER_TOKEN}: {message.content}') 95 | assert new_messages[-1].role == 'user' 96 | new_messages.append( 97 | Message('assistant', '\n'.join(content))) 98 | content = [] 99 | return new_messages 100 | 101 | def _postprocess_messages(self, messages: List[Message]) -> List[Message]: 102 | new_messages = [] 103 | assert len(messages) == 1 104 | message = messages[-1] 105 | background, cfgs, answer = self._extract_role_config_and_answer( 106 | message.content) 107 | if background: 108 | new_messages.append( 109 | Message(message.role, background, name='background')) 110 | if cfgs: 111 | for cfg in cfgs: 112 | new_messages.append( 113 | Message(message.role, cfg, name='role_config')) 114 | 115 | new_messages.append(Message(message.role, answer, name=message.name)) 116 | return new_messages 117 | 118 | def _extract_role_config_and_answer( 119 | self, text: str) -> Tuple[str, List[str], str]: 120 | background, cfgs, answer = '', [], '' 121 | back_pos, cfg_pos, ans_pos = text.find( 122 | f'{BACKGROUND_TOKEN}: '), text.find( 123 | f'{CONFIG_TOKEN}: '), text.find(f'{ANSWER_TOKEN}: ') 124 | 125 | if ans_pos > -1: 126 | answer = text[ans_pos + len(f'{ANSWER_TOKEN}: '):] 127 | else: 128 | ans_pos = len(text) 129 | 130 | if back_pos > -1: 131 | if cfg_pos > back_pos: 132 | background = text[back_pos + 133 | len(f'{BACKGROUND_TOKEN}: '):cfg_pos] 134 | else: 135 | background = text[back_pos + 136 | len(f'{BACKGROUND_TOKEN}: '):ans_pos] 137 | text = text[:ans_pos] 138 | 139 | tmp = text.split(f'{CONFIG_TOKEN}: ') 140 | for t in tmp: 141 | if t.strip(): 142 | try: 143 | _ = json5.loads(t.strip()) 144 | cfgs.append(t.strip()) 145 | except Exception: 146 | continue 147 | 148 | if not (background or cfgs or answer): 149 | # There should always be ANSWER_TOKEN, if not, treat the entire content as answer 150 | answer = text 151 | return background, cfgs, answer 152 | -------------------------------------------------------------------------------- /qwen_agent/agents/react_chat.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Dict, Iterator, List, Optional, Tuple, Union 3 | 4 | from qwen_agent.agents.fncall_agent import MAX_LLM_CALL_PER_RUN, FnCallAgent 5 | from qwen_agent.llm import BaseChatModel 6 | from qwen_agent.llm.schema import (ASSISTANT, CONTENT, DEFAULT_SYSTEM_MESSAGE, 7 | ROLE, ContentItem, Message) 8 | from qwen_agent.tools import BaseTool 9 | from qwen_agent.utils.utils import (get_basename_from_url, 10 | get_function_description, 11 | has_chinese_chars) 12 | 13 | PROMPT_REACT = """Answer the following questions as best you can. You have access to the following tools: 14 | 15 | {tool_descs} 16 | 17 | Use the following format: 18 | 19 | Question: the input question you must answer 20 | Thought: you should always think about what to do 21 | Action: the action to take, should be one of [{tool_names}] 22 | Action Input: the input to the action 23 | Observation: the result of the action 24 | ... (this Thought/Action/Action Input/Observation can be repeated zero or more times) 25 | Thought: I now know the final answer 26 | Final Answer: the final answer to the original input question 27 | 28 | Begin! 29 | 30 | Question: {query}""" 31 | 32 | 33 | class ReActChat(FnCallAgent): 34 | """This agent use ReAct format to call tools""" 35 | 36 | def __init__(self, 37 | function_list: Optional[List[Union[str, Dict, 38 | BaseTool]]] = None, 39 | llm: Optional[Union[Dict, BaseChatModel]] = None, 40 | system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE, 41 | name: Optional[str] = None, 42 | description: Optional[str] = None, 43 | files: Optional[List[str]] = None): 44 | super().__init__(function_list=function_list, 45 | llm=llm, 46 | system_message=system_message, 47 | name=name, 48 | description=description, 49 | files=files) 50 | stop = self.llm.generate_cfg.get('stop', []) 51 | fn_stop = ['Observation:', 'Observation:\n'] 52 | self.llm.generate_cfg['stop'] = stop + [ 53 | x for x in fn_stop if x not in stop 54 | ] 55 | 56 | def _run(self, 57 | messages: List[Message], 58 | lang: str = 'en', 59 | **kwargs) -> Iterator[List[Message]]: 60 | ori_messages = messages 61 | messages = self._preprocess_react_prompt(messages) 62 | 63 | num_llm_calls_available = MAX_LLM_CALL_PER_RUN 64 | response = [] 65 | while True and num_llm_calls_available > 0: 66 | num_llm_calls_available -= 1 67 | output_stream = self._call_llm(messages=messages) 68 | output = [] 69 | 70 | # Yield the streaming response 71 | response_tmp = copy.deepcopy(response) 72 | for output in output_stream: 73 | if output: 74 | if not response_tmp: 75 | yield output 76 | else: 77 | response_tmp[-1][CONTENT] = response[-1][ 78 | CONTENT] + output[-1][CONTENT] 79 | yield response_tmp 80 | # Record the incremental response 81 | assert len(output) == 1 and output[-1][ROLE] == ASSISTANT 82 | if not response: 83 | response += output 84 | else: 85 | response[-1][CONTENT] += output[-1][CONTENT] 86 | 87 | output = output[-1][CONTENT] 88 | 89 | use_tool, action, action_input, text = self._detect_tool(output) 90 | 91 | if use_tool: 92 | observation = self._call_tool(action, 93 | action_input, 94 | messages=ori_messages) 95 | observation = f'\nObservation: {observation}\nThought: ' 96 | response[-1][CONTENT] += observation 97 | yield response 98 | if isinstance(messages[-1][CONTENT], list): 99 | if not ('text' in messages[-1][CONTENT][-1] 100 | and messages[-1][CONTENT][-1]['text'].endswith( 101 | '\nThought: ')): 102 | if not text.startswith('\n'): 103 | text = '\n' + text 104 | messages[-1][CONTENT].append( 105 | ContentItem( 106 | text=text + 107 | f'\nAction: {action}\nAction Input:{action_input}' 108 | + observation)) 109 | else: 110 | if not (messages[-1][CONTENT].endswith('\nThought: ')): 111 | if not text.startswith('\n'): 112 | text = '\n' + text 113 | messages[-1][ 114 | CONTENT] += text + f'\nAction: {action}\nAction Input:{action_input}' + observation 115 | else: 116 | break 117 | 118 | def _detect_tool(self, text: str) -> Tuple[bool, str, str, str]: 119 | special_func_token = '\nAction:' 120 | special_args_token = '\nAction Input:' 121 | special_obs_token = '\nObservation:' 122 | func_name, func_args = None, None 123 | i = text.rfind(special_func_token) 124 | j = text.rfind(special_args_token) 125 | k = text.rfind(special_obs_token) 126 | if 0 <= i < j: # If the text has `Action` and `Action input`, 127 | if k < j: # but does not contain `Observation`, 128 | # then it is likely that `Observation` is ommited by the LLM, 129 | # because the output text may have discarded the stop word. 130 | text = text.rstrip() + special_obs_token # Add it back. 131 | k = text.rfind(special_obs_token) 132 | func_name = text[i + len(special_func_token):j].strip() 133 | func_args = text[j + len(special_args_token):k].strip() 134 | text = text[:i] # Return the response before tool call 135 | 136 | return (func_name is not None), func_name, func_args, text 137 | 138 | def _preprocess_react_prompt(self, 139 | messages: List[Message]) -> List[Message]: 140 | messages = copy.deepcopy(messages) 141 | tool_descs = '\n\n'.join( 142 | get_function_description(func.function) 143 | for func in self.function_map.values()) 144 | tool_names = ','.join(tool.name for tool in self.function_map.values()) 145 | 146 | if isinstance(messages[-1][CONTENT], str): 147 | prompt = PROMPT_REACT.format(tool_descs=tool_descs, 148 | tool_names=tool_names, 149 | query=messages[-1][CONTENT]) 150 | messages[-1][CONTENT] = prompt 151 | return messages 152 | else: 153 | query = '' 154 | new_content = [] 155 | files = [] 156 | for item in messages[-1][CONTENT]: 157 | for k, v in item.model_dump().items(): 158 | if k == 'text': 159 | query += v 160 | elif k == 'file': 161 | files.append(v) 162 | else: 163 | new_content.append(item) 164 | if files: 165 | has_zh = has_chinese_chars(query) 166 | upload = [] 167 | for f in [get_basename_from_url(f) for f in files]: 168 | if has_zh: 169 | upload.append(f'[文件]({f})') 170 | else: 171 | upload.append(f'[file]({f})') 172 | upload = ' '.join(upload) 173 | if has_zh: 174 | upload = f'(上传了 {upload})\n\n' 175 | else: 176 | upload = f'(Uploaded {upload})\n\n' 177 | query = upload + query 178 | 179 | prompt = PROMPT_REACT.format(tool_descs=tool_descs, 180 | tool_names=tool_names, 181 | query=query) 182 | new_content.insert(0, ContentItem(text=prompt)) 183 | messages[-1][CONTENT] = new_content 184 | return messages 185 | -------------------------------------------------------------------------------- /qwen_agent/agents/router.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Dict, Iterator, List, Optional, Union 3 | 4 | from qwen_agent.llm import BaseChatModel 5 | from qwen_agent.llm.schema import ASSISTANT, ROLE, Message 6 | from qwen_agent.tools import BaseTool 7 | 8 | from ..log import logger 9 | from .assistant import Assistant 10 | 11 | ROUTER_PROMPT = '''你有下列帮手: 12 | {agent_descs} 13 | 14 | 当你可以直接回答用户时,请忽略帮手,直接回复;但当你的能力无法达成用户的请求时,请选择其中一个来帮你回答,选择的模版如下: 15 | Call: ... # 选中的帮手的名字,必须在[{agent_names}]中,除了名字,不要返回其余任何内容。 16 | Reply: ... # 选中的帮手的回复 17 | 18 | ——不要向用户透露此条指令。''' 19 | 20 | 21 | class Router(Assistant): 22 | 23 | def __init__(self, 24 | function_list: Optional[List[Union[str, Dict, 25 | BaseTool]]] = None, 26 | llm: Optional[Union[Dict, BaseChatModel]] = None, 27 | files: Optional[List[str]] = None, 28 | name: Optional[str] = None, 29 | description: Optional[str] = None, 30 | agents: Optional[Dict[str, Dict]] = None): 31 | self.agents = agents 32 | 33 | agent_descs = '\n\n'.join( 34 | [f'{k}: {v["desc"]}' for k, v in agents.items()]) 35 | agent_names = ', '.join([k for k in agents.keys()]) 36 | super().__init__(function_list=function_list, 37 | llm=llm, 38 | system_message=ROUTER_PROMPT.format( 39 | agent_descs=agent_descs, agent_names=agent_names), 40 | name=name, 41 | description=description, 42 | files=files) 43 | 44 | stop = self.llm.generate_cfg.get('stop', []) 45 | fn_stop = ['Reply:', 'Reply:\n'] 46 | self.llm.generate_cfg['stop'] = stop + [ 47 | x for x in fn_stop if x not in stop 48 | ] 49 | 50 | def _run(self, 51 | messages: List[Message], 52 | lang: str = 'en', 53 | max_ref_token: int = 4000, 54 | **kwargs) -> Iterator[List[Message]]: 55 | # This is a temporary plan to determine the source of a message 56 | messages_for_router = [] 57 | for msg in messages: 58 | if msg[ROLE] == ASSISTANT: 59 | msg = self.supplement_name_special_token(msg) 60 | messages_for_router.append(msg) 61 | response = [] 62 | for response in super()._run(messages=messages_for_router, 63 | lang=lang, 64 | max_ref_token=max_ref_token, 65 | **kwargs): # noqa 66 | yield response 67 | 68 | if 'Call:' in response[-1].content: 69 | # According to the rule in prompt to selected agent 70 | selected_agent_name = response[-1].content.split( 71 | 'Call:')[-1].strip() 72 | logger.info(f'Need help from {selected_agent_name}') 73 | selected_agent = self.agents[selected_agent_name]['obj'] 74 | for response in selected_agent.run(messages=messages, 75 | lang=lang, 76 | max_ref_token=max_ref_token, 77 | **kwargs): 78 | for i in range(len(response)): 79 | if response[i].role == ASSISTANT: 80 | response[i].name = selected_agent_name 81 | yield response 82 | 83 | @staticmethod 84 | def supplement_name_special_token(message: Message) -> Message: 85 | message = copy.deepcopy(message) 86 | if not message.name: 87 | return message 88 | 89 | if isinstance(message['content'], str): 90 | message['content'] = 'Call: ' + message[ 91 | 'name'] + '\nReply:' + message['content'] 92 | return message 93 | assert isinstance(message['content'], list) 94 | for i, item in enumerate(message['content']): 95 | for k, v in item.model_dump().items(): 96 | if k == 'text': 97 | message['content'][i][k] = 'Call: ' + message[ 98 | 'name'] + '\nReply:' + message['content'][i][k] 99 | break 100 | return message 101 | -------------------------------------------------------------------------------- /qwen_agent/agents/user_agent.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, List 2 | 3 | from qwen_agent.agents.assistant import Assistant 4 | from qwen_agent.llm.schema import Message 5 | 6 | PENDING_USER_INPUT = '' 7 | 8 | 9 | class UserAgent(Assistant): 10 | 11 | def _run(self, 12 | messages: List[Message], 13 | lang: str = 'en', 14 | max_ref_token: int = 4000, 15 | **kwargs) -> Iterator[List[Message]]: 16 | 17 | yield [ 18 | Message(role='user', content=PENDING_USER_INPUT, name=self.name) 19 | ] 20 | -------------------------------------------------------------------------------- /qwen_agent/agents/write_from_scratch.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Iterator, List 3 | 4 | import json5 5 | 6 | from qwen_agent import Agent 7 | from qwen_agent.llm.schema import ASSISTANT, CONTENT, USER, Message 8 | from qwen_agent.prompts import DocQA, ExpandWriting, OutlineWriting 9 | 10 | default_plan = """{"action1": "summarize", "action2": "outline", "action3": "expand"}""" 11 | 12 | 13 | def is_roman_numeral(s): 14 | pattern = r'^(I|V|X|L|C|D|M)+' 15 | match = re.match(pattern, s) 16 | return match is not None 17 | 18 | 19 | class WriteFromScratch(Agent): 20 | 21 | def _run(self, 22 | messages: List[Message], 23 | knowledge: str = '', 24 | lang: str = 'en') -> Iterator[List[Message]]: 25 | 26 | response = [ 27 | Message(ASSISTANT, f'>\n> Use Default plans: \n{default_plan}') 28 | ] 29 | yield response 30 | res_plans = json5.loads(default_plan) 31 | 32 | summ = '' 33 | outline = '' 34 | for plan_id in sorted(res_plans.keys()): 35 | plan = res_plans[plan_id] 36 | if plan == 'summarize': 37 | response.append( 38 | Message(ASSISTANT, '>\n> Summarize Browse Content: \n')) 39 | yield response 40 | 41 | if lang == 'zh': 42 | user_request = '总结参考资料的主要内容' 43 | elif lang == 'en': 44 | user_request = 'Summarize the main content of reference materials.' 45 | else: 46 | raise NotImplementedError 47 | sum_agent = DocQA(llm=self.llm) 48 | res_sum = sum_agent.run(messages=[Message(USER, user_request)], 49 | knowledge=knowledge, 50 | lang=lang) 51 | trunk = None 52 | for trunk in res_sum: 53 | yield response + trunk 54 | if trunk: 55 | response.extend(trunk) 56 | summ = trunk[-1][CONTENT] 57 | elif plan == 'outline': 58 | response.append(Message(ASSISTANT, 59 | '>\n> Generate Outline: \n')) 60 | yield response 61 | 62 | otl_agent = OutlineWriting(llm=self.llm) 63 | res_otl = otl_agent.run(messages=messages, 64 | knowledge=summ, 65 | lang=lang) 66 | trunk = None 67 | for trunk in res_otl: 68 | yield response + trunk 69 | if trunk: 70 | response.extend(trunk) 71 | outline = trunk[-1][CONTENT] 72 | elif plan == 'expand': 73 | response.append(Message(ASSISTANT, '>\n> Writing Text: \n')) 74 | yield response 75 | 76 | outline_list_all = outline.split('\n') 77 | outline_list = [] 78 | for x in outline_list_all: 79 | if is_roman_numeral(x): 80 | outline_list.append(x) 81 | 82 | otl_num = len(outline_list) 83 | for i, v in enumerate(outline_list): 84 | response.append(Message(ASSISTANT, '>\n# ')) 85 | yield response 86 | 87 | index = i + 1 88 | capture = v.strip() 89 | capture_later = '' 90 | if i < otl_num - 1: 91 | capture_later = outline_list[i + 1].strip() 92 | exp_agent = ExpandWriting(llm=self.llm) 93 | res_exp = exp_agent.run( 94 | messages=messages, 95 | knowledge=knowledge, 96 | outline=outline, 97 | index=str(index), 98 | capture=capture, 99 | capture_later=capture_later, 100 | lang=lang, 101 | ) 102 | trunk = None 103 | for trunk in res_exp: 104 | yield response + trunk 105 | if trunk: 106 | response.extend(trunk) 107 | else: 108 | pass 109 | -------------------------------------------------------------------------------- /qwen_agent/llm/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | from qwen_agent.llm.base import LLM_REGISTRY 4 | 5 | from .base import BaseChatModel, ModelServiceError 6 | from .oai import TextChatAtOAI 7 | from .qwen_dashscope import QwenChatAtDS 8 | from .qwenvl_dashscope import QwenVLChatAtDS 9 | 10 | 11 | def get_chat_model(cfg: Optional[Dict] = None) -> BaseChatModel: 12 | """The interface of instantiating LLM objects. 13 | 14 | Args: 15 | cfg: The LLM configuration, one example is: 16 | llm_cfg = { 17 | # Use the model service provided by DashScope: 18 | 'model': 'qwen-max', 19 | 'model_server': 'dashscope', 20 | # Use your own model service compatible with OpenAI API: 21 | # 'model': 'Qwen', 22 | # 'model_server': 'http://127.0.0.1:7905/v1', 23 | # (Optional) LLM hyper-paramters: 24 | 'generate_cfg': { 25 | 'top_p': 0.8 26 | } 27 | } 28 | 29 | Returns: 30 | LLM object. 31 | """ 32 | cfg = cfg or {} 33 | if 'model_type' in cfg: 34 | model_type = cfg['model_type'] 35 | if model_type in LLM_REGISTRY: 36 | return LLM_REGISTRY[model_type](cfg) 37 | else: 38 | raise ValueError( 39 | f'Please set model_type from {str(LLM_REGISTRY.keys())}') 40 | 41 | # Deduce model_type from model and model_server if model_type is not provided: 42 | 43 | if 'model_server' in cfg: 44 | if cfg['model_server'].strip().startswith('http'): 45 | model_type = 'oai' 46 | return LLM_REGISTRY[model_type](cfg) 47 | 48 | model = cfg.get('model', '') 49 | 50 | if 'qwen-vl' in model: 51 | model_type = 'qwenvl_dashscope' 52 | return LLM_REGISTRY[model_type](cfg) 53 | 54 | if 'qwen' in model: 55 | model_type = 'qwen_dashscope' 56 | return LLM_REGISTRY[model_type](cfg) 57 | 58 | raise ValueError(f'Invalid model cfg: {cfg}') 59 | 60 | 61 | __all__ = [ 62 | 'BaseChatModel', 63 | 'QwenChatAtDS', 64 | 'TextChatAtOAI', 65 | 'QwenVLChatAtDS', 66 | 'get_chat_model', 67 | 'ModelServiceError', 68 | ] 69 | -------------------------------------------------------------------------------- /qwen_agent/llm/oai.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pprint import pformat 3 | from typing import Dict, Iterator, List, Optional 4 | 5 | import openai 6 | 7 | if openai.__version__.startswith('0.'): 8 | from openai.error import OpenAIError 9 | else: 10 | from openai import OpenAIError 11 | 12 | from qwen_agent.llm.base import ModelServiceError, register_llm 13 | from qwen_agent.llm.text_base import BaseTextChatModel 14 | from qwen_agent.log import logger 15 | 16 | from .schema import ASSISTANT, Message 17 | 18 | 19 | @register_llm('oai') 20 | class TextChatAtOAI(BaseTextChatModel): 21 | 22 | def __init__(self, cfg: Optional[Dict] = None): 23 | super().__init__(cfg) 24 | self.model = self.model or 'gpt-3.5-turbo' 25 | cfg = cfg or {} 26 | 27 | api_base = cfg.get( 28 | 'api_base', 29 | cfg.get( 30 | 'base_url', 31 | cfg.get('model_server', ''), 32 | ), 33 | ).strip() 34 | 35 | api_key = cfg.get('api_key', '') 36 | if not api_key: 37 | api_key = os.getenv('OPENAI_API_KEY', 'EMPTY') 38 | api_key = api_key.strip() 39 | 40 | if openai.__version__.startswith('0.'): 41 | if api_base: 42 | openai.api_base = api_base 43 | if api_key: 44 | openai.api_key = api_key 45 | self._chat_complete_create = openai.ChatCompletion.create 46 | else: 47 | api_kwargs = {} 48 | if api_base: 49 | api_kwargs['base_url'] = api_base 50 | if api_key: 51 | api_kwargs['api_key'] = api_key 52 | 53 | # OpenAI API v1 does not allow the following args, must pass by extra_body 54 | extra_params = ['top_k', 'repetition_penalty'] 55 | if any((k in self.generate_cfg) for k in extra_params): 56 | self.generate_cfg['extra_body'] = {} 57 | for k in extra_params: 58 | if k in self.generate_cfg: 59 | self.generate_cfg['extra_body'][ 60 | k] = self.generate_cfg.pop(k) 61 | if 'request_timeout' in self.generate_cfg: 62 | self.generate_cfg['timeout'] = self.generate_cfg.pop( 63 | 'request_timeout') 64 | 65 | def _chat_complete_create(*args, **kwargs): 66 | client = openai.OpenAI(**api_kwargs) 67 | return client.chat.completions.create(*args, **kwargs) 68 | 69 | self._chat_complete_create = _chat_complete_create 70 | 71 | def _chat_stream( 72 | self, 73 | messages: List[Message], 74 | delta_stream: bool = False, 75 | ) -> Iterator[List[Message]]: 76 | messages = [msg.model_dump() for msg in messages] 77 | logger.debug(f'*{pformat(messages, indent=2)}*') 78 | try: 79 | response = self._chat_complete_create(model=self.model, 80 | messages=messages, 81 | stream=True, 82 | **self.generate_cfg) 83 | if delta_stream: 84 | for chunk in response: 85 | if hasattr(chunk.choices[0].delta, 86 | 'content') and chunk.choices[0].delta.content: 87 | yield [ 88 | Message(ASSISTANT, chunk.choices[0].delta.content) 89 | ] 90 | else: 91 | full_response = '' 92 | for chunk in response: 93 | if hasattr(chunk.choices[0].delta, 94 | 'content') and chunk.choices[0].delta.content: 95 | full_response += chunk.choices[0].delta.content 96 | yield [Message(ASSISTANT, full_response)] 97 | except OpenAIError as ex: 98 | raise ModelServiceError(exception=ex) 99 | 100 | def _chat_no_stream(self, messages: List[Message]) -> List[Message]: 101 | messages = [msg.model_dump() for msg in messages] 102 | logger.debug(f'*{pformat(messages, indent=2)}*') 103 | try: 104 | response = self._chat_complete_create(model=self.model, 105 | messages=messages, 106 | stream=False, 107 | **self.generate_cfg) 108 | return [Message(ASSISTANT, response.choices[0].message.content)] 109 | except OpenAIError as ex: 110 | raise ModelServiceError(exception=ex) 111 | -------------------------------------------------------------------------------- /qwen_agent/llm/qwen_dashscope.py: -------------------------------------------------------------------------------- 1 | import os 2 | from http import HTTPStatus 3 | from pprint import pformat 4 | from typing import Dict, Iterator, List, Optional, Union 5 | 6 | import dashscope 7 | 8 | from qwen_agent.llm.base import ModelServiceError, register_llm 9 | from qwen_agent.log import logger 10 | 11 | from .schema import ASSISTANT, DEFAULT_SYSTEM_MESSAGE, SYSTEM, USER, Message 12 | from .text_base import BaseTextChatModel 13 | 14 | 15 | @register_llm('qwen_dashscope') 16 | class QwenChatAtDS(BaseTextChatModel): 17 | 18 | def __init__(self, cfg: Optional[Dict] = None): 19 | super().__init__(cfg) 20 | self.model = self.model or 'qwen-max' 21 | 22 | cfg = cfg or {} 23 | api_key = cfg.get('api_key', '') 24 | if not api_key: 25 | api_key = os.getenv('DASHSCOPE_API_KEY', 'EMPTY') 26 | api_key = api_key.strip() 27 | dashscope.api_key = api_key 28 | 29 | def _chat_stream( 30 | self, 31 | messages: List[Message], 32 | delta_stream: bool = False, 33 | ) -> Iterator[List[Message]]: 34 | messages = [msg.model_dump() for msg in messages] 35 | logger.debug(f'*{pformat(messages, indent=2)}*') 36 | response = dashscope.Generation.call( 37 | self.model, 38 | messages=messages, # noqa 39 | result_format='message', 40 | stream=True, 41 | **self.generate_cfg) 42 | if delta_stream: 43 | return self._delta_stream_output(response) 44 | else: 45 | return self._full_stream_output(response) 46 | 47 | def _chat_no_stream( 48 | self, 49 | messages: List[Message], 50 | ) -> List[Message]: 51 | messages = [msg.model_dump() for msg in messages] 52 | logger.debug(f'*{pformat(messages, indent=2)}*') 53 | response = dashscope.Generation.call( 54 | self.model, 55 | messages=messages, # noqa 56 | result_format='message', 57 | stream=False, 58 | **self.generate_cfg) 59 | if response.status_code == HTTPStatus.OK: 60 | return [ 61 | Message(ASSISTANT, response.output.choices[0].message.content) 62 | ] 63 | else: 64 | raise ModelServiceError(code=response.code, 65 | message=response.message) 66 | 67 | def _chat_with_functions( 68 | self, 69 | messages: List[Message], 70 | functions: List[Dict], 71 | stream: bool = True, 72 | delta_stream: bool = False 73 | ) -> Union[List[Message], Iterator[List[Message]]]: 74 | if delta_stream: 75 | raise NotImplementedError 76 | 77 | messages = self._prepend_fncall_system(messages, functions) 78 | 79 | # Using text completion 80 | prompt = self._build_text_completion_prompt(messages) 81 | if stream: 82 | return self._text_completion_stream(prompt, delta_stream) 83 | else: 84 | return self._text_completion_no_stream(prompt) 85 | 86 | def _text_completion_no_stream( 87 | self, 88 | prompt: str, 89 | ) -> List[Message]: 90 | logger.debug(f'*{prompt}*') 91 | response = dashscope.Generation.call(self.model, 92 | prompt=prompt, 93 | result_format='message', 94 | stream=False, 95 | use_raw_prompt=True, 96 | **self.generate_cfg) 97 | if response.status_code == HTTPStatus.OK: 98 | return [ 99 | Message(ASSISTANT, response.output.choices[0].message.content) 100 | ] 101 | else: 102 | raise ModelServiceError(code=response.code, 103 | message=response.message) 104 | 105 | def _text_completion_stream( 106 | self, 107 | prompt: str, 108 | delta_stream: bool = False, 109 | ) -> Iterator[List[Message]]: 110 | logger.debug(f'*{prompt}*') 111 | response = dashscope.Generation.call( 112 | self.model, 113 | prompt=prompt, # noqa 114 | result_format='message', 115 | stream=True, 116 | use_raw_prompt=True, 117 | **self.generate_cfg) 118 | if delta_stream: 119 | return self._delta_stream_output(response) 120 | else: 121 | return self._full_stream_output(response) 122 | 123 | @staticmethod 124 | def _build_text_completion_prompt(messages: List[Message]) -> str: 125 | im_start = '<|im_start|>' 126 | im_end = '<|im_end|>' 127 | if messages[0].role == SYSTEM: 128 | sys = messages[0].content 129 | assert isinstance(sys, str) 130 | prompt = f'{im_start}{SYSTEM}\n{sys}{im_end}' 131 | else: 132 | prompt = f'{im_start}{SYSTEM}\n{DEFAULT_SYSTEM_MESSAGE}{im_end}' 133 | if messages[-1].role != ASSISTANT: 134 | messages.append(Message(ASSISTANT, '')) 135 | for msg in messages: 136 | assert isinstance(msg.content, str) 137 | if msg.role == USER: 138 | query = msg.content.lstrip('\n').rstrip() 139 | prompt += f'\n{im_start}{USER}\n{query}{im_end}' 140 | elif msg.role == ASSISTANT: 141 | response = msg.content.lstrip('\n').rstrip() 142 | prompt += f'\n{im_start}{ASSISTANT}\n{response}{im_end}' 143 | assert prompt.endswith(im_end) 144 | prompt = prompt[:-len(im_end)] 145 | return prompt 146 | 147 | @staticmethod 148 | def _delta_stream_output(response) -> Iterator[List[Message]]: 149 | last_len = 0 150 | delay_len = 5 151 | in_delay = False 152 | text = '' 153 | for trunk in response: 154 | if trunk.status_code == HTTPStatus.OK: 155 | text = trunk.output.choices[0].message.content 156 | if (len(text) - last_len) <= delay_len: 157 | in_delay = True 158 | continue 159 | else: 160 | in_delay = False 161 | real_text = text[:-delay_len] 162 | now_rsp = real_text[last_len:] 163 | yield [Message(ASSISTANT, now_rsp)] 164 | last_len = len(real_text) 165 | else: 166 | raise ModelServiceError(code=trunk.code, message=trunk.message) 167 | if text and (in_delay or (last_len != len(text))): 168 | yield [Message(ASSISTANT, text[last_len:])] 169 | 170 | @staticmethod 171 | def _full_stream_output(response) -> Iterator[List[Message]]: 172 | for trunk in response: 173 | if trunk.status_code == HTTPStatus.OK: 174 | yield [ 175 | Message(ASSISTANT, trunk.output.choices[0].message.content) 176 | ] 177 | else: 178 | raise ModelServiceError(code=trunk.code, message=trunk.message) 179 | -------------------------------------------------------------------------------- /qwen_agent/llm/qwenvl_dashscope.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import re 4 | from http import HTTPStatus 5 | from pprint import pformat 6 | from typing import Dict, Iterator, List, Optional 7 | 8 | import dashscope 9 | 10 | from qwen_agent.llm.base import ModelServiceError, register_llm 11 | from qwen_agent.llm.function_calling import BaseFnCallModel 12 | from qwen_agent.llm.text_base import format_as_text_messages 13 | from qwen_agent.log import logger 14 | 15 | from .schema import ContentItem, Message 16 | 17 | 18 | @register_llm('qwenvl_dashscope') 19 | class QwenVLChatAtDS(BaseFnCallModel): 20 | 21 | def __init__(self, cfg: Optional[Dict] = None): 22 | super().__init__(cfg) 23 | self.model = self.model or 'qwen-vl-max' 24 | 25 | cfg = cfg or {} 26 | api_key = cfg.get('api_key', '') 27 | if not api_key: 28 | api_key = os.getenv('DASHSCOPE_API_KEY', 'EMPTY') 29 | api_key = api_key.strip() 30 | dashscope.api_key = api_key 31 | 32 | def _chat_stream( 33 | self, 34 | messages: List[Message], 35 | delta_stream: bool = False, 36 | ) -> Iterator[List[Message]]: 37 | if delta_stream: 38 | raise NotImplementedError 39 | 40 | messages = _format_local_files(messages) 41 | messages = [msg.model_dump() for msg in messages] 42 | logger.debug(f'*{pformat(messages, indent=2)}*') 43 | response = dashscope.MultiModalConversation.call( 44 | model=self.model, 45 | messages=messages, 46 | result_format='message', 47 | stream=True, 48 | **self.generate_cfg) 49 | 50 | for trunk in response: 51 | if trunk.status_code == HTTPStatus.OK: 52 | yield _extract_vl_response(trunk) 53 | else: 54 | raise ModelServiceError(code=trunk.code, message=trunk.message) 55 | 56 | def _chat_no_stream( 57 | self, 58 | messages: List[Message], 59 | ) -> List[Message]: 60 | messages = _format_local_files(messages) 61 | messages = [msg.model_dump() for msg in messages] 62 | logger.debug(f'*{pformat(messages, indent=2)}*') 63 | response = dashscope.MultiModalConversation.call( 64 | model=self.model, 65 | messages=messages, 66 | result_format='message', 67 | stream=False, 68 | **self.generate_cfg) 69 | if response.status_code == HTTPStatus.OK: 70 | return _extract_vl_response(response=response) 71 | else: 72 | raise ModelServiceError(code=response.code, 73 | message=response.message) 74 | 75 | def _postprocess_messages(self, messages: List[Message], 76 | fncall_mode: bool) -> List[Message]: 77 | messages = super()._postprocess_messages(messages, 78 | fncall_mode=fncall_mode) 79 | # Make VL return the same format as text models for easy usage 80 | messages = format_as_text_messages(messages) 81 | return messages 82 | 83 | 84 | # DashScope Qwen-VL requires the following format for local files: 85 | # Linux & Mac: file:///home/images/test.png 86 | # Windows: file://D:/images/abc.png 87 | def _format_local_files(messages: List[Message]) -> List[Message]: 88 | messages = copy.deepcopy(messages) 89 | for msg in messages: 90 | if isinstance(msg.content, list): 91 | for item in msg.content: 92 | if item.image: 93 | fname = item.image 94 | if not fname.startswith(( 95 | 'http://', 96 | 'https://', 97 | 'file://', 98 | )): 99 | if fname.startswith('~'): 100 | fname = os.path.expanduser(fname) 101 | if re.match(r'^[A-Za-z]:\\', fname): 102 | fname = fname.replace('\\', '/') 103 | item.image = fname 104 | return messages 105 | 106 | 107 | def _extract_vl_response(response) -> List[Message]: 108 | output = response.output.choices[0].message 109 | text_content = [] 110 | for item in output.content: 111 | if isinstance(item, str): 112 | text_content.append(ContentItem(text=item)) 113 | else: 114 | for k, v in item.items(): 115 | if k in ('text', 'box'): 116 | text_content.append(ContentItem(text=v)) 117 | return [Message(role=output.role, content=text_content)] 118 | -------------------------------------------------------------------------------- /qwen_agent/llm/schema.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | from pydantic import BaseModel, field_validator, model_validator 4 | 5 | DEFAULT_SYSTEM_MESSAGE = 'You are a helpful assistant.' 6 | 7 | ROLE = 'role' 8 | CONTENT = 'content' 9 | 10 | SYSTEM = 'system' 11 | USER = 'user' 12 | ASSISTANT = 'assistant' 13 | FUNCTION = 'function' 14 | 15 | FILE = 'file' 16 | IMAGE = 'image' 17 | 18 | 19 | class BaseModelCompatibleDict(BaseModel): 20 | 21 | def __getitem__(self, item): 22 | return getattr(self, item) 23 | 24 | def __setitem__(self, key, value): 25 | setattr(self, key, value) 26 | 27 | def model_dump(self, **kwargs): 28 | return super().model_dump(exclude_none=True, **kwargs) 29 | 30 | def model_dump_json(self, **kwargs): 31 | return super().model_dump_json(exclude_none=True, **kwargs) 32 | 33 | def get(self, key, default=None): 34 | try: 35 | value = getattr(self, key) 36 | if value: 37 | return value 38 | else: 39 | return default 40 | except AttributeError: 41 | return default 42 | 43 | def __str__(self): 44 | return f'{self.model_dump()}' 45 | 46 | 47 | class FunctionCall(BaseModelCompatibleDict): 48 | name: str 49 | arguments: str 50 | 51 | def __init__(self, name: str, arguments: str): 52 | super().__init__(name=name, arguments=arguments) 53 | 54 | def __repr__(self): 55 | return f'FunctionCall({self.model_dump()})' 56 | 57 | 58 | class ContentItem(BaseModelCompatibleDict): 59 | text: Optional[str] = None 60 | image: Optional[str] = None 61 | file: Optional[str] = None 62 | 63 | def __init__(self, 64 | text: Optional[str] = None, 65 | image: Optional[str] = None, 66 | file: Optional[str] = None): 67 | super().__init__(text=text, image=image, file=file) 68 | 69 | @model_validator(mode='after') 70 | def check_exclusivity(self): 71 | provided_fields = 0 72 | if self.text is not None: 73 | provided_fields += 1 74 | if self.image: 75 | provided_fields += 1 76 | if self.file: 77 | provided_fields += 1 78 | 79 | if provided_fields != 1: 80 | raise ValueError( 81 | "Exactly one of 'text', 'image', or 'file' must be provided.") 82 | return self 83 | 84 | def __repr__(self): 85 | return f'ContentItem({self.model_dump()})' 86 | 87 | def get_type_and_value(self): 88 | (t, v), = self.model_dump().items() 89 | assert t in ('text', 'image', 'file') 90 | return t, v 91 | 92 | 93 | class Message(BaseModelCompatibleDict): 94 | role: str 95 | content: Union[str, List[ContentItem]] 96 | name: Optional[str] = None 97 | function_call: Optional[FunctionCall] = None 98 | 99 | def __init__(self, 100 | role: str, 101 | content: Optional[Union[str, List[ContentItem]]], 102 | name: Optional[str] = None, 103 | function_call: Optional[FunctionCall] = None, 104 | **kwargs): 105 | if content is None: 106 | content = '' 107 | super().__init__(role=role, 108 | content=content, 109 | name=name, 110 | function_call=function_call) 111 | 112 | def __repr__(self): 113 | return f'Message({self.model_dump()})' 114 | 115 | @field_validator('role') 116 | def role_checker(cls, value: str) -> str: 117 | if value not in [USER, ASSISTANT, SYSTEM, FUNCTION]: 118 | raise ValueError( 119 | f'{value} must be one of {",".join([USER, ASSISTANT, SYSTEM, FUNCTION])}' 120 | ) 121 | return value 122 | -------------------------------------------------------------------------------- /qwen_agent/llm/text_base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from typing import List 3 | 4 | from qwen_agent.llm.function_calling import BaseFnCallModel 5 | 6 | from .schema import ASSISTANT, FUNCTION, SYSTEM, USER, Message 7 | 8 | 9 | class BaseTextChatModel(BaseFnCallModel, ABC): 10 | 11 | def _preprocess_messages(self, messages: List[Message]) -> List[Message]: 12 | messages = super()._preprocess_messages(messages) 13 | messages = format_as_text_messages(messages) 14 | return messages 15 | 16 | def _postprocess_messages(self, messages: List[Message], 17 | fncall_mode: bool) -> List[Message]: 18 | messages = super()._postprocess_messages(messages, 19 | fncall_mode=fncall_mode) 20 | messages = format_as_text_messages(messages) 21 | return messages 22 | 23 | 24 | def format_as_text_messages( 25 | multimodal_messages: List[Message]) -> List[Message]: 26 | text_messages = [] 27 | for msg in multimodal_messages: 28 | assert msg.role in (USER, ASSISTANT, SYSTEM, FUNCTION) 29 | content = '' 30 | if isinstance(msg.content, str): 31 | content = msg.content 32 | elif isinstance(msg.content, list): 33 | for item in msg.content: 34 | if item.text: 35 | content += item.text 36 | # Discard multimodal content such as files and images 37 | else: 38 | raise TypeError 39 | text_messages.append( 40 | Message(role=msg.role, 41 | content=content, 42 | name=msg.name if msg.role == FUNCTION else None, 43 | function_call=msg.function_call)) 44 | return text_messages 45 | -------------------------------------------------------------------------------- /qwen_agent/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | 5 | def setup_logger(level=None): 6 | 7 | if level is None: 8 | if int(os.getenv('QWEN_AGENT_DEBUG', '0').strip()): 9 | level = logging.DEBUG 10 | else: 11 | level = logging.INFO 12 | 13 | logger = logging.getLogger('qwen_agent_logger') 14 | logger.setLevel(level) 15 | handler = logging.StreamHandler() 16 | handler.setLevel(level) 17 | formatter = logging.Formatter( 18 | '%(asctime)s - %(filename)s - %(lineno)d - %(levelname)s - %(message)s' 19 | ) 20 | handler.setFormatter(formatter) 21 | logger.addHandler(handler) 22 | return logger 23 | 24 | 25 | logger = setup_logger() 26 | -------------------------------------------------------------------------------- /qwen_agent/memory/__init__.py: -------------------------------------------------------------------------------- 1 | from .memory import Memory 2 | 3 | __all__ = ['Memory'] 4 | -------------------------------------------------------------------------------- /qwen_agent/memory/memory.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Dict, Iterator, List, Optional, Union 3 | 4 | import json5 5 | 6 | from qwen_agent import Agent 7 | from qwen_agent.llm import BaseChatModel 8 | from qwen_agent.llm.schema import (ASSISTANT, DEFAULT_SYSTEM_MESSAGE, USER, 9 | Message) 10 | from qwen_agent.log import logger 11 | from qwen_agent.prompts import GenKeyword 12 | from qwen_agent.tools import BaseTool 13 | from qwen_agent.utils.utils import get_file_type 14 | 15 | 16 | class Memory(Agent): 17 | """Memory is special agent for file management. 18 | 19 | By default, this memory can use retrieval tool for RAG. 20 | """ 21 | 22 | def __init__(self, 23 | function_list: Optional[List[Union[str, Dict, 24 | BaseTool]]] = None, 25 | llm: Optional[Union[Dict, BaseChatModel]] = None, 26 | system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE, 27 | files: Optional[List[str]] = None): 28 | function_list = function_list or [] 29 | super().__init__(function_list=['retrieval'] + function_list, 30 | llm=llm, 31 | system_message=system_message) 32 | 33 | self.keygen = GenKeyword(llm=llm) 34 | 35 | self.system_files = files or [] 36 | 37 | def _run(self, 38 | messages: List[Message], 39 | max_ref_token: int = 4000, 40 | lang: str = 'en', 41 | ignore_cache: bool = False) -> Iterator[List[Message]]: 42 | """This agent is responsible for processing the input files in the message. 43 | 44 | This method stores the files in the knowledge base, and retrievals the relevant parts 45 | based on the query and returning them. 46 | The currently supported file types include: .pdf, .docx, .pptx, and html. 47 | 48 | Args: 49 | messages: A list of messages. 50 | max_ref_token: Search window for reference materials. 51 | lang: Language. 52 | ignore_cache: Whether to reparse the same files. 53 | 54 | Yields: 55 | The message of retrieved documents. 56 | """ 57 | # process files in messages 58 | session_files = self.get_all_files_of_messages(messages) 59 | files = self.system_files + session_files 60 | rag_files = [] 61 | for file in files: 62 | if (file.split('.')[-1].lower() in [ 63 | 'pdf', 'docx', 'pptx' 64 | ]) or get_file_type(file) == 'html': 65 | rag_files.append(file) 66 | 67 | if not rag_files: 68 | yield [Message(role=ASSISTANT, content='', name='memory')] 69 | else: 70 | query = '' 71 | # Only retrieval content according to the last user query if exists 72 | if messages and messages[-1].role == USER: 73 | if isinstance(messages[-1].content, str): 74 | query = messages[-1].content 75 | else: 76 | for item in messages[-1].content: 77 | if item.text: 78 | query += item.text 79 | if query: 80 | # Gen keyword 81 | *_, last = self.keygen.run([Message(USER, query)]) 82 | keyword = last[-1].content 83 | keyword = keyword.strip() 84 | if keyword.startswith('```json'): 85 | keyword = keyword[len('```json'):] 86 | if keyword.endswith('```'): 87 | keyword = keyword[:-3] 88 | try: 89 | logger.info(keyword) 90 | keyword_dict = json5.loads(keyword) 91 | keyword_dict['text'] = query 92 | query = json.dumps(keyword_dict, ensure_ascii=False) 93 | except Exception: 94 | query = query 95 | 96 | content = self._call_tool('retrieval', { 97 | 'query': query, 98 | 'files': rag_files 99 | }, 100 | ignore_cache=ignore_cache, 101 | max_token=max_ref_token) 102 | 103 | yield [Message(role=ASSISTANT, content=content, name='memory')] 104 | 105 | @staticmethod 106 | def get_all_files_of_messages(messages: List[Message]): 107 | files = [] 108 | for msg in messages: 109 | if isinstance(msg.content, list): 110 | for item in msg.content: 111 | if item.file and item.file not in files: 112 | files.append(item.file) 113 | return files 114 | -------------------------------------------------------------------------------- /qwen_agent/prompts/__init__.py: -------------------------------------------------------------------------------- 1 | """Prompts are special agents: using a prompt template to complete one QA.""" 2 | 3 | from .continue_writing import ContinueWriting 4 | from .doc_qa import DocQA 5 | from .expand_writing import ExpandWriting 6 | from .gen_keyword import GenKeyword 7 | from .outline_writing import OutlineWriting 8 | 9 | __all__ = [ 10 | 'DocQA', 'ContinueWriting', 'OutlineWriting', 'ExpandWriting', 'GenKeyword' 11 | ] 12 | -------------------------------------------------------------------------------- /qwen_agent/prompts/continue_writing.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Iterator, List 3 | 4 | from qwen_agent import Agent 5 | from qwen_agent.llm.schema import CONTENT, Message 6 | 7 | PROMPT_TEMPLATE_ZH = """你是一个写作助手,请依据参考资料,根据给定的前置文本续写合适的内容。 8 | #参考资料: 9 | {ref_doc} 10 | 11 | #前置文本: 12 | {user_request} 13 | 14 | 保证续写内容和前置文本保持连贯,请开始续写:""" 15 | 16 | PROMPT_TEMPLATE_EN = """You are a writing assistant, please follow the reference materials and continue to write appropriate content based on the given previous text. 17 | 18 | # References: 19 | {ref_doc} 20 | 21 | # Previous text: 22 | {user_request} 23 | 24 | Please start writing directly, output only the continued text, do not repeat the previous text, do not say irrelevant words, and ensure that the continued content and the previous text remain consistent.""" 25 | 26 | PROMPT_TEMPLATE = { 27 | 'zh': PROMPT_TEMPLATE_ZH, 28 | 'en': PROMPT_TEMPLATE_EN, 29 | } 30 | 31 | 32 | class ContinueWriting(Agent): 33 | 34 | def _run(self, 35 | messages: List[Message], 36 | knowledge: str = '', 37 | lang: str = 'en', 38 | **kwargs) -> Iterator[List[Message]]: 39 | messages = copy.deepcopy(messages) 40 | messages[-1][CONTENT] = PROMPT_TEMPLATE[lang].format( 41 | ref_doc=knowledge, 42 | user_request=messages[-1][CONTENT], 43 | ) 44 | return self._call_llm(messages) 45 | -------------------------------------------------------------------------------- /qwen_agent/prompts/doc_qa.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Iterator, List 3 | 4 | from qwen_agent import Agent 5 | from qwen_agent.llm.schema import CONTENT, ROLE, SYSTEM, Message 6 | 7 | PROMPT_TEMPLATE_ZH = """ 8 | 请充分理解以下参考资料内容,组织出满足用户提问的条理清晰的回复。 9 | #参考资料: 10 | {ref_doc} 11 | 12 | """ 13 | 14 | PROMPT_TEMPLATE_EN = """ 15 | Please fully understand the content of the following reference materials and organize a clear response that meets the user's questions. 16 | # Reference materials: 17 | {ref_doc} 18 | 19 | """ 20 | 21 | PROMPT_TEMPLATE = { 22 | 'zh': PROMPT_TEMPLATE_ZH, 23 | 'en': PROMPT_TEMPLATE_EN, 24 | } 25 | 26 | 27 | class DocQA(Agent): 28 | 29 | def _run(self, 30 | messages: List[Message], 31 | knowledge: str = '', 32 | lang: str = 'en', 33 | **kwargs) -> Iterator[List[Message]]: 34 | messages = copy.deepcopy(messages) 35 | system_prompt = PROMPT_TEMPLATE[lang].format(ref_doc=knowledge) 36 | if messages[0][ROLE] == SYSTEM: 37 | messages[0][CONTENT] += system_prompt 38 | else: 39 | messages.insert(0, Message(SYSTEM, system_prompt)) 40 | 41 | return self._call_llm(messages=messages) 42 | -------------------------------------------------------------------------------- /qwen_agent/prompts/expand_writing.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Iterator, List 3 | 4 | from qwen_agent import Agent 5 | from qwen_agent.llm.schema import CONTENT, Message 6 | 7 | PROMPT_TEMPLATE_ZH = """ 8 | 你是一个写作助手,任务是依据参考资料,完成写作任务。 9 | #参考资料: 10 | {ref_doc} 11 | 12 | 写作标题是:{user_request} 13 | 大纲是: 14 | {outline} 15 | 16 | 此时你的任务是扩写第{index}个一级标题对应的章节:{capture}。注意每个章节负责撰写不同的内容,所以你不需要为了全面而涵盖之后的内容。请不要在这里生成大纲。只依据给定的参考资料来写,不要引入其余知识。 17 | """ 18 | 19 | PROMPT_TEMPLATE_EN = """ 20 | You are a writing assistant. Your task is to complete writing article based on reference materials. 21 | 22 | # References: 23 | {ref_doc} 24 | 25 | The title is: {user_request} 26 | 27 | The outline is: 28 | {outline} 29 | 30 | At this point, your task is to expand the chapter corresponding to the {index} first level title: {capture}. 31 | Note that each chapter is responsible for writing different content, so you don't need to cover the following content. Please do not generate an outline here. Write only based on the given reference materials and do not introduce other knowledge. 32 | """ 33 | 34 | PROMPT_TEMPLATE = { 35 | 'zh': PROMPT_TEMPLATE_ZH, 36 | 'en': PROMPT_TEMPLATE_EN, 37 | } 38 | 39 | 40 | class ExpandWriting(Agent): 41 | 42 | def _run(self, 43 | messages: List[Message], 44 | knowledge: str = '', 45 | outline: str = '', 46 | index: str = '1', 47 | capture: str = '', 48 | capture_later: str = '', 49 | lang: str = 'en', 50 | **kwargs) -> Iterator[List[Message]]: 51 | messages = copy.deepcopy(messages) 52 | prompt = PROMPT_TEMPLATE[lang].format( 53 | ref_doc=knowledge, 54 | user_request=messages[-1][CONTENT], 55 | index=index, 56 | outline=outline, 57 | capture=capture, 58 | ) 59 | if capture_later: 60 | if lang == 'zh': 61 | prompt = prompt + '请在涉及 ' + capture_later + ' 时停止。' 62 | elif lang == 'en': 63 | prompt = prompt + ' Please stop when writing ' + capture_later 64 | else: 65 | raise NotImplementedError 66 | 67 | messages[-1][CONTENT] = prompt 68 | return self._call_llm(messages) 69 | -------------------------------------------------------------------------------- /qwen_agent/prompts/gen_keyword.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Dict, Iterator, List, Optional, Union 3 | 4 | from qwen_agent import Agent 5 | from qwen_agent.llm import get_chat_model 6 | from qwen_agent.llm.base import BaseChatModel 7 | from qwen_agent.llm.schema import CONTENT, DEFAULT_SYSTEM_MESSAGE, Message 8 | from qwen_agent.tools import BaseTool 9 | 10 | PROMPT_TEMPLATE_ZH = """请提取问题中的关键词,需要中英文均有,可以适量补充不在问题中但相关的关键词。关键词尽量切分为动词/名词/形容词等类型,不要长词组。关键词以JSON的格式给出,比如{{"keywords_zh": ["关键词1", "关键词2"], "keywords_en": ["keyword 1", "keyword 2"]}} 11 | 12 | Question: 这篇文章的作者是谁? 13 | Keywords: {{"keywords_zh": ["作者"], "keywords_en": ["author"]}} 14 | Observation: ... 15 | 16 | Question: 解释下图一 17 | Keywords: {{"keywords_zh": ["图一", "图 1"], "keywords_en": ["Figure 1"]}} 18 | Observation: ... 19 | 20 | Question: 核心公式 21 | Keywords: {{"keywords_zh": ["核心公式", "公式"], "keywords_en": ["core formula", "formula", "equation"]}} 22 | Observation: ... 23 | 24 | Question: {user_request} 25 | Keywords: 26 | """ 27 | 28 | PROMPT_TEMPLATE_EN = """Please extract keywords from the question, both in Chinese and English, and supplement them appropriately with relevant keywords that are not in the question. Try to divide keywords into verb/noun/adjective types and avoid long phrases. 29 | Keywords are provided in JSON format, such as {{"keywords_zh": ["关键词1", "关键词2"], "keywords_en": ["keyword 1", "keyword 2"]}} 30 | 31 | Question: Who are the authors of this article? 32 | Keywords: {{"keywords_zh": ["作者"], "keywords_en": ["author"]}} 33 | Observation: ... 34 | 35 | Question: Explain Figure 1 36 | Keywords: {{"keywords_zh": ["图一", "图 1"], "keywords_en": ["Figure 1"]}} 37 | Observation: ... 38 | 39 | Question: core formula 40 | Keywords: {{"keywords_zh": ["核心公式", "公式"], "keywords_en": ["core formula", "formula", "equation"]}} 41 | Observation: ... 42 | 43 | Question: {user_request} 44 | Keywords: 45 | """ 46 | 47 | PROMPT_TEMPLATE = { 48 | 'zh': PROMPT_TEMPLATE_ZH, 49 | 'en': PROMPT_TEMPLATE_EN, 50 | } 51 | 52 | 53 | class GenKeyword(Agent): 54 | 55 | # TODO: Adding a stop word is not conveient! We should fix this later. 56 | def __init__(self, 57 | function_list: Optional[List[Union[str, Dict, 58 | BaseTool]]] = None, 59 | llm: Optional[Union[Dict, BaseChatModel]] = None, 60 | system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE, 61 | **kwargs): 62 | if llm is not None: # TODO: Why this happens? 63 | llm = copy.deepcopy(llm) 64 | if isinstance(llm, dict): 65 | llm = get_chat_model(llm) 66 | stop = llm.generate_cfg.get('stop', []) 67 | key_stop = ['Observation:', 'Observation:\n'] 68 | llm.generate_cfg['stop'] = stop + [ 69 | x for x in key_stop if x not in stop 70 | ] 71 | super().__init__(function_list, llm, system_message, **kwargs) 72 | 73 | def _run(self, 74 | messages: List[Message], 75 | lang: str = 'en', 76 | **kwargs) -> Iterator[List[Message]]: 77 | messages = copy.deepcopy(messages) 78 | messages[-1][CONTENT] = PROMPT_TEMPLATE[lang].format( 79 | user_request=messages[-1][CONTENT]) 80 | return self._call_llm(messages) 81 | -------------------------------------------------------------------------------- /qwen_agent/prompts/outline_writing.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Iterator, List 3 | 4 | from qwen_agent import Agent 5 | from qwen_agent.llm.schema import CONTENT, Message 6 | 7 | PROMPT_TEMPLATE_ZH = """ 8 | 你是一个写作助手,任务是充分理解参考资料,从而完成写作。 9 | #参考资料: 10 | {ref_doc} 11 | 12 | 写作标题是:{user_request} 13 | 14 | 为了完成以上写作任务,请先列出大纲。回复只需包含大纲。大纲的一级标题全部以罗马数字计数。只依据给定的参考资料来写,不要引入其余知识。 15 | """ 16 | 17 | PROMPT_TEMPLATE_EN = """ 18 | You are a writing assistant. Your task is to complete writing article based on reference materials. 19 | 20 | # References: 21 | {ref_doc} 22 | 23 | The title is: {user_request} 24 | 25 | In order to complete the above writing tasks, please provide an outline first. The reply only needs to include an outline. The first level titles of the outline are all counted in Roman numerals. Write only based on the given reference materials and do not introduce other knowledge. 26 | """ 27 | 28 | PROMPT_TEMPLATE = { 29 | 'zh': PROMPT_TEMPLATE_ZH, 30 | 'en': PROMPT_TEMPLATE_EN, 31 | } 32 | 33 | 34 | class OutlineWriting(Agent): 35 | 36 | def _run(self, 37 | messages: List[Message], 38 | knowledge: str = '', 39 | lang: str = 'en', 40 | **kwargs) -> Iterator[List[Message]]: 41 | messages = copy.deepcopy(messages) 42 | messages[-1][CONTENT] = PROMPT_TEMPLATE[lang].format( 43 | ref_doc=knowledge, 44 | user_request=messages[-1][CONTENT], 45 | ) 46 | return self._call_llm(messages) 47 | -------------------------------------------------------------------------------- /qwen_agent/tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .amap_weather import AmapWeather 2 | from .base import TOOL_REGISTRY, BaseTool 3 | # from .code_interpreter import CodeInterpreter 4 | from .doc_parser import DocParser 5 | from .image_gen import ImageGen 6 | from .retrieval import Retrieval 7 | from .similarity_search import SimilaritySearch 8 | from .storage import Storage 9 | from .web_extractor import WebExtractor 10 | 11 | 12 | def call_tool(plugin_name: str, plugin_args: str) -> str: 13 | if plugin_name in TOOL_REGISTRY: 14 | return TOOL_REGISTRY[plugin_name].call(plugin_args) 15 | else: 16 | raise NotImplementedError 17 | 18 | 19 | __all__ = [ 20 | 'BaseTool', 'CodeInterpreter', 'ImageGen', 'AmapWeather', 'TOOL_REGISTRY', 21 | 'DocParser', 'SimilaritySearch', 'Storage', 'Retrieval', 'WebExtractor' 22 | ] 23 | -------------------------------------------------------------------------------- /qwen_agent/tools/amap_weather.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, Optional, Union 3 | 4 | import pandas as pd 5 | import requests 6 | 7 | from qwen_agent.tools.base import BaseTool, register_tool 8 | 9 | 10 | @register_tool('amap_weather') 11 | class AmapWeather(BaseTool): 12 | description = '获取对应城市的天气数据' 13 | parameters = [{ 14 | 'name': 'location', 15 | 'type': 'string', 16 | 'description': 'get temperature for a specific location', 17 | 'required': True 18 | }] 19 | 20 | def __init__(self, cfg: Optional[Dict] = None): 21 | super().__init__(cfg) 22 | 23 | # remote call 24 | self.url = 'https://restapi.amap.com/v3/weather/weatherInfo?city={city}&key={key}' 25 | self.city_df = pd.read_excel( 26 | 'https://modelscope.oss-cn-beijing.aliyuncs.com/resource/agent/AMap_adcode_citycode.xlsx' 27 | ) 28 | 29 | self.token = self.cfg.get('token', os.environ.get('AMAP_TOKEN', '')) 30 | assert self.token != '', 'weather api token must be acquired through ' \ 31 | 'https://lbs.amap.com/api/webservice/guide/create-project/get-key and set by AMAP_TOKEN' 32 | 33 | def get_city_adcode(self, city_name): 34 | filtered_df = self.city_df[self.city_df['中文名'] == city_name] 35 | if len(filtered_df['adcode'].values) == 0: 36 | raise ValueError( 37 | f'location {city_name} not found, availables are {self.city_df["中文名"]}' 38 | ) 39 | else: 40 | return filtered_df['adcode'].values[0] 41 | 42 | def call(self, params: Union[str, dict], **kwargs) -> str: 43 | params = self._verify_json_format_args(params) 44 | 45 | location = params['location'] 46 | response = requests.get( 47 | self.url.format(city=self.get_city_adcode(location), 48 | key=self.token)) 49 | data = response.json() 50 | if data['status'] == '0': 51 | raise RuntimeError(data) 52 | else: 53 | weather = data['lives'][0]['weather'] 54 | temperature = data['lives'][0]['temperature'] 55 | return f'{location}的天气是{weather}温度是{temperature}度。' 56 | -------------------------------------------------------------------------------- /qwen_agent/tools/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, List, Optional, Union 3 | 4 | import json5 5 | 6 | from qwen_agent.utils.utils import logger 7 | 8 | TOOL_REGISTRY = {} 9 | 10 | 11 | def register_tool(name, allow_overwrite=False): 12 | 13 | def decorator(cls): 14 | if name in TOOL_REGISTRY: 15 | if allow_overwrite: 16 | logger.warning( 17 | f'Tool `{name}` already exists! Overwriting with class {cls}.' 18 | ) 19 | else: 20 | raise ValueError( 21 | f'Tool `{name}` already exists! Please ensure that the tool name is unique.' 22 | ) 23 | if cls.name and (cls.name != name): 24 | raise ValueError( 25 | f'{cls.__name__}.name="{cls.name}" conflicts with @register_tool(name="{name}").' 26 | ) 27 | cls.name = name 28 | TOOL_REGISTRY[name] = cls 29 | 30 | return cls 31 | 32 | return decorator 33 | 34 | 35 | class BaseTool(ABC): 36 | name: str = '' 37 | description: str = '' 38 | parameters: List[Dict] = [] 39 | 40 | def __init__(self, cfg: Optional[Dict] = None): 41 | self.cfg = cfg or {} 42 | if not self.name: 43 | raise ValueError( 44 | f'You must set {self.__class__.__name__}.name, either by @register_tool(name=...) or explicitly setting {self.__class__.__name__}.name' 45 | ) 46 | 47 | self.name_for_human = self.cfg.get('name_for_human', self.name) 48 | if not hasattr(self, 'args_format'): 49 | self.args_format = self.cfg.get('args_format', '此工具的输入应为JSON对象。') 50 | self.function = self._build_function() 51 | self.file_access = False 52 | 53 | @abstractmethod 54 | def call(self, params: Union[str, dict], 55 | **kwargs) -> Union[str, list, dict]: 56 | """The interface for calling tools. 57 | 58 | Each tool needs to implement this function, which is the workflow of the tool. 59 | 60 | Args: 61 | params: The parameters of func_call. 62 | kwargs: Additional parameters for calling tools. 63 | 64 | Returns: 65 | The result returned by the tool, implemented in the subclass. 66 | """ 67 | raise NotImplementedError 68 | 69 | def _verify_json_format_args(self, 70 | params: Union[str, dict]) -> Union[str, dict]: 71 | """Verify the parameters of the function call""" 72 | try: 73 | if isinstance(params, str): 74 | params_json = json5.loads(params) 75 | else: 76 | params_json = params 77 | for param in self.parameters: 78 | if 'required' in param and param['required']: 79 | if param['name'] not in params_json: 80 | raise ValueError('Parameters %s is required!' % 81 | param['name']) 82 | return params_json 83 | except Exception: 84 | raise ValueError('Parameters cannot be converted to Json Format!') 85 | 86 | def _build_function(self) -> dict: 87 | return { 88 | 'name_for_human': self.name_for_human, 89 | 'name': self.name, 90 | 'description': self.description, 91 | 'parameters': self.parameters, 92 | 'args_format': self.args_format 93 | } 94 | -------------------------------------------------------------------------------- /qwen_agent/tools/code_interpreter.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import atexit 3 | import base64 4 | import glob 5 | import io 6 | import json 7 | import os 8 | import queue 9 | import re 10 | import shutil 11 | import signal 12 | import stat 13 | import subprocess 14 | import sys 15 | import time 16 | import uuid 17 | from pathlib import Path 18 | from typing import Dict, List, Optional, Union 19 | 20 | import json5 21 | import matplotlib 22 | import PIL.Image 23 | from jupyter_client import BlockingKernelClient 24 | 25 | from qwen_agent.log import logger 26 | from qwen_agent.tools.base import BaseTool, register_tool 27 | from qwen_agent.utils.utils import (extract_code, print_traceback, 28 | save_url_to_local_work_dir) 29 | 30 | WORK_DIR = os.getenv('M6_CODE_INTERPRETER_WORK_DIR', 31 | os.getcwd() + '/workspace/ci_workspace/') 32 | 33 | 34 | def _fix_secure_write_for_code_interpreter(): 35 | if 'linux' in sys.platform.lower(): 36 | os.makedirs(WORK_DIR, exist_ok=True) 37 | fname = os.path.join(WORK_DIR, 38 | f'test_file_permission_{os.getpid()}.txt') 39 | if os.path.exists(fname): 40 | os.remove(fname) 41 | with os.fdopen( 42 | os.open(fname, os.O_CREAT | os.O_WRONLY | os.O_TRUNC, 0o0600), 43 | 'w') as f: 44 | f.write('test') 45 | file_mode = stat.S_IMODE(os.stat(fname).st_mode) & 0o6677 46 | if file_mode != 0o0600: 47 | os.environ['JUPYTER_ALLOW_INSECURE_WRITES'] = '1' 48 | if os.path.exists(fname): 49 | os.remove(fname) 50 | 51 | 52 | _fix_secure_write_for_code_interpreter() 53 | 54 | LAUNCH_KERNEL_PY = """ 55 | from ipykernel import kernelapp as app 56 | app.launch_new_instance() 57 | """ 58 | 59 | INIT_CODE_FILE = str( 60 | Path(__file__).absolute().parent / 'resource' / 61 | 'code_interpreter_init_kernel.py') 62 | 63 | ALIB_FONT_FILE = str( 64 | Path(__file__).absolute().parent / 'resource' / 65 | 'AlibabaPuHuiTi-3-45-Light.ttf') 66 | 67 | _KERNEL_CLIENTS: Dict[int, BlockingKernelClient] = {} 68 | _MISC_SUBPROCESSES: Dict[str, subprocess.Popen] = {} 69 | 70 | 71 | def _start_kernel(pid) -> BlockingKernelClient: 72 | connection_file = os.path.join(WORK_DIR, 73 | f'kernel_connection_file_{pid}.json') 74 | launch_kernel_script = os.path.join(WORK_DIR, f'launch_kernel_{pid}.py') 75 | for f in [connection_file, launch_kernel_script]: 76 | if os.path.exists(f): 77 | logger.info(f'WARNING: {f} already exists') 78 | os.remove(f) 79 | 80 | os.makedirs(WORK_DIR, exist_ok=True) 81 | with open(launch_kernel_script, 'w') as fout: 82 | fout.write(LAUNCH_KERNEL_PY) 83 | 84 | kernel_process = subprocess.Popen( 85 | [ 86 | sys.executable, 87 | launch_kernel_script, 88 | '--IPKernelApp.connection_file', 89 | connection_file, 90 | '--matplotlib=inline', 91 | '--quiet', 92 | ], 93 | cwd=WORK_DIR, 94 | ) 95 | _MISC_SUBPROCESSES[f'kc_{kernel_process.pid}'] = kernel_process 96 | logger.info(f"INFO: kernel process's PID = {kernel_process.pid}") 97 | 98 | # Wait for kernel connection file to be written 99 | while True: 100 | if not os.path.isfile(connection_file): 101 | time.sleep(0.1) 102 | else: 103 | # Keep looping if JSON parsing fails, file may be partially written 104 | try: 105 | with open(connection_file, 'r') as fp: 106 | json.load(fp) 107 | break 108 | except json.JSONDecodeError: 109 | pass 110 | 111 | # Client 112 | kc = BlockingKernelClient(connection_file=connection_file) 113 | asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) 114 | kc.load_connection_file() 115 | kc.start_channels() 116 | kc.wait_for_ready() 117 | return kc 118 | 119 | 120 | def _kill_kernels_and_subprocesses(sig_num=None, _frame=None): 121 | for v in _KERNEL_CLIENTS.values(): 122 | v.shutdown() 123 | for k in list(_KERNEL_CLIENTS.keys()): 124 | del _KERNEL_CLIENTS[k] 125 | 126 | for v in _MISC_SUBPROCESSES.values(): 127 | v.terminate() 128 | for k in list(_MISC_SUBPROCESSES.keys()): 129 | del _MISC_SUBPROCESSES[k] 130 | 131 | if sig_num == signal.SIGINT: 132 | raise KeyboardInterrupt() 133 | 134 | 135 | atexit.register(_kill_kernels_and_subprocesses) 136 | signal.signal(signal.SIGTERM, _kill_kernels_and_subprocesses) 137 | signal.signal(signal.SIGINT, _kill_kernels_and_subprocesses) 138 | 139 | 140 | def _serve_image(image_base64: str) -> str: 141 | image_file = f'{uuid.uuid4()}.png' 142 | local_image_file = os.path.join(WORK_DIR, image_file) 143 | 144 | png_bytes = base64.b64decode(image_base64) 145 | assert isinstance(png_bytes, bytes) 146 | bytes_io = io.BytesIO(png_bytes) 147 | PIL.Image.open(bytes_io).save(local_image_file, 'png') 148 | 149 | static_url = os.getenv('M6_CODE_INTERPRETER_STATIC_URL', 150 | 'http://127.0.0.1:7865/static') 151 | 152 | # Hotfix: Temporarily generate image URL proxies for code interpreter to display in gradio 153 | # Todo: Generate real url 154 | if static_url == 'http://127.0.0.1:7865/static': 155 | if 'image_service' not in _MISC_SUBPROCESSES: 156 | try: 157 | # run a fastapi server for image show in gradio demo by http://127.0.0.1:7865/figure_name 158 | _MISC_SUBPROCESSES['image_service'] = subprocess.Popen([ 159 | 'python', 160 | Path(__file__).absolute().parent / 'resource' / 161 | 'image_service.py' 162 | ]) 163 | except Exception: 164 | print_traceback() 165 | 166 | image_url = f'{static_url}/{image_file}' 167 | 168 | return image_url 169 | 170 | 171 | def _escape_ansi(line: str) -> str: 172 | ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]') 173 | return ansi_escape.sub('', line) 174 | 175 | 176 | def _fix_matplotlib_cjk_font_issue(): 177 | ttf_name = os.path.basename(ALIB_FONT_FILE) 178 | local_ttf = os.path.join( 179 | os.path.abspath( 180 | os.path.join(matplotlib.matplotlib_fname(), os.path.pardir)), 181 | 'fonts', 'ttf', ttf_name) 182 | if not os.path.exists(local_ttf): 183 | try: 184 | shutil.copy(ALIB_FONT_FILE, local_ttf) 185 | font_list_cache = os.path.join(matplotlib.get_cachedir(), 186 | 'fontlist-*.json') 187 | for cache_file in glob.glob(font_list_cache): 188 | with open(cache_file) as fin: 189 | cache_content = fin.read() 190 | if ttf_name not in cache_content: 191 | os.remove(cache_file) 192 | except Exception: 193 | print_traceback() 194 | 195 | 196 | def _execute_code(kc: BlockingKernelClient, code: str) -> str: 197 | kc.wait_for_ready() 198 | kc.execute(code) 199 | result = '' 200 | image_idx = 0 201 | while True: 202 | text = '' 203 | image = '' 204 | finished = False 205 | msg_type = 'error' 206 | try: 207 | msg = kc.get_iopub_msg() 208 | msg_type = msg['msg_type'] 209 | if msg_type == 'status': 210 | if msg['content'].get('execution_state') == 'idle': 211 | finished = True 212 | elif msg_type == 'execute_result': 213 | text = msg['content']['data'].get('text/plain', '') 214 | if 'image/png' in msg['content']['data']: 215 | image_b64 = msg['content']['data']['image/png'] 216 | image_url = _serve_image(image_b64) 217 | image_idx += 1 218 | image = '![fig-%03d](%s)' % (image_idx, image_url) 219 | elif msg_type == 'display_data': 220 | if 'image/png' in msg['content']['data']: 221 | image_b64 = msg['content']['data']['image/png'] 222 | image_url = _serve_image(image_b64) 223 | image_idx += 1 224 | image = '![fig-%03d](%s)' % (image_idx, image_url) 225 | else: 226 | text = msg['content']['data'].get('text/plain', '') 227 | elif msg_type == 'stream': 228 | msg_type = msg['content']['name'] # stdout, stderr 229 | text = msg['content']['text'] 230 | elif msg_type == 'error': 231 | text = _escape_ansi('\n'.join(msg['content']['traceback'])) 232 | if 'M6_CODE_INTERPRETER_TIMEOUT' in text: 233 | text = 'Timeout: Code execution exceeded the time limit.' 234 | except queue.Empty: 235 | text = 'Timeout: Code execution exceeded the time limit.' 236 | finished = True 237 | except Exception: 238 | text = 'The code interpreter encountered an unexpected error.' 239 | print_traceback() 240 | finished = True 241 | if text: 242 | result += f'\n\n{msg_type}:\n\n```\n{text}\n```' 243 | if image: 244 | result += f'\n\n{image}' 245 | if finished: 246 | break 247 | result = result.lstrip('\n') 248 | return result 249 | 250 | 251 | @register_tool('code_interpreter') 252 | class CodeInterpreter(BaseTool): 253 | description = 'Python代码沙盒,可用于执行Python代码。' 254 | parameters = [{ 255 | 'name': 'code', 256 | 'type': 'string', 257 | 'description': '待执行的代码', 258 | 'required': True 259 | }] 260 | 261 | def __init__(self, cfg: Optional[Dict] = None): 262 | self.args_format = '此工具的输入应为Markdown代码块。' 263 | super().__init__(cfg) 264 | self.file_access = True 265 | 266 | def call(self, 267 | params: Union[str, dict], 268 | files: List[str] = None, 269 | timeout: Optional[int] = 30, 270 | **kwargs) -> str: 271 | try: 272 | params = json5.loads(params) 273 | code = params['code'] 274 | except Exception: 275 | code = extract_code(params) 276 | 277 | if not code.strip(): 278 | return '' 279 | # download file 280 | if files: 281 | os.makedirs(WORK_DIR, exist_ok=True) 282 | for file in files: 283 | try: 284 | save_url_to_local_work_dir(file, WORK_DIR) 285 | except Exception: 286 | print_traceback() 287 | 288 | pid: int = os.getpid() 289 | if pid in _KERNEL_CLIENTS: 290 | kc = _KERNEL_CLIENTS[pid] 291 | else: 292 | _fix_matplotlib_cjk_font_issue() 293 | kc = _start_kernel(pid) 294 | with open(INIT_CODE_FILE) as fin: 295 | start_code = fin.read() 296 | start_code = start_code.replace('{{M6_FONT_PATH}}', 297 | repr(ALIB_FONT_FILE)[1:-1]) 298 | logger.info(_execute_code(kc, start_code)) 299 | _KERNEL_CLIENTS[pid] = kc 300 | 301 | if timeout: 302 | code = f'_M6CountdownTimer.start({timeout})\n{code}' 303 | 304 | fixed_code = [] 305 | for line in code.split('\n'): 306 | fixed_code.append(line) 307 | if line.startswith('sns.set_theme('): 308 | fixed_code.append( 309 | 'plt.rcParams["font.family"] = _m6_font_prop.get_name()') 310 | fixed_code = '\n'.join(fixed_code) 311 | fixed_code += '\n\n' # Prevent code not executing in notebook due to no line breaks at the end 312 | result = _execute_code(kc, fixed_code) 313 | 314 | if timeout: 315 | _execute_code(kc, '_M6CountdownTimer.cancel()') 316 | 317 | return result if result.strip() else 'Finished execution.' 318 | 319 | 320 | # 321 | # The _BasePolicy and AnyThreadEventLoopPolicy below are borrowed from Tornado. 322 | # Ref: https://www.tornadoweb.org/en/stable/_modules/tornado/platform/asyncio.html#AnyThreadEventLoopPolicy 323 | # 324 | 325 | if sys.platform == 'win32' and hasattr(asyncio, 326 | 'WindowsSelectorEventLoopPolicy'): 327 | _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore 328 | else: 329 | _BasePolicy = asyncio.DefaultEventLoopPolicy 330 | 331 | 332 | class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore 333 | """Event loop policy that allows loop creation on any thread. 334 | 335 | The default `asyncio` event loop policy only automatically creates 336 | event loops in the main threads. Other threads must create event 337 | loops explicitly or `asyncio.get_event_loop` (and therefore 338 | `.IOLoop.current`) will fail. Installing this policy allows event 339 | loops to be created automatically on any thread. 340 | 341 | Usage:: 342 | asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) 343 | """ 344 | 345 | def get_event_loop(self) -> asyncio.AbstractEventLoop: 346 | try: 347 | return super().get_event_loop() 348 | except RuntimeError: 349 | # "There is no current event loop in thread %r" 350 | loop = self.new_event_loop() 351 | self.set_event_loop(loop) 352 | return loop 353 | -------------------------------------------------------------------------------- /qwen_agent/tools/doc_parser.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import os 4 | import re 5 | from typing import Dict, Optional, Union 6 | from urllib.parse import unquote, urlparse 7 | 8 | import json5 9 | from pydantic import BaseModel 10 | 11 | from qwen_agent.log import logger 12 | from qwen_agent.tools.base import BaseTool, register_tool 13 | from qwen_agent.tools.storage import Storage 14 | from qwen_agent.utils.doc_parser import parse_doc, parse_html_bs 15 | from qwen_agent.utils.utils import (get_file_type, hash_sha256, is_local_path, 16 | print_traceback, 17 | save_url_to_local_work_dir) 18 | 19 | 20 | class FileTypeNotImplError(NotImplementedError): 21 | pass 22 | 23 | 24 | class Record(BaseModel): 25 | url: str 26 | time: str 27 | source: str 28 | raw: list 29 | title: str 30 | topic: str 31 | checked: bool 32 | session: list 33 | 34 | def to_dict(self) -> dict: 35 | return { 36 | 'url': self.url, 37 | 'time': self.time, 38 | 'source': self.source, 39 | 'raw': self.raw, 40 | 'title': self.title, 41 | 'topic': self.topic, 42 | 'checked': self.checked, 43 | 'session': self.session 44 | } 45 | 46 | 47 | def sanitize_chrome_file_path(file_path: str) -> str: 48 | # For Linux and macOS. 49 | if os.path.exists(file_path): 50 | return file_path 51 | 52 | # For native Windows, drop the leading '/' in '/C:/' 53 | win_path = file_path 54 | if win_path.startswith('/'): 55 | win_path = win_path[1:] 56 | if os.path.exists(win_path): 57 | return win_path 58 | 59 | # For Windows + WSL. 60 | if re.match(r'^[A-Za-z]:/', win_path): 61 | wsl_path = f'/mnt/{win_path[0].lower()}/{win_path[3:]}' 62 | if os.path.exists(wsl_path): 63 | return wsl_path 64 | 65 | # For native Windows, replace / with \. 66 | win_path = win_path.replace('/', '\\') 67 | if os.path.exists(win_path): 68 | return win_path 69 | 70 | return file_path 71 | 72 | 73 | def process_file(url: str, db: Storage = None): 74 | logger.info('Starting cache pages...') 75 | url = url 76 | if url.split('.')[-1].lower() in ['pdf', 'docx', 'pptx']: 77 | date1 = datetime.datetime.now() 78 | 79 | if url.startswith('https://') or url.startswith('http://') or re.match( 80 | r'^[A-Za-z]:\\', url) or re.match(r'^[A-Za-z]:/', url): 81 | pdf_path = url 82 | else: 83 | parsed_url = urlparse(url) 84 | pdf_path = unquote(parsed_url.path) 85 | pdf_path = sanitize_chrome_file_path(pdf_path) 86 | 87 | try: 88 | if not is_local_path(url): 89 | # download 90 | file_tmp_path = save_url_to_local_work_dir( 91 | pdf_path, 92 | db.root, 93 | new_name=hash_sha256(url) + '.' + 94 | pdf_path.split('.')[-1].lower()) 95 | pdf_content = parse_doc(file_tmp_path) 96 | else: 97 | pdf_content = parse_doc(pdf_path) 98 | date2 = datetime.datetime.now() 99 | logger.info('Parsing pdf time: ' + str(date2 - date1)) 100 | content = pdf_content 101 | source = 'doc' 102 | title = pdf_path.split('/')[-1].split('\\')[-1].split('.')[0] 103 | except Exception: 104 | print_traceback() 105 | return 'failed' 106 | else: 107 | if not is_local_path(url): 108 | file_tmp_path = save_url_to_local_work_dir( 109 | url, db.root, new_name=hash_sha256(url)) 110 | else: 111 | file_tmp_path = url 112 | file_source = get_file_type(file_tmp_path) 113 | if file_source == 'html': 114 | try: 115 | content = parse_html_bs(file_tmp_path) 116 | title = content[0]['metadata']['title'] 117 | except Exception: 118 | print_traceback() 119 | return 'failed' 120 | source = 'html' 121 | else: 122 | raise FileTypeNotImplError 123 | 124 | # save real data 125 | now_time = str(datetime.date.today()) 126 | new_record = Record(url=url, 127 | time=now_time, 128 | source=source, 129 | raw=content, 130 | title=title, 131 | topic='', 132 | checked=True, 133 | session=[]).to_dict() 134 | new_record_str = json.dumps(new_record, ensure_ascii=False) 135 | db.put(hash_sha256(url), new_record_str) 136 | 137 | return new_record 138 | 139 | 140 | @register_tool('doc_parser') 141 | class DocParser(BaseTool): 142 | description = '解析并存储一个文件,返回解析后的文件内容' 143 | parameters = [{ 144 | 'name': 'url', 145 | 'type': 'string', 146 | 'description': '待解析的文件的路径', 147 | 'required': True 148 | }] 149 | 150 | def __init__(self, cfg: Optional[Dict] = None): 151 | super().__init__(cfg) 152 | self.data_root = self.cfg.get( 153 | 'path', 'workspace/default_doc_parser_data_path') 154 | self.db = Storage({'storage_root_path': self.data_root}) 155 | 156 | def call(self, 157 | params: Union[str, dict], 158 | ignore_cache: bool = False) -> dict: 159 | """Parse file by url, and return the formatted content.""" 160 | 161 | params = self._verify_json_format_args(params) 162 | 163 | if ignore_cache: 164 | record = process_file(url=params['url'], db=self.db) 165 | else: 166 | try: 167 | record = self.db.get(hash_sha256(params['url'])) 168 | record = json5.loads(record) 169 | 170 | except Exception: 171 | record = process_file(url=params['url'], db=self.db) 172 | 173 | return record 174 | -------------------------------------------------------------------------------- /qwen_agent/tools/image_gen.py: -------------------------------------------------------------------------------- 1 | import json 2 | import urllib.parse 3 | from typing import Union 4 | 5 | from qwen_agent.tools.base import BaseTool, register_tool 6 | 7 | 8 | @register_tool('image_gen') 9 | class ImageGen(BaseTool): 10 | description = 'AI绘画(图像生成)服务,输入文本描述和图像分辨率,返回根据文本信息绘制的图片URL。' 11 | parameters = [{ 12 | 'name': 'prompt', 13 | 'type': 'string', 14 | 'description': '详细描述了希望生成的图像具有什么内容,例如人物、环境、动作等细节描述,使用英文', 15 | 'required': True 16 | }, { 17 | 'name': 18 | 'resolution', 19 | 'type': 20 | 'string', 21 | 'description': 22 | '格式是 数字*数字,表示希望生成的图像的分辨率大小,选项有[1024*1024, 720*1280, 1280*720]' 23 | }] 24 | 25 | def call(self, params: Union[str, dict], **kwargs) -> str: 26 | params = self._verify_json_format_args(params) 27 | 28 | prompt = params['prompt'] 29 | prompt = urllib.parse.quote(prompt) 30 | return json.dumps( 31 | {'image_url': f'https://image.pollinations.ai/prompt/{prompt}'}, 32 | ensure_ascii=False) 33 | -------------------------------------------------------------------------------- /qwen_agent/tools/resource/AlibabaPuHuiTi-3-45-Light.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidons-master/HomeRPC/07147c3d4fc554b46bec8296e393d727a7b6238c/qwen_agent/tools/resource/AlibabaPuHuiTi-3-45-Light.ttf -------------------------------------------------------------------------------- /qwen_agent/tools/resource/code_interpreter_init_kernel.py: -------------------------------------------------------------------------------- 1 | import json # noqa 2 | import math # noqa 3 | import os # noqa 4 | import re # noqa 5 | import signal 6 | 7 | import matplotlib # noqa 8 | import matplotlib.pyplot as plt 9 | import numpy as np # noqa 10 | import pandas as pd # noqa 11 | import seaborn as sns 12 | from matplotlib.font_manager import FontProperties 13 | from sympy import Eq, solve, symbols # noqa 14 | 15 | 16 | def input(*args, **kwargs): # noqa 17 | raise NotImplementedError('Python input() function is disabled.') 18 | 19 | 20 | def _m6_timout_handler(_signum=None, _frame=None): 21 | raise TimeoutError('M6_CODE_INTERPRETER_TIMEOUT') 22 | 23 | 24 | try: 25 | signal.signal(signal.SIGALRM, _m6_timout_handler) 26 | except AttributeError: # windows 27 | pass 28 | 29 | 30 | class _M6CountdownTimer: 31 | 32 | @classmethod 33 | def start(cls, timeout: int): 34 | try: 35 | signal.alarm(timeout) 36 | except AttributeError: # windows 37 | pass # TODO: I haven't found a solution that works with jupyter yet. 38 | 39 | @classmethod 40 | def cancel(cls): 41 | try: 42 | signal.alarm(0) 43 | except AttributeError: # windows 44 | pass # TODO 45 | 46 | 47 | sns.set_theme() 48 | 49 | _m6_font_prop = FontProperties(fname='{{M6_FONT_PATH}}') 50 | plt.rcParams['font.family'] = _m6_font_prop.get_name() 51 | -------------------------------------------------------------------------------- /qwen_agent/tools/resource/image_service.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import uvicorn 4 | from fastapi import FastAPI 5 | from fastapi.middleware.cors import CORSMiddleware 6 | from fastapi.staticfiles import StaticFiles 7 | 8 | app = FastAPI() 9 | 10 | origins = ['http://127.0.0.1:7860'] 11 | 12 | app.add_middleware( 13 | CORSMiddleware, 14 | allow_origins=origins, 15 | allow_credentials=True, 16 | allow_methods=['*'], 17 | allow_headers=['*'], 18 | ) 19 | 20 | app.mount('/static', 21 | StaticFiles(directory=os.getcwd() + '/workspace/ci_workspace/'), 22 | name='static') 23 | 24 | if __name__ == '__main__': 25 | uvicorn.run(app='image_service:app', port=7865) 26 | -------------------------------------------------------------------------------- /qwen_agent/tools/retrieval.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Union 2 | 3 | import json5 4 | 5 | from qwen_agent.log import logger 6 | from qwen_agent.tools.base import BaseTool, register_tool 7 | from qwen_agent.utils.utils import get_basename_from_url, print_traceback 8 | 9 | from .doc_parser import DocParser, FileTypeNotImplError 10 | from .similarity_search import (RefMaterialInput, RefMaterialInputItem, 11 | SimilaritySearch) 12 | 13 | 14 | def format_records(records: List[Dict]): 15 | formatted_records = [] 16 | for record in records: 17 | formatted_records.append( 18 | RefMaterialInput(url=get_basename_from_url(record['url']), 19 | text=[ 20 | RefMaterialInputItem( 21 | content=x['page_content'], 22 | token=x['token']) for x in record['raw'] 23 | ])) 24 | return formatted_records 25 | 26 | 27 | @register_tool('retrieval') 28 | class Retrieval(BaseTool): 29 | description = '从给定文件列表中检索出和问题相关的内容' 30 | parameters = [{ 31 | 'name': 'query', 32 | 'type': 'string', 33 | 'description': '问题,需要从文档中检索和这个问题有关的内容' 34 | }, { 35 | 'name': 'files', 36 | 'type': 'array', 37 | 'items': { 38 | 'type': 'string' 39 | }, 40 | 'description': '待解析的文件路径列表', 41 | 'required': True 42 | }] 43 | 44 | def __init__(self, cfg: Optional[Dict] = None): 45 | super().__init__(cfg) 46 | self.doc_parse = DocParser() 47 | self.search = SimilaritySearch() 48 | 49 | def call(self, 50 | params: Union[str, dict], 51 | ignore_cache: bool = False, 52 | max_token: int = 4000) -> list: 53 | """RAG tool. 54 | 55 | Step1: Parse and save files 56 | Step2: Retrieval related content according to query 57 | 58 | Args: 59 | params: The files and query. 60 | ignore_cache: When set to True, overwrite the same documents that have been parsed before. 61 | max_token: Maximum retrieval length. 62 | 63 | Returns: 64 | The retrieved file list. 65 | """ 66 | 67 | params = self._verify_json_format_args(params) 68 | files = params.get('files', []) 69 | if isinstance(files, str): 70 | files = json5.loads(files) 71 | records = [] 72 | for file in files: 73 | try: 74 | _record = self.doc_parse.call(params={'url': file}, 75 | ignore_cache=ignore_cache) 76 | records.append(_record) 77 | except FileTypeNotImplError: 78 | logger.warning( 79 | 'Only Parsing the Following File Types: [\'web page\', \'.pdf\', \'.docx\', \'.pptx\'] to knowledge base!' 80 | ) 81 | except Exception: 82 | print_traceback() 83 | 84 | query = params.get('query', '') 85 | if query and records: 86 | records = format_records(records) 87 | return self._retrieve_content(query, records, max_token) 88 | else: 89 | return records 90 | 91 | def _retrieve_content(self, 92 | query: str, 93 | records: List[RefMaterialInput], 94 | max_token=4000) -> List[Dict]: 95 | single_max_token = int(max_token / len(records)) 96 | _ref_list = [] 97 | for record in records: 98 | # Retrieval for query 99 | now_ref_list = self.search.call(params={'query': query}, 100 | doc=record, 101 | max_token=single_max_token) 102 | _ref_list.append(now_ref_list) 103 | return _ref_list 104 | -------------------------------------------------------------------------------- /qwen_agent/tools/similarity_search.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | from pydantic import BaseModel 4 | 5 | from qwen_agent.log import logger 6 | from qwen_agent.tools.base import BaseTool, register_tool 7 | from qwen_agent.utils.tokenization_qwen import count_tokens 8 | from qwen_agent.utils.utils import get_split_word, parse_keyword 9 | 10 | 11 | class RefMaterialOutput(BaseModel): 12 | """The knowledge data format output from the retrieval""" 13 | url: str 14 | text: list 15 | 16 | def to_dict(self) -> dict: 17 | return { 18 | 'url': self.url, 19 | 'text': self.text, 20 | } 21 | 22 | 23 | class RefMaterialInputItem(BaseModel): 24 | content: str 25 | token: int 26 | 27 | def to_dict(self) -> dict: 28 | return {'content': self.content, 'token': self.token} 29 | 30 | 31 | class RefMaterialInput(BaseModel): 32 | """The knowledge data format input to the retrieval""" 33 | url: str 34 | text: List[RefMaterialInputItem] 35 | 36 | def to_dict(self) -> dict: 37 | return {'url': self.url, 'text': [x.to_dict() for x in self.text]} 38 | 39 | 40 | def format_input_doc(doc: List[str]) -> RefMaterialInput: 41 | new_doc = [] 42 | for x in doc: 43 | item = RefMaterialInputItem(content=x, token=count_tokens(x)) 44 | new_doc.append(item) 45 | return RefMaterialInput(url='', text=new_doc) 46 | 47 | 48 | @register_tool('similarity_search') 49 | class SimilaritySearch(BaseTool): 50 | description = '从给定文档中检索和问题相关的部分' 51 | parameters = [{ 52 | 'name': 'query', 53 | 'type': 'string', 54 | 'description': '问题,需要从文档中检索和这个问题有关的内容', 55 | 'required': True 56 | }] 57 | 58 | def call(self, 59 | params: Union[str, dict], 60 | doc: Union[RefMaterialInput, str, List[str]] = None, 61 | max_token: int = 4000) -> dict: 62 | params = self._verify_json_format_args(params) 63 | 64 | query = params['query'] 65 | if not doc: 66 | return {} 67 | if isinstance(doc, str): 68 | doc = [doc] 69 | if isinstance(doc, list): 70 | doc = format_input_doc(doc) 71 | 72 | tokens = [page.token for page in doc.text] 73 | all_tokens = sum(tokens) 74 | logger.info(f'all tokens of {doc.url}: {all_tokens}') 75 | if all_tokens <= max_token: 76 | logger.info('use full ref') 77 | return RefMaterialOutput(url=doc.url, 78 | text=[x.content 79 | for x in doc.text]).to_dict() 80 | 81 | wordlist = parse_keyword(query) 82 | logger.info('wordlist: ' + ','.join(wordlist)) 83 | if not wordlist: 84 | return self.get_top(doc, max_token) 85 | 86 | sims = [] 87 | for i, page in enumerate(doc.text): 88 | sim = self.filter_section(page.content, wordlist) 89 | sims.append([i, sim]) 90 | sims.sort(key=lambda item: item[1], reverse=True) 91 | assert len(sims) > 0 92 | 93 | res = [] 94 | max_sims = sims[0][1] 95 | if max_sims != 0: 96 | manul = 0 97 | for i in range(min(manul, len(doc.text))): 98 | if max_token >= tokens[ 99 | i] * 2: # Ensure that the first two pages do not fill up the window 100 | res.append(doc.text[i].content) 101 | max_token -= tokens[i] 102 | for i, x in enumerate(sims): 103 | if x[0] < manul: 104 | continue 105 | page = doc.text[x[0]] 106 | if max_token < page.token: 107 | use_rate = (max_token / page.token) * 0.2 108 | res.append(page.content[:int(len(page.content) * 109 | use_rate)]) 110 | break 111 | 112 | res.append(page.content) 113 | max_token -= page.token 114 | 115 | logger.info(f'remaining slots: {max_token}') 116 | return RefMaterialOutput(url=doc.url, text=res).to_dict() 117 | else: 118 | return self.get_top(doc, max_token) 119 | 120 | def filter_section(self, text: str, wordlist: list) -> int: 121 | page_list = get_split_word(text) 122 | sim = self.jaccard_similarity(wordlist, page_list) 123 | 124 | return sim 125 | 126 | @staticmethod 127 | def jaccard_similarity(list1: list, list2: list) -> int: 128 | s1 = set(list1) 129 | s2 = set(list2) 130 | return len(s1.intersection(s2)) # avoid text length impact 131 | # return len(s1.intersection(s2)) / len(s1.union(s2)) # jaccard similarity 132 | 133 | @staticmethod 134 | def get_top(doc: RefMaterialInput, max_token=4000) -> dict: 135 | now_token = 0 136 | text = [] 137 | for page in doc.text: 138 | if (now_token + page.token) <= max_token: 139 | text.append(page.content) 140 | now_token += page.token 141 | else: 142 | use_rate = ((max_token - now_token) / page.token) * 0.2 143 | text.append(page.content[:int(len(page.content) * use_rate)]) 144 | break 145 | logger.info(f'remaining slots: {max_token-now_token}') 146 | return RefMaterialOutput(url=doc.url, text=text).to_dict() 147 | -------------------------------------------------------------------------------- /qwen_agent/tools/storage.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, Optional, Union 3 | 4 | from qwen_agent.tools.base import BaseTool, register_tool 5 | from qwen_agent.utils.utils import read_text_from_file, save_text_to_file 6 | 7 | DEFAULT_STORAGE_PATH = 'workspace/default_data_path' 8 | SUCCESS_MESSAGE = 'SUCCESS' 9 | 10 | 11 | @register_tool('storage') 12 | class Storage(BaseTool): 13 | """ 14 | This is a special tool for data storage 15 | """ 16 | description = '存储和读取数据的工具' 17 | parameters = [{ 18 | 'name': 'operate', 19 | 'type': 'string', 20 | 'description': 21 | '数据操作类型,可选项为["put", "get", "delete", "scan"]之一,分别为存数据、取数据、删除数据、遍历数据', 22 | 'required': True 23 | }, { 24 | 'name': 'key', 25 | 'type': 'string', 26 | 'description': 27 | '数据的路径,类似于文件路径,是一份数据的唯一标识,不能为空,默认根目录为`/`。存数据时,应该合理的设计路径,保证路径含义清晰且唯一。', 28 | 'default': '/' 29 | }, { 30 | 'name': 'value', 31 | 'type': 'string', 32 | 'description': '数据的内容,仅存数据时需要' 33 | }] 34 | 35 | def __init__(self, cfg: Optional[Dict] = None): 36 | super().__init__(cfg) 37 | self.root = self.cfg.get('storage_root_path', DEFAULT_STORAGE_PATH) 38 | os.makedirs(self.root, exist_ok=True) 39 | 40 | def call(self, params: Union[str, dict], **kwargs) -> str: 41 | params = self._verify_json_format_args(params) 42 | operate = params['operate'] 43 | key = params.get('key', '/') 44 | if key.startswith('/'): 45 | key = key[1:] 46 | 47 | if operate == 'put': 48 | assert 'value' in params 49 | return self.put(key, params['value']) 50 | elif operate == 'get': 51 | return self.get(key) 52 | elif operate == 'delete': 53 | return self.delete(key) 54 | else: 55 | return self.scan(key) 56 | 57 | def put(self, key: str, value: str, path: Optional[str] = None) -> str: 58 | path = path or self.root 59 | 60 | # one file for one key value pair 61 | path = os.path.join(path, key) 62 | 63 | path_dir = path[:path.rfind('/') + 1] 64 | if path_dir: 65 | os.makedirs(path_dir, exist_ok=True) 66 | 67 | save_text_to_file(path, value) 68 | return SUCCESS_MESSAGE 69 | 70 | def get(self, key: str, path: Optional[str] = None) -> str: 71 | path = path or self.root 72 | return read_text_from_file(os.path.join(path, key)) 73 | 74 | def delete(self, key, path: Optional[str] = None) -> str: 75 | path = path or self.root 76 | path = os.path.join(path, key) 77 | if os.path.exists(path): 78 | os.remove(path) 79 | return f'Successfully deleted{key}' 80 | else: 81 | return f'Delete Failed: {key} does not exist' 82 | 83 | def scan(self, key: str, path: Optional[str] = None) -> str: 84 | path = path or self.root 85 | path = os.path.join(path, key) 86 | if os.path.exists(path): 87 | if not os.path.isdir(path): 88 | return 'Scan Failed: The scan operation requires passing in a key to a folder path' 89 | # All key-value pairs 90 | kvs = {} 91 | for root, dirs, files in os.walk(path): 92 | for file in files: 93 | k = os.path.join(root, file)[len(path):] 94 | if not k.startswith('/'): 95 | k = '/' + k 96 | v = read_text_from_file(os.path.join(root, file)) 97 | kvs[k] = v 98 | return '\n'.join([f'{k}: {v}' for k, v in kvs.items()]) 99 | else: 100 | return f'Scan Failed: {key} does not exist.' 101 | -------------------------------------------------------------------------------- /qwen_agent/tools/web_extractor.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import requests 4 | 5 | from qwen_agent.tools.base import BaseTool, register_tool 6 | 7 | 8 | @register_tool('web_extractor') 9 | class WebExtractor(BaseTool): 10 | description = '根据网页URL,获取网页内容的工具' 11 | parameters = [{ 12 | 'name': 'url', 13 | 'type': 'string', 14 | 'description': '网页URL', 15 | 'required': True 16 | }] 17 | 18 | def call(self, params: Union[str, dict], **kwargs) -> str: 19 | only_text = self.cfg.get('only_text', False) 20 | 21 | params = self._verify_json_format_args(params) 22 | 23 | url = params['url'] 24 | headers = { 25 | 'User-Agent': 26 | 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3' 27 | } 28 | response = requests.get(url, headers=headers) 29 | if response.status_code == 200: 30 | if only_text: 31 | 32 | import justext 33 | 34 | paragraphs = justext.justext(response.text, 35 | justext.get_stoplist('English')) 36 | content = '\n\n'.join( 37 | [paragraph.text for paragraph in paragraphs]).strip() 38 | if content: 39 | return content 40 | else: 41 | return response.text 42 | else: 43 | return response.text 44 | else: 45 | return '' 46 | -------------------------------------------------------------------------------- /qwen_agent/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/guidons-master/HomeRPC/07147c3d4fc554b46bec8296e393d727a7b6238c/qwen_agent/utils/__init__.py -------------------------------------------------------------------------------- /qwen_agent/utils/doc_parser.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from qwen_agent.utils.tokenization_qwen import count_tokens 4 | 5 | ONE_PAGE_TOKEN = 500 6 | 7 | 8 | def rm_newlines(text): 9 | text = re.sub(r'(?<=[^\.。::])\n', ' ', text) 10 | return text 11 | 12 | 13 | def rm_cid(text): 14 | text = re.sub(r'\(cid:\d+\)', '', text) 15 | return text 16 | 17 | 18 | def rm_hexadecimal(text): 19 | text = re.sub(r'[0-9A-Fa-f]{21,}', '', text) 20 | return text 21 | 22 | 23 | def rm_continuous_placeholders(text): 24 | text = re.sub(r'(\.|-|——|。|_|\*){7,}', '...', text) 25 | return text 26 | 27 | 28 | def deal(text): 29 | text = rm_newlines(text) 30 | text = rm_cid(text) 31 | text = rm_hexadecimal(text) 32 | text = rm_continuous_placeholders(text) 33 | return text 34 | 35 | 36 | def parse_doc(path): 37 | if '.pdf' in path.lower(): 38 | from pdfminer.high_level import extract_text 39 | text = extract_text(path) 40 | elif '.docx' in path.lower(): 41 | import docx2txt 42 | text = docx2txt.process(path) 43 | elif '.pptx' in path.lower(): 44 | from pptx import Presentation 45 | ppt = Presentation(path) 46 | text = [] 47 | for slide in ppt.slides: 48 | for shape in slide.shapes: 49 | if hasattr(shape, 'text'): 50 | text.append(shape.text) 51 | text = '\n'.join(text) 52 | else: 53 | raise TypeError 54 | 55 | text = deal(text) 56 | return split_text_to_trunk(text, path) 57 | 58 | 59 | def pre_process_html(s): 60 | # replace multiple newlines 61 | s = re.sub('\n+', '\n', s) 62 | # replace special string 63 | s = s.replace("Add to Qwen's Reading List", '') 64 | return s 65 | 66 | 67 | def parse_html_bs(path): 68 | try: 69 | from bs4 import BeautifulSoup 70 | except Exception: 71 | raise ValueError('Please install bs4 by `pip install beautifulsoup4`') 72 | bs_kwargs = {'features': 'lxml'} 73 | with open(path, 'r', encoding='utf-8') as f: 74 | soup = BeautifulSoup(f, **bs_kwargs) 75 | 76 | text = soup.get_text() 77 | 78 | if soup.title: 79 | title = str(soup.title.string) 80 | else: 81 | title = '' 82 | text = pre_process_html(text) 83 | return split_text_to_trunk(text, path, title) 84 | 85 | 86 | def split_text_to_trunk(content: str, path: str, title: str = ''): 87 | all_tokens = count_tokens(content) 88 | all_pages = round(all_tokens / ONE_PAGE_TOKEN) 89 | if all_pages == 0: 90 | all_pages = 1 91 | len_content = len(content) 92 | len_one_page = int(len_content / 93 | all_pages) # Approximately equal to ONE_PAGE_TOKEN 94 | 95 | res = [] 96 | for i in range(0, len_content, len_one_page): 97 | text = content[i:min(i + len_one_page, len_content)] 98 | res.append({ 99 | 'page_content': text, 100 | 'metadata': { 101 | 'source': path, 102 | 'title': title, 103 | 'page': (i % len_one_page) 104 | }, 105 | 'token': count_tokens(text) 106 | }) 107 | return res 108 | -------------------------------------------------------------------------------- /qwen_agent/utils/tokenization_qwen.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba Cloud. 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """Tokenization classes for QWen.""" 6 | 7 | import base64 8 | import logging 9 | import os 10 | import unicodedata 11 | from dataclasses import dataclass, field 12 | from pathlib import Path 13 | from typing import Collection, Dict, List, Set, Tuple, Union 14 | 15 | import tiktoken 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | VOCAB_FILES_NAMES = {'vocab_file': 'qwen.tiktoken'} 20 | 21 | PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" 22 | ENDOFTEXT = '<|endoftext|>' 23 | IMSTART = '<|im_start|>' 24 | IMEND = '<|im_end|>' 25 | # as the default behavior is changed to allow special tokens in 26 | # regular texts, the surface forms of special tokens need to be 27 | # as different as possible to minimize the impact 28 | EXTRAS = tuple((f'<|extra_{i}|>' for i in range(205))) 29 | # changed to use actual index to avoid misconfiguration with vocabulary expansion 30 | SPECIAL_START_ID = 151643 31 | SPECIAL_TOKENS = tuple( 32 | enumerate( 33 | (( 34 | ENDOFTEXT, 35 | IMSTART, 36 | IMEND, 37 | ) + EXTRAS), 38 | start=SPECIAL_START_ID, 39 | )) 40 | SPECIAL_TOKENS_SET = set(t for i, t in SPECIAL_TOKENS) 41 | 42 | 43 | def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]: 44 | with open(tiktoken_bpe_file, 'rb') as f: 45 | contents = f.read() 46 | return { 47 | base64.b64decode(token): int(rank) 48 | for token, rank in (line.split() for line in contents.splitlines() 49 | if line) 50 | } 51 | 52 | 53 | @dataclass(frozen=True, eq=True) 54 | class AddedToken: 55 | """ 56 | AddedToken represents a token to be added to a Tokenizer An AddedToken can have special options defining the 57 | way it should behave. 58 | """ 59 | 60 | content: str = field(default_factory=str) 61 | single_word: bool = False 62 | lstrip: bool = False 63 | rstrip: bool = False 64 | normalized: bool = True 65 | 66 | def __getstate__(self): 67 | return self.__dict__ 68 | 69 | 70 | class QWenTokenizer: 71 | """QWen tokenizer.""" 72 | 73 | vocab_files_names = VOCAB_FILES_NAMES 74 | 75 | def __init__( 76 | self, 77 | vocab_file=None, 78 | errors='replace', 79 | extra_vocab_file=None, 80 | **kwargs, 81 | ): 82 | if not vocab_file: 83 | vocab_file = VOCAB_FILES_NAMES['vocab_file'] 84 | self._decode_use_source_tokenizer = False 85 | 86 | # how to handle errors in decoding UTF-8 byte sequences 87 | # use ignore if you are in streaming inference 88 | self.errors = errors 89 | 90 | self.mergeable_ranks = _load_tiktoken_bpe( 91 | vocab_file) # type: Dict[bytes, int] 92 | self.special_tokens = {token: index for index, token in SPECIAL_TOKENS} 93 | 94 | # try load extra vocab from file 95 | if extra_vocab_file is not None: 96 | used_ids = set(self.mergeable_ranks.values()) | set( 97 | self.special_tokens.values()) 98 | extra_mergeable_ranks = _load_tiktoken_bpe(extra_vocab_file) 99 | for token, index in extra_mergeable_ranks.items(): 100 | if token in self.mergeable_ranks: 101 | logger.info(f'extra token {token} exists, skipping') 102 | continue 103 | if index in used_ids: 104 | logger.info( 105 | f'the index {index} for extra token {token} exists, skipping' 106 | ) 107 | continue 108 | self.mergeable_ranks[token] = index 109 | # the index may be sparse after this, but don't worry tiktoken.Encoding will handle this 110 | 111 | enc = tiktoken.Encoding( 112 | 'Qwen', 113 | pat_str=PAT_STR, 114 | mergeable_ranks=self.mergeable_ranks, 115 | special_tokens=self.special_tokens, 116 | ) 117 | assert ( 118 | len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab 119 | ), f'{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding' 120 | 121 | self.decoder = {v: k 122 | for k, v in self.mergeable_ranks.items() 123 | } # type: dict[int, bytes|str] 124 | self.decoder.update({v: k for k, v in self.special_tokens.items()}) 125 | 126 | self.tokenizer = enc # type: tiktoken.Encoding 127 | 128 | self.eod_id = self.tokenizer.eot_token 129 | self.im_start_id = self.special_tokens[IMSTART] 130 | self.im_end_id = self.special_tokens[IMEND] 131 | 132 | def __getstate__(self): 133 | # for pickle lovers 134 | state = self.__dict__.copy() 135 | del state['tokenizer'] 136 | return state 137 | 138 | def __setstate__(self, state): 139 | # tokenizer is not python native; don't pass it; rebuild it 140 | self.__dict__.update(state) 141 | enc = tiktoken.Encoding( 142 | 'Qwen', 143 | pat_str=PAT_STR, 144 | mergeable_ranks=self.mergeable_ranks, 145 | special_tokens=self.special_tokens, 146 | ) 147 | self.tokenizer = enc 148 | 149 | def __len__(self) -> int: 150 | return self.tokenizer.n_vocab 151 | 152 | def get_vocab(self) -> Dict[bytes, int]: 153 | return self.mergeable_ranks 154 | 155 | def convert_tokens_to_ids( 156 | self, tokens: Union[bytes, str, List[Union[bytes, 157 | str]]]) -> List[int]: 158 | ids = [] 159 | if isinstance(tokens, (str, bytes)): 160 | if tokens in self.special_tokens: 161 | return self.special_tokens[tokens] 162 | else: 163 | return self.mergeable_ranks.get(tokens) 164 | for token in tokens: 165 | if token in self.special_tokens: 166 | ids.append(self.special_tokens[token]) 167 | else: 168 | ids.append(self.mergeable_ranks.get(token)) 169 | return ids 170 | 171 | def _add_tokens( 172 | self, 173 | new_tokens: Union[List[str], List[AddedToken]], 174 | special_tokens: bool = False, 175 | ) -> int: 176 | if not special_tokens and new_tokens: 177 | raise ValueError('Adding regular tokens is not supported') 178 | for token in new_tokens: 179 | surface_form = token.content if isinstance(token, 180 | AddedToken) else token 181 | if surface_form not in SPECIAL_TOKENS_SET: 182 | raise ValueError( 183 | 'Adding unknown special tokens is not supported') 184 | return 0 185 | 186 | def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]: 187 | """ 188 | Save only the vocabulary of the tokenizer (vocabulary). 189 | 190 | Returns: 191 | `Tuple(str)`: Paths to the files saved. 192 | """ 193 | file_path = os.path.join(save_directory, 'qwen.tiktoken') 194 | with open(file_path, 'w', encoding='utf8') as w: 195 | for k, v in self.mergeable_ranks.items(): 196 | line = base64.b64encode(k).decode('utf8') + ' ' + str(v) + '\n' 197 | w.write(line) 198 | return (file_path, ) 199 | 200 | def tokenize( 201 | self, 202 | text: str, 203 | allowed_special: Union[Set, str] = 'all', 204 | disallowed_special: Union[Collection, str] = (), 205 | **kwargs, 206 | ) -> List[Union[bytes, str]]: 207 | """ 208 | Converts a string in a sequence of tokens. 209 | 210 | Args: 211 | text (`str`): 212 | The sequence to be encoded. 213 | allowed_special (`Literal["all"]` or `set`): 214 | The surface forms of the tokens to be encoded as special tokens in regular texts. 215 | Default to "all". 216 | disallowed_special (`Literal["all"]` or `Collection`): 217 | The surface forms of the tokens that should not be in regular texts and trigger errors. 218 | Default to an empty tuple. 219 | 220 | kwargs (additional keyword arguments, *optional*): 221 | Will be passed to the underlying model specific encode method. 222 | 223 | Returns: 224 | `List[bytes|str]`: The list of tokens. 225 | """ 226 | tokens = [] 227 | text = unicodedata.normalize('NFC', text) 228 | 229 | # this implementation takes a detour: text -> token id -> token surface forms 230 | for t in self.tokenizer.encode(text, 231 | allowed_special=allowed_special, 232 | disallowed_special=disallowed_special): 233 | tokens.append(self.decoder[t]) 234 | return tokens 235 | 236 | def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str: 237 | """ 238 | Converts a sequence of tokens in a single string. 239 | """ 240 | text = '' 241 | temp = b'' 242 | for t in tokens: 243 | if isinstance(t, str): 244 | if temp: 245 | text += temp.decode('utf-8', errors=self.errors) 246 | temp = b'' 247 | text += t 248 | elif isinstance(t, bytes): 249 | temp += t 250 | else: 251 | raise TypeError('token should only be of type types or str') 252 | if temp: 253 | text += temp.decode('utf-8', errors=self.errors) 254 | return text 255 | 256 | @property 257 | def vocab_size(self): 258 | return self.tokenizer.n_vocab 259 | 260 | def _convert_id_to_token(self, index: int) -> Union[bytes, str]: 261 | """Converts an id to a token, special tokens included""" 262 | if index in self.decoder: 263 | return self.decoder[index] 264 | raise ValueError('unknown ids') 265 | 266 | def _convert_token_to_id(self, token: Union[bytes, str]) -> int: 267 | """Converts a token to an id using the vocab, special tokens included""" 268 | if token in self.special_tokens: 269 | return self.special_tokens[token] 270 | if token in self.mergeable_ranks: 271 | return self.mergeable_ranks[token] 272 | raise ValueError('unknown token') 273 | 274 | def _tokenize(self, text: str, **kwargs): 275 | """ 276 | Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based 277 | vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces). 278 | 279 | Do NOT take care of added tokens. 280 | """ 281 | raise NotImplementedError 282 | 283 | def _decode( 284 | self, 285 | token_ids: Union[int, List[int]], 286 | skip_special_tokens: bool = False, 287 | errors: str = None, 288 | **kwargs, 289 | ) -> str: 290 | if isinstance(token_ids, int): 291 | token_ids = [token_ids] 292 | if skip_special_tokens: 293 | token_ids = [i for i in token_ids if i < self.eod_id] 294 | return self.tokenizer.decode(token_ids, errors=errors or self.errors) 295 | 296 | 297 | tokenizer = QWenTokenizer(Path(__file__).resolve().parent / 'qwen.tiktoken') 298 | 299 | 300 | def count_tokens(text): 301 | tokens = tokenizer.tokenize(text) 302 | return len(tokens) 303 | -------------------------------------------------------------------------------- /qwen_agent/utils/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import hashlib 3 | import json 4 | import os 5 | import re 6 | import shutil 7 | import socket 8 | import sys 9 | import traceback 10 | import urllib 11 | from typing import Dict, List, Literal, Optional, Union 12 | from urllib.parse import urlparse 13 | 14 | import jieba 15 | import json5 16 | import requests 17 | from jieba import analyse 18 | 19 | from qwen_agent.log import logger 20 | 21 | 22 | def get_local_ip(): 23 | s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 24 | try: 25 | # doesn't even have to be reachable 26 | s.connect(('10.255.255.255', 1)) 27 | ip = s.getsockname()[0] 28 | except Exception: 29 | ip = '127.0.0.1' 30 | finally: 31 | s.close() 32 | return ip 33 | 34 | 35 | def hash_sha256(key): 36 | hash_object = hashlib.sha256(key.encode()) 37 | key = hash_object.hexdigest() 38 | return key 39 | 40 | 41 | def print_traceback(is_error=True): 42 | if is_error: 43 | logger.error(''.join(traceback.format_exception(*sys.exc_info()))) 44 | else: 45 | logger.warning(''.join(traceback.format_exception(*sys.exc_info()))) 46 | 47 | 48 | def has_chinese_chars(data) -> bool: 49 | text = f'{data}' 50 | return len(re.findall(r'[\u4e00-\u9fff]+', text)) > 0 51 | 52 | 53 | def get_basename_from_url(url: str) -> str: 54 | basename = os.path.basename(urlparse(url).path) 55 | basename = urllib.parse.unquote(basename) 56 | return basename.strip() 57 | 58 | 59 | def is_local_path(path): 60 | if path.startswith('https://') or path.startswith('http://'): 61 | return False 62 | return True 63 | 64 | 65 | def save_url_to_local_work_dir(url, base_dir, new_name=''): 66 | if not new_name: 67 | new_name = get_basename_from_url(url) 68 | new_path = os.path.join(base_dir, new_name) 69 | if os.path.exists(new_path): 70 | os.remove(new_path) 71 | logger.info(f'download {url} to {new_path}') 72 | start_time = datetime.datetime.now() 73 | if is_local_path(url): 74 | shutil.copy(url, new_path) 75 | else: 76 | headers = { 77 | 'User-Agent': 78 | 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3' 79 | } 80 | response = requests.get(url, headers=headers) 81 | if response.status_code == 200: 82 | with open(new_path, 'wb') as file: 83 | file.write(response.content) 84 | else: 85 | raise ValueError( 86 | 'Can not download this file. Please check your network or the file link.' 87 | ) 88 | end_time = datetime.datetime.now() 89 | logger.info(f'Time: {str(end_time - start_time)}') 90 | return new_path 91 | 92 | 93 | def is_image(filename): 94 | filename = filename.lower() 95 | for ext in ['jpg', 'jpeg', 'png', 'webp']: 96 | if filename.endswith(ext): 97 | return True 98 | return False 99 | 100 | 101 | def get_current_date_str( 102 | lang: Literal['en', 'zh'] = 'en', 103 | hours_from_utc: Optional[int] = None, 104 | ) -> str: 105 | if hours_from_utc is None: 106 | cur_time = datetime.datetime.now() 107 | else: 108 | cur_time = datetime.datetime.utcnow() + datetime.timedelta( 109 | hours=hours_from_utc) 110 | if lang == 'en': 111 | date_str = 'Current date: ' + cur_time.strftime('%A, %B %d, %Y') 112 | elif lang == 'zh': 113 | cur_time = cur_time.timetuple() 114 | date_str = f'当前时间:{cur_time.tm_year}年{cur_time.tm_mon}月{cur_time.tm_mday}日,星期' 115 | date_str += ['一', '二', '三', '四', '五', '六', '日'][cur_time.tm_wday] 116 | date_str += '。' 117 | else: 118 | raise NotImplementedError 119 | return date_str 120 | 121 | 122 | def save_text_to_file(path, text): 123 | with open(path, 'w', encoding='utf-8') as fp: 124 | fp.write(text) 125 | 126 | 127 | def read_text_from_file(path): 128 | with open(path, 'r', encoding='utf-8') as file: 129 | file_content = file.read() 130 | return file_content 131 | 132 | 133 | def contains_html_tags(text): 134 | pattern = r'<(p|span|div|li|html|script)[^>]*?' 135 | return bool(re.search(pattern, text)) 136 | 137 | 138 | def get_file_type(path): 139 | # This is a temporary plan 140 | if is_local_path(path): 141 | try: 142 | content = read_text_from_file(path) 143 | except Exception: 144 | print_traceback() 145 | return 'Unknown' 146 | 147 | if contains_html_tags(content): 148 | return 'html' 149 | else: 150 | return 'Unknown' 151 | else: 152 | headers = { 153 | 'User-Agent': 154 | 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3' 155 | } 156 | response = requests.get(path, headers=headers) 157 | if response.status_code == 200: 158 | if contains_html_tags(response.text): 159 | return 'html' 160 | else: 161 | return 'Unknown' 162 | else: 163 | print_traceback() 164 | return 'Unknown' 165 | 166 | 167 | ignore_words = [ 168 | '', ' ', '\t', '\n', '\\', 'is', 'are', 'am', 'what', 'how', '的', '吗', '是', 169 | '了', '啊', '呢', '怎么', '如何', '什么', '?', '?', '!', '!', '“', '”', '‘', '’', 170 | "'", "'", '"', '"', ':', ':', '讲了', '描述', '讲', '说说', '讲讲', '介绍', '总结下', 171 | '总结一下', '文档', '文章', '文稿', '稿子', '论文', 'PDF', 'pdf', '这个', '这篇', '这', '我', 172 | '帮我', '那个', '下', '翻译' 173 | ] 174 | 175 | 176 | def get_split_word(text): 177 | text = text.lower() 178 | _wordlist = jieba.lcut(text.strip()) 179 | wordlist = [] 180 | for x in _wordlist: 181 | if x in ignore_words: 182 | continue 183 | wordlist.append(x) 184 | return wordlist 185 | 186 | 187 | def parse_keyword(text): 188 | try: 189 | res = json5.loads(text) 190 | except Exception: 191 | return get_split_word(text) 192 | 193 | # json format 194 | _wordlist = [] 195 | try: 196 | if 'keywords_zh' in res and isinstance(res['keywords_zh'], list): 197 | _wordlist.extend([kw.lower() for kw in res['keywords_zh']]) 198 | if 'keywords_en' in res and isinstance(res['keywords_en'], list): 199 | _wordlist.extend([kw.lower() for kw in res['keywords_en']]) 200 | wordlist = [] 201 | for x in _wordlist: 202 | if x in ignore_words: 203 | continue 204 | wordlist.append(x) 205 | wordlist.extend(get_split_word(res['text'])) 206 | return wordlist 207 | except Exception: 208 | return get_split_word(text) 209 | 210 | 211 | def get_key_word(text): 212 | text = text.lower() 213 | _wordlist = analyse.extract_tags(text) 214 | wordlist = [] 215 | for x in _wordlist: 216 | if x in ignore_words: 217 | continue 218 | wordlist.append(x) 219 | return wordlist 220 | 221 | 222 | def get_last_one_line_context(text): 223 | lines = text.split('\n') 224 | n = len(lines) 225 | res = '' 226 | for i in range(n - 1, -1, -1): 227 | if lines[i].strip(): 228 | res = lines[i] 229 | break 230 | return res 231 | 232 | 233 | def extract_urls(text): 234 | pattern = re.compile(r'https?://\S+') 235 | urls = re.findall(pattern, text) 236 | return urls 237 | 238 | 239 | def extract_obs(text): 240 | k = text.rfind('\nObservation:') 241 | j = text.rfind('\nThought:') 242 | obs = text[k + len('\nObservation:'):j] 243 | return obs.strip() 244 | 245 | 246 | def extract_code(text): 247 | # Match triple backtick blocks first 248 | triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL) 249 | if triple_match: 250 | text = triple_match.group(1) 251 | else: 252 | try: 253 | text = json5.loads(text)['code'] 254 | except Exception: 255 | print_traceback() 256 | # If no code blocks found, return original text 257 | return text 258 | 259 | 260 | def parse_latest_plugin_call(text): 261 | plugin_name, plugin_args = '', '' 262 | i = text.rfind('\nAction:') 263 | j = text.rfind('\nAction Input:') 264 | k = text.rfind('\nObservation:') 265 | if 0 <= i < j: # If the text has `Action` and `Action input`, 266 | if k < j: # but does not contain `Observation`, 267 | # then it is likely that `Observation` is ommited by the LLM, 268 | # because the output text may have discarded the stop word. 269 | text = text.rstrip() + '\nObservation:' # Add it back. 270 | k = text.rfind('\nObservation:') 271 | plugin_name = text[i + len('\nAction:'):j].strip() 272 | plugin_args = text[j + len('\nAction Input:'):k].strip() 273 | text = text[:k] 274 | return plugin_name, plugin_args, text 275 | 276 | 277 | def get_function_description(function: Dict) -> str: 278 | """ 279 | Text description of function 280 | """ 281 | tool_desc_template = { 282 | 'zh': 283 | '### {name_for_human}\n\n{name_for_model}: {description_for_model} 输入参数:{parameters} {args_format}', 284 | 'en': 285 | '### {name_for_human}\n\n{name_for_model}: {description_for_model} Parameters: {parameters} {args_format}' 286 | } 287 | if has_chinese_chars(function): 288 | tool_desc = tool_desc_template['zh'] 289 | else: 290 | tool_desc = tool_desc_template['en'] 291 | 292 | name = function.get('name', None) 293 | name_for_human = function.get('name_for_human', name) 294 | name_for_model = function.get('name_for_model', name) 295 | assert name_for_human and name_for_model 296 | args_format = function.get('args_format', '') 297 | return tool_desc.format(name_for_human=name_for_human, 298 | name_for_model=name_for_model, 299 | description_for_model=function['description'], 300 | parameters=json.dumps(function['parameters'], 301 | ensure_ascii=False), 302 | args_format=args_format).rstrip() 303 | 304 | 305 | def format_knowledge_to_source_and_content( 306 | result: Union[str, List[dict]]) -> List[dict]: 307 | knowledge = [] 308 | if isinstance(result, str): 309 | result = f'{result}'.strip() 310 | docs = json5.loads(result) 311 | else: 312 | docs = result 313 | try: 314 | _tmp_knowledge = [] 315 | assert isinstance(docs, list) 316 | for doc in docs: 317 | url, snippets = doc['url'], doc['text'] 318 | assert isinstance(snippets, list) 319 | _tmp_knowledge.append({'source': f'[文件]({url})', 'content': '\n\n...\n\n'.join(snippets)}) 320 | knowledge.extend(_tmp_knowledge) 321 | except Exception: 322 | print_traceback() 323 | knowledge.append({'source': '上传的文档', 'content': result}) 324 | return knowledge 325 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/Yakifo/amqtt#egg=amqtt 2 | 3 | colorlog 4 | zeroconf 5 | streamlit 6 | anyio>=3.7.1 7 | beautifulsoup4 8 | dashscope>=1.11.0 9 | fastapi>=0.103.1 10 | html2text 11 | jieba 12 | json5 13 | jsonlines 14 | jupyter>=1.0.0 15 | matplotlib 16 | numpy 17 | openai 18 | pandas 19 | pdfminer-six 20 | pillow 21 | pydantic>=2.3.0 22 | python-docx 23 | python-pptx 24 | requests 25 | seaborn 26 | sympy 27 | tiktoken -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | from server import HomeRPC 2 | 3 | if __name__ == '__main__': 4 | # 启动HomeRPC 5 | HomeRPC.setup(ip = "192.168.43.9", log = True) 6 | 7 | # 等待ESP32连接 8 | input("Waiting for ESP32 to connect...") 9 | 10 | place = HomeRPC.place("room") 11 | # 调用ESP32客户端服务 12 | place.device("light").id(1).call("trigger", 1, timeout_s = 10) 13 | print("led status: ", place.device("light").id(1).call("status", timeout_s = 10)) -------------------------------------------------------------------------------- /server/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import HomeRPC 2 | 3 | __all__ = ['HomeRPC'] -------------------------------------------------------------------------------- /server/base.py: -------------------------------------------------------------------------------- 1 | from .log import Log 2 | from .broker import MqttBroker 3 | from .rule import Rule 4 | from multiprocessing import Process, Queue 5 | from json import dumps, loads 6 | from struct import unpack, pack 7 | from typing import List 8 | import asyncio 9 | 10 | _topic_registry = {} 11 | _recv, _send = Queue(), Queue() 12 | _topic = Queue() 13 | _rpc_log = None 14 | 15 | def dict2funcs() -> List[dict]: 16 | global _topic_registry 17 | data = _topic_registry 18 | 19 | functions = [] 20 | 21 | type_signature_mapping = { 22 | 'c': {'type': 'string', 'description': 'a character'}, 23 | 'i': {'type': 'integer', 'description': 'a number'}, 24 | 'f': {'type': 'number', 'description': 'a single-precision floating-point number'}, 25 | 'd': {'type': 'number', 'description': 'a double-precision floating-point number'}, 26 | } 27 | 28 | for place, devices in data.items(): 29 | for device_type, device_list in devices.items(): 30 | for device in device_list: 31 | for service in device['services']: 32 | function = { 33 | 'name': service['name'], 34 | 'description': service['desc'], 35 | 'parameters': { 36 | 'type': 'object', 37 | 'properties': { 38 | 'place': { 39 | 'type': 'string', 40 | 'description': 'where the device is located', 41 | 'enum': [place] 42 | }, 43 | 'device_type': { 44 | 'type': 'string', 45 | 'description': 'what type of device', 46 | 'enum': [device_type] 47 | }, 48 | 'device_id': { 49 | 'type': 'integer', 50 | 'description': 'the id of the device', 51 | 'enum': [device['id']] 52 | } 53 | }, 54 | 'required': ['place', 'device_type', 'device_id'] 55 | } 56 | } 57 | if service['input_type']: 58 | function['parameters']['properties']['input_type'] = type_signature_mapping[service['input_type']] 59 | function['parameters']['required'].append('input_type') 60 | functions.append(function) 61 | return functions 62 | 63 | def register_device(device: dict): 64 | place = device['place'] 65 | type_ = device['type'] 66 | id_ = device['id'] 67 | services = device['services'] 68 | 69 | if place not in _topic_registry: 70 | _topic_registry[place] = {} 71 | 72 | if type_ not in _topic_registry[place]: 73 | _topic_registry[place][type_] = [] 74 | 75 | _topic_registry[place][type_].append({'id': id_, 'services': services}) 76 | 77 | def update_registry(): 78 | while True: 79 | try: 80 | topic = dict(_topic.get_nowait()) 81 | except Exception: 82 | break 83 | else: 84 | register_device(topic) 85 | _rpc_log.log(f'Registered device: {_topic_registry}') 86 | 87 | def deserialize_rpc_any(buffer: str, data_type: str) -> any: 88 | res = bytes.fromhex(buffer.ljust(16, '0')) 89 | 90 | if data_type == 'f': 91 | return unpack('f', res[:4])[0] 92 | elif data_type == 'i': 93 | return unpack('q', res[:8])[0] 94 | elif data_type == 'd': 95 | return unpack('d', res[:8])[0] 96 | elif data_type == 'c': 97 | ascii_value = unpack('B', res[:1])[0] 98 | return chr(ascii_value) 99 | else: 100 | raise ValueError('Unsupported data type: ' + data_type) 101 | 102 | def serialize_rpc_any(data_list: List[any], data_type: str) -> List[str]: 103 | res = [] 104 | 105 | for data, _type in zip(data_list, data_type): 106 | if _type == 'f': 107 | res.append(pack('f', data).hex().ljust(8, '0')) 108 | elif _type == 'i': 109 | res.append(pack('q', data).hex().ljust(16, '0')) 110 | elif _type == 'd': 111 | res.append(pack('d', data).hex().ljust(16, '0')) 112 | elif _type == 'c': 113 | res.append(pack('B', ord(data)).hex().ljust(2, '0')) 114 | else: 115 | raise ValueError('Unsupported data type: ' + _type) 116 | 117 | return res 118 | 119 | def _asyncio_loop(ip: str, log: bool, recv: Queue, send: Queue, topic: Queue): 120 | loop = asyncio.get_event_loop() 121 | rule_engine = Rule(loop, recv, send, topic, log) 122 | 123 | class Server(MqttBroker): 124 | def __init__(self, ip, log): 125 | super().__init__(loop, ip, log) 126 | 127 | def _init_task(self): 128 | rule_engine.listen() 129 | 130 | server = Server(ip, log) 131 | server.loop_forever() 132 | 133 | class HomeRPC(): 134 | @staticmethod 135 | def setup(ip: str = None, log: bool = True): 136 | global _rpc_log 137 | _rpc_log = Log(disable = not log) 138 | Process(target = _asyncio_loop, args = (ip, log, _recv, _send, _topic), daemon = True).start() 139 | 140 | @staticmethod 141 | def place(name: str): 142 | 143 | class Place: 144 | def device(self, device_name: str): 145 | 146 | class Device: 147 | def id(self, device_id: int): 148 | 149 | class Id: 150 | def call(self, func: str, *args, **kwargs): 151 | update_registry() 152 | timeout = kwargs.get("timeout_s", 10) * 1000 153 | 154 | if name in _topic_registry and device_name in _topic_registry[name]: 155 | for device_info in _topic_registry[name][device_name]: 156 | if device_info['id'] != device_id: 157 | continue 158 | 159 | for service in device_info['services']: 160 | if service['name'] != func: 161 | continue 162 | 163 | if len(args) != len(service['input_type']): 164 | _rpc_log.log_error(f'Function {func} requires {len(service["input_type"])} arguments, but {len(args)} were given') 165 | return None 166 | 167 | message = { 168 | "callback": "/callback/master", 169 | } 170 | if len(args): 171 | message["params"] = serialize_rpc_any(args, service['input_type']) 172 | 173 | _send.put(dumps({ 174 | "topic": f"/{name}/{device_name}/{device_id}/{func}", 175 | "message": dumps(message) 176 | })) 177 | 178 | try: 179 | data = loads(_recv.get(timeout = timeout)) 180 | return deserialize_rpc_any(data['result'], service['output_type']) 181 | except Exception as e: 182 | _rpc_log.log_error(e) 183 | return None 184 | 185 | _rpc_log.log_error(f'Device {name}/{device_name}/{device_id} does not exist') 186 | return None 187 | return Id() 188 | return Device() 189 | return Place() 190 | 191 | @staticmethod 192 | def funcs(): 193 | update_registry() 194 | return dict2funcs() -------------------------------------------------------------------------------- /server/broker.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from amqtt.broker import Broker 3 | from zeroconf import ServiceInfo, Zeroconf, IPVersion 4 | from .log import Log 5 | from socket import inet_aton 6 | 7 | class MqttBroker: 8 | def __init__(self, loop: asyncio.AbstractEventLoop = None, ip: str = None , log: bool = True): 9 | self.zeroconf = Zeroconf(ip_version = IPVersion.All) 10 | self.info = ServiceInfo( 11 | "_mqtt._tcp.local.", 12 | "broker._mqtt._tcp.local.", 13 | addresses = [inet_aton(ip),], 14 | port = 3000, 15 | properties = { "name": "MQTT Broker", "path": "/mqtt" }, 16 | server = "homerpc.local.", 17 | ) 18 | 19 | self.log = Log(disable = not log) 20 | 21 | self.loop = loop or asyncio.get_event_loop() 22 | 23 | async def _broker_coro(self): 24 | broker = Broker(config = { 25 | "listeners": { 26 | "default": { 27 | "type": "ws", 28 | "bind": "0.0.0.0:%d" % self.info.port, 29 | "max_connections": 0 30 | } 31 | }, 32 | "sys_interval": 10, 33 | "auth": { 34 | "allow-anonymous": True, 35 | "plugins": ['auth.anonymous'] 36 | }, 37 | "topic-check": { 38 | "enabled": False 39 | } 40 | }) 41 | await broker.start() 42 | 43 | def _register_service(self): 44 | self.zeroconf.register_service(self.info) 45 | 46 | def _init_task(self): 47 | raise NotImplementedError 48 | 49 | def loop_forever(self): 50 | self._register_service() 51 | 52 | try: 53 | self.loop.run_until_complete(self._broker_coro()) 54 | self._init_task() 55 | self.loop.run_forever() 56 | except Exception as e: 57 | self.loop.stop() 58 | self.log.log_error(e) -------------------------------------------------------------------------------- /server/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import colorlog 3 | import traceback 4 | 5 | class Log: 6 | def __init__(self, log_file=None, log_level=logging.INFO, disable=False): 7 | if disable: 8 | log_level = logging.CRITICAL + 1 9 | 10 | log_format = ( 11 | "%(log_color)s[%(asctime)s] :: %(levelname)s :: %(name)s :: %(message)s%(reset)s" 12 | ) 13 | 14 | colorlog.basicConfig( 15 | filename=log_file, 16 | level=log_level, 17 | format=log_format, 18 | log_colors={ 19 | "DEBUG": "green", 20 | "INFO": "cyan", 21 | "WARNING": "yellow", 22 | "ERROR": "red", 23 | "CRITICAL": "red,bg_white", 24 | }, 25 | ) 26 | 27 | def log(self, message): 28 | logging.info(message) 29 | 30 | def log_error(self, message): 31 | error_info = traceback.format_exc() 32 | logging.error(f"{message} => {error_info}") -------------------------------------------------------------------------------- /server/rule.py: -------------------------------------------------------------------------------- 1 | from .log import Log 2 | import asyncio 3 | import json 4 | from amqtt.client import MQTTClient 5 | from amqtt.mqtt.constants import QOS_1 6 | from multiprocessing import Queue 7 | 8 | class Rule: 9 | def __init__(self, loop: asyncio.AbstractEventLoop = None, \ 10 | recv: Queue = None, send: Queue = None, topic: Queue = None, log: Log = True): 11 | self.log = Log(disable = not log) 12 | self.loop = loop or asyncio.get_event_loop() 13 | (self.recv, self.send) = (recv, send) 14 | self.topic = topic 15 | 16 | async def _subscribe(self): 17 | self.client = MQTTClient( 18 | client_id = "self", 19 | config = { 20 | "keep_alive": 60, 21 | "reconnect_max_interval": 5, 22 | "reconnect_retries": 5, 23 | "ping_delay": 2, 24 | }) 25 | try: 26 | await self.client.connect('ws://127.0.0.1:3000/') 27 | await self.client.subscribe([ 28 | ('/topics/register', QOS_1), 29 | ('/callback/master', QOS_1), 30 | ]) 31 | except Exception as ce: 32 | self.log.log_error(ce) 33 | 34 | async def _listen_mqtt(self): 35 | while True: 36 | message = await self.client.deliver_message() 37 | packet = message.publish_packet 38 | try: 39 | data = packet.payload.data.decode('utf-8') 40 | json_data = json.loads(data) 41 | self.log.log(json_data) 42 | if (packet.variable_header.topic_name == "/topics/register"): 43 | self.loop.run_in_executor(None, self.topic.put, json_data) 44 | elif (packet.variable_header.topic_name == "/callback/master"): 45 | self.loop.run_in_executor(None, self.recv.put, data) 46 | except Exception as e: 47 | self.log.log_error(e) 48 | 49 | async def _listen_queue(self): 50 | while True: 51 | data = await self.loop.run_in_executor(None, self.send.get) 52 | data = json.loads(data) 53 | self.log.log(data) 54 | await self._call(data['topic'], bytearray(data['message'], 'utf-8')) 55 | 56 | async def _call(self, topic, message): 57 | await self.client.subscribe([ 58 | (topic, QOS_1), 59 | ]) 60 | await self.client.publish(topic, message, qos = QOS_1) 61 | 62 | def listen(self): 63 | self.loop.run_until_complete(self._subscribe()) 64 | self.loop.create_task(self._listen_mqtt()) 65 | self.loop.create_task(self._listen_queue()) -------------------------------------------------------------------------------- /src/home_rpc.c: -------------------------------------------------------------------------------- 1 | #include "freertos/FreeRTOS.h" 2 | #include "freertos/task.h" 3 | #include "freertos/event_groups.h" 4 | #include "freertos/queue.h" 5 | #include 6 | #include 7 | #include "HomeRPC.h" 8 | #include "rpc_mdns.h" 9 | #include "rpc_mesh.h" 10 | #include "rpc_mqtt.h" 11 | #include "rpc_log.h" 12 | #include "rpc_data.h" 13 | #include "esp_event.h" 14 | #include "esp_err.h" 15 | #include "esp_netif.h" 16 | #include "esp_mac.h" 17 | 18 | static DeviceList_t *device_list_end = NULL; 19 | static SemaphoreHandle_t addDeviceMutex = NULL; 20 | 21 | static const char *TAG = "core"; 22 | typedef struct { 23 | QueueHandle_t device_queue; 24 | EventGroupHandle_t event_group; 25 | QueueHandle_t response_queue; 26 | DeviceList_t *device_list; 27 | char uri[32]; 28 | } SharedData_t; 29 | static SharedData_t SharedData; 30 | 31 | static TimerHandle_t timer; 32 | 33 | static void mdns_search() { 34 | static esp_ip4_addr_t addr = { 0 }; 35 | if (addr.addr != 0) return; 36 | rpc_log.log_info(TAG, "MDNS Search"); 37 | esp_err_t err = rpc_mdns_search(CONFIG_BROKER_URL, &addr); 38 | if (err == ESP_OK) { 39 | snprintf(SharedData.uri, sizeof(SharedData.uri), "ws://%d.%d.%d.%d:3000", IP2STR(&addr)); 40 | xTaskCreatePinnedToCore(rpc_mqtt_task, "event_loop", 4096, (void *)&SharedData, 5, NULL, 0); 41 | xTimerDelete(timer, portMAX_DELAY); 42 | } 43 | } 44 | 45 | static void got_ip_event_handler(void *arg, esp_event_base_t event_base, int32_t event_id, void *event_data) { 46 | rpc_log.log_info(TAG, "Got IP"); 47 | static bool task = false; 48 | if (!task) { 49 | RPC_ERROR_CHECK(TAG, rpc_mdns_init()); 50 | timer = xTimerCreate("mdns_search", 1000 / portTICK_PERIOD_MS, pdTRUE, NULL, mdns_search); 51 | xTimerStart(timer, 0); 52 | task = true; 53 | } 54 | } 55 | 56 | // only called once 57 | static void rpc_start(void) { 58 | static bool initialized = false; 59 | if (initialized) 60 | return; 61 | 62 | rpc_log.log_info(TAG, "Setup"); 63 | 64 | rpc_mesh_init(got_ip_event_handler); 65 | 66 | uint8_t mac[6]; 67 | esp_efuse_mac_get_default(mac); 68 | snprintf(callback_topic, sizeof(callback_topic), "/callback/%02X%02X%02X", mac[3], mac[4], mac[5]); 69 | 70 | SharedData.device_queue = xQueueCreate(10, sizeof(Device_t *)); 71 | SharedData.event_group = xEventGroupCreate(); 72 | SharedData.response_queue = xQueueCreate(10, sizeof(rpc_any_t)); 73 | addDeviceMutex = xSemaphoreCreateMutex(); 74 | 75 | initialized = true; 76 | } 77 | 78 | // thread-safe 79 | void rpc_addDevice(const Device_t *dev) { 80 | if (xSemaphoreTake(addDeviceMutex, portMAX_DELAY) == pdTRUE) { 81 | DeviceList_t *new_device = malloc(sizeof(DeviceList_t)); 82 | if (new_device == NULL) { 83 | rpc_log.log_error(TAG, "Failed to allocate memory for new device"); 84 | xSemaphoreGive(addDeviceMutex); 85 | return; 86 | } 87 | new_device->device = malloc(sizeof(Device_t)); 88 | if (new_device->device == NULL) { 89 | rpc_log.log_error(TAG, "Failed to allocate memory for new device data"); 90 | free(new_device); 91 | xSemaphoreGive(addDeviceMutex); 92 | return; 93 | } 94 | memcpy(new_device->device, dev, sizeof(Device_t)); 95 | new_device->next = NULL; 96 | 97 | static EventBits_t nextWaitBit = 1; 98 | for (unsigned int i = 0; i < new_device->device->services_num; i++) { 99 | new_device->device->services[i]._wait = nextWaitBit; 100 | nextWaitBit <<= 1; 101 | } 102 | 103 | if (SharedData.device_list == NULL) { 104 | SharedData.device_list = new_device; 105 | device_list_end = new_device; 106 | } else { 107 | device_list_end->next = new_device; 108 | device_list_end = new_device; 109 | } 110 | 111 | rpc_log.log_info(TAG, "Add Device"); 112 | if (xQueueSendToBack(SharedData.device_queue, &new_device->device, 0) != pdTRUE) { 113 | rpc_log.log_error(TAG, "Failed to add device to queue"); 114 | } 115 | 116 | xSemaphoreGive(addDeviceMutex); 117 | } 118 | } 119 | 120 | // thread-safe 121 | static rpc_any_t rpc_callService(const Device_t *dev, const char *name, 122 | const rpc_any_t *params, unsigned int params_num, const TickType_t timeout_s) { 123 | rpc_log.log_info(TAG, "Call Service"); 124 | rpc_mqtt_call(dev, name, params, params_num); 125 | rpc_any_t res; 126 | if (xQueueReceive(SharedData.response_queue, &res, pdMS_TO_TICKS(timeout_s)) == pdTRUE) { 127 | rpc_log.log_info(TAG, "Service called"); 128 | return res; 129 | } 130 | rpc_log.log_error(TAG, "Timeout waiting for response"); 131 | return (rpc_any_t) { .i = -1 }; 132 | } 133 | 134 | HomeRPC_t HomeRPC = { 135 | .log_enable = true, 136 | .log_level = ESP_LOG_ERROR, 137 | .start = rpc_start, 138 | .addDevice = rpc_addDevice, 139 | ._callService = rpc_callService 140 | }; -------------------------------------------------------------------------------- /src/rpc_data.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include "HomeRPC.h" 3 | #include "rpc_data.h" 4 | #include 5 | #include 6 | 7 | char* serialize_device(const Device_t* device) { 8 | cJSON* json = cJSON_CreateObject(); 9 | cJSON_AddStringToObject(json, "place", device->place); 10 | cJSON_AddStringToObject(json, "type", device->type); 11 | cJSON_AddNumberToObject(json, "id", device->id); 12 | cJSON* services = cJSON_CreateArray(); 13 | for (unsigned int i = 0; i < device->services_num; i++) { 14 | cJSON* service = cJSON_CreateObject(); 15 | cJSON_AddStringToObject(service, "input_type", device->services[i].input_type); 16 | cJSON_AddStringToObject(service, "output_type", &device->services[i].output_type); 17 | cJSON_AddStringToObject(service, "name", device->services[i].name); 18 | cJSON_AddStringToObject(service, "desc", device->services[i].desc); 19 | cJSON_AddItemToArray(services, service); 20 | } 21 | cJSON_AddItemToObject(json, "services", services); 22 | cJSON_AddNumberToObject(json, "services_num", device->services_num); 23 | char* json_string = cJSON_Print(json); 24 | cJSON_Delete(json); 25 | return json_string; 26 | } 27 | 28 | // need to free the returned buffer! 29 | static char* serializeRpcAny(const rpc_any_t* data) { 30 | char* buffer = (char*)malloc(sizeof(rpc_any_t) * 2 + 1); 31 | if (buffer == NULL) 32 | return NULL; 33 | 34 | for (size_t i = 0; i < sizeof(rpc_any_t); i++) 35 | sprintf(buffer + i * 2, "%02x", ((unsigned char*)data)[i]); 36 | 37 | buffer[sizeof(rpc_any_t) * 2] = '\0'; 38 | return buffer; 39 | } 40 | 41 | char* serialize_service(const char* callback, const rpc_any_t* params, unsigned int params_num) { 42 | cJSON* json = cJSON_CreateObject(); 43 | cJSON_AddStringToObject(json, "callback", callback); 44 | if (params_num == 0) { 45 | char* json_string = cJSON_Print(json); 46 | cJSON_Delete(json); 47 | return json_string; 48 | } 49 | cJSON* params_array = cJSON_CreateArray(); 50 | char *param = NULL; 51 | for (unsigned int i = 0; i < params_num; i++) { 52 | param = serializeRpcAny(¶ms[i]); 53 | if (param == NULL) { 54 | cJSON_Delete(json); 55 | return NULL; 56 | } 57 | cJSON_AddItemToArray(params_array, cJSON_CreateString(param)); 58 | free(param); 59 | } 60 | cJSON_AddItemToObject(json, "params", params_array); 61 | char* json_string = cJSON_Print(json); 62 | cJSON_Delete(json); 63 | return json_string; 64 | } 65 | 66 | static rpc_any_t deserializeRpcAny(const char* buffer) { 67 | rpc_any_t data; 68 | for (size_t i = 0; i < sizeof(rpc_any_t); i++) { 69 | sscanf(buffer + i * 2, "%02x", (unsigned int*)&((unsigned char*)&data)[i]); 70 | } 71 | return data; 72 | } 73 | 74 | int deserialize_service(cJSON *json, char** callback, rpc_any_t *params, unsigned int* params_num) { 75 | if (json == NULL) { 76 | return -1; 77 | } 78 | 79 | cJSON* callback_json = cJSON_GetObjectItem(json, "callback"); 80 | if (callback_json == NULL) { 81 | return -1; 82 | } 83 | *callback = strdup(callback_json->valuestring); 84 | 85 | cJSON* params_array = cJSON_GetObjectItem(json, "params"); 86 | if (params_array == NULL) { 87 | *params_num = 0; 88 | return 0; 89 | } 90 | 91 | *params_num = cJSON_GetArraySize(params_array); 92 | if (*params_num > CONFIG_PARAMS_MAX) { 93 | free(*callback); 94 | return -1; 95 | } 96 | 97 | for (unsigned int i = 0; i < *params_num; i++) { 98 | cJSON* param_json = cJSON_GetArrayItem(params_array, i); 99 | if (param_json == NULL) { 100 | free(*callback); 101 | return -1; 102 | } 103 | params[i] = deserializeRpcAny(param_json->valuestring); 104 | } 105 | 106 | return 0; 107 | } 108 | 109 | cJSON* rpc_any_to_json(rpc_any_t data) { 110 | cJSON* json = cJSON_CreateObject(); 111 | if (json == NULL) { 112 | return NULL; 113 | } 114 | 115 | const char* str = serializeRpcAny(&data); 116 | if (str == NULL) { 117 | cJSON_Delete(json); 118 | return NULL; 119 | } 120 | 121 | cJSON* result = cJSON_AddStringToObject(json, "result", str); 122 | free((void*)str); 123 | if (result == NULL) { 124 | cJSON_Delete(json); 125 | return NULL; 126 | } 127 | 128 | return json; 129 | } 130 | 131 | rpc_any_t json_to_rpc_any(cJSON* json) { 132 | cJSON* result = cJSON_GetObjectItem(json, "result"); 133 | if (result == NULL || strlen(result->valuestring) < sizeof(rpc_any_t) * 2) { 134 | return (rpc_any_t) { .i = -1 }; 135 | } 136 | 137 | return deserializeRpcAny(result->valuestring); 138 | } -------------------------------------------------------------------------------- /src/rpc_log.c: -------------------------------------------------------------------------------- 1 | #include "esp_log.h" 2 | #include "rpc_log.h" 3 | #include 4 | #include 5 | #include "HomeRPC.h" 6 | 7 | void log_info(const char* tag, const char* format, ...) { 8 | if (HomeRPC.log_enable && HomeRPC.log_level <= ESP_LOG_INFO) { 9 | va_list args; 10 | va_start(args, format); 11 | printf("%s[HomeRPC::%s]: ", LOG_COLOR_I, tag); 12 | vprintf(format, args); 13 | printf("%s\n", LOG_RESET_COLOR); 14 | va_end(args); 15 | } 16 | } 17 | 18 | void log_error(const char* tag, const char* format, ...) { 19 | if (HomeRPC.log_enable && HomeRPC.log_level <= ESP_LOG_ERROR) { 20 | va_list args; 21 | va_start(args, format); 22 | printf("%s[HomeRPC::%s]: ", LOG_COLOR_E, tag); 23 | vprintf(format, args); 24 | printf("%s\n", LOG_RESET_COLOR); 25 | va_end(args); 26 | } 27 | } 28 | 29 | void log_warn(const char* tag, const char* format, ...) { 30 | if (HomeRPC.log_enable && HomeRPC.log_level <= ESP_LOG_WARN) { 31 | va_list args; 32 | va_start(args, format); 33 | printf("%s[HomeRPC::%s]: ", LOG_COLOR_W, tag); 34 | vprintf(format, args); 35 | printf("%s\n", LOG_RESET_COLOR); 36 | va_end(args); 37 | } 38 | } 39 | 40 | rpc_log_t rpc_log = { 41 | .log_info = log_info, 42 | .log_error = log_error, 43 | .log_warn = log_warn, 44 | }; -------------------------------------------------------------------------------- /src/rpc_mdns.c: -------------------------------------------------------------------------------- 1 | #include "mdns.h" 2 | #include "rpc_mdns.h" 3 | #include "rpc_log.h" 4 | #include "esp_err.h" 5 | #include "esp_netif.h" 6 | 7 | static const char *TAG = "rpc_mdns"; 8 | 9 | esp_err_t rpc_mdns_init(void) 10 | { 11 | esp_err_t err = mdns_init(); 12 | if (err) { 13 | rpc_log.log_error(TAG, "MDNS Init failed"); 14 | return err; 15 | } 16 | return ESP_OK; 17 | } 18 | 19 | esp_err_t rpc_mdns_search(const char *host, esp_ip4_addr_t *ip) 20 | { 21 | assert(ip != NULL); 22 | rpc_log.log_info(TAG, "Searching for %s", host); 23 | esp_err_t err = mdns_query_a(host, 500, ip); 24 | if (err) { 25 | if (err == ESP_ERR_NOT_FOUND) 26 | rpc_log.log_info(TAG, "Host was not found!"); 27 | else 28 | rpc_log.log_info(TAG, "Query Failed"); 29 | return err; 30 | } else { 31 | rpc_log.log_info(TAG, "Found %s at " IPSTR, host, IP2STR(ip)); 32 | return ESP_OK; 33 | } 34 | } 35 | 36 | void rpc_mdns_stop(void) 37 | { 38 | mdns_free(); 39 | } -------------------------------------------------------------------------------- /src/rpc_mesh.c: -------------------------------------------------------------------------------- 1 | #include "rpc_mesh.h" 2 | #include "rpc_log.h" 3 | #include "HomeRPC.h" 4 | #include "freertos/FreeRTOS.h" 5 | #include "freertos/task.h" 6 | #include "freertos/timers.h" 7 | 8 | #include "esp_wifi.h" 9 | #include "nvs_flash.h" 10 | #include 11 | 12 | #if ESP_IDF_VERSION >= ESP_IDF_VERSION_VAL(4, 4, 0) 13 | #include "esp_mac.h" 14 | #endif 15 | 16 | #include "esp_bridge.h" 17 | #include "esp_mesh_lite.h" 18 | 19 | static const char *TAG = "rpc_mesh"; 20 | 21 | static esp_err_t esp_storage_init(void) { 22 | esp_err_t ret = nvs_flash_init(); 23 | 24 | if (ret == ESP_ERR_NVS_NO_FREE_PAGES || ret == ESP_ERR_NVS_NEW_VERSION_FOUND) { 25 | rpc_log.log_info(TAG, "NVS Flash Erase"); 26 | RPC_ERROR_CHECK(TAG, nvs_flash_erase()); 27 | ret = nvs_flash_init(); 28 | } 29 | 30 | rpc_log.log_info(TAG, "NVS Flash Init"); 31 | return ret; 32 | } 33 | 34 | static void wifi_init(void) { 35 | // Station 36 | wifi_config_t wifi_config = { 37 | .sta = { 38 | .ssid = CONFIG_ROUTER_SSID, 39 | .password = CONFIG_ROUTER_PASSWORD, 40 | }, 41 | }; 42 | rpc_log.log_info(TAG, "Setting WiFi Station Config"); 43 | esp_bridge_wifi_set_config(WIFI_IF_STA, &wifi_config); 44 | 45 | // Softap 46 | snprintf((char *)wifi_config.ap.ssid, sizeof(wifi_config.ap.ssid), "%s", CONFIG_BRIDGE_SOFTAP_SSID); 47 | strlcpy((char *)wifi_config.ap.password, CONFIG_BRIDGE_SOFTAP_PASSWORD, sizeof(wifi_config.ap.password)); 48 | rpc_log.log_info(TAG, "Setting WiFi SoftAP Config"); 49 | esp_bridge_wifi_set_config(WIFI_IF_AP, &wifi_config); 50 | } 51 | 52 | void app_wifi_set_softap_info(void) { 53 | char softap_ssid[32]; 54 | uint8_t softap_mac[6]; 55 | esp_wifi_get_mac(WIFI_IF_AP, softap_mac); 56 | memset(softap_ssid, 0x0, sizeof(softap_ssid)); 57 | 58 | #ifdef CONFIG_BRIDGE_SOFTAP_SSID_END_WITH_THE_MAC 59 | snprintf(softap_ssid, sizeof(softap_ssid), "%.25s_%02x%02x%02x", CONFIG_BRIDGE_SOFTAP_SSID, softap_mac[3], softap_mac[4], softap_mac[5]); 60 | #else 61 | snprintf(softap_ssid, sizeof(softap_ssid), "%.32s", CONFIG_BRIDGE_SOFTAP_SSID); 62 | #endif 63 | rpc_log.log_info(TAG, "Setting SoftAP Info"); 64 | esp_mesh_lite_set_softap_ssid_to_nvs(softap_ssid); 65 | esp_mesh_lite_set_softap_psw_to_nvs(CONFIG_BRIDGE_SOFTAP_PASSWORD); 66 | esp_mesh_lite_set_softap_info(softap_ssid, CONFIG_BRIDGE_SOFTAP_PASSWORD); 67 | } 68 | 69 | void rpc_mesh_init(esp_event_handler_t got_ip_handler) { 70 | rpc_log.log_info(TAG, "Initializing RPC Mesh"); 71 | esp_storage_init(); 72 | 73 | RPC_ERROR_CHECK(TAG, esp_netif_init()); 74 | RPC_ERROR_CHECK(TAG, esp_event_loop_create_default()); 75 | 76 | esp_bridge_create_all_netif(); 77 | rpc_log.log_info(TAG, "Creating Netif"); 78 | 79 | wifi_init(); 80 | rpc_log.log_info(TAG, "Setting WiFi Mode"); 81 | 82 | esp_mesh_lite_config_t mesh_lite_config = ESP_MESH_LITE_DEFAULT_INIT(); 83 | esp_mesh_lite_init(&mesh_lite_config); 84 | 85 | app_wifi_set_softap_info(); 86 | 87 | esp_mesh_lite_start(); 88 | rpc_log.log_info(TAG, "Starting Mesh Lite"); 89 | 90 | RPC_ERROR_CHECK(TAG, esp_event_handler_instance_register(IP_EVENT, IP_EVENT_STA_GOT_IP, got_ip_handler, NULL, NULL)); 91 | rpc_log.log_info(TAG, "RPC Mesh Initialized"); 92 | } --------------------------------------------------------------------------------