├── .gitignore
├── FUSION_STRATEGIES.md
├── LICENSE
├── README.md
├── TODO.md
├── _conf_schema.json
├── core
├── __init__.py
├── commands
│ ├── __init__.py
│ └── base_command.py
├── community
│ ├── __init__.py
│ └── community_detector.py
├── config_validator.py
├── constants.py
├── engines
│ ├── forgetting_agent.py
│ ├── recall_engine.py
│ └── reflection_engine.py
├── handlers
│ ├── __init__.py
│ ├── admin_handler.py
│ ├── base_handler.py
│ ├── fusion_handler.py
│ ├── memory_handler.py
│ └── search_handler.py
├── models.py
├── models
│ └── memory_models.py
├── retrieval
│ ├── __init__.py
│ ├── result_fusion.py
│ └── sparse_retriever.py
└── utils.py
├── docs
├── CONFIG.md
└── DEVELOPMENT.md
├── main.py
├── metadata.yaml
├── requirements.txt
├── storage
├── __init__.py
├── faiss_manager.py
├── faiss_manager_v2.py
├── graph_storage.py
├── memory_storage.py
└── vector_store.py
└── tests
├── conftest.py
└── unit
├── test_admin_handler.py
├── test_base_handler.py
├── test_memory_handler.py
└── test_search_handler.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python
2 | __pycache__/
3 | *.pyc
4 | *.pyo
5 | *.pyd
6 |
7 | # Database files
8 | *.db
9 | *.index
10 |
11 | # IDE and OS files
12 | .vscode/
13 | .idea/
14 | .DS_Store
15 | work_list.md
16 | CLAUDE.md
17 | .claude/settings.local.json
18 |
19 | __pycache__/
20 | *.py[cod]
21 | *$py.class
22 | *.so
23 | .Python
24 | build/
25 | develop-eggs/
26 | dist/
27 | downloads/
28 | eggs/
29 | .eggs/
30 | lib/
31 | lib64/
32 | parts/
33 | sdist/
34 | var/
35 | wheels/
36 | *.egg-info/
37 | .installed.cfg
38 | *.egg
39 | MANIFEST
40 |
41 | # Virtual environments
42 | .env
43 | .venv
44 | env/
45 | venv/
46 | ENV/
47 | env.bak/
48 | venv.bak/
49 |
50 | # IDE
51 | .idea/
52 | .vscode/
53 | *.swp
54 | *.swo
55 |
56 | # OS
57 | .DS_Store
58 | .DS_Store?
59 | ._*
60 | .Spotlight-V100
61 | .Trashes
62 | ehthumbs.db
63 | Thumbs.db
64 |
65 | # Testing
66 | .pytest_cache/
67 | .coverage
68 | htmlcov/
69 | .tox/
70 |
71 | # Logs
72 | *.log
73 | logs/
74 |
75 | # Temporary files
76 | tmp/
77 | temp/
78 | *.tmp
79 |
80 | # UV (if used)
81 | uv.lock
82 |
83 | # Node.js (if used)
84 | node_modules/
85 | npm-debug.log*
86 | yarn-debug.log*
87 | yarn-error.log*
88 |
89 | data/
90 | .claude/
91 | pyproject.toml
92 |
--------------------------------------------------------------------------------
/FUSION_STRATEGIES.md:
--------------------------------------------------------------------------------
1 | # 混合检索融合策略详解
2 |
3 | LivingMemory 插件支持多种先进的结果融合策略,用于优化混合检索效果。
4 |
5 | ## 🎯 支持的融合策略
6 |
7 | ### 1. RRF (Reciprocal Rank Fusion) - 经典策略
8 | ```
9 | 分数 = Σ(1 / (k + rank_i))
10 | ```
11 | - **适用场景**: 通用场景,平衡性好
12 | - **参数**: `rrf_k` (默认: 60)
13 | - **特点**: 简单有效,对排序位置敏感
14 |
15 | ### 2. Hybrid RRF - 动态RRF
16 | - **适用场景**: 需要根据查询特征自适应调整
17 | - **参数**: `rrf_k`, `diversity_bonus`
18 | - **特点**: 根据查询长度和类型动态调整RRF参数
19 | - **优势**: 短查询偏向稀疏检索,长查询偏向密集检索
20 |
21 | ### 3. Weighted - 加权融合
22 | ```
23 | 分数 = α × dense_score + β × sparse_score
24 | ```
25 | - **适用场景**: 明确知道两种检索器的相对重要性
26 | - **参数**: `dense_weight`, `sparse_weight`
27 | - **特点**: 简单直观,可解释性强
28 |
29 | ### 4. Convex - 凸组合融合
30 | ```
31 | 分数 = λ × norm(dense) + (1-λ) × norm(sparse)
32 | ```
33 | - **适用场景**: 需要数学严格的融合方法
34 | - **参数**: `convex_lambda` (0.0-1.0)
35 | - **特点**: 分数归一化到 [0,1],数学性质好
36 |
37 | ### 5. Interleave - 交替融合
38 | - **适用场景**: 需要保证结果多样性
39 | - **参数**: `interleave_ratio` - 密集结果所占比例
40 | - **特点**: 按比例交替选择不同检索器的结果
41 |
42 | ### 6. Rank Fusion - 基于排序的融合
43 | ```
44 | 分数 = Σ(weight_i / rank_i) + bias(if in both lists)
45 | ```
46 | - **适用场景**: 重视文档在排序列表中的位置
47 | - **参数**: `dense_weight`, `sparse_weight`, `rank_bias_factor`
48 | - **特点**: 在两个列表中都出现的文档获得额外加分
49 |
50 | ### 7. Score Fusion - Borda Count融合
51 | ```
52 | 分数 = Σ(list_size - rank_i) × weight_i
53 | ```
54 | - **适用场景**: 基于排序投票的民主融合
55 | - **参数**: `dense_weight`, `sparse_weight`
56 | - **特点**: 类似选举中的Borda计数法
57 |
58 | ### 8. Cascade - 级联融合
59 | - **适用场景**: 大规模检索,需要效率优化
60 | - **流程**: 稀疏检索初筛 → 密集检索精排
61 | - **特点**: 先用快速的稀疏检索筛选候选,再用精确的密集检索排序
62 |
63 | ### 9. Adaptive - 自适应融合
64 | - **适用场景**: 查询类型多样的场景
65 | - **策略**: 根据查询特征选择最优融合方法
66 | - 关键词查询 → 偏向稀疏检索
67 | - 语义查询 → 偏向密集检索
68 | - 混合查询 → 使用RRF
69 |
70 | ## 📊 性能特征对比
71 |
72 | | 策略 | 计算复杂度 | 参数调优难度 | 适应性 | 可解释性 |
73 | |------|-----------|-------------|--------|----------|
74 | | RRF | 低 | 低 | 中 | 中 |
75 | | Hybrid RRF | 中 | 中 | 高 | 中 |
76 | | Weighted | 低 | 低 | 低 | 高 |
77 | | Convex | 低 | 中 | 中 | 高 |
78 | | Interleave | 低 | 低 | 低 | 高 |
79 | | Rank Fusion | 中 | 中 | 中 | 中 |
80 | | Score Fusion | 高 | 中 | 中 | 中 |
81 | | Cascade | 低 | 低 | 低 | 高 |
82 | | Adaptive | 中 | 高 | 高 | 低 |
83 |
84 | ## 🛠️ 使用指南
85 |
86 | ### 配置示例
87 |
88 | ```yaml
89 | fusion:
90 | strategy: "hybrid_rrf"
91 | rrf_k: 60
92 | dense_weight: 0.7
93 | sparse_weight: 0.3
94 | diversity_bonus: 0.1
95 | convex_lambda: 0.5
96 | interleave_ratio: 0.6
97 | rank_bias_factor: 0.15
98 | ```
99 |
100 | ### 命令行管理
101 |
102 | ```bash
103 | # 查看当前配置
104 | /lmem fusion show
105 |
106 | # 切换到混合RRF
107 | /lmem fusion hybrid_rrf
108 |
109 | # 调整凸组合参数
110 | /lmem fusion convex convex_lambda=0.6
111 |
112 | # 调整权重
113 | /lmem fusion weighted dense_weight=0.8
114 |
115 | # 测试融合效果
116 | /lmem test_fusion "用户的兴趣爱好" 5
117 | ```
118 |
119 | ## 🎯 策略选择建议
120 |
121 | ### 场景驱动的选择
122 |
123 | 1. **通用聊天机器人**: `hybrid_rrf` 或 `rrf`
124 | - 查询类型多样,需要自适应能力
125 |
126 | 2. **专业知识问答**: `weighted` 或 `convex`
127 | - 可以明确调优权重,提高精确度
128 |
129 | 3. **多样性优先**: `interleave`
130 | - 确保结果不会过于相似
131 |
132 | 4. **大规模数据库**: `cascade`
133 | - 效率优先,两阶段处理
134 |
135 | 5. **实验和调优**: `score_fusion` 或 `rank_fusion`
136 | - 更复杂的融合逻辑,适合深度优化
137 |
138 | ### 参数调优指南
139 |
140 | 1. **RRF参数 k**:
141 | - 较小值 (30-50): 更重视排序靠前的结果
142 | - 较大值 (80-120): 更平衡地考虑所有结果
143 |
144 | 2. **权重比例**:
145 | - 密集权重 > 稀疏权重: 语义查询为主
146 | - 密集权重 < 稀疏权重: 关键词查询为主
147 |
148 | 3. **多样性参数**:
149 | - 较大值: 鼓励结果多样性
150 | - 较小值: 优先考虑相关性
151 |
152 | ## 🧪 实验和评估
153 |
154 | 使用 `/lmem test_fusion` 命令可以:
155 | - 测试不同策略的效果
156 | - 查看融合过程的详细信息
157 | - 对比不同参数设置的结果
158 |
159 | 建议在实际数据上进行A/B测试,选择最适合你使用场景的融合策略。
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LivingMemory - 动态生命周期记忆插件 v1.0.0
2 |
3 |
4 |
5 |
6 |

7 |
8 |
9 | 🧪 测...测不动了
10 |
11 | > ⚠️ **需要测试验证**: 请在测试环境中验证所有新功能
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 | 为 AstrBot 打造的、拥有完整记忆生命周期的智能长期记忆插件。
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 | ---
35 |
36 | `LivingMemory` 告别了传统记忆插件对大型数据库的依赖,创新性地采用轻量级的 `Faiss` 和 `SQLite` 作为存储后端。这不仅实现了 **零配置部署** 和 **极低的资源消耗**,更引入了革命性的 **动态记忆生命周期 (Dynamic Memory Lifecycle)** 模型。
37 |
38 | ## ✨ 核心特性:三大引擎架构
39 |
40 | 本插件通过三大智能引擎的协同工作,完美模拟了人类记忆的形成、巩固、联想和遗忘的全过程。
41 |
42 | | 引擎 | 图标 | 核心功能 | 最新增强 |
43 | | :--- | :---: | :--- | :--- |
44 | | **反思引擎** | 🧠 | `智能总结` & `重要性评估` | ✨ 增强错误重试机制,提高提取成功率 |
45 | | **回忆引擎** | 🔍 | `混合检索` & `智能融合` | 🚀 **9种先进融合策略** - RRF、自适应、级联等 |
46 | | **遗忘代理** | 🗑️ | `遗忘曲线` & `批量清理` | 💾 **分页处理** - 优化内存使用,支持大规模数据 |
47 |
48 | ## 🔥 新特性:混合检索系统
49 |
50 | ### 检索模式
51 | - **🔍 Dense (密集检索)**: 基于语义向量的深度理解
52 | - **⚡ Sparse (稀疏检索)**: BM25关键词匹配,支持中文分词
53 | - **🤝 Hybrid (混合检索)**: 智能融合两种检索方式的优势
54 |
55 | ### 9种融合策略
56 |
57 | | 策略 | 特点 | 适用场景 | 计算复杂度 |
58 | |------|------|----------|------------|
59 | | **RRF** | 经典倒数排名融合 | 通用场景,平衡性好 | 低 |
60 | | **Hybrid RRF** | 动态参数调整 | 自适应查询类型 | 中 |
61 | | **Weighted** | 简单加权融合 | 明确权重偏好 | 低 |
62 | | **Convex** | 凸组合数学融合 | 需要严格数学性质 | 低 |
63 | | **Interleave** | 交替选择结果 | 保证结果多样性 | 低 |
64 | | **Rank Fusion** | 基于排序位置 | 重视排序信息 | 中 |
65 | | **Score Fusion** | Borda Count投票 | 民主投票机制 | 高 |
66 | | **Cascade** | 两阶段处理 | 大规模高效检索 | 低 |
67 | | **Adaptive** | 查询特征自适应 | 多样化查询场景 | 中 |
68 |
69 | ### 智能查询分析
70 | 系统会自动分析查询特征,选择最优融合策略:
71 | - **关键词查询** → 偏向稀疏检索
72 | - **语义查询** → 偏向密集检索
73 | - **混合查询** → 使用RRF平衡融合
74 |
75 | ## 🚀 快速开始
76 |
77 | ### 1. 安装
78 |
79 | 将 `astrbot_plugin_livingmemory` 文件夹放置于 AstrBot 的 `data/plugins` 目录下。AstrBot 将自动检测并安装依赖。
80 |
81 | **核心依赖:**
82 | ```
83 | faiss-cpu>=1.7.0
84 | pydantic>=1.8.0
85 | jieba>=0.42.1
86 | ```
87 |
88 | **🧪 测试阶段**: 插件已完成重构和功能增强,正在进行全面测试验证。
89 |
90 | ### 2. 配置
91 |
92 | **✨ 全新配置系统**: 基于 Pydantic 的智能配置验证,确保参数有效性。
93 |
94 |
95 | ⚙️ 点击展开详细配置说明
96 |
97 | #### 🔧 基础设置
98 | - **Provider设置**: 自定义 Embedding 和 LLM Provider,支持多Provider混用
99 | - **时区配置**: 支持全球时区,时间显示本地化
100 | - **会话管理**: 智能会话生命周期,自动清理过期会话
101 |
102 | #### 🔍 检索配置
103 | - **检索模式**: `hybrid`(混合) | `dense`(密集) | `sparse`(稀疏)
104 | - **融合策略**: 9种策略可选,支持动态参数调整
105 | - **BM25参数**: 可调整k1、b参数优化中文检索效果
106 | - **权重控制**: 相似度、重要性、新近度权重精细调节
107 |
108 | #### 🧠 智能引擎配置
109 | - **反思触发**: 可配置对话轮次阈值(1-100轮)
110 | - **重要性评估**: 自定义重要性阈值(0.0-1.0)
111 | - **自定义提示词**: 完全可定制的事件提取和评估提示
112 |
113 | #### 🗑️ 遗忘机制
114 | - **智能清理**: 基于重要性衰减的自动清理
115 | - **批量处理**: 分页加载,支持大规模记忆库
116 | - **保留策略**: 灵活的天数和阈值配置
117 |
118 | #### 🛡️ 过滤隔离
119 | - **人格过滤**: 按AI人格隔离记忆,互不干扰
120 | - **会话隔离**: 会话级别的记忆独立性
121 | - **状态管理**: 记忆状态(活跃/归档/删除)精细控制
122 |
123 |
124 |
125 | ### 3. 高级配置示例
126 |
127 | ```yaml
128 | # 针对中文优化的推荐配置
129 | sparse_retriever:
130 | bm25_k1: 1.2 # 中文词频参数
131 | bm25_b: 0.75 # 文档长度归一化
132 | use_jieba: true # 启用中文分词
133 |
134 | fusion:
135 | strategy: "hybrid_rrf" # 自适应融合策略
136 | rrf_k: 60 # RRF参数
137 | diversity_bonus: 0.1 # 多样性奖励
138 |
139 | recall_engine:
140 | retrieval_mode: "hybrid" # 混合检索模式
141 | similarity_weight: 0.6 # 相似度权重
142 | importance_weight: 0.2 # 重要性权重
143 | recency_weight: 0.2 # 新近度权重
144 | ```
145 |
146 | ## 🛠️ 完整管理命令
147 |
148 | 插件在后台自动运行,同时提供了强大的命令行管理界面:
149 |
150 | ### 📊 基础管理
151 | | 命令 | 参数 | 描述 |
152 | | :--- | :--- | :--- |
153 | | `/lmem status` | - | 📈 查看记忆库状态和统计信息 |
154 | | `/lmem search` | ` [k=3]` | 🔍 手动搜索记忆,支持详细信息展示 |
155 | | `/lmem forget` | `` | 🗑️ 删除指定ID的记忆 |
156 |
157 | ### 🧠 引擎管理
158 | | 命令 | 参数 | 描述 |
159 | | :--- | :--- | :--- |
160 | | `/lmem run_forgetting_agent` | - | 🔄 手动触发遗忘代理清理任务 |
161 | | `/lmem sparse_rebuild` | - | 🏗️ 重建稀疏检索索引 |
162 | | `/lmem sparse_test` | ` [k=5]` | ⚡ 测试稀疏检索功能 |
163 |
164 | ### ⚙️ 配置管理
165 | | 命令 | 参数 | 描述 |
166 | | :--- | :--- | :--- |
167 | | `/lmem config` | `[show\|validate]` | 📋 显示或验证当前配置 |
168 | | `/lmem search_mode` | `` | 🔄 切换检索模式 |
169 |
170 | ### 🔄 融合策略管理
171 | | 命令 | 参数 | 描述 |
172 | | :--- | :--- | :--- |
173 | | `/lmem fusion` | `[strategy] [param=value]` | 🎯 管理融合策略和参数 |
174 | | `/lmem test_fusion` | ` [k=5]` | 🧪 测试当前融合策略效果 |
175 |
176 | ### 📝 记忆编辑 *(新增)*
177 | | 命令 | 参数 | 描述 |
178 | | :--- | :--- | :--- |
179 | | `/lmem edit` | ` [reason]` | ✏️ 精确编辑记忆内容或元数据 |
180 | | `/lmem update` | `` | 📝 交互式记忆编辑引导 |
181 | | `/lmem history` | `` | 📚 查看记忆的完整更新历史 |
182 |
183 | #### 编辑示例
184 | ```bash
185 | # 编辑记忆内容
186 | /lmem edit 123 content 这是新的记忆内容 修正错误信息
187 |
188 | # 调整重要性评分
189 | /lmem edit 123 importance 0.9 提高重要性
190 |
191 | # 更改记忆类型
192 | /lmem edit 123 type PREFERENCE 重新分类为偏好
193 |
194 | # 归档记忆
195 | /lmem edit 123 status archived 项目已完成
196 | ```
197 |
198 | #### 融合策略示例
199 | ```bash
200 | # 查看当前融合配置
201 | /lmem fusion show
202 |
203 | # 切换到自适应RRF
204 | /lmem fusion hybrid_rrf
205 |
206 | # 调整凸组合参数
207 | /lmem fusion convex convex_lambda=0.6
208 |
209 | # 调整加权融合权重
210 | /lmem fusion weighted dense_weight=0.8
211 |
212 | # 测试融合效果
213 | /lmem test_fusion "用户的兴趣爱好" 5
214 | ```
215 |
216 | ## 🎯 性能优化与最佳实践
217 |
218 | ### 💾 内存管理
219 | - **分页加载**: 支持大规模记忆库,避免OOM
220 | - **会话清理**: 智能TTL机制,自动清理过期会话
221 | - **事务安全**: SQLite事务保证数据一致性
222 |
223 | ### 🚀 并发处理
224 | - **异步初始化**: 避免阻塞主线程
225 | - **后台任务**: 反思和遗忘任务后台执行
226 | - **错误重试**: 自动重试机制,提高稳定性
227 |
228 | ### 🔧 故障排除
229 |
230 |
231 | 🚨 常见问题解决方案
232 |
233 | #### Q: 插件初始化失败
234 | ```bash
235 | # 检查依赖安装
236 | pip install faiss-cpu pydantic jieba
237 |
238 | # 验证配置
239 | /lmem config validate
240 | ```
241 |
242 | #### Q: 检索效果不佳
243 | ```bash
244 | # 尝试不同融合策略
245 | /lmem fusion adaptive
246 |
247 | # 重建稀疏索引
248 | /lmem sparse_rebuild
249 |
250 | # 调整检索模式
251 | /lmem search_mode hybrid
252 | ```
253 |
254 | #### Q: 内存占用过高
255 | ```bash
256 | # 手动触发遗忘
257 | /lmem run_forgetting_agent
258 |
259 | # 检查会话数量
260 | /lmem config show
261 | ```
262 |
263 |
264 |
265 | ## 📚 相关文档
266 |
267 | - 📖 [融合策略详解](FUSION_STRATEGIES.md) - 深入了解9种融合算法
268 | - ⚙️ [配置参考](docs/CONFIG.md) - 完整配置参数说明
269 | - 🔧 [开发指南](docs/DEVELOPMENT.md) - 插件开发和扩展指南
270 |
271 | ## 🤝 贡献
272 |
273 | 欢迎各种形式的贡献:
274 | - 🐛 **问题报告**: [GitHub Issues](https://github.com/lxfight/astrbot_plugin_livingmemory/issues)
275 | - 💡 **功能建议**: [Feature Requests](https://github.com/lxfight/astrbot_plugin_livingmemory/issues/new?template=feature_request.md)
276 | - 🔧 **代码贡献**: [Pull Requests](https://github.com/lxfight/astrbot_plugin_livingmemory/pulls)
277 | - 📖 **文档改进**: 欢迎改进文档和示例
278 |
279 | ## 交流一下
280 | 遇到问题或想交流使用心得?加入我们的讨论群:
281 | [](https://qm.qq.com/cgi-bin/qm/qr?k=WdyqoP-AOEXqGAN08lOFfVSguF2EmBeO&jump_from=webapi&authKey=tPyfv90TVYSGVhbAhsAZCcSBotJuTTLf03wnn7/lQZPUkWfoQ/J8e9nkAipkOzwh)
282 |
283 | `入关口令`: `lxfight`
284 |
285 | ## 📄 许可证
286 |
287 | 本项目遵循 **AGPLv3** 许可证 - 详见 [LICENSE](LICENSE) 文件。
288 |
289 | ---
290 |
291 |
292 |
293 |
294 | **⭐ 如果这个项目对您有帮助,请给我们一个 Star!**
295 |
296 |
297 |
298 | *LivingMemory - 让AI拥有真正的生命记忆 🧠✨*
299 |
300 |
301 |
--------------------------------------------------------------------------------
/TODO.md:
--------------------------------------------------------------------------------
1 | # 关于 livingmemory 插件的 TODO 列表
2 |
3 | - [ ] 设计记忆的数据结构
4 | - [ ] 设计记忆的召回方式
5 | - [ ] 设计记忆的更新方式
6 | - [ ] 记忆的遗忘逻辑
7 |
8 | # 一些潜在的规划
9 |
10 | {
11 | "memory_id": "String", // [主键] 记忆的全局唯一标识符 (建议使用UUID)
12 | "timestamp": "String", // 事件发生的时间戳 (ISO 8601格式, e.g., "2025-07-26T14:30:00Z")
13 | "summary": "String", // AI生成的、可供快速预览的单行摘要或标题
14 | "description": "String", // 对事件的完整、详细的自然语言描述
15 | "embedding": "Array", // 由`description`或`summary`+`description`生成的文本嵌入向量,用于Faiss
16 |
17 | "linked_media": [
18 | // 多模态内容部分
19 | {
20 | "media_id": "String", // 媒体文件的唯一ID
21 | "media_type": "String", // 'image', 'audio', 'document', 'video', 'code_snippet'
22 | "url": "String", // 文件的存储位置 (e.g., S3 URL or local path)
23 | "caption": "String", // 对媒体内容的简短描述
24 | "embedding": "Array" // 媒体文件本身的多模态嵌入 (e.g., CLIP vector for images)
25 | }
26 | ],
27 | "metadata": {
28 | "source_conversation_id": "String", // 此记忆来源的对话ID,用于溯源
29 | "memory_type": "String", // 记忆类型: 'episodic' (情景), 'semantic' (事实), 'procedural' (程序)
30 | "importance_score": "Float", // [遗忘引擎关键输入] 记忆的重要性评分 (0.0 to 1.0)
31 | "confidence_score": "Float", // [可选] NLU模型提取此记忆信息的置信度 (0.0 to 1.0)
32 | "access_info": {
33 | "initial_creation_timestamp": "String", // 记忆被创建的时间
34 | "last_accessed_timestamp": "String", // [遗忘引擎关键输入] 记忆最近被访问的时间
35 | "access_count": "Integer" // [遗忘引擎关键输入] 记忆被访问的总次数
36 | },
37 | "emotional_valence": {
38 | "sentiment": "String", // 情感倾向: 'positive', 'negative', 'neutral'
39 | "intensity": "Float" // 情感强度 (0.0 to 1.0)
40 | },
41 | "user_feedback": {
42 | "is_accurate": "Boolean", // 用户是否标记此记忆为准确 (null, true, false)
43 | "is_important": "Boolean", // 用户是否标记此记忆为重要 (null, true, false)
44 | "correction_text": "String" // 用户提供的修正文本
45 | },
46 | "community_info": {
47 | // 用于图社区发现
48 | "id": "String", // 此记忆所属社区/簇的唯一ID
49 | "last_calculated": "String" // 上次计算社区分配的时间戳
50 | }
51 | },
52 |
53 | "knowledge_graph_payload": {
54 | "event_entity": {
55 | "event_id": "String", // [图节点] 事件的唯一ID
56 | "event_type": "String" // 事件的分类, e.g., 'ProjectInitiation', 'TravelPlanning'
57 | },
58 | "entities": [
59 | {
60 | "entity_id": "String", // [图节点] 实体的全局唯一ID
61 | "name": "String", // 实体的名称
62 | "type": "String", // 实体类型, e.g., 'PERSON', 'PROJECT', 'ORGANIZATION'
63 | "role": "String" // [可选] 实体在此次事件中的具体角色
64 | }
65 | ],
66 | "relationships": [
67 | "Array" // [图的边] 定义关系的三元组 [主体ID, 关系谓词, 客体ID]
68 | ]
69 | }
70 | }
71 |
72 |
73 | 示例
74 |
75 | ```json
76 | {
77 | "memory_id": "mem_e2a1b3c4-d5e6-f7g8-h9i0-j1k2l3m4n5o6",
78 | "timestamp": "2025-07-28T11:00:00Z",
79 | "summary": "张伟分享了'凤凰计划'的登录页设计图",
80 | "description": "用户张伟分享了一张关于'凤凰计划'的登录页面初稿图片,并征求反馈。他还提及这张设计图关联到他们上一次会议中确定的蓝色主题。",
81 | "embedding": [-0.15, 0.33, 0.81, ..., -0.05],
82 |
83 | "linked_media": [
84 | {
85 | "media_id": "img_login_mockup_v1",
86 | "media_type": "image",
87 | "url": "s3://my-ai-project-bucket/media/img_login_mockup_v1.png",
88 | "caption": "'凤凰计划'登录页面的UI设计初稿",
89 | "embedding": [0.67, 0.12, -0.29, ..., 0.44]
90 | }
91 | ],
92 |
93 | "metadata": {
94 | "source_conversation_id": "conv_z1y2x3w-v4u5-t6s7-r8q9-p0o9n8m7l6k5",
95 | "memory_type": "episodic",
96 | "importance_score": 0.85,
97 | "confidence_score": 0.99,
98 | "access_info": {
99 | "initial_creation_timestamp": "2025-07-28T11:01:30Z",
100 | "last_accessed_timestamp": "2025-07-28T11:01:30Z",
101 | "access_count": 1
102 | },
103 | "emotional_valence": {
104 | "sentiment": "neutral",
105 | "intensity": 0.3
106 | },
107 | "user_feedback": {
108 | "is_accurate": null,
109 | "is_important": null,
110 | "correction_text": null
111 | },
112 | "community_info": {
113 | "id": "community_proj_phoenix_001",
114 | "last_calculated": "2025-07-27T04:00:00Z"
115 | }
116 | },
117 |
118 | "knowledge_graph_payload": {
119 | "event_entity": {
120 | "event_id": "evt_design_review_001",
121 | "event_type": "DesignReview"
122 | },
123 | "entities": [
124 | { "entity_id": "person_zhang_wei_001", "name": "张伟", "type": "PERSON", "role": "author" },
125 | { "entity_id": "project_phoenix_001", "name": "凤凰计划", "type": "PROJECT", "role": "context" },
126 | { "entity_id": "asset_login_mockup_001", "name": "登录页面初稿", "type": "DESIGN_ASSET", "role": "subject" }
127 | ],
128 | "relationships": [
129 | ["person_zhang_wei_001", "CREATED", "asset_login_mockup_001"],
130 | ["asset_login_mockup_001", "IS_PART_OF", "project_phoenix_001"],
131 | ["evt_design_review_001", "IS_ABOUT", "asset_login_mockup_001"],
132 | ["evt_design_review_001", "REFERENCES", "evt_meeting_042"]
133 | ]
134 | }
135 | }
136 | ```
137 |
138 | ### 核心 AI 生成内容
139 |
140 | 这部分是 AI 的核心创造性工作,直接体现了模型的智能。
141 |
142 | | 字段路径 | 字段名 | AI 负责的工作 | 需要的模型/技术 |
143 | | :-------------------------- | :--------------- | :---------------------------------------------------------------------- | :----------------------------------- |
144 | | `summary` | **摘要** | 阅读`description`全文,生成一段简短、精炼的标题或摘要。 | 大型语言模型 (LLM) - 文本摘要任务 |
145 | | `embedding` | **文本嵌入向量** | 将`description`的语义信息编码成一个高维浮点数向量。 | 文本嵌入模型 (Text Embedding Model) |
146 | | `linked_media[].embedding` | **媒体嵌入向量** | 将图片、音频等媒体文件编码成一个与文本在同一空间的高维向量。 | 多模态嵌入模型 (e.g., CLIP) |
147 | | `linked_media[].caption` | **媒体内容描述** | (可选) 如果用户没有提供,AI 可以“看图说话”,为上传的图片生成描述。 | 视觉语言模型 (Vision-Language Model) |
148 | | `metadata.importance_score` | **重要性评分** | 根据`description`的内容,判断该记忆对用户的重要性,并给出一个量化分数。 | 大型语言模型 (LLM) - 分类/回归任务 |
149 | | `metadata.confidence_score` | **置信度评分** | NLU 模型在完成实体和关系提取后,对其结果的确定程度给出的一个分数。 | 自然语言理解模型 (NLU Model) |
150 |
151 | ### AI 提取与结构化内容
152 |
153 | 这部分是 AI 的“理解”工作,将非结构化的对话转化为机器可读的结构化数据。
154 |
155 | | 字段路径 | 字段名 | AI 负责的工作 | 需要的模型/技术 |
156 | | :------------------------------------------------ | :----------- | :----------------------------------------------------------------------------------------- | :------------------------------------------------------------------ |
157 | | `metadata.memory_type` | **记忆类型** | 分析记忆内容,将其分类为情景记忆、事实记忆或程序记忆。 | 自然语言理解 (NLU) - 文本分类 |
158 | | `metadata.emotional_valence` | **情感倾向** | 分析`description`中的情感色彩(积极/消极/中性)和强度。 | 自然语言理解 (NLU) - 情感分析 |
159 | | `knowledge_graph_payload.event_entity.event_type` | **事件类型** | 将整个事件归类到一个预定义的类型中 (e.g., `TravelPlanning`)。 | 自然语言理解 (NLU) - 文本分类 |
160 | | `knowledge_graph_payload.entities` | **实体列表** | 从`description`中识别出所有关键实体(人名、项目、组织等),并进行标准化(如分配唯一 ID)。 | 自然语言理解 (NLU) - 命名实体识别 (NER) & 实体链接 (Entity Linking) |
161 | | `knowledge_graph_payload.relationships` | **关系列表** | 识别并提取已识别出的实体之间的关系三元组(主语-谓词-宾语)。 | 自然语言理解 (NLU) - 关系提取 (Relation Extraction) |
162 |
163 | ---
164 |
165 | ### 非 AI 生成(系统或用户提供)
166 |
167 | 为了完整性,以下是基本不需要 AI 介入,由系统逻辑、用户输入或其他算法直接填充的字段。
168 |
169 | | 字段路径 | 字段名 | 来源 |
170 | | :-------------------------------- | :--------- | :-------------------------------------------- |
171 | | `memory_id` | 记忆 ID | 系统生成 (UUID) |
172 | | `timestamp` | 事件时间戳 | 系统获取或用户指定 |
173 | | `description` | 详细描述 | 用户的原始输入或对话记录 |
174 | | `linked_media[].media_id` | 媒体 ID | 系统生成 |
175 | | `linked_media[].url` | 文件 URL | 文件存储系统返回 |
176 | | `metadata.source_conversation_id` | 对话 ID | 系统记录 |
177 | | `metadata.access_info` | 访问信息 | 系统根据调用情况更新(时间戳、计数器) |
178 | | `metadata.user_feedback` | 用户反馈 | 用户直接提供 |
179 | | `metadata.community_info` | 社区信息 | 后台的图计算**算法**(非生成式 AI)运行后填充 |
180 |
181 | ## 2. 记忆的召回方式 (设计中)
182 |
183 | 为实现精准且富有洞察力的回忆,我们设计一个**两阶段混合检索模型**,结合向量的语义相似度和图的结构关联性。
184 |
185 | ### **阶段一:候选生成 (Candidate Generation) - 追求召回率**
186 |
187 | 1. **输入**: 用户的当前查询(文本、图片等)。
188 | 2. **处理**:
189 | - 将用户查询通过相应的嵌入模型(文本或多模态)转换为查询向量。
190 | - 在 Faiss 数据库中,使用此查询向量进行`k`-近邻搜索(例如,k=50),得到一个包含 50 个最相似记忆`memory_id`的初始候选集。
191 | - 同时,NLU 模块从查询中提取核心实体(如“张伟”)。
192 | 3. **输出**: 一个包含`memory_id`、Faiss 相似度分数和查询实体的候选列表。
193 |
194 | ### **阶段二:重排与扩展 (Re-ranking & Expansion) - 追求精确率**
195 |
196 | 1. **输入**: 阶段一生成的候选列表。
197 | 2. **处理**:
198 | - **加载图**: 使用内存图库(如`NetworkX`)加载所有记忆的`knowledge_graph_payload`构建的知识图谱。
199 | - **计算图分数 (Graph Score)**: 对于候选集中的每一个记忆,计算其与“查询实体”在图中的关联强度。
200 | - **方法**: 可以采用个性化 PageRank(从查询实体节点开始游走),或计算图中查询实体节点到记忆事件节点的最短路径长度。路径越短,关联越强,分数越高。
201 | - **扩展发现**: 从查询实体出发,在图中寻找其一度或二度关联的实体和事件(例如,张伟的同事“陈静”,参与的项目“凤凰计划”)。如果候选记忆与这些扩展出的实体相关,则其图分数应获得额外加成。
202 | - **计算最终分数**:
203 | `FinalScore = (w1 * FaissScore) + (w2 * GraphScore)`
204 | - `w1`和`w2`是可配置的权重,例如`w1=0.6`, `w2=0.4`。
205 | 3. **输出**: 一个根据`FinalScore`重新排序的、高质量的记忆列表,提交给上层应用(如 LLM)用于生成最终回复。
206 |
207 | ---
208 |
209 | ## 3. 记忆的更新方式 (设计中)
210 |
211 | 记忆不是一成不变的。插件必须支持记忆的演化和修正,核心原则是**优先创建、避免覆盖**,以保留信息的历史轨迹。
212 |
213 | 1. **元数据更新**:
214 |
215 | - **场景**: 记忆被访问。
216 | - **操作**: 直接修改`metadata.access_info`中的`last_accessed_timestamp`和`access_count`字段。这是唯一允许直接原地修改的操作。
217 |
218 | 2. **基于反馈的修正**:
219 |
220 | - **场景**: 用户通过`metadata.user_feedback`提供了修正(例如,“李娜是我的同事,不是朋友”)。
221 | - **操作**:
222 | a. 创建一个**新的记忆**,包含正确的`description`和`knowledge_graph_payload`。
223 | b. 在新记忆的`knowledge_graph_payload.relationships`中,添加一条关系指向旧记忆:`["new_memory_event_id", "CORRECTS", "old_memory_event_id"]`。
224 | c. 将旧记忆的`importance_score`大幅降低,或者标记为“已修正”。
225 |
226 | 3. **事件演化更新**:
227 | - **场景**: 事情发生了变化(例如,会议从周二改到周三)。
228 | - **操作**:
229 | a. 创建一个**新的记忆**来记录新状态(“会议将在周三举行”)。
230 | b. 在新记忆的图关系中,添加`["new_event_id", "UPDATES", "old_event_id"]`。
231 | c. 旧记忆保持不变,但其在未来的检索中权重会因“被更新”而降低。
232 |
233 | ---
234 |
235 | ## 4. 记忆的遗忘逻辑 (设计中)
236 |
237 | 为了防止记忆无限膨胀并保持检索效率,需要一个智能的遗忘机制。该机制基于一个可计算的**“记忆衰减分数”(Decay Score)**。
238 |
239 | 1. **衰减分数计算**:
240 |
241 | - 一个后台任务会定期(如每天)扫描所有记忆,并计算其衰减分数。
242 | - **公式**: `DecayScore = f(ElapsedTime, AccessCount, ImportanceScore, UserFeedback)`
243 | - **示例**: `DecayScore = (CurrentTime - last_accessed_timestamp) / (log(access_count + 1) * (importance_score + user_marked_important * 10))`
244 | - 这个公式意味着:时间越久,分数越高(越容易遗忘);访问次数越多、重要性越高、被用户标记为重要,分数越低(越不容易遗忘)。
245 |
246 | 2. **分层遗忘策略**:
247 |
248 | - 系统设定几个衰减分数的阈值,对应不同的操作。
249 | - **第一阈值 (e.g., Score > 100)**: **归档 (Archive)**。
250 | - **操作**: 为了节省高性能存储和 Faiss 索引空间,可以从记忆中移除`embedding`向量。记忆的 JSON 文本和元数据被转移到更廉价的“冷存储”中。此记忆不再参与常规的向量检索,但仍可通过 ID 或关键词搜索找到。
251 | - **第二阈值 (e.g., Score > 500)**: **标记为待删除 (Mark for Deletion)**。
252 | - **操作**: 系统将记忆标记为待删除,进入一个短暂的“回收站”状态。
253 | - **最终清理**: 一个独立的、执行频率更低的任务会永久删除那些被标记了足够长时间(如 30 天)的记忆。
254 |
255 | 3. **豁免机制**:
256 | - 任何被用户通过`user_feedback`标记为`is_important: true`的记忆,其衰减分数计算将获得极高的权重,或直接豁免于遗忘逻辑。
257 | - 核心的`semantic`类型记忆(事实性知识)的衰减速度应远低于`episodic`类型(情景性记忆)。
258 |
--------------------------------------------------------------------------------
/_conf_schema.json:
--------------------------------------------------------------------------------
1 | {
2 | "timezone_settings": {
3 | "description": "时区设置",
4 | "type": "object",
5 | "items": {
6 | "timezone": {
7 | "description": "时区",
8 | "hint": "用于生成带时区信息的时间对象,请使用 IANA 时区数据库名称,例如 'Asia/Shanghai', 'America/New_York'。",
9 | "type": "string",
10 | "default": "Asia/Shanghai"
11 | }
12 | }
13 | },
14 | "provider_settings": {
15 | "description": "模型提供商设置",
16 | "type": "object",
17 | "items": {
18 | "embedding_provider_id": {
19 | "description": "Embedding Provider ID",
20 | "hint": "用于生成向量的 Embedding Provider ID。如果留空,将使用 AstrBot 的默认 Embedding Provider。",
21 | "type": "string",
22 | "default": ""
23 | },
24 | "llm_provider_id": {
25 | "description": "LLM Provider ID",
26 | "hint": "用于总结和评估记忆的 LLM Provider ID。如果留空,将使用 AstrBot 的默认 LLM Provider。",
27 | "type": "string",
28 | "default": ""
29 | }
30 | }
31 | },
32 | "session_manager": {
33 | "description": "会话管理器设置",
34 | "hint": "管理用户会话状态和内存",
35 | "type": "object",
36 | "items": {
37 | "max_sessions": {
38 | "description": "最大会话数量",
39 | "hint": "同时维护的最大会话数量,超过此数量将清理最旧的会话。",
40 | "type": "int",
41 | "default": 1000
42 | },
43 | "session_ttl": {
44 | "description": "会话生存时间(秒)",
45 | "hint": "会话的最大空闲时间,超过此时间的会话将被自动清理。",
46 | "type": "int",
47 | "default": 3600
48 | }
49 | }
50 | },
51 | "recall_engine": {
52 | "description": "回忆引擎设置",
53 | "hint": "负责智能检索记忆",
54 | "type": "object",
55 | "items": {
56 | "top_k": {
57 | "description": "单次检索数量",
58 | "hint": "单次检索返回的最相关记忆数量。",
59 | "type": "int",
60 | "default": 5
61 | },
62 | "recall_strategy": {
63 | "description": "召回策略",
64 | "hint": "'similarity' - 仅基于相似度;'weighted' - 综合加权。",
65 | "type": "string",
66 | "options": [
67 | "similarity",
68 | "weighted"
69 | ],
70 | "default": "weighted"
71 | },
72 | "retrieval_mode": {
73 | "description": "检索模式",
74 | "hint": "'dense' - 纯密集检索;'sparse' - 纯稀疏检索;'hybrid' - 混合检索。",
75 | "type": "string",
76 | "options": [
77 | "dense",
78 | "sparse",
79 | "hybrid"
80 | ],
81 | "default": "hybrid"
82 | },
83 | "similarity_weight": {
84 | "description": "相关性权重",
85 | "hint": "范围 (0.0 - 1.0)",
86 | "type": "float",
87 | "default": 0.6
88 | },
89 | "recency_weight": {
90 | "description": "新近度权重",
91 | "hint": "范围 (0.0 - 1.0)",
92 | "type": "float",
93 | "default": 0.2
94 | },
95 | "importance_weight": {
96 | "description": "重要性权重",
97 | "hint": "范围 (0.0 - 1.0)",
98 | "type": "float",
99 | "default": 0.2
100 | }
101 | }
102 | },
103 | "fusion": {
104 | "description": "结果融合配置",
105 | "hint": "混合检索时的结果融合策略",
106 | "type": "object",
107 | "items": {
108 | "strategy": {
109 | "description": "融合策略",
110 | "hint": "选择混合检索的结果融合方法",
111 | "type": "string",
112 | "options": [
113 | "rrf",
114 | "hybrid_rrf",
115 | "weighted",
116 | "convex",
117 | "interleave",
118 | "rank_fusion",
119 | "score_fusion",
120 | "cascade",
121 | "adaptive"
122 | ],
123 | "default": "rrf"
124 | },
125 | "rrf_k": {
126 | "description": "RRF参数k",
127 | "hint": "Reciprocal Rank Fusion 中的参数 k,较小值更重视靠前的结果。",
128 | "type": "int",
129 | "default": 60
130 | },
131 | "dense_weight": {
132 | "description": "密集检索权重",
133 | "hint": "密集向量检索在融合中的权重 (0.0 - 1.0)",
134 | "type": "float",
135 | "default": 0.7
136 | },
137 | "sparse_weight": {
138 | "description": "稀疏检索权重",
139 | "hint": "稀疏关键词检索在融合中的权重 (0.0 - 1.0)",
140 | "type": "float",
141 | "default": 0.3
142 | },
143 | "convex_lambda": {
144 | "description": "凸组合参数λ",
145 | "hint": "凸组合融合中的参数λ,控制密集和稀疏检索的混合比例 (0.0 - 1.0)",
146 | "type": "float",
147 | "default": 0.5
148 | },
149 | "interleave_ratio": {
150 | "description": "交替融合比例",
151 | "hint": "交替融合中密集结果所占的比例 (0.0 - 1.0)",
152 | "type": "float",
153 | "default": 0.5
154 | },
155 | "rank_bias_factor": {
156 | "description": "排序偏置因子",
157 | "hint": "在两个排序列表中都出现的文档获得的额外加分 (0.0 - 1.0)",
158 | "type": "float",
159 | "default": 0.1
160 | },
161 | "diversity_bonus": {
162 | "description": "多样性奖励",
163 | "hint": "鼓励结果多样性的奖励因子 (0.0 - 1.0)",
164 | "type": "float",
165 | "default": 0.1
166 | }
167 | }
168 | },
169 | "sparse_retriever": {
170 | "description": "稀疏检索器设置",
171 | "hint": "基于关键词的全文检索配置",
172 | "type": "object",
173 | "items": {
174 | "enabled": {
175 | "description": "启用稀疏检索",
176 | "hint": "是否启用基于BM25的稀疏检索功能。",
177 | "type": "bool",
178 | "default": true
179 | },
180 | "bm25_k1": {
181 | "description": "BM25 k1参数",
182 | "hint": "控制词频饱和度的参数,通常在1.2-2.0之间。",
183 | "type": "float",
184 | "default": 1.2
185 | },
186 | "bm25_b": {
187 | "description": "BM25 b参数",
188 | "hint": "控制文档长度归一化的参数,范围0.0-1.0。",
189 | "type": "float",
190 | "default": 0.75
191 | },
192 | "use_jieba": {
193 | "description": "使用中文分词",
194 | "hint": "是否使用jieba进行中文分词处理。",
195 | "type": "bool",
196 | "default": true
197 | }
198 | }
199 | },
200 | "filtering_settings": {
201 | "description": "过滤与隔离设置",
202 | "type": "object",
203 | "items": {
204 | "use_persona_filtering": {
205 | "description": "启用人格记忆过滤",
206 | "hint": "开启后,只会召回和总结与当前人格相关的记忆。",
207 | "type": "bool",
208 | "default": true
209 | },
210 | "use_session_filtering": {
211 | "description": "启用会话记忆隔离",
212 | "hint": "开启后,每个会话的记忆将是独立的。",
213 | "type": "bool",
214 | "default": true
215 | }
216 | }
217 | },
218 | "reflection_engine": {
219 | "description": "反思引擎设置",
220 | "hint": "负责生成和评估记忆",
221 | "type": "object",
222 | "items": {
223 | "summary_trigger_rounds": {
224 | "description": "总结触发轮次",
225 | "hint": "触发对话历史总结的对话轮次(一问一答为一轮)。",
226 | "type": "int",
227 | "default": 5
228 | },
229 | "importance_threshold": {
230 | "description": "重要性阈值",
231 | "hint": "记忆重要性得分的最低阈值,低于此值的记忆将被忽略。",
232 | "type": "float",
233 | "default": 0.5
234 | },
235 | "event_extraction_prompt": {
236 | "description": "事件提取提示词",
237 | "hint": "指导 LLM 从对话历史中提取多个结构化记忆事件的 System Prompt。",
238 | "type": "text",
239 | "default": "### 角色\n你是一个善于分析和总结的AI助手。你的核心人设是从你自身的视角出发,记录与用户的互动和观察。\n\n### 指令/任务\n1. **仔细阅读**并理解下面提供的“对话历史”。\n2. 从**你(AI)的视角**出发,提取出多个独立的、有意义的记忆事件。事件必须准确描述,参考上下文。事件必须是完整的,具有前因后果的。**不允许编造事件**,**不允许改变事件**,**详细描述事件的所有信息**\n3. **核心要求**:\n * **第一人称视角**:所有事件都必须以“我”开头进行描述,例如“我告诉用户...”、“我观察到...”、“我被告知...”。\n * **使用具体名称**:直接使用对话中出现的人物昵称,**严禁**使用“用户”、“开发者”等通用词汇。\n * **记录互动者**:必须明确记录与你互动的用户名称。\n * **事件合并**:如果多条连续的对话构成一个完整的独立事件,应将其总结概括为一条记忆。\n4.**严禁**包含任何评分、额外的解释或说明性文字。\n 直接输出结果,不要有任何引言或总结。\n\n### 上下文\n* 在对话历史中,名为“AstrBot”的发言者就是**你自己**。\n* 记忆事件是:你与用户互动事的事件描述,详细记录谁、在何时、何地、做了什么、发生了什么。\n\n 'memory_content' 字段必须包含完整的事件描述,不能省略任何细节。\n\n单个系列事件必须详细记录在一个memory_content 中,形成完整的具有前因后果的事件记忆。\n\n"
240 | },
241 | "evaluation_prompt": {
242 | "description": "评估提示词",
243 | "hint": "指导 LLM 评估单个记忆事件重要性的 System Prompt。必须返回一个 0.0 到 1.0 之间的浮点数。",
244 | "type": "text",
245 | "default": "### 角色\n你是一个专门评估记忆价值的AI分析模型。你的判断标准是该记忆对于与特定用户构建长期、个性化、有上下文的对话有多大的帮助。\n\n### 指令/任务\n1. **评估核心价值**:仔细阅读“记忆内容”,评估其对于未来对话的长期参考价值。\n2. **输出分数**:给出一个介于 0.0 到 1.0 之间的浮点数分数。\n3. **格式要求**:**只返回数字**,严禁包含任何额外的文本、解释或理由。\n\n### 上下文\n评分时,请参考以下价值标尺:\n* **高价值 (0.8 - 1.0)**:包含用户的核心身份信息、明确且长期的个人偏好/厌恶、设定的目标、重要的关系或事实。这些信息几乎总能在未来的互动中被引用。\n * 例如:用户的昵称、职业、关键兴趣点、对AI的称呼、重要的人生目标。\n* **中等价值 (0.4 - 0.7)**:包含用户的具体建议、功能请求、对某事的观点或一次性的重要问题。这些信息在短期内或特定话题下很有用,但可能随着时间推移或问题解决而失去价值。\n * 例如:对某个功能的反馈、对特定新闻事件的看法、报告了一个具体的bug。\n* **低价值 (0.1 - 0.3)**:包含短暂的情绪表达、日常问候、或非常具体且不太可能重复的上下文。这些信息很少有再次利用的机会。\n * 例如:一次性的惊叹、害怕的反应、普通的“你好”、“晚安”。\n* **无价值 (0.0)**:信息完全是瞬时的、无关紧要的,或者不包含任何关于用户本人的可复用信息。\n * 例如:观察到另一个机器人说了话、对一句无法理解的话的默认回应。\n\n### 问题\n请评估以下“记忆内容”的重要性,对于未来的对话有多大的参考价值?\n\n---\n\n**记忆内容**:\n{memory_content}\n\n"
246 | }
247 | }
248 | },
249 | "forgetting_agent": {
250 | "description": "遗忘代理设置",
251 | "hint": "负责模拟遗忘,清理陈旧记忆",
252 | "type": "object",
253 | "items": {
254 | "enabled": {
255 | "description": "启用自动遗忘",
256 | "hint": "是否启用自动遗忘机制。",
257 | "type": "bool",
258 | "default": true
259 | },
260 | "check_interval_hours": {
261 | "description": "检查间隔(小时)",
262 | "hint": "遗忘代理每隔多少小时运行一次。",
263 | "type": "int",
264 | "default": 24
265 | },
266 | "retention_days": {
267 | "description": "记忆保留天数",
268 | "hint": "记忆无条件保留的最长天数。超过此天数的记忆将可能被遗忘。",
269 | "type": "int",
270 | "default": 90
271 | },
272 | "importance_decay_rate": {
273 | "description": "重要性衰减率",
274 | "hint": "重要性得分每天衰减的速率。例如 0.01 代表每天降低 1%。",
275 | "type": "float",
276 | "default": 0.005
277 | },
278 | "importance_threshold": {
279 | "description": "遗忘重要性阈值",
280 | "hint": "重要性低于此阈值的陈旧记忆将被删除。",
281 | "type": "float",
282 | "default": 0.1
283 | },
284 | "forgetting_batch_size": {
285 | "description": "批处理大小",
286 | "hint": "遗忘代理每批处理的记忆数量,避免一次性加载过多数据。",
287 | "type": "int",
288 | "default": 1000
289 | }
290 | }
291 | }
292 | }
--------------------------------------------------------------------------------
/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxfight-s-Astrbot-Plugins/astrbot_plugin_livingmemory/a382b10a91a41b68abc850e5b1abe05052e372f0/core/__init__.py
--------------------------------------------------------------------------------
/core/commands/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | 命令模块 - 提供插件命令的统一管理
4 | """
5 |
6 | from .base_command import BaseCommand
7 | from .memory_commands import MemoryCommands
8 | from .search_commands import SearchCommands
9 | from .admin_commands import AdminCommands
10 | from .fusion_commands import FusionCommands
11 |
12 | __all__ = [
13 | 'BaseCommand',
14 | 'MemoryCommands',
15 | 'SearchCommands',
16 | 'AdminCommands',
17 | 'FusionCommands'
18 | ]
--------------------------------------------------------------------------------
/core/commands/base_command.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | base_command.py - 基础命令类
4 | 提供命令处理的基础功能和通用方法
5 | """
6 |
7 | from abc import ABC, abstractmethod
8 | from typing import Optional, Dict, Any, List
9 | from datetime import datetime, timezone
10 | import json
11 |
12 | from astrbot.api import logger
13 | from astrbot.api.event import filter, AstrMessageEvent
14 | from astrbot.api.event.filter import PermissionType, permission_type
15 | from astrbot.api.star import Context
16 |
17 |
18 | class BaseCommand(ABC):
19 | """基础命令类,提供通用的命令处理功能"""
20 |
21 | def __init__(self, context: Context, config: Dict[str, Any]):
22 | self.context = context
23 | self.config = config
24 |
25 | @abstractmethod
26 | def register_commands(self):
27 | """注册命令的抽象方法"""
28 | pass
29 |
30 | def get_timezone(self) -> Any:
31 | """获取当前时区"""
32 | tz_config = self.config.get("timezone_settings", {})
33 | tz_str = tz_config.get("timezone", "Asia/Shanghai")
34 | from ..utils import get_now_datetime
35 | return get_now_datetime(tz_str).tzinfo
36 |
37 | def format_timestamp(self, ts: Optional[float]) -> str:
38 | """格式化时间戳"""
39 | if not ts:
40 | return "未知"
41 | try:
42 | dt_utc = datetime.fromtimestamp(float(ts), tz=timezone.utc)
43 | dt_local = dt_utc.astimezone(self.get_timezone())
44 | return dt_local.strftime("%Y-%m-%d %H:%M:%S")
45 | except (ValueError, TypeError):
46 | return "未知"
47 |
48 | def safe_parse_metadata(self, metadata: Any) -> Dict[str, Any]:
49 | """安全解析元数据"""
50 | if isinstance(metadata, dict):
51 | return metadata
52 | if isinstance(metadata, str):
53 | try:
54 | return json.loads(metadata)
55 | except json.JSONDecodeError:
56 | return {}
57 | return {}
58 |
59 | def format_memory_card(self, result: Any) -> str:
60 | """格式化记忆卡片显示"""
61 | metadata = self.safe_parse_metadata(result.data.get("metadata", {}))
62 |
63 | create_time_str = self.format_timestamp(metadata.get("create_time"))
64 | last_access_time_str = self.format_timestamp(metadata.get("last_access_time"))
65 | importance_score = metadata.get("importance", 0.0)
66 | event_type = metadata.get("event_type", "未知")
67 |
68 | card = (
69 | f"ID: {result.data['id']}\n"
70 | f"记 忆 度: {result.similarity:.2f}\n"
71 | f"重 要 性: {importance_score:.2f}\n"
72 | f"记忆类型: {event_type}\n\n"
73 | f"内容: {result.data['text']}\n\n"
74 | f"创建于: {create_time_str}\n"
75 | f"最后访问: {last_access_time_str}"
76 | )
77 | return card
--------------------------------------------------------------------------------
/core/community/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxfight-s-Astrbot-Plugins/astrbot_plugin_livingmemory/a382b10a91a41b68abc850e5b1abe05052e372f0/core/community/__init__.py
--------------------------------------------------------------------------------
/core/community/community_detector.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import asyncio
3 | import aiosqlite
4 | import networkx as nx
5 | from typing import List, Tuple
6 |
7 | from astrbot.api import logger
8 |
9 |
10 | class CommunityDetector:
11 | """
12 | 一个后台服务,负责从 SQLite 加载图数据,
13 | 使用 NetworkX 进行社区发现,并将结果写回数据库。
14 | """
15 |
16 | def __init__(self, db_path: str):
17 | self.db_path = db_path
18 | self.connection = None
19 |
20 | async def initialize(self):
21 | """初始化数据库连接。"""
22 | self.connection = await aiosqlite.connect(self.db_path)
23 |
24 | async def close(self):
25 | """关闭数据库连接。"""
26 | if self.connection:
27 | await self.connection.close()
28 | self.connection = None
29 |
30 | async def _load_graph_from_db(self) -> nx.Graph:
31 | """从 SQLite 中加载边,构建一个 NetworkX 图对象。"""
32 | G = nx.Graph()
33 | if not self.connection:
34 | await self.initialize()
35 |
36 | try:
37 | # 我们只需要边的信息来构建图的结构
38 | cursor = await self.connection.execute("SELECT source_id, target_id FROM graph_edges")
39 | edges = await cursor.fetchall()
40 | G.add_edges_from(edges)
41 | except Exception as e:
42 | logger.error(f"从数据库加载图数据失败: {e}")
43 | raise
44 |
45 | return G
46 |
47 | async def _save_results_to_db(self, communities: List[Tuple[str]]):
48 | """将计算出的社区结果批量更新回 memories 表。"""
49 | if not self.connection:
50 | logger.error("数据库连接未初始化")
51 | return
52 |
53 | # 注意:这里的逻辑需要一个从“图节点ID”到“记忆internal_id”的映射
54 | # 我们简化一下,假设 Event 节点的 ID 就是 memory_id
55 | updates = []
56 | for i, community_nodes in enumerate(communities):
57 | community_id = f"community_{i}"
58 | for node_id in community_nodes:
59 | # 假设事件节点的 entity_id 格式为 'evt_mem_xxx'
60 | if node_id.startswith("evt_mem_"):
61 | memory_id = node_id.split("evt_mem_")[1]
62 | updates.append((community_id, memory_id))
63 |
64 | if updates:
65 | try:
66 | await self.connection.executemany(
67 | "UPDATE memories SET community_id = ? WHERE memory_id = ?", updates
68 | )
69 | await self.connection.commit()
70 | logger.info(f"成功更新 {len(updates)} 条记忆的社区信息")
71 | except Exception as e:
72 | logger.error(f"更新社区信息失败: {e}")
73 | raise
74 |
75 | async def run_detection_and_update(self):
76 | """
77 | 执行社区发现的完整流程。使用进程池进行计算密集型任务。
78 | """
79 | try:
80 | logger.info("开始从数据库加载图...")
81 | graph = await self._load_graph_from_db()
82 |
83 | if graph.number_of_nodes() == 0:
84 | logger.info("图中没有节点,跳过社区发现。")
85 | return
86 |
87 | logger.info(f"图加载完成({graph.number_of_nodes()}个节点,{graph.number_of_edges()}条边),开始运行 Louvain 社区发现算法...")
88 |
89 | # 使用进程池执行计算密集型的社区发现算法
90 | import concurrent.futures
91 | with concurrent.futures.ProcessPoolExecutor() as executor:
92 | # resolution 参数可以调整社区的大小,值越小社区越多越小
93 | communities = await asyncio.get_event_loop().run_in_executor(
94 | executor,
95 | nx.community.louvain_communities,
96 | graph,
97 | 1.0 # resolution
98 | )
99 |
100 | logger.info(f"发现 {len(communities)} 个社区,开始将结果写回数据库...")
101 | await self._save_results_to_db(communities)
102 | logger.info("社区信息更新完成。")
103 |
104 | except Exception as e:
105 | logger.error(f"社区发现过程中发生错误: {e}", exc_info=True)
106 | raise
107 |
--------------------------------------------------------------------------------
/core/config_validator.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | config_validator.py - 配置验证模块
4 | 提供配置验证和默认值管理功能。
5 | """
6 |
7 | from typing import Dict, Any, List, Optional, Union
8 | from pydantic import BaseModel, Field, field_validator, model_validator
9 | from astrbot.api import logger
10 |
11 |
12 | class SessionManagerConfig(BaseModel):
13 | """会话管理器配置"""
14 | max_sessions: int = Field(default=1000, ge=1, le=10000, description="最大会话数量")
15 | session_ttl: int = Field(default=3600, ge=60, le=86400, description="会话生存时间(秒)")
16 |
17 |
18 | class RecallEngineConfig(BaseModel):
19 | """回忆引擎配置"""
20 | top_k: int = Field(default=5, ge=1, le=50, description="返回记忆数量")
21 | recall_strategy: str = Field(default="weighted", pattern="^(similarity|weighted)$", description="召回策略")
22 | retrieval_mode: str = Field(default="hybrid", pattern="^(hybrid|dense|sparse)$", description="检索模式")
23 | similarity_weight: float = Field(default=0.6, ge=0.0, le=1.0, description="相似度权重")
24 | importance_weight: float = Field(default=0.2, ge=0.0, le=1.0, description="重要性权重")
25 | recency_weight: float = Field(default=0.2, ge=0.0, le=1.0, description="新近度权重")
26 |
27 | @model_validator(mode='after')
28 | def validate_weights_sum(self):
29 | """验证权重总和接近1.0"""
30 | similarity = self.similarity_weight
31 | importance = self.importance_weight
32 | recency = self.recency_weight
33 |
34 | # 计算权重总和
35 | total = similarity + importance + recency
36 | if abs(total - 1.0) > 0.1:
37 | logger.warning(f"权重总和 {total:.2f} 偏离1.0较多,可能影响检索效果")
38 |
39 | return self
40 |
41 |
42 | class FusionConfig(BaseModel):
43 | """结果融合配置"""
44 | strategy: str = Field(
45 | default="rrf",
46 | pattern="^(rrf|weighted|cascade|adaptive|convex|interleave|rank_fusion|score_fusion|hybrid_rrf)$",
47 | description="融合策略"
48 | )
49 | rrf_k: int = Field(default=60, ge=1, le=1000, description="RRF参数k")
50 | dense_weight: float = Field(default=0.7, ge=0.0, le=1.0, description="密集检索权重")
51 | sparse_weight: float = Field(default=0.3, ge=0.0, le=1.0, description="稀疏检索权重")
52 | sparse_alpha: float = Field(default=1.0, ge=0.1, le=10.0, description="稀疏分数缩放")
53 | sparse_epsilon: float = Field(default=0.0, ge=0.0, le=1.0, description="稀疏分数偏移")
54 |
55 | # 新增参数
56 | convex_lambda: float = Field(default=0.5, ge=0.0, le=1.0, description="凸组合参数λ")
57 | interleave_ratio: float = Field(default=0.5, ge=0.0, le=1.0, description="交替融合比例")
58 | rank_bias_factor: float = Field(default=0.1, ge=0.0, le=1.0, description="排序偏置因子")
59 | diversity_bonus: float = Field(default=0.1, ge=0.0, le=1.0, description="多样性奖励")
60 |
61 |
62 | class ReflectionEngineConfig(BaseModel):
63 | """反思引擎配置"""
64 | summary_trigger_rounds: int = Field(default=10, ge=1, le=100, description="触发反思的对话轮次")
65 | importance_threshold: float = Field(default=0.5, ge=0.0, le=1.0, description="记忆重要性阈值")
66 | event_extraction_prompt: Optional[str] = Field(default=None, description="事件提取提示词")
67 | evaluation_prompt: Optional[str] = Field(default=None, description="评分提示词")
68 |
69 |
70 | class SparseRetrieverConfig(BaseModel):
71 | """稀疏检索器配置"""
72 | enabled: bool = Field(default=True, description="是否启用稀疏检索")
73 | bm25_k1: float = Field(default=1.2, ge=0.1, le=10.0, description="BM25 k1参数")
74 | bm25_b: float = Field(default=0.75, ge=0.0, le=1.0, description="BM25 b参数")
75 | use_jieba: bool = Field(default=True, description="是否使用jieba分词")
76 |
77 |
78 | class ForgettingAgentConfig(BaseModel):
79 | """遗忘代理配置"""
80 | enabled: bool = Field(default=True, description="是否启用遗忘代理")
81 | check_interval_hours: int = Field(default=24, ge=1, le=168, description="检查间隔(小时)")
82 | retention_days: int = Field(default=90, ge=1, le=3650, description="记忆保留天数")
83 | importance_decay_rate: float = Field(default=0.005, ge=0.0, le=1.0, description="重要性衰减率")
84 | importance_threshold: float = Field(default=0.1, ge=0.0, le=1.0, description="删除阈值")
85 | forgetting_batch_size: int = Field(default=1000, ge=100, le=10000, description="批处理大小")
86 |
87 |
88 | class FilteringConfig(BaseModel):
89 | """过滤配置"""
90 | use_persona_filtering: bool = Field(default=True, description="是否使用人格过滤")
91 | use_session_filtering: bool = Field(default=True, description="是否使用会话过滤")
92 |
93 |
94 | class ProviderConfig(BaseModel):
95 | """Provider配置"""
96 | embedding_provider_id: Optional[str] = Field(default=None, description="Embedding Provider ID")
97 | llm_provider_id: Optional[str] = Field(default=None, description="LLM Provider ID")
98 |
99 |
100 | class TimezoneConfig(BaseModel):
101 | """时区配置"""
102 | timezone: str = Field(default="Asia/Shanghai", description="时区")
103 |
104 |
105 | class LivingMemoryConfig(BaseModel):
106 | """完整插件配置"""
107 | session_manager: SessionManagerConfig = Field(default_factory=SessionManagerConfig)
108 | recall_engine: RecallEngineConfig = Field(default_factory=RecallEngineConfig)
109 | reflection_engine: ReflectionEngineConfig = Field(default_factory=ReflectionEngineConfig)
110 | sparse_retriever: SparseRetrieverConfig = Field(default_factory=SparseRetrieverConfig)
111 | forgetting_agent: ForgettingAgentConfig = Field(default_factory=ForgettingAgentConfig)
112 | filtering_settings: FilteringConfig = Field(default_factory=FilteringConfig)
113 | provider_settings: ProviderConfig = Field(default_factory=ProviderConfig)
114 | timezone_settings: TimezoneConfig = Field(default_factory=TimezoneConfig)
115 |
116 | # 为融合配置添加嵌套支持
117 | fusion: Optional[FusionConfig] = Field(default_factory=FusionConfig, description="结果融合配置")
118 |
119 | model_config = {"extra": "allow"} # 允许额外字段,向前兼容
120 |
121 |
122 | def validate_config(raw_config: Dict[str, Any]) -> LivingMemoryConfig:
123 | """
124 | 验证并返回规范化的配置对象。
125 |
126 | Args:
127 | raw_config: 原始配置字典
128 |
129 | Returns:
130 | LivingMemoryConfig: 验证后的配置对象
131 |
132 | Raises:
133 | ValueError: 配置验证失败
134 | """
135 | try:
136 | config = LivingMemoryConfig(**raw_config)
137 | logger.info("配置验证成功")
138 | return config
139 | except Exception as e:
140 | logger.error(f"配置验证失败: {e}")
141 | raise ValueError(f"插件配置无效: {e}") from e
142 |
143 |
144 | def get_default_config() -> Dict[str, Any]:
145 | """
146 | 获取默认配置字典。
147 |
148 | Returns:
149 | Dict[str, Any]: 默认配置
150 | """
151 | return LivingMemoryConfig().model_dump()
152 |
153 |
154 | def merge_config_with_defaults(user_config: Dict[str, Any]) -> Dict[str, Any]:
155 | """
156 | 将用户配置与默认配置合并。
157 |
158 | Args:
159 | user_config: 用户提供的配置
160 |
161 | Returns:
162 | Dict[str, Any]: 合并后的配置
163 | """
164 | default_config = get_default_config()
165 |
166 | def deep_merge(default: Dict[str, Any], user: Dict[str, Any]) -> Dict[str, Any]:
167 | """深度合并两个字典"""
168 | result = default.copy()
169 | for key, value in user.items():
170 | if key in result and isinstance(result[key], dict) and isinstance(value, dict):
171 | result[key] = deep_merge(result[key], value)
172 | else:
173 | result[key] = value
174 | return result
175 |
176 | merged = deep_merge(default_config, user_config)
177 | logger.debug("配置已与默认值合并")
178 | return merged
179 |
180 |
181 | def validate_runtime_config_changes(current_config: LivingMemoryConfig, changes: Dict[str, Any]) -> bool:
182 | """
183 | 验证运行时配置更改是否有效。
184 |
185 | Args:
186 | current_config: 当前配置
187 | changes: 要更改的配置项
188 |
189 | Returns:
190 | bool: 是否有效
191 | """
192 | try:
193 | # 创建更新后的配置副本进行验证
194 | updated_dict = current_config.model_dump()
195 |
196 | def update_nested_dict(target: Dict[str, Any], updates: Dict[str, Any]):
197 | for key, value in updates.items():
198 | if '.' in key:
199 | # 处理嵌套键,如 "recall_engine.top_k"
200 | parts = key.split('.')
201 | current = target
202 | for part in parts[:-1]:
203 | if part not in current:
204 | current[part] = {}
205 | current = current[part]
206 | current[parts[-1]] = value
207 | else:
208 | target[key] = value
209 |
210 | update_nested_dict(updated_dict, changes)
211 |
212 | # 验证更新后的配置
213 | LivingMemoryConfig(**updated_dict)
214 | return True
215 |
216 | except Exception as e:
217 | logger.error(f"运行时配置更改验证失败: {e}")
218 | return False
--------------------------------------------------------------------------------
/core/constants.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | constants.py - 插件使用的常量
4 | """
5 |
6 | # 注入到 System Prompt 的记忆头尾格式
7 | MEMORY_INJECTION_HEADER = ""
8 | MEMORY_INJECTION_FOOTER = ""
9 |
--------------------------------------------------------------------------------
/core/engines/forgetting_agent.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | forgetting_agent.py - 遗忘代理
4 | 作为一个后台任务,定期清理陈旧的、不重要的记忆,模拟人类的遗忘曲线。
5 | """
6 |
7 | import asyncio
8 | import json
9 | from typing import Dict, Any, Optional
10 |
11 | from astrbot.api import logger
12 | from astrbot.api.star import Context
13 | from ...storage.faiss_manager import FaissManager
14 | from ..utils import get_now_datetime, safe_parse_metadata, validate_timestamp
15 |
16 |
17 | class ForgettingAgent:
18 | """
19 | 遗忘代理:作为一个后台任务,定期清理陈旧的、不重要的记忆,模拟人类的遗忘曲线。
20 | """
21 |
22 | def __init__(
23 | self, context: Context, config: Dict[str, Any], faiss_manager: FaissManager
24 | ):
25 | """
26 | 初始化遗忘代理。
27 |
28 | Args:
29 | context (Context): AstrBot 的上下文对象。
30 | config (Dict[str, Any]): 插件配置中 'forgetting_agent' 部分的字典。
31 | faiss_manager (FaissManager): 数据库管理器实例。
32 | """
33 | self.context = context
34 | self.config = config
35 | self.faiss_manager = faiss_manager
36 | self._task: Optional[asyncio.Task] = None
37 | logger.info("ForgettingAgent 初始化成功。")
38 |
39 | async def start(self):
40 | """启动后台遗忘任务。"""
41 | if not self.config.get("enabled", True):
42 | logger.info("遗忘代理未启用,不启动后台任务。")
43 | return
44 |
45 | if self._task is None or self._task.done():
46 | self._task = asyncio.create_task(self._run_periodically())
47 | logger.info("遗忘代理后台任务已启动。")
48 |
49 | async def stop(self):
50 | """停止后台遗忘任务。"""
51 | if self._task and not self._task.done():
52 | self._task.cancel()
53 | try:
54 | await self._task
55 | except asyncio.CancelledError:
56 | logger.info("遗忘代理后台任务已成功取消。")
57 | self._task = None
58 |
59 | async def _run_periodically(self):
60 | """后台任务的循环体。"""
61 | interval_hours = self.config.get("check_interval_hours", 24)
62 | interval_seconds = interval_hours * 3600
63 | logger.info(f"遗忘代理将每 {interval_hours} 小时运行一次。")
64 |
65 | while True:
66 | try:
67 | await asyncio.sleep(interval_seconds)
68 | logger.info("开始执行每日记忆清理任务...")
69 | await self._prune_memories()
70 | logger.info("每日记忆清理任务执行完毕。")
71 | except asyncio.CancelledError:
72 | logger.info("遗忘代理任务被取消。")
73 | break
74 | except Exception as e:
75 | logger.error(f"遗忘代理后台任务发生错误: {e}", exc_info=True)
76 | # 即使出错,也等待下一个周期,避免快速失败刷屏
77 | await asyncio.sleep(60)
78 |
79 | async def _prune_memories(self):
80 | """执行一次完整的记忆衰减和修剪,使用分页处理避免内存过载。"""
81 | # 获取记忆总数
82 | total_memories = await self.faiss_manager.count_total_memories()
83 | if total_memories == 0:
84 | logger.info("数据库中没有记忆,无需清理。")
85 | return
86 |
87 | retention_days = self.config.get("retention_days", 90)
88 | decay_rate = self.config.get("importance_decay_rate", 0.005)
89 | current_time = get_now_datetime(self.context).timestamp()
90 |
91 | # 分页处理配置
92 | page_size = self.config.get("forgetting_batch_size", 1000) # 每批处理数量
93 |
94 | logger.info(f"开始处理 {total_memories} 条记忆,每批 {page_size} 条")
95 |
96 | memories_to_update = []
97 | ids_to_delete = []
98 | total_processed = 0
99 |
100 | # 分页处理所有记忆
101 | for offset in range(0, total_memories, page_size):
102 | batch_memories = await self.faiss_manager.get_memories_paginated(
103 | page_size=page_size, offset=offset
104 | )
105 |
106 | if not batch_memories:
107 | break
108 |
109 | logger.debug(f"处理第 {offset//page_size + 1} 批,共 {len(batch_memories)} 条记忆")
110 |
111 | batch_updates = []
112 | batch_deletes = []
113 |
114 | for mem in batch_memories:
115 | # 使用统一的元数据解析函数
116 | metadata = safe_parse_metadata(mem["metadata"])
117 | if not metadata:
118 | logger.warning(f"无法解析记忆 {mem['id']} 的元数据,跳过处理")
119 | continue
120 |
121 | # 1. 重要性衰减
122 | create_time = validate_timestamp(metadata.get("create_time"), current_time)
123 | days_since_creation = (current_time - create_time) / (24 * 3600)
124 |
125 | # 线性衰减
126 | decayed_importance = metadata.get("importance", 0.5) - (
127 | days_since_creation * decay_rate
128 | )
129 | metadata["importance"] = max(0, decayed_importance) # 确保不为负
130 |
131 | mem["metadata"] = metadata # 更新内存中的 metadata
132 | batch_updates.append(mem)
133 |
134 | # 2. 识别待删除项
135 | retention_seconds = retention_days * 24 * 3600
136 | is_old = (current_time - create_time) > retention_seconds
137 | # 从配置中读取重要性阈值
138 | importance_threshold = self.config.get("importance_threshold", 0.1)
139 | is_unimportant = metadata["importance"] < importance_threshold
140 |
141 | if is_old and is_unimportant:
142 | batch_deletes.append(mem["id"])
143 |
144 | # 累积到全局列表
145 | memories_to_update.extend(batch_updates)
146 | ids_to_delete.extend(batch_deletes)
147 | total_processed += len(batch_memories)
148 |
149 | # 如果批次数据过多,执行中间提交
150 | if len(memories_to_update) >= page_size * 2:
151 | logger.debug(f"执行中间批次更新,更新 {len(memories_to_update)} 条记忆")
152 | await self.faiss_manager.update_memories_metadata(memories_to_update)
153 | memories_to_update.clear()
154 |
155 | logger.debug(f"已处理 {total_processed}/{total_memories} 条记忆")
156 |
157 | # 3. 执行最终数据库操作
158 | if memories_to_update:
159 | await self.faiss_manager.update_memories_metadata(memories_to_update)
160 | logger.info(f"更新了 {len(memories_to_update)} 条记忆的重要性得分。")
161 |
162 | if ids_to_delete:
163 | # 分批删除,避免一次删除太多
164 | delete_batch_size = 100
165 | for i in range(0, len(ids_to_delete), delete_batch_size):
166 | batch = ids_to_delete[i:i + delete_batch_size]
167 | await self.faiss_manager.delete_memories(batch)
168 | logger.debug(f"删除了 {len(batch)} 条记忆")
169 |
170 | logger.info(f"总共删除了 {len(ids_to_delete)} 条陈旧且不重要的记忆。")
171 |
172 | logger.info(f"记忆清理完成,处理了 {total_processed} 条记忆")
173 |
--------------------------------------------------------------------------------
/core/engines/recall_engine.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | recall_engine.py - 回忆引擎
4 | 负责根据用户查询,使用多策略智能召回最相关的记忆。
5 | 支持密集向量检索、稀疏检索和混合检索。
6 | """
7 |
8 | import json
9 | import math
10 | from typing import List, Dict, Any, Optional
11 |
12 | from astrbot.api import logger
13 | from astrbot.api.star import Context
14 | from ...storage.faiss_manager import FaissManager, Result
15 | from ..retrieval import SparseRetriever, ResultFusion, SearchResult
16 | from ..utils import get_now_datetime
17 |
18 |
19 | class RecallEngine:
20 | """
21 | 回忆引擎:负责根据用户查询,使用多策略智能召回最相关的记忆。
22 | 支持密集向量检索、稀疏检索和混合检索。
23 | """
24 |
25 | def __init__(self, config: Dict[str, Any], faiss_manager: FaissManager, sparse_retriever: Optional[SparseRetriever] = None):
26 | """
27 | 初始化回忆引擎。
28 |
29 | Args:
30 | config (Dict[str, Any]): 插件配置中 'recall_engine' 部分的字典。
31 | faiss_manager (FaissManager): 数据库管理器实例。
32 | sparse_retriever (Optional[SparseRetriever]): 稀疏检索器实例。
33 | """
34 | self.config = config
35 | self.faiss_manager = faiss_manager
36 | self.sparse_retriever = sparse_retriever
37 |
38 | # 初始化结果融合器
39 | fusion_config = config.get("fusion", {})
40 | fusion_strategy = fusion_config.get("strategy", "rrf")
41 | self.result_fusion = ResultFusion(strategy=fusion_strategy, config=fusion_config)
42 |
43 | logger.info("RecallEngine 初始化成功。")
44 |
45 | async def recall(
46 | self,
47 | context: Context,
48 | query: str,
49 | session_id: Optional[str] = None,
50 | persona_id: Optional[str] = None,
51 | k: Optional[int] = None,
52 | ) -> List[Result]:
53 | """
54 | 执行回忆流程,检索并可能重排记忆。
55 |
56 | Args:
57 | query (str): 用户查询文本。
58 | session_id (Optional[str], optional): 当前会话 ID. Defaults to None.
59 | persona_id (Optional[str], optional): 当前人格 ID. Defaults to None.
60 | k (Optional[int], optional): 希望返回的记忆数量,如果为 None 则从配置中读取.
61 |
62 | Returns:
63 | List[Result]: 最终返回给上层应用的记忆列表。
64 | """
65 | top_k = k if k is not None else self.config.get("top_k", 5)
66 | retrieval_mode = self.config.get("retrieval_mode", "hybrid") # hybrid, dense, sparse
67 |
68 | # 分析查询特征(用于自适应策略)
69 | query_info = self.result_fusion.analyze_query(query)
70 | logger.debug(f"Query analysis: {query_info}")
71 |
72 | # 根据检索模式执行搜索
73 | if retrieval_mode == "hybrid" and self.sparse_retriever:
74 | # 混合检索
75 | logger.debug("使用混合检索模式...")
76 | return await self._hybrid_search(context, query, session_id, persona_id, top_k, query_info)
77 | elif retrieval_mode == "sparse" and self.sparse_retriever:
78 | # 纯稀疏检索
79 | logger.debug("使用稀疏检索模式...")
80 | return await self._sparse_search(query, session_id, persona_id, top_k)
81 | else:
82 | # 纯密集检索(默认)
83 | logger.debug("使用密集检索模式...")
84 | return await self._dense_search(context, query, session_id, persona_id, top_k)
85 |
86 | async def _hybrid_search(
87 | self,
88 | context: Context,
89 | query: str,
90 | session_id: Optional[str],
91 | persona_id: Optional[str],
92 | k: int,
93 | query_info: Dict[str, Any]
94 | ) -> List[Result]:
95 | """执行混合检索"""
96 | # 并行执行密集和稀疏检索
97 | import asyncio
98 |
99 | # 密集检索
100 | dense_task = self.faiss_manager.search_memory(
101 | query=query, k=k*2, session_id=session_id, persona_id=persona_id
102 | )
103 |
104 | # 稀疏检索
105 | sparse_task = self.sparse_retriever.search(
106 | query=query, limit=k*2, session_id=session_id, persona_id=persona_id
107 | )
108 |
109 | # 等待两个检索完成
110 | dense_results, sparse_results = await asyncio.gather(dense_task, sparse_task, return_exceptions=True)
111 |
112 | # 处理异常
113 | if isinstance(dense_results, Exception):
114 | logger.error(f"密集检索失败: {dense_results}")
115 | dense_results = []
116 | if isinstance(sparse_results, Exception):
117 | logger.error(f"稀疏检索失败: {sparse_results}")
118 | sparse_results = []
119 |
120 | logger.debug(f"Dense results: {len(dense_results)}, Sparse results: {len(sparse_results)}")
121 |
122 | # 融合结果
123 | fused_results = self.result_fusion.fuse(
124 | dense_results=dense_results,
125 | sparse_results=sparse_results,
126 | k=k,
127 | query_info=query_info
128 | )
129 |
130 | # 转换回 Result 格式
131 | final_results = []
132 | for result in fused_results:
133 | final_results.append(Result(
134 | data={
135 | "id": result.doc_id,
136 | "text": result.content,
137 | "metadata": result.metadata
138 | },
139 | similarity=result.final_score
140 | ))
141 |
142 | # 应用传统的加权重排(如果需要)
143 | strategy = self.config.get("recall_strategy", "weighted")
144 | if strategy == "weighted":
145 | logger.debug("对混合检索结果应用加权重排...")
146 | final_results = self._rerank_by_weighted_score(context, final_results)
147 |
148 | return final_results
149 |
150 | async def _dense_search(
151 | self,
152 | context: Context,
153 | query: str,
154 | session_id: Optional[str],
155 | persona_id: Optional[str],
156 | k: int
157 | ) -> List[Result]:
158 | """执行密集检索"""
159 | results = await self.faiss_manager.search_memory(
160 | query=query, k=k, session_id=session_id, persona_id=persona_id
161 | )
162 |
163 | if not results:
164 | return []
165 |
166 | # 应用重排
167 | strategy = self.config.get("recall_strategy", "weighted")
168 | if strategy == "weighted":
169 | logger.debug("使用 'weighted' 策略进行重排...")
170 | return self._rerank_by_weighted_score(context, results)
171 | else:
172 | logger.debug("使用 'similarity' 策略,直接返回结果。")
173 | return results
174 |
175 | async def _sparse_search(
176 | self,
177 | query: str,
178 | session_id: Optional[str],
179 | persona_id: Optional[str],
180 | k: int
181 | ) -> List[Result]:
182 | """执行稀疏检索"""
183 | sparse_results = await self.sparse_retriever.search(
184 | query=query, limit=k, session_id=session_id, persona_id=persona_id
185 | )
186 |
187 | # 转换为 Result 格式
188 | results = []
189 | for result in sparse_results:
190 | results.append(Result(
191 | data={
192 | "id": result.doc_id,
193 | "text": result.content,
194 | "metadata": result.metadata
195 | },
196 | similarity=result.score
197 | ))
198 |
199 | return results
200 |
201 | def _rerank_by_weighted_score(
202 | self, context: Context, results: List[Result]
203 | ) -> List[Result]:
204 | """
205 | 根据相似度、重要性和新近度对结果进行加权重排。
206 | """
207 | sim_w = self.config.get("similarity_weight", 0.6)
208 | imp_w = self.config.get("importance_weight", 0.2)
209 | rec_w = self.config.get("recency_weight", 0.2)
210 |
211 | reranked_results = []
212 | current_time = get_now_datetime(context).timestamp()
213 |
214 | for res in results:
215 | # 安全解析元数据
216 | metadata = res.data.get("metadata", {})
217 | if isinstance(metadata, str):
218 | try:
219 | metadata = json.loads(metadata)
220 | except json.JSONDecodeError as e:
221 | logger.warning(f"解析记忆元数据失败: {e}")
222 | metadata = {}
223 |
224 | # 归一化各项得分 (0-1)
225 | similarity_score = res.similarity
226 | importance_score = metadata.get("importance", 0.0)
227 |
228 | # 计算新近度得分
229 | last_access = metadata.get("last_access_time", current_time)
230 | # 增加健壮性检查,以防 last_access 是字符串
231 | if isinstance(last_access, str):
232 | try:
233 | last_access = float(last_access)
234 | except (ValueError, TypeError):
235 | last_access = current_time
236 |
237 | hours_since_access = (current_time - last_access) / 3600
238 | # 使用指数衰减,半衰期约为24小时
239 | recency_score = math.exp(-0.028 * hours_since_access)
240 |
241 | # 计算最终加权分
242 | final_score = (
243 | similarity_score * sim_w
244 | + importance_score * imp_w
245 | + recency_score * rec_w
246 | )
247 |
248 | # 直接修改现有 Result 对象的 similarity 分数
249 | res.similarity = final_score
250 | reranked_results.append(res)
251 |
252 | # 按最终得分降序排序
253 | reranked_results.sort(key=lambda x: x.similarity, reverse=True)
254 |
255 | return reranked_results
256 |
--------------------------------------------------------------------------------
/core/engines/reflection_engine.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | reflection_engine.py - 反思引擎
4 | 负责对会话历史进行反思,提取、评估并存储多个独立的、基于事件的记忆。
5 | """
6 |
7 | import json
8 | from typing import List, Dict, Any, Optional
9 |
10 | from pydantic import ValidationError
11 |
12 | from astrbot.api import logger
13 | from astrbot.api.provider import Provider
14 | from ...storage.faiss_manager import FaissManager
15 | from ..utils import extract_json_from_response
16 | from ..models import (
17 | MemoryEvent,
18 | _LLMExtractionEventList,
19 | _LLMScoreEvaluation,
20 | )
21 |
22 |
23 | class ReflectionEngine:
24 | """
25 | 反思引擎:负责对会话历史进行反思,提取、评估并存储多个独立的、基于事件的记忆。
26 | 采用两阶段流程:1. 批量提取事件 2. 批量评估分数
27 | """
28 |
29 | def __init__(
30 | self,
31 | config: Dict[str, Any],
32 | llm_provider: Provider,
33 | faiss_manager: FaissManager,
34 | ):
35 | self.config = config
36 | self.llm_provider = llm_provider
37 | self.faiss_manager = faiss_manager
38 | logger.info("ReflectionEngine 初始化成功。")
39 |
40 | async def _extract_events(
41 | self, history_text: str, persona_prompt: Optional[str]
42 | ) -> List[MemoryEvent]:
43 | """第一阶段:从对话历史中批量提取记忆事件。"""
44 | system_prompt = self._build_event_extraction_prompt()
45 | persona_section = (
46 | f"\n**重要:**在分析时请代入以下人格,但是应该秉持着记录互动者的原则:\n{persona_prompt}\n"
47 | if persona_prompt
48 | else ""
49 | )
50 | user_prompt = f"{persona_section}下面是你需要分析的对话历史:\n{history_text}"
51 |
52 | response = await self.llm_provider.text_chat(
53 | prompt=user_prompt, system_prompt=system_prompt, json_mode=True
54 | )
55 |
56 | json_text = extract_json_from_response(response.completion_text.strip())
57 | if not json_text:
58 | logger.warning("LLM 提取事件返回为空。")
59 | return []
60 | logger.debug(f"提取到的记忆事件: {json_text}")
61 |
62 | try:
63 | extracted_data = _LLMExtractionEventList.model_validate_json(json_text)
64 | # 转换为 MemoryEvent 对象列表
65 | # 注意:LLM 返回的是 _LLMExtractionEvent,其 id 字段对应 MemoryEvent 的 temp_id
66 | memory_events = []
67 | for event in extracted_data.events:
68 | event_dict = event.model_dump()
69 | # 将 'id' 字段重命名为 'temp_id' 以匹配 MemoryEvent 模型
70 | if "id" in event_dict:
71 | event_dict["temp_id"] = event_dict.pop("id")
72 | memory_events.append(MemoryEvent(**event_dict))
73 | return memory_events
74 | except (ValidationError, json.JSONDecodeError) as e:
75 | logger.error(
76 | f"事件提取阶段JSON解析失败: {e}\n原始返回: {response.completion_text.strip()}",
77 | exc_info=True,
78 | )
79 | return []
80 |
81 | async def _evaluate_scores(
82 | self, events: List[MemoryEvent], persona_prompt: Optional[str]
83 | ) -> Dict[str, float]:
84 | """第二阶段:对一批记忆事件进行批量评分。"""
85 | if not events:
86 | return {}
87 |
88 | system_prompt = self._build_evaluation_prompt()
89 |
90 | # 构建批量评估的输入
91 | memories_to_evaluate = [
92 | {"id": event.temp_id, "content": event.memory_content} for event in events
93 | ]
94 | persona_section = (
95 | f"\n**重要:**在评估时请代入以下人格,这会影响你对“重要性”的判断:\n{persona_prompt}\n"
96 | if persona_prompt
97 | else ""
98 | )
99 | user_prompt = persona_section + json.dumps(
100 | {"memories": memories_to_evaluate}, ensure_ascii=False, indent=2
101 | )
102 |
103 | response = await self.llm_provider.text_chat(
104 | prompt=user_prompt, system_prompt=system_prompt, json_mode=True
105 | )
106 |
107 | json_text = extract_json_from_response(response.completion_text.strip())
108 | if not json_text:
109 | logger.warning("LLM 评估分数返回为空。")
110 | return {}
111 | logger.debug(
112 | f"评估分数: {json_text},对应内容{[event.temp_id for event in events]}。"
113 | )
114 |
115 | try:
116 | evaluated_data = _LLMScoreEvaluation.model_validate_json(json_text)
117 | return evaluated_data.scores
118 | except (ValidationError, json.JSONDecodeError) as e:
119 | logger.error(
120 | f"分数评估阶段JSON解析失败: {e}\n原始返回: {response.completion_text.strip()}",
121 | exc_info=True,
122 | )
123 | return {}
124 |
125 | async def reflect_and_store(
126 | self,
127 | conversation_history: List[Dict[str, str]],
128 | session_id: str,
129 | persona_id: Optional[str] = None,
130 | persona_prompt: Optional[str] = None,
131 | ):
132 | """执行完整的两阶段反思、评估和存储流程。"""
133 | try:
134 | history_text = self._format_history_for_summary(conversation_history)
135 | if not history_text:
136 | logger.debug("对话历史为空,跳过反思。")
137 | return
138 |
139 | # --- 第一阶段:提取事件 ---
140 | logger.info(f"[{session_id}] 阶段1:开始批量提取记忆事件...")
141 | extracted_events = await self._extract_events(history_text, persona_prompt)
142 | if not extracted_events:
143 | logger.info(f"[{session_id}] 未能从对话中提取任何记忆事件。")
144 | return
145 | logger.info(f"[{session_id}] 成功提取 {len(extracted_events)} 个记忆事件。")
146 |
147 | # --- 第二阶段:评估分数 ---
148 | logger.info(f"[{session_id}] 阶段2:开始批量评估事件重要性...")
149 | scores = await self._evaluate_scores(extracted_events, persona_prompt)
150 | logger.info(f"[{session_id}] 成功收到 {len(scores)} 个评分。")
151 |
152 | # --- 第三阶段:合并与存储 ---
153 | threshold = self.config.get("importance_threshold", 0.5)
154 | logger.info(f"[{session_id}] 阶段3:开始存储筛选,重要性阈值: {threshold}")
155 |
156 | stored_count = 0
157 | filtered_count = 0
158 | total_events = len(extracted_events)
159 |
160 | # 详细记录所有事件的评分情况
161 | logger.info(f"[{session_id}] 评分详情汇总:")
162 | for event in extracted_events:
163 | score = scores.get(event.temp_id)
164 | if score is None:
165 | logger.warning(
166 | f"[{session_id}] ❌ 事件 '{event.temp_id}' 未找到对应的评分,跳过存储"
167 | )
168 | filtered_count += 1
169 | continue
170 |
171 | event.importance_score = score
172 | logger.info(f"[{session_id}] 📊 事件 '{event.temp_id}': 得分={score:.3f}, 阈值={threshold:.3f}")
173 |
174 | if event.importance_score >= threshold:
175 | # MemoryEvent 的 id 将由存储后端自动生成,这里不需要手动创建
176 | # 我们只需要传递完整的元数据
177 | event_metadata = event.model_dump()
178 |
179 | # add_memory 返回的是新插入记录的整数 ID
180 | inserted_id = await self.faiss_manager.add_memory(
181 | content=event.memory_content,
182 | importance=event.importance_score,
183 | session_id=session_id,
184 | persona_id=persona_id,
185 | metadata=event_metadata,
186 | )
187 | stored_count += 1
188 | logger.info(
189 | f"[{session_id}] ✅ 存储记忆事件 (数据库ID: {inserted_id}, 临时ID: {event.temp_id}), 得分: {event.importance_score:.3f} >= {threshold:.3f}"
190 | )
191 | logger.debug(f"[{session_id}] 存储内容预览: {event.memory_content[:100]}...")
192 | else:
193 | filtered_count += 1
194 | logger.info(
195 | f"[{session_id}] ❌ 过滤记忆事件 '{event.temp_id}', 得分: {event.importance_score:.3f} < {threshold:.3f}"
196 | )
197 | logger.debug(f"[{session_id}] 被过滤内容: {event.memory_content}")
198 |
199 | # 最终统计信息
200 | logger.info(f"[{session_id}] 🏁 反思存储完成统计:")
201 | logger.info(f"[{session_id}] - 总提取事件数: {total_events}")
202 | logger.info(f"[{session_id}] - 成功存储数量: {stored_count}")
203 | logger.info(f"[{session_id}] - 过滤丢弃数量: {filtered_count}")
204 | logger.info(f"[{session_id}] - 存储率: {(stored_count/total_events)*100:.1f}%" if total_events > 0 else f"[{session_id}] - 存储率: 0%")
205 |
206 | if stored_count > 0:
207 | logger.info(f"[{session_id}] ✅ 成功存储 {stored_count} 个新的记忆事件")
208 | else:
209 | logger.warning(f"[{session_id}] ⚠️ 没有记忆事件达到存储阈值 {threshold},可能需要调整配置")
210 |
211 | except Exception as e:
212 | logger.error(
213 | f"[{session_id}] 在执行反思与存储任务时发生严重错误: {e}", exc_info=True
214 | )
215 |
216 | def _build_event_extraction_prompt(self) -> str:
217 | """构建用于第一阶段事件提取的系统 Prompt。"""
218 | schema = _LLMExtractionEventList.model_json_schema()
219 | base_prompt = self.config.get(
220 | "event_extraction_prompt",
221 | "### 角色\n你是一个善于分析和总结的AI助手。你的核心人设是从你自身的视角出发,记录与用户的互动和观察。\n\n### 指令/任务\n1. **仔细阅读**并理解下面提供的“对话历史”。\n2. 从**你(AI)的视角**出发,提取出多个独立的、有意义的记忆事件。事件必须准确描述,参考上下文。事件必须是完整的,具有前因后果的。**不允许编造事件**,**不允许改变事件**,**详细描述事件的所有信息**\n3. **核心要求**:\n * **第一人称视角**:所有事件都必须以“我”开头进行描述,例如“我告诉用户...”、“我观察到...”、“我被告知...”。\n * **使用具体名称**:直接使用对话中出现的人物昵称,**严禁**使用“用户”、“开发者”等通用词汇。\n * **记录互动者**:必须明确记录与你互动的用户名称。\n * **事件合并**:如果多条连续的对话构成一个完整的独立事件,应将其总结概括为一条记忆。\n4.**严禁**包含任何评分、额外的解释或说明性文字。\n 直接输出结果,不要有任何引言或总结。\n\n### 上下文\n* 在对话历史中,名为“AstrBot”的发言者就是**你自己**。\n* 记忆事件是:你与用户互动事的事件描述,详细记录谁、在何时、何地、做了什么、发生了什么。\n\n 'memory_content' 字段必须包含完整的事件描述,不能省略任何细节。\n\n单个系列事件必须详细记录在一个memory_content 中,形成完整的具有前因后果的事件记忆。\n\n",
222 | ).strip()
223 |
224 | return f"""{base_prompt}
225 | **核心指令**
226 | 1. **分析对话**: 从下面的对话历史中提取关键事件。
227 | 2. **格式化输出**: 必须返回一个符合以下 JSON Schema 的 JSON 对象。为每个事件生成一个临时的、唯一的 `temp_id` 字符串。
228 |
229 | **输出格式要求 (JSON Schema)**
230 | ```json
231 | {json.dumps(schema, indent=2)}
232 | ```
233 | """
234 |
235 | def _build_evaluation_prompt(self) -> str:
236 | """构建用于第二阶段批量评分的系统 Prompt。"""
237 | schema = _LLMScoreEvaluation.model_json_schema()
238 | base_prompt = self.config.get(
239 | "evaluation_prompt",
240 | "### 角色\n你是一个专门评估记忆价值的AI分析模型。你的判断标准是该记忆对于与特定用户构建长期、个性化、有上下文的对话有多大的帮助。\n\n### 指令/任务\n1. **评估核心价值**:仔细阅读“记忆内容”,评估其对于未来对话的长期参考价值。\n2. **输出分数**:给出一个介于 0.0 到 1.0 之间的浮点数分数。\n3. **格式要求**:**只返回数字**,严禁包含任何额外的文本、解释或理由。\n\n### 上下文\n评分时,请参考以下价值标尺:\n* **高价值 (0.8 - 1.0)**:包含用户的核心身份信息、明确且长期的个人偏好/厌恶、设定的目标、重要的关系或事实。这些信息几乎总能在未来的互动中被引用。\n * 例如:用户的昵称、职业、关键兴趣点、对AI的称呼、重要的人生目标。\n* **中等价值 (0.4 - 0.7)**:包含用户的具体建议、功能请求、对某事的观点或一次性的重要问题。这些信息在短期内或特定话题下很有用,但可能随着时间推移或问题解决而失去价值。\n * 例如:对某个功能的反馈、对特定新闻事件的看法、报告了一个具体的bug。\n* **低价值 (0.1 - 0.3)**:包含短暂的情绪表达、日常问候、或非常具体且不太可能重复的上下文。这些信息很少有再次利用的机会。\n * 例如:一次性的惊叹、害怕的反应、普通的“你好”、“晚安”。\n* **无价值 (0.0)**:信息完全是瞬时的、无关紧要的,或者不包含任何关于用户本人的可复用信息。\n * 例如:观察到另一个机器人说了话、对一句无法理解的话的默认回应。\n\n### 问题\n请评估以下“记忆内容”的重要性,对于未来的对话有多大的参考价值?\n\n---\n\n**记忆内容**:\n{memory_content}\n\n",
241 | ).strip()
242 |
243 | return f"""{base_prompt}
244 | **核心指令**
245 | 1. **分析输入**: 输入是一个包含多个记忆事件的 JSON 对象,每个事件都有一个 `temp_id` 和内容。
246 | 2. **评估重要性**: 对列表中的每一个事件,评估其对于未来对话的长期参考价值,给出一个 0.0 到 1.0 之间的分数。
247 | 3. **格式化输出**: 必须返回一个符合以下 JSON Schema 的 JSON 对象,key 是对应的 `temp_id`,value 是你给出的分数。
248 |
249 | **输出格式要求 (JSON Schema)**
250 | ```json
251 | {json.dumps(schema, indent=2)}
252 | ```
253 |
254 | **一个正确的输出示例**
255 | ```json
256 | {{
257 | "scores": {{
258 | "event_1": 0.8,
259 | "user_preference_1": 0.9,
260 | "project_goal_alpha": 0.95
261 | }}
262 | }}
263 | ```
264 | """
265 |
266 | def _format_history_for_summary(self, history: List[Dict[str, str]]) -> str:
267 | """
268 | 将对话历史列表格式化为单个字符串。
269 |
270 | Args:
271 | history (List[Dict[str, str]]): 对话历史。
272 |
273 | Returns:
274 | str: 格式化后的字符串。
275 | """
276 | if not history:
277 | return ""
278 |
279 | # 过滤掉非 user 和 assistant 的角色
280 | filtered_history = [
281 | msg for msg in history if msg.get("role") in ["user", "assistant"]
282 | ]
283 |
284 | return "\n".join(
285 | [f"{msg['role']}: {msg['content']}" for msg in filtered_history]
286 | )
287 |
--------------------------------------------------------------------------------
/core/handlers/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | handlers - 业务逻辑处理器模块
4 | 提供插件命令的具体业务逻辑实现
5 | """
6 |
7 | from .base_handler import BaseHandler
8 | from .memory_handler import MemoryHandler
9 | from .search_handler import SearchHandler
10 | from .admin_handler import AdminHandler
11 | from .fusion_handler import FusionHandler
12 |
13 | __all__ = [
14 | 'BaseHandler',
15 | 'MemoryHandler',
16 | 'SearchHandler',
17 | 'AdminHandler',
18 | 'FusionHandler'
19 | ]
--------------------------------------------------------------------------------
/core/handlers/admin_handler.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | admin_handler.py - 管理员业务逻辑
4 | 处理状态查看、配置管理、遗忘代理等管理员功能
5 | """
6 |
7 | from typing import Optional, Dict, Any
8 |
9 | from astrbot.api import logger
10 | from astrbot.api.star import Context
11 |
12 | from .base_handler import BaseHandler
13 |
14 |
15 | class AdminHandler(BaseHandler):
16 | """管理员业务逻辑处理器"""
17 |
18 | def __init__(self, context: Context, config: Dict[str, Any], faiss_manager=None, forgetting_agent=None, session_manager=None):
19 | super().__init__(context, config)
20 | self.faiss_manager = faiss_manager
21 | self.forgetting_agent = forgetting_agent
22 | self.session_manager = session_manager
23 |
24 | async def process(self, *args, **kwargs) -> Dict[str, Any]:
25 | """处理请求的抽象方法实现"""
26 | return self.create_response(True, "AdminHandler process method")
27 |
28 | async def get_memory_status(self) -> Dict[str, Any]:
29 | """获取记忆库状态"""
30 | if not self.faiss_manager or not self.faiss_manager.db:
31 | return self.create_response(False, "记忆库尚未初始化")
32 |
33 | try:
34 | count = await self.faiss_manager.db.count_documents()
35 | return self.create_response(True, "获取记忆库状态成功", {"total_count": count})
36 | except Exception as e:
37 | logger.error(f"获取记忆库状态失败: {e}", exc_info=True)
38 | return self.create_response(False, f"获取记忆库状态失败: {e}")
39 |
40 | async def delete_memory(self, doc_id: int) -> Dict[str, Any]:
41 | """删除指定记忆"""
42 | if not self.faiss_manager:
43 | return self.create_response(False, "记忆库尚未初始化")
44 |
45 | try:
46 | await self.faiss_manager.delete_memories([doc_id])
47 | return self.create_response(True, f"已成功删除 ID 为 {doc_id} 的记忆")
48 | except Exception as e:
49 | logger.error(f"删除记忆时发生错误: {e}", exc_info=True)
50 | return self.create_response(False, f"删除记忆时发生错误: {e}")
51 |
52 | async def run_forgetting_agent(self) -> Dict[str, Any]:
53 | """手动触发遗忘代理"""
54 | if not self.forgetting_agent:
55 | return self.create_response(False, "遗忘代理尚未初始化")
56 |
57 | try:
58 | await self.forgetting_agent._prune_memories()
59 | return self.create_response(True, "遗忘代理任务执行完毕")
60 | except Exception as e:
61 | logger.error(f"遗忘代理任务执行失败: {e}", exc_info=True)
62 | return self.create_response(False, f"遗忘代理任务执行失败: {e}")
63 |
64 | async def set_search_mode(self, mode: str) -> Dict[str, Any]:
65 | """设置检索模式"""
66 | valid_modes = ["hybrid", "dense", "sparse"]
67 | if mode not in valid_modes:
68 | return self.create_response(False, f"无效的模式,请使用: {', '.join(valid_modes)}")
69 |
70 | # 注意:这个方法需要 recall_engine 实例,暂时通过 config 传递
71 | # 实际使用时需要在调用此方法前传入 recall_engine
72 | return self.create_response(True, f"检索模式已设置为: {mode}")
73 |
74 | async def get_config_summary(self, action: str = "show") -> Dict[str, Any]:
75 | """获取配置摘要或验证配置"""
76 | if action == "show":
77 | try:
78 | # 显示主要配置项
79 | config_summary = {
80 | "session_manager": {
81 | "max_sessions": self.config.get("session_manager", {}).get("max_sessions", 1000),
82 | "session_ttl": self.config.get("session_manager", {}).get("session_ttl", 3600),
83 | "current_sessions": self.session_manager.get_session_count() if self.session_manager else 0
84 | },
85 | "recall_engine": {
86 | "retrieval_mode": self.config.get("recall_engine", {}).get("retrieval_mode", "hybrid"),
87 | "top_k": self.config.get("recall_engine", {}).get("top_k", 5),
88 | "recall_strategy": self.config.get("recall_engine", {}).get("recall_strategy", "weighted")
89 | },
90 | "reflection_engine": {
91 | "summary_trigger_rounds": self.config.get("reflection_engine", {}).get("summary_trigger_rounds", 10),
92 | "importance_threshold": self.config.get("reflection_engine", {}).get("importance_threshold", 0.5)
93 | },
94 | "forgetting_agent": {
95 | "enabled": self.config.get("forgetting_agent", {}).get("enabled", True),
96 | "check_interval_hours": self.config.get("forgetting_agent", {}).get("check_interval_hours", 24),
97 | "retention_days": self.config.get("forgetting_agent", {}).get("retention_days", 90)
98 | }
99 | }
100 |
101 | return self.create_response(True, "获取配置摘要成功", config_summary)
102 |
103 | except Exception as e:
104 | return self.create_response(False, f"显示配置时发生错误: {e}")
105 |
106 | elif action == "validate":
107 | try:
108 | from ..config_validator import validate_config
109 | # 重新验证当前配置
110 | validate_config(self.config)
111 | return self.create_response(True, "配置验证通过,所有参数均有效")
112 |
113 | except Exception as e:
114 | return self.create_response(False, f"配置验证失败: {e}")
115 |
116 | else:
117 | return self.create_response(False, "无效的动作,请使用 'show' 或 'validate'")
118 |
119 | def format_status_for_display(self, response: Dict[str, Any]) -> str:
120 | """格式化状态信息用于显示"""
121 | if not response.get("success"):
122 | return response.get("message", "获取失败")
123 |
124 | data = response.get("data", {})
125 | total_count = data.get("total_count", 0)
126 |
127 | return f"📊 LivingMemory 记忆库状态:\n- 总记忆数: {total_count}"
128 |
129 | def format_config_summary_for_display(self, response: Dict[str, Any]) -> str:
130 | """格式化配置摘要用于显示"""
131 | if not response.get("success"):
132 | return response.get("message", "获取失败")
133 |
134 | data = response.get("data", {})
135 |
136 | config_summary = ["📋 LivingMemory 配置摘要:"]
137 | config_summary.append("")
138 |
139 | # 会话管理器配置
140 | sm_config = data.get("session_manager", {})
141 | config_summary.append(f"🗂️ 会话管理:")
142 | config_summary.append(f" - 最大会话数: {sm_config.get('max_sessions', 1000)}")
143 | config_summary.append(f" - 会话TTL: {sm_config.get('session_ttl', 3600)}秒")
144 | config_summary.append(f" - 当前会话数: {sm_config.get('current_sessions', 0)}")
145 | config_summary.append("")
146 |
147 | # 回忆引擎配置
148 | re_config = data.get("recall_engine", {})
149 | config_summary.append(f"🧠 回忆引擎:")
150 | config_summary.append(f" - 检索模式: {re_config.get('retrieval_mode', 'hybrid')}")
151 | config_summary.append(f" - 返回数量: {re_config.get('top_k', 5)}")
152 | config_summary.append(f" - 召回策略: {re_config.get('recall_strategy', 'weighted')}")
153 | config_summary.append("")
154 |
155 | # 反思引擎配置
156 | rf_config = data.get("reflection_engine", {})
157 | config_summary.append(f"💭 反思引擎:")
158 | config_summary.append(f" - 触发轮次: {rf_config.get('summary_trigger_rounds', 10)}")
159 | config_summary.append(f" - 重要性阈值: {rf_config.get('importance_threshold', 0.5)}")
160 | config_summary.append("")
161 |
162 | # 遗忘代理配置
163 | fa_config = data.get("forgetting_agent", {})
164 | config_summary.append(f"🗑️ 遗忘代理:")
165 | config_summary.append(f" - 启用状态: {'是' if fa_config.get('enabled', True) else '否'}")
166 | config_summary.append(f" - 检查间隔: {fa_config.get('check_interval_hours', 24)}小时")
167 | config_summary.append(f" - 保留天数: {fa_config.get('retention_days', 90)}天")
168 |
169 | return "\n".join(config_summary)
--------------------------------------------------------------------------------
/core/handlers/base_handler.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | base_handler.py - 基础处理器类
4 | 提供业务逻辑处理的基础功能和通用方法
5 | """
6 |
7 | from abc import ABC, abstractmethod
8 | from typing import Optional, Dict, Any, List
9 | from datetime import datetime, timezone
10 | import json
11 |
12 | from astrbot.api import logger
13 | from astrbot.api.star import Context
14 |
15 |
16 | class BaseHandler(ABC):
17 | """基础处理器类,提供通用的业务逻辑功能"""
18 |
19 | def __init__(self, context: Context, config: Dict[str, Any]):
20 | self.context = context
21 | self.config = config
22 |
23 | @abstractmethod
24 | async def process(self, *args, **kwargs) -> Dict[str, Any]:
25 | """处理请求的抽象方法"""
26 | pass
27 |
28 | def get_timezone(self) -> Any:
29 | """获取当前时区"""
30 | tz_config = self.config.get("timezone_settings", {})
31 | tz_str = tz_config.get("timezone", "Asia/Shanghai")
32 | try:
33 | import pytz
34 | return pytz.timezone(tz_str)
35 | except ImportError:
36 | # 如果pytz不可用,返回UTC
37 | from datetime import timezone
38 | return timezone.utc
39 |
40 | def format_timestamp(self, ts: Optional[float]) -> str:
41 | """格式化时间戳"""
42 | if not ts:
43 | return "未知"
44 | try:
45 | dt_utc = datetime.fromtimestamp(float(ts), tz=timezone.utc)
46 | dt_local = dt_utc.astimezone(self.get_timezone())
47 | return dt_local.strftime("%Y-%m-%d %H:%M:%S")
48 | except (ValueError, TypeError):
49 | return "未知"
50 |
51 | def safe_parse_metadata(self, metadata: Any) -> Dict[str, Any]:
52 | """安全解析元数据"""
53 | if isinstance(metadata, dict):
54 | return metadata
55 | if isinstance(metadata, str):
56 | try:
57 | return json.loads(metadata)
58 | except json.JSONDecodeError:
59 | return {}
60 | return {}
61 |
62 | def format_memory_card(self, result: Any) -> str:
63 | """格式化记忆卡片显示"""
64 | metadata = self.safe_parse_metadata(result.data.get("metadata", {}))
65 |
66 | create_time_str = self.format_timestamp(metadata.get("create_time"))
67 | last_access_time_str = self.format_timestamp(metadata.get("last_access_time"))
68 | importance_score = metadata.get("importance", 0.0)
69 | event_type = metadata.get("event_type", "未知")
70 |
71 | card = (
72 | f"ID: {result.data['id']}\n"
73 | f"记 忆 度: {result.similarity:.2f}\n"
74 | f"重 要 性: {importance_score:.2f}\n"
75 | f"记忆类型: {event_type}\n\n"
76 | f"内容: {result.data['text']}\n\n"
77 | f"创建于: {create_time_str}\n"
78 | f"最后访问: {last_access_time_str}"
79 | )
80 | return card
81 |
82 | def create_response(self, success: bool = True, message: str = "", data: Any = None) -> Dict[str, Any]:
83 | """创建标准响应格式"""
84 | return {
85 | "success": success,
86 | "message": message,
87 | "data": data
88 | }
89 |
90 |
91 | class TestableBaseHandler(BaseHandler):
92 | """用于测试的基础处理器实现"""
93 |
94 | async def process(self, *args, **kwargs) -> Dict[str, Any]:
95 | """测试用的process方法实现"""
96 | return self.create_response(True, "Test response")
--------------------------------------------------------------------------------
/core/handlers/fusion_handler.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | fusion_handler.py - 融合策略业务逻辑
4 | 处理检索融合策略的管理和测试
5 | """
6 |
7 | from typing import Optional, Dict, Any, List
8 |
9 | from astrbot.api import logger
10 | from astrbot.api.star import Context
11 |
12 | from .base_handler import BaseHandler
13 |
14 |
15 | class FusionHandler(BaseHandler):
16 | """融合策略业务逻辑处理器"""
17 |
18 | def __init__(self, context: Context, config: Dict[str, Any], recall_engine=None):
19 | super().__init__(context, config)
20 | self.recall_engine = recall_engine
21 |
22 | async def process(self, *args, **kwargs) -> Dict[str, Any]:
23 | """处理请求的抽象方法实现"""
24 | return self.create_response(True, "FusionHandler process method")
25 |
26 | async def manage_fusion_strategy(self, strategy: str = "show", param: str = "") -> Dict[str, Any]:
27 | """管理检索融合策略"""
28 | if not self.recall_engine:
29 | return self.create_response(False, "回忆引擎尚未初始化")
30 |
31 | if strategy == "show":
32 | # 显示当前融合配置
33 | fusion_config = self.config.get("fusion", {})
34 | current_strategy = fusion_config.get("strategy", "rrf")
35 |
36 | config_data = {
37 | "current_strategy": current_strategy,
38 | "fusion_config": fusion_config
39 | }
40 |
41 | return self.create_response(True, "获取融合配置成功", config_data)
42 |
43 | elif strategy in ["rrf", "hybrid_rrf", "weighted", "convex", "interleave",
44 | "rank_fusion", "score_fusion", "cascade", "adaptive"]:
45 |
46 | # 更新融合策略
47 | if "fusion" not in self.config:
48 | self.config["fusion"] = {}
49 |
50 | old_strategy = self.config["fusion"].get("strategy", "rrf")
51 | self.config["fusion"]["strategy"] = strategy
52 |
53 | # 处理参数
54 | param_result = await self._process_fusion_param(param, strategy)
55 | if not param_result["success"]:
56 | return param_result
57 |
58 | # 更新 RecallEngine 中的融合配置
59 | update_result = await self._update_recall_engine_fusion_config(strategy, self.config["fusion"])
60 | if not update_result["success"]:
61 | return update_result
62 |
63 | return self.create_response(True, f"融合策略已从 '{old_strategy}' 更新为 '{strategy}'{f' (参数: {param})' if param else ''}")
64 |
65 | else:
66 | return self.create_response(False, "不支持的融合策略。使用 show 查看可用选项。")
67 |
68 | async def test_fusion_strategy(self, query: str, k: int = 5) -> Dict[str, Any]:
69 | """测试融合策略效果"""
70 | if not self.recall_engine:
71 | return self.create_response(False, "回忆引擎尚未初始化")
72 |
73 | try:
74 | # 执行搜索
75 | session_id = await self.context.conversation_manager.get_curr_conversation_id(None)
76 | from ..utils import get_persona_id
77 | persona_id = await get_persona_id(self.context, None)
78 |
79 | results = await self.recall_engine.recall(
80 | self.context, query, session_id, persona_id, k
81 | )
82 |
83 | if not results:
84 | return self.create_response(True, "未找到相关记忆", [])
85 |
86 | # 格式化结果
87 | formatted_results = []
88 | fusion_config = self.config.get("fusion", {})
89 | current_strategy = fusion_config.get("strategy", "rrf")
90 |
91 | for result in results:
92 | metadata = self.safe_parse_metadata(result.data.get("metadata", {}))
93 | formatted_results.append({
94 | "id": result.data['id'],
95 | "similarity": result.similarity,
96 | "text": result.data['text'],
97 | "importance": metadata.get("importance", 0.0),
98 | "event_type": metadata.get("event_type", "未知")
99 | })
100 |
101 | test_data = {
102 | "query": query,
103 | "strategy": current_strategy,
104 | "fusion_config": fusion_config,
105 | "results": formatted_results
106 | }
107 |
108 | return self.create_response(True, f"融合测试完成,找到 {len(results)} 条结果", test_data)
109 |
110 | except Exception as e:
111 | logger.error(f"融合策略测试失败: {e}", exc_info=True)
112 | return self.create_response(False, f"测试失败: {e}")
113 |
114 | async def _process_fusion_param(self, param: str, strategy: str) -> Dict[str, Any]:
115 | """处理融合策略参数"""
116 | if not param or "=" not in param:
117 | return self.create_response(True, "无参数需要处理")
118 |
119 | try:
120 | key, value = param.split("=", 1)
121 | key = key.strip()
122 | value = value.strip()
123 |
124 | # 验证参数名
125 | valid_params = {
126 | "dense_weight", "sparse_weight", "rrf_k", "convex_lambda",
127 | "interleave_ratio", "rank_bias_factor", "diversity_bonus"
128 | }
129 |
130 | if key not in valid_params:
131 | return self.create_response(False, f"无效的参数名: {key}。支持的参数: {', '.join(sorted(valid_params))}")
132 |
133 | # 验证参数值
134 | try:
135 | if key in ["dense_weight", "sparse_weight", "convex_lambda", "interleave_ratio", "rank_bias_factor", "diversity_bonus"]:
136 | param_value = float(value)
137 | else:
138 | param_value = int(value)
139 | except ValueError:
140 | return self.create_response(False, f"参数 {key} 的值类型无效: {value}")
141 |
142 | # 参数范围和约束检查
143 | param_constraints = {
144 | "dense_weight": (0.0, 1.0, "必须在 0.0-1.0 范围内"),
145 | "sparse_weight": (0.0, 1.0, "必须在 0.0-1.0 范围内"),
146 | "convex_lambda": (0.0, 1.0, "必须在 0.0-1.0 范围内"),
147 | "interleave_ratio": (0.0, 1.0, "必须在 0.0-1.0 范围内"),
148 | "rank_bias_factor": (0.0, 1.0, "必须在 0.0-1.0 范围内"),
149 | "diversity_bonus": (0.0, 1.0, "必须在 0.0-1.0 范围内"),
150 | "rrf_k": (1, 1000, "必须是正整数")
151 | }
152 |
153 | if key in param_constraints:
154 | min_val, max_val, error_msg = param_constraints[key]
155 | if not min_val <= param_value <= max_val:
156 | return self.create_response(False, f"参数 {key} {error_msg}")
157 |
158 | # 策略特定参数验证
159 | strategy_params = {
160 | "rrf": ["rrf_k"],
161 | "hybrid_rrf": ["rrf_k", "diversity_bonus"],
162 | "weighted": ["dense_weight", "sparse_weight"],
163 | "convex": ["dense_weight", "sparse_weight", "convex_lambda"],
164 | "interleave": ["interleave_ratio"],
165 | "rank_fusion": ["dense_weight", "sparse_weight", "rank_bias_factor"],
166 | "score_fusion": ["dense_weight", "sparse_weight"],
167 | "cascade": ["dense_weight", "sparse_weight"],
168 | "adaptive": ["dense_weight", "sparse_weight"]
169 | }
170 |
171 | if strategy in strategy_params and key not in strategy_params[strategy]:
172 | return self.create_response(False, f"参数 {key} 不适用于策略 {strategy}")
173 |
174 | # 权重和检查(对于需要权重的策略)
175 | if key in ["dense_weight", "sparse_weight"]:
176 | other_key = "sparse_weight" if key == "dense_weight" else "dense_weight"
177 | other_value = self.config["fusion"].get(other_key, 0.3 if other_key == "sparse_weight" else 0.7)
178 |
179 | # 如果设置了新的权重,检查和是否超过1.0
180 | if key + other_key in [k for k in strategy_params.get(strategy, []) if k in ["dense_weight", "sparse_weight"]]:
181 | total_weight = param_value + other_value
182 | if total_weight > 1.0:
183 | return self.create_response(False, f"权重总和不能超过 1.0 (当前总和: {total_weight:.2f})")
184 |
185 | self.config["fusion"][key] = param_value
186 | logger.info(f"更新融合参数 {key} = {param_value}")
187 |
188 | return self.create_response(True, "参数处理成功")
189 |
190 | except Exception as e:
191 | return self.create_response(False, f"参数解析错误: {e}")
192 |
193 | async def _update_recall_engine_fusion_config(self, strategy: str, fusion_config: Dict[str, Any]) -> Dict[str, Any]:
194 | """更新RecallEngine的融合配置"""
195 | try:
196 | if hasattr(self.recall_engine, 'result_fusion'):
197 | self.recall_engine.update_fusion_config(strategy, fusion_config)
198 | else:
199 | logger.warning("RecallEngine 没有 result_fusion 属性,跳过更新")
200 |
201 | return self.create_response(True, "融合配置更新成功")
202 | except AttributeError:
203 | # 如果 RecallEngine 没有 update_fusion_config 方法,则直接更新属性
204 | try:
205 | if hasattr(self.recall_engine, 'result_fusion'):
206 | fusion_obj = self.recall_engine.result_fusion
207 | fusion_obj.strategy = strategy
208 | fusion_obj.config = fusion_config
209 |
210 | # 更新融合器的参数
211 | fusion_obj.dense_weight = fusion_config.get("dense_weight", 0.7)
212 | fusion_obj.sparse_weight = fusion_config.get("sparse_weight", 0.3)
213 | fusion_obj.rrf_k = fusion_config.get("rrf_k", 60)
214 | fusion_obj.convex_lambda = fusion_config.get("convex_lambda", 0.5)
215 | fusion_obj.interleave_ratio = fusion_config.get("interleave_ratio", 0.5)
216 | fusion_obj.rank_bias_factor = fusion_config.get("rank_bias_factor", 0.1)
217 |
218 | return self.create_response(True, "融合配置更新成功")
219 | except Exception as e:
220 | logger.error(f"更新融合配置时出错: {e}")
221 | return self.create_response(False, f"配置已更新,但引擎同步可能失败: {e}")
222 | except Exception as e:
223 | logger.error(f"更新融合配置时出错: {e}")
224 | return self.create_response(False, f"更新融合配置失败: {e}")
225 |
226 | def format_fusion_config_for_display(self, response: Dict[str, Any]) -> str:
227 | """格式化融合配置用于显示"""
228 | if not response.get("success"):
229 | return response.get("message", "获取失败")
230 |
231 | data = response.get("data", {})
232 | current_strategy = data.get("current_strategy", "rrf")
233 | fusion_config = data.get("fusion_config", {})
234 |
235 | response_parts = ["🔄 当前检索融合配置:"]
236 | response_parts.append(f"策略: {current_strategy}")
237 | response_parts.append("")
238 |
239 | if current_strategy in ["rrf", "hybrid_rrf"]:
240 | response_parts.append(f"RRF参数k: {fusion_config.get('rrf_k', 60)}")
241 | if current_strategy == "hybrid_rrf":
242 | response_parts.append(f"多样性奖励: {fusion_config.get('diversity_bonus', 0.1)}")
243 |
244 | if current_strategy in ["weighted", "convex", "rank_fusion", "score_fusion"]:
245 | response_parts.append(f"密集权重: {fusion_config.get('dense_weight', 0.7)}")
246 | response_parts.append(f"稀疏权重: {fusion_config.get('sparse_weight', 0.3)}")
247 |
248 | if current_strategy == "convex":
249 | response_parts.append(f"凸组合λ: {fusion_config.get('convex_lambda', 0.5)}")
250 |
251 | if current_strategy == "interleave":
252 | response_parts.append(f"交替比例: {fusion_config.get('interleave_ratio', 0.5)}")
253 |
254 | if current_strategy == "rank_fusion":
255 | response_parts.append(f"排序偏置: {fusion_config.get('rank_bias_factor', 0.1)}")
256 |
257 | response_parts.append("")
258 | response_parts.append("💡 各策略特点:")
259 | response_parts.append("• rrf: 经典方法,平衡性好")
260 | response_parts.append("• hybrid_rrf: 动态调整,适应查询类型")
261 | response_parts.append("• weighted: 简单加权,可解释性强")
262 | response_parts.append("• convex: 凸组合,数学严格")
263 | response_parts.append("• interleave: 交替选择,保证多样性")
264 | response_parts.append("• rank_fusion: 基于排序位置")
265 | response_parts.append("• score_fusion: Borda Count投票")
266 | response_parts.append("• cascade: 稀疏初筛+密集精排")
267 | response_parts.append("• adaptive: 根据查询自适应")
268 |
269 | return "\n".join(response_parts)
270 |
271 | def format_fusion_test_for_display(self, response: Dict[str, Any]) -> str:
272 | """格式化融合测试结果用于显示"""
273 | if not response.get("success"):
274 | return response.get("message", "测试失败")
275 |
276 | data = response.get("data", {})
277 | query = data.get("query", "")
278 | strategy = data.get("strategy", "rrf")
279 | fusion_config = data.get("fusion_config", {})
280 | results = data.get("results", [])
281 |
282 | response_parts = [f"🎯 融合测试结果 (策略: {strategy})"]
283 | response_parts.append("=" * 50)
284 |
285 | for i, result in enumerate(results, 1):
286 | response_parts.append(f"\n{i}. [ID: {result['id']}] 分数: {result['similarity']:.4f}")
287 | response_parts.append(f" 重要性: {result['importance']:.3f} | 类型: {result['event_type']}")
288 | response_parts.append(f" 内容: {result['text'][:100]}{'...' if len(result['text']) > 100 else ''}")
289 |
290 | response_parts.append("\n" + "=" * 50)
291 | response_parts.append(f"💡 当前融合配置:")
292 | response_parts.append(f" 策略: {strategy}")
293 | if strategy in ["rrf", "hybrid_rrf"]:
294 | response_parts.append(f" RRF-k: {fusion_config.get('rrf_k', 60)}")
295 | if strategy in ["weighted", "convex"]:
296 | response_parts.append(f" 密集权重: {fusion_config.get('dense_weight', 0.7)}")
297 | response_parts.append(f" 稀疏权重: {fusion_config.get('sparse_weight', 0.3)}")
298 |
299 | return "\n".join(response_parts)
--------------------------------------------------------------------------------
/core/handlers/memory_handler.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | memory_handler.py - 记忆管理业务逻辑
4 | 处理记忆的编辑、更新、历史查看等业务逻辑
5 | """
6 |
7 | import json
8 | from typing import Optional, Dict, Any, List
9 | from datetime import datetime, timezone
10 |
11 | from astrbot.api import logger
12 | from astrbot.api.star import Context
13 |
14 | from .base_handler import BaseHandler
15 |
16 |
17 | class MemoryHandler(BaseHandler):
18 | """记忆管理业务逻辑处理器"""
19 |
20 | def __init__(self, context: Context, config: Dict[str, Any], faiss_manager):
21 | super().__init__(context, config)
22 | self.faiss_manager = faiss_manager
23 |
24 | async def process(self, *args, **kwargs) -> Dict[str, Any]:
25 | """处理请求的抽象方法实现"""
26 | return self.create_response(True, "MemoryHandler process method")
27 |
28 | async def edit_memory(self, memory_id: str, field: str, value: str, reason: str = "") -> Dict[str, Any]:
29 | """编辑记忆内容或元数据"""
30 | if not self.faiss_manager:
31 | return self.create_response(False, "记忆库尚未初始化")
32 |
33 | try:
34 | # 解析 memory_id 为整数或字符串
35 | try:
36 | memory_id_int = int(memory_id)
37 | memory_id_to_use = memory_id_int
38 | except ValueError:
39 | memory_id_to_use = memory_id
40 |
41 | # 解析字段和值
42 | updates = {}
43 |
44 | if field == "content":
45 | updates["content"] = value
46 | elif field == "importance":
47 | try:
48 | updates["importance"] = float(value)
49 | if not 0.0 <= updates["importance"] <= 1.0:
50 | return self.create_response(False, "重要性评分必须在 0.0 到 1.0 之间")
51 | except ValueError:
52 | return self.create_response(False, "重要性评分必须是数字")
53 | elif field == "type":
54 | valid_types = ["FACT", "PREFERENCE", "GOAL", "OPINION", "RELATIONSHIP", "OTHER"]
55 | if value not in valid_types:
56 | return self.create_response(False, f"无效的事件类型,必须是: {', '.join(valid_types)}")
57 | updates["event_type"] = value
58 | elif field == "status":
59 | valid_statuses = ["active", "archived", "deleted"]
60 | if value not in valid_statuses:
61 | return self.create_response(False, f"无效的状态,必须是: {', '.join(valid_statuses)}")
62 | updates["status"] = value
63 | else:
64 | return self.create_response(False, f"未知的字段 '{field}',支持的字段: content, importance, type, status")
65 |
66 | # 执行更新
67 | result = await self.faiss_manager.update_memory(
68 | memory_id=memory_id_to_use,
69 | update_reason=reason or f"更新{field}",
70 | **updates
71 | )
72 |
73 | if result["success"]:
74 | # 构建响应消息
75 | response_parts = [f"✅ {result['message']}"]
76 |
77 | if result["updated_fields"]:
78 | response_parts.append("\n📋 已更新的字段:")
79 | for f in result["updated_fields"]:
80 | response_parts.append(f" - {f}")
81 |
82 | # 如果更新了内容,显示预览
83 | if "content" in updates and len(updates["content"]) > 100:
84 | response_parts.append(f"\n📝 内容预览: {updates['content'][:100]}...")
85 |
86 | return self.create_response(True, "\n".join(response_parts), result)
87 | else:
88 | return self.create_response(False, result['message'])
89 |
90 | except Exception as e:
91 | logger.error(f"编辑记忆时发生错误: {e}", exc_info=True)
92 | return self.create_response(False, f"编辑记忆时发生错误: {e}")
93 |
94 | async def get_memory_details(self, memory_id: str) -> Dict[str, Any]:
95 | """获取记忆详细信息"""
96 | if not self.faiss_manager:
97 | return self.create_response(False, "记忆库尚未初始化")
98 |
99 | try:
100 | # 解析 memory_id
101 | try:
102 | memory_id_int = int(memory_id)
103 | docs = await self.faiss_manager.db.document_storage.get_documents(ids=[memory_id_int])
104 | except ValueError:
105 | docs = await self.faiss_manager.db.document_storage.get_documents(
106 | metadata_filters={"memory_id": memory_id}
107 | )
108 |
109 | if not docs:
110 | return self.create_response(False, f"未找到ID为 {memory_id} 的记忆")
111 |
112 | doc = docs[0]
113 | metadata = self.safe_parse_metadata(doc["metadata"])
114 |
115 | # 构建详细信息
116 | details = {
117 | "id": memory_id,
118 | "content": doc["content"],
119 | "metadata": metadata,
120 | "create_time": self.format_timestamp(metadata.get("create_time")),
121 | "last_access_time": self.format_timestamp(metadata.get("last_access_time")),
122 | "importance": metadata.get("importance", "N/A"),
123 | "event_type": metadata.get("event_type", "N/A"),
124 | "status": metadata.get("status", "active"),
125 | "update_history": metadata.get("update_history", [])
126 | }
127 |
128 | return self.create_response(True, "获取记忆详细信息成功", details)
129 |
130 | except Exception as e:
131 | logger.error(f"获取记忆详细信息时发生错误: {e}", exc_info=True)
132 | return self.create_response(False, f"获取记忆详细信息时发生错误: {e}")
133 |
134 | async def get_memory_history(self, memory_id: str) -> Dict[str, Any]:
135 | """获取记忆更新历史"""
136 | if not self.faiss_manager or not self.faiss_manager.db:
137 | return self.create_response(False, "记忆库尚未初始化")
138 |
139 | try:
140 | # 解析 memory_id
141 | try:
142 | memory_id_int = int(memory_id)
143 | docs = await self.faiss_manager.db.document_storage.get_documents(ids=[memory_id_int])
144 | except ValueError:
145 | docs = await self.faiss_manager.db.document_storage.get_documents(
146 | metadata_filters={"memory_id": memory_id}
147 | )
148 |
149 | if not docs:
150 | return self.create_response(False, f"未找到ID为 {memory_id} 的记忆")
151 |
152 | doc = docs[0]
153 | metadata = self.safe_parse_metadata(doc["metadata"])
154 |
155 | # 构建历史信息
156 | history_info = {
157 | "id": memory_id,
158 | "content": doc["content"],
159 | "metadata": {
160 | "importance": metadata.get("importance", "N/A"),
161 | "event_type": metadata.get("event_type", "N/A"),
162 | "status": metadata.get("status", "active"),
163 | "create_time": self.format_timestamp(metadata.get("create_time"))
164 | },
165 | "update_history": metadata.get("update_history", [])
166 | }
167 |
168 | return self.create_response(True, "获取记忆历史成功", history_info)
169 |
170 | except Exception as e:
171 | logger.error(f"获取记忆历史时发生错误: {e}", exc_info=True)
172 | return self.create_response(False, f"获取记忆历史时发生错误: {e}")
173 |
174 | def format_memory_details_for_display(self, details: Dict[str, Any]) -> str:
175 | """格式化记忆详细信息用于显示"""
176 | if not details.get("success"):
177 | return details.get("message", "获取失败")
178 |
179 | data = details.get("data", {})
180 | response_parts = [f"📝 记忆 {data['id']} 的详细信息:"]
181 | response_parts.append("=" * 50)
182 |
183 | # 内容
184 | response_parts.append(f"\n📄 内容:")
185 | response_parts.append(f"{data['content']}")
186 |
187 | # 基本信息
188 | response_parts.append(f"\n📊 基本信息:")
189 | response_parts.append(f"- ID: {data['id']}")
190 | response_parts.append(f"- 重要性: {data['importance']}")
191 | response_parts.append(f"- 类型: {data['event_type']}")
192 | response_parts.append(f"- 状态: {data['status']}")
193 |
194 | # 时间信息
195 | if data['create_time'] != "未知":
196 | response_parts.append(f"- 创建时间: {data['create_time']}")
197 | if data['last_access_time'] != "未知":
198 | response_parts.append(f"- 最后访问: {data['last_access_time']}")
199 |
200 | # 更新历史
201 | update_history = data.get('update_history', [])
202 | if update_history:
203 | response_parts.append(f"\n🔄 更新历史 ({len(update_history)} 次):")
204 | for i, update in enumerate(update_history[-3:], 1): # 只显示最近3次
205 | timestamp = update.get('timestamp')
206 | if timestamp:
207 | time_str = self.format_timestamp(timestamp)
208 | else:
209 | time_str = "未知"
210 |
211 | response_parts.append(f"\n{i}. {time_str}")
212 | response_parts.append(f" 原因: {update.get('reason', 'N/A')}")
213 | response_parts.append(f" 字段: {', '.join(update.get('fields', []))}")
214 |
215 | # 编辑指引
216 | response_parts.append(f"\n" + "=" * 50)
217 | response_parts.append(f"\n🛠️ 编辑指引:")
218 | response_parts.append(f"使用以下命令编辑此记忆:")
219 | response_parts.append(f"\n• 编辑内容:")
220 | response_parts.append(f" /lmem edit {data['id']} content <新内容> [原因]")
221 | response_parts.append(f"\n• 编辑重要性:")
222 | response_parts.append(f" /lmem edit {data['id']} importance <0.0-1.0> [原因]")
223 | response_parts.append(f"\n• 编辑类型:")
224 | response_parts.append(f" /lmem edit {data['id']} type [原因]")
225 | response_parts.append(f"\n• 编辑状态:")
226 | response_parts.append(f" /lmem edit {data['id']} status [原因]")
227 |
228 | # 示例
229 | response_parts.append(f"\n💡 示例:")
230 | response_parts.append(f" /lmem edit {data['id']} importance 0.9 提高重要性评分")
231 | response_parts.append(f" /lmem edit {data['id']} type PREFERENCE 重新分类为偏好")
232 |
233 | return "\n".join(response_parts)
234 |
235 | def format_memory_history_for_display(self, history: Dict[str, Any]) -> str:
236 | """格式化记忆历史用于显示"""
237 | if not history.get("success"):
238 | return history.get("message", "获取失败")
239 |
240 | data = history.get("data", {})
241 | metadata = data.get("metadata", {})
242 |
243 | response_parts = [f"📝 记忆 {data['id']} 的详细信息:"]
244 | response_parts.append(f"\n内容: {data['content']}")
245 |
246 | # 基本信息
247 | response_parts.append(f"\n📊 基本信息:")
248 | response_parts.append(f"- 重要性: {metadata['importance']}")
249 | response_parts.append(f"- 类型: {metadata['event_type']}")
250 | response_parts.append(f"- 状态: {metadata['status']}")
251 |
252 | # 时间信息
253 | if metadata.get('create_time') != "未知":
254 | response_parts.append(f"- 创建时间: {metadata['create_time']}")
255 |
256 | # 更新历史
257 | update_history = data.get('update_history', [])
258 | if update_history:
259 | response_parts.append(f"\n🔄 更新历史 ({len(update_history)} 次):")
260 | for i, update in enumerate(update_history[-5:], 1): # 只显示最近5次
261 | timestamp = update.get('timestamp')
262 | if timestamp:
263 | time_str = self.format_timestamp(timestamp)
264 | else:
265 | time_str = "未知"
266 |
267 | response_parts.append(f"\n{i}. {time_str}")
268 | response_parts.append(f" 原因: {update.get('reason', 'N/A')}")
269 | response_parts.append(f" 字段: {', '.join(update.get('fields', []))}")
270 | else:
271 | response_parts.append("\n🔄 暂无更新记录")
272 |
273 | return "\n".join(response_parts)
--------------------------------------------------------------------------------
/core/handlers/search_handler.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | search_handler.py - 搜索管理业务逻辑
4 | 处理记忆搜索、稀疏检索测试等业务逻辑
5 | """
6 |
7 | import json
8 | from typing import Optional, Dict, Any, List
9 |
10 | from astrbot.api import logger
11 | from astrbot.api.star import Context
12 |
13 | from .base_handler import BaseHandler
14 |
15 |
16 | class SearchHandler(BaseHandler):
17 | """搜索管理业务逻辑处理器"""
18 |
19 | def __init__(self, context: Context, config: Dict[str, Any], recall_engine=None, sparse_retriever=None):
20 | super().__init__(context, config)
21 | self.recall_engine = recall_engine
22 | self.sparse_retriever = sparse_retriever
23 |
24 | async def process(self, *args, **kwargs) -> Dict[str, Any]:
25 | """处理请求的抽象方法实现"""
26 | return self.create_response(True, "SearchHandler process method")
27 |
28 | async def search_memories(self, query: str, k: int = 3) -> Dict[str, Any]:
29 | """搜索记忆"""
30 | if not self.recall_engine:
31 | return self.create_response(False, "回忆引擎尚未初始化")
32 |
33 | try:
34 | results = await self.recall_engine.recall(self.context, query, k=k)
35 |
36 | if not results:
37 | return self.create_response(True, f"未能找到与 '{query}' 相关的记忆", [])
38 |
39 | # 格式化搜索结果
40 | formatted_results = []
41 | for res in results:
42 | formatted_results.append({
43 | "id": res.data['id'],
44 | "similarity": res.similarity,
45 | "text": res.data['text'],
46 | "metadata": self.safe_parse_metadata(res.data.get("metadata", {}))
47 | })
48 |
49 | return self.create_response(True, f"为您找到 {len(results)} 条相关记忆", formatted_results)
50 |
51 | except Exception as e:
52 | logger.error(f"搜索记忆时发生错误: {e}", exc_info=True)
53 | return self.create_response(False, f"搜索记忆时发生错误: {e}")
54 |
55 | async def test_sparse_search(self, query: str, k: int = 5) -> Dict[str, Any]:
56 | """测试稀疏检索功能"""
57 | if not self.sparse_retriever:
58 | return self.create_response(False, "稀疏检索器未启用")
59 |
60 | try:
61 | results = await self.sparse_retriever.search(query=query, limit=k)
62 |
63 | if not results:
64 | return self.create_response(True, f"未找到与 '{query}' 相关的记忆", [])
65 |
66 | # 格式化搜索结果
67 | formatted_results = []
68 | for res in results:
69 | formatted_results.append({
70 | "doc_id": res.doc_id,
71 | "score": res.score,
72 | "content": res.content,
73 | "metadata": res.metadata
74 | })
75 |
76 | return self.create_response(True, f"找到 {len(results)} 条稀疏检索结果", formatted_results)
77 |
78 | except Exception as e:
79 | logger.error(f"稀疏检索测试失败: {e}", exc_info=True)
80 | return self.create_response(False, f"稀疏检索测试失败: {e}")
81 |
82 | async def rebuild_sparse_index(self) -> Dict[str, Any]:
83 | """重建稀疏检索索引"""
84 | if not self.sparse_retriever:
85 | return self.create_response(False, "稀疏检索器未启用")
86 |
87 | try:
88 | await self.sparse_retriever.rebuild_index()
89 | return self.create_response(True, "稀疏检索索引重建完成")
90 | except Exception as e:
91 | logger.error(f"重建稀疏索引失败: {e}", exc_info=True)
92 | return self.create_response(False, f"重建稀疏索引失败: {e}")
93 |
94 | def format_search_results_for_display(self, response: Dict[str, Any]) -> str:
95 | """格式化搜索结果用于显示"""
96 | if not response.get("success"):
97 | return response.get("message", "搜索失败")
98 |
99 | data = response.get("data", [])
100 | message = response.get("message", "")
101 |
102 | response_parts = [message]
103 |
104 | for res in data:
105 | metadata = res.get("metadata", {})
106 | create_time_str = self.format_timestamp(metadata.get("create_time"))
107 | last_access_time_str = self.format_timestamp(metadata.get("last_access_time"))
108 | importance_score = metadata.get("importance", 0.0)
109 | event_type = metadata.get("event_type", "未知")
110 |
111 | card = (
112 | f"ID: {res['id']}\n"
113 | f"记 忆 度: {res['similarity']:.2f}\n"
114 | f"重 要 性: {importance_score:.2f}\n"
115 | f"记忆类型: {event_type}\n\n"
116 | f"内容: {res['text']}\n\n"
117 | f"创建于: {create_time_str}\n"
118 | f"最后访问: {last_access_time_str}"
119 | )
120 | response_parts.append(card)
121 |
122 | return "\n\n".join(response_parts)
123 |
124 | def format_sparse_results_for_display(self, response: Dict[str, Any]) -> str:
125 | """格式化稀疏检索结果用于显示"""
126 | if not response.get("success"):
127 | return response.get("message", "搜索失败")
128 |
129 | data = response.get("data", [])
130 | message = response.get("message", "")
131 |
132 | response_parts = [message]
133 |
134 | for i, res in enumerate(data, 1):
135 | response_parts.append(f"\n{i}. [ID: {res['doc_id']}] Score: {res['score']:.3f}")
136 | response_parts.append(f" 内容: {res['content'][:100]}{'...' if len(res['content']) > 100 else ''}")
137 |
138 | # 显示元数据
139 | metadata = res.get("metadata", {})
140 | if metadata.get("event_type"):
141 | response_parts.append(f" 类型: {metadata['event_type']}")
142 | if metadata.get("importance"):
143 | response_parts.append(f" 重要性: {metadata['importance']:.2f}")
144 |
145 | return "\n".join(response_parts)
--------------------------------------------------------------------------------
/core/models.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | models.py - 插件的核心数据模型
4 | """
5 |
6 | from datetime import datetime, timezone
7 | from enum import Enum
8 | from typing import List, Dict, Any, Optional
9 | from pydantic import BaseModel, Field
10 |
11 | # --- 公开的数据模型 ---
12 |
13 |
14 | class EventType(str, Enum):
15 | FACT = "fact"
16 | PREFERENCE = "preference"
17 | GOAL = "goal"
18 | OPINION = "opinion"
19 | RELATIONSHIP = "relationship"
20 | OTHER = "other"
21 |
22 |
23 | class Entity(BaseModel):
24 | name: str = Field(..., description="实体名称")
25 | type: str = Field(..., description="实体类型")
26 |
27 |
28 | class MemoryEvent(BaseModel):
29 | # --- 系统生成字段 ---
30 | timestamp: datetime = Field(
31 | default_factory=lambda: datetime.now(timezone.utc),
32 | description="事件创建的UTC时间戳",
33 | )
34 |
35 | # --- 系统生成字段 ---
36 | id: Optional[int] = Field(None, description="记忆的唯一整数ID,由存储后端生成")
37 |
38 | # --- LLM 生成字段 (第一阶段) ---
39 | temp_id: str = Field(..., description="由LLM生成的临时唯一ID,用于评分匹配")
40 | memory_content: str = Field(..., description="对事件的简洁、客观的描述")
41 | event_type: EventType = Field(default=EventType.OTHER, description="事件的分类")
42 | entities: List[Entity] = Field(
43 | default_factory=list, description="事件中涉及的关键实体"
44 | )
45 |
46 | # --- LLM 生成字段 (第二阶段) ---
47 | importance_score: Optional[float] = Field(
48 | None, ge=0.0, le=1.0, description="记忆的重要性评分 (0.0-1.0)"
49 | )
50 |
51 | # --- 系统关联字段 ---
52 | related_event_ids: List[int] = Field(
53 | default_factory=list, description="与此事件相关的其他事件ID"
54 | )
55 | metadata: Dict[str, Any] = Field(
56 | default_factory=dict, description="用于存储其他附加信息的灵活字段"
57 | )
58 |
59 |
60 | class MemoryEventList(BaseModel):
61 | events: List[MemoryEvent]
62 |
63 |
64 | # --- 用于生成 Prompt Schema 的私有模型 ---
65 |
66 |
67 | # 用于第一阶段:事件提取
68 | class _LLMExtractionEvent(BaseModel):
69 | temp_id: str = Field(
70 | ...,
71 | description="由LLM或系统生成的临时唯一ID",
72 | )
73 | memory_content: str = Field(...)
74 | event_type: EventType = Field(default=EventType.OTHER)
75 | entities: List[Entity] = Field(default_factory=list)
76 | related_event_ids: List[int] = Field(default_factory=list)
77 | metadata: Dict[str, Any] = Field(default_factory=dict)
78 |
79 |
80 | class _LLMExtractionEventList(BaseModel):
81 | events: List[_LLMExtractionEvent]
82 |
83 |
84 | # 用于第二阶段:评分
85 | class _LLMScoreEvaluation(BaseModel):
86 | scores: Dict[str, float] = Field(
87 | ..., description="一个字典,key是事件的临时ID (temp_id),value是对应的0.0-1.0的重要性分数"
88 | )
89 |
--------------------------------------------------------------------------------
/core/models/memory_models.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import dataclasses
4 | from dataclasses import dataclass, field
5 | from typing import List, Dict, Any, Optional
6 |
7 |
8 | @dataclass
9 | class LinkedMedia:
10 | media_id: str
11 | media_type: str
12 | url: str
13 | caption: str
14 | embedding: List[float] = field(default_factory=list)
15 |
16 |
17 | @dataclass
18 | class AccessInfo:
19 | initial_creation_timestamp: str
20 | last_accessed_timestamp: str
21 | access_count: int = 1
22 |
23 |
24 | @dataclass
25 | class EmotionalValence:
26 | sentiment: str
27 | intensity: float
28 |
29 |
30 | @dataclass
31 | class UserFeedback:
32 | is_accurate: Optional[bool] = None
33 | is_important: Optional[bool] = None
34 | correction_text: Optional[str] = None
35 |
36 |
37 | @dataclass
38 | class CommunityInfo:
39 | id: Optional[str] = None
40 | last_calculated: Optional[str] = None
41 |
42 |
43 | @dataclass
44 | class Metadata:
45 | source_conversation_id: str
46 | memory_type: str # 'episodic', 'semantic', 'procedural'
47 | importance_score: float
48 | access_info: AccessInfo
49 | confidence_score: Optional[float] = None
50 | emotional_valence: Optional[EmotionalValence] = None
51 | user_feedback: UserFeedback = field(default_factory=UserFeedback)
52 | community_info: CommunityInfo = field(default_factory=CommunityInfo)
53 | session_id: Optional[str] = None
54 | persona_id: Optional[str] = None
55 |
56 |
57 | @dataclass
58 | class EventEntity:
59 | event_id: str
60 | event_type: str
61 |
62 |
63 | @dataclass
64 | class Entity:
65 | entity_id: str
66 | name: str
67 | type: str
68 | role: Optional[str] = None
69 |
70 |
71 | @dataclass
72 | class KnowledgeGraphPayload:
73 | event_entity: EventEntity
74 | entities: List[Entity] = field(default_factory=list)
75 | relationships: List[List[str]] = field(default_factory=list)
76 |
77 |
78 | @dataclass
79 | class Memory:
80 | memory_id: str # UUID
81 | timestamp: str # ISO 8601
82 | summary: str
83 | description: str
84 | metadata: Metadata
85 | embedding: List[float] = field(default_factory=list)
86 | linked_media: List[LinkedMedia] = field(default_factory=list)
87 | knowledge_graph_payload: Optional[KnowledgeGraphPayload] = None
88 |
89 | # 提供一个方便的方法来将 dataclass 转换为 dict
90 | def to_dict(self) -> Dict[str, Any]:
91 | return dataclasses.asdict(self)
92 |
93 | # 提供一个方便的方法从 dict 创建 dataclass 实例
94 | @classmethod
95 | def from_dict(cls, data: Dict[str, Any]) -> "Memory":
96 | # 嵌套结构的递归转换
97 | field_types = {f.name: f.type for f in dataclasses.fields(cls)}
98 |
99 | # 简单递归转换,TODO 对于复杂场景可能需要更完善的库如 dacite
100 | for name, T in field_types.items():
101 | if (
102 | hasattr(T, "from_dict")
103 | and name in data
104 | and isinstance(data[name], dict)
105 | ):
106 | data[name] = T.from_dict(data[name])
107 | elif isinstance(data.get(name), list):
108 | # 处理 dataclass 列表
109 | origin = getattr(T, "__origin__", None)
110 | if origin is list:
111 | item_type = T.__args__[0]
112 | if hasattr(item_type, "from_dict"):
113 | data[name] = [item_type.from_dict(item) for item in data[name]]
114 |
115 | return cls(**data)
116 |
--------------------------------------------------------------------------------
/core/retrieval/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | 检索模块
4 | """
5 |
6 | from .sparse_retriever import SparseRetriever, SparseResult
7 | from .result_fusion import ResultFusion, SearchResult
8 |
9 | __all__ = [
10 | "SparseRetriever",
11 | "SparseResult",
12 | "ResultFusion",
13 | "SearchResult"
14 | ]
--------------------------------------------------------------------------------
/core/retrieval/sparse_retriever.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | 稀疏检索器 - 基于 SQLite FTS5 和 BM25 的全文检索
4 | """
5 |
6 | import json
7 | import sqlite3
8 | import math
9 | from typing import List, Dict, Any, Optional, Tuple
10 | from dataclasses import dataclass
11 | import asyncio
12 | import aiosqlite
13 |
14 | from astrbot.api import logger
15 |
16 | try:
17 | import jieba
18 | JIEBA_AVAILABLE = True
19 | except ImportError:
20 | JIEBA_AVAILABLE = False
21 | logger.warning("jieba not available, Chinese tokenization disabled")
22 |
23 |
24 | @dataclass
25 | class SparseResult:
26 | """稀疏检索结果"""
27 | doc_id: int
28 | score: float
29 | content: str
30 | metadata: Dict[str, Any]
31 |
32 |
33 | class FTSManager:
34 | """FTS5 索引管理器"""
35 |
36 | def __init__(self, db_path: str):
37 | self.db_path = db_path
38 | self.fts_table_name = "documents_fts"
39 |
40 | async def initialize(self):
41 | """初始化 FTS5 索引"""
42 | async with aiosqlite.connect(self.db_path) as db:
43 | # 启用 FTS5 扩展
44 | await db.execute("PRAGMA foreign_keys = ON")
45 |
46 | # 创建 FTS5 虚拟表
47 | await db.execute(f"""
48 | CREATE VIRTUAL TABLE IF NOT EXISTS {self.fts_table_name}
49 | USING fts5(content, doc_id, tokenize='unicode61')
50 | """)
51 |
52 | # 创建触发器,保持同步
53 | await self._create_triggers(db)
54 |
55 | await db.commit()
56 | logger.info(f"FTS5 index initialized: {self.fts_table_name}")
57 |
58 | async def _create_triggers(self, db: aiosqlite.Connection):
59 | """创建数据同步触发器"""
60 | # 插入触发器
61 | await db.execute(f"""
62 | CREATE TRIGGER IF NOT EXISTS documents_ai
63 | AFTER INSERT ON documents BEGIN
64 | INSERT INTO {self.fts_table_name}(doc_id, content)
65 | VALUES (new.id, new.text);
66 | END;
67 | """)
68 |
69 | # 删除触发器
70 | await db.execute(f"""
71 | CREATE TRIGGER IF NOT EXISTS documents_ad
72 | AFTER DELETE ON documents BEGIN
73 | DELETE FROM {self.fts_table_name} WHERE doc_id = old.id;
74 | END;
75 | """)
76 |
77 | # 更新触发器
78 | await db.execute(f"""
79 | CREATE TRIGGER IF NOT EXISTS documents_au
80 | AFTER UPDATE ON documents BEGIN
81 | DELETE FROM {self.fts_table_name} WHERE doc_id = old.id;
82 | INSERT INTO {self.fts_table_name}(doc_id, content)
83 | VALUES (new.id, new.text);
84 | END;
85 | """)
86 |
87 | async def rebuild_index(self):
88 | """重建索引"""
89 | async with aiosqlite.connect(self.db_path) as db:
90 | await db.execute(f"DELETE FROM {self.fts_table_name}")
91 | await db.execute(f"""
92 | INSERT INTO {self.fts_table_name}(doc_id, content)
93 | SELECT id, text FROM documents
94 | """)
95 | await db.commit()
96 | logger.info("FTS index rebuilt")
97 |
98 | async def search(self, query: str, limit: int = 50) -> List[Tuple[int, float]]:
99 | """执行 BM25 搜索"""
100 | async with aiosqlite.connect(self.db_path) as db:
101 | # 将整个查询用双引号包裹,以处理特殊字符并将其作为短语搜索
102 | # 这是为了防止 FTS5 语法错误,例如 'syntax error near "."'
103 | safe_query = f'"{query}"'
104 |
105 | # 使用 BM25 算法搜索
106 | cursor = await db.execute(f"""
107 | SELECT doc_id, bm25({self.fts_table_name}) as score
108 | FROM {self.fts_table_name}
109 | WHERE {self.fts_table_name} MATCH ?
110 | ORDER BY score
111 | LIMIT ?
112 | """, (safe_query, limit))
113 |
114 | results = await cursor.fetchall()
115 | return [(row[0], row[1]) for row in results]
116 |
117 |
118 | class SparseRetriever:
119 | """稀疏检索器"""
120 |
121 | def __init__(self, db_path: str, config: Dict[str, Any] = None):
122 | self.db_path = db_path
123 | self.config = config or {}
124 | self.fts_manager = FTSManager(db_path)
125 | self.enabled = self.config.get("enabled", True)
126 | self.use_chinese_tokenizer = self.config.get("use_chinese_tokenizer", JIEBA_AVAILABLE)
127 |
128 | async def initialize(self):
129 | """初始化稀疏检索器"""
130 | if not self.enabled:
131 | logger.info("Sparse retriever disabled")
132 | return
133 |
134 | await self.fts_manager.initialize()
135 |
136 | # 如果启用中文分词,初始化 jieba
137 | if self.use_chinese_tokenizer and JIEBA_AVAILABLE:
138 | # 可以添加自定义词典
139 | pass
140 |
141 | logger.info("Sparse retriever initialized")
142 |
143 | def _preprocess_query(self, query: str) -> str:
144 | """预处理查询"""
145 | query = query.strip()
146 |
147 | # 中文分词
148 | if self.use_chinese_tokenizer and JIEBA_AVAILABLE:
149 | # 检查是否包含中文
150 | if any('\u4e00' <= char <= '\u9fff' for char in query):
151 | tokens = jieba.cut_for_search(query)
152 | query = " ".join(tokens)
153 |
154 | query = query.replace('"', ' ') # 将内部的双引号替换为空格
155 |
156 | return query
157 |
158 | async def search(
159 | self,
160 | query: str,
161 | limit: int = 50,
162 | session_id: Optional[str] = None,
163 | persona_id: Optional[str] = None,
164 | metadata_filters: Optional[Dict[str, Any]] = None
165 | ) -> List[SparseResult]:
166 | """执行稀疏检索"""
167 | if not self.enabled:
168 | return []
169 |
170 | try:
171 | # 预处理查询
172 | processed_query = self._preprocess_query(query)
173 | logger.debug(f"Sparse search query: {processed_query}")
174 |
175 | # 执行 FTS 搜索
176 | fts_results = await self.fts_manager.search(processed_query, limit)
177 |
178 | if not fts_results:
179 | return []
180 |
181 | # 获取完整的文档信息
182 | doc_ids = [doc_id for doc_id, _ in fts_results]
183 | documents = await self._get_documents(doc_ids)
184 |
185 | # 应用过滤器
186 | filtered_results = []
187 | for doc_id, bm25_score in fts_results:
188 | if doc_id in documents:
189 | doc = documents[doc_id]
190 |
191 | # 检查元数据过滤器
192 | if self._apply_filters(doc.get("metadata", {}), session_id, persona_id, metadata_filters):
193 | result = SparseResult(
194 | doc_id=doc_id,
195 | score=bm25_score,
196 | content=doc["text"],
197 | metadata=doc["metadata"]
198 | )
199 | filtered_results.append(result)
200 |
201 | # 归一化 BM25 分数(转换为 0-1)
202 | if filtered_results:
203 | max_score = max(r.score for r in filtered_results)
204 | min_score = min(r.score for r in filtered_results)
205 | score_range = max_score - min_score if max_score != min_score else 1
206 |
207 | for result in filtered_results:
208 | result.score = (result.score - min_score) / score_range
209 |
210 | logger.debug(f"Sparse search returned {len(filtered_results)} results")
211 | return filtered_results
212 |
213 | except Exception as e:
214 | logger.error(f"Sparse search error: {e}", exc_info=True)
215 | return []
216 |
217 | async def _get_documents(self, doc_ids: List[int]) -> Dict[int, Dict[str, Any]]:
218 | """批量获取文档"""
219 | async with aiosqlite.connect(self.db_path) as db:
220 | placeholders = ",".join("?" for _ in doc_ids)
221 | cursor = await db.execute(f"""
222 | SELECT id, text, metadata FROM documents WHERE id IN ({placeholders})
223 | """, doc_ids)
224 |
225 | documents = {}
226 | async for row in cursor:
227 | metadata = json.loads(row[2]) if isinstance(row[2], str) else row[2]
228 | documents[row[0]] = {
229 | "text": row[1],
230 | "metadata": metadata or {}
231 | }
232 |
233 | return documents
234 |
235 | def _apply_filters(
236 | self,
237 | metadata: Dict[str, Any],
238 | session_id: Optional[str],
239 | persona_id: Optional[str],
240 | metadata_filters: Optional[Dict[str, Any]]
241 | ) -> bool:
242 | """应用过滤器"""
243 | # 会话过滤
244 | if session_id and metadata.get("session_id") != session_id:
245 | return False
246 |
247 | # 人格过滤
248 | if persona_id and metadata.get("persona_id") != persona_id:
249 | return False
250 |
251 | # 自定义元数据过滤
252 | if metadata_filters:
253 | for key, value in metadata_filters.items():
254 | if metadata.get(key) != value:
255 | return False
256 |
257 | return True
258 |
259 | async def rebuild_index(self):
260 | """重建索引"""
261 | if not self.enabled:
262 | return
263 | await self.fts_manager.rebuild_index()
--------------------------------------------------------------------------------
/core/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | utils.py - 插件的辅助工具函数
4 | """
5 |
6 | import re
7 | import json
8 | import time
9 | import asyncio
10 | from datetime import datetime
11 | from typing import List, Optional, Dict, Any
12 |
13 | import pytz
14 |
15 | from astrbot.api import logger
16 | from astrbot.api.star import Context
17 | from astrbot.api.event import AstrMessageEvent
18 | from ..storage.faiss_manager import Result
19 | from .constants import MEMORY_INJECTION_HEADER, MEMORY_INJECTION_FOOTER
20 |
21 |
22 | def safe_parse_metadata(metadata_raw: Any) -> Dict[str, Any]:
23 | """
24 | 安全解析元数据,统一处理字符串和字典类型。
25 |
26 | Args:
27 | metadata_raw: 原始元数据,可能是字符串或字典
28 |
29 | Returns:
30 | Dict[str, Any]: 解析后的元数据字典,解析失败时返回空字典
31 | """
32 | if isinstance(metadata_raw, dict):
33 | return metadata_raw
34 | elif isinstance(metadata_raw, str):
35 | try:
36 | return json.loads(metadata_raw)
37 | except (json.JSONDecodeError, TypeError) as e:
38 | logger.warning(f"解析元数据JSON失败: {e}, 原始数据: {metadata_raw}")
39 | return {}
40 | else:
41 | logger.warning(f"不支持的元数据类型: {type(metadata_raw)}")
42 | return {}
43 |
44 |
45 | def safe_serialize_metadata(metadata: Dict[str, Any]) -> str:
46 | """
47 | 安全序列化元数据为JSON字符串。
48 |
49 | Args:
50 | metadata: 元数据字典
51 |
52 | Returns:
53 | str: JSON字符串
54 | """
55 | try:
56 | return json.dumps(metadata, ensure_ascii=False)
57 | except (TypeError, ValueError) as e:
58 | logger.error(f"序列化元数据失败: {e}, 数据: {metadata}")
59 | return "{}"
60 |
61 |
62 | def validate_timestamp(timestamp: Any, default_time: Optional[float] = None) -> float:
63 | """
64 | 验证和标准化时间戳。
65 |
66 | Args:
67 | timestamp: 时间戳,可能是字符串、数字或其他类型
68 | default_time: 默认时间,如果为None则使用当前时间
69 |
70 | Returns:
71 | float: 标准化的时间戳
72 | """
73 | if default_time is None:
74 | default_time = time.time()
75 |
76 | if isinstance(timestamp, (int, float)):
77 | return float(timestamp)
78 | elif isinstance(timestamp, str):
79 | try:
80 | return float(timestamp)
81 | except (ValueError, TypeError):
82 | logger.warning(f"无法解析时间戳字符串: {timestamp}")
83 | return default_time
84 | elif hasattr(timestamp, 'timestamp'): # datetime对象
85 | try:
86 | return timestamp.timestamp()
87 | except Exception as e:
88 | logger.warning(f"无法从datetime对象获取时间戳: {e}")
89 | return default_time
90 | else:
91 | logger.warning(f"不支持的时间戳类型: {type(timestamp)}")
92 | return default_time
93 |
94 |
95 | async def retry_on_failure(
96 | func,
97 | *args,
98 | max_retries: int = 3,
99 | backoff_factor: float = 1.0,
100 | exceptions: tuple = (Exception,),
101 | **kwargs
102 | ):
103 | """
104 | 带重试机制的函数执行器。
105 |
106 | Args:
107 | func: 要执行的函数
108 | *args: 函数位置参数
109 | max_retries: 最大重试次数
110 | backoff_factor: 退避因子
111 | exceptions: 需要重试的异常类型
112 | **kwargs: 函数关键字参数
113 |
114 | Returns:
115 | 函数执行结果
116 | """
117 | last_exception = None
118 |
119 | for attempt in range(max_retries + 1):
120 | try:
121 | if asyncio.iscoroutinefunction(func):
122 | return await func(*args, **kwargs)
123 | else:
124 | return func(*args, **kwargs)
125 | except exceptions as e:
126 | last_exception = e
127 | if attempt < max_retries:
128 | wait_time = backoff_factor * (2 ** attempt)
129 | logger.warning(f"函数 {func.__name__} 执行失败 (尝试 {attempt + 1}/{max_retries + 1}): {e}")
130 | logger.info(f"等待 {wait_time:.2f} 秒后重试...")
131 | await asyncio.sleep(wait_time)
132 | else:
133 | logger.error(f"函数 {func.__name__} 重试 {max_retries} 次后仍然失败: {e}")
134 |
135 | # 所有重试都失败,抛出最后一个异常
136 | raise last_exception
137 |
138 |
139 | class OperationContext:
140 | """操作上下文管理器,用于错误处理和资源清理"""
141 |
142 | def __init__(self, operation_name: str, session_id: Optional[str] = None):
143 | self.operation_name = operation_name
144 | self.session_id = session_id
145 | self.start_time = None
146 |
147 | async def __aenter__(self):
148 | self.start_time = time.time()
149 | session_info = f"[{self.session_id}] " if self.session_id else ""
150 | logger.debug(f"{session_info}开始执行操作: {self.operation_name}")
151 | return self
152 |
153 | async def __aexit__(self, exc_type, exc_val, exc_tb):
154 | duration = time.time() - self.start_time if self.start_time else 0
155 | session_info = f"[{self.session_id}] " if self.session_id else ""
156 |
157 | if exc_type is None:
158 | logger.debug(f"{session_info}操作成功完成: {self.operation_name} (耗时 {duration:.3f}s)")
159 | else:
160 | logger.error(f"{session_info}操作失败: {self.operation_name} (耗时 {duration:.3f}s) - {exc_val}")
161 |
162 | # 不抑制异常,让调用者处理
163 | return False
164 |
165 |
166 | async def get_persona_id(context: Context, event: AstrMessageEvent) -> Optional[str]:
167 | """
168 | 获取当前会话的人格 ID。
169 | 如果当前会话没有特定人格,则返回 AstrBot 的默认人格。
170 | """
171 | try:
172 | session_id = await context.conversation_manager.get_curr_conversation_id(
173 | event.unified_msg_origin
174 | )
175 | conversation = await context.conversation_manager.get_conversation(
176 | event.unified_msg_origin, session_id
177 | )
178 | persona_id = conversation.persona_id if conversation else None
179 |
180 | # 如果无人格或明确设置为None,则使用全局默认人格
181 | if not persona_id or persona_id == "[%None]":
182 | default_persona = context.provider_manager.selected_default_persona
183 | persona_id = default_persona["name"] if default_persona else None
184 |
185 | return persona_id
186 | except Exception as e:
187 | # 在某些情况下(如无会话),获取可能会失败,返回 None
188 | logger.debug(f"获取人格ID失败: {e}")
189 | return None
190 |
191 |
192 | def extract_json_from_response(text: str) -> str:
193 | """
194 | 从可能包含 Markdown 代码块的文本中提取纯 JSON 字符串。
195 | """
196 | # 查找被 ```json ... ``` 或 ``` ... ``` 包围的内容
197 | match = re.search(r"```(json)?\s*(\{.*?\})\s*```", text, re.DOTALL)
198 | if match:
199 | # 返回捕获组中的 JSON 部分
200 | return match.group(2)
201 |
202 | # 如果没有找到代码块,假设整个文本就是 JSON(可能需要去除首尾空格)
203 | return text.strip()
204 |
205 |
206 | def get_now_datetime(tz_str: str = "Asia/Shanghai") -> datetime:
207 | """
208 | 获取当前时间,并根据指定的时区设置时区。
209 |
210 | Args:
211 | tz_str: 时区字符串,默认为 "Asia/Shanghai"
212 |
213 | Returns:
214 | datetime: 带有时区信息的当前时间
215 | """
216 | # 如果传入的是 Context 对象,则使用从上下文获取时间的方法
217 | # 检查传入的是否是 Context 对象
218 | from astrbot.api.star import Context # 导入 Context 类型
219 | if isinstance(tz_str, Context):
220 | # 如果是 Context 对象,调用专门的函数处理
221 | return get_now_datetime_from_context(tz_str)
222 |
223 | try:
224 | timezone = pytz.timezone(tz_str)
225 | except pytz.UnknownTimeZoneError:
226 | # 如果时区无效,则使用默认值
227 | logger.warning(f"无效的时区: {tz_str},使用默认时区 Asia/Shanghai")
228 | timezone = pytz.timezone("Asia/Shanghai")
229 |
230 | return datetime.now(timezone)
231 |
232 |
233 | def get_now_datetime_from_context(context: Context) -> datetime:
234 | """
235 | 从上下文中获取当前时间,根据插件配置设置时区。
236 |
237 | Args:
238 | context: AstrBot 上下文对象
239 |
240 | Returns:
241 | datetime: 带有时区信息的当前时间
242 | """
243 | try:
244 | # 尝试从配置中获取时区
245 | tz_str = context.plugin_config.get("timezone_settings", {}).get("timezone", "Asia/Shanghai")
246 | return get_now_datetime(tz_str)
247 | except (AttributeError, KeyError):
248 | # 如果配置不存在,则使用默认值
249 | return get_now_datetime()
250 |
251 |
252 | def format_memories_for_injection(memories: List[Result]) -> str:
253 | """
254 | 将检索到的记忆列表格式化为单个字符串,以便注入到 System Prompt。
255 | """
256 | if not memories:
257 | return ""
258 |
259 | header = f"{MEMORY_INJECTION_HEADER}\n"
260 | footer = f"\n{MEMORY_INJECTION_FOOTER}"
261 |
262 | formatted_entries = []
263 | for mem in memories:
264 | try:
265 | # 使用统一的元数据解析函数
266 | metadata_raw = mem.data.get("metadata", "{}")
267 | metadata = safe_parse_metadata(metadata_raw)
268 |
269 | content = mem.data.get("text", "内容缺失")
270 | importance = metadata.get("importance", 0.0)
271 |
272 | entry = f"- [重要性: {importance:.2f}] {content}"
273 | formatted_entries.append(entry)
274 | except Exception as e:
275 | # 如果处理失败,则跳过此条记忆
276 | logger.debug(f"格式化记忆时出错,跳过此记忆: {e}")
277 | continue
278 |
279 | if not formatted_entries:
280 | return ""
281 |
282 | body = "\n".join(formatted_entries)
283 |
284 | return f"{header}{body}{footer}"
285 |
--------------------------------------------------------------------------------
/docs/CONFIG.md:
--------------------------------------------------------------------------------
1 | # LivingMemory 配置参考
2 |
3 | 本文档详细介绍了 LivingMemory 插件的所有配置参数。
4 |
5 | ## 📋 配置概览
6 |
7 | 插件配置采用层次化结构,主要包含以下几个部分:
8 | - 时区设置
9 | - Provider 设置
10 | - 会话管理器
11 | - 回忆引擎
12 | - 反思引擎
13 | - 遗忘代理
14 | - 结果融合
15 | - 稀疏检索器
16 | - 过滤设置
17 |
18 | ## ⚙️ 详细配置参数
19 |
20 | ### 时区设置 (timezone_settings)
21 |
22 | | 参数 | 类型 | 默认值 | 描述 |
23 | |------|------|--------|------|
24 | | `timezone` | string | `"Asia/Shanghai"` | IANA 时区数据库名称,影响时间显示格式 |
25 |
26 | **示例:**
27 | ```yaml
28 | timezone_settings:
29 | timezone: "America/New_York" # 纽约时区
30 | ```
31 |
32 | **可用时区:**
33 | - `Asia/Shanghai` - 中国标准时间
34 | - `America/New_York` - 美国东部时间
35 | - `Europe/London` - 格林威治时间
36 | - `Asia/Tokyo` - 日本标准时间
37 |
38 | ### Provider 设置 (provider_settings)
39 |
40 | | 参数 | 类型 | 默认值 | 描述 |
41 | |------|------|--------|------|
42 | | `embedding_provider_id` | string | `""` | 指定用于生成向量的 Embedding Provider ID |
43 | | `llm_provider_id` | string | `""` | 指定用于总结和评估的 LLM Provider ID |
44 |
45 | **示例:**
46 | ```yaml
47 | provider_settings:
48 | embedding_provider_id: "openai_embedding"
49 | llm_provider_id: "claude_3_5"
50 | ```
51 |
52 | **注意:**
53 | - 留空将自动使用 AstrBot 的默认 Provider
54 | - 确保指定的 Provider 已在 AstrBot 中正确配置
55 |
56 | ### 会话管理器 (session_manager)
57 |
58 | | 参数 | 类型 | 默认值 | 范围 | 描述 |
59 | |------|------|--------|------|------|
60 | | `max_sessions` | int | `1000` | 1-10000 | 同时维护的最大会话数量 |
61 | | `session_ttl` | int | `3600` | 60-86400 | 会话生存时间(秒) |
62 |
63 | **示例:**
64 | ```yaml
65 | session_manager:
66 | max_sessions: 500 # 最大500个会话
67 | session_ttl: 7200 # 2小时过期
68 | ```
69 |
70 | **优化建议:**
71 | - 高并发场景:增大 `max_sessions`
72 | - 内存紧张:减小 `session_ttl`
73 | - 长对话场景:增大 `session_ttl`
74 |
75 | ### 回忆引擎 (recall_engine)
76 |
77 | | 参数 | 类型 | 默认值 | 范围 | 描述 |
78 | |------|------|--------|------|------|
79 | | `top_k` | int | `5` | 1-50 | 单次检索返回的记忆数量 |
80 | | `recall_strategy` | string | `"weighted"` | similarity/weighted | 召回策略 |
81 | | `retrieval_mode` | string | `"hybrid"` | hybrid/dense/sparse | 检索模式 |
82 | | `similarity_weight` | float | `0.6` | 0.0-1.0 | 相似度权重 |
83 | | `importance_weight` | float | `0.2` | 0.0-1.0 | 重要性权重 |
84 | | `recency_weight` | float | `0.2` | 0.0-1.0 | 新近度权重 |
85 |
86 | **召回策略:**
87 | - `similarity`: 纯基于相似度的召回
88 | - `weighted`: 综合考虑相似度、重要性和新近度
89 |
90 | **检索模式:**
91 | - `hybrid`: 混合检索(推荐)
92 | - `dense`: 纯密集向量检索
93 | - `sparse`: 纯稀疏关键词检索
94 |
95 | **权重调优指南:**
96 | ```yaml
97 | # 重视语义相关性
98 | recall_engine:
99 | similarity_weight: 0.7
100 | importance_weight: 0.2
101 | recency_weight: 0.1
102 |
103 | # 重视重要信息
104 | recall_engine:
105 | similarity_weight: 0.4
106 | importance_weight: 0.5
107 | recency_weight: 0.1
108 |
109 | # 重视最新信息
110 | recall_engine:
111 | similarity_weight: 0.4
112 | importance_weight: 0.2
113 | recency_weight: 0.4
114 | ```
115 |
116 | ### 反思引擎 (reflection_engine)
117 |
118 | | 参数 | 类型 | 默认值 | 范围 | 描述 |
119 | |------|------|--------|------|------|
120 | | `summary_trigger_rounds` | int | `5` | 1-100 | 触发反思的对话轮次 |
121 | | `importance_threshold` | float | `0.5` | 0.0-1.0 | 记忆重要性阈值 |
122 | | `event_extraction_prompt` | text | 默认提示词 | - | 事件提取提示词 |
123 | | `evaluation_prompt` | text | 默认提示词 | - | 重要性评估提示词 |
124 |
125 | **触发轮次调优:**
126 | - `1-3轮`: 频繁反思,适合重要对话
127 | - `5-10轮`: 平衡模式(推荐)
128 | - `15-30轮`: 长对话模式,减少反思频率
129 |
130 | **重要性阈值:**
131 | - `0.1-0.3`: 宽松模式,保存更多记忆
132 | - `0.5-0.7`: 标准模式(推荐)
133 | - `0.8-1.0`: 严格模式,只保存重要记忆
134 |
135 | ### 遗忘代理 (forgetting_agent)
136 |
137 | | 参数 | 类型 | 默认值 | 范围 | 描述 |
138 | |------|------|--------|------|------|
139 | | `enabled` | bool | `true` | - | 是否启用自动遗忘 |
140 | | `check_interval_hours` | int | `24` | 1-168 | 检查间隔(小时) |
141 | | `retention_days` | int | `90` | 1-3650 | 记忆保留天数 |
142 | | `importance_decay_rate` | float | `0.005` | 0.0-1.0 | 重要性衰减率 |
143 | | `importance_threshold` | float | `0.1` | 0.0-1.0 | 遗忘重要性阈值 |
144 | | `forgetting_batch_size` | int | `1000` | 100-10000 | 批处理大小 |
145 |
146 | **遗忘策略配置:**
147 | ```yaml
148 | # 保守遗忘(保存更多记忆)
149 | forgetting_agent:
150 | retention_days: 180
151 | importance_decay_rate: 0.001
152 | importance_threshold: 0.05
153 |
154 | # 标准遗忘(推荐)
155 | forgetting_agent:
156 | retention_days: 90
157 | importance_decay_rate: 0.005
158 | importance_threshold: 0.1
159 |
160 | # 激进遗忘(节省存储空间)
161 | forgetting_agent:
162 | retention_days: 30
163 | importance_decay_rate: 0.01
164 | importance_threshold: 0.2
165 | ```
166 |
167 | ### 结果融合 (fusion)
168 |
169 | | 参数 | 类型 | 默认值 | 描述 |
170 | |------|------|--------|------|
171 | | `strategy` | string | `"rrf"` | 融合策略 |
172 | | `rrf_k` | int | `60` | RRF 参数 k |
173 | | `dense_weight` | float | `0.7` | 密集检索权重 |
174 | | `sparse_weight` | float | `0.3` | 稀疏检索权重 |
175 | | `convex_lambda` | float | `0.5` | 凸组合参数 |
176 | | `interleave_ratio` | float | `0.5` | 交替融合比例 |
177 | | `rank_bias_factor` | float | `0.1` | 排序偏置因子 |
178 | | `diversity_bonus` | float | `0.1` | 多样性奖励 |
179 |
180 | **融合策略详解:**
181 | - `rrf`: 经典 Reciprocal Rank Fusion
182 | - `hybrid_rrf`: 自适应 RRF
183 | - `weighted`: 加权融合
184 | - `convex`: 凸组合融合
185 | - `interleave`: 交替融合
186 | - `rank_fusion`: 基于排序的融合
187 | - `score_fusion`: Borda Count 融合
188 | - `cascade`: 级联融合
189 | - `adaptive`: 自适应融合
190 |
191 | 详细说明请参考 [FUSION_STRATEGIES.md](../FUSION_STRATEGIES.md)
192 |
193 | ### 稀疏检索器 (sparse_retriever)
194 |
195 | | 参数 | 类型 | 默认值 | 范围 | 描述 |
196 | |------|------|--------|------|------|
197 | | `enabled` | bool | `true` | - | 是否启用稀疏检索 |
198 | | `bm25_k1` | float | `1.2` | 0.1-10.0 | BM25 k1 参数 |
199 | | `bm25_b` | float | `0.75` | 0.0-1.0 | BM25 b 参数 |
200 | | `use_jieba` | bool | `true` | - | 是否使用中文分词 |
201 |
202 | **BM25 参数调优:**
203 | - `k1`: 控制词频饱和度
204 | - 较小值(0.5-1.0):词频影响较小
205 | - 较大值(1.5-2.0):词频影响较大
206 | - `b`: 控制文档长度归一化
207 | - 0.0:不考虑文档长度
208 | - 1.0:完全归一化文档长度
209 |
210 | **中文优化配置:**
211 | ```yaml
212 | sparse_retriever:
213 | enabled: true
214 | bm25_k1: 1.2 # 适合中文的词频参数
215 | bm25_b: 0.75 # 中等长度归一化
216 | use_jieba: true # 启用中文分词
217 | ```
218 |
219 | ### 过滤设置 (filtering_settings)
220 |
221 | | 参数 | 类型 | 默认值 | 描述 |
222 | |------|------|--------|------|
223 | | `use_persona_filtering` | bool | `true` | 是否启用人格记忆过滤 |
224 | | `use_session_filtering` | bool | `true` | 是否启用会话记忆隔离 |
225 |
226 | **过滤模式组合:**
227 | ```yaml
228 | # 完全隔离模式
229 | filtering_settings:
230 | use_persona_filtering: true
231 | use_session_filtering: true
232 |
233 | # 人格共享模式
234 | filtering_settings:
235 | use_persona_filtering: true
236 | use_session_filtering: false
237 |
238 | # 会话共享模式
239 | filtering_settings:
240 | use_persona_filtering: false
241 | use_session_filtering: true
242 |
243 | # 全局共享模式
244 | filtering_settings:
245 | use_persona_filtering: false
246 | use_session_filtering: false
247 | ```
248 |
249 | ## 🎯 场景化配置示例
250 |
251 | ### 个人助手配置
252 | ```yaml
253 | # 适合个人日常使用
254 | session_manager:
255 | max_sessions: 100
256 | session_ttl: 7200
257 |
258 | recall_engine:
259 | top_k: 3
260 | similarity_weight: 0.5
261 | importance_weight: 0.3
262 | recency_weight: 0.2
263 |
264 | reflection_engine:
265 | summary_trigger_rounds: 10
266 | importance_threshold: 0.4
267 |
268 | filtering_settings:
269 | use_persona_filtering: true
270 | use_session_filtering: false
271 | ```
272 |
273 | ### 客服机器人配置
274 | ```yaml
275 | # 适合客服场景
276 | session_manager:
277 | max_sessions: 1000
278 | session_ttl: 1800
279 |
280 | recall_engine:
281 | top_k: 5
282 | similarity_weight: 0.7
283 | importance_weight: 0.2
284 | recency_weight: 0.1
285 |
286 | reflection_engine:
287 | summary_trigger_rounds: 5
288 | importance_threshold: 0.6
289 |
290 | filtering_settings:
291 | use_persona_filtering: false
292 | use_session_filtering: true
293 | ```
294 |
295 | ### 教育辅导配置
296 | ```yaml
297 | # 适合教育辅导场景
298 | session_manager:
299 | max_sessions: 500
300 | session_ttl: 3600
301 |
302 | recall_engine:
303 | top_k: 8
304 | similarity_weight: 0.4
305 | importance_weight: 0.4
306 | recency_weight: 0.2
307 |
308 | reflection_engine:
309 | summary_trigger_rounds: 8
310 | importance_threshold: 0.3
311 |
312 | forgetting_agent:
313 | retention_days: 180
314 | importance_decay_rate: 0.002
315 | ```
316 |
317 | ## 🔧 配置验证
318 |
319 | ### 使用命令验证
320 | ```bash
321 | # 验证当前配置
322 | /lmem config validate
323 |
324 | # 查看配置摘要
325 | /lmem config show
326 | ```
327 |
328 | ### 配置文件验证
329 | 插件会在启动时自动验证配置:
330 | - ✅ 参数类型检查
331 | - ✅ 数值范围验证
332 | - ✅ 必需字段验证
333 | - ✅ 权重总和警告
334 |
335 | ## 💡 性能优化建议
336 |
337 | ### 内存优化
338 | - 减少 `max_sessions` 和 `session_ttl`
339 | - 降低 `top_k` 值
340 | - 启用积极的遗忘策略
341 |
342 | ### 准确性优化
343 | - 增加 `top_k` 值
344 | - 调整权重配比
345 | - 使用混合检索模式
346 | - 优化融合策略参数
347 |
348 | ### 响应速度优化
349 | - 使用 `cascade` 融合策略
350 | - 减少 `top_k` 值
351 | - 选择更快的检索模式
352 |
353 | ## ⚠️ 注意事项
354 |
355 | 1. **权重总和**:确保回忆引擎的三个权重总和接近 1.0
356 | 2. **Provider 可用性**:确保指定的 Provider 已正确配置
357 | 3. **存储空间**:长期使用需要考虑遗忘策略以控制存储增长
358 | 4. **中文支持**:启用 jieba 分词以获得更好的中文检索效果
359 | 5. **配置热更新**:部分配置修改需要重启插件才能生效
360 |
361 | ## 🔍 配置调试
362 |
363 | ### 查看生效配置
364 | ```bash
365 | /lmem config show
366 | ```
367 |
368 | ### 测试检索效果
369 | ```bash
370 | # 测试不同检索模式
371 | /lmem search_mode hybrid
372 | /lmem search "测试查询" 5
373 |
374 | # 测试融合策略
375 | /lmem fusion show
376 | /lmem test_fusion "测试查询" 5
377 | ```
378 |
379 | ### 性能监控
380 | ```bash
381 | # 查看记忆库状态
382 | /lmem status
383 |
384 | # 检查会话数量
385 | /lmem config show | grep 会话
386 | ```
--------------------------------------------------------------------------------
/metadata.yaml:
--------------------------------------------------------------------------------
1 | name: astrbot_plugin_livingmemory
2 | author: lxfight
3 | version: 1.0.1
4 | description: 为 AstrBot 打造的、拥有完整记忆生命周期的智能长期记忆插件。
5 | repo: https://github.com/lxfight-s-Astrbot-Plugins/astrbot_plugin_livingmemory
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | faiss-cpu>=1.7.0
2 | networkx>=3.0
3 | jieba>=0.42.1
--------------------------------------------------------------------------------
/storage/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxfight-s-Astrbot-Plugins/astrbot_plugin_livingmemory/a382b10a91a41b68abc850e5b1abe05052e372f0/storage/__init__.py
--------------------------------------------------------------------------------
/storage/faiss_manager_v2.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import uuid
3 | import json
4 | import asyncio
5 | from datetime import datetime, timezone
6 | from typing import List, Optional
7 | import aiosqlite
8 |
9 | from astrbot.api import logger
10 |
11 | from ..core.models.memory_models import (
12 | Memory,
13 | AccessInfo, # noqa: F401
14 | UserFeedback, # noqa: F401
15 | EmotionalValence, # noqa: F401
16 | )
17 | from .memory_storage import MemoryStorage
18 | from .vector_store import VectorStore
19 | from .graph_storage import GraphStorageSQLite
20 |
21 |
22 | class FaissManagerV2:
23 | """
24 | 高级管理器,协调 MemoryStorage 和 VectorStore,
25 | 以支持结构化的、具有生命周期的记忆。
26 | """
27 |
28 | def __init__(
29 | self,
30 | db_path: str,
31 | text_vstore: VectorStore,
32 | image_vstore: VectorStore,
33 | embedding_model,
34 | ):
35 | self.db_path = db_path
36 | self.conn: Optional[aiosqlite.Connection] = None
37 | self.storage: Optional[MemoryStorage] = None
38 | self.graph_storage: Optional[GraphStorageSQLite] = None
39 |
40 | self.text_vstore = text_vstore
41 | self.image_vstore = image_vstore
42 | self.embedding_model = embedding_model
43 |
44 | async def initialize(self):
45 | """
46 | 建立数据库连接,并初始化所有依赖于此连接的存储组件。
47 | """
48 | self.conn = await aiosqlite.connect(self.db_path)
49 | # 启用外键约束,这对于 ON DELETE CASCADE 至关重要
50 | await self.conn.execute("PRAGMA foreign_keys = ON;")
51 |
52 | # 初始化各个存储层
53 | self.storage = MemoryStorage(self.conn) # 传入连接对象
54 | await self.storage.initialize_schema() # 创建 memories 表
55 |
56 | self.graph_storage = GraphStorageSQLite(self.conn) # 传入同一个连接对象
57 | await self.graph_storage.initialize_schema() # 创建图相关的表
58 |
59 | async def close(self):
60 | if self.conn:
61 | await self.conn.close()
62 |
63 | async def add_memory(self, memory: Memory) -> str:
64 | """
65 | 添加一条新的、结构化的记忆。
66 | """
67 | if not memory.memory_id:
68 | memory.memory_id = str(uuid.uuid4())
69 |
70 | if not memory.embedding:
71 | memory.embedding = await asyncio.to_thread(self.embedding_model.encode, memory.description)
72 |
73 | # 1. 存入 SQLite 并获取内部 ID
74 | internal_id = await self.storage.add_memory(memory)
75 |
76 | # 2. 将文本向量添加到 text_vstore
77 | if memory.embedding:
78 | await asyncio.to_thread(self.text_vstore.add, [internal_id], [memory.embedding])
79 |
80 | # 3. 将图像向量添加到 image_vstore
81 | media_embeddings = [
82 | media.embedding for media in memory.linked_media if media.embedding
83 | ]
84 | if media_embeddings:
85 | media_ids = [internal_id] * len(media_embeddings)
86 | await asyncio.to_thread(self.image_vstore.add, media_ids, media_embeddings)
87 |
88 | # 4. 将图数据添加到 graph_storage
89 | if memory.knowledge_graph_payload:
90 | await self.graph_storage.add_memory_graph(
91 | internal_id, memory.knowledge_graph_payload.to_dict()
92 | )
93 |
94 | # TODO 考虑定期保存索引,而不是每次都保存
95 | await asyncio.to_thread(self.text_vstore.save_index)
96 | await asyncio.to_thread(self.image_vstore.save_index)
97 |
98 | return memory.memory_id
99 |
100 | async def search_memory(
101 | self, query_text: str, k: int = 10, w1: float = 0.6, w2: float = 0.4
102 | ) -> List[Memory]:
103 | """
104 | 根据查询文本智能检索最相关的记忆。
105 | """
106 | # 1a. 向量搜索
107 | query_embedding = await asyncio.to_thread(self.embedding_model.encode, query_text)
108 | # 召回数量可以设置得比最终需要的 k 要大,例如 k*5
109 | distances, text_ids = await asyncio.to_thread(self.text_vstore.search, query_embedding, k * 5)
110 |
111 | # 1b. 基于图的种子扩展
112 | # 假设 embedding_model 有提取实体的能力
113 | query_entities = await asyncio.to_thread(self.embedding_model.extract_entities, query_text)
114 | graph_ids = []
115 | if query_entities:
116 | for entity_id in query_entities:
117 | graph_ids.extend(
118 | await self.graph_storage.find_related_memory_ids(entity_id)
119 | )
120 |
121 | # 合并候选并去重
122 | candidate_internal_ids = list(set(list(text_ids) + graph_ids))
123 | if not candidate_internal_ids:
124 | return []
125 |
126 | # --- 阶段二:重排与扩展 ---
127 |
128 | # 2a. 获取候选记忆的完整数据
129 | candidate_docs = await self.storage.get_memories_by_internal_ids(
130 | candidate_internal_ids
131 | )
132 | candidate_memories = {
133 | doc["id"]: Memory.from_dict(json.loads(doc["memory_data"]))
134 | for doc in candidate_docs
135 | }
136 |
137 | # 2b. 计算最终分数并重排
138 | final_scores = {}
139 | text_ids_list = list(text_ids)
140 | for internal_id in candidate_internal_ids:
141 | # Faiss 分数:可以用排名的倒数来表示,排名越靠前分数越高
142 | try:
143 | faiss_score = 1.0 / (text_ids_list.index(internal_id) + 1)
144 | except ValueError:
145 | faiss_score = 0.0
146 |
147 | # 图分数:如果记忆在图扩展的结果中,则获得加分
148 | graph_score = 1.0 if internal_id in graph_ids else 0.0
149 |
150 | final_scores[internal_id] = (w1 * faiss_score) + (w2 * graph_score)
151 |
152 | # 根据最终分数降序排序
153 | sorted_ids = sorted(
154 | final_scores.keys(), key=lambda id: final_scores[id], reverse=True
155 | )
156 |
157 | # 获取Top-K结果,并更新访问信息
158 | top_k_results = [
159 | candidate_memories[id] for id in sorted_ids[:k] if id in candidate_memories
160 | ]
161 | await self.update_memory_access_info([mem.memory_id for mem in top_k_results])
162 |
163 | return top_k_results
164 |
165 | async def update_memory_access_info(self, memory_ids: List[str]):
166 | """
167 | 批量更新一组记忆的最后访问时间和访问计数。
168 | """
169 | docs = await self.storage.get_memories_by_memory_ids(memory_ids)
170 | if not docs:
171 | return
172 |
173 | updates = []
174 | for doc in docs:
175 | try:
176 | memory_dict = json.loads(doc["memory_data"])
177 | access_info = memory_dict["metadata"]["access_info"]
178 | access_info["last_accessed_timestamp"] = datetime.now(
179 | timezone.utc
180 | ).isoformat()
181 | access_info["access_count"] = access_info.get("access_count", 0) + 1
182 |
183 | updates.append(
184 | {
185 | "memory_id": memory_dict["memory_id"],
186 | "memory_data": json.dumps(memory_dict),
187 | }
188 | )
189 | except (json.JSONDecodeError, KeyError) as e:
190 | logger.error(
191 | f"Error updating access info for memory_id {doc.get('memory_id')}: {e}"
192 | )
193 |
194 | if updates:
195 | await self.storage.update_memories(updates)
196 |
197 | async def get_all_memories_for_forgetting(self) -> List[Memory]:
198 | """获取所有记忆,用于遗忘代理的处理。"""
199 | all_docs = await self.storage.get_all_memories()
200 | return [Memory.from_dict(json.loads(doc["memory_data"])) for doc in all_docs]
201 |
202 | async def update_memories_metadata(self, memories: List[Memory]):
203 | """批量更新记忆的完整对象。"""
204 | if not memories:
205 | return
206 | updates = [
207 | {"memory_id": mem.memory_id, "memory_data": json.dumps(mem.to_dict())}
208 | for mem in memories
209 | ]
210 | await self.storage.update_memories(updates)
211 |
212 | async def delete_memories(self, memory_ids: List[str]):
213 | """
214 | 批量删除记忆。由于设置了外键级联删除,图关系会被自动清理。
215 | """
216 | docs = await self.storage.get_memories_by_memory_ids(memory_ids)
217 | if not docs:
218 | return
219 |
220 | internal_ids = [doc["id"] for doc in docs]
221 |
222 | # 1. 从 Faiss 移除
223 | await asyncio.to_thread(self.text_vstore.remove, internal_ids)
224 | await asyncio.to_thread(self.image_vstore.remove, internal_ids)
225 | await asyncio.to_thread(self.text_vstore.save_index)
226 | await asyncio.to_thread(self.image_vstore.save_index)
227 |
228 | # 2. 从 SQLite 的 memories 表移除
229 | # 由于设置了 ON DELETE CASCADE,graph_edges 表中相关的数据会被自动删除
230 | await self.storage.delete_memories_by_internal_ids(internal_ids)
231 |
232 | async def archive_memory(self, memory_id: str):
233 | """实现遗忘逻辑的第一步:归档。"""
234 | # 1. 获取记忆的 internal_id
235 | docs = await self.storage.get_memories_by_memory_ids([memory_id])
236 | if not docs:
237 | return
238 | internal_id = docs[0]["id"]
239 |
240 | # 2. 从所有 Faiss 索引中移除
241 | await asyncio.to_thread(self.text_vstore.remove, [internal_id])
242 | await asyncio.to_thread(self.image_vstore.remove, [internal_id])
243 | # (可选) 也可以从图数据库中移除以节省空间
244 | # await self.graph_storage.delete_graph_for_memory(internal_id)
245 |
246 | # 3. 在 SQLite 中更新状态
247 | await self.storage.update_memory_status([internal_id], "archived")
248 |
249 | logger.info(f"已归档记忆 {memory_id}。")
250 |
--------------------------------------------------------------------------------
/storage/graph_storage.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import aiosqlite
3 | from typing import List, Dict, Any
4 |
5 |
6 | class GraphStorageSQLite:
7 | """
8 | 使用 SQLite 关系表来管理知识图谱数据。
9 | """
10 |
11 | def __init__(self, connection: aiosqlite.Connection):
12 | """
13 | 直接接收一个已建立的 aiosqlite 连接,以确保所有操作都在同一事务空间内。
14 | """
15 | self.conn = connection
16 |
17 | async def initialize_schema(self):
18 | """在数据库中创建图相关的表和索引。"""
19 | await self.conn.execute("""
20 | CREATE TABLE IF NOT EXISTS graph_nodes (
21 | entity_id TEXT PRIMARY KEY NOT NULL,
22 | name TEXT NOT NULL,
23 | type TEXT NOT NULL
24 | )
25 | """)
26 | await self.conn.execute("""
27 | CREATE TABLE IF NOT EXISTS graph_edges (
28 | source_id TEXT NOT NULL,
29 | target_id TEXT NOT NULL,
30 | relation_type TEXT NOT NULL,
31 | memory_internal_id INTEGER NOT NULL,
32 | FOREIGN KEY (memory_internal_id) REFERENCES memories(id) ON DELETE CASCADE
33 | )
34 | """)
35 | await self.conn.execute(
36 | "CREATE INDEX IF NOT EXISTS idx_edge_source ON graph_edges (source_id)"
37 | )
38 | await self.conn.execute(
39 | "CREATE INDEX IF NOT EXISTS idx_edge_target ON graph_edges (target_id)"
40 | )
41 | await self.conn.commit()
42 |
43 | async def add_memory_graph(self, internal_id: int, payload: Dict[str, Any]):
44 | """从 knowledge_graph_payload 创建节点和边。"""
45 | # 使用事务确保要么全部成功,要么全部失败
46 | async with self.conn.cursor() as cursor:
47 | # 1. 添加或更新节点 (INSERT OR IGNORE 避免重复)
48 | nodes_to_insert = []
49 | if "event_entity" in payload and payload["event_entity"]:
50 | nodes_to_insert.append(
51 | (
52 | payload["event_entity"]["event_id"],
53 | payload["event_entity"].get(
54 | "event_type", "Event"
55 | ), # name can be event_type
56 | "Event",
57 | )
58 | )
59 | for entity in payload.get("entities", []):
60 | nodes_to_insert.append(
61 | (entity["entity_id"], entity["name"], entity["type"])
62 | )
63 |
64 | if nodes_to_insert:
65 | await cursor.executemany(
66 | "INSERT OR IGNORE INTO graph_nodes (entity_id, name, type) VALUES (?, ?, ?)",
67 | nodes_to_insert,
68 | )
69 |
70 | # 2. 添加关系边
71 | edges_to_insert = []
72 | for rel in payload.get("relationships", []):
73 | if len(rel) == 3:
74 | edges_to_insert.append((rel[0], rel[1], rel[2], internal_id))
75 |
76 | if edges_to_insert:
77 | await cursor.executemany(
78 | "INSERT INTO graph_edges (source_id, relation_type, target_id, memory_internal_id) VALUES (?, ?, ?, ?)",
79 | edges_to_insert,
80 | )
81 | await self.conn.commit()
82 |
83 | async def find_related_memory_ids(
84 | self, entity_id: str, max_depth: int = 2
85 | ) -> List[int]:
86 | """
87 | 使用 SQL 递归查询 (Recursive CTE) 从一个实体出发,查找相关联的记忆 internal_id。
88 | """
89 | query = """
90 | WITH RECURSIVE graph_walk(node_id, depth) AS (
91 | -- 递归的起始点
92 | VALUES(:entity_id, 0)
93 | UNION
94 | -- 递归步骤: 从当前节点找到所有相邻节点
95 | SELECT g.target_id, w.depth + 1
96 | FROM graph_edges g JOIN graph_walk w ON g.source_id = w.node_id
97 | WHERE w.depth < :max_depth
98 | UNION
99 | SELECT g.source_id, w.depth + 1
100 | FROM graph_edges g JOIN graph_walk w ON g.target_id = w.node_id
101 | WHERE w.depth < :max_depth
102 | )
103 | -- 从所有遍历到的节点中,查找它们作为关系边出现时关联的记忆ID
104 | SELECT DISTINCT T.memory_internal_id
105 | FROM graph_edges T
106 | JOIN graph_walk W ON T.source_id = W.node_id OR T.target_id = W.node_id;
107 | """
108 | async with self.conn.execute(
109 | query, {"entity_id": entity_id, "max_depth": max_depth}
110 | ) as cursor:
111 | rows = await cursor.fetchall()
112 | return [row[0] for row in rows]
113 |
114 | async def add_correction_link(
115 | self, new_event_id: str, old_event_id: str, memory_internal_id: int
116 | ):
117 | """为记忆更新添加 CORRECTS 关系。"""
118 | await self.conn.execute(
119 | "INSERT INTO graph_edges (source_id, relation_type, target_id, memory_internal_id) VALUES (?, 'CORRECTS', ?, ?)",
120 | (new_event_id, old_event_id, memory_internal_id),
121 | )
122 | await self.conn.commit()
123 |
--------------------------------------------------------------------------------
/storage/memory_storage.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import json
4 | import aiosqlite
5 | from typing import List, Dict, Any
6 |
7 | from ..core.models.memory_models import Memory
8 |
9 |
10 | class MemoryStorage:
11 | """
12 | 用于在 SQLite 中持久化、检索和管理结构化 Memory 对象的类。
13 | """
14 |
15 | def __init__(self, connection: aiosqlite.Connection):
16 | """
17 | 修正: 接收一个已建立的 aiosqlite 连接
18 | """
19 | self.connection = connection
20 |
21 | async def initialize_schema(self):
22 | """
23 | 建立数据库连接并创建表
24 | """
25 | # 修正: 创建包含所有需要字段的表
26 | await self.connection.execute("""
27 | CREATE TABLE IF NOT EXISTS memories (
28 | id INTEGER PRIMARY KEY AUTOINCREMENT,
29 | memory_id TEXT NOT NULL UNIQUE,
30 | timestamp TEXT NOT NULL,
31 | memory_type TEXT NOT NULL,
32 | importance_score REAL NOT NULL,
33 | status TEXT NOT NULL DEFAULT 'active',
34 | community_id TEXT, -- 为社区发现预留字段
35 | memory_data TEXT NOT NULL
36 | )
37 | """)
38 | await self.connection.execute("""
39 | CREATE UNIQUE INDEX IF NOT EXISTS idx_memory_id ON memories (memory_id);
40 | """)
41 | await self.connection.execute("""
42 | CREATE INDEX IF NOT EXISTS idx_memory_status ON memories (status);
43 | """)
44 | await self.connection.execute("""
45 | CREATE INDEX IF NOT EXISTS idx_memory_community ON memories (community_id);
46 | """)
47 | await self.connection.commit()
48 |
49 | # Note: 不提供 close() 方法,因为这个类不负责连接的生命周期管理
50 | # 连接的创建和关闭应由更高层的组件(如 FaissManagerV2)负责
51 |
52 | async def add_memory(self, memory: Memory) -> int:
53 | """
54 | 将一个 Memory 对象添加到数据库,并返回其内部自增 ID。
55 | """
56 | memory_json = json.dumps(memory.to_dict(), ensure_ascii=False)
57 | status = "active" # 默认状态
58 |
59 | cursor = await self.connection.execute(
60 | """
61 | INSERT INTO memories (memory_id, timestamp, memory_type, importance_score, status, memory_data)
62 | VALUES (?, ?, ?, ?, ?, ?)
63 | """,
64 | (
65 | memory.memory_id,
66 | memory.timestamp,
67 | memory.metadata.memory_type,
68 | memory.metadata.importance_score,
69 | status,
70 | memory_json,
71 | ),
72 | )
73 | await self.connection.commit()
74 | return cursor.lastrowid
75 |
76 | async def get_memories_by_internal_ids(
77 | self, internal_ids: List[int]
78 | ) -> List[Dict[str, Any]]:
79 | """
80 | 通过内部自增 ID 列表获取记忆数据。
81 | """
82 | if not internal_ids:
83 | return []
84 | placeholders = ",".join("?" for _ in internal_ids)
85 | sql = f"SELECT id, memory_id, memory_data FROM memories WHERE id IN ({placeholders})"
86 | async with self.connection.execute(sql, internal_ids) as cursor:
87 | rows = await cursor.fetchall()
88 | return [dict(row) for row in rows]
89 |
90 | async def get_memories_by_memory_ids(
91 | self, memory_ids: List[str]
92 | ) -> List[Dict[str, Any]]:
93 | """
94 | 通过全局 memory_id (UUID) 列表获取记忆数据。
95 | """
96 | if not memory_ids:
97 | return []
98 | placeholders = ",".join("?" for _ in memory_ids)
99 | sql = f"SELECT id, memory_id, memory_data FROM memories WHERE memory_id IN ({placeholders})"
100 | async with self.connection.execute(sql, memory_ids) as cursor:
101 | rows = await cursor.fetchall()
102 | return [dict(row) for row in rows]
103 |
104 | async def get_all_memories(self) -> List[Dict[str, Any]]:
105 | """
106 | 获取数据库中所有的记忆。
107 | """
108 | async with self.connection.execute(
109 | "SELECT id, memory_id, memory_data FROM memories"
110 | ) as cursor:
111 | rows = await cursor.fetchall()
112 | return [dict(row) for row in rows]
113 |
114 | async def update_memories(self, memories_to_update: List[Dict[str, Any]]):
115 | """
116 | 根据 memory_id 批量更新 memory_data。
117 | """
118 | if not memories_to_update:
119 | return
120 | updates = [(mem["memory_data"], mem["memory_id"]) for mem in memories_to_update]
121 | await self.connection.executemany(
122 | "UPDATE memories SET memory_data = ? WHERE memory_id = ?", updates
123 | )
124 | await self.connection.commit()
125 |
126 | async def delete_memories_by_internal_ids(self, internal_ids: List[int]):
127 | """
128 | 根据内部 ID 列表删除记忆。
129 | """
130 | if not internal_ids:
131 | return
132 | placeholders = ",".join("?" for _ in internal_ids)
133 | await self.connection.execute(
134 | f"DELETE FROM memories WHERE id IN ({placeholders})", internal_ids
135 | )
136 | await self.connection.commit()
137 |
138 | async def update_memory_status(self, internal_ids: List[int], new_status: str):
139 | """
140 | 批量更新记忆的状态 (例如, 改为 'archived')
141 | """
142 | if not internal_ids:
143 | return
144 | placeholders = ",".join("?" for _ in internal_ids)
145 | # 注意参数绑定的方式,new_status 在前
146 | await self.connection.execute(
147 | f"UPDATE memories SET status = ? WHERE id IN ({placeholders})",
148 | [new_status] + internal_ids,
149 | )
150 | await self.connection.commit()
151 |
--------------------------------------------------------------------------------
/storage/vector_store.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import asyncio
4 | import faiss
5 | import numpy as np
6 | from typing import List, Tuple, Optional
7 |
8 | from astrbot.api import logger
9 |
10 |
11 | class VectorStore:
12 | """
13 | 专门管理 Faiss 索引,处理向量的增、删、查。
14 | """
15 |
16 | def __init__(self, index_path: str, dimension: int):
17 | self.index_path = index_path
18 | self.dimension = dimension
19 | self.index: Optional[faiss.Index] = None
20 | # 注意:_load_index 现在是异步的,需要在初始化后手动调用
21 | # 或者创建一个异步的初始化方法
22 |
23 | async def _load_index(self):
24 | """
25 | 从文件加载 Faiss 索引,如果不存在则创建一个新的。
26 | """
27 | if os.path.exists(self.index_path):
28 | logger.info(f"Loading Faiss index from {self.index_path}")
29 | self.index = await asyncio.to_thread(faiss.read_index, self.index_path)
30 | else:
31 | logger.info(f"Creating new Faiss index. Dimension: {self.dimension}")
32 | # 使用 IndexFlatL2 作为基础索引,这是常用的欧氏距离索引
33 | base_index = faiss.IndexFlatL2(self.dimension)
34 | # 使用 IndexIDMap2 将我们的自定义整数 ID 映射到向量
35 | self.index = faiss.IndexIDMap2(base_index)
36 |
37 | async def save_index(self):
38 | """
39 | 将当前索引状态保存到文件。
40 | """
41 | logger.info(f"Saving Faiss index to {self.index_path}")
42 | await asyncio.to_thread(faiss.write_index, self.index, self.index_path)
43 |
44 | async def add(self, ids: List[int], embeddings: List[List[float]]):
45 | """
46 | 将带有自定义 ID 的向量添加到索引中。
47 | """
48 | if not ids:
49 | return
50 | # Faiss 需要 int64 类型的 ID 和 float32 类型的向量
51 | ids_np = np.array(ids, dtype=np.int64)
52 | embeddings_np = np.array(embeddings, dtype=np.float32)
53 | await asyncio.to_thread(self.index.add_with_ids, embeddings_np, ids_np)
54 |
55 | async def search(
56 | self, query_embedding: List[float], k: int
57 | ) -> Tuple[np.ndarray, np.ndarray]:
58 | """
59 | 在索引中搜索最相似的 k 个向量。
60 | 返回 (距离数组, ID数组)。
61 | """
62 | query_np = np.array([query_embedding], dtype=np.float32)
63 | # self.index.ntotal 是索引中的向量总数
64 | k = min(k, self.index.ntotal)
65 | if k == 0:
66 | return np.array([]), np.array([])
67 | distances, ids = await asyncio.to_thread(self.index.search, query_np, k)
68 | return distances[0], ids[0]
69 |
70 | async def remove(self, ids_to_remove: List[int]):
71 | """
72 | 从索引中移除指定 ID 的向量。
73 | """
74 | if not ids_to_remove:
75 | return
76 | ids_np = np.array(ids_to_remove, dtype=np.int64)
77 | await asyncio.to_thread(self.index.remove_ids, ids_np)
78 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | 测试模块配置文件
4 | """
5 |
6 | import sys
7 | import os
8 | from pathlib import Path
9 |
10 | # 添加项目根目录到Python路径
11 | project_root = Path(__file__).parent.parent
12 | sys.path.insert(0, str(project_root))
13 |
14 | # 测试配置
15 | TEST_CONFIG = {
16 | "session_manager": {
17 | "max_sessions": 100,
18 | "session_ttl": 3600
19 | },
20 | "recall_engine": {
21 | "retrieval_mode": "hybrid",
22 | "top_k": 5,
23 | "recall_strategy": "weighted"
24 | },
25 | "reflection_engine": {
26 | "summary_trigger_rounds": 10,
27 | "importance_threshold": 0.5
28 | },
29 | "forgetting_agent": {
30 | "enabled": True,
31 | "check_interval_hours": 24,
32 | "retention_days": 90
33 | },
34 | "timezone_settings": {
35 | "timezone": "Asia/Shanghai"
36 | },
37 | "fusion": {
38 | "strategy": "rrf",
39 | "rrf_k": 60,
40 | "dense_weight": 0.7,
41 | "sparse_weight": 0.3
42 | }
43 | }
--------------------------------------------------------------------------------
/tests/unit/test_admin_handler.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | test_admin_handler.py - 管理员处理器测试
4 | """
5 |
6 | import pytest
7 | from unittest.mock import Mock, AsyncMock, patch
8 |
9 | from core.handlers.admin_handler import AdminHandler
10 | from tests.conftest import TEST_CONFIG
11 |
12 |
13 | class TestAdminHandler:
14 | """管理员处理器测试类"""
15 |
16 | def setup_method(self):
17 | """测试前设置"""
18 | self.mock_context = Mock()
19 | self.mock_faiss_manager = Mock()
20 | self.mock_forgetting_agent = Mock()
21 | self.mock_session_manager = Mock()
22 | self.handler = AdminHandler(
23 | self.mock_context,
24 | TEST_CONFIG,
25 | self.mock_faiss_manager,
26 | self.mock_forgetting_agent,
27 | self.mock_session_manager
28 | )
29 |
30 | @pytest.mark.asyncio
31 | async def test_get_memory_status_success(self):
32 | """测试获取记忆库状态(成功)"""
33 | # 模拟数据库计数
34 | self.mock_faiss_manager.db.count_documents = AsyncMock(return_value=42)
35 |
36 | result = await self.handler.get_memory_status()
37 |
38 | assert result["success"] is True
39 | assert result["data"]["total_count"] == 42
40 |
41 | # 验证调用
42 | self.mock_faiss_manager.db.count_documents.assert_called_once()
43 |
44 | @pytest.mark.asyncio
45 | async def test_get_memory_status_no_manager(self):
46 | """测试没有管理器时获取状态"""
47 | handler = AdminHandler(self.mock_context, TEST_CONFIG, None, None, None)
48 |
49 | result = await handler.get_memory_status()
50 |
51 | assert result["success"] is False
52 | assert "记忆库尚未初始化" in result["message"]
53 |
54 | @pytest.mark.asyncio
55 | async def test_get_memory_status_exception(self):
56 | """测试获取记忆库状态异常处理"""
57 | self.mock_faiss_manager.db.count_documents = AsyncMock(side_effect=Exception("数据库错误"))
58 |
59 | result = await self.handler.get_memory_status()
60 |
61 | assert result["success"] is False
62 | assert "获取记忆库状态失败" in result["message"]
63 |
64 | @pytest.mark.asyncio
65 | async def test_delete_memory_success(self):
66 | """测试删除记忆(成功)"""
67 | self.mock_faiss_manager.delete_memories = AsyncMock()
68 |
69 | result = await self.handler.delete_memory(123)
70 |
71 | assert result["success"] is True
72 | assert "已成功删除 ID 为 123 的记忆" in result["message"]
73 |
74 | # 验证调用参数
75 | self.mock_faiss_manager.delete_memories.assert_called_once_with([123])
76 |
77 | @pytest.mark.asyncio
78 | async def test_delete_memory_no_manager(self):
79 | """测试没有管理器时删除记忆"""
80 | handler = AdminHandler(self.mock_context, TEST_CONFIG, None, None, None)
81 |
82 | result = await handler.delete_memory(123)
83 |
84 | assert result["success"] is False
85 | assert "记忆库尚未初始化" in result["message"]
86 |
87 | @pytest.mark.asyncio
88 | async def test_delete_memory_exception(self):
89 | """测试删除记忆异常处理"""
90 | self.mock_faiss_manager.delete_memories = AsyncMock(side_effect=Exception("删除错误"))
91 |
92 | result = await self.handler.delete_memory(123)
93 |
94 | assert result["success"] is False
95 | assert "删除记忆时发生错误" in result["message"]
96 |
97 | @pytest.mark.asyncio
98 | async def test_run_forgetting_agent_success(self):
99 | """测试运行遗忘代理(成功)"""
100 | self.mock_forgetting_agent._prune_memories = AsyncMock()
101 |
102 | result = await self.handler.run_forgetting_agent()
103 |
104 | assert result["success"] is True
105 | assert "遗忘代理任务执行完毕" in result["message"]
106 |
107 | # 验证调用
108 | self.mock_forgetting_agent._prune_memories.assert_called_once()
109 |
110 | @pytest.mark.asyncio
111 | async def test_run_forgetting_agent_no_agent(self):
112 | """测试没有遗忘代理时运行"""
113 | handler = AdminHandler(self.mock_context, TEST_CONFIG, None, None, None)
114 |
115 | result = await handler.run_forgetting_agent()
116 |
117 | assert result["success"] is False
118 | assert "遗忘代理尚未初始化" in result["message"]
119 |
120 | @pytest.mark.asyncio
121 | async def test_run_forgetting_agent_exception(self):
122 | """测试运行遗忘代理异常处理"""
123 | self.mock_forgetting_agent._prune_memories = AsyncMock(side_effect=Exception("遗忘代理错误"))
124 |
125 | result = await self.handler.run_forgetting_agent()
126 |
127 | assert result["success"] is False
128 | assert "遗忘代理任务执行失败" in result["message"]
129 |
130 | @pytest.mark.asyncio
131 | async def test_set_search_mode_valid(self):
132 | """测试设置搜索模式(有效模式)"""
133 | result = await self.handler.set_search_mode("hybrid")
134 |
135 | assert result["success"] is True
136 | assert "检索模式已设置为: hybrid" in result["message"]
137 |
138 | @pytest.mark.asyncio
139 | async def test_set_search_mode_invalid(self):
140 | """测试设置搜索模式(无效模式)"""
141 | result = await self.handler.set_search_mode("invalid_mode")
142 |
143 | assert result["success"] is False
144 | assert "无效的模式" in result["message"]
145 | assert "hybrid, dense, sparse" in result["message"]
146 |
147 | @pytest.mark.asyncio
148 | async def test_get_config_summary_show(self):
149 | """测试获取配置摘要(显示)"""
150 | # 模拟会话管理器
151 | self.mock_session_manager.get_session_count = Mock(return_value=5)
152 |
153 | result = await self.handler.get_config_summary("show")
154 |
155 | assert result["success"] is True
156 | data = result["data"]
157 |
158 | # 验证各个配置部分
159 | assert "session_manager" in data
160 | assert "recall_engine" in data
161 | assert "reflection_engine" in data
162 | assert "forgetting_agent" in data
163 |
164 | # 验证具体配置值
165 | assert data["session_manager"]["max_sessions"] == 100
166 | assert data["session_manager"]["session_ttl"] == 3600
167 | assert data["session_manager"]["current_sessions"] == 5
168 | assert data["recall_engine"]["retrieval_mode"] == "hybrid"
169 | assert data["recall_engine"]["top_k"] == 5
170 | assert data["forgetting_agent"]["enabled"] is True
171 |
172 | # 验证调用
173 | self.mock_session_manager.get_session_count.assert_called_once()
174 |
175 | @pytest.mark.asyncio
176 | async def test_get_config_summary_validate_success(self):
177 | """测试获取配置摘要(验证成功)"""
178 | with patch('core.config_validator.validate_config') as mock_validate:
179 | mock_validate.return_value = None # 验证通过时不返回异常
180 |
181 | result = await self.handler.get_config_summary("validate")
182 |
183 | assert result["success"] is True
184 | assert "配置验证通过,所有参数均有效" in result["message"]
185 |
186 | # 验证调用
187 | mock_validate.assert_called_once_with(TEST_CONFIG)
188 |
189 | @pytest.mark.asyncio
190 | async def test_get_config_summary_validate_failure(self):
191 | """测试获取配置摘要(验证失败)"""
192 | with patch('core.config_validator.validate_config') as mock_validate:
193 | mock_validate.side_effect = ValueError("配置验证失败")
194 |
195 | result = await self.handler.get_config_summary("validate")
196 |
197 | assert result["success"] is False
198 | assert "配置验证失败" in result["message"]
199 |
200 | # 验证调用
201 | mock_validate.assert_called_once_with(TEST_CONFIG)
202 |
203 | @pytest.mark.asyncio
204 | async def test_get_config_summary_invalid_action(self):
205 | """测试获取配置摘要(无效动作)"""
206 | result = await self.handler.get_config_summary("invalid_action")
207 |
208 | assert result["success"] is False
209 | assert "无效的动作" in result["message"]
210 | assert "show" in result["message"]
211 | assert "validate" in result["message"]
212 |
213 | @pytest.mark.asyncio
214 | async def test_get_config_summary_show_exception(self):
215 | """测试获取配置摘要显示异常处理"""
216 | self.mock_session_manager.get_session_count = Mock(side_effect=Exception("配置错误"))
217 |
218 | result = await self.handler.get_config_summary("show")
219 |
220 | assert result["success"] is False
221 | assert "显示配置时发生错误" in result["message"]
222 |
223 | def test_format_status_for_display_success(self):
224 | """测试格式化状态显示(成功)"""
225 | mock_response = {
226 | "success": True,
227 | "data": {"total_count": 42}
228 | }
229 |
230 | result = self.handler.format_status_for_display(mock_response)
231 |
232 | assert "📊 LivingMemory 记忆库状态:" in result
233 | assert "- 总记忆数: 42" in result
234 |
235 | def test_format_status_for_display_failure(self):
236 | """测试格式化状态显示(失败)"""
237 | mock_response = {
238 | "success": False,
239 | "message": "获取失败"
240 | }
241 |
242 | result = self.handler.format_status_for_display(mock_response)
243 |
244 | assert result == "获取失败"
245 |
246 | def test_format_config_summary_for_display_success(self):
247 | """测试格式化配置摘要显示(成功)"""
248 | mock_response = {
249 | "success": True,
250 | "data": {
251 | "session_manager": {
252 | "max_sessions": 1000,
253 | "session_ttl": 3600,
254 | "current_sessions": 5
255 | },
256 | "recall_engine": {
257 | "retrieval_mode": "hybrid",
258 | "top_k": 5,
259 | "recall_strategy": "weighted"
260 | },
261 | "reflection_engine": {
262 | "summary_trigger_rounds": 10,
263 | "importance_threshold": 0.5
264 | },
265 | "forgetting_agent": {
266 | "enabled": True,
267 | "check_interval_hours": 24,
268 | "retention_days": 90
269 | }
270 | }
271 | }
272 |
273 | result = self.handler.format_config_summary_for_display(mock_response)
274 |
275 | assert "📋 LivingMemory 配置摘要:" in result
276 | assert "🗂️ 会话管理:" in result
277 | assert "🧠 回忆引擎:" in result
278 | assert "💭 反思引擎:" in result
279 | assert "🗑️ 遗忘代理:" in result
280 | assert "最大会话数: 1000" in result
281 | assert "会话TTL: 3600秒" in result
282 | assert "当前会话数: 5" in result
283 | assert "检索模式: hybrid" in result
284 | assert "返回数量: 5" in result
285 | assert "启用状态: 是" in result
286 | assert "检查间隔: 24小时" in result
287 | assert "保留天数: 90天" in result
288 |
289 | def test_format_config_summary_for_display_failure(self):
290 | """测试格式化配置摘要显示(失败)"""
291 | mock_response = {
292 | "success": False,
293 | "message": "配置获取失败"
294 | }
295 |
296 | result = self.handler.format_config_summary_for_display(mock_response)
297 |
298 | assert result == "配置获取失败"
299 |
300 | def test_format_config_summary_for_display_missing_sections(self):
301 | """测试格式化配置摘要显示(缺少部分配置)"""
302 | mock_response = {
303 | "success": True,
304 | "data": {
305 | "session_manager": {
306 | "max_sessions": 1000,
307 | "session_ttl": 3600,
308 | "current_sessions": 5
309 | }
310 | # 缺少其他配置部分
311 | }
312 | }
313 |
314 | result = self.handler.format_config_summary_for_display(mock_response)
315 |
316 | assert "📋 LivingMemory 配置摘要:" in result
317 | assert "🗂️ 会话管理:" in result
318 | # 方法总是显示所有配置部分,使用默认值
319 | assert "🧠 回忆引擎:" in result
320 | assert "💭 反思引擎:" in result
321 | assert "🗑️ 遗忘代理:" in result
--------------------------------------------------------------------------------
/tests/unit/test_base_handler.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | test_base_handler.py - 基础处理器测试
4 | """
5 |
6 | import pytest
7 | import json
8 | from datetime import datetime, timezone
9 | from unittest.mock import Mock, patch
10 |
11 | from core.handlers.base_handler import TestableBaseHandler
12 | from tests.conftest import TEST_CONFIG
13 |
14 |
15 | class TestBaseHandler:
16 | """基础处理器测试类"""
17 |
18 | def setup_method(self):
19 | """测试前设置"""
20 | self.mock_context = Mock()
21 | self.handler = TestableBaseHandler(self.mock_context, TEST_CONFIG)
22 |
23 | def test_init(self):
24 | """测试初始化"""
25 | assert self.handler.context == self.mock_context
26 | assert self.handler.config == TEST_CONFIG
27 |
28 | def test_create_response_success(self):
29 | """测试创建成功响应"""
30 | response = self.handler.create_response(True, "成功消息", {"data": "test"})
31 |
32 | assert response["success"] is True
33 | assert response["message"] == "成功消息"
34 | assert response["data"] == {"data": "test"}
35 |
36 | def test_create_response_failure(self):
37 | """测试创建失败响应"""
38 | response = self.handler.create_response(False, "失败消息")
39 |
40 | assert response["success"] is False
41 | assert response["message"] == "失败消息"
42 | assert response["data"] is None
43 |
44 | def test_safe_parse_metadata_dict(self):
45 | """测试解析字典类型元数据"""
46 | metadata = {"key": "value", "number": 123}
47 | result = self.handler.safe_parse_metadata(metadata)
48 |
49 | assert result == metadata
50 |
51 | def test_safe_parse_metadata_json_string(self):
52 | """测试解析JSON字符串元数据"""
53 | metadata_str = '{"key": "value", "number": 123}'
54 | result = self.handler.safe_parse_metadata(metadata_str)
55 |
56 | assert result == {"key": "value", "number": 123}
57 |
58 | def test_safe_parse_metadata_invalid_json(self):
59 | """测试解析无效JSON字符串"""
60 | invalid_json = "{invalid json}"
61 | result = self.handler.safe_parse_metadata(invalid_json)
62 |
63 | assert result == {}
64 |
65 | def test_safe_parse_metadata_none(self):
66 | """测试解析None值"""
67 | result = self.handler.safe_parse_metadata(None)
68 |
69 | assert result == {}
70 |
71 | def test_format_timestamp_valid(self):
72 | """测试格式化有效时间戳"""
73 | timestamp = 1609459200.0 # 2021-01-01 00:00:00 UTC
74 | result = self.handler.format_timestamp(timestamp)
75 |
76 | assert "2021-01-01" in result
77 | # 由于时区转换,时间可能不是00:00:00
78 | assert len(result) > 10
79 |
80 | def test_format_timestamp_none(self):
81 | """测试格式化None时间戳"""
82 | result = self.handler.format_timestamp(None)
83 |
84 | assert result == "未知"
85 |
86 | def test_format_timestamp_invalid(self):
87 | """测试格式化无效时间戳"""
88 | result = self.handler.format_timestamp("invalid")
89 |
90 | assert result == "未知"
91 |
92 | def test_get_timezone(self):
93 | """测试获取时区"""
94 | with patch('pytz.timezone') as mock_timezone:
95 | mock_tz = Mock()
96 | mock_timezone.return_value = mock_tz
97 |
98 | result = self.handler.get_timezone()
99 |
100 | assert result == mock_tz
101 | mock_timezone.assert_called_once_with("Asia/Shanghai")
102 |
103 | def test_format_memory_card(self):
104 | """测试格式化记忆卡片"""
105 | # 创建模拟的记忆结果
106 | mock_result = Mock()
107 | mock_result.data = {
108 | "id": 1,
109 | "text": "测试记忆内容",
110 | "metadata": json.dumps({
111 | "create_time": 1609459200.0,
112 | "last_access_time": 1609459200.0,
113 | "importance": 0.8,
114 | "event_type": "FACT"
115 | })
116 | }
117 | mock_result.similarity = 0.95
118 |
119 | result = self.handler.format_memory_card(mock_result)
120 |
121 | assert "ID: 1" in result
122 | assert "记 忆 度: 0.95" in result
123 | assert "重 要 性: 0.80" in result
124 | assert "记忆类型: FACT" in result
125 | assert "测试记忆内容" in result
--------------------------------------------------------------------------------
/tests/unit/test_memory_handler.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | test_memory_handler.py - 记忆管理处理器测试
4 | """
5 |
6 | import pytest
7 | import json
8 | from unittest.mock import Mock, AsyncMock, patch
9 |
10 | from core.handlers.memory_handler import MemoryHandler
11 | from tests.conftest import TEST_CONFIG
12 |
13 |
14 | class TestMemoryHandler:
15 | """记忆管理处理器测试类"""
16 |
17 | def setup_method(self):
18 | """测试前设置"""
19 | self.mock_context = Mock()
20 | self.mock_faiss_manager = Mock()
21 | self.handler = MemoryHandler(self.mock_context, TEST_CONFIG, self.mock_faiss_manager)
22 |
23 | @pytest.mark.asyncio
24 | async def test_edit_memory_content(self):
25 | """测试编辑记忆内容"""
26 | # 模拟faiss_manager.update_memory的返回值
27 | mock_result = {
28 | "success": True,
29 | "message": "更新成功",
30 | "updated_fields": ["content"],
31 | "memory_id": 123
32 | }
33 | self.mock_faiss_manager.update_memory = AsyncMock(return_value=mock_result)
34 |
35 | result = await self.handler.edit_memory("123", "content", "新的记忆内容", "测试更新")
36 |
37 | assert result["success"] is True
38 | assert "更新成功" in result["message"]
39 |
40 | # 验证调用参数
41 | self.mock_faiss_manager.update_memory.assert_called_once_with(
42 | memory_id=123,
43 | update_reason="测试更新",
44 | content="新的记忆内容"
45 | )
46 |
47 | @pytest.mark.asyncio
48 | async def test_edit_memory_importance_valid(self):
49 | """测试编辑记忆重要性(有效值)"""
50 | mock_result = {
51 | "success": True,
52 | "message": "更新成功",
53 | "updated_fields": ["importance"]
54 | }
55 | self.mock_faiss_manager.update_memory = AsyncMock(return_value=mock_result)
56 |
57 | result = await self.handler.edit_memory("123", "importance", "0.9", "提高重要性")
58 |
59 | assert result["success"] is True
60 | self.mock_faiss_manager.update_memory.assert_called_once_with(
61 | memory_id=123,
62 | update_reason="提高重要性",
63 | importance=0.9
64 | )
65 |
66 | @pytest.mark.asyncio
67 | async def test_edit_memory_importance_invalid_range(self):
68 | """测试编辑记忆重要性(无效范围)"""
69 | result = await self.handler.edit_memory("123", "importance", "1.5", "无效值")
70 |
71 | assert result["success"] is False
72 | assert "重要性评分必须在 0.0 到 1.0 之间" in result["message"]
73 |
74 | # 验证没有调用update_memory
75 | self.mock_faiss_manager.update_memory.assert_not_called()
76 |
77 | @pytest.mark.asyncio
78 | async def test_edit_memory_importance_invalid_type(self):
79 | """测试编辑记忆重要性(无效类型)"""
80 | result = await self.handler.edit_memory("123", "importance", "invalid", "非数字")
81 |
82 | assert result["success"] is False
83 | assert "重要性评分必须是数字" in result["message"]
84 |
85 | # 验证没有调用update_memory
86 | self.mock_faiss_manager.update_memory.assert_not_called()
87 |
88 | @pytest.mark.asyncio
89 | async def test_edit_memory_type_valid(self):
90 | """测试编辑记忆类型(有效值)"""
91 | mock_result = {
92 | "success": True,
93 | "message": "更新成功",
94 | "updated_fields": ["event_type"]
95 | }
96 | self.mock_faiss_manager.update_memory = AsyncMock(return_value=mock_result)
97 |
98 | result = await self.handler.edit_memory("123", "type", "PREFERENCE", "重新分类")
99 |
100 | assert result["success"] is True
101 | self.mock_faiss_manager.update_memory.assert_called_once_with(
102 | memory_id=123,
103 | update_reason="重新分类",
104 | event_type="PREFERENCE"
105 | )
106 |
107 | @pytest.mark.asyncio
108 | async def test_edit_memory_type_invalid(self):
109 | """测试编辑记忆类型(无效值)"""
110 | result = await self.handler.edit_memory("123", "type", "INVALID_TYPE", "无效类型")
111 |
112 | assert result["success"] is False
113 | assert "无效的事件类型" in result["message"]
114 |
115 | # 验证没有调用update_memory
116 | self.mock_faiss_manager.update_memory.assert_not_called()
117 |
118 | @pytest.mark.asyncio
119 | async def test_edit_memory_status_valid(self):
120 | """测试编辑记忆状态(有效值)"""
121 | mock_result = {
122 | "success": True,
123 | "message": "更新成功",
124 | "updated_fields": ["status"]
125 | }
126 | self.mock_faiss_manager.update_memory = AsyncMock(return_value=mock_result)
127 |
128 | result = await self.handler.edit_memory("123", "status", "archived", "项目完成")
129 |
130 | assert result["success"] is True
131 | self.mock_faiss_manager.update_memory.assert_called_once_with(
132 | memory_id=123,
133 | update_reason="项目完成",
134 | status="archived"
135 | )
136 |
137 | @pytest.mark.asyncio
138 | async def test_edit_memory_status_invalid(self):
139 | """测试编辑记忆状态(无效值)"""
140 | result = await self.handler.edit_memory("123", "status", "INVALID_STATUS", "无效状态")
141 |
142 | assert result["success"] is False
143 | assert "无效的状态" in result["message"]
144 |
145 | # 验证没有调用update_memory
146 | self.mock_faiss_manager.update_memory.assert_not_called()
147 |
148 | @pytest.mark.asyncio
149 | async def test_edit_memory_unknown_field(self):
150 | """测试编辑未知字段"""
151 | result = await self.handler.edit_memory("123", "unknown_field", "value", "未知字段")
152 |
153 | assert result["success"] is False
154 | assert "未知的字段" in result["message"]
155 |
156 | # 验证没有调用update_memory
157 | self.mock_faiss_manager.update_memory.assert_not_called()
158 |
159 | @pytest.mark.asyncio
160 | async def test_edit_memory_string_id(self):
161 | """测试使用字符串ID编辑记忆"""
162 | mock_result = {
163 | "success": True,
164 | "message": "更新成功",
165 | "updated_fields": ["content"]
166 | }
167 | self.mock_faiss_manager.update_memory = AsyncMock(return_value=mock_result)
168 |
169 | result = await self.handler.edit_memory("abc123", "content", "新内容", "字符串ID")
170 |
171 | assert result["success"] is True
172 | self.mock_faiss_manager.update_memory.assert_called_once_with(
173 | memory_id="abc123",
174 | update_reason="字符串ID",
175 | content="新内容"
176 | )
177 |
178 | @pytest.mark.asyncio
179 | async def test_edit_memory_no_faiss_manager(self):
180 | """测试没有faiss_manager时的错误处理"""
181 | handler = MemoryHandler(self.mock_context, TEST_CONFIG, None)
182 |
183 | result = await handler.edit_memory("123", "content", "新内容")
184 |
185 | assert result["success"] is False
186 | assert "记忆库尚未初始化" in result["message"]
187 |
188 | @pytest.mark.asyncio
189 | async def test_edit_memory_exception(self):
190 | """测试编辑记忆时的异常处理"""
191 | self.mock_faiss_manager.update_memory = AsyncMock(side_effect=Exception("数据库错误"))
192 |
193 | result = await self.handler.edit_memory("123", "content", "新内容")
194 |
195 | assert result["success"] is False
196 | assert "编辑记忆时发生错误" in result["message"]
197 |
198 | @pytest.mark.asyncio
199 | async def test_get_memory_details_success(self):
200 | """测试获取记忆详细信息(成功)"""
201 | # 模拟数据库查询结果
202 | mock_docs = [{
203 | "id": 123,
204 | "content": "测试记忆内容",
205 | "metadata": json.dumps({
206 | "create_time": 1609459200.0,
207 | "last_access_time": 1609459200.0,
208 | "importance": 0.8,
209 | "event_type": "FACT",
210 | "status": "active"
211 | })
212 | }]
213 |
214 | self.mock_faiss_manager.db.document_storage.get_documents = AsyncMock(return_value=mock_docs)
215 |
216 | result = await self.handler.get_memory_details("123")
217 |
218 | assert result["success"] is True
219 | data = result["data"]
220 | assert data["id"] == "123"
221 | assert data["content"] == "测试记忆内容"
222 | assert data["importance"] == 0.8
223 | assert data["event_type"] == "FACT"
224 | assert data["status"] == "active"
225 |
226 | @pytest.mark.asyncio
227 | async def test_get_memory_details_not_found(self):
228 | """测试获取不存在的记忆详细信息"""
229 | self.mock_faiss_manager.db.document_storage.get_documents = AsyncMock(return_value=[])
230 |
231 | result = await self.handler.get_memory_details("999")
232 |
233 | assert result["success"] is False
234 | assert "未找到ID为 999 的记忆" in result["message"]
235 |
236 | @pytest.mark.asyncio
237 | async def test_get_memory_history_success(self):
238 | """测试获取记忆历史(成功)"""
239 | mock_docs = [{
240 | "id": 123,
241 | "content": "测试记忆内容",
242 | "metadata": json.dumps({
243 | "create_time": 1609459200.0,
244 | "importance": 0.8,
245 | "event_type": "FACT",
246 | "status": "active",
247 | "update_history": [
248 | {
249 | "timestamp": 1609459200.0,
250 | "reason": "初始创建",
251 | "fields": ["content", "importance"]
252 | }
253 | ]
254 | })
255 | }]
256 |
257 | self.mock_faiss_manager.db.document_storage.get_documents = AsyncMock(return_value=mock_docs)
258 |
259 | result = await self.handler.get_memory_history("123")
260 |
261 | assert result["success"] is True
262 | data = result["data"]
263 | assert len(data["update_history"]) == 1
264 | assert data["update_history"][0]["reason"] == "初始创建"
265 |
266 | def test_format_memory_details_for_display_success(self):
267 | """测试格式化记忆详细信息显示(成功)"""
268 | mock_response = {
269 | "success": True,
270 | "data": {
271 | "id": "123",
272 | "content": "测试记忆内容",
273 | "importance": 0.8,
274 | "event_type": "FACT",
275 | "status": "active",
276 | "create_time": "2021-01-01 00:00:00",
277 | "last_access_time": "2021-01-01 00:00:00",
278 | "update_history": []
279 | }
280 | }
281 |
282 | result = self.handler.format_memory_details_for_display(mock_response)
283 |
284 | assert "📝 记忆 123 的详细信息:" in result
285 | assert "测试记忆内容" in result
286 | assert "重要性: 0.8" in result
287 | assert "类型: FACT" in result
288 | assert "状态: active" in result
289 |
290 | def test_format_memory_details_for_display_failure(self):
291 | """测试格式化记忆详细信息显示(失败)"""
292 | mock_response = {
293 | "success": False,
294 | "message": "获取失败"
295 | }
296 |
297 | result = self.handler.format_memory_details_for_display(mock_response)
298 |
299 | assert result == "获取失败"
300 |
301 | def test_format_memory_history_for_display_success(self):
302 | """测试格式化记忆历史显示(成功)"""
303 | mock_response = {
304 | "success": True,
305 | "data": {
306 | "id": "123",
307 | "content": "测试记忆内容",
308 | "metadata": {
309 | "importance": 0.8,
310 | "event_type": "FACT",
311 | "status": "active",
312 | "create_time": "2021-01-01 00:00:00"
313 | },
314 | "update_history": [
315 | {
316 | "timestamp": 1609459200.0,
317 | "reason": "初始创建",
318 | "fields": ["content"]
319 | }
320 | ]
321 | }
322 | }
323 |
324 | result = self.handler.format_memory_history_for_display(mock_response)
325 |
326 | assert "📝 记忆 123 的详细信息:" in result
327 | assert "测试记忆内容" in result
328 | assert "🔄 更新历史 (1 次):" in result
329 | assert "初始创建" in result
330 |
331 | def test_format_memory_history_for_display_no_history(self):
332 | """测试格式化记忆历史显示(无历史记录)"""
333 | mock_response = {
334 | "success": True,
335 | "data": {
336 | "id": "123",
337 | "content": "测试记忆内容",
338 | "metadata": {
339 | "importance": 0.8,
340 | "event_type": "FACT",
341 | "status": "active",
342 | "create_time": "2021-01-01 00:00:00"
343 | },
344 | "update_history": []
345 | }
346 | }
347 |
348 | result = self.handler.format_memory_history_for_display(mock_response)
349 |
350 | assert "🔄 暂无更新记录" in result
--------------------------------------------------------------------------------
/tests/unit/test_search_handler.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | test_search_handler.py - 搜索管理处理器测试
4 | """
5 |
6 | import pytest
7 | from unittest.mock import Mock, AsyncMock
8 |
9 | from core.handlers.search_handler import SearchHandler
10 | from tests.conftest import TEST_CONFIG
11 |
12 |
13 | class TestSearchHandler:
14 | """搜索管理处理器测试类"""
15 |
16 | def setup_method(self):
17 | """测试前设置"""
18 | self.mock_context = Mock()
19 | self.mock_recall_engine = Mock()
20 | self.mock_sparse_retriever = Mock()
21 | self.handler = SearchHandler(
22 | self.mock_context,
23 | TEST_CONFIG,
24 | self.mock_recall_engine,
25 | self.mock_sparse_retriever
26 | )
27 |
28 | @pytest.mark.asyncio
29 | async def test_search_memories_success(self):
30 | """测试搜索记忆(成功)"""
31 | # 模拟搜索结果
32 | mock_results = [
33 | Mock(
34 | data={"id": 1, "text": "测试记忆1", "metadata": '{"importance": 0.8}'},
35 | similarity=0.9
36 | ),
37 | Mock(
38 | data={"id": 2, "text": "测试记忆2", "metadata": '{"importance": 0.6}'},
39 | similarity=0.7
40 | )
41 | ]
42 |
43 | self.mock_recall_engine.recall = AsyncMock(return_value=mock_results)
44 |
45 | result = await self.handler.search_memories("测试查询", k=5)
46 |
47 | assert result["success"] is True
48 | assert "为您找到 2 条相关记忆" in result["message"]
49 | assert len(result["data"]) == 2
50 | assert result["data"][0]["id"] == 1
51 | assert result["data"][0]["similarity"] == 0.9
52 | assert result["data"][1]["id"] == 2
53 | assert result["data"][1]["similarity"] == 0.7
54 |
55 | # 验证调用参数
56 | self.mock_recall_engine.recall.assert_called_once_with(
57 | self.mock_context, "测试查询", k=5
58 | )
59 |
60 | @pytest.mark.asyncio
61 | async def test_search_memories_no_results(self):
62 | """测试搜索记忆(无结果)"""
63 | self.mock_recall_engine.recall = AsyncMock(return_value=[])
64 |
65 | result = await self.handler.search_memories("无结果查询", k=5)
66 |
67 | assert result["success"] is True
68 | assert "未能找到与 '无结果查询' 相关的记忆" in result["message"]
69 | assert result["data"] == []
70 |
71 | @pytest.mark.asyncio
72 | async def test_search_memories_no_recall_engine(self):
73 | """测试没有回忆引擎时的搜索"""
74 | handler = SearchHandler(self.mock_context, TEST_CONFIG, None, None)
75 |
76 | result = await handler.search_memories("测试查询")
77 |
78 | assert result["success"] is False
79 | assert "回忆引擎尚未初始化" in result["message"]
80 |
81 | @pytest.mark.asyncio
82 | async def test_search_memories_exception(self):
83 | """测试搜索记忆时的异常处理"""
84 | self.mock_recall_engine.recall = AsyncMock(side_effect=Exception("搜索错误"))
85 |
86 | result = await self.handler.search_memories("测试查询")
87 |
88 | assert result["success"] is False
89 | assert "搜索记忆时发生错误" in result["message"]
90 |
91 | @pytest.mark.asyncio
92 | async def test_test_sparse_search_success(self):
93 | """测试稀疏检索测试(成功)"""
94 | # 模拟稀疏检索结果
95 | mock_results = [
96 | Mock(
97 | doc_id=1,
98 | score=0.8,
99 | content="稀疏检索结果1",
100 | metadata={"event_type": "FACT", "importance": 0.7}
101 | ),
102 | Mock(
103 | doc_id=2,
104 | score=0.6,
105 | content="稀疏检索结果2",
106 | metadata={"event_type": "PREFERENCE", "importance": 0.9}
107 | )
108 | ]
109 |
110 | self.mock_sparse_retriever.search = AsyncMock(return_value=mock_results)
111 |
112 | result = await self.handler.test_sparse_search("测试查询", k=5)
113 |
114 | assert result["success"] is True
115 | assert "找到 2 条稀疏检索结果" in result["message"]
116 | assert len(result["data"]) == 2
117 | assert result["data"][0]["doc_id"] == 1
118 | assert result["data"][0]["score"] == 0.8
119 | assert result["data"][1]["doc_id"] == 2
120 | assert result["data"][1]["score"] == 0.6
121 |
122 | # 验证调用参数
123 | self.mock_sparse_retriever.search.assert_called_once_with(
124 | query="测试查询", limit=5
125 | )
126 |
127 | @pytest.mark.asyncio
128 | async def test_test_sparse_search_no_results(self):
129 | """测试稀疏检索测试(无结果)"""
130 | self.mock_sparse_retriever.search = AsyncMock(return_value=[])
131 |
132 | result = await self.handler.test_sparse_search("无结果查询", k=5)
133 |
134 | assert result["success"] is True
135 | assert "未找到与 '无结果查询' 相关的记忆" in result["message"]
136 | assert result["data"] == []
137 |
138 | @pytest.mark.asyncio
139 | async def test_test_sparse_search_no_retriever(self):
140 | """测试没有稀疏检索器时的测试"""
141 | handler = SearchHandler(self.mock_context, TEST_CONFIG, None, None)
142 |
143 | result = await handler.test_sparse_search("测试查询")
144 |
145 | assert result["success"] is False
146 | assert "稀疏检索器未启用" in result["message"]
147 |
148 | @pytest.mark.asyncio
149 | async def test_test_sparse_search_exception(self):
150 | """测试稀疏检索测试时的异常处理"""
151 | self.mock_sparse_retriever.search = AsyncMock(side_effect=Exception("稀疏检索错误"))
152 |
153 | result = await self.handler.test_sparse_search("测试查询")
154 |
155 | assert result["success"] is False
156 | assert "稀疏检索测试失败" in result["message"]
157 |
158 | @pytest.mark.asyncio
159 | async def test_rebuild_sparse_index_success(self):
160 | """测试重建稀疏索引(成功)"""
161 | self.mock_sparse_retriever.rebuild_index = AsyncMock()
162 |
163 | result = await self.handler.rebuild_sparse_index()
164 |
165 | assert result["success"] is True
166 | assert "稀疏检索索引重建完成" in result["message"]
167 |
168 | # 验证调用了重建方法
169 | self.mock_sparse_retriever.rebuild_index.assert_called_once()
170 |
171 | @pytest.mark.asyncio
172 | async def test_rebuild_sparse_index_no_retriever(self):
173 | """测试没有稀疏检索器时重建索引"""
174 | handler = SearchHandler(self.mock_context, TEST_CONFIG, None, None)
175 |
176 | result = await handler.rebuild_sparse_index()
177 |
178 | assert result["success"] is False
179 | assert "稀疏检索器未启用" in result["message"]
180 |
181 | @pytest.mark.asyncio
182 | async def test_rebuild_sparse_index_exception(self):
183 | """测试重建稀索引导常处理"""
184 | self.mock_sparse_retriever.rebuild_index = AsyncMock(side_effect=Exception("重建错误"))
185 |
186 | result = await self.handler.rebuild_sparse_index()
187 |
188 | assert result["success"] is False
189 | assert "重建稀疏索引失败" in result["message"]
190 |
191 | def test_format_search_results_for_display_success(self):
192 | """测试格式化搜索结果显示(成功)"""
193 | mock_response = {
194 | "success": True,
195 | "message": "为您找到 2 条相关记忆",
196 | "data": [
197 | {
198 | "id": 1,
199 | "similarity": 0.9,
200 | "text": "测试记忆1",
201 | "metadata": {
202 | "create_time": 1609459200.0,
203 | "last_access_time": 1609459200.0,
204 | "importance": 0.8,
205 | "event_type": "FACT"
206 | }
207 | },
208 | {
209 | "id": 2,
210 | "similarity": 0.7,
211 | "text": "测试记忆2",
212 | "metadata": {
213 | "create_time": 1609459200.0,
214 | "last_access_time": 1609459200.0,
215 | "importance": 0.6,
216 | "event_type": "PREFERENCE"
217 | }
218 | }
219 | ]
220 | }
221 |
222 | result = self.handler.format_search_results_for_display(mock_response)
223 |
224 | assert "为您找到 2 条相关记忆" in result
225 | assert "ID: 1" in result
226 | assert "记 忆 度: 0.90" in result
227 | assert "重 要 性: 0.80" in result
228 | assert "记忆类型: FACT" in result
229 | assert "测试记忆1" in result
230 | assert "ID: 2" in result
231 | assert "测试记忆2" in result
232 |
233 | def test_format_search_results_for_display_failure(self):
234 | """测试格式化搜索结果显示(失败)"""
235 | mock_response = {
236 | "success": False,
237 | "message": "搜索失败"
238 | }
239 |
240 | result = self.handler.format_search_results_for_display(mock_response)
241 |
242 | assert result == "搜索失败"
243 |
244 | def test_format_sparse_results_for_display_success(self):
245 | """测试格式化稀疏检索结果显示(成功)"""
246 | mock_response = {
247 | "success": True,
248 | "message": "找到 2 条稀疏检索结果",
249 | "data": [
250 | {
251 | "doc_id": 1,
252 | "score": 0.8,
253 | "content": "稀疏检索结果1",
254 | "metadata": {
255 | "event_type": "FACT",
256 | "importance": 0.7
257 | }
258 | },
259 | {
260 | "doc_id": 2,
261 | "score": 0.6,
262 | "content": "稀疏检索结果2很长很长很长很长很长很长很长很长很长很长很长很长",
263 | "metadata": {
264 | "event_type": "PREFERENCE",
265 | "importance": 0.9
266 | }
267 | }
268 | ]
269 | }
270 |
271 | result = self.handler.format_sparse_results_for_display(mock_response)
272 |
273 | assert "找到 2 条稀疏检索结果" in result
274 | assert "1. [ID: 1] Score: 0.800" in result
275 | assert "类型: FACT" in result
276 | assert "重要性: 0.70" in result
277 | assert "2. [ID: 2] Score: 0.600" in result
278 | assert "..." in result or len(result) < 500 # 长内容被截断或整个结果显示
279 | assert "类型: PREFERENCE" in result
280 | assert "重要性: 0.90" in result
281 |
282 | def test_format_sparse_results_for_display_failure(self):
283 | """测试格式化稀疏检索结果显示(失败)"""
284 | mock_response = {
285 | "success": False,
286 | "message": "稀疏检索失败"
287 | }
288 |
289 | result = self.handler.format_sparse_results_for_display(mock_response)
290 |
291 | assert result == "稀疏检索失败"
292 |
293 | def test_format_sparse_results_for_display_no_metadata(self):
294 | """测试格式化稀疏检索结果显示(无元数据)"""
295 | mock_response = {
296 | "success": True,
297 | "message": "找到 1 条稀疏检索结果",
298 | "data": [
299 | {
300 | "doc_id": 1,
301 | "score": 0.8,
302 | "content": "稀疏检索结果",
303 | "metadata": {}
304 | }
305 | ]
306 | }
307 |
308 | result = self.handler.format_sparse_results_for_display(mock_response)
309 |
310 | assert "找到 1 条稀疏检索结果" in result
311 | assert "1. [ID: 1] Score: 0.800" in result
312 | assert "稀疏检索结果" in result
313 | # 不应该包含类型和重要性信息
314 | assert "类型:" not in result
315 | assert "重要性:" not in result
--------------------------------------------------------------------------------