├── .gitignore
├── CMakeLists.txt
├── Kconfig
├── README.md
├── assets
├── chat_with_llm.mp4
└── works.png
├── example
└── test
│ ├── CMakeLists.txt
│ ├── idf_component.yml
│ └── main.c
├── homeagent.py
├── idf_component.yml
├── include
├── HomeRPC.h
├── rpc_data.h
├── rpc_log.h
├── rpc_mdns.h
├── rpc_mesh.h
└── rpc_mqtt.h
├── license.txt
├── qwen_agent
├── __init__.py
├── agent.py
├── agents
│ ├── __init__.py
│ ├── article_agent.py
│ ├── assistant.py
│ ├── docqa_agent.py
│ ├── fncall_agent.py
│ ├── group_chat.py
│ ├── group_chat_auto_router.py
│ ├── group_chat_creator.py
│ ├── react_chat.py
│ ├── router.py
│ ├── user_agent.py
│ └── write_from_scratch.py
├── llm
│ ├── __init__.py
│ ├── base.py
│ ├── function_calling.py
│ ├── oai.py
│ ├── qwen_dashscope.py
│ ├── qwenvl_dashscope.py
│ ├── schema.py
│ └── text_base.py
├── log.py
├── memory
│ ├── __init__.py
│ └── memory.py
├── prompts
│ ├── __init__.py
│ ├── continue_writing.py
│ ├── doc_qa.py
│ ├── expand_writing.py
│ ├── gen_keyword.py
│ └── outline_writing.py
├── tools
│ ├── __init__.py
│ ├── amap_weather.py
│ ├── base.py
│ ├── code_interpreter.py
│ ├── doc_parser.py
│ ├── image_gen.py
│ ├── resource
│ │ ├── AlibabaPuHuiTi-3-45-Light.ttf
│ │ ├── code_interpreter_init_kernel.py
│ │ └── image_service.py
│ ├── retrieval.py
│ ├── similarity_search.py
│ ├── storage.py
│ └── web_extractor.py
└── utils
│ ├── __init__.py
│ ├── doc_parser.py
│ ├── qwen.tiktoken
│ ├── tokenization_qwen.py
│ └── utils.py
├── requirements.txt
├── server.py
├── server
├── __init__.py
├── base.py
├── broker.py
├── log.py
└── rule.py
└── src
├── home_rpc.c
├── rpc_data.c
├── rpc_log.c
├── rpc_mdns.c
├── rpc_mesh.c
└── rpc_mqtt.c
/.gitignore:
--------------------------------------------------------------------------------
1 | **/__pycache__/
--------------------------------------------------------------------------------
/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | file(GLOB SOURCES "src/*.c")
2 |
3 | idf_component_register(SRCS "${SOURCES}"
4 | INCLUDE_DIRS "include"
5 | REQUIRES mdns esp_wifi nvs_flash esp_netif esp_event iot_bridge mesh_lite mqtt
6 | )
7 |
--------------------------------------------------------------------------------
/Kconfig:
--------------------------------------------------------------------------------
1 | menu "HomeRPC configuration"
2 |
3 | config ROUTER_SSID
4 | string "Router SSID"
5 | default "ROUTER_SSID"
6 | help
7 | Router SSID.
8 |
9 | config ROUTER_PASSWORD
10 | string "Router password"
11 | default "ROUTER_PASSWORD"
12 | help
13 | Router password.
14 |
15 | config BROKER_URL
16 | string "Broker URL"
17 | default "homerpc"
18 | help
19 | Broker URL.
20 |
21 | config PARAMS_MAX
22 | int "Params max"
23 | default 10
24 | help
25 | Params max.
26 |
27 | config TOPIC_LEN_MAX
28 | int "Topic len max"
29 | default 128
30 | help
31 | Topic len max.
32 |
33 | endmenu
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # *HomeRPC*
4 |
5 | [](https://www.freertos.org/)
6 | [](https://docs.espressif.com/projects/esp-idf/en/latest/esp32/)
7 | [](https://mqtt.org/)
8 | [](./LICENSE)
9 |
10 | ### **适用于智能家居场景的嵌入式 RPC 框架**
11 |
12 |
13 |
14 | ## 🚀 项目介绍
15 |
16 | **`HomeRPC`** 是一个面向智能家居场景的嵌入式 RPC 框架,旨在提供一种简单、高效的远程调用解决方案,非常适合与 **大语言模型** 进行结合,实现智能家居场景下的对话式交互。
17 |
18 | ## 🎨 效果展示
19 |
20 | https://github.com/guidons-master/HomeRPC/assets/41904603/89fafaa5-8a3b-4159-9780-6227521c16d5
21 |
22 | ```bash
23 | streamlit run homeagent.py
24 | ```
25 |
26 | ## ⚙️ 实现原理
27 |
28 | 
29 |
30 | ## 🛠️ 使用说明
31 |
32 | `HomeRPC` 支持以下基本数据类型作为命令参数:
33 |
34 | | 类型 | 签名 | 示例 |
35 | | ----------------------- | ---- | ----- |
36 | | char(字符) | c | 'a' |
37 | | short、int、long(数字) | i | 123 |
38 | | float(单精度浮点数) | f | 3.14 |
39 | | double(双精度浮点数) | d | 3.141 |
40 |
41 | 其中 `rpc_any_t` 类型的定义如下:
42 |
43 | ```
44 | typedef union {
45 | char c;
46 | unsigned char uc;
47 | short s;
48 | unsigned short us;
49 | int i;
50 | unsigned int ui;
51 | long l;
52 | float f;
53 | double d;
54 | } __attribute__((aligned(1))) rpc_any_t;
55 | ```
56 |
57 | ### 📚 示例代码
58 |
59 | `ESP32` 客户端示例代码如下:
60 | ```c
61 | // file: main.c
62 | #include "freertos/FreeRTOS.h"
63 | #include "driver/gpio.h"
64 | #include "HomeRPC.h"
65 |
66 | #define BLINK_GPIO 2
67 |
68 | static uint8_t s_led_state = 0;
69 |
70 | static void configure_led(void) {
71 | gpio_reset_pin(BLINK_GPIO);
72 | gpio_set_direction(BLINK_GPIO, GPIO_MODE_OUTPUT);
73 | }
74 |
75 | // 触发LED
76 | static rpc_any_t trigger_led(rpc_any_t state) {
77 | gpio_set_level(BLINK_GPIO, state.uc);
78 | rpc_any_t ret;
79 | ret.i = 0;
80 | return ret;
81 | }
82 |
83 | // 获取LED状态
84 | static rpc_any_t led_status(void) {
85 | rpc_any_t ret;
86 | ret.uc = s_led_state;
87 | return ret;
88 | }
89 |
90 | void app_main(void) {
91 | configure_led();
92 | // 启动HomeRPC
93 | HomeRPC.start();
94 | // 服务列表
95 | Service_t services[] = {
96 | {
97 | .func = trigger_led,
98 | .input_type = "i",
99 | .output_type = 'i',
100 | .name = "trigger",
101 | .desc = "open the light",
102 | },
103 | {
104 | .func = led_status,
105 | .input_type = "",
106 | .output_type = 'i',
107 | .name = "status",
108 | .desc = "check the light status, return 1 if on, 0 if off",
109 | }
110 | };
111 | // 设备信息
112 | Device_t led = {
113 | .place = "room",
114 | .type = "light",
115 | .id = 1,
116 | .services = services,
117 | .services_num = sizeof(services) / sizeof(Service_t)
118 | };
119 | // 注册设备
120 | HomeRPC.addDevice(&led);
121 | // 调用服务
122 | Device_t led2 = {
123 | .place = "room",
124 | .type = "light",
125 | .id = 1
126 | };
127 |
128 | while (1) {
129 | vTaskDelay(5000 / portTICK_PERIOD_MS);
130 | // 调用服务
131 | // rpc_any_t status = HomeRPC.callService(&led2, "status", NULL, 10);
132 | // printf("led status: %d\n", status.i);
133 | }
134 | }
135 | ```
136 |
137 | `Broker` 服务端示例代码如下:
138 | ```Python
139 | # file: server.py
140 | from server import HomeRPC
141 |
142 | if __name__ == '__main__':
143 | # 启动HomeRPC
144 | HomeRPC.setup(ip = "192.168.43.9", log = True)
145 |
146 | # 等待ESP32连接
147 | input("Waiting for ESP32 to connect...")
148 |
149 | place = HomeRPC.place("room")
150 | # 调用ESP32客户端服务
151 | place.device("light").id(1).call("trigger", 1, timeout_s = 10)
152 | print("led status: ", place.device("light").id(1).call("status", timeout_s = 10))
153 | ```
154 | ## 📦 安装方法
155 |
156 | 1. 将 `HomeRPC` 组件添加到您的 `ESP-IDF` 项目中:
157 | ```bash
158 | cd ~/my_esp_idf_project
159 | mkdir components
160 | cd components
161 | git clone https://github.com/guidons-master/HomeRPC.git
162 | ```
163 | 2. 在 `menuconfig` 中配置 `HomeRPC`
164 |
165 | ## 🧑💻 维护人员
166 |
167 | - [@guidons](https://github.com/guidons-master)
168 | - [@Hexin Lv](https://github.com/Mondaylv)
169 |
--------------------------------------------------------------------------------
/assets/chat_with_llm.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidons-master/HomeRPC/07147c3d4fc554b46bec8296e393d727a7b6238c/assets/chat_with_llm.mp4
--------------------------------------------------------------------------------
/assets/works.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidons-master/HomeRPC/07147c3d4fc554b46bec8296e393d727a7b6238c/assets/works.png
--------------------------------------------------------------------------------
/example/test/CMakeLists.txt:
--------------------------------------------------------------------------------
1 |
2 | idf_component_register(SRCS "main.c"
3 | INCLUDE_DIRS ".")
4 |
--------------------------------------------------------------------------------
/example/test/idf_component.yml:
--------------------------------------------------------------------------------
1 | dependencies:
2 | espressif/mdns: "^1.2.5"
3 | idf:
4 | version: '>=4.3'
5 | mesh_lite:
6 | version: '*'
7 | usb_device:
8 | git: https://github.com/espressif/esp-iot-bridge.git
9 | path: components/usb/usb_device
10 | rules:
11 | - if: target in [esp32s2, esp32s3]
12 |
--------------------------------------------------------------------------------
/example/test/main.c:
--------------------------------------------------------------------------------
1 | #include "freertos/FreeRTOS.h"
2 | #include "driver/gpio.h"
3 | #include "HomeRPC.h"
4 |
5 | #define BLINK_GPIO 2
6 |
7 | static uint8_t s_led_state = 0;
8 |
9 | static void configure_led(void) {
10 | gpio_reset_pin(BLINK_GPIO);
11 | gpio_set_direction(BLINK_GPIO, GPIO_MODE_OUTPUT);
12 | }
13 |
14 | // 触发LED
15 | static rpc_any_t trigger_led(rpc_any_t state) {
16 | gpio_set_level(BLINK_GPIO, state.uc);
17 | rpc_any_t ret;
18 | ret.i = 0;
19 | return ret;
20 | }
21 |
22 | // 获取LED状态
23 | static rpc_any_t led_status(void) {
24 | rpc_any_t ret;
25 | ret.uc = s_led_state;
26 | return ret;
27 | }
28 |
29 | void app_main(void) {
30 | configure_led();
31 | // 启动HomeRPC
32 | HomeRPC.start();
33 | // 服务列表
34 | Service_t services[] = {
35 | {
36 | .func = trigger_led,
37 | .input_type = "i",
38 | .output_type = 'i',
39 | .name = "trigger",
40 | .desc = "open the light",
41 | },
42 | {
43 | .func = led_status,
44 | .input_type = "",
45 | .output_type = 'i',
46 | .name = "status",
47 | .desc = "check the light status, return 1 if on, 0 if off",
48 | }
49 | };
50 | // 设备信息
51 | Device_t led = {
52 | .place = "room",
53 | .type = "light",
54 | .id = 1,
55 | .services = services,
56 | .services_num = sizeof(services) / sizeof(Service_t)
57 | };
58 | // 注册设备
59 | HomeRPC.addDevice(&led);
60 | // 调用服务
61 | Device_t led2 = {
62 | .place = "room",
63 | .type = "light",
64 | .id = 1
65 | };
66 |
67 | while (1) {
68 | vTaskDelay(5000 / portTICK_PERIOD_MS);
69 | // 调用服务
70 | // rpc_any_t status = HomeRPC.callService(&led2, "status", NULL, 10);
71 | // printf("led status: %d\n", status.i);
72 | }
73 | }
--------------------------------------------------------------------------------
/homeagent.py:
--------------------------------------------------------------------------------
1 | import json
2 | from qwen_agent.llm import get_chat_model
3 | import streamlit as st
4 | from server import HomeRPC
5 |
6 | def function_call(func, place = None, device_type = None, device_id = None, input_type = None):
7 | # Call the HomeRPC function
8 | if not (func and place and device_type and device_id):
9 | return json.dumps({
10 | "status": "error",
11 | "message": "Missing required parameters"
12 | })
13 |
14 | if input_type:
15 | ret = HomeRPC.place(place).device(device_type).id(device_id).call(func, input_type)
16 | else:
17 | ret = HomeRPC.place(place).device(device_type).id(device_id).call(func)
18 |
19 | if ret is None:
20 | return json.dumps({
21 | "status": "error",
22 | "message": "Function call failed"
23 | })
24 |
25 | return json.dumps({
26 | "status": "success",
27 | "return": ret
28 | })
29 |
30 | if __name__ == "__main__":
31 | st.set_page_config(
32 | page_title="HomeAgent",
33 | page_icon=":robot:",
34 | layout="wide"
35 | )
36 |
37 | @st.cache_resource
38 | def get_model():
39 | llm = get_chat_model({
40 | 'model': 'qwen-max',
41 | 'model_server': 'dashscope',
42 | 'api_key': 'xxxxx-xxxxx-xxxxx-xxxxx-xxxxx-xxxxx-xxxxx-xxxxx',
43 | 'generate_cfg': {
44 | 'top_p': 0.8,
45 | 'temperature': 0.8
46 | }
47 | })
48 |
49 | HomeRPC.setup(ip = "192.168.43.9", log = True)
50 |
51 | return llm
52 |
53 | llm = get_model()
54 |
55 | if "history" not in st.session_state:
56 | st.session_state.history = []
57 |
58 | st.sidebar.header("HomeAgent", divider = False)
59 |
60 | st.sidebar.divider()
61 |
62 | funcName = st.sidebar.markdown("**函数调用:**")
63 | funcArgs = st.sidebar.markdown("**参数:**")
64 | callStatus = st.sidebar.markdown("**调用状态:**")
65 |
66 | st.sidebar.divider()
67 |
68 | buttonClean = st.sidebar.button("清理会话历史", key="clean")
69 | if buttonClean:
70 | st.session_state.history = []
71 | st.rerun()
72 |
73 | for i, message in enumerate(st.session_state.history):
74 | if message["role"] == "user":
75 | with st.chat_message(name="user", avatar="user"):
76 | st.markdown(message["content"])
77 | elif message["role"] == "assistant":
78 | if message["content"]:
79 | with st.chat_message(name="assistant", avatar="assistant"):
80 | st.markdown(message["content"])
81 |
82 | with st.chat_message(name="user", avatar="user"):
83 | input_placeholder = st.empty()
84 | with st.chat_message(name="assistant", avatar="assistant"):
85 | message_placeholder = st.empty()
86 |
87 | prompt_text = st.chat_input("需要什么帮助吗?")
88 |
89 | if prompt_text:
90 |
91 | st.session_state.history.append({
92 | "role": "user",
93 | "content": prompt_text
94 | })
95 |
96 | input_placeholder.markdown(prompt_text)
97 | history = st.session_state.history
98 | responses = []
99 |
100 | for responses in llm.chat(messages=history,
101 | functions=HomeRPC.funcs(),
102 | stream=True):
103 | if responses[0]["content"]:
104 | message_placeholder.markdown(responses[0]["content"])
105 |
106 | history.extend(responses)
107 |
108 | last_response = history[-1]
109 |
110 | if last_response.get('function_call', None):
111 |
112 | try:
113 | function_name = last_response['function_call']['name']
114 | funcName.markdown(f"**函数调用:**```{function_name}```")
115 | function_args = json.loads(last_response['function_call']['arguments'])
116 | funcArgs.markdown(f"**参数:**```{function_args}```")
117 |
118 | function_response = function_call(
119 | function_name,
120 | **function_args
121 | )
122 |
123 | if json.loads(function_response).get('status') == 'success':
124 | callStatus.markdown(f"**调用状态:**```{function_response}```", unsafe_allow_html=True)
125 | else:
126 | callStatus.markdown(f"**调用状态:**```{function_response}```", unsafe_allow_html=True)
127 |
128 | history.append({
129 | 'role': 'function',
130 | 'name': function_name,
131 | 'content': function_response
132 | })
133 |
134 | for responses in llm.chat(
135 | messages=history,
136 | functions=HomeRPC.funcs(),
137 | stream=True,
138 | ):
139 | message_placeholder.markdown(responses[0]["content"])
140 |
141 | history.extend(responses)
142 |
143 | except Exception as e:
144 |
145 | callStatus.markdown(f'**调用状态:**```{e}```', unsafe_allow_html=True)
--------------------------------------------------------------------------------
/idf_component.yml:
--------------------------------------------------------------------------------
1 | dependencies:
2 | espressif/mdns: "^1.2.5"
3 | idf:
4 | version: '>=4.3'
5 | mesh_lite:
6 | version: '*'
7 | usb_device:
8 | git: https://github.com/espressif/esp-iot-bridge.git
9 | path: components/usb/usb_device
10 | rules:
11 | - if: target in [esp32s2, esp32s3]
--------------------------------------------------------------------------------
/include/HomeRPC.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #ifdef __cplusplus
4 | extern "C" {
5 | #endif
6 |
7 | #include "esp_system.h"
8 | #include "esp_log.h"
9 | #include "cJSON.h"
10 | #include "freertos/FreeRTOS.h"
11 | #include "freertos/event_groups.h"
12 |
13 | #define RPC_ERROR_CHECK(TAG, err) if ((err)) { \
14 | rpc_log.log_error((TAG), "Error occurred: %s", esp_err_to_name((err))); \
15 | esp_restart(); \
16 | }
17 |
18 | typedef union {
19 | char c;
20 | unsigned char uc;
21 | short s;
22 | unsigned short us;
23 | int i;
24 | unsigned int ui;
25 | long l;
26 | float f;
27 | double d;
28 | // char* str; // todo
29 | } __attribute__((aligned(1))) rpc_any_t;
30 |
31 | typedef rpc_any_t (*rpc_func_t)(void);
32 |
33 | typedef struct {
34 | /* Public */
35 | rpc_func_t func;
36 | char *input_type;
37 | char output_type;
38 | const char *name;
39 | const char *desc;
40 | /* Private */
41 | cJSON *_input;
42 | EventBits_t _wait;
43 | } Service_t;
44 |
45 | typedef struct {
46 | char *place;
47 | char *type;
48 | unsigned int id;
49 | Service_t *services;
50 | unsigned int services_num;
51 | } Device_t;
52 |
53 | struct List_t {
54 | Device_t* device;
55 | struct List_t *next;
56 | };
57 | typedef struct List_t DeviceList_t;
58 |
59 | typedef struct {
60 | esp_log_level_t log_level;
61 | uint8_t log_enable;
62 | void (*start)(void);
63 | void (*addDevice)(const Device_t *);
64 | rpc_any_t (*_callService)(const Device_t *, const char *, const rpc_any_t *, unsigned int, TickType_t);
65 | } HomeRPC_t;
66 |
67 | #define callService(device, service, params, timeout_s) _callService((device), (service), (params), ((params) && (sizeof((params)) / sizeof(rpc_any_t))), (timeout_s) * 1000)
68 |
69 | extern HomeRPC_t HomeRPC;
70 |
71 | #ifdef __cplusplus
72 | }
73 | #endif
--------------------------------------------------------------------------------
/include/rpc_data.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #ifdef __cplusplus
4 | extern "C" {
5 | #endif
6 |
7 | #include "HomeRPC.h"
8 | #include "cJSON.h"
9 |
10 | char* serialize_device(const Device_t*);
11 | char* serialize_service(const char*, const rpc_any_t*, unsigned int);
12 | int deserialize_service(cJSON*, char**, rpc_any_t*, unsigned int*);
13 | cJSON* rpc_any_to_json(rpc_any_t);
14 | rpc_any_t json_to_rpc_any(cJSON *);
15 |
16 | #ifdef __cplusplus
17 | }
18 | #endif
--------------------------------------------------------------------------------
/include/rpc_log.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #ifdef __cplusplus
4 | extern "C" {
5 | #endif
6 |
7 | typedef struct {
8 | void (*log_info)(const char* tag, const char* format, ...);
9 | void (*log_error)(const char* tag, const char* format, ...);
10 | void (*log_warn)(const char* tag, const char* format, ...);
11 | } rpc_log_t;
12 |
13 | extern rpc_log_t rpc_log;
14 |
15 | #ifdef __cplusplus
16 | }
17 | #endif
--------------------------------------------------------------------------------
/include/rpc_mdns.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #ifdef __cplusplus
4 | extern "C" {
5 | #endif
6 |
7 | #include "esp_err.h"
8 | #include "mdns.h"
9 | #include "esp_netif.h"
10 |
11 | esp_err_t rpc_mdns_init(void);
12 | esp_err_t rpc_mdns_search(const char *, esp_ip4_addr_t *);
13 | void rpc_mdns_stop(void);
14 |
15 | #ifdef __cplusplus
16 | }
17 | #endif
--------------------------------------------------------------------------------
/include/rpc_mesh.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #ifdef __cplusplus
4 | extern "C" {
5 | #endif
6 |
7 | void rpc_mesh_init(esp_event_handler_t);
8 |
9 | #ifdef __cplusplus
10 | }
11 | #endif
--------------------------------------------------------------------------------
/include/rpc_mqtt.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #ifdef __cplusplus
4 | extern "C" {
5 | #endif
6 |
7 | #include "HomeRPC.h"
8 |
9 | extern char callback_topic[CONFIG_TOPIC_LEN_MAX];
10 | void rpc_mqtt_call(const Device_t *, const char *, const rpc_any_t*, const unsigned int);
11 | void rpc_mqtt_task(void *);
12 |
13 | #ifdef __cplusplus
14 | }
15 | #endif
--------------------------------------------------------------------------------
/license.txt:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
--------------------------------------------------------------------------------
/qwen_agent/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.0.2'
2 | from .agent import Agent
3 |
4 | __all__ = ['Agent']
5 |
--------------------------------------------------------------------------------
/qwen_agent/agent.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import json
3 | import traceback
4 | from abc import ABC, abstractmethod
5 | from typing import Dict, Iterator, List, Optional, Tuple, Union
6 |
7 | from qwen_agent.llm import get_chat_model
8 | from qwen_agent.llm.base import BaseChatModel
9 | from qwen_agent.llm.schema import (CONTENT, DEFAULT_SYSTEM_MESSAGE, ROLE,
10 | SYSTEM, ContentItem, Message)
11 | from qwen_agent.log import logger
12 | from qwen_agent.tools import TOOL_REGISTRY, BaseTool
13 | from qwen_agent.utils.utils import has_chinese_chars
14 |
15 |
16 | class Agent(ABC):
17 | """A base class for Agent.
18 |
19 | An agent can receive messages and provide response by LLM or Tools.
20 | Different agents have distinct workflows for processing messages and generating responses in the `_run` method.
21 | """
22 |
23 | def __init__(self,
24 | function_list: Optional[List[Union[str, Dict,
25 | BaseTool]]] = None,
26 | llm: Optional[Union[Dict, BaseChatModel]] = None,
27 | system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
28 | name: Optional[str] = None,
29 | description: Optional[str] = None,
30 | **kwargs):
31 | """Initialization the agent.
32 |
33 | Args:
34 | function_list: One list of tool name, tool configuration or Tool object,
35 | such as 'code_interpreter', {'name': 'code_interpreter', 'timeout': 10}, or CodeInterpreter().
36 | llm: The LLM model configuration or LLM model object.
37 | Set the configuration as {'model': '', 'api_key': '', 'model_server': ''}.
38 | system_message: The specified system message for LLM chat.
39 | name: The name of this agent.
40 | description: The description of this agent, which will be used for multi_agent.
41 | """
42 | if isinstance(llm, dict):
43 | self.llm = get_chat_model(llm)
44 | else:
45 | self.llm = llm
46 |
47 | self.function_map = {}
48 | if function_list:
49 | for tool in function_list:
50 | self._init_tool(tool)
51 |
52 | self.system_message = system_message
53 | self.name = name
54 | self.description = description
55 |
56 | def run(self, messages: List[Union[Dict, Message]],
57 | **kwargs) -> Union[Iterator[List[Message]], Iterator[List[Dict]]]:
58 | """Return one response generator based on the received messages.
59 |
60 | This method performs a uniform type conversion for the inputted messages,
61 | and calls the _run method to generate a reply.
62 |
63 | Args:
64 | messages: A list of messages.
65 |
66 | Yields:
67 | The response generator.
68 | """
69 | messages = copy.deepcopy(messages)
70 | _return_message_type = 'dict'
71 | new_messages = []
72 | # Only return dict when all input messages are dict
73 | if not messages:
74 | _return_message_type = 'message'
75 | for msg in messages:
76 | if isinstance(msg, dict):
77 | new_messages.append(Message(**msg))
78 | else:
79 | new_messages.append(msg)
80 | _return_message_type = 'message'
81 |
82 | if new_messages and 'lang' not in kwargs:
83 | if has_chinese_chars([new_messages[-1][CONTENT], kwargs]):
84 | kwargs['lang'] = 'zh'
85 | else:
86 | kwargs['lang'] = 'en'
87 |
88 | for rsp in self._run(messages=new_messages, **kwargs):
89 | for i in range(len(rsp)):
90 | if not rsp[i].name and self.name:
91 | rsp[i].name = self.name
92 | if _return_message_type == 'message':
93 | yield [Message(**x) if isinstance(x, dict) else x for x in rsp]
94 | else:
95 | yield [
96 | x.model_dump() if not isinstance(x, dict) else x
97 | for x in rsp
98 | ]
99 |
100 | @abstractmethod
101 | def _run(self,
102 | messages: List[Message],
103 | lang: str = 'en',
104 | **kwargs) -> Iterator[List[Message]]:
105 | """Return one response generator based on the received messages.
106 |
107 | The workflow for an agent to generate a reply.
108 | Each agent subclass needs to implement this method.
109 |
110 | Args:
111 | messages: A list of messages.
112 | lang: Language, which will be used to select the language of the prompt
113 | during the agent's execution process.
114 |
115 | Yields:
116 | The response generator.
117 | """
118 | raise NotImplementedError
119 |
120 | def _call_llm(
121 | self,
122 | messages: List[Message],
123 | functions: Optional[List[Dict]] = None,
124 | stream: bool = True,
125 | ) -> Iterator[List[Message]]:
126 | """The interface of calling LLM for the agent.
127 |
128 | We prepend the system_message of this agent to the messages, and call LLM.
129 |
130 | Args:
131 | messages: A list of messages.
132 | functions: The list of functions provided to LLM.
133 | stream: LLM streaming output or non-streaming output.
134 | For consistency, we default to using streaming output across all agents.
135 |
136 | Yields:
137 | The response generator of LLM.
138 | """
139 | messages = copy.deepcopy(messages)
140 | if messages[0][ROLE] != SYSTEM:
141 | messages.insert(0, Message(role=SYSTEM,
142 | content=self.system_message))
143 | elif isinstance(messages[0][CONTENT], str):
144 | messages[0][CONTENT] = self.system_message + messages[0][CONTENT]
145 | else:
146 | assert isinstance(messages[0][CONTENT], list)
147 | messages[0][CONTENT] = [ContentItem(text=self.system_message)
148 | ] + messages[0][CONTENT]
149 | return self.llm.chat(messages=messages,
150 | functions=functions,
151 | stream=stream)
152 |
153 | def _call_tool(self,
154 | tool_name: str,
155 | tool_args: Union[str, dict] = '{}',
156 | **kwargs) -> str:
157 | """The interface of calling tools for the agent.
158 |
159 | Args:
160 | tool_name: The name of one tool.
161 | tool_args: Model generated or user given tool parameters.
162 |
163 | Returns:
164 | The output of tools.
165 | """
166 | if tool_name not in self.function_map:
167 | return f'Tool {tool_name} does not exists.'
168 | tool = self.function_map[tool_name]
169 | try:
170 | tool_result = tool.call(tool_args, **kwargs)
171 | except Exception as ex:
172 | exception_type = type(ex).__name__
173 | exception_message = str(ex)
174 | traceback_info = ''.join(traceback.format_tb(ex.__traceback__))
175 | error_message = f'An error occurred when calling tool `{tool_name}`:\n' \
176 | f'{exception_type}: {exception_message}\n' \
177 | f'Traceback:\n{traceback_info}'
178 | return error_message
179 |
180 | if isinstance(tool_result, str):
181 | return tool_result
182 | else:
183 | return json.dumps(tool_result, ensure_ascii=False, indent=4)
184 |
185 | def _init_tool(self, tool: Union[str, Dict, BaseTool]):
186 | if isinstance(tool, BaseTool):
187 | tool_name = tool.name
188 | if tool_name in self.function_map:
189 | logger.warning(
190 | f'Repeatedly adding tool {tool_name}, will use the newest tool in function list'
191 | )
192 | self.function_map[tool_name] = tool
193 | else:
194 | if isinstance(tool, dict):
195 | tool_name = tool['name']
196 | tool_cfg = tool
197 | else:
198 | tool_name = tool
199 | tool_cfg = None
200 | if tool_name not in TOOL_REGISTRY:
201 | raise ValueError(f'Tool {tool_name} is not registered.')
202 |
203 | if tool_name in self.function_map:
204 | logger.warning(
205 | f'Repeatedly adding tool {tool_name}, will use the newest tool in function list'
206 | )
207 | self.function_map[tool_name] = TOOL_REGISTRY[tool_name](tool_cfg)
208 |
209 | def _detect_tool(self, message: Message) -> Tuple[bool, str, str, str]:
210 | """A built-in tool call detection for func_call format message.
211 |
212 | Args:
213 | message: one message generated by LLM.
214 |
215 | Returns:
216 | Need to call tool or not, tool name, tool args, text replies.
217 | """
218 | func_name = None
219 | func_args = None
220 |
221 | if message.function_call:
222 | func_call = message.function_call
223 | func_name = func_call.name
224 | func_args = func_call.arguments
225 | text = message.content
226 | if not text:
227 | text = ''
228 |
229 | return (func_name is not None), func_name, func_args, text
230 |
--------------------------------------------------------------------------------
/qwen_agent/agents/__init__.py:
--------------------------------------------------------------------------------
1 | from .article_agent import ArticleAgent
2 | from .assistant import Assistant
3 | from .docqa_agent import DocQAAgent
4 | from .fncall_agent import FnCallAgent
5 | from .group_chat import GroupChat
6 | from .group_chat_auto_router import GroupChatAutoRouter
7 | from .group_chat_creator import GroupChatCreator
8 | from .react_chat import ReActChat
9 | from .router import Router
10 | from .user_agent import UserAgent
11 | from .write_from_scratch import WriteFromScratch
12 |
13 | __all__ = [
14 | 'DocQAAgent', 'Assistant', 'ArticleAgent', 'ReActChat', 'Router',
15 | 'UserAgent', 'GroupChat', 'WriteFromScratch', 'GroupChatCreator',
16 | 'GroupChatAutoRouter', 'FnCallAgent'
17 | ]
18 |
--------------------------------------------------------------------------------
/qwen_agent/agents/article_agent.py:
--------------------------------------------------------------------------------
1 | from typing import Iterator, List
2 |
3 | from qwen_agent.agents.assistant import Assistant
4 | from qwen_agent.agents.write_from_scratch import WriteFromScratch
5 | from qwen_agent.llm.schema import ASSISTANT, CONTENT, Message
6 | from qwen_agent.prompts import ContinueWriting
7 |
8 |
9 | class ArticleAgent(Assistant):
10 | """This is an agent for writing articles.
11 |
12 | It can write a thematic essay or continue writing an article based on reference materials
13 | """
14 |
15 | def _run(self,
16 | messages: List[Message],
17 | lang: str = 'en',
18 | max_ref_token: int = 4000,
19 | full_article: bool = False,
20 | **kwargs) -> Iterator[List[Message]]:
21 |
22 | # Need to use Memory agent for data management
23 | *_, last = self.mem.run(messages=messages,
24 | max_ref_token=max_ref_token,
25 | **kwargs)
26 | _ref = last[-1][CONTENT]
27 |
28 | response = []
29 | if _ref:
30 | response.append(
31 | Message(ASSISTANT,
32 | f'>\n> Search for relevant information: \n{_ref}\n'))
33 | yield response
34 |
35 | if full_article:
36 | writing_agent = WriteFromScratch(llm=self.llm)
37 | else:
38 | writing_agent = ContinueWriting(llm=self.llm)
39 | response.append(Message(ASSISTANT, '>\n> Writing Text: \n'))
40 | yield response
41 |
42 | for trunk in writing_agent.run(messages=messages,
43 | lang=lang,
44 | knowledge=_ref):
45 | if trunk:
46 | yield response + trunk
47 |
--------------------------------------------------------------------------------
/qwen_agent/agents/assistant.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from typing import Iterator, List
3 |
4 | from qwen_agent.llm.schema import CONTENT, ROLE, SYSTEM, Message
5 | from qwen_agent.log import logger
6 | from qwen_agent.utils.utils import format_knowledge_to_source_and_content
7 |
8 | from .fncall_agent import FnCallAgent
9 |
10 | KNOWLEDGE_SNIPPET_ZH = """## 来自 {source} 的内容:
11 |
12 | ```
13 | {content}
14 | ```"""
15 | KNOWLEDGE_TEMPLATE_ZH = """
16 |
17 | # 知识库
18 |
19 | {knowledge}"""
20 |
21 | KNOWLEDGE_SNIPPET_EN = """## The content from {source}:
22 |
23 | ```
24 | {content}
25 | ```"""
26 | KNOWLEDGE_TEMPLATE_EN = """
27 |
28 | # Knowledge Base
29 |
30 | {knowledge}"""
31 |
32 | KNOWLEDGE_SNIPPET = {'zh': KNOWLEDGE_SNIPPET_ZH, 'en': KNOWLEDGE_SNIPPET_EN}
33 | KNOWLEDGE_TEMPLATE = {'zh': KNOWLEDGE_TEMPLATE_ZH, 'en': KNOWLEDGE_TEMPLATE_EN}
34 |
35 |
36 | class Assistant(FnCallAgent):
37 | """This is a widely applicable agent integrated with RAG capabilities and function call ability."""
38 |
39 | def _run(self,
40 | messages: List[Message],
41 | lang: str = 'en',
42 | max_ref_token: int = 4000,
43 | **kwargs) -> Iterator[List[Message]]:
44 |
45 | new_messages = self._prepend_knowledge_prompt(messages, lang,
46 | max_ref_token, **kwargs)
47 | return super()._run(messages=new_messages,
48 | lang=lang,
49 | max_ref_token=max_ref_token,
50 | **kwargs)
51 |
52 | def _prepend_knowledge_prompt(self,
53 | messages: List[Message],
54 | lang: str = 'en',
55 | max_ref_token: int = 4000,
56 | **kwargs) -> List[Message]:
57 | messages = copy.deepcopy(messages)
58 | # Retrieval knowledge from files
59 | *_, last = self.mem.run(messages=messages, max_ref_token=max_ref_token, lang=lang, **kwargs)
60 | knowledge = last[-1][CONTENT]
61 |
62 | logger.debug(
63 | f'Retrieved knowledge of type `{type(knowledge).__name__}`:\n{knowledge}'
64 | )
65 | if knowledge:
66 | knowledge = format_knowledge_to_source_and_content(knowledge)
67 | logger.debug(
68 | f'Formatted knowledge into type `{type(knowledge).__name__}`:\n{knowledge}'
69 | )
70 | else:
71 | knowledge = []
72 | snippets = []
73 | for k in knowledge:
74 | snippets.append(KNOWLEDGE_SNIPPET[lang].format(
75 | source=k['source'], content=k['content']))
76 | knowledge_prompt = ''
77 | if snippets:
78 | knowledge_prompt = KNOWLEDGE_TEMPLATE[lang].format(
79 | knowledge='\n\n'.join(snippets))
80 |
81 | if knowledge_prompt:
82 | if messages[0][ROLE] == SYSTEM:
83 | messages[0][CONTENT] += knowledge_prompt
84 | else:
85 | messages = [Message(role=SYSTEM, content=knowledge_prompt)
86 | ] + messages
87 | return messages
88 |
--------------------------------------------------------------------------------
/qwen_agent/agents/docqa_agent.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Iterator, List, Optional, Union
2 |
3 | from qwen_agent.agents.assistant import Assistant
4 | from qwen_agent.llm.base import BaseChatModel
5 | from qwen_agent.llm.schema import CONTENT, DEFAULT_SYSTEM_MESSAGE, Message
6 | from qwen_agent.prompts import DocQA
7 | from qwen_agent.tools import BaseTool
8 |
9 |
10 | class DocQAAgent(Assistant):
11 | """This is an agent for doc QA."""
12 |
13 | def __init__(self,
14 | function_list: Optional[List[Union[str, Dict,
15 | BaseTool]]] = None,
16 | llm: Optional[Union[Dict, BaseChatModel]] = None,
17 | system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
18 | name: Optional[str] = None,
19 | description: Optional[str] = None,
20 | files: Optional[List[str]] = None):
21 | super().__init__(function_list=function_list,
22 | llm=llm,
23 | system_message=system_message,
24 | name=name,
25 | description=description,
26 | files=files)
27 |
28 | self.doc_qa = DocQA(llm=self.llm)
29 |
30 | def _run(self,
31 | messages: List[Message],
32 | lang: str = 'en',
33 | max_ref_token: int = 4000,
34 | **kwargs) -> Iterator[List[Message]]:
35 |
36 | # Need to use Memory agent for data management
37 | *_, last = self.mem.run(messages=messages,
38 | max_ref_token=max_ref_token,
39 | **kwargs)
40 | _ref = last[-1][CONTENT]
41 |
42 | # Use RetrievalQA agent
43 | response = self.doc_qa.run(messages=messages,
44 | lang=lang,
45 | knowledge=_ref)
46 |
47 | return response
48 |
--------------------------------------------------------------------------------
/qwen_agent/agents/fncall_agent.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from typing import Dict, Iterator, List, Optional, Union
3 |
4 | from qwen_agent import Agent
5 | from qwen_agent.llm import BaseChatModel
6 | from qwen_agent.llm.schema import DEFAULT_SYSTEM_MESSAGE, FUNCTION, Message
7 | from qwen_agent.memory import Memory
8 | from qwen_agent.tools import BaseTool
9 |
10 | MAX_LLM_CALL_PER_RUN = 8
11 |
12 |
13 | class FnCallAgent(Agent):
14 | """This is a widely applicable function call agent integrated with llm and tool use ability."""
15 |
16 | def __init__(self,
17 | function_list: Optional[List[Union[str, Dict,
18 | BaseTool]]] = None,
19 | llm: Optional[Union[Dict, BaseChatModel]] = None,
20 | system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
21 | name: Optional[str] = None,
22 | description: Optional[str] = None,
23 | files: Optional[List[str]] = None):
24 | """Initialization the agent.
25 |
26 | Args:
27 | function_list: One list of tool name, tool configuration or Tool object,
28 | such as 'code_interpreter', {'name': 'code_interpreter', 'timeout': 10}, or CodeInterpreter().
29 | llm: The LLM model configuration or LLM model object.
30 | Set the configuration as {'model': '', 'api_key': '', 'model_server': ''}.
31 | system_message: The specified system message for LLM chat.
32 | name: The name of this agent.
33 | description: The description of this agent, which will be used for multi_agent.
34 | files: A file url list. The initialized files for the agent.
35 | """
36 | super().__init__(function_list=function_list,
37 | llm=llm,
38 | system_message=system_message,
39 | name=name,
40 | description=description)
41 |
42 | # Default to use Memory to manage files
43 | self.mem = Memory(llm=self.llm, files=files)
44 |
45 | def _run(self,
46 | messages: List[Message],
47 | lang: str = 'en',
48 | **kwargs) -> Iterator[List[Message]]:
49 | messages = copy.deepcopy(messages)
50 | num_llm_calls_available = MAX_LLM_CALL_PER_RUN
51 | response = []
52 | while True and num_llm_calls_available > 0:
53 | num_llm_calls_available -= 1
54 | output_stream = self._call_llm(
55 | messages=messages,
56 | functions=[
57 | func.function for func in self.function_map.values()
58 | ])
59 | output: List[Message] = []
60 | for output in output_stream:
61 | if output:
62 | yield response + output
63 | if output:
64 | response.extend(output)
65 | messages.extend(output)
66 | use_tool, action, action_input, _ = self._detect_tool(response[-1])
67 | if use_tool:
68 | observation = self._call_tool(action,
69 | action_input,
70 | messages=messages)
71 | fn_msg = Message(
72 | role=FUNCTION,
73 | name=action,
74 | content=observation,
75 | )
76 | messages.append(fn_msg)
77 | response.append(fn_msg)
78 | yield response
79 | else:
80 | break
81 |
82 | def _call_tool(self,
83 | tool_name: str,
84 | tool_args: Union[str, dict] = '{}',
85 | **kwargs) -> str:
86 | # Temporary plan: Check if it is necessary to transfer files to the tool
87 | # Todo: This should be changed to parameter passing, and the file URL should be determined by the model
88 | if self.function_map[tool_name].file_access:
89 | assert 'messages' in kwargs
90 | files = self.mem.get_all_files_of_messages(
91 | kwargs['messages']) + self.mem.system_files
92 | return super()._call_tool(tool_name,
93 | tool_args,
94 | files=files,
95 | **kwargs)
96 | else:
97 | return super()._call_tool(tool_name, tool_args, **kwargs)
98 |
--------------------------------------------------------------------------------
/qwen_agent/agents/group_chat_auto_router.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Iterator, List, Optional, Union
2 |
3 | from qwen_agent import Agent
4 | from qwen_agent.llm import BaseChatModel
5 | from qwen_agent.llm.schema import Message
6 | from qwen_agent.tools import BaseTool
7 | from qwen_agent.utils.utils import has_chinese_chars
8 |
9 | PROMPT_TEMPLATE_ZH = '''你扮演角色扮演游戏的上帝,你的任务是选择合适的发言角色。有如下角色:
10 | {agent_descs}
11 |
12 | 角色间的对话历史格式如下,越新的对话越重要:
13 | 角色名: 说话内容
14 |
15 | 请阅读对话历史,并选择下一个合适的发言角色,从 [{agent_names}] 里选,当真实用户最近表明了停止聊天时,或话题应该终止时,请返回“[STOP]”,用户很懒,非必要不要选真实用户。
16 | 仅返回角色名或“[STOP]”,不要返回其余内容。'''
17 |
18 | PROMPT_TEMPLATE_EN = '''You are in a role play game. The following roles are available:
19 | {agent_descs}
20 |
21 | The format of dialogue history between roles is as follows:
22 | Role Name: Speech Content
23 |
24 | Please read the dialogue history and choose the next suitable role to speak.
25 | When the user indicates to stop chatting or when the topic should be terminated, please return '[STOP]'.
26 | Only return the role name from [{agent_names}] or '[STOP]'. Do not reply any other content.'''
27 |
28 | PROMPT_TEMPLATE = {
29 | 'zh': PROMPT_TEMPLATE_ZH,
30 | 'en': PROMPT_TEMPLATE_EN,
31 | }
32 |
33 |
34 | class GroupChatAutoRouter(Agent):
35 |
36 | def __init__(self,
37 | function_list: Optional[List[Union[str, Dict,
38 | BaseTool]]] = None,
39 | llm: Optional[Union[Dict, BaseChatModel]] = None,
40 | agents: List[Agent] = None,
41 | name: Optional[str] = None,
42 | description: Optional[str] = None,
43 | **kwargs):
44 | # This agent need prepend special system message according to inputted agents
45 | agent_descs = '\n'.join([f'{x.name}: {x.description}' for x in agents])
46 | lang = 'en'
47 | if has_chinese_chars(agent_descs):
48 | lang = 'zh'
49 | system_prompt = PROMPT_TEMPLATE[lang].format(
50 | agent_descs=agent_descs,
51 | agent_names=', '.join([x.name for x in agents]))
52 |
53 | super().__init__(function_list=function_list,
54 | llm=llm,
55 | system_message=system_prompt,
56 | name=name,
57 | description=description,
58 | **kwargs)
59 |
60 | def _run(self,
61 | messages: List[Message],
62 | lang: str = 'en',
63 | **kwargs) -> Iterator[List[Message]]:
64 |
65 | dialogue = []
66 | for msg in messages:
67 | if msg.role == 'function' or not msg.content:
68 | continue
69 | if isinstance(msg.content, list):
70 | content = '\n'.join(
71 | [x.text if x.text else '' for x in msg.content]).strip()
72 | else:
73 | content = msg.content.strip()
74 | display_name = msg.role
75 | if msg.name:
76 | display_name = msg.name
77 | if dialogue and dialogue[-1].startswith(display_name):
78 | dialogue[-1] += f'\n{content}'
79 | else:
80 | dialogue.append(f'{display_name}: {content}')
81 |
82 | if not dialogue:
83 | dialogue.append('对话刚开始,请任意选择一个发言人,别选真实用户')
84 | new_messages = [Message('user', '\n'.join(dialogue))]
85 |
86 | return self._call_llm(messages=new_messages)
87 |
--------------------------------------------------------------------------------
/qwen_agent/agents/group_chat_creator.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import json
3 | from typing import Dict, Iterator, List, Optional, Tuple, Union
4 |
5 | import json5
6 |
7 | from qwen_agent import Agent
8 | from qwen_agent.llm import BaseChatModel
9 | from qwen_agent.llm.schema import Message
10 | from qwen_agent.tools import BaseTool
11 |
12 | CONFIG_SCHEMA = {
13 | 'name': '... # 角色名字,5字左右',
14 | 'description': '... # 角色简介,10字左右',
15 | 'instructions': '... # 对角色的具体功能要求,30字左右,以第二人称称呼角色'
16 | }
17 |
18 | CONFIG_EXAMPLE = {
19 | 'name': '小红书写作专家',
20 | 'description': '我会写小红书爆款',
21 | 'instructions': '你是小红书爆款写作专家,创作会先产5个标题(含emoji),再产正文(每段落含emoji,文末有tag)。'
22 | }
23 |
24 | BACKGROUND_TOKEN = ''
25 | CONFIG_TOKEN = ''
26 | ANSWER_TOKEN = ''
27 |
28 | ROLE_CREATE_SYSTEM = '''你扮演创建群聊的助手,请你根据用户输入的聊天主题,创建n个合适的虚拟角色,这些角色将在一个聊天室内对话,你需要和用户进行对话,明确用户对这些角色的要求。
29 |
30 | 配置文件为json格式:
31 | {config_schema}
32 |
33 | 一个优秀的RichConfig样例如下:
34 | {config_example}
35 |
36 | 在接下来的对话中,请在回答时严格使用如下格式,先生成群聊背景,然后依次生成所有角色的配置文件,最后再作出回复,除此之外不要回复其他任何内容:
37 | {background_token}: ... # 生成的群聊背景,包括人物关系,预设故事背景等信息。
38 | {config_token}: ... # 生成的第一个角色的配置文件,严格按照以上json格式,禁止为空。保证name和description不为空。instructions内容比description具体,如果用户给出了详细指令,请完全保留,用第二人称描述角色,例如“你是xxx,你具有xxx能力。
39 | {config_token}: ... # 生成的第二个角色的配置文件,要求同上。
40 | ...
41 | {config_token}: ... # 生成的第n个角色的配置文件,要求同上,如果用户没有明确指出n的数量,则n等于3;要求每个角色的名字不相同。
42 | {answer_token}: ... # 你希望对用户说的话,用于询问用户对角色的要求,禁止为空,问题要广泛,不要重复问类似的问题。
43 |
44 | 如果群聊背景或某个角色的配置文件不需要更新,可以不重复输出{background_token}和对应的{config_token}的内容、只输出{answer_token}和需要修改的{config_token}的内容。'''.format(
45 | config_schema=json.dumps(CONFIG_SCHEMA, ensure_ascii=False, indent=2),
46 | config_example=json.dumps(CONFIG_EXAMPLE, ensure_ascii=False, indent=2),
47 | background_token=BACKGROUND_TOKEN,
48 | config_token=CONFIG_TOKEN,
49 | answer_token=ANSWER_TOKEN,
50 | )
51 | assert CONFIG_TOKEN in ROLE_CREATE_SYSTEM
52 | assert ANSWER_TOKEN in ROLE_CREATE_SYSTEM
53 |
54 |
55 | class GroupChatCreator(Agent):
56 |
57 | def __init__(self,
58 | function_list: Optional[List[Union[str, Dict,
59 | BaseTool]]] = None,
60 | llm: Optional[Union[Dict, BaseChatModel]] = None,
61 | name: Optional[str] = None,
62 | description: Optional[str] = None,
63 | **kwargs):
64 | super().__init__(function_list=function_list,
65 | llm=llm,
66 | system_message=ROLE_CREATE_SYSTEM,
67 | name=name,
68 | description=description,
69 | **kwargs)
70 |
71 | def _run(self,
72 | messages: List[Message],
73 | agents: List[Agent] = None,
74 | lang: str = 'en',
75 | **kwargs) -> Iterator[List[Message]]:
76 | messages = copy.deepcopy(messages)
77 | messages = self._preprocess_messages(messages)
78 |
79 | for rsp in self._call_llm(messages=messages):
80 | yield self._postprocess_messages(rsp)
81 |
82 | def _preprocess_messages(self, messages: List[Message]) -> List[Message]:
83 | new_messages = []
84 | content = []
85 | for message in messages:
86 | if message.role != 'assistant':
87 | new_messages.append(message)
88 | else:
89 | if message.name == 'background':
90 | content.append(f'{BACKGROUND_TOKEN}: {message.content}')
91 | elif message.name == 'role_config':
92 | content.append(f'{CONFIG_TOKEN}: {message.content}')
93 | else:
94 | content.append(f'{ANSWER_TOKEN}: {message.content}')
95 | assert new_messages[-1].role == 'user'
96 | new_messages.append(
97 | Message('assistant', '\n'.join(content)))
98 | content = []
99 | return new_messages
100 |
101 | def _postprocess_messages(self, messages: List[Message]) -> List[Message]:
102 | new_messages = []
103 | assert len(messages) == 1
104 | message = messages[-1]
105 | background, cfgs, answer = self._extract_role_config_and_answer(
106 | message.content)
107 | if background:
108 | new_messages.append(
109 | Message(message.role, background, name='background'))
110 | if cfgs:
111 | for cfg in cfgs:
112 | new_messages.append(
113 | Message(message.role, cfg, name='role_config'))
114 |
115 | new_messages.append(Message(message.role, answer, name=message.name))
116 | return new_messages
117 |
118 | def _extract_role_config_and_answer(
119 | self, text: str) -> Tuple[str, List[str], str]:
120 | background, cfgs, answer = '', [], ''
121 | back_pos, cfg_pos, ans_pos = text.find(
122 | f'{BACKGROUND_TOKEN}: '), text.find(
123 | f'{CONFIG_TOKEN}: '), text.find(f'{ANSWER_TOKEN}: ')
124 |
125 | if ans_pos > -1:
126 | answer = text[ans_pos + len(f'{ANSWER_TOKEN}: '):]
127 | else:
128 | ans_pos = len(text)
129 |
130 | if back_pos > -1:
131 | if cfg_pos > back_pos:
132 | background = text[back_pos +
133 | len(f'{BACKGROUND_TOKEN}: '):cfg_pos]
134 | else:
135 | background = text[back_pos +
136 | len(f'{BACKGROUND_TOKEN}: '):ans_pos]
137 | text = text[:ans_pos]
138 |
139 | tmp = text.split(f'{CONFIG_TOKEN}: ')
140 | for t in tmp:
141 | if t.strip():
142 | try:
143 | _ = json5.loads(t.strip())
144 | cfgs.append(t.strip())
145 | except Exception:
146 | continue
147 |
148 | if not (background or cfgs or answer):
149 | # There should always be ANSWER_TOKEN, if not, treat the entire content as answer
150 | answer = text
151 | return background, cfgs, answer
152 |
--------------------------------------------------------------------------------
/qwen_agent/agents/react_chat.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from typing import Dict, Iterator, List, Optional, Tuple, Union
3 |
4 | from qwen_agent.agents.fncall_agent import MAX_LLM_CALL_PER_RUN, FnCallAgent
5 | from qwen_agent.llm import BaseChatModel
6 | from qwen_agent.llm.schema import (ASSISTANT, CONTENT, DEFAULT_SYSTEM_MESSAGE,
7 | ROLE, ContentItem, Message)
8 | from qwen_agent.tools import BaseTool
9 | from qwen_agent.utils.utils import (get_basename_from_url,
10 | get_function_description,
11 | has_chinese_chars)
12 |
13 | PROMPT_REACT = """Answer the following questions as best you can. You have access to the following tools:
14 |
15 | {tool_descs}
16 |
17 | Use the following format:
18 |
19 | Question: the input question you must answer
20 | Thought: you should always think about what to do
21 | Action: the action to take, should be one of [{tool_names}]
22 | Action Input: the input to the action
23 | Observation: the result of the action
24 | ... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
25 | Thought: I now know the final answer
26 | Final Answer: the final answer to the original input question
27 |
28 | Begin!
29 |
30 | Question: {query}"""
31 |
32 |
33 | class ReActChat(FnCallAgent):
34 | """This agent use ReAct format to call tools"""
35 |
36 | def __init__(self,
37 | function_list: Optional[List[Union[str, Dict,
38 | BaseTool]]] = None,
39 | llm: Optional[Union[Dict, BaseChatModel]] = None,
40 | system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
41 | name: Optional[str] = None,
42 | description: Optional[str] = None,
43 | files: Optional[List[str]] = None):
44 | super().__init__(function_list=function_list,
45 | llm=llm,
46 | system_message=system_message,
47 | name=name,
48 | description=description,
49 | files=files)
50 | stop = self.llm.generate_cfg.get('stop', [])
51 | fn_stop = ['Observation:', 'Observation:\n']
52 | self.llm.generate_cfg['stop'] = stop + [
53 | x for x in fn_stop if x not in stop
54 | ]
55 |
56 | def _run(self,
57 | messages: List[Message],
58 | lang: str = 'en',
59 | **kwargs) -> Iterator[List[Message]]:
60 | ori_messages = messages
61 | messages = self._preprocess_react_prompt(messages)
62 |
63 | num_llm_calls_available = MAX_LLM_CALL_PER_RUN
64 | response = []
65 | while True and num_llm_calls_available > 0:
66 | num_llm_calls_available -= 1
67 | output_stream = self._call_llm(messages=messages)
68 | output = []
69 |
70 | # Yield the streaming response
71 | response_tmp = copy.deepcopy(response)
72 | for output in output_stream:
73 | if output:
74 | if not response_tmp:
75 | yield output
76 | else:
77 | response_tmp[-1][CONTENT] = response[-1][
78 | CONTENT] + output[-1][CONTENT]
79 | yield response_tmp
80 | # Record the incremental response
81 | assert len(output) == 1 and output[-1][ROLE] == ASSISTANT
82 | if not response:
83 | response += output
84 | else:
85 | response[-1][CONTENT] += output[-1][CONTENT]
86 |
87 | output = output[-1][CONTENT]
88 |
89 | use_tool, action, action_input, text = self._detect_tool(output)
90 |
91 | if use_tool:
92 | observation = self._call_tool(action,
93 | action_input,
94 | messages=ori_messages)
95 | observation = f'\nObservation: {observation}\nThought: '
96 | response[-1][CONTENT] += observation
97 | yield response
98 | if isinstance(messages[-1][CONTENT], list):
99 | if not ('text' in messages[-1][CONTENT][-1]
100 | and messages[-1][CONTENT][-1]['text'].endswith(
101 | '\nThought: ')):
102 | if not text.startswith('\n'):
103 | text = '\n' + text
104 | messages[-1][CONTENT].append(
105 | ContentItem(
106 | text=text +
107 | f'\nAction: {action}\nAction Input:{action_input}'
108 | + observation))
109 | else:
110 | if not (messages[-1][CONTENT].endswith('\nThought: ')):
111 | if not text.startswith('\n'):
112 | text = '\n' + text
113 | messages[-1][
114 | CONTENT] += text + f'\nAction: {action}\nAction Input:{action_input}' + observation
115 | else:
116 | break
117 |
118 | def _detect_tool(self, text: str) -> Tuple[bool, str, str, str]:
119 | special_func_token = '\nAction:'
120 | special_args_token = '\nAction Input:'
121 | special_obs_token = '\nObservation:'
122 | func_name, func_args = None, None
123 | i = text.rfind(special_func_token)
124 | j = text.rfind(special_args_token)
125 | k = text.rfind(special_obs_token)
126 | if 0 <= i < j: # If the text has `Action` and `Action input`,
127 | if k < j: # but does not contain `Observation`,
128 | # then it is likely that `Observation` is ommited by the LLM,
129 | # because the output text may have discarded the stop word.
130 | text = text.rstrip() + special_obs_token # Add it back.
131 | k = text.rfind(special_obs_token)
132 | func_name = text[i + len(special_func_token):j].strip()
133 | func_args = text[j + len(special_args_token):k].strip()
134 | text = text[:i] # Return the response before tool call
135 |
136 | return (func_name is not None), func_name, func_args, text
137 |
138 | def _preprocess_react_prompt(self,
139 | messages: List[Message]) -> List[Message]:
140 | messages = copy.deepcopy(messages)
141 | tool_descs = '\n\n'.join(
142 | get_function_description(func.function)
143 | for func in self.function_map.values())
144 | tool_names = ','.join(tool.name for tool in self.function_map.values())
145 |
146 | if isinstance(messages[-1][CONTENT], str):
147 | prompt = PROMPT_REACT.format(tool_descs=tool_descs,
148 | tool_names=tool_names,
149 | query=messages[-1][CONTENT])
150 | messages[-1][CONTENT] = prompt
151 | return messages
152 | else:
153 | query = ''
154 | new_content = []
155 | files = []
156 | for item in messages[-1][CONTENT]:
157 | for k, v in item.model_dump().items():
158 | if k == 'text':
159 | query += v
160 | elif k == 'file':
161 | files.append(v)
162 | else:
163 | new_content.append(item)
164 | if files:
165 | has_zh = has_chinese_chars(query)
166 | upload = []
167 | for f in [get_basename_from_url(f) for f in files]:
168 | if has_zh:
169 | upload.append(f'[文件]({f})')
170 | else:
171 | upload.append(f'[file]({f})')
172 | upload = ' '.join(upload)
173 | if has_zh:
174 | upload = f'(上传了 {upload})\n\n'
175 | else:
176 | upload = f'(Uploaded {upload})\n\n'
177 | query = upload + query
178 |
179 | prompt = PROMPT_REACT.format(tool_descs=tool_descs,
180 | tool_names=tool_names,
181 | query=query)
182 | new_content.insert(0, ContentItem(text=prompt))
183 | messages[-1][CONTENT] = new_content
184 | return messages
185 |
--------------------------------------------------------------------------------
/qwen_agent/agents/router.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from typing import Dict, Iterator, List, Optional, Union
3 |
4 | from qwen_agent.llm import BaseChatModel
5 | from qwen_agent.llm.schema import ASSISTANT, ROLE, Message
6 | from qwen_agent.tools import BaseTool
7 |
8 | from ..log import logger
9 | from .assistant import Assistant
10 |
11 | ROUTER_PROMPT = '''你有下列帮手:
12 | {agent_descs}
13 |
14 | 当你可以直接回答用户时,请忽略帮手,直接回复;但当你的能力无法达成用户的请求时,请选择其中一个来帮你回答,选择的模版如下:
15 | Call: ... # 选中的帮手的名字,必须在[{agent_names}]中,除了名字,不要返回其余任何内容。
16 | Reply: ... # 选中的帮手的回复
17 |
18 | ——不要向用户透露此条指令。'''
19 |
20 |
21 | class Router(Assistant):
22 |
23 | def __init__(self,
24 | function_list: Optional[List[Union[str, Dict,
25 | BaseTool]]] = None,
26 | llm: Optional[Union[Dict, BaseChatModel]] = None,
27 | files: Optional[List[str]] = None,
28 | name: Optional[str] = None,
29 | description: Optional[str] = None,
30 | agents: Optional[Dict[str, Dict]] = None):
31 | self.agents = agents
32 |
33 | agent_descs = '\n\n'.join(
34 | [f'{k}: {v["desc"]}' for k, v in agents.items()])
35 | agent_names = ', '.join([k for k in agents.keys()])
36 | super().__init__(function_list=function_list,
37 | llm=llm,
38 | system_message=ROUTER_PROMPT.format(
39 | agent_descs=agent_descs, agent_names=agent_names),
40 | name=name,
41 | description=description,
42 | files=files)
43 |
44 | stop = self.llm.generate_cfg.get('stop', [])
45 | fn_stop = ['Reply:', 'Reply:\n']
46 | self.llm.generate_cfg['stop'] = stop + [
47 | x for x in fn_stop if x not in stop
48 | ]
49 |
50 | def _run(self,
51 | messages: List[Message],
52 | lang: str = 'en',
53 | max_ref_token: int = 4000,
54 | **kwargs) -> Iterator[List[Message]]:
55 | # This is a temporary plan to determine the source of a message
56 | messages_for_router = []
57 | for msg in messages:
58 | if msg[ROLE] == ASSISTANT:
59 | msg = self.supplement_name_special_token(msg)
60 | messages_for_router.append(msg)
61 | response = []
62 | for response in super()._run(messages=messages_for_router,
63 | lang=lang,
64 | max_ref_token=max_ref_token,
65 | **kwargs): # noqa
66 | yield response
67 |
68 | if 'Call:' in response[-1].content:
69 | # According to the rule in prompt to selected agent
70 | selected_agent_name = response[-1].content.split(
71 | 'Call:')[-1].strip()
72 | logger.info(f'Need help from {selected_agent_name}')
73 | selected_agent = self.agents[selected_agent_name]['obj']
74 | for response in selected_agent.run(messages=messages,
75 | lang=lang,
76 | max_ref_token=max_ref_token,
77 | **kwargs):
78 | for i in range(len(response)):
79 | if response[i].role == ASSISTANT:
80 | response[i].name = selected_agent_name
81 | yield response
82 |
83 | @staticmethod
84 | def supplement_name_special_token(message: Message) -> Message:
85 | message = copy.deepcopy(message)
86 | if not message.name:
87 | return message
88 |
89 | if isinstance(message['content'], str):
90 | message['content'] = 'Call: ' + message[
91 | 'name'] + '\nReply:' + message['content']
92 | return message
93 | assert isinstance(message['content'], list)
94 | for i, item in enumerate(message['content']):
95 | for k, v in item.model_dump().items():
96 | if k == 'text':
97 | message['content'][i][k] = 'Call: ' + message[
98 | 'name'] + '\nReply:' + message['content'][i][k]
99 | break
100 | return message
101 |
--------------------------------------------------------------------------------
/qwen_agent/agents/user_agent.py:
--------------------------------------------------------------------------------
1 | from typing import Iterator, List
2 |
3 | from qwen_agent.agents.assistant import Assistant
4 | from qwen_agent.llm.schema import Message
5 |
6 | PENDING_USER_INPUT = ''
7 |
8 |
9 | class UserAgent(Assistant):
10 |
11 | def _run(self,
12 | messages: List[Message],
13 | lang: str = 'en',
14 | max_ref_token: int = 4000,
15 | **kwargs) -> Iterator[List[Message]]:
16 |
17 | yield [
18 | Message(role='user', content=PENDING_USER_INPUT, name=self.name)
19 | ]
20 |
--------------------------------------------------------------------------------
/qwen_agent/agents/write_from_scratch.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import Iterator, List
3 |
4 | import json5
5 |
6 | from qwen_agent import Agent
7 | from qwen_agent.llm.schema import ASSISTANT, CONTENT, USER, Message
8 | from qwen_agent.prompts import DocQA, ExpandWriting, OutlineWriting
9 |
10 | default_plan = """{"action1": "summarize", "action2": "outline", "action3": "expand"}"""
11 |
12 |
13 | def is_roman_numeral(s):
14 | pattern = r'^(I|V|X|L|C|D|M)+'
15 | match = re.match(pattern, s)
16 | return match is not None
17 |
18 |
19 | class WriteFromScratch(Agent):
20 |
21 | def _run(self,
22 | messages: List[Message],
23 | knowledge: str = '',
24 | lang: str = 'en') -> Iterator[List[Message]]:
25 |
26 | response = [
27 | Message(ASSISTANT, f'>\n> Use Default plans: \n{default_plan}')
28 | ]
29 | yield response
30 | res_plans = json5.loads(default_plan)
31 |
32 | summ = ''
33 | outline = ''
34 | for plan_id in sorted(res_plans.keys()):
35 | plan = res_plans[plan_id]
36 | if plan == 'summarize':
37 | response.append(
38 | Message(ASSISTANT, '>\n> Summarize Browse Content: \n'))
39 | yield response
40 |
41 | if lang == 'zh':
42 | user_request = '总结参考资料的主要内容'
43 | elif lang == 'en':
44 | user_request = 'Summarize the main content of reference materials.'
45 | else:
46 | raise NotImplementedError
47 | sum_agent = DocQA(llm=self.llm)
48 | res_sum = sum_agent.run(messages=[Message(USER, user_request)],
49 | knowledge=knowledge,
50 | lang=lang)
51 | trunk = None
52 | for trunk in res_sum:
53 | yield response + trunk
54 | if trunk:
55 | response.extend(trunk)
56 | summ = trunk[-1][CONTENT]
57 | elif plan == 'outline':
58 | response.append(Message(ASSISTANT,
59 | '>\n> Generate Outline: \n'))
60 | yield response
61 |
62 | otl_agent = OutlineWriting(llm=self.llm)
63 | res_otl = otl_agent.run(messages=messages,
64 | knowledge=summ,
65 | lang=lang)
66 | trunk = None
67 | for trunk in res_otl:
68 | yield response + trunk
69 | if trunk:
70 | response.extend(trunk)
71 | outline = trunk[-1][CONTENT]
72 | elif plan == 'expand':
73 | response.append(Message(ASSISTANT, '>\n> Writing Text: \n'))
74 | yield response
75 |
76 | outline_list_all = outline.split('\n')
77 | outline_list = []
78 | for x in outline_list_all:
79 | if is_roman_numeral(x):
80 | outline_list.append(x)
81 |
82 | otl_num = len(outline_list)
83 | for i, v in enumerate(outline_list):
84 | response.append(Message(ASSISTANT, '>\n# '))
85 | yield response
86 |
87 | index = i + 1
88 | capture = v.strip()
89 | capture_later = ''
90 | if i < otl_num - 1:
91 | capture_later = outline_list[i + 1].strip()
92 | exp_agent = ExpandWriting(llm=self.llm)
93 | res_exp = exp_agent.run(
94 | messages=messages,
95 | knowledge=knowledge,
96 | outline=outline,
97 | index=str(index),
98 | capture=capture,
99 | capture_later=capture_later,
100 | lang=lang,
101 | )
102 | trunk = None
103 | for trunk in res_exp:
104 | yield response + trunk
105 | if trunk:
106 | response.extend(trunk)
107 | else:
108 | pass
109 |
--------------------------------------------------------------------------------
/qwen_agent/llm/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional
2 |
3 | from qwen_agent.llm.base import LLM_REGISTRY
4 |
5 | from .base import BaseChatModel, ModelServiceError
6 | from .oai import TextChatAtOAI
7 | from .qwen_dashscope import QwenChatAtDS
8 | from .qwenvl_dashscope import QwenVLChatAtDS
9 |
10 |
11 | def get_chat_model(cfg: Optional[Dict] = None) -> BaseChatModel:
12 | """The interface of instantiating LLM objects.
13 |
14 | Args:
15 | cfg: The LLM configuration, one example is:
16 | llm_cfg = {
17 | # Use the model service provided by DashScope:
18 | 'model': 'qwen-max',
19 | 'model_server': 'dashscope',
20 | # Use your own model service compatible with OpenAI API:
21 | # 'model': 'Qwen',
22 | # 'model_server': 'http://127.0.0.1:7905/v1',
23 | # (Optional) LLM hyper-paramters:
24 | 'generate_cfg': {
25 | 'top_p': 0.8
26 | }
27 | }
28 |
29 | Returns:
30 | LLM object.
31 | """
32 | cfg = cfg or {}
33 | if 'model_type' in cfg:
34 | model_type = cfg['model_type']
35 | if model_type in LLM_REGISTRY:
36 | return LLM_REGISTRY[model_type](cfg)
37 | else:
38 | raise ValueError(
39 | f'Please set model_type from {str(LLM_REGISTRY.keys())}')
40 |
41 | # Deduce model_type from model and model_server if model_type is not provided:
42 |
43 | if 'model_server' in cfg:
44 | if cfg['model_server'].strip().startswith('http'):
45 | model_type = 'oai'
46 | return LLM_REGISTRY[model_type](cfg)
47 |
48 | model = cfg.get('model', '')
49 |
50 | if 'qwen-vl' in model:
51 | model_type = 'qwenvl_dashscope'
52 | return LLM_REGISTRY[model_type](cfg)
53 |
54 | if 'qwen' in model:
55 | model_type = 'qwen_dashscope'
56 | return LLM_REGISTRY[model_type](cfg)
57 |
58 | raise ValueError(f'Invalid model cfg: {cfg}')
59 |
60 |
61 | __all__ = [
62 | 'BaseChatModel',
63 | 'QwenChatAtDS',
64 | 'TextChatAtOAI',
65 | 'QwenVLChatAtDS',
66 | 'get_chat_model',
67 | 'ModelServiceError',
68 | ]
69 |
--------------------------------------------------------------------------------
/qwen_agent/llm/oai.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pprint import pformat
3 | from typing import Dict, Iterator, List, Optional
4 |
5 | import openai
6 |
7 | if openai.__version__.startswith('0.'):
8 | from openai.error import OpenAIError
9 | else:
10 | from openai import OpenAIError
11 |
12 | from qwen_agent.llm.base import ModelServiceError, register_llm
13 | from qwen_agent.llm.text_base import BaseTextChatModel
14 | from qwen_agent.log import logger
15 |
16 | from .schema import ASSISTANT, Message
17 |
18 |
19 | @register_llm('oai')
20 | class TextChatAtOAI(BaseTextChatModel):
21 |
22 | def __init__(self, cfg: Optional[Dict] = None):
23 | super().__init__(cfg)
24 | self.model = self.model or 'gpt-3.5-turbo'
25 | cfg = cfg or {}
26 |
27 | api_base = cfg.get(
28 | 'api_base',
29 | cfg.get(
30 | 'base_url',
31 | cfg.get('model_server', ''),
32 | ),
33 | ).strip()
34 |
35 | api_key = cfg.get('api_key', '')
36 | if not api_key:
37 | api_key = os.getenv('OPENAI_API_KEY', 'EMPTY')
38 | api_key = api_key.strip()
39 |
40 | if openai.__version__.startswith('0.'):
41 | if api_base:
42 | openai.api_base = api_base
43 | if api_key:
44 | openai.api_key = api_key
45 | self._chat_complete_create = openai.ChatCompletion.create
46 | else:
47 | api_kwargs = {}
48 | if api_base:
49 | api_kwargs['base_url'] = api_base
50 | if api_key:
51 | api_kwargs['api_key'] = api_key
52 |
53 | # OpenAI API v1 does not allow the following args, must pass by extra_body
54 | extra_params = ['top_k', 'repetition_penalty']
55 | if any((k in self.generate_cfg) for k in extra_params):
56 | self.generate_cfg['extra_body'] = {}
57 | for k in extra_params:
58 | if k in self.generate_cfg:
59 | self.generate_cfg['extra_body'][
60 | k] = self.generate_cfg.pop(k)
61 | if 'request_timeout' in self.generate_cfg:
62 | self.generate_cfg['timeout'] = self.generate_cfg.pop(
63 | 'request_timeout')
64 |
65 | def _chat_complete_create(*args, **kwargs):
66 | client = openai.OpenAI(**api_kwargs)
67 | return client.chat.completions.create(*args, **kwargs)
68 |
69 | self._chat_complete_create = _chat_complete_create
70 |
71 | def _chat_stream(
72 | self,
73 | messages: List[Message],
74 | delta_stream: bool = False,
75 | ) -> Iterator[List[Message]]:
76 | messages = [msg.model_dump() for msg in messages]
77 | logger.debug(f'*{pformat(messages, indent=2)}*')
78 | try:
79 | response = self._chat_complete_create(model=self.model,
80 | messages=messages,
81 | stream=True,
82 | **self.generate_cfg)
83 | if delta_stream:
84 | for chunk in response:
85 | if hasattr(chunk.choices[0].delta,
86 | 'content') and chunk.choices[0].delta.content:
87 | yield [
88 | Message(ASSISTANT, chunk.choices[0].delta.content)
89 | ]
90 | else:
91 | full_response = ''
92 | for chunk in response:
93 | if hasattr(chunk.choices[0].delta,
94 | 'content') and chunk.choices[0].delta.content:
95 | full_response += chunk.choices[0].delta.content
96 | yield [Message(ASSISTANT, full_response)]
97 | except OpenAIError as ex:
98 | raise ModelServiceError(exception=ex)
99 |
100 | def _chat_no_stream(self, messages: List[Message]) -> List[Message]:
101 | messages = [msg.model_dump() for msg in messages]
102 | logger.debug(f'*{pformat(messages, indent=2)}*')
103 | try:
104 | response = self._chat_complete_create(model=self.model,
105 | messages=messages,
106 | stream=False,
107 | **self.generate_cfg)
108 | return [Message(ASSISTANT, response.choices[0].message.content)]
109 | except OpenAIError as ex:
110 | raise ModelServiceError(exception=ex)
111 |
--------------------------------------------------------------------------------
/qwen_agent/llm/qwen_dashscope.py:
--------------------------------------------------------------------------------
1 | import os
2 | from http import HTTPStatus
3 | from pprint import pformat
4 | from typing import Dict, Iterator, List, Optional, Union
5 |
6 | import dashscope
7 |
8 | from qwen_agent.llm.base import ModelServiceError, register_llm
9 | from qwen_agent.log import logger
10 |
11 | from .schema import ASSISTANT, DEFAULT_SYSTEM_MESSAGE, SYSTEM, USER, Message
12 | from .text_base import BaseTextChatModel
13 |
14 |
15 | @register_llm('qwen_dashscope')
16 | class QwenChatAtDS(BaseTextChatModel):
17 |
18 | def __init__(self, cfg: Optional[Dict] = None):
19 | super().__init__(cfg)
20 | self.model = self.model or 'qwen-max'
21 |
22 | cfg = cfg or {}
23 | api_key = cfg.get('api_key', '')
24 | if not api_key:
25 | api_key = os.getenv('DASHSCOPE_API_KEY', 'EMPTY')
26 | api_key = api_key.strip()
27 | dashscope.api_key = api_key
28 |
29 | def _chat_stream(
30 | self,
31 | messages: List[Message],
32 | delta_stream: bool = False,
33 | ) -> Iterator[List[Message]]:
34 | messages = [msg.model_dump() for msg in messages]
35 | logger.debug(f'*{pformat(messages, indent=2)}*')
36 | response = dashscope.Generation.call(
37 | self.model,
38 | messages=messages, # noqa
39 | result_format='message',
40 | stream=True,
41 | **self.generate_cfg)
42 | if delta_stream:
43 | return self._delta_stream_output(response)
44 | else:
45 | return self._full_stream_output(response)
46 |
47 | def _chat_no_stream(
48 | self,
49 | messages: List[Message],
50 | ) -> List[Message]:
51 | messages = [msg.model_dump() for msg in messages]
52 | logger.debug(f'*{pformat(messages, indent=2)}*')
53 | response = dashscope.Generation.call(
54 | self.model,
55 | messages=messages, # noqa
56 | result_format='message',
57 | stream=False,
58 | **self.generate_cfg)
59 | if response.status_code == HTTPStatus.OK:
60 | return [
61 | Message(ASSISTANT, response.output.choices[0].message.content)
62 | ]
63 | else:
64 | raise ModelServiceError(code=response.code,
65 | message=response.message)
66 |
67 | def _chat_with_functions(
68 | self,
69 | messages: List[Message],
70 | functions: List[Dict],
71 | stream: bool = True,
72 | delta_stream: bool = False
73 | ) -> Union[List[Message], Iterator[List[Message]]]:
74 | if delta_stream:
75 | raise NotImplementedError
76 |
77 | messages = self._prepend_fncall_system(messages, functions)
78 |
79 | # Using text completion
80 | prompt = self._build_text_completion_prompt(messages)
81 | if stream:
82 | return self._text_completion_stream(prompt, delta_stream)
83 | else:
84 | return self._text_completion_no_stream(prompt)
85 |
86 | def _text_completion_no_stream(
87 | self,
88 | prompt: str,
89 | ) -> List[Message]:
90 | logger.debug(f'*{prompt}*')
91 | response = dashscope.Generation.call(self.model,
92 | prompt=prompt,
93 | result_format='message',
94 | stream=False,
95 | use_raw_prompt=True,
96 | **self.generate_cfg)
97 | if response.status_code == HTTPStatus.OK:
98 | return [
99 | Message(ASSISTANT, response.output.choices[0].message.content)
100 | ]
101 | else:
102 | raise ModelServiceError(code=response.code,
103 | message=response.message)
104 |
105 | def _text_completion_stream(
106 | self,
107 | prompt: str,
108 | delta_stream: bool = False,
109 | ) -> Iterator[List[Message]]:
110 | logger.debug(f'*{prompt}*')
111 | response = dashscope.Generation.call(
112 | self.model,
113 | prompt=prompt, # noqa
114 | result_format='message',
115 | stream=True,
116 | use_raw_prompt=True,
117 | **self.generate_cfg)
118 | if delta_stream:
119 | return self._delta_stream_output(response)
120 | else:
121 | return self._full_stream_output(response)
122 |
123 | @staticmethod
124 | def _build_text_completion_prompt(messages: List[Message]) -> str:
125 | im_start = '<|im_start|>'
126 | im_end = '<|im_end|>'
127 | if messages[0].role == SYSTEM:
128 | sys = messages[0].content
129 | assert isinstance(sys, str)
130 | prompt = f'{im_start}{SYSTEM}\n{sys}{im_end}'
131 | else:
132 | prompt = f'{im_start}{SYSTEM}\n{DEFAULT_SYSTEM_MESSAGE}{im_end}'
133 | if messages[-1].role != ASSISTANT:
134 | messages.append(Message(ASSISTANT, ''))
135 | for msg in messages:
136 | assert isinstance(msg.content, str)
137 | if msg.role == USER:
138 | query = msg.content.lstrip('\n').rstrip()
139 | prompt += f'\n{im_start}{USER}\n{query}{im_end}'
140 | elif msg.role == ASSISTANT:
141 | response = msg.content.lstrip('\n').rstrip()
142 | prompt += f'\n{im_start}{ASSISTANT}\n{response}{im_end}'
143 | assert prompt.endswith(im_end)
144 | prompt = prompt[:-len(im_end)]
145 | return prompt
146 |
147 | @staticmethod
148 | def _delta_stream_output(response) -> Iterator[List[Message]]:
149 | last_len = 0
150 | delay_len = 5
151 | in_delay = False
152 | text = ''
153 | for trunk in response:
154 | if trunk.status_code == HTTPStatus.OK:
155 | text = trunk.output.choices[0].message.content
156 | if (len(text) - last_len) <= delay_len:
157 | in_delay = True
158 | continue
159 | else:
160 | in_delay = False
161 | real_text = text[:-delay_len]
162 | now_rsp = real_text[last_len:]
163 | yield [Message(ASSISTANT, now_rsp)]
164 | last_len = len(real_text)
165 | else:
166 | raise ModelServiceError(code=trunk.code, message=trunk.message)
167 | if text and (in_delay or (last_len != len(text))):
168 | yield [Message(ASSISTANT, text[last_len:])]
169 |
170 | @staticmethod
171 | def _full_stream_output(response) -> Iterator[List[Message]]:
172 | for trunk in response:
173 | if trunk.status_code == HTTPStatus.OK:
174 | yield [
175 | Message(ASSISTANT, trunk.output.choices[0].message.content)
176 | ]
177 | else:
178 | raise ModelServiceError(code=trunk.code, message=trunk.message)
179 |
--------------------------------------------------------------------------------
/qwen_agent/llm/qwenvl_dashscope.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import os
3 | import re
4 | from http import HTTPStatus
5 | from pprint import pformat
6 | from typing import Dict, Iterator, List, Optional
7 |
8 | import dashscope
9 |
10 | from qwen_agent.llm.base import ModelServiceError, register_llm
11 | from qwen_agent.llm.function_calling import BaseFnCallModel
12 | from qwen_agent.llm.text_base import format_as_text_messages
13 | from qwen_agent.log import logger
14 |
15 | from .schema import ContentItem, Message
16 |
17 |
18 | @register_llm('qwenvl_dashscope')
19 | class QwenVLChatAtDS(BaseFnCallModel):
20 |
21 | def __init__(self, cfg: Optional[Dict] = None):
22 | super().__init__(cfg)
23 | self.model = self.model or 'qwen-vl-max'
24 |
25 | cfg = cfg or {}
26 | api_key = cfg.get('api_key', '')
27 | if not api_key:
28 | api_key = os.getenv('DASHSCOPE_API_KEY', 'EMPTY')
29 | api_key = api_key.strip()
30 | dashscope.api_key = api_key
31 |
32 | def _chat_stream(
33 | self,
34 | messages: List[Message],
35 | delta_stream: bool = False,
36 | ) -> Iterator[List[Message]]:
37 | if delta_stream:
38 | raise NotImplementedError
39 |
40 | messages = _format_local_files(messages)
41 | messages = [msg.model_dump() for msg in messages]
42 | logger.debug(f'*{pformat(messages, indent=2)}*')
43 | response = dashscope.MultiModalConversation.call(
44 | model=self.model,
45 | messages=messages,
46 | result_format='message',
47 | stream=True,
48 | **self.generate_cfg)
49 |
50 | for trunk in response:
51 | if trunk.status_code == HTTPStatus.OK:
52 | yield _extract_vl_response(trunk)
53 | else:
54 | raise ModelServiceError(code=trunk.code, message=trunk.message)
55 |
56 | def _chat_no_stream(
57 | self,
58 | messages: List[Message],
59 | ) -> List[Message]:
60 | messages = _format_local_files(messages)
61 | messages = [msg.model_dump() for msg in messages]
62 | logger.debug(f'*{pformat(messages, indent=2)}*')
63 | response = dashscope.MultiModalConversation.call(
64 | model=self.model,
65 | messages=messages,
66 | result_format='message',
67 | stream=False,
68 | **self.generate_cfg)
69 | if response.status_code == HTTPStatus.OK:
70 | return _extract_vl_response(response=response)
71 | else:
72 | raise ModelServiceError(code=response.code,
73 | message=response.message)
74 |
75 | def _postprocess_messages(self, messages: List[Message],
76 | fncall_mode: bool) -> List[Message]:
77 | messages = super()._postprocess_messages(messages,
78 | fncall_mode=fncall_mode)
79 | # Make VL return the same format as text models for easy usage
80 | messages = format_as_text_messages(messages)
81 | return messages
82 |
83 |
84 | # DashScope Qwen-VL requires the following format for local files:
85 | # Linux & Mac: file:///home/images/test.png
86 | # Windows: file://D:/images/abc.png
87 | def _format_local_files(messages: List[Message]) -> List[Message]:
88 | messages = copy.deepcopy(messages)
89 | for msg in messages:
90 | if isinstance(msg.content, list):
91 | for item in msg.content:
92 | if item.image:
93 | fname = item.image
94 | if not fname.startswith((
95 | 'http://',
96 | 'https://',
97 | 'file://',
98 | )):
99 | if fname.startswith('~'):
100 | fname = os.path.expanduser(fname)
101 | if re.match(r'^[A-Za-z]:\\', fname):
102 | fname = fname.replace('\\', '/')
103 | item.image = fname
104 | return messages
105 |
106 |
107 | def _extract_vl_response(response) -> List[Message]:
108 | output = response.output.choices[0].message
109 | text_content = []
110 | for item in output.content:
111 | if isinstance(item, str):
112 | text_content.append(ContentItem(text=item))
113 | else:
114 | for k, v in item.items():
115 | if k in ('text', 'box'):
116 | text_content.append(ContentItem(text=v))
117 | return [Message(role=output.role, content=text_content)]
118 |
--------------------------------------------------------------------------------
/qwen_agent/llm/schema.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Union
2 |
3 | from pydantic import BaseModel, field_validator, model_validator
4 |
5 | DEFAULT_SYSTEM_MESSAGE = 'You are a helpful assistant.'
6 |
7 | ROLE = 'role'
8 | CONTENT = 'content'
9 |
10 | SYSTEM = 'system'
11 | USER = 'user'
12 | ASSISTANT = 'assistant'
13 | FUNCTION = 'function'
14 |
15 | FILE = 'file'
16 | IMAGE = 'image'
17 |
18 |
19 | class BaseModelCompatibleDict(BaseModel):
20 |
21 | def __getitem__(self, item):
22 | return getattr(self, item)
23 |
24 | def __setitem__(self, key, value):
25 | setattr(self, key, value)
26 |
27 | def model_dump(self, **kwargs):
28 | return super().model_dump(exclude_none=True, **kwargs)
29 |
30 | def model_dump_json(self, **kwargs):
31 | return super().model_dump_json(exclude_none=True, **kwargs)
32 |
33 | def get(self, key, default=None):
34 | try:
35 | value = getattr(self, key)
36 | if value:
37 | return value
38 | else:
39 | return default
40 | except AttributeError:
41 | return default
42 |
43 | def __str__(self):
44 | return f'{self.model_dump()}'
45 |
46 |
47 | class FunctionCall(BaseModelCompatibleDict):
48 | name: str
49 | arguments: str
50 |
51 | def __init__(self, name: str, arguments: str):
52 | super().__init__(name=name, arguments=arguments)
53 |
54 | def __repr__(self):
55 | return f'FunctionCall({self.model_dump()})'
56 |
57 |
58 | class ContentItem(BaseModelCompatibleDict):
59 | text: Optional[str] = None
60 | image: Optional[str] = None
61 | file: Optional[str] = None
62 |
63 | def __init__(self,
64 | text: Optional[str] = None,
65 | image: Optional[str] = None,
66 | file: Optional[str] = None):
67 | super().__init__(text=text, image=image, file=file)
68 |
69 | @model_validator(mode='after')
70 | def check_exclusivity(self):
71 | provided_fields = 0
72 | if self.text is not None:
73 | provided_fields += 1
74 | if self.image:
75 | provided_fields += 1
76 | if self.file:
77 | provided_fields += 1
78 |
79 | if provided_fields != 1:
80 | raise ValueError(
81 | "Exactly one of 'text', 'image', or 'file' must be provided.")
82 | return self
83 |
84 | def __repr__(self):
85 | return f'ContentItem({self.model_dump()})'
86 |
87 | def get_type_and_value(self):
88 | (t, v), = self.model_dump().items()
89 | assert t in ('text', 'image', 'file')
90 | return t, v
91 |
92 |
93 | class Message(BaseModelCompatibleDict):
94 | role: str
95 | content: Union[str, List[ContentItem]]
96 | name: Optional[str] = None
97 | function_call: Optional[FunctionCall] = None
98 |
99 | def __init__(self,
100 | role: str,
101 | content: Optional[Union[str, List[ContentItem]]],
102 | name: Optional[str] = None,
103 | function_call: Optional[FunctionCall] = None,
104 | **kwargs):
105 | if content is None:
106 | content = ''
107 | super().__init__(role=role,
108 | content=content,
109 | name=name,
110 | function_call=function_call)
111 |
112 | def __repr__(self):
113 | return f'Message({self.model_dump()})'
114 |
115 | @field_validator('role')
116 | def role_checker(cls, value: str) -> str:
117 | if value not in [USER, ASSISTANT, SYSTEM, FUNCTION]:
118 | raise ValueError(
119 | f'{value} must be one of {",".join([USER, ASSISTANT, SYSTEM, FUNCTION])}'
120 | )
121 | return value
122 |
--------------------------------------------------------------------------------
/qwen_agent/llm/text_base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 | from typing import List
3 |
4 | from qwen_agent.llm.function_calling import BaseFnCallModel
5 |
6 | from .schema import ASSISTANT, FUNCTION, SYSTEM, USER, Message
7 |
8 |
9 | class BaseTextChatModel(BaseFnCallModel, ABC):
10 |
11 | def _preprocess_messages(self, messages: List[Message]) -> List[Message]:
12 | messages = super()._preprocess_messages(messages)
13 | messages = format_as_text_messages(messages)
14 | return messages
15 |
16 | def _postprocess_messages(self, messages: List[Message],
17 | fncall_mode: bool) -> List[Message]:
18 | messages = super()._postprocess_messages(messages,
19 | fncall_mode=fncall_mode)
20 | messages = format_as_text_messages(messages)
21 | return messages
22 |
23 |
24 | def format_as_text_messages(
25 | multimodal_messages: List[Message]) -> List[Message]:
26 | text_messages = []
27 | for msg in multimodal_messages:
28 | assert msg.role in (USER, ASSISTANT, SYSTEM, FUNCTION)
29 | content = ''
30 | if isinstance(msg.content, str):
31 | content = msg.content
32 | elif isinstance(msg.content, list):
33 | for item in msg.content:
34 | if item.text:
35 | content += item.text
36 | # Discard multimodal content such as files and images
37 | else:
38 | raise TypeError
39 | text_messages.append(
40 | Message(role=msg.role,
41 | content=content,
42 | name=msg.name if msg.role == FUNCTION else None,
43 | function_call=msg.function_call))
44 | return text_messages
45 |
--------------------------------------------------------------------------------
/qwen_agent/log.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 |
5 | def setup_logger(level=None):
6 |
7 | if level is None:
8 | if int(os.getenv('QWEN_AGENT_DEBUG', '0').strip()):
9 | level = logging.DEBUG
10 | else:
11 | level = logging.INFO
12 |
13 | logger = logging.getLogger('qwen_agent_logger')
14 | logger.setLevel(level)
15 | handler = logging.StreamHandler()
16 | handler.setLevel(level)
17 | formatter = logging.Formatter(
18 | '%(asctime)s - %(filename)s - %(lineno)d - %(levelname)s - %(message)s'
19 | )
20 | handler.setFormatter(formatter)
21 | logger.addHandler(handler)
22 | return logger
23 |
24 |
25 | logger = setup_logger()
26 |
--------------------------------------------------------------------------------
/qwen_agent/memory/__init__.py:
--------------------------------------------------------------------------------
1 | from .memory import Memory
2 |
3 | __all__ = ['Memory']
4 |
--------------------------------------------------------------------------------
/qwen_agent/memory/memory.py:
--------------------------------------------------------------------------------
1 | import json
2 | from typing import Dict, Iterator, List, Optional, Union
3 |
4 | import json5
5 |
6 | from qwen_agent import Agent
7 | from qwen_agent.llm import BaseChatModel
8 | from qwen_agent.llm.schema import (ASSISTANT, DEFAULT_SYSTEM_MESSAGE, USER,
9 | Message)
10 | from qwen_agent.log import logger
11 | from qwen_agent.prompts import GenKeyword
12 | from qwen_agent.tools import BaseTool
13 | from qwen_agent.utils.utils import get_file_type
14 |
15 |
16 | class Memory(Agent):
17 | """Memory is special agent for file management.
18 |
19 | By default, this memory can use retrieval tool for RAG.
20 | """
21 |
22 | def __init__(self,
23 | function_list: Optional[List[Union[str, Dict,
24 | BaseTool]]] = None,
25 | llm: Optional[Union[Dict, BaseChatModel]] = None,
26 | system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
27 | files: Optional[List[str]] = None):
28 | function_list = function_list or []
29 | super().__init__(function_list=['retrieval'] + function_list,
30 | llm=llm,
31 | system_message=system_message)
32 |
33 | self.keygen = GenKeyword(llm=llm)
34 |
35 | self.system_files = files or []
36 |
37 | def _run(self,
38 | messages: List[Message],
39 | max_ref_token: int = 4000,
40 | lang: str = 'en',
41 | ignore_cache: bool = False) -> Iterator[List[Message]]:
42 | """This agent is responsible for processing the input files in the message.
43 |
44 | This method stores the files in the knowledge base, and retrievals the relevant parts
45 | based on the query and returning them.
46 | The currently supported file types include: .pdf, .docx, .pptx, and html.
47 |
48 | Args:
49 | messages: A list of messages.
50 | max_ref_token: Search window for reference materials.
51 | lang: Language.
52 | ignore_cache: Whether to reparse the same files.
53 |
54 | Yields:
55 | The message of retrieved documents.
56 | """
57 | # process files in messages
58 | session_files = self.get_all_files_of_messages(messages)
59 | files = self.system_files + session_files
60 | rag_files = []
61 | for file in files:
62 | if (file.split('.')[-1].lower() in [
63 | 'pdf', 'docx', 'pptx'
64 | ]) or get_file_type(file) == 'html':
65 | rag_files.append(file)
66 |
67 | if not rag_files:
68 | yield [Message(role=ASSISTANT, content='', name='memory')]
69 | else:
70 | query = ''
71 | # Only retrieval content according to the last user query if exists
72 | if messages and messages[-1].role == USER:
73 | if isinstance(messages[-1].content, str):
74 | query = messages[-1].content
75 | else:
76 | for item in messages[-1].content:
77 | if item.text:
78 | query += item.text
79 | if query:
80 | # Gen keyword
81 | *_, last = self.keygen.run([Message(USER, query)])
82 | keyword = last[-1].content
83 | keyword = keyword.strip()
84 | if keyword.startswith('```json'):
85 | keyword = keyword[len('```json'):]
86 | if keyword.endswith('```'):
87 | keyword = keyword[:-3]
88 | try:
89 | logger.info(keyword)
90 | keyword_dict = json5.loads(keyword)
91 | keyword_dict['text'] = query
92 | query = json.dumps(keyword_dict, ensure_ascii=False)
93 | except Exception:
94 | query = query
95 |
96 | content = self._call_tool('retrieval', {
97 | 'query': query,
98 | 'files': rag_files
99 | },
100 | ignore_cache=ignore_cache,
101 | max_token=max_ref_token)
102 |
103 | yield [Message(role=ASSISTANT, content=content, name='memory')]
104 |
105 | @staticmethod
106 | def get_all_files_of_messages(messages: List[Message]):
107 | files = []
108 | for msg in messages:
109 | if isinstance(msg.content, list):
110 | for item in msg.content:
111 | if item.file and item.file not in files:
112 | files.append(item.file)
113 | return files
114 |
--------------------------------------------------------------------------------
/qwen_agent/prompts/__init__.py:
--------------------------------------------------------------------------------
1 | """Prompts are special agents: using a prompt template to complete one QA."""
2 |
3 | from .continue_writing import ContinueWriting
4 | from .doc_qa import DocQA
5 | from .expand_writing import ExpandWriting
6 | from .gen_keyword import GenKeyword
7 | from .outline_writing import OutlineWriting
8 |
9 | __all__ = [
10 | 'DocQA', 'ContinueWriting', 'OutlineWriting', 'ExpandWriting', 'GenKeyword'
11 | ]
12 |
--------------------------------------------------------------------------------
/qwen_agent/prompts/continue_writing.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from typing import Iterator, List
3 |
4 | from qwen_agent import Agent
5 | from qwen_agent.llm.schema import CONTENT, Message
6 |
7 | PROMPT_TEMPLATE_ZH = """你是一个写作助手,请依据参考资料,根据给定的前置文本续写合适的内容。
8 | #参考资料:
9 | {ref_doc}
10 |
11 | #前置文本:
12 | {user_request}
13 |
14 | 保证续写内容和前置文本保持连贯,请开始续写:"""
15 |
16 | PROMPT_TEMPLATE_EN = """You are a writing assistant, please follow the reference materials and continue to write appropriate content based on the given previous text.
17 |
18 | # References:
19 | {ref_doc}
20 |
21 | # Previous text:
22 | {user_request}
23 |
24 | Please start writing directly, output only the continued text, do not repeat the previous text, do not say irrelevant words, and ensure that the continued content and the previous text remain consistent."""
25 |
26 | PROMPT_TEMPLATE = {
27 | 'zh': PROMPT_TEMPLATE_ZH,
28 | 'en': PROMPT_TEMPLATE_EN,
29 | }
30 |
31 |
32 | class ContinueWriting(Agent):
33 |
34 | def _run(self,
35 | messages: List[Message],
36 | knowledge: str = '',
37 | lang: str = 'en',
38 | **kwargs) -> Iterator[List[Message]]:
39 | messages = copy.deepcopy(messages)
40 | messages[-1][CONTENT] = PROMPT_TEMPLATE[lang].format(
41 | ref_doc=knowledge,
42 | user_request=messages[-1][CONTENT],
43 | )
44 | return self._call_llm(messages)
45 |
--------------------------------------------------------------------------------
/qwen_agent/prompts/doc_qa.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from typing import Iterator, List
3 |
4 | from qwen_agent import Agent
5 | from qwen_agent.llm.schema import CONTENT, ROLE, SYSTEM, Message
6 |
7 | PROMPT_TEMPLATE_ZH = """
8 | 请充分理解以下参考资料内容,组织出满足用户提问的条理清晰的回复。
9 | #参考资料:
10 | {ref_doc}
11 |
12 | """
13 |
14 | PROMPT_TEMPLATE_EN = """
15 | Please fully understand the content of the following reference materials and organize a clear response that meets the user's questions.
16 | # Reference materials:
17 | {ref_doc}
18 |
19 | """
20 |
21 | PROMPT_TEMPLATE = {
22 | 'zh': PROMPT_TEMPLATE_ZH,
23 | 'en': PROMPT_TEMPLATE_EN,
24 | }
25 |
26 |
27 | class DocQA(Agent):
28 |
29 | def _run(self,
30 | messages: List[Message],
31 | knowledge: str = '',
32 | lang: str = 'en',
33 | **kwargs) -> Iterator[List[Message]]:
34 | messages = copy.deepcopy(messages)
35 | system_prompt = PROMPT_TEMPLATE[lang].format(ref_doc=knowledge)
36 | if messages[0][ROLE] == SYSTEM:
37 | messages[0][CONTENT] += system_prompt
38 | else:
39 | messages.insert(0, Message(SYSTEM, system_prompt))
40 |
41 | return self._call_llm(messages=messages)
42 |
--------------------------------------------------------------------------------
/qwen_agent/prompts/expand_writing.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from typing import Iterator, List
3 |
4 | from qwen_agent import Agent
5 | from qwen_agent.llm.schema import CONTENT, Message
6 |
7 | PROMPT_TEMPLATE_ZH = """
8 | 你是一个写作助手,任务是依据参考资料,完成写作任务。
9 | #参考资料:
10 | {ref_doc}
11 |
12 | 写作标题是:{user_request}
13 | 大纲是:
14 | {outline}
15 |
16 | 此时你的任务是扩写第{index}个一级标题对应的章节:{capture}。注意每个章节负责撰写不同的内容,所以你不需要为了全面而涵盖之后的内容。请不要在这里生成大纲。只依据给定的参考资料来写,不要引入其余知识。
17 | """
18 |
19 | PROMPT_TEMPLATE_EN = """
20 | You are a writing assistant. Your task is to complete writing article based on reference materials.
21 |
22 | # References:
23 | {ref_doc}
24 |
25 | The title is: {user_request}
26 |
27 | The outline is:
28 | {outline}
29 |
30 | At this point, your task is to expand the chapter corresponding to the {index} first level title: {capture}.
31 | Note that each chapter is responsible for writing different content, so you don't need to cover the following content. Please do not generate an outline here. Write only based on the given reference materials and do not introduce other knowledge.
32 | """
33 |
34 | PROMPT_TEMPLATE = {
35 | 'zh': PROMPT_TEMPLATE_ZH,
36 | 'en': PROMPT_TEMPLATE_EN,
37 | }
38 |
39 |
40 | class ExpandWriting(Agent):
41 |
42 | def _run(self,
43 | messages: List[Message],
44 | knowledge: str = '',
45 | outline: str = '',
46 | index: str = '1',
47 | capture: str = '',
48 | capture_later: str = '',
49 | lang: str = 'en',
50 | **kwargs) -> Iterator[List[Message]]:
51 | messages = copy.deepcopy(messages)
52 | prompt = PROMPT_TEMPLATE[lang].format(
53 | ref_doc=knowledge,
54 | user_request=messages[-1][CONTENT],
55 | index=index,
56 | outline=outline,
57 | capture=capture,
58 | )
59 | if capture_later:
60 | if lang == 'zh':
61 | prompt = prompt + '请在涉及 ' + capture_later + ' 时停止。'
62 | elif lang == 'en':
63 | prompt = prompt + ' Please stop when writing ' + capture_later
64 | else:
65 | raise NotImplementedError
66 |
67 | messages[-1][CONTENT] = prompt
68 | return self._call_llm(messages)
69 |
--------------------------------------------------------------------------------
/qwen_agent/prompts/gen_keyword.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from typing import Dict, Iterator, List, Optional, Union
3 |
4 | from qwen_agent import Agent
5 | from qwen_agent.llm import get_chat_model
6 | from qwen_agent.llm.base import BaseChatModel
7 | from qwen_agent.llm.schema import CONTENT, DEFAULT_SYSTEM_MESSAGE, Message
8 | from qwen_agent.tools import BaseTool
9 |
10 | PROMPT_TEMPLATE_ZH = """请提取问题中的关键词,需要中英文均有,可以适量补充不在问题中但相关的关键词。关键词尽量切分为动词/名词/形容词等类型,不要长词组。关键词以JSON的格式给出,比如{{"keywords_zh": ["关键词1", "关键词2"], "keywords_en": ["keyword 1", "keyword 2"]}}
11 |
12 | Question: 这篇文章的作者是谁?
13 | Keywords: {{"keywords_zh": ["作者"], "keywords_en": ["author"]}}
14 | Observation: ...
15 |
16 | Question: 解释下图一
17 | Keywords: {{"keywords_zh": ["图一", "图 1"], "keywords_en": ["Figure 1"]}}
18 | Observation: ...
19 |
20 | Question: 核心公式
21 | Keywords: {{"keywords_zh": ["核心公式", "公式"], "keywords_en": ["core formula", "formula", "equation"]}}
22 | Observation: ...
23 |
24 | Question: {user_request}
25 | Keywords:
26 | """
27 |
28 | PROMPT_TEMPLATE_EN = """Please extract keywords from the question, both in Chinese and English, and supplement them appropriately with relevant keywords that are not in the question. Try to divide keywords into verb/noun/adjective types and avoid long phrases.
29 | Keywords are provided in JSON format, such as {{"keywords_zh": ["关键词1", "关键词2"], "keywords_en": ["keyword 1", "keyword 2"]}}
30 |
31 | Question: Who are the authors of this article?
32 | Keywords: {{"keywords_zh": ["作者"], "keywords_en": ["author"]}}
33 | Observation: ...
34 |
35 | Question: Explain Figure 1
36 | Keywords: {{"keywords_zh": ["图一", "图 1"], "keywords_en": ["Figure 1"]}}
37 | Observation: ...
38 |
39 | Question: core formula
40 | Keywords: {{"keywords_zh": ["核心公式", "公式"], "keywords_en": ["core formula", "formula", "equation"]}}
41 | Observation: ...
42 |
43 | Question: {user_request}
44 | Keywords:
45 | """
46 |
47 | PROMPT_TEMPLATE = {
48 | 'zh': PROMPT_TEMPLATE_ZH,
49 | 'en': PROMPT_TEMPLATE_EN,
50 | }
51 |
52 |
53 | class GenKeyword(Agent):
54 |
55 | # TODO: Adding a stop word is not conveient! We should fix this later.
56 | def __init__(self,
57 | function_list: Optional[List[Union[str, Dict,
58 | BaseTool]]] = None,
59 | llm: Optional[Union[Dict, BaseChatModel]] = None,
60 | system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
61 | **kwargs):
62 | if llm is not None: # TODO: Why this happens?
63 | llm = copy.deepcopy(llm)
64 | if isinstance(llm, dict):
65 | llm = get_chat_model(llm)
66 | stop = llm.generate_cfg.get('stop', [])
67 | key_stop = ['Observation:', 'Observation:\n']
68 | llm.generate_cfg['stop'] = stop + [
69 | x for x in key_stop if x not in stop
70 | ]
71 | super().__init__(function_list, llm, system_message, **kwargs)
72 |
73 | def _run(self,
74 | messages: List[Message],
75 | lang: str = 'en',
76 | **kwargs) -> Iterator[List[Message]]:
77 | messages = copy.deepcopy(messages)
78 | messages[-1][CONTENT] = PROMPT_TEMPLATE[lang].format(
79 | user_request=messages[-1][CONTENT])
80 | return self._call_llm(messages)
81 |
--------------------------------------------------------------------------------
/qwen_agent/prompts/outline_writing.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from typing import Iterator, List
3 |
4 | from qwen_agent import Agent
5 | from qwen_agent.llm.schema import CONTENT, Message
6 |
7 | PROMPT_TEMPLATE_ZH = """
8 | 你是一个写作助手,任务是充分理解参考资料,从而完成写作。
9 | #参考资料:
10 | {ref_doc}
11 |
12 | 写作标题是:{user_request}
13 |
14 | 为了完成以上写作任务,请先列出大纲。回复只需包含大纲。大纲的一级标题全部以罗马数字计数。只依据给定的参考资料来写,不要引入其余知识。
15 | """
16 |
17 | PROMPT_TEMPLATE_EN = """
18 | You are a writing assistant. Your task is to complete writing article based on reference materials.
19 |
20 | # References:
21 | {ref_doc}
22 |
23 | The title is: {user_request}
24 |
25 | In order to complete the above writing tasks, please provide an outline first. The reply only needs to include an outline. The first level titles of the outline are all counted in Roman numerals. Write only based on the given reference materials and do not introduce other knowledge.
26 | """
27 |
28 | PROMPT_TEMPLATE = {
29 | 'zh': PROMPT_TEMPLATE_ZH,
30 | 'en': PROMPT_TEMPLATE_EN,
31 | }
32 |
33 |
34 | class OutlineWriting(Agent):
35 |
36 | def _run(self,
37 | messages: List[Message],
38 | knowledge: str = '',
39 | lang: str = 'en',
40 | **kwargs) -> Iterator[List[Message]]:
41 | messages = copy.deepcopy(messages)
42 | messages[-1][CONTENT] = PROMPT_TEMPLATE[lang].format(
43 | ref_doc=knowledge,
44 | user_request=messages[-1][CONTENT],
45 | )
46 | return self._call_llm(messages)
47 |
--------------------------------------------------------------------------------
/qwen_agent/tools/__init__.py:
--------------------------------------------------------------------------------
1 | from .amap_weather import AmapWeather
2 | from .base import TOOL_REGISTRY, BaseTool
3 | # from .code_interpreter import CodeInterpreter
4 | from .doc_parser import DocParser
5 | from .image_gen import ImageGen
6 | from .retrieval import Retrieval
7 | from .similarity_search import SimilaritySearch
8 | from .storage import Storage
9 | from .web_extractor import WebExtractor
10 |
11 |
12 | def call_tool(plugin_name: str, plugin_args: str) -> str:
13 | if plugin_name in TOOL_REGISTRY:
14 | return TOOL_REGISTRY[plugin_name].call(plugin_args)
15 | else:
16 | raise NotImplementedError
17 |
18 |
19 | __all__ = [
20 | 'BaseTool', 'CodeInterpreter', 'ImageGen', 'AmapWeather', 'TOOL_REGISTRY',
21 | 'DocParser', 'SimilaritySearch', 'Storage', 'Retrieval', 'WebExtractor'
22 | ]
23 |
--------------------------------------------------------------------------------
/qwen_agent/tools/amap_weather.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Dict, Optional, Union
3 |
4 | import pandas as pd
5 | import requests
6 |
7 | from qwen_agent.tools.base import BaseTool, register_tool
8 |
9 |
10 | @register_tool('amap_weather')
11 | class AmapWeather(BaseTool):
12 | description = '获取对应城市的天气数据'
13 | parameters = [{
14 | 'name': 'location',
15 | 'type': 'string',
16 | 'description': 'get temperature for a specific location',
17 | 'required': True
18 | }]
19 |
20 | def __init__(self, cfg: Optional[Dict] = None):
21 | super().__init__(cfg)
22 |
23 | # remote call
24 | self.url = 'https://restapi.amap.com/v3/weather/weatherInfo?city={city}&key={key}'
25 | self.city_df = pd.read_excel(
26 | 'https://modelscope.oss-cn-beijing.aliyuncs.com/resource/agent/AMap_adcode_citycode.xlsx'
27 | )
28 |
29 | self.token = self.cfg.get('token', os.environ.get('AMAP_TOKEN', ''))
30 | assert self.token != '', 'weather api token must be acquired through ' \
31 | 'https://lbs.amap.com/api/webservice/guide/create-project/get-key and set by AMAP_TOKEN'
32 |
33 | def get_city_adcode(self, city_name):
34 | filtered_df = self.city_df[self.city_df['中文名'] == city_name]
35 | if len(filtered_df['adcode'].values) == 0:
36 | raise ValueError(
37 | f'location {city_name} not found, availables are {self.city_df["中文名"]}'
38 | )
39 | else:
40 | return filtered_df['adcode'].values[0]
41 |
42 | def call(self, params: Union[str, dict], **kwargs) -> str:
43 | params = self._verify_json_format_args(params)
44 |
45 | location = params['location']
46 | response = requests.get(
47 | self.url.format(city=self.get_city_adcode(location),
48 | key=self.token))
49 | data = response.json()
50 | if data['status'] == '0':
51 | raise RuntimeError(data)
52 | else:
53 | weather = data['lives'][0]['weather']
54 | temperature = data['lives'][0]['temperature']
55 | return f'{location}的天气是{weather}温度是{temperature}度。'
56 |
--------------------------------------------------------------------------------
/qwen_agent/tools/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Dict, List, Optional, Union
3 |
4 | import json5
5 |
6 | from qwen_agent.utils.utils import logger
7 |
8 | TOOL_REGISTRY = {}
9 |
10 |
11 | def register_tool(name, allow_overwrite=False):
12 |
13 | def decorator(cls):
14 | if name in TOOL_REGISTRY:
15 | if allow_overwrite:
16 | logger.warning(
17 | f'Tool `{name}` already exists! Overwriting with class {cls}.'
18 | )
19 | else:
20 | raise ValueError(
21 | f'Tool `{name}` already exists! Please ensure that the tool name is unique.'
22 | )
23 | if cls.name and (cls.name != name):
24 | raise ValueError(
25 | f'{cls.__name__}.name="{cls.name}" conflicts with @register_tool(name="{name}").'
26 | )
27 | cls.name = name
28 | TOOL_REGISTRY[name] = cls
29 |
30 | return cls
31 |
32 | return decorator
33 |
34 |
35 | class BaseTool(ABC):
36 | name: str = ''
37 | description: str = ''
38 | parameters: List[Dict] = []
39 |
40 | def __init__(self, cfg: Optional[Dict] = None):
41 | self.cfg = cfg or {}
42 | if not self.name:
43 | raise ValueError(
44 | f'You must set {self.__class__.__name__}.name, either by @register_tool(name=...) or explicitly setting {self.__class__.__name__}.name'
45 | )
46 |
47 | self.name_for_human = self.cfg.get('name_for_human', self.name)
48 | if not hasattr(self, 'args_format'):
49 | self.args_format = self.cfg.get('args_format', '此工具的输入应为JSON对象。')
50 | self.function = self._build_function()
51 | self.file_access = False
52 |
53 | @abstractmethod
54 | def call(self, params: Union[str, dict],
55 | **kwargs) -> Union[str, list, dict]:
56 | """The interface for calling tools.
57 |
58 | Each tool needs to implement this function, which is the workflow of the tool.
59 |
60 | Args:
61 | params: The parameters of func_call.
62 | kwargs: Additional parameters for calling tools.
63 |
64 | Returns:
65 | The result returned by the tool, implemented in the subclass.
66 | """
67 | raise NotImplementedError
68 |
69 | def _verify_json_format_args(self,
70 | params: Union[str, dict]) -> Union[str, dict]:
71 | """Verify the parameters of the function call"""
72 | try:
73 | if isinstance(params, str):
74 | params_json = json5.loads(params)
75 | else:
76 | params_json = params
77 | for param in self.parameters:
78 | if 'required' in param and param['required']:
79 | if param['name'] not in params_json:
80 | raise ValueError('Parameters %s is required!' %
81 | param['name'])
82 | return params_json
83 | except Exception:
84 | raise ValueError('Parameters cannot be converted to Json Format!')
85 |
86 | def _build_function(self) -> dict:
87 | return {
88 | 'name_for_human': self.name_for_human,
89 | 'name': self.name,
90 | 'description': self.description,
91 | 'parameters': self.parameters,
92 | 'args_format': self.args_format
93 | }
94 |
--------------------------------------------------------------------------------
/qwen_agent/tools/code_interpreter.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import atexit
3 | import base64
4 | import glob
5 | import io
6 | import json
7 | import os
8 | import queue
9 | import re
10 | import shutil
11 | import signal
12 | import stat
13 | import subprocess
14 | import sys
15 | import time
16 | import uuid
17 | from pathlib import Path
18 | from typing import Dict, List, Optional, Union
19 |
20 | import json5
21 | import matplotlib
22 | import PIL.Image
23 | from jupyter_client import BlockingKernelClient
24 |
25 | from qwen_agent.log import logger
26 | from qwen_agent.tools.base import BaseTool, register_tool
27 | from qwen_agent.utils.utils import (extract_code, print_traceback,
28 | save_url_to_local_work_dir)
29 |
30 | WORK_DIR = os.getenv('M6_CODE_INTERPRETER_WORK_DIR',
31 | os.getcwd() + '/workspace/ci_workspace/')
32 |
33 |
34 | def _fix_secure_write_for_code_interpreter():
35 | if 'linux' in sys.platform.lower():
36 | os.makedirs(WORK_DIR, exist_ok=True)
37 | fname = os.path.join(WORK_DIR,
38 | f'test_file_permission_{os.getpid()}.txt')
39 | if os.path.exists(fname):
40 | os.remove(fname)
41 | with os.fdopen(
42 | os.open(fname, os.O_CREAT | os.O_WRONLY | os.O_TRUNC, 0o0600),
43 | 'w') as f:
44 | f.write('test')
45 | file_mode = stat.S_IMODE(os.stat(fname).st_mode) & 0o6677
46 | if file_mode != 0o0600:
47 | os.environ['JUPYTER_ALLOW_INSECURE_WRITES'] = '1'
48 | if os.path.exists(fname):
49 | os.remove(fname)
50 |
51 |
52 | _fix_secure_write_for_code_interpreter()
53 |
54 | LAUNCH_KERNEL_PY = """
55 | from ipykernel import kernelapp as app
56 | app.launch_new_instance()
57 | """
58 |
59 | INIT_CODE_FILE = str(
60 | Path(__file__).absolute().parent / 'resource' /
61 | 'code_interpreter_init_kernel.py')
62 |
63 | ALIB_FONT_FILE = str(
64 | Path(__file__).absolute().parent / 'resource' /
65 | 'AlibabaPuHuiTi-3-45-Light.ttf')
66 |
67 | _KERNEL_CLIENTS: Dict[int, BlockingKernelClient] = {}
68 | _MISC_SUBPROCESSES: Dict[str, subprocess.Popen] = {}
69 |
70 |
71 | def _start_kernel(pid) -> BlockingKernelClient:
72 | connection_file = os.path.join(WORK_DIR,
73 | f'kernel_connection_file_{pid}.json')
74 | launch_kernel_script = os.path.join(WORK_DIR, f'launch_kernel_{pid}.py')
75 | for f in [connection_file, launch_kernel_script]:
76 | if os.path.exists(f):
77 | logger.info(f'WARNING: {f} already exists')
78 | os.remove(f)
79 |
80 | os.makedirs(WORK_DIR, exist_ok=True)
81 | with open(launch_kernel_script, 'w') as fout:
82 | fout.write(LAUNCH_KERNEL_PY)
83 |
84 | kernel_process = subprocess.Popen(
85 | [
86 | sys.executable,
87 | launch_kernel_script,
88 | '--IPKernelApp.connection_file',
89 | connection_file,
90 | '--matplotlib=inline',
91 | '--quiet',
92 | ],
93 | cwd=WORK_DIR,
94 | )
95 | _MISC_SUBPROCESSES[f'kc_{kernel_process.pid}'] = kernel_process
96 | logger.info(f"INFO: kernel process's PID = {kernel_process.pid}")
97 |
98 | # Wait for kernel connection file to be written
99 | while True:
100 | if not os.path.isfile(connection_file):
101 | time.sleep(0.1)
102 | else:
103 | # Keep looping if JSON parsing fails, file may be partially written
104 | try:
105 | with open(connection_file, 'r') as fp:
106 | json.load(fp)
107 | break
108 | except json.JSONDecodeError:
109 | pass
110 |
111 | # Client
112 | kc = BlockingKernelClient(connection_file=connection_file)
113 | asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
114 | kc.load_connection_file()
115 | kc.start_channels()
116 | kc.wait_for_ready()
117 | return kc
118 |
119 |
120 | def _kill_kernels_and_subprocesses(sig_num=None, _frame=None):
121 | for v in _KERNEL_CLIENTS.values():
122 | v.shutdown()
123 | for k in list(_KERNEL_CLIENTS.keys()):
124 | del _KERNEL_CLIENTS[k]
125 |
126 | for v in _MISC_SUBPROCESSES.values():
127 | v.terminate()
128 | for k in list(_MISC_SUBPROCESSES.keys()):
129 | del _MISC_SUBPROCESSES[k]
130 |
131 | if sig_num == signal.SIGINT:
132 | raise KeyboardInterrupt()
133 |
134 |
135 | atexit.register(_kill_kernels_and_subprocesses)
136 | signal.signal(signal.SIGTERM, _kill_kernels_and_subprocesses)
137 | signal.signal(signal.SIGINT, _kill_kernels_and_subprocesses)
138 |
139 |
140 | def _serve_image(image_base64: str) -> str:
141 | image_file = f'{uuid.uuid4()}.png'
142 | local_image_file = os.path.join(WORK_DIR, image_file)
143 |
144 | png_bytes = base64.b64decode(image_base64)
145 | assert isinstance(png_bytes, bytes)
146 | bytes_io = io.BytesIO(png_bytes)
147 | PIL.Image.open(bytes_io).save(local_image_file, 'png')
148 |
149 | static_url = os.getenv('M6_CODE_INTERPRETER_STATIC_URL',
150 | 'http://127.0.0.1:7865/static')
151 |
152 | # Hotfix: Temporarily generate image URL proxies for code interpreter to display in gradio
153 | # Todo: Generate real url
154 | if static_url == 'http://127.0.0.1:7865/static':
155 | if 'image_service' not in _MISC_SUBPROCESSES:
156 | try:
157 | # run a fastapi server for image show in gradio demo by http://127.0.0.1:7865/figure_name
158 | _MISC_SUBPROCESSES['image_service'] = subprocess.Popen([
159 | 'python',
160 | Path(__file__).absolute().parent / 'resource' /
161 | 'image_service.py'
162 | ])
163 | except Exception:
164 | print_traceback()
165 |
166 | image_url = f'{static_url}/{image_file}'
167 |
168 | return image_url
169 |
170 |
171 | def _escape_ansi(line: str) -> str:
172 | ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]')
173 | return ansi_escape.sub('', line)
174 |
175 |
176 | def _fix_matplotlib_cjk_font_issue():
177 | ttf_name = os.path.basename(ALIB_FONT_FILE)
178 | local_ttf = os.path.join(
179 | os.path.abspath(
180 | os.path.join(matplotlib.matplotlib_fname(), os.path.pardir)),
181 | 'fonts', 'ttf', ttf_name)
182 | if not os.path.exists(local_ttf):
183 | try:
184 | shutil.copy(ALIB_FONT_FILE, local_ttf)
185 | font_list_cache = os.path.join(matplotlib.get_cachedir(),
186 | 'fontlist-*.json')
187 | for cache_file in glob.glob(font_list_cache):
188 | with open(cache_file) as fin:
189 | cache_content = fin.read()
190 | if ttf_name not in cache_content:
191 | os.remove(cache_file)
192 | except Exception:
193 | print_traceback()
194 |
195 |
196 | def _execute_code(kc: BlockingKernelClient, code: str) -> str:
197 | kc.wait_for_ready()
198 | kc.execute(code)
199 | result = ''
200 | image_idx = 0
201 | while True:
202 | text = ''
203 | image = ''
204 | finished = False
205 | msg_type = 'error'
206 | try:
207 | msg = kc.get_iopub_msg()
208 | msg_type = msg['msg_type']
209 | if msg_type == 'status':
210 | if msg['content'].get('execution_state') == 'idle':
211 | finished = True
212 | elif msg_type == 'execute_result':
213 | text = msg['content']['data'].get('text/plain', '')
214 | if 'image/png' in msg['content']['data']:
215 | image_b64 = msg['content']['data']['image/png']
216 | image_url = _serve_image(image_b64)
217 | image_idx += 1
218 | image = '' % (image_idx, image_url)
219 | elif msg_type == 'display_data':
220 | if 'image/png' in msg['content']['data']:
221 | image_b64 = msg['content']['data']['image/png']
222 | image_url = _serve_image(image_b64)
223 | image_idx += 1
224 | image = '' % (image_idx, image_url)
225 | else:
226 | text = msg['content']['data'].get('text/plain', '')
227 | elif msg_type == 'stream':
228 | msg_type = msg['content']['name'] # stdout, stderr
229 | text = msg['content']['text']
230 | elif msg_type == 'error':
231 | text = _escape_ansi('\n'.join(msg['content']['traceback']))
232 | if 'M6_CODE_INTERPRETER_TIMEOUT' in text:
233 | text = 'Timeout: Code execution exceeded the time limit.'
234 | except queue.Empty:
235 | text = 'Timeout: Code execution exceeded the time limit.'
236 | finished = True
237 | except Exception:
238 | text = 'The code interpreter encountered an unexpected error.'
239 | print_traceback()
240 | finished = True
241 | if text:
242 | result += f'\n\n{msg_type}:\n\n```\n{text}\n```'
243 | if image:
244 | result += f'\n\n{image}'
245 | if finished:
246 | break
247 | result = result.lstrip('\n')
248 | return result
249 |
250 |
251 | @register_tool('code_interpreter')
252 | class CodeInterpreter(BaseTool):
253 | description = 'Python代码沙盒,可用于执行Python代码。'
254 | parameters = [{
255 | 'name': 'code',
256 | 'type': 'string',
257 | 'description': '待执行的代码',
258 | 'required': True
259 | }]
260 |
261 | def __init__(self, cfg: Optional[Dict] = None):
262 | self.args_format = '此工具的输入应为Markdown代码块。'
263 | super().__init__(cfg)
264 | self.file_access = True
265 |
266 | def call(self,
267 | params: Union[str, dict],
268 | files: List[str] = None,
269 | timeout: Optional[int] = 30,
270 | **kwargs) -> str:
271 | try:
272 | params = json5.loads(params)
273 | code = params['code']
274 | except Exception:
275 | code = extract_code(params)
276 |
277 | if not code.strip():
278 | return ''
279 | # download file
280 | if files:
281 | os.makedirs(WORK_DIR, exist_ok=True)
282 | for file in files:
283 | try:
284 | save_url_to_local_work_dir(file, WORK_DIR)
285 | except Exception:
286 | print_traceback()
287 |
288 | pid: int = os.getpid()
289 | if pid in _KERNEL_CLIENTS:
290 | kc = _KERNEL_CLIENTS[pid]
291 | else:
292 | _fix_matplotlib_cjk_font_issue()
293 | kc = _start_kernel(pid)
294 | with open(INIT_CODE_FILE) as fin:
295 | start_code = fin.read()
296 | start_code = start_code.replace('{{M6_FONT_PATH}}',
297 | repr(ALIB_FONT_FILE)[1:-1])
298 | logger.info(_execute_code(kc, start_code))
299 | _KERNEL_CLIENTS[pid] = kc
300 |
301 | if timeout:
302 | code = f'_M6CountdownTimer.start({timeout})\n{code}'
303 |
304 | fixed_code = []
305 | for line in code.split('\n'):
306 | fixed_code.append(line)
307 | if line.startswith('sns.set_theme('):
308 | fixed_code.append(
309 | 'plt.rcParams["font.family"] = _m6_font_prop.get_name()')
310 | fixed_code = '\n'.join(fixed_code)
311 | fixed_code += '\n\n' # Prevent code not executing in notebook due to no line breaks at the end
312 | result = _execute_code(kc, fixed_code)
313 |
314 | if timeout:
315 | _execute_code(kc, '_M6CountdownTimer.cancel()')
316 |
317 | return result if result.strip() else 'Finished execution.'
318 |
319 |
320 | #
321 | # The _BasePolicy and AnyThreadEventLoopPolicy below are borrowed from Tornado.
322 | # Ref: https://www.tornadoweb.org/en/stable/_modules/tornado/platform/asyncio.html#AnyThreadEventLoopPolicy
323 | #
324 |
325 | if sys.platform == 'win32' and hasattr(asyncio,
326 | 'WindowsSelectorEventLoopPolicy'):
327 | _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
328 | else:
329 | _BasePolicy = asyncio.DefaultEventLoopPolicy
330 |
331 |
332 | class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
333 | """Event loop policy that allows loop creation on any thread.
334 |
335 | The default `asyncio` event loop policy only automatically creates
336 | event loops in the main threads. Other threads must create event
337 | loops explicitly or `asyncio.get_event_loop` (and therefore
338 | `.IOLoop.current`) will fail. Installing this policy allows event
339 | loops to be created automatically on any thread.
340 |
341 | Usage::
342 | asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
343 | """
344 |
345 | def get_event_loop(self) -> asyncio.AbstractEventLoop:
346 | try:
347 | return super().get_event_loop()
348 | except RuntimeError:
349 | # "There is no current event loop in thread %r"
350 | loop = self.new_event_loop()
351 | self.set_event_loop(loop)
352 | return loop
353 |
--------------------------------------------------------------------------------
/qwen_agent/tools/doc_parser.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import json
3 | import os
4 | import re
5 | from typing import Dict, Optional, Union
6 | from urllib.parse import unquote, urlparse
7 |
8 | import json5
9 | from pydantic import BaseModel
10 |
11 | from qwen_agent.log import logger
12 | from qwen_agent.tools.base import BaseTool, register_tool
13 | from qwen_agent.tools.storage import Storage
14 | from qwen_agent.utils.doc_parser import parse_doc, parse_html_bs
15 | from qwen_agent.utils.utils import (get_file_type, hash_sha256, is_local_path,
16 | print_traceback,
17 | save_url_to_local_work_dir)
18 |
19 |
20 | class FileTypeNotImplError(NotImplementedError):
21 | pass
22 |
23 |
24 | class Record(BaseModel):
25 | url: str
26 | time: str
27 | source: str
28 | raw: list
29 | title: str
30 | topic: str
31 | checked: bool
32 | session: list
33 |
34 | def to_dict(self) -> dict:
35 | return {
36 | 'url': self.url,
37 | 'time': self.time,
38 | 'source': self.source,
39 | 'raw': self.raw,
40 | 'title': self.title,
41 | 'topic': self.topic,
42 | 'checked': self.checked,
43 | 'session': self.session
44 | }
45 |
46 |
47 | def sanitize_chrome_file_path(file_path: str) -> str:
48 | # For Linux and macOS.
49 | if os.path.exists(file_path):
50 | return file_path
51 |
52 | # For native Windows, drop the leading '/' in '/C:/'
53 | win_path = file_path
54 | if win_path.startswith('/'):
55 | win_path = win_path[1:]
56 | if os.path.exists(win_path):
57 | return win_path
58 |
59 | # For Windows + WSL.
60 | if re.match(r'^[A-Za-z]:/', win_path):
61 | wsl_path = f'/mnt/{win_path[0].lower()}/{win_path[3:]}'
62 | if os.path.exists(wsl_path):
63 | return wsl_path
64 |
65 | # For native Windows, replace / with \.
66 | win_path = win_path.replace('/', '\\')
67 | if os.path.exists(win_path):
68 | return win_path
69 |
70 | return file_path
71 |
72 |
73 | def process_file(url: str, db: Storage = None):
74 | logger.info('Starting cache pages...')
75 | url = url
76 | if url.split('.')[-1].lower() in ['pdf', 'docx', 'pptx']:
77 | date1 = datetime.datetime.now()
78 |
79 | if url.startswith('https://') or url.startswith('http://') or re.match(
80 | r'^[A-Za-z]:\\', url) or re.match(r'^[A-Za-z]:/', url):
81 | pdf_path = url
82 | else:
83 | parsed_url = urlparse(url)
84 | pdf_path = unquote(parsed_url.path)
85 | pdf_path = sanitize_chrome_file_path(pdf_path)
86 |
87 | try:
88 | if not is_local_path(url):
89 | # download
90 | file_tmp_path = save_url_to_local_work_dir(
91 | pdf_path,
92 | db.root,
93 | new_name=hash_sha256(url) + '.' +
94 | pdf_path.split('.')[-1].lower())
95 | pdf_content = parse_doc(file_tmp_path)
96 | else:
97 | pdf_content = parse_doc(pdf_path)
98 | date2 = datetime.datetime.now()
99 | logger.info('Parsing pdf time: ' + str(date2 - date1))
100 | content = pdf_content
101 | source = 'doc'
102 | title = pdf_path.split('/')[-1].split('\\')[-1].split('.')[0]
103 | except Exception:
104 | print_traceback()
105 | return 'failed'
106 | else:
107 | if not is_local_path(url):
108 | file_tmp_path = save_url_to_local_work_dir(
109 | url, db.root, new_name=hash_sha256(url))
110 | else:
111 | file_tmp_path = url
112 | file_source = get_file_type(file_tmp_path)
113 | if file_source == 'html':
114 | try:
115 | content = parse_html_bs(file_tmp_path)
116 | title = content[0]['metadata']['title']
117 | except Exception:
118 | print_traceback()
119 | return 'failed'
120 | source = 'html'
121 | else:
122 | raise FileTypeNotImplError
123 |
124 | # save real data
125 | now_time = str(datetime.date.today())
126 | new_record = Record(url=url,
127 | time=now_time,
128 | source=source,
129 | raw=content,
130 | title=title,
131 | topic='',
132 | checked=True,
133 | session=[]).to_dict()
134 | new_record_str = json.dumps(new_record, ensure_ascii=False)
135 | db.put(hash_sha256(url), new_record_str)
136 |
137 | return new_record
138 |
139 |
140 | @register_tool('doc_parser')
141 | class DocParser(BaseTool):
142 | description = '解析并存储一个文件,返回解析后的文件内容'
143 | parameters = [{
144 | 'name': 'url',
145 | 'type': 'string',
146 | 'description': '待解析的文件的路径',
147 | 'required': True
148 | }]
149 |
150 | def __init__(self, cfg: Optional[Dict] = None):
151 | super().__init__(cfg)
152 | self.data_root = self.cfg.get(
153 | 'path', 'workspace/default_doc_parser_data_path')
154 | self.db = Storage({'storage_root_path': self.data_root})
155 |
156 | def call(self,
157 | params: Union[str, dict],
158 | ignore_cache: bool = False) -> dict:
159 | """Parse file by url, and return the formatted content."""
160 |
161 | params = self._verify_json_format_args(params)
162 |
163 | if ignore_cache:
164 | record = process_file(url=params['url'], db=self.db)
165 | else:
166 | try:
167 | record = self.db.get(hash_sha256(params['url']))
168 | record = json5.loads(record)
169 |
170 | except Exception:
171 | record = process_file(url=params['url'], db=self.db)
172 |
173 | return record
174 |
--------------------------------------------------------------------------------
/qwen_agent/tools/image_gen.py:
--------------------------------------------------------------------------------
1 | import json
2 | import urllib.parse
3 | from typing import Union
4 |
5 | from qwen_agent.tools.base import BaseTool, register_tool
6 |
7 |
8 | @register_tool('image_gen')
9 | class ImageGen(BaseTool):
10 | description = 'AI绘画(图像生成)服务,输入文本描述和图像分辨率,返回根据文本信息绘制的图片URL。'
11 | parameters = [{
12 | 'name': 'prompt',
13 | 'type': 'string',
14 | 'description': '详细描述了希望生成的图像具有什么内容,例如人物、环境、动作等细节描述,使用英文',
15 | 'required': True
16 | }, {
17 | 'name':
18 | 'resolution',
19 | 'type':
20 | 'string',
21 | 'description':
22 | '格式是 数字*数字,表示希望生成的图像的分辨率大小,选项有[1024*1024, 720*1280, 1280*720]'
23 | }]
24 |
25 | def call(self, params: Union[str, dict], **kwargs) -> str:
26 | params = self._verify_json_format_args(params)
27 |
28 | prompt = params['prompt']
29 | prompt = urllib.parse.quote(prompt)
30 | return json.dumps(
31 | {'image_url': f'https://image.pollinations.ai/prompt/{prompt}'},
32 | ensure_ascii=False)
33 |
--------------------------------------------------------------------------------
/qwen_agent/tools/resource/AlibabaPuHuiTi-3-45-Light.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidons-master/HomeRPC/07147c3d4fc554b46bec8296e393d727a7b6238c/qwen_agent/tools/resource/AlibabaPuHuiTi-3-45-Light.ttf
--------------------------------------------------------------------------------
/qwen_agent/tools/resource/code_interpreter_init_kernel.py:
--------------------------------------------------------------------------------
1 | import json # noqa
2 | import math # noqa
3 | import os # noqa
4 | import re # noqa
5 | import signal
6 |
7 | import matplotlib # noqa
8 | import matplotlib.pyplot as plt
9 | import numpy as np # noqa
10 | import pandas as pd # noqa
11 | import seaborn as sns
12 | from matplotlib.font_manager import FontProperties
13 | from sympy import Eq, solve, symbols # noqa
14 |
15 |
16 | def input(*args, **kwargs): # noqa
17 | raise NotImplementedError('Python input() function is disabled.')
18 |
19 |
20 | def _m6_timout_handler(_signum=None, _frame=None):
21 | raise TimeoutError('M6_CODE_INTERPRETER_TIMEOUT')
22 |
23 |
24 | try:
25 | signal.signal(signal.SIGALRM, _m6_timout_handler)
26 | except AttributeError: # windows
27 | pass
28 |
29 |
30 | class _M6CountdownTimer:
31 |
32 | @classmethod
33 | def start(cls, timeout: int):
34 | try:
35 | signal.alarm(timeout)
36 | except AttributeError: # windows
37 | pass # TODO: I haven't found a solution that works with jupyter yet.
38 |
39 | @classmethod
40 | def cancel(cls):
41 | try:
42 | signal.alarm(0)
43 | except AttributeError: # windows
44 | pass # TODO
45 |
46 |
47 | sns.set_theme()
48 |
49 | _m6_font_prop = FontProperties(fname='{{M6_FONT_PATH}}')
50 | plt.rcParams['font.family'] = _m6_font_prop.get_name()
51 |
--------------------------------------------------------------------------------
/qwen_agent/tools/resource/image_service.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import uvicorn
4 | from fastapi import FastAPI
5 | from fastapi.middleware.cors import CORSMiddleware
6 | from fastapi.staticfiles import StaticFiles
7 |
8 | app = FastAPI()
9 |
10 | origins = ['http://127.0.0.1:7860']
11 |
12 | app.add_middleware(
13 | CORSMiddleware,
14 | allow_origins=origins,
15 | allow_credentials=True,
16 | allow_methods=['*'],
17 | allow_headers=['*'],
18 | )
19 |
20 | app.mount('/static',
21 | StaticFiles(directory=os.getcwd() + '/workspace/ci_workspace/'),
22 | name='static')
23 |
24 | if __name__ == '__main__':
25 | uvicorn.run(app='image_service:app', port=7865)
26 |
--------------------------------------------------------------------------------
/qwen_agent/tools/retrieval.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional, Union
2 |
3 | import json5
4 |
5 | from qwen_agent.log import logger
6 | from qwen_agent.tools.base import BaseTool, register_tool
7 | from qwen_agent.utils.utils import get_basename_from_url, print_traceback
8 |
9 | from .doc_parser import DocParser, FileTypeNotImplError
10 | from .similarity_search import (RefMaterialInput, RefMaterialInputItem,
11 | SimilaritySearch)
12 |
13 |
14 | def format_records(records: List[Dict]):
15 | formatted_records = []
16 | for record in records:
17 | formatted_records.append(
18 | RefMaterialInput(url=get_basename_from_url(record['url']),
19 | text=[
20 | RefMaterialInputItem(
21 | content=x['page_content'],
22 | token=x['token']) for x in record['raw']
23 | ]))
24 | return formatted_records
25 |
26 |
27 | @register_tool('retrieval')
28 | class Retrieval(BaseTool):
29 | description = '从给定文件列表中检索出和问题相关的内容'
30 | parameters = [{
31 | 'name': 'query',
32 | 'type': 'string',
33 | 'description': '问题,需要从文档中检索和这个问题有关的内容'
34 | }, {
35 | 'name': 'files',
36 | 'type': 'array',
37 | 'items': {
38 | 'type': 'string'
39 | },
40 | 'description': '待解析的文件路径列表',
41 | 'required': True
42 | }]
43 |
44 | def __init__(self, cfg: Optional[Dict] = None):
45 | super().__init__(cfg)
46 | self.doc_parse = DocParser()
47 | self.search = SimilaritySearch()
48 |
49 | def call(self,
50 | params: Union[str, dict],
51 | ignore_cache: bool = False,
52 | max_token: int = 4000) -> list:
53 | """RAG tool.
54 |
55 | Step1: Parse and save files
56 | Step2: Retrieval related content according to query
57 |
58 | Args:
59 | params: The files and query.
60 | ignore_cache: When set to True, overwrite the same documents that have been parsed before.
61 | max_token: Maximum retrieval length.
62 |
63 | Returns:
64 | The retrieved file list.
65 | """
66 |
67 | params = self._verify_json_format_args(params)
68 | files = params.get('files', [])
69 | if isinstance(files, str):
70 | files = json5.loads(files)
71 | records = []
72 | for file in files:
73 | try:
74 | _record = self.doc_parse.call(params={'url': file},
75 | ignore_cache=ignore_cache)
76 | records.append(_record)
77 | except FileTypeNotImplError:
78 | logger.warning(
79 | 'Only Parsing the Following File Types: [\'web page\', \'.pdf\', \'.docx\', \'.pptx\'] to knowledge base!'
80 | )
81 | except Exception:
82 | print_traceback()
83 |
84 | query = params.get('query', '')
85 | if query and records:
86 | records = format_records(records)
87 | return self._retrieve_content(query, records, max_token)
88 | else:
89 | return records
90 |
91 | def _retrieve_content(self,
92 | query: str,
93 | records: List[RefMaterialInput],
94 | max_token=4000) -> List[Dict]:
95 | single_max_token = int(max_token / len(records))
96 | _ref_list = []
97 | for record in records:
98 | # Retrieval for query
99 | now_ref_list = self.search.call(params={'query': query},
100 | doc=record,
101 | max_token=single_max_token)
102 | _ref_list.append(now_ref_list)
103 | return _ref_list
104 |
--------------------------------------------------------------------------------
/qwen_agent/tools/similarity_search.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union
2 |
3 | from pydantic import BaseModel
4 |
5 | from qwen_agent.log import logger
6 | from qwen_agent.tools.base import BaseTool, register_tool
7 | from qwen_agent.utils.tokenization_qwen import count_tokens
8 | from qwen_agent.utils.utils import get_split_word, parse_keyword
9 |
10 |
11 | class RefMaterialOutput(BaseModel):
12 | """The knowledge data format output from the retrieval"""
13 | url: str
14 | text: list
15 |
16 | def to_dict(self) -> dict:
17 | return {
18 | 'url': self.url,
19 | 'text': self.text,
20 | }
21 |
22 |
23 | class RefMaterialInputItem(BaseModel):
24 | content: str
25 | token: int
26 |
27 | def to_dict(self) -> dict:
28 | return {'content': self.content, 'token': self.token}
29 |
30 |
31 | class RefMaterialInput(BaseModel):
32 | """The knowledge data format input to the retrieval"""
33 | url: str
34 | text: List[RefMaterialInputItem]
35 |
36 | def to_dict(self) -> dict:
37 | return {'url': self.url, 'text': [x.to_dict() for x in self.text]}
38 |
39 |
40 | def format_input_doc(doc: List[str]) -> RefMaterialInput:
41 | new_doc = []
42 | for x in doc:
43 | item = RefMaterialInputItem(content=x, token=count_tokens(x))
44 | new_doc.append(item)
45 | return RefMaterialInput(url='', text=new_doc)
46 |
47 |
48 | @register_tool('similarity_search')
49 | class SimilaritySearch(BaseTool):
50 | description = '从给定文档中检索和问题相关的部分'
51 | parameters = [{
52 | 'name': 'query',
53 | 'type': 'string',
54 | 'description': '问题,需要从文档中检索和这个问题有关的内容',
55 | 'required': True
56 | }]
57 |
58 | def call(self,
59 | params: Union[str, dict],
60 | doc: Union[RefMaterialInput, str, List[str]] = None,
61 | max_token: int = 4000) -> dict:
62 | params = self._verify_json_format_args(params)
63 |
64 | query = params['query']
65 | if not doc:
66 | return {}
67 | if isinstance(doc, str):
68 | doc = [doc]
69 | if isinstance(doc, list):
70 | doc = format_input_doc(doc)
71 |
72 | tokens = [page.token for page in doc.text]
73 | all_tokens = sum(tokens)
74 | logger.info(f'all tokens of {doc.url}: {all_tokens}')
75 | if all_tokens <= max_token:
76 | logger.info('use full ref')
77 | return RefMaterialOutput(url=doc.url,
78 | text=[x.content
79 | for x in doc.text]).to_dict()
80 |
81 | wordlist = parse_keyword(query)
82 | logger.info('wordlist: ' + ','.join(wordlist))
83 | if not wordlist:
84 | return self.get_top(doc, max_token)
85 |
86 | sims = []
87 | for i, page in enumerate(doc.text):
88 | sim = self.filter_section(page.content, wordlist)
89 | sims.append([i, sim])
90 | sims.sort(key=lambda item: item[1], reverse=True)
91 | assert len(sims) > 0
92 |
93 | res = []
94 | max_sims = sims[0][1]
95 | if max_sims != 0:
96 | manul = 0
97 | for i in range(min(manul, len(doc.text))):
98 | if max_token >= tokens[
99 | i] * 2: # Ensure that the first two pages do not fill up the window
100 | res.append(doc.text[i].content)
101 | max_token -= tokens[i]
102 | for i, x in enumerate(sims):
103 | if x[0] < manul:
104 | continue
105 | page = doc.text[x[0]]
106 | if max_token < page.token:
107 | use_rate = (max_token / page.token) * 0.2
108 | res.append(page.content[:int(len(page.content) *
109 | use_rate)])
110 | break
111 |
112 | res.append(page.content)
113 | max_token -= page.token
114 |
115 | logger.info(f'remaining slots: {max_token}')
116 | return RefMaterialOutput(url=doc.url, text=res).to_dict()
117 | else:
118 | return self.get_top(doc, max_token)
119 |
120 | def filter_section(self, text: str, wordlist: list) -> int:
121 | page_list = get_split_word(text)
122 | sim = self.jaccard_similarity(wordlist, page_list)
123 |
124 | return sim
125 |
126 | @staticmethod
127 | def jaccard_similarity(list1: list, list2: list) -> int:
128 | s1 = set(list1)
129 | s2 = set(list2)
130 | return len(s1.intersection(s2)) # avoid text length impact
131 | # return len(s1.intersection(s2)) / len(s1.union(s2)) # jaccard similarity
132 |
133 | @staticmethod
134 | def get_top(doc: RefMaterialInput, max_token=4000) -> dict:
135 | now_token = 0
136 | text = []
137 | for page in doc.text:
138 | if (now_token + page.token) <= max_token:
139 | text.append(page.content)
140 | now_token += page.token
141 | else:
142 | use_rate = ((max_token - now_token) / page.token) * 0.2
143 | text.append(page.content[:int(len(page.content) * use_rate)])
144 | break
145 | logger.info(f'remaining slots: {max_token-now_token}')
146 | return RefMaterialOutput(url=doc.url, text=text).to_dict()
147 |
--------------------------------------------------------------------------------
/qwen_agent/tools/storage.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Dict, Optional, Union
3 |
4 | from qwen_agent.tools.base import BaseTool, register_tool
5 | from qwen_agent.utils.utils import read_text_from_file, save_text_to_file
6 |
7 | DEFAULT_STORAGE_PATH = 'workspace/default_data_path'
8 | SUCCESS_MESSAGE = 'SUCCESS'
9 |
10 |
11 | @register_tool('storage')
12 | class Storage(BaseTool):
13 | """
14 | This is a special tool for data storage
15 | """
16 | description = '存储和读取数据的工具'
17 | parameters = [{
18 | 'name': 'operate',
19 | 'type': 'string',
20 | 'description':
21 | '数据操作类型,可选项为["put", "get", "delete", "scan"]之一,分别为存数据、取数据、删除数据、遍历数据',
22 | 'required': True
23 | }, {
24 | 'name': 'key',
25 | 'type': 'string',
26 | 'description':
27 | '数据的路径,类似于文件路径,是一份数据的唯一标识,不能为空,默认根目录为`/`。存数据时,应该合理的设计路径,保证路径含义清晰且唯一。',
28 | 'default': '/'
29 | }, {
30 | 'name': 'value',
31 | 'type': 'string',
32 | 'description': '数据的内容,仅存数据时需要'
33 | }]
34 |
35 | def __init__(self, cfg: Optional[Dict] = None):
36 | super().__init__(cfg)
37 | self.root = self.cfg.get('storage_root_path', DEFAULT_STORAGE_PATH)
38 | os.makedirs(self.root, exist_ok=True)
39 |
40 | def call(self, params: Union[str, dict], **kwargs) -> str:
41 | params = self._verify_json_format_args(params)
42 | operate = params['operate']
43 | key = params.get('key', '/')
44 | if key.startswith('/'):
45 | key = key[1:]
46 |
47 | if operate == 'put':
48 | assert 'value' in params
49 | return self.put(key, params['value'])
50 | elif operate == 'get':
51 | return self.get(key)
52 | elif operate == 'delete':
53 | return self.delete(key)
54 | else:
55 | return self.scan(key)
56 |
57 | def put(self, key: str, value: str, path: Optional[str] = None) -> str:
58 | path = path or self.root
59 |
60 | # one file for one key value pair
61 | path = os.path.join(path, key)
62 |
63 | path_dir = path[:path.rfind('/') + 1]
64 | if path_dir:
65 | os.makedirs(path_dir, exist_ok=True)
66 |
67 | save_text_to_file(path, value)
68 | return SUCCESS_MESSAGE
69 |
70 | def get(self, key: str, path: Optional[str] = None) -> str:
71 | path = path or self.root
72 | return read_text_from_file(os.path.join(path, key))
73 |
74 | def delete(self, key, path: Optional[str] = None) -> str:
75 | path = path or self.root
76 | path = os.path.join(path, key)
77 | if os.path.exists(path):
78 | os.remove(path)
79 | return f'Successfully deleted{key}'
80 | else:
81 | return f'Delete Failed: {key} does not exist'
82 |
83 | def scan(self, key: str, path: Optional[str] = None) -> str:
84 | path = path or self.root
85 | path = os.path.join(path, key)
86 | if os.path.exists(path):
87 | if not os.path.isdir(path):
88 | return 'Scan Failed: The scan operation requires passing in a key to a folder path'
89 | # All key-value pairs
90 | kvs = {}
91 | for root, dirs, files in os.walk(path):
92 | for file in files:
93 | k = os.path.join(root, file)[len(path):]
94 | if not k.startswith('/'):
95 | k = '/' + k
96 | v = read_text_from_file(os.path.join(root, file))
97 | kvs[k] = v
98 | return '\n'.join([f'{k}: {v}' for k, v in kvs.items()])
99 | else:
100 | return f'Scan Failed: {key} does not exist.'
101 |
--------------------------------------------------------------------------------
/qwen_agent/tools/web_extractor.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | import requests
4 |
5 | from qwen_agent.tools.base import BaseTool, register_tool
6 |
7 |
8 | @register_tool('web_extractor')
9 | class WebExtractor(BaseTool):
10 | description = '根据网页URL,获取网页内容的工具'
11 | parameters = [{
12 | 'name': 'url',
13 | 'type': 'string',
14 | 'description': '网页URL',
15 | 'required': True
16 | }]
17 |
18 | def call(self, params: Union[str, dict], **kwargs) -> str:
19 | only_text = self.cfg.get('only_text', False)
20 |
21 | params = self._verify_json_format_args(params)
22 |
23 | url = params['url']
24 | headers = {
25 | 'User-Agent':
26 | 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
27 | }
28 | response = requests.get(url, headers=headers)
29 | if response.status_code == 200:
30 | if only_text:
31 |
32 | import justext
33 |
34 | paragraphs = justext.justext(response.text,
35 | justext.get_stoplist('English'))
36 | content = '\n\n'.join(
37 | [paragraph.text for paragraph in paragraphs]).strip()
38 | if content:
39 | return content
40 | else:
41 | return response.text
42 | else:
43 | return response.text
44 | else:
45 | return ''
46 |
--------------------------------------------------------------------------------
/qwen_agent/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/guidons-master/HomeRPC/07147c3d4fc554b46bec8296e393d727a7b6238c/qwen_agent/utils/__init__.py
--------------------------------------------------------------------------------
/qwen_agent/utils/doc_parser.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | from qwen_agent.utils.tokenization_qwen import count_tokens
4 |
5 | ONE_PAGE_TOKEN = 500
6 |
7 |
8 | def rm_newlines(text):
9 | text = re.sub(r'(?<=[^\.。::])\n', ' ', text)
10 | return text
11 |
12 |
13 | def rm_cid(text):
14 | text = re.sub(r'\(cid:\d+\)', '', text)
15 | return text
16 |
17 |
18 | def rm_hexadecimal(text):
19 | text = re.sub(r'[0-9A-Fa-f]{21,}', '', text)
20 | return text
21 |
22 |
23 | def rm_continuous_placeholders(text):
24 | text = re.sub(r'(\.|-|——|。|_|\*){7,}', '...', text)
25 | return text
26 |
27 |
28 | def deal(text):
29 | text = rm_newlines(text)
30 | text = rm_cid(text)
31 | text = rm_hexadecimal(text)
32 | text = rm_continuous_placeholders(text)
33 | return text
34 |
35 |
36 | def parse_doc(path):
37 | if '.pdf' in path.lower():
38 | from pdfminer.high_level import extract_text
39 | text = extract_text(path)
40 | elif '.docx' in path.lower():
41 | import docx2txt
42 | text = docx2txt.process(path)
43 | elif '.pptx' in path.lower():
44 | from pptx import Presentation
45 | ppt = Presentation(path)
46 | text = []
47 | for slide in ppt.slides:
48 | for shape in slide.shapes:
49 | if hasattr(shape, 'text'):
50 | text.append(shape.text)
51 | text = '\n'.join(text)
52 | else:
53 | raise TypeError
54 |
55 | text = deal(text)
56 | return split_text_to_trunk(text, path)
57 |
58 |
59 | def pre_process_html(s):
60 | # replace multiple newlines
61 | s = re.sub('\n+', '\n', s)
62 | # replace special string
63 | s = s.replace("Add to Qwen's Reading List", '')
64 | return s
65 |
66 |
67 | def parse_html_bs(path):
68 | try:
69 | from bs4 import BeautifulSoup
70 | except Exception:
71 | raise ValueError('Please install bs4 by `pip install beautifulsoup4`')
72 | bs_kwargs = {'features': 'lxml'}
73 | with open(path, 'r', encoding='utf-8') as f:
74 | soup = BeautifulSoup(f, **bs_kwargs)
75 |
76 | text = soup.get_text()
77 |
78 | if soup.title:
79 | title = str(soup.title.string)
80 | else:
81 | title = ''
82 | text = pre_process_html(text)
83 | return split_text_to_trunk(text, path, title)
84 |
85 |
86 | def split_text_to_trunk(content: str, path: str, title: str = ''):
87 | all_tokens = count_tokens(content)
88 | all_pages = round(all_tokens / ONE_PAGE_TOKEN)
89 | if all_pages == 0:
90 | all_pages = 1
91 | len_content = len(content)
92 | len_one_page = int(len_content /
93 | all_pages) # Approximately equal to ONE_PAGE_TOKEN
94 |
95 | res = []
96 | for i in range(0, len_content, len_one_page):
97 | text = content[i:min(i + len_one_page, len_content)]
98 | res.append({
99 | 'page_content': text,
100 | 'metadata': {
101 | 'source': path,
102 | 'title': title,
103 | 'page': (i % len_one_page)
104 | },
105 | 'token': count_tokens(text)
106 | })
107 | return res
108 |
--------------------------------------------------------------------------------
/qwen_agent/utils/tokenization_qwen.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Alibaba Cloud.
2 | #
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 | """Tokenization classes for QWen."""
6 |
7 | import base64
8 | import logging
9 | import os
10 | import unicodedata
11 | from dataclasses import dataclass, field
12 | from pathlib import Path
13 | from typing import Collection, Dict, List, Set, Tuple, Union
14 |
15 | import tiktoken
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 | VOCAB_FILES_NAMES = {'vocab_file': 'qwen.tiktoken'}
20 |
21 | PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
22 | ENDOFTEXT = '<|endoftext|>'
23 | IMSTART = '<|im_start|>'
24 | IMEND = '<|im_end|>'
25 | # as the default behavior is changed to allow special tokens in
26 | # regular texts, the surface forms of special tokens need to be
27 | # as different as possible to minimize the impact
28 | EXTRAS = tuple((f'<|extra_{i}|>' for i in range(205)))
29 | # changed to use actual index to avoid misconfiguration with vocabulary expansion
30 | SPECIAL_START_ID = 151643
31 | SPECIAL_TOKENS = tuple(
32 | enumerate(
33 | ((
34 | ENDOFTEXT,
35 | IMSTART,
36 | IMEND,
37 | ) + EXTRAS),
38 | start=SPECIAL_START_ID,
39 | ))
40 | SPECIAL_TOKENS_SET = set(t for i, t in SPECIAL_TOKENS)
41 |
42 |
43 | def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
44 | with open(tiktoken_bpe_file, 'rb') as f:
45 | contents = f.read()
46 | return {
47 | base64.b64decode(token): int(rank)
48 | for token, rank in (line.split() for line in contents.splitlines()
49 | if line)
50 | }
51 |
52 |
53 | @dataclass(frozen=True, eq=True)
54 | class AddedToken:
55 | """
56 | AddedToken represents a token to be added to a Tokenizer An AddedToken can have special options defining the
57 | way it should behave.
58 | """
59 |
60 | content: str = field(default_factory=str)
61 | single_word: bool = False
62 | lstrip: bool = False
63 | rstrip: bool = False
64 | normalized: bool = True
65 |
66 | def __getstate__(self):
67 | return self.__dict__
68 |
69 |
70 | class QWenTokenizer:
71 | """QWen tokenizer."""
72 |
73 | vocab_files_names = VOCAB_FILES_NAMES
74 |
75 | def __init__(
76 | self,
77 | vocab_file=None,
78 | errors='replace',
79 | extra_vocab_file=None,
80 | **kwargs,
81 | ):
82 | if not vocab_file:
83 | vocab_file = VOCAB_FILES_NAMES['vocab_file']
84 | self._decode_use_source_tokenizer = False
85 |
86 | # how to handle errors in decoding UTF-8 byte sequences
87 | # use ignore if you are in streaming inference
88 | self.errors = errors
89 |
90 | self.mergeable_ranks = _load_tiktoken_bpe(
91 | vocab_file) # type: Dict[bytes, int]
92 | self.special_tokens = {token: index for index, token in SPECIAL_TOKENS}
93 |
94 | # try load extra vocab from file
95 | if extra_vocab_file is not None:
96 | used_ids = set(self.mergeable_ranks.values()) | set(
97 | self.special_tokens.values())
98 | extra_mergeable_ranks = _load_tiktoken_bpe(extra_vocab_file)
99 | for token, index in extra_mergeable_ranks.items():
100 | if token in self.mergeable_ranks:
101 | logger.info(f'extra token {token} exists, skipping')
102 | continue
103 | if index in used_ids:
104 | logger.info(
105 | f'the index {index} for extra token {token} exists, skipping'
106 | )
107 | continue
108 | self.mergeable_ranks[token] = index
109 | # the index may be sparse after this, but don't worry tiktoken.Encoding will handle this
110 |
111 | enc = tiktoken.Encoding(
112 | 'Qwen',
113 | pat_str=PAT_STR,
114 | mergeable_ranks=self.mergeable_ranks,
115 | special_tokens=self.special_tokens,
116 | )
117 | assert (
118 | len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
119 | ), f'{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding'
120 |
121 | self.decoder = {v: k
122 | for k, v in self.mergeable_ranks.items()
123 | } # type: dict[int, bytes|str]
124 | self.decoder.update({v: k for k, v in self.special_tokens.items()})
125 |
126 | self.tokenizer = enc # type: tiktoken.Encoding
127 |
128 | self.eod_id = self.tokenizer.eot_token
129 | self.im_start_id = self.special_tokens[IMSTART]
130 | self.im_end_id = self.special_tokens[IMEND]
131 |
132 | def __getstate__(self):
133 | # for pickle lovers
134 | state = self.__dict__.copy()
135 | del state['tokenizer']
136 | return state
137 |
138 | def __setstate__(self, state):
139 | # tokenizer is not python native; don't pass it; rebuild it
140 | self.__dict__.update(state)
141 | enc = tiktoken.Encoding(
142 | 'Qwen',
143 | pat_str=PAT_STR,
144 | mergeable_ranks=self.mergeable_ranks,
145 | special_tokens=self.special_tokens,
146 | )
147 | self.tokenizer = enc
148 |
149 | def __len__(self) -> int:
150 | return self.tokenizer.n_vocab
151 |
152 | def get_vocab(self) -> Dict[bytes, int]:
153 | return self.mergeable_ranks
154 |
155 | def convert_tokens_to_ids(
156 | self, tokens: Union[bytes, str, List[Union[bytes,
157 | str]]]) -> List[int]:
158 | ids = []
159 | if isinstance(tokens, (str, bytes)):
160 | if tokens in self.special_tokens:
161 | return self.special_tokens[tokens]
162 | else:
163 | return self.mergeable_ranks.get(tokens)
164 | for token in tokens:
165 | if token in self.special_tokens:
166 | ids.append(self.special_tokens[token])
167 | else:
168 | ids.append(self.mergeable_ranks.get(token))
169 | return ids
170 |
171 | def _add_tokens(
172 | self,
173 | new_tokens: Union[List[str], List[AddedToken]],
174 | special_tokens: bool = False,
175 | ) -> int:
176 | if not special_tokens and new_tokens:
177 | raise ValueError('Adding regular tokens is not supported')
178 | for token in new_tokens:
179 | surface_form = token.content if isinstance(token,
180 | AddedToken) else token
181 | if surface_form not in SPECIAL_TOKENS_SET:
182 | raise ValueError(
183 | 'Adding unknown special tokens is not supported')
184 | return 0
185 |
186 | def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
187 | """
188 | Save only the vocabulary of the tokenizer (vocabulary).
189 |
190 | Returns:
191 | `Tuple(str)`: Paths to the files saved.
192 | """
193 | file_path = os.path.join(save_directory, 'qwen.tiktoken')
194 | with open(file_path, 'w', encoding='utf8') as w:
195 | for k, v in self.mergeable_ranks.items():
196 | line = base64.b64encode(k).decode('utf8') + ' ' + str(v) + '\n'
197 | w.write(line)
198 | return (file_path, )
199 |
200 | def tokenize(
201 | self,
202 | text: str,
203 | allowed_special: Union[Set, str] = 'all',
204 | disallowed_special: Union[Collection, str] = (),
205 | **kwargs,
206 | ) -> List[Union[bytes, str]]:
207 | """
208 | Converts a string in a sequence of tokens.
209 |
210 | Args:
211 | text (`str`):
212 | The sequence to be encoded.
213 | allowed_special (`Literal["all"]` or `set`):
214 | The surface forms of the tokens to be encoded as special tokens in regular texts.
215 | Default to "all".
216 | disallowed_special (`Literal["all"]` or `Collection`):
217 | The surface forms of the tokens that should not be in regular texts and trigger errors.
218 | Default to an empty tuple.
219 |
220 | kwargs (additional keyword arguments, *optional*):
221 | Will be passed to the underlying model specific encode method.
222 |
223 | Returns:
224 | `List[bytes|str]`: The list of tokens.
225 | """
226 | tokens = []
227 | text = unicodedata.normalize('NFC', text)
228 |
229 | # this implementation takes a detour: text -> token id -> token surface forms
230 | for t in self.tokenizer.encode(text,
231 | allowed_special=allowed_special,
232 | disallowed_special=disallowed_special):
233 | tokens.append(self.decoder[t])
234 | return tokens
235 |
236 | def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
237 | """
238 | Converts a sequence of tokens in a single string.
239 | """
240 | text = ''
241 | temp = b''
242 | for t in tokens:
243 | if isinstance(t, str):
244 | if temp:
245 | text += temp.decode('utf-8', errors=self.errors)
246 | temp = b''
247 | text += t
248 | elif isinstance(t, bytes):
249 | temp += t
250 | else:
251 | raise TypeError('token should only be of type types or str')
252 | if temp:
253 | text += temp.decode('utf-8', errors=self.errors)
254 | return text
255 |
256 | @property
257 | def vocab_size(self):
258 | return self.tokenizer.n_vocab
259 |
260 | def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
261 | """Converts an id to a token, special tokens included"""
262 | if index in self.decoder:
263 | return self.decoder[index]
264 | raise ValueError('unknown ids')
265 |
266 | def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
267 | """Converts a token to an id using the vocab, special tokens included"""
268 | if token in self.special_tokens:
269 | return self.special_tokens[token]
270 | if token in self.mergeable_ranks:
271 | return self.mergeable_ranks[token]
272 | raise ValueError('unknown token')
273 |
274 | def _tokenize(self, text: str, **kwargs):
275 | """
276 | Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
277 | vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
278 |
279 | Do NOT take care of added tokens.
280 | """
281 | raise NotImplementedError
282 |
283 | def _decode(
284 | self,
285 | token_ids: Union[int, List[int]],
286 | skip_special_tokens: bool = False,
287 | errors: str = None,
288 | **kwargs,
289 | ) -> str:
290 | if isinstance(token_ids, int):
291 | token_ids = [token_ids]
292 | if skip_special_tokens:
293 | token_ids = [i for i in token_ids if i < self.eod_id]
294 | return self.tokenizer.decode(token_ids, errors=errors or self.errors)
295 |
296 |
297 | tokenizer = QWenTokenizer(Path(__file__).resolve().parent / 'qwen.tiktoken')
298 |
299 |
300 | def count_tokens(text):
301 | tokens = tokenizer.tokenize(text)
302 | return len(tokens)
303 |
--------------------------------------------------------------------------------
/qwen_agent/utils/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import hashlib
3 | import json
4 | import os
5 | import re
6 | import shutil
7 | import socket
8 | import sys
9 | import traceback
10 | import urllib
11 | from typing import Dict, List, Literal, Optional, Union
12 | from urllib.parse import urlparse
13 |
14 | import jieba
15 | import json5
16 | import requests
17 | from jieba import analyse
18 |
19 | from qwen_agent.log import logger
20 |
21 |
22 | def get_local_ip():
23 | s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
24 | try:
25 | # doesn't even have to be reachable
26 | s.connect(('10.255.255.255', 1))
27 | ip = s.getsockname()[0]
28 | except Exception:
29 | ip = '127.0.0.1'
30 | finally:
31 | s.close()
32 | return ip
33 |
34 |
35 | def hash_sha256(key):
36 | hash_object = hashlib.sha256(key.encode())
37 | key = hash_object.hexdigest()
38 | return key
39 |
40 |
41 | def print_traceback(is_error=True):
42 | if is_error:
43 | logger.error(''.join(traceback.format_exception(*sys.exc_info())))
44 | else:
45 | logger.warning(''.join(traceback.format_exception(*sys.exc_info())))
46 |
47 |
48 | def has_chinese_chars(data) -> bool:
49 | text = f'{data}'
50 | return len(re.findall(r'[\u4e00-\u9fff]+', text)) > 0
51 |
52 |
53 | def get_basename_from_url(url: str) -> str:
54 | basename = os.path.basename(urlparse(url).path)
55 | basename = urllib.parse.unquote(basename)
56 | return basename.strip()
57 |
58 |
59 | def is_local_path(path):
60 | if path.startswith('https://') or path.startswith('http://'):
61 | return False
62 | return True
63 |
64 |
65 | def save_url_to_local_work_dir(url, base_dir, new_name=''):
66 | if not new_name:
67 | new_name = get_basename_from_url(url)
68 | new_path = os.path.join(base_dir, new_name)
69 | if os.path.exists(new_path):
70 | os.remove(new_path)
71 | logger.info(f'download {url} to {new_path}')
72 | start_time = datetime.datetime.now()
73 | if is_local_path(url):
74 | shutil.copy(url, new_path)
75 | else:
76 | headers = {
77 | 'User-Agent':
78 | 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
79 | }
80 | response = requests.get(url, headers=headers)
81 | if response.status_code == 200:
82 | with open(new_path, 'wb') as file:
83 | file.write(response.content)
84 | else:
85 | raise ValueError(
86 | 'Can not download this file. Please check your network or the file link.'
87 | )
88 | end_time = datetime.datetime.now()
89 | logger.info(f'Time: {str(end_time - start_time)}')
90 | return new_path
91 |
92 |
93 | def is_image(filename):
94 | filename = filename.lower()
95 | for ext in ['jpg', 'jpeg', 'png', 'webp']:
96 | if filename.endswith(ext):
97 | return True
98 | return False
99 |
100 |
101 | def get_current_date_str(
102 | lang: Literal['en', 'zh'] = 'en',
103 | hours_from_utc: Optional[int] = None,
104 | ) -> str:
105 | if hours_from_utc is None:
106 | cur_time = datetime.datetime.now()
107 | else:
108 | cur_time = datetime.datetime.utcnow() + datetime.timedelta(
109 | hours=hours_from_utc)
110 | if lang == 'en':
111 | date_str = 'Current date: ' + cur_time.strftime('%A, %B %d, %Y')
112 | elif lang == 'zh':
113 | cur_time = cur_time.timetuple()
114 | date_str = f'当前时间:{cur_time.tm_year}年{cur_time.tm_mon}月{cur_time.tm_mday}日,星期'
115 | date_str += ['一', '二', '三', '四', '五', '六', '日'][cur_time.tm_wday]
116 | date_str += '。'
117 | else:
118 | raise NotImplementedError
119 | return date_str
120 |
121 |
122 | def save_text_to_file(path, text):
123 | with open(path, 'w', encoding='utf-8') as fp:
124 | fp.write(text)
125 |
126 |
127 | def read_text_from_file(path):
128 | with open(path, 'r', encoding='utf-8') as file:
129 | file_content = file.read()
130 | return file_content
131 |
132 |
133 | def contains_html_tags(text):
134 | pattern = r'<(p|span|div|li|html|script)[^>]*?'
135 | return bool(re.search(pattern, text))
136 |
137 |
138 | def get_file_type(path):
139 | # This is a temporary plan
140 | if is_local_path(path):
141 | try:
142 | content = read_text_from_file(path)
143 | except Exception:
144 | print_traceback()
145 | return 'Unknown'
146 |
147 | if contains_html_tags(content):
148 | return 'html'
149 | else:
150 | return 'Unknown'
151 | else:
152 | headers = {
153 | 'User-Agent':
154 | 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
155 | }
156 | response = requests.get(path, headers=headers)
157 | if response.status_code == 200:
158 | if contains_html_tags(response.text):
159 | return 'html'
160 | else:
161 | return 'Unknown'
162 | else:
163 | print_traceback()
164 | return 'Unknown'
165 |
166 |
167 | ignore_words = [
168 | '', ' ', '\t', '\n', '\\', 'is', 'are', 'am', 'what', 'how', '的', '吗', '是',
169 | '了', '啊', '呢', '怎么', '如何', '什么', '?', '?', '!', '!', '“', '”', '‘', '’',
170 | "'", "'", '"', '"', ':', ':', '讲了', '描述', '讲', '说说', '讲讲', '介绍', '总结下',
171 | '总结一下', '文档', '文章', '文稿', '稿子', '论文', 'PDF', 'pdf', '这个', '这篇', '这', '我',
172 | '帮我', '那个', '下', '翻译'
173 | ]
174 |
175 |
176 | def get_split_word(text):
177 | text = text.lower()
178 | _wordlist = jieba.lcut(text.strip())
179 | wordlist = []
180 | for x in _wordlist:
181 | if x in ignore_words:
182 | continue
183 | wordlist.append(x)
184 | return wordlist
185 |
186 |
187 | def parse_keyword(text):
188 | try:
189 | res = json5.loads(text)
190 | except Exception:
191 | return get_split_word(text)
192 |
193 | # json format
194 | _wordlist = []
195 | try:
196 | if 'keywords_zh' in res and isinstance(res['keywords_zh'], list):
197 | _wordlist.extend([kw.lower() for kw in res['keywords_zh']])
198 | if 'keywords_en' in res and isinstance(res['keywords_en'], list):
199 | _wordlist.extend([kw.lower() for kw in res['keywords_en']])
200 | wordlist = []
201 | for x in _wordlist:
202 | if x in ignore_words:
203 | continue
204 | wordlist.append(x)
205 | wordlist.extend(get_split_word(res['text']))
206 | return wordlist
207 | except Exception:
208 | return get_split_word(text)
209 |
210 |
211 | def get_key_word(text):
212 | text = text.lower()
213 | _wordlist = analyse.extract_tags(text)
214 | wordlist = []
215 | for x in _wordlist:
216 | if x in ignore_words:
217 | continue
218 | wordlist.append(x)
219 | return wordlist
220 |
221 |
222 | def get_last_one_line_context(text):
223 | lines = text.split('\n')
224 | n = len(lines)
225 | res = ''
226 | for i in range(n - 1, -1, -1):
227 | if lines[i].strip():
228 | res = lines[i]
229 | break
230 | return res
231 |
232 |
233 | def extract_urls(text):
234 | pattern = re.compile(r'https?://\S+')
235 | urls = re.findall(pattern, text)
236 | return urls
237 |
238 |
239 | def extract_obs(text):
240 | k = text.rfind('\nObservation:')
241 | j = text.rfind('\nThought:')
242 | obs = text[k + len('\nObservation:'):j]
243 | return obs.strip()
244 |
245 |
246 | def extract_code(text):
247 | # Match triple backtick blocks first
248 | triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
249 | if triple_match:
250 | text = triple_match.group(1)
251 | else:
252 | try:
253 | text = json5.loads(text)['code']
254 | except Exception:
255 | print_traceback()
256 | # If no code blocks found, return original text
257 | return text
258 |
259 |
260 | def parse_latest_plugin_call(text):
261 | plugin_name, plugin_args = '', ''
262 | i = text.rfind('\nAction:')
263 | j = text.rfind('\nAction Input:')
264 | k = text.rfind('\nObservation:')
265 | if 0 <= i < j: # If the text has `Action` and `Action input`,
266 | if k < j: # but does not contain `Observation`,
267 | # then it is likely that `Observation` is ommited by the LLM,
268 | # because the output text may have discarded the stop word.
269 | text = text.rstrip() + '\nObservation:' # Add it back.
270 | k = text.rfind('\nObservation:')
271 | plugin_name = text[i + len('\nAction:'):j].strip()
272 | plugin_args = text[j + len('\nAction Input:'):k].strip()
273 | text = text[:k]
274 | return plugin_name, plugin_args, text
275 |
276 |
277 | def get_function_description(function: Dict) -> str:
278 | """
279 | Text description of function
280 | """
281 | tool_desc_template = {
282 | 'zh':
283 | '### {name_for_human}\n\n{name_for_model}: {description_for_model} 输入参数:{parameters} {args_format}',
284 | 'en':
285 | '### {name_for_human}\n\n{name_for_model}: {description_for_model} Parameters: {parameters} {args_format}'
286 | }
287 | if has_chinese_chars(function):
288 | tool_desc = tool_desc_template['zh']
289 | else:
290 | tool_desc = tool_desc_template['en']
291 |
292 | name = function.get('name', None)
293 | name_for_human = function.get('name_for_human', name)
294 | name_for_model = function.get('name_for_model', name)
295 | assert name_for_human and name_for_model
296 | args_format = function.get('args_format', '')
297 | return tool_desc.format(name_for_human=name_for_human,
298 | name_for_model=name_for_model,
299 | description_for_model=function['description'],
300 | parameters=json.dumps(function['parameters'],
301 | ensure_ascii=False),
302 | args_format=args_format).rstrip()
303 |
304 |
305 | def format_knowledge_to_source_and_content(
306 | result: Union[str, List[dict]]) -> List[dict]:
307 | knowledge = []
308 | if isinstance(result, str):
309 | result = f'{result}'.strip()
310 | docs = json5.loads(result)
311 | else:
312 | docs = result
313 | try:
314 | _tmp_knowledge = []
315 | assert isinstance(docs, list)
316 | for doc in docs:
317 | url, snippets = doc['url'], doc['text']
318 | assert isinstance(snippets, list)
319 | _tmp_knowledge.append({'source': f'[文件]({url})', 'content': '\n\n...\n\n'.join(snippets)})
320 | knowledge.extend(_tmp_knowledge)
321 | except Exception:
322 | print_traceback()
323 | knowledge.append({'source': '上传的文档', 'content': result})
324 | return knowledge
325 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | git+https://github.com/Yakifo/amqtt#egg=amqtt
2 |
3 | colorlog
4 | zeroconf
5 | streamlit
6 | anyio>=3.7.1
7 | beautifulsoup4
8 | dashscope>=1.11.0
9 | fastapi>=0.103.1
10 | html2text
11 | jieba
12 | json5
13 | jsonlines
14 | jupyter>=1.0.0
15 | matplotlib
16 | numpy
17 | openai
18 | pandas
19 | pdfminer-six
20 | pillow
21 | pydantic>=2.3.0
22 | python-docx
23 | python-pptx
24 | requests
25 | seaborn
26 | sympy
27 | tiktoken
--------------------------------------------------------------------------------
/server.py:
--------------------------------------------------------------------------------
1 | from server import HomeRPC
2 |
3 | if __name__ == '__main__':
4 | # 启动HomeRPC
5 | HomeRPC.setup(ip = "192.168.43.9", log = True)
6 |
7 | # 等待ESP32连接
8 | input("Waiting for ESP32 to connect...")
9 |
10 | place = HomeRPC.place("room")
11 | # 调用ESP32客户端服务
12 | place.device("light").id(1).call("trigger", 1, timeout_s = 10)
13 | print("led status: ", place.device("light").id(1).call("status", timeout_s = 10))
--------------------------------------------------------------------------------
/server/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import HomeRPC
2 |
3 | __all__ = ['HomeRPC']
--------------------------------------------------------------------------------
/server/base.py:
--------------------------------------------------------------------------------
1 | from .log import Log
2 | from .broker import MqttBroker
3 | from .rule import Rule
4 | from multiprocessing import Process, Queue
5 | from json import dumps, loads
6 | from struct import unpack, pack
7 | from typing import List
8 | import asyncio
9 |
10 | _topic_registry = {}
11 | _recv, _send = Queue(), Queue()
12 | _topic = Queue()
13 | _rpc_log = None
14 |
15 | def dict2funcs() -> List[dict]:
16 | global _topic_registry
17 | data = _topic_registry
18 |
19 | functions = []
20 |
21 | type_signature_mapping = {
22 | 'c': {'type': 'string', 'description': 'a character'},
23 | 'i': {'type': 'integer', 'description': 'a number'},
24 | 'f': {'type': 'number', 'description': 'a single-precision floating-point number'},
25 | 'd': {'type': 'number', 'description': 'a double-precision floating-point number'},
26 | }
27 |
28 | for place, devices in data.items():
29 | for device_type, device_list in devices.items():
30 | for device in device_list:
31 | for service in device['services']:
32 | function = {
33 | 'name': service['name'],
34 | 'description': service['desc'],
35 | 'parameters': {
36 | 'type': 'object',
37 | 'properties': {
38 | 'place': {
39 | 'type': 'string',
40 | 'description': 'where the device is located',
41 | 'enum': [place]
42 | },
43 | 'device_type': {
44 | 'type': 'string',
45 | 'description': 'what type of device',
46 | 'enum': [device_type]
47 | },
48 | 'device_id': {
49 | 'type': 'integer',
50 | 'description': 'the id of the device',
51 | 'enum': [device['id']]
52 | }
53 | },
54 | 'required': ['place', 'device_type', 'device_id']
55 | }
56 | }
57 | if service['input_type']:
58 | function['parameters']['properties']['input_type'] = type_signature_mapping[service['input_type']]
59 | function['parameters']['required'].append('input_type')
60 | functions.append(function)
61 | return functions
62 |
63 | def register_device(device: dict):
64 | place = device['place']
65 | type_ = device['type']
66 | id_ = device['id']
67 | services = device['services']
68 |
69 | if place not in _topic_registry:
70 | _topic_registry[place] = {}
71 |
72 | if type_ not in _topic_registry[place]:
73 | _topic_registry[place][type_] = []
74 |
75 | _topic_registry[place][type_].append({'id': id_, 'services': services})
76 |
77 | def update_registry():
78 | while True:
79 | try:
80 | topic = dict(_topic.get_nowait())
81 | except Exception:
82 | break
83 | else:
84 | register_device(topic)
85 | _rpc_log.log(f'Registered device: {_topic_registry}')
86 |
87 | def deserialize_rpc_any(buffer: str, data_type: str) -> any:
88 | res = bytes.fromhex(buffer.ljust(16, '0'))
89 |
90 | if data_type == 'f':
91 | return unpack('f', res[:4])[0]
92 | elif data_type == 'i':
93 | return unpack('q', res[:8])[0]
94 | elif data_type == 'd':
95 | return unpack('d', res[:8])[0]
96 | elif data_type == 'c':
97 | ascii_value = unpack('B', res[:1])[0]
98 | return chr(ascii_value)
99 | else:
100 | raise ValueError('Unsupported data type: ' + data_type)
101 |
102 | def serialize_rpc_any(data_list: List[any], data_type: str) -> List[str]:
103 | res = []
104 |
105 | for data, _type in zip(data_list, data_type):
106 | if _type == 'f':
107 | res.append(pack('f', data).hex().ljust(8, '0'))
108 | elif _type == 'i':
109 | res.append(pack('q', data).hex().ljust(16, '0'))
110 | elif _type == 'd':
111 | res.append(pack('d', data).hex().ljust(16, '0'))
112 | elif _type == 'c':
113 | res.append(pack('B', ord(data)).hex().ljust(2, '0'))
114 | else:
115 | raise ValueError('Unsupported data type: ' + _type)
116 |
117 | return res
118 |
119 | def _asyncio_loop(ip: str, log: bool, recv: Queue, send: Queue, topic: Queue):
120 | loop = asyncio.get_event_loop()
121 | rule_engine = Rule(loop, recv, send, topic, log)
122 |
123 | class Server(MqttBroker):
124 | def __init__(self, ip, log):
125 | super().__init__(loop, ip, log)
126 |
127 | def _init_task(self):
128 | rule_engine.listen()
129 |
130 | server = Server(ip, log)
131 | server.loop_forever()
132 |
133 | class HomeRPC():
134 | @staticmethod
135 | def setup(ip: str = None, log: bool = True):
136 | global _rpc_log
137 | _rpc_log = Log(disable = not log)
138 | Process(target = _asyncio_loop, args = (ip, log, _recv, _send, _topic), daemon = True).start()
139 |
140 | @staticmethod
141 | def place(name: str):
142 |
143 | class Place:
144 | def device(self, device_name: str):
145 |
146 | class Device:
147 | def id(self, device_id: int):
148 |
149 | class Id:
150 | def call(self, func: str, *args, **kwargs):
151 | update_registry()
152 | timeout = kwargs.get("timeout_s", 10) * 1000
153 |
154 | if name in _topic_registry and device_name in _topic_registry[name]:
155 | for device_info in _topic_registry[name][device_name]:
156 | if device_info['id'] != device_id:
157 | continue
158 |
159 | for service in device_info['services']:
160 | if service['name'] != func:
161 | continue
162 |
163 | if len(args) != len(service['input_type']):
164 | _rpc_log.log_error(f'Function {func} requires {len(service["input_type"])} arguments, but {len(args)} were given')
165 | return None
166 |
167 | message = {
168 | "callback": "/callback/master",
169 | }
170 | if len(args):
171 | message["params"] = serialize_rpc_any(args, service['input_type'])
172 |
173 | _send.put(dumps({
174 | "topic": f"/{name}/{device_name}/{device_id}/{func}",
175 | "message": dumps(message)
176 | }))
177 |
178 | try:
179 | data = loads(_recv.get(timeout = timeout))
180 | return deserialize_rpc_any(data['result'], service['output_type'])
181 | except Exception as e:
182 | _rpc_log.log_error(e)
183 | return None
184 |
185 | _rpc_log.log_error(f'Device {name}/{device_name}/{device_id} does not exist')
186 | return None
187 | return Id()
188 | return Device()
189 | return Place()
190 |
191 | @staticmethod
192 | def funcs():
193 | update_registry()
194 | return dict2funcs()
--------------------------------------------------------------------------------
/server/broker.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from amqtt.broker import Broker
3 | from zeroconf import ServiceInfo, Zeroconf, IPVersion
4 | from .log import Log
5 | from socket import inet_aton
6 |
7 | class MqttBroker:
8 | def __init__(self, loop: asyncio.AbstractEventLoop = None, ip: str = None , log: bool = True):
9 | self.zeroconf = Zeroconf(ip_version = IPVersion.All)
10 | self.info = ServiceInfo(
11 | "_mqtt._tcp.local.",
12 | "broker._mqtt._tcp.local.",
13 | addresses = [inet_aton(ip),],
14 | port = 3000,
15 | properties = { "name": "MQTT Broker", "path": "/mqtt" },
16 | server = "homerpc.local.",
17 | )
18 |
19 | self.log = Log(disable = not log)
20 |
21 | self.loop = loop or asyncio.get_event_loop()
22 |
23 | async def _broker_coro(self):
24 | broker = Broker(config = {
25 | "listeners": {
26 | "default": {
27 | "type": "ws",
28 | "bind": "0.0.0.0:%d" % self.info.port,
29 | "max_connections": 0
30 | }
31 | },
32 | "sys_interval": 10,
33 | "auth": {
34 | "allow-anonymous": True,
35 | "plugins": ['auth.anonymous']
36 | },
37 | "topic-check": {
38 | "enabled": False
39 | }
40 | })
41 | await broker.start()
42 |
43 | def _register_service(self):
44 | self.zeroconf.register_service(self.info)
45 |
46 | def _init_task(self):
47 | raise NotImplementedError
48 |
49 | def loop_forever(self):
50 | self._register_service()
51 |
52 | try:
53 | self.loop.run_until_complete(self._broker_coro())
54 | self._init_task()
55 | self.loop.run_forever()
56 | except Exception as e:
57 | self.loop.stop()
58 | self.log.log_error(e)
--------------------------------------------------------------------------------
/server/log.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import colorlog
3 | import traceback
4 |
5 | class Log:
6 | def __init__(self, log_file=None, log_level=logging.INFO, disable=False):
7 | if disable:
8 | log_level = logging.CRITICAL + 1
9 |
10 | log_format = (
11 | "%(log_color)s[%(asctime)s] :: %(levelname)s :: %(name)s :: %(message)s%(reset)s"
12 | )
13 |
14 | colorlog.basicConfig(
15 | filename=log_file,
16 | level=log_level,
17 | format=log_format,
18 | log_colors={
19 | "DEBUG": "green",
20 | "INFO": "cyan",
21 | "WARNING": "yellow",
22 | "ERROR": "red",
23 | "CRITICAL": "red,bg_white",
24 | },
25 | )
26 |
27 | def log(self, message):
28 | logging.info(message)
29 |
30 | def log_error(self, message):
31 | error_info = traceback.format_exc()
32 | logging.error(f"{message} => {error_info}")
--------------------------------------------------------------------------------
/server/rule.py:
--------------------------------------------------------------------------------
1 | from .log import Log
2 | import asyncio
3 | import json
4 | from amqtt.client import MQTTClient
5 | from amqtt.mqtt.constants import QOS_1
6 | from multiprocessing import Queue
7 |
8 | class Rule:
9 | def __init__(self, loop: asyncio.AbstractEventLoop = None, \
10 | recv: Queue = None, send: Queue = None, topic: Queue = None, log: Log = True):
11 | self.log = Log(disable = not log)
12 | self.loop = loop or asyncio.get_event_loop()
13 | (self.recv, self.send) = (recv, send)
14 | self.topic = topic
15 |
16 | async def _subscribe(self):
17 | self.client = MQTTClient(
18 | client_id = "self",
19 | config = {
20 | "keep_alive": 60,
21 | "reconnect_max_interval": 5,
22 | "reconnect_retries": 5,
23 | "ping_delay": 2,
24 | })
25 | try:
26 | await self.client.connect('ws://127.0.0.1:3000/')
27 | await self.client.subscribe([
28 | ('/topics/register', QOS_1),
29 | ('/callback/master', QOS_1),
30 | ])
31 | except Exception as ce:
32 | self.log.log_error(ce)
33 |
34 | async def _listen_mqtt(self):
35 | while True:
36 | message = await self.client.deliver_message()
37 | packet = message.publish_packet
38 | try:
39 | data = packet.payload.data.decode('utf-8')
40 | json_data = json.loads(data)
41 | self.log.log(json_data)
42 | if (packet.variable_header.topic_name == "/topics/register"):
43 | self.loop.run_in_executor(None, self.topic.put, json_data)
44 | elif (packet.variable_header.topic_name == "/callback/master"):
45 | self.loop.run_in_executor(None, self.recv.put, data)
46 | except Exception as e:
47 | self.log.log_error(e)
48 |
49 | async def _listen_queue(self):
50 | while True:
51 | data = await self.loop.run_in_executor(None, self.send.get)
52 | data = json.loads(data)
53 | self.log.log(data)
54 | await self._call(data['topic'], bytearray(data['message'], 'utf-8'))
55 |
56 | async def _call(self, topic, message):
57 | await self.client.subscribe([
58 | (topic, QOS_1),
59 | ])
60 | await self.client.publish(topic, message, qos = QOS_1)
61 |
62 | def listen(self):
63 | self.loop.run_until_complete(self._subscribe())
64 | self.loop.create_task(self._listen_mqtt())
65 | self.loop.create_task(self._listen_queue())
--------------------------------------------------------------------------------
/src/home_rpc.c:
--------------------------------------------------------------------------------
1 | #include "freertos/FreeRTOS.h"
2 | #include "freertos/task.h"
3 | #include "freertos/event_groups.h"
4 | #include "freertos/queue.h"
5 | #include
6 | #include
7 | #include "HomeRPC.h"
8 | #include "rpc_mdns.h"
9 | #include "rpc_mesh.h"
10 | #include "rpc_mqtt.h"
11 | #include "rpc_log.h"
12 | #include "rpc_data.h"
13 | #include "esp_event.h"
14 | #include "esp_err.h"
15 | #include "esp_netif.h"
16 | #include "esp_mac.h"
17 |
18 | static DeviceList_t *device_list_end = NULL;
19 | static SemaphoreHandle_t addDeviceMutex = NULL;
20 |
21 | static const char *TAG = "core";
22 | typedef struct {
23 | QueueHandle_t device_queue;
24 | EventGroupHandle_t event_group;
25 | QueueHandle_t response_queue;
26 | DeviceList_t *device_list;
27 | char uri[32];
28 | } SharedData_t;
29 | static SharedData_t SharedData;
30 |
31 | static TimerHandle_t timer;
32 |
33 | static void mdns_search() {
34 | static esp_ip4_addr_t addr = { 0 };
35 | if (addr.addr != 0) return;
36 | rpc_log.log_info(TAG, "MDNS Search");
37 | esp_err_t err = rpc_mdns_search(CONFIG_BROKER_URL, &addr);
38 | if (err == ESP_OK) {
39 | snprintf(SharedData.uri, sizeof(SharedData.uri), "ws://%d.%d.%d.%d:3000", IP2STR(&addr));
40 | xTaskCreatePinnedToCore(rpc_mqtt_task, "event_loop", 4096, (void *)&SharedData, 5, NULL, 0);
41 | xTimerDelete(timer, portMAX_DELAY);
42 | }
43 | }
44 |
45 | static void got_ip_event_handler(void *arg, esp_event_base_t event_base, int32_t event_id, void *event_data) {
46 | rpc_log.log_info(TAG, "Got IP");
47 | static bool task = false;
48 | if (!task) {
49 | RPC_ERROR_CHECK(TAG, rpc_mdns_init());
50 | timer = xTimerCreate("mdns_search", 1000 / portTICK_PERIOD_MS, pdTRUE, NULL, mdns_search);
51 | xTimerStart(timer, 0);
52 | task = true;
53 | }
54 | }
55 |
56 | // only called once
57 | static void rpc_start(void) {
58 | static bool initialized = false;
59 | if (initialized)
60 | return;
61 |
62 | rpc_log.log_info(TAG, "Setup");
63 |
64 | rpc_mesh_init(got_ip_event_handler);
65 |
66 | uint8_t mac[6];
67 | esp_efuse_mac_get_default(mac);
68 | snprintf(callback_topic, sizeof(callback_topic), "/callback/%02X%02X%02X", mac[3], mac[4], mac[5]);
69 |
70 | SharedData.device_queue = xQueueCreate(10, sizeof(Device_t *));
71 | SharedData.event_group = xEventGroupCreate();
72 | SharedData.response_queue = xQueueCreate(10, sizeof(rpc_any_t));
73 | addDeviceMutex = xSemaphoreCreateMutex();
74 |
75 | initialized = true;
76 | }
77 |
78 | // thread-safe
79 | void rpc_addDevice(const Device_t *dev) {
80 | if (xSemaphoreTake(addDeviceMutex, portMAX_DELAY) == pdTRUE) {
81 | DeviceList_t *new_device = malloc(sizeof(DeviceList_t));
82 | if (new_device == NULL) {
83 | rpc_log.log_error(TAG, "Failed to allocate memory for new device");
84 | xSemaphoreGive(addDeviceMutex);
85 | return;
86 | }
87 | new_device->device = malloc(sizeof(Device_t));
88 | if (new_device->device == NULL) {
89 | rpc_log.log_error(TAG, "Failed to allocate memory for new device data");
90 | free(new_device);
91 | xSemaphoreGive(addDeviceMutex);
92 | return;
93 | }
94 | memcpy(new_device->device, dev, sizeof(Device_t));
95 | new_device->next = NULL;
96 |
97 | static EventBits_t nextWaitBit = 1;
98 | for (unsigned int i = 0; i < new_device->device->services_num; i++) {
99 | new_device->device->services[i]._wait = nextWaitBit;
100 | nextWaitBit <<= 1;
101 | }
102 |
103 | if (SharedData.device_list == NULL) {
104 | SharedData.device_list = new_device;
105 | device_list_end = new_device;
106 | } else {
107 | device_list_end->next = new_device;
108 | device_list_end = new_device;
109 | }
110 |
111 | rpc_log.log_info(TAG, "Add Device");
112 | if (xQueueSendToBack(SharedData.device_queue, &new_device->device, 0) != pdTRUE) {
113 | rpc_log.log_error(TAG, "Failed to add device to queue");
114 | }
115 |
116 | xSemaphoreGive(addDeviceMutex);
117 | }
118 | }
119 |
120 | // thread-safe
121 | static rpc_any_t rpc_callService(const Device_t *dev, const char *name,
122 | const rpc_any_t *params, unsigned int params_num, const TickType_t timeout_s) {
123 | rpc_log.log_info(TAG, "Call Service");
124 | rpc_mqtt_call(dev, name, params, params_num);
125 | rpc_any_t res;
126 | if (xQueueReceive(SharedData.response_queue, &res, pdMS_TO_TICKS(timeout_s)) == pdTRUE) {
127 | rpc_log.log_info(TAG, "Service called");
128 | return res;
129 | }
130 | rpc_log.log_error(TAG, "Timeout waiting for response");
131 | return (rpc_any_t) { .i = -1 };
132 | }
133 |
134 | HomeRPC_t HomeRPC = {
135 | .log_enable = true,
136 | .log_level = ESP_LOG_ERROR,
137 | .start = rpc_start,
138 | .addDevice = rpc_addDevice,
139 | ._callService = rpc_callService
140 | };
--------------------------------------------------------------------------------
/src/rpc_data.c:
--------------------------------------------------------------------------------
1 | #include
2 | #include "HomeRPC.h"
3 | #include "rpc_data.h"
4 | #include
5 | #include
6 |
7 | char* serialize_device(const Device_t* device) {
8 | cJSON* json = cJSON_CreateObject();
9 | cJSON_AddStringToObject(json, "place", device->place);
10 | cJSON_AddStringToObject(json, "type", device->type);
11 | cJSON_AddNumberToObject(json, "id", device->id);
12 | cJSON* services = cJSON_CreateArray();
13 | for (unsigned int i = 0; i < device->services_num; i++) {
14 | cJSON* service = cJSON_CreateObject();
15 | cJSON_AddStringToObject(service, "input_type", device->services[i].input_type);
16 | cJSON_AddStringToObject(service, "output_type", &device->services[i].output_type);
17 | cJSON_AddStringToObject(service, "name", device->services[i].name);
18 | cJSON_AddStringToObject(service, "desc", device->services[i].desc);
19 | cJSON_AddItemToArray(services, service);
20 | }
21 | cJSON_AddItemToObject(json, "services", services);
22 | cJSON_AddNumberToObject(json, "services_num", device->services_num);
23 | char* json_string = cJSON_Print(json);
24 | cJSON_Delete(json);
25 | return json_string;
26 | }
27 |
28 | // need to free the returned buffer!
29 | static char* serializeRpcAny(const rpc_any_t* data) {
30 | char* buffer = (char*)malloc(sizeof(rpc_any_t) * 2 + 1);
31 | if (buffer == NULL)
32 | return NULL;
33 |
34 | for (size_t i = 0; i < sizeof(rpc_any_t); i++)
35 | sprintf(buffer + i * 2, "%02x", ((unsigned char*)data)[i]);
36 |
37 | buffer[sizeof(rpc_any_t) * 2] = '\0';
38 | return buffer;
39 | }
40 |
41 | char* serialize_service(const char* callback, const rpc_any_t* params, unsigned int params_num) {
42 | cJSON* json = cJSON_CreateObject();
43 | cJSON_AddStringToObject(json, "callback", callback);
44 | if (params_num == 0) {
45 | char* json_string = cJSON_Print(json);
46 | cJSON_Delete(json);
47 | return json_string;
48 | }
49 | cJSON* params_array = cJSON_CreateArray();
50 | char *param = NULL;
51 | for (unsigned int i = 0; i < params_num; i++) {
52 | param = serializeRpcAny(¶ms[i]);
53 | if (param == NULL) {
54 | cJSON_Delete(json);
55 | return NULL;
56 | }
57 | cJSON_AddItemToArray(params_array, cJSON_CreateString(param));
58 | free(param);
59 | }
60 | cJSON_AddItemToObject(json, "params", params_array);
61 | char* json_string = cJSON_Print(json);
62 | cJSON_Delete(json);
63 | return json_string;
64 | }
65 |
66 | static rpc_any_t deserializeRpcAny(const char* buffer) {
67 | rpc_any_t data;
68 | for (size_t i = 0; i < sizeof(rpc_any_t); i++) {
69 | sscanf(buffer + i * 2, "%02x", (unsigned int*)&((unsigned char*)&data)[i]);
70 | }
71 | return data;
72 | }
73 |
74 | int deserialize_service(cJSON *json, char** callback, rpc_any_t *params, unsigned int* params_num) {
75 | if (json == NULL) {
76 | return -1;
77 | }
78 |
79 | cJSON* callback_json = cJSON_GetObjectItem(json, "callback");
80 | if (callback_json == NULL) {
81 | return -1;
82 | }
83 | *callback = strdup(callback_json->valuestring);
84 |
85 | cJSON* params_array = cJSON_GetObjectItem(json, "params");
86 | if (params_array == NULL) {
87 | *params_num = 0;
88 | return 0;
89 | }
90 |
91 | *params_num = cJSON_GetArraySize(params_array);
92 | if (*params_num > CONFIG_PARAMS_MAX) {
93 | free(*callback);
94 | return -1;
95 | }
96 |
97 | for (unsigned int i = 0; i < *params_num; i++) {
98 | cJSON* param_json = cJSON_GetArrayItem(params_array, i);
99 | if (param_json == NULL) {
100 | free(*callback);
101 | return -1;
102 | }
103 | params[i] = deserializeRpcAny(param_json->valuestring);
104 | }
105 |
106 | return 0;
107 | }
108 |
109 | cJSON* rpc_any_to_json(rpc_any_t data) {
110 | cJSON* json = cJSON_CreateObject();
111 | if (json == NULL) {
112 | return NULL;
113 | }
114 |
115 | const char* str = serializeRpcAny(&data);
116 | if (str == NULL) {
117 | cJSON_Delete(json);
118 | return NULL;
119 | }
120 |
121 | cJSON* result = cJSON_AddStringToObject(json, "result", str);
122 | free((void*)str);
123 | if (result == NULL) {
124 | cJSON_Delete(json);
125 | return NULL;
126 | }
127 |
128 | return json;
129 | }
130 |
131 | rpc_any_t json_to_rpc_any(cJSON* json) {
132 | cJSON* result = cJSON_GetObjectItem(json, "result");
133 | if (result == NULL || strlen(result->valuestring) < sizeof(rpc_any_t) * 2) {
134 | return (rpc_any_t) { .i = -1 };
135 | }
136 |
137 | return deserializeRpcAny(result->valuestring);
138 | }
--------------------------------------------------------------------------------
/src/rpc_log.c:
--------------------------------------------------------------------------------
1 | #include "esp_log.h"
2 | #include "rpc_log.h"
3 | #include
4 | #include
5 | #include "HomeRPC.h"
6 |
7 | void log_info(const char* tag, const char* format, ...) {
8 | if (HomeRPC.log_enable && HomeRPC.log_level <= ESP_LOG_INFO) {
9 | va_list args;
10 | va_start(args, format);
11 | printf("%s[HomeRPC::%s]: ", LOG_COLOR_I, tag);
12 | vprintf(format, args);
13 | printf("%s\n", LOG_RESET_COLOR);
14 | va_end(args);
15 | }
16 | }
17 |
18 | void log_error(const char* tag, const char* format, ...) {
19 | if (HomeRPC.log_enable && HomeRPC.log_level <= ESP_LOG_ERROR) {
20 | va_list args;
21 | va_start(args, format);
22 | printf("%s[HomeRPC::%s]: ", LOG_COLOR_E, tag);
23 | vprintf(format, args);
24 | printf("%s\n", LOG_RESET_COLOR);
25 | va_end(args);
26 | }
27 | }
28 |
29 | void log_warn(const char* tag, const char* format, ...) {
30 | if (HomeRPC.log_enable && HomeRPC.log_level <= ESP_LOG_WARN) {
31 | va_list args;
32 | va_start(args, format);
33 | printf("%s[HomeRPC::%s]: ", LOG_COLOR_W, tag);
34 | vprintf(format, args);
35 | printf("%s\n", LOG_RESET_COLOR);
36 | va_end(args);
37 | }
38 | }
39 |
40 | rpc_log_t rpc_log = {
41 | .log_info = log_info,
42 | .log_error = log_error,
43 | .log_warn = log_warn,
44 | };
--------------------------------------------------------------------------------
/src/rpc_mdns.c:
--------------------------------------------------------------------------------
1 | #include "mdns.h"
2 | #include "rpc_mdns.h"
3 | #include "rpc_log.h"
4 | #include "esp_err.h"
5 | #include "esp_netif.h"
6 |
7 | static const char *TAG = "rpc_mdns";
8 |
9 | esp_err_t rpc_mdns_init(void)
10 | {
11 | esp_err_t err = mdns_init();
12 | if (err) {
13 | rpc_log.log_error(TAG, "MDNS Init failed");
14 | return err;
15 | }
16 | return ESP_OK;
17 | }
18 |
19 | esp_err_t rpc_mdns_search(const char *host, esp_ip4_addr_t *ip)
20 | {
21 | assert(ip != NULL);
22 | rpc_log.log_info(TAG, "Searching for %s", host);
23 | esp_err_t err = mdns_query_a(host, 500, ip);
24 | if (err) {
25 | if (err == ESP_ERR_NOT_FOUND)
26 | rpc_log.log_info(TAG, "Host was not found!");
27 | else
28 | rpc_log.log_info(TAG, "Query Failed");
29 | return err;
30 | } else {
31 | rpc_log.log_info(TAG, "Found %s at " IPSTR, host, IP2STR(ip));
32 | return ESP_OK;
33 | }
34 | }
35 |
36 | void rpc_mdns_stop(void)
37 | {
38 | mdns_free();
39 | }
--------------------------------------------------------------------------------
/src/rpc_mesh.c:
--------------------------------------------------------------------------------
1 | #include "rpc_mesh.h"
2 | #include "rpc_log.h"
3 | #include "HomeRPC.h"
4 | #include "freertos/FreeRTOS.h"
5 | #include "freertos/task.h"
6 | #include "freertos/timers.h"
7 |
8 | #include "esp_wifi.h"
9 | #include "nvs_flash.h"
10 | #include
11 |
12 | #if ESP_IDF_VERSION >= ESP_IDF_VERSION_VAL(4, 4, 0)
13 | #include "esp_mac.h"
14 | #endif
15 |
16 | #include "esp_bridge.h"
17 | #include "esp_mesh_lite.h"
18 |
19 | static const char *TAG = "rpc_mesh";
20 |
21 | static esp_err_t esp_storage_init(void) {
22 | esp_err_t ret = nvs_flash_init();
23 |
24 | if (ret == ESP_ERR_NVS_NO_FREE_PAGES || ret == ESP_ERR_NVS_NEW_VERSION_FOUND) {
25 | rpc_log.log_info(TAG, "NVS Flash Erase");
26 | RPC_ERROR_CHECK(TAG, nvs_flash_erase());
27 | ret = nvs_flash_init();
28 | }
29 |
30 | rpc_log.log_info(TAG, "NVS Flash Init");
31 | return ret;
32 | }
33 |
34 | static void wifi_init(void) {
35 | // Station
36 | wifi_config_t wifi_config = {
37 | .sta = {
38 | .ssid = CONFIG_ROUTER_SSID,
39 | .password = CONFIG_ROUTER_PASSWORD,
40 | },
41 | };
42 | rpc_log.log_info(TAG, "Setting WiFi Station Config");
43 | esp_bridge_wifi_set_config(WIFI_IF_STA, &wifi_config);
44 |
45 | // Softap
46 | snprintf((char *)wifi_config.ap.ssid, sizeof(wifi_config.ap.ssid), "%s", CONFIG_BRIDGE_SOFTAP_SSID);
47 | strlcpy((char *)wifi_config.ap.password, CONFIG_BRIDGE_SOFTAP_PASSWORD, sizeof(wifi_config.ap.password));
48 | rpc_log.log_info(TAG, "Setting WiFi SoftAP Config");
49 | esp_bridge_wifi_set_config(WIFI_IF_AP, &wifi_config);
50 | }
51 |
52 | void app_wifi_set_softap_info(void) {
53 | char softap_ssid[32];
54 | uint8_t softap_mac[6];
55 | esp_wifi_get_mac(WIFI_IF_AP, softap_mac);
56 | memset(softap_ssid, 0x0, sizeof(softap_ssid));
57 |
58 | #ifdef CONFIG_BRIDGE_SOFTAP_SSID_END_WITH_THE_MAC
59 | snprintf(softap_ssid, sizeof(softap_ssid), "%.25s_%02x%02x%02x", CONFIG_BRIDGE_SOFTAP_SSID, softap_mac[3], softap_mac[4], softap_mac[5]);
60 | #else
61 | snprintf(softap_ssid, sizeof(softap_ssid), "%.32s", CONFIG_BRIDGE_SOFTAP_SSID);
62 | #endif
63 | rpc_log.log_info(TAG, "Setting SoftAP Info");
64 | esp_mesh_lite_set_softap_ssid_to_nvs(softap_ssid);
65 | esp_mesh_lite_set_softap_psw_to_nvs(CONFIG_BRIDGE_SOFTAP_PASSWORD);
66 | esp_mesh_lite_set_softap_info(softap_ssid, CONFIG_BRIDGE_SOFTAP_PASSWORD);
67 | }
68 |
69 | void rpc_mesh_init(esp_event_handler_t got_ip_handler) {
70 | rpc_log.log_info(TAG, "Initializing RPC Mesh");
71 | esp_storage_init();
72 |
73 | RPC_ERROR_CHECK(TAG, esp_netif_init());
74 | RPC_ERROR_CHECK(TAG, esp_event_loop_create_default());
75 |
76 | esp_bridge_create_all_netif();
77 | rpc_log.log_info(TAG, "Creating Netif");
78 |
79 | wifi_init();
80 | rpc_log.log_info(TAG, "Setting WiFi Mode");
81 |
82 | esp_mesh_lite_config_t mesh_lite_config = ESP_MESH_LITE_DEFAULT_INIT();
83 | esp_mesh_lite_init(&mesh_lite_config);
84 |
85 | app_wifi_set_softap_info();
86 |
87 | esp_mesh_lite_start();
88 | rpc_log.log_info(TAG, "Starting Mesh Lite");
89 |
90 | RPC_ERROR_CHECK(TAG, esp_event_handler_instance_register(IP_EVENT, IP_EVENT_STA_GOT_IP, got_ip_handler, NULL, NULL));
91 | rpc_log.log_info(TAG, "RPC Mesh Initialized");
92 | }
--------------------------------------------------------------------------------