├── .env.example
├── .github
└── workflows
│ └── go.yml
├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── agent.gif
├── api
└── router.go
├── cmd
├── cli
│ └── main.go
└── server
│ └── main.go
├── examples
├── get_weather.go
├── google_search.go
├── home_assistant.go
├── rag.go
└── webpage_summary.go
├── go.mod
├── go.sum
├── knowledge_base
├── apple.txt
└── morocco.txt
├── pkg
├── agent
│ └── agent.go
├── config
│ └── config.go
├── embedding
│ ├── default.go
│ ├── embedding.go
│ └── openai.go
├── ha
│ ├── home_assistant.go
│ └── model.go
├── llm
│ ├── default.go
│ ├── llm.go
│ ├── message.go
│ └── openai.go
├── memory
│ ├── buffer.go
│ ├── memory.go
│ └── token_base.go
├── provider
│ └── util.go
├── rag
│ ├── document.go
│ ├── rag.go
│ ├── splitter.go
│ └── store.go
├── tool
│ ├── execute_ha_devices.go
│ ├── get_ha_devices.go
│ ├── get_weather.go
│ ├── google_search.go
│ ├── tool.go
│ └── webpage_summary.go
└── wechat
│ ├── crypt.go
│ ├── model.go
│ └── signature.go
├── vercel.json
└── web
├── chat_handler.go
├── keyword.go
├── router.go
├── user_memory.go
├── util.go
└── wechat_handler.go
/.env.example:
--------------------------------------------------------------------------------
1 | AGENT_TOOLS=
2 | LLM_MODEL=
3 | OPENAI_API_KEY=
4 | OPENAI_BASE_URL=
5 | WECHAT_TOKEN=
6 |
--------------------------------------------------------------------------------
/.github/workflows/go.yml:
--------------------------------------------------------------------------------
1 | # This workflow will build a golang project
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-go
3 |
4 | name: Go
5 |
6 | on:
7 | push:
8 | branches: [ "master" ]
9 | pull_request:
10 | branches: [ "master" ]
11 |
12 | jobs:
13 |
14 | build:
15 | runs-on: ubuntu-latest
16 | steps:
17 | - uses: actions/checkout@v4
18 |
19 | - name: Set up Go
20 | uses: actions/setup-go@v4
21 | with:
22 | go-version: '1.21'
23 |
24 | - name: Build
25 | run: go build -v ./...
26 |
27 | - name: Test
28 | run: go test -v ./...
29 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .env
2 | .idea/
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM golang:1.21.8-alpine AS builder
2 |
3 | WORKDIR /app
4 |
5 | COPY . .
6 |
7 | RUN go mod tidy && go build -o server cmd/server/main.go
8 |
9 |
10 | FROM alpine:3.18
11 |
12 | COPY --from=builder /app/server /server
13 |
14 | RUN chmod +x /server
15 |
16 | EXPOSE 8082
17 |
18 | CMD ["/server"]
19 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 tonnie17
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # wxagent
2 |
3 | 让公众号瞬间变身为智能助手,支持特性:
4 |
5 | - 接近零成本部署,只需要一个域名即可绑定到公众号
6 | - 支持对话记忆,超时自动回复,以及对话结果回溯
7 | - 支持多种工具调用,包括:获取天气,关键字搜索,网页总结,Home Assistant 设备控制
8 | - 支持构建本地 RAG 知识库进行读取和检索
9 |
10 | 一键部署到 Vercel:
11 |
12 | [](https://vercel.com/new/clone?repository-url=https%3A%2F%2Fgithub.com%2Ftonnie17%2Fwxagent&env=WECHAT_TOKEN,LLM_MODEL,OPENAI_API_KEY,OPENAI_BASE_URL)
13 |
14 | 1. 参考[配置](#配置)填写环境变量,完成部署
15 | 2. 拿到Vercel生成的默认域名进行访问,没问题的话会输出`OK`
16 | 3. 绑定域名到服务器地址,然后配置公众号服务器地址为:`{domain}/wechat/receive/`
17 |
18 | ## 示例
19 |
20 | 点击查看对话效果
21 |
22 | 基础对话:
23 |
24 |
25 |
26 | 获取天气:
27 |
28 |
29 |
30 | 文章总结:
31 |
32 |
33 |
34 | 信息搜索:
35 |
36 |
37 |
38 | 知识库检索:
39 |
40 |
41 |
42 |
43 |
44 |
45 | ## 配置
46 |
47 | 所有配置通过环境变量指定,支持通过`.env`文件自动加载:
48 |
49 | - `WECHAT_TOKEN`:公众号服务器配置的令牌(Token)
50 | - `WECHAT_ALLOW_LIST`:允许交互的微信账号(openid),用逗号分隔,默认无限制
51 | - `WECHAT_MEM_TTL`:公众号单轮对话记忆保存时间,默认为`5m`
52 | - `WECHAT_MEM_MSG_SIZE`:公众号单轮对话记忆消息记录上限(包括工具消息),默认为`6`
53 | - `WECHAT_TIMEOUT`:公众号单轮对话超时时间(公众号限制回复时间不能超过5秒),默认为`4s`
54 | - `WECHAT_APP_ID`:公众号 AppID,安全模式下需要指定
55 | - `WECHAT_ENCODING_AES_KEY`:公众号消息加解密密钥 (EncodingAESKey),安全模式下需要指定
56 | - `LLM_PROVIDER`:LLM 提供者,支持:`openai`,默认为`openai`
57 | - `OPENAI_API_KEY`:OpenAI(兼容接口)API KEY
58 | - `OPENAI_BASE_URL`:OpenAI(兼容接口)Base URL
59 | - `SERVER_ADDR`:服务器模式的启动地址,默认为`0.0.0.0:8082`
60 | - `USE_RAG`:是否开启 RAG 从知识库检索查询
61 | - `EMBEDDING_PROVIDER`:文本嵌入模型提供者,支持:`openai`,默认为 `openai`
62 | - `KNOWLEDGE_BASE_PATH`:本地知识库目录路径,目前支持文件格式:`txt`,默认为 `./knowledge_base`
63 |
64 | ### Agent 配置
65 |
66 | - `AGENT_TOOLS`:Agent 可以使用的 Tools,用逗号分隔,需要配置相关的环境变量,支持:
67 | - `google_search`:Google 搜索
68 | - `get_weather`:天气查询
69 | - `webpage_summary`:网页文本总结
70 | - `get_devices`:获取 Home Assistant 设备列表
71 | - `execute_device`:执行 Home Assistant 设备动作
72 | - `AGENT_TIMEOUT`:Agent 对话超时时间,默认为`30s`
73 | - `MAX_TOOL_ITER`:Agent 调用工具最大迭代次数,默认为`3`
74 | - `TOOL_TIMEOUT`:工具调用超时时间,默认为`10s`
75 | - `LLM_MODEL`:LLM 模型名称,默认为`gpt-3.5-turbo`
76 | - `LLM_MAX_TOKENS`:最大输出 Token 数量,默认为`500`
77 | - `LLM_TEMPERATURE`:Temperature 参数,默认为`0.2`
78 | - `LLM_TOP_P`:Top P 参数,默认为`0.9`
79 | - `SYSTEM_PROMPT`:设置 Agent 对话的 System Prompt
80 | - `EMBEDDING_MODEL`:文本嵌入模型,支持:`openai`,默认为 `openai`
81 |
82 |
83 | ### 工具配置
84 |
85 | #### Google 搜索(google_search)
86 |
87 | - `GOOGLE_SEARCH_ENGINE`:Google 搜索引擎
88 | - `GOOGLE_SEARCH_API_KEY`:Google 搜索 API Key
89 |
90 |
91 | #### 天气查询(get_weather)
92 |
93 | - `OPENWEATHERMAP_API_KEY`:OpenWeatherMap API Key
94 |
95 |
96 | #### Home Assistant(get_devices,execute_device)
97 |
98 | - `HA_BASE_URL`:Home Assistant 服务器地址
99 | - `HA_BEARER_TOKEN`:Home Assistant API 验证 Bearer Token
100 |
101 |
102 | ## 扩展使用
103 |
104 | 与Agent交互
105 |
106 | ```go
107 | package main
108 |
109 | import (
110 | "context"
111 | "fmt"
112 | "log"
113 | "time"
114 |
115 | "github.com/tonnie17/wxagent/pkg/agent"
116 | "github.com/tonnie17/wxagent/pkg/config"
117 | "github.com/tonnie17/wxagent/pkg/llm"
118 | "github.com/tonnie17/wxagent/pkg/memory"
119 | "github.com/tonnie17/wxagent/pkg/tool"
120 | )
121 |
122 | func main() {
123 | tools := []tool.Tool{
124 | tool.NewWebPageSummary(),
125 | }
126 |
127 | agent := agent.NewAgent(&config.AgentConfig{
128 | AgentTools: []string{"webpage_summary"},
129 | AgentTimeout: 30 * time.Second,
130 | MaxToolIter: 3,
131 | ToolTimeout: 10 * time.Second,
132 | Model: "qwen-plus",
133 | MaxTokens: 500,
134 | Temperature: 0.2,
135 | TopP: 0.9,
136 | }, llm.NewOpenAI(), memory.NewBuffer(6), tools, nil)
137 |
138 | output, err := agent.Chat(context.Background(), "总结一下:https://golangnote.com/golang/golang-stringsbuilder-vs-bytesbuffer")
139 | if err != nil {
140 | log.Fatalf("chat failed: %v", err)
141 | }
142 |
143 | fmt.Println(output)
144 | }
145 | ```
146 |
147 |
148 |
149 | 自定义工具
150 |
151 |
152 | 要定义一个工具,需要实现`Tool`定义的接口:
153 |
154 | ```go
155 | type Tool interface {
156 | Name() string
157 | Description() string
158 | Schema() map[string]interface{}
159 | Execute(context.Context, string) (string, error)
160 | }
161 | ```
162 |
163 | 方法含义:
164 | - `Name()`:工具名称
165 | - `Description()`:工具描述,描述尽量清晰以便模型了解选择工具进行调用
166 | - `Schema() map[string]interface{}`:提供工具的参数描述以及定义
167 | - `Execute(context.Context, string) (string, error)`:工具的执行逻辑,接收模型输入,返回执行结果
168 |
169 |
170 |
171 | 接入新模型
172 |
173 |
174 | 要接入新的模型,需要实现`LLM`定义的接口:
175 |
176 | ```go
177 | type LLM interface {
178 | Chat(context.Context, model string, messages []*ChatMessage, options ...ChatOption) (*ChatMessage, error)
179 | }
180 | ```
181 |
182 |
183 |
184 | ## 本地开发
185 |
186 | 创建配置文件`.env`,将项目配置写入到文件:
187 |
188 | ```sh
189 | echo "{CONFIG_KEY}={CONFIG_VALUE}" > .env
190 | ```
191 |
192 | 拉取依赖:
193 |
194 | ```sh
195 | go mod tidy
196 | ```
197 |
198 | 启动服务器:
199 |
200 | ```sh
201 | go run cmd/server/main.go
202 | ```
203 |
204 | 启动交互式命令行:
205 |
206 | ```sh
207 | go run cmd/cli/main.go
208 | ```
209 |
210 | Streamlit 调用:
211 |
212 | ```python
213 | with httpx.stream("POST", server_addr, headers=headers, data=json.dumps({
214 | "messages": st.session_state.messages
215 | }), timeout=30) as r:
216 | for line in r.iter_lines():
217 | if not line.strip():
218 | continue
219 | data = json.loads(line)
220 | # ...
221 | ```
222 |
223 |
224 |
225 | ## 构建
226 |
227 | ### 本地构建
228 |
229 | 编译二进制文件:
230 |
231 | ```sh
232 | go build -o server /cmd/server
233 | ```
234 |
235 | 运行:
236 |
237 | ```sh
238 | ./server
239 | ```
240 |
241 | ### Docker 构建
242 |
243 | 创建镜像:
244 |
245 | ```sh
246 | docker build -t wxagent .
247 | ```
248 |
249 | 运行容器:
250 |
251 | ```sh
252 | docker run -it --rm -p 8082:8082 --env-file .env wxagent
253 | ```
254 |
--------------------------------------------------------------------------------
/agent.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonnie17/wxagent/b8bcf2c707400bbf384739055baf3e4ab3476932/agent.gif
--------------------------------------------------------------------------------
/api/router.go:
--------------------------------------------------------------------------------
1 | package api
2 |
3 | import (
4 | "context"
5 | "github.com/go-chi/chi/v5"
6 | _ "github.com/joho/godotenv/autoload"
7 | "github.com/tonnie17/wxagent/pkg/config"
8 | "github.com/tonnie17/wxagent/pkg/embedding"
9 | "github.com/tonnie17/wxagent/pkg/rag"
10 | "github.com/tonnie17/wxagent/web"
11 | "log/slog"
12 | "net/http"
13 | "os"
14 | )
15 |
16 | var router chi.Router
17 |
18 | func init() {
19 | logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
20 | slog.SetDefault(logger)
21 |
22 | cfg, err := config.LoadConfig()
23 | if err != nil {
24 | slog.Error("config load failed", slog.Any("err", err))
25 | return
26 | }
27 |
28 | var ragClient *rag.Client
29 | if cfg.UseRAG {
30 | store, err := rag.NewPgVectorStore()
31 | if err != nil {
32 | slog.Error("init vector store failed", slog.Any("err", err))
33 | return
34 | }
35 |
36 | ragClient = rag.NewClient(embedding.New(cfg.EmbeddingProvider), store)
37 | if err := ragClient.BuildKnowledgeBase(context.Background(), cfg.KnowledgeBasePath, cfg.EmbeddingModel, false); err != nil {
38 | slog.Error("load data failed", slog.Any("err", err))
39 | return
40 | }
41 | }
42 |
43 | router = chi.NewRouter()
44 | web.SetupRouter(router, cfg, logger, ragClient)
45 | }
46 |
47 | func Handler(w http.ResponseWriter, r *http.Request) {
48 | router.ServeHTTP(w, r)
49 | }
50 |
--------------------------------------------------------------------------------
/cmd/cli/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "github.com/chzyer/readline"
7 | "github.com/tonnie17/wxagent/pkg/agent"
8 | "github.com/tonnie17/wxagent/pkg/config"
9 | "github.com/tonnie17/wxagent/pkg/llm"
10 | "github.com/tonnie17/wxagent/pkg/memory"
11 | "github.com/tonnie17/wxagent/pkg/tool"
12 | "log"
13 | "log/slog"
14 | "strings"
15 | )
16 |
17 | func main() {
18 | rl, err := readline.NewEx(&readline.Config{
19 | Prompt: "> ",
20 | })
21 |
22 | if err != nil {
23 | log.Fatal("readline create failed:", err)
24 | }
25 | defer rl.Close()
26 |
27 | cfg, err := config.LoadConfig()
28 | if err != nil {
29 | slog.Error("config load failed", slog.Any("err", err))
30 | return
31 | }
32 |
33 | a := agent.NewAgent(&cfg.AgentConfig, llm.NewOpenAI(), memory.NewBuffer(6), tool.DefaultTools(), nil)
34 | for {
35 | line, err := rl.Readline()
36 | if err != nil {
37 | break
38 | }
39 |
40 | input := strings.TrimSpace(line)
41 |
42 | output, err := a.Chat(context.Background(), input)
43 | if err != nil {
44 | log.Fatalln(err)
45 | }
46 |
47 | fmt.Println(output)
48 | }
49 |
50 | }
51 |
--------------------------------------------------------------------------------
/cmd/server/main.go:
--------------------------------------------------------------------------------
1 | package main
2 |
3 | import (
4 | "context"
5 | "errors"
6 | "github.com/go-chi/chi/v5"
7 | "github.com/tonnie17/wxagent/pkg/config"
8 | "github.com/tonnie17/wxagent/pkg/embedding"
9 | "github.com/tonnie17/wxagent/pkg/rag"
10 | "github.com/tonnie17/wxagent/web"
11 | "log/slog"
12 | "net/http"
13 | "os"
14 | "os/signal"
15 | "syscall"
16 | "time"
17 |
18 | _ "github.com/joho/godotenv/autoload"
19 | )
20 |
21 | func main() {
22 | logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo}))
23 | slog.SetDefault(logger)
24 |
25 | cfg, err := config.LoadConfig()
26 | if err != nil {
27 | slog.Error("config load failed", slog.Any("err", err))
28 | return
29 | }
30 |
31 | var ragClient *rag.Client
32 | if cfg.UseRAG {
33 | store, err := rag.NewPgVectorStore()
34 | if err != nil {
35 | slog.Error("init vector store failed", slog.Any("err", err))
36 | return
37 | }
38 | defer store.Release()
39 |
40 | ragClient = rag.NewClient(embedding.New(cfg.EmbeddingProvider), store)
41 | if err := ragClient.BuildKnowledgeBase(context.Background(), cfg.KnowledgeBasePath, cfg.EmbeddingModel, false); err != nil {
42 | slog.Error("load data failed", slog.Any("err", err))
43 | return
44 | }
45 | }
46 |
47 | r := chi.NewRouter()
48 | web.SetupRouter(r, cfg, logger, ragClient)
49 |
50 | server := &http.Server{
51 | Addr: cfg.ServerAddr,
52 | Handler: r,
53 | }
54 | go func() {
55 | slog.Info("running on " + cfg.ServerAddr)
56 | if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
57 | slog.Error("http serve failed", slog.Any("err", err))
58 | }
59 | }()
60 |
61 | quit := make(chan os.Signal, 1)
62 | signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
63 | <-quit
64 |
65 | shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
66 | defer cancel()
67 |
68 | slog.Info("server shutdown...")
69 | if err := server.Shutdown(shutdownCtx); err != nil {
70 | slog.Info("server shutdown failed")
71 | }
72 | }
73 |
--------------------------------------------------------------------------------
/examples/get_weather.go:
--------------------------------------------------------------------------------
1 | //go:build ignore
2 | // +build ignore
3 |
4 | package main
5 |
6 | import (
7 | "context"
8 | "fmt"
9 | "github.com/tonnie17/wxagent/pkg/agent"
10 | "github.com/tonnie17/wxagent/pkg/config"
11 | "github.com/tonnie17/wxagent/pkg/llm"
12 | "github.com/tonnie17/wxagent/pkg/memory"
13 | "github.com/tonnie17/wxagent/pkg/tool"
14 | "log/slog"
15 | "os"
16 |
17 | _ "github.com/joho/godotenv/autoload"
18 | )
19 |
20 | func main() {
21 | logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
22 | slog.SetDefault(logger)
23 |
24 | cfg, err := config.LoadConfig()
25 | if err != nil {
26 | slog.Error("config load failed", slog.Any("err", err))
27 | return
28 | }
29 |
30 | ctx := context.Background()
31 | tools := []tool.Tool{
32 | tool.NewGetWeather(),
33 | }
34 | a := agent.NewAgent(&cfg.AgentConfig, llm.NewOpenAI(), memory.NewBuffer(6), tools, nil)
35 | output, err := a.Chat(ctx, "天气怎么样")
36 | fmt.Println(output, err)
37 | output, err = a.Chat(ctx, "深圳")
38 | fmt.Println(output, err)
39 | }
40 |
--------------------------------------------------------------------------------
/examples/google_search.go:
--------------------------------------------------------------------------------
1 | //go:build ignore
2 | // +build ignore
3 |
4 | package main
5 |
6 | import (
7 | "context"
8 | "fmt"
9 | "github.com/tonnie17/wxagent/pkg/agent"
10 | "github.com/tonnie17/wxagent/pkg/config"
11 | "github.com/tonnie17/wxagent/pkg/llm"
12 | "github.com/tonnie17/wxagent/pkg/memory"
13 | "github.com/tonnie17/wxagent/pkg/tool"
14 | "log/slog"
15 | "os"
16 |
17 | _ "github.com/joho/godotenv/autoload"
18 | )
19 |
20 | func main() {
21 | logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
22 | slog.SetDefault(logger)
23 |
24 | cfg, err := config.LoadConfig()
25 | if err != nil {
26 | slog.Error("config load failed", slog.Any("err", err))
27 | return
28 | }
29 |
30 | ctx := context.Background()
31 | tools := []tool.Tool{
32 | tool.NewGoogleSearch(),
33 | }
34 | a := agent.NewAgent(&cfg.AgentConfig, llm.NewOpenAI(), memory.NewBuffer(6), tools, nil)
35 | output, err := a.Chat(ctx, "搜索一下法国的首都在哪里")
36 | fmt.Println(output, err)
37 | }
38 |
--------------------------------------------------------------------------------
/examples/home_assistant.go:
--------------------------------------------------------------------------------
1 | //go:build ignore
2 | // +build ignore
3 |
4 | package main
5 |
6 | import (
7 | "context"
8 | "fmt"
9 | "github.com/tonnie17/wxagent/pkg/agent"
10 | "github.com/tonnie17/wxagent/pkg/config"
11 | "github.com/tonnie17/wxagent/pkg/llm"
12 | "github.com/tonnie17/wxagent/pkg/memory"
13 | "github.com/tonnie17/wxagent/pkg/tool"
14 | "log/slog"
15 | "os"
16 |
17 | _ "github.com/joho/godotenv/autoload"
18 | )
19 |
20 | func main() {
21 | logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
22 | slog.SetDefault(logger)
23 |
24 | cfg, err := config.LoadConfig()
25 | if err != nil {
26 | slog.Error("config load failed", slog.Any("err", err))
27 | return
28 | }
29 |
30 | ctx := context.Background()
31 | tools := []tool.Tool{
32 | tool.NewGetHADevices(),
33 | tool.NewExecuteHADevice(),
34 | }
35 | a := agent.NewAgent(&cfg.AgentConfig, llm.NewOpenAI(), memory.NewBuffer(6), tools, nil)
36 | output, err := a.Chat(ctx, "打开客厅灯和厨房灯")
37 | fmt.Println(output, err)
38 | }
39 |
--------------------------------------------------------------------------------
/examples/rag.go:
--------------------------------------------------------------------------------
1 | //go:build ignore
2 | // +build ignore
3 |
4 | package main
5 |
6 | import (
7 | "context"
8 | "fmt"
9 | "github.com/tonnie17/wxagent/pkg/agent"
10 | "github.com/tonnie17/wxagent/pkg/config"
11 | "github.com/tonnie17/wxagent/pkg/embedding"
12 | "github.com/tonnie17/wxagent/pkg/llm"
13 | "github.com/tonnie17/wxagent/pkg/memory"
14 | "github.com/tonnie17/wxagent/pkg/rag"
15 | "log/slog"
16 | "os"
17 |
18 | _ "github.com/joho/godotenv/autoload"
19 | )
20 |
21 | func main() {
22 | logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
23 | slog.SetDefault(logger)
24 |
25 | cfg, err := config.LoadConfig()
26 | if err != nil {
27 | slog.Error("config load failed", slog.Any("err", err))
28 | return
29 | }
30 |
31 | store, err := rag.NewPgVectorStore()
32 | if err != nil {
33 | slog.Error("init vector store failed", slog.Any("err", err))
34 | return
35 | }
36 | defer store.Release()
37 |
38 | client := rag.NewClient(embedding.NewOpenAI(), store)
39 | if err := client.BuildKnowledgeBase(context.Background(), cfg.KnowledgeBasePath, cfg.EmbeddingModel, false); err != nil {
40 | slog.Error("load failed", slog.Any("err", err))
41 | return
42 | }
43 |
44 | a := agent.NewAgent(&cfg.AgentConfig, llm.NewOpenAI(), memory.NewBuffer(6), nil, client)
45 | output, err := a.Chat(context.Background(), "摩洛哥的公路有多少公里")
46 | fmt.Println(output, err)
47 | }
48 |
--------------------------------------------------------------------------------
/examples/webpage_summary.go:
--------------------------------------------------------------------------------
1 | //go:build ignore
2 | // +build ignore
3 |
4 | package main
5 |
6 | import (
7 | "context"
8 | "fmt"
9 | "github.com/tonnie17/wxagent/pkg/agent"
10 | "github.com/tonnie17/wxagent/pkg/config"
11 | "github.com/tonnie17/wxagent/pkg/llm"
12 | "github.com/tonnie17/wxagent/pkg/memory"
13 | "github.com/tonnie17/wxagent/pkg/tool"
14 | "log/slog"
15 | "os"
16 |
17 | _ "github.com/joho/godotenv/autoload"
18 | )
19 |
20 | func main() {
21 | logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
22 | slog.SetDefault(logger)
23 |
24 | cfg, err := config.LoadConfig()
25 | if err != nil {
26 | slog.Error("config load failed", slog.Any("err", err))
27 | return
28 | }
29 |
30 | ctx := context.Background()
31 | tools := []tool.Tool{
32 | tool.NewWebPageSummary(),
33 | }
34 | a := agent.NewAgent(&cfg.AgentConfig, llm.NewOpenAI(), memory.NewBuffer(6), tools, nil)
35 | output, err := a.Chat(ctx, "总结一下: https://golangnote.com/golang/golang-stringsbuilder-vs-bytesbuffer")
36 | fmt.Println(output, err)
37 | }
38 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/tonnie17/wxagent
2 |
3 | go 1.21.8
4 |
5 | require (
6 | github.com/PuerkitoBio/goquery v1.9.3
7 | github.com/caarlos0/env/v11 v11.2.2
8 | github.com/chzyer/readline v1.5.1
9 | github.com/go-chi/chi/v5 v5.1.0
10 | github.com/jackc/pgx/v5 v5.7.1
11 | github.com/joho/godotenv v1.5.1
12 | github.com/pgvector/pgvector-go v0.2.2
13 | github.com/pkoukk/tiktoken-go v0.1.7
14 | github.com/samber/slog-chi v1.12.3
15 | github.com/sashabaranov/go-openai v1.35.6
16 | )
17 |
18 | require (
19 | github.com/andybalholm/cascadia v1.3.2 // indirect
20 | github.com/dlclark/regexp2 v1.10.0 // indirect
21 | github.com/google/uuid v1.6.0 // indirect
22 | github.com/jackc/pgpassfile v1.0.0 // indirect
23 | github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
24 | github.com/jackc/puddle/v2 v2.2.2 // indirect
25 | go.opentelemetry.io/otel v1.29.0 // indirect
26 | go.opentelemetry.io/otel/trace v1.29.0 // indirect
27 | golang.org/x/crypto v0.27.0 // indirect
28 | golang.org/x/net v0.29.0 // indirect
29 | golang.org/x/sync v0.8.0 // indirect
30 | golang.org/x/sys v0.27.0 // indirect
31 | golang.org/x/text v0.18.0 // indirect
32 | )
33 |
--------------------------------------------------------------------------------
/go.sum:
--------------------------------------------------------------------------------
1 | entgo.io/ent v0.13.1 h1:uD8QwN1h6SNphdCCzmkMN3feSUzNnVvV/WIkHKMbzOE=
2 | entgo.io/ent v0.13.1/go.mod h1:qCEmo+biw3ccBn9OyL4ZK5dfpwg++l1Gxwac5B1206A=
3 | github.com/PuerkitoBio/goquery v1.9.3 h1:mpJr/ikUA9/GNJB/DBZcGeFDXUtosHRyRrwh7KGdTG0=
4 | github.com/PuerkitoBio/goquery v1.9.3/go.mod h1:1ndLHPdTz+DyQPICCWYlYQMPl0oXZj0G6D4LCYA6u4U=
5 | github.com/andybalholm/cascadia v1.3.2 h1:3Xi6Dw5lHF15JtdcmAHD3i1+T8plmv7BQ/nsViSLyss=
6 | github.com/andybalholm/cascadia v1.3.2/go.mod h1:7gtRlve5FxPPgIgX36uWBX58OdBsSS6lUvCFb+h7KvU=
7 | github.com/caarlos0/env/v11 v11.2.2 h1:95fApNrUyueipoZN/EhA8mMxiNxrBwDa+oAZrMWl3Kg=
8 | github.com/caarlos0/env/v11 v11.2.2/go.mod h1:JBfcdeQiBoI3Zh1QRAWfe+tpiNTmDtcCj/hHHHMx0vc=
9 | github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM=
10 | github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ=
11 | github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI=
12 | github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
13 | github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
14 | github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
15 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
16 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
17 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
18 | github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
19 | github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
20 | github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw=
21 | github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
22 | github.com/go-pg/pg/v10 v10.11.0 h1:CMKJqLgTrfpE/aOVeLdybezR2om071Vh38OLZjsyMI0=
23 | github.com/go-pg/pg/v10 v10.11.0/go.mod h1:4BpHRoxE61y4Onpof3x1a2SQvi9c+q1dJnrNdMjsroA=
24 | github.com/go-pg/zerochecker v0.2.0 h1:pp7f72c3DobMWOb2ErtZsnrPaSvHd2W4o9//8HtF4mU=
25 | github.com/go-pg/zerochecker v0.2.0/go.mod h1:NJZ4wKL0NmTtz0GKCoJ8kym6Xn/EQzXRl2OnAe7MmDo=
26 | github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
27 | github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
28 | github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
29 | github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
30 | github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
31 | github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
32 | github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
33 | github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
34 | github.com/jackc/pgx/v5 v5.7.1 h1:x7SYsPBYDkHDksogeSmZZ5xzThcTgRz++I5E+ePFUcs=
35 | github.com/jackc/pgx/v5 v5.7.1/go.mod h1:e7O26IywZZ+naJtWWos6i6fvWK+29etgITqrqHLfoZA=
36 | github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
37 | github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
38 | github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
39 | github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
40 | github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
41 | github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
42 | github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g=
43 | github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ=
44 | github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
45 | github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
46 | github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
47 | github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
48 | github.com/pgvector/pgvector-go v0.2.2 h1:Q/oArmzgbEcio88q0tWQksv/u9Gnb1c3F1K2TnalxR0=
49 | github.com/pgvector/pgvector-go v0.2.2/go.mod h1:u5sg3z9bnqVEdpe1pkTij8/rFhTaMCMNyQagPDLK8gQ=
50 | github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw=
51 | github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
52 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
53 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
54 | github.com/samber/slog-chi v1.12.3 h1:jZ09VSMCEytjdAIiNS2pMe7a1mZSfhjI0pdvZv0nTng=
55 | github.com/samber/slog-chi v1.12.3/go.mod h1:uKRFgoHdeoeme1SUzIbhxGoNfY/MluEtokKZigeG5Es=
56 | github.com/sashabaranov/go-openai v1.35.6 h1:oi0rwCvyxMxgFALDGnyqFTyCJm6n72OnEG3sybIFR0g=
57 | github.com/sashabaranov/go-openai v1.35.6/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
58 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
59 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
60 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
61 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
62 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
63 | github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc h1:9lRDQMhESg+zvGYmW5DyG0UqvY96Bu5QYsTLvCHdrgo=
64 | github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc/go.mod h1:bciPuU6GHm1iF1pBvUfxfsH0Wmnc2VbpgvbI9ZWuIRs=
65 | github.com/uptrace/bun v1.1.12 h1:sOjDVHxNTuM6dNGaba0wUuz7KvDE1BmNu9Gqs2gJSXQ=
66 | github.com/uptrace/bun v1.1.12/go.mod h1:NPG6JGULBeQ9IU6yHp7YGELRa5Agmd7ATZdz4tGZ6z0=
67 | github.com/uptrace/bun/dialect/pgdialect v1.1.12 h1:m/CM1UfOkoBTglGO5CUTKnIKKOApOYxkcP2qn0F9tJk=
68 | github.com/uptrace/bun/dialect/pgdialect v1.1.12/go.mod h1:Ij6WIxQILxLlL2frUBxUBOZJtLElD2QQNDcu/PWDHTc=
69 | github.com/uptrace/bun/driver/pgdriver v1.1.12 h1:3rRWB1GK0psTJrHwxzNfEij2MLibggiLdTqjTtfHc1w=
70 | github.com/uptrace/bun/driver/pgdriver v1.1.12/go.mod h1:ssYUP+qwSEgeDDS1xm2XBip9el1y9Mi5mTAvLoiADLM=
71 | github.com/vmihailenco/bufpool v0.1.11 h1:gOq2WmBrq0i2yW5QJ16ykccQ4wH9UyEsgLm6czKAd94=
72 | github.com/vmihailenco/bufpool v0.1.11/go.mod h1:AFf/MOy3l2CFTKbxwt0mp2MwnqjNEs5H/UxrkA5jxTQ=
73 | github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU=
74 | github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
75 | github.com/vmihailenco/tagparser v0.1.2 h1:gnjoVuB/kljJ5wICEEOpx98oXMWPLj22G67Vbd1qPqc=
76 | github.com/vmihailenco/tagparser v0.1.2/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI=
77 | github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
78 | github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
79 | github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
80 | go.opentelemetry.io/otel v1.29.0 h1:PdomN/Al4q/lN6iBJEN3AwPvUiHPMlt93c8bqTG5Llw=
81 | go.opentelemetry.io/otel v1.29.0/go.mod h1:N/WtXPs1CNCUEx+Agz5uouwCba+i+bJGFicT8SR4NP8=
82 | go.opentelemetry.io/otel/trace v1.29.0 h1:J/8ZNK4XgR7a21DZUAsbF8pZ5Jcw1VhACmnYt39JTi4=
83 | go.opentelemetry.io/otel/trace v1.29.0/go.mod h1:eHl3w0sp3paPkYstJOmAimxhiFXPg+MMTlEh3nsQgWQ=
84 | golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
85 | golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
86 | golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A=
87 | golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70=
88 | golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
89 | golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
90 | golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
91 | golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
92 | golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
93 | golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
94 | golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
95 | golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo=
96 | golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0=
97 | golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
98 | golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
99 | golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
100 | golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ=
101 | golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
102 | golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
103 | golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
104 | golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
105 | golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
106 | golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
107 | golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
108 | golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
109 | golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
110 | golang.org/x/sys v0.27.0 h1:wBqf8DvsY9Y/2P8gAfPDEYNuS30J4lPHJxXSb/nJZ+s=
111 | golang.org/x/sys v0.27.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
112 | golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
113 | golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
114 | golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
115 | golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY=
116 | golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
117 | golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
118 | golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
119 | golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
120 | golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
121 | golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224=
122 | golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
123 | golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
124 | golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
125 | golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
126 | golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
127 | golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
128 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
129 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
130 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
131 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
132 | gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo=
133 | gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0=
134 | gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls=
135 | gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
136 | mellium.im/sasl v0.3.1 h1:wE0LW6g7U83vhvxjC1IY8DnXM+EU095yeo8XClvCdfo=
137 | mellium.im/sasl v0.3.1/go.mod h1:xm59PUYpZHhgQ9ZqoJ5QaCqzWMi8IeS49dhp6plPCzw=
138 |
--------------------------------------------------------------------------------
/knowledge_base/apple.txt:
--------------------------------------------------------------------------------
1 | 苹果树(学名:Malus domestica)是蔷薇科苹果亚科苹果属植物,为落叶乔木,在世界上广泛种植。苹果,又称柰或林檎,是苹果树的果实,一般呈红色,但需视品种而定,富含矿物质和维生素,是人们最常食用的水果之一。人们根据需求的不同口感、用途(比如烹饪、生吃、酿苹果酒等)培育不同的品种,已知有超过7,500个苹果品种,拥有一系列人们需要的不同特性。
2 |
3 | 苹果起源于中亚,直到今天当地还可以找到苹果的野生祖先:新疆野苹果。苹果在亚洲和欧洲都有着数千年的种植历史,并由欧洲的殖民者带到了北美,是苹果属中生长最广泛的树种。 苹果在北欧、希腊、欧洲基督教传统等许多文化中都有宗教和神话的意义。
4 |
5 | 苹果开花期是基于各地气候而定,但一般集中在4-5月份。苹果是异花授粉植物,大部分品种自花不能结成果实。一般苹果栽种后,于2-3年才开始结果。果实成长期之长短,一般早熟品种为65-87天,中熟品种为90-133天,晚熟品种则为137-168天。在一般情形下,栽种后苹果树可有10-1000年寿命。苹果树如果从种子开始长,就会长得较为高大,可高至15米,但栽培树木一般只高3-5米左右,因此,苹果的各种栽培品种通常通过嫁接在砧木上进行繁殖,砧木决定了最终树木的大小。树干呈灰褐色,树皮有一定程度的脱落。苹果树及其果实很容易出现许多真菌、细菌和害虫问题,这些问题可以通过多种有机和非有机手段加以控制。 2010年,人们对苹果的基因组进行了测序,作为苹果生产中疾病控制和选择育种研究的一部分。2017年全球苹果产量为8310万公吨,中国占总产量的一半。
--------------------------------------------------------------------------------
/knowledge_base/morocco.txt:
--------------------------------------------------------------------------------
1 | 摩洛哥是二元制君主立宪制国家,行政权归于政府,立法权归于政府和议会两院,国王仍然握有行政和立法权力,在军事、外交和宗教事务拥有较大话语权。伊斯兰教是摩洛哥的主要宗教,官方语言包括现代标准阿拉伯语和柏柏尔语,其中后者于2011年获得官方地位。摩洛哥人主要使用摩洛哥阿拉伯语作为口语,其和现代标准阿拉伯语有一定差异;由于历史原因,摩洛哥同时也使用法语和西班牙语。
2 |
3 | 摩洛哥现为阿拉伯国家联盟、非洲联盟和地中海联盟成员国,是非洲第五大经济体。
4 |
5 | 距今40万年前便有古人类在摩洛哥活动。考古发现证明这里很早就有人居住了。最早在这里定居的柏柏尔人的来历已无以考证。腓尼基人,迦太基人和罗马人先后占领这里。7世纪时,阿拉伯人到来,并在8世纪建立王国。中世纪时期在这里统治的朝代有些是柏柏尔人,有些是阿拉伯人。摩洛哥王国的国名来源于阿拉伯语中称呼西方的“马格里布”,中国古籍中有关摩洛哥的较早记载,有宋人赵汝适《诸蕃志》“默伽猎国”。明代《坤舆万国全图》把摩洛哥称为“马逻可国”。
6 |
7 | 从15世纪开始,摩洛哥受到多个西方国家的入侵。20世纪初期的摩洛哥危机是导致第一次世界大战的重要原因之一。1912年3月,法国占领摩洛哥。1956年3月,获得独立。
8 |
9 | 现今摩洛哥阿拉维王朝是在17世纪开始的,1957年,摩洛哥苏丹穆罕默德五世宣布改称国王,摩洛哥苏丹国改制为摩洛哥王国。实行君主立宪制,但君主较内阁和国会拥有相当大的权力,所以是二元君主制。
10 |
11 | 摩洛哥是君主立宪制的多党制的伊斯兰国家,国王拥有比较大的权力。议会分参议院和众议院两院,众议院325名议员由国民直接选举而出,任期为五年。参议院270名由地方代表选出,每三年更换三分之一。
12 |
13 | 穆罕默德六世从1999年登基后加强国家的民主化,对内试图缓和贫困和保证社会治安,对外持比较缓和的伊斯兰国家政策。摩洛哥是中东与欧美对话的重要中间国。2004年6月,美国总统布什宣布给予摩洛哥主要非北约盟友地位。
14 |
15 | 自2011年2月20日起,摩洛哥民众展开民主示威活动。2011年3月9日,穆罕默德六世国王宣布将进行全面的宪政改革,内容有首相产生方式和权力、提高国民的自由和人权、建立独立的司法体系等各个方面,以真正实现民主政治。6月17日,国王提出政改计划,将自限权力建立一个民主的君主立宪体制,计划将在7月1日交付公投。方案内容:将国王许多权力交给首相和国会;未来将实施直接普选,政府首长由胜选政党组成;由首相指派政府官员,包括行政部门和国营企业首长,并且首相将有权力解散国会等。
16 |
17 | 2011年7月1日,摩洛哥举行新宪法草案全民公决。2日凌晨,摩洛哥内政大臣沙尔阿维公布草案以98.49%的赞成票获得通过。这标志着该国家以和平方式迈出了宪政改革的重要一步。
18 |
19 | 新宪法草案的主要内容:
20 | 将议会的权限扩大,特别是加强了两院中众议院的主导地位。众议院在五分之一议员支持的情况下即可对重要官员展开调查,或在获得三分之一议员支持的情况下对大臣进行弹劾,同时还将取代国王行使大赦权力。
21 | 新宪法草案同时强调司法独立,国家的司法体系将由法官与全国人权委员会共同组成的最高委员会掌管,而司法大臣被排除在委员会之外。
22 | 在新宪法草案中,摩洛哥国王仍为国家元首、陆海空三军最高统帅和宗教领袖,并分别担任“大臣委员会”和新成立的“国家安全委员会”主席,掌握着重大决策的最终决定权,同时仍拥有重要地方长官和驻外大使的任命权等。
23 |
24 | 位于非洲大陆西北部,直布罗陀海峡南岸,扼地中海入大西洋的门户,海岸线1700多公里。地形复杂,中部和北部为峻峭的阿特拉斯山脉,东部和南部是上高原和前撒哈拉高原,仅西北沿海一带为狭长低缓的平原。由于斜贯全境的阿特拉斯山阻挡了南部撒哈拉沙漠热浪的侵袭,摩洛哥常年气候宜人,花木繁茂,赢得“烈日下的清凉国土”的美誉,还享有“北非花园”的美称。
25 |
26 | 受副热带高压带控制和加那利寒流影响,形成干燥的热带沙漠气候;阿特拉斯山脉横贯全国,其中图卜卡勒峰(4,165米)是全国最高点。
27 |
28 | 摩洛哥磷酸盐资源极为丰富,估计储量1100亿吨,占世界储量的75%。其它矿产资源有铁、铅、锌、钴、锰、钡、铜、盐、磁铁矿、无烟煤、油页岩等。其中油页岩储量1000亿吨以上,含原油60亿吨。
29 |
30 | 摩洛哥是一个以第三产业经济为主,居中等收入水平的发展中国家,是非洲第五、北非第三大经济体。其主要的经济部门是旅游业、渔业和磷酸矿的出口,磷酸盐储量1100亿吨,占世界首位。农业与牧业也比较重要,但受气候影响比较大。摩洛哥在许多方面依靠外来资助,法国是其第二大援助国,西班牙是其第一大援助国。国民经济发展缓慢。2010年摩洛哥国内生产总值为917亿美元,人均2,839美元。经济增长率3.2%,通货膨胀率则为1.4%。2022年人均3,628美元。
31 |
32 | 陆路交通较发达。在国内运输业中占主导地位,90%的客运和75%的货运通过陆路交通完成。
33 |
34 | 铁路:投入运营线路1907公里,其中复线483公里,电气化铁路1014公里。另有765公里磷酸盐运输线。2003年,摩洛哥与西班牙达成协议,两国共同修建一条穿过直布罗陀海峡的海底复线铁路。该工程曾预计2010年开工,将是连接欧、非两大洲的首条铁路线,但直至2021年,此工程未有任何实质进展。 2018年,卡萨布兰卡-丹吉尔高速铁路通车,是摩洛哥的第一条高速铁路,也是非洲第一条高速铁路。
35 |
36 | 公路:总长64452公里,其中一级公路15907公里,二级公路9367公里,三级公路39178公里。至2009年中,高速公路916公里,有拉巴特-丹吉尔、拉巴特-卡萨布兰卡-塞达特、拉巴特-梅克内斯-非斯等多段高速公路。根据摩第二个全国乡村公路10年计划(2005-2015),每年将新建1500公里公路,届时将使全国80%的农村地区通公路。
37 |
38 | 水运:现拥有港口30个,其中11个为多功能港口,11个为运输、捕鱼用港口。主要港口有卡萨布兰卡、穆罕默迪耶、萨非、丹吉尔、阿加迪尔等,2009年全国总吞吐量6682万吨。其中卡萨布兰卡为全国最大港口,占全国港口总吞吐量的37%;丹吉尔-地中海港一期已完工,正在扩建二期,将成为非洲和地中海最大港口之一。
39 |
40 | 空运:全国共有机场28个,其中国际机场12个,如卡萨布兰卡穆罕默德五世国际机场、拉巴特-塞拉国际机场、马拉喀什-迈纳拉国际机场等。摩王家航空公司有飞机33架,开通75条航线,航线通往四大洲32个国家,总航线30多万公里。2008年客运量1060万人次。2005年,摩与欧盟签署“天空开放”协议,摩航空市场对欧洲航空公司开放。
41 |
42 | 1996年12月,经阿尔及利亚、摩洛哥至西班牙、葡萄牙的马格里布-欧洲天然气管道正式开通。管道全长1385公里,初期每年可输送天然气90亿立方米,摩每年可获10亿立方米天然气。2005年起,摩首次将该管道的过境天然气截留配额用于发电,可满足全国17%的电力需求。
43 |
--------------------------------------------------------------------------------
/pkg/agent/agent.go:
--------------------------------------------------------------------------------
1 | package agent
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/json"
7 | "errors"
8 | "fmt"
9 | _ "github.com/joho/godotenv/autoload"
10 | "github.com/tonnie17/wxagent/pkg/config"
11 | "github.com/tonnie17/wxagent/pkg/llm"
12 | "github.com/tonnie17/wxagent/pkg/memory"
13 | "github.com/tonnie17/wxagent/pkg/rag"
14 | "github.com/tonnie17/wxagent/pkg/tool"
15 | "html/template"
16 | "log/slog"
17 | "strings"
18 | "time"
19 | )
20 |
21 | var (
22 | ErrMemoryInUse = errors.New("memory in use")
23 | PromptFuncMap = template.FuncMap{
24 | "now": time.Now,
25 | }
26 | )
27 |
28 | type Agent struct {
29 | config *config.AgentConfig
30 | llm llm.LLM
31 | memory memory.Memory
32 | tools []tool.Tool
33 | ragClient *rag.Client
34 | }
35 |
36 | func NewAgent(config *config.AgentConfig, llm llm.LLM, memory memory.Memory, tools []tool.Tool, ragClient *rag.Client) *Agent {
37 | return &Agent{
38 | config: config,
39 | llm: llm,
40 | memory: memory,
41 | tools: tools,
42 | ragClient: ragClient,
43 | }
44 | }
45 |
46 | func (a *Agent) Chat(ctx context.Context, input string) (string, error) {
47 | msg, err := a.Process(ctx, &llm.ChatMessage{
48 | Role: llm.RoleUser,
49 | Content: input,
50 | })
51 | if err != nil {
52 | return "", err
53 | }
54 |
55 | return msg.Content, nil
56 | }
57 |
58 | func (a *Agent) ChatContinue(ctx context.Context) (string, error) {
59 | if l, ok := a.memory.(memory.Lock); ok && l.IsLocked() {
60 | return "", ErrMemoryInUse
61 | }
62 |
63 | messages, err := a.memory.History()
64 | if err != nil {
65 | return "", err
66 | }
67 |
68 | for i := len(messages) - 1; i > 0; i-- {
69 | msg := messages[i]
70 | if msg.Role == llm.RoleAssistant {
71 | return msg.Content, nil
72 | }
73 | }
74 | return "", nil
75 | }
76 |
77 | func (a *Agent) Process(ctx context.Context, message *llm.ChatMessage) (*llm.ChatMessage, error) {
78 | var err error
79 | out := make(chan *llm.ChatMessage)
80 | go func() {
81 | if err = a.ProcessStream(ctx, message, out); err != nil {
82 | return
83 | }
84 | }()
85 |
86 | var res *llm.ChatMessage
87 | for msg := range out {
88 | res = msg
89 | }
90 |
91 | return res, err
92 | }
93 |
94 | func (a *Agent) ProcessStream(ctx context.Context, message *llm.ChatMessage, outputChan chan<- *llm.ChatMessage) error {
95 | defer close(outputChan)
96 |
97 | if a.config.AgentTimeout != 0 {
98 | timeoutCtx, cancel := context.WithTimeout(ctx, a.config.AgentTimeout)
99 | defer cancel()
100 | ctx = timeoutCtx
101 | }
102 |
103 | if a.ragClient != nil && message != nil {
104 | documents, err := a.ragClient.Query(ctx, a.config.EmbeddingModel, message.Content, 3)
105 | if err != nil {
106 | return err
107 | }
108 |
109 | if len(documents) > 0 {
110 | contexts := make([]string, 0, len(documents))
111 | for _, doc := range documents {
112 | contexts = append(contexts, doc.Content)
113 | }
114 | message.Content = a.buildRAGPrompt(contexts, message.Content)
115 | }
116 | }
117 |
118 | if l, ok := a.memory.(memory.Lock); ok {
119 | if l.Lock() {
120 | defer l.Release()
121 | } else {
122 | return ErrMemoryInUse
123 | }
124 | }
125 |
126 | var messages []*llm.ChatMessage
127 | if a.config.SystemPrompt != "" {
128 | systemPrompt := a.config.SystemPrompt
129 | if tmp, err := template.New("systemPrompt").Funcs(PromptFuncMap).Parse(systemPrompt); err == nil {
130 | promptTpl := new(bytes.Buffer)
131 | if tmp.Execute(promptTpl, nil) == nil {
132 | systemPrompt = promptTpl.String()
133 | }
134 | }
135 | messages = append(messages, &llm.ChatMessage{
136 | Role: llm.RoleSystem,
137 | Content: systemPrompt,
138 | })
139 | }
140 |
141 | history, err := a.memory.History()
142 | if err != nil {
143 | return err
144 | }
145 |
146 | messages = append(messages, history...)
147 | if message != nil {
148 | messages = append(messages, message)
149 | }
150 |
151 | toolsMap := make(map[string]tool.Tool, len(a.tools))
152 | for _, t := range a.tools {
153 | toolsMap[t.Name()] = t
154 | }
155 |
156 | for i := 0; i < a.config.MaxToolIter; i++ {
157 | select {
158 | case <-ctx.Done():
159 | return ctx.Err()
160 | default:
161 | }
162 |
163 | slog.Debug("chat input", slog.String("model", a.config.Model), slog.Any("messages", messages), slog.Any("tools", a.tools))
164 | msg, err := a.llm.Chat(ctx, a.config.Model, messages,
165 | llm.Tools(a.tools),
166 | llm.MaxTokens(a.config.MaxTokens),
167 | llm.Temperature(a.config.Temperature),
168 | llm.TopP(a.config.TopP),
169 | )
170 | if err != nil {
171 | slog.Debug("chat error", slog.Any("err", err))
172 | return err
173 | }
174 | slog.Debug("chat output", slog.Any("result", msg))
175 |
176 | select {
177 | case outputChan <- msg:
178 | case <-ctx.Done():
179 | return ctx.Err()
180 | }
181 |
182 | messages = append(messages, msg)
183 |
184 | if len(msg.ToolCalls) == 0 {
185 | break
186 | }
187 |
188 | for _, toolCall := range msg.ToolCalls {
189 | t, ok := toolsMap[toolCall.Name]
190 | if !ok {
191 | slog.Error("tool not exist", slog.String("tool_call_id", toolCall.ID))
192 | continue
193 | }
194 |
195 | var (
196 | toolCtx = ctx
197 | toolCtxCancel context.CancelFunc
198 | )
199 | if a.config.ToolTimeout != 0 {
200 | toolCtx, toolCtxCancel = context.WithTimeout(ctx, a.config.ToolTimeout)
201 | }
202 |
203 | output, err := t.Execute(toolCtx, toolCall.Arguments)
204 | if toolCtxCancel != nil {
205 | toolCtxCancel()
206 | }
207 |
208 | if err != nil {
209 | slog.Error("tool call function execute failed", slog.String("tool_call_id", toolCall.ID), slog.Any("err", err))
210 | }
211 |
212 | toolResponse := a.convertToolCallMessage(toolCall, output, err)
213 | messages = append(messages, toolResponse)
214 |
215 | select {
216 | case outputChan <- toolResponse:
217 | case <-ctx.Done():
218 | return ctx.Err()
219 | }
220 | }
221 | }
222 |
223 | a.memory.Update(messages)
224 | return nil
225 | }
226 |
227 | func (a *Agent) convertToolCallMessage(toolCall *tool.Call, output string, err error) *llm.ChatMessage {
228 | status := "success"
229 | if err != nil {
230 | status = "failed"
231 | output = err.Error()
232 | }
233 |
234 | content, _ := json.Marshal(struct {
235 | Status string `json:"status"`
236 | Output string `json:"output"`
237 | Name string `json:"name"`
238 | }{
239 | Status: status,
240 | Output: output,
241 | Name: toolCall.Name,
242 | })
243 |
244 | toolMessage := &llm.ChatMessage{
245 | Role: llm.RoleTool,
246 | Content: string(content),
247 | ToolCallID: toolCall.ID,
248 | }
249 |
250 | return toolMessage
251 | }
252 |
253 | func (a *Agent) buildRAGPrompt(contexts []string, question string) string {
254 | return fmt.Sprintf(`You are an assistant. Answer the question based on the given context.
255 |
256 | Context:
257 | %v
258 |
259 | Question:
260 | %v
261 |
262 | Answer:`, strings.Join(contexts, "\n"), question)
263 | }
264 |
--------------------------------------------------------------------------------
/pkg/config/config.go:
--------------------------------------------------------------------------------
1 | package config
2 |
3 | import (
4 | "github.com/caarlos0/env/v11"
5 | "time"
6 | )
7 |
8 | type Config struct {
9 | ServerAddr string `env:"SERVER_ADDR" envDefault:"0.0.0.0:8082"`
10 | ChatAPIKEY string `env:"CHAT_API_KEY" envDefault:""`
11 | WechatAppID string `env:"WECHAT_APP_ID"`
12 | WechatEncodingAESKey string `env:"WECHAT_ENCODING_AES_KEY"`
13 | WechatToken string `env:"WECHAT_TOKEN"`
14 | WechatAllowList []string `env:"WECHAT_ALLOW_LIST"`
15 | WechatMemTTL time.Duration `env:"WECHAT_MEM_TTL" envDefault:"5m"`
16 | WechatMemMsgSize int `env:"WECHAT_MEM_MSG_SIZE" envDefault:"6"`
17 | WechatTimeout time.Duration `env:"WECHAT_TIMEOUT" envDefault:"4s"`
18 | LLMProvider string `env:"LLM_PROVIDER" envDefault:"openai"`
19 | UseRAG bool `env:"USE_RAG" envDefault:"false"`
20 | EmbeddingProvider string `env:"EMBEDDING_PROVIDER" envDefault:"openai"`
21 | KnowledgeBasePath string `env:"KNOWLEDGE_BASE_PATH" envDefault:"./knowledge_base"`
22 | AgentConfig
23 | }
24 |
25 | type AgentConfig struct {
26 | AgentTools []string `env:"AGENT_TOOLS"`
27 | AgentTimeout time.Duration `env:"AGENT_TIMEOUT" envDefault:"30s"`
28 | MaxToolIter int `env:"MAX_TOOL_ITER" envDefault:"5"`
29 | ToolTimeout time.Duration `env:"TOOL_TIMEOUT" envDefault:"10s"`
30 | Model string `env:"LLM_MODEL" envDefault:"gpt-3.5-turbo"`
31 | MaxTokens int `env:"LLM_MAX_TOKENS" envDefault:"500"`
32 | Temperature float32 `env:"LLM_TEMPERATURE" envDefault:"0.2"`
33 | TopP float32 `env:"LLM_TOP_P" envDefault:"0.9"`
34 | SystemPrompt string `env:"SYSTEM_PROMPT" envDefault:"当前时间: {{now.UTC}}"`
35 | EmbeddingModel string `env:"EMBEDDING_MODEL" envDefault:"text-embedding-ada-002"`
36 | }
37 |
38 | func LoadConfig() (*Config, error) {
39 | cfg := &Config{}
40 | if err := env.Parse(cfg); err != nil {
41 | return nil, err
42 | }
43 | return cfg, nil
44 | }
45 |
--------------------------------------------------------------------------------
/pkg/embedding/default.go:
--------------------------------------------------------------------------------
1 | package embedding
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | )
7 |
8 | type NotImplemented struct {
9 | }
10 |
11 | func NewNotImplemented() Model {
12 | return &NotImplemented{}
13 | }
14 |
15 | func (o *NotImplemented) CreateEmbeddings(ctx context.Context, model string, content string) ([]float32, error) {
16 | return nil, fmt.Errorf("embedding not implemented")
17 | }
18 |
--------------------------------------------------------------------------------
/pkg/embedding/embedding.go:
--------------------------------------------------------------------------------
1 | package embedding
2 |
3 | import "context"
4 |
5 | type Model interface {
6 | CreateEmbeddings(ctx context.Context, model string, content string) ([]float32, error)
7 | }
8 |
9 | func New(provider string) Model {
10 | switch provider {
11 | case "openai":
12 | return NewOpenAI()
13 | default:
14 | return NewNotImplemented()
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/pkg/embedding/openai.go:
--------------------------------------------------------------------------------
1 | package embedding
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "github.com/sashabaranov/go-openai"
7 | "github.com/tonnie17/wxagent/pkg/provider"
8 | )
9 |
10 | type OpenAI struct {
11 | client *openai.Client
12 | }
13 |
14 | func NewOpenAI() Model {
15 | openaiConfig := openai.DefaultConfig(provider.GetAPIKey("openai"))
16 | openaiConfig.BaseURL = provider.GetAPIBaseURL("openai")
17 | client := openai.NewClientWithConfig(openaiConfig)
18 | return &OpenAI{
19 | client: client,
20 | }
21 | }
22 |
23 | func (o *OpenAI) CreateEmbeddings(ctx context.Context, model string, content string) ([]float32, error) {
24 | resp, err := o.client.CreateEmbeddings(ctx, openai.EmbeddingRequestStrings{
25 | Input: []string{content},
26 | Model: openai.EmbeddingModel(model),
27 | EncodingFormat: openai.EmbeddingEncodingFormatFloat,
28 | })
29 |
30 | if err != nil {
31 | return nil, err
32 | }
33 |
34 | if len(resp.Data) == 0 {
35 | return nil, fmt.Errorf("data is empty")
36 | }
37 |
38 | return resp.Data[0].Embedding, nil
39 | }
40 |
--------------------------------------------------------------------------------
/pkg/ha/home_assistant.go:
--------------------------------------------------------------------------------
1 | package ha
2 |
3 | import (
4 | "bytes"
5 | "context"
6 | "encoding/json"
7 | "fmt"
8 | "io"
9 | "net/http"
10 | "os"
11 | "strings"
12 | )
13 |
14 | func GetEntityStates(ctx context.Context, domains []string) ([]EntityState, error) {
15 | req, err := newHARequest(ctx, "GET", "/api/states", nil)
16 | if err != nil {
17 | return nil, err
18 | }
19 | client := &http.Client{}
20 | resp, err := client.Do(req)
21 | if err != nil {
22 | return nil, err
23 | }
24 | defer resp.Body.Close()
25 |
26 | content, err := io.ReadAll(resp.Body)
27 | if err != nil {
28 | return nil, err
29 | }
30 |
31 | entityStates := make([]EntityState, 0)
32 | if err := json.Unmarshal(content, &entityStates); err != nil {
33 | return nil, err
34 | }
35 |
36 | res := make([]EntityState, 0)
37 | for _, entityState := range entityStates {
38 | if entityState.State == "unknown" {
39 | continue
40 | }
41 | split := strings.Split(entityState.EntityID, ".")
42 | if len(domains) > 0 && (len(split) < 2 || !containString(domains, split[0])) {
43 | continue
44 | }
45 | res = append(res, entityState)
46 | }
47 |
48 | return res, nil
49 | }
50 |
51 | func ExecuteService(ctx context.Context, domain string, service string, entityID string) ([]EntityState, error) {
52 | body := map[string]interface{}{
53 | "entity_id": entityID,
54 | }
55 | bodyJSON, _ := json.Marshal(body)
56 | req, err := newHARequest(ctx, "POST", fmt.Sprintf("/api/services/%v/%v", domain, service), bytes.NewReader(bodyJSON))
57 |
58 | client := &http.Client{}
59 | resp, err := client.Do(req)
60 | if err != nil {
61 | return nil, err
62 | }
63 | defer resp.Body.Close()
64 |
65 | content, err := io.ReadAll(resp.Body)
66 | if err != nil {
67 | return nil, err
68 | }
69 |
70 | entityStates := make([]EntityState, 0)
71 | if err := json.Unmarshal(content, &entityStates); err != nil {
72 | return nil, err
73 | }
74 |
75 | return entityStates, nil
76 | }
77 |
78 | func newHARequest(ctx context.Context, method string, apiPath string, body io.Reader) (*http.Request, error) {
79 | haBaseURL := os.Getenv("HA_BASE_URL")
80 | haToken := os.Getenv("HA_BEARER_TOKEN")
81 |
82 | apiURL := fmt.Sprintf("%v%v", haBaseURL, apiPath)
83 | req, err := http.NewRequest(method, apiURL, body)
84 | if err != nil {
85 | return nil, err
86 | }
87 | req = req.WithContext(ctx)
88 | req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", haToken))
89 | return req, nil
90 | }
91 |
92 | func containString(target []string, s string) bool {
93 | for _, t := range target {
94 | if s == t {
95 | return true
96 | }
97 | }
98 | return false
99 | }
100 |
--------------------------------------------------------------------------------
/pkg/ha/model.go:
--------------------------------------------------------------------------------
1 | package ha
2 |
3 | type EntityState struct {
4 | EntityID string `json:"entity_id"`
5 | State string `json:"state"`
6 | Attributes struct {
7 | FriendlyName string `json:"friendly_name"`
8 | } `json:"attributes"`
9 | LastChanged string `json:"last_changed"`
10 | }
11 |
--------------------------------------------------------------------------------
/pkg/llm/default.go:
--------------------------------------------------------------------------------
1 | package llm
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | )
7 |
8 | type NotImplemented struct {
9 | }
10 |
11 | func NewNotImplemented() LLM {
12 | return &NotImplemented{}
13 | }
14 |
15 | func (o *NotImplemented) Chat(ctx context.Context, model string, chatMessages []*ChatMessage, options ...ChatOption) (*ChatMessage, error) {
16 | return nil, fmt.Errorf("llm not implemented")
17 | }
18 |
--------------------------------------------------------------------------------
/pkg/llm/llm.go:
--------------------------------------------------------------------------------
1 | package llm
2 |
3 | import (
4 | "context"
5 | "github.com/tonnie17/wxagent/pkg/tool"
6 | )
7 |
8 | type LLM interface {
9 | Chat(ctx context.Context, model string, messages []*ChatMessage, options ...ChatOption) (*ChatMessage, error)
10 | }
11 |
12 | func New(provider string) LLM {
13 | switch provider {
14 | case "openai":
15 | return NewOpenAI()
16 | default:
17 | return NewNotImplemented()
18 | }
19 | }
20 |
21 | type chatOptions struct {
22 | tools []tool.Tool
23 | maxTokens int
24 | temperature float32
25 | topP float32
26 | }
27 |
28 | type ChatOption func(*chatOptions)
29 |
30 | func Tools(tools []tool.Tool) ChatOption {
31 | return func(o *chatOptions) {
32 | o.tools = tools
33 | }
34 | }
35 |
36 | func MaxTokens(maxTokens int) ChatOption {
37 | return func(o *chatOptions) {
38 | o.maxTokens = maxTokens
39 | }
40 | }
41 |
42 | func Temperature(temperature float32) ChatOption {
43 | return func(o *chatOptions) {
44 | o.temperature = temperature
45 | }
46 | }
47 |
48 | func TopP(topP float32) ChatOption {
49 | return func(o *chatOptions) {
50 | o.topP = topP
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/pkg/llm/message.go:
--------------------------------------------------------------------------------
1 | package llm
2 |
3 | import "github.com/tonnie17/wxagent/pkg/tool"
4 |
5 | type Role string
6 |
7 | const (
8 | RoleUser Role = "user"
9 | RoleSystem Role = "system"
10 | RoleTool Role = "tool"
11 | RoleAssistant Role = "assistant"
12 | )
13 |
14 | type ChatMessage struct {
15 | Role Role `json:"role"`
16 | Content string `json:"content"`
17 | ToolCallID string `json:"tool_call_id"`
18 | ToolCalls []*tool.Call `json:"tool_calls"`
19 | }
20 |
--------------------------------------------------------------------------------
/pkg/llm/openai.go:
--------------------------------------------------------------------------------
1 | package llm
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "github.com/sashabaranov/go-openai"
7 | "github.com/tonnie17/wxagent/pkg/provider"
8 | "github.com/tonnie17/wxagent/pkg/tool"
9 | )
10 |
11 | type OpenAI struct {
12 | client *openai.Client
13 | }
14 |
15 | func NewOpenAI() LLM {
16 | openaiConfig := openai.DefaultConfig(provider.GetAPIKey("openai"))
17 | openaiConfig.BaseURL = provider.GetAPIBaseURL("openai")
18 | client := openai.NewClientWithConfig(openaiConfig)
19 | return &OpenAI{
20 | client: client,
21 | }
22 | }
23 |
24 | func (o *OpenAI) Chat(ctx context.Context, model string, chatMessages []*ChatMessage, options ...ChatOption) (*ChatMessage, error) {
25 | message := make([]openai.ChatCompletionMessage, 0, len(chatMessages))
26 | for _, chatMessage := range chatMessages {
27 | msg := openai.ChatCompletionMessage{
28 | Role: string(chatMessage.Role),
29 | Content: chatMessage.Content,
30 | ToolCallID: chatMessage.ToolCallID,
31 | }
32 | for _, toolCall := range chatMessage.ToolCalls {
33 | msg.ToolCalls = append(msg.ToolCalls, openai.ToolCall{
34 | ID: toolCall.ID,
35 | Type: openai.ToolType(toolCall.Type),
36 | Function: openai.FunctionCall{
37 | Name: toolCall.Name,
38 | Arguments: toolCall.Arguments,
39 | },
40 | })
41 | }
42 | message = append(message, msg)
43 | }
44 |
45 | var option chatOptions
46 | for _, o := range options {
47 | o(&option)
48 | }
49 |
50 | requestTools := make([]openai.Tool, 0, len(option.tools))
51 | for _, t := range option.tools {
52 | requestTools = append(requestTools, openai.Tool{
53 | Type: openai.ToolTypeFunction,
54 | Function: &openai.FunctionDefinition{
55 | Name: t.Name(),
56 | Description: t.Description(),
57 | Parameters: t.Schema(),
58 | },
59 | })
60 | }
61 |
62 | response, err := o.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
63 | Model: model,
64 | Messages: message,
65 | Tools: requestTools,
66 | MaxTokens: option.maxTokens,
67 | Temperature: option.temperature,
68 | TopP: option.topP,
69 | })
70 | if err != nil {
71 | return nil, err
72 | }
73 |
74 | return o.convertResponse(response)
75 | }
76 |
77 | func (o *OpenAI) convertResponse(response openai.ChatCompletionResponse) (*ChatMessage, error) {
78 | if len(response.Choices) == 0 {
79 | return nil, fmt.Errorf("empty response")
80 | }
81 |
82 | message := response.Choices[0].Message
83 | if message.Content != "" {
84 | rc := &ChatMessage{
85 | Role: Role(message.Role),
86 | Content: message.Content,
87 | }
88 | return rc, nil
89 | }
90 |
91 | if len(message.ToolCalls) > 0 {
92 | rc := &ChatMessage{
93 | Role: Role(message.Role),
94 | Content: message.Content,
95 | }
96 | for _, toolCall := range message.ToolCalls {
97 | rc.ToolCalls = append(rc.ToolCalls, &tool.Call{
98 | ID: toolCall.ID,
99 | Type: string(toolCall.Type),
100 | Name: toolCall.Function.Name,
101 | Arguments: toolCall.Function.Arguments,
102 | })
103 | }
104 | return rc, nil
105 | }
106 |
107 | return nil, fmt.Errorf("no content or tool calls")
108 | }
109 |
--------------------------------------------------------------------------------
/pkg/memory/buffer.go:
--------------------------------------------------------------------------------
1 | package memory
2 |
3 | import (
4 | "github.com/tonnie17/wxagent/pkg/llm"
5 | )
6 |
7 | type Buffer struct {
8 | BaseLock
9 | maxMessages int
10 | messages []*llm.ChatMessage
11 | }
12 |
13 | func NewBuffer(maxMessages int) *Buffer {
14 | return &Buffer{
15 | maxMessages: maxMessages,
16 | messages: []*llm.ChatMessage{},
17 | }
18 | }
19 |
20 | func (m *Buffer) History() ([]*llm.ChatMessage, error) {
21 | return m.messages, nil
22 | }
23 |
24 | func (m *Buffer) Update(messages []*llm.ChatMessage) error {
25 | m.messages = messages
26 | m.truncate()
27 | return nil
28 | }
29 |
30 | func (m *Buffer) truncate() {
31 | start := 0
32 | for start < len(m.messages) && len(m.messages) > m.maxMessages {
33 | start++
34 | for start < len(m.messages) && (m.messages[start].Role == llm.RoleTool || m.messages[start].Role == llm.RoleAssistant) {
35 | start++
36 | }
37 | m.messages = m.messages[start:]
38 | }
39 | }
40 |
--------------------------------------------------------------------------------
/pkg/memory/memory.go:
--------------------------------------------------------------------------------
1 | package memory
2 |
3 | import (
4 | "github.com/tonnie17/wxagent/pkg/llm"
5 | "sync"
6 | "sync/atomic"
7 | )
8 |
9 | type Memory interface {
10 | Update(messages []*llm.ChatMessage) error
11 | History() ([]*llm.ChatMessage, error)
12 | }
13 |
14 | type Lock interface {
15 | Lock() bool
16 | Release()
17 | IsLocked() bool
18 | }
19 |
20 | type BaseLock struct {
21 | isLock int32
22 | lock sync.Mutex
23 | }
24 |
25 | func (m *BaseLock) Lock() bool {
26 | lockSuccess := m.lock.TryLock()
27 | if lockSuccess {
28 | atomic.StoreInt32(&m.isLock, 1)
29 | }
30 | return lockSuccess
31 | }
32 |
33 | func (m *BaseLock) IsLocked() bool {
34 | return atomic.LoadInt32(&m.isLock) == 1
35 | }
36 |
37 | func (m *BaseLock) Release() {
38 | atomic.StoreInt32(&m.isLock, 0)
39 | m.lock.Unlock()
40 | }
41 |
--------------------------------------------------------------------------------
/pkg/memory/token_base.go:
--------------------------------------------------------------------------------
1 | package memory
2 |
3 | import (
4 | "github.com/pkoukk/tiktoken-go"
5 | "github.com/tonnie17/wxagent/pkg/llm"
6 | "log/slog"
7 | )
8 |
9 | type TokenBase struct {
10 | BaseLock
11 | messages []*llm.ChatMessage
12 | maxTokens int
13 | totalTokens int
14 | encoding *tiktoken.Tiktoken
15 | }
16 |
17 | func NewTokenBase(model string, maxTokens int) *TokenBase {
18 | encoding, err := tiktoken.EncodingForModel(model)
19 | if err != nil {
20 | slog.Warn("failed to get encoding for model")
21 | encoding, _ = tiktoken.EncodingForModel("gpt-4o")
22 | }
23 | return &TokenBase{
24 | messages: []*llm.ChatMessage{},
25 | maxTokens: maxTokens,
26 | totalTokens: 0,
27 | encoding: encoding,
28 | }
29 | }
30 |
31 | func (m *TokenBase) History() ([]*llm.ChatMessage, error) {
32 | return m.messages, nil
33 | }
34 |
35 | func (m *TokenBase) Update(messages []*llm.ChatMessage) error {
36 | m.totalTokens = m.getMessageTokens(messages)
37 | m.messages = messages
38 | m.truncate()
39 | return nil
40 | }
41 |
42 | func (m *TokenBase) truncate() {
43 | start := 0
44 | for start < len(m.messages) && m.maxTokens > 0 && m.getMessageTokens(m.messages) > m.maxTokens {
45 | start++
46 | for start < len(m.messages) && (m.messages[start].Role == llm.RoleTool || m.messages[start].Role == llm.RoleAssistant) {
47 | start++
48 | }
49 | m.messages = m.messages[start:]
50 | }
51 | }
52 |
53 | func (m *TokenBase) getMessageTokens(messages []*llm.ChatMessage) int {
54 | var total int
55 | for _, message := range messages {
56 | tokens := m.encoding.Encode(message.Content, nil, nil)
57 | total += len(tokens)
58 | }
59 | return total
60 | }
61 |
--------------------------------------------------------------------------------
/pkg/provider/util.go:
--------------------------------------------------------------------------------
1 | package provider
2 |
3 | import (
4 | "fmt"
5 | "os"
6 | "strings"
7 | )
8 |
9 | func GetAPIKey(provider string) string {
10 | return os.Getenv(fmt.Sprintf("%s_API_KEY", strings.ToUpper(provider)))
11 | }
12 |
13 | func GetAPIBaseURL(provider string) string {
14 | return os.Getenv(fmt.Sprintf("%s_BASE_URL", strings.ToUpper(provider)))
15 | }
16 |
--------------------------------------------------------------------------------
/pkg/rag/document.go:
--------------------------------------------------------------------------------
1 | package rag
2 |
3 | type DocumentPart struct {
4 | DocumentID string
5 | PartIndex int
6 | Content string
7 | }
8 |
--------------------------------------------------------------------------------
/pkg/rag/rag.go:
--------------------------------------------------------------------------------
1 | package rag
2 |
3 | import (
4 | "context"
5 | "github.com/tonnie17/wxagent/pkg/embedding"
6 | "io"
7 | "io/fs"
8 | "log/slog"
9 | "os"
10 | "path/filepath"
11 | )
12 |
13 | type Client struct {
14 | embeddingModel embedding.Model
15 | store *VectorStore
16 | }
17 |
18 | func NewClient(embeddingModel embedding.Model, store *VectorStore) *Client {
19 | return &Client{
20 | embeddingModel: embeddingModel,
21 | store: store,
22 | }
23 | }
24 |
25 | func (c *Client) Query(ctx context.Context, model string, content string, limit int) ([]*DocumentPart, error) {
26 | embeddingData, err := c.embeddingModel.CreateEmbeddings(ctx, model, content)
27 | if err != nil {
28 | return nil, err
29 | }
30 |
31 | return c.store.GetMostRelevantDocuments(ctx, embeddingData, 1, limit)
32 | }
33 |
34 | func (c *Client) BuildKnowledgeBase(ctx context.Context, knowledgeBasePath string, model string, reBuild bool) error {
35 | if knowledgeBasePath == "" {
36 | return nil
37 | }
38 | if err := c.store.Init(ctx); err != nil {
39 | return err
40 | }
41 | return filepath.WalkDir(knowledgeBasePath, func(path string, d fs.DirEntry, e error) error {
42 | if d != nil && d.IsDir() {
43 | return nil
44 | }
45 | logger := slog.With(slog.String("file", path))
46 | ext := filepath.Ext(path)
47 | if ext != ".txt" {
48 | return nil
49 | }
50 | logger.Info("build knowledge base")
51 |
52 | switch ext {
53 | case ".txt":
54 | documentID := filepath.Base(path)
55 | if reBuild {
56 | err := c.store.DeleteDocuments(ctx, documentID)
57 | if err != nil {
58 | logger.Error("delete documents failed", slog.Any("err", err))
59 | return nil
60 | }
61 | } else {
62 | documentExist, err := c.store.CheckDocumentExist(ctx, documentID)
63 | if err != nil {
64 | logger.Error("check document exist failed", slog.Any("err", err))
65 | return nil
66 | }
67 | if documentExist {
68 | return nil
69 | }
70 | }
71 |
72 | splits, err := processTextFile(path)
73 | if err != nil {
74 | logger.Error("process text file failed", slog.Any("err", err))
75 | return nil
76 | }
77 |
78 | var partIndex int
79 | for _, content := range splits {
80 | embeddingData, err := c.embeddingModel.CreateEmbeddings(ctx, model, content)
81 | if err != nil {
82 | logger.Error("create embeddings failed", slog.Any("err", err))
83 | return nil
84 | }
85 | partIndex++
86 | if err := c.store.SaveDocumentEmbedding(ctx, documentID, partIndex, content, embeddingData); err != nil {
87 | logger.Error("save embedding failed", slog.Any("err", err))
88 | return nil
89 | }
90 | }
91 | }
92 |
93 | return nil
94 | })
95 | }
96 |
97 | func processTextFile(fileName string) ([]string, error) {
98 | f, err := os.Open(fileName)
99 | if err != nil {
100 | return nil, err
101 | }
102 | defer f.Close()
103 |
104 | content, err := io.ReadAll(f)
105 | if err != nil {
106 | return nil, err
107 | }
108 |
109 | return splitText(string(content), 500, 50), nil
110 | }
111 |
--------------------------------------------------------------------------------
/pkg/rag/splitter.go:
--------------------------------------------------------------------------------
1 | package rag
2 |
3 | import (
4 | "strings"
5 | "unicode/utf8"
6 | )
7 |
8 | func splitText(text string, chunkSize int, chunkOverlap int) []string {
9 | separators := []string{"\n\n", "\n", " ", ""}
10 | separator := separators[len(separators)-1]
11 |
12 | var newSeparators []string
13 | for i, sep := range separators {
14 | if sep == "" || strings.Contains(text, sep) {
15 | separator = sep
16 | newSeparators = separators[i+1:]
17 | break
18 | }
19 | }
20 |
21 | var final []string
22 | goodSplits := make([]string, 0)
23 | for _, split := range strings.Split(text, separator) {
24 | if utf8.RuneCountInString(split) < chunkSize {
25 | goodSplits = append(goodSplits, split)
26 | continue
27 | }
28 |
29 | if len(goodSplits) > 0 {
30 | mergedSplits := mergeSplits(goodSplits, separator, chunkSize, chunkOverlap)
31 |
32 | final = append(final, mergedSplits...)
33 | goodSplits = make([]string, 0)
34 | }
35 |
36 | if len(newSeparators) == 0 {
37 | final = append(final, split)
38 | } else {
39 | other := splitText(split, chunkSize, chunkOverlap)
40 | final = append(final, other...)
41 | }
42 | }
43 |
44 | if len(goodSplits) > 0 {
45 | mergedSplits := mergeSplits(goodSplits, separator, chunkSize, chunkOverlap)
46 | final = append(final, mergedSplits...)
47 | }
48 |
49 | return final
50 | }
51 |
52 | func mergeSplits(splits []string, separator string, chunkSize int, chunkOverlap int) []string {
53 | docs := make([]string, 0)
54 | currentDoc := make([]string, 0)
55 | total := 0
56 |
57 | for _, split := range splits {
58 | splitLen := utf8.RuneCountInString(split)
59 | sepLen := utf8.RuneCountInString(separator) * compareDocsLen(currentDoc, 0)
60 |
61 | if total+splitLen+sepLen > chunkSize && len(currentDoc) > 0 {
62 | if doc := strings.TrimSpace(strings.Join(currentDoc, separator)); doc != "" {
63 | docs = append(docs, doc)
64 | }
65 |
66 | for len(currentDoc) > 0 && (total > chunkOverlap ||
67 | (total+splitLen+utf8.RuneCountInString(separator)*compareDocsLen(currentDoc, 1) > chunkSize && total > 0)) {
68 | total -= utf8.RuneCountInString(currentDoc[0]) + utf8.RuneCountInString(separator)*compareDocsLen(currentDoc, 1)
69 | currentDoc = currentDoc[1:]
70 | }
71 | }
72 |
73 | currentDoc = append(currentDoc, split)
74 | total += utf8.RuneCountInString(split)
75 | total += utf8.RuneCountInString(separator) * compareDocsLen(currentDoc, 1)
76 | }
77 |
78 | if doc := strings.TrimSpace(strings.Join(currentDoc, separator)); doc != "" {
79 | docs = append(docs, doc)
80 | }
81 | return docs
82 | }
83 |
84 | func compareDocsLen(currentDocs []string, cmp int) int {
85 | if len(currentDocs) > cmp {
86 | return 1
87 | }
88 | return 0
89 | }
90 |
--------------------------------------------------------------------------------
/pkg/rag/store.go:
--------------------------------------------------------------------------------
1 | package rag
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "github.com/jackc/pgx/v5"
7 | "github.com/jackc/pgx/v5/pgxpool"
8 | "github.com/pgvector/pgvector-go"
9 | pgxvector "github.com/pgvector/pgvector-go/pgx"
10 | "os"
11 | )
12 |
13 | type VectorStore struct {
14 | pool *pgxpool.Pool
15 | }
16 |
17 | func NewPgVectorStore() (*VectorStore, error) {
18 | poolConfig, err := pgxpool.ParseConfig(os.Getenv("POSTGRES_URL"))
19 | if err != nil {
20 | return nil, err
21 | }
22 | poolConfig.ConnConfig.DefaultQueryExecMode = pgx.QueryExecModeCacheDescribe
23 | poolConfig.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error {
24 | return pgxvector.RegisterTypes(ctx, conn)
25 | }
26 |
27 | pool, err := pgxpool.NewWithConfig(context.Background(), poolConfig)
28 | if err != nil {
29 | return nil, err
30 | }
31 |
32 | return &VectorStore{
33 | pool: pool,
34 | }, nil
35 | }
36 |
37 | func (s *VectorStore) Init(ctx context.Context) error {
38 | query := `
39 | CREATE TABLE IF NOT EXISTS knowledge_base (
40 | id SERIAL PRIMARY KEY,
41 | document_id TEXT NOT NULL,
42 | part_index INT NOT NULL,
43 | content TEXT,
44 | embedding vector(1536),
45 | UNIQUE (document_id, part_index)
46 | )
47 | `
48 | if _, err := s.pool.Exec(ctx, query); err != nil {
49 | return err
50 | }
51 |
52 | return nil
53 | }
54 |
55 | func (s *VectorStore) GetMostRelevantDocuments(ctx context.Context, embedding []float32, threshold float32, limit int) ([]*DocumentPart, error) {
56 | query := fmt.Sprintf("SELECT document_id, part_index, content FROM knowledge_base WHERE embedding <-> $1 < $2 ORDER BY embedding <-> $1 LIMIT %v", limit)
57 | rows, err := s.pool.Query(ctx, query, pgvector.NewVector(embedding), threshold)
58 | if err != nil {
59 | return nil, err
60 | }
61 | defer rows.Close()
62 |
63 | var documents []*DocumentPart
64 | for rows.Next() {
65 | document := &DocumentPart{}
66 | if err := rows.Scan(&document.DocumentID, &document.PartIndex, &document.Content); err != nil {
67 | return nil, err
68 | }
69 | documents = append(documents, document)
70 | }
71 |
72 | if rows.Err() != nil {
73 | return nil, rows.Err()
74 | }
75 |
76 | return documents, nil
77 | }
78 |
79 | func (s *VectorStore) SaveDocumentEmbedding(ctx context.Context, documentID string, partIndex int, content string, embedding []float32) error {
80 | _, err := s.pool.Exec(ctx, `
81 | INSERT INTO knowledge_base (document_id, part_index, content, embedding)
82 | VALUES ($1, $2, $3, $4)
83 | ON CONFLICT (document_id, part_index)
84 | DO UPDATE SET
85 | content = EXCLUDED.content,
86 | embedding = EXCLUDED.embedding;
87 | `, documentID, partIndex, content, pgvector.NewVector(embedding))
88 |
89 | if err != nil {
90 | return err
91 | }
92 |
93 | return nil
94 | }
95 |
96 | func (s *VectorStore) CheckDocumentExist(ctx context.Context, documentID string) (bool, error) {
97 | var exists bool
98 | err := s.pool.QueryRow(ctx, `
99 | SELECT EXISTS (
100 | SELECT 1 FROM knowledge_base WHERE document_id = $1
101 | )
102 | `, documentID).Scan(&exists)
103 |
104 | if err != nil {
105 | return false, nil
106 | }
107 |
108 | return exists, nil
109 | }
110 |
111 | func (s *VectorStore) DeleteDocuments(ctx context.Context, documentID string) error {
112 | _, err := s.pool.Exec(ctx, `
113 | DELETE FROM knowledge_base WHERE document_id = $1
114 | `, documentID)
115 |
116 | if err != nil {
117 | return err
118 | }
119 |
120 | return nil
121 | }
122 |
123 | func (s *VectorStore) Release() {
124 | s.pool.Close()
125 | }
126 |
--------------------------------------------------------------------------------
/pkg/tool/execute_ha_devices.go:
--------------------------------------------------------------------------------
1 | package tool
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "github.com/tonnie17/wxagent/pkg/ha"
7 | "log/slog"
8 | )
9 |
10 | type ExecuteHADevice struct {
11 | }
12 |
13 | func NewExecuteHADevice() Tool {
14 | return &ExecuteHADevice{}
15 | }
16 |
17 | func (e *ExecuteHADevice) Name() string {
18 | return "execute_device"
19 | }
20 |
21 | func (e *ExecuteHADevice) Description() string {
22 | return "Use this function to execute service of devices in Home Assistant"
23 | }
24 |
25 | func (e *ExecuteHADevice) Schema() map[string]interface{} {
26 | return map[string]interface{}{
27 | "type": "object",
28 | "properties": map[string]interface{}{
29 | "list": map[string]interface{}{
30 | "type": "array",
31 | "items": map[string]interface{}{
32 | "type": "object",
33 | "properties": map[string]interface{}{
34 | "domain": map[string]interface{}{
35 | "type": "string",
36 | "description": "The domain of the service",
37 | },
38 | "service": map[string]interface{}{
39 | "type": "string",
40 | "description": "The service to be called",
41 | },
42 | "entity_id": map[string]interface{}{
43 | "type": "string",
44 | "description": "The entity_id retrieved from available devices. It must start with domain, followed by dot character",
45 | },
46 | },
47 | "required": []string{"domain", "service", "entity_id"},
48 | },
49 | },
50 | },
51 | }
52 | }
53 |
54 | func (e *ExecuteHADevice) Execute(ctx context.Context, input string) (string, error) {
55 | var arguments struct {
56 | List []struct {
57 | Domain string `json:"domain"`
58 | Service string `json:"service"`
59 | EntityID string `json:"entity_id"`
60 | } `json:"list"`
61 | }
62 | if err := json.Unmarshal([]byte(input), &arguments); err != nil {
63 | slog.Error("unmarshal failed", slog.Any("err", err))
64 | return "", err
65 | }
66 |
67 | var entityStates []ha.EntityState
68 | for _, action := range arguments.List {
69 | states, err := ha.ExecuteService(ctx, action.Domain, action.Service, action.EntityID)
70 | if err != nil {
71 | return "", err
72 | }
73 | entityStates = append(entityStates, states...)
74 | }
75 |
76 | statesJSON, _ := json.Marshal(entityStates)
77 |
78 | return string(statesJSON), nil
79 | }
80 |
--------------------------------------------------------------------------------
/pkg/tool/get_ha_devices.go:
--------------------------------------------------------------------------------
1 | package tool
2 |
3 | import (
4 | "context"
5 | "fmt"
6 | "github.com/tonnie17/wxagent/pkg/ha"
7 | "strings"
8 | )
9 |
10 | type GetHADevices struct {
11 | }
12 |
13 | func NewGetHADevices() Tool {
14 | return &GetHADevices{}
15 | }
16 |
17 | func (g *GetHADevices) Name() string {
18 | return "get_devices"
19 | }
20 |
21 | func (g *GetHADevices) Description() string {
22 | return "Use this function to get devices in Home Assistant, including their state and entity_id"
23 | }
24 |
25 | func (g *GetHADevices) Schema() map[string]interface{} {
26 | return map[string]interface{}{
27 | "type": "object",
28 | "properties": map[string]interface{}{},
29 | "required": []string{},
30 | }
31 | }
32 |
33 | func (g *GetHADevices) Execute(ctx context.Context, input string) (string, error) {
34 | entityStates, err := ha.GetEntityStates(ctx, g.defaultDomains())
35 | if err != nil {
36 | return "", err
37 | }
38 |
39 | var builder strings.Builder
40 | builder.WriteString("An overview of the devices in this smart home:\n")
41 | builder.WriteString("```csv\n")
42 | builder.WriteString("entity_id,name,state\n")
43 | for _, entityState := range entityStates {
44 | builder.WriteString(fmt.Sprintf("%v,%v,%v\n", entityState.EntityID, entityState.Attributes.FriendlyName, entityState.State))
45 | }
46 | builder.WriteString("```\n")
47 |
48 | return builder.String(), nil
49 | }
50 |
51 | func (g *GetHADevices) defaultDomains() []string {
52 | return []string{
53 | "door",
54 | "lock",
55 | "occupancy",
56 | "motion",
57 | "climate",
58 | "light",
59 | "switch",
60 | "sensor",
61 | "speaker",
62 | "media_player",
63 | "temperature",
64 | "humidity",
65 | "battery",
66 | "tv",
67 | "remote",
68 | "light",
69 | "vacuum",
70 | }
71 | }
72 |
--------------------------------------------------------------------------------
/pkg/tool/get_weather.go:
--------------------------------------------------------------------------------
1 | package tool
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "fmt"
7 | "io"
8 | "log/slog"
9 | "net/http"
10 | "os"
11 | "strings"
12 | "unicode"
13 | )
14 |
15 | type GetWeather struct {
16 | }
17 |
18 | func NewGetWeather() Tool {
19 | return &GetWeather{}
20 | }
21 |
22 | func (w *GetWeather) Name() string {
23 | return "get_weather"
24 | }
25 |
26 | func (w *GetWeather) Description() string {
27 | return "Retrieve the current weather information for a specified city and return the city name in English"
28 | }
29 |
30 | func (w *GetWeather) Schema() map[string]interface{} {
31 | return map[string]interface{}{
32 | "type": "object",
33 | "properties": map[string]interface{}{
34 | "city": map[string]interface{}{
35 | "type": "string",
36 | "description": "city name",
37 | },
38 | },
39 | "required": []string{"city"},
40 | }
41 | }
42 |
43 | func (w *GetWeather) Execute(ctx context.Context, input string) (string, error) {
44 | var arguments struct {
45 | City string `json:"city"`
46 | }
47 | if err := json.Unmarshal([]byte(input), &arguments); err != nil {
48 | slog.Error("unmarshal failed", slog.Any("err", err))
49 | return "", err
50 | }
51 |
52 | if arguments.City == "" {
53 | return "", fmt.Errorf("city name is empty")
54 | }
55 |
56 | city := w.completeCNCity(arguments.City)
57 |
58 | appID := os.Getenv("OPENWEATHERMAP_API_KEY")
59 | apiURL := fmt.Sprintf("https://api.openweathermap.org/data/2.5/forecast?q=%v&units=metric&appid=%v&lang=zh_cn&cnt=5", city, appID)
60 |
61 | req, err := http.NewRequest("GET", apiURL, nil)
62 | if err != nil {
63 | return "", err
64 | }
65 | req = req.WithContext(ctx)
66 |
67 | client := &http.Client{}
68 | resp, err := client.Do(req)
69 | if err != nil {
70 | return "", err
71 | }
72 | defer resp.Body.Close()
73 |
74 | content, err := io.ReadAll(resp.Body)
75 | if err != nil {
76 | return "", err
77 | }
78 |
79 | var weatherData struct {
80 | List []struct {
81 | Dt int `json:"dt"`
82 | Main struct {
83 | Temp float64 `json:"temp"`
84 | } `json:"main"`
85 | Weather []struct {
86 | Description string `json:"description"`
87 | } `json:"weather"`
88 | DtTxt string `json:"dt_txt"`
89 | } `json:"list"`
90 | City struct {
91 | Name string `json:"name"`
92 | Coord struct {
93 | Lat float64 `json:"lat"`
94 | Lon float64 `json:"lon"`
95 | } `json:"coord"`
96 | Sunrise int `json:"sunrise"`
97 | Sunset int `json:"sunset"`
98 | } `json:"city"`
99 | }
100 |
101 | if err := json.Unmarshal(content, &weatherData); err != nil {
102 | return "", err
103 | }
104 |
105 | mainContent, _ := json.Marshal(weatherData)
106 |
107 | return string(mainContent), nil
108 | }
109 |
110 | func (w *GetWeather) isCNCity(s string) bool {
111 | if s == "" {
112 | return false
113 | }
114 |
115 | for _, r := range s {
116 | if !unicode.Is(unicode.Han, r) {
117 | return false
118 | }
119 | }
120 | return true
121 | }
122 |
123 | func (w *GetWeather) completeCNCity(city string) string {
124 | if w.isCNCity(city) && !strings.HasSuffix(city, "市") {
125 | return city + "市"
126 | }
127 | return city
128 | }
129 |
--------------------------------------------------------------------------------
/pkg/tool/google_search.go:
--------------------------------------------------------------------------------
1 | package tool
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "fmt"
7 | "io"
8 | "log/slog"
9 | "net/http"
10 | "net/url"
11 | "os"
12 | )
13 |
14 | type GoogleSearch struct {
15 | }
16 |
17 | func NewGoogleSearch() Tool {
18 | return &GoogleSearch{}
19 | }
20 |
21 | func (w *GoogleSearch) Name() string {
22 | return "google_search"
23 | }
24 |
25 | func (w *GoogleSearch) Description() string {
26 | return "Make a query to the Google search engine to receive a list of results"
27 | }
28 |
29 | func (w *GoogleSearch) Schema() map[string]interface{} {
30 | return map[string]interface{}{
31 | "type": "object",
32 | "properties": map[string]interface{}{
33 | "query": map[string]interface{}{
34 | "type": "string",
35 | "description": "The query to be passed to Google search",
36 | },
37 | },
38 | "required": []string{"query"},
39 | }
40 | }
41 |
42 | func (w *GoogleSearch) Execute(ctx context.Context, input string) (string, error) {
43 | var arguments struct {
44 | Query string `json:"query"`
45 | }
46 | if err := json.Unmarshal([]byte(input), &arguments); err != nil {
47 | slog.Error("unmarshal failed", slog.Any("err", err))
48 | return "", err
49 | }
50 |
51 | apiKey := os.Getenv("GOOGLE_SEARCH_API_KEY")
52 | engine := os.Getenv("GOOGLE_SEARCH_ENGINE")
53 | apiURL := fmt.Sprintf("https://www.googleapis.com/customsearch/v1?key=%v&cx=%v&q=%v&num=3", apiKey, engine, url.QueryEscape(arguments.Query))
54 |
55 | req, err := http.NewRequest("GET", apiURL, nil)
56 | if err != nil {
57 | return "", err
58 | }
59 | req = req.WithContext(ctx)
60 |
61 | client := &http.Client{}
62 | resp, err := client.Do(req)
63 | if err != nil {
64 | return "", err
65 | }
66 | defer resp.Body.Close()
67 |
68 | content, err := io.ReadAll(resp.Body)
69 | if err != nil {
70 | return "", err
71 | }
72 |
73 | sr := &searchResult{}
74 | if err := json.Unmarshal(content, &sr); err != nil {
75 | return "", err
76 | }
77 |
78 | itemsJSON, _ := json.Marshal(sr.Items)
79 | return string(itemsJSON), nil
80 | }
81 |
82 | type searchResult struct {
83 | Items []*searchResultItem `json:"items"`
84 | }
85 |
86 | type searchResultItem struct {
87 | Title string `json:"title"`
88 | Link string `json:"link"`
89 | Snippet string `json:"snippet"`
90 | }
91 |
--------------------------------------------------------------------------------
/pkg/tool/tool.go:
--------------------------------------------------------------------------------
1 | package tool
2 |
3 | import "context"
4 |
5 | var tools = make(map[string]Tool)
6 |
7 | func init() {
8 | for _, tool := range []Tool{
9 | NewGetWeather(),
10 | NewGoogleSearch(),
11 | NewWebPageSummary(),
12 | NewGetHADevices(),
13 | NewExecuteHADevice(),
14 | } {
15 | tools[tool.Name()] = tool
16 | }
17 | }
18 |
19 | type Tool interface {
20 | Name() string
21 | Description() string
22 | Schema() map[string]interface{}
23 | Execute(context.Context, string) (string, error)
24 | }
25 |
26 | type Call struct {
27 | ID string `json:"id"`
28 | Type string `json:"type"`
29 | Name string `json:"name"`
30 | Arguments string `json:"arguments"`
31 | }
32 |
33 | func DefaultTools() []Tool {
34 | res := make([]Tool, 0, len(tools))
35 | for _, tool := range tools {
36 | res = append(res, tool)
37 | }
38 | return res
39 | }
40 |
41 | func GetTools(names []string) []Tool {
42 | res := make([]Tool, 0, len(names))
43 | for _, name := range names {
44 | if _, ok := tools[name]; ok {
45 | res = append(res, tools[name])
46 | }
47 | }
48 | return res
49 | }
50 |
--------------------------------------------------------------------------------
/pkg/tool/webpage_summary.go:
--------------------------------------------------------------------------------
1 | package tool
2 |
3 | import (
4 | "context"
5 | "encoding/json"
6 | "github.com/PuerkitoBio/goquery"
7 | "io"
8 | "log/slog"
9 | "net/http"
10 | "regexp"
11 | "strings"
12 | )
13 |
14 | type WebPageSummary struct {
15 | }
16 |
17 | func NewWebPageSummary() Tool {
18 | return &WebPageSummary{}
19 | }
20 |
21 | func (w *WebPageSummary) Name() string {
22 | return "webpage_summary"
23 | }
24 |
25 | func (w *WebPageSummary) Description() string {
26 | return "Summaries the content of a web page"
27 | }
28 |
29 | func (w *WebPageSummary) Schema() map[string]interface{} {
30 | return map[string]interface{}{
31 | "type": "object",
32 | "properties": map[string]interface{}{
33 | "url": map[string]interface{}{
34 | "type": "string",
35 | "description": "the URL of the web page",
36 | },
37 | },
38 | "required": []string{"url"},
39 | }
40 | }
41 |
42 | func (w *WebPageSummary) Execute(ctx context.Context, input string) (string, error) {
43 | var arguments struct {
44 | URL string `json:"url"`
45 | }
46 | if err := json.Unmarshal([]byte(input), &arguments); err != nil {
47 | slog.Error("unmarshal failed", slog.Any("err", err))
48 | return "", err
49 | }
50 |
51 | req, err := http.NewRequest("GET", arguments.URL, nil)
52 | if err != nil {
53 | return "", err
54 | }
55 | req = req.WithContext(ctx)
56 |
57 | client := http.Client{}
58 | r, err := client.Do(req)
59 | if err != nil {
60 | return "", err
61 | }
62 | body, _ := io.ReadAll(r.Body)
63 |
64 | document, err := goquery.NewDocumentFromReader(strings.NewReader(string(body)))
65 | if err != nil {
66 | return "", err
67 | }
68 | document.Find("script, style, pre, code").Each(func(index int, item *goquery.Selection) {
69 | item.Remove()
70 | })
71 | text := document.Text()
72 |
73 | re := regexp.MustCompile(`\s+`)
74 | text = re.ReplaceAllString(text, " ")
75 |
76 | return text, nil
77 | }
78 |
--------------------------------------------------------------------------------
/pkg/wechat/crypt.go:
--------------------------------------------------------------------------------
1 | package wechat
2 |
3 | import (
4 | "crypto/aes"
5 | "crypto/cipher"
6 | "encoding/base64"
7 | "encoding/binary"
8 | "errors"
9 | "fmt"
10 | "math/rand"
11 | "time"
12 | )
13 |
14 | const (
15 | blockSize = 32
16 | blockMask = blockSize - 1
17 |
18 | randomSize = 16
19 | contentLenSize = 4
20 | )
21 |
22 | func init() {
23 | rand.New(rand.NewSource(time.Now().UnixNano()))
24 | }
25 |
26 | func EncryptMsg(appID string, msg string, encodingAesKey string) (string, error) {
27 | aesKey, err := base64.StdEncoding.DecodeString(encodingAesKey + "=")
28 | if err != nil {
29 | return "", err
30 | }
31 |
32 | if len(aesKey) != 32 {
33 | return "", errors.New("invalid aes key length")
34 | }
35 |
36 | appIDOffset := randomSize + contentLenSize + len(msg)
37 | contentLen := appIDOffset + len(appID)
38 | amountToPad := blockSize - contentLen&blockMask
39 | plaintextLen := contentLen + amountToPad
40 | plaintext := make([]byte, plaintextLen)
41 |
42 | copy(plaintext[:randomSize], randString(randomSize))
43 | binary.BigEndian.PutUint32(plaintext[randomSize:randomSize+contentLenSize], uint32(len(msg)))
44 | copy(plaintext[randomSize+contentLenSize:], msg)
45 | copy(plaintext[appIDOffset:], appID)
46 |
47 | for i := contentLen; i < plaintextLen; i++ {
48 | plaintext[i] = byte(amountToPad)
49 | }
50 |
51 | block, err := aes.NewCipher(aesKey)
52 | if err != nil {
53 | return "", err
54 | }
55 |
56 | mode := cipher.NewCBCEncrypter(block, aesKey[:aes.BlockSize])
57 | mode.CryptBlocks(plaintext, plaintext)
58 |
59 | return base64.StdEncoding.EncodeToString(plaintext), nil
60 | }
61 |
62 | func DecryptMsg(appID, encryptedMsg, encodingAesKey string) (string, error) {
63 | aesKey, err := base64.StdEncoding.DecodeString(encodingAesKey + "=")
64 | if err != nil {
65 | return "", err
66 | }
67 |
68 | if len(aesKey) != 32 {
69 | return "", errors.New("invalid aes key length")
70 | }
71 |
72 | cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg)
73 | if err != nil {
74 | return "", fmt.Errorf("base64 decode error: %v", err)
75 | }
76 |
77 | block, err := aes.NewCipher(aesKey)
78 | if err != nil {
79 | return "", fmt.Errorf("failed to create aes cipher: %v", err)
80 | }
81 |
82 | plainBytes := make([]byte, len(cipherText))
83 | mode := cipher.NewCBCDecrypter(block, aesKey[:aes.BlockSize])
84 | mode.CryptBlocks(plainBytes, cipherText)
85 |
86 | pad := int(plainBytes[len(plainBytes)-1])
87 | if pad < 1 || pad > blockSize {
88 | return "", errors.New("invalid padding byte")
89 | }
90 | plainBytes = plainBytes[:len(plainBytes)-pad]
91 |
92 | content := plainBytes[randomSize:]
93 | if len(content) < contentLenSize {
94 | return "", errors.New("invalid content length")
95 | }
96 |
97 | contentLen := binary.BigEndian.Uint32(content[:contentLenSize])
98 | if len(content) < int(contentLenSize+contentLen) {
99 | return "", errors.New("invalid content length")
100 | }
101 |
102 | rawContent := content[contentLenSize : contentLenSize+contentLen]
103 | fromAppID := string(content[contentLenSize+contentLen:])
104 |
105 | if fromAppID != appID {
106 | return "", errors.New("app id mismatch")
107 | }
108 |
109 | return string(rawContent), nil
110 | }
111 |
112 | func randString(n int) string {
113 | var letterRunes = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
114 | b := make([]rune, n)
115 | for i := range b {
116 | b[i] = letterRunes[rand.Intn(len(letterRunes))]
117 | }
118 | return string(b)
119 | }
120 |
--------------------------------------------------------------------------------
/pkg/wechat/model.go:
--------------------------------------------------------------------------------
1 | package wechat
2 |
3 | import "encoding/xml"
4 |
5 | const (
6 | MsgTypeText = "text"
7 | )
8 |
9 | type CommonMessage struct {
10 | XMLName xml.Name `xml:"xml"`
11 | Encrypt string `xml:"Encrypt" json:"Encrypt"`
12 | ToUserName string `xml:"ToUserName" json:"ToUserName"`
13 | FromUserName string `xml:"FromUserName" json:"FromUserName"`
14 | CreateTime int64 `xml:"CreateTime" json:"CreateTime"`
15 | MsgType string `xml:"MsgType" json:"MsgType"`
16 | MsgID string `xml:"MsgId" json:"MsgId"`
17 | }
18 |
19 | type TextMessage struct {
20 | CommonMessage
21 | Content string `xml:"Content" json:"Content"`
22 | }
23 |
24 | type ImageMessage struct {
25 | CommonMessage
26 | PicURL string `xml:"PicUrl" json:"PicUrl"`
27 | MediaId string `xml:"MediaId" json:"MediaId"`
28 | }
29 |
30 | type EncryptMessage struct {
31 | XMLName xml.Name `xml:"xml"`
32 | Encrypt string `xml:"Encrypt" json:"Encrypt"`
33 | MsgSignature string `xml:"MsgSignature" json:"MsgSignature"`
34 | Timestamp int64 `xml:"TimeStamp" json:"TimeStamp"`
35 | Nonce string `xml:"Nonce" json:"Nonce"`
36 | }
37 |
--------------------------------------------------------------------------------
/pkg/wechat/signature.go:
--------------------------------------------------------------------------------
1 | package wechat
2 |
3 | import (
4 | "crypto/sha1"
5 | "fmt"
6 | "sort"
7 | )
8 |
9 | func Signature(params ...string) string {
10 | sort.Strings(params)
11 | h := sha1.New()
12 | for _, s := range params {
13 | h.Write([]byte(s))
14 | }
15 | return fmt.Sprintf("%x", h.Sum(nil))
16 | }
17 |
--------------------------------------------------------------------------------
/vercel.json:
--------------------------------------------------------------------------------
1 | {
2 | "routes": [
3 | {
4 | "src": "/.*",
5 | "dest": "/api/router.go"
6 | }
7 | ]
8 | }
--------------------------------------------------------------------------------
/web/chat_handler.go:
--------------------------------------------------------------------------------
1 | package web
2 |
3 | import (
4 | "encoding/json"
5 | "fmt"
6 | "github.com/tonnie17/wxagent/pkg/agent"
7 | "github.com/tonnie17/wxagent/pkg/config"
8 | "github.com/tonnie17/wxagent/pkg/llm"
9 | "github.com/tonnie17/wxagent/pkg/memory"
10 | "github.com/tonnie17/wxagent/pkg/rag"
11 | "github.com/tonnie17/wxagent/pkg/tool"
12 | "io"
13 | "log/slog"
14 | "net/http"
15 | )
16 |
17 | type ChatHandler struct {
18 | config *config.Config
19 | ragClient *rag.Client
20 | }
21 |
22 | func NewChatHandler(config *config.Config, ragClient *rag.Client) *ChatHandler {
23 | return &ChatHandler{
24 | config: config,
25 | ragClient: ragClient,
26 | }
27 | }
28 |
29 | func (h *ChatHandler) Stream(w http.ResponseWriter, r *http.Request) {
30 | w.Header().Set("Content-Type", "text/event-stream")
31 | w.Header().Set("Cache-Control", "no-cache")
32 | w.Header().Set("Connection", "keep-alive")
33 |
34 | flusher, ok := w.(http.Flusher)
35 | if !ok {
36 | http.Error(w, "streaming unsupported", http.StatusInternalServerError)
37 | return
38 | }
39 |
40 | body, err := io.ReadAll(r.Body)
41 | if err != nil {
42 | http.Error(w, err.Error(), http.StatusInternalServerError)
43 | return
44 | }
45 |
46 | notify := r.Context().Done()
47 | req := &struct {
48 | Model string `json:"model"`
49 | MaxInputTokens int `json:"max_input_tokens"`
50 | MaxTokens int `json:"max_tokens"`
51 | Temperature float32 `json:"temperature"`
52 | TopP float32 `json:"top_p"`
53 | Messages []*llm.ChatMessage `json:"messages"`
54 | }{}
55 |
56 | if err := json.Unmarshal(body, &req); err != nil {
57 | http.Error(w, err.Error(), http.StatusInternalServerError)
58 | return
59 | }
60 |
61 | var maxInputTokens int
62 | if req.MaxInputTokens == 0 {
63 | maxInputTokens = 500
64 | }
65 | mem := memory.NewTokenBase(req.Model, maxInputTokens)
66 | mem.Update(req.Messages)
67 |
68 | a := agent.NewAgent(&h.config.AgentConfig, llm.New(h.config.LLMProvider), mem, tool.GetTools(h.config.AgentTools), h.ragClient)
69 | out := make(chan *llm.ChatMessage)
70 | go func() {
71 | if err = a.ProcessStream(r.Context(), nil, out); err != nil {
72 | return
73 | }
74 | }()
75 |
76 | for {
77 | select {
78 | case <-notify:
79 | slog.Info("client disconnected")
80 | return
81 | case msg, ok := <-out:
82 | if !ok {
83 | slog.Info("client finished")
84 | return
85 | }
86 | if err != nil {
87 | fmt.Fprintf(w, "%s\n\n", err.Error())
88 | return
89 | }
90 | msgJSON, _ := json.Marshal(msg)
91 | fmt.Fprintf(w, "%s\n\n", string(msgJSON))
92 | flusher.Flush()
93 | }
94 | }
95 | }
96 |
--------------------------------------------------------------------------------
/web/keyword.go:
--------------------------------------------------------------------------------
1 | package web
2 |
3 | import (
4 | "fmt"
5 | "strings"
6 | )
7 |
8 | const (
9 | continueKeyword string = "继续"
10 | )
11 |
12 | func detectContinue(input string) bool {
13 | return strings.EqualFold(input, continueKeyword)
14 | }
15 |
16 | func getContinueHint() string {
17 | return fmt.Sprintf("处理时间过长,请稍后回复%s获取上一轮对话结果", continueKeyword)
18 | }
19 |
20 | func getContinueEmptyHint() string {
21 | return "上一轮对话结果为空,请重试对话"
22 | }
23 |
24 | func getProcessingHint() string {
25 | return "上一轮对话正在处理中,请稍后再试"
26 | }
27 |
--------------------------------------------------------------------------------
/web/router.go:
--------------------------------------------------------------------------------
1 | package web
2 |
3 | import (
4 | "github.com/go-chi/chi/v5"
5 | "github.com/go-chi/chi/v5/middleware"
6 | slogchi "github.com/samber/slog-chi"
7 | "github.com/tonnie17/wxagent/pkg/config"
8 | "github.com/tonnie17/wxagent/pkg/rag"
9 | "log/slog"
10 | "net/http"
11 | "strings"
12 | "time"
13 | )
14 |
15 | func SetupRouter(r chi.Router, config *config.Config, logger *slog.Logger, ragClient *rag.Client) {
16 | r.Use(slogchi.New(logger))
17 | r.Use(middleware.Recoverer)
18 | r.Use(middleware.StripSlashes)
19 |
20 | memStore := NewUserMemoryStore(config.WechatMemTTL)
21 | memStore.CheckAndClear(time.Second)
22 |
23 | wechatHandler := NewWechatHandler(config, memStore, ragClient)
24 | r.Get("/", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("OK")) })
25 | r.Get("/wechat/receive", wechatHandler.Receive)
26 | r.Post("/wechat/receive", wechatHandler.Receive)
27 |
28 | chatHandler := NewChatHandler(config, ragClient)
29 | r.With(apiKeyAuth(config.ChatAPIKEY)).Post("/chat/stream", chatHandler.Stream)
30 | }
31 |
32 | func apiKeyAuth(apiKey string) func(http.Handler) http.Handler {
33 | return func(next http.Handler) http.Handler {
34 | fn := func(w http.ResponseWriter, r *http.Request) {
35 | var token string
36 | splits := strings.Split(r.Header.Get("Authorization"), "Bearer ")
37 | if len(splits) == 2 {
38 | token = splits[1]
39 | }
40 | if apiKey == "" || token != apiKey {
41 | http.Error(w, "miss api key", http.StatusUnauthorized)
42 | return
43 | }
44 | next.ServeHTTP(w, r)
45 | }
46 | return http.HandlerFunc(fn)
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/web/user_memory.go:
--------------------------------------------------------------------------------
1 | package web
2 |
3 | import (
4 | "github.com/tonnie17/wxagent/pkg/memory"
5 | "sync"
6 | "time"
7 | )
8 |
9 | type UserMemory struct {
10 | mem memory.Memory
11 | lastAccess time.Time
12 | }
13 |
14 | type UserMemoryStore struct {
15 | memTTL time.Duration
16 | once sync.Once
17 | userMemory sync.Map
18 | }
19 |
20 | func NewUserMemoryStore(memTTL time.Duration) *UserMemoryStore {
21 | return &UserMemoryStore{
22 | memTTL: memTTL,
23 | }
24 | }
25 |
26 | func (s *UserMemoryStore) GetOrNew(userID string, memFactory func() memory.Memory) memory.Memory {
27 | v, ok := s.userMemory.Load(userID)
28 | if !ok {
29 | mem := memFactory()
30 | userMem := &UserMemory{
31 | mem: mem,
32 | lastAccess: time.Now(),
33 | }
34 | s.userMemory.Store(userID, userMem)
35 | return mem
36 | }
37 | userMem := v.(*UserMemory)
38 | userMem.lastAccess = time.Now()
39 | return userMem.mem
40 | }
41 |
42 | func (s *UserMemoryStore) CheckAndClear(interval time.Duration) {
43 | worker := func() {
44 | ticker := time.NewTicker(interval)
45 | defer ticker.Stop()
46 |
47 | for range ticker.C {
48 | s.userMemory.Range(func(key, value any) bool {
49 | userMem := value.(*UserMemory)
50 | if time.Now().Sub(userMem.lastAccess) > s.memTTL {
51 | s.userMemory.Delete(key)
52 | }
53 | return true
54 | })
55 | }
56 | }
57 | s.once.Do(func() {
58 | go worker()
59 | })
60 | }
61 |
--------------------------------------------------------------------------------
/web/util.go:
--------------------------------------------------------------------------------
1 | package web
2 |
3 | import (
4 | "encoding/xml"
5 | "io"
6 | "log/slog"
7 | )
8 |
9 | func xmlParseRequest(body io.Reader, req interface{}) error {
10 | b, err := io.ReadAll(body)
11 | if err != nil {
12 | slog.Error("body read failed", slog.Any("err", err))
13 | return err
14 | }
15 |
16 | if err := xml.Unmarshal(b, &req); err != nil {
17 | slog.Error("request parse failed", slog.Any("err", err))
18 | return err
19 | }
20 |
21 | return nil
22 | }
23 |
--------------------------------------------------------------------------------
/web/wechat_handler.go:
--------------------------------------------------------------------------------
1 | package web
2 |
3 | import (
4 | "context"
5 | "encoding/xml"
6 | "errors"
7 | "github.com/tonnie17/wxagent/pkg/agent"
8 | "github.com/tonnie17/wxagent/pkg/config"
9 | "github.com/tonnie17/wxagent/pkg/llm"
10 | "github.com/tonnie17/wxagent/pkg/memory"
11 | "github.com/tonnie17/wxagent/pkg/rag"
12 | "github.com/tonnie17/wxagent/pkg/tool"
13 | "github.com/tonnie17/wxagent/pkg/wechat"
14 | "log/slog"
15 | "net/http"
16 | "strconv"
17 | "strings"
18 | "time"
19 | )
20 |
21 | type WechatHandler struct {
22 | config *config.Config
23 | memStore *UserMemoryStore
24 | ragClient *rag.Client
25 | }
26 |
27 | func NewWechatHandler(config *config.Config, memStore *UserMemoryStore, ragClient *rag.Client) *WechatHandler {
28 | return &WechatHandler{
29 | config: config,
30 | memStore: memStore,
31 | ragClient: ragClient,
32 | }
33 | }
34 |
35 | func (h *WechatHandler) Receive(w http.ResponseWriter, r *http.Request) {
36 | signature := r.URL.Query().Get("signature")
37 | msgSignature := r.URL.Query().Get("msg_signature")
38 | timestamp := r.URL.Query().Get("timestamp")
39 | nonce := r.URL.Query().Get("nonce")
40 | echoStr := r.URL.Query().Get("echostr")
41 |
42 | if signature != wechat.Signature(h.config.WechatToken, timestamp, nonce) {
43 | slog.Error("signature check failed",
44 | slog.String("signature", signature),
45 | slog.String("timestamp", timestamp),
46 | slog.String("nonce", nonce),
47 | slog.String("echostr", echoStr),
48 | )
49 | http.Error(w, "signature check failed", http.StatusUnauthorized)
50 | return
51 | }
52 |
53 | if echoStr != "" {
54 | w.Write([]byte(echoStr))
55 | return
56 | }
57 |
58 | var reqMessage wechat.TextMessage
59 | if err := xmlParseRequest(r.Body, &reqMessage); err != nil {
60 | http.Error(w, err.Error(), http.StatusInternalServerError)
61 | return
62 | }
63 |
64 | if reqMessage.Encrypt != "" {
65 | if msgSignature != wechat.Signature(h.config.WechatToken, timestamp, nonce, reqMessage.Encrypt) {
66 | slog.Error("msg signature check failed",
67 | slog.String("signature", signature),
68 | slog.String("timestamp", timestamp),
69 | slog.String("nonce", nonce),
70 | slog.String("echostr", echoStr),
71 | )
72 | http.Error(w, "signature check failed", http.StatusUnauthorized)
73 | return
74 | }
75 |
76 | content, err := wechat.DecryptMsg(h.config.WechatAppID, reqMessage.Encrypt, h.config.WechatEncodingAESKey)
77 | if err != nil {
78 | http.Error(w, err.Error(), http.StatusInternalServerError)
79 | return
80 | }
81 |
82 | if err := xmlParseRequest(strings.NewReader(content), &reqMessage); err != nil {
83 | http.Error(w, err.Error(), http.StatusInternalServerError)
84 | return
85 | }
86 | }
87 |
88 | slog.Info("receive req", slog.Any("req", reqMessage))
89 |
90 | if len(h.config.WechatAllowList) > 0 {
91 | var isAllow bool
92 | for _, userID := range h.config.WechatAllowList {
93 | if userID == reqMessage.FromUserName {
94 | isAllow = true
95 | break
96 | }
97 | }
98 | if !isAllow {
99 | http.Error(w, "access denied", http.StatusUnauthorized)
100 | return
101 | }
102 | }
103 |
104 | mem := h.memStore.GetOrNew(reqMessage.FromUserName, func() memory.Memory {
105 | return memory.NewBuffer(h.config.WechatMemMsgSize)
106 | })
107 |
108 | a := agent.NewAgent(&h.config.AgentConfig, llm.New(h.config.LLMProvider), mem, tool.GetTools(h.config.AgentTools), h.ragClient)
109 |
110 | input := strings.TrimSpace(reqMessage.Content)
111 | result := make(chan string)
112 | go func() {
113 | switch reqMessage.MsgType {
114 | case wechat.MsgTypeText:
115 | var (
116 | output string
117 | err error
118 | )
119 | if detectContinue(input) {
120 | if output, err = a.ChatContinue(context.Background()); output == "" {
121 | output = getContinueEmptyHint()
122 | }
123 | } else {
124 | output, err = a.Chat(context.Background(), input)
125 | }
126 | if err != nil {
127 | if errors.Is(err, agent.ErrMemoryInUse) {
128 | result <- getProcessingHint()
129 | } else {
130 | result <- err.Error()
131 | }
132 | } else {
133 | result <- output
134 | }
135 | }
136 | close(result)
137 | }()
138 |
139 | var content string
140 | ticker := time.NewTicker(h.config.WechatTimeout)
141 | select {
142 | case <-ticker.C:
143 | content = getContinueHint()
144 | case content = <-result:
145 | }
146 |
147 | now := time.Now().Unix()
148 | var respMessage wechat.TextMessage
149 | respMessage.FromUserName = reqMessage.ToUserName
150 | respMessage.ToUserName = reqMessage.FromUserName
151 | respMessage.MsgType = reqMessage.MsgType
152 | respMessage.Content = content
153 | respMessage.CreateTime = now
154 |
155 | resp, err := xml.Marshal(respMessage)
156 | if err != nil {
157 | http.Error(w, err.Error(), http.StatusInternalServerError)
158 | return
159 | }
160 |
161 | if reqMessage.Encrypt != "" {
162 | encrypt, err := wechat.EncryptMsg(h.config.WechatAppID, string(resp), h.config.WechatEncodingAESKey)
163 | if err != nil {
164 | http.Error(w, err.Error(), http.StatusInternalServerError)
165 | return
166 | }
167 | encryptMessage := &wechat.EncryptMessage{
168 | Encrypt: encrypt,
169 | MsgSignature: wechat.Signature(h.config.WechatToken, nonce, encrypt, strconv.FormatInt(now, 10)),
170 | Timestamp: now,
171 | Nonce: nonce,
172 | }
173 | resp, err = xml.Marshal(encryptMessage)
174 | if err != nil {
175 | http.Error(w, err.Error(), http.StatusInternalServerError)
176 | return
177 | }
178 | }
179 |
180 | w.Header().Set("Content-Type", "application/xml")
181 | w.WriteHeader(http.StatusOK)
182 | w.Write(resp)
183 | }
184 |
--------------------------------------------------------------------------------