├── .gitignore ├── FUSION_STRATEGIES.md ├── LICENSE ├── README.md ├── TODO.md ├── _conf_schema.json ├── core ├── __init__.py ├── commands │ ├── __init__.py │ └── base_command.py ├── community │ ├── __init__.py │ └── community_detector.py ├── config_validator.py ├── constants.py ├── engines │ ├── forgetting_agent.py │ ├── recall_engine.py │ └── reflection_engine.py ├── handlers │ ├── __init__.py │ ├── admin_handler.py │ ├── base_handler.py │ ├── fusion_handler.py │ ├── memory_handler.py │ └── search_handler.py ├── models.py ├── models │ └── memory_models.py ├── retrieval │ ├── __init__.py │ ├── result_fusion.py │ └── sparse_retriever.py └── utils.py ├── docs ├── CONFIG.md └── DEVELOPMENT.md ├── main.py ├── metadata.yaml ├── requirements.txt ├── storage ├── __init__.py ├── faiss_manager.py ├── faiss_manager_v2.py ├── graph_storage.py ├── memory_storage.py └── vector_store.py └── tests ├── conftest.py └── unit ├── test_admin_handler.py ├── test_base_handler.py ├── test_memory_handler.py └── test_search_handler.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.pyc 4 | *.pyo 5 | *.pyd 6 | 7 | # Database files 8 | *.db 9 | *.index 10 | 11 | # IDE and OS files 12 | .vscode/ 13 | .idea/ 14 | .DS_Store 15 | work_list.md 16 | CLAUDE.md 17 | .claude/settings.local.json 18 | 19 | __pycache__/ 20 | *.py[cod] 21 | *$py.class 22 | *.so 23 | .Python 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # Virtual environments 42 | .env 43 | .venv 44 | env/ 45 | venv/ 46 | ENV/ 47 | env.bak/ 48 | venv.bak/ 49 | 50 | # IDE 51 | .idea/ 52 | .vscode/ 53 | *.swp 54 | *.swo 55 | 56 | # OS 57 | .DS_Store 58 | .DS_Store? 59 | ._* 60 | .Spotlight-V100 61 | .Trashes 62 | ehthumbs.db 63 | Thumbs.db 64 | 65 | # Testing 66 | .pytest_cache/ 67 | .coverage 68 | htmlcov/ 69 | .tox/ 70 | 71 | # Logs 72 | *.log 73 | logs/ 74 | 75 | # Temporary files 76 | tmp/ 77 | temp/ 78 | *.tmp 79 | 80 | # UV (if used) 81 | uv.lock 82 | 83 | # Node.js (if used) 84 | node_modules/ 85 | npm-debug.log* 86 | yarn-debug.log* 87 | yarn-error.log* 88 | 89 | data/ 90 | .claude/ 91 | pyproject.toml 92 | -------------------------------------------------------------------------------- /FUSION_STRATEGIES.md: -------------------------------------------------------------------------------- 1 | # 混合检索融合策略详解 2 | 3 | LivingMemory 插件支持多种先进的结果融合策略,用于优化混合检索效果。 4 | 5 | ## 🎯 支持的融合策略 6 | 7 | ### 1. RRF (Reciprocal Rank Fusion) - 经典策略 8 | ``` 9 | 分数 = Σ(1 / (k + rank_i)) 10 | ``` 11 | - **适用场景**: 通用场景,平衡性好 12 | - **参数**: `rrf_k` (默认: 60) 13 | - **特点**: 简单有效,对排序位置敏感 14 | 15 | ### 2. Hybrid RRF - 动态RRF 16 | - **适用场景**: 需要根据查询特征自适应调整 17 | - **参数**: `rrf_k`, `diversity_bonus` 18 | - **特点**: 根据查询长度和类型动态调整RRF参数 19 | - **优势**: 短查询偏向稀疏检索,长查询偏向密集检索 20 | 21 | ### 3. Weighted - 加权融合 22 | ``` 23 | 分数 = α × dense_score + β × sparse_score 24 | ``` 25 | - **适用场景**: 明确知道两种检索器的相对重要性 26 | - **参数**: `dense_weight`, `sparse_weight` 27 | - **特点**: 简单直观,可解释性强 28 | 29 | ### 4. Convex - 凸组合融合 30 | ``` 31 | 分数 = λ × norm(dense) + (1-λ) × norm(sparse) 32 | ``` 33 | - **适用场景**: 需要数学严格的融合方法 34 | - **参数**: `convex_lambda` (0.0-1.0) 35 | - **特点**: 分数归一化到 [0,1],数学性质好 36 | 37 | ### 5. Interleave - 交替融合 38 | - **适用场景**: 需要保证结果多样性 39 | - **参数**: `interleave_ratio` - 密集结果所占比例 40 | - **特点**: 按比例交替选择不同检索器的结果 41 | 42 | ### 6. Rank Fusion - 基于排序的融合 43 | ``` 44 | 分数 = Σ(weight_i / rank_i) + bias(if in both lists) 45 | ``` 46 | - **适用场景**: 重视文档在排序列表中的位置 47 | - **参数**: `dense_weight`, `sparse_weight`, `rank_bias_factor` 48 | - **特点**: 在两个列表中都出现的文档获得额外加分 49 | 50 | ### 7. Score Fusion - Borda Count融合 51 | ``` 52 | 分数 = Σ(list_size - rank_i) × weight_i 53 | ``` 54 | - **适用场景**: 基于排序投票的民主融合 55 | - **参数**: `dense_weight`, `sparse_weight` 56 | - **特点**: 类似选举中的Borda计数法 57 | 58 | ### 8. Cascade - 级联融合 59 | - **适用场景**: 大规模检索,需要效率优化 60 | - **流程**: 稀疏检索初筛 → 密集检索精排 61 | - **特点**: 先用快速的稀疏检索筛选候选,再用精确的密集检索排序 62 | 63 | ### 9. Adaptive - 自适应融合 64 | - **适用场景**: 查询类型多样的场景 65 | - **策略**: 根据查询特征选择最优融合方法 66 | - 关键词查询 → 偏向稀疏检索 67 | - 语义查询 → 偏向密集检索 68 | - 混合查询 → 使用RRF 69 | 70 | ## 📊 性能特征对比 71 | 72 | | 策略 | 计算复杂度 | 参数调优难度 | 适应性 | 可解释性 | 73 | |------|-----------|-------------|--------|----------| 74 | | RRF | 低 | 低 | 中 | 中 | 75 | | Hybrid RRF | 中 | 中 | 高 | 中 | 76 | | Weighted | 低 | 低 | 低 | 高 | 77 | | Convex | 低 | 中 | 中 | 高 | 78 | | Interleave | 低 | 低 | 低 | 高 | 79 | | Rank Fusion | 中 | 中 | 中 | 中 | 80 | | Score Fusion | 高 | 中 | 中 | 中 | 81 | | Cascade | 低 | 低 | 低 | 高 | 82 | | Adaptive | 中 | 高 | 高 | 低 | 83 | 84 | ## 🛠️ 使用指南 85 | 86 | ### 配置示例 87 | 88 | ```yaml 89 | fusion: 90 | strategy: "hybrid_rrf" 91 | rrf_k: 60 92 | dense_weight: 0.7 93 | sparse_weight: 0.3 94 | diversity_bonus: 0.1 95 | convex_lambda: 0.5 96 | interleave_ratio: 0.6 97 | rank_bias_factor: 0.15 98 | ``` 99 | 100 | ### 命令行管理 101 | 102 | ```bash 103 | # 查看当前配置 104 | /lmem fusion show 105 | 106 | # 切换到混合RRF 107 | /lmem fusion hybrid_rrf 108 | 109 | # 调整凸组合参数 110 | /lmem fusion convex convex_lambda=0.6 111 | 112 | # 调整权重 113 | /lmem fusion weighted dense_weight=0.8 114 | 115 | # 测试融合效果 116 | /lmem test_fusion "用户的兴趣爱好" 5 117 | ``` 118 | 119 | ## 🎯 策略选择建议 120 | 121 | ### 场景驱动的选择 122 | 123 | 1. **通用聊天机器人**: `hybrid_rrf` 或 `rrf` 124 | - 查询类型多样,需要自适应能力 125 | 126 | 2. **专业知识问答**: `weighted` 或 `convex` 127 | - 可以明确调优权重,提高精确度 128 | 129 | 3. **多样性优先**: `interleave` 130 | - 确保结果不会过于相似 131 | 132 | 4. **大规模数据库**: `cascade` 133 | - 效率优先,两阶段处理 134 | 135 | 5. **实验和调优**: `score_fusion` 或 `rank_fusion` 136 | - 更复杂的融合逻辑,适合深度优化 137 | 138 | ### 参数调优指南 139 | 140 | 1. **RRF参数 k**: 141 | - 较小值 (30-50): 更重视排序靠前的结果 142 | - 较大值 (80-120): 更平衡地考虑所有结果 143 | 144 | 2. **权重比例**: 145 | - 密集权重 > 稀疏权重: 语义查询为主 146 | - 密集权重 < 稀疏权重: 关键词查询为主 147 | 148 | 3. **多样性参数**: 149 | - 较大值: 鼓励结果多样性 150 | - 较小值: 优先考虑相关性 151 | 152 | ## 🧪 实验和评估 153 | 154 | 使用 `/lmem test_fusion` 命令可以: 155 | - 测试不同策略的效果 156 | - 查看融合过程的详细信息 157 | - 对比不同参数设置的结果 158 | 159 | 建议在实际数据上进行A/B测试,选择最适合你使用场景的融合策略。 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LivingMemory - 动态生命周期记忆插件 v1.0.0 2 | 3 | 4 |
5 | 6 | 测...测不动了 7 | 8 |
9 | 🧪 测...测不动了 10 | 11 | > ⚠️ **需要测试验证**: 请在测试环境中验证所有新功能 12 | 13 |
14 | 15 |
16 | 17 | 18 |

19 | 为 AstrBot 打造的、拥有完整记忆生命周期的智能长期记忆插件。 20 |

21 | 22 | Python 23 | Faiss 24 | License 25 | 26 | 27 | GitHub Stars 28 | 29 | 30 | Visitor Count 31 |

32 | 33 | 34 | --- 35 | 36 | `LivingMemory` 告别了传统记忆插件对大型数据库的依赖,创新性地采用轻量级的 `Faiss` 和 `SQLite` 作为存储后端。这不仅实现了 **零配置部署** 和 **极低的资源消耗**,更引入了革命性的 **动态记忆生命周期 (Dynamic Memory Lifecycle)** 模型。 37 | 38 | ## ✨ 核心特性:三大引擎架构 39 | 40 | 本插件通过三大智能引擎的协同工作,完美模拟了人类记忆的形成、巩固、联想和遗忘的全过程。 41 | 42 | | 引擎 | 图标 | 核心功能 | 最新增强 | 43 | | :--- | :---: | :--- | :--- | 44 | | **反思引擎** | 🧠 | `智能总结` & `重要性评估` | ✨ 增强错误重试机制,提高提取成功率 | 45 | | **回忆引擎** | 🔍 | `混合检索` & `智能融合` | 🚀 **9种先进融合策略** - RRF、自适应、级联等 | 46 | | **遗忘代理** | 🗑️ | `遗忘曲线` & `批量清理` | 💾 **分页处理** - 优化内存使用,支持大规模数据 | 47 | 48 | ## 🔥 新特性:混合检索系统 49 | 50 | ### 检索模式 51 | - **🔍 Dense (密集检索)**: 基于语义向量的深度理解 52 | - **⚡ Sparse (稀疏检索)**: BM25关键词匹配,支持中文分词 53 | - **🤝 Hybrid (混合检索)**: 智能融合两种检索方式的优势 54 | 55 | ### 9种融合策略 56 | 57 | | 策略 | 特点 | 适用场景 | 计算复杂度 | 58 | |------|------|----------|------------| 59 | | **RRF** | 经典倒数排名融合 | 通用场景,平衡性好 | 低 | 60 | | **Hybrid RRF** | 动态参数调整 | 自适应查询类型 | 中 | 61 | | **Weighted** | 简单加权融合 | 明确权重偏好 | 低 | 62 | | **Convex** | 凸组合数学融合 | 需要严格数学性质 | 低 | 63 | | **Interleave** | 交替选择结果 | 保证结果多样性 | 低 | 64 | | **Rank Fusion** | 基于排序位置 | 重视排序信息 | 中 | 65 | | **Score Fusion** | Borda Count投票 | 民主投票机制 | 高 | 66 | | **Cascade** | 两阶段处理 | 大规模高效检索 | 低 | 67 | | **Adaptive** | 查询特征自适应 | 多样化查询场景 | 中 | 68 | 69 | ### 智能查询分析 70 | 系统会自动分析查询特征,选择最优融合策略: 71 | - **关键词查询** → 偏向稀疏检索 72 | - **语义查询** → 偏向密集检索 73 | - **混合查询** → 使用RRF平衡融合 74 | 75 | ## 🚀 快速开始 76 | 77 | ### 1. 安装 78 | 79 | 将 `astrbot_plugin_livingmemory` 文件夹放置于 AstrBot 的 `data/plugins` 目录下。AstrBot 将自动检测并安装依赖。 80 | 81 | **核心依赖:** 82 | ``` 83 | faiss-cpu>=1.7.0 84 | pydantic>=1.8.0 85 | jieba>=0.42.1 86 | ``` 87 | 88 | **🧪 测试阶段**: 插件已完成重构和功能增强,正在进行全面测试验证。 89 | 90 | ### 2. 配置 91 | 92 | **✨ 全新配置系统**: 基于 Pydantic 的智能配置验证,确保参数有效性。 93 | 94 |
95 | ⚙️ 点击展开详细配置说明 96 | 97 | #### 🔧 基础设置 98 | - **Provider设置**: 自定义 Embedding 和 LLM Provider,支持多Provider混用 99 | - **时区配置**: 支持全球时区,时间显示本地化 100 | - **会话管理**: 智能会话生命周期,自动清理过期会话 101 | 102 | #### 🔍 检索配置 103 | - **检索模式**: `hybrid`(混合) | `dense`(密集) | `sparse`(稀疏) 104 | - **融合策略**: 9种策略可选,支持动态参数调整 105 | - **BM25参数**: 可调整k1、b参数优化中文检索效果 106 | - **权重控制**: 相似度、重要性、新近度权重精细调节 107 | 108 | #### 🧠 智能引擎配置 109 | - **反思触发**: 可配置对话轮次阈值(1-100轮) 110 | - **重要性评估**: 自定义重要性阈值(0.0-1.0) 111 | - **自定义提示词**: 完全可定制的事件提取和评估提示 112 | 113 | #### 🗑️ 遗忘机制 114 | - **智能清理**: 基于重要性衰减的自动清理 115 | - **批量处理**: 分页加载,支持大规模记忆库 116 | - **保留策略**: 灵活的天数和阈值配置 117 | 118 | #### 🛡️ 过滤隔离 119 | - **人格过滤**: 按AI人格隔离记忆,互不干扰 120 | - **会话隔离**: 会话级别的记忆独立性 121 | - **状态管理**: 记忆状态(活跃/归档/删除)精细控制 122 | 123 |
124 | 125 | ### 3. 高级配置示例 126 | 127 | ```yaml 128 | # 针对中文优化的推荐配置 129 | sparse_retriever: 130 | bm25_k1: 1.2 # 中文词频参数 131 | bm25_b: 0.75 # 文档长度归一化 132 | use_jieba: true # 启用中文分词 133 | 134 | fusion: 135 | strategy: "hybrid_rrf" # 自适应融合策略 136 | rrf_k: 60 # RRF参数 137 | diversity_bonus: 0.1 # 多样性奖励 138 | 139 | recall_engine: 140 | retrieval_mode: "hybrid" # 混合检索模式 141 | similarity_weight: 0.6 # 相似度权重 142 | importance_weight: 0.2 # 重要性权重 143 | recency_weight: 0.2 # 新近度权重 144 | ``` 145 | 146 | ## 🛠️ 完整管理命令 147 | 148 | 插件在后台自动运行,同时提供了强大的命令行管理界面: 149 | 150 | ### 📊 基础管理 151 | | 命令 | 参数 | 描述 | 152 | | :--- | :--- | :--- | 153 | | `/lmem status` | - | 📈 查看记忆库状态和统计信息 | 154 | | `/lmem search` | ` [k=3]` | 🔍 手动搜索记忆,支持详细信息展示 | 155 | | `/lmem forget` | `` | 🗑️ 删除指定ID的记忆 | 156 | 157 | ### 🧠 引擎管理 158 | | 命令 | 参数 | 描述 | 159 | | :--- | :--- | :--- | 160 | | `/lmem run_forgetting_agent` | - | 🔄 手动触发遗忘代理清理任务 | 161 | | `/lmem sparse_rebuild` | - | 🏗️ 重建稀疏检索索引 | 162 | | `/lmem sparse_test` | ` [k=5]` | ⚡ 测试稀疏检索功能 | 163 | 164 | ### ⚙️ 配置管理 165 | | 命令 | 参数 | 描述 | 166 | | :--- | :--- | :--- | 167 | | `/lmem config` | `[show\|validate]` | 📋 显示或验证当前配置 | 168 | | `/lmem search_mode` | `` | 🔄 切换检索模式 | 169 | 170 | ### 🔄 融合策略管理 171 | | 命令 | 参数 | 描述 | 172 | | :--- | :--- | :--- | 173 | | `/lmem fusion` | `[strategy] [param=value]` | 🎯 管理融合策略和参数 | 174 | | `/lmem test_fusion` | ` [k=5]` | 🧪 测试当前融合策略效果 | 175 | 176 | ### 📝 记忆编辑 *(新增)* 177 | | 命令 | 参数 | 描述 | 178 | | :--- | :--- | :--- | 179 | | `/lmem edit` | ` [reason]` | ✏️ 精确编辑记忆内容或元数据 | 180 | | `/lmem update` | `` | 📝 交互式记忆编辑引导 | 181 | | `/lmem history` | `` | 📚 查看记忆的完整更新历史 | 182 | 183 | #### 编辑示例 184 | ```bash 185 | # 编辑记忆内容 186 | /lmem edit 123 content 这是新的记忆内容 修正错误信息 187 | 188 | # 调整重要性评分 189 | /lmem edit 123 importance 0.9 提高重要性 190 | 191 | # 更改记忆类型 192 | /lmem edit 123 type PREFERENCE 重新分类为偏好 193 | 194 | # 归档记忆 195 | /lmem edit 123 status archived 项目已完成 196 | ``` 197 | 198 | #### 融合策略示例 199 | ```bash 200 | # 查看当前融合配置 201 | /lmem fusion show 202 | 203 | # 切换到自适应RRF 204 | /lmem fusion hybrid_rrf 205 | 206 | # 调整凸组合参数 207 | /lmem fusion convex convex_lambda=0.6 208 | 209 | # 调整加权融合权重 210 | /lmem fusion weighted dense_weight=0.8 211 | 212 | # 测试融合效果 213 | /lmem test_fusion "用户的兴趣爱好" 5 214 | ``` 215 | 216 | ## 🎯 性能优化与最佳实践 217 | 218 | ### 💾 内存管理 219 | - **分页加载**: 支持大规模记忆库,避免OOM 220 | - **会话清理**: 智能TTL机制,自动清理过期会话 221 | - **事务安全**: SQLite事务保证数据一致性 222 | 223 | ### 🚀 并发处理 224 | - **异步初始化**: 避免阻塞主线程 225 | - **后台任务**: 反思和遗忘任务后台执行 226 | - **错误重试**: 自动重试机制,提高稳定性 227 | 228 | ### 🔧 故障排除 229 | 230 |
231 | 🚨 常见问题解决方案 232 | 233 | #### Q: 插件初始化失败 234 | ```bash 235 | # 检查依赖安装 236 | pip install faiss-cpu pydantic jieba 237 | 238 | # 验证配置 239 | /lmem config validate 240 | ``` 241 | 242 | #### Q: 检索效果不佳 243 | ```bash 244 | # 尝试不同融合策略 245 | /lmem fusion adaptive 246 | 247 | # 重建稀疏索引 248 | /lmem sparse_rebuild 249 | 250 | # 调整检索模式 251 | /lmem search_mode hybrid 252 | ``` 253 | 254 | #### Q: 内存占用过高 255 | ```bash 256 | # 手动触发遗忘 257 | /lmem run_forgetting_agent 258 | 259 | # 检查会话数量 260 | /lmem config show 261 | ``` 262 | 263 |
264 | 265 | ## 📚 相关文档 266 | 267 | - 📖 [融合策略详解](FUSION_STRATEGIES.md) - 深入了解9种融合算法 268 | - ⚙️ [配置参考](docs/CONFIG.md) - 完整配置参数说明 269 | - 🔧 [开发指南](docs/DEVELOPMENT.md) - 插件开发和扩展指南 270 | 271 | ## 🤝 贡献 272 | 273 | 欢迎各种形式的贡献: 274 | - 🐛 **问题报告**: [GitHub Issues](https://github.com/lxfight/astrbot_plugin_livingmemory/issues) 275 | - 💡 **功能建议**: [Feature Requests](https://github.com/lxfight/astrbot_plugin_livingmemory/issues/new?template=feature_request.md) 276 | - 🔧 **代码贡献**: [Pull Requests](https://github.com/lxfight/astrbot_plugin_livingmemory/pulls) 277 | - 📖 **文档改进**: 欢迎改进文档和示例 278 | 279 | ## 交流一下 280 | 遇到问题或想交流使用心得?加入我们的讨论群: 281 | [![加入QQ群](https://img.shields.io/badge/QQ群-953245617-blue?style=flat-square&logo=tencent-qq)](https://qm.qq.com/cgi-bin/qm/qr?k=WdyqoP-AOEXqGAN08lOFfVSguF2EmBeO&jump_from=webapi&authKey=tPyfv90TVYSGVhbAhsAZCcSBotJuTTLf03wnn7/lQZPUkWfoQ/J8e9nkAipkOzwh) 282 | 283 | `入关口令`: `lxfight` 284 | 285 | ## 📄 许可证 286 | 287 | 本项目遵循 **AGPLv3** 许可证 - 详见 [LICENSE](LICENSE) 文件。 288 | 289 | --- 290 | 291 |
292 |
293 | 294 | **⭐ 如果这个项目对您有帮助,请给我们一个 Star!** 295 | 296 |
297 | 298 | *LivingMemory - 让AI拥有真正的生命记忆 🧠✨* 299 | 300 |
301 | -------------------------------------------------------------------------------- /TODO.md: -------------------------------------------------------------------------------- 1 | # 关于 livingmemory 插件的 TODO 列表 2 | 3 | - [ ] 设计记忆的数据结构 4 | - [ ] 设计记忆的召回方式 5 | - [ ] 设计记忆的更新方式 6 | - [ ] 记忆的遗忘逻辑 7 | 8 | # 一些潜在的规划 9 | 10 | { 11 | "memory_id": "String", // [主键] 记忆的全局唯一标识符 (建议使用UUID) 12 | "timestamp": "String", // 事件发生的时间戳 (ISO 8601格式, e.g., "2025-07-26T14:30:00Z") 13 | "summary": "String", // AI生成的、可供快速预览的单行摘要或标题 14 | "description": "String", // 对事件的完整、详细的自然语言描述 15 | "embedding": "Array", // 由`description`或`summary`+`description`生成的文本嵌入向量,用于Faiss 16 | 17 | "linked_media": [ 18 | // 多模态内容部分 19 | { 20 | "media_id": "String", // 媒体文件的唯一ID 21 | "media_type": "String", // 'image', 'audio', 'document', 'video', 'code_snippet' 22 | "url": "String", // 文件的存储位置 (e.g., S3 URL or local path) 23 | "caption": "String", // 对媒体内容的简短描述 24 | "embedding": "Array" // 媒体文件本身的多模态嵌入 (e.g., CLIP vector for images) 25 | } 26 | ], 27 | "metadata": { 28 | "source_conversation_id": "String", // 此记忆来源的对话ID,用于溯源 29 | "memory_type": "String", // 记忆类型: 'episodic' (情景), 'semantic' (事实), 'procedural' (程序) 30 | "importance_score": "Float", // [遗忘引擎关键输入] 记忆的重要性评分 (0.0 to 1.0) 31 | "confidence_score": "Float", // [可选] NLU模型提取此记忆信息的置信度 (0.0 to 1.0) 32 | "access_info": { 33 | "initial_creation_timestamp": "String", // 记忆被创建的时间 34 | "last_accessed_timestamp": "String", // [遗忘引擎关键输入] 记忆最近被访问的时间 35 | "access_count": "Integer" // [遗忘引擎关键输入] 记忆被访问的总次数 36 | }, 37 | "emotional_valence": { 38 | "sentiment": "String", // 情感倾向: 'positive', 'negative', 'neutral' 39 | "intensity": "Float" // 情感强度 (0.0 to 1.0) 40 | }, 41 | "user_feedback": { 42 | "is_accurate": "Boolean", // 用户是否标记此记忆为准确 (null, true, false) 43 | "is_important": "Boolean", // 用户是否标记此记忆为重要 (null, true, false) 44 | "correction_text": "String" // 用户提供的修正文本 45 | }, 46 | "community_info": { 47 | // 用于图社区发现 48 | "id": "String", // 此记忆所属社区/簇的唯一ID 49 | "last_calculated": "String" // 上次计算社区分配的时间戳 50 | } 51 | }, 52 | 53 | "knowledge_graph_payload": { 54 | "event_entity": { 55 | "event_id": "String", // [图节点] 事件的唯一ID 56 | "event_type": "String" // 事件的分类, e.g., 'ProjectInitiation', 'TravelPlanning' 57 | }, 58 | "entities": [ 59 | { 60 | "entity_id": "String", // [图节点] 实体的全局唯一ID 61 | "name": "String", // 实体的名称 62 | "type": "String", // 实体类型, e.g., 'PERSON', 'PROJECT', 'ORGANIZATION' 63 | "role": "String" // [可选] 实体在此次事件中的具体角色 64 | } 65 | ], 66 | "relationships": [ 67 | "Array" // [图的边] 定义关系的三元组 [主体ID, 关系谓词, 客体ID] 68 | ] 69 | } 70 | } 71 | 72 | 73 | 示例 74 | 75 | ```json 76 | { 77 | "memory_id": "mem_e2a1b3c4-d5e6-f7g8-h9i0-j1k2l3m4n5o6", 78 | "timestamp": "2025-07-28T11:00:00Z", 79 | "summary": "张伟分享了'凤凰计划'的登录页设计图", 80 | "description": "用户张伟分享了一张关于'凤凰计划'的登录页面初稿图片,并征求反馈。他还提及这张设计图关联到他们上一次会议中确定的蓝色主题。", 81 | "embedding": [-0.15, 0.33, 0.81, ..., -0.05], 82 | 83 | "linked_media": [ 84 | { 85 | "media_id": "img_login_mockup_v1", 86 | "media_type": "image", 87 | "url": "s3://my-ai-project-bucket/media/img_login_mockup_v1.png", 88 | "caption": "'凤凰计划'登录页面的UI设计初稿", 89 | "embedding": [0.67, 0.12, -0.29, ..., 0.44] 90 | } 91 | ], 92 | 93 | "metadata": { 94 | "source_conversation_id": "conv_z1y2x3w-v4u5-t6s7-r8q9-p0o9n8m7l6k5", 95 | "memory_type": "episodic", 96 | "importance_score": 0.85, 97 | "confidence_score": 0.99, 98 | "access_info": { 99 | "initial_creation_timestamp": "2025-07-28T11:01:30Z", 100 | "last_accessed_timestamp": "2025-07-28T11:01:30Z", 101 | "access_count": 1 102 | }, 103 | "emotional_valence": { 104 | "sentiment": "neutral", 105 | "intensity": 0.3 106 | }, 107 | "user_feedback": { 108 | "is_accurate": null, 109 | "is_important": null, 110 | "correction_text": null 111 | }, 112 | "community_info": { 113 | "id": "community_proj_phoenix_001", 114 | "last_calculated": "2025-07-27T04:00:00Z" 115 | } 116 | }, 117 | 118 | "knowledge_graph_payload": { 119 | "event_entity": { 120 | "event_id": "evt_design_review_001", 121 | "event_type": "DesignReview" 122 | }, 123 | "entities": [ 124 | { "entity_id": "person_zhang_wei_001", "name": "张伟", "type": "PERSON", "role": "author" }, 125 | { "entity_id": "project_phoenix_001", "name": "凤凰计划", "type": "PROJECT", "role": "context" }, 126 | { "entity_id": "asset_login_mockup_001", "name": "登录页面初稿", "type": "DESIGN_ASSET", "role": "subject" } 127 | ], 128 | "relationships": [ 129 | ["person_zhang_wei_001", "CREATED", "asset_login_mockup_001"], 130 | ["asset_login_mockup_001", "IS_PART_OF", "project_phoenix_001"], 131 | ["evt_design_review_001", "IS_ABOUT", "asset_login_mockup_001"], 132 | ["evt_design_review_001", "REFERENCES", "evt_meeting_042"] 133 | ] 134 | } 135 | } 136 | ``` 137 | 138 | ### 核心 AI 生成内容 139 | 140 | 这部分是 AI 的核心创造性工作,直接体现了模型的智能。 141 | 142 | | 字段路径 | 字段名 | AI 负责的工作 | 需要的模型/技术 | 143 | | :-------------------------- | :--------------- | :---------------------------------------------------------------------- | :----------------------------------- | 144 | | `summary` | **摘要** | 阅读`description`全文,生成一段简短、精炼的标题或摘要。 | 大型语言模型 (LLM) - 文本摘要任务 | 145 | | `embedding` | **文本嵌入向量** | 将`description`的语义信息编码成一个高维浮点数向量。 | 文本嵌入模型 (Text Embedding Model) | 146 | | `linked_media[].embedding` | **媒体嵌入向量** | 将图片、音频等媒体文件编码成一个与文本在同一空间的高维向量。 | 多模态嵌入模型 (e.g., CLIP) | 147 | | `linked_media[].caption` | **媒体内容描述** | (可选) 如果用户没有提供,AI 可以“看图说话”,为上传的图片生成描述。 | 视觉语言模型 (Vision-Language Model) | 148 | | `metadata.importance_score` | **重要性评分** | 根据`description`的内容,判断该记忆对用户的重要性,并给出一个量化分数。 | 大型语言模型 (LLM) - 分类/回归任务 | 149 | | `metadata.confidence_score` | **置信度评分** | NLU 模型在完成实体和关系提取后,对其结果的确定程度给出的一个分数。 | 自然语言理解模型 (NLU Model) | 150 | 151 | ### AI 提取与结构化内容 152 | 153 | 这部分是 AI 的“理解”工作,将非结构化的对话转化为机器可读的结构化数据。 154 | 155 | | 字段路径 | 字段名 | AI 负责的工作 | 需要的模型/技术 | 156 | | :------------------------------------------------ | :----------- | :----------------------------------------------------------------------------------------- | :------------------------------------------------------------------ | 157 | | `metadata.memory_type` | **记忆类型** | 分析记忆内容,将其分类为情景记忆、事实记忆或程序记忆。 | 自然语言理解 (NLU) - 文本分类 | 158 | | `metadata.emotional_valence` | **情感倾向** | 分析`description`中的情感色彩(积极/消极/中性)和强度。 | 自然语言理解 (NLU) - 情感分析 | 159 | | `knowledge_graph_payload.event_entity.event_type` | **事件类型** | 将整个事件归类到一个预定义的类型中 (e.g., `TravelPlanning`)。 | 自然语言理解 (NLU) - 文本分类 | 160 | | `knowledge_graph_payload.entities` | **实体列表** | 从`description`中识别出所有关键实体(人名、项目、组织等),并进行标准化(如分配唯一 ID)。 | 自然语言理解 (NLU) - 命名实体识别 (NER) & 实体链接 (Entity Linking) | 161 | | `knowledge_graph_payload.relationships` | **关系列表** | 识别并提取已识别出的实体之间的关系三元组(主语-谓词-宾语)。 | 自然语言理解 (NLU) - 关系提取 (Relation Extraction) | 162 | 163 | --- 164 | 165 | ### 非 AI 生成(系统或用户提供) 166 | 167 | 为了完整性,以下是基本不需要 AI 介入,由系统逻辑、用户输入或其他算法直接填充的字段。 168 | 169 | | 字段路径 | 字段名 | 来源 | 170 | | :-------------------------------- | :--------- | :-------------------------------------------- | 171 | | `memory_id` | 记忆 ID | 系统生成 (UUID) | 172 | | `timestamp` | 事件时间戳 | 系统获取或用户指定 | 173 | | `description` | 详细描述 | 用户的原始输入或对话记录 | 174 | | `linked_media[].media_id` | 媒体 ID | 系统生成 | 175 | | `linked_media[].url` | 文件 URL | 文件存储系统返回 | 176 | | `metadata.source_conversation_id` | 对话 ID | 系统记录 | 177 | | `metadata.access_info` | 访问信息 | 系统根据调用情况更新(时间戳、计数器) | 178 | | `metadata.user_feedback` | 用户反馈 | 用户直接提供 | 179 | | `metadata.community_info` | 社区信息 | 后台的图计算**算法**(非生成式 AI)运行后填充 | 180 | 181 | ## 2. 记忆的召回方式 (设计中) 182 | 183 | 为实现精准且富有洞察力的回忆,我们设计一个**两阶段混合检索模型**,结合向量的语义相似度和图的结构关联性。 184 | 185 | ### **阶段一:候选生成 (Candidate Generation) - 追求召回率** 186 | 187 | 1. **输入**: 用户的当前查询(文本、图片等)。 188 | 2. **处理**: 189 | - 将用户查询通过相应的嵌入模型(文本或多模态)转换为查询向量。 190 | - 在 Faiss 数据库中,使用此查询向量进行`k`-近邻搜索(例如,k=50),得到一个包含 50 个最相似记忆`memory_id`的初始候选集。 191 | - 同时,NLU 模块从查询中提取核心实体(如“张伟”)。 192 | 3. **输出**: 一个包含`memory_id`、Faiss 相似度分数和查询实体的候选列表。 193 | 194 | ### **阶段二:重排与扩展 (Re-ranking & Expansion) - 追求精确率** 195 | 196 | 1. **输入**: 阶段一生成的候选列表。 197 | 2. **处理**: 198 | - **加载图**: 使用内存图库(如`NetworkX`)加载所有记忆的`knowledge_graph_payload`构建的知识图谱。 199 | - **计算图分数 (Graph Score)**: 对于候选集中的每一个记忆,计算其与“查询实体”在图中的关联强度。 200 | - **方法**: 可以采用个性化 PageRank(从查询实体节点开始游走),或计算图中查询实体节点到记忆事件节点的最短路径长度。路径越短,关联越强,分数越高。 201 | - **扩展发现**: 从查询实体出发,在图中寻找其一度或二度关联的实体和事件(例如,张伟的同事“陈静”,参与的项目“凤凰计划”)。如果候选记忆与这些扩展出的实体相关,则其图分数应获得额外加成。 202 | - **计算最终分数**: 203 | `FinalScore = (w1 * FaissScore) + (w2 * GraphScore)` 204 | - `w1`和`w2`是可配置的权重,例如`w1=0.6`, `w2=0.4`。 205 | 3. **输出**: 一个根据`FinalScore`重新排序的、高质量的记忆列表,提交给上层应用(如 LLM)用于生成最终回复。 206 | 207 | --- 208 | 209 | ## 3. 记忆的更新方式 (设计中) 210 | 211 | 记忆不是一成不变的。插件必须支持记忆的演化和修正,核心原则是**优先创建、避免覆盖**,以保留信息的历史轨迹。 212 | 213 | 1. **元数据更新**: 214 | 215 | - **场景**: 记忆被访问。 216 | - **操作**: 直接修改`metadata.access_info`中的`last_accessed_timestamp`和`access_count`字段。这是唯一允许直接原地修改的操作。 217 | 218 | 2. **基于反馈的修正**: 219 | 220 | - **场景**: 用户通过`metadata.user_feedback`提供了修正(例如,“李娜是我的同事,不是朋友”)。 221 | - **操作**: 222 | a. 创建一个**新的记忆**,包含正确的`description`和`knowledge_graph_payload`。 223 | b. 在新记忆的`knowledge_graph_payload.relationships`中,添加一条关系指向旧记忆:`["new_memory_event_id", "CORRECTS", "old_memory_event_id"]`。 224 | c. 将旧记忆的`importance_score`大幅降低,或者标记为“已修正”。 225 | 226 | 3. **事件演化更新**: 227 | - **场景**: 事情发生了变化(例如,会议从周二改到周三)。 228 | - **操作**: 229 | a. 创建一个**新的记忆**来记录新状态(“会议将在周三举行”)。 230 | b. 在新记忆的图关系中,添加`["new_event_id", "UPDATES", "old_event_id"]`。 231 | c. 旧记忆保持不变,但其在未来的检索中权重会因“被更新”而降低。 232 | 233 | --- 234 | 235 | ## 4. 记忆的遗忘逻辑 (设计中) 236 | 237 | 为了防止记忆无限膨胀并保持检索效率,需要一个智能的遗忘机制。该机制基于一个可计算的**“记忆衰减分数”(Decay Score)**。 238 | 239 | 1. **衰减分数计算**: 240 | 241 | - 一个后台任务会定期(如每天)扫描所有记忆,并计算其衰减分数。 242 | - **公式**: `DecayScore = f(ElapsedTime, AccessCount, ImportanceScore, UserFeedback)` 243 | - **示例**: `DecayScore = (CurrentTime - last_accessed_timestamp) / (log(access_count + 1) * (importance_score + user_marked_important * 10))` 244 | - 这个公式意味着:时间越久,分数越高(越容易遗忘);访问次数越多、重要性越高、被用户标记为重要,分数越低(越不容易遗忘)。 245 | 246 | 2. **分层遗忘策略**: 247 | 248 | - 系统设定几个衰减分数的阈值,对应不同的操作。 249 | - **第一阈值 (e.g., Score > 100)**: **归档 (Archive)**。 250 | - **操作**: 为了节省高性能存储和 Faiss 索引空间,可以从记忆中移除`embedding`向量。记忆的 JSON 文本和元数据被转移到更廉价的“冷存储”中。此记忆不再参与常规的向量检索,但仍可通过 ID 或关键词搜索找到。 251 | - **第二阈值 (e.g., Score > 500)**: **标记为待删除 (Mark for Deletion)**。 252 | - **操作**: 系统将记忆标记为待删除,进入一个短暂的“回收站”状态。 253 | - **最终清理**: 一个独立的、执行频率更低的任务会永久删除那些被标记了足够长时间(如 30 天)的记忆。 254 | 255 | 3. **豁免机制**: 256 | - 任何被用户通过`user_feedback`标记为`is_important: true`的记忆,其衰减分数计算将获得极高的权重,或直接豁免于遗忘逻辑。 257 | - 核心的`semantic`类型记忆(事实性知识)的衰减速度应远低于`episodic`类型(情景性记忆)。 258 | -------------------------------------------------------------------------------- /_conf_schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "timezone_settings": { 3 | "description": "时区设置", 4 | "type": "object", 5 | "items": { 6 | "timezone": { 7 | "description": "时区", 8 | "hint": "用于生成带时区信息的时间对象,请使用 IANA 时区数据库名称,例如 'Asia/Shanghai', 'America/New_York'。", 9 | "type": "string", 10 | "default": "Asia/Shanghai" 11 | } 12 | } 13 | }, 14 | "provider_settings": { 15 | "description": "模型提供商设置", 16 | "type": "object", 17 | "items": { 18 | "embedding_provider_id": { 19 | "description": "Embedding Provider ID", 20 | "hint": "用于生成向量的 Embedding Provider ID。如果留空,将使用 AstrBot 的默认 Embedding Provider。", 21 | "type": "string", 22 | "default": "" 23 | }, 24 | "llm_provider_id": { 25 | "description": "LLM Provider ID", 26 | "hint": "用于总结和评估记忆的 LLM Provider ID。如果留空,将使用 AstrBot 的默认 LLM Provider。", 27 | "type": "string", 28 | "default": "" 29 | } 30 | } 31 | }, 32 | "session_manager": { 33 | "description": "会话管理器设置", 34 | "hint": "管理用户会话状态和内存", 35 | "type": "object", 36 | "items": { 37 | "max_sessions": { 38 | "description": "最大会话数量", 39 | "hint": "同时维护的最大会话数量,超过此数量将清理最旧的会话。", 40 | "type": "int", 41 | "default": 1000 42 | }, 43 | "session_ttl": { 44 | "description": "会话生存时间(秒)", 45 | "hint": "会话的最大空闲时间,超过此时间的会话将被自动清理。", 46 | "type": "int", 47 | "default": 3600 48 | } 49 | } 50 | }, 51 | "recall_engine": { 52 | "description": "回忆引擎设置", 53 | "hint": "负责智能检索记忆", 54 | "type": "object", 55 | "items": { 56 | "top_k": { 57 | "description": "单次检索数量", 58 | "hint": "单次检索返回的最相关记忆数量。", 59 | "type": "int", 60 | "default": 5 61 | }, 62 | "recall_strategy": { 63 | "description": "召回策略", 64 | "hint": "'similarity' - 仅基于相似度;'weighted' - 综合加权。", 65 | "type": "string", 66 | "options": [ 67 | "similarity", 68 | "weighted" 69 | ], 70 | "default": "weighted" 71 | }, 72 | "retrieval_mode": { 73 | "description": "检索模式", 74 | "hint": "'dense' - 纯密集检索;'sparse' - 纯稀疏检索;'hybrid' - 混合检索。", 75 | "type": "string", 76 | "options": [ 77 | "dense", 78 | "sparse", 79 | "hybrid" 80 | ], 81 | "default": "hybrid" 82 | }, 83 | "similarity_weight": { 84 | "description": "相关性权重", 85 | "hint": "范围 (0.0 - 1.0)", 86 | "type": "float", 87 | "default": 0.6 88 | }, 89 | "recency_weight": { 90 | "description": "新近度权重", 91 | "hint": "范围 (0.0 - 1.0)", 92 | "type": "float", 93 | "default": 0.2 94 | }, 95 | "importance_weight": { 96 | "description": "重要性权重", 97 | "hint": "范围 (0.0 - 1.0)", 98 | "type": "float", 99 | "default": 0.2 100 | } 101 | } 102 | }, 103 | "fusion": { 104 | "description": "结果融合配置", 105 | "hint": "混合检索时的结果融合策略", 106 | "type": "object", 107 | "items": { 108 | "strategy": { 109 | "description": "融合策略", 110 | "hint": "选择混合检索的结果融合方法", 111 | "type": "string", 112 | "options": [ 113 | "rrf", 114 | "hybrid_rrf", 115 | "weighted", 116 | "convex", 117 | "interleave", 118 | "rank_fusion", 119 | "score_fusion", 120 | "cascade", 121 | "adaptive" 122 | ], 123 | "default": "rrf" 124 | }, 125 | "rrf_k": { 126 | "description": "RRF参数k", 127 | "hint": "Reciprocal Rank Fusion 中的参数 k,较小值更重视靠前的结果。", 128 | "type": "int", 129 | "default": 60 130 | }, 131 | "dense_weight": { 132 | "description": "密集检索权重", 133 | "hint": "密集向量检索在融合中的权重 (0.0 - 1.0)", 134 | "type": "float", 135 | "default": 0.7 136 | }, 137 | "sparse_weight": { 138 | "description": "稀疏检索权重", 139 | "hint": "稀疏关键词检索在融合中的权重 (0.0 - 1.0)", 140 | "type": "float", 141 | "default": 0.3 142 | }, 143 | "convex_lambda": { 144 | "description": "凸组合参数λ", 145 | "hint": "凸组合融合中的参数λ,控制密集和稀疏检索的混合比例 (0.0 - 1.0)", 146 | "type": "float", 147 | "default": 0.5 148 | }, 149 | "interleave_ratio": { 150 | "description": "交替融合比例", 151 | "hint": "交替融合中密集结果所占的比例 (0.0 - 1.0)", 152 | "type": "float", 153 | "default": 0.5 154 | }, 155 | "rank_bias_factor": { 156 | "description": "排序偏置因子", 157 | "hint": "在两个排序列表中都出现的文档获得的额外加分 (0.0 - 1.0)", 158 | "type": "float", 159 | "default": 0.1 160 | }, 161 | "diversity_bonus": { 162 | "description": "多样性奖励", 163 | "hint": "鼓励结果多样性的奖励因子 (0.0 - 1.0)", 164 | "type": "float", 165 | "default": 0.1 166 | } 167 | } 168 | }, 169 | "sparse_retriever": { 170 | "description": "稀疏检索器设置", 171 | "hint": "基于关键词的全文检索配置", 172 | "type": "object", 173 | "items": { 174 | "enabled": { 175 | "description": "启用稀疏检索", 176 | "hint": "是否启用基于BM25的稀疏检索功能。", 177 | "type": "bool", 178 | "default": true 179 | }, 180 | "bm25_k1": { 181 | "description": "BM25 k1参数", 182 | "hint": "控制词频饱和度的参数,通常在1.2-2.0之间。", 183 | "type": "float", 184 | "default": 1.2 185 | }, 186 | "bm25_b": { 187 | "description": "BM25 b参数", 188 | "hint": "控制文档长度归一化的参数,范围0.0-1.0。", 189 | "type": "float", 190 | "default": 0.75 191 | }, 192 | "use_jieba": { 193 | "description": "使用中文分词", 194 | "hint": "是否使用jieba进行中文分词处理。", 195 | "type": "bool", 196 | "default": true 197 | } 198 | } 199 | }, 200 | "filtering_settings": { 201 | "description": "过滤与隔离设置", 202 | "type": "object", 203 | "items": { 204 | "use_persona_filtering": { 205 | "description": "启用人格记忆过滤", 206 | "hint": "开启后,只会召回和总结与当前人格相关的记忆。", 207 | "type": "bool", 208 | "default": true 209 | }, 210 | "use_session_filtering": { 211 | "description": "启用会话记忆隔离", 212 | "hint": "开启后,每个会话的记忆将是独立的。", 213 | "type": "bool", 214 | "default": true 215 | } 216 | } 217 | }, 218 | "reflection_engine": { 219 | "description": "反思引擎设置", 220 | "hint": "负责生成和评估记忆", 221 | "type": "object", 222 | "items": { 223 | "summary_trigger_rounds": { 224 | "description": "总结触发轮次", 225 | "hint": "触发对话历史总结的对话轮次(一问一答为一轮)。", 226 | "type": "int", 227 | "default": 5 228 | }, 229 | "importance_threshold": { 230 | "description": "重要性阈值", 231 | "hint": "记忆重要性得分的最低阈值,低于此值的记忆将被忽略。", 232 | "type": "float", 233 | "default": 0.5 234 | }, 235 | "event_extraction_prompt": { 236 | "description": "事件提取提示词", 237 | "hint": "指导 LLM 从对话历史中提取多个结构化记忆事件的 System Prompt。", 238 | "type": "text", 239 | "default": "### 角色\n你是一个善于分析和总结的AI助手。你的核心人设是从你自身的视角出发,记录与用户的互动和观察。\n\n### 指令/任务\n1. **仔细阅读**并理解下面提供的“对话历史”。\n2. 从**你(AI)的视角**出发,提取出多个独立的、有意义的记忆事件。事件必须准确描述,参考上下文。事件必须是完整的,具有前因后果的。**不允许编造事件**,**不允许改变事件**,**详细描述事件的所有信息**\n3. **核心要求**:\n * **第一人称视角**:所有事件都必须以“我”开头进行描述,例如“我告诉用户...”、“我观察到...”、“我被告知...”。\n * **使用具体名称**:直接使用对话中出现的人物昵称,**严禁**使用“用户”、“开发者”等通用词汇。\n * **记录互动者**:必须明确记录与你互动的用户名称。\n * **事件合并**:如果多条连续的对话构成一个完整的独立事件,应将其总结概括为一条记忆。\n4.**严禁**包含任何评分、额外的解释或说明性文字。\n 直接输出结果,不要有任何引言或总结。\n\n### 上下文\n* 在对话历史中,名为“AstrBot”的发言者就是**你自己**。\n* 记忆事件是:你与用户互动事的事件描述,详细记录谁、在何时、何地、做了什么、发生了什么。\n\n 'memory_content' 字段必须包含完整的事件描述,不能省略任何细节。\n\n单个系列事件必须详细记录在一个memory_content 中,形成完整的具有前因后果的事件记忆。\n\n" 240 | }, 241 | "evaluation_prompt": { 242 | "description": "评估提示词", 243 | "hint": "指导 LLM 评估单个记忆事件重要性的 System Prompt。必须返回一个 0.0 到 1.0 之间的浮点数。", 244 | "type": "text", 245 | "default": "### 角色\n你是一个专门评估记忆价值的AI分析模型。你的判断标准是该记忆对于与特定用户构建长期、个性化、有上下文的对话有多大的帮助。\n\n### 指令/任务\n1. **评估核心价值**:仔细阅读“记忆内容”,评估其对于未来对话的长期参考价值。\n2. **输出分数**:给出一个介于 0.0 到 1.0 之间的浮点数分数。\n3. **格式要求**:**只返回数字**,严禁包含任何额外的文本、解释或理由。\n\n### 上下文\n评分时,请参考以下价值标尺:\n* **高价值 (0.8 - 1.0)**:包含用户的核心身份信息、明确且长期的个人偏好/厌恶、设定的目标、重要的关系或事实。这些信息几乎总能在未来的互动中被引用。\n * 例如:用户的昵称、职业、关键兴趣点、对AI的称呼、重要的人生目标。\n* **中等价值 (0.4 - 0.7)**:包含用户的具体建议、功能请求、对某事的观点或一次性的重要问题。这些信息在短期内或特定话题下很有用,但可能随着时间推移或问题解决而失去价值。\n * 例如:对某个功能的反馈、对特定新闻事件的看法、报告了一个具体的bug。\n* **低价值 (0.1 - 0.3)**:包含短暂的情绪表达、日常问候、或非常具体且不太可能重复的上下文。这些信息很少有再次利用的机会。\n * 例如:一次性的惊叹、害怕的反应、普通的“你好”、“晚安”。\n* **无价值 (0.0)**:信息完全是瞬时的、无关紧要的,或者不包含任何关于用户本人的可复用信息。\n * 例如:观察到另一个机器人说了话、对一句无法理解的话的默认回应。\n\n### 问题\n请评估以下“记忆内容”的重要性,对于未来的对话有多大的参考价值?\n\n---\n\n**记忆内容**:\n{memory_content}\n\n" 246 | } 247 | } 248 | }, 249 | "forgetting_agent": { 250 | "description": "遗忘代理设置", 251 | "hint": "负责模拟遗忘,清理陈旧记忆", 252 | "type": "object", 253 | "items": { 254 | "enabled": { 255 | "description": "启用自动遗忘", 256 | "hint": "是否启用自动遗忘机制。", 257 | "type": "bool", 258 | "default": true 259 | }, 260 | "check_interval_hours": { 261 | "description": "检查间隔(小时)", 262 | "hint": "遗忘代理每隔多少小时运行一次。", 263 | "type": "int", 264 | "default": 24 265 | }, 266 | "retention_days": { 267 | "description": "记忆保留天数", 268 | "hint": "记忆无条件保留的最长天数。超过此天数的记忆将可能被遗忘。", 269 | "type": "int", 270 | "default": 90 271 | }, 272 | "importance_decay_rate": { 273 | "description": "重要性衰减率", 274 | "hint": "重要性得分每天衰减的速率。例如 0.01 代表每天降低 1%。", 275 | "type": "float", 276 | "default": 0.005 277 | }, 278 | "importance_threshold": { 279 | "description": "遗忘重要性阈值", 280 | "hint": "重要性低于此阈值的陈旧记忆将被删除。", 281 | "type": "float", 282 | "default": 0.1 283 | }, 284 | "forgetting_batch_size": { 285 | "description": "批处理大小", 286 | "hint": "遗忘代理每批处理的记忆数量,避免一次性加载过多数据。", 287 | "type": "int", 288 | "default": 1000 289 | } 290 | } 291 | } 292 | } -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxfight-s-Astrbot-Plugins/astrbot_plugin_livingmemory/a382b10a91a41b68abc850e5b1abe05052e372f0/core/__init__.py -------------------------------------------------------------------------------- /core/commands/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | 命令模块 - 提供插件命令的统一管理 4 | """ 5 | 6 | from .base_command import BaseCommand 7 | from .memory_commands import MemoryCommands 8 | from .search_commands import SearchCommands 9 | from .admin_commands import AdminCommands 10 | from .fusion_commands import FusionCommands 11 | 12 | __all__ = [ 13 | 'BaseCommand', 14 | 'MemoryCommands', 15 | 'SearchCommands', 16 | 'AdminCommands', 17 | 'FusionCommands' 18 | ] -------------------------------------------------------------------------------- /core/commands/base_command.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | base_command.py - 基础命令类 4 | 提供命令处理的基础功能和通用方法 5 | """ 6 | 7 | from abc import ABC, abstractmethod 8 | from typing import Optional, Dict, Any, List 9 | from datetime import datetime, timezone 10 | import json 11 | 12 | from astrbot.api import logger 13 | from astrbot.api.event import filter, AstrMessageEvent 14 | from astrbot.api.event.filter import PermissionType, permission_type 15 | from astrbot.api.star import Context 16 | 17 | 18 | class BaseCommand(ABC): 19 | """基础命令类,提供通用的命令处理功能""" 20 | 21 | def __init__(self, context: Context, config: Dict[str, Any]): 22 | self.context = context 23 | self.config = config 24 | 25 | @abstractmethod 26 | def register_commands(self): 27 | """注册命令的抽象方法""" 28 | pass 29 | 30 | def get_timezone(self) -> Any: 31 | """获取当前时区""" 32 | tz_config = self.config.get("timezone_settings", {}) 33 | tz_str = tz_config.get("timezone", "Asia/Shanghai") 34 | from ..utils import get_now_datetime 35 | return get_now_datetime(tz_str).tzinfo 36 | 37 | def format_timestamp(self, ts: Optional[float]) -> str: 38 | """格式化时间戳""" 39 | if not ts: 40 | return "未知" 41 | try: 42 | dt_utc = datetime.fromtimestamp(float(ts), tz=timezone.utc) 43 | dt_local = dt_utc.astimezone(self.get_timezone()) 44 | return dt_local.strftime("%Y-%m-%d %H:%M:%S") 45 | except (ValueError, TypeError): 46 | return "未知" 47 | 48 | def safe_parse_metadata(self, metadata: Any) -> Dict[str, Any]: 49 | """安全解析元数据""" 50 | if isinstance(metadata, dict): 51 | return metadata 52 | if isinstance(metadata, str): 53 | try: 54 | return json.loads(metadata) 55 | except json.JSONDecodeError: 56 | return {} 57 | return {} 58 | 59 | def format_memory_card(self, result: Any) -> str: 60 | """格式化记忆卡片显示""" 61 | metadata = self.safe_parse_metadata(result.data.get("metadata", {})) 62 | 63 | create_time_str = self.format_timestamp(metadata.get("create_time")) 64 | last_access_time_str = self.format_timestamp(metadata.get("last_access_time")) 65 | importance_score = metadata.get("importance", 0.0) 66 | event_type = metadata.get("event_type", "未知") 67 | 68 | card = ( 69 | f"ID: {result.data['id']}\n" 70 | f"记 忆 度: {result.similarity:.2f}\n" 71 | f"重 要 性: {importance_score:.2f}\n" 72 | f"记忆类型: {event_type}\n\n" 73 | f"内容: {result.data['text']}\n\n" 74 | f"创建于: {create_time_str}\n" 75 | f"最后访问: {last_access_time_str}" 76 | ) 77 | return card -------------------------------------------------------------------------------- /core/community/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxfight-s-Astrbot-Plugins/astrbot_plugin_livingmemory/a382b10a91a41b68abc850e5b1abe05052e372f0/core/community/__init__.py -------------------------------------------------------------------------------- /core/community/community_detector.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import asyncio 3 | import aiosqlite 4 | import networkx as nx 5 | from typing import List, Tuple 6 | 7 | from astrbot.api import logger 8 | 9 | 10 | class CommunityDetector: 11 | """ 12 | 一个后台服务,负责从 SQLite 加载图数据, 13 | 使用 NetworkX 进行社区发现,并将结果写回数据库。 14 | """ 15 | 16 | def __init__(self, db_path: str): 17 | self.db_path = db_path 18 | self.connection = None 19 | 20 | async def initialize(self): 21 | """初始化数据库连接。""" 22 | self.connection = await aiosqlite.connect(self.db_path) 23 | 24 | async def close(self): 25 | """关闭数据库连接。""" 26 | if self.connection: 27 | await self.connection.close() 28 | self.connection = None 29 | 30 | async def _load_graph_from_db(self) -> nx.Graph: 31 | """从 SQLite 中加载边,构建一个 NetworkX 图对象。""" 32 | G = nx.Graph() 33 | if not self.connection: 34 | await self.initialize() 35 | 36 | try: 37 | # 我们只需要边的信息来构建图的结构 38 | cursor = await self.connection.execute("SELECT source_id, target_id FROM graph_edges") 39 | edges = await cursor.fetchall() 40 | G.add_edges_from(edges) 41 | except Exception as e: 42 | logger.error(f"从数据库加载图数据失败: {e}") 43 | raise 44 | 45 | return G 46 | 47 | async def _save_results_to_db(self, communities: List[Tuple[str]]): 48 | """将计算出的社区结果批量更新回 memories 表。""" 49 | if not self.connection: 50 | logger.error("数据库连接未初始化") 51 | return 52 | 53 | # 注意:这里的逻辑需要一个从“图节点ID”到“记忆internal_id”的映射 54 | # 我们简化一下,假设 Event 节点的 ID 就是 memory_id 55 | updates = [] 56 | for i, community_nodes in enumerate(communities): 57 | community_id = f"community_{i}" 58 | for node_id in community_nodes: 59 | # 假设事件节点的 entity_id 格式为 'evt_mem_xxx' 60 | if node_id.startswith("evt_mem_"): 61 | memory_id = node_id.split("evt_mem_")[1] 62 | updates.append((community_id, memory_id)) 63 | 64 | if updates: 65 | try: 66 | await self.connection.executemany( 67 | "UPDATE memories SET community_id = ? WHERE memory_id = ?", updates 68 | ) 69 | await self.connection.commit() 70 | logger.info(f"成功更新 {len(updates)} 条记忆的社区信息") 71 | except Exception as e: 72 | logger.error(f"更新社区信息失败: {e}") 73 | raise 74 | 75 | async def run_detection_and_update(self): 76 | """ 77 | 执行社区发现的完整流程。使用进程池进行计算密集型任务。 78 | """ 79 | try: 80 | logger.info("开始从数据库加载图...") 81 | graph = await self._load_graph_from_db() 82 | 83 | if graph.number_of_nodes() == 0: 84 | logger.info("图中没有节点,跳过社区发现。") 85 | return 86 | 87 | logger.info(f"图加载完成({graph.number_of_nodes()}个节点,{graph.number_of_edges()}条边),开始运行 Louvain 社区发现算法...") 88 | 89 | # 使用进程池执行计算密集型的社区发现算法 90 | import concurrent.futures 91 | with concurrent.futures.ProcessPoolExecutor() as executor: 92 | # resolution 参数可以调整社区的大小,值越小社区越多越小 93 | communities = await asyncio.get_event_loop().run_in_executor( 94 | executor, 95 | nx.community.louvain_communities, 96 | graph, 97 | 1.0 # resolution 98 | ) 99 | 100 | logger.info(f"发现 {len(communities)} 个社区,开始将结果写回数据库...") 101 | await self._save_results_to_db(communities) 102 | logger.info("社区信息更新完成。") 103 | 104 | except Exception as e: 105 | logger.error(f"社区发现过程中发生错误: {e}", exc_info=True) 106 | raise 107 | -------------------------------------------------------------------------------- /core/config_validator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | config_validator.py - 配置验证模块 4 | 提供配置验证和默认值管理功能。 5 | """ 6 | 7 | from typing import Dict, Any, List, Optional, Union 8 | from pydantic import BaseModel, Field, field_validator, model_validator 9 | from astrbot.api import logger 10 | 11 | 12 | class SessionManagerConfig(BaseModel): 13 | """会话管理器配置""" 14 | max_sessions: int = Field(default=1000, ge=1, le=10000, description="最大会话数量") 15 | session_ttl: int = Field(default=3600, ge=60, le=86400, description="会话生存时间(秒)") 16 | 17 | 18 | class RecallEngineConfig(BaseModel): 19 | """回忆引擎配置""" 20 | top_k: int = Field(default=5, ge=1, le=50, description="返回记忆数量") 21 | recall_strategy: str = Field(default="weighted", pattern="^(similarity|weighted)$", description="召回策略") 22 | retrieval_mode: str = Field(default="hybrid", pattern="^(hybrid|dense|sparse)$", description="检索模式") 23 | similarity_weight: float = Field(default=0.6, ge=0.0, le=1.0, description="相似度权重") 24 | importance_weight: float = Field(default=0.2, ge=0.0, le=1.0, description="重要性权重") 25 | recency_weight: float = Field(default=0.2, ge=0.0, le=1.0, description="新近度权重") 26 | 27 | @model_validator(mode='after') 28 | def validate_weights_sum(self): 29 | """验证权重总和接近1.0""" 30 | similarity = self.similarity_weight 31 | importance = self.importance_weight 32 | recency = self.recency_weight 33 | 34 | # 计算权重总和 35 | total = similarity + importance + recency 36 | if abs(total - 1.0) > 0.1: 37 | logger.warning(f"权重总和 {total:.2f} 偏离1.0较多,可能影响检索效果") 38 | 39 | return self 40 | 41 | 42 | class FusionConfig(BaseModel): 43 | """结果融合配置""" 44 | strategy: str = Field( 45 | default="rrf", 46 | pattern="^(rrf|weighted|cascade|adaptive|convex|interleave|rank_fusion|score_fusion|hybrid_rrf)$", 47 | description="融合策略" 48 | ) 49 | rrf_k: int = Field(default=60, ge=1, le=1000, description="RRF参数k") 50 | dense_weight: float = Field(default=0.7, ge=0.0, le=1.0, description="密集检索权重") 51 | sparse_weight: float = Field(default=0.3, ge=0.0, le=1.0, description="稀疏检索权重") 52 | sparse_alpha: float = Field(default=1.0, ge=0.1, le=10.0, description="稀疏分数缩放") 53 | sparse_epsilon: float = Field(default=0.0, ge=0.0, le=1.0, description="稀疏分数偏移") 54 | 55 | # 新增参数 56 | convex_lambda: float = Field(default=0.5, ge=0.0, le=1.0, description="凸组合参数λ") 57 | interleave_ratio: float = Field(default=0.5, ge=0.0, le=1.0, description="交替融合比例") 58 | rank_bias_factor: float = Field(default=0.1, ge=0.0, le=1.0, description="排序偏置因子") 59 | diversity_bonus: float = Field(default=0.1, ge=0.0, le=1.0, description="多样性奖励") 60 | 61 | 62 | class ReflectionEngineConfig(BaseModel): 63 | """反思引擎配置""" 64 | summary_trigger_rounds: int = Field(default=10, ge=1, le=100, description="触发反思的对话轮次") 65 | importance_threshold: float = Field(default=0.5, ge=0.0, le=1.0, description="记忆重要性阈值") 66 | event_extraction_prompt: Optional[str] = Field(default=None, description="事件提取提示词") 67 | evaluation_prompt: Optional[str] = Field(default=None, description="评分提示词") 68 | 69 | 70 | class SparseRetrieverConfig(BaseModel): 71 | """稀疏检索器配置""" 72 | enabled: bool = Field(default=True, description="是否启用稀疏检索") 73 | bm25_k1: float = Field(default=1.2, ge=0.1, le=10.0, description="BM25 k1参数") 74 | bm25_b: float = Field(default=0.75, ge=0.0, le=1.0, description="BM25 b参数") 75 | use_jieba: bool = Field(default=True, description="是否使用jieba分词") 76 | 77 | 78 | class ForgettingAgentConfig(BaseModel): 79 | """遗忘代理配置""" 80 | enabled: bool = Field(default=True, description="是否启用遗忘代理") 81 | check_interval_hours: int = Field(default=24, ge=1, le=168, description="检查间隔(小时)") 82 | retention_days: int = Field(default=90, ge=1, le=3650, description="记忆保留天数") 83 | importance_decay_rate: float = Field(default=0.005, ge=0.0, le=1.0, description="重要性衰减率") 84 | importance_threshold: float = Field(default=0.1, ge=0.0, le=1.0, description="删除阈值") 85 | forgetting_batch_size: int = Field(default=1000, ge=100, le=10000, description="批处理大小") 86 | 87 | 88 | class FilteringConfig(BaseModel): 89 | """过滤配置""" 90 | use_persona_filtering: bool = Field(default=True, description="是否使用人格过滤") 91 | use_session_filtering: bool = Field(default=True, description="是否使用会话过滤") 92 | 93 | 94 | class ProviderConfig(BaseModel): 95 | """Provider配置""" 96 | embedding_provider_id: Optional[str] = Field(default=None, description="Embedding Provider ID") 97 | llm_provider_id: Optional[str] = Field(default=None, description="LLM Provider ID") 98 | 99 | 100 | class TimezoneConfig(BaseModel): 101 | """时区配置""" 102 | timezone: str = Field(default="Asia/Shanghai", description="时区") 103 | 104 | 105 | class LivingMemoryConfig(BaseModel): 106 | """完整插件配置""" 107 | session_manager: SessionManagerConfig = Field(default_factory=SessionManagerConfig) 108 | recall_engine: RecallEngineConfig = Field(default_factory=RecallEngineConfig) 109 | reflection_engine: ReflectionEngineConfig = Field(default_factory=ReflectionEngineConfig) 110 | sparse_retriever: SparseRetrieverConfig = Field(default_factory=SparseRetrieverConfig) 111 | forgetting_agent: ForgettingAgentConfig = Field(default_factory=ForgettingAgentConfig) 112 | filtering_settings: FilteringConfig = Field(default_factory=FilteringConfig) 113 | provider_settings: ProviderConfig = Field(default_factory=ProviderConfig) 114 | timezone_settings: TimezoneConfig = Field(default_factory=TimezoneConfig) 115 | 116 | # 为融合配置添加嵌套支持 117 | fusion: Optional[FusionConfig] = Field(default_factory=FusionConfig, description="结果融合配置") 118 | 119 | model_config = {"extra": "allow"} # 允许额外字段,向前兼容 120 | 121 | 122 | def validate_config(raw_config: Dict[str, Any]) -> LivingMemoryConfig: 123 | """ 124 | 验证并返回规范化的配置对象。 125 | 126 | Args: 127 | raw_config: 原始配置字典 128 | 129 | Returns: 130 | LivingMemoryConfig: 验证后的配置对象 131 | 132 | Raises: 133 | ValueError: 配置验证失败 134 | """ 135 | try: 136 | config = LivingMemoryConfig(**raw_config) 137 | logger.info("配置验证成功") 138 | return config 139 | except Exception as e: 140 | logger.error(f"配置验证失败: {e}") 141 | raise ValueError(f"插件配置无效: {e}") from e 142 | 143 | 144 | def get_default_config() -> Dict[str, Any]: 145 | """ 146 | 获取默认配置字典。 147 | 148 | Returns: 149 | Dict[str, Any]: 默认配置 150 | """ 151 | return LivingMemoryConfig().model_dump() 152 | 153 | 154 | def merge_config_with_defaults(user_config: Dict[str, Any]) -> Dict[str, Any]: 155 | """ 156 | 将用户配置与默认配置合并。 157 | 158 | Args: 159 | user_config: 用户提供的配置 160 | 161 | Returns: 162 | Dict[str, Any]: 合并后的配置 163 | """ 164 | default_config = get_default_config() 165 | 166 | def deep_merge(default: Dict[str, Any], user: Dict[str, Any]) -> Dict[str, Any]: 167 | """深度合并两个字典""" 168 | result = default.copy() 169 | for key, value in user.items(): 170 | if key in result and isinstance(result[key], dict) and isinstance(value, dict): 171 | result[key] = deep_merge(result[key], value) 172 | else: 173 | result[key] = value 174 | return result 175 | 176 | merged = deep_merge(default_config, user_config) 177 | logger.debug("配置已与默认值合并") 178 | return merged 179 | 180 | 181 | def validate_runtime_config_changes(current_config: LivingMemoryConfig, changes: Dict[str, Any]) -> bool: 182 | """ 183 | 验证运行时配置更改是否有效。 184 | 185 | Args: 186 | current_config: 当前配置 187 | changes: 要更改的配置项 188 | 189 | Returns: 190 | bool: 是否有效 191 | """ 192 | try: 193 | # 创建更新后的配置副本进行验证 194 | updated_dict = current_config.model_dump() 195 | 196 | def update_nested_dict(target: Dict[str, Any], updates: Dict[str, Any]): 197 | for key, value in updates.items(): 198 | if '.' in key: 199 | # 处理嵌套键,如 "recall_engine.top_k" 200 | parts = key.split('.') 201 | current = target 202 | for part in parts[:-1]: 203 | if part not in current: 204 | current[part] = {} 205 | current = current[part] 206 | current[parts[-1]] = value 207 | else: 208 | target[key] = value 209 | 210 | update_nested_dict(updated_dict, changes) 211 | 212 | # 验证更新后的配置 213 | LivingMemoryConfig(**updated_dict) 214 | return True 215 | 216 | except Exception as e: 217 | logger.error(f"运行时配置更改验证失败: {e}") 218 | return False -------------------------------------------------------------------------------- /core/constants.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | constants.py - 插件使用的常量 4 | """ 5 | 6 | # 注入到 System Prompt 的记忆头尾格式 7 | MEMORY_INJECTION_HEADER = "" 8 | MEMORY_INJECTION_FOOTER = "" 9 | -------------------------------------------------------------------------------- /core/engines/forgetting_agent.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | forgetting_agent.py - 遗忘代理 4 | 作为一个后台任务,定期清理陈旧的、不重要的记忆,模拟人类的遗忘曲线。 5 | """ 6 | 7 | import asyncio 8 | import json 9 | from typing import Dict, Any, Optional 10 | 11 | from astrbot.api import logger 12 | from astrbot.api.star import Context 13 | from ...storage.faiss_manager import FaissManager 14 | from ..utils import get_now_datetime, safe_parse_metadata, validate_timestamp 15 | 16 | 17 | class ForgettingAgent: 18 | """ 19 | 遗忘代理:作为一个后台任务,定期清理陈旧的、不重要的记忆,模拟人类的遗忘曲线。 20 | """ 21 | 22 | def __init__( 23 | self, context: Context, config: Dict[str, Any], faiss_manager: FaissManager 24 | ): 25 | """ 26 | 初始化遗忘代理。 27 | 28 | Args: 29 | context (Context): AstrBot 的上下文对象。 30 | config (Dict[str, Any]): 插件配置中 'forgetting_agent' 部分的字典。 31 | faiss_manager (FaissManager): 数据库管理器实例。 32 | """ 33 | self.context = context 34 | self.config = config 35 | self.faiss_manager = faiss_manager 36 | self._task: Optional[asyncio.Task] = None 37 | logger.info("ForgettingAgent 初始化成功。") 38 | 39 | async def start(self): 40 | """启动后台遗忘任务。""" 41 | if not self.config.get("enabled", True): 42 | logger.info("遗忘代理未启用,不启动后台任务。") 43 | return 44 | 45 | if self._task is None or self._task.done(): 46 | self._task = asyncio.create_task(self._run_periodically()) 47 | logger.info("遗忘代理后台任务已启动。") 48 | 49 | async def stop(self): 50 | """停止后台遗忘任务。""" 51 | if self._task and not self._task.done(): 52 | self._task.cancel() 53 | try: 54 | await self._task 55 | except asyncio.CancelledError: 56 | logger.info("遗忘代理后台任务已成功取消。") 57 | self._task = None 58 | 59 | async def _run_periodically(self): 60 | """后台任务的循环体。""" 61 | interval_hours = self.config.get("check_interval_hours", 24) 62 | interval_seconds = interval_hours * 3600 63 | logger.info(f"遗忘代理将每 {interval_hours} 小时运行一次。") 64 | 65 | while True: 66 | try: 67 | await asyncio.sleep(interval_seconds) 68 | logger.info("开始执行每日记忆清理任务...") 69 | await self._prune_memories() 70 | logger.info("每日记忆清理任务执行完毕。") 71 | except asyncio.CancelledError: 72 | logger.info("遗忘代理任务被取消。") 73 | break 74 | except Exception as e: 75 | logger.error(f"遗忘代理后台任务发生错误: {e}", exc_info=True) 76 | # 即使出错,也等待下一个周期,避免快速失败刷屏 77 | await asyncio.sleep(60) 78 | 79 | async def _prune_memories(self): 80 | """执行一次完整的记忆衰减和修剪,使用分页处理避免内存过载。""" 81 | # 获取记忆总数 82 | total_memories = await self.faiss_manager.count_total_memories() 83 | if total_memories == 0: 84 | logger.info("数据库中没有记忆,无需清理。") 85 | return 86 | 87 | retention_days = self.config.get("retention_days", 90) 88 | decay_rate = self.config.get("importance_decay_rate", 0.005) 89 | current_time = get_now_datetime(self.context).timestamp() 90 | 91 | # 分页处理配置 92 | page_size = self.config.get("forgetting_batch_size", 1000) # 每批处理数量 93 | 94 | logger.info(f"开始处理 {total_memories} 条记忆,每批 {page_size} 条") 95 | 96 | memories_to_update = [] 97 | ids_to_delete = [] 98 | total_processed = 0 99 | 100 | # 分页处理所有记忆 101 | for offset in range(0, total_memories, page_size): 102 | batch_memories = await self.faiss_manager.get_memories_paginated( 103 | page_size=page_size, offset=offset 104 | ) 105 | 106 | if not batch_memories: 107 | break 108 | 109 | logger.debug(f"处理第 {offset//page_size + 1} 批,共 {len(batch_memories)} 条记忆") 110 | 111 | batch_updates = [] 112 | batch_deletes = [] 113 | 114 | for mem in batch_memories: 115 | # 使用统一的元数据解析函数 116 | metadata = safe_parse_metadata(mem["metadata"]) 117 | if not metadata: 118 | logger.warning(f"无法解析记忆 {mem['id']} 的元数据,跳过处理") 119 | continue 120 | 121 | # 1. 重要性衰减 122 | create_time = validate_timestamp(metadata.get("create_time"), current_time) 123 | days_since_creation = (current_time - create_time) / (24 * 3600) 124 | 125 | # 线性衰减 126 | decayed_importance = metadata.get("importance", 0.5) - ( 127 | days_since_creation * decay_rate 128 | ) 129 | metadata["importance"] = max(0, decayed_importance) # 确保不为负 130 | 131 | mem["metadata"] = metadata # 更新内存中的 metadata 132 | batch_updates.append(mem) 133 | 134 | # 2. 识别待删除项 135 | retention_seconds = retention_days * 24 * 3600 136 | is_old = (current_time - create_time) > retention_seconds 137 | # 从配置中读取重要性阈值 138 | importance_threshold = self.config.get("importance_threshold", 0.1) 139 | is_unimportant = metadata["importance"] < importance_threshold 140 | 141 | if is_old and is_unimportant: 142 | batch_deletes.append(mem["id"]) 143 | 144 | # 累积到全局列表 145 | memories_to_update.extend(batch_updates) 146 | ids_to_delete.extend(batch_deletes) 147 | total_processed += len(batch_memories) 148 | 149 | # 如果批次数据过多,执行中间提交 150 | if len(memories_to_update) >= page_size * 2: 151 | logger.debug(f"执行中间批次更新,更新 {len(memories_to_update)} 条记忆") 152 | await self.faiss_manager.update_memories_metadata(memories_to_update) 153 | memories_to_update.clear() 154 | 155 | logger.debug(f"已处理 {total_processed}/{total_memories} 条记忆") 156 | 157 | # 3. 执行最终数据库操作 158 | if memories_to_update: 159 | await self.faiss_manager.update_memories_metadata(memories_to_update) 160 | logger.info(f"更新了 {len(memories_to_update)} 条记忆的重要性得分。") 161 | 162 | if ids_to_delete: 163 | # 分批删除,避免一次删除太多 164 | delete_batch_size = 100 165 | for i in range(0, len(ids_to_delete), delete_batch_size): 166 | batch = ids_to_delete[i:i + delete_batch_size] 167 | await self.faiss_manager.delete_memories(batch) 168 | logger.debug(f"删除了 {len(batch)} 条记忆") 169 | 170 | logger.info(f"总共删除了 {len(ids_to_delete)} 条陈旧且不重要的记忆。") 171 | 172 | logger.info(f"记忆清理完成,处理了 {total_processed} 条记忆") 173 | -------------------------------------------------------------------------------- /core/engines/recall_engine.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | recall_engine.py - 回忆引擎 4 | 负责根据用户查询,使用多策略智能召回最相关的记忆。 5 | 支持密集向量检索、稀疏检索和混合检索。 6 | """ 7 | 8 | import json 9 | import math 10 | from typing import List, Dict, Any, Optional 11 | 12 | from astrbot.api import logger 13 | from astrbot.api.star import Context 14 | from ...storage.faiss_manager import FaissManager, Result 15 | from ..retrieval import SparseRetriever, ResultFusion, SearchResult 16 | from ..utils import get_now_datetime 17 | 18 | 19 | class RecallEngine: 20 | """ 21 | 回忆引擎:负责根据用户查询,使用多策略智能召回最相关的记忆。 22 | 支持密集向量检索、稀疏检索和混合检索。 23 | """ 24 | 25 | def __init__(self, config: Dict[str, Any], faiss_manager: FaissManager, sparse_retriever: Optional[SparseRetriever] = None): 26 | """ 27 | 初始化回忆引擎。 28 | 29 | Args: 30 | config (Dict[str, Any]): 插件配置中 'recall_engine' 部分的字典。 31 | faiss_manager (FaissManager): 数据库管理器实例。 32 | sparse_retriever (Optional[SparseRetriever]): 稀疏检索器实例。 33 | """ 34 | self.config = config 35 | self.faiss_manager = faiss_manager 36 | self.sparse_retriever = sparse_retriever 37 | 38 | # 初始化结果融合器 39 | fusion_config = config.get("fusion", {}) 40 | fusion_strategy = fusion_config.get("strategy", "rrf") 41 | self.result_fusion = ResultFusion(strategy=fusion_strategy, config=fusion_config) 42 | 43 | logger.info("RecallEngine 初始化成功。") 44 | 45 | async def recall( 46 | self, 47 | context: Context, 48 | query: str, 49 | session_id: Optional[str] = None, 50 | persona_id: Optional[str] = None, 51 | k: Optional[int] = None, 52 | ) -> List[Result]: 53 | """ 54 | 执行回忆流程,检索并可能重排记忆。 55 | 56 | Args: 57 | query (str): 用户查询文本。 58 | session_id (Optional[str], optional): 当前会话 ID. Defaults to None. 59 | persona_id (Optional[str], optional): 当前人格 ID. Defaults to None. 60 | k (Optional[int], optional): 希望返回的记忆数量,如果为 None 则从配置中读取. 61 | 62 | Returns: 63 | List[Result]: 最终返回给上层应用的记忆列表。 64 | """ 65 | top_k = k if k is not None else self.config.get("top_k", 5) 66 | retrieval_mode = self.config.get("retrieval_mode", "hybrid") # hybrid, dense, sparse 67 | 68 | # 分析查询特征(用于自适应策略) 69 | query_info = self.result_fusion.analyze_query(query) 70 | logger.debug(f"Query analysis: {query_info}") 71 | 72 | # 根据检索模式执行搜索 73 | if retrieval_mode == "hybrid" and self.sparse_retriever: 74 | # 混合检索 75 | logger.debug("使用混合检索模式...") 76 | return await self._hybrid_search(context, query, session_id, persona_id, top_k, query_info) 77 | elif retrieval_mode == "sparse" and self.sparse_retriever: 78 | # 纯稀疏检索 79 | logger.debug("使用稀疏检索模式...") 80 | return await self._sparse_search(query, session_id, persona_id, top_k) 81 | else: 82 | # 纯密集检索(默认) 83 | logger.debug("使用密集检索模式...") 84 | return await self._dense_search(context, query, session_id, persona_id, top_k) 85 | 86 | async def _hybrid_search( 87 | self, 88 | context: Context, 89 | query: str, 90 | session_id: Optional[str], 91 | persona_id: Optional[str], 92 | k: int, 93 | query_info: Dict[str, Any] 94 | ) -> List[Result]: 95 | """执行混合检索""" 96 | # 并行执行密集和稀疏检索 97 | import asyncio 98 | 99 | # 密集检索 100 | dense_task = self.faiss_manager.search_memory( 101 | query=query, k=k*2, session_id=session_id, persona_id=persona_id 102 | ) 103 | 104 | # 稀疏检索 105 | sparse_task = self.sparse_retriever.search( 106 | query=query, limit=k*2, session_id=session_id, persona_id=persona_id 107 | ) 108 | 109 | # 等待两个检索完成 110 | dense_results, sparse_results = await asyncio.gather(dense_task, sparse_task, return_exceptions=True) 111 | 112 | # 处理异常 113 | if isinstance(dense_results, Exception): 114 | logger.error(f"密集检索失败: {dense_results}") 115 | dense_results = [] 116 | if isinstance(sparse_results, Exception): 117 | logger.error(f"稀疏检索失败: {sparse_results}") 118 | sparse_results = [] 119 | 120 | logger.debug(f"Dense results: {len(dense_results)}, Sparse results: {len(sparse_results)}") 121 | 122 | # 融合结果 123 | fused_results = self.result_fusion.fuse( 124 | dense_results=dense_results, 125 | sparse_results=sparse_results, 126 | k=k, 127 | query_info=query_info 128 | ) 129 | 130 | # 转换回 Result 格式 131 | final_results = [] 132 | for result in fused_results: 133 | final_results.append(Result( 134 | data={ 135 | "id": result.doc_id, 136 | "text": result.content, 137 | "metadata": result.metadata 138 | }, 139 | similarity=result.final_score 140 | )) 141 | 142 | # 应用传统的加权重排(如果需要) 143 | strategy = self.config.get("recall_strategy", "weighted") 144 | if strategy == "weighted": 145 | logger.debug("对混合检索结果应用加权重排...") 146 | final_results = self._rerank_by_weighted_score(context, final_results) 147 | 148 | return final_results 149 | 150 | async def _dense_search( 151 | self, 152 | context: Context, 153 | query: str, 154 | session_id: Optional[str], 155 | persona_id: Optional[str], 156 | k: int 157 | ) -> List[Result]: 158 | """执行密集检索""" 159 | results = await self.faiss_manager.search_memory( 160 | query=query, k=k, session_id=session_id, persona_id=persona_id 161 | ) 162 | 163 | if not results: 164 | return [] 165 | 166 | # 应用重排 167 | strategy = self.config.get("recall_strategy", "weighted") 168 | if strategy == "weighted": 169 | logger.debug("使用 'weighted' 策略进行重排...") 170 | return self._rerank_by_weighted_score(context, results) 171 | else: 172 | logger.debug("使用 'similarity' 策略,直接返回结果。") 173 | return results 174 | 175 | async def _sparse_search( 176 | self, 177 | query: str, 178 | session_id: Optional[str], 179 | persona_id: Optional[str], 180 | k: int 181 | ) -> List[Result]: 182 | """执行稀疏检索""" 183 | sparse_results = await self.sparse_retriever.search( 184 | query=query, limit=k, session_id=session_id, persona_id=persona_id 185 | ) 186 | 187 | # 转换为 Result 格式 188 | results = [] 189 | for result in sparse_results: 190 | results.append(Result( 191 | data={ 192 | "id": result.doc_id, 193 | "text": result.content, 194 | "metadata": result.metadata 195 | }, 196 | similarity=result.score 197 | )) 198 | 199 | return results 200 | 201 | def _rerank_by_weighted_score( 202 | self, context: Context, results: List[Result] 203 | ) -> List[Result]: 204 | """ 205 | 根据相似度、重要性和新近度对结果进行加权重排。 206 | """ 207 | sim_w = self.config.get("similarity_weight", 0.6) 208 | imp_w = self.config.get("importance_weight", 0.2) 209 | rec_w = self.config.get("recency_weight", 0.2) 210 | 211 | reranked_results = [] 212 | current_time = get_now_datetime(context).timestamp() 213 | 214 | for res in results: 215 | # 安全解析元数据 216 | metadata = res.data.get("metadata", {}) 217 | if isinstance(metadata, str): 218 | try: 219 | metadata = json.loads(metadata) 220 | except json.JSONDecodeError as e: 221 | logger.warning(f"解析记忆元数据失败: {e}") 222 | metadata = {} 223 | 224 | # 归一化各项得分 (0-1) 225 | similarity_score = res.similarity 226 | importance_score = metadata.get("importance", 0.0) 227 | 228 | # 计算新近度得分 229 | last_access = metadata.get("last_access_time", current_time) 230 | # 增加健壮性检查,以防 last_access 是字符串 231 | if isinstance(last_access, str): 232 | try: 233 | last_access = float(last_access) 234 | except (ValueError, TypeError): 235 | last_access = current_time 236 | 237 | hours_since_access = (current_time - last_access) / 3600 238 | # 使用指数衰减,半衰期约为24小时 239 | recency_score = math.exp(-0.028 * hours_since_access) 240 | 241 | # 计算最终加权分 242 | final_score = ( 243 | similarity_score * sim_w 244 | + importance_score * imp_w 245 | + recency_score * rec_w 246 | ) 247 | 248 | # 直接修改现有 Result 对象的 similarity 分数 249 | res.similarity = final_score 250 | reranked_results.append(res) 251 | 252 | # 按最终得分降序排序 253 | reranked_results.sort(key=lambda x: x.similarity, reverse=True) 254 | 255 | return reranked_results 256 | -------------------------------------------------------------------------------- /core/engines/reflection_engine.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | reflection_engine.py - 反思引擎 4 | 负责对会话历史进行反思,提取、评估并存储多个独立的、基于事件的记忆。 5 | """ 6 | 7 | import json 8 | from typing import List, Dict, Any, Optional 9 | 10 | from pydantic import ValidationError 11 | 12 | from astrbot.api import logger 13 | from astrbot.api.provider import Provider 14 | from ...storage.faiss_manager import FaissManager 15 | from ..utils import extract_json_from_response 16 | from ..models import ( 17 | MemoryEvent, 18 | _LLMExtractionEventList, 19 | _LLMScoreEvaluation, 20 | ) 21 | 22 | 23 | class ReflectionEngine: 24 | """ 25 | 反思引擎:负责对会话历史进行反思,提取、评估并存储多个独立的、基于事件的记忆。 26 | 采用两阶段流程:1. 批量提取事件 2. 批量评估分数 27 | """ 28 | 29 | def __init__( 30 | self, 31 | config: Dict[str, Any], 32 | llm_provider: Provider, 33 | faiss_manager: FaissManager, 34 | ): 35 | self.config = config 36 | self.llm_provider = llm_provider 37 | self.faiss_manager = faiss_manager 38 | logger.info("ReflectionEngine 初始化成功。") 39 | 40 | async def _extract_events( 41 | self, history_text: str, persona_prompt: Optional[str] 42 | ) -> List[MemoryEvent]: 43 | """第一阶段:从对话历史中批量提取记忆事件。""" 44 | system_prompt = self._build_event_extraction_prompt() 45 | persona_section = ( 46 | f"\n**重要:**在分析时请代入以下人格,但是应该秉持着记录互动者的原则:\n{persona_prompt}\n" 47 | if persona_prompt 48 | else "" 49 | ) 50 | user_prompt = f"{persona_section}下面是你需要分析的对话历史:\n{history_text}" 51 | 52 | response = await self.llm_provider.text_chat( 53 | prompt=user_prompt, system_prompt=system_prompt, json_mode=True 54 | ) 55 | 56 | json_text = extract_json_from_response(response.completion_text.strip()) 57 | if not json_text: 58 | logger.warning("LLM 提取事件返回为空。") 59 | return [] 60 | logger.debug(f"提取到的记忆事件: {json_text}") 61 | 62 | try: 63 | extracted_data = _LLMExtractionEventList.model_validate_json(json_text) 64 | # 转换为 MemoryEvent 对象列表 65 | # 注意:LLM 返回的是 _LLMExtractionEvent,其 id 字段对应 MemoryEvent 的 temp_id 66 | memory_events = [] 67 | for event in extracted_data.events: 68 | event_dict = event.model_dump() 69 | # 将 'id' 字段重命名为 'temp_id' 以匹配 MemoryEvent 模型 70 | if "id" in event_dict: 71 | event_dict["temp_id"] = event_dict.pop("id") 72 | memory_events.append(MemoryEvent(**event_dict)) 73 | return memory_events 74 | except (ValidationError, json.JSONDecodeError) as e: 75 | logger.error( 76 | f"事件提取阶段JSON解析失败: {e}\n原始返回: {response.completion_text.strip()}", 77 | exc_info=True, 78 | ) 79 | return [] 80 | 81 | async def _evaluate_scores( 82 | self, events: List[MemoryEvent], persona_prompt: Optional[str] 83 | ) -> Dict[str, float]: 84 | """第二阶段:对一批记忆事件进行批量评分。""" 85 | if not events: 86 | return {} 87 | 88 | system_prompt = self._build_evaluation_prompt() 89 | 90 | # 构建批量评估的输入 91 | memories_to_evaluate = [ 92 | {"id": event.temp_id, "content": event.memory_content} for event in events 93 | ] 94 | persona_section = ( 95 | f"\n**重要:**在评估时请代入以下人格,这会影响你对“重要性”的判断:\n{persona_prompt}\n" 96 | if persona_prompt 97 | else "" 98 | ) 99 | user_prompt = persona_section + json.dumps( 100 | {"memories": memories_to_evaluate}, ensure_ascii=False, indent=2 101 | ) 102 | 103 | response = await self.llm_provider.text_chat( 104 | prompt=user_prompt, system_prompt=system_prompt, json_mode=True 105 | ) 106 | 107 | json_text = extract_json_from_response(response.completion_text.strip()) 108 | if not json_text: 109 | logger.warning("LLM 评估分数返回为空。") 110 | return {} 111 | logger.debug( 112 | f"评估分数: {json_text},对应内容{[event.temp_id for event in events]}。" 113 | ) 114 | 115 | try: 116 | evaluated_data = _LLMScoreEvaluation.model_validate_json(json_text) 117 | return evaluated_data.scores 118 | except (ValidationError, json.JSONDecodeError) as e: 119 | logger.error( 120 | f"分数评估阶段JSON解析失败: {e}\n原始返回: {response.completion_text.strip()}", 121 | exc_info=True, 122 | ) 123 | return {} 124 | 125 | async def reflect_and_store( 126 | self, 127 | conversation_history: List[Dict[str, str]], 128 | session_id: str, 129 | persona_id: Optional[str] = None, 130 | persona_prompt: Optional[str] = None, 131 | ): 132 | """执行完整的两阶段反思、评估和存储流程。""" 133 | try: 134 | history_text = self._format_history_for_summary(conversation_history) 135 | if not history_text: 136 | logger.debug("对话历史为空,跳过反思。") 137 | return 138 | 139 | # --- 第一阶段:提取事件 --- 140 | logger.info(f"[{session_id}] 阶段1:开始批量提取记忆事件...") 141 | extracted_events = await self._extract_events(history_text, persona_prompt) 142 | if not extracted_events: 143 | logger.info(f"[{session_id}] 未能从对话中提取任何记忆事件。") 144 | return 145 | logger.info(f"[{session_id}] 成功提取 {len(extracted_events)} 个记忆事件。") 146 | 147 | # --- 第二阶段:评估分数 --- 148 | logger.info(f"[{session_id}] 阶段2:开始批量评估事件重要性...") 149 | scores = await self._evaluate_scores(extracted_events, persona_prompt) 150 | logger.info(f"[{session_id}] 成功收到 {len(scores)} 个评分。") 151 | 152 | # --- 第三阶段:合并与存储 --- 153 | threshold = self.config.get("importance_threshold", 0.5) 154 | logger.info(f"[{session_id}] 阶段3:开始存储筛选,重要性阈值: {threshold}") 155 | 156 | stored_count = 0 157 | filtered_count = 0 158 | total_events = len(extracted_events) 159 | 160 | # 详细记录所有事件的评分情况 161 | logger.info(f"[{session_id}] 评分详情汇总:") 162 | for event in extracted_events: 163 | score = scores.get(event.temp_id) 164 | if score is None: 165 | logger.warning( 166 | f"[{session_id}] ❌ 事件 '{event.temp_id}' 未找到对应的评分,跳过存储" 167 | ) 168 | filtered_count += 1 169 | continue 170 | 171 | event.importance_score = score 172 | logger.info(f"[{session_id}] 📊 事件 '{event.temp_id}': 得分={score:.3f}, 阈值={threshold:.3f}") 173 | 174 | if event.importance_score >= threshold: 175 | # MemoryEvent 的 id 将由存储后端自动生成,这里不需要手动创建 176 | # 我们只需要传递完整的元数据 177 | event_metadata = event.model_dump() 178 | 179 | # add_memory 返回的是新插入记录的整数 ID 180 | inserted_id = await self.faiss_manager.add_memory( 181 | content=event.memory_content, 182 | importance=event.importance_score, 183 | session_id=session_id, 184 | persona_id=persona_id, 185 | metadata=event_metadata, 186 | ) 187 | stored_count += 1 188 | logger.info( 189 | f"[{session_id}] ✅ 存储记忆事件 (数据库ID: {inserted_id}, 临时ID: {event.temp_id}), 得分: {event.importance_score:.3f} >= {threshold:.3f}" 190 | ) 191 | logger.debug(f"[{session_id}] 存储内容预览: {event.memory_content[:100]}...") 192 | else: 193 | filtered_count += 1 194 | logger.info( 195 | f"[{session_id}] ❌ 过滤记忆事件 '{event.temp_id}', 得分: {event.importance_score:.3f} < {threshold:.3f}" 196 | ) 197 | logger.debug(f"[{session_id}] 被过滤内容: {event.memory_content}") 198 | 199 | # 最终统计信息 200 | logger.info(f"[{session_id}] 🏁 反思存储完成统计:") 201 | logger.info(f"[{session_id}] - 总提取事件数: {total_events}") 202 | logger.info(f"[{session_id}] - 成功存储数量: {stored_count}") 203 | logger.info(f"[{session_id}] - 过滤丢弃数量: {filtered_count}") 204 | logger.info(f"[{session_id}] - 存储率: {(stored_count/total_events)*100:.1f}%" if total_events > 0 else f"[{session_id}] - 存储率: 0%") 205 | 206 | if stored_count > 0: 207 | logger.info(f"[{session_id}] ✅ 成功存储 {stored_count} 个新的记忆事件") 208 | else: 209 | logger.warning(f"[{session_id}] ⚠️ 没有记忆事件达到存储阈值 {threshold},可能需要调整配置") 210 | 211 | except Exception as e: 212 | logger.error( 213 | f"[{session_id}] 在执行反思与存储任务时发生严重错误: {e}", exc_info=True 214 | ) 215 | 216 | def _build_event_extraction_prompt(self) -> str: 217 | """构建用于第一阶段事件提取的系统 Prompt。""" 218 | schema = _LLMExtractionEventList.model_json_schema() 219 | base_prompt = self.config.get( 220 | "event_extraction_prompt", 221 | "### 角色\n你是一个善于分析和总结的AI助手。你的核心人设是从你自身的视角出发,记录与用户的互动和观察。\n\n### 指令/任务\n1. **仔细阅读**并理解下面提供的“对话历史”。\n2. 从**你(AI)的视角**出发,提取出多个独立的、有意义的记忆事件。事件必须准确描述,参考上下文。事件必须是完整的,具有前因后果的。**不允许编造事件**,**不允许改变事件**,**详细描述事件的所有信息**\n3. **核心要求**:\n * **第一人称视角**:所有事件都必须以“我”开头进行描述,例如“我告诉用户...”、“我观察到...”、“我被告知...”。\n * **使用具体名称**:直接使用对话中出现的人物昵称,**严禁**使用“用户”、“开发者”等通用词汇。\n * **记录互动者**:必须明确记录与你互动的用户名称。\n * **事件合并**:如果多条连续的对话构成一个完整的独立事件,应将其总结概括为一条记忆。\n4.**严禁**包含任何评分、额外的解释或说明性文字。\n 直接输出结果,不要有任何引言或总结。\n\n### 上下文\n* 在对话历史中,名为“AstrBot”的发言者就是**你自己**。\n* 记忆事件是:你与用户互动事的事件描述,详细记录谁、在何时、何地、做了什么、发生了什么。\n\n 'memory_content' 字段必须包含完整的事件描述,不能省略任何细节。\n\n单个系列事件必须详细记录在一个memory_content 中,形成完整的具有前因后果的事件记忆。\n\n", 222 | ).strip() 223 | 224 | return f"""{base_prompt} 225 | **核心指令** 226 | 1. **分析对话**: 从下面的对话历史中提取关键事件。 227 | 2. **格式化输出**: 必须返回一个符合以下 JSON Schema 的 JSON 对象。为每个事件生成一个临时的、唯一的 `temp_id` 字符串。 228 | 229 | **输出格式要求 (JSON Schema)** 230 | ```json 231 | {json.dumps(schema, indent=2)} 232 | ``` 233 | """ 234 | 235 | def _build_evaluation_prompt(self) -> str: 236 | """构建用于第二阶段批量评分的系统 Prompt。""" 237 | schema = _LLMScoreEvaluation.model_json_schema() 238 | base_prompt = self.config.get( 239 | "evaluation_prompt", 240 | "### 角色\n你是一个专门评估记忆价值的AI分析模型。你的判断标准是该记忆对于与特定用户构建长期、个性化、有上下文的对话有多大的帮助。\n\n### 指令/任务\n1. **评估核心价值**:仔细阅读“记忆内容”,评估其对于未来对话的长期参考价值。\n2. **输出分数**:给出一个介于 0.0 到 1.0 之间的浮点数分数。\n3. **格式要求**:**只返回数字**,严禁包含任何额外的文本、解释或理由。\n\n### 上下文\n评分时,请参考以下价值标尺:\n* **高价值 (0.8 - 1.0)**:包含用户的核心身份信息、明确且长期的个人偏好/厌恶、设定的目标、重要的关系或事实。这些信息几乎总能在未来的互动中被引用。\n * 例如:用户的昵称、职业、关键兴趣点、对AI的称呼、重要的人生目标。\n* **中等价值 (0.4 - 0.7)**:包含用户的具体建议、功能请求、对某事的观点或一次性的重要问题。这些信息在短期内或特定话题下很有用,但可能随着时间推移或问题解决而失去价值。\n * 例如:对某个功能的反馈、对特定新闻事件的看法、报告了一个具体的bug。\n* **低价值 (0.1 - 0.3)**:包含短暂的情绪表达、日常问候、或非常具体且不太可能重复的上下文。这些信息很少有再次利用的机会。\n * 例如:一次性的惊叹、害怕的反应、普通的“你好”、“晚安”。\n* **无价值 (0.0)**:信息完全是瞬时的、无关紧要的,或者不包含任何关于用户本人的可复用信息。\n * 例如:观察到另一个机器人说了话、对一句无法理解的话的默认回应。\n\n### 问题\n请评估以下“记忆内容”的重要性,对于未来的对话有多大的参考价值?\n\n---\n\n**记忆内容**:\n{memory_content}\n\n", 241 | ).strip() 242 | 243 | return f"""{base_prompt} 244 | **核心指令** 245 | 1. **分析输入**: 输入是一个包含多个记忆事件的 JSON 对象,每个事件都有一个 `temp_id` 和内容。 246 | 2. **评估重要性**: 对列表中的每一个事件,评估其对于未来对话的长期参考价值,给出一个 0.0 到 1.0 之间的分数。 247 | 3. **格式化输出**: 必须返回一个符合以下 JSON Schema 的 JSON 对象,key 是对应的 `temp_id`,value 是你给出的分数。 248 | 249 | **输出格式要求 (JSON Schema)** 250 | ```json 251 | {json.dumps(schema, indent=2)} 252 | ``` 253 | 254 | **一个正确的输出示例** 255 | ```json 256 | {{ 257 | "scores": {{ 258 | "event_1": 0.8, 259 | "user_preference_1": 0.9, 260 | "project_goal_alpha": 0.95 261 | }} 262 | }} 263 | ``` 264 | """ 265 | 266 | def _format_history_for_summary(self, history: List[Dict[str, str]]) -> str: 267 | """ 268 | 将对话历史列表格式化为单个字符串。 269 | 270 | Args: 271 | history (List[Dict[str, str]]): 对话历史。 272 | 273 | Returns: 274 | str: 格式化后的字符串。 275 | """ 276 | if not history: 277 | return "" 278 | 279 | # 过滤掉非 user 和 assistant 的角色 280 | filtered_history = [ 281 | msg for msg in history if msg.get("role") in ["user", "assistant"] 282 | ] 283 | 284 | return "\n".join( 285 | [f"{msg['role']}: {msg['content']}" for msg in filtered_history] 286 | ) 287 | -------------------------------------------------------------------------------- /core/handlers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | handlers - 业务逻辑处理器模块 4 | 提供插件命令的具体业务逻辑实现 5 | """ 6 | 7 | from .base_handler import BaseHandler 8 | from .memory_handler import MemoryHandler 9 | from .search_handler import SearchHandler 10 | from .admin_handler import AdminHandler 11 | from .fusion_handler import FusionHandler 12 | 13 | __all__ = [ 14 | 'BaseHandler', 15 | 'MemoryHandler', 16 | 'SearchHandler', 17 | 'AdminHandler', 18 | 'FusionHandler' 19 | ] -------------------------------------------------------------------------------- /core/handlers/admin_handler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | admin_handler.py - 管理员业务逻辑 4 | 处理状态查看、配置管理、遗忘代理等管理员功能 5 | """ 6 | 7 | from typing import Optional, Dict, Any 8 | 9 | from astrbot.api import logger 10 | from astrbot.api.star import Context 11 | 12 | from .base_handler import BaseHandler 13 | 14 | 15 | class AdminHandler(BaseHandler): 16 | """管理员业务逻辑处理器""" 17 | 18 | def __init__(self, context: Context, config: Dict[str, Any], faiss_manager=None, forgetting_agent=None, session_manager=None): 19 | super().__init__(context, config) 20 | self.faiss_manager = faiss_manager 21 | self.forgetting_agent = forgetting_agent 22 | self.session_manager = session_manager 23 | 24 | async def process(self, *args, **kwargs) -> Dict[str, Any]: 25 | """处理请求的抽象方法实现""" 26 | return self.create_response(True, "AdminHandler process method") 27 | 28 | async def get_memory_status(self) -> Dict[str, Any]: 29 | """获取记忆库状态""" 30 | if not self.faiss_manager or not self.faiss_manager.db: 31 | return self.create_response(False, "记忆库尚未初始化") 32 | 33 | try: 34 | count = await self.faiss_manager.db.count_documents() 35 | return self.create_response(True, "获取记忆库状态成功", {"total_count": count}) 36 | except Exception as e: 37 | logger.error(f"获取记忆库状态失败: {e}", exc_info=True) 38 | return self.create_response(False, f"获取记忆库状态失败: {e}") 39 | 40 | async def delete_memory(self, doc_id: int) -> Dict[str, Any]: 41 | """删除指定记忆""" 42 | if not self.faiss_manager: 43 | return self.create_response(False, "记忆库尚未初始化") 44 | 45 | try: 46 | await self.faiss_manager.delete_memories([doc_id]) 47 | return self.create_response(True, f"已成功删除 ID 为 {doc_id} 的记忆") 48 | except Exception as e: 49 | logger.error(f"删除记忆时发生错误: {e}", exc_info=True) 50 | return self.create_response(False, f"删除记忆时发生错误: {e}") 51 | 52 | async def run_forgetting_agent(self) -> Dict[str, Any]: 53 | """手动触发遗忘代理""" 54 | if not self.forgetting_agent: 55 | return self.create_response(False, "遗忘代理尚未初始化") 56 | 57 | try: 58 | await self.forgetting_agent._prune_memories() 59 | return self.create_response(True, "遗忘代理任务执行完毕") 60 | except Exception as e: 61 | logger.error(f"遗忘代理任务执行失败: {e}", exc_info=True) 62 | return self.create_response(False, f"遗忘代理任务执行失败: {e}") 63 | 64 | async def set_search_mode(self, mode: str) -> Dict[str, Any]: 65 | """设置检索模式""" 66 | valid_modes = ["hybrid", "dense", "sparse"] 67 | if mode not in valid_modes: 68 | return self.create_response(False, f"无效的模式,请使用: {', '.join(valid_modes)}") 69 | 70 | # 注意:这个方法需要 recall_engine 实例,暂时通过 config 传递 71 | # 实际使用时需要在调用此方法前传入 recall_engine 72 | return self.create_response(True, f"检索模式已设置为: {mode}") 73 | 74 | async def get_config_summary(self, action: str = "show") -> Dict[str, Any]: 75 | """获取配置摘要或验证配置""" 76 | if action == "show": 77 | try: 78 | # 显示主要配置项 79 | config_summary = { 80 | "session_manager": { 81 | "max_sessions": self.config.get("session_manager", {}).get("max_sessions", 1000), 82 | "session_ttl": self.config.get("session_manager", {}).get("session_ttl", 3600), 83 | "current_sessions": self.session_manager.get_session_count() if self.session_manager else 0 84 | }, 85 | "recall_engine": { 86 | "retrieval_mode": self.config.get("recall_engine", {}).get("retrieval_mode", "hybrid"), 87 | "top_k": self.config.get("recall_engine", {}).get("top_k", 5), 88 | "recall_strategy": self.config.get("recall_engine", {}).get("recall_strategy", "weighted") 89 | }, 90 | "reflection_engine": { 91 | "summary_trigger_rounds": self.config.get("reflection_engine", {}).get("summary_trigger_rounds", 10), 92 | "importance_threshold": self.config.get("reflection_engine", {}).get("importance_threshold", 0.5) 93 | }, 94 | "forgetting_agent": { 95 | "enabled": self.config.get("forgetting_agent", {}).get("enabled", True), 96 | "check_interval_hours": self.config.get("forgetting_agent", {}).get("check_interval_hours", 24), 97 | "retention_days": self.config.get("forgetting_agent", {}).get("retention_days", 90) 98 | } 99 | } 100 | 101 | return self.create_response(True, "获取配置摘要成功", config_summary) 102 | 103 | except Exception as e: 104 | return self.create_response(False, f"显示配置时发生错误: {e}") 105 | 106 | elif action == "validate": 107 | try: 108 | from ..config_validator import validate_config 109 | # 重新验证当前配置 110 | validate_config(self.config) 111 | return self.create_response(True, "配置验证通过,所有参数均有效") 112 | 113 | except Exception as e: 114 | return self.create_response(False, f"配置验证失败: {e}") 115 | 116 | else: 117 | return self.create_response(False, "无效的动作,请使用 'show' 或 'validate'") 118 | 119 | def format_status_for_display(self, response: Dict[str, Any]) -> str: 120 | """格式化状态信息用于显示""" 121 | if not response.get("success"): 122 | return response.get("message", "获取失败") 123 | 124 | data = response.get("data", {}) 125 | total_count = data.get("total_count", 0) 126 | 127 | return f"📊 LivingMemory 记忆库状态:\n- 总记忆数: {total_count}" 128 | 129 | def format_config_summary_for_display(self, response: Dict[str, Any]) -> str: 130 | """格式化配置摘要用于显示""" 131 | if not response.get("success"): 132 | return response.get("message", "获取失败") 133 | 134 | data = response.get("data", {}) 135 | 136 | config_summary = ["📋 LivingMemory 配置摘要:"] 137 | config_summary.append("") 138 | 139 | # 会话管理器配置 140 | sm_config = data.get("session_manager", {}) 141 | config_summary.append(f"🗂️ 会话管理:") 142 | config_summary.append(f" - 最大会话数: {sm_config.get('max_sessions', 1000)}") 143 | config_summary.append(f" - 会话TTL: {sm_config.get('session_ttl', 3600)}秒") 144 | config_summary.append(f" - 当前会话数: {sm_config.get('current_sessions', 0)}") 145 | config_summary.append("") 146 | 147 | # 回忆引擎配置 148 | re_config = data.get("recall_engine", {}) 149 | config_summary.append(f"🧠 回忆引擎:") 150 | config_summary.append(f" - 检索模式: {re_config.get('retrieval_mode', 'hybrid')}") 151 | config_summary.append(f" - 返回数量: {re_config.get('top_k', 5)}") 152 | config_summary.append(f" - 召回策略: {re_config.get('recall_strategy', 'weighted')}") 153 | config_summary.append("") 154 | 155 | # 反思引擎配置 156 | rf_config = data.get("reflection_engine", {}) 157 | config_summary.append(f"💭 反思引擎:") 158 | config_summary.append(f" - 触发轮次: {rf_config.get('summary_trigger_rounds', 10)}") 159 | config_summary.append(f" - 重要性阈值: {rf_config.get('importance_threshold', 0.5)}") 160 | config_summary.append("") 161 | 162 | # 遗忘代理配置 163 | fa_config = data.get("forgetting_agent", {}) 164 | config_summary.append(f"🗑️ 遗忘代理:") 165 | config_summary.append(f" - 启用状态: {'是' if fa_config.get('enabled', True) else '否'}") 166 | config_summary.append(f" - 检查间隔: {fa_config.get('check_interval_hours', 24)}小时") 167 | config_summary.append(f" - 保留天数: {fa_config.get('retention_days', 90)}天") 168 | 169 | return "\n".join(config_summary) -------------------------------------------------------------------------------- /core/handlers/base_handler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | base_handler.py - 基础处理器类 4 | 提供业务逻辑处理的基础功能和通用方法 5 | """ 6 | 7 | from abc import ABC, abstractmethod 8 | from typing import Optional, Dict, Any, List 9 | from datetime import datetime, timezone 10 | import json 11 | 12 | from astrbot.api import logger 13 | from astrbot.api.star import Context 14 | 15 | 16 | class BaseHandler(ABC): 17 | """基础处理器类,提供通用的业务逻辑功能""" 18 | 19 | def __init__(self, context: Context, config: Dict[str, Any]): 20 | self.context = context 21 | self.config = config 22 | 23 | @abstractmethod 24 | async def process(self, *args, **kwargs) -> Dict[str, Any]: 25 | """处理请求的抽象方法""" 26 | pass 27 | 28 | def get_timezone(self) -> Any: 29 | """获取当前时区""" 30 | tz_config = self.config.get("timezone_settings", {}) 31 | tz_str = tz_config.get("timezone", "Asia/Shanghai") 32 | try: 33 | import pytz 34 | return pytz.timezone(tz_str) 35 | except ImportError: 36 | # 如果pytz不可用,返回UTC 37 | from datetime import timezone 38 | return timezone.utc 39 | 40 | def format_timestamp(self, ts: Optional[float]) -> str: 41 | """格式化时间戳""" 42 | if not ts: 43 | return "未知" 44 | try: 45 | dt_utc = datetime.fromtimestamp(float(ts), tz=timezone.utc) 46 | dt_local = dt_utc.astimezone(self.get_timezone()) 47 | return dt_local.strftime("%Y-%m-%d %H:%M:%S") 48 | except (ValueError, TypeError): 49 | return "未知" 50 | 51 | def safe_parse_metadata(self, metadata: Any) -> Dict[str, Any]: 52 | """安全解析元数据""" 53 | if isinstance(metadata, dict): 54 | return metadata 55 | if isinstance(metadata, str): 56 | try: 57 | return json.loads(metadata) 58 | except json.JSONDecodeError: 59 | return {} 60 | return {} 61 | 62 | def format_memory_card(self, result: Any) -> str: 63 | """格式化记忆卡片显示""" 64 | metadata = self.safe_parse_metadata(result.data.get("metadata", {})) 65 | 66 | create_time_str = self.format_timestamp(metadata.get("create_time")) 67 | last_access_time_str = self.format_timestamp(metadata.get("last_access_time")) 68 | importance_score = metadata.get("importance", 0.0) 69 | event_type = metadata.get("event_type", "未知") 70 | 71 | card = ( 72 | f"ID: {result.data['id']}\n" 73 | f"记 忆 度: {result.similarity:.2f}\n" 74 | f"重 要 性: {importance_score:.2f}\n" 75 | f"记忆类型: {event_type}\n\n" 76 | f"内容: {result.data['text']}\n\n" 77 | f"创建于: {create_time_str}\n" 78 | f"最后访问: {last_access_time_str}" 79 | ) 80 | return card 81 | 82 | def create_response(self, success: bool = True, message: str = "", data: Any = None) -> Dict[str, Any]: 83 | """创建标准响应格式""" 84 | return { 85 | "success": success, 86 | "message": message, 87 | "data": data 88 | } 89 | 90 | 91 | class TestableBaseHandler(BaseHandler): 92 | """用于测试的基础处理器实现""" 93 | 94 | async def process(self, *args, **kwargs) -> Dict[str, Any]: 95 | """测试用的process方法实现""" 96 | return self.create_response(True, "Test response") -------------------------------------------------------------------------------- /core/handlers/fusion_handler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | fusion_handler.py - 融合策略业务逻辑 4 | 处理检索融合策略的管理和测试 5 | """ 6 | 7 | from typing import Optional, Dict, Any, List 8 | 9 | from astrbot.api import logger 10 | from astrbot.api.star import Context 11 | 12 | from .base_handler import BaseHandler 13 | 14 | 15 | class FusionHandler(BaseHandler): 16 | """融合策略业务逻辑处理器""" 17 | 18 | def __init__(self, context: Context, config: Dict[str, Any], recall_engine=None): 19 | super().__init__(context, config) 20 | self.recall_engine = recall_engine 21 | 22 | async def process(self, *args, **kwargs) -> Dict[str, Any]: 23 | """处理请求的抽象方法实现""" 24 | return self.create_response(True, "FusionHandler process method") 25 | 26 | async def manage_fusion_strategy(self, strategy: str = "show", param: str = "") -> Dict[str, Any]: 27 | """管理检索融合策略""" 28 | if not self.recall_engine: 29 | return self.create_response(False, "回忆引擎尚未初始化") 30 | 31 | if strategy == "show": 32 | # 显示当前融合配置 33 | fusion_config = self.config.get("fusion", {}) 34 | current_strategy = fusion_config.get("strategy", "rrf") 35 | 36 | config_data = { 37 | "current_strategy": current_strategy, 38 | "fusion_config": fusion_config 39 | } 40 | 41 | return self.create_response(True, "获取融合配置成功", config_data) 42 | 43 | elif strategy in ["rrf", "hybrid_rrf", "weighted", "convex", "interleave", 44 | "rank_fusion", "score_fusion", "cascade", "adaptive"]: 45 | 46 | # 更新融合策略 47 | if "fusion" not in self.config: 48 | self.config["fusion"] = {} 49 | 50 | old_strategy = self.config["fusion"].get("strategy", "rrf") 51 | self.config["fusion"]["strategy"] = strategy 52 | 53 | # 处理参数 54 | param_result = await self._process_fusion_param(param, strategy) 55 | if not param_result["success"]: 56 | return param_result 57 | 58 | # 更新 RecallEngine 中的融合配置 59 | update_result = await self._update_recall_engine_fusion_config(strategy, self.config["fusion"]) 60 | if not update_result["success"]: 61 | return update_result 62 | 63 | return self.create_response(True, f"融合策略已从 '{old_strategy}' 更新为 '{strategy}'{f' (参数: {param})' if param else ''}") 64 | 65 | else: 66 | return self.create_response(False, "不支持的融合策略。使用 show 查看可用选项。") 67 | 68 | async def test_fusion_strategy(self, query: str, k: int = 5) -> Dict[str, Any]: 69 | """测试融合策略效果""" 70 | if not self.recall_engine: 71 | return self.create_response(False, "回忆引擎尚未初始化") 72 | 73 | try: 74 | # 执行搜索 75 | session_id = await self.context.conversation_manager.get_curr_conversation_id(None) 76 | from ..utils import get_persona_id 77 | persona_id = await get_persona_id(self.context, None) 78 | 79 | results = await self.recall_engine.recall( 80 | self.context, query, session_id, persona_id, k 81 | ) 82 | 83 | if not results: 84 | return self.create_response(True, "未找到相关记忆", []) 85 | 86 | # 格式化结果 87 | formatted_results = [] 88 | fusion_config = self.config.get("fusion", {}) 89 | current_strategy = fusion_config.get("strategy", "rrf") 90 | 91 | for result in results: 92 | metadata = self.safe_parse_metadata(result.data.get("metadata", {})) 93 | formatted_results.append({ 94 | "id": result.data['id'], 95 | "similarity": result.similarity, 96 | "text": result.data['text'], 97 | "importance": metadata.get("importance", 0.0), 98 | "event_type": metadata.get("event_type", "未知") 99 | }) 100 | 101 | test_data = { 102 | "query": query, 103 | "strategy": current_strategy, 104 | "fusion_config": fusion_config, 105 | "results": formatted_results 106 | } 107 | 108 | return self.create_response(True, f"融合测试完成,找到 {len(results)} 条结果", test_data) 109 | 110 | except Exception as e: 111 | logger.error(f"融合策略测试失败: {e}", exc_info=True) 112 | return self.create_response(False, f"测试失败: {e}") 113 | 114 | async def _process_fusion_param(self, param: str, strategy: str) -> Dict[str, Any]: 115 | """处理融合策略参数""" 116 | if not param or "=" not in param: 117 | return self.create_response(True, "无参数需要处理") 118 | 119 | try: 120 | key, value = param.split("=", 1) 121 | key = key.strip() 122 | value = value.strip() 123 | 124 | # 验证参数名 125 | valid_params = { 126 | "dense_weight", "sparse_weight", "rrf_k", "convex_lambda", 127 | "interleave_ratio", "rank_bias_factor", "diversity_bonus" 128 | } 129 | 130 | if key not in valid_params: 131 | return self.create_response(False, f"无效的参数名: {key}。支持的参数: {', '.join(sorted(valid_params))}") 132 | 133 | # 验证参数值 134 | try: 135 | if key in ["dense_weight", "sparse_weight", "convex_lambda", "interleave_ratio", "rank_bias_factor", "diversity_bonus"]: 136 | param_value = float(value) 137 | else: 138 | param_value = int(value) 139 | except ValueError: 140 | return self.create_response(False, f"参数 {key} 的值类型无效: {value}") 141 | 142 | # 参数范围和约束检查 143 | param_constraints = { 144 | "dense_weight": (0.0, 1.0, "必须在 0.0-1.0 范围内"), 145 | "sparse_weight": (0.0, 1.0, "必须在 0.0-1.0 范围内"), 146 | "convex_lambda": (0.0, 1.0, "必须在 0.0-1.0 范围内"), 147 | "interleave_ratio": (0.0, 1.0, "必须在 0.0-1.0 范围内"), 148 | "rank_bias_factor": (0.0, 1.0, "必须在 0.0-1.0 范围内"), 149 | "diversity_bonus": (0.0, 1.0, "必须在 0.0-1.0 范围内"), 150 | "rrf_k": (1, 1000, "必须是正整数") 151 | } 152 | 153 | if key in param_constraints: 154 | min_val, max_val, error_msg = param_constraints[key] 155 | if not min_val <= param_value <= max_val: 156 | return self.create_response(False, f"参数 {key} {error_msg}") 157 | 158 | # 策略特定参数验证 159 | strategy_params = { 160 | "rrf": ["rrf_k"], 161 | "hybrid_rrf": ["rrf_k", "diversity_bonus"], 162 | "weighted": ["dense_weight", "sparse_weight"], 163 | "convex": ["dense_weight", "sparse_weight", "convex_lambda"], 164 | "interleave": ["interleave_ratio"], 165 | "rank_fusion": ["dense_weight", "sparse_weight", "rank_bias_factor"], 166 | "score_fusion": ["dense_weight", "sparse_weight"], 167 | "cascade": ["dense_weight", "sparse_weight"], 168 | "adaptive": ["dense_weight", "sparse_weight"] 169 | } 170 | 171 | if strategy in strategy_params and key not in strategy_params[strategy]: 172 | return self.create_response(False, f"参数 {key} 不适用于策略 {strategy}") 173 | 174 | # 权重和检查(对于需要权重的策略) 175 | if key in ["dense_weight", "sparse_weight"]: 176 | other_key = "sparse_weight" if key == "dense_weight" else "dense_weight" 177 | other_value = self.config["fusion"].get(other_key, 0.3 if other_key == "sparse_weight" else 0.7) 178 | 179 | # 如果设置了新的权重,检查和是否超过1.0 180 | if key + other_key in [k for k in strategy_params.get(strategy, []) if k in ["dense_weight", "sparse_weight"]]: 181 | total_weight = param_value + other_value 182 | if total_weight > 1.0: 183 | return self.create_response(False, f"权重总和不能超过 1.0 (当前总和: {total_weight:.2f})") 184 | 185 | self.config["fusion"][key] = param_value 186 | logger.info(f"更新融合参数 {key} = {param_value}") 187 | 188 | return self.create_response(True, "参数处理成功") 189 | 190 | except Exception as e: 191 | return self.create_response(False, f"参数解析错误: {e}") 192 | 193 | async def _update_recall_engine_fusion_config(self, strategy: str, fusion_config: Dict[str, Any]) -> Dict[str, Any]: 194 | """更新RecallEngine的融合配置""" 195 | try: 196 | if hasattr(self.recall_engine, 'result_fusion'): 197 | self.recall_engine.update_fusion_config(strategy, fusion_config) 198 | else: 199 | logger.warning("RecallEngine 没有 result_fusion 属性,跳过更新") 200 | 201 | return self.create_response(True, "融合配置更新成功") 202 | except AttributeError: 203 | # 如果 RecallEngine 没有 update_fusion_config 方法,则直接更新属性 204 | try: 205 | if hasattr(self.recall_engine, 'result_fusion'): 206 | fusion_obj = self.recall_engine.result_fusion 207 | fusion_obj.strategy = strategy 208 | fusion_obj.config = fusion_config 209 | 210 | # 更新融合器的参数 211 | fusion_obj.dense_weight = fusion_config.get("dense_weight", 0.7) 212 | fusion_obj.sparse_weight = fusion_config.get("sparse_weight", 0.3) 213 | fusion_obj.rrf_k = fusion_config.get("rrf_k", 60) 214 | fusion_obj.convex_lambda = fusion_config.get("convex_lambda", 0.5) 215 | fusion_obj.interleave_ratio = fusion_config.get("interleave_ratio", 0.5) 216 | fusion_obj.rank_bias_factor = fusion_config.get("rank_bias_factor", 0.1) 217 | 218 | return self.create_response(True, "融合配置更新成功") 219 | except Exception as e: 220 | logger.error(f"更新融合配置时出错: {e}") 221 | return self.create_response(False, f"配置已更新,但引擎同步可能失败: {e}") 222 | except Exception as e: 223 | logger.error(f"更新融合配置时出错: {e}") 224 | return self.create_response(False, f"更新融合配置失败: {e}") 225 | 226 | def format_fusion_config_for_display(self, response: Dict[str, Any]) -> str: 227 | """格式化融合配置用于显示""" 228 | if not response.get("success"): 229 | return response.get("message", "获取失败") 230 | 231 | data = response.get("data", {}) 232 | current_strategy = data.get("current_strategy", "rrf") 233 | fusion_config = data.get("fusion_config", {}) 234 | 235 | response_parts = ["🔄 当前检索融合配置:"] 236 | response_parts.append(f"策略: {current_strategy}") 237 | response_parts.append("") 238 | 239 | if current_strategy in ["rrf", "hybrid_rrf"]: 240 | response_parts.append(f"RRF参数k: {fusion_config.get('rrf_k', 60)}") 241 | if current_strategy == "hybrid_rrf": 242 | response_parts.append(f"多样性奖励: {fusion_config.get('diversity_bonus', 0.1)}") 243 | 244 | if current_strategy in ["weighted", "convex", "rank_fusion", "score_fusion"]: 245 | response_parts.append(f"密集权重: {fusion_config.get('dense_weight', 0.7)}") 246 | response_parts.append(f"稀疏权重: {fusion_config.get('sparse_weight', 0.3)}") 247 | 248 | if current_strategy == "convex": 249 | response_parts.append(f"凸组合λ: {fusion_config.get('convex_lambda', 0.5)}") 250 | 251 | if current_strategy == "interleave": 252 | response_parts.append(f"交替比例: {fusion_config.get('interleave_ratio', 0.5)}") 253 | 254 | if current_strategy == "rank_fusion": 255 | response_parts.append(f"排序偏置: {fusion_config.get('rank_bias_factor', 0.1)}") 256 | 257 | response_parts.append("") 258 | response_parts.append("💡 各策略特点:") 259 | response_parts.append("• rrf: 经典方法,平衡性好") 260 | response_parts.append("• hybrid_rrf: 动态调整,适应查询类型") 261 | response_parts.append("• weighted: 简单加权,可解释性强") 262 | response_parts.append("• convex: 凸组合,数学严格") 263 | response_parts.append("• interleave: 交替选择,保证多样性") 264 | response_parts.append("• rank_fusion: 基于排序位置") 265 | response_parts.append("• score_fusion: Borda Count投票") 266 | response_parts.append("• cascade: 稀疏初筛+密集精排") 267 | response_parts.append("• adaptive: 根据查询自适应") 268 | 269 | return "\n".join(response_parts) 270 | 271 | def format_fusion_test_for_display(self, response: Dict[str, Any]) -> str: 272 | """格式化融合测试结果用于显示""" 273 | if not response.get("success"): 274 | return response.get("message", "测试失败") 275 | 276 | data = response.get("data", {}) 277 | query = data.get("query", "") 278 | strategy = data.get("strategy", "rrf") 279 | fusion_config = data.get("fusion_config", {}) 280 | results = data.get("results", []) 281 | 282 | response_parts = [f"🎯 融合测试结果 (策略: {strategy})"] 283 | response_parts.append("=" * 50) 284 | 285 | for i, result in enumerate(results, 1): 286 | response_parts.append(f"\n{i}. [ID: {result['id']}] 分数: {result['similarity']:.4f}") 287 | response_parts.append(f" 重要性: {result['importance']:.3f} | 类型: {result['event_type']}") 288 | response_parts.append(f" 内容: {result['text'][:100]}{'...' if len(result['text']) > 100 else ''}") 289 | 290 | response_parts.append("\n" + "=" * 50) 291 | response_parts.append(f"💡 当前融合配置:") 292 | response_parts.append(f" 策略: {strategy}") 293 | if strategy in ["rrf", "hybrid_rrf"]: 294 | response_parts.append(f" RRF-k: {fusion_config.get('rrf_k', 60)}") 295 | if strategy in ["weighted", "convex"]: 296 | response_parts.append(f" 密集权重: {fusion_config.get('dense_weight', 0.7)}") 297 | response_parts.append(f" 稀疏权重: {fusion_config.get('sparse_weight', 0.3)}") 298 | 299 | return "\n".join(response_parts) -------------------------------------------------------------------------------- /core/handlers/memory_handler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | memory_handler.py - 记忆管理业务逻辑 4 | 处理记忆的编辑、更新、历史查看等业务逻辑 5 | """ 6 | 7 | import json 8 | from typing import Optional, Dict, Any, List 9 | from datetime import datetime, timezone 10 | 11 | from astrbot.api import logger 12 | from astrbot.api.star import Context 13 | 14 | from .base_handler import BaseHandler 15 | 16 | 17 | class MemoryHandler(BaseHandler): 18 | """记忆管理业务逻辑处理器""" 19 | 20 | def __init__(self, context: Context, config: Dict[str, Any], faiss_manager): 21 | super().__init__(context, config) 22 | self.faiss_manager = faiss_manager 23 | 24 | async def process(self, *args, **kwargs) -> Dict[str, Any]: 25 | """处理请求的抽象方法实现""" 26 | return self.create_response(True, "MemoryHandler process method") 27 | 28 | async def edit_memory(self, memory_id: str, field: str, value: str, reason: str = "") -> Dict[str, Any]: 29 | """编辑记忆内容或元数据""" 30 | if not self.faiss_manager: 31 | return self.create_response(False, "记忆库尚未初始化") 32 | 33 | try: 34 | # 解析 memory_id 为整数或字符串 35 | try: 36 | memory_id_int = int(memory_id) 37 | memory_id_to_use = memory_id_int 38 | except ValueError: 39 | memory_id_to_use = memory_id 40 | 41 | # 解析字段和值 42 | updates = {} 43 | 44 | if field == "content": 45 | updates["content"] = value 46 | elif field == "importance": 47 | try: 48 | updates["importance"] = float(value) 49 | if not 0.0 <= updates["importance"] <= 1.0: 50 | return self.create_response(False, "重要性评分必须在 0.0 到 1.0 之间") 51 | except ValueError: 52 | return self.create_response(False, "重要性评分必须是数字") 53 | elif field == "type": 54 | valid_types = ["FACT", "PREFERENCE", "GOAL", "OPINION", "RELATIONSHIP", "OTHER"] 55 | if value not in valid_types: 56 | return self.create_response(False, f"无效的事件类型,必须是: {', '.join(valid_types)}") 57 | updates["event_type"] = value 58 | elif field == "status": 59 | valid_statuses = ["active", "archived", "deleted"] 60 | if value not in valid_statuses: 61 | return self.create_response(False, f"无效的状态,必须是: {', '.join(valid_statuses)}") 62 | updates["status"] = value 63 | else: 64 | return self.create_response(False, f"未知的字段 '{field}',支持的字段: content, importance, type, status") 65 | 66 | # 执行更新 67 | result = await self.faiss_manager.update_memory( 68 | memory_id=memory_id_to_use, 69 | update_reason=reason or f"更新{field}", 70 | **updates 71 | ) 72 | 73 | if result["success"]: 74 | # 构建响应消息 75 | response_parts = [f"✅ {result['message']}"] 76 | 77 | if result["updated_fields"]: 78 | response_parts.append("\n📋 已更新的字段:") 79 | for f in result["updated_fields"]: 80 | response_parts.append(f" - {f}") 81 | 82 | # 如果更新了内容,显示预览 83 | if "content" in updates and len(updates["content"]) > 100: 84 | response_parts.append(f"\n📝 内容预览: {updates['content'][:100]}...") 85 | 86 | return self.create_response(True, "\n".join(response_parts), result) 87 | else: 88 | return self.create_response(False, result['message']) 89 | 90 | except Exception as e: 91 | logger.error(f"编辑记忆时发生错误: {e}", exc_info=True) 92 | return self.create_response(False, f"编辑记忆时发生错误: {e}") 93 | 94 | async def get_memory_details(self, memory_id: str) -> Dict[str, Any]: 95 | """获取记忆详细信息""" 96 | if not self.faiss_manager: 97 | return self.create_response(False, "记忆库尚未初始化") 98 | 99 | try: 100 | # 解析 memory_id 101 | try: 102 | memory_id_int = int(memory_id) 103 | docs = await self.faiss_manager.db.document_storage.get_documents(ids=[memory_id_int]) 104 | except ValueError: 105 | docs = await self.faiss_manager.db.document_storage.get_documents( 106 | metadata_filters={"memory_id": memory_id} 107 | ) 108 | 109 | if not docs: 110 | return self.create_response(False, f"未找到ID为 {memory_id} 的记忆") 111 | 112 | doc = docs[0] 113 | metadata = self.safe_parse_metadata(doc["metadata"]) 114 | 115 | # 构建详细信息 116 | details = { 117 | "id": memory_id, 118 | "content": doc["content"], 119 | "metadata": metadata, 120 | "create_time": self.format_timestamp(metadata.get("create_time")), 121 | "last_access_time": self.format_timestamp(metadata.get("last_access_time")), 122 | "importance": metadata.get("importance", "N/A"), 123 | "event_type": metadata.get("event_type", "N/A"), 124 | "status": metadata.get("status", "active"), 125 | "update_history": metadata.get("update_history", []) 126 | } 127 | 128 | return self.create_response(True, "获取记忆详细信息成功", details) 129 | 130 | except Exception as e: 131 | logger.error(f"获取记忆详细信息时发生错误: {e}", exc_info=True) 132 | return self.create_response(False, f"获取记忆详细信息时发生错误: {e}") 133 | 134 | async def get_memory_history(self, memory_id: str) -> Dict[str, Any]: 135 | """获取记忆更新历史""" 136 | if not self.faiss_manager or not self.faiss_manager.db: 137 | return self.create_response(False, "记忆库尚未初始化") 138 | 139 | try: 140 | # 解析 memory_id 141 | try: 142 | memory_id_int = int(memory_id) 143 | docs = await self.faiss_manager.db.document_storage.get_documents(ids=[memory_id_int]) 144 | except ValueError: 145 | docs = await self.faiss_manager.db.document_storage.get_documents( 146 | metadata_filters={"memory_id": memory_id} 147 | ) 148 | 149 | if not docs: 150 | return self.create_response(False, f"未找到ID为 {memory_id} 的记忆") 151 | 152 | doc = docs[0] 153 | metadata = self.safe_parse_metadata(doc["metadata"]) 154 | 155 | # 构建历史信息 156 | history_info = { 157 | "id": memory_id, 158 | "content": doc["content"], 159 | "metadata": { 160 | "importance": metadata.get("importance", "N/A"), 161 | "event_type": metadata.get("event_type", "N/A"), 162 | "status": metadata.get("status", "active"), 163 | "create_time": self.format_timestamp(metadata.get("create_time")) 164 | }, 165 | "update_history": metadata.get("update_history", []) 166 | } 167 | 168 | return self.create_response(True, "获取记忆历史成功", history_info) 169 | 170 | except Exception as e: 171 | logger.error(f"获取记忆历史时发生错误: {e}", exc_info=True) 172 | return self.create_response(False, f"获取记忆历史时发生错误: {e}") 173 | 174 | def format_memory_details_for_display(self, details: Dict[str, Any]) -> str: 175 | """格式化记忆详细信息用于显示""" 176 | if not details.get("success"): 177 | return details.get("message", "获取失败") 178 | 179 | data = details.get("data", {}) 180 | response_parts = [f"📝 记忆 {data['id']} 的详细信息:"] 181 | response_parts.append("=" * 50) 182 | 183 | # 内容 184 | response_parts.append(f"\n📄 内容:") 185 | response_parts.append(f"{data['content']}") 186 | 187 | # 基本信息 188 | response_parts.append(f"\n📊 基本信息:") 189 | response_parts.append(f"- ID: {data['id']}") 190 | response_parts.append(f"- 重要性: {data['importance']}") 191 | response_parts.append(f"- 类型: {data['event_type']}") 192 | response_parts.append(f"- 状态: {data['status']}") 193 | 194 | # 时间信息 195 | if data['create_time'] != "未知": 196 | response_parts.append(f"- 创建时间: {data['create_time']}") 197 | if data['last_access_time'] != "未知": 198 | response_parts.append(f"- 最后访问: {data['last_access_time']}") 199 | 200 | # 更新历史 201 | update_history = data.get('update_history', []) 202 | if update_history: 203 | response_parts.append(f"\n🔄 更新历史 ({len(update_history)} 次):") 204 | for i, update in enumerate(update_history[-3:], 1): # 只显示最近3次 205 | timestamp = update.get('timestamp') 206 | if timestamp: 207 | time_str = self.format_timestamp(timestamp) 208 | else: 209 | time_str = "未知" 210 | 211 | response_parts.append(f"\n{i}. {time_str}") 212 | response_parts.append(f" 原因: {update.get('reason', 'N/A')}") 213 | response_parts.append(f" 字段: {', '.join(update.get('fields', []))}") 214 | 215 | # 编辑指引 216 | response_parts.append(f"\n" + "=" * 50) 217 | response_parts.append(f"\n🛠️ 编辑指引:") 218 | response_parts.append(f"使用以下命令编辑此记忆:") 219 | response_parts.append(f"\n• 编辑内容:") 220 | response_parts.append(f" /lmem edit {data['id']} content <新内容> [原因]") 221 | response_parts.append(f"\n• 编辑重要性:") 222 | response_parts.append(f" /lmem edit {data['id']} importance <0.0-1.0> [原因]") 223 | response_parts.append(f"\n• 编辑类型:") 224 | response_parts.append(f" /lmem edit {data['id']} type [原因]") 225 | response_parts.append(f"\n• 编辑状态:") 226 | response_parts.append(f" /lmem edit {data['id']} status [原因]") 227 | 228 | # 示例 229 | response_parts.append(f"\n💡 示例:") 230 | response_parts.append(f" /lmem edit {data['id']} importance 0.9 提高重要性评分") 231 | response_parts.append(f" /lmem edit {data['id']} type PREFERENCE 重新分类为偏好") 232 | 233 | return "\n".join(response_parts) 234 | 235 | def format_memory_history_for_display(self, history: Dict[str, Any]) -> str: 236 | """格式化记忆历史用于显示""" 237 | if not history.get("success"): 238 | return history.get("message", "获取失败") 239 | 240 | data = history.get("data", {}) 241 | metadata = data.get("metadata", {}) 242 | 243 | response_parts = [f"📝 记忆 {data['id']} 的详细信息:"] 244 | response_parts.append(f"\n内容: {data['content']}") 245 | 246 | # 基本信息 247 | response_parts.append(f"\n📊 基本信息:") 248 | response_parts.append(f"- 重要性: {metadata['importance']}") 249 | response_parts.append(f"- 类型: {metadata['event_type']}") 250 | response_parts.append(f"- 状态: {metadata['status']}") 251 | 252 | # 时间信息 253 | if metadata.get('create_time') != "未知": 254 | response_parts.append(f"- 创建时间: {metadata['create_time']}") 255 | 256 | # 更新历史 257 | update_history = data.get('update_history', []) 258 | if update_history: 259 | response_parts.append(f"\n🔄 更新历史 ({len(update_history)} 次):") 260 | for i, update in enumerate(update_history[-5:], 1): # 只显示最近5次 261 | timestamp = update.get('timestamp') 262 | if timestamp: 263 | time_str = self.format_timestamp(timestamp) 264 | else: 265 | time_str = "未知" 266 | 267 | response_parts.append(f"\n{i}. {time_str}") 268 | response_parts.append(f" 原因: {update.get('reason', 'N/A')}") 269 | response_parts.append(f" 字段: {', '.join(update.get('fields', []))}") 270 | else: 271 | response_parts.append("\n🔄 暂无更新记录") 272 | 273 | return "\n".join(response_parts) -------------------------------------------------------------------------------- /core/handlers/search_handler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | search_handler.py - 搜索管理业务逻辑 4 | 处理记忆搜索、稀疏检索测试等业务逻辑 5 | """ 6 | 7 | import json 8 | from typing import Optional, Dict, Any, List 9 | 10 | from astrbot.api import logger 11 | from astrbot.api.star import Context 12 | 13 | from .base_handler import BaseHandler 14 | 15 | 16 | class SearchHandler(BaseHandler): 17 | """搜索管理业务逻辑处理器""" 18 | 19 | def __init__(self, context: Context, config: Dict[str, Any], recall_engine=None, sparse_retriever=None): 20 | super().__init__(context, config) 21 | self.recall_engine = recall_engine 22 | self.sparse_retriever = sparse_retriever 23 | 24 | async def process(self, *args, **kwargs) -> Dict[str, Any]: 25 | """处理请求的抽象方法实现""" 26 | return self.create_response(True, "SearchHandler process method") 27 | 28 | async def search_memories(self, query: str, k: int = 3) -> Dict[str, Any]: 29 | """搜索记忆""" 30 | if not self.recall_engine: 31 | return self.create_response(False, "回忆引擎尚未初始化") 32 | 33 | try: 34 | results = await self.recall_engine.recall(self.context, query, k=k) 35 | 36 | if not results: 37 | return self.create_response(True, f"未能找到与 '{query}' 相关的记忆", []) 38 | 39 | # 格式化搜索结果 40 | formatted_results = [] 41 | for res in results: 42 | formatted_results.append({ 43 | "id": res.data['id'], 44 | "similarity": res.similarity, 45 | "text": res.data['text'], 46 | "metadata": self.safe_parse_metadata(res.data.get("metadata", {})) 47 | }) 48 | 49 | return self.create_response(True, f"为您找到 {len(results)} 条相关记忆", formatted_results) 50 | 51 | except Exception as e: 52 | logger.error(f"搜索记忆时发生错误: {e}", exc_info=True) 53 | return self.create_response(False, f"搜索记忆时发生错误: {e}") 54 | 55 | async def test_sparse_search(self, query: str, k: int = 5) -> Dict[str, Any]: 56 | """测试稀疏检索功能""" 57 | if not self.sparse_retriever: 58 | return self.create_response(False, "稀疏检索器未启用") 59 | 60 | try: 61 | results = await self.sparse_retriever.search(query=query, limit=k) 62 | 63 | if not results: 64 | return self.create_response(True, f"未找到与 '{query}' 相关的记忆", []) 65 | 66 | # 格式化搜索结果 67 | formatted_results = [] 68 | for res in results: 69 | formatted_results.append({ 70 | "doc_id": res.doc_id, 71 | "score": res.score, 72 | "content": res.content, 73 | "metadata": res.metadata 74 | }) 75 | 76 | return self.create_response(True, f"找到 {len(results)} 条稀疏检索结果", formatted_results) 77 | 78 | except Exception as e: 79 | logger.error(f"稀疏检索测试失败: {e}", exc_info=True) 80 | return self.create_response(False, f"稀疏检索测试失败: {e}") 81 | 82 | async def rebuild_sparse_index(self) -> Dict[str, Any]: 83 | """重建稀疏检索索引""" 84 | if not self.sparse_retriever: 85 | return self.create_response(False, "稀疏检索器未启用") 86 | 87 | try: 88 | await self.sparse_retriever.rebuild_index() 89 | return self.create_response(True, "稀疏检索索引重建完成") 90 | except Exception as e: 91 | logger.error(f"重建稀疏索引失败: {e}", exc_info=True) 92 | return self.create_response(False, f"重建稀疏索引失败: {e}") 93 | 94 | def format_search_results_for_display(self, response: Dict[str, Any]) -> str: 95 | """格式化搜索结果用于显示""" 96 | if not response.get("success"): 97 | return response.get("message", "搜索失败") 98 | 99 | data = response.get("data", []) 100 | message = response.get("message", "") 101 | 102 | response_parts = [message] 103 | 104 | for res in data: 105 | metadata = res.get("metadata", {}) 106 | create_time_str = self.format_timestamp(metadata.get("create_time")) 107 | last_access_time_str = self.format_timestamp(metadata.get("last_access_time")) 108 | importance_score = metadata.get("importance", 0.0) 109 | event_type = metadata.get("event_type", "未知") 110 | 111 | card = ( 112 | f"ID: {res['id']}\n" 113 | f"记 忆 度: {res['similarity']:.2f}\n" 114 | f"重 要 性: {importance_score:.2f}\n" 115 | f"记忆类型: {event_type}\n\n" 116 | f"内容: {res['text']}\n\n" 117 | f"创建于: {create_time_str}\n" 118 | f"最后访问: {last_access_time_str}" 119 | ) 120 | response_parts.append(card) 121 | 122 | return "\n\n".join(response_parts) 123 | 124 | def format_sparse_results_for_display(self, response: Dict[str, Any]) -> str: 125 | """格式化稀疏检索结果用于显示""" 126 | if not response.get("success"): 127 | return response.get("message", "搜索失败") 128 | 129 | data = response.get("data", []) 130 | message = response.get("message", "") 131 | 132 | response_parts = [message] 133 | 134 | for i, res in enumerate(data, 1): 135 | response_parts.append(f"\n{i}. [ID: {res['doc_id']}] Score: {res['score']:.3f}") 136 | response_parts.append(f" 内容: {res['content'][:100]}{'...' if len(res['content']) > 100 else ''}") 137 | 138 | # 显示元数据 139 | metadata = res.get("metadata", {}) 140 | if metadata.get("event_type"): 141 | response_parts.append(f" 类型: {metadata['event_type']}") 142 | if metadata.get("importance"): 143 | response_parts.append(f" 重要性: {metadata['importance']:.2f}") 144 | 145 | return "\n".join(response_parts) -------------------------------------------------------------------------------- /core/models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | models.py - 插件的核心数据模型 4 | """ 5 | 6 | from datetime import datetime, timezone 7 | from enum import Enum 8 | from typing import List, Dict, Any, Optional 9 | from pydantic import BaseModel, Field 10 | 11 | # --- 公开的数据模型 --- 12 | 13 | 14 | class EventType(str, Enum): 15 | FACT = "fact" 16 | PREFERENCE = "preference" 17 | GOAL = "goal" 18 | OPINION = "opinion" 19 | RELATIONSHIP = "relationship" 20 | OTHER = "other" 21 | 22 | 23 | class Entity(BaseModel): 24 | name: str = Field(..., description="实体名称") 25 | type: str = Field(..., description="实体类型") 26 | 27 | 28 | class MemoryEvent(BaseModel): 29 | # --- 系统生成字段 --- 30 | timestamp: datetime = Field( 31 | default_factory=lambda: datetime.now(timezone.utc), 32 | description="事件创建的UTC时间戳", 33 | ) 34 | 35 | # --- 系统生成字段 --- 36 | id: Optional[int] = Field(None, description="记忆的唯一整数ID,由存储后端生成") 37 | 38 | # --- LLM 生成字段 (第一阶段) --- 39 | temp_id: str = Field(..., description="由LLM生成的临时唯一ID,用于评分匹配") 40 | memory_content: str = Field(..., description="对事件的简洁、客观的描述") 41 | event_type: EventType = Field(default=EventType.OTHER, description="事件的分类") 42 | entities: List[Entity] = Field( 43 | default_factory=list, description="事件中涉及的关键实体" 44 | ) 45 | 46 | # --- LLM 生成字段 (第二阶段) --- 47 | importance_score: Optional[float] = Field( 48 | None, ge=0.0, le=1.0, description="记忆的重要性评分 (0.0-1.0)" 49 | ) 50 | 51 | # --- 系统关联字段 --- 52 | related_event_ids: List[int] = Field( 53 | default_factory=list, description="与此事件相关的其他事件ID" 54 | ) 55 | metadata: Dict[str, Any] = Field( 56 | default_factory=dict, description="用于存储其他附加信息的灵活字段" 57 | ) 58 | 59 | 60 | class MemoryEventList(BaseModel): 61 | events: List[MemoryEvent] 62 | 63 | 64 | # --- 用于生成 Prompt Schema 的私有模型 --- 65 | 66 | 67 | # 用于第一阶段:事件提取 68 | class _LLMExtractionEvent(BaseModel): 69 | temp_id: str = Field( 70 | ..., 71 | description="由LLM或系统生成的临时唯一ID", 72 | ) 73 | memory_content: str = Field(...) 74 | event_type: EventType = Field(default=EventType.OTHER) 75 | entities: List[Entity] = Field(default_factory=list) 76 | related_event_ids: List[int] = Field(default_factory=list) 77 | metadata: Dict[str, Any] = Field(default_factory=dict) 78 | 79 | 80 | class _LLMExtractionEventList(BaseModel): 81 | events: List[_LLMExtractionEvent] 82 | 83 | 84 | # 用于第二阶段:评分 85 | class _LLMScoreEvaluation(BaseModel): 86 | scores: Dict[str, float] = Field( 87 | ..., description="一个字典,key是事件的临时ID (temp_id),value是对应的0.0-1.0的重要性分数" 88 | ) 89 | -------------------------------------------------------------------------------- /core/models/memory_models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import dataclasses 4 | from dataclasses import dataclass, field 5 | from typing import List, Dict, Any, Optional 6 | 7 | 8 | @dataclass 9 | class LinkedMedia: 10 | media_id: str 11 | media_type: str 12 | url: str 13 | caption: str 14 | embedding: List[float] = field(default_factory=list) 15 | 16 | 17 | @dataclass 18 | class AccessInfo: 19 | initial_creation_timestamp: str 20 | last_accessed_timestamp: str 21 | access_count: int = 1 22 | 23 | 24 | @dataclass 25 | class EmotionalValence: 26 | sentiment: str 27 | intensity: float 28 | 29 | 30 | @dataclass 31 | class UserFeedback: 32 | is_accurate: Optional[bool] = None 33 | is_important: Optional[bool] = None 34 | correction_text: Optional[str] = None 35 | 36 | 37 | @dataclass 38 | class CommunityInfo: 39 | id: Optional[str] = None 40 | last_calculated: Optional[str] = None 41 | 42 | 43 | @dataclass 44 | class Metadata: 45 | source_conversation_id: str 46 | memory_type: str # 'episodic', 'semantic', 'procedural' 47 | importance_score: float 48 | access_info: AccessInfo 49 | confidence_score: Optional[float] = None 50 | emotional_valence: Optional[EmotionalValence] = None 51 | user_feedback: UserFeedback = field(default_factory=UserFeedback) 52 | community_info: CommunityInfo = field(default_factory=CommunityInfo) 53 | session_id: Optional[str] = None 54 | persona_id: Optional[str] = None 55 | 56 | 57 | @dataclass 58 | class EventEntity: 59 | event_id: str 60 | event_type: str 61 | 62 | 63 | @dataclass 64 | class Entity: 65 | entity_id: str 66 | name: str 67 | type: str 68 | role: Optional[str] = None 69 | 70 | 71 | @dataclass 72 | class KnowledgeGraphPayload: 73 | event_entity: EventEntity 74 | entities: List[Entity] = field(default_factory=list) 75 | relationships: List[List[str]] = field(default_factory=list) 76 | 77 | 78 | @dataclass 79 | class Memory: 80 | memory_id: str # UUID 81 | timestamp: str # ISO 8601 82 | summary: str 83 | description: str 84 | metadata: Metadata 85 | embedding: List[float] = field(default_factory=list) 86 | linked_media: List[LinkedMedia] = field(default_factory=list) 87 | knowledge_graph_payload: Optional[KnowledgeGraphPayload] = None 88 | 89 | # 提供一个方便的方法来将 dataclass 转换为 dict 90 | def to_dict(self) -> Dict[str, Any]: 91 | return dataclasses.asdict(self) 92 | 93 | # 提供一个方便的方法从 dict 创建 dataclass 实例 94 | @classmethod 95 | def from_dict(cls, data: Dict[str, Any]) -> "Memory": 96 | # 嵌套结构的递归转换 97 | field_types = {f.name: f.type for f in dataclasses.fields(cls)} 98 | 99 | # 简单递归转换,TODO 对于复杂场景可能需要更完善的库如 dacite 100 | for name, T in field_types.items(): 101 | if ( 102 | hasattr(T, "from_dict") 103 | and name in data 104 | and isinstance(data[name], dict) 105 | ): 106 | data[name] = T.from_dict(data[name]) 107 | elif isinstance(data.get(name), list): 108 | # 处理 dataclass 列表 109 | origin = getattr(T, "__origin__", None) 110 | if origin is list: 111 | item_type = T.__args__[0] 112 | if hasattr(item_type, "from_dict"): 113 | data[name] = [item_type.from_dict(item) for item in data[name]] 114 | 115 | return cls(**data) 116 | -------------------------------------------------------------------------------- /core/retrieval/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | 检索模块 4 | """ 5 | 6 | from .sparse_retriever import SparseRetriever, SparseResult 7 | from .result_fusion import ResultFusion, SearchResult 8 | 9 | __all__ = [ 10 | "SparseRetriever", 11 | "SparseResult", 12 | "ResultFusion", 13 | "SearchResult" 14 | ] -------------------------------------------------------------------------------- /core/retrieval/sparse_retriever.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | 稀疏检索器 - 基于 SQLite FTS5 和 BM25 的全文检索 4 | """ 5 | 6 | import json 7 | import sqlite3 8 | import math 9 | from typing import List, Dict, Any, Optional, Tuple 10 | from dataclasses import dataclass 11 | import asyncio 12 | import aiosqlite 13 | 14 | from astrbot.api import logger 15 | 16 | try: 17 | import jieba 18 | JIEBA_AVAILABLE = True 19 | except ImportError: 20 | JIEBA_AVAILABLE = False 21 | logger.warning("jieba not available, Chinese tokenization disabled") 22 | 23 | 24 | @dataclass 25 | class SparseResult: 26 | """稀疏检索结果""" 27 | doc_id: int 28 | score: float 29 | content: str 30 | metadata: Dict[str, Any] 31 | 32 | 33 | class FTSManager: 34 | """FTS5 索引管理器""" 35 | 36 | def __init__(self, db_path: str): 37 | self.db_path = db_path 38 | self.fts_table_name = "documents_fts" 39 | 40 | async def initialize(self): 41 | """初始化 FTS5 索引""" 42 | async with aiosqlite.connect(self.db_path) as db: 43 | # 启用 FTS5 扩展 44 | await db.execute("PRAGMA foreign_keys = ON") 45 | 46 | # 创建 FTS5 虚拟表 47 | await db.execute(f""" 48 | CREATE VIRTUAL TABLE IF NOT EXISTS {self.fts_table_name} 49 | USING fts5(content, doc_id, tokenize='unicode61') 50 | """) 51 | 52 | # 创建触发器,保持同步 53 | await self._create_triggers(db) 54 | 55 | await db.commit() 56 | logger.info(f"FTS5 index initialized: {self.fts_table_name}") 57 | 58 | async def _create_triggers(self, db: aiosqlite.Connection): 59 | """创建数据同步触发器""" 60 | # 插入触发器 61 | await db.execute(f""" 62 | CREATE TRIGGER IF NOT EXISTS documents_ai 63 | AFTER INSERT ON documents BEGIN 64 | INSERT INTO {self.fts_table_name}(doc_id, content) 65 | VALUES (new.id, new.text); 66 | END; 67 | """) 68 | 69 | # 删除触发器 70 | await db.execute(f""" 71 | CREATE TRIGGER IF NOT EXISTS documents_ad 72 | AFTER DELETE ON documents BEGIN 73 | DELETE FROM {self.fts_table_name} WHERE doc_id = old.id; 74 | END; 75 | """) 76 | 77 | # 更新触发器 78 | await db.execute(f""" 79 | CREATE TRIGGER IF NOT EXISTS documents_au 80 | AFTER UPDATE ON documents BEGIN 81 | DELETE FROM {self.fts_table_name} WHERE doc_id = old.id; 82 | INSERT INTO {self.fts_table_name}(doc_id, content) 83 | VALUES (new.id, new.text); 84 | END; 85 | """) 86 | 87 | async def rebuild_index(self): 88 | """重建索引""" 89 | async with aiosqlite.connect(self.db_path) as db: 90 | await db.execute(f"DELETE FROM {self.fts_table_name}") 91 | await db.execute(f""" 92 | INSERT INTO {self.fts_table_name}(doc_id, content) 93 | SELECT id, text FROM documents 94 | """) 95 | await db.commit() 96 | logger.info("FTS index rebuilt") 97 | 98 | async def search(self, query: str, limit: int = 50) -> List[Tuple[int, float]]: 99 | """执行 BM25 搜索""" 100 | async with aiosqlite.connect(self.db_path) as db: 101 | # 将整个查询用双引号包裹,以处理特殊字符并将其作为短语搜索 102 | # 这是为了防止 FTS5 语法错误,例如 'syntax error near "."' 103 | safe_query = f'"{query}"' 104 | 105 | # 使用 BM25 算法搜索 106 | cursor = await db.execute(f""" 107 | SELECT doc_id, bm25({self.fts_table_name}) as score 108 | FROM {self.fts_table_name} 109 | WHERE {self.fts_table_name} MATCH ? 110 | ORDER BY score 111 | LIMIT ? 112 | """, (safe_query, limit)) 113 | 114 | results = await cursor.fetchall() 115 | return [(row[0], row[1]) for row in results] 116 | 117 | 118 | class SparseRetriever: 119 | """稀疏检索器""" 120 | 121 | def __init__(self, db_path: str, config: Dict[str, Any] = None): 122 | self.db_path = db_path 123 | self.config = config or {} 124 | self.fts_manager = FTSManager(db_path) 125 | self.enabled = self.config.get("enabled", True) 126 | self.use_chinese_tokenizer = self.config.get("use_chinese_tokenizer", JIEBA_AVAILABLE) 127 | 128 | async def initialize(self): 129 | """初始化稀疏检索器""" 130 | if not self.enabled: 131 | logger.info("Sparse retriever disabled") 132 | return 133 | 134 | await self.fts_manager.initialize() 135 | 136 | # 如果启用中文分词,初始化 jieba 137 | if self.use_chinese_tokenizer and JIEBA_AVAILABLE: 138 | # 可以添加自定义词典 139 | pass 140 | 141 | logger.info("Sparse retriever initialized") 142 | 143 | def _preprocess_query(self, query: str) -> str: 144 | """预处理查询""" 145 | query = query.strip() 146 | 147 | # 中文分词 148 | if self.use_chinese_tokenizer and JIEBA_AVAILABLE: 149 | # 检查是否包含中文 150 | if any('\u4e00' <= char <= '\u9fff' for char in query): 151 | tokens = jieba.cut_for_search(query) 152 | query = " ".join(tokens) 153 | 154 | query = query.replace('"', ' ') # 将内部的双引号替换为空格 155 | 156 | return query 157 | 158 | async def search( 159 | self, 160 | query: str, 161 | limit: int = 50, 162 | session_id: Optional[str] = None, 163 | persona_id: Optional[str] = None, 164 | metadata_filters: Optional[Dict[str, Any]] = None 165 | ) -> List[SparseResult]: 166 | """执行稀疏检索""" 167 | if not self.enabled: 168 | return [] 169 | 170 | try: 171 | # 预处理查询 172 | processed_query = self._preprocess_query(query) 173 | logger.debug(f"Sparse search query: {processed_query}") 174 | 175 | # 执行 FTS 搜索 176 | fts_results = await self.fts_manager.search(processed_query, limit) 177 | 178 | if not fts_results: 179 | return [] 180 | 181 | # 获取完整的文档信息 182 | doc_ids = [doc_id for doc_id, _ in fts_results] 183 | documents = await self._get_documents(doc_ids) 184 | 185 | # 应用过滤器 186 | filtered_results = [] 187 | for doc_id, bm25_score in fts_results: 188 | if doc_id in documents: 189 | doc = documents[doc_id] 190 | 191 | # 检查元数据过滤器 192 | if self._apply_filters(doc.get("metadata", {}), session_id, persona_id, metadata_filters): 193 | result = SparseResult( 194 | doc_id=doc_id, 195 | score=bm25_score, 196 | content=doc["text"], 197 | metadata=doc["metadata"] 198 | ) 199 | filtered_results.append(result) 200 | 201 | # 归一化 BM25 分数(转换为 0-1) 202 | if filtered_results: 203 | max_score = max(r.score for r in filtered_results) 204 | min_score = min(r.score for r in filtered_results) 205 | score_range = max_score - min_score if max_score != min_score else 1 206 | 207 | for result in filtered_results: 208 | result.score = (result.score - min_score) / score_range 209 | 210 | logger.debug(f"Sparse search returned {len(filtered_results)} results") 211 | return filtered_results 212 | 213 | except Exception as e: 214 | logger.error(f"Sparse search error: {e}", exc_info=True) 215 | return [] 216 | 217 | async def _get_documents(self, doc_ids: List[int]) -> Dict[int, Dict[str, Any]]: 218 | """批量获取文档""" 219 | async with aiosqlite.connect(self.db_path) as db: 220 | placeholders = ",".join("?" for _ in doc_ids) 221 | cursor = await db.execute(f""" 222 | SELECT id, text, metadata FROM documents WHERE id IN ({placeholders}) 223 | """, doc_ids) 224 | 225 | documents = {} 226 | async for row in cursor: 227 | metadata = json.loads(row[2]) if isinstance(row[2], str) else row[2] 228 | documents[row[0]] = { 229 | "text": row[1], 230 | "metadata": metadata or {} 231 | } 232 | 233 | return documents 234 | 235 | def _apply_filters( 236 | self, 237 | metadata: Dict[str, Any], 238 | session_id: Optional[str], 239 | persona_id: Optional[str], 240 | metadata_filters: Optional[Dict[str, Any]] 241 | ) -> bool: 242 | """应用过滤器""" 243 | # 会话过滤 244 | if session_id and metadata.get("session_id") != session_id: 245 | return False 246 | 247 | # 人格过滤 248 | if persona_id and metadata.get("persona_id") != persona_id: 249 | return False 250 | 251 | # 自定义元数据过滤 252 | if metadata_filters: 253 | for key, value in metadata_filters.items(): 254 | if metadata.get(key) != value: 255 | return False 256 | 257 | return True 258 | 259 | async def rebuild_index(self): 260 | """重建索引""" 261 | if not self.enabled: 262 | return 263 | await self.fts_manager.rebuild_index() -------------------------------------------------------------------------------- /core/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | utils.py - 插件的辅助工具函数 4 | """ 5 | 6 | import re 7 | import json 8 | import time 9 | import asyncio 10 | from datetime import datetime 11 | from typing import List, Optional, Dict, Any 12 | 13 | import pytz 14 | 15 | from astrbot.api import logger 16 | from astrbot.api.star import Context 17 | from astrbot.api.event import AstrMessageEvent 18 | from ..storage.faiss_manager import Result 19 | from .constants import MEMORY_INJECTION_HEADER, MEMORY_INJECTION_FOOTER 20 | 21 | 22 | def safe_parse_metadata(metadata_raw: Any) -> Dict[str, Any]: 23 | """ 24 | 安全解析元数据,统一处理字符串和字典类型。 25 | 26 | Args: 27 | metadata_raw: 原始元数据,可能是字符串或字典 28 | 29 | Returns: 30 | Dict[str, Any]: 解析后的元数据字典,解析失败时返回空字典 31 | """ 32 | if isinstance(metadata_raw, dict): 33 | return metadata_raw 34 | elif isinstance(metadata_raw, str): 35 | try: 36 | return json.loads(metadata_raw) 37 | except (json.JSONDecodeError, TypeError) as e: 38 | logger.warning(f"解析元数据JSON失败: {e}, 原始数据: {metadata_raw}") 39 | return {} 40 | else: 41 | logger.warning(f"不支持的元数据类型: {type(metadata_raw)}") 42 | return {} 43 | 44 | 45 | def safe_serialize_metadata(metadata: Dict[str, Any]) -> str: 46 | """ 47 | 安全序列化元数据为JSON字符串。 48 | 49 | Args: 50 | metadata: 元数据字典 51 | 52 | Returns: 53 | str: JSON字符串 54 | """ 55 | try: 56 | return json.dumps(metadata, ensure_ascii=False) 57 | except (TypeError, ValueError) as e: 58 | logger.error(f"序列化元数据失败: {e}, 数据: {metadata}") 59 | return "{}" 60 | 61 | 62 | def validate_timestamp(timestamp: Any, default_time: Optional[float] = None) -> float: 63 | """ 64 | 验证和标准化时间戳。 65 | 66 | Args: 67 | timestamp: 时间戳,可能是字符串、数字或其他类型 68 | default_time: 默认时间,如果为None则使用当前时间 69 | 70 | Returns: 71 | float: 标准化的时间戳 72 | """ 73 | if default_time is None: 74 | default_time = time.time() 75 | 76 | if isinstance(timestamp, (int, float)): 77 | return float(timestamp) 78 | elif isinstance(timestamp, str): 79 | try: 80 | return float(timestamp) 81 | except (ValueError, TypeError): 82 | logger.warning(f"无法解析时间戳字符串: {timestamp}") 83 | return default_time 84 | elif hasattr(timestamp, 'timestamp'): # datetime对象 85 | try: 86 | return timestamp.timestamp() 87 | except Exception as e: 88 | logger.warning(f"无法从datetime对象获取时间戳: {e}") 89 | return default_time 90 | else: 91 | logger.warning(f"不支持的时间戳类型: {type(timestamp)}") 92 | return default_time 93 | 94 | 95 | async def retry_on_failure( 96 | func, 97 | *args, 98 | max_retries: int = 3, 99 | backoff_factor: float = 1.0, 100 | exceptions: tuple = (Exception,), 101 | **kwargs 102 | ): 103 | """ 104 | 带重试机制的函数执行器。 105 | 106 | Args: 107 | func: 要执行的函数 108 | *args: 函数位置参数 109 | max_retries: 最大重试次数 110 | backoff_factor: 退避因子 111 | exceptions: 需要重试的异常类型 112 | **kwargs: 函数关键字参数 113 | 114 | Returns: 115 | 函数执行结果 116 | """ 117 | last_exception = None 118 | 119 | for attempt in range(max_retries + 1): 120 | try: 121 | if asyncio.iscoroutinefunction(func): 122 | return await func(*args, **kwargs) 123 | else: 124 | return func(*args, **kwargs) 125 | except exceptions as e: 126 | last_exception = e 127 | if attempt < max_retries: 128 | wait_time = backoff_factor * (2 ** attempt) 129 | logger.warning(f"函数 {func.__name__} 执行失败 (尝试 {attempt + 1}/{max_retries + 1}): {e}") 130 | logger.info(f"等待 {wait_time:.2f} 秒后重试...") 131 | await asyncio.sleep(wait_time) 132 | else: 133 | logger.error(f"函数 {func.__name__} 重试 {max_retries} 次后仍然失败: {e}") 134 | 135 | # 所有重试都失败,抛出最后一个异常 136 | raise last_exception 137 | 138 | 139 | class OperationContext: 140 | """操作上下文管理器,用于错误处理和资源清理""" 141 | 142 | def __init__(self, operation_name: str, session_id: Optional[str] = None): 143 | self.operation_name = operation_name 144 | self.session_id = session_id 145 | self.start_time = None 146 | 147 | async def __aenter__(self): 148 | self.start_time = time.time() 149 | session_info = f"[{self.session_id}] " if self.session_id else "" 150 | logger.debug(f"{session_info}开始执行操作: {self.operation_name}") 151 | return self 152 | 153 | async def __aexit__(self, exc_type, exc_val, exc_tb): 154 | duration = time.time() - self.start_time if self.start_time else 0 155 | session_info = f"[{self.session_id}] " if self.session_id else "" 156 | 157 | if exc_type is None: 158 | logger.debug(f"{session_info}操作成功完成: {self.operation_name} (耗时 {duration:.3f}s)") 159 | else: 160 | logger.error(f"{session_info}操作失败: {self.operation_name} (耗时 {duration:.3f}s) - {exc_val}") 161 | 162 | # 不抑制异常,让调用者处理 163 | return False 164 | 165 | 166 | async def get_persona_id(context: Context, event: AstrMessageEvent) -> Optional[str]: 167 | """ 168 | 获取当前会话的人格 ID。 169 | 如果当前会话没有特定人格,则返回 AstrBot 的默认人格。 170 | """ 171 | try: 172 | session_id = await context.conversation_manager.get_curr_conversation_id( 173 | event.unified_msg_origin 174 | ) 175 | conversation = await context.conversation_manager.get_conversation( 176 | event.unified_msg_origin, session_id 177 | ) 178 | persona_id = conversation.persona_id if conversation else None 179 | 180 | # 如果无人格或明确设置为None,则使用全局默认人格 181 | if not persona_id or persona_id == "[%None]": 182 | default_persona = context.provider_manager.selected_default_persona 183 | persona_id = default_persona["name"] if default_persona else None 184 | 185 | return persona_id 186 | except Exception as e: 187 | # 在某些情况下(如无会话),获取可能会失败,返回 None 188 | logger.debug(f"获取人格ID失败: {e}") 189 | return None 190 | 191 | 192 | def extract_json_from_response(text: str) -> str: 193 | """ 194 | 从可能包含 Markdown 代码块的文本中提取纯 JSON 字符串。 195 | """ 196 | # 查找被 ```json ... ``` 或 ``` ... ``` 包围的内容 197 | match = re.search(r"```(json)?\s*(\{.*?\})\s*```", text, re.DOTALL) 198 | if match: 199 | # 返回捕获组中的 JSON 部分 200 | return match.group(2) 201 | 202 | # 如果没有找到代码块,假设整个文本就是 JSON(可能需要去除首尾空格) 203 | return text.strip() 204 | 205 | 206 | def get_now_datetime(tz_str: str = "Asia/Shanghai") -> datetime: 207 | """ 208 | 获取当前时间,并根据指定的时区设置时区。 209 | 210 | Args: 211 | tz_str: 时区字符串,默认为 "Asia/Shanghai" 212 | 213 | Returns: 214 | datetime: 带有时区信息的当前时间 215 | """ 216 | # 如果传入的是 Context 对象,则使用从上下文获取时间的方法 217 | # 检查传入的是否是 Context 对象 218 | from astrbot.api.star import Context # 导入 Context 类型 219 | if isinstance(tz_str, Context): 220 | # 如果是 Context 对象,调用专门的函数处理 221 | return get_now_datetime_from_context(tz_str) 222 | 223 | try: 224 | timezone = pytz.timezone(tz_str) 225 | except pytz.UnknownTimeZoneError: 226 | # 如果时区无效,则使用默认值 227 | logger.warning(f"无效的时区: {tz_str},使用默认时区 Asia/Shanghai") 228 | timezone = pytz.timezone("Asia/Shanghai") 229 | 230 | return datetime.now(timezone) 231 | 232 | 233 | def get_now_datetime_from_context(context: Context) -> datetime: 234 | """ 235 | 从上下文中获取当前时间,根据插件配置设置时区。 236 | 237 | Args: 238 | context: AstrBot 上下文对象 239 | 240 | Returns: 241 | datetime: 带有时区信息的当前时间 242 | """ 243 | try: 244 | # 尝试从配置中获取时区 245 | tz_str = context.plugin_config.get("timezone_settings", {}).get("timezone", "Asia/Shanghai") 246 | return get_now_datetime(tz_str) 247 | except (AttributeError, KeyError): 248 | # 如果配置不存在,则使用默认值 249 | return get_now_datetime() 250 | 251 | 252 | def format_memories_for_injection(memories: List[Result]) -> str: 253 | """ 254 | 将检索到的记忆列表格式化为单个字符串,以便注入到 System Prompt。 255 | """ 256 | if not memories: 257 | return "" 258 | 259 | header = f"{MEMORY_INJECTION_HEADER}\n" 260 | footer = f"\n{MEMORY_INJECTION_FOOTER}" 261 | 262 | formatted_entries = [] 263 | for mem in memories: 264 | try: 265 | # 使用统一的元数据解析函数 266 | metadata_raw = mem.data.get("metadata", "{}") 267 | metadata = safe_parse_metadata(metadata_raw) 268 | 269 | content = mem.data.get("text", "内容缺失") 270 | importance = metadata.get("importance", 0.0) 271 | 272 | entry = f"- [重要性: {importance:.2f}] {content}" 273 | formatted_entries.append(entry) 274 | except Exception as e: 275 | # 如果处理失败,则跳过此条记忆 276 | logger.debug(f"格式化记忆时出错,跳过此记忆: {e}") 277 | continue 278 | 279 | if not formatted_entries: 280 | return "" 281 | 282 | body = "\n".join(formatted_entries) 283 | 284 | return f"{header}{body}{footer}" 285 | -------------------------------------------------------------------------------- /docs/CONFIG.md: -------------------------------------------------------------------------------- 1 | # LivingMemory 配置参考 2 | 3 | 本文档详细介绍了 LivingMemory 插件的所有配置参数。 4 | 5 | ## 📋 配置概览 6 | 7 | 插件配置采用层次化结构,主要包含以下几个部分: 8 | - 时区设置 9 | - Provider 设置 10 | - 会话管理器 11 | - 回忆引擎 12 | - 反思引擎 13 | - 遗忘代理 14 | - 结果融合 15 | - 稀疏检索器 16 | - 过滤设置 17 | 18 | ## ⚙️ 详细配置参数 19 | 20 | ### 时区设置 (timezone_settings) 21 | 22 | | 参数 | 类型 | 默认值 | 描述 | 23 | |------|------|--------|------| 24 | | `timezone` | string | `"Asia/Shanghai"` | IANA 时区数据库名称,影响时间显示格式 | 25 | 26 | **示例:** 27 | ```yaml 28 | timezone_settings: 29 | timezone: "America/New_York" # 纽约时区 30 | ``` 31 | 32 | **可用时区:** 33 | - `Asia/Shanghai` - 中国标准时间 34 | - `America/New_York` - 美国东部时间 35 | - `Europe/London` - 格林威治时间 36 | - `Asia/Tokyo` - 日本标准时间 37 | 38 | ### Provider 设置 (provider_settings) 39 | 40 | | 参数 | 类型 | 默认值 | 描述 | 41 | |------|------|--------|------| 42 | | `embedding_provider_id` | string | `""` | 指定用于生成向量的 Embedding Provider ID | 43 | | `llm_provider_id` | string | `""` | 指定用于总结和评估的 LLM Provider ID | 44 | 45 | **示例:** 46 | ```yaml 47 | provider_settings: 48 | embedding_provider_id: "openai_embedding" 49 | llm_provider_id: "claude_3_5" 50 | ``` 51 | 52 | **注意:** 53 | - 留空将自动使用 AstrBot 的默认 Provider 54 | - 确保指定的 Provider 已在 AstrBot 中正确配置 55 | 56 | ### 会话管理器 (session_manager) 57 | 58 | | 参数 | 类型 | 默认值 | 范围 | 描述 | 59 | |------|------|--------|------|------| 60 | | `max_sessions` | int | `1000` | 1-10000 | 同时维护的最大会话数量 | 61 | | `session_ttl` | int | `3600` | 60-86400 | 会话生存时间(秒) | 62 | 63 | **示例:** 64 | ```yaml 65 | session_manager: 66 | max_sessions: 500 # 最大500个会话 67 | session_ttl: 7200 # 2小时过期 68 | ``` 69 | 70 | **优化建议:** 71 | - 高并发场景:增大 `max_sessions` 72 | - 内存紧张:减小 `session_ttl` 73 | - 长对话场景:增大 `session_ttl` 74 | 75 | ### 回忆引擎 (recall_engine) 76 | 77 | | 参数 | 类型 | 默认值 | 范围 | 描述 | 78 | |------|------|--------|------|------| 79 | | `top_k` | int | `5` | 1-50 | 单次检索返回的记忆数量 | 80 | | `recall_strategy` | string | `"weighted"` | similarity/weighted | 召回策略 | 81 | | `retrieval_mode` | string | `"hybrid"` | hybrid/dense/sparse | 检索模式 | 82 | | `similarity_weight` | float | `0.6` | 0.0-1.0 | 相似度权重 | 83 | | `importance_weight` | float | `0.2` | 0.0-1.0 | 重要性权重 | 84 | | `recency_weight` | float | `0.2` | 0.0-1.0 | 新近度权重 | 85 | 86 | **召回策略:** 87 | - `similarity`: 纯基于相似度的召回 88 | - `weighted`: 综合考虑相似度、重要性和新近度 89 | 90 | **检索模式:** 91 | - `hybrid`: 混合检索(推荐) 92 | - `dense`: 纯密集向量检索 93 | - `sparse`: 纯稀疏关键词检索 94 | 95 | **权重调优指南:** 96 | ```yaml 97 | # 重视语义相关性 98 | recall_engine: 99 | similarity_weight: 0.7 100 | importance_weight: 0.2 101 | recency_weight: 0.1 102 | 103 | # 重视重要信息 104 | recall_engine: 105 | similarity_weight: 0.4 106 | importance_weight: 0.5 107 | recency_weight: 0.1 108 | 109 | # 重视最新信息 110 | recall_engine: 111 | similarity_weight: 0.4 112 | importance_weight: 0.2 113 | recency_weight: 0.4 114 | ``` 115 | 116 | ### 反思引擎 (reflection_engine) 117 | 118 | | 参数 | 类型 | 默认值 | 范围 | 描述 | 119 | |------|------|--------|------|------| 120 | | `summary_trigger_rounds` | int | `5` | 1-100 | 触发反思的对话轮次 | 121 | | `importance_threshold` | float | `0.5` | 0.0-1.0 | 记忆重要性阈值 | 122 | | `event_extraction_prompt` | text | 默认提示词 | - | 事件提取提示词 | 123 | | `evaluation_prompt` | text | 默认提示词 | - | 重要性评估提示词 | 124 | 125 | **触发轮次调优:** 126 | - `1-3轮`: 频繁反思,适合重要对话 127 | - `5-10轮`: 平衡模式(推荐) 128 | - `15-30轮`: 长对话模式,减少反思频率 129 | 130 | **重要性阈值:** 131 | - `0.1-0.3`: 宽松模式,保存更多记忆 132 | - `0.5-0.7`: 标准模式(推荐) 133 | - `0.8-1.0`: 严格模式,只保存重要记忆 134 | 135 | ### 遗忘代理 (forgetting_agent) 136 | 137 | | 参数 | 类型 | 默认值 | 范围 | 描述 | 138 | |------|------|--------|------|------| 139 | | `enabled` | bool | `true` | - | 是否启用自动遗忘 | 140 | | `check_interval_hours` | int | `24` | 1-168 | 检查间隔(小时) | 141 | | `retention_days` | int | `90` | 1-3650 | 记忆保留天数 | 142 | | `importance_decay_rate` | float | `0.005` | 0.0-1.0 | 重要性衰减率 | 143 | | `importance_threshold` | float | `0.1` | 0.0-1.0 | 遗忘重要性阈值 | 144 | | `forgetting_batch_size` | int | `1000` | 100-10000 | 批处理大小 | 145 | 146 | **遗忘策略配置:** 147 | ```yaml 148 | # 保守遗忘(保存更多记忆) 149 | forgetting_agent: 150 | retention_days: 180 151 | importance_decay_rate: 0.001 152 | importance_threshold: 0.05 153 | 154 | # 标准遗忘(推荐) 155 | forgetting_agent: 156 | retention_days: 90 157 | importance_decay_rate: 0.005 158 | importance_threshold: 0.1 159 | 160 | # 激进遗忘(节省存储空间) 161 | forgetting_agent: 162 | retention_days: 30 163 | importance_decay_rate: 0.01 164 | importance_threshold: 0.2 165 | ``` 166 | 167 | ### 结果融合 (fusion) 168 | 169 | | 参数 | 类型 | 默认值 | 描述 | 170 | |------|------|--------|------| 171 | | `strategy` | string | `"rrf"` | 融合策略 | 172 | | `rrf_k` | int | `60` | RRF 参数 k | 173 | | `dense_weight` | float | `0.7` | 密集检索权重 | 174 | | `sparse_weight` | float | `0.3` | 稀疏检索权重 | 175 | | `convex_lambda` | float | `0.5` | 凸组合参数 | 176 | | `interleave_ratio` | float | `0.5` | 交替融合比例 | 177 | | `rank_bias_factor` | float | `0.1` | 排序偏置因子 | 178 | | `diversity_bonus` | float | `0.1` | 多样性奖励 | 179 | 180 | **融合策略详解:** 181 | - `rrf`: 经典 Reciprocal Rank Fusion 182 | - `hybrid_rrf`: 自适应 RRF 183 | - `weighted`: 加权融合 184 | - `convex`: 凸组合融合 185 | - `interleave`: 交替融合 186 | - `rank_fusion`: 基于排序的融合 187 | - `score_fusion`: Borda Count 融合 188 | - `cascade`: 级联融合 189 | - `adaptive`: 自适应融合 190 | 191 | 详细说明请参考 [FUSION_STRATEGIES.md](../FUSION_STRATEGIES.md) 192 | 193 | ### 稀疏检索器 (sparse_retriever) 194 | 195 | | 参数 | 类型 | 默认值 | 范围 | 描述 | 196 | |------|------|--------|------|------| 197 | | `enabled` | bool | `true` | - | 是否启用稀疏检索 | 198 | | `bm25_k1` | float | `1.2` | 0.1-10.0 | BM25 k1 参数 | 199 | | `bm25_b` | float | `0.75` | 0.0-1.0 | BM25 b 参数 | 200 | | `use_jieba` | bool | `true` | - | 是否使用中文分词 | 201 | 202 | **BM25 参数调优:** 203 | - `k1`: 控制词频饱和度 204 | - 较小值(0.5-1.0):词频影响较小 205 | - 较大值(1.5-2.0):词频影响较大 206 | - `b`: 控制文档长度归一化 207 | - 0.0:不考虑文档长度 208 | - 1.0:完全归一化文档长度 209 | 210 | **中文优化配置:** 211 | ```yaml 212 | sparse_retriever: 213 | enabled: true 214 | bm25_k1: 1.2 # 适合中文的词频参数 215 | bm25_b: 0.75 # 中等长度归一化 216 | use_jieba: true # 启用中文分词 217 | ``` 218 | 219 | ### 过滤设置 (filtering_settings) 220 | 221 | | 参数 | 类型 | 默认值 | 描述 | 222 | |------|------|--------|------| 223 | | `use_persona_filtering` | bool | `true` | 是否启用人格记忆过滤 | 224 | | `use_session_filtering` | bool | `true` | 是否启用会话记忆隔离 | 225 | 226 | **过滤模式组合:** 227 | ```yaml 228 | # 完全隔离模式 229 | filtering_settings: 230 | use_persona_filtering: true 231 | use_session_filtering: true 232 | 233 | # 人格共享模式 234 | filtering_settings: 235 | use_persona_filtering: true 236 | use_session_filtering: false 237 | 238 | # 会话共享模式 239 | filtering_settings: 240 | use_persona_filtering: false 241 | use_session_filtering: true 242 | 243 | # 全局共享模式 244 | filtering_settings: 245 | use_persona_filtering: false 246 | use_session_filtering: false 247 | ``` 248 | 249 | ## 🎯 场景化配置示例 250 | 251 | ### 个人助手配置 252 | ```yaml 253 | # 适合个人日常使用 254 | session_manager: 255 | max_sessions: 100 256 | session_ttl: 7200 257 | 258 | recall_engine: 259 | top_k: 3 260 | similarity_weight: 0.5 261 | importance_weight: 0.3 262 | recency_weight: 0.2 263 | 264 | reflection_engine: 265 | summary_trigger_rounds: 10 266 | importance_threshold: 0.4 267 | 268 | filtering_settings: 269 | use_persona_filtering: true 270 | use_session_filtering: false 271 | ``` 272 | 273 | ### 客服机器人配置 274 | ```yaml 275 | # 适合客服场景 276 | session_manager: 277 | max_sessions: 1000 278 | session_ttl: 1800 279 | 280 | recall_engine: 281 | top_k: 5 282 | similarity_weight: 0.7 283 | importance_weight: 0.2 284 | recency_weight: 0.1 285 | 286 | reflection_engine: 287 | summary_trigger_rounds: 5 288 | importance_threshold: 0.6 289 | 290 | filtering_settings: 291 | use_persona_filtering: false 292 | use_session_filtering: true 293 | ``` 294 | 295 | ### 教育辅导配置 296 | ```yaml 297 | # 适合教育辅导场景 298 | session_manager: 299 | max_sessions: 500 300 | session_ttl: 3600 301 | 302 | recall_engine: 303 | top_k: 8 304 | similarity_weight: 0.4 305 | importance_weight: 0.4 306 | recency_weight: 0.2 307 | 308 | reflection_engine: 309 | summary_trigger_rounds: 8 310 | importance_threshold: 0.3 311 | 312 | forgetting_agent: 313 | retention_days: 180 314 | importance_decay_rate: 0.002 315 | ``` 316 | 317 | ## 🔧 配置验证 318 | 319 | ### 使用命令验证 320 | ```bash 321 | # 验证当前配置 322 | /lmem config validate 323 | 324 | # 查看配置摘要 325 | /lmem config show 326 | ``` 327 | 328 | ### 配置文件验证 329 | 插件会在启动时自动验证配置: 330 | - ✅ 参数类型检查 331 | - ✅ 数值范围验证 332 | - ✅ 必需字段验证 333 | - ✅ 权重总和警告 334 | 335 | ## 💡 性能优化建议 336 | 337 | ### 内存优化 338 | - 减少 `max_sessions` 和 `session_ttl` 339 | - 降低 `top_k` 值 340 | - 启用积极的遗忘策略 341 | 342 | ### 准确性优化 343 | - 增加 `top_k` 值 344 | - 调整权重配比 345 | - 使用混合检索模式 346 | - 优化融合策略参数 347 | 348 | ### 响应速度优化 349 | - 使用 `cascade` 融合策略 350 | - 减少 `top_k` 值 351 | - 选择更快的检索模式 352 | 353 | ## ⚠️ 注意事项 354 | 355 | 1. **权重总和**:确保回忆引擎的三个权重总和接近 1.0 356 | 2. **Provider 可用性**:确保指定的 Provider 已正确配置 357 | 3. **存储空间**:长期使用需要考虑遗忘策略以控制存储增长 358 | 4. **中文支持**:启用 jieba 分词以获得更好的中文检索效果 359 | 5. **配置热更新**:部分配置修改需要重启插件才能生效 360 | 361 | ## 🔍 配置调试 362 | 363 | ### 查看生效配置 364 | ```bash 365 | /lmem config show 366 | ``` 367 | 368 | ### 测试检索效果 369 | ```bash 370 | # 测试不同检索模式 371 | /lmem search_mode hybrid 372 | /lmem search "测试查询" 5 373 | 374 | # 测试融合策略 375 | /lmem fusion show 376 | /lmem test_fusion "测试查询" 5 377 | ``` 378 | 379 | ### 性能监控 380 | ```bash 381 | # 查看记忆库状态 382 | /lmem status 383 | 384 | # 检查会话数量 385 | /lmem config show | grep 会话 386 | ``` -------------------------------------------------------------------------------- /metadata.yaml: -------------------------------------------------------------------------------- 1 | name: astrbot_plugin_livingmemory 2 | author: lxfight 3 | version: 1.0.1 4 | description: 为 AstrBot 打造的、拥有完整记忆生命周期的智能长期记忆插件。 5 | repo: https://github.com/lxfight-s-Astrbot-Plugins/astrbot_plugin_livingmemory -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | faiss-cpu>=1.7.0 2 | networkx>=3.0 3 | jieba>=0.42.1 -------------------------------------------------------------------------------- /storage/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxfight-s-Astrbot-Plugins/astrbot_plugin_livingmemory/a382b10a91a41b68abc850e5b1abe05052e372f0/storage/__init__.py -------------------------------------------------------------------------------- /storage/faiss_manager_v2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import uuid 3 | import json 4 | import asyncio 5 | from datetime import datetime, timezone 6 | from typing import List, Optional 7 | import aiosqlite 8 | 9 | from astrbot.api import logger 10 | 11 | from ..core.models.memory_models import ( 12 | Memory, 13 | AccessInfo, # noqa: F401 14 | UserFeedback, # noqa: F401 15 | EmotionalValence, # noqa: F401 16 | ) 17 | from .memory_storage import MemoryStorage 18 | from .vector_store import VectorStore 19 | from .graph_storage import GraphStorageSQLite 20 | 21 | 22 | class FaissManagerV2: 23 | """ 24 | 高级管理器,协调 MemoryStorage 和 VectorStore, 25 | 以支持结构化的、具有生命周期的记忆。 26 | """ 27 | 28 | def __init__( 29 | self, 30 | db_path: str, 31 | text_vstore: VectorStore, 32 | image_vstore: VectorStore, 33 | embedding_model, 34 | ): 35 | self.db_path = db_path 36 | self.conn: Optional[aiosqlite.Connection] = None 37 | self.storage: Optional[MemoryStorage] = None 38 | self.graph_storage: Optional[GraphStorageSQLite] = None 39 | 40 | self.text_vstore = text_vstore 41 | self.image_vstore = image_vstore 42 | self.embedding_model = embedding_model 43 | 44 | async def initialize(self): 45 | """ 46 | 建立数据库连接,并初始化所有依赖于此连接的存储组件。 47 | """ 48 | self.conn = await aiosqlite.connect(self.db_path) 49 | # 启用外键约束,这对于 ON DELETE CASCADE 至关重要 50 | await self.conn.execute("PRAGMA foreign_keys = ON;") 51 | 52 | # 初始化各个存储层 53 | self.storage = MemoryStorage(self.conn) # 传入连接对象 54 | await self.storage.initialize_schema() # 创建 memories 表 55 | 56 | self.graph_storage = GraphStorageSQLite(self.conn) # 传入同一个连接对象 57 | await self.graph_storage.initialize_schema() # 创建图相关的表 58 | 59 | async def close(self): 60 | if self.conn: 61 | await self.conn.close() 62 | 63 | async def add_memory(self, memory: Memory) -> str: 64 | """ 65 | 添加一条新的、结构化的记忆。 66 | """ 67 | if not memory.memory_id: 68 | memory.memory_id = str(uuid.uuid4()) 69 | 70 | if not memory.embedding: 71 | memory.embedding = await asyncio.to_thread(self.embedding_model.encode, memory.description) 72 | 73 | # 1. 存入 SQLite 并获取内部 ID 74 | internal_id = await self.storage.add_memory(memory) 75 | 76 | # 2. 将文本向量添加到 text_vstore 77 | if memory.embedding: 78 | await asyncio.to_thread(self.text_vstore.add, [internal_id], [memory.embedding]) 79 | 80 | # 3. 将图像向量添加到 image_vstore 81 | media_embeddings = [ 82 | media.embedding for media in memory.linked_media if media.embedding 83 | ] 84 | if media_embeddings: 85 | media_ids = [internal_id] * len(media_embeddings) 86 | await asyncio.to_thread(self.image_vstore.add, media_ids, media_embeddings) 87 | 88 | # 4. 将图数据添加到 graph_storage 89 | if memory.knowledge_graph_payload: 90 | await self.graph_storage.add_memory_graph( 91 | internal_id, memory.knowledge_graph_payload.to_dict() 92 | ) 93 | 94 | # TODO 考虑定期保存索引,而不是每次都保存 95 | await asyncio.to_thread(self.text_vstore.save_index) 96 | await asyncio.to_thread(self.image_vstore.save_index) 97 | 98 | return memory.memory_id 99 | 100 | async def search_memory( 101 | self, query_text: str, k: int = 10, w1: float = 0.6, w2: float = 0.4 102 | ) -> List[Memory]: 103 | """ 104 | 根据查询文本智能检索最相关的记忆。 105 | """ 106 | # 1a. 向量搜索 107 | query_embedding = await asyncio.to_thread(self.embedding_model.encode, query_text) 108 | # 召回数量可以设置得比最终需要的 k 要大,例如 k*5 109 | distances, text_ids = await asyncio.to_thread(self.text_vstore.search, query_embedding, k * 5) 110 | 111 | # 1b. 基于图的种子扩展 112 | # 假设 embedding_model 有提取实体的能力 113 | query_entities = await asyncio.to_thread(self.embedding_model.extract_entities, query_text) 114 | graph_ids = [] 115 | if query_entities: 116 | for entity_id in query_entities: 117 | graph_ids.extend( 118 | await self.graph_storage.find_related_memory_ids(entity_id) 119 | ) 120 | 121 | # 合并候选并去重 122 | candidate_internal_ids = list(set(list(text_ids) + graph_ids)) 123 | if not candidate_internal_ids: 124 | return [] 125 | 126 | # --- 阶段二:重排与扩展 --- 127 | 128 | # 2a. 获取候选记忆的完整数据 129 | candidate_docs = await self.storage.get_memories_by_internal_ids( 130 | candidate_internal_ids 131 | ) 132 | candidate_memories = { 133 | doc["id"]: Memory.from_dict(json.loads(doc["memory_data"])) 134 | for doc in candidate_docs 135 | } 136 | 137 | # 2b. 计算最终分数并重排 138 | final_scores = {} 139 | text_ids_list = list(text_ids) 140 | for internal_id in candidate_internal_ids: 141 | # Faiss 分数:可以用排名的倒数来表示,排名越靠前分数越高 142 | try: 143 | faiss_score = 1.0 / (text_ids_list.index(internal_id) + 1) 144 | except ValueError: 145 | faiss_score = 0.0 146 | 147 | # 图分数:如果记忆在图扩展的结果中,则获得加分 148 | graph_score = 1.0 if internal_id in graph_ids else 0.0 149 | 150 | final_scores[internal_id] = (w1 * faiss_score) + (w2 * graph_score) 151 | 152 | # 根据最终分数降序排序 153 | sorted_ids = sorted( 154 | final_scores.keys(), key=lambda id: final_scores[id], reverse=True 155 | ) 156 | 157 | # 获取Top-K结果,并更新访问信息 158 | top_k_results = [ 159 | candidate_memories[id] for id in sorted_ids[:k] if id in candidate_memories 160 | ] 161 | await self.update_memory_access_info([mem.memory_id for mem in top_k_results]) 162 | 163 | return top_k_results 164 | 165 | async def update_memory_access_info(self, memory_ids: List[str]): 166 | """ 167 | 批量更新一组记忆的最后访问时间和访问计数。 168 | """ 169 | docs = await self.storage.get_memories_by_memory_ids(memory_ids) 170 | if not docs: 171 | return 172 | 173 | updates = [] 174 | for doc in docs: 175 | try: 176 | memory_dict = json.loads(doc["memory_data"]) 177 | access_info = memory_dict["metadata"]["access_info"] 178 | access_info["last_accessed_timestamp"] = datetime.now( 179 | timezone.utc 180 | ).isoformat() 181 | access_info["access_count"] = access_info.get("access_count", 0) + 1 182 | 183 | updates.append( 184 | { 185 | "memory_id": memory_dict["memory_id"], 186 | "memory_data": json.dumps(memory_dict), 187 | } 188 | ) 189 | except (json.JSONDecodeError, KeyError) as e: 190 | logger.error( 191 | f"Error updating access info for memory_id {doc.get('memory_id')}: {e}" 192 | ) 193 | 194 | if updates: 195 | await self.storage.update_memories(updates) 196 | 197 | async def get_all_memories_for_forgetting(self) -> List[Memory]: 198 | """获取所有记忆,用于遗忘代理的处理。""" 199 | all_docs = await self.storage.get_all_memories() 200 | return [Memory.from_dict(json.loads(doc["memory_data"])) for doc in all_docs] 201 | 202 | async def update_memories_metadata(self, memories: List[Memory]): 203 | """批量更新记忆的完整对象。""" 204 | if not memories: 205 | return 206 | updates = [ 207 | {"memory_id": mem.memory_id, "memory_data": json.dumps(mem.to_dict())} 208 | for mem in memories 209 | ] 210 | await self.storage.update_memories(updates) 211 | 212 | async def delete_memories(self, memory_ids: List[str]): 213 | """ 214 | 批量删除记忆。由于设置了外键级联删除,图关系会被自动清理。 215 | """ 216 | docs = await self.storage.get_memories_by_memory_ids(memory_ids) 217 | if not docs: 218 | return 219 | 220 | internal_ids = [doc["id"] for doc in docs] 221 | 222 | # 1. 从 Faiss 移除 223 | await asyncio.to_thread(self.text_vstore.remove, internal_ids) 224 | await asyncio.to_thread(self.image_vstore.remove, internal_ids) 225 | await asyncio.to_thread(self.text_vstore.save_index) 226 | await asyncio.to_thread(self.image_vstore.save_index) 227 | 228 | # 2. 从 SQLite 的 memories 表移除 229 | # 由于设置了 ON DELETE CASCADE,graph_edges 表中相关的数据会被自动删除 230 | await self.storage.delete_memories_by_internal_ids(internal_ids) 231 | 232 | async def archive_memory(self, memory_id: str): 233 | """实现遗忘逻辑的第一步:归档。""" 234 | # 1. 获取记忆的 internal_id 235 | docs = await self.storage.get_memories_by_memory_ids([memory_id]) 236 | if not docs: 237 | return 238 | internal_id = docs[0]["id"] 239 | 240 | # 2. 从所有 Faiss 索引中移除 241 | await asyncio.to_thread(self.text_vstore.remove, [internal_id]) 242 | await asyncio.to_thread(self.image_vstore.remove, [internal_id]) 243 | # (可选) 也可以从图数据库中移除以节省空间 244 | # await self.graph_storage.delete_graph_for_memory(internal_id) 245 | 246 | # 3. 在 SQLite 中更新状态 247 | await self.storage.update_memory_status([internal_id], "archived") 248 | 249 | logger.info(f"已归档记忆 {memory_id}。") 250 | -------------------------------------------------------------------------------- /storage/graph_storage.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import aiosqlite 3 | from typing import List, Dict, Any 4 | 5 | 6 | class GraphStorageSQLite: 7 | """ 8 | 使用 SQLite 关系表来管理知识图谱数据。 9 | """ 10 | 11 | def __init__(self, connection: aiosqlite.Connection): 12 | """ 13 | 直接接收一个已建立的 aiosqlite 连接,以确保所有操作都在同一事务空间内。 14 | """ 15 | self.conn = connection 16 | 17 | async def initialize_schema(self): 18 | """在数据库中创建图相关的表和索引。""" 19 | await self.conn.execute(""" 20 | CREATE TABLE IF NOT EXISTS graph_nodes ( 21 | entity_id TEXT PRIMARY KEY NOT NULL, 22 | name TEXT NOT NULL, 23 | type TEXT NOT NULL 24 | ) 25 | """) 26 | await self.conn.execute(""" 27 | CREATE TABLE IF NOT EXISTS graph_edges ( 28 | source_id TEXT NOT NULL, 29 | target_id TEXT NOT NULL, 30 | relation_type TEXT NOT NULL, 31 | memory_internal_id INTEGER NOT NULL, 32 | FOREIGN KEY (memory_internal_id) REFERENCES memories(id) ON DELETE CASCADE 33 | ) 34 | """) 35 | await self.conn.execute( 36 | "CREATE INDEX IF NOT EXISTS idx_edge_source ON graph_edges (source_id)" 37 | ) 38 | await self.conn.execute( 39 | "CREATE INDEX IF NOT EXISTS idx_edge_target ON graph_edges (target_id)" 40 | ) 41 | await self.conn.commit() 42 | 43 | async def add_memory_graph(self, internal_id: int, payload: Dict[str, Any]): 44 | """从 knowledge_graph_payload 创建节点和边。""" 45 | # 使用事务确保要么全部成功,要么全部失败 46 | async with self.conn.cursor() as cursor: 47 | # 1. 添加或更新节点 (INSERT OR IGNORE 避免重复) 48 | nodes_to_insert = [] 49 | if "event_entity" in payload and payload["event_entity"]: 50 | nodes_to_insert.append( 51 | ( 52 | payload["event_entity"]["event_id"], 53 | payload["event_entity"].get( 54 | "event_type", "Event" 55 | ), # name can be event_type 56 | "Event", 57 | ) 58 | ) 59 | for entity in payload.get("entities", []): 60 | nodes_to_insert.append( 61 | (entity["entity_id"], entity["name"], entity["type"]) 62 | ) 63 | 64 | if nodes_to_insert: 65 | await cursor.executemany( 66 | "INSERT OR IGNORE INTO graph_nodes (entity_id, name, type) VALUES (?, ?, ?)", 67 | nodes_to_insert, 68 | ) 69 | 70 | # 2. 添加关系边 71 | edges_to_insert = [] 72 | for rel in payload.get("relationships", []): 73 | if len(rel) == 3: 74 | edges_to_insert.append((rel[0], rel[1], rel[2], internal_id)) 75 | 76 | if edges_to_insert: 77 | await cursor.executemany( 78 | "INSERT INTO graph_edges (source_id, relation_type, target_id, memory_internal_id) VALUES (?, ?, ?, ?)", 79 | edges_to_insert, 80 | ) 81 | await self.conn.commit() 82 | 83 | async def find_related_memory_ids( 84 | self, entity_id: str, max_depth: int = 2 85 | ) -> List[int]: 86 | """ 87 | 使用 SQL 递归查询 (Recursive CTE) 从一个实体出发,查找相关联的记忆 internal_id。 88 | """ 89 | query = """ 90 | WITH RECURSIVE graph_walk(node_id, depth) AS ( 91 | -- 递归的起始点 92 | VALUES(:entity_id, 0) 93 | UNION 94 | -- 递归步骤: 从当前节点找到所有相邻节点 95 | SELECT g.target_id, w.depth + 1 96 | FROM graph_edges g JOIN graph_walk w ON g.source_id = w.node_id 97 | WHERE w.depth < :max_depth 98 | UNION 99 | SELECT g.source_id, w.depth + 1 100 | FROM graph_edges g JOIN graph_walk w ON g.target_id = w.node_id 101 | WHERE w.depth < :max_depth 102 | ) 103 | -- 从所有遍历到的节点中,查找它们作为关系边出现时关联的记忆ID 104 | SELECT DISTINCT T.memory_internal_id 105 | FROM graph_edges T 106 | JOIN graph_walk W ON T.source_id = W.node_id OR T.target_id = W.node_id; 107 | """ 108 | async with self.conn.execute( 109 | query, {"entity_id": entity_id, "max_depth": max_depth} 110 | ) as cursor: 111 | rows = await cursor.fetchall() 112 | return [row[0] for row in rows] 113 | 114 | async def add_correction_link( 115 | self, new_event_id: str, old_event_id: str, memory_internal_id: int 116 | ): 117 | """为记忆更新添加 CORRECTS 关系。""" 118 | await self.conn.execute( 119 | "INSERT INTO graph_edges (source_id, relation_type, target_id, memory_internal_id) VALUES (?, 'CORRECTS', ?, ?)", 120 | (new_event_id, old_event_id, memory_internal_id), 121 | ) 122 | await self.conn.commit() 123 | -------------------------------------------------------------------------------- /storage/memory_storage.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import json 4 | import aiosqlite 5 | from typing import List, Dict, Any 6 | 7 | from ..core.models.memory_models import Memory 8 | 9 | 10 | class MemoryStorage: 11 | """ 12 | 用于在 SQLite 中持久化、检索和管理结构化 Memory 对象的类。 13 | """ 14 | 15 | def __init__(self, connection: aiosqlite.Connection): 16 | """ 17 | 修正: 接收一个已建立的 aiosqlite 连接 18 | """ 19 | self.connection = connection 20 | 21 | async def initialize_schema(self): 22 | """ 23 | 建立数据库连接并创建表 24 | """ 25 | # 修正: 创建包含所有需要字段的表 26 | await self.connection.execute(""" 27 | CREATE TABLE IF NOT EXISTS memories ( 28 | id INTEGER PRIMARY KEY AUTOINCREMENT, 29 | memory_id TEXT NOT NULL UNIQUE, 30 | timestamp TEXT NOT NULL, 31 | memory_type TEXT NOT NULL, 32 | importance_score REAL NOT NULL, 33 | status TEXT NOT NULL DEFAULT 'active', 34 | community_id TEXT, -- 为社区发现预留字段 35 | memory_data TEXT NOT NULL 36 | ) 37 | """) 38 | await self.connection.execute(""" 39 | CREATE UNIQUE INDEX IF NOT EXISTS idx_memory_id ON memories (memory_id); 40 | """) 41 | await self.connection.execute(""" 42 | CREATE INDEX IF NOT EXISTS idx_memory_status ON memories (status); 43 | """) 44 | await self.connection.execute(""" 45 | CREATE INDEX IF NOT EXISTS idx_memory_community ON memories (community_id); 46 | """) 47 | await self.connection.commit() 48 | 49 | # Note: 不提供 close() 方法,因为这个类不负责连接的生命周期管理 50 | # 连接的创建和关闭应由更高层的组件(如 FaissManagerV2)负责 51 | 52 | async def add_memory(self, memory: Memory) -> int: 53 | """ 54 | 将一个 Memory 对象添加到数据库,并返回其内部自增 ID。 55 | """ 56 | memory_json = json.dumps(memory.to_dict(), ensure_ascii=False) 57 | status = "active" # 默认状态 58 | 59 | cursor = await self.connection.execute( 60 | """ 61 | INSERT INTO memories (memory_id, timestamp, memory_type, importance_score, status, memory_data) 62 | VALUES (?, ?, ?, ?, ?, ?) 63 | """, 64 | ( 65 | memory.memory_id, 66 | memory.timestamp, 67 | memory.metadata.memory_type, 68 | memory.metadata.importance_score, 69 | status, 70 | memory_json, 71 | ), 72 | ) 73 | await self.connection.commit() 74 | return cursor.lastrowid 75 | 76 | async def get_memories_by_internal_ids( 77 | self, internal_ids: List[int] 78 | ) -> List[Dict[str, Any]]: 79 | """ 80 | 通过内部自增 ID 列表获取记忆数据。 81 | """ 82 | if not internal_ids: 83 | return [] 84 | placeholders = ",".join("?" for _ in internal_ids) 85 | sql = f"SELECT id, memory_id, memory_data FROM memories WHERE id IN ({placeholders})" 86 | async with self.connection.execute(sql, internal_ids) as cursor: 87 | rows = await cursor.fetchall() 88 | return [dict(row) for row in rows] 89 | 90 | async def get_memories_by_memory_ids( 91 | self, memory_ids: List[str] 92 | ) -> List[Dict[str, Any]]: 93 | """ 94 | 通过全局 memory_id (UUID) 列表获取记忆数据。 95 | """ 96 | if not memory_ids: 97 | return [] 98 | placeholders = ",".join("?" for _ in memory_ids) 99 | sql = f"SELECT id, memory_id, memory_data FROM memories WHERE memory_id IN ({placeholders})" 100 | async with self.connection.execute(sql, memory_ids) as cursor: 101 | rows = await cursor.fetchall() 102 | return [dict(row) for row in rows] 103 | 104 | async def get_all_memories(self) -> List[Dict[str, Any]]: 105 | """ 106 | 获取数据库中所有的记忆。 107 | """ 108 | async with self.connection.execute( 109 | "SELECT id, memory_id, memory_data FROM memories" 110 | ) as cursor: 111 | rows = await cursor.fetchall() 112 | return [dict(row) for row in rows] 113 | 114 | async def update_memories(self, memories_to_update: List[Dict[str, Any]]): 115 | """ 116 | 根据 memory_id 批量更新 memory_data。 117 | """ 118 | if not memories_to_update: 119 | return 120 | updates = [(mem["memory_data"], mem["memory_id"]) for mem in memories_to_update] 121 | await self.connection.executemany( 122 | "UPDATE memories SET memory_data = ? WHERE memory_id = ?", updates 123 | ) 124 | await self.connection.commit() 125 | 126 | async def delete_memories_by_internal_ids(self, internal_ids: List[int]): 127 | """ 128 | 根据内部 ID 列表删除记忆。 129 | """ 130 | if not internal_ids: 131 | return 132 | placeholders = ",".join("?" for _ in internal_ids) 133 | await self.connection.execute( 134 | f"DELETE FROM memories WHERE id IN ({placeholders})", internal_ids 135 | ) 136 | await self.connection.commit() 137 | 138 | async def update_memory_status(self, internal_ids: List[int], new_status: str): 139 | """ 140 | 批量更新记忆的状态 (例如, 改为 'archived') 141 | """ 142 | if not internal_ids: 143 | return 144 | placeholders = ",".join("?" for _ in internal_ids) 145 | # 注意参数绑定的方式,new_status 在前 146 | await self.connection.execute( 147 | f"UPDATE memories SET status = ? WHERE id IN ({placeholders})", 148 | [new_status] + internal_ids, 149 | ) 150 | await self.connection.commit() 151 | -------------------------------------------------------------------------------- /storage/vector_store.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import asyncio 4 | import faiss 5 | import numpy as np 6 | from typing import List, Tuple, Optional 7 | 8 | from astrbot.api import logger 9 | 10 | 11 | class VectorStore: 12 | """ 13 | 专门管理 Faiss 索引,处理向量的增、删、查。 14 | """ 15 | 16 | def __init__(self, index_path: str, dimension: int): 17 | self.index_path = index_path 18 | self.dimension = dimension 19 | self.index: Optional[faiss.Index] = None 20 | # 注意:_load_index 现在是异步的,需要在初始化后手动调用 21 | # 或者创建一个异步的初始化方法 22 | 23 | async def _load_index(self): 24 | """ 25 | 从文件加载 Faiss 索引,如果不存在则创建一个新的。 26 | """ 27 | if os.path.exists(self.index_path): 28 | logger.info(f"Loading Faiss index from {self.index_path}") 29 | self.index = await asyncio.to_thread(faiss.read_index, self.index_path) 30 | else: 31 | logger.info(f"Creating new Faiss index. Dimension: {self.dimension}") 32 | # 使用 IndexFlatL2 作为基础索引,这是常用的欧氏距离索引 33 | base_index = faiss.IndexFlatL2(self.dimension) 34 | # 使用 IndexIDMap2 将我们的自定义整数 ID 映射到向量 35 | self.index = faiss.IndexIDMap2(base_index) 36 | 37 | async def save_index(self): 38 | """ 39 | 将当前索引状态保存到文件。 40 | """ 41 | logger.info(f"Saving Faiss index to {self.index_path}") 42 | await asyncio.to_thread(faiss.write_index, self.index, self.index_path) 43 | 44 | async def add(self, ids: List[int], embeddings: List[List[float]]): 45 | """ 46 | 将带有自定义 ID 的向量添加到索引中。 47 | """ 48 | if not ids: 49 | return 50 | # Faiss 需要 int64 类型的 ID 和 float32 类型的向量 51 | ids_np = np.array(ids, dtype=np.int64) 52 | embeddings_np = np.array(embeddings, dtype=np.float32) 53 | await asyncio.to_thread(self.index.add_with_ids, embeddings_np, ids_np) 54 | 55 | async def search( 56 | self, query_embedding: List[float], k: int 57 | ) -> Tuple[np.ndarray, np.ndarray]: 58 | """ 59 | 在索引中搜索最相似的 k 个向量。 60 | 返回 (距离数组, ID数组)。 61 | """ 62 | query_np = np.array([query_embedding], dtype=np.float32) 63 | # self.index.ntotal 是索引中的向量总数 64 | k = min(k, self.index.ntotal) 65 | if k == 0: 66 | return np.array([]), np.array([]) 67 | distances, ids = await asyncio.to_thread(self.index.search, query_np, k) 68 | return distances[0], ids[0] 69 | 70 | async def remove(self, ids_to_remove: List[int]): 71 | """ 72 | 从索引中移除指定 ID 的向量。 73 | """ 74 | if not ids_to_remove: 75 | return 76 | ids_np = np.array(ids_to_remove, dtype=np.int64) 77 | await asyncio.to_thread(self.index.remove_ids, ids_np) 78 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | 测试模块配置文件 4 | """ 5 | 6 | import sys 7 | import os 8 | from pathlib import Path 9 | 10 | # 添加项目根目录到Python路径 11 | project_root = Path(__file__).parent.parent 12 | sys.path.insert(0, str(project_root)) 13 | 14 | # 测试配置 15 | TEST_CONFIG = { 16 | "session_manager": { 17 | "max_sessions": 100, 18 | "session_ttl": 3600 19 | }, 20 | "recall_engine": { 21 | "retrieval_mode": "hybrid", 22 | "top_k": 5, 23 | "recall_strategy": "weighted" 24 | }, 25 | "reflection_engine": { 26 | "summary_trigger_rounds": 10, 27 | "importance_threshold": 0.5 28 | }, 29 | "forgetting_agent": { 30 | "enabled": True, 31 | "check_interval_hours": 24, 32 | "retention_days": 90 33 | }, 34 | "timezone_settings": { 35 | "timezone": "Asia/Shanghai" 36 | }, 37 | "fusion": { 38 | "strategy": "rrf", 39 | "rrf_k": 60, 40 | "dense_weight": 0.7, 41 | "sparse_weight": 0.3 42 | } 43 | } -------------------------------------------------------------------------------- /tests/unit/test_admin_handler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | test_admin_handler.py - 管理员处理器测试 4 | """ 5 | 6 | import pytest 7 | from unittest.mock import Mock, AsyncMock, patch 8 | 9 | from core.handlers.admin_handler import AdminHandler 10 | from tests.conftest import TEST_CONFIG 11 | 12 | 13 | class TestAdminHandler: 14 | """管理员处理器测试类""" 15 | 16 | def setup_method(self): 17 | """测试前设置""" 18 | self.mock_context = Mock() 19 | self.mock_faiss_manager = Mock() 20 | self.mock_forgetting_agent = Mock() 21 | self.mock_session_manager = Mock() 22 | self.handler = AdminHandler( 23 | self.mock_context, 24 | TEST_CONFIG, 25 | self.mock_faiss_manager, 26 | self.mock_forgetting_agent, 27 | self.mock_session_manager 28 | ) 29 | 30 | @pytest.mark.asyncio 31 | async def test_get_memory_status_success(self): 32 | """测试获取记忆库状态(成功)""" 33 | # 模拟数据库计数 34 | self.mock_faiss_manager.db.count_documents = AsyncMock(return_value=42) 35 | 36 | result = await self.handler.get_memory_status() 37 | 38 | assert result["success"] is True 39 | assert result["data"]["total_count"] == 42 40 | 41 | # 验证调用 42 | self.mock_faiss_manager.db.count_documents.assert_called_once() 43 | 44 | @pytest.mark.asyncio 45 | async def test_get_memory_status_no_manager(self): 46 | """测试没有管理器时获取状态""" 47 | handler = AdminHandler(self.mock_context, TEST_CONFIG, None, None, None) 48 | 49 | result = await handler.get_memory_status() 50 | 51 | assert result["success"] is False 52 | assert "记忆库尚未初始化" in result["message"] 53 | 54 | @pytest.mark.asyncio 55 | async def test_get_memory_status_exception(self): 56 | """测试获取记忆库状态异常处理""" 57 | self.mock_faiss_manager.db.count_documents = AsyncMock(side_effect=Exception("数据库错误")) 58 | 59 | result = await self.handler.get_memory_status() 60 | 61 | assert result["success"] is False 62 | assert "获取记忆库状态失败" in result["message"] 63 | 64 | @pytest.mark.asyncio 65 | async def test_delete_memory_success(self): 66 | """测试删除记忆(成功)""" 67 | self.mock_faiss_manager.delete_memories = AsyncMock() 68 | 69 | result = await self.handler.delete_memory(123) 70 | 71 | assert result["success"] is True 72 | assert "已成功删除 ID 为 123 的记忆" in result["message"] 73 | 74 | # 验证调用参数 75 | self.mock_faiss_manager.delete_memories.assert_called_once_with([123]) 76 | 77 | @pytest.mark.asyncio 78 | async def test_delete_memory_no_manager(self): 79 | """测试没有管理器时删除记忆""" 80 | handler = AdminHandler(self.mock_context, TEST_CONFIG, None, None, None) 81 | 82 | result = await handler.delete_memory(123) 83 | 84 | assert result["success"] is False 85 | assert "记忆库尚未初始化" in result["message"] 86 | 87 | @pytest.mark.asyncio 88 | async def test_delete_memory_exception(self): 89 | """测试删除记忆异常处理""" 90 | self.mock_faiss_manager.delete_memories = AsyncMock(side_effect=Exception("删除错误")) 91 | 92 | result = await self.handler.delete_memory(123) 93 | 94 | assert result["success"] is False 95 | assert "删除记忆时发生错误" in result["message"] 96 | 97 | @pytest.mark.asyncio 98 | async def test_run_forgetting_agent_success(self): 99 | """测试运行遗忘代理(成功)""" 100 | self.mock_forgetting_agent._prune_memories = AsyncMock() 101 | 102 | result = await self.handler.run_forgetting_agent() 103 | 104 | assert result["success"] is True 105 | assert "遗忘代理任务执行完毕" in result["message"] 106 | 107 | # 验证调用 108 | self.mock_forgetting_agent._prune_memories.assert_called_once() 109 | 110 | @pytest.mark.asyncio 111 | async def test_run_forgetting_agent_no_agent(self): 112 | """测试没有遗忘代理时运行""" 113 | handler = AdminHandler(self.mock_context, TEST_CONFIG, None, None, None) 114 | 115 | result = await handler.run_forgetting_agent() 116 | 117 | assert result["success"] is False 118 | assert "遗忘代理尚未初始化" in result["message"] 119 | 120 | @pytest.mark.asyncio 121 | async def test_run_forgetting_agent_exception(self): 122 | """测试运行遗忘代理异常处理""" 123 | self.mock_forgetting_agent._prune_memories = AsyncMock(side_effect=Exception("遗忘代理错误")) 124 | 125 | result = await self.handler.run_forgetting_agent() 126 | 127 | assert result["success"] is False 128 | assert "遗忘代理任务执行失败" in result["message"] 129 | 130 | @pytest.mark.asyncio 131 | async def test_set_search_mode_valid(self): 132 | """测试设置搜索模式(有效模式)""" 133 | result = await self.handler.set_search_mode("hybrid") 134 | 135 | assert result["success"] is True 136 | assert "检索模式已设置为: hybrid" in result["message"] 137 | 138 | @pytest.mark.asyncio 139 | async def test_set_search_mode_invalid(self): 140 | """测试设置搜索模式(无效模式)""" 141 | result = await self.handler.set_search_mode("invalid_mode") 142 | 143 | assert result["success"] is False 144 | assert "无效的模式" in result["message"] 145 | assert "hybrid, dense, sparse" in result["message"] 146 | 147 | @pytest.mark.asyncio 148 | async def test_get_config_summary_show(self): 149 | """测试获取配置摘要(显示)""" 150 | # 模拟会话管理器 151 | self.mock_session_manager.get_session_count = Mock(return_value=5) 152 | 153 | result = await self.handler.get_config_summary("show") 154 | 155 | assert result["success"] is True 156 | data = result["data"] 157 | 158 | # 验证各个配置部分 159 | assert "session_manager" in data 160 | assert "recall_engine" in data 161 | assert "reflection_engine" in data 162 | assert "forgetting_agent" in data 163 | 164 | # 验证具体配置值 165 | assert data["session_manager"]["max_sessions"] == 100 166 | assert data["session_manager"]["session_ttl"] == 3600 167 | assert data["session_manager"]["current_sessions"] == 5 168 | assert data["recall_engine"]["retrieval_mode"] == "hybrid" 169 | assert data["recall_engine"]["top_k"] == 5 170 | assert data["forgetting_agent"]["enabled"] is True 171 | 172 | # 验证调用 173 | self.mock_session_manager.get_session_count.assert_called_once() 174 | 175 | @pytest.mark.asyncio 176 | async def test_get_config_summary_validate_success(self): 177 | """测试获取配置摘要(验证成功)""" 178 | with patch('core.config_validator.validate_config') as mock_validate: 179 | mock_validate.return_value = None # 验证通过时不返回异常 180 | 181 | result = await self.handler.get_config_summary("validate") 182 | 183 | assert result["success"] is True 184 | assert "配置验证通过,所有参数均有效" in result["message"] 185 | 186 | # 验证调用 187 | mock_validate.assert_called_once_with(TEST_CONFIG) 188 | 189 | @pytest.mark.asyncio 190 | async def test_get_config_summary_validate_failure(self): 191 | """测试获取配置摘要(验证失败)""" 192 | with patch('core.config_validator.validate_config') as mock_validate: 193 | mock_validate.side_effect = ValueError("配置验证失败") 194 | 195 | result = await self.handler.get_config_summary("validate") 196 | 197 | assert result["success"] is False 198 | assert "配置验证失败" in result["message"] 199 | 200 | # 验证调用 201 | mock_validate.assert_called_once_with(TEST_CONFIG) 202 | 203 | @pytest.mark.asyncio 204 | async def test_get_config_summary_invalid_action(self): 205 | """测试获取配置摘要(无效动作)""" 206 | result = await self.handler.get_config_summary("invalid_action") 207 | 208 | assert result["success"] is False 209 | assert "无效的动作" in result["message"] 210 | assert "show" in result["message"] 211 | assert "validate" in result["message"] 212 | 213 | @pytest.mark.asyncio 214 | async def test_get_config_summary_show_exception(self): 215 | """测试获取配置摘要显示异常处理""" 216 | self.mock_session_manager.get_session_count = Mock(side_effect=Exception("配置错误")) 217 | 218 | result = await self.handler.get_config_summary("show") 219 | 220 | assert result["success"] is False 221 | assert "显示配置时发生错误" in result["message"] 222 | 223 | def test_format_status_for_display_success(self): 224 | """测试格式化状态显示(成功)""" 225 | mock_response = { 226 | "success": True, 227 | "data": {"total_count": 42} 228 | } 229 | 230 | result = self.handler.format_status_for_display(mock_response) 231 | 232 | assert "📊 LivingMemory 记忆库状态:" in result 233 | assert "- 总记忆数: 42" in result 234 | 235 | def test_format_status_for_display_failure(self): 236 | """测试格式化状态显示(失败)""" 237 | mock_response = { 238 | "success": False, 239 | "message": "获取失败" 240 | } 241 | 242 | result = self.handler.format_status_for_display(mock_response) 243 | 244 | assert result == "获取失败" 245 | 246 | def test_format_config_summary_for_display_success(self): 247 | """测试格式化配置摘要显示(成功)""" 248 | mock_response = { 249 | "success": True, 250 | "data": { 251 | "session_manager": { 252 | "max_sessions": 1000, 253 | "session_ttl": 3600, 254 | "current_sessions": 5 255 | }, 256 | "recall_engine": { 257 | "retrieval_mode": "hybrid", 258 | "top_k": 5, 259 | "recall_strategy": "weighted" 260 | }, 261 | "reflection_engine": { 262 | "summary_trigger_rounds": 10, 263 | "importance_threshold": 0.5 264 | }, 265 | "forgetting_agent": { 266 | "enabled": True, 267 | "check_interval_hours": 24, 268 | "retention_days": 90 269 | } 270 | } 271 | } 272 | 273 | result = self.handler.format_config_summary_for_display(mock_response) 274 | 275 | assert "📋 LivingMemory 配置摘要:" in result 276 | assert "🗂️ 会话管理:" in result 277 | assert "🧠 回忆引擎:" in result 278 | assert "💭 反思引擎:" in result 279 | assert "🗑️ 遗忘代理:" in result 280 | assert "最大会话数: 1000" in result 281 | assert "会话TTL: 3600秒" in result 282 | assert "当前会话数: 5" in result 283 | assert "检索模式: hybrid" in result 284 | assert "返回数量: 5" in result 285 | assert "启用状态: 是" in result 286 | assert "检查间隔: 24小时" in result 287 | assert "保留天数: 90天" in result 288 | 289 | def test_format_config_summary_for_display_failure(self): 290 | """测试格式化配置摘要显示(失败)""" 291 | mock_response = { 292 | "success": False, 293 | "message": "配置获取失败" 294 | } 295 | 296 | result = self.handler.format_config_summary_for_display(mock_response) 297 | 298 | assert result == "配置获取失败" 299 | 300 | def test_format_config_summary_for_display_missing_sections(self): 301 | """测试格式化配置摘要显示(缺少部分配置)""" 302 | mock_response = { 303 | "success": True, 304 | "data": { 305 | "session_manager": { 306 | "max_sessions": 1000, 307 | "session_ttl": 3600, 308 | "current_sessions": 5 309 | } 310 | # 缺少其他配置部分 311 | } 312 | } 313 | 314 | result = self.handler.format_config_summary_for_display(mock_response) 315 | 316 | assert "📋 LivingMemory 配置摘要:" in result 317 | assert "🗂️ 会话管理:" in result 318 | # 方法总是显示所有配置部分,使用默认值 319 | assert "🧠 回忆引擎:" in result 320 | assert "💭 反思引擎:" in result 321 | assert "🗑️ 遗忘代理:" in result -------------------------------------------------------------------------------- /tests/unit/test_base_handler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | test_base_handler.py - 基础处理器测试 4 | """ 5 | 6 | import pytest 7 | import json 8 | from datetime import datetime, timezone 9 | from unittest.mock import Mock, patch 10 | 11 | from core.handlers.base_handler import TestableBaseHandler 12 | from tests.conftest import TEST_CONFIG 13 | 14 | 15 | class TestBaseHandler: 16 | """基础处理器测试类""" 17 | 18 | def setup_method(self): 19 | """测试前设置""" 20 | self.mock_context = Mock() 21 | self.handler = TestableBaseHandler(self.mock_context, TEST_CONFIG) 22 | 23 | def test_init(self): 24 | """测试初始化""" 25 | assert self.handler.context == self.mock_context 26 | assert self.handler.config == TEST_CONFIG 27 | 28 | def test_create_response_success(self): 29 | """测试创建成功响应""" 30 | response = self.handler.create_response(True, "成功消息", {"data": "test"}) 31 | 32 | assert response["success"] is True 33 | assert response["message"] == "成功消息" 34 | assert response["data"] == {"data": "test"} 35 | 36 | def test_create_response_failure(self): 37 | """测试创建失败响应""" 38 | response = self.handler.create_response(False, "失败消息") 39 | 40 | assert response["success"] is False 41 | assert response["message"] == "失败消息" 42 | assert response["data"] is None 43 | 44 | def test_safe_parse_metadata_dict(self): 45 | """测试解析字典类型元数据""" 46 | metadata = {"key": "value", "number": 123} 47 | result = self.handler.safe_parse_metadata(metadata) 48 | 49 | assert result == metadata 50 | 51 | def test_safe_parse_metadata_json_string(self): 52 | """测试解析JSON字符串元数据""" 53 | metadata_str = '{"key": "value", "number": 123}' 54 | result = self.handler.safe_parse_metadata(metadata_str) 55 | 56 | assert result == {"key": "value", "number": 123} 57 | 58 | def test_safe_parse_metadata_invalid_json(self): 59 | """测试解析无效JSON字符串""" 60 | invalid_json = "{invalid json}" 61 | result = self.handler.safe_parse_metadata(invalid_json) 62 | 63 | assert result == {} 64 | 65 | def test_safe_parse_metadata_none(self): 66 | """测试解析None值""" 67 | result = self.handler.safe_parse_metadata(None) 68 | 69 | assert result == {} 70 | 71 | def test_format_timestamp_valid(self): 72 | """测试格式化有效时间戳""" 73 | timestamp = 1609459200.0 # 2021-01-01 00:00:00 UTC 74 | result = self.handler.format_timestamp(timestamp) 75 | 76 | assert "2021-01-01" in result 77 | # 由于时区转换,时间可能不是00:00:00 78 | assert len(result) > 10 79 | 80 | def test_format_timestamp_none(self): 81 | """测试格式化None时间戳""" 82 | result = self.handler.format_timestamp(None) 83 | 84 | assert result == "未知" 85 | 86 | def test_format_timestamp_invalid(self): 87 | """测试格式化无效时间戳""" 88 | result = self.handler.format_timestamp("invalid") 89 | 90 | assert result == "未知" 91 | 92 | def test_get_timezone(self): 93 | """测试获取时区""" 94 | with patch('pytz.timezone') as mock_timezone: 95 | mock_tz = Mock() 96 | mock_timezone.return_value = mock_tz 97 | 98 | result = self.handler.get_timezone() 99 | 100 | assert result == mock_tz 101 | mock_timezone.assert_called_once_with("Asia/Shanghai") 102 | 103 | def test_format_memory_card(self): 104 | """测试格式化记忆卡片""" 105 | # 创建模拟的记忆结果 106 | mock_result = Mock() 107 | mock_result.data = { 108 | "id": 1, 109 | "text": "测试记忆内容", 110 | "metadata": json.dumps({ 111 | "create_time": 1609459200.0, 112 | "last_access_time": 1609459200.0, 113 | "importance": 0.8, 114 | "event_type": "FACT" 115 | }) 116 | } 117 | mock_result.similarity = 0.95 118 | 119 | result = self.handler.format_memory_card(mock_result) 120 | 121 | assert "ID: 1" in result 122 | assert "记 忆 度: 0.95" in result 123 | assert "重 要 性: 0.80" in result 124 | assert "记忆类型: FACT" in result 125 | assert "测试记忆内容" in result -------------------------------------------------------------------------------- /tests/unit/test_memory_handler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | test_memory_handler.py - 记忆管理处理器测试 4 | """ 5 | 6 | import pytest 7 | import json 8 | from unittest.mock import Mock, AsyncMock, patch 9 | 10 | from core.handlers.memory_handler import MemoryHandler 11 | from tests.conftest import TEST_CONFIG 12 | 13 | 14 | class TestMemoryHandler: 15 | """记忆管理处理器测试类""" 16 | 17 | def setup_method(self): 18 | """测试前设置""" 19 | self.mock_context = Mock() 20 | self.mock_faiss_manager = Mock() 21 | self.handler = MemoryHandler(self.mock_context, TEST_CONFIG, self.mock_faiss_manager) 22 | 23 | @pytest.mark.asyncio 24 | async def test_edit_memory_content(self): 25 | """测试编辑记忆内容""" 26 | # 模拟faiss_manager.update_memory的返回值 27 | mock_result = { 28 | "success": True, 29 | "message": "更新成功", 30 | "updated_fields": ["content"], 31 | "memory_id": 123 32 | } 33 | self.mock_faiss_manager.update_memory = AsyncMock(return_value=mock_result) 34 | 35 | result = await self.handler.edit_memory("123", "content", "新的记忆内容", "测试更新") 36 | 37 | assert result["success"] is True 38 | assert "更新成功" in result["message"] 39 | 40 | # 验证调用参数 41 | self.mock_faiss_manager.update_memory.assert_called_once_with( 42 | memory_id=123, 43 | update_reason="测试更新", 44 | content="新的记忆内容" 45 | ) 46 | 47 | @pytest.mark.asyncio 48 | async def test_edit_memory_importance_valid(self): 49 | """测试编辑记忆重要性(有效值)""" 50 | mock_result = { 51 | "success": True, 52 | "message": "更新成功", 53 | "updated_fields": ["importance"] 54 | } 55 | self.mock_faiss_manager.update_memory = AsyncMock(return_value=mock_result) 56 | 57 | result = await self.handler.edit_memory("123", "importance", "0.9", "提高重要性") 58 | 59 | assert result["success"] is True 60 | self.mock_faiss_manager.update_memory.assert_called_once_with( 61 | memory_id=123, 62 | update_reason="提高重要性", 63 | importance=0.9 64 | ) 65 | 66 | @pytest.mark.asyncio 67 | async def test_edit_memory_importance_invalid_range(self): 68 | """测试编辑记忆重要性(无效范围)""" 69 | result = await self.handler.edit_memory("123", "importance", "1.5", "无效值") 70 | 71 | assert result["success"] is False 72 | assert "重要性评分必须在 0.0 到 1.0 之间" in result["message"] 73 | 74 | # 验证没有调用update_memory 75 | self.mock_faiss_manager.update_memory.assert_not_called() 76 | 77 | @pytest.mark.asyncio 78 | async def test_edit_memory_importance_invalid_type(self): 79 | """测试编辑记忆重要性(无效类型)""" 80 | result = await self.handler.edit_memory("123", "importance", "invalid", "非数字") 81 | 82 | assert result["success"] is False 83 | assert "重要性评分必须是数字" in result["message"] 84 | 85 | # 验证没有调用update_memory 86 | self.mock_faiss_manager.update_memory.assert_not_called() 87 | 88 | @pytest.mark.asyncio 89 | async def test_edit_memory_type_valid(self): 90 | """测试编辑记忆类型(有效值)""" 91 | mock_result = { 92 | "success": True, 93 | "message": "更新成功", 94 | "updated_fields": ["event_type"] 95 | } 96 | self.mock_faiss_manager.update_memory = AsyncMock(return_value=mock_result) 97 | 98 | result = await self.handler.edit_memory("123", "type", "PREFERENCE", "重新分类") 99 | 100 | assert result["success"] is True 101 | self.mock_faiss_manager.update_memory.assert_called_once_with( 102 | memory_id=123, 103 | update_reason="重新分类", 104 | event_type="PREFERENCE" 105 | ) 106 | 107 | @pytest.mark.asyncio 108 | async def test_edit_memory_type_invalid(self): 109 | """测试编辑记忆类型(无效值)""" 110 | result = await self.handler.edit_memory("123", "type", "INVALID_TYPE", "无效类型") 111 | 112 | assert result["success"] is False 113 | assert "无效的事件类型" in result["message"] 114 | 115 | # 验证没有调用update_memory 116 | self.mock_faiss_manager.update_memory.assert_not_called() 117 | 118 | @pytest.mark.asyncio 119 | async def test_edit_memory_status_valid(self): 120 | """测试编辑记忆状态(有效值)""" 121 | mock_result = { 122 | "success": True, 123 | "message": "更新成功", 124 | "updated_fields": ["status"] 125 | } 126 | self.mock_faiss_manager.update_memory = AsyncMock(return_value=mock_result) 127 | 128 | result = await self.handler.edit_memory("123", "status", "archived", "项目完成") 129 | 130 | assert result["success"] is True 131 | self.mock_faiss_manager.update_memory.assert_called_once_with( 132 | memory_id=123, 133 | update_reason="项目完成", 134 | status="archived" 135 | ) 136 | 137 | @pytest.mark.asyncio 138 | async def test_edit_memory_status_invalid(self): 139 | """测试编辑记忆状态(无效值)""" 140 | result = await self.handler.edit_memory("123", "status", "INVALID_STATUS", "无效状态") 141 | 142 | assert result["success"] is False 143 | assert "无效的状态" in result["message"] 144 | 145 | # 验证没有调用update_memory 146 | self.mock_faiss_manager.update_memory.assert_not_called() 147 | 148 | @pytest.mark.asyncio 149 | async def test_edit_memory_unknown_field(self): 150 | """测试编辑未知字段""" 151 | result = await self.handler.edit_memory("123", "unknown_field", "value", "未知字段") 152 | 153 | assert result["success"] is False 154 | assert "未知的字段" in result["message"] 155 | 156 | # 验证没有调用update_memory 157 | self.mock_faiss_manager.update_memory.assert_not_called() 158 | 159 | @pytest.mark.asyncio 160 | async def test_edit_memory_string_id(self): 161 | """测试使用字符串ID编辑记忆""" 162 | mock_result = { 163 | "success": True, 164 | "message": "更新成功", 165 | "updated_fields": ["content"] 166 | } 167 | self.mock_faiss_manager.update_memory = AsyncMock(return_value=mock_result) 168 | 169 | result = await self.handler.edit_memory("abc123", "content", "新内容", "字符串ID") 170 | 171 | assert result["success"] is True 172 | self.mock_faiss_manager.update_memory.assert_called_once_with( 173 | memory_id="abc123", 174 | update_reason="字符串ID", 175 | content="新内容" 176 | ) 177 | 178 | @pytest.mark.asyncio 179 | async def test_edit_memory_no_faiss_manager(self): 180 | """测试没有faiss_manager时的错误处理""" 181 | handler = MemoryHandler(self.mock_context, TEST_CONFIG, None) 182 | 183 | result = await handler.edit_memory("123", "content", "新内容") 184 | 185 | assert result["success"] is False 186 | assert "记忆库尚未初始化" in result["message"] 187 | 188 | @pytest.mark.asyncio 189 | async def test_edit_memory_exception(self): 190 | """测试编辑记忆时的异常处理""" 191 | self.mock_faiss_manager.update_memory = AsyncMock(side_effect=Exception("数据库错误")) 192 | 193 | result = await self.handler.edit_memory("123", "content", "新内容") 194 | 195 | assert result["success"] is False 196 | assert "编辑记忆时发生错误" in result["message"] 197 | 198 | @pytest.mark.asyncio 199 | async def test_get_memory_details_success(self): 200 | """测试获取记忆详细信息(成功)""" 201 | # 模拟数据库查询结果 202 | mock_docs = [{ 203 | "id": 123, 204 | "content": "测试记忆内容", 205 | "metadata": json.dumps({ 206 | "create_time": 1609459200.0, 207 | "last_access_time": 1609459200.0, 208 | "importance": 0.8, 209 | "event_type": "FACT", 210 | "status": "active" 211 | }) 212 | }] 213 | 214 | self.mock_faiss_manager.db.document_storage.get_documents = AsyncMock(return_value=mock_docs) 215 | 216 | result = await self.handler.get_memory_details("123") 217 | 218 | assert result["success"] is True 219 | data = result["data"] 220 | assert data["id"] == "123" 221 | assert data["content"] == "测试记忆内容" 222 | assert data["importance"] == 0.8 223 | assert data["event_type"] == "FACT" 224 | assert data["status"] == "active" 225 | 226 | @pytest.mark.asyncio 227 | async def test_get_memory_details_not_found(self): 228 | """测试获取不存在的记忆详细信息""" 229 | self.mock_faiss_manager.db.document_storage.get_documents = AsyncMock(return_value=[]) 230 | 231 | result = await self.handler.get_memory_details("999") 232 | 233 | assert result["success"] is False 234 | assert "未找到ID为 999 的记忆" in result["message"] 235 | 236 | @pytest.mark.asyncio 237 | async def test_get_memory_history_success(self): 238 | """测试获取记忆历史(成功)""" 239 | mock_docs = [{ 240 | "id": 123, 241 | "content": "测试记忆内容", 242 | "metadata": json.dumps({ 243 | "create_time": 1609459200.0, 244 | "importance": 0.8, 245 | "event_type": "FACT", 246 | "status": "active", 247 | "update_history": [ 248 | { 249 | "timestamp": 1609459200.0, 250 | "reason": "初始创建", 251 | "fields": ["content", "importance"] 252 | } 253 | ] 254 | }) 255 | }] 256 | 257 | self.mock_faiss_manager.db.document_storage.get_documents = AsyncMock(return_value=mock_docs) 258 | 259 | result = await self.handler.get_memory_history("123") 260 | 261 | assert result["success"] is True 262 | data = result["data"] 263 | assert len(data["update_history"]) == 1 264 | assert data["update_history"][0]["reason"] == "初始创建" 265 | 266 | def test_format_memory_details_for_display_success(self): 267 | """测试格式化记忆详细信息显示(成功)""" 268 | mock_response = { 269 | "success": True, 270 | "data": { 271 | "id": "123", 272 | "content": "测试记忆内容", 273 | "importance": 0.8, 274 | "event_type": "FACT", 275 | "status": "active", 276 | "create_time": "2021-01-01 00:00:00", 277 | "last_access_time": "2021-01-01 00:00:00", 278 | "update_history": [] 279 | } 280 | } 281 | 282 | result = self.handler.format_memory_details_for_display(mock_response) 283 | 284 | assert "📝 记忆 123 的详细信息:" in result 285 | assert "测试记忆内容" in result 286 | assert "重要性: 0.8" in result 287 | assert "类型: FACT" in result 288 | assert "状态: active" in result 289 | 290 | def test_format_memory_details_for_display_failure(self): 291 | """测试格式化记忆详细信息显示(失败)""" 292 | mock_response = { 293 | "success": False, 294 | "message": "获取失败" 295 | } 296 | 297 | result = self.handler.format_memory_details_for_display(mock_response) 298 | 299 | assert result == "获取失败" 300 | 301 | def test_format_memory_history_for_display_success(self): 302 | """测试格式化记忆历史显示(成功)""" 303 | mock_response = { 304 | "success": True, 305 | "data": { 306 | "id": "123", 307 | "content": "测试记忆内容", 308 | "metadata": { 309 | "importance": 0.8, 310 | "event_type": "FACT", 311 | "status": "active", 312 | "create_time": "2021-01-01 00:00:00" 313 | }, 314 | "update_history": [ 315 | { 316 | "timestamp": 1609459200.0, 317 | "reason": "初始创建", 318 | "fields": ["content"] 319 | } 320 | ] 321 | } 322 | } 323 | 324 | result = self.handler.format_memory_history_for_display(mock_response) 325 | 326 | assert "📝 记忆 123 的详细信息:" in result 327 | assert "测试记忆内容" in result 328 | assert "🔄 更新历史 (1 次):" in result 329 | assert "初始创建" in result 330 | 331 | def test_format_memory_history_for_display_no_history(self): 332 | """测试格式化记忆历史显示(无历史记录)""" 333 | mock_response = { 334 | "success": True, 335 | "data": { 336 | "id": "123", 337 | "content": "测试记忆内容", 338 | "metadata": { 339 | "importance": 0.8, 340 | "event_type": "FACT", 341 | "status": "active", 342 | "create_time": "2021-01-01 00:00:00" 343 | }, 344 | "update_history": [] 345 | } 346 | } 347 | 348 | result = self.handler.format_memory_history_for_display(mock_response) 349 | 350 | assert "🔄 暂无更新记录" in result -------------------------------------------------------------------------------- /tests/unit/test_search_handler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | test_search_handler.py - 搜索管理处理器测试 4 | """ 5 | 6 | import pytest 7 | from unittest.mock import Mock, AsyncMock 8 | 9 | from core.handlers.search_handler import SearchHandler 10 | from tests.conftest import TEST_CONFIG 11 | 12 | 13 | class TestSearchHandler: 14 | """搜索管理处理器测试类""" 15 | 16 | def setup_method(self): 17 | """测试前设置""" 18 | self.mock_context = Mock() 19 | self.mock_recall_engine = Mock() 20 | self.mock_sparse_retriever = Mock() 21 | self.handler = SearchHandler( 22 | self.mock_context, 23 | TEST_CONFIG, 24 | self.mock_recall_engine, 25 | self.mock_sparse_retriever 26 | ) 27 | 28 | @pytest.mark.asyncio 29 | async def test_search_memories_success(self): 30 | """测试搜索记忆(成功)""" 31 | # 模拟搜索结果 32 | mock_results = [ 33 | Mock( 34 | data={"id": 1, "text": "测试记忆1", "metadata": '{"importance": 0.8}'}, 35 | similarity=0.9 36 | ), 37 | Mock( 38 | data={"id": 2, "text": "测试记忆2", "metadata": '{"importance": 0.6}'}, 39 | similarity=0.7 40 | ) 41 | ] 42 | 43 | self.mock_recall_engine.recall = AsyncMock(return_value=mock_results) 44 | 45 | result = await self.handler.search_memories("测试查询", k=5) 46 | 47 | assert result["success"] is True 48 | assert "为您找到 2 条相关记忆" in result["message"] 49 | assert len(result["data"]) == 2 50 | assert result["data"][0]["id"] == 1 51 | assert result["data"][0]["similarity"] == 0.9 52 | assert result["data"][1]["id"] == 2 53 | assert result["data"][1]["similarity"] == 0.7 54 | 55 | # 验证调用参数 56 | self.mock_recall_engine.recall.assert_called_once_with( 57 | self.mock_context, "测试查询", k=5 58 | ) 59 | 60 | @pytest.mark.asyncio 61 | async def test_search_memories_no_results(self): 62 | """测试搜索记忆(无结果)""" 63 | self.mock_recall_engine.recall = AsyncMock(return_value=[]) 64 | 65 | result = await self.handler.search_memories("无结果查询", k=5) 66 | 67 | assert result["success"] is True 68 | assert "未能找到与 '无结果查询' 相关的记忆" in result["message"] 69 | assert result["data"] == [] 70 | 71 | @pytest.mark.asyncio 72 | async def test_search_memories_no_recall_engine(self): 73 | """测试没有回忆引擎时的搜索""" 74 | handler = SearchHandler(self.mock_context, TEST_CONFIG, None, None) 75 | 76 | result = await handler.search_memories("测试查询") 77 | 78 | assert result["success"] is False 79 | assert "回忆引擎尚未初始化" in result["message"] 80 | 81 | @pytest.mark.asyncio 82 | async def test_search_memories_exception(self): 83 | """测试搜索记忆时的异常处理""" 84 | self.mock_recall_engine.recall = AsyncMock(side_effect=Exception("搜索错误")) 85 | 86 | result = await self.handler.search_memories("测试查询") 87 | 88 | assert result["success"] is False 89 | assert "搜索记忆时发生错误" in result["message"] 90 | 91 | @pytest.mark.asyncio 92 | async def test_test_sparse_search_success(self): 93 | """测试稀疏检索测试(成功)""" 94 | # 模拟稀疏检索结果 95 | mock_results = [ 96 | Mock( 97 | doc_id=1, 98 | score=0.8, 99 | content="稀疏检索结果1", 100 | metadata={"event_type": "FACT", "importance": 0.7} 101 | ), 102 | Mock( 103 | doc_id=2, 104 | score=0.6, 105 | content="稀疏检索结果2", 106 | metadata={"event_type": "PREFERENCE", "importance": 0.9} 107 | ) 108 | ] 109 | 110 | self.mock_sparse_retriever.search = AsyncMock(return_value=mock_results) 111 | 112 | result = await self.handler.test_sparse_search("测试查询", k=5) 113 | 114 | assert result["success"] is True 115 | assert "找到 2 条稀疏检索结果" in result["message"] 116 | assert len(result["data"]) == 2 117 | assert result["data"][0]["doc_id"] == 1 118 | assert result["data"][0]["score"] == 0.8 119 | assert result["data"][1]["doc_id"] == 2 120 | assert result["data"][1]["score"] == 0.6 121 | 122 | # 验证调用参数 123 | self.mock_sparse_retriever.search.assert_called_once_with( 124 | query="测试查询", limit=5 125 | ) 126 | 127 | @pytest.mark.asyncio 128 | async def test_test_sparse_search_no_results(self): 129 | """测试稀疏检索测试(无结果)""" 130 | self.mock_sparse_retriever.search = AsyncMock(return_value=[]) 131 | 132 | result = await self.handler.test_sparse_search("无结果查询", k=5) 133 | 134 | assert result["success"] is True 135 | assert "未找到与 '无结果查询' 相关的记忆" in result["message"] 136 | assert result["data"] == [] 137 | 138 | @pytest.mark.asyncio 139 | async def test_test_sparse_search_no_retriever(self): 140 | """测试没有稀疏检索器时的测试""" 141 | handler = SearchHandler(self.mock_context, TEST_CONFIG, None, None) 142 | 143 | result = await handler.test_sparse_search("测试查询") 144 | 145 | assert result["success"] is False 146 | assert "稀疏检索器未启用" in result["message"] 147 | 148 | @pytest.mark.asyncio 149 | async def test_test_sparse_search_exception(self): 150 | """测试稀疏检索测试时的异常处理""" 151 | self.mock_sparse_retriever.search = AsyncMock(side_effect=Exception("稀疏检索错误")) 152 | 153 | result = await self.handler.test_sparse_search("测试查询") 154 | 155 | assert result["success"] is False 156 | assert "稀疏检索测试失败" in result["message"] 157 | 158 | @pytest.mark.asyncio 159 | async def test_rebuild_sparse_index_success(self): 160 | """测试重建稀疏索引(成功)""" 161 | self.mock_sparse_retriever.rebuild_index = AsyncMock() 162 | 163 | result = await self.handler.rebuild_sparse_index() 164 | 165 | assert result["success"] is True 166 | assert "稀疏检索索引重建完成" in result["message"] 167 | 168 | # 验证调用了重建方法 169 | self.mock_sparse_retriever.rebuild_index.assert_called_once() 170 | 171 | @pytest.mark.asyncio 172 | async def test_rebuild_sparse_index_no_retriever(self): 173 | """测试没有稀疏检索器时重建索引""" 174 | handler = SearchHandler(self.mock_context, TEST_CONFIG, None, None) 175 | 176 | result = await handler.rebuild_sparse_index() 177 | 178 | assert result["success"] is False 179 | assert "稀疏检索器未启用" in result["message"] 180 | 181 | @pytest.mark.asyncio 182 | async def test_rebuild_sparse_index_exception(self): 183 | """测试重建稀索引导常处理""" 184 | self.mock_sparse_retriever.rebuild_index = AsyncMock(side_effect=Exception("重建错误")) 185 | 186 | result = await self.handler.rebuild_sparse_index() 187 | 188 | assert result["success"] is False 189 | assert "重建稀疏索引失败" in result["message"] 190 | 191 | def test_format_search_results_for_display_success(self): 192 | """测试格式化搜索结果显示(成功)""" 193 | mock_response = { 194 | "success": True, 195 | "message": "为您找到 2 条相关记忆", 196 | "data": [ 197 | { 198 | "id": 1, 199 | "similarity": 0.9, 200 | "text": "测试记忆1", 201 | "metadata": { 202 | "create_time": 1609459200.0, 203 | "last_access_time": 1609459200.0, 204 | "importance": 0.8, 205 | "event_type": "FACT" 206 | } 207 | }, 208 | { 209 | "id": 2, 210 | "similarity": 0.7, 211 | "text": "测试记忆2", 212 | "metadata": { 213 | "create_time": 1609459200.0, 214 | "last_access_time": 1609459200.0, 215 | "importance": 0.6, 216 | "event_type": "PREFERENCE" 217 | } 218 | } 219 | ] 220 | } 221 | 222 | result = self.handler.format_search_results_for_display(mock_response) 223 | 224 | assert "为您找到 2 条相关记忆" in result 225 | assert "ID: 1" in result 226 | assert "记 忆 度: 0.90" in result 227 | assert "重 要 性: 0.80" in result 228 | assert "记忆类型: FACT" in result 229 | assert "测试记忆1" in result 230 | assert "ID: 2" in result 231 | assert "测试记忆2" in result 232 | 233 | def test_format_search_results_for_display_failure(self): 234 | """测试格式化搜索结果显示(失败)""" 235 | mock_response = { 236 | "success": False, 237 | "message": "搜索失败" 238 | } 239 | 240 | result = self.handler.format_search_results_for_display(mock_response) 241 | 242 | assert result == "搜索失败" 243 | 244 | def test_format_sparse_results_for_display_success(self): 245 | """测试格式化稀疏检索结果显示(成功)""" 246 | mock_response = { 247 | "success": True, 248 | "message": "找到 2 条稀疏检索结果", 249 | "data": [ 250 | { 251 | "doc_id": 1, 252 | "score": 0.8, 253 | "content": "稀疏检索结果1", 254 | "metadata": { 255 | "event_type": "FACT", 256 | "importance": 0.7 257 | } 258 | }, 259 | { 260 | "doc_id": 2, 261 | "score": 0.6, 262 | "content": "稀疏检索结果2很长很长很长很长很长很长很长很长很长很长很长很长", 263 | "metadata": { 264 | "event_type": "PREFERENCE", 265 | "importance": 0.9 266 | } 267 | } 268 | ] 269 | } 270 | 271 | result = self.handler.format_sparse_results_for_display(mock_response) 272 | 273 | assert "找到 2 条稀疏检索结果" in result 274 | assert "1. [ID: 1] Score: 0.800" in result 275 | assert "类型: FACT" in result 276 | assert "重要性: 0.70" in result 277 | assert "2. [ID: 2] Score: 0.600" in result 278 | assert "..." in result or len(result) < 500 # 长内容被截断或整个结果显示 279 | assert "类型: PREFERENCE" in result 280 | assert "重要性: 0.90" in result 281 | 282 | def test_format_sparse_results_for_display_failure(self): 283 | """测试格式化稀疏检索结果显示(失败)""" 284 | mock_response = { 285 | "success": False, 286 | "message": "稀疏检索失败" 287 | } 288 | 289 | result = self.handler.format_sparse_results_for_display(mock_response) 290 | 291 | assert result == "稀疏检索失败" 292 | 293 | def test_format_sparse_results_for_display_no_metadata(self): 294 | """测试格式化稀疏检索结果显示(无元数据)""" 295 | mock_response = { 296 | "success": True, 297 | "message": "找到 1 条稀疏检索结果", 298 | "data": [ 299 | { 300 | "doc_id": 1, 301 | "score": 0.8, 302 | "content": "稀疏检索结果", 303 | "metadata": {} 304 | } 305 | ] 306 | } 307 | 308 | result = self.handler.format_sparse_results_for_display(mock_response) 309 | 310 | assert "找到 1 条稀疏检索结果" in result 311 | assert "1. [ID: 1] Score: 0.800" in result 312 | assert "稀疏检索结果" in result 313 | # 不应该包含类型和重要性信息 314 | assert "类型:" not in result 315 | assert "重要性:" not in result --------------------------------------------------------------------------------