├── .env.example ├── .github └── workflows │ └── go.yml ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── agent.gif ├── api └── router.go ├── cmd ├── cli │ └── main.go └── server │ └── main.go ├── examples ├── get_weather.go ├── google_search.go ├── home_assistant.go ├── rag.go └── webpage_summary.go ├── go.mod ├── go.sum ├── knowledge_base ├── apple.txt └── morocco.txt ├── pkg ├── agent │ └── agent.go ├── config │ └── config.go ├── embedding │ ├── default.go │ ├── embedding.go │ └── openai.go ├── ha │ ├── home_assistant.go │ └── model.go ├── llm │ ├── default.go │ ├── llm.go │ ├── message.go │ └── openai.go ├── memory │ ├── buffer.go │ ├── memory.go │ └── token_base.go ├── provider │ └── util.go ├── rag │ ├── document.go │ ├── rag.go │ ├── splitter.go │ └── store.go ├── tool │ ├── execute_ha_devices.go │ ├── get_ha_devices.go │ ├── get_weather.go │ ├── google_search.go │ ├── tool.go │ └── webpage_summary.go └── wechat │ ├── crypt.go │ ├── model.go │ └── signature.go ├── vercel.json └── web ├── chat_handler.go ├── keyword.go ├── router.go ├── user_memory.go ├── util.go └── wechat_handler.go /.env.example: -------------------------------------------------------------------------------- 1 | AGENT_TOOLS= 2 | LLM_MODEL= 3 | OPENAI_API_KEY= 4 | OPENAI_BASE_URL= 5 | WECHAT_TOKEN= 6 | -------------------------------------------------------------------------------- /.github/workflows/go.yml: -------------------------------------------------------------------------------- 1 | # This workflow will build a golang project 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go 3 | 4 | name: Go 5 | 6 | on: 7 | push: 8 | branches: [ "master" ] 9 | pull_request: 10 | branches: [ "master" ] 11 | 12 | jobs: 13 | 14 | build: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v4 18 | 19 | - name: Set up Go 20 | uses: actions/setup-go@v4 21 | with: 22 | go-version: '1.21' 23 | 24 | - name: Build 25 | run: go build -v ./... 26 | 27 | - name: Test 28 | run: go test -v ./... 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | .idea/ -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM golang:1.21.8-alpine AS builder 2 | 3 | WORKDIR /app 4 | 5 | COPY . . 6 | 7 | RUN go mod tidy && go build -o server cmd/server/main.go 8 | 9 | 10 | FROM alpine:3.18 11 | 12 | COPY --from=builder /app/server /server 13 | 14 | RUN chmod +x /server 15 | 16 | EXPOSE 8082 17 | 18 | CMD ["/server"] 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 tonnie17 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # wxagent 2 | 3 | 让公众号瞬间变身为智能助手,支持特性: 4 | 5 | - 接近零成本部署,只需要一个域名即可绑定到公众号 6 | - 支持对话记忆,超时自动回复,以及对话结果回溯 7 | - 支持多种工具调用,包括:获取天气,关键字搜索,网页总结,Home Assistant 设备控制 8 | - 支持构建本地 RAG 知识库进行读取和检索 9 | 10 | 一键部署到 Vercel: 11 | 12 | [![Deploy with Vercel](https://vercel.com/button)](https://vercel.com/new/clone?repository-url=https%3A%2F%2Fgithub.com%2Ftonnie17%2Fwxagent&env=WECHAT_TOKEN,LLM_MODEL,OPENAI_API_KEY,OPENAI_BASE_URL) 13 | 14 | 1. 参考[配置](#配置)填写环境变量,完成部署 15 | 2. 拿到Vercel生成的默认域名进行访问,没问题的话会输出`OK` 16 | 3. 绑定域名到服务器地址,然后配置公众号服务器地址为:`{domain}/wechat/receive/` 17 | 18 | ## 示例 19 | 20 |
点击查看对话效果 21 | 22 | 基础对话: 23 | 24 | basic_dialogue 25 | 26 | 获取天气: 27 | 28 | get_weather 29 | 30 | 文章总结: 31 | 32 | webpage_summary 33 | 34 | 信息搜索: 35 | 36 | google_search 37 | 38 | 知识库检索: 39 | 40 | knowledge_base 41 | 42 | 43 |
44 | 45 | ## 配置 46 | 47 | 所有配置通过环境变量指定,支持通过`.env`文件自动加载: 48 | 49 | - `WECHAT_TOKEN`:公众号服务器配置的令牌(Token) 50 | - `WECHAT_ALLOW_LIST`:允许交互的微信账号(openid),用逗号分隔,默认无限制 51 | - `WECHAT_MEM_TTL`:公众号单轮对话记忆保存时间,默认为`5m` 52 | - `WECHAT_MEM_MSG_SIZE`:公众号单轮对话记忆消息记录上限(包括工具消息),默认为`6` 53 | - `WECHAT_TIMEOUT`:公众号单轮对话超时时间(公众号限制回复时间不能超过5秒),默认为`4s` 54 | - `WECHAT_APP_ID`:公众号 AppID,安全模式下需要指定 55 | - `WECHAT_ENCODING_AES_KEY`:公众号消息加解密密钥 (EncodingAESKey),安全模式下需要指定 56 | - `LLM_PROVIDER`:LLM 提供者,支持:`openai`,默认为`openai` 57 | - `OPENAI_API_KEY`:OpenAI(兼容接口)API KEY 58 | - `OPENAI_BASE_URL`:OpenAI(兼容接口)Base URL 59 | - `SERVER_ADDR`:服务器模式的启动地址,默认为`0.0.0.0:8082` 60 | - `USE_RAG`:是否开启 RAG 从知识库检索查询 61 | - `EMBEDDING_PROVIDER`:文本嵌入模型提供者,支持:`openai`,默认为 `openai` 62 | - `KNOWLEDGE_BASE_PATH`:本地知识库目录路径,目前支持文件格式:`txt`,默认为 `./knowledge_base` 63 | 64 | ### Agent 配置 65 | 66 | - `AGENT_TOOLS`:Agent 可以使用的 Tools,用逗号分隔,需要配置相关的环境变量,支持: 67 | - `google_search`:Google 搜索 68 | - `get_weather`:天气查询 69 | - `webpage_summary`:网页文本总结 70 | - `get_devices`:获取 Home Assistant 设备列表 71 | - `execute_device`:执行 Home Assistant 设备动作 72 | - `AGENT_TIMEOUT`:Agent 对话超时时间,默认为`30s` 73 | - `MAX_TOOL_ITER`:Agent 调用工具最大迭代次数,默认为`3` 74 | - `TOOL_TIMEOUT`:工具调用超时时间,默认为`10s` 75 | - `LLM_MODEL`:LLM 模型名称,默认为`gpt-3.5-turbo` 76 | - `LLM_MAX_TOKENS`:最大输出 Token 数量,默认为`500` 77 | - `LLM_TEMPERATURE`:Temperature 参数,默认为`0.2` 78 | - `LLM_TOP_P`:Top P 参数,默认为`0.9` 79 | - `SYSTEM_PROMPT`:设置 Agent 对话的 System Prompt 80 | - `EMBEDDING_MODEL`:文本嵌入模型,支持:`openai`,默认为 `openai` 81 | 82 | 83 | ### 工具配置 84 | 85 | #### Google 搜索(google_search) 86 | 87 | - `GOOGLE_SEARCH_ENGINE`:Google 搜索引擎 88 | - `GOOGLE_SEARCH_API_KEY`:Google 搜索 API Key 89 | 90 | 91 | #### 天气查询(get_weather) 92 | 93 | - `OPENWEATHERMAP_API_KEY`:OpenWeatherMap API Key 94 | 95 | 96 | #### Home Assistant(get_devices,execute_device) 97 | 98 | - `HA_BASE_URL`:Home Assistant 服务器地址 99 | - `HA_BEARER_TOKEN`:Home Assistant API 验证 Bearer Token 100 | 101 | 102 | ## 扩展使用 103 | 104 |
与Agent交互 105 | 106 | ```go 107 | package main 108 | 109 | import ( 110 | "context" 111 | "fmt" 112 | "log" 113 | "time" 114 | 115 | "github.com/tonnie17/wxagent/pkg/agent" 116 | "github.com/tonnie17/wxagent/pkg/config" 117 | "github.com/tonnie17/wxagent/pkg/llm" 118 | "github.com/tonnie17/wxagent/pkg/memory" 119 | "github.com/tonnie17/wxagent/pkg/tool" 120 | ) 121 | 122 | func main() { 123 | tools := []tool.Tool{ 124 | tool.NewWebPageSummary(), 125 | } 126 | 127 | agent := agent.NewAgent(&config.AgentConfig{ 128 | AgentTools: []string{"webpage_summary"}, 129 | AgentTimeout: 30 * time.Second, 130 | MaxToolIter: 3, 131 | ToolTimeout: 10 * time.Second, 132 | Model: "qwen-plus", 133 | MaxTokens: 500, 134 | Temperature: 0.2, 135 | TopP: 0.9, 136 | }, llm.NewOpenAI(), memory.NewBuffer(6), tools, nil) 137 | 138 | output, err := agent.Chat(context.Background(), "总结一下:https://golangnote.com/golang/golang-stringsbuilder-vs-bytesbuffer") 139 | if err != nil { 140 | log.Fatalf("chat failed: %v", err) 141 | } 142 | 143 | fmt.Println(output) 144 | } 145 | ``` 146 | 147 |
148 | 149 |
自定义工具 150 | 151 | 152 | 要定义一个工具,需要实现`Tool`定义的接口: 153 | 154 | ```go 155 | type Tool interface { 156 | Name() string 157 | Description() string 158 | Schema() map[string]interface{} 159 | Execute(context.Context, string) (string, error) 160 | } 161 | ``` 162 | 163 | 方法含义: 164 | - `Name()`:工具名称 165 | - `Description()`:工具描述,描述尽量清晰以便模型了解选择工具进行调用 166 | - `Schema() map[string]interface{}`:提供工具的参数描述以及定义 167 | - `Execute(context.Context, string) (string, error)`:工具的执行逻辑,接收模型输入,返回执行结果 168 | 169 |
170 | 171 |
接入新模型 172 | 173 | 174 | 要接入新的模型,需要实现`LLM`定义的接口: 175 | 176 | ```go 177 | type LLM interface { 178 | Chat(context.Context, model string, messages []*ChatMessage, options ...ChatOption) (*ChatMessage, error) 179 | } 180 | ``` 181 | 182 |
183 | 184 | ## 本地开发 185 | 186 | 创建配置文件`.env`,将项目配置写入到文件: 187 | 188 | ```sh 189 | echo "{CONFIG_KEY}={CONFIG_VALUE}" > .env 190 | ``` 191 | 192 | 拉取依赖: 193 | 194 | ```sh 195 | go mod tidy 196 | ``` 197 | 198 | 启动服务器: 199 | 200 | ```sh 201 | go run cmd/server/main.go 202 | ``` 203 | 204 | 启动交互式命令行: 205 | 206 | ```sh 207 | go run cmd/cli/main.go 208 | ``` 209 | 210 | Streamlit 调用: 211 | 212 | ```python 213 | with httpx.stream("POST", server_addr, headers=headers, data=json.dumps({ 214 | "messages": st.session_state.messages 215 | }), timeout=30) as r: 216 | for line in r.iter_lines(): 217 | if not line.strip(): 218 | continue 219 | data = json.loads(line) 220 | # ... 221 | ``` 222 | 223 | 224 | 225 | ## 构建 226 | 227 | ### 本地构建 228 | 229 | 编译二进制文件: 230 | 231 | ```sh 232 | go build -o server /cmd/server 233 | ``` 234 | 235 | 运行: 236 | 237 | ```sh 238 | ./server 239 | ``` 240 | 241 | ### Docker 构建 242 | 243 | 创建镜像: 244 | 245 | ```sh 246 | docker build -t wxagent . 247 | ``` 248 | 249 | 运行容器: 250 | 251 | ```sh 252 | docker run -it --rm -p 8082:8082 --env-file .env wxagent 253 | ``` 254 | -------------------------------------------------------------------------------- /agent.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tonnie17/wxagent/b8bcf2c707400bbf384739055baf3e4ab3476932/agent.gif -------------------------------------------------------------------------------- /api/router.go: -------------------------------------------------------------------------------- 1 | package api 2 | 3 | import ( 4 | "context" 5 | "github.com/go-chi/chi/v5" 6 | _ "github.com/joho/godotenv/autoload" 7 | "github.com/tonnie17/wxagent/pkg/config" 8 | "github.com/tonnie17/wxagent/pkg/embedding" 9 | "github.com/tonnie17/wxagent/pkg/rag" 10 | "github.com/tonnie17/wxagent/web" 11 | "log/slog" 12 | "net/http" 13 | "os" 14 | ) 15 | 16 | var router chi.Router 17 | 18 | func init() { 19 | logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) 20 | slog.SetDefault(logger) 21 | 22 | cfg, err := config.LoadConfig() 23 | if err != nil { 24 | slog.Error("config load failed", slog.Any("err", err)) 25 | return 26 | } 27 | 28 | var ragClient *rag.Client 29 | if cfg.UseRAG { 30 | store, err := rag.NewPgVectorStore() 31 | if err != nil { 32 | slog.Error("init vector store failed", slog.Any("err", err)) 33 | return 34 | } 35 | 36 | ragClient = rag.NewClient(embedding.New(cfg.EmbeddingProvider), store) 37 | if err := ragClient.BuildKnowledgeBase(context.Background(), cfg.KnowledgeBasePath, cfg.EmbeddingModel, false); err != nil { 38 | slog.Error("load data failed", slog.Any("err", err)) 39 | return 40 | } 41 | } 42 | 43 | router = chi.NewRouter() 44 | web.SetupRouter(router, cfg, logger, ragClient) 45 | } 46 | 47 | func Handler(w http.ResponseWriter, r *http.Request) { 48 | router.ServeHTTP(w, r) 49 | } 50 | -------------------------------------------------------------------------------- /cmd/cli/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "github.com/chzyer/readline" 7 | "github.com/tonnie17/wxagent/pkg/agent" 8 | "github.com/tonnie17/wxagent/pkg/config" 9 | "github.com/tonnie17/wxagent/pkg/llm" 10 | "github.com/tonnie17/wxagent/pkg/memory" 11 | "github.com/tonnie17/wxagent/pkg/tool" 12 | "log" 13 | "log/slog" 14 | "strings" 15 | ) 16 | 17 | func main() { 18 | rl, err := readline.NewEx(&readline.Config{ 19 | Prompt: "> ", 20 | }) 21 | 22 | if err != nil { 23 | log.Fatal("readline create failed:", err) 24 | } 25 | defer rl.Close() 26 | 27 | cfg, err := config.LoadConfig() 28 | if err != nil { 29 | slog.Error("config load failed", slog.Any("err", err)) 30 | return 31 | } 32 | 33 | a := agent.NewAgent(&cfg.AgentConfig, llm.NewOpenAI(), memory.NewBuffer(6), tool.DefaultTools(), nil) 34 | for { 35 | line, err := rl.Readline() 36 | if err != nil { 37 | break 38 | } 39 | 40 | input := strings.TrimSpace(line) 41 | 42 | output, err := a.Chat(context.Background(), input) 43 | if err != nil { 44 | log.Fatalln(err) 45 | } 46 | 47 | fmt.Println(output) 48 | } 49 | 50 | } 51 | -------------------------------------------------------------------------------- /cmd/server/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "github.com/go-chi/chi/v5" 7 | "github.com/tonnie17/wxagent/pkg/config" 8 | "github.com/tonnie17/wxagent/pkg/embedding" 9 | "github.com/tonnie17/wxagent/pkg/rag" 10 | "github.com/tonnie17/wxagent/web" 11 | "log/slog" 12 | "net/http" 13 | "os" 14 | "os/signal" 15 | "syscall" 16 | "time" 17 | 18 | _ "github.com/joho/godotenv/autoload" 19 | ) 20 | 21 | func main() { 22 | logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo})) 23 | slog.SetDefault(logger) 24 | 25 | cfg, err := config.LoadConfig() 26 | if err != nil { 27 | slog.Error("config load failed", slog.Any("err", err)) 28 | return 29 | } 30 | 31 | var ragClient *rag.Client 32 | if cfg.UseRAG { 33 | store, err := rag.NewPgVectorStore() 34 | if err != nil { 35 | slog.Error("init vector store failed", slog.Any("err", err)) 36 | return 37 | } 38 | defer store.Release() 39 | 40 | ragClient = rag.NewClient(embedding.New(cfg.EmbeddingProvider), store) 41 | if err := ragClient.BuildKnowledgeBase(context.Background(), cfg.KnowledgeBasePath, cfg.EmbeddingModel, false); err != nil { 42 | slog.Error("load data failed", slog.Any("err", err)) 43 | return 44 | } 45 | } 46 | 47 | r := chi.NewRouter() 48 | web.SetupRouter(r, cfg, logger, ragClient) 49 | 50 | server := &http.Server{ 51 | Addr: cfg.ServerAddr, 52 | Handler: r, 53 | } 54 | go func() { 55 | slog.Info("running on " + cfg.ServerAddr) 56 | if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { 57 | slog.Error("http serve failed", slog.Any("err", err)) 58 | } 59 | }() 60 | 61 | quit := make(chan os.Signal, 1) 62 | signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) 63 | <-quit 64 | 65 | shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 66 | defer cancel() 67 | 68 | slog.Info("server shutdown...") 69 | if err := server.Shutdown(shutdownCtx); err != nil { 70 | slog.Info("server shutdown failed") 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /examples/get_weather.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | // +build ignore 3 | 4 | package main 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "github.com/tonnie17/wxagent/pkg/agent" 10 | "github.com/tonnie17/wxagent/pkg/config" 11 | "github.com/tonnie17/wxagent/pkg/llm" 12 | "github.com/tonnie17/wxagent/pkg/memory" 13 | "github.com/tonnie17/wxagent/pkg/tool" 14 | "log/slog" 15 | "os" 16 | 17 | _ "github.com/joho/godotenv/autoload" 18 | ) 19 | 20 | func main() { 21 | logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) 22 | slog.SetDefault(logger) 23 | 24 | cfg, err := config.LoadConfig() 25 | if err != nil { 26 | slog.Error("config load failed", slog.Any("err", err)) 27 | return 28 | } 29 | 30 | ctx := context.Background() 31 | tools := []tool.Tool{ 32 | tool.NewGetWeather(), 33 | } 34 | a := agent.NewAgent(&cfg.AgentConfig, llm.NewOpenAI(), memory.NewBuffer(6), tools, nil) 35 | output, err := a.Chat(ctx, "天气怎么样") 36 | fmt.Println(output, err) 37 | output, err = a.Chat(ctx, "深圳") 38 | fmt.Println(output, err) 39 | } 40 | -------------------------------------------------------------------------------- /examples/google_search.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | // +build ignore 3 | 4 | package main 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "github.com/tonnie17/wxagent/pkg/agent" 10 | "github.com/tonnie17/wxagent/pkg/config" 11 | "github.com/tonnie17/wxagent/pkg/llm" 12 | "github.com/tonnie17/wxagent/pkg/memory" 13 | "github.com/tonnie17/wxagent/pkg/tool" 14 | "log/slog" 15 | "os" 16 | 17 | _ "github.com/joho/godotenv/autoload" 18 | ) 19 | 20 | func main() { 21 | logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) 22 | slog.SetDefault(logger) 23 | 24 | cfg, err := config.LoadConfig() 25 | if err != nil { 26 | slog.Error("config load failed", slog.Any("err", err)) 27 | return 28 | } 29 | 30 | ctx := context.Background() 31 | tools := []tool.Tool{ 32 | tool.NewGoogleSearch(), 33 | } 34 | a := agent.NewAgent(&cfg.AgentConfig, llm.NewOpenAI(), memory.NewBuffer(6), tools, nil) 35 | output, err := a.Chat(ctx, "搜索一下法国的首都在哪里") 36 | fmt.Println(output, err) 37 | } 38 | -------------------------------------------------------------------------------- /examples/home_assistant.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | // +build ignore 3 | 4 | package main 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "github.com/tonnie17/wxagent/pkg/agent" 10 | "github.com/tonnie17/wxagent/pkg/config" 11 | "github.com/tonnie17/wxagent/pkg/llm" 12 | "github.com/tonnie17/wxagent/pkg/memory" 13 | "github.com/tonnie17/wxagent/pkg/tool" 14 | "log/slog" 15 | "os" 16 | 17 | _ "github.com/joho/godotenv/autoload" 18 | ) 19 | 20 | func main() { 21 | logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) 22 | slog.SetDefault(logger) 23 | 24 | cfg, err := config.LoadConfig() 25 | if err != nil { 26 | slog.Error("config load failed", slog.Any("err", err)) 27 | return 28 | } 29 | 30 | ctx := context.Background() 31 | tools := []tool.Tool{ 32 | tool.NewGetHADevices(), 33 | tool.NewExecuteHADevice(), 34 | } 35 | a := agent.NewAgent(&cfg.AgentConfig, llm.NewOpenAI(), memory.NewBuffer(6), tools, nil) 36 | output, err := a.Chat(ctx, "打开客厅灯和厨房灯") 37 | fmt.Println(output, err) 38 | } 39 | -------------------------------------------------------------------------------- /examples/rag.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | // +build ignore 3 | 4 | package main 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "github.com/tonnie17/wxagent/pkg/agent" 10 | "github.com/tonnie17/wxagent/pkg/config" 11 | "github.com/tonnie17/wxagent/pkg/embedding" 12 | "github.com/tonnie17/wxagent/pkg/llm" 13 | "github.com/tonnie17/wxagent/pkg/memory" 14 | "github.com/tonnie17/wxagent/pkg/rag" 15 | "log/slog" 16 | "os" 17 | 18 | _ "github.com/joho/godotenv/autoload" 19 | ) 20 | 21 | func main() { 22 | logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) 23 | slog.SetDefault(logger) 24 | 25 | cfg, err := config.LoadConfig() 26 | if err != nil { 27 | slog.Error("config load failed", slog.Any("err", err)) 28 | return 29 | } 30 | 31 | store, err := rag.NewPgVectorStore() 32 | if err != nil { 33 | slog.Error("init vector store failed", slog.Any("err", err)) 34 | return 35 | } 36 | defer store.Release() 37 | 38 | client := rag.NewClient(embedding.NewOpenAI(), store) 39 | if err := client.BuildKnowledgeBase(context.Background(), cfg.KnowledgeBasePath, cfg.EmbeddingModel, false); err != nil { 40 | slog.Error("load failed", slog.Any("err", err)) 41 | return 42 | } 43 | 44 | a := agent.NewAgent(&cfg.AgentConfig, llm.NewOpenAI(), memory.NewBuffer(6), nil, client) 45 | output, err := a.Chat(context.Background(), "摩洛哥的公路有多少公里") 46 | fmt.Println(output, err) 47 | } 48 | -------------------------------------------------------------------------------- /examples/webpage_summary.go: -------------------------------------------------------------------------------- 1 | //go:build ignore 2 | // +build ignore 3 | 4 | package main 5 | 6 | import ( 7 | "context" 8 | "fmt" 9 | "github.com/tonnie17/wxagent/pkg/agent" 10 | "github.com/tonnie17/wxagent/pkg/config" 11 | "github.com/tonnie17/wxagent/pkg/llm" 12 | "github.com/tonnie17/wxagent/pkg/memory" 13 | "github.com/tonnie17/wxagent/pkg/tool" 14 | "log/slog" 15 | "os" 16 | 17 | _ "github.com/joho/godotenv/autoload" 18 | ) 19 | 20 | func main() { 21 | logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) 22 | slog.SetDefault(logger) 23 | 24 | cfg, err := config.LoadConfig() 25 | if err != nil { 26 | slog.Error("config load failed", slog.Any("err", err)) 27 | return 28 | } 29 | 30 | ctx := context.Background() 31 | tools := []tool.Tool{ 32 | tool.NewWebPageSummary(), 33 | } 34 | a := agent.NewAgent(&cfg.AgentConfig, llm.NewOpenAI(), memory.NewBuffer(6), tools, nil) 35 | output, err := a.Chat(ctx, "总结一下: https://golangnote.com/golang/golang-stringsbuilder-vs-bytesbuffer") 36 | fmt.Println(output, err) 37 | } 38 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/tonnie17/wxagent 2 | 3 | go 1.21.8 4 | 5 | require ( 6 | github.com/PuerkitoBio/goquery v1.9.3 7 | github.com/caarlos0/env/v11 v11.2.2 8 | github.com/chzyer/readline v1.5.1 9 | github.com/go-chi/chi/v5 v5.1.0 10 | github.com/jackc/pgx/v5 v5.7.1 11 | github.com/joho/godotenv v1.5.1 12 | github.com/pgvector/pgvector-go v0.2.2 13 | github.com/pkoukk/tiktoken-go v0.1.7 14 | github.com/samber/slog-chi v1.12.3 15 | github.com/sashabaranov/go-openai v1.35.6 16 | ) 17 | 18 | require ( 19 | github.com/andybalholm/cascadia v1.3.2 // indirect 20 | github.com/dlclark/regexp2 v1.10.0 // indirect 21 | github.com/google/uuid v1.6.0 // indirect 22 | github.com/jackc/pgpassfile v1.0.0 // indirect 23 | github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect 24 | github.com/jackc/puddle/v2 v2.2.2 // indirect 25 | go.opentelemetry.io/otel v1.29.0 // indirect 26 | go.opentelemetry.io/otel/trace v1.29.0 // indirect 27 | golang.org/x/crypto v0.27.0 // indirect 28 | golang.org/x/net v0.29.0 // indirect 29 | golang.org/x/sync v0.8.0 // indirect 30 | golang.org/x/sys v0.27.0 // indirect 31 | golang.org/x/text v0.18.0 // indirect 32 | ) 33 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | entgo.io/ent v0.13.1 h1:uD8QwN1h6SNphdCCzmkMN3feSUzNnVvV/WIkHKMbzOE= 2 | entgo.io/ent v0.13.1/go.mod h1:qCEmo+biw3ccBn9OyL4ZK5dfpwg++l1Gxwac5B1206A= 3 | github.com/PuerkitoBio/goquery v1.9.3 h1:mpJr/ikUA9/GNJB/DBZcGeFDXUtosHRyRrwh7KGdTG0= 4 | github.com/PuerkitoBio/goquery v1.9.3/go.mod h1:1ndLHPdTz+DyQPICCWYlYQMPl0oXZj0G6D4LCYA6u4U= 5 | github.com/andybalholm/cascadia v1.3.2 h1:3Xi6Dw5lHF15JtdcmAHD3i1+T8plmv7BQ/nsViSLyss= 6 | github.com/andybalholm/cascadia v1.3.2/go.mod h1:7gtRlve5FxPPgIgX36uWBX58OdBsSS6lUvCFb+h7KvU= 7 | github.com/caarlos0/env/v11 v11.2.2 h1:95fApNrUyueipoZN/EhA8mMxiNxrBwDa+oAZrMWl3Kg= 8 | github.com/caarlos0/env/v11 v11.2.2/go.mod h1:JBfcdeQiBoI3Zh1QRAWfe+tpiNTmDtcCj/hHHHMx0vc= 9 | github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM= 10 | github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ= 11 | github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI= 12 | github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= 13 | github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= 14 | github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= 15 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 16 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 17 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 18 | github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= 19 | github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= 20 | github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw= 21 | github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= 22 | github.com/go-pg/pg/v10 v10.11.0 h1:CMKJqLgTrfpE/aOVeLdybezR2om071Vh38OLZjsyMI0= 23 | github.com/go-pg/pg/v10 v10.11.0/go.mod h1:4BpHRoxE61y4Onpof3x1a2SQvi9c+q1dJnrNdMjsroA= 24 | github.com/go-pg/zerochecker v0.2.0 h1:pp7f72c3DobMWOb2ErtZsnrPaSvHd2W4o9//8HtF4mU= 25 | github.com/go-pg/zerochecker v0.2.0/go.mod h1:NJZ4wKL0NmTtz0GKCoJ8kym6Xn/EQzXRl2OnAe7MmDo= 26 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= 27 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 28 | github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= 29 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 30 | github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= 31 | github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= 32 | github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= 33 | github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= 34 | github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs= 35 | github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA= 36 | github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= 37 | github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= 38 | github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= 39 | github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= 40 | github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= 41 | github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= 42 | github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= 43 | github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= 44 | github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= 45 | github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= 46 | github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= 47 | github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= 48 | github.com/pgvector/pgvector-go v0.2.2 h1:Q/oArmzgbEcio88q0tWQksv/u9Gnb1c3F1K2TnalxR0= 49 | github.com/pgvector/pgvector-go v0.2.2/go.mod h1:u5sg3z9bnqVEdpe1pkTij8/rFhTaMCMNyQagPDLK8gQ= 50 | github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw= 51 | github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= 52 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 53 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 54 | github.com/samber/slog-chi v1.12.3 h1:jZ09VSMCEytjdAIiNS2pMe7a1mZSfhjI0pdvZv0nTng= 55 | github.com/samber/slog-chi v1.12.3/go.mod h1:uKRFgoHdeoeme1SUzIbhxGoNfY/MluEtokKZigeG5Es= 56 | github.com/sashabaranov/go-openai v1.35.6 h1:oi0rwCvyxMxgFALDGnyqFTyCJm6n72OnEG3sybIFR0g= 57 | github.com/sashabaranov/go-openai v1.35.6/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= 58 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 59 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 60 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 61 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= 62 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 63 | github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo= 64 | github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs= 65 | github.com/uptrace/bun v1.1.12 h1:sOjDVHxNTuM6dNGaba0wUuz7KvDE1BmNu9Gqs2gJSXQ= 66 | github.com/uptrace/bun v1.1.12/go.mod h1:NPG6JGULBeQ9IU6yHp7YGELRa5Agmd7ATZdz4tGZ6z0= 67 | github.com/uptrace/bun/dialect/pgdialect v1.1.12 h1:m/CM1UfOkoBTglGO5CUTKnIKKOApOYxkcP2qn0F9tJk= 68 | github.com/uptrace/bun/dialect/pgdialect v1.1.12/go.mod h1:Ij6WIxQILxLlL2frUBxUBOZJtLElD2QQNDcu/PWDHTc= 69 | github.com/uptrace/bun/driver/pgdriver v1.1.12 h1:3rRWB1GK0psTJrHwxzNfEij2MLibggiLdTqjTtfHc1w= 70 | github.com/uptrace/bun/driver/pgdriver v1.1.12/go.mod h1:ssYUP+qwSEgeDDS1xm2XBip9el1y9Mi5mTAvLoiADLM= 71 | github.com/vmihailenco/bufpool v0.1.11 h1:gOq2WmBrq0i2yW5QJ16ykccQ4wH9UyEsgLm6czKAd94= 72 | github.com/vmihailenco/bufpool v0.1.11/go.mod h1:AFf/MOy3l2CFTKbxwt0mp2MwnqjNEs5H/UxrkA5jxTQ= 73 | github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU= 74 | github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc= 75 | github.com/vmihailenco/tagparser v0.1.2 h1:gnjoVuB/kljJ5wICEEOpx98oXMWPLj22G67Vbd1qPqc= 76 | github.com/vmihailenco/tagparser v0.1.2/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI= 77 | github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= 78 | github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= 79 | github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= 80 | go.opentelemetry.io/otel v1.29.0 h1:PdomN/Al4q/lN6iBJEN3AwPvUiHPMlt93c8bqTG5Llw= 81 | go.opentelemetry.io/otel v1.29.0/go.mod h1:N/WtXPs1CNCUEx+Agz5uouwCba+i+bJGFicT8SR4NP8= 82 | go.opentelemetry.io/otel/trace v1.29.0 h1:J/8ZNK4XgR7a21DZUAsbF8pZ5Jcw1VhACmnYt39JTi4= 83 | go.opentelemetry.io/otel/trace v1.29.0/go.mod h1:eHl3w0sp3paPkYstJOmAimxhiFXPg+MMTlEh3nsQgWQ= 84 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 85 | golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= 86 | golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= 87 | golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= 88 | golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= 89 | golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= 90 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= 91 | golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= 92 | golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= 93 | golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= 94 | golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= 95 | golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= 96 | golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= 97 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 98 | golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 99 | golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 100 | golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= 101 | golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= 102 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 103 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 104 | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 105 | golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 106 | golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 107 | golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 108 | golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 109 | golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 110 | golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s= 111 | golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 112 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 113 | golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= 114 | golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= 115 | golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= 116 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= 117 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= 118 | golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= 119 | golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= 120 | golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= 121 | golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= 122 | golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= 123 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= 124 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= 125 | golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= 126 | golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= 127 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 128 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 129 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 130 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 131 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 132 | gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo= 133 | gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0= 134 | gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls= 135 | gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= 136 | mellium.im/sasl v0.3.1 h1:wE0LW6g7U83vhvxjC1IY8DnXM+EU095yeo8XClvCdfo= 137 | mellium.im/sasl v0.3.1/go.mod h1:xm59PUYpZHhgQ9ZqoJ5QaCqzWMi8IeS49dhp6plPCzw= 138 | -------------------------------------------------------------------------------- /knowledge_base/apple.txt: -------------------------------------------------------------------------------- 1 | 苹果树(学名:Malus domestica)是蔷薇科苹果亚科苹果属植物,为落叶乔木,在世界上广泛种植。苹果,又称柰或林檎,是苹果树的果实,一般呈红色,但需视品种而定,富含矿物质和维生素,是人们最常食用的水果之一。人们根据需求的不同口感、用途(比如烹饪、生吃、酿苹果酒等)培育不同的品种,已知有超过7,500个苹果品种,拥有一系列人们需要的不同特性。 2 | 3 | 苹果起源于中亚,直到今天当地还可以找到苹果的野生祖先:新疆野苹果。苹果在亚洲和欧洲都有着数千年的种植历史,并由欧洲的殖民者带到了北美,是苹果属中生长最广泛的树种。 苹果在北欧、希腊、欧洲基督教传统等许多文化中都有宗教和神话的意义。 4 | 5 | 苹果开花期是基于各地气候而定,但一般集中在4-5月份。苹果是异花授粉植物,大部分品种自花不能结成果实。一般苹果栽种后,于2-3年才开始结果。果实成长期之长短,一般早熟品种为65-87天,中熟品种为90-133天,晚熟品种则为137-168天。在一般情形下,栽种后苹果树可有10-1000年寿命。苹果树如果从种子开始长,就会长得较为高大,可高至15米,但栽培树木一般只高3-5米左右,因此,苹果的各种栽培品种通常通过嫁接在砧木上进行繁殖,砧木决定了最终树木的大小。树干呈灰褐色,树皮有一定程度的脱落。苹果树及其果实很容易出现许多真菌、细菌和害虫问题,这些问题可以通过多种有机和非有机手段加以控制。 2010年,人们对苹果的基因组进行了测序,作为苹果生产中疾病控制和选择育种研究的一部分。2017年全球苹果产量为8310万公吨,中国占总产量的一半。 -------------------------------------------------------------------------------- /knowledge_base/morocco.txt: -------------------------------------------------------------------------------- 1 | 摩洛哥是二元制君主立宪制国家,行政权归于政府,立法权归于政府和议会两院,国王仍然握有行政和立法权力,在军事、外交和宗教事务拥有较大话语权。伊斯兰教是摩洛哥的主要宗教,官方语言包括现代标准阿拉伯语和柏柏尔语,其中后者于2011年获得官方地位。摩洛哥人主要使用摩洛哥阿拉伯语作为口语,其和现代标准阿拉伯语有一定差异;由于历史原因,摩洛哥同时也使用法语和西班牙语。 2 | 3 | 摩洛哥现为阿拉伯国家联盟、非洲联盟和地中海联盟成员国,是非洲第五大经济体。 4 | 5 | 距今40万年前便有古人类在摩洛哥活动。考古发现证明这里很早就有人居住了。最早在这里定居的柏柏尔人的来历已无以考证。腓尼基人,迦太基人和罗马人先后占领这里。7世纪时,阿拉伯人到来,并在8世纪建立王国。中世纪时期在这里统治的朝代有些是柏柏尔人,有些是阿拉伯人。摩洛哥王国的国名来源于阿拉伯语中称呼西方的“马格里布”,中国古籍中有关摩洛哥的较早记载,有宋人赵汝适《诸蕃志》“默伽猎国”。明代《坤舆万国全图》把摩洛哥称为“马逻可国”。 6 | 7 | 从15世纪开始,摩洛哥受到多个西方国家的入侵。20世纪初期的摩洛哥危机是导致第一次世界大战的重要原因之一。1912年3月,法国占领摩洛哥。1956年3月,获得独立。 8 | 9 | 现今摩洛哥阿拉维王朝是在17世纪开始的,1957年,摩洛哥苏丹穆罕默德五世宣布改称国王,摩洛哥苏丹国改制为摩洛哥王国。实行君主立宪制,但君主较内阁和国会拥有相当大的权力,所以是二元君主制。 10 | 11 | 摩洛哥是君主立宪制的多党制的伊斯兰国家,国王拥有比较大的权力。议会分参议院和众议院两院,众议院325名议员由国民直接选举而出,任期为五年。参议院270名由地方代表选出,每三年更换三分之一。 12 | 13 | 穆罕默德六世从1999年登基后加强国家的民主化,对内试图缓和贫困和保证社会治安,对外持比较缓和的伊斯兰国家政策。摩洛哥是中东与欧美对话的重要中间国。2004年6月,美国总统布什宣布给予摩洛哥主要非北约盟友地位。 14 | 15 | 自2011年2月20日起,摩洛哥民众展开民主示威活动。2011年3月9日,穆罕默德六世国王宣布将进行全面的宪政改革,内容有首相产生方式和权力、提高国民的自由和人权、建立独立的司法体系等各个方面,以真正实现民主政治。6月17日,国王提出政改计划,将自限权力建立一个民主的君主立宪体制,计划将在7月1日交付公投。方案内容:将国王许多权力交给首相和国会;未来将实施直接普选,政府首长由胜选政党组成;由首相指派政府官员,包括行政部门和国营企业首长,并且首相将有权力解散国会等。 16 | 17 | 2011年7月1日,摩洛哥举行新宪法草案全民公决。2日凌晨,摩洛哥内政大臣沙尔阿维公布草案以98.49%的赞成票获得通过。这标志着该国家以和平方式迈出了宪政改革的重要一步。 18 | 19 | 新宪法草案的主要内容: 20 | 将议会的权限扩大,特别是加强了两院中众议院的主导地位。众议院在五分之一议员支持的情况下即可对重要官员展开调查,或在获得三分之一议员支持的情况下对大臣进行弹劾,同时还将取代国王行使大赦权力。 21 | 新宪法草案同时强调司法独立,国家的司法体系将由法官与全国人权委员会共同组成的最高委员会掌管,而司法大臣被排除在委员会之外。 22 | 在新宪法草案中,摩洛哥国王仍为国家元首、陆海空三军最高统帅和宗教领袖,并分别担任“大臣委员会”和新成立的“国家安全委员会”主席,掌握着重大决策的最终决定权,同时仍拥有重要地方长官和驻外大使的任命权等。 23 | 24 | 位于非洲大陆西北部,直布罗陀海峡南岸,扼地中海入大西洋的门户,海岸线1700多公里。地形复杂,中部和北部为峻峭的阿特拉斯山脉,东部和南部是上高原和前撒哈拉高原,仅西北沿海一带为狭长低缓的平原。由于斜贯全境的阿特拉斯山阻挡了南部撒哈拉沙漠热浪的侵袭,摩洛哥常年气候宜人,花木繁茂,赢得“烈日下的清凉国土”的美誉,还享有“北非花园”的美称。 25 | 26 | 受副热带高压带控制和加那利寒流影响,形成干燥的热带沙漠气候;阿特拉斯山脉横贯全国,其中图卜卡勒峰(4,165米)是全国最高点。 27 | 28 | 摩洛哥磷酸盐资源极为丰富,估计储量1100亿吨,占世界储量的75%。其它矿产资源有铁、铅、锌、钴、锰、钡、铜、盐、磁铁矿、无烟煤、油页岩等。其中油页岩储量1000亿吨以上,含原油60亿吨。 29 | 30 | 摩洛哥是一个以第三产业经济为主,居中等收入水平的发展中国家,是非洲第五、北非第三大经济体。其主要的经济部门是旅游业、渔业和磷酸矿的出口,磷酸盐储量1100亿吨,占世界首位。农业与牧业也比较重要,但受气候影响比较大。摩洛哥在许多方面依靠外来资助,法国是其第二大援助国,西班牙是其第一大援助国。国民经济发展缓慢。2010年摩洛哥国内生产总值为917亿美元,人均2,839美元。经济增长率3.2%,通货膨胀率则为1.4%。2022年人均3,628美元。 31 | 32 | 陆路交通较发达。在国内运输业中占主导地位,90%的客运和75%的货运通过陆路交通完成。 33 | 34 | 铁路:投入运营线路1907公里,其中复线483公里,电气化铁路1014公里。另有765公里磷酸盐运输线。2003年,摩洛哥与西班牙达成协议,两国共同修建一条穿过直布罗陀海峡的海底复线铁路。该工程曾预计2010年开工,将是连接欧、非两大洲的首条铁路线,但直至2021年,此工程未有任何实质进展。 2018年,卡萨布兰卡-丹吉尔高速铁路通车,是摩洛哥的第一条高速铁路,也是非洲第一条高速铁路。 35 | 36 | 公路:总长64452公里,其中一级公路15907公里,二级公路9367公里,三级公路39178公里。至2009年中,高速公路916公里,有拉巴特-丹吉尔、拉巴特-卡萨布兰卡-塞达特、拉巴特-梅克内斯-非斯等多段高速公路。根据摩第二个全国乡村公路10年计划(2005-2015),每年将新建1500公里公路,届时将使全国80%的农村地区通公路。 37 | 38 | 水运:现拥有港口30个,其中11个为多功能港口,11个为运输、捕鱼用港口。主要港口有卡萨布兰卡、穆罕默迪耶、萨非、丹吉尔、阿加迪尔等,2009年全国总吞吐量6682万吨。其中卡萨布兰卡为全国最大港口,占全国港口总吞吐量的37%;丹吉尔-地中海港一期已完工,正在扩建二期,将成为非洲和地中海最大港口之一。 39 | 40 | 空运:全国共有机场28个,其中国际机场12个,如卡萨布兰卡穆罕默德五世国际机场、拉巴特-塞拉国际机场、马拉喀什-迈纳拉国际机场等。摩王家航空公司有飞机33架,开通75条航线,航线通往四大洲32个国家,总航线30多万公里。2008年客运量1060万人次。2005年,摩与欧盟签署“天空开放”协议,摩航空市场对欧洲航空公司开放。 41 | 42 | 1996年12月,经阿尔及利亚、摩洛哥至西班牙、葡萄牙的马格里布-欧洲天然气管道正式开通。管道全长1385公里,初期每年可输送天然气90亿立方米,摩每年可获10亿立方米天然气。2005年起,摩首次将该管道的过境天然气截留配额用于发电,可满足全国17%的电力需求。 43 | -------------------------------------------------------------------------------- /pkg/agent/agent.go: -------------------------------------------------------------------------------- 1 | package agent 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "errors" 8 | "fmt" 9 | _ "github.com/joho/godotenv/autoload" 10 | "github.com/tonnie17/wxagent/pkg/config" 11 | "github.com/tonnie17/wxagent/pkg/llm" 12 | "github.com/tonnie17/wxagent/pkg/memory" 13 | "github.com/tonnie17/wxagent/pkg/rag" 14 | "github.com/tonnie17/wxagent/pkg/tool" 15 | "html/template" 16 | "log/slog" 17 | "strings" 18 | "time" 19 | ) 20 | 21 | var ( 22 | ErrMemoryInUse = errors.New("memory in use") 23 | PromptFuncMap = template.FuncMap{ 24 | "now": time.Now, 25 | } 26 | ) 27 | 28 | type Agent struct { 29 | config *config.AgentConfig 30 | llm llm.LLM 31 | memory memory.Memory 32 | tools []tool.Tool 33 | ragClient *rag.Client 34 | } 35 | 36 | func NewAgent(config *config.AgentConfig, llm llm.LLM, memory memory.Memory, tools []tool.Tool, ragClient *rag.Client) *Agent { 37 | return &Agent{ 38 | config: config, 39 | llm: llm, 40 | memory: memory, 41 | tools: tools, 42 | ragClient: ragClient, 43 | } 44 | } 45 | 46 | func (a *Agent) Chat(ctx context.Context, input string) (string, error) { 47 | msg, err := a.Process(ctx, &llm.ChatMessage{ 48 | Role: llm.RoleUser, 49 | Content: input, 50 | }) 51 | if err != nil { 52 | return "", err 53 | } 54 | 55 | return msg.Content, nil 56 | } 57 | 58 | func (a *Agent) ChatContinue(ctx context.Context) (string, error) { 59 | if l, ok := a.memory.(memory.Lock); ok && l.IsLocked() { 60 | return "", ErrMemoryInUse 61 | } 62 | 63 | messages, err := a.memory.History() 64 | if err != nil { 65 | return "", err 66 | } 67 | 68 | for i := len(messages) - 1; i > 0; i-- { 69 | msg := messages[i] 70 | if msg.Role == llm.RoleAssistant { 71 | return msg.Content, nil 72 | } 73 | } 74 | return "", nil 75 | } 76 | 77 | func (a *Agent) Process(ctx context.Context, message *llm.ChatMessage) (*llm.ChatMessage, error) { 78 | var err error 79 | out := make(chan *llm.ChatMessage) 80 | go func() { 81 | if err = a.ProcessStream(ctx, message, out); err != nil { 82 | return 83 | } 84 | }() 85 | 86 | var res *llm.ChatMessage 87 | for msg := range out { 88 | res = msg 89 | } 90 | 91 | return res, err 92 | } 93 | 94 | func (a *Agent) ProcessStream(ctx context.Context, message *llm.ChatMessage, outputChan chan<- *llm.ChatMessage) error { 95 | defer close(outputChan) 96 | 97 | if a.config.AgentTimeout != 0 { 98 | timeoutCtx, cancel := context.WithTimeout(ctx, a.config.AgentTimeout) 99 | defer cancel() 100 | ctx = timeoutCtx 101 | } 102 | 103 | if a.ragClient != nil && message != nil { 104 | documents, err := a.ragClient.Query(ctx, a.config.EmbeddingModel, message.Content, 3) 105 | if err != nil { 106 | return err 107 | } 108 | 109 | if len(documents) > 0 { 110 | contexts := make([]string, 0, len(documents)) 111 | for _, doc := range documents { 112 | contexts = append(contexts, doc.Content) 113 | } 114 | message.Content = a.buildRAGPrompt(contexts, message.Content) 115 | } 116 | } 117 | 118 | if l, ok := a.memory.(memory.Lock); ok { 119 | if l.Lock() { 120 | defer l.Release() 121 | } else { 122 | return ErrMemoryInUse 123 | } 124 | } 125 | 126 | var messages []*llm.ChatMessage 127 | if a.config.SystemPrompt != "" { 128 | systemPrompt := a.config.SystemPrompt 129 | if tmp, err := template.New("systemPrompt").Funcs(PromptFuncMap).Parse(systemPrompt); err == nil { 130 | promptTpl := new(bytes.Buffer) 131 | if tmp.Execute(promptTpl, nil) == nil { 132 | systemPrompt = promptTpl.String() 133 | } 134 | } 135 | messages = append(messages, &llm.ChatMessage{ 136 | Role: llm.RoleSystem, 137 | Content: systemPrompt, 138 | }) 139 | } 140 | 141 | history, err := a.memory.History() 142 | if err != nil { 143 | return err 144 | } 145 | 146 | messages = append(messages, history...) 147 | if message != nil { 148 | messages = append(messages, message) 149 | } 150 | 151 | toolsMap := make(map[string]tool.Tool, len(a.tools)) 152 | for _, t := range a.tools { 153 | toolsMap[t.Name()] = t 154 | } 155 | 156 | for i := 0; i < a.config.MaxToolIter; i++ { 157 | select { 158 | case <-ctx.Done(): 159 | return ctx.Err() 160 | default: 161 | } 162 | 163 | slog.Debug("chat input", slog.String("model", a.config.Model), slog.Any("messages", messages), slog.Any("tools", a.tools)) 164 | msg, err := a.llm.Chat(ctx, a.config.Model, messages, 165 | llm.Tools(a.tools), 166 | llm.MaxTokens(a.config.MaxTokens), 167 | llm.Temperature(a.config.Temperature), 168 | llm.TopP(a.config.TopP), 169 | ) 170 | if err != nil { 171 | slog.Debug("chat error", slog.Any("err", err)) 172 | return err 173 | } 174 | slog.Debug("chat output", slog.Any("result", msg)) 175 | 176 | select { 177 | case outputChan <- msg: 178 | case <-ctx.Done(): 179 | return ctx.Err() 180 | } 181 | 182 | messages = append(messages, msg) 183 | 184 | if len(msg.ToolCalls) == 0 { 185 | break 186 | } 187 | 188 | for _, toolCall := range msg.ToolCalls { 189 | t, ok := toolsMap[toolCall.Name] 190 | if !ok { 191 | slog.Error("tool not exist", slog.String("tool_call_id", toolCall.ID)) 192 | continue 193 | } 194 | 195 | var ( 196 | toolCtx = ctx 197 | toolCtxCancel context.CancelFunc 198 | ) 199 | if a.config.ToolTimeout != 0 { 200 | toolCtx, toolCtxCancel = context.WithTimeout(ctx, a.config.ToolTimeout) 201 | } 202 | 203 | output, err := t.Execute(toolCtx, toolCall.Arguments) 204 | if toolCtxCancel != nil { 205 | toolCtxCancel() 206 | } 207 | 208 | if err != nil { 209 | slog.Error("tool call function execute failed", slog.String("tool_call_id", toolCall.ID), slog.Any("err", err)) 210 | } 211 | 212 | toolResponse := a.convertToolCallMessage(toolCall, output, err) 213 | messages = append(messages, toolResponse) 214 | 215 | select { 216 | case outputChan <- toolResponse: 217 | case <-ctx.Done(): 218 | return ctx.Err() 219 | } 220 | } 221 | } 222 | 223 | a.memory.Update(messages) 224 | return nil 225 | } 226 | 227 | func (a *Agent) convertToolCallMessage(toolCall *tool.Call, output string, err error) *llm.ChatMessage { 228 | status := "success" 229 | if err != nil { 230 | status = "failed" 231 | output = err.Error() 232 | } 233 | 234 | content, _ := json.Marshal(struct { 235 | Status string `json:"status"` 236 | Output string `json:"output"` 237 | Name string `json:"name"` 238 | }{ 239 | Status: status, 240 | Output: output, 241 | Name: toolCall.Name, 242 | }) 243 | 244 | toolMessage := &llm.ChatMessage{ 245 | Role: llm.RoleTool, 246 | Content: string(content), 247 | ToolCallID: toolCall.ID, 248 | } 249 | 250 | return toolMessage 251 | } 252 | 253 | func (a *Agent) buildRAGPrompt(contexts []string, question string) string { 254 | return fmt.Sprintf(`You are an assistant. Answer the question based on the given context. 255 | 256 | Context: 257 | %v 258 | 259 | Question: 260 | %v 261 | 262 | Answer:`, strings.Join(contexts, "\n"), question) 263 | } 264 | -------------------------------------------------------------------------------- /pkg/config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import ( 4 | "github.com/caarlos0/env/v11" 5 | "time" 6 | ) 7 | 8 | type Config struct { 9 | ServerAddr string `env:"SERVER_ADDR" envDefault:"0.0.0.0:8082"` 10 | ChatAPIKEY string `env:"CHAT_API_KEY" envDefault:""` 11 | WechatAppID string `env:"WECHAT_APP_ID"` 12 | WechatEncodingAESKey string `env:"WECHAT_ENCODING_AES_KEY"` 13 | WechatToken string `env:"WECHAT_TOKEN"` 14 | WechatAllowList []string `env:"WECHAT_ALLOW_LIST"` 15 | WechatMemTTL time.Duration `env:"WECHAT_MEM_TTL" envDefault:"5m"` 16 | WechatMemMsgSize int `env:"WECHAT_MEM_MSG_SIZE" envDefault:"6"` 17 | WechatTimeout time.Duration `env:"WECHAT_TIMEOUT" envDefault:"4s"` 18 | LLMProvider string `env:"LLM_PROVIDER" envDefault:"openai"` 19 | UseRAG bool `env:"USE_RAG" envDefault:"false"` 20 | EmbeddingProvider string `env:"EMBEDDING_PROVIDER" envDefault:"openai"` 21 | KnowledgeBasePath string `env:"KNOWLEDGE_BASE_PATH" envDefault:"./knowledge_base"` 22 | AgentConfig 23 | } 24 | 25 | type AgentConfig struct { 26 | AgentTools []string `env:"AGENT_TOOLS"` 27 | AgentTimeout time.Duration `env:"AGENT_TIMEOUT" envDefault:"30s"` 28 | MaxToolIter int `env:"MAX_TOOL_ITER" envDefault:"5"` 29 | ToolTimeout time.Duration `env:"TOOL_TIMEOUT" envDefault:"10s"` 30 | Model string `env:"LLM_MODEL" envDefault:"gpt-3.5-turbo"` 31 | MaxTokens int `env:"LLM_MAX_TOKENS" envDefault:"500"` 32 | Temperature float32 `env:"LLM_TEMPERATURE" envDefault:"0.2"` 33 | TopP float32 `env:"LLM_TOP_P" envDefault:"0.9"` 34 | SystemPrompt string `env:"SYSTEM_PROMPT" envDefault:"当前时间: {{now.UTC}}"` 35 | EmbeddingModel string `env:"EMBEDDING_MODEL" envDefault:"text-embedding-ada-002"` 36 | } 37 | 38 | func LoadConfig() (*Config, error) { 39 | cfg := &Config{} 40 | if err := env.Parse(cfg); err != nil { 41 | return nil, err 42 | } 43 | return cfg, nil 44 | } 45 | -------------------------------------------------------------------------------- /pkg/embedding/default.go: -------------------------------------------------------------------------------- 1 | package embedding 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | ) 7 | 8 | type NotImplemented struct { 9 | } 10 | 11 | func NewNotImplemented() Model { 12 | return &NotImplemented{} 13 | } 14 | 15 | func (o *NotImplemented) CreateEmbeddings(ctx context.Context, model string, content string) ([]float32, error) { 16 | return nil, fmt.Errorf("embedding not implemented") 17 | } 18 | -------------------------------------------------------------------------------- /pkg/embedding/embedding.go: -------------------------------------------------------------------------------- 1 | package embedding 2 | 3 | import "context" 4 | 5 | type Model interface { 6 | CreateEmbeddings(ctx context.Context, model string, content string) ([]float32, error) 7 | } 8 | 9 | func New(provider string) Model { 10 | switch provider { 11 | case "openai": 12 | return NewOpenAI() 13 | default: 14 | return NewNotImplemented() 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /pkg/embedding/openai.go: -------------------------------------------------------------------------------- 1 | package embedding 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "github.com/sashabaranov/go-openai" 7 | "github.com/tonnie17/wxagent/pkg/provider" 8 | ) 9 | 10 | type OpenAI struct { 11 | client *openai.Client 12 | } 13 | 14 | func NewOpenAI() Model { 15 | openaiConfig := openai.DefaultConfig(provider.GetAPIKey("openai")) 16 | openaiConfig.BaseURL = provider.GetAPIBaseURL("openai") 17 | client := openai.NewClientWithConfig(openaiConfig) 18 | return &OpenAI{ 19 | client: client, 20 | } 21 | } 22 | 23 | func (o *OpenAI) CreateEmbeddings(ctx context.Context, model string, content string) ([]float32, error) { 24 | resp, err := o.client.CreateEmbeddings(ctx, openai.EmbeddingRequestStrings{ 25 | Input: []string{content}, 26 | Model: openai.EmbeddingModel(model), 27 | EncodingFormat: openai.EmbeddingEncodingFormatFloat, 28 | }) 29 | 30 | if err != nil { 31 | return nil, err 32 | } 33 | 34 | if len(resp.Data) == 0 { 35 | return nil, fmt.Errorf("data is empty") 36 | } 37 | 38 | return resp.Data[0].Embedding, nil 39 | } 40 | -------------------------------------------------------------------------------- /pkg/ha/home_assistant.go: -------------------------------------------------------------------------------- 1 | package ha 2 | 3 | import ( 4 | "bytes" 5 | "context" 6 | "encoding/json" 7 | "fmt" 8 | "io" 9 | "net/http" 10 | "os" 11 | "strings" 12 | ) 13 | 14 | func GetEntityStates(ctx context.Context, domains []string) ([]EntityState, error) { 15 | req, err := newHARequest(ctx, "GET", "/api/states", nil) 16 | if err != nil { 17 | return nil, err 18 | } 19 | client := &http.Client{} 20 | resp, err := client.Do(req) 21 | if err != nil { 22 | return nil, err 23 | } 24 | defer resp.Body.Close() 25 | 26 | content, err := io.ReadAll(resp.Body) 27 | if err != nil { 28 | return nil, err 29 | } 30 | 31 | entityStates := make([]EntityState, 0) 32 | if err := json.Unmarshal(content, &entityStates); err != nil { 33 | return nil, err 34 | } 35 | 36 | res := make([]EntityState, 0) 37 | for _, entityState := range entityStates { 38 | if entityState.State == "unknown" { 39 | continue 40 | } 41 | split := strings.Split(entityState.EntityID, ".") 42 | if len(domains) > 0 && (len(split) < 2 || !containString(domains, split[0])) { 43 | continue 44 | } 45 | res = append(res, entityState) 46 | } 47 | 48 | return res, nil 49 | } 50 | 51 | func ExecuteService(ctx context.Context, domain string, service string, entityID string) ([]EntityState, error) { 52 | body := map[string]interface{}{ 53 | "entity_id": entityID, 54 | } 55 | bodyJSON, _ := json.Marshal(body) 56 | req, err := newHARequest(ctx, "POST", fmt.Sprintf("/api/services/%v/%v", domain, service), bytes.NewReader(bodyJSON)) 57 | 58 | client := &http.Client{} 59 | resp, err := client.Do(req) 60 | if err != nil { 61 | return nil, err 62 | } 63 | defer resp.Body.Close() 64 | 65 | content, err := io.ReadAll(resp.Body) 66 | if err != nil { 67 | return nil, err 68 | } 69 | 70 | entityStates := make([]EntityState, 0) 71 | if err := json.Unmarshal(content, &entityStates); err != nil { 72 | return nil, err 73 | } 74 | 75 | return entityStates, nil 76 | } 77 | 78 | func newHARequest(ctx context.Context, method string, apiPath string, body io.Reader) (*http.Request, error) { 79 | haBaseURL := os.Getenv("HA_BASE_URL") 80 | haToken := os.Getenv("HA_BEARER_TOKEN") 81 | 82 | apiURL := fmt.Sprintf("%v%v", haBaseURL, apiPath) 83 | req, err := http.NewRequest(method, apiURL, body) 84 | if err != nil { 85 | return nil, err 86 | } 87 | req = req.WithContext(ctx) 88 | req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", haToken)) 89 | return req, nil 90 | } 91 | 92 | func containString(target []string, s string) bool { 93 | for _, t := range target { 94 | if s == t { 95 | return true 96 | } 97 | } 98 | return false 99 | } 100 | -------------------------------------------------------------------------------- /pkg/ha/model.go: -------------------------------------------------------------------------------- 1 | package ha 2 | 3 | type EntityState struct { 4 | EntityID string `json:"entity_id"` 5 | State string `json:"state"` 6 | Attributes struct { 7 | FriendlyName string `json:"friendly_name"` 8 | } `json:"attributes"` 9 | LastChanged string `json:"last_changed"` 10 | } 11 | -------------------------------------------------------------------------------- /pkg/llm/default.go: -------------------------------------------------------------------------------- 1 | package llm 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | ) 7 | 8 | type NotImplemented struct { 9 | } 10 | 11 | func NewNotImplemented() LLM { 12 | return &NotImplemented{} 13 | } 14 | 15 | func (o *NotImplemented) Chat(ctx context.Context, model string, chatMessages []*ChatMessage, options ...ChatOption) (*ChatMessage, error) { 16 | return nil, fmt.Errorf("llm not implemented") 17 | } 18 | -------------------------------------------------------------------------------- /pkg/llm/llm.go: -------------------------------------------------------------------------------- 1 | package llm 2 | 3 | import ( 4 | "context" 5 | "github.com/tonnie17/wxagent/pkg/tool" 6 | ) 7 | 8 | type LLM interface { 9 | Chat(ctx context.Context, model string, messages []*ChatMessage, options ...ChatOption) (*ChatMessage, error) 10 | } 11 | 12 | func New(provider string) LLM { 13 | switch provider { 14 | case "openai": 15 | return NewOpenAI() 16 | default: 17 | return NewNotImplemented() 18 | } 19 | } 20 | 21 | type chatOptions struct { 22 | tools []tool.Tool 23 | maxTokens int 24 | temperature float32 25 | topP float32 26 | } 27 | 28 | type ChatOption func(*chatOptions) 29 | 30 | func Tools(tools []tool.Tool) ChatOption { 31 | return func(o *chatOptions) { 32 | o.tools = tools 33 | } 34 | } 35 | 36 | func MaxTokens(maxTokens int) ChatOption { 37 | return func(o *chatOptions) { 38 | o.maxTokens = maxTokens 39 | } 40 | } 41 | 42 | func Temperature(temperature float32) ChatOption { 43 | return func(o *chatOptions) { 44 | o.temperature = temperature 45 | } 46 | } 47 | 48 | func TopP(topP float32) ChatOption { 49 | return func(o *chatOptions) { 50 | o.topP = topP 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /pkg/llm/message.go: -------------------------------------------------------------------------------- 1 | package llm 2 | 3 | import "github.com/tonnie17/wxagent/pkg/tool" 4 | 5 | type Role string 6 | 7 | const ( 8 | RoleUser Role = "user" 9 | RoleSystem Role = "system" 10 | RoleTool Role = "tool" 11 | RoleAssistant Role = "assistant" 12 | ) 13 | 14 | type ChatMessage struct { 15 | Role Role `json:"role"` 16 | Content string `json:"content"` 17 | ToolCallID string `json:"tool_call_id"` 18 | ToolCalls []*tool.Call `json:"tool_calls"` 19 | } 20 | -------------------------------------------------------------------------------- /pkg/llm/openai.go: -------------------------------------------------------------------------------- 1 | package llm 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "github.com/sashabaranov/go-openai" 7 | "github.com/tonnie17/wxagent/pkg/provider" 8 | "github.com/tonnie17/wxagent/pkg/tool" 9 | ) 10 | 11 | type OpenAI struct { 12 | client *openai.Client 13 | } 14 | 15 | func NewOpenAI() LLM { 16 | openaiConfig := openai.DefaultConfig(provider.GetAPIKey("openai")) 17 | openaiConfig.BaseURL = provider.GetAPIBaseURL("openai") 18 | client := openai.NewClientWithConfig(openaiConfig) 19 | return &OpenAI{ 20 | client: client, 21 | } 22 | } 23 | 24 | func (o *OpenAI) Chat(ctx context.Context, model string, chatMessages []*ChatMessage, options ...ChatOption) (*ChatMessage, error) { 25 | message := make([]openai.ChatCompletionMessage, 0, len(chatMessages)) 26 | for _, chatMessage := range chatMessages { 27 | msg := openai.ChatCompletionMessage{ 28 | Role: string(chatMessage.Role), 29 | Content: chatMessage.Content, 30 | ToolCallID: chatMessage.ToolCallID, 31 | } 32 | for _, toolCall := range chatMessage.ToolCalls { 33 | msg.ToolCalls = append(msg.ToolCalls, openai.ToolCall{ 34 | ID: toolCall.ID, 35 | Type: openai.ToolType(toolCall.Type), 36 | Function: openai.FunctionCall{ 37 | Name: toolCall.Name, 38 | Arguments: toolCall.Arguments, 39 | }, 40 | }) 41 | } 42 | message = append(message, msg) 43 | } 44 | 45 | var option chatOptions 46 | for _, o := range options { 47 | o(&option) 48 | } 49 | 50 | requestTools := make([]openai.Tool, 0, len(option.tools)) 51 | for _, t := range option.tools { 52 | requestTools = append(requestTools, openai.Tool{ 53 | Type: openai.ToolTypeFunction, 54 | Function: &openai.FunctionDefinition{ 55 | Name: t.Name(), 56 | Description: t.Description(), 57 | Parameters: t.Schema(), 58 | }, 59 | }) 60 | } 61 | 62 | response, err := o.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{ 63 | Model: model, 64 | Messages: message, 65 | Tools: requestTools, 66 | MaxTokens: option.maxTokens, 67 | Temperature: option.temperature, 68 | TopP: option.topP, 69 | }) 70 | if err != nil { 71 | return nil, err 72 | } 73 | 74 | return o.convertResponse(response) 75 | } 76 | 77 | func (o *OpenAI) convertResponse(response openai.ChatCompletionResponse) (*ChatMessage, error) { 78 | if len(response.Choices) == 0 { 79 | return nil, fmt.Errorf("empty response") 80 | } 81 | 82 | message := response.Choices[0].Message 83 | if message.Content != "" { 84 | rc := &ChatMessage{ 85 | Role: Role(message.Role), 86 | Content: message.Content, 87 | } 88 | return rc, nil 89 | } 90 | 91 | if len(message.ToolCalls) > 0 { 92 | rc := &ChatMessage{ 93 | Role: Role(message.Role), 94 | Content: message.Content, 95 | } 96 | for _, toolCall := range message.ToolCalls { 97 | rc.ToolCalls = append(rc.ToolCalls, &tool.Call{ 98 | ID: toolCall.ID, 99 | Type: string(toolCall.Type), 100 | Name: toolCall.Function.Name, 101 | Arguments: toolCall.Function.Arguments, 102 | }) 103 | } 104 | return rc, nil 105 | } 106 | 107 | return nil, fmt.Errorf("no content or tool calls") 108 | } 109 | -------------------------------------------------------------------------------- /pkg/memory/buffer.go: -------------------------------------------------------------------------------- 1 | package memory 2 | 3 | import ( 4 | "github.com/tonnie17/wxagent/pkg/llm" 5 | ) 6 | 7 | type Buffer struct { 8 | BaseLock 9 | maxMessages int 10 | messages []*llm.ChatMessage 11 | } 12 | 13 | func NewBuffer(maxMessages int) *Buffer { 14 | return &Buffer{ 15 | maxMessages: maxMessages, 16 | messages: []*llm.ChatMessage{}, 17 | } 18 | } 19 | 20 | func (m *Buffer) History() ([]*llm.ChatMessage, error) { 21 | return m.messages, nil 22 | } 23 | 24 | func (m *Buffer) Update(messages []*llm.ChatMessage) error { 25 | m.messages = messages 26 | m.truncate() 27 | return nil 28 | } 29 | 30 | func (m *Buffer) truncate() { 31 | start := 0 32 | for start < len(m.messages) && len(m.messages) > m.maxMessages { 33 | start++ 34 | for start < len(m.messages) && (m.messages[start].Role == llm.RoleTool || m.messages[start].Role == llm.RoleAssistant) { 35 | start++ 36 | } 37 | m.messages = m.messages[start:] 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /pkg/memory/memory.go: -------------------------------------------------------------------------------- 1 | package memory 2 | 3 | import ( 4 | "github.com/tonnie17/wxagent/pkg/llm" 5 | "sync" 6 | "sync/atomic" 7 | ) 8 | 9 | type Memory interface { 10 | Update(messages []*llm.ChatMessage) error 11 | History() ([]*llm.ChatMessage, error) 12 | } 13 | 14 | type Lock interface { 15 | Lock() bool 16 | Release() 17 | IsLocked() bool 18 | } 19 | 20 | type BaseLock struct { 21 | isLock int32 22 | lock sync.Mutex 23 | } 24 | 25 | func (m *BaseLock) Lock() bool { 26 | lockSuccess := m.lock.TryLock() 27 | if lockSuccess { 28 | atomic.StoreInt32(&m.isLock, 1) 29 | } 30 | return lockSuccess 31 | } 32 | 33 | func (m *BaseLock) IsLocked() bool { 34 | return atomic.LoadInt32(&m.isLock) == 1 35 | } 36 | 37 | func (m *BaseLock) Release() { 38 | atomic.StoreInt32(&m.isLock, 0) 39 | m.lock.Unlock() 40 | } 41 | -------------------------------------------------------------------------------- /pkg/memory/token_base.go: -------------------------------------------------------------------------------- 1 | package memory 2 | 3 | import ( 4 | "github.com/pkoukk/tiktoken-go" 5 | "github.com/tonnie17/wxagent/pkg/llm" 6 | "log/slog" 7 | ) 8 | 9 | type TokenBase struct { 10 | BaseLock 11 | messages []*llm.ChatMessage 12 | maxTokens int 13 | totalTokens int 14 | encoding *tiktoken.Tiktoken 15 | } 16 | 17 | func NewTokenBase(model string, maxTokens int) *TokenBase { 18 | encoding, err := tiktoken.EncodingForModel(model) 19 | if err != nil { 20 | slog.Warn("failed to get encoding for model") 21 | encoding, _ = tiktoken.EncodingForModel("gpt-4o") 22 | } 23 | return &TokenBase{ 24 | messages: []*llm.ChatMessage{}, 25 | maxTokens: maxTokens, 26 | totalTokens: 0, 27 | encoding: encoding, 28 | } 29 | } 30 | 31 | func (m *TokenBase) History() ([]*llm.ChatMessage, error) { 32 | return m.messages, nil 33 | } 34 | 35 | func (m *TokenBase) Update(messages []*llm.ChatMessage) error { 36 | m.totalTokens = m.getMessageTokens(messages) 37 | m.messages = messages 38 | m.truncate() 39 | return nil 40 | } 41 | 42 | func (m *TokenBase) truncate() { 43 | start := 0 44 | for start < len(m.messages) && m.maxTokens > 0 && m.getMessageTokens(m.messages) > m.maxTokens { 45 | start++ 46 | for start < len(m.messages) && (m.messages[start].Role == llm.RoleTool || m.messages[start].Role == llm.RoleAssistant) { 47 | start++ 48 | } 49 | m.messages = m.messages[start:] 50 | } 51 | } 52 | 53 | func (m *TokenBase) getMessageTokens(messages []*llm.ChatMessage) int { 54 | var total int 55 | for _, message := range messages { 56 | tokens := m.encoding.Encode(message.Content, nil, nil) 57 | total += len(tokens) 58 | } 59 | return total 60 | } 61 | -------------------------------------------------------------------------------- /pkg/provider/util.go: -------------------------------------------------------------------------------- 1 | package provider 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "strings" 7 | ) 8 | 9 | func GetAPIKey(provider string) string { 10 | return os.Getenv(fmt.Sprintf("%s_API_KEY", strings.ToUpper(provider))) 11 | } 12 | 13 | func GetAPIBaseURL(provider string) string { 14 | return os.Getenv(fmt.Sprintf("%s_BASE_URL", strings.ToUpper(provider))) 15 | } 16 | -------------------------------------------------------------------------------- /pkg/rag/document.go: -------------------------------------------------------------------------------- 1 | package rag 2 | 3 | type DocumentPart struct { 4 | DocumentID string 5 | PartIndex int 6 | Content string 7 | } 8 | -------------------------------------------------------------------------------- /pkg/rag/rag.go: -------------------------------------------------------------------------------- 1 | package rag 2 | 3 | import ( 4 | "context" 5 | "github.com/tonnie17/wxagent/pkg/embedding" 6 | "io" 7 | "io/fs" 8 | "log/slog" 9 | "os" 10 | "path/filepath" 11 | ) 12 | 13 | type Client struct { 14 | embeddingModel embedding.Model 15 | store *VectorStore 16 | } 17 | 18 | func NewClient(embeddingModel embedding.Model, store *VectorStore) *Client { 19 | return &Client{ 20 | embeddingModel: embeddingModel, 21 | store: store, 22 | } 23 | } 24 | 25 | func (c *Client) Query(ctx context.Context, model string, content string, limit int) ([]*DocumentPart, error) { 26 | embeddingData, err := c.embeddingModel.CreateEmbeddings(ctx, model, content) 27 | if err != nil { 28 | return nil, err 29 | } 30 | 31 | return c.store.GetMostRelevantDocuments(ctx, embeddingData, 1, limit) 32 | } 33 | 34 | func (c *Client) BuildKnowledgeBase(ctx context.Context, knowledgeBasePath string, model string, reBuild bool) error { 35 | if knowledgeBasePath == "" { 36 | return nil 37 | } 38 | if err := c.store.Init(ctx); err != nil { 39 | return err 40 | } 41 | return filepath.WalkDir(knowledgeBasePath, func(path string, d fs.DirEntry, e error) error { 42 | if d != nil && d.IsDir() { 43 | return nil 44 | } 45 | logger := slog.With(slog.String("file", path)) 46 | ext := filepath.Ext(path) 47 | if ext != ".txt" { 48 | return nil 49 | } 50 | logger.Info("build knowledge base") 51 | 52 | switch ext { 53 | case ".txt": 54 | documentID := filepath.Base(path) 55 | if reBuild { 56 | err := c.store.DeleteDocuments(ctx, documentID) 57 | if err != nil { 58 | logger.Error("delete documents failed", slog.Any("err", err)) 59 | return nil 60 | } 61 | } else { 62 | documentExist, err := c.store.CheckDocumentExist(ctx, documentID) 63 | if err != nil { 64 | logger.Error("check document exist failed", slog.Any("err", err)) 65 | return nil 66 | } 67 | if documentExist { 68 | return nil 69 | } 70 | } 71 | 72 | splits, err := processTextFile(path) 73 | if err != nil { 74 | logger.Error("process text file failed", slog.Any("err", err)) 75 | return nil 76 | } 77 | 78 | var partIndex int 79 | for _, content := range splits { 80 | embeddingData, err := c.embeddingModel.CreateEmbeddings(ctx, model, content) 81 | if err != nil { 82 | logger.Error("create embeddings failed", slog.Any("err", err)) 83 | return nil 84 | } 85 | partIndex++ 86 | if err := c.store.SaveDocumentEmbedding(ctx, documentID, partIndex, content, embeddingData); err != nil { 87 | logger.Error("save embedding failed", slog.Any("err", err)) 88 | return nil 89 | } 90 | } 91 | } 92 | 93 | return nil 94 | }) 95 | } 96 | 97 | func processTextFile(fileName string) ([]string, error) { 98 | f, err := os.Open(fileName) 99 | if err != nil { 100 | return nil, err 101 | } 102 | defer f.Close() 103 | 104 | content, err := io.ReadAll(f) 105 | if err != nil { 106 | return nil, err 107 | } 108 | 109 | return splitText(string(content), 500, 50), nil 110 | } 111 | -------------------------------------------------------------------------------- /pkg/rag/splitter.go: -------------------------------------------------------------------------------- 1 | package rag 2 | 3 | import ( 4 | "strings" 5 | "unicode/utf8" 6 | ) 7 | 8 | func splitText(text string, chunkSize int, chunkOverlap int) []string { 9 | separators := []string{"\n\n", "\n", " ", ""} 10 | separator := separators[len(separators)-1] 11 | 12 | var newSeparators []string 13 | for i, sep := range separators { 14 | if sep == "" || strings.Contains(text, sep) { 15 | separator = sep 16 | newSeparators = separators[i+1:] 17 | break 18 | } 19 | } 20 | 21 | var final []string 22 | goodSplits := make([]string, 0) 23 | for _, split := range strings.Split(text, separator) { 24 | if utf8.RuneCountInString(split) < chunkSize { 25 | goodSplits = append(goodSplits, split) 26 | continue 27 | } 28 | 29 | if len(goodSplits) > 0 { 30 | mergedSplits := mergeSplits(goodSplits, separator, chunkSize, chunkOverlap) 31 | 32 | final = append(final, mergedSplits...) 33 | goodSplits = make([]string, 0) 34 | } 35 | 36 | if len(newSeparators) == 0 { 37 | final = append(final, split) 38 | } else { 39 | other := splitText(split, chunkSize, chunkOverlap) 40 | final = append(final, other...) 41 | } 42 | } 43 | 44 | if len(goodSplits) > 0 { 45 | mergedSplits := mergeSplits(goodSplits, separator, chunkSize, chunkOverlap) 46 | final = append(final, mergedSplits...) 47 | } 48 | 49 | return final 50 | } 51 | 52 | func mergeSplits(splits []string, separator string, chunkSize int, chunkOverlap int) []string { 53 | docs := make([]string, 0) 54 | currentDoc := make([]string, 0) 55 | total := 0 56 | 57 | for _, split := range splits { 58 | splitLen := utf8.RuneCountInString(split) 59 | sepLen := utf8.RuneCountInString(separator) * compareDocsLen(currentDoc, 0) 60 | 61 | if total+splitLen+sepLen > chunkSize && len(currentDoc) > 0 { 62 | if doc := strings.TrimSpace(strings.Join(currentDoc, separator)); doc != "" { 63 | docs = append(docs, doc) 64 | } 65 | 66 | for len(currentDoc) > 0 && (total > chunkOverlap || 67 | (total+splitLen+utf8.RuneCountInString(separator)*compareDocsLen(currentDoc, 1) > chunkSize && total > 0)) { 68 | total -= utf8.RuneCountInString(currentDoc[0]) + utf8.RuneCountInString(separator)*compareDocsLen(currentDoc, 1) 69 | currentDoc = currentDoc[1:] 70 | } 71 | } 72 | 73 | currentDoc = append(currentDoc, split) 74 | total += utf8.RuneCountInString(split) 75 | total += utf8.RuneCountInString(separator) * compareDocsLen(currentDoc, 1) 76 | } 77 | 78 | if doc := strings.TrimSpace(strings.Join(currentDoc, separator)); doc != "" { 79 | docs = append(docs, doc) 80 | } 81 | return docs 82 | } 83 | 84 | func compareDocsLen(currentDocs []string, cmp int) int { 85 | if len(currentDocs) > cmp { 86 | return 1 87 | } 88 | return 0 89 | } 90 | -------------------------------------------------------------------------------- /pkg/rag/store.go: -------------------------------------------------------------------------------- 1 | package rag 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "github.com/jackc/pgx/v5" 7 | "github.com/jackc/pgx/v5/pgxpool" 8 | "github.com/pgvector/pgvector-go" 9 | pgxvector "github.com/pgvector/pgvector-go/pgx" 10 | "os" 11 | ) 12 | 13 | type VectorStore struct { 14 | pool *pgxpool.Pool 15 | } 16 | 17 | func NewPgVectorStore() (*VectorStore, error) { 18 | poolConfig, err := pgxpool.ParseConfig(os.Getenv("POSTGRES_URL")) 19 | if err != nil { 20 | return nil, err 21 | } 22 | poolConfig.ConnConfig.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe 23 | poolConfig.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error { 24 | return pgxvector.RegisterTypes(ctx, conn) 25 | } 26 | 27 | pool, err := pgxpool.NewWithConfig(context.Background(), poolConfig) 28 | if err != nil { 29 | return nil, err 30 | } 31 | 32 | return &VectorStore{ 33 | pool: pool, 34 | }, nil 35 | } 36 | 37 | func (s *VectorStore) Init(ctx context.Context) error { 38 | query := ` 39 | CREATE TABLE IF NOT EXISTS knowledge_base ( 40 | id SERIAL PRIMARY KEY, 41 | document_id TEXT NOT NULL, 42 | part_index INT NOT NULL, 43 | content TEXT, 44 | embedding vector(1536), 45 | UNIQUE (document_id, part_index) 46 | ) 47 | ` 48 | if _, err := s.pool.Exec(ctx, query); err != nil { 49 | return err 50 | } 51 | 52 | return nil 53 | } 54 | 55 | func (s *VectorStore) GetMostRelevantDocuments(ctx context.Context, embedding []float32, threshold float32, limit int) ([]*DocumentPart, error) { 56 | query := fmt.Sprintf("SELECT document_id, part_index, content FROM knowledge_base WHERE embedding <-> $1 < $2 ORDER BY embedding <-> $1 LIMIT %v", limit) 57 | rows, err := s.pool.Query(ctx, query, pgvector.NewVector(embedding), threshold) 58 | if err != nil { 59 | return nil, err 60 | } 61 | defer rows.Close() 62 | 63 | var documents []*DocumentPart 64 | for rows.Next() { 65 | document := &DocumentPart{} 66 | if err := rows.Scan(&document.DocumentID, &document.PartIndex, &document.Content); err != nil { 67 | return nil, err 68 | } 69 | documents = append(documents, document) 70 | } 71 | 72 | if rows.Err() != nil { 73 | return nil, rows.Err() 74 | } 75 | 76 | return documents, nil 77 | } 78 | 79 | func (s *VectorStore) SaveDocumentEmbedding(ctx context.Context, documentID string, partIndex int, content string, embedding []float32) error { 80 | _, err := s.pool.Exec(ctx, ` 81 | INSERT INTO knowledge_base (document_id, part_index, content, embedding) 82 | VALUES ($1, $2, $3, $4) 83 | ON CONFLICT (document_id, part_index) 84 | DO UPDATE SET 85 | content = EXCLUDED.content, 86 | embedding = EXCLUDED.embedding; 87 | `, documentID, partIndex, content, pgvector.NewVector(embedding)) 88 | 89 | if err != nil { 90 | return err 91 | } 92 | 93 | return nil 94 | } 95 | 96 | func (s *VectorStore) CheckDocumentExist(ctx context.Context, documentID string) (bool, error) { 97 | var exists bool 98 | err := s.pool.QueryRow(ctx, ` 99 | SELECT EXISTS ( 100 | SELECT 1 FROM knowledge_base WHERE document_id = $1 101 | ) 102 | `, documentID).Scan(&exists) 103 | 104 | if err != nil { 105 | return false, nil 106 | } 107 | 108 | return exists, nil 109 | } 110 | 111 | func (s *VectorStore) DeleteDocuments(ctx context.Context, documentID string) error { 112 | _, err := s.pool.Exec(ctx, ` 113 | DELETE FROM knowledge_base WHERE document_id = $1 114 | `, documentID) 115 | 116 | if err != nil { 117 | return err 118 | } 119 | 120 | return nil 121 | } 122 | 123 | func (s *VectorStore) Release() { 124 | s.pool.Close() 125 | } 126 | -------------------------------------------------------------------------------- /pkg/tool/execute_ha_devices.go: -------------------------------------------------------------------------------- 1 | package tool 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "github.com/tonnie17/wxagent/pkg/ha" 7 | "log/slog" 8 | ) 9 | 10 | type ExecuteHADevice struct { 11 | } 12 | 13 | func NewExecuteHADevice() Tool { 14 | return &ExecuteHADevice{} 15 | } 16 | 17 | func (e *ExecuteHADevice) Name() string { 18 | return "execute_device" 19 | } 20 | 21 | func (e *ExecuteHADevice) Description() string { 22 | return "Use this function to execute service of devices in Home Assistant" 23 | } 24 | 25 | func (e *ExecuteHADevice) Schema() map[string]interface{} { 26 | return map[string]interface{}{ 27 | "type": "object", 28 | "properties": map[string]interface{}{ 29 | "list": map[string]interface{}{ 30 | "type": "array", 31 | "items": map[string]interface{}{ 32 | "type": "object", 33 | "properties": map[string]interface{}{ 34 | "domain": map[string]interface{}{ 35 | "type": "string", 36 | "description": "The domain of the service", 37 | }, 38 | "service": map[string]interface{}{ 39 | "type": "string", 40 | "description": "The service to be called", 41 | }, 42 | "entity_id": map[string]interface{}{ 43 | "type": "string", 44 | "description": "The entity_id retrieved from available devices. It must start with domain, followed by dot character", 45 | }, 46 | }, 47 | "required": []string{"domain", "service", "entity_id"}, 48 | }, 49 | }, 50 | }, 51 | } 52 | } 53 | 54 | func (e *ExecuteHADevice) Execute(ctx context.Context, input string) (string, error) { 55 | var arguments struct { 56 | List []struct { 57 | Domain string `json:"domain"` 58 | Service string `json:"service"` 59 | EntityID string `json:"entity_id"` 60 | } `json:"list"` 61 | } 62 | if err := json.Unmarshal([]byte(input), &arguments); err != nil { 63 | slog.Error("unmarshal failed", slog.Any("err", err)) 64 | return "", err 65 | } 66 | 67 | var entityStates []ha.EntityState 68 | for _, action := range arguments.List { 69 | states, err := ha.ExecuteService(ctx, action.Domain, action.Service, action.EntityID) 70 | if err != nil { 71 | return "", err 72 | } 73 | entityStates = append(entityStates, states...) 74 | } 75 | 76 | statesJSON, _ := json.Marshal(entityStates) 77 | 78 | return string(statesJSON), nil 79 | } 80 | -------------------------------------------------------------------------------- /pkg/tool/get_ha_devices.go: -------------------------------------------------------------------------------- 1 | package tool 2 | 3 | import ( 4 | "context" 5 | "fmt" 6 | "github.com/tonnie17/wxagent/pkg/ha" 7 | "strings" 8 | ) 9 | 10 | type GetHADevices struct { 11 | } 12 | 13 | func NewGetHADevices() Tool { 14 | return &GetHADevices{} 15 | } 16 | 17 | func (g *GetHADevices) Name() string { 18 | return "get_devices" 19 | } 20 | 21 | func (g *GetHADevices) Description() string { 22 | return "Use this function to get devices in Home Assistant, including their state and entity_id" 23 | } 24 | 25 | func (g *GetHADevices) Schema() map[string]interface{} { 26 | return map[string]interface{}{ 27 | "type": "object", 28 | "properties": map[string]interface{}{}, 29 | "required": []string{}, 30 | } 31 | } 32 | 33 | func (g *GetHADevices) Execute(ctx context.Context, input string) (string, error) { 34 | entityStates, err := ha.GetEntityStates(ctx, g.defaultDomains()) 35 | if err != nil { 36 | return "", err 37 | } 38 | 39 | var builder strings.Builder 40 | builder.WriteString("An overview of the devices in this smart home:\n") 41 | builder.WriteString("```csv\n") 42 | builder.WriteString("entity_id,name,state\n") 43 | for _, entityState := range entityStates { 44 | builder.WriteString(fmt.Sprintf("%v,%v,%v\n", entityState.EntityID, entityState.Attributes.FriendlyName, entityState.State)) 45 | } 46 | builder.WriteString("```\n") 47 | 48 | return builder.String(), nil 49 | } 50 | 51 | func (g *GetHADevices) defaultDomains() []string { 52 | return []string{ 53 | "door", 54 | "lock", 55 | "occupancy", 56 | "motion", 57 | "climate", 58 | "light", 59 | "switch", 60 | "sensor", 61 | "speaker", 62 | "media_player", 63 | "temperature", 64 | "humidity", 65 | "battery", 66 | "tv", 67 | "remote", 68 | "light", 69 | "vacuum", 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /pkg/tool/get_weather.go: -------------------------------------------------------------------------------- 1 | package tool 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "log/slog" 9 | "net/http" 10 | "os" 11 | "strings" 12 | "unicode" 13 | ) 14 | 15 | type GetWeather struct { 16 | } 17 | 18 | func NewGetWeather() Tool { 19 | return &GetWeather{} 20 | } 21 | 22 | func (w *GetWeather) Name() string { 23 | return "get_weather" 24 | } 25 | 26 | func (w *GetWeather) Description() string { 27 | return "Retrieve the current weather information for a specified city and return the city name in English" 28 | } 29 | 30 | func (w *GetWeather) Schema() map[string]interface{} { 31 | return map[string]interface{}{ 32 | "type": "object", 33 | "properties": map[string]interface{}{ 34 | "city": map[string]interface{}{ 35 | "type": "string", 36 | "description": "city name", 37 | }, 38 | }, 39 | "required": []string{"city"}, 40 | } 41 | } 42 | 43 | func (w *GetWeather) Execute(ctx context.Context, input string) (string, error) { 44 | var arguments struct { 45 | City string `json:"city"` 46 | } 47 | if err := json.Unmarshal([]byte(input), &arguments); err != nil { 48 | slog.Error("unmarshal failed", slog.Any("err", err)) 49 | return "", err 50 | } 51 | 52 | if arguments.City == "" { 53 | return "", fmt.Errorf("city name is empty") 54 | } 55 | 56 | city := w.completeCNCity(arguments.City) 57 | 58 | appID := os.Getenv("OPENWEATHERMAP_API_KEY") 59 | apiURL := fmt.Sprintf("https://api.openweathermap.org/data/2.5/forecast?q=%v&units=metric&appid=%v&lang=zh_cn&cnt=5", city, appID) 60 | 61 | req, err := http.NewRequest("GET", apiURL, nil) 62 | if err != nil { 63 | return "", err 64 | } 65 | req = req.WithContext(ctx) 66 | 67 | client := &http.Client{} 68 | resp, err := client.Do(req) 69 | if err != nil { 70 | return "", err 71 | } 72 | defer resp.Body.Close() 73 | 74 | content, err := io.ReadAll(resp.Body) 75 | if err != nil { 76 | return "", err 77 | } 78 | 79 | var weatherData struct { 80 | List []struct { 81 | Dt int `json:"dt"` 82 | Main struct { 83 | Temp float64 `json:"temp"` 84 | } `json:"main"` 85 | Weather []struct { 86 | Description string `json:"description"` 87 | } `json:"weather"` 88 | DtTxt string `json:"dt_txt"` 89 | } `json:"list"` 90 | City struct { 91 | Name string `json:"name"` 92 | Coord struct { 93 | Lat float64 `json:"lat"` 94 | Lon float64 `json:"lon"` 95 | } `json:"coord"` 96 | Sunrise int `json:"sunrise"` 97 | Sunset int `json:"sunset"` 98 | } `json:"city"` 99 | } 100 | 101 | if err := json.Unmarshal(content, &weatherData); err != nil { 102 | return "", err 103 | } 104 | 105 | mainContent, _ := json.Marshal(weatherData) 106 | 107 | return string(mainContent), nil 108 | } 109 | 110 | func (w *GetWeather) isCNCity(s string) bool { 111 | if s == "" { 112 | return false 113 | } 114 | 115 | for _, r := range s { 116 | if !unicode.Is(unicode.Han, r) { 117 | return false 118 | } 119 | } 120 | return true 121 | } 122 | 123 | func (w *GetWeather) completeCNCity(city string) string { 124 | if w.isCNCity(city) && !strings.HasSuffix(city, "市") { 125 | return city + "市" 126 | } 127 | return city 128 | } 129 | -------------------------------------------------------------------------------- /pkg/tool/google_search.go: -------------------------------------------------------------------------------- 1 | package tool 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "log/slog" 9 | "net/http" 10 | "net/url" 11 | "os" 12 | ) 13 | 14 | type GoogleSearch struct { 15 | } 16 | 17 | func NewGoogleSearch() Tool { 18 | return &GoogleSearch{} 19 | } 20 | 21 | func (w *GoogleSearch) Name() string { 22 | return "google_search" 23 | } 24 | 25 | func (w *GoogleSearch) Description() string { 26 | return "Make a query to the Google search engine to receive a list of results" 27 | } 28 | 29 | func (w *GoogleSearch) Schema() map[string]interface{} { 30 | return map[string]interface{}{ 31 | "type": "object", 32 | "properties": map[string]interface{}{ 33 | "query": map[string]interface{}{ 34 | "type": "string", 35 | "description": "The query to be passed to Google search", 36 | }, 37 | }, 38 | "required": []string{"query"}, 39 | } 40 | } 41 | 42 | func (w *GoogleSearch) Execute(ctx context.Context, input string) (string, error) { 43 | var arguments struct { 44 | Query string `json:"query"` 45 | } 46 | if err := json.Unmarshal([]byte(input), &arguments); err != nil { 47 | slog.Error("unmarshal failed", slog.Any("err", err)) 48 | return "", err 49 | } 50 | 51 | apiKey := os.Getenv("GOOGLE_SEARCH_API_KEY") 52 | engine := os.Getenv("GOOGLE_SEARCH_ENGINE") 53 | apiURL := fmt.Sprintf("https://www.googleapis.com/customsearch/v1?key=%v&cx=%v&q=%v&num=3", apiKey, engine, url.QueryEscape(arguments.Query)) 54 | 55 | req, err := http.NewRequest("GET", apiURL, nil) 56 | if err != nil { 57 | return "", err 58 | } 59 | req = req.WithContext(ctx) 60 | 61 | client := &http.Client{} 62 | resp, err := client.Do(req) 63 | if err != nil { 64 | return "", err 65 | } 66 | defer resp.Body.Close() 67 | 68 | content, err := io.ReadAll(resp.Body) 69 | if err != nil { 70 | return "", err 71 | } 72 | 73 | sr := &searchResult{} 74 | if err := json.Unmarshal(content, &sr); err != nil { 75 | return "", err 76 | } 77 | 78 | itemsJSON, _ := json.Marshal(sr.Items) 79 | return string(itemsJSON), nil 80 | } 81 | 82 | type searchResult struct { 83 | Items []*searchResultItem `json:"items"` 84 | } 85 | 86 | type searchResultItem struct { 87 | Title string `json:"title"` 88 | Link string `json:"link"` 89 | Snippet string `json:"snippet"` 90 | } 91 | -------------------------------------------------------------------------------- /pkg/tool/tool.go: -------------------------------------------------------------------------------- 1 | package tool 2 | 3 | import "context" 4 | 5 | var tools = make(map[string]Tool) 6 | 7 | func init() { 8 | for _, tool := range []Tool{ 9 | NewGetWeather(), 10 | NewGoogleSearch(), 11 | NewWebPageSummary(), 12 | NewGetHADevices(), 13 | NewExecuteHADevice(), 14 | } { 15 | tools[tool.Name()] = tool 16 | } 17 | } 18 | 19 | type Tool interface { 20 | Name() string 21 | Description() string 22 | Schema() map[string]interface{} 23 | Execute(context.Context, string) (string, error) 24 | } 25 | 26 | type Call struct { 27 | ID string `json:"id"` 28 | Type string `json:"type"` 29 | Name string `json:"name"` 30 | Arguments string `json:"arguments"` 31 | } 32 | 33 | func DefaultTools() []Tool { 34 | res := make([]Tool, 0, len(tools)) 35 | for _, tool := range tools { 36 | res = append(res, tool) 37 | } 38 | return res 39 | } 40 | 41 | func GetTools(names []string) []Tool { 42 | res := make([]Tool, 0, len(names)) 43 | for _, name := range names { 44 | if _, ok := tools[name]; ok { 45 | res = append(res, tools[name]) 46 | } 47 | } 48 | return res 49 | } 50 | -------------------------------------------------------------------------------- /pkg/tool/webpage_summary.go: -------------------------------------------------------------------------------- 1 | package tool 2 | 3 | import ( 4 | "context" 5 | "encoding/json" 6 | "github.com/PuerkitoBio/goquery" 7 | "io" 8 | "log/slog" 9 | "net/http" 10 | "regexp" 11 | "strings" 12 | ) 13 | 14 | type WebPageSummary struct { 15 | } 16 | 17 | func NewWebPageSummary() Tool { 18 | return &WebPageSummary{} 19 | } 20 | 21 | func (w *WebPageSummary) Name() string { 22 | return "webpage_summary" 23 | } 24 | 25 | func (w *WebPageSummary) Description() string { 26 | return "Summaries the content of a web page" 27 | } 28 | 29 | func (w *WebPageSummary) Schema() map[string]interface{} { 30 | return map[string]interface{}{ 31 | "type": "object", 32 | "properties": map[string]interface{}{ 33 | "url": map[string]interface{}{ 34 | "type": "string", 35 | "description": "the URL of the web page", 36 | }, 37 | }, 38 | "required": []string{"url"}, 39 | } 40 | } 41 | 42 | func (w *WebPageSummary) Execute(ctx context.Context, input string) (string, error) { 43 | var arguments struct { 44 | URL string `json:"url"` 45 | } 46 | if err := json.Unmarshal([]byte(input), &arguments); err != nil { 47 | slog.Error("unmarshal failed", slog.Any("err", err)) 48 | return "", err 49 | } 50 | 51 | req, err := http.NewRequest("GET", arguments.URL, nil) 52 | if err != nil { 53 | return "", err 54 | } 55 | req = req.WithContext(ctx) 56 | 57 | client := http.Client{} 58 | r, err := client.Do(req) 59 | if err != nil { 60 | return "", err 61 | } 62 | body, _ := io.ReadAll(r.Body) 63 | 64 | document, err := goquery.NewDocumentFromReader(strings.NewReader(string(body))) 65 | if err != nil { 66 | return "", err 67 | } 68 | document.Find("script, style, pre, code").Each(func(index int, item *goquery.Selection) { 69 | item.Remove() 70 | }) 71 | text := document.Text() 72 | 73 | re := regexp.MustCompile(`\s+`) 74 | text = re.ReplaceAllString(text, " ") 75 | 76 | return text, nil 77 | } 78 | -------------------------------------------------------------------------------- /pkg/wechat/crypt.go: -------------------------------------------------------------------------------- 1 | package wechat 2 | 3 | import ( 4 | "crypto/aes" 5 | "crypto/cipher" 6 | "encoding/base64" 7 | "encoding/binary" 8 | "errors" 9 | "fmt" 10 | "math/rand" 11 | "time" 12 | ) 13 | 14 | const ( 15 | blockSize = 32 16 | blockMask = blockSize - 1 17 | 18 | randomSize = 16 19 | contentLenSize = 4 20 | ) 21 | 22 | func init() { 23 | rand.New(rand.NewSource(time.Now().UnixNano())) 24 | } 25 | 26 | func EncryptMsg(appID string, msg string, encodingAesKey string) (string, error) { 27 | aesKey, err := base64.StdEncoding.DecodeString(encodingAesKey + "=") 28 | if err != nil { 29 | return "", err 30 | } 31 | 32 | if len(aesKey) != 32 { 33 | return "", errors.New("invalid aes key length") 34 | } 35 | 36 | appIDOffset := randomSize + contentLenSize + len(msg) 37 | contentLen := appIDOffset + len(appID) 38 | amountToPad := blockSize - contentLen&blockMask 39 | plaintextLen := contentLen + amountToPad 40 | plaintext := make([]byte, plaintextLen) 41 | 42 | copy(plaintext[:randomSize], randString(randomSize)) 43 | binary.BigEndian.PutUint32(plaintext[randomSize:randomSize+contentLenSize], uint32(len(msg))) 44 | copy(plaintext[randomSize+contentLenSize:], msg) 45 | copy(plaintext[appIDOffset:], appID) 46 | 47 | for i := contentLen; i < plaintextLen; i++ { 48 | plaintext[i] = byte(amountToPad) 49 | } 50 | 51 | block, err := aes.NewCipher(aesKey) 52 | if err != nil { 53 | return "", err 54 | } 55 | 56 | mode := cipher.NewCBCEncrypter(block, aesKey[:aes.BlockSize]) 57 | mode.CryptBlocks(plaintext, plaintext) 58 | 59 | return base64.StdEncoding.EncodeToString(plaintext), nil 60 | } 61 | 62 | func DecryptMsg(appID, encryptedMsg, encodingAesKey string) (string, error) { 63 | aesKey, err := base64.StdEncoding.DecodeString(encodingAesKey + "=") 64 | if err != nil { 65 | return "", err 66 | } 67 | 68 | if len(aesKey) != 32 { 69 | return "", errors.New("invalid aes key length") 70 | } 71 | 72 | cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg) 73 | if err != nil { 74 | return "", fmt.Errorf("base64 decode error: %v", err) 75 | } 76 | 77 | block, err := aes.NewCipher(aesKey) 78 | if err != nil { 79 | return "", fmt.Errorf("failed to create aes cipher: %v", err) 80 | } 81 | 82 | plainBytes := make([]byte, len(cipherText)) 83 | mode := cipher.NewCBCDecrypter(block, aesKey[:aes.BlockSize]) 84 | mode.CryptBlocks(plainBytes, cipherText) 85 | 86 | pad := int(plainBytes[len(plainBytes)-1]) 87 | if pad < 1 || pad > blockSize { 88 | return "", errors.New("invalid padding byte") 89 | } 90 | plainBytes = plainBytes[:len(plainBytes)-pad] 91 | 92 | content := plainBytes[randomSize:] 93 | if len(content) < contentLenSize { 94 | return "", errors.New("invalid content length") 95 | } 96 | 97 | contentLen := binary.BigEndian.Uint32(content[:contentLenSize]) 98 | if len(content) < int(contentLenSize+contentLen) { 99 | return "", errors.New("invalid content length") 100 | } 101 | 102 | rawContent := content[contentLenSize : contentLenSize+contentLen] 103 | fromAppID := string(content[contentLenSize+contentLen:]) 104 | 105 | if fromAppID != appID { 106 | return "", errors.New("app id mismatch") 107 | } 108 | 109 | return string(rawContent), nil 110 | } 111 | 112 | func randString(n int) string { 113 | var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") 114 | b := make([]rune, n) 115 | for i := range b { 116 | b[i] = letterRunes[rand.Intn(len(letterRunes))] 117 | } 118 | return string(b) 119 | } 120 | -------------------------------------------------------------------------------- /pkg/wechat/model.go: -------------------------------------------------------------------------------- 1 | package wechat 2 | 3 | import "encoding/xml" 4 | 5 | const ( 6 | MsgTypeText = "text" 7 | ) 8 | 9 | type CommonMessage struct { 10 | XMLName xml.Name `xml:"xml"` 11 | Encrypt string `xml:"Encrypt" json:"Encrypt"` 12 | ToUserName string `xml:"ToUserName" json:"ToUserName"` 13 | FromUserName string `xml:"FromUserName" json:"FromUserName"` 14 | CreateTime int64 `xml:"CreateTime" json:"CreateTime"` 15 | MsgType string `xml:"MsgType" json:"MsgType"` 16 | MsgID string `xml:"MsgId" json:"MsgId"` 17 | } 18 | 19 | type TextMessage struct { 20 | CommonMessage 21 | Content string `xml:"Content" json:"Content"` 22 | } 23 | 24 | type ImageMessage struct { 25 | CommonMessage 26 | PicURL string `xml:"PicUrl" json:"PicUrl"` 27 | MediaId string `xml:"MediaId" json:"MediaId"` 28 | } 29 | 30 | type EncryptMessage struct { 31 | XMLName xml.Name `xml:"xml"` 32 | Encrypt string `xml:"Encrypt" json:"Encrypt"` 33 | MsgSignature string `xml:"MsgSignature" json:"MsgSignature"` 34 | Timestamp int64 `xml:"TimeStamp" json:"TimeStamp"` 35 | Nonce string `xml:"Nonce" json:"Nonce"` 36 | } 37 | -------------------------------------------------------------------------------- /pkg/wechat/signature.go: -------------------------------------------------------------------------------- 1 | package wechat 2 | 3 | import ( 4 | "crypto/sha1" 5 | "fmt" 6 | "sort" 7 | ) 8 | 9 | func Signature(params ...string) string { 10 | sort.Strings(params) 11 | h := sha1.New() 12 | for _, s := range params { 13 | h.Write([]byte(s)) 14 | } 15 | return fmt.Sprintf("%x", h.Sum(nil)) 16 | } 17 | -------------------------------------------------------------------------------- /vercel.json: -------------------------------------------------------------------------------- 1 | { 2 | "routes": [ 3 | { 4 | "src": "/.*", 5 | "dest": "/api/router.go" 6 | } 7 | ] 8 | } -------------------------------------------------------------------------------- /web/chat_handler.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "encoding/json" 5 | "fmt" 6 | "github.com/tonnie17/wxagent/pkg/agent" 7 | "github.com/tonnie17/wxagent/pkg/config" 8 | "github.com/tonnie17/wxagent/pkg/llm" 9 | "github.com/tonnie17/wxagent/pkg/memory" 10 | "github.com/tonnie17/wxagent/pkg/rag" 11 | "github.com/tonnie17/wxagent/pkg/tool" 12 | "io" 13 | "log/slog" 14 | "net/http" 15 | ) 16 | 17 | type ChatHandler struct { 18 | config *config.Config 19 | ragClient *rag.Client 20 | } 21 | 22 | func NewChatHandler(config *config.Config, ragClient *rag.Client) *ChatHandler { 23 | return &ChatHandler{ 24 | config: config, 25 | ragClient: ragClient, 26 | } 27 | } 28 | 29 | func (h *ChatHandler) Stream(w http.ResponseWriter, r *http.Request) { 30 | w.Header().Set("Content-Type", "text/event-stream") 31 | w.Header().Set("Cache-Control", "no-cache") 32 | w.Header().Set("Connection", "keep-alive") 33 | 34 | flusher, ok := w.(http.Flusher) 35 | if !ok { 36 | http.Error(w, "streaming unsupported", http.StatusInternalServerError) 37 | return 38 | } 39 | 40 | body, err := io.ReadAll(r.Body) 41 | if err != nil { 42 | http.Error(w, err.Error(), http.StatusInternalServerError) 43 | return 44 | } 45 | 46 | notify := r.Context().Done() 47 | req := &struct { 48 | Model string `json:"model"` 49 | MaxInputTokens int `json:"max_input_tokens"` 50 | MaxTokens int `json:"max_tokens"` 51 | Temperature float32 `json:"temperature"` 52 | TopP float32 `json:"top_p"` 53 | Messages []*llm.ChatMessage `json:"messages"` 54 | }{} 55 | 56 | if err := json.Unmarshal(body, &req); err != nil { 57 | http.Error(w, err.Error(), http.StatusInternalServerError) 58 | return 59 | } 60 | 61 | var maxInputTokens int 62 | if req.MaxInputTokens == 0 { 63 | maxInputTokens = 500 64 | } 65 | mem := memory.NewTokenBase(req.Model, maxInputTokens) 66 | mem.Update(req.Messages) 67 | 68 | a := agent.NewAgent(&h.config.AgentConfig, llm.New(h.config.LLMProvider), mem, tool.GetTools(h.config.AgentTools), h.ragClient) 69 | out := make(chan *llm.ChatMessage) 70 | go func() { 71 | if err = a.ProcessStream(r.Context(), nil, out); err != nil { 72 | return 73 | } 74 | }() 75 | 76 | for { 77 | select { 78 | case <-notify: 79 | slog.Info("client disconnected") 80 | return 81 | case msg, ok := <-out: 82 | if !ok { 83 | slog.Info("client finished") 84 | return 85 | } 86 | if err != nil { 87 | fmt.Fprintf(w, "%s\n\n", err.Error()) 88 | return 89 | } 90 | msgJSON, _ := json.Marshal(msg) 91 | fmt.Fprintf(w, "%s\n\n", string(msgJSON)) 92 | flusher.Flush() 93 | } 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /web/keyword.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "fmt" 5 | "strings" 6 | ) 7 | 8 | const ( 9 | continueKeyword string = "继续" 10 | ) 11 | 12 | func detectContinue(input string) bool { 13 | return strings.EqualFold(input, continueKeyword) 14 | } 15 | 16 | func getContinueHint() string { 17 | return fmt.Sprintf("处理时间过长,请稍后回复%s获取上一轮对话结果", continueKeyword) 18 | } 19 | 20 | func getContinueEmptyHint() string { 21 | return "上一轮对话结果为空,请重试对话" 22 | } 23 | 24 | func getProcessingHint() string { 25 | return "上一轮对话正在处理中,请稍后再试" 26 | } 27 | -------------------------------------------------------------------------------- /web/router.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "github.com/go-chi/chi/v5" 5 | "github.com/go-chi/chi/v5/middleware" 6 | slogchi "github.com/samber/slog-chi" 7 | "github.com/tonnie17/wxagent/pkg/config" 8 | "github.com/tonnie17/wxagent/pkg/rag" 9 | "log/slog" 10 | "net/http" 11 | "strings" 12 | "time" 13 | ) 14 | 15 | func SetupRouter(r chi.Router, config *config.Config, logger *slog.Logger, ragClient *rag.Client) { 16 | r.Use(slogchi.New(logger)) 17 | r.Use(middleware.Recoverer) 18 | r.Use(middleware.StripSlashes) 19 | 20 | memStore := NewUserMemoryStore(config.WechatMemTTL) 21 | memStore.CheckAndClear(time.Second) 22 | 23 | wechatHandler := NewWechatHandler(config, memStore, ragClient) 24 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("OK")) }) 25 | r.Get("/wechat/receive", wechatHandler.Receive) 26 | r.Post("/wechat/receive", wechatHandler.Receive) 27 | 28 | chatHandler := NewChatHandler(config, ragClient) 29 | r.With(apiKeyAuth(config.ChatAPIKEY)).Post("/chat/stream", chatHandler.Stream) 30 | } 31 | 32 | func apiKeyAuth(apiKey string) func(http.Handler) http.Handler { 33 | return func(next http.Handler) http.Handler { 34 | fn := func(w http.ResponseWriter, r *http.Request) { 35 | var token string 36 | splits := strings.Split(r.Header.Get("Authorization"), "Bearer ") 37 | if len(splits) == 2 { 38 | token = splits[1] 39 | } 40 | if apiKey == "" || token != apiKey { 41 | http.Error(w, "miss api key", http.StatusUnauthorized) 42 | return 43 | } 44 | next.ServeHTTP(w, r) 45 | } 46 | return http.HandlerFunc(fn) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /web/user_memory.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "github.com/tonnie17/wxagent/pkg/memory" 5 | "sync" 6 | "time" 7 | ) 8 | 9 | type UserMemory struct { 10 | mem memory.Memory 11 | lastAccess time.Time 12 | } 13 | 14 | type UserMemoryStore struct { 15 | memTTL time.Duration 16 | once sync.Once 17 | userMemory sync.Map 18 | } 19 | 20 | func NewUserMemoryStore(memTTL time.Duration) *UserMemoryStore { 21 | return &UserMemoryStore{ 22 | memTTL: memTTL, 23 | } 24 | } 25 | 26 | func (s *UserMemoryStore) GetOrNew(userID string, memFactory func() memory.Memory) memory.Memory { 27 | v, ok := s.userMemory.Load(userID) 28 | if !ok { 29 | mem := memFactory() 30 | userMem := &UserMemory{ 31 | mem: mem, 32 | lastAccess: time.Now(), 33 | } 34 | s.userMemory.Store(userID, userMem) 35 | return mem 36 | } 37 | userMem := v.(*UserMemory) 38 | userMem.lastAccess = time.Now() 39 | return userMem.mem 40 | } 41 | 42 | func (s *UserMemoryStore) CheckAndClear(interval time.Duration) { 43 | worker := func() { 44 | ticker := time.NewTicker(interval) 45 | defer ticker.Stop() 46 | 47 | for range ticker.C { 48 | s.userMemory.Range(func(key, value any) bool { 49 | userMem := value.(*UserMemory) 50 | if time.Now().Sub(userMem.lastAccess) > s.memTTL { 51 | s.userMemory.Delete(key) 52 | } 53 | return true 54 | }) 55 | } 56 | } 57 | s.once.Do(func() { 58 | go worker() 59 | }) 60 | } 61 | -------------------------------------------------------------------------------- /web/util.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "encoding/xml" 5 | "io" 6 | "log/slog" 7 | ) 8 | 9 | func xmlParseRequest(body io.Reader, req interface{}) error { 10 | b, err := io.ReadAll(body) 11 | if err != nil { 12 | slog.Error("body read failed", slog.Any("err", err)) 13 | return err 14 | } 15 | 16 | if err := xml.Unmarshal(b, &req); err != nil { 17 | slog.Error("request parse failed", slog.Any("err", err)) 18 | return err 19 | } 20 | 21 | return nil 22 | } 23 | -------------------------------------------------------------------------------- /web/wechat_handler.go: -------------------------------------------------------------------------------- 1 | package web 2 | 3 | import ( 4 | "context" 5 | "encoding/xml" 6 | "errors" 7 | "github.com/tonnie17/wxagent/pkg/agent" 8 | "github.com/tonnie17/wxagent/pkg/config" 9 | "github.com/tonnie17/wxagent/pkg/llm" 10 | "github.com/tonnie17/wxagent/pkg/memory" 11 | "github.com/tonnie17/wxagent/pkg/rag" 12 | "github.com/tonnie17/wxagent/pkg/tool" 13 | "github.com/tonnie17/wxagent/pkg/wechat" 14 | "log/slog" 15 | "net/http" 16 | "strconv" 17 | "strings" 18 | "time" 19 | ) 20 | 21 | type WechatHandler struct { 22 | config *config.Config 23 | memStore *UserMemoryStore 24 | ragClient *rag.Client 25 | } 26 | 27 | func NewWechatHandler(config *config.Config, memStore *UserMemoryStore, ragClient *rag.Client) *WechatHandler { 28 | return &WechatHandler{ 29 | config: config, 30 | memStore: memStore, 31 | ragClient: ragClient, 32 | } 33 | } 34 | 35 | func (h *WechatHandler) Receive(w http.ResponseWriter, r *http.Request) { 36 | signature := r.URL.Query().Get("signature") 37 | msgSignature := r.URL.Query().Get("msg_signature") 38 | timestamp := r.URL.Query().Get("timestamp") 39 | nonce := r.URL.Query().Get("nonce") 40 | echoStr := r.URL.Query().Get("echostr") 41 | 42 | if signature != wechat.Signature(h.config.WechatToken, timestamp, nonce) { 43 | slog.Error("signature check failed", 44 | slog.String("signature", signature), 45 | slog.String("timestamp", timestamp), 46 | slog.String("nonce", nonce), 47 | slog.String("echostr", echoStr), 48 | ) 49 | http.Error(w, "signature check failed", http.StatusUnauthorized) 50 | return 51 | } 52 | 53 | if echoStr != "" { 54 | w.Write([]byte(echoStr)) 55 | return 56 | } 57 | 58 | var reqMessage wechat.TextMessage 59 | if err := xmlParseRequest(r.Body, &reqMessage); err != nil { 60 | http.Error(w, err.Error(), http.StatusInternalServerError) 61 | return 62 | } 63 | 64 | if reqMessage.Encrypt != "" { 65 | if msgSignature != wechat.Signature(h.config.WechatToken, timestamp, nonce, reqMessage.Encrypt) { 66 | slog.Error("msg signature check failed", 67 | slog.String("signature", signature), 68 | slog.String("timestamp", timestamp), 69 | slog.String("nonce", nonce), 70 | slog.String("echostr", echoStr), 71 | ) 72 | http.Error(w, "signature check failed", http.StatusUnauthorized) 73 | return 74 | } 75 | 76 | content, err := wechat.DecryptMsg(h.config.WechatAppID, reqMessage.Encrypt, h.config.WechatEncodingAESKey) 77 | if err != nil { 78 | http.Error(w, err.Error(), http.StatusInternalServerError) 79 | return 80 | } 81 | 82 | if err := xmlParseRequest(strings.NewReader(content), &reqMessage); err != nil { 83 | http.Error(w, err.Error(), http.StatusInternalServerError) 84 | return 85 | } 86 | } 87 | 88 | slog.Info("receive req", slog.Any("req", reqMessage)) 89 | 90 | if len(h.config.WechatAllowList) > 0 { 91 | var isAllow bool 92 | for _, userID := range h.config.WechatAllowList { 93 | if userID == reqMessage.FromUserName { 94 | isAllow = true 95 | break 96 | } 97 | } 98 | if !isAllow { 99 | http.Error(w, "access denied", http.StatusUnauthorized) 100 | return 101 | } 102 | } 103 | 104 | mem := h.memStore.GetOrNew(reqMessage.FromUserName, func() memory.Memory { 105 | return memory.NewBuffer(h.config.WechatMemMsgSize) 106 | }) 107 | 108 | a := agent.NewAgent(&h.config.AgentConfig, llm.New(h.config.LLMProvider), mem, tool.GetTools(h.config.AgentTools), h.ragClient) 109 | 110 | input := strings.TrimSpace(reqMessage.Content) 111 | result := make(chan string) 112 | go func() { 113 | switch reqMessage.MsgType { 114 | case wechat.MsgTypeText: 115 | var ( 116 | output string 117 | err error 118 | ) 119 | if detectContinue(input) { 120 | if output, err = a.ChatContinue(context.Background()); output == "" { 121 | output = getContinueEmptyHint() 122 | } 123 | } else { 124 | output, err = a.Chat(context.Background(), input) 125 | } 126 | if err != nil { 127 | if errors.Is(err, agent.ErrMemoryInUse) { 128 | result <- getProcessingHint() 129 | } else { 130 | result <- err.Error() 131 | } 132 | } else { 133 | result <- output 134 | } 135 | } 136 | close(result) 137 | }() 138 | 139 | var content string 140 | ticker := time.NewTicker(h.config.WechatTimeout) 141 | select { 142 | case <-ticker.C: 143 | content = getContinueHint() 144 | case content = <-result: 145 | } 146 | 147 | now := time.Now().Unix() 148 | var respMessage wechat.TextMessage 149 | respMessage.FromUserName = reqMessage.ToUserName 150 | respMessage.ToUserName = reqMessage.FromUserName 151 | respMessage.MsgType = reqMessage.MsgType 152 | respMessage.Content = content 153 | respMessage.CreateTime = now 154 | 155 | resp, err := xml.Marshal(respMessage) 156 | if err != nil { 157 | http.Error(w, err.Error(), http.StatusInternalServerError) 158 | return 159 | } 160 | 161 | if reqMessage.Encrypt != "" { 162 | encrypt, err := wechat.EncryptMsg(h.config.WechatAppID, string(resp), h.config.WechatEncodingAESKey) 163 | if err != nil { 164 | http.Error(w, err.Error(), http.StatusInternalServerError) 165 | return 166 | } 167 | encryptMessage := &wechat.EncryptMessage{ 168 | Encrypt: encrypt, 169 | MsgSignature: wechat.Signature(h.config.WechatToken, nonce, encrypt, strconv.FormatInt(now, 10)), 170 | Timestamp: now, 171 | Nonce: nonce, 172 | } 173 | resp, err = xml.Marshal(encryptMessage) 174 | if err != nil { 175 | http.Error(w, err.Error(), http.StatusInternalServerError) 176 | return 177 | } 178 | } 179 | 180 | w.Header().Set("Content-Type", "application/xml") 181 | w.WriteHeader(http.StatusOK) 182 | w.Write(resp) 183 | } 184 | --------------------------------------------------------------------------------