├── .gitignore ├── __init__.py ├── main.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | .vscode 4 | __pycache__/ 5 | venv* 6 | *.pyc 7 | config.json 8 | nohup.out 9 | tmp 10 | *.db -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .main import * -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | 3 | import json 4 | import os,re 5 | import time 6 | from bot import bot_factory 7 | from bridge.bridge import Bridge 8 | from bridge.context import ContextType 9 | from bridge.reply import Reply, ReplyType 10 | from channel.chat_channel import check_contain, check_prefix 11 | from channel.chat_message import ChatMessage 12 | from config import conf 13 | import plugins 14 | from plugins import * 15 | from common.log import logger 16 | from common import const 17 | import sqlite3 18 | from chatgpt_tool_hub.chains.llm import LLMChain 19 | from chatgpt_tool_hub.models import build_model_params 20 | from chatgpt_tool_hub.models.model_factory import ModelFactory 21 | from chatgpt_tool_hub.prompts import PromptTemplate 22 | TRANSLATE_PROMPT = ''' 23 | You are now the following python function: 24 | ```# {{translate text to commands}}" 25 | def translate_text(text: str) -> str: 26 | ``` 27 | Only respond with your `return` value, Don't reply anything else. 28 | 29 | Commands: 30 | {{Summary chat logs}}: "summary", args: {{("duration_in_seconds"): , ("count"): }} 31 | {{Do Nothing}}:"do_nothing", args: {{}} 32 | 33 | argument in brackets means optional argument. 34 | 35 | You should only respond in JSON format as described below. 36 | Response Format: 37 | {{ 38 | "name": "command name", 39 | "args": {{"arg name": "value"}} 40 | }} 41 | Ensure the response can be parsed by Python json.loads. 42 | 43 | Input: {input} 44 | ''' 45 | def find_json(json_string): 46 | json_pattern = re.compile(r"\{[\s\S]*\}") 47 | json_match = json_pattern.search(json_string) 48 | if json_match: 49 | json_string = json_match.group(0) 50 | else: 51 | json_string = "" 52 | return json_string 53 | @plugins.register(name="summary", desire_priority=-1, desc="A simple plugin to summary messages", version="0.3.2", author="lanvent") 54 | class Summary(Plugin): 55 | def __init__(self): 56 | super().__init__() 57 | 58 | curdir = os.path.dirname(__file__) 59 | db_path = os.path.join(curdir, "chat.db") 60 | self.conn = sqlite3.connect(db_path, check_same_thread=False) 61 | c = self.conn.cursor() 62 | c.execute('''CREATE TABLE IF NOT EXISTS chat_records 63 | (sessionid TEXT, msgid INTEGER, user TEXT, content TEXT, type TEXT, timestamp INTEGER, is_triggered INTEGER, 64 | PRIMARY KEY (sessionid, msgid))''') 65 | 66 | # 后期增加了is_triggered字段,这里做个过渡,这段代码某天会删除 67 | c = c.execute("PRAGMA table_info(chat_records);") 68 | column_exists = False 69 | for column in c.fetchall(): 70 | logger.debug("[Summary] column: {}" .format(column)) 71 | if column[1] == 'is_triggered': 72 | column_exists = True 73 | break 74 | if not column_exists: 75 | self.conn.execute("ALTER TABLE chat_records ADD COLUMN is_triggered INTEGER DEFAULT 0;") 76 | self.conn.execute("UPDATE chat_records SET is_triggered = 0;") 77 | 78 | self.conn.commit() 79 | 80 | btype = Bridge().btype['chat'] 81 | if btype not in [const.OPEN_AI, const.CHATGPT, const.CHATGPTONAZURE, const.LINKAI]: 82 | raise Exception("[Summary] init failed, not supported bot type") 83 | self.bot = bot_factory.create_bot(Bridge().btype['chat']) 84 | self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context 85 | self.handlers[Event.ON_RECEIVE_MESSAGE] = self.on_receive_message 86 | logger.info("[Summary] inited") 87 | 88 | def _insert_record(self, session_id, msg_id, user, content, msg_type, timestamp, is_triggered = 0): 89 | c = self.conn.cursor() 90 | logger.debug("[Summary] insert record: {} {} {} {} {} {} {}" .format(session_id, msg_id, user, content, msg_type, timestamp, is_triggered)) 91 | c.execute("INSERT OR REPLACE INTO chat_records VALUES (?,?,?,?,?,?,?)", (session_id, msg_id, user, content, msg_type, timestamp, is_triggered)) 92 | self.conn.commit() 93 | 94 | def _get_records(self, session_id, start_timestamp=0, limit=9999): 95 | c = self.conn.cursor() 96 | c.execute("SELECT * FROM chat_records WHERE sessionid=? and timestamp>? ORDER BY timestamp DESC LIMIT ?", (session_id, start_timestamp, limit)) 97 | return c.fetchall() 98 | 99 | def on_receive_message(self, e_context: EventContext): 100 | context = e_context['context'] 101 | cmsg : ChatMessage = e_context['context']['msg'] 102 | username = None 103 | session_id = cmsg.from_user_id 104 | if conf().get('channel_type', 'wx') == 'wx' and cmsg.from_user_nickname is not None: 105 | session_id = cmsg.from_user_nickname # itchat channel id会变动,只好用群名作为session id 106 | 107 | if context.get("isgroup", False): 108 | username = cmsg.actual_user_nickname 109 | if username is None: 110 | username = cmsg.actual_user_id 111 | else: 112 | username = cmsg.from_user_nickname 113 | if username is None: 114 | username = cmsg.from_user_id 115 | 116 | is_triggered = False 117 | content = context.content 118 | if context.get("isgroup", False): # 群聊 119 | # 校验关键字 120 | match_prefix = check_prefix(content, conf().get('group_chat_prefix')) 121 | match_contain = check_contain(content, conf().get('group_chat_keyword')) 122 | if match_prefix is not None or match_contain is not None: 123 | is_triggered = True 124 | if context['msg'].is_at and not conf().get("group_at_off", False): 125 | is_triggered = True 126 | else: # 单聊 127 | match_prefix = check_prefix(content, conf().get('single_chat_prefix',[''])) 128 | if match_prefix is not None: 129 | is_triggered = True 130 | 131 | self._insert_record(session_id, cmsg.msg_id, username, context.content, str(context.type), cmsg.create_time, int(is_triggered)) 132 | # logger.debug("[Summary] {}:{} ({})" .format(username, context.content, session_id)) 133 | 134 | def _translate_text_to_commands(self, text): 135 | llm = ModelFactory().create_llm_model(**build_model_params({ 136 | "openai_api_key": conf().get("open_ai_api_key", ""), 137 | "proxy": conf().get("proxy", ""), 138 | })) 139 | 140 | prompt = PromptTemplate( 141 | input_variables=["input"], 142 | template=TRANSLATE_PROMPT, 143 | ) 144 | bot = LLMChain(llm=llm, prompt=prompt) 145 | content = bot.run(text) 146 | return content 147 | 148 | def _check_tokens(self, records, max_tokens=3600): 149 | query = "" 150 | for record in records[::-1]: 151 | username = record[2] 152 | content = record[3] 153 | is_triggered = record[6] 154 | if record[4] in [str(ContextType.IMAGE),str(ContextType.VOICE)]: 155 | content = f"[{record[4]}]" 156 | 157 | sentence = "" 158 | sentence += f'{username}' + ": \"" + content + "\"" 159 | if is_triggered: 160 | sentence += " " 161 | query += "\n\n"+sentence 162 | prompt = "你是一位群聊机器人,需要对聊天记录进行简明扼要的总结,用列表的形式输出。\n聊天记录格式:[x]是emoji表情或者是对图片和声音文件的说明,消息最后出现表示消息触发了群聊机器人的回复,内容通常是提问,若带有特殊符号如#和$则是触发你无法感知的某个插件功能,聊天记录中不包含你对这类消息的回复,可降低这些消息的权重。请不要在回复中包含聊天记录格式中出现的符号。\n" 163 | 164 | firstmsg_id = records[0][1] 165 | session = self.bot.sessions.build_session(firstmsg_id, prompt) 166 | 167 | session.add_query("需要你总结的聊天记录如下:%s"%query) 168 | if session.calc_tokens() > max_tokens: 169 | # logger.debug("[Summary] summary failed, tokens: %d" % session.calc_tokens()) 170 | return None 171 | return session 172 | 173 | def _split_messages_to_summarys(self, records, max_tokens_persession=3600 , max_summarys=8): 174 | summarys = [] 175 | count = 0 176 | self.bot.args["max_tokens"] = 400 177 | while len(records) > 0 and len(summarys) < max_summarys: 178 | session = self._check_tokens(records,max_tokens_persession) 179 | last = 0 180 | if session is None: 181 | left,right = 0, len(records) 182 | while left < right: 183 | mid = (left + right) // 2 184 | logger.debug("[Summary] left: %d, right: %d, mid: %d" % (left, right, mid)) 185 | session = self._check_tokens(records[:mid], max_tokens_persession) 186 | if session is None: 187 | right = mid - 1 188 | else: 189 | left = mid + 1 190 | session = self._check_tokens(records[:left-1], max_tokens_persession) 191 | last = left 192 | logger.debug("[Summary] summary %d messages" % (left)) 193 | else: 194 | last = len(records) 195 | logger.debug("[Summary] summary all %d messages" % (len(records))) 196 | if session is None: 197 | logger.debug("[Summary] summary failed, session is None") 198 | break 199 | logger.debug("[Summary] session query: %s, prompt_tokens: %d" % (session.messages, session.calc_tokens())) 200 | result = self.bot.reply_text(session) 201 | total_tokens, completion_tokens, reply_content = result['total_tokens'], result['completion_tokens'], result['content'] 202 | logger.debug("[Summary] total_tokens: %d, completion_tokens: %d, reply_content: %s" % (total_tokens, completion_tokens, reply_content)) 203 | if completion_tokens == 0: 204 | if len(summarys) == 0: 205 | return count,reply_content 206 | else: 207 | break 208 | summary = reply_content 209 | summarys.append(summary) 210 | records = records[last:] 211 | count += last 212 | return count,summarys 213 | 214 | 215 | def on_handle_context(self, e_context: EventContext): 216 | 217 | if e_context['context'].type != ContextType.TEXT: 218 | return 219 | 220 | content = e_context['context'].content 221 | logger.debug("[Summary] on_handle_context. content: %s" % content) 222 | trigger_prefix = conf().get('plugin_trigger_prefix', "$") 223 | clist = content.split() 224 | if clist[0].startswith(trigger_prefix): 225 | limit = 99 226 | duration = -1 227 | 228 | if "总结" in clist[0]: 229 | flag = False 230 | if clist[0] == trigger_prefix+"总结": 231 | flag = True 232 | if len(clist) > 1: 233 | try: 234 | limit = int(clist[1]) 235 | logger.debug("[Summary] limit: %d" % limit) 236 | except Exception as e: 237 | flag = False 238 | if not flag: 239 | text = content.split(trigger_prefix,maxsplit=1)[1] 240 | try: 241 | command_json = find_json(self._translate_text_to_commands(text)) 242 | command = json.loads(command_json) 243 | name = command["name"] 244 | if name.lower() == "summary": 245 | limit = int(command["args"].get("count", 99)) 246 | if limit < 0: 247 | limit = 299 248 | duration = int(command["args"].get("duration_in_seconds", -1)) 249 | logger.debug("[Summary] limit: %d, duration: %d seconds" % (limit, duration)) 250 | except Exception as e: 251 | logger.error("[Summary] translate failed: %s" % e) 252 | return 253 | else: 254 | return 255 | 256 | start_time = int(time.time()) 257 | if duration > 0: 258 | start_time = start_time - duration 259 | else: 260 | start_time = 0 261 | 262 | 263 | 264 | msg:ChatMessage = e_context['context']['msg'] 265 | session_id = msg.from_user_id 266 | if conf().get('channel_type', 'wx') == 'wx' and msg.from_user_nickname is not None: 267 | session_id = msg.from_user_nickname # itchat channel id会变动,只好用名字作为session id 268 | records = self._get_records(session_id, start_time, limit) 269 | for i in range(len(records)): 270 | record=list(records[i]) 271 | content = record[3] 272 | clist = re.split(r'\n- - - - - - - - -.*?\n', content) 273 | if len(clist) > 1: 274 | record[3] = clist[1] 275 | records[i] = tuple(record) 276 | if len(records) <= 1: 277 | reply = Reply(ReplyType.INFO, "无聊天记录可供总结") 278 | e_context['reply'] = reply 279 | e_context.action = EventAction.BREAK_PASS 280 | return 281 | 282 | max_tokens_persession = 3600 283 | 284 | count, summarys = self._split_messages_to_summarys(records, max_tokens_persession) 285 | if count == 0 : 286 | if isinstance(summarys,str): 287 | reply = Reply(ReplyType.ERROR, summarys) 288 | else: 289 | reply = Reply(ReplyType.ERROR, "总结聊天记录失败") 290 | e_context['reply'] = reply 291 | e_context.action = EventAction.BREAK_PASS 292 | return 293 | 294 | 295 | if len(summarys) == 1: 296 | reply = Reply(ReplyType.TEXT, f"本次总结了{count}条消息。\n\n"+summarys[0]) 297 | e_context['reply'] = reply 298 | e_context.action = EventAction.BREAK_PASS 299 | return 300 | 301 | self.bot.args["max_tokens"] = None 302 | query = "" 303 | for i,summary in enumerate(reversed(summarys)): 304 | query += summary + "\n----------------\n\n" 305 | prompt = "你是一位群聊机器人,聊天记录已经在你的大脑中被你总结成多段摘要总结,你需要对它们进行摘要总结,最后输出一篇完整的摘要总结,用列表的形式输出。\n" 306 | 307 | session = self.bot.sessions.build_session(session_id, prompt) 308 | session.add_query(query) 309 | result = self.bot.reply_text(session) 310 | total_tokens, completion_tokens, reply_content = result['total_tokens'], result['completion_tokens'], result['content'] 311 | logger.debug("[Summary] total_tokens: %d, completion_tokens: %d, reply_content: %s" % (total_tokens, completion_tokens, reply_content)) 312 | if completion_tokens == 0: 313 | reply = Reply(ReplyType.ERROR, "合并摘要失败,"+reply_content+"\n原始多段摘要如下:\n"+query) 314 | else: 315 | reply = Reply(ReplyType.TEXT, f"本次总结了{count}条消息。\n\n"+reply_content) 316 | e_context['reply'] = reply 317 | e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 318 | 319 | 320 | def get_help_text(self, verbose = False, **kwargs): 321 | help_text = "聊天记录总结插件。\n" 322 | if not verbose: 323 | return help_text 324 | trigger_prefix = conf().get('plugin_trigger_prefix', "$") 325 | help_text += f"使用方法:输入\"{trigger_prefix}总结 最近消息数量\",我会帮助你总结聊天记录。\n例如:\"{trigger_prefix}总结 100\",我会总结最近100条消息。\n\n你也可以直接输入\"{trigger_prefix}总结前99条信息\"或\"{trigger_prefix}总结3小时内的最近10条消息\"\n我会尽可能理解你的指令。" 326 | return help_text 327 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tiktoken>=0.3.2 2 | --extra-index-url https://pypi.python.org/simple 3 | chatgpt_tool_hub>=0.3.10 --------------------------------------------------------------------------------