├── .flake8 ├── .github ├── ISSUE_TEMPLATE.md └── workflows │ └── deploy-image.yml ├── .gitignore ├── .pre-commit-config.yaml ├── Dockerfile ├── LICENSE ├── README.md ├── app.py ├── bot ├── baidu │ └── baidu_unit_bot.py ├── bot.py ├── bot_factory.py ├── chatgpt │ ├── chat_gpt_bot.py │ └── chat_gpt_session.py ├── openai │ ├── open_ai_bot.py │ ├── open_ai_image.py │ └── open_ai_session.py └── session_manager.py ├── bridge ├── bridge.py ├── context.py └── reply.py ├── channel ├── channel.py ├── channel_factory.py ├── chat_channel.py ├── chat_message.py ├── terminal │ └── terminal_channel.py ├── wechat │ ├── wechat_channel.py │ ├── wechat_message.py │ ├── wechaty_channel.py │ └── wechaty_message.py └── wechatmp │ ├── README.md │ ├── active_reply.py │ ├── common.py │ ├── passive_reply.py │ ├── passive_reply_message.py │ ├── wechatmp_channel.py │ ├── wechatmp_client.py │ └── wechatmp_message.py ├── common ├── const.py ├── dequeue.py ├── expired_dict.py ├── log.py ├── package_manager.py ├── singleton.py ├── sorted_dict.py ├── time_check.py ├── tmp_dir.py └── token_bucket.py ├── config-template.json ├── config.py ├── docker ├── Dockerfile.alpine ├── Dockerfile.debian ├── Dockerfile.debian.latest ├── Dockerfile.latest ├── build.alpine.sh ├── build.debian.sh ├── build.latest.sh ├── chatgpt-on-wechat-voice-reply │ ├── Dockerfile.alpine │ ├── Dockerfile.debian │ ├── docker-compose.yaml │ └── entrypoint.sh ├── docker-compose.yaml ├── entrypoint.sh └── sample-chatgpt-on-wechat │ ├── .env │ ├── Makefile │ └── Name ├── docs └── images │ ├── group-chat-sample.jpg │ ├── image-create-sample.jpg │ ├── planet.jpg │ └── single-chat-sample.jpg ├── lib └── itchat │ ├── __init__.py │ ├── async_components │ ├── __init__.py │ ├── contact.py │ ├── hotreload.py │ ├── login.py │ ├── messages.py │ └── register.py │ ├── components │ ├── __init__.py │ ├── contact.py │ ├── hotreload.py │ ├── login.py │ ├── messages.py │ └── register.py │ ├── config.py │ ├── content.py │ ├── core.py │ ├── log.py │ ├── returnvalues.py │ ├── storage │ ├── __init__.py │ ├── messagequeue.py │ └── templates.py │ └── utils.py ├── nixpacks.toml ├── plugins ├── README.md ├── __init__.py ├── banwords │ ├── .gitignore │ ├── README.md │ ├── __init__.py │ ├── banwords.py │ ├── banwords.txt.template │ ├── config.json.template │ └── lib │ │ └── WordsSearch.py ├── bdunit │ ├── README.md │ ├── __init__.py │ ├── bdunit.py │ └── config.json.template ├── dungeon │ ├── README.md │ ├── __init__.py │ └── dungeon.py ├── event.py ├── finish │ ├── __init__.py │ └── finish.py ├── godcmd │ ├── README.md │ ├── __init__.py │ ├── config.json.template │ └── godcmd.py ├── hello │ ├── __init__.py │ └── hello.py ├── keyword │ ├── README.md │ ├── __init__.py │ ├── config.json.template │ ├── keyword.py │ └── test-keyword.png ├── plugin.py ├── plugin_manager.py ├── role │ ├── README.md │ ├── __init__.py │ ├── role.py │ └── roles.json ├── source.json └── tool │ ├── README.md │ ├── __init__.py │ ├── config.json.template │ └── tool.py ├── requirements-optional.txt ├── requirements.txt ├── scripts ├── shutdown.sh ├── start.sh └── tout.sh └── voice ├── audio_convert.py ├── azure ├── azure_voice.py └── config.json.template ├── baidu ├── README.md ├── baidu_voice.py └── config.json.template ├── google └── google_voice.py ├── openai └── openai_voice.py ├── pytts └── pytts_voice.py ├── voice.py └── voice_factory.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | select = E303,W293,W291,W292,E305,E231,E302 4 | exclude = 5 | .tox, 6 | __pycache__, 7 | *.pyc, 8 | .env 9 | venv/* 10 | .venv/* 11 | reports/* 12 | dist/* 13 | lib/* -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ### 前置确认 2 | 3 | 1. 网络能够访问openai接口 4 | 2. python 已安装:版本在 3.7 ~ 3.10 之间 5 | 3. `git pull` 拉取最新代码 6 | 4. 执行`pip3 install -r requirements.txt`,检查依赖是否满足 7 | 5. 拓展功能请执行`pip3 install -r requirements-optional.txt`,检查依赖是否满足 8 | 6. 在已有 issue 中未搜索到类似问题 9 | 7. [FAQS](https://github.com/zhayujie/chatgpt-on-wechat/wiki/FAQs) 中无类似问题 10 | 11 | 12 | ### 问题描述 13 | 14 | > 简要说明、截图、复现步骤等,也可以是需求或想法 15 | 16 | 17 | 18 | 19 | ### 终端日志 (如有报错) 20 | 21 | ``` 22 | [在此处粘贴终端日志, 可在主目录下`run.log`文件中找到] 23 | ``` 24 | 25 | 26 | 27 | ### 环境 28 | 29 | - 操作系统类型 (Mac/Windows/Linux): 30 | - Python版本 ( 执行 `python3 -V` ): 31 | - pip版本 ( 依赖问题此项必填,执行 `pip3 -V`): 32 | -------------------------------------------------------------------------------- /.github/workflows/deploy-image.yml: -------------------------------------------------------------------------------- 1 | # This workflow uses actions that are not certified by GitHub. 2 | # They are provided by a third-party and are governed by 3 | # separate terms of service, privacy policy, and support 4 | # documentation. 5 | 6 | # GitHub recommends pinning actions to a commit SHA. 7 | # To get a newer version, you will need to update the SHA. 8 | # You can also reference a tag or branch, but the action may change without warning. 9 | 10 | name: Create and publish a Docker image 11 | 12 | on: 13 | push: 14 | branches: ['master'] 15 | create: 16 | env: 17 | REGISTRY: ghcr.io 18 | IMAGE_NAME: ${{ github.repository }} 19 | 20 | jobs: 21 | build-and-push-image: 22 | runs-on: ubuntu-latest 23 | permissions: 24 | contents: read 25 | packages: write 26 | 27 | steps: 28 | - name: Checkout repository 29 | uses: actions/checkout@v3 30 | 31 | - name: Log in to the Container registry 32 | uses: docker/login-action@v2 33 | with: 34 | registry: ${{ env.REGISTRY }} 35 | username: ${{ github.actor }} 36 | password: ${{ secrets.GITHUB_TOKEN }} 37 | 38 | - name: Extract metadata (tags, labels) for Docker 39 | id: meta 40 | uses: docker/metadata-action@v4 41 | with: 42 | images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} 43 | 44 | - name: Build and push Docker image 45 | uses: docker/build-push-action@v3 46 | with: 47 | context: . 48 | push: true 49 | file: ./docker/Dockerfile.latest 50 | tags: ${{ steps.meta.outputs.tags }} 51 | labels: ${{ steps.meta.outputs.labels }} 52 | 53 | - uses: actions/delete-package-versions@v4 54 | with: 55 | package-name: 'chatgpt-on-wechat' 56 | package-type: 'container' 57 | min-versions-to-keep: 10 58 | delete-only-untagged-versions: 'true' 59 | token: ${{ secrets.GITHUB_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | .vscode 4 | .wechaty/ 5 | __pycache__/ 6 | venv* 7 | *.pyc 8 | config.json 9 | QR.png 10 | nohup.out 11 | tmp 12 | plugins.json 13 | itchat.pkl 14 | *.log 15 | user_datas.pkl 16 | chatgpt_tool_hub/ 17 | plugins/**/ 18 | !plugins/bdunit 19 | !plugins/dungeon 20 | !plugins/finish 21 | !plugins/godcmd 22 | !plugins/tool 23 | !plugins/banwords 24 | !plugins/banwords/**/ 25 | !plugins/hello 26 | !plugins/role 27 | !plugins/keyword -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | - id: fix-byte-order-marker 6 | - id: check-case-conflict 7 | - id: check-merge-conflict 8 | - id: debug-statements 9 | - id: pretty-format-json 10 | types: [text] 11 | files: \.json(.template)?$ 12 | args: [ --autofix , --no-ensure-ascii, --indent=2, --no-sort-keys] 13 | - id: trailing-whitespace 14 | exclude: '(\/|^)lib\/' 15 | args: [ --markdown-linebreak-ext=md ] 16 | - repo: https://github.com/PyCQA/isort 17 | rev: 5.12.0 18 | hooks: 19 | - id: isort 20 | exclude: '(\/|^)lib\/' 21 | - repo: https://github.com/psf/black 22 | rev: 23.3.0 23 | hooks: 24 | - id: black 25 | exclude: '(\/|^)lib\/' 26 | - repo: https://github.com/PyCQA/flake8 27 | rev: 6.0.0 28 | hooks: 29 | - id: flake8 30 | exclude: '(\/|^)lib\/' -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ghcr.io/zhayujie/chatgpt-on-wechat:latest 2 | 3 | ENTRYPOINT ["/entrypoint.sh"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022 zhayujie 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | 3 | import os 4 | import signal 5 | import sys 6 | 7 | from channel import channel_factory 8 | from common.log import logger 9 | from config import conf, load_config 10 | from plugins import * 11 | 12 | 13 | def sigterm_handler_wrap(_signo): 14 | old_handler = signal.getsignal(_signo) 15 | 16 | def func(_signo, _stack_frame): 17 | logger.info("signal {} received, exiting...".format(_signo)) 18 | conf().save_user_datas() 19 | if callable(old_handler): # check old_handler 20 | return old_handler(_signo, _stack_frame) 21 | sys.exit(0) 22 | signal.signal(_signo, func) 23 | 24 | 25 | def run(): 26 | try: 27 | # load config 28 | load_config() 29 | # ctrl + c 30 | sigterm_handler_wrap(signal.SIGINT) 31 | # kill signal 32 | sigterm_handler_wrap(signal.SIGTERM) 33 | 34 | # create channel 35 | channel_name = conf().get("channel_type", "wx") 36 | 37 | if "--cmd" in sys.argv: 38 | channel_name = "terminal" 39 | 40 | if channel_name == "wxy": 41 | os.environ["WECHATY_LOG"] = "warn" 42 | # os.environ['WECHATY_PUPPET_SERVICE_ENDPOINT'] = '127.0.0.1:9001' 43 | 44 | channel = channel_factory.create_channel(channel_name) 45 | if channel_name in ["wx", "wxy", "terminal", "wechatmp", "wechatmp_service"]: 46 | PluginManager().load_plugins() 47 | 48 | # startup channel 49 | channel.startup() 50 | except Exception as e: 51 | logger.error("App startup failed!") 52 | logger.exception(e) 53 | 54 | 55 | if __name__ == "__main__": 56 | run() 57 | -------------------------------------------------------------------------------- /bot/baidu/baidu_unit_bot.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | 3 | import requests 4 | 5 | from bot.bot import Bot 6 | from bridge.reply import Reply, ReplyType 7 | 8 | 9 | # Baidu Unit对话接口 (可用, 但能力较弱) 10 | class BaiduUnitBot(Bot): 11 | def reply(self, query, context=None): 12 | token = self.get_token() 13 | url = ( 14 | "https://aip.baidubce.com/rpc/2.0/unit/service/v3/chat?access_token=" 15 | + token 16 | ) 17 | post_data = ( 18 | '{"version":"3.0","service_id":"S73177","session_id":"","log_id":"7758521","skill_ids":["1221886"],"request":{"terminal_id":"88888","query":"' 19 | + query 20 | + '", "hyper_params": {"chat_custom_bot_profile": 1}}}' 21 | ) 22 | print(post_data) 23 | headers = {"content-type": "application/x-www-form-urlencoded"} 24 | response = requests.post(url, data=post_data.encode(), headers=headers) 25 | if response: 26 | reply = Reply( 27 | ReplyType.TEXT, 28 | response.json()["result"]["context"]["SYS_PRESUMED_HIST"][1], 29 | ) 30 | return reply 31 | 32 | def get_token(self): 33 | access_key = "YOUR_ACCESS_KEY" 34 | secret_key = "YOUR_SECRET_KEY" 35 | host = ( 36 | "https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=" 37 | + access_key 38 | + "&client_secret=" 39 | + secret_key 40 | ) 41 | response = requests.get(host) 42 | if response: 43 | print(response.json()) 44 | return response.json()["access_token"] 45 | -------------------------------------------------------------------------------- /bot/bot.py: -------------------------------------------------------------------------------- 1 | """ 2 | Auto-replay chat robot abstract class 3 | """ 4 | 5 | 6 | from bridge.context import Context 7 | from bridge.reply import Reply 8 | 9 | 10 | class Bot(object): 11 | def reply(self, query, context: Context = None) -> Reply: 12 | """ 13 | bot auto-reply content 14 | :param req: received message 15 | :return: reply content 16 | """ 17 | raise NotImplementedError 18 | -------------------------------------------------------------------------------- /bot/bot_factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | channel factory 3 | """ 4 | from common import const 5 | 6 | 7 | def create_bot(bot_type): 8 | """ 9 | create a bot_type instance 10 | :param bot_type: bot type code 11 | :return: bot instance 12 | """ 13 | if bot_type == const.BAIDU: 14 | # Baidu Unit对话接口 15 | from bot.baidu.baidu_unit_bot import BaiduUnitBot 16 | 17 | return BaiduUnitBot() 18 | 19 | elif bot_type == const.CHATGPT: 20 | # ChatGPT 网页端web接口 21 | from bot.chatgpt.chat_gpt_bot import ChatGPTBot 22 | 23 | return ChatGPTBot() 24 | 25 | elif bot_type == const.OPEN_AI: 26 | # OpenAI 官方对话模型API 27 | from bot.openai.open_ai_bot import OpenAIBot 28 | 29 | return OpenAIBot() 30 | 31 | elif bot_type == const.CHATGPTONAZURE: 32 | # Azure chatgpt service https://azure.microsoft.com/en-in/products/cognitive-services/openai-service/ 33 | from bot.chatgpt.chat_gpt_bot import AzureChatGPTBot 34 | 35 | return AzureChatGPTBot() 36 | raise RuntimeError 37 | -------------------------------------------------------------------------------- /bot/chatgpt/chat_gpt_bot.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | 3 | import time 4 | 5 | import openai 6 | import openai.error 7 | 8 | from bot.bot import Bot 9 | from bot.chatgpt.chat_gpt_session import ChatGPTSession 10 | from bot.openai.open_ai_image import OpenAIImage 11 | from bot.session_manager import SessionManager 12 | from bridge.context import ContextType 13 | from bridge.reply import Reply, ReplyType 14 | from common.log import logger 15 | from common.token_bucket import TokenBucket 16 | from config import conf, load_config 17 | 18 | 19 | # OpenAI对话模型API (可用) 20 | class ChatGPTBot(Bot, OpenAIImage): 21 | def __init__(self): 22 | super().__init__() 23 | # set the default api_key 24 | openai.api_key = conf().get("open_ai_api_key") 25 | if conf().get("open_ai_api_base"): 26 | openai.api_base = conf().get("open_ai_api_base") 27 | proxy = conf().get("proxy") 28 | if proxy: 29 | openai.proxy = proxy 30 | if conf().get("rate_limit_chatgpt"): 31 | self.tb4chatgpt = TokenBucket(conf().get("rate_limit_chatgpt", 20)) 32 | 33 | self.sessions = SessionManager( 34 | ChatGPTSession, model=conf().get("model") or "gpt-3.5-turbo" 35 | ) 36 | self.args = { 37 | "model": conf().get("model") or "gpt-3.5-turbo", # 对话模型的名称 38 | "temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 39 | # "max_tokens":4096, # 回复最大的字符数 40 | "top_p": 1, 41 | "frequency_penalty": conf().get( 42 | "frequency_penalty", 0.0 43 | ), # [-2,2]之间,该值越大则更倾向于产生不同的内容 44 | "presence_penalty": conf().get( 45 | "presence_penalty", 0.0 46 | ), # [-2,2]之间,该值越大则更倾向于产生不同的内容 47 | "request_timeout": conf().get( 48 | "request_timeout", None 49 | ), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 50 | "timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试 51 | } 52 | 53 | def reply(self, query, context=None): 54 | # acquire reply content 55 | if context.type == ContextType.TEXT: 56 | logger.info("[CHATGPT] query={}".format(query)) 57 | 58 | session_id = context["session_id"] 59 | reply = None 60 | clear_memory_commands = conf().get("clear_memory_commands", ["#清除记忆"]) 61 | if query in clear_memory_commands: 62 | self.sessions.clear_session(session_id) 63 | reply = Reply(ReplyType.INFO, "记忆已清除") 64 | elif query == "#清除所有": 65 | self.sessions.clear_all_session() 66 | reply = Reply(ReplyType.INFO, "所有人记忆已清除") 67 | elif query == "#更新配置": 68 | load_config() 69 | reply = Reply(ReplyType.INFO, "配置已更新") 70 | if reply: 71 | return reply 72 | session = self.sessions.session_query(query, session_id) 73 | logger.debug("[CHATGPT] session query={}".format(session.messages)) 74 | 75 | api_key = context.get("openai_api_key") 76 | 77 | # if context.get('stream'): 78 | # # reply in stream 79 | # return self.reply_text_stream(query, new_query, session_id) 80 | 81 | reply_content = self.reply_text(session, api_key) 82 | logger.debug( 83 | "[CHATGPT] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( 84 | session.messages, 85 | session_id, 86 | reply_content["content"], 87 | reply_content["completion_tokens"], 88 | ) 89 | ) 90 | if ( 91 | reply_content["completion_tokens"] == 0 92 | and len(reply_content["content"]) > 0 93 | ): 94 | reply = Reply(ReplyType.ERROR, reply_content["content"]) 95 | elif reply_content["completion_tokens"] > 0: 96 | self.sessions.session_reply( 97 | reply_content["content"], session_id, reply_content["total_tokens"] 98 | ) 99 | reply = Reply(ReplyType.TEXT, reply_content["content"]) 100 | else: 101 | reply = Reply(ReplyType.ERROR, reply_content["content"]) 102 | logger.debug("[CHATGPT] reply {} used 0 tokens.".format(reply_content)) 103 | return reply 104 | 105 | elif context.type == ContextType.IMAGE_CREATE: 106 | ok, retstring = self.create_img(query, 0) 107 | reply = None 108 | if ok: 109 | reply = Reply(ReplyType.IMAGE_URL, retstring) 110 | else: 111 | reply = Reply(ReplyType.ERROR, retstring) 112 | return reply 113 | else: 114 | reply = Reply(ReplyType.ERROR, "Bot不支持处理{}类型的消息".format(context.type)) 115 | return reply 116 | 117 | def reply_text(self, session: ChatGPTSession, api_key=None, retry_count=0) -> dict: 118 | """ 119 | call openai's ChatCompletion to get the answer 120 | :param session: a conversation session 121 | :param session_id: session id 122 | :param retry_count: retry count 123 | :return: {} 124 | """ 125 | try: 126 | if conf().get("rate_limit_chatgpt") and not self.tb4chatgpt.get_token(): 127 | raise openai.error.RateLimitError("RateLimitError: rate limit exceeded") 128 | # if api_key == None, the default openai.api_key will be used 129 | response = openai.ChatCompletion.create( 130 | api_key=api_key, messages=session.messages, **self.args 131 | ) 132 | # logger.info("[ChatGPT] reply={}, total_tokens={}".format(response.choices[0]['message']['content'], response["usage"]["total_tokens"])) 133 | return { 134 | "total_tokens": response["usage"]["total_tokens"], 135 | "completion_tokens": response["usage"]["completion_tokens"], 136 | "content": response.choices[0]["message"]["content"], 137 | } 138 | except Exception as e: 139 | need_retry = retry_count < 2 140 | result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} 141 | if isinstance(e, openai.error.RateLimitError): 142 | logger.warn("[CHATGPT] RateLimitError: {}".format(e)) 143 | result["content"] = "提问太快啦,请休息一下再问我吧" 144 | if need_retry: 145 | time.sleep(20) 146 | elif isinstance(e, openai.error.Timeout): 147 | logger.warn("[CHATGPT] Timeout: {}".format(e)) 148 | result["content"] = "我没有收到你的消息" 149 | if need_retry: 150 | time.sleep(5) 151 | elif isinstance(e, openai.error.APIConnectionError): 152 | logger.warn("[CHATGPT] APIConnectionError: {}".format(e)) 153 | need_retry = False 154 | result["content"] = "我连接不到你的网络" 155 | else: 156 | logger.warn("[CHATGPT] Exception: {}".format(e)) 157 | need_retry = False 158 | self.sessions.clear_session(session.session_id) 159 | 160 | if need_retry: 161 | logger.warn("[CHATGPT] 第{}次重试".format(retry_count + 1)) 162 | return self.reply_text(session, api_key, retry_count + 1) 163 | else: 164 | return result 165 | 166 | 167 | class AzureChatGPTBot(ChatGPTBot): 168 | def __init__(self): 169 | super().__init__() 170 | openai.api_type = "azure" 171 | openai.api_version = "2023-03-15-preview" 172 | self.args["deployment_id"] = conf().get("azure_deployment_id") 173 | -------------------------------------------------------------------------------- /bot/chatgpt/chat_gpt_session.py: -------------------------------------------------------------------------------- 1 | from bot.session_manager import Session 2 | from common.log import logger 3 | 4 | """ 5 | e.g. [ 6 | {"role": "system", "content": "You are a helpful assistant."}, 7 | {"role": "user", "content": "Who won the world series in 2020?"}, 8 | {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."}, 9 | {"role": "user", "content": "Where was it played?"} 10 | ] 11 | """ 12 | 13 | 14 | class ChatGPTSession(Session): 15 | def __init__(self, session_id, system_prompt=None, model="gpt-3.5-turbo"): 16 | super().__init__(session_id, system_prompt) 17 | self.model = model 18 | self.reset() 19 | 20 | def discard_exceeding(self, max_tokens, cur_tokens=None): 21 | precise = True 22 | try: 23 | cur_tokens = self.calc_tokens() 24 | except Exception as e: 25 | precise = False 26 | if cur_tokens is None: 27 | raise e 28 | logger.debug( 29 | "Exception when counting tokens precisely for query: {}".format(e) 30 | ) 31 | while cur_tokens > max_tokens: 32 | if len(self.messages) > 2: 33 | self.messages.pop(1) 34 | elif len(self.messages) == 2 and self.messages[1]["role"] == "assistant": 35 | self.messages.pop(1) 36 | if precise: 37 | cur_tokens = self.calc_tokens() 38 | else: 39 | cur_tokens = cur_tokens - max_tokens 40 | break 41 | elif len(self.messages) == 2 and self.messages[1]["role"] == "user": 42 | logger.warn( 43 | "user message exceed max_tokens. total_tokens={}".format(cur_tokens) 44 | ) 45 | break 46 | else: 47 | logger.debug( 48 | "max_tokens={}, total_tokens={}, len(messages)={}".format( 49 | max_tokens, cur_tokens, len(self.messages) 50 | ) 51 | ) 52 | break 53 | if precise: 54 | cur_tokens = self.calc_tokens() 55 | else: 56 | cur_tokens = cur_tokens - max_tokens 57 | return cur_tokens 58 | 59 | def calc_tokens(self): 60 | return num_tokens_from_messages(self.messages, self.model) 61 | 62 | 63 | # refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb 64 | def num_tokens_from_messages(messages, model): 65 | """Returns the number of tokens used by a list of messages.""" 66 | import tiktoken 67 | 68 | try: 69 | encoding = tiktoken.encoding_for_model(model) 70 | except KeyError: 71 | logger.debug("Warning: model not found. Using cl100k_base encoding.") 72 | encoding = tiktoken.get_encoding("cl100k_base") 73 | if model == "gpt-3.5-turbo" or model == "gpt-35-turbo": 74 | return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301") 75 | elif model == "gpt-4": 76 | return num_tokens_from_messages(messages, model="gpt-4-0314") 77 | elif model == "gpt-3.5-turbo-0301": 78 | tokens_per_message = ( 79 | 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n 80 | ) 81 | tokens_per_name = -1 # if there's a name, the role is omitted 82 | elif model == "gpt-4-0314": 83 | tokens_per_message = 3 84 | tokens_per_name = 1 85 | else: 86 | logger.warn( 87 | f"num_tokens_from_messages() is not implemented for model {model}. Returning num tokens assuming gpt-3.5-turbo-0301." 88 | ) 89 | return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0301") 90 | num_tokens = 0 91 | for message in messages: 92 | num_tokens += tokens_per_message 93 | for key, value in message.items(): 94 | num_tokens += len(encoding.encode(value)) 95 | if key == "name": 96 | num_tokens += tokens_per_name 97 | num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> 98 | return num_tokens 99 | -------------------------------------------------------------------------------- /bot/openai/open_ai_bot.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | 3 | import time 4 | 5 | import openai 6 | import openai.error 7 | 8 | from bot.bot import Bot 9 | from bot.openai.open_ai_image import OpenAIImage 10 | from bot.openai.open_ai_session import OpenAISession 11 | from bot.session_manager import SessionManager 12 | from bridge.context import ContextType 13 | from bridge.reply import Reply, ReplyType 14 | from common.log import logger 15 | from config import conf 16 | 17 | user_session = dict() 18 | 19 | 20 | # OpenAI对话模型API (可用) 21 | class OpenAIBot(Bot, OpenAIImage): 22 | def __init__(self): 23 | super().__init__() 24 | openai.api_key = conf().get("open_ai_api_key") 25 | if conf().get("open_ai_api_base"): 26 | openai.api_base = conf().get("open_ai_api_base") 27 | proxy = conf().get("proxy") 28 | if proxy: 29 | openai.proxy = proxy 30 | 31 | self.sessions = SessionManager( 32 | OpenAISession, model=conf().get("model") or "text-davinci-003" 33 | ) 34 | self.args = { 35 | "model": conf().get("model") or "text-davinci-003", # 对话模型的名称 36 | "temperature": conf().get("temperature", 0.9), # 值在[0,1]之间,越大表示回复越具有不确定性 37 | "max_tokens": 1200, # 回复最大的字符数 38 | "top_p": 1, 39 | "frequency_penalty": conf().get( 40 | "frequency_penalty", 0.0 41 | ), # [-2,2]之间,该值越大则更倾向于产生不同的内容 42 | "presence_penalty": conf().get( 43 | "presence_penalty", 0.0 44 | ), # [-2,2]之间,该值越大则更倾向于产生不同的内容 45 | "request_timeout": conf().get( 46 | "request_timeout", None 47 | ), # 请求超时时间,openai接口默认设置为600,对于难问题一般需要较长时间 48 | "timeout": conf().get("request_timeout", None), # 重试超时时间,在这个时间内,将会自动重试 49 | "stop": ["\n\n\n"], 50 | } 51 | 52 | def reply(self, query, context=None): 53 | # acquire reply content 54 | if context and context.type: 55 | if context.type == ContextType.TEXT: 56 | logger.info("[OPEN_AI] query={}".format(query)) 57 | session_id = context["session_id"] 58 | reply = None 59 | if query == "#清除记忆": 60 | self.sessions.clear_session(session_id) 61 | reply = Reply(ReplyType.INFO, "记忆已清除") 62 | elif query == "#清除所有": 63 | self.sessions.clear_all_session() 64 | reply = Reply(ReplyType.INFO, "所有人记忆已清除") 65 | else: 66 | session = self.sessions.session_query(query, session_id) 67 | result = self.reply_text(session) 68 | total_tokens, completion_tokens, reply_content = ( 69 | result["total_tokens"], 70 | result["completion_tokens"], 71 | result["content"], 72 | ) 73 | logger.debug( 74 | "[OPEN_AI] new_query={}, session_id={}, reply_cont={}, completion_tokens={}".format( 75 | str(session), session_id, reply_content, completion_tokens 76 | ) 77 | ) 78 | 79 | if total_tokens == 0: 80 | reply = Reply(ReplyType.ERROR, reply_content) 81 | else: 82 | self.sessions.session_reply( 83 | reply_content, session_id, total_tokens 84 | ) 85 | reply = Reply(ReplyType.TEXT, reply_content) 86 | return reply 87 | elif context.type == ContextType.IMAGE_CREATE: 88 | ok, retstring = self.create_img(query, 0) 89 | reply = None 90 | if ok: 91 | reply = Reply(ReplyType.IMAGE_URL, retstring) 92 | else: 93 | reply = Reply(ReplyType.ERROR, retstring) 94 | return reply 95 | 96 | def reply_text(self, session: OpenAISession, retry_count=0): 97 | try: 98 | response = openai.Completion.create(prompt=str(session), **self.args) 99 | res_content = ( 100 | response.choices[0]["text"].strip().replace("<|endoftext|>", "") 101 | ) 102 | total_tokens = response["usage"]["total_tokens"] 103 | completion_tokens = response["usage"]["completion_tokens"] 104 | logger.info("[OPEN_AI] reply={}".format(res_content)) 105 | return { 106 | "total_tokens": total_tokens, 107 | "completion_tokens": completion_tokens, 108 | "content": res_content, 109 | } 110 | except Exception as e: 111 | need_retry = retry_count < 2 112 | result = {"completion_tokens": 0, "content": "我现在有点累了,等会再来吧"} 113 | if isinstance(e, openai.error.RateLimitError): 114 | logger.warn("[OPEN_AI] RateLimitError: {}".format(e)) 115 | result["content"] = "提问太快啦,请休息一下再问我吧" 116 | if need_retry: 117 | time.sleep(20) 118 | elif isinstance(e, openai.error.Timeout): 119 | logger.warn("[OPEN_AI] Timeout: {}".format(e)) 120 | result["content"] = "我没有收到你的消息" 121 | if need_retry: 122 | time.sleep(5) 123 | elif isinstance(e, openai.error.APIConnectionError): 124 | logger.warn("[OPEN_AI] APIConnectionError: {}".format(e)) 125 | need_retry = False 126 | result["content"] = "我连接不到你的网络" 127 | else: 128 | logger.warn("[OPEN_AI] Exception: {}".format(e)) 129 | need_retry = False 130 | self.sessions.clear_session(session.session_id) 131 | 132 | if need_retry: 133 | logger.warn("[OPEN_AI] 第{}次重试".format(retry_count + 1)) 134 | return self.reply_text(session, retry_count + 1) 135 | else: 136 | return result 137 | -------------------------------------------------------------------------------- /bot/openai/open_ai_image.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import openai 4 | import openai.error 5 | 6 | from common.log import logger 7 | from common.token_bucket import TokenBucket 8 | from config import conf 9 | 10 | 11 | # OPENAI提供的画图接口 12 | class OpenAIImage(object): 13 | def __init__(self): 14 | openai.api_key = conf().get("open_ai_api_key") 15 | if conf().get("rate_limit_dalle"): 16 | self.tb4dalle = TokenBucket(conf().get("rate_limit_dalle", 50)) 17 | 18 | def create_img(self, query, retry_count=0): 19 | try: 20 | if conf().get("rate_limit_dalle") and not self.tb4dalle.get_token(): 21 | return False, "请求太快了,请休息一下再问我吧" 22 | logger.info("[OPEN_AI] image_query={}".format(query)) 23 | response = openai.Image.create( 24 | prompt=query, # 图片描述 25 | n=1, # 每次生成图片的数量 26 | size=conf().get( 27 | "image_create_size", "256x256" 28 | ), # 图片大小,可选有 256x256, 512x512, 1024x1024 29 | ) 30 | image_url = response["data"][0]["url"] 31 | logger.info("[OPEN_AI] image_url={}".format(image_url)) 32 | return True, image_url 33 | except openai.error.RateLimitError as e: 34 | logger.warn(e) 35 | if retry_count < 1: 36 | time.sleep(5) 37 | logger.warn( 38 | "[OPEN_AI] ImgCreate RateLimit exceed, 第{}次重试".format( 39 | retry_count + 1 40 | ) 41 | ) 42 | return self.create_img(query, retry_count + 1) 43 | else: 44 | return False, "提问太快啦,请休息一下再问我吧" 45 | except Exception as e: 46 | logger.exception(e) 47 | return False, str(e) 48 | -------------------------------------------------------------------------------- /bot/openai/open_ai_session.py: -------------------------------------------------------------------------------- 1 | from bot.session_manager import Session 2 | from common.log import logger 3 | 4 | 5 | class OpenAISession(Session): 6 | def __init__(self, session_id, system_prompt=None, model="text-davinci-003"): 7 | super().__init__(session_id, system_prompt) 8 | self.model = model 9 | self.reset() 10 | 11 | def __str__(self): 12 | # 构造对话模型的输入 13 | """ 14 | e.g. Q: xxx 15 | A: xxx 16 | Q: xxx 17 | """ 18 | prompt = "" 19 | for item in self.messages: 20 | if item["role"] == "system": 21 | prompt += item["content"] + "<|endoftext|>\n\n\n" 22 | elif item["role"] == "user": 23 | prompt += "Q: " + item["content"] + "\n" 24 | elif item["role"] == "assistant": 25 | prompt += "\n\nA: " + item["content"] + "<|endoftext|>\n" 26 | 27 | if len(self.messages) > 0 and self.messages[-1]["role"] == "user": 28 | prompt += "A: " 29 | return prompt 30 | 31 | def discard_exceeding(self, max_tokens, cur_tokens=None): 32 | precise = True 33 | try: 34 | cur_tokens = self.calc_tokens() 35 | except Exception as e: 36 | precise = False 37 | if cur_tokens is None: 38 | raise e 39 | logger.debug( 40 | "Exception when counting tokens precisely for query: {}".format(e) 41 | ) 42 | while cur_tokens > max_tokens: 43 | if len(self.messages) > 1: 44 | self.messages.pop(0) 45 | elif len(self.messages) == 1 and self.messages[0]["role"] == "assistant": 46 | self.messages.pop(0) 47 | if precise: 48 | cur_tokens = self.calc_tokens() 49 | else: 50 | cur_tokens = len(str(self)) 51 | break 52 | elif len(self.messages) == 1 and self.messages[0]["role"] == "user": 53 | logger.warn( 54 | "user question exceed max_tokens. total_tokens={}".format( 55 | cur_tokens 56 | ) 57 | ) 58 | break 59 | else: 60 | logger.debug( 61 | "max_tokens={}, total_tokens={}, len(conversation)={}".format( 62 | max_tokens, cur_tokens, len(self.messages) 63 | ) 64 | ) 65 | break 66 | if precise: 67 | cur_tokens = self.calc_tokens() 68 | else: 69 | cur_tokens = len(str(self)) 70 | return cur_tokens 71 | 72 | def calc_tokens(self): 73 | return num_tokens_from_string(str(self), self.model) 74 | 75 | 76 | # refer to https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb 77 | def num_tokens_from_string(string: str, model: str) -> int: 78 | """Returns the number of tokens in a text string.""" 79 | import tiktoken 80 | 81 | encoding = tiktoken.encoding_for_model(model) 82 | num_tokens = len(encoding.encode(string, disallowed_special=())) 83 | return num_tokens 84 | -------------------------------------------------------------------------------- /bot/session_manager.py: -------------------------------------------------------------------------------- 1 | from common.expired_dict import ExpiredDict 2 | from common.log import logger 3 | from config import conf 4 | 5 | 6 | class Session(object): 7 | def __init__(self, session_id, system_prompt=None): 8 | self.session_id = session_id 9 | self.messages = [] 10 | if system_prompt is None: 11 | self.system_prompt = conf().get("character_desc", "") 12 | else: 13 | self.system_prompt = system_prompt 14 | 15 | # 重置会话 16 | def reset(self): 17 | system_item = {"role": "system", "content": self.system_prompt} 18 | self.messages = [system_item] 19 | 20 | def set_system_prompt(self, system_prompt): 21 | self.system_prompt = system_prompt 22 | self.reset() 23 | 24 | def add_query(self, query): 25 | user_item = {"role": "user", "content": query} 26 | self.messages.append(user_item) 27 | 28 | def add_reply(self, reply): 29 | assistant_item = {"role": "assistant", "content": reply} 30 | self.messages.append(assistant_item) 31 | 32 | def discard_exceeding(self, max_tokens=None, cur_tokens=None): 33 | raise NotImplementedError 34 | 35 | def calc_tokens(self): 36 | raise NotImplementedError 37 | 38 | 39 | class SessionManager(object): 40 | def __init__(self, sessioncls, **session_args): 41 | if conf().get("expires_in_seconds"): 42 | sessions = ExpiredDict(conf().get("expires_in_seconds")) 43 | else: 44 | sessions = dict() 45 | self.sessions = sessions 46 | self.sessioncls = sessioncls 47 | self.session_args = session_args 48 | 49 | def build_session(self, session_id, system_prompt=None): 50 | """ 51 | 如果session_id不在sessions中,创建一个新的session并添加到sessions中 52 | 如果system_prompt不会空,会更新session的system_prompt并重置session 53 | """ 54 | if session_id is None: 55 | return self.sessioncls(session_id, system_prompt, **self.session_args) 56 | 57 | if session_id not in self.sessions: 58 | self.sessions[session_id] = self.sessioncls( 59 | session_id, system_prompt, **self.session_args 60 | ) 61 | elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session 62 | self.sessions[session_id].set_system_prompt(system_prompt) 63 | session = self.sessions[session_id] 64 | return session 65 | 66 | def session_query(self, query, session_id): 67 | session = self.build_session(session_id) 68 | session.add_query(query) 69 | try: 70 | max_tokens = conf().get("conversation_max_tokens", 1000) 71 | total_tokens = session.discard_exceeding(max_tokens, None) 72 | logger.debug("prompt tokens used={}".format(total_tokens)) 73 | except Exception as e: 74 | logger.debug( 75 | "Exception when counting tokens precisely for prompt: {}".format(str(e)) 76 | ) 77 | return session 78 | 79 | def session_reply(self, reply, session_id, total_tokens=None): 80 | session = self.build_session(session_id) 81 | session.add_reply(reply) 82 | try: 83 | max_tokens = conf().get("conversation_max_tokens", 1000) 84 | tokens_cnt = session.discard_exceeding(max_tokens, total_tokens) 85 | logger.debug( 86 | "raw total_tokens={}, savesession tokens={}".format( 87 | total_tokens, tokens_cnt 88 | ) 89 | ) 90 | except Exception as e: 91 | logger.debug( 92 | "Exception when counting tokens precisely for session: {}".format( 93 | str(e) 94 | ) 95 | ) 96 | return session 97 | 98 | def clear_session(self, session_id): 99 | if session_id in self.sessions: 100 | del self.sessions[session_id] 101 | 102 | def clear_all_session(self): 103 | self.sessions.clear() 104 | -------------------------------------------------------------------------------- /bridge/bridge.py: -------------------------------------------------------------------------------- 1 | from bot import bot_factory 2 | from bridge.context import Context 3 | from bridge.reply import Reply 4 | from common import const 5 | from common.log import logger 6 | from common.singleton import singleton 7 | from config import conf 8 | from voice import voice_factory 9 | 10 | 11 | @singleton 12 | class Bridge(object): 13 | def __init__(self): 14 | self.btype = { 15 | "chat": const.CHATGPT, 16 | "voice_to_text": conf().get("voice_to_text", "openai"), 17 | "text_to_voice": conf().get("text_to_voice", "google"), 18 | } 19 | model_type = conf().get("model") 20 | if model_type in ["text-davinci-003"]: 21 | self.btype["chat"] = const.OPEN_AI 22 | if conf().get("use_azure_chatgpt", False): 23 | self.btype["chat"] = const.CHATGPTONAZURE 24 | self.bots = {} 25 | 26 | def get_bot(self, typename): 27 | if self.bots.get(typename) is None: 28 | logger.info("create bot {} for {}".format(self.btype[typename], typename)) 29 | if typename == "text_to_voice": 30 | self.bots[typename] = voice_factory.create_voice(self.btype[typename]) 31 | elif typename == "voice_to_text": 32 | self.bots[typename] = voice_factory.create_voice(self.btype[typename]) 33 | elif typename == "chat": 34 | self.bots[typename] = bot_factory.create_bot(self.btype[typename]) 35 | return self.bots[typename] 36 | 37 | def get_bot_type(self, typename): 38 | return self.btype[typename] 39 | 40 | def fetch_reply_content(self, query, context: Context) -> Reply: 41 | return self.get_bot("chat").reply(query, context) 42 | 43 | def fetch_voice_to_text(self, voiceFile) -> Reply: 44 | return self.get_bot("voice_to_text").voiceToText(voiceFile) 45 | 46 | def fetch_text_to_voice(self, text) -> Reply: 47 | return self.get_bot("text_to_voice").textToVoice(text) 48 | -------------------------------------------------------------------------------- /bridge/context.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | 3 | from enum import Enum 4 | 5 | 6 | class ContextType(Enum): 7 | TEXT = 1 # 文本消息 8 | VOICE = 2 # 音频消息 9 | IMAGE = 3 # 图片消息 10 | IMAGE_CREATE = 10 # 创建图片命令 11 | JOIN_GROUP = 20 # 加入群聊 12 | PATPAT = 21 # 拍了拍 13 | 14 | def __str__(self): 15 | return self.name 16 | 17 | 18 | class Context: 19 | def __init__(self, type: ContextType = None, content=None, kwargs=dict()): 20 | self.type = type 21 | self.content = content 22 | self.kwargs = kwargs 23 | 24 | def __contains__(self, key): 25 | if key == "type": 26 | return self.type is not None 27 | elif key == "content": 28 | return self.content is not None 29 | else: 30 | return key in self.kwargs 31 | 32 | def __getitem__(self, key): 33 | if key == "type": 34 | return self.type 35 | elif key == "content": 36 | return self.content 37 | else: 38 | return self.kwargs[key] 39 | 40 | def get(self, key, default=None): 41 | try: 42 | return self[key] 43 | except KeyError: 44 | return default 45 | 46 | def __setitem__(self, key, value): 47 | if key == "type": 48 | self.type = value 49 | elif key == "content": 50 | self.content = value 51 | else: 52 | self.kwargs[key] = value 53 | 54 | def __delitem__(self, key): 55 | if key == "type": 56 | self.type = None 57 | elif key == "content": 58 | self.content = None 59 | else: 60 | del self.kwargs[key] 61 | 62 | def __str__(self): 63 | return "Context(type={}, content={}, kwargs={})".format( 64 | self.type, self.content, self.kwargs 65 | ) 66 | -------------------------------------------------------------------------------- /bridge/reply.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | 3 | from enum import Enum 4 | 5 | 6 | class ReplyType(Enum): 7 | TEXT = 1 # 文本 8 | VOICE = 2 # 音频文件 9 | IMAGE = 3 # 图片文件 10 | IMAGE_URL = 4 # 图片URL 11 | 12 | INFO = 9 13 | ERROR = 10 14 | 15 | def __str__(self): 16 | return self.name 17 | 18 | 19 | class Reply: 20 | def __init__(self, type: ReplyType = None, content=None): 21 | self.type = type 22 | self.content = content 23 | 24 | def __str__(self): 25 | return "Reply(type={}, content={})".format(self.type, self.content) 26 | -------------------------------------------------------------------------------- /channel/channel.py: -------------------------------------------------------------------------------- 1 | """ 2 | Message sending channel abstract class 3 | """ 4 | 5 | from bridge.bridge import Bridge 6 | from bridge.context import Context 7 | from bridge.reply import * 8 | 9 | 10 | class Channel(object): 11 | NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE, ReplyType.IMAGE] 12 | 13 | def startup(self): 14 | """ 15 | init channel 16 | """ 17 | raise NotImplementedError 18 | 19 | def handle_text(self, msg): 20 | """ 21 | process received msg 22 | :param msg: message object 23 | """ 24 | raise NotImplementedError 25 | 26 | # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 27 | def send(self, reply: Reply, context: Context): 28 | """ 29 | send message to user 30 | :param msg: message content 31 | :param receiver: receiver channel account 32 | :return: 33 | """ 34 | raise NotImplementedError 35 | 36 | def build_reply_content(self, query, context: Context = None) -> Reply: 37 | return Bridge().fetch_reply_content(query, context) 38 | 39 | def build_voice_to_text(self, voice_file) -> Reply: 40 | return Bridge().fetch_voice_to_text(voice_file) 41 | 42 | def build_text_to_voice(self, text) -> Reply: 43 | return Bridge().fetch_text_to_voice(text) 44 | -------------------------------------------------------------------------------- /channel/channel_factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | channel factory 3 | """ 4 | 5 | 6 | def create_channel(channel_type): 7 | """ 8 | create a channel instance 9 | :param channel_type: channel type code 10 | :return: channel instance 11 | """ 12 | if channel_type == "wx": 13 | from channel.wechat.wechat_channel import WechatChannel 14 | 15 | return WechatChannel() 16 | elif channel_type == "wxy": 17 | from channel.wechat.wechaty_channel import WechatyChannel 18 | 19 | return WechatyChannel() 20 | elif channel_type == "terminal": 21 | from channel.terminal.terminal_channel import TerminalChannel 22 | 23 | return TerminalChannel() 24 | elif channel_type == "wechatmp": 25 | from channel.wechatmp.wechatmp_channel import WechatMPChannel 26 | 27 | return WechatMPChannel(passive_reply=True) 28 | elif channel_type == "wechatmp_service": 29 | from channel.wechatmp.wechatmp_channel import WechatMPChannel 30 | 31 | return WechatMPChannel(passive_reply=False) 32 | raise RuntimeError 33 | -------------------------------------------------------------------------------- /channel/chat_message.py: -------------------------------------------------------------------------------- 1 | """ 2 | 本类表示聊天消息,用于对itchat和wechaty的消息进行统一的封装。 3 | 4 | 填好必填项(群聊6个,非群聊8个),即可接入ChatChannel,并支持插件,参考TerminalChannel 5 | 6 | ChatMessage 7 | msg_id: 消息id (必填) 8 | create_time: 消息创建时间 9 | 10 | ctype: 消息类型 : ContextType (必填) 11 | content: 消息内容, 如果是声音/图片,这里是文件路径 (必填) 12 | 13 | from_user_id: 发送者id (必填) 14 | from_user_nickname: 发送者昵称 15 | to_user_id: 接收者id (必填) 16 | to_user_nickname: 接收者昵称 17 | 18 | other_user_id: 对方的id,如果你是发送者,那这个就是接收者id,如果你是接收者,那这个就是发送者id,如果是群消息,那这一直是群id (必填) 19 | other_user_nickname: 同上 20 | 21 | is_group: 是否是群消息 (群聊必填) 22 | is_at: 是否被at 23 | 24 | - (群消息时,一般会存在实际发送者,是群内某个成员的id和昵称,下列项仅在群消息时存在) 25 | actual_user_id: 实际发送者id (群聊必填) 26 | actual_user_nickname:实际发送者昵称 27 | 28 | 29 | 30 | 31 | _prepare_fn: 准备函数,用于准备消息的内容,比如下载图片等, 32 | _prepared: 是否已经调用过准备函数 33 | _rawmsg: 原始消息对象 34 | 35 | """ 36 | 37 | 38 | class ChatMessage(object): 39 | msg_id = None 40 | create_time = None 41 | 42 | ctype = None 43 | content = None 44 | 45 | from_user_id = None 46 | from_user_nickname = None 47 | to_user_id = None 48 | to_user_nickname = None 49 | other_user_id = None 50 | other_user_nickname = None 51 | 52 | is_group = False 53 | is_at = False 54 | actual_user_id = None 55 | actual_user_nickname = None 56 | 57 | _prepare_fn = None 58 | _prepared = False 59 | _rawmsg = None 60 | 61 | def __init__(self, _rawmsg): 62 | self._rawmsg = _rawmsg 63 | 64 | def prepare(self): 65 | if self._prepare_fn and not self._prepared: 66 | self._prepared = True 67 | self._prepare_fn() 68 | 69 | def __str__(self): 70 | return "ChatMessage: id={}, create_time={}, ctype={}, content={}, from_user_id={}, from_user_nickname={}, to_user_id={}, to_user_nickname={}, other_user_id={}, other_user_nickname={}, is_group={}, is_at={}, actual_user_id={}, actual_user_nickname={}".format( 71 | self.msg_id, 72 | self.create_time, 73 | self.ctype, 74 | self.content, 75 | self.from_user_id, 76 | self.from_user_nickname, 77 | self.to_user_id, 78 | self.to_user_nickname, 79 | self.other_user_id, 80 | self.other_user_nickname, 81 | self.is_group, 82 | self.is_at, 83 | self.actual_user_id, 84 | self.actual_user_nickname, 85 | ) 86 | -------------------------------------------------------------------------------- /channel/terminal/terminal_channel.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from bridge.context import * 4 | from bridge.reply import Reply, ReplyType 5 | from channel.chat_channel import ChatChannel, check_prefix 6 | from channel.chat_message import ChatMessage 7 | from common.log import logger 8 | from config import conf 9 | 10 | 11 | class TerminalMessage(ChatMessage): 12 | def __init__( 13 | self, 14 | msg_id, 15 | content, 16 | ctype=ContextType.TEXT, 17 | from_user_id="User", 18 | to_user_id="Chatgpt", 19 | other_user_id="Chatgpt", 20 | ): 21 | self.msg_id = msg_id 22 | self.ctype = ctype 23 | self.content = content 24 | self.from_user_id = from_user_id 25 | self.to_user_id = to_user_id 26 | self.other_user_id = other_user_id 27 | 28 | 29 | class TerminalChannel(ChatChannel): 30 | NOT_SUPPORT_REPLYTYPE = [ReplyType.VOICE] 31 | 32 | def send(self, reply: Reply, context: Context): 33 | print("\nBot:") 34 | if reply.type == ReplyType.IMAGE: 35 | from PIL import Image 36 | 37 | image_storage = reply.content 38 | image_storage.seek(0) 39 | img = Image.open(image_storage) 40 | print("") 41 | img.show() 42 | elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 43 | import io 44 | 45 | import requests 46 | from PIL import Image 47 | 48 | img_url = reply.content 49 | pic_res = requests.get(img_url, stream=True) 50 | image_storage = io.BytesIO() 51 | for block in pic_res.iter_content(1024): 52 | image_storage.write(block) 53 | image_storage.seek(0) 54 | img = Image.open(image_storage) 55 | print(img_url) 56 | img.show() 57 | else: 58 | print(reply.content) 59 | print("\nUser:", end="") 60 | sys.stdout.flush() 61 | return 62 | 63 | def startup(self): 64 | context = Context() 65 | logger.setLevel("WARN") 66 | print("\nPlease input your question:\nUser:", end="") 67 | sys.stdout.flush() 68 | msg_id = 0 69 | while True: 70 | try: 71 | prompt = self.get_input() 72 | except KeyboardInterrupt: 73 | print("\nExiting...") 74 | sys.exit() 75 | msg_id += 1 76 | trigger_prefixs = conf().get("single_chat_prefix", [""]) 77 | if check_prefix(prompt, trigger_prefixs) is None: 78 | prompt = trigger_prefixs[0] + prompt # 给没触发的消息加上触发前缀 79 | 80 | context = self._compose_context( 81 | ContextType.TEXT, prompt, msg=TerminalMessage(msg_id, prompt) 82 | ) 83 | if context: 84 | self.produce(context) 85 | else: 86 | raise Exception("context is None") 87 | 88 | def get_input(self): 89 | """ 90 | Multi-line input function 91 | """ 92 | sys.stdout.flush() 93 | line = input() 94 | return line 95 | -------------------------------------------------------------------------------- /channel/wechat/wechat_message.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from bridge.context import ContextType 4 | from channel.chat_message import ChatMessage 5 | from common.log import logger 6 | from common.tmp_dir import TmpDir 7 | from lib import itchat 8 | from lib.itchat.content import * 9 | 10 | 11 | class WeChatMessage(ChatMessage): 12 | def __init__(self, itchat_msg, is_group=False): 13 | super().__init__(itchat_msg) 14 | self.msg_id = itchat_msg["MsgId"] 15 | self.create_time = itchat_msg["CreateTime"] 16 | self.is_group = is_group 17 | 18 | if itchat_msg["Type"] == TEXT: 19 | self.ctype = ContextType.TEXT 20 | self.content = itchat_msg["Text"] 21 | elif itchat_msg["Type"] == VOICE: 22 | self.ctype = ContextType.VOICE 23 | self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径 24 | self._prepare_fn = lambda: itchat_msg.download(self.content) 25 | elif itchat_msg["Type"] == PICTURE and itchat_msg["MsgType"] == 3: 26 | self.ctype = ContextType.IMAGE 27 | self.content = TmpDir().path() + itchat_msg["FileName"] # content直接存临时目录路径 28 | self._prepare_fn = lambda: itchat_msg.download(self.content) 29 | elif itchat_msg["Type"] == NOTE and itchat_msg["MsgType"] == 10000: 30 | if is_group and ( 31 | "加入群聊" in itchat_msg["Content"] or "加入了群聊" in itchat_msg["Content"] 32 | ): 33 | self.ctype = ContextType.JOIN_GROUP 34 | self.content = itchat_msg["Content"] 35 | # 这里只能得到nickname, actual_user_id还是机器人的id 36 | if "加入了群聊" in itchat_msg["Content"]: 37 | self.actual_user_nickname = re.findall( 38 | r"\"(.*?)\"", itchat_msg["Content"] 39 | )[-1] 40 | elif "加入群聊" in itchat_msg["Content"]: 41 | self.actual_user_nickname = re.findall( 42 | r"\"(.*?)\"", itchat_msg["Content"] 43 | )[0] 44 | elif "拍了拍我" in itchat_msg["Content"]: 45 | self.ctype = ContextType.PATPAT 46 | self.content = itchat_msg["Content"] 47 | if is_group: 48 | self.actual_user_nickname = re.findall( 49 | r"\"(.*?)\"", itchat_msg["Content"] 50 | )[0] 51 | else: 52 | raise NotImplementedError( 53 | "Unsupported note message: " + itchat_msg["Content"] 54 | ) 55 | else: 56 | raise NotImplementedError( 57 | "Unsupported message type: Type:{} MsgType:{}".format( 58 | itchat_msg["Type"], itchat_msg["MsgType"] 59 | ) 60 | ) 61 | 62 | self.from_user_id = itchat_msg["FromUserName"] 63 | self.to_user_id = itchat_msg["ToUserName"] 64 | 65 | user_id = itchat.instance.storageClass.userName 66 | nickname = itchat.instance.storageClass.nickName 67 | 68 | # 虽然from_user_id和to_user_id用的少,但是为了保持一致性,还是要填充一下 69 | # 以下很繁琐,一句话总结:能填的都填了。 70 | if self.from_user_id == user_id: 71 | self.from_user_nickname = nickname 72 | if self.to_user_id == user_id: 73 | self.to_user_nickname = nickname 74 | try: # 陌生人时候, 'User'字段可能不存在 75 | self.other_user_id = itchat_msg["User"]["UserName"] 76 | self.other_user_nickname = itchat_msg["User"]["NickName"] 77 | if self.other_user_id == self.from_user_id: 78 | self.from_user_nickname = self.other_user_nickname 79 | if self.other_user_id == self.to_user_id: 80 | self.to_user_nickname = self.other_user_nickname 81 | except KeyError as e: # 处理偶尔没有对方信息的情况 82 | logger.warn("[WX]get other_user_id failed: " + str(e)) 83 | if self.from_user_id == user_id: 84 | self.other_user_id = self.to_user_id 85 | else: 86 | self.other_user_id = self.from_user_id 87 | 88 | if self.is_group: 89 | self.is_at = itchat_msg["IsAt"] 90 | self.actual_user_id = itchat_msg["ActualUserName"] 91 | if self.ctype not in [ContextType.JOIN_GROUP, ContextType.PATPAT]: 92 | self.actual_user_nickname = itchat_msg["ActualNickName"] 93 | -------------------------------------------------------------------------------- /channel/wechat/wechaty_channel.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | 3 | """ 4 | wechaty channel 5 | Python Wechaty - https://github.com/wechaty/python-wechaty 6 | """ 7 | import asyncio 8 | import base64 9 | import os 10 | import time 11 | 12 | from wechaty import Contact, Wechaty 13 | from wechaty.user import Message 14 | from wechaty_puppet import FileBox 15 | 16 | from bridge.context import * 17 | from bridge.context import Context 18 | from bridge.reply import * 19 | from channel.chat_channel import ChatChannel 20 | from channel.wechat.wechaty_message import WechatyMessage 21 | from common.log import logger 22 | from common.singleton import singleton 23 | from config import conf 24 | 25 | try: 26 | from voice.audio_convert import any_to_sil 27 | except Exception as e: 28 | pass 29 | 30 | 31 | @singleton 32 | class WechatyChannel(ChatChannel): 33 | NOT_SUPPORT_REPLYTYPE = [] 34 | 35 | def __init__(self): 36 | super().__init__() 37 | 38 | def startup(self): 39 | config = conf() 40 | token = config.get("wechaty_puppet_service_token") 41 | os.environ["WECHATY_PUPPET_SERVICE_TOKEN"] = token 42 | asyncio.run(self.main()) 43 | 44 | async def main(self): 45 | loop = asyncio.get_event_loop() 46 | # 将asyncio的loop传入处理线程 47 | self.handler_pool._initializer = lambda: asyncio.set_event_loop(loop) 48 | self.bot = Wechaty() 49 | self.bot.on("login", self.on_login) 50 | self.bot.on("message", self.on_message) 51 | await self.bot.start() 52 | 53 | async def on_login(self, contact: Contact): 54 | self.user_id = contact.contact_id 55 | self.name = contact.name 56 | logger.info("[WX] login user={}".format(contact)) 57 | 58 | # 统一的发送函数,每个Channel自行实现,根据reply的type字段发送不同类型的消息 59 | def send(self, reply: Reply, context: Context): 60 | receiver_id = context["receiver"] 61 | loop = asyncio.get_event_loop() 62 | if context["isgroup"]: 63 | receiver = asyncio.run_coroutine_threadsafe( 64 | self.bot.Room.find(receiver_id), loop 65 | ).result() 66 | else: 67 | receiver = asyncio.run_coroutine_threadsafe( 68 | self.bot.Contact.find(receiver_id), loop 69 | ).result() 70 | msg = None 71 | if reply.type == ReplyType.TEXT: 72 | msg = reply.content 73 | asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result() 74 | logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver)) 75 | elif reply.type == ReplyType.ERROR or reply.type == ReplyType.INFO: 76 | msg = reply.content 77 | asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result() 78 | logger.info("[WX] sendMsg={}, receiver={}".format(reply, receiver)) 79 | elif reply.type == ReplyType.VOICE: 80 | voiceLength = None 81 | file_path = reply.content 82 | sil_file = os.path.splitext(file_path)[0] + ".sil" 83 | voiceLength = int(any_to_sil(file_path, sil_file)) 84 | if voiceLength >= 60000: 85 | voiceLength = 60000 86 | logger.info( 87 | "[WX] voice too long, length={}, set to 60s".format(voiceLength) 88 | ) 89 | # 发送语音 90 | t = int(time.time()) 91 | msg = FileBox.from_file(sil_file, name=str(t) + ".sil") 92 | if voiceLength is not None: 93 | msg.metadata["voiceLength"] = voiceLength 94 | asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result() 95 | try: 96 | os.remove(file_path) 97 | if sil_file != file_path: 98 | os.remove(sil_file) 99 | except Exception as e: 100 | pass 101 | logger.info( 102 | "[WX] sendVoice={}, receiver={}".format(reply.content, receiver) 103 | ) 104 | elif reply.type == ReplyType.IMAGE_URL: # 从网络下载图片 105 | img_url = reply.content 106 | t = int(time.time()) 107 | msg = FileBox.from_url(url=img_url, name=str(t) + ".png") 108 | asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result() 109 | logger.info("[WX] sendImage url={}, receiver={}".format(img_url, receiver)) 110 | elif reply.type == ReplyType.IMAGE: # 从文件读取图片 111 | image_storage = reply.content 112 | image_storage.seek(0) 113 | t = int(time.time()) 114 | msg = FileBox.from_base64( 115 | base64.b64encode(image_storage.read()), str(t) + ".png" 116 | ) 117 | asyncio.run_coroutine_threadsafe(receiver.say(msg), loop).result() 118 | logger.info("[WX] sendImage, receiver={}".format(receiver)) 119 | 120 | async def on_message(self, msg: Message): 121 | """ 122 | listen for message event 123 | """ 124 | try: 125 | cmsg = await WechatyMessage(msg) 126 | except NotImplementedError as e: 127 | logger.debug("[WX] {}".format(e)) 128 | return 129 | except Exception as e: 130 | logger.exception("[WX] {}".format(e)) 131 | return 132 | logger.debug("[WX] message:{}".format(cmsg)) 133 | room = msg.room() # 获取消息来自的群聊. 如果消息不是来自群聊, 则返回None 134 | isgroup = room is not None 135 | ctype = cmsg.ctype 136 | context = self._compose_context(ctype, cmsg.content, isgroup=isgroup, msg=cmsg) 137 | if context: 138 | logger.info("[WX] receiveMsg={}, context={}".format(cmsg, context)) 139 | self.produce(context) 140 | -------------------------------------------------------------------------------- /channel/wechat/wechaty_message.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import re 3 | 4 | from wechaty import MessageType 5 | from wechaty.user import Message 6 | 7 | from bridge.context import ContextType 8 | from channel.chat_message import ChatMessage 9 | from common.log import logger 10 | from common.tmp_dir import TmpDir 11 | 12 | 13 | class aobject(object): 14 | """Inheriting this class allows you to define an async __init__. 15 | 16 | So you can create objects by doing something like `await MyClass(params)` 17 | """ 18 | 19 | async def __new__(cls, *a, **kw): 20 | instance = super().__new__(cls) 21 | await instance.__init__(*a, **kw) 22 | return instance 23 | 24 | async def __init__(self): 25 | pass 26 | 27 | 28 | class WechatyMessage(ChatMessage, aobject): 29 | async def __init__(self, wechaty_msg: Message): 30 | super().__init__(wechaty_msg) 31 | 32 | room = wechaty_msg.room() 33 | 34 | self.msg_id = wechaty_msg.message_id 35 | self.create_time = wechaty_msg.payload.timestamp 36 | self.is_group = room is not None 37 | 38 | if wechaty_msg.type() == MessageType.MESSAGE_TYPE_TEXT: 39 | self.ctype = ContextType.TEXT 40 | self.content = wechaty_msg.text() 41 | elif wechaty_msg.type() == MessageType.MESSAGE_TYPE_AUDIO: 42 | self.ctype = ContextType.VOICE 43 | voice_file = await wechaty_msg.to_file_box() 44 | self.content = TmpDir().path() + voice_file.name # content直接存临时目录路径 45 | 46 | def func(): 47 | loop = asyncio.get_event_loop() 48 | asyncio.run_coroutine_threadsafe( 49 | voice_file.to_file(self.content), loop 50 | ).result() 51 | 52 | self._prepare_fn = func 53 | 54 | else: 55 | raise NotImplementedError( 56 | "Unsupported message type: {}".format(wechaty_msg.type()) 57 | ) 58 | 59 | from_contact = wechaty_msg.talker() # 获取消息的发送者 60 | self.from_user_id = from_contact.contact_id 61 | self.from_user_nickname = from_contact.name 62 | 63 | # group中的from和to,wechaty跟itchat含义不一样 64 | # wecahty: from是消息实际发送者, to:所在群 65 | # itchat: 如果是你发送群消息,from和to是你自己和所在群,如果是别人发群消息,from和to是所在群和你自己 66 | # 但这个差别不影响逻辑,group中只使用到:1.用from来判断是否是自己发的,2.actual_user_id来判断实际发送用户 67 | 68 | if self.is_group: 69 | self.to_user_id = room.room_id 70 | self.to_user_nickname = await room.topic() 71 | else: 72 | to_contact = wechaty_msg.to() 73 | self.to_user_id = to_contact.contact_id 74 | self.to_user_nickname = to_contact.name 75 | 76 | if ( 77 | self.is_group or wechaty_msg.is_self() 78 | ): # 如果是群消息,other_user设置为群,如果是私聊消息,而且自己发的,就设置成对方。 79 | self.other_user_id = self.to_user_id 80 | self.other_user_nickname = self.to_user_nickname 81 | else: 82 | self.other_user_id = self.from_user_id 83 | self.other_user_nickname = self.from_user_nickname 84 | 85 | if self.is_group: # wechaty群聊中,实际发送用户就是from_user 86 | self.is_at = await wechaty_msg.mention_self() 87 | if not self.is_at: # 有时候复制粘贴的消息,不算做@,但是内容里面会有@xxx,这里做一下兼容 88 | name = wechaty_msg.wechaty.user_self().name 89 | pattern = f"@{name}(\u2005|\u0020)" 90 | if re.search(pattern, self.content): 91 | logger.debug(f"wechaty message {self.msg_id} include at") 92 | self.is_at = True 93 | 94 | self.actual_user_id = self.from_user_id 95 | self.actual_user_nickname = self.from_user_nickname 96 | -------------------------------------------------------------------------------- /channel/wechatmp/README.md: -------------------------------------------------------------------------------- 1 | # 微信公众号channel 2 | 3 | 鉴于个人微信号在服务器上通过itchat登录有封号风险,这里新增了微信公众号channel,提供无风险的服务。 4 | 目前支持订阅号和服务号两种类型的公众号。个人主体的微信订阅号由于无法通过微信认证,接口存在限制,目前仅支持最基本的文本交互和语音输入。通过微信认证的订阅号或者服务号可以回复图片和语音。 5 | 6 | ## 使用方法(订阅号,服务号类似) 7 | 8 | 在开始部署前,你需要一个拥有公网IP的服务器,以提供微信服务器和我们自己服务器的连接。或者你需要进行内网穿透,否则微信服务器无法将消息发送给我们的服务器。 9 | 10 | 此外,需要在我们的服务器上安装python的web框架web.py。 11 | 以ubuntu为例(在ubuntu 22.04上测试): 12 | ``` 13 | pip3 install web.py 14 | ``` 15 | 16 | 然后在[微信公众平台](https://mp.weixin.qq.com)注册一个自己的公众号,类型选择订阅号,主体为个人即可。 17 | 18 | 然后根据[接入指南](https://developers.weixin.qq.com/doc/offiaccount/Basic_Information/Access_Overview.html)的说明,在[微信公众平台](https://mp.weixin.qq.com)的“设置与开发”-“基本配置”-“服务器配置”中填写服务器地址`URL`和令牌`Token`。这里的`URL`是`example.com/wx`的形式,不可以使用IP,`Token`是你自己编的一个特定的令牌。消息加解密方式目前选择的是明文模式。 19 | 20 | 相关的服务器验证代码已经写好,你不需要再添加任何代码。你只需要在本项目根目录的`config.json`中添加 21 | ``` 22 | "channel_type": "wechatmp", # 如果通过了微信认证,将"wechatmp"替换为"wechatmp_service",可极大的优化使用体验 23 | "wechatmp_token": "xxxx", # 微信公众平台的Token 24 | "wechatmp_port": 8080, # 微信公众平台的端口,需要端口转发到80或443 25 | "wechatmp_app_id": "xxxx", # 微信公众平台的appID 26 | "wechatmp_app_secret": "xxxx", # 微信公众平台的appsecret 27 | "single_chat_prefix": [""], # 推荐设置,任意对话都可以触发回复,不添加前缀 28 | "single_chat_reply_prefix": "", # 推荐设置,回复不设置前缀 29 | "plugin_trigger_prefix": "&", # 推荐设置,在手机微信客户端中,$%^等符号与中文连在一起时会自动显示一段较大的间隔,用户体验不好。请不要使用管理员指令前缀"#",这会造成未知问题。 30 | ``` 31 | 然后运行`python3 app.py`启动web服务器。这里会默认监听8080端口,但是微信公众号的服务器配置只支持80/443端口,有两种方法来解决这个问题。第一个是推荐的方法,使用端口转发命令将80端口转发到8080端口: 32 | ``` 33 | sudo iptables -t nat -A PREROUTING -p tcp --dport 80 -j REDIRECT --to-port 8080 34 | sudo iptables-save > /etc/iptables/rules.v4 35 | ``` 36 | 第二个方法是让python程序直接监听80端口,在配置文件中设置`"wechatmp_port": 80` ,在linux上需要使用`sudo python3 app.py`启动程序。然而这会导致一系列环境和权限问题,因此不是推荐的方法。 37 | 38 | 443端口同理,注意需要支持SSL,也就是https的访问,在`wechatmp_channel.py`中需要修改相应的证书路径。 39 | 40 | 程序启动并监听端口后,在刚才的“服务器配置”中点击`提交`即可验证你的服务器。 41 | 随后在[微信公众平台](https://mp.weixin.qq.com)启用服务器,关闭手动填写规则的自动回复,即可实现ChatGPT的自动回复。 42 | 43 | 如果在启用后如果遇到如下报错: 44 | ``` 45 | 'errcode': 40164, 'errmsg': 'invalid ip xx.xx.xx.xx not in whitelist rid 46 | ``` 47 | 48 | 需要在公众号开发信息下将IP加入到IP白名单。 49 | 50 | ## 个人微信公众号的限制 51 | 由于人微信公众号不能通过微信认证,所以没有客服接口,因此公众号无法主动发出消息,只能被动回复。而微信官方对被动回复有5秒的时间限制,最多重试2次,因此最多只有15秒的自动回复时间窗口。因此如果问题比较复杂或者我们的服务器比较忙,ChatGPT的回答就没办法及时回复给用户。为了解决这个问题,这里做了回答缓存,它需要你在回复超时后,再次主动发送任意文字(例如1)来尝试拿到回答缓存。为了优化使用体验,目前设置了两分钟(120秒)的timeout,用户在至多两分钟后即可得到查询到回复或者错误原因。 52 | 53 | 另外,由于微信官方的限制,自动回复有长度限制。因此这里将ChatGPT的回答进行了拆分,以满足限制。 54 | 55 | ## 私有api_key 56 | 公共api有访问频率限制(免费账号每分钟最多3次ChatGPT的API调用),这在服务多人的时候会遇到问题。因此这里多加了一个设置私有api_key的功能。目前通过godcmd插件的命令来设置私有api_key。 57 | 58 | ## 语音输入 59 | 利用微信自带的语音识别功能,提供语音输入能力。需要在公众号管理页面的“设置与开发”->“接口权限”页面开启“接收语音识别结果”。 60 | 61 | ## 语音回复 62 | 请在配置文件中添加以下词条: 63 | ``` 64 | "voice_reply_voice": true, 65 | ``` 66 | 这样公众号将会用语音回复语音消息,实现语音对话。 67 | 68 | 默认的语音合成引擎是`google`,它是免费使用的。 69 | 70 | 如果要选择其他的语音合成引擎,请添加以下配置项: 71 | ``` 72 | "text_to_voice": "pytts" 73 | ``` 74 | 75 | pytts是本地的语音合成引擎。还支持baidu,azure,这些你需要自行配置相关的依赖和key。 76 | 77 | 如果使用pytts,在ubuntu上需要安装如下依赖: 78 | ``` 79 | sudo apt update 80 | sudo apt install espeak 81 | sudo apt install ffmpeg 82 | python3 -m pip install pyttsx3 83 | ``` 84 | 不是很建议开启pytts语音回复,因为它是离线本地计算,算的慢会拖垮服务器,且声音不好听。 85 | 86 | ## 图片回复 87 | 现在认证公众号和非认证公众号都可以实现的图片和语音回复。但是非认证公众号使用了永久素材接口,每天有1000次的调用上限(每个月有10次重置机会,程序中已设定遇到上限会自动重置),且永久素材库存也有上限。因此对于非认证公众号,我们会在回复图片或者语音消息后的10秒内从永久素材库存内删除该素材。 88 | 89 | ## 测试 90 | 目前在`RoboStyle`这个公众号上进行了测试(基于[wechatmp分支](https://github.com/JS00000/chatgpt-on-wechat/tree/wechatmp)),感兴趣的可以关注并体验。开启了godcmd, Banwords, role, dungeon, finish这五个插件,其他的插件还没有详尽测试。百度的接口暂未测试。[wechatmp-stable分支](https://github.com/JS00000/chatgpt-on-wechat/tree/wechatmp-stable)是较稳定的上个版本,但也缺少最新的功能支持。 91 | 92 | ## TODO 93 | - [x] 语音输入 94 | - [ ] 图片输入 95 | - [x] 使用临时素材接口提供认证公众号的图片和语音回复 96 | - [x] 使用永久素材接口提供未认证公众号的图片和语音回复 97 | - [ ] 高并发支持 98 | -------------------------------------------------------------------------------- /channel/wechatmp/active_reply.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import web 4 | 5 | from channel.wechatmp.wechatmp_message import parse_xml 6 | from channel.wechatmp.passive_reply_message import TextMsg 7 | from bridge.context import * 8 | from bridge.reply import ReplyType 9 | from channel.wechatmp.common import * 10 | from channel.wechatmp.wechatmp_channel import WechatMPChannel 11 | from common.log import logger 12 | from config import conf 13 | 14 | 15 | # This class is instantiated once per query 16 | class Query: 17 | def GET(self): 18 | return verify_server(web.input()) 19 | 20 | def POST(self): 21 | # Make sure to return the instance that first created, @singleton will do that. 22 | channel = WechatMPChannel() 23 | try: 24 | webData = web.data() 25 | # logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8")) 26 | wechatmp_msg = parse_xml(webData) 27 | if ( 28 | wechatmp_msg.msg_type == "text" 29 | or wechatmp_msg.msg_type == "voice" 30 | # or wechatmp_msg.msg_type == "image" 31 | ): 32 | from_user = wechatmp_msg.from_user_id 33 | message = wechatmp_msg.content 34 | message_id = wechatmp_msg.msg_id 35 | 36 | logger.info( 37 | "[wechatmp] {}:{} Receive post query {} {}: {}".format( 38 | web.ctx.env.get("REMOTE_ADDR"), 39 | web.ctx.env.get("REMOTE_PORT"), 40 | from_user, 41 | message_id, 42 | message, 43 | ) 44 | ) 45 | if (wechatmp_msg.msg_type == "voice" and conf().get("voice_reply_voice") == True): 46 | rtype = ReplyType.VOICE 47 | else: 48 | rtype = None 49 | context = channel._compose_context( 50 | ContextType.TEXT, message, isgroup=False, desire_rtype=rtype, msg=wechatmp_msg 51 | ) 52 | if context: 53 | # set private openai_api_key 54 | # if from_user is not changed in itchat, this can be placed at chat_channel 55 | user_data = conf().get_user_data(from_user) 56 | context["openai_api_key"] = user_data.get( 57 | "openai_api_key" 58 | ) # None or user openai_api_key 59 | channel.produce(context) 60 | # The reply will be sent by channel.send() in another thread 61 | return "success" 62 | 63 | elif wechatmp_msg.msg_type == "event": 64 | logger.info( 65 | "[wechatmp] Event {} from {}".format( 66 | wechatmp_msg.Event, wechatmp_msg.from_user_id 67 | ) 68 | ) 69 | content = subscribe_msg() 70 | replyMsg = TextMsg( 71 | wechatmp_msg.from_user_id, wechatmp_msg.to_user_id, content 72 | ) 73 | return replyMsg.send() 74 | else: 75 | logger.info("暂且不处理") 76 | return "success" 77 | except Exception as exc: 78 | logger.exception(exc) 79 | return exc 80 | -------------------------------------------------------------------------------- /channel/wechatmp/common.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import textwrap 3 | 4 | from config import conf 5 | 6 | MAX_UTF8_LEN = 2048 7 | 8 | 9 | class WeChatAPIException(Exception): 10 | pass 11 | 12 | 13 | def verify_server(data): 14 | try: 15 | if len(data) == 0: 16 | return "None" 17 | signature = data.signature 18 | timestamp = data.timestamp 19 | nonce = data.nonce 20 | echostr = data.echostr 21 | token = conf().get("wechatmp_token") # 请按照公众平台官网\基本配置中信息填写 22 | 23 | data_list = [token, timestamp, nonce] 24 | data_list.sort() 25 | sha1 = hashlib.sha1() 26 | # map(sha1.update, data_list) #python2 27 | sha1.update("".join(data_list).encode("utf-8")) 28 | hashcode = sha1.hexdigest() 29 | print("handle/GET func: hashcode, signature: ", hashcode, signature) 30 | if hashcode == signature: 31 | return echostr 32 | else: 33 | return "" 34 | except Exception as Argument: 35 | return Argument 36 | 37 | 38 | def subscribe_msg(): 39 | trigger_prefix = conf().get("single_chat_prefix", [""]) 40 | msg = textwrap.dedent( 41 | f"""\ 42 | 感谢您的关注! 43 | 这里是ChatGPT,可以自由对话。 44 | 资源有限,回复较慢,请勿着急。 45 | 支持语音对话。 46 | 暂时不支持图片输入。 47 | 支持图片输出,画字开头的消息将按要求创作图片。 48 | 支持tool、角色扮演和文字冒险等丰富的插件。 49 | 输入'{trigger_prefix}#帮助' 查看详细指令。""" 50 | ) 51 | return msg 52 | 53 | 54 | def split_string_by_utf8_length(string, max_length, max_split=0): 55 | encoded = string.encode("utf-8") 56 | start, end = 0, 0 57 | result = [] 58 | while end < len(encoded): 59 | if max_split > 0 and len(result) >= max_split: 60 | result.append(encoded[start:].decode("utf-8")) 61 | break 62 | end = start + max_length 63 | # 如果当前字节不是 UTF-8 编码的开始字节,则向前查找直到找到开始字节为止 64 | while end < len(encoded) and (encoded[end] & 0b11000000) == 0b10000000: 65 | end -= 1 66 | result.append(encoded[start:end].decode("utf-8")) 67 | start = end 68 | return result 69 | -------------------------------------------------------------------------------- /channel/wechatmp/passive_reply_message.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*-# 2 | # filename: reply.py 3 | import time 4 | 5 | 6 | class Msg(object): 7 | def __init__(self): 8 | pass 9 | 10 | def send(self): 11 | return "success" 12 | 13 | 14 | class TextMsg(Msg): 15 | def __init__(self, toUserName, fromUserName, content): 16 | self.__dict = dict() 17 | self.__dict["ToUserName"] = toUserName 18 | self.__dict["FromUserName"] = fromUserName 19 | self.__dict["CreateTime"] = int(time.time()) 20 | self.__dict["Content"] = content 21 | 22 | def send(self): 23 | XmlForm = """ 24 | 25 | 26 | 27 | {CreateTime} 28 | 29 | 30 | 31 | """ 32 | return XmlForm.format(**self.__dict) 33 | 34 | 35 | class VoiceMsg(Msg): 36 | def __init__(self, toUserName, fromUserName, mediaId): 37 | self.__dict = dict() 38 | self.__dict["ToUserName"] = toUserName 39 | self.__dict["FromUserName"] = fromUserName 40 | self.__dict["CreateTime"] = int(time.time()) 41 | self.__dict["MediaId"] = mediaId 42 | 43 | def send(self): 44 | XmlForm = """ 45 | 46 | 47 | 48 | {CreateTime} 49 | 50 | 51 | 52 | 53 | 54 | """ 55 | return XmlForm.format(**self.__dict) 56 | 57 | 58 | class ImageMsg(Msg): 59 | def __init__(self, toUserName, fromUserName, mediaId): 60 | self.__dict = dict() 61 | self.__dict["ToUserName"] = toUserName 62 | self.__dict["FromUserName"] = fromUserName 63 | self.__dict["CreateTime"] = int(time.time()) 64 | self.__dict["MediaId"] = mediaId 65 | 66 | def send(self): 67 | XmlForm = """ 68 | 69 | 70 | 71 | {CreateTime} 72 | 73 | 74 | 75 | 76 | 77 | """ 78 | return XmlForm.format(**self.__dict) 79 | -------------------------------------------------------------------------------- /channel/wechatmp/wechatmp_client.py: -------------------------------------------------------------------------------- 1 | import time 2 | import json 3 | import requests 4 | import threading 5 | from channel.wechatmp.common import * 6 | from common.log import logger 7 | from config import conf 8 | 9 | 10 | class WechatMPClient: 11 | def __init__(self): 12 | self.app_id = conf().get("wechatmp_app_id") 13 | self.app_secret = conf().get("wechatmp_app_secret") 14 | self.access_token = None 15 | self.access_token_expires_time = 0 16 | self.access_token_lock = threading.Lock() 17 | self.get_access_token() 18 | 19 | 20 | def wechatmp_request(self, method, url, **kwargs): 21 | r = requests.request(method=method, url=url, **kwargs) 22 | r.raise_for_status() 23 | r.encoding = "utf-8" 24 | ret = r.json() 25 | if "errcode" in ret and ret["errcode"] != 0: 26 | if ret["errcode"] == 45009: 27 | self.clear_quota_v2() 28 | raise WeChatAPIException("{}".format(ret)) 29 | return ret 30 | 31 | def get_access_token(self): 32 | # return the access_token 33 | if self.access_token: 34 | if self.access_token_expires_time - time.time() > 60: 35 | return self.access_token 36 | 37 | # Get new access_token 38 | # Do not request access_token in parallel! Only the last obtained is valid. 39 | if self.access_token_lock.acquire(blocking=False): 40 | # Wait for other threads that have previously obtained access_token to complete the request 41 | # This happens every 2 hours, so it doesn't affect the experience very much 42 | time.sleep(1) 43 | self.access_token = None 44 | url = "https://api.weixin.qq.com/cgi-bin/token" 45 | params = { 46 | "grant_type": "client_credential", 47 | "appid": self.app_id, 48 | "secret": self.app_secret, 49 | } 50 | ret = self.wechatmp_request(method="get", url=url, params=params) 51 | self.access_token = ret["access_token"] 52 | self.access_token_expires_time = int(time.time()) + ret["expires_in"] 53 | logger.info("[wechatmp] access_token: {}".format(self.access_token)) 54 | self.access_token_lock.release() 55 | else: 56 | # Wait for token update 57 | while self.access_token_lock.locked(): 58 | time.sleep(0.1) 59 | return self.access_token 60 | 61 | 62 | def send_text(self, receiver, reply_text): 63 | url = "https://api.weixin.qq.com/cgi-bin/message/custom/send" 64 | params = {"access_token": self.get_access_token()} 65 | json_data = { 66 | "touser": receiver, 67 | "msgtype": "text", 68 | "text": {"content": reply_text}, 69 | } 70 | self.wechatmp_request( 71 | method="post", 72 | url=url, 73 | params=params, 74 | data=json.dumps(json_data, ensure_ascii=False).encode("utf8"), 75 | ) 76 | 77 | 78 | def send_voice(self, receiver, media_id): 79 | url="https://api.weixin.qq.com/cgi-bin/message/custom/send" 80 | params = {"access_token": self.get_access_token()} 81 | json_data = { 82 | "touser": receiver, 83 | "msgtype": "voice", 84 | "voice": { 85 | "media_id": media_id 86 | } 87 | } 88 | self.wechatmp_request( 89 | method="post", 90 | url=url, 91 | params=params, 92 | data=json.dumps(json_data, ensure_ascii=False).encode("utf8"), 93 | ) 94 | 95 | def send_image(self, receiver, media_id): 96 | url="https://api.weixin.qq.com/cgi-bin/message/custom/send" 97 | params = {"access_token": self.get_access_token()} 98 | json_data = { 99 | "touser": receiver, 100 | "msgtype": "image", 101 | "image": { 102 | "media_id": media_id 103 | } 104 | } 105 | self.wechatmp_request( 106 | method="post", 107 | url=url, 108 | params=params, 109 | data=json.dumps(json_data, ensure_ascii=False).encode("utf8"), 110 | ) 111 | 112 | 113 | def upload_media(self, media_type, media_file): 114 | url="https://api.weixin.qq.com/cgi-bin/media/upload" 115 | params={ 116 | "access_token": self.get_access_token(), 117 | "type": media_type 118 | } 119 | files={"media": media_file} 120 | ret = self.wechatmp_request( 121 | method="post", 122 | url=url, 123 | params=params, 124 | files=files 125 | ) 126 | logger.debug("[wechatmp] media {} uploaded".format(media_file)) 127 | return ret["media_id"] 128 | 129 | 130 | def upload_permanent_media(self, media_type, media_file): 131 | url="https://api.weixin.qq.com/cgi-bin/material/add_material" 132 | params={ 133 | "access_token": self.get_access_token(), 134 | "type": media_type 135 | } 136 | files={"media": media_file} 137 | ret = self.wechatmp_request( 138 | method="post", 139 | url=url, 140 | params=params, 141 | files=files 142 | ) 143 | logger.debug("[wechatmp] permanent media {} uploaded".format(media_file)) 144 | return ret["media_id"] 145 | 146 | 147 | def delete_permanent_media(self, media_id): 148 | url="https://api.weixin.qq.com/cgi-bin/material/del_material" 149 | params={ 150 | "access_token": self.get_access_token() 151 | } 152 | self.wechatmp_request( 153 | method="post", 154 | url=url, 155 | params=params, 156 | data=json.dumps({"media_id": media_id}, ensure_ascii=False).encode("utf8") 157 | ) 158 | logger.debug("[wechatmp] permanent media {} deleted".format(media_id)) 159 | 160 | def clear_quota(self): 161 | url="https://api.weixin.qq.com/cgi-bin/clear_quota" 162 | params = { 163 | "access_token": self.get_access_token() 164 | } 165 | self.wechatmp_request( 166 | method="post", 167 | url=url, 168 | params=params, 169 | data={"appid": self.app_id} 170 | ) 171 | logger.debug("[wechatmp] API quata has been cleard") 172 | 173 | def clear_quota_v2(self): 174 | url="https://api.weixin.qq.com/cgi-bin/clear_quota/v2" 175 | self.wechatmp_request( 176 | method="post", 177 | url=url, 178 | data={"appid": self.app_id, "appsecret": self.app_secret} 179 | ) 180 | logger.debug("[wechatmp] API quata has been cleard") 181 | -------------------------------------------------------------------------------- /channel/wechatmp/wechatmp_message.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*-# 2 | # filename: receive.py 3 | import xml.etree.ElementTree as ET 4 | 5 | from bridge.context import ContextType 6 | from channel.chat_message import ChatMessage 7 | from common.log import logger 8 | 9 | 10 | def parse_xml(web_data): 11 | if len(web_data) == 0: 12 | return None 13 | xmlData = ET.fromstring(web_data) 14 | return WeChatMPMessage(xmlData) 15 | 16 | 17 | class WeChatMPMessage(ChatMessage): 18 | def __init__(self, xmlData): 19 | super().__init__(xmlData) 20 | self.to_user_id = xmlData.find("ToUserName").text 21 | self.from_user_id = xmlData.find("FromUserName").text 22 | self.create_time = xmlData.find("CreateTime").text 23 | self.msg_type = xmlData.find("MsgType").text 24 | try: 25 | self.msg_id = xmlData.find("MsgId").text 26 | except: 27 | self.msg_id = self.from_user_id + self.create_time 28 | self.is_group = False 29 | 30 | # reply to other_user_id 31 | self.other_user_id = self.from_user_id 32 | 33 | if self.msg_type == "text": 34 | self.ctype = ContextType.TEXT 35 | self.content = xmlData.find("Content").text 36 | elif self.msg_type == "voice": 37 | self.ctype = ContextType.TEXT 38 | self.content = xmlData.find("Recognition").text # 接收语音识别结果 39 | # other voice_to_text method not implemented yet 40 | if self.content == None: 41 | self.content = "你好" 42 | elif self.msg_type == "image": 43 | # not implemented yet 44 | self.pic_url = xmlData.find("PicUrl").text 45 | self.media_id = xmlData.find("MediaId").text 46 | elif self.msg_type == "event": 47 | self.content = xmlData.find("Event").text 48 | else: # video, shortvideo, location, link 49 | # not implemented 50 | pass 51 | -------------------------------------------------------------------------------- /common/const.py: -------------------------------------------------------------------------------- 1 | # bot_type 2 | OPEN_AI = "openAI" 3 | CHATGPT = "chatGPT" 4 | BAIDU = "baidu" 5 | CHATGPTONAZURE = "chatGPTOnAzure" 6 | -------------------------------------------------------------------------------- /common/dequeue.py: -------------------------------------------------------------------------------- 1 | from queue import Full, Queue 2 | from time import monotonic as time 3 | 4 | 5 | # add implementation of putleft to Queue 6 | class Dequeue(Queue): 7 | def putleft(self, item, block=True, timeout=None): 8 | with self.not_full: 9 | if self.maxsize > 0: 10 | if not block: 11 | if self._qsize() >= self.maxsize: 12 | raise Full 13 | elif timeout is None: 14 | while self._qsize() >= self.maxsize: 15 | self.not_full.wait() 16 | elif timeout < 0: 17 | raise ValueError("'timeout' must be a non-negative number") 18 | else: 19 | endtime = time() + timeout 20 | while self._qsize() >= self.maxsize: 21 | remaining = endtime - time() 22 | if remaining <= 0.0: 23 | raise Full 24 | self.not_full.wait(remaining) 25 | self._putleft(item) 26 | self.unfinished_tasks += 1 27 | self.not_empty.notify() 28 | 29 | def putleft_nowait(self, item): 30 | return self.putleft(item, block=False) 31 | 32 | def _putleft(self, item): 33 | self.queue.appendleft(item) 34 | -------------------------------------------------------------------------------- /common/expired_dict.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | 3 | 4 | class ExpiredDict(dict): 5 | def __init__(self, expires_in_seconds): 6 | super().__init__() 7 | self.expires_in_seconds = expires_in_seconds 8 | 9 | def __getitem__(self, key): 10 | value, expiry_time = super().__getitem__(key) 11 | if datetime.now() > expiry_time: 12 | del self[key] 13 | raise KeyError("expired {}".format(key)) 14 | self.__setitem__(key, value) 15 | return value 16 | 17 | def __setitem__(self, key, value): 18 | expiry_time = datetime.now() + timedelta(seconds=self.expires_in_seconds) 19 | super().__setitem__(key, (value, expiry_time)) 20 | 21 | def get(self, key, default=None): 22 | try: 23 | return self[key] 24 | except KeyError: 25 | return default 26 | 27 | def __contains__(self, key): 28 | try: 29 | self[key] 30 | return True 31 | except KeyError: 32 | return False 33 | 34 | def keys(self): 35 | keys = list(super().keys()) 36 | return [key for key in keys if key in self] 37 | 38 | def items(self): 39 | return [(key, self[key]) for key in self.keys()] 40 | 41 | def __iter__(self): 42 | return self.keys().__iter__() 43 | -------------------------------------------------------------------------------- /common/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | 5 | def _reset_logger(log): 6 | for handler in log.handlers: 7 | handler.close() 8 | log.removeHandler(handler) 9 | del handler 10 | log.handlers.clear() 11 | log.propagate = False 12 | console_handle = logging.StreamHandler(sys.stdout) 13 | console_handle.setFormatter( 14 | logging.Formatter( 15 | "[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s", 16 | datefmt="%Y-%m-%d %H:%M:%S", 17 | ) 18 | ) 19 | file_handle = logging.FileHandler("run.log", encoding="utf-8") 20 | file_handle.setFormatter( 21 | logging.Formatter( 22 | "[%(levelname)s][%(asctime)s][%(filename)s:%(lineno)d] - %(message)s", 23 | datefmt="%Y-%m-%d %H:%M:%S", 24 | ) 25 | ) 26 | log.addHandler(file_handle) 27 | log.addHandler(console_handle) 28 | 29 | 30 | def _get_logger(): 31 | log = logging.getLogger("log") 32 | _reset_logger(log) 33 | log.setLevel(logging.INFO) 34 | return log 35 | 36 | 37 | # 日志句柄 38 | logger = _get_logger() 39 | -------------------------------------------------------------------------------- /common/package_manager.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import pip 4 | from pip._internal import main as pipmain 5 | 6 | from common.log import _reset_logger, logger 7 | 8 | 9 | def install(package): 10 | pipmain(["install", package]) 11 | 12 | 13 | def install_requirements(file): 14 | pipmain(["install", "-r", file, "--upgrade"]) 15 | _reset_logger(logger) 16 | 17 | 18 | def check_dulwich(): 19 | needwait = False 20 | for i in range(2): 21 | if needwait: 22 | time.sleep(3) 23 | needwait = False 24 | try: 25 | import dulwich 26 | 27 | return 28 | except ImportError: 29 | try: 30 | install("dulwich") 31 | except: 32 | needwait = True 33 | try: 34 | import dulwich 35 | except ImportError: 36 | raise ImportError("Unable to import dulwich") 37 | -------------------------------------------------------------------------------- /common/singleton.py: -------------------------------------------------------------------------------- 1 | def singleton(cls): 2 | instances = {} 3 | 4 | def get_instance(*args, **kwargs): 5 | if cls not in instances: 6 | instances[cls] = cls(*args, **kwargs) 7 | return instances[cls] 8 | 9 | return get_instance 10 | -------------------------------------------------------------------------------- /common/sorted_dict.py: -------------------------------------------------------------------------------- 1 | import heapq 2 | 3 | 4 | class SortedDict(dict): 5 | def __init__(self, sort_func=lambda k, v: k, init_dict=None, reverse=False): 6 | if init_dict is None: 7 | init_dict = [] 8 | if isinstance(init_dict, dict): 9 | init_dict = init_dict.items() 10 | self.sort_func = sort_func 11 | self.sorted_keys = None 12 | self.reverse = reverse 13 | self.heap = [] 14 | for k, v in init_dict: 15 | self[k] = v 16 | 17 | def __setitem__(self, key, value): 18 | if key in self: 19 | super().__setitem__(key, value) 20 | for i, (priority, k) in enumerate(self.heap): 21 | if k == key: 22 | self.heap[i] = (self.sort_func(key, value), key) 23 | heapq.heapify(self.heap) 24 | break 25 | self.sorted_keys = None 26 | else: 27 | super().__setitem__(key, value) 28 | heapq.heappush(self.heap, (self.sort_func(key, value), key)) 29 | self.sorted_keys = None 30 | 31 | def __delitem__(self, key): 32 | super().__delitem__(key) 33 | for i, (priority, k) in enumerate(self.heap): 34 | if k == key: 35 | del self.heap[i] 36 | heapq.heapify(self.heap) 37 | break 38 | self.sorted_keys = None 39 | 40 | def keys(self): 41 | if self.sorted_keys is None: 42 | self.sorted_keys = [k for _, k in sorted(self.heap, reverse=self.reverse)] 43 | return self.sorted_keys 44 | 45 | def items(self): 46 | if self.sorted_keys is None: 47 | self.sorted_keys = [k for _, k in sorted(self.heap, reverse=self.reverse)] 48 | sorted_items = [(k, self[k]) for k in self.sorted_keys] 49 | return sorted_items 50 | 51 | def _update_heap(self, key): 52 | for i, (priority, k) in enumerate(self.heap): 53 | if k == key: 54 | new_priority = self.sort_func(key, self[key]) 55 | if new_priority != priority: 56 | self.heap[i] = (new_priority, key) 57 | heapq.heapify(self.heap) 58 | self.sorted_keys = None 59 | break 60 | 61 | def __iter__(self): 62 | return iter(self.keys()) 63 | 64 | def __repr__(self): 65 | return f"{type(self).__name__}({dict(self)}, sort_func={self.sort_func.__name__}, reverse={self.reverse})" 66 | -------------------------------------------------------------------------------- /common/time_check.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import re 3 | import time 4 | 5 | import config 6 | from common.log import logger 7 | 8 | 9 | def time_checker(f): 10 | def _time_checker(self, *args, **kwargs): 11 | _config = config.conf() 12 | chat_time_module = _config.get("chat_time_module", False) 13 | if chat_time_module: 14 | chat_start_time = _config.get("chat_start_time", "00:00") 15 | chat_stopt_time = _config.get("chat_stop_time", "24:00") 16 | time_regex = re.compile( 17 | r"^([01]?[0-9]|2[0-4])(:)([0-5][0-9])$" 18 | ) # 时间匹配,包含24:00 19 | 20 | starttime_format_check = time_regex.match(chat_start_time) # 检查停止时间格式 21 | stoptime_format_check = time_regex.match(chat_stopt_time) # 检查停止时间格式 22 | chat_time_check = chat_start_time < chat_stopt_time # 确定启动时间<停止时间 23 | 24 | # 时间格式检查 25 | if not ( 26 | starttime_format_check and stoptime_format_check and chat_time_check 27 | ): 28 | logger.warn( 29 | "时间格式不正确,请在config.json中修改您的CHAT_START_TIME/CHAT_STOP_TIME,否则可能会影响您正常使用,开始({})-结束({})".format( 30 | starttime_format_check, stoptime_format_check 31 | ) 32 | ) 33 | if chat_start_time > "23:59": 34 | logger.error("启动时间可能存在问题,请修改!") 35 | 36 | # 服务时间检查 37 | now_time = time.strftime("%H:%M", time.localtime()) 38 | if chat_start_time <= now_time <= chat_stopt_time: # 服务时间内,正常返回回答 39 | f(self, *args, **kwargs) 40 | return None 41 | else: 42 | if args[0]["Content"] == "#更新配置": # 不在服务时间内也可以更新配置 43 | f(self, *args, **kwargs) 44 | else: 45 | logger.info("非服务时间内,不接受访问") 46 | return None 47 | else: 48 | f(self, *args, **kwargs) # 未开启时间模块则直接回答 49 | 50 | return _time_checker 51 | -------------------------------------------------------------------------------- /common/tmp_dir.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | 4 | from config import conf 5 | 6 | 7 | class TmpDir(object): 8 | """A temporary directory that is deleted when the object is destroyed.""" 9 | 10 | tmpFilePath = pathlib.Path("./tmp/") 11 | 12 | def __init__(self): 13 | pathExists = os.path.exists(self.tmpFilePath) 14 | if not pathExists: 15 | os.makedirs(self.tmpFilePath) 16 | 17 | def path(self): 18 | return str(self.tmpFilePath) + "/" 19 | -------------------------------------------------------------------------------- /common/token_bucket.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import time 3 | 4 | 5 | class TokenBucket: 6 | def __init__(self, tpm, timeout=None): 7 | self.capacity = int(tpm) # 令牌桶容量 8 | self.tokens = 0 # 初始令牌数为0 9 | self.rate = int(tpm) / 60 # 令牌每秒生成速率 10 | self.timeout = timeout # 等待令牌超时时间 11 | self.cond = threading.Condition() # 条件变量 12 | self.is_running = True 13 | # 开启令牌生成线程 14 | threading.Thread(target=self._generate_tokens).start() 15 | 16 | def _generate_tokens(self): 17 | """生成令牌""" 18 | while self.is_running: 19 | with self.cond: 20 | if self.tokens < self.capacity: 21 | self.tokens += 1 22 | self.cond.notify() # 通知获取令牌的线程 23 | time.sleep(1 / self.rate) 24 | 25 | def get_token(self): 26 | """获取令牌""" 27 | with self.cond: 28 | while self.tokens <= 0: 29 | flag = self.cond.wait(self.timeout) 30 | if not flag: # 超时 31 | return False 32 | self.tokens -= 1 33 | return True 34 | 35 | def close(self): 36 | self.is_running = False 37 | 38 | 39 | if __name__ == "__main__": 40 | token_bucket = TokenBucket(20, None) # 创建一个每分钟生产20个tokens的令牌桶 41 | # token_bucket = TokenBucket(20, 0.1) 42 | for i in range(3): 43 | if token_bucket.get_token(): 44 | print(f"第{i+1}次请求成功") 45 | token_bucket.close() 46 | -------------------------------------------------------------------------------- /config-template.json: -------------------------------------------------------------------------------- 1 | { 2 | "open_ai_api_key": "YOUR API KEY", 3 | "model": "gpt-3.5-turbo", 4 | "proxy": "", 5 | "single_chat_prefix": [ 6 | "bot", 7 | "@bot" 8 | ], 9 | "single_chat_reply_prefix": "[bot] ", 10 | "group_chat_prefix": [ 11 | "@bot" 12 | ], 13 | "group_name_white_list": [ 14 | "ChatGPT测试群", 15 | "ChatGPT测试群2" 16 | ], 17 | "group_chat_in_one_session": [ 18 | "ChatGPT测试群" 19 | ], 20 | "image_create_prefix": [ 21 | "画", 22 | "看", 23 | "找" 24 | ], 25 | "speech_recognition": false, 26 | "group_speech_recognition": false, 27 | "voice_reply_voice": false, 28 | "conversation_max_tokens": 1000, 29 | "expires_in_seconds": 3600, 30 | "character_desc": "你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。" 31 | } 32 | -------------------------------------------------------------------------------- /docker/Dockerfile.alpine: -------------------------------------------------------------------------------- 1 | FROM python:3.10-alpine 2 | 3 | LABEL maintainer="foo@bar.com" 4 | ARG TZ='Asia/Shanghai' 5 | 6 | ARG CHATGPT_ON_WECHAT_VER 7 | 8 | ENV BUILD_PREFIX=/app 9 | 10 | RUN apk add --no-cache \ 11 | bash \ 12 | curl \ 13 | wget \ 14 | && export BUILD_GITHUB_TAG=${CHATGPT_ON_WECHAT_VER:-`curl -sL "https://api.github.com/repos/zhayujie/chatgpt-on-wechat/releases/latest" | \ 15 | grep '"tag_name":' | \ 16 | sed -E 's/.*"([^"]+)".*/\1/'`} \ 17 | && wget -t 3 -T 30 -nv -O chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \ 18 | https://github.com/zhayujie/chatgpt-on-wechat/archive/refs/tags/${BUILD_GITHUB_TAG}.tar.gz \ 19 | && tar -xzf chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \ 20 | && mv chatgpt-on-wechat-${BUILD_GITHUB_TAG} ${BUILD_PREFIX} \ 21 | && rm chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \ 22 | && cd ${BUILD_PREFIX} \ 23 | && cp config-template.json ${BUILD_PREFIX}/config.json \ 24 | && /usr/local/bin/python -m pip install --no-cache --upgrade pip \ 25 | && pip install --no-cache -r requirements.txt \ 26 | && pip install --no-cache -r requirements-optional.txt \ 27 | && apk del curl wget 28 | 29 | WORKDIR ${BUILD_PREFIX} 30 | 31 | ADD ./entrypoint.sh /entrypoint.sh 32 | 33 | RUN chmod +x /entrypoint.sh \ 34 | && adduser -D -h /home/noroot -u 1000 -s /bin/bash noroot \ 35 | && chown -R noroot:noroot ${BUILD_PREFIX} 36 | 37 | USER noroot 38 | 39 | ENTRYPOINT ["/entrypoint.sh"] 40 | -------------------------------------------------------------------------------- /docker/Dockerfile.debian: -------------------------------------------------------------------------------- 1 | FROM python:3.10 2 | 3 | LABEL maintainer="foo@bar.com" 4 | ARG TZ='Asia/Shanghai' 5 | 6 | ARG CHATGPT_ON_WECHAT_VER 7 | 8 | ENV BUILD_PREFIX=/app 9 | 10 | RUN apt-get update \ 11 | && apt-get install -y --no-install-recommends \ 12 | wget \ 13 | curl \ 14 | && rm -rf /var/lib/apt/lists/* \ 15 | && export BUILD_GITHUB_TAG=${CHATGPT_ON_WECHAT_VER:-`curl -sL "https://api.github.com/repos/zhayujie/chatgpt-on-wechat/releases/latest" | \ 16 | grep '"tag_name":' | \ 17 | sed -E 's/.*"([^"]+)".*/\1/'`} \ 18 | && wget -t 3 -T 30 -nv -O chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \ 19 | https://github.com/zhayujie/chatgpt-on-wechat/archive/refs/tags/${BUILD_GITHUB_TAG}.tar.gz \ 20 | && tar -xzf chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \ 21 | && mv chatgpt-on-wechat-${BUILD_GITHUB_TAG} ${BUILD_PREFIX} \ 22 | && rm chatgpt-on-wechat-${BUILD_GITHUB_TAG}.tar.gz \ 23 | && cd ${BUILD_PREFIX} \ 24 | && cp config-template.json ${BUILD_PREFIX}/config.json \ 25 | && /usr/local/bin/python -m pip install --no-cache --upgrade pip \ 26 | && pip install --no-cache -r requirements.txt \ 27 | && pip install --no-cache -r requirements-optional.txt 28 | 29 | WORKDIR ${BUILD_PREFIX} 30 | 31 | ADD ./entrypoint.sh /entrypoint.sh 32 | 33 | RUN chmod +x /entrypoint.sh \ 34 | && groupadd -r noroot \ 35 | && useradd -r -g noroot -s /bin/bash -d /home/noroot noroot \ 36 | && chown -R noroot:noroot ${BUILD_PREFIX} 37 | 38 | USER noroot 39 | 40 | ENTRYPOINT ["/entrypoint.sh"] 41 | -------------------------------------------------------------------------------- /docker/Dockerfile.debian.latest: -------------------------------------------------------------------------------- 1 | FROM python:3.10-slim 2 | 3 | LABEL maintainer="foo@bar.com" 4 | ARG TZ='Asia/Shanghai' 5 | 6 | ARG CHATGPT_ON_WECHAT_VER 7 | 8 | ENV BUILD_PREFIX=/app 9 | 10 | ADD . ${BUILD_PREFIX} 11 | 12 | RUN apt-get update \ 13 | &&apt-get install -y --no-install-recommends bash \ 14 | ffmpeg espeak \ 15 | && cd ${BUILD_PREFIX} \ 16 | && cp config-template.json config.json \ 17 | && /usr/local/bin/python -m pip install --no-cache --upgrade pip \ 18 | && pip install --no-cache -r requirements.txt \ 19 | && pip install --no-cache -r requirements-optional.txt \ 20 | && pip install azure-cognitiveservices-speech 21 | 22 | WORKDIR ${BUILD_PREFIX} 23 | 24 | ADD docker/entrypoint.sh /entrypoint.sh 25 | 26 | RUN chmod +x /entrypoint.sh \ 27 | && groupadd -r noroot \ 28 | && useradd -r -g noroot -s /bin/bash -d /home/noroot noroot \ 29 | && chown -R noroot:noroot ${BUILD_PREFIX} 30 | 31 | USER noroot 32 | 33 | ENTRYPOINT ["docker/entrypoint.sh"] -------------------------------------------------------------------------------- /docker/Dockerfile.latest: -------------------------------------------------------------------------------- 1 | FROM python:3.10-alpine 2 | 3 | LABEL maintainer="foo@bar.com" 4 | ARG TZ='Asia/Shanghai' 5 | 6 | ARG CHATGPT_ON_WECHAT_VER 7 | 8 | ENV BUILD_PREFIX=/app 9 | 10 | ADD . ${BUILD_PREFIX} 11 | 12 | RUN apk add --no-cache bash ffmpeg espeak \ 13 | && cd ${BUILD_PREFIX} \ 14 | && cp config-template.json config.json \ 15 | && /usr/local/bin/python -m pip install --no-cache --upgrade pip \ 16 | && pip install --no-cache -r requirements.txt \ 17 | && pip install --no-cache -r requirements-optional.txt 18 | 19 | WORKDIR ${BUILD_PREFIX} 20 | 21 | ADD docker/entrypoint.sh /entrypoint.sh 22 | 23 | RUN chmod +x /entrypoint.sh \ 24 | && adduser -D -h /home/noroot -u 1000 -s /bin/bash noroot \ 25 | && chown -R noroot:noroot ${BUILD_PREFIX} 26 | 27 | USER noroot 28 | 29 | ENTRYPOINT ["docker/entrypoint.sh"] -------------------------------------------------------------------------------- /docker/build.alpine.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # fetch latest release tag 4 | CHATGPT_ON_WECHAT_TAG=`curl -sL "https://api.github.com/repos/zhayujie/chatgpt-on-wechat/releases/latest" | \ 5 | grep '"tag_name":' | \ 6 | sed -E 's/.*"([^"]+)".*/\1/'` 7 | 8 | # build image 9 | docker build -f Dockerfile.alpine \ 10 | --build-arg CHATGPT_ON_WECHAT_VER=$CHATGPT_ON_WECHAT_TAG \ 11 | -t zhayujie/chatgpt-on-wechat . 12 | 13 | # tag image 14 | docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:alpine 15 | docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$CHATGPT_ON_WECHAT_TAG-alpine 16 | -------------------------------------------------------------------------------- /docker/build.debian.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # fetch latest release tag 4 | CHATGPT_ON_WECHAT_TAG=`curl -sL "https://api.github.com/repos/zhayujie/chatgpt-on-wechat/releases/latest" | \ 5 | grep '"tag_name":' | \ 6 | sed -E 's/.*"([^"]+)".*/\1/'` 7 | 8 | # build image 9 | docker build -f Dockerfile.debian \ 10 | --build-arg CHATGPT_ON_WECHAT_VER=$CHATGPT_ON_WECHAT_TAG \ 11 | -t zhayujie/chatgpt-on-wechat . 12 | 13 | # tag image 14 | docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:debian 15 | docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$CHATGPT_ON_WECHAT_TAG-debian -------------------------------------------------------------------------------- /docker/build.latest.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | unset KUBECONFIG 4 | 5 | cd .. && docker build -f docker/Dockerfile.latest \ 6 | -t zhayujie/chatgpt-on-wechat . 7 | 8 | docker tag zhayujie/chatgpt-on-wechat zhayujie/chatgpt-on-wechat:$(date +%y%m%d) -------------------------------------------------------------------------------- /docker/chatgpt-on-wechat-voice-reply/Dockerfile.alpine: -------------------------------------------------------------------------------- 1 | FROM zhayujie/chatgpt-on-wechat:alpine 2 | 3 | LABEL maintainer="foo@bar.com" 4 | ARG TZ='Asia/Shanghai' 5 | 6 | USER root 7 | 8 | RUN apk add --no-cache \ 9 | ffmpeg \ 10 | espeak \ 11 | && pip install --no-cache \ 12 | baidu-aip \ 13 | chardet \ 14 | SpeechRecognition 15 | 16 | # replace entrypoint 17 | ADD ./entrypoint.sh /entrypoint.sh 18 | 19 | RUN chmod +x /entrypoint.sh 20 | 21 | USER noroot 22 | 23 | ENTRYPOINT ["/entrypoint.sh"] -------------------------------------------------------------------------------- /docker/chatgpt-on-wechat-voice-reply/Dockerfile.debian: -------------------------------------------------------------------------------- 1 | FROM zhayujie/chatgpt-on-wechat:debian 2 | 3 | LABEL maintainer="foo@bar.com" 4 | ARG TZ='Asia/Shanghai' 5 | 6 | USER root 7 | 8 | RUN apt-get update \ 9 | && apt-get install -y --no-install-recommends \ 10 | ffmpeg \ 11 | espeak \ 12 | && pip install --no-cache \ 13 | baidu-aip \ 14 | chardet \ 15 | SpeechRecognition 16 | 17 | # replace entrypoint 18 | ADD ./entrypoint.sh /entrypoint.sh 19 | 20 | RUN chmod +x /entrypoint.sh 21 | 22 | USER noroot 23 | 24 | ENTRYPOINT ["/entrypoint.sh"] 25 | -------------------------------------------------------------------------------- /docker/chatgpt-on-wechat-voice-reply/docker-compose.yaml: -------------------------------------------------------------------------------- 1 | version: '2.0' 2 | services: 3 | chatgpt-on-wechat: 4 | build: 5 | context: ./ 6 | dockerfile: Dockerfile.alpine 7 | image: zhayujie/chatgpt-on-wechat-voice-reply 8 | container_name: chatgpt-on-wechat-voice-reply 9 | environment: 10 | OPEN_AI_API_KEY: 'YOUR API KEY' 11 | OPEN_AI_PROXY: '' 12 | SINGLE_CHAT_PREFIX: '["bot", "@bot"]' 13 | SINGLE_CHAT_REPLY_PREFIX: '"[bot] "' 14 | GROUP_CHAT_PREFIX: '["@bot"]' 15 | GROUP_NAME_WHITE_LIST: '["ChatGPT测试群", "ChatGPT测试群2"]' 16 | IMAGE_CREATE_PREFIX: '["画", "看", "找"]' 17 | CONVERSATION_MAX_TOKENS: 1000 18 | SPEECH_RECOGNITION: 'true' 19 | CHARACTER_DESC: '你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。' 20 | EXPIRES_IN_SECONDS: 3600 21 | VOICE_REPLY_VOICE: 'true' 22 | BAIDU_APP_ID: 'YOUR BAIDU APP ID' 23 | BAIDU_API_KEY: 'YOUR BAIDU API KEY' 24 | BAIDU_SECRET_KEY: 'YOUR BAIDU SERVICE KEY' -------------------------------------------------------------------------------- /docker/chatgpt-on-wechat-voice-reply/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | # build prefix 5 | CHATGPT_ON_WECHAT_PREFIX=${CHATGPT_ON_WECHAT_PREFIX:-""} 6 | # path to config.json 7 | CHATGPT_ON_WECHAT_CONFIG_PATH=${CHATGPT_ON_WECHAT_CONFIG_PATH:-""} 8 | # execution command line 9 | CHATGPT_ON_WECHAT_EXEC=${CHATGPT_ON_WECHAT_EXEC:-""} 10 | 11 | OPEN_AI_API_KEY=${OPEN_AI_API_KEY:-""} 12 | OPEN_AI_PROXY=${OPEN_AI_PROXY:-""} 13 | SINGLE_CHAT_PREFIX=${SINGLE_CHAT_PREFIX:-""} 14 | SINGLE_CHAT_REPLY_PREFIX=${SINGLE_CHAT_REPLY_PREFIX:-""} 15 | GROUP_CHAT_PREFIX=${GROUP_CHAT_PREFIX:-""} 16 | GROUP_NAME_WHITE_LIST=${GROUP_NAME_WHITE_LIST:-""} 17 | IMAGE_CREATE_PREFIX=${IMAGE_CREATE_PREFIX:-""} 18 | CONVERSATION_MAX_TOKENS=${CONVERSATION_MAX_TOKENS:-""} 19 | SPEECH_RECOGNITION=${SPEECH_RECOGNITION:-""} 20 | CHARACTER_DESC=${CHARACTER_DESC:-""} 21 | EXPIRES_IN_SECONDS=${EXPIRES_IN_SECONDS:-""} 22 | 23 | VOICE_REPLY_VOICE=${VOICE_REPLY_VOICE:-""} 24 | BAIDU_APP_ID=${BAIDU_APP_ID:-""} 25 | BAIDU_API_KEY=${BAIDU_API_KEY:-""} 26 | BAIDU_SECRET_KEY=${BAIDU_SECRET_KEY:-""} 27 | 28 | # CHATGPT_ON_WECHAT_PREFIX is empty, use /app 29 | if [ "$CHATGPT_ON_WECHAT_PREFIX" == "" ] ; then 30 | CHATGPT_ON_WECHAT_PREFIX=/app 31 | fi 32 | 33 | # CHATGPT_ON_WECHAT_CONFIG_PATH is empty, use '/app/config.json' 34 | if [ "$CHATGPT_ON_WECHAT_CONFIG_PATH" == "" ] ; then 35 | CHATGPT_ON_WECHAT_CONFIG_PATH=$CHATGPT_ON_WECHAT_PREFIX/config.json 36 | fi 37 | 38 | # CHATGPT_ON_WECHAT_EXEC is empty, use ‘python app.py’ 39 | if [ "$CHATGPT_ON_WECHAT_EXEC" == "" ] ; then 40 | CHATGPT_ON_WECHAT_EXEC="python app.py" 41 | fi 42 | 43 | # modify content in config.json 44 | if [ "$OPEN_AI_API_KEY" != "" ] ; then 45 | sed -i "s/\"open_ai_api_key\".*,$/\"open_ai_api_key\": \"$OPEN_AI_API_KEY\",/" $CHATGPT_ON_WECHAT_CONFIG_PATH 46 | else 47 | echo -e "\033[31m[Warning] You need to set OPEN_AI_API_KEY before running!\033[0m" 48 | fi 49 | 50 | # use http_proxy as default 51 | if [ "$HTTP_PROXY" != "" ] ; then 52 | sed -i "s/\"proxy\".*,$/\"proxy\": \"$HTTP_PROXY\",/" $CHATGPT_ON_WECHAT_CONFIG_PATH 53 | fi 54 | 55 | if [ "$OPEN_AI_PROXY" != "" ] ; then 56 | sed -i "s/\"proxy\".*,$/\"proxy\": \"$OPEN_AI_PROXY\",/" $CHATGPT_ON_WECHAT_CONFIG_PATH 57 | fi 58 | 59 | if [ "$SINGLE_CHAT_PREFIX" != "" ] ; then 60 | sed -i "s/\"single_chat_prefix\".*,$/\"single_chat_prefix\": $SINGLE_CHAT_PREFIX,/" $CHATGPT_ON_WECHAT_CONFIG_PATH 61 | fi 62 | 63 | if [ "$SINGLE_CHAT_REPLY_PREFIX" != "" ] ; then 64 | sed -i "s/\"single_chat_reply_prefix\".*,$/\"single_chat_reply_prefix\": $SINGLE_CHAT_REPLY_PREFIX,/" $CHATGPT_ON_WECHAT_CONFIG_PATH 65 | fi 66 | 67 | if [ "$GROUP_CHAT_PREFIX" != "" ] ; then 68 | sed -i "s/\"group_chat_prefix\".*,$/\"group_chat_prefix\": $GROUP_CHAT_PREFIX,/" $CHATGPT_ON_WECHAT_CONFIG_PATH 69 | fi 70 | 71 | if [ "$GROUP_NAME_WHITE_LIST" != "" ] ; then 72 | sed -i "s/\"group_name_white_list\".*,$/\"group_name_white_list\": $GROUP_NAME_WHITE_LIST,/" $CHATGPT_ON_WECHAT_CONFIG_PATH 73 | fi 74 | 75 | if [ "$IMAGE_CREATE_PREFIX" != "" ] ; then 76 | sed -i "s/\"image_create_prefix\".*,$/\"image_create_prefix\": $IMAGE_CREATE_PREFIX,/" $CHATGPT_ON_WECHAT_CONFIG_PATH 77 | fi 78 | 79 | if [ "$CONVERSATION_MAX_TOKENS" != "" ] ; then 80 | sed -i "s/\"conversation_max_tokens\".*,$/\"conversation_max_tokens\": $CONVERSATION_MAX_TOKENS,/" $CHATGPT_ON_WECHAT_CONFIG_PATH 81 | fi 82 | 83 | if [ "$SPEECH_RECOGNITION" != "" ] ; then 84 | sed -i "s/\"speech_recognition\".*,$/\"speech_recognition\": $SPEECH_RECOGNITION,/" $CHATGPT_ON_WECHAT_CONFIG_PATH 85 | fi 86 | 87 | if [ "$CHARACTER_DESC" != "" ] ; then 88 | sed -i "s/\"character_desc\".*,$/\"character_desc\": \"$CHARACTER_DESC\",/" $CHATGPT_ON_WECHAT_CONFIG_PATH 89 | fi 90 | 91 | if [ "$EXPIRES_IN_SECONDS" != "" ] ; then 92 | sed -i "s/\"expires_in_seconds\".*$/\"expires_in_seconds\": $EXPIRES_IN_SECONDS/" $CHATGPT_ON_WECHAT_CONFIG_PATH 93 | fi 94 | 95 | # append 96 | if [ "$BAIDU_SECRET_KEY" != "" ] ; then 97 | sed -i "1a \ \ \"baidu_secret_key\": \"$BAIDU_SECRET_KEY\"," $CHATGPT_ON_WECHAT_CONFIG_PATH 98 | fi 99 | 100 | if [ "$BAIDU_API_KEY" != "" ] ; then 101 | sed -i "1a \ \ \"baidu_api_key\": \"$BAIDU_API_KEY\"," $CHATGPT_ON_WECHAT_CONFIG_PATH 102 | fi 103 | 104 | if [ "$BAIDU_APP_ID" != "" ] ; then 105 | sed -i "1a \ \ \"baidu_app_id\": \"$BAIDU_APP_ID\"," $CHATGPT_ON_WECHAT_CONFIG_PATH 106 | fi 107 | 108 | if [ "$VOICE_REPLY_VOICE" != "" ] ; then 109 | sed -i "1a \ \ \"voice_reply_voice\": $VOICE_REPLY_VOICE," $CHATGPT_ON_WECHAT_CONFIG_PATH 110 | fi 111 | 112 | # go to prefix dir 113 | cd $CHATGPT_ON_WECHAT_PREFIX 114 | # excute 115 | $CHATGPT_ON_WECHAT_EXEC 116 | 117 | 118 | -------------------------------------------------------------------------------- /docker/docker-compose.yaml: -------------------------------------------------------------------------------- 1 | version: '2.0' 2 | services: 3 | chatgpt-on-wechat: 4 | build: 5 | context: ./ 6 | dockerfile: Dockerfile.alpine 7 | image: zhayujie/chatgpt-on-wechat 8 | container_name: sample-chatgpt-on-wechat 9 | environment: 10 | OPEN_AI_API_KEY: 'YOUR API KEY' 11 | OPEN_AI_PROXY: '' 12 | SINGLE_CHAT_PREFIX: '["bot", "@bot"]' 13 | SINGLE_CHAT_REPLY_PREFIX: '"[bot] "' 14 | GROUP_CHAT_PREFIX: '["@bot"]' 15 | GROUP_NAME_WHITE_LIST: '["ChatGPT测试群", "ChatGPT测试群2"]' 16 | IMAGE_CREATE_PREFIX: '["画", "看", "找"]' 17 | CONVERSATION_MAX_TOKENS: 1000 18 | SPEECH_RECOGNITION: "False" 19 | CHARACTER_DESC: '你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。' 20 | EXPIRES_IN_SECONDS: 3600 -------------------------------------------------------------------------------- /docker/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | # build prefix 5 | CHATGPT_ON_WECHAT_PREFIX=${CHATGPT_ON_WECHAT_PREFIX:-""} 6 | # path to config.json 7 | CHATGPT_ON_WECHAT_CONFIG_PATH=${CHATGPT_ON_WECHAT_CONFIG_PATH:-""} 8 | # execution command line 9 | CHATGPT_ON_WECHAT_EXEC=${CHATGPT_ON_WECHAT_EXEC:-""} 10 | 11 | # use environment variables to pass parameters 12 | # if you have not defined environment variables, set them below 13 | # export OPEN_AI_API_KEY=${OPEN_AI_API_KEY:-'YOUR API KEY'} 14 | # export OPEN_AI_PROXY=${OPEN_AI_PROXY:-""} 15 | # export SINGLE_CHAT_PREFIX=${SINGLE_CHAT_PREFIX:-'["bot", "@bot"]'} 16 | # export SINGLE_CHAT_REPLY_PREFIX=${SINGLE_CHAT_REPLY_PREFIX:-'"[bot] "'} 17 | # export GROUP_CHAT_PREFIX=${GROUP_CHAT_PREFIX:-'["@bot"]'} 18 | # export GROUP_NAME_WHITE_LIST=${GROUP_NAME_WHITE_LIST:-'["ChatGPT测试群", "ChatGPT测试群2"]'} 19 | # export IMAGE_CREATE_PREFIX=${IMAGE_CREATE_PREFIX:-'["画", "看", "找"]'} 20 | # export CONVERSATION_MAX_TOKENS=${CONVERSATION_MAX_TOKENS:-"1000"} 21 | # export SPEECH_RECOGNITION=${SPEECH_RECOGNITION:-"False"} 22 | # export CHARACTER_DESC=${CHARACTER_DESC:-"你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。"} 23 | # export EXPIRES_IN_SECONDS=${EXPIRES_IN_SECONDS:-"3600"} 24 | 25 | # CHATGPT_ON_WECHAT_PREFIX is empty, use /app 26 | if [ "$CHATGPT_ON_WECHAT_PREFIX" == "" ] ; then 27 | CHATGPT_ON_WECHAT_PREFIX=/app 28 | fi 29 | 30 | # CHATGPT_ON_WECHAT_CONFIG_PATH is empty, use '/app/config.json' 31 | if [ "$CHATGPT_ON_WECHAT_CONFIG_PATH" == "" ] ; then 32 | CHATGPT_ON_WECHAT_CONFIG_PATH=$CHATGPT_ON_WECHAT_PREFIX/config.json 33 | fi 34 | 35 | # CHATGPT_ON_WECHAT_EXEC is empty, use ‘python app.py’ 36 | if [ "$CHATGPT_ON_WECHAT_EXEC" == "" ] ; then 37 | CHATGPT_ON_WECHAT_EXEC="python app.py" 38 | fi 39 | 40 | # modify content in config.json 41 | if [ "$OPEN_AI_API_KEY" == "YOUR API KEY" ] || [ "$OPEN_AI_API_KEY" == "" ]; then 42 | echo -e "\033[31m[Warning] You need to set OPEN_AI_API_KEY before running!\033[0m" 43 | fi 44 | 45 | 46 | # go to prefix dir 47 | cd $CHATGPT_ON_WECHAT_PREFIX 48 | # excute 49 | $CHATGPT_ON_WECHAT_EXEC 50 | 51 | 52 | -------------------------------------------------------------------------------- /docker/sample-chatgpt-on-wechat/.env: -------------------------------------------------------------------------------- 1 | OPEN_AI_API_KEY=YOUR API KEY 2 | OPEN_AI_PROXY= 3 | SINGLE_CHAT_PREFIX=["bot", "@bot"] 4 | SINGLE_CHAT_REPLY_PREFIX="[bot] " 5 | GROUP_CHAT_PREFIX=["@bot"] 6 | GROUP_NAME_WHITE_LIST=["ChatGPT测试群", "ChatGPT测试群2"] 7 | IMAGE_CREATE_PREFIX=["画", "看", "找"] 8 | CONVERSATION_MAX_TOKENS=1000 9 | SPEECH_RECOGNITION=false 10 | CHARACTER_DESC=你是ChatGPT, 一个由OpenAI训练的大型语言模型, 你旨在回答并解决人们的任何问题,并且可以使用多种语言与人交流。 11 | EXPIRES_IN_SECONDS=3600 12 | 13 | # Optional 14 | #CHATGPT_ON_WECHAT_PREFIX=/app 15 | #CHATGPT_ON_WECHAT_CONFIG_PATH=/app/config.json 16 | #CHATGPT_ON_WECHAT_EXEC=python app.py -------------------------------------------------------------------------------- /docker/sample-chatgpt-on-wechat/Makefile: -------------------------------------------------------------------------------- 1 | IMG:=`cat Name` 2 | MOUNT:= 3 | PORT_MAP:= 4 | DOTENV:=.env 5 | CONTAINER_NAME:=sample-chatgpt-on-wechat 6 | 7 | echo: 8 | echo $(IMG) 9 | 10 | run_d: 11 | docker rm $(CONTAINER_NAME) || echo 12 | docker run -dt --name $(CONTAINER_NAME) $(PORT_MAP) \ 13 | --env-file=$(DOTENV) \ 14 | $(MOUNT) $(IMG) 15 | 16 | run_i: 17 | docker rm $(CONTAINER_NAME) || echo 18 | docker run -it --name $(CONTAINER_NAME) $(PORT_MAP) \ 19 | --env-file=$(DOTENV) \ 20 | $(MOUNT) $(IMG) 21 | 22 | stop: 23 | docker stop $(CONTAINER_NAME) 24 | 25 | rm: stop 26 | docker rm $(CONTAINER_NAME) 27 | -------------------------------------------------------------------------------- /docker/sample-chatgpt-on-wechat/Name: -------------------------------------------------------------------------------- 1 | zhayujie/chatgpt-on-wechat 2 | -------------------------------------------------------------------------------- /docs/images/group-chat-sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limccn/chatgpt-on-wechat/c60f0517fb03aa09f72a5885477569cf2d813d65/docs/images/group-chat-sample.jpg -------------------------------------------------------------------------------- /docs/images/image-create-sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limccn/chatgpt-on-wechat/c60f0517fb03aa09f72a5885477569cf2d813d65/docs/images/image-create-sample.jpg -------------------------------------------------------------------------------- /docs/images/planet.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limccn/chatgpt-on-wechat/c60f0517fb03aa09f72a5885477569cf2d813d65/docs/images/planet.jpg -------------------------------------------------------------------------------- /docs/images/single-chat-sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limccn/chatgpt-on-wechat/c60f0517fb03aa09f72a5885477569cf2d813d65/docs/images/single-chat-sample.jpg -------------------------------------------------------------------------------- /lib/itchat/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import Core 2 | from .config import VERSION, ASYNC_COMPONENTS 3 | from .log import set_logging 4 | 5 | if ASYNC_COMPONENTS: 6 | from .async_components import load_components 7 | else: 8 | from .components import load_components 9 | 10 | 11 | __version__ = VERSION 12 | 13 | 14 | instanceList = [] 15 | 16 | def load_async_itchat() -> Core: 17 | """load async-based itchat instance 18 | 19 | Returns: 20 | Core: the abstract interface of itchat 21 | """ 22 | from .async_components import load_components 23 | load_components(Core) 24 | return Core() 25 | 26 | 27 | def load_sync_itchat() -> Core: 28 | """load sync-based itchat instance 29 | 30 | Returns: 31 | Core: the abstract interface of itchat 32 | """ 33 | from .components import load_components 34 | load_components(Core) 35 | return Core() 36 | 37 | 38 | if ASYNC_COMPONENTS: 39 | instance = load_async_itchat() 40 | else: 41 | instance = load_sync_itchat() 42 | 43 | 44 | instanceList = [instance] 45 | 46 | # I really want to use sys.modules[__name__] = originInstance 47 | # but it makes auto-fill a real mess, so forgive me for my following ** 48 | # actually it toke me less than 30 seconds, god bless Uganda 49 | 50 | # components.login 51 | login = instance.login 52 | get_QRuuid = instance.get_QRuuid 53 | get_QR = instance.get_QR 54 | check_login = instance.check_login 55 | web_init = instance.web_init 56 | show_mobile_login = instance.show_mobile_login 57 | start_receiving = instance.start_receiving 58 | get_msg = instance.get_msg 59 | logout = instance.logout 60 | # components.contact 61 | update_chatroom = instance.update_chatroom 62 | update_friend = instance.update_friend 63 | get_contact = instance.get_contact 64 | get_friends = instance.get_friends 65 | get_chatrooms = instance.get_chatrooms 66 | get_mps = instance.get_mps 67 | set_alias = instance.set_alias 68 | set_pinned = instance.set_pinned 69 | accept_friend = instance.accept_friend 70 | get_head_img = instance.get_head_img 71 | create_chatroom = instance.create_chatroom 72 | set_chatroom_name = instance.set_chatroom_name 73 | delete_member_from_chatroom = instance.delete_member_from_chatroom 74 | add_member_into_chatroom = instance.add_member_into_chatroom 75 | # components.messages 76 | send_raw_msg = instance.send_raw_msg 77 | send_msg = instance.send_msg 78 | upload_file = instance.upload_file 79 | send_file = instance.send_file 80 | send_image = instance.send_image 81 | send_video = instance.send_video 82 | send = instance.send 83 | revoke = instance.revoke 84 | # components.hotreload 85 | dump_login_status = instance.dump_login_status 86 | load_login_status = instance.load_login_status 87 | # components.register 88 | auto_login = instance.auto_login 89 | configured_reply = instance.configured_reply 90 | msg_register = instance.msg_register 91 | run = instance.run 92 | # other functions 93 | search_friends = instance.search_friends 94 | search_chatrooms = instance.search_chatrooms 95 | search_mps = instance.search_mps 96 | set_logging = set_logging 97 | -------------------------------------------------------------------------------- /lib/itchat/async_components/__init__.py: -------------------------------------------------------------------------------- 1 | from .contact import load_contact 2 | from .hotreload import load_hotreload 3 | from .login import load_login 4 | from .messages import load_messages 5 | from .register import load_register 6 | 7 | def load_components(core): 8 | load_contact(core) 9 | load_hotreload(core) 10 | load_login(core) 11 | load_messages(core) 12 | load_register(core) 13 | -------------------------------------------------------------------------------- /lib/itchat/async_components/hotreload.py: -------------------------------------------------------------------------------- 1 | import pickle, os 2 | import logging 3 | 4 | import requests # type: ignore 5 | 6 | from ..config import VERSION 7 | from ..returnvalues import ReturnValue 8 | from ..storage import templates 9 | from .contact import update_local_chatrooms, update_local_friends 10 | from .messages import produce_msg 11 | 12 | logger = logging.getLogger('itchat') 13 | 14 | def load_hotreload(core): 15 | core.dump_login_status = dump_login_status 16 | core.load_login_status = load_login_status 17 | 18 | async def dump_login_status(self, fileDir=None): 19 | fileDir = fileDir or self.hotReloadDir 20 | try: 21 | with open(fileDir, 'w') as f: 22 | f.write('itchat - DELETE THIS') 23 | os.remove(fileDir) 24 | except: 25 | raise Exception('Incorrect fileDir') 26 | status = { 27 | 'version' : VERSION, 28 | 'loginInfo' : self.loginInfo, 29 | 'cookies' : self.s.cookies.get_dict(), 30 | 'storage' : self.storageClass.dumps()} 31 | with open(fileDir, 'wb') as f: 32 | pickle.dump(status, f) 33 | logger.debug('Dump login status for hot reload successfully.') 34 | 35 | async def load_login_status(self, fileDir, 36 | loginCallback=None, exitCallback=None): 37 | try: 38 | with open(fileDir, 'rb') as f: 39 | j = pickle.load(f) 40 | except Exception as e: 41 | logger.debug('No such file, loading login status failed.') 42 | return ReturnValue({'BaseResponse': { 43 | 'ErrMsg': 'No such file, loading login status failed.', 44 | 'Ret': -1002, }}) 45 | 46 | if j.get('version', '') != VERSION: 47 | logger.debug(('you have updated itchat from %s to %s, ' + 48 | 'so cached status is ignored') % ( 49 | j.get('version', 'old version'), VERSION)) 50 | return ReturnValue({'BaseResponse': { 51 | 'ErrMsg': 'cached status ignored because of version', 52 | 'Ret': -1005, }}) 53 | self.loginInfo = j['loginInfo'] 54 | self.loginInfo['User'] = templates.User(self.loginInfo['User']) 55 | self.loginInfo['User'].core = self 56 | self.s.cookies = requests.utils.cookiejar_from_dict(j['cookies']) 57 | self.storageClass.loads(j['storage']) 58 | try: 59 | msgList, contactList = self.get_msg() 60 | except: 61 | msgList = contactList = None 62 | if (msgList or contactList) is None: 63 | self.logout() 64 | await load_last_login_status(self.s, j['cookies']) 65 | logger.debug('server refused, loading login status failed.') 66 | return ReturnValue({'BaseResponse': { 67 | 'ErrMsg': 'server refused, loading login status failed.', 68 | 'Ret': -1003, }}) 69 | else: 70 | if contactList: 71 | for contact in contactList: 72 | if '@@' in contact['UserName']: 73 | update_local_chatrooms(self, [contact]) 74 | else: 75 | update_local_friends(self, [contact]) 76 | if msgList: 77 | msgList = produce_msg(self, msgList) 78 | for msg in msgList: self.msgList.put(msg) 79 | await self.start_receiving(exitCallback) 80 | logger.debug('loading login status succeeded.') 81 | if hasattr(loginCallback, '__call__'): 82 | await loginCallback(self.storageClass.userName) 83 | return ReturnValue({'BaseResponse': { 84 | 'ErrMsg': 'loading login status succeeded.', 85 | 'Ret': 0, }}) 86 | 87 | async def load_last_login_status(session, cookiesDict): 88 | try: 89 | session.cookies = requests.utils.cookiejar_from_dict({ 90 | 'webwxuvid': cookiesDict['webwxuvid'], 91 | 'webwx_auth_ticket': cookiesDict['webwx_auth_ticket'], 92 | 'login_frequency': '2', 93 | 'last_wxuin': cookiesDict['wxuin'], 94 | 'wxloadtime': cookiesDict['wxloadtime'] + '_expired', 95 | 'wxpluginkey': cookiesDict['wxloadtime'], 96 | 'wxuin': cookiesDict['wxuin'], 97 | 'mm_lang': 'zh_CN', 98 | 'MM_WX_NOTIFY_STATE': '1', 99 | 'MM_WX_SOUND_STATE': '1', }) 100 | except: 101 | logger.info('Load status for push login failed, we may have experienced a cookies change.') 102 | logger.info('If you are using the newest version of itchat, you may report a bug.') 103 | -------------------------------------------------------------------------------- /lib/itchat/async_components/register.py: -------------------------------------------------------------------------------- 1 | import logging, traceback, sys, threading 2 | try: 3 | import Queue 4 | except ImportError: 5 | import queue as Queue # type: ignore 6 | 7 | from ..log import set_logging 8 | from ..utils import test_connect 9 | from ..storage import templates 10 | 11 | logger = logging.getLogger('itchat') 12 | 13 | def load_register(core): 14 | core.auto_login = auto_login 15 | core.configured_reply = configured_reply 16 | core.msg_register = msg_register 17 | core.run = run 18 | 19 | async def auto_login(self, EventScanPayload=None,ScanStatus=None,event_stream=None, 20 | hotReload=True, statusStorageDir='itchat.pkl', 21 | enableCmdQR=False, picDir=None, qrCallback=None, 22 | loginCallback=None, exitCallback=None): 23 | if not test_connect(): 24 | logger.info("You can't get access to internet or wechat domain, so exit.") 25 | sys.exit() 26 | self.useHotReload = hotReload 27 | self.hotReloadDir = statusStorageDir 28 | if hotReload: 29 | if await self.load_login_status(statusStorageDir, 30 | loginCallback=loginCallback, exitCallback=exitCallback): 31 | return 32 | await self.login(enableCmdQR=enableCmdQR, picDir=picDir, qrCallback=qrCallback, EventScanPayload=EventScanPayload, ScanStatus=ScanStatus, event_stream=event_stream, 33 | loginCallback=loginCallback, exitCallback=exitCallback) 34 | await self.dump_login_status(statusStorageDir) 35 | else: 36 | await self.login(enableCmdQR=enableCmdQR, picDir=picDir, qrCallback=qrCallback, EventScanPayload=EventScanPayload, ScanStatus=ScanStatus, event_stream=event_stream, 37 | loginCallback=loginCallback, exitCallback=exitCallback) 38 | 39 | async def configured_reply(self, event_stream, payload, message_container): 40 | ''' determine the type of message and reply if its method is defined 41 | however, I use a strange way to determine whether a msg is from massive platform 42 | I haven't found a better solution here 43 | The main problem I'm worrying about is the mismatching of new friends added on phone 44 | If you have any good idea, pleeeease report an issue. I will be more than grateful. 45 | ''' 46 | try: 47 | msg = self.msgList.get(timeout=1) 48 | if 'MsgId' in msg.keys(): 49 | message_container[msg['MsgId']] = msg 50 | except Queue.Empty: 51 | pass 52 | else: 53 | if isinstance(msg['User'], templates.User): 54 | replyFn = self.functionDict['FriendChat'].get(msg['Type']) 55 | elif isinstance(msg['User'], templates.MassivePlatform): 56 | replyFn = self.functionDict['MpChat'].get(msg['Type']) 57 | elif isinstance(msg['User'], templates.Chatroom): 58 | replyFn = self.functionDict['GroupChat'].get(msg['Type']) 59 | if replyFn is None: 60 | r = None 61 | else: 62 | try: 63 | r = await replyFn(msg) 64 | if r is not None: 65 | await self.send(r, msg.get('FromUserName')) 66 | except: 67 | logger.warning(traceback.format_exc()) 68 | 69 | def msg_register(self, msgType, isFriendChat=False, isGroupChat=False, isMpChat=False): 70 | ''' a decorator constructor 71 | return a specific decorator based on information given ''' 72 | if not (isinstance(msgType, list) or isinstance(msgType, tuple)): 73 | msgType = [msgType] 74 | def _msg_register(fn): 75 | for _msgType in msgType: 76 | if isFriendChat: 77 | self.functionDict['FriendChat'][_msgType] = fn 78 | if isGroupChat: 79 | self.functionDict['GroupChat'][_msgType] = fn 80 | if isMpChat: 81 | self.functionDict['MpChat'][_msgType] = fn 82 | if not any((isFriendChat, isGroupChat, isMpChat)): 83 | self.functionDict['FriendChat'][_msgType] = fn 84 | return fn 85 | return _msg_register 86 | 87 | async def run(self, debug=False, blockThread=True): 88 | logger.info('Start auto replying.') 89 | if debug: 90 | set_logging(loggingLevel=logging.DEBUG) 91 | async def reply_fn(): 92 | try: 93 | while self.alive: 94 | await self.configured_reply() 95 | except KeyboardInterrupt: 96 | if self.useHotReload: 97 | await self.dump_login_status() 98 | self.alive = False 99 | logger.debug('itchat received an ^C and exit.') 100 | logger.info('Bye~') 101 | if blockThread: 102 | await reply_fn() 103 | else: 104 | replyThread = threading.Thread(target=reply_fn) 105 | replyThread.setDaemon(True) 106 | replyThread.start() 107 | -------------------------------------------------------------------------------- /lib/itchat/components/__init__.py: -------------------------------------------------------------------------------- 1 | from .contact import load_contact 2 | from .hotreload import load_hotreload 3 | from .login import load_login 4 | from .messages import load_messages 5 | from .register import load_register 6 | 7 | def load_components(core): 8 | load_contact(core) 9 | load_hotreload(core) 10 | load_login(core) 11 | load_messages(core) 12 | load_register(core) 13 | -------------------------------------------------------------------------------- /lib/itchat/components/hotreload.py: -------------------------------------------------------------------------------- 1 | import pickle, os 2 | import logging 3 | 4 | import requests 5 | 6 | from ..config import VERSION 7 | from ..returnvalues import ReturnValue 8 | from ..storage import templates 9 | from .contact import update_local_chatrooms, update_local_friends 10 | from .messages import produce_msg 11 | 12 | logger = logging.getLogger('itchat') 13 | 14 | def load_hotreload(core): 15 | core.dump_login_status = dump_login_status 16 | core.load_login_status = load_login_status 17 | 18 | def dump_login_status(self, fileDir=None): 19 | fileDir = fileDir or self.hotReloadDir 20 | try: 21 | with open(fileDir, 'w') as f: 22 | f.write('itchat - DELETE THIS') 23 | os.remove(fileDir) 24 | except: 25 | raise Exception('Incorrect fileDir') 26 | status = { 27 | 'version' : VERSION, 28 | 'loginInfo' : self.loginInfo, 29 | 'cookies' : self.s.cookies.get_dict(), 30 | 'storage' : self.storageClass.dumps()} 31 | with open(fileDir, 'wb') as f: 32 | pickle.dump(status, f) 33 | logger.debug('Dump login status for hot reload successfully.') 34 | 35 | def load_login_status(self, fileDir, 36 | loginCallback=None, exitCallback=None): 37 | try: 38 | with open(fileDir, 'rb') as f: 39 | j = pickle.load(f) 40 | except Exception as e: 41 | logger.debug('No such file, loading login status failed.') 42 | return ReturnValue({'BaseResponse': { 43 | 'ErrMsg': 'No such file, loading login status failed.', 44 | 'Ret': -1002, }}) 45 | 46 | if j.get('version', '') != VERSION: 47 | logger.debug(('you have updated itchat from %s to %s, ' + 48 | 'so cached status is ignored') % ( 49 | j.get('version', 'old version'), VERSION)) 50 | return ReturnValue({'BaseResponse': { 51 | 'ErrMsg': 'cached status ignored because of version', 52 | 'Ret': -1005, }}) 53 | self.loginInfo = j['loginInfo'] 54 | self.loginInfo['User'] = templates.User(self.loginInfo['User']) 55 | self.loginInfo['User'].core = self 56 | self.s.cookies = requests.utils.cookiejar_from_dict(j['cookies']) 57 | self.storageClass.loads(j['storage']) 58 | try: 59 | msgList, contactList = self.get_msg() 60 | except: 61 | msgList = contactList = None 62 | if (msgList or contactList) is None: 63 | self.logout() 64 | load_last_login_status(self.s, j['cookies']) 65 | logger.debug('server refused, loading login status failed.') 66 | return ReturnValue({'BaseResponse': { 67 | 'ErrMsg': 'server refused, loading login status failed.', 68 | 'Ret': -1003, }}) 69 | else: 70 | if contactList: 71 | for contact in contactList: 72 | if '@@' in contact['UserName']: 73 | update_local_chatrooms(self, [contact]) 74 | else: 75 | update_local_friends(self, [contact]) 76 | if msgList: 77 | msgList = produce_msg(self, msgList) 78 | for msg in msgList: self.msgList.put(msg) 79 | self.start_receiving(exitCallback) 80 | logger.debug('loading login status succeeded.') 81 | if hasattr(loginCallback, '__call__'): 82 | loginCallback() 83 | return ReturnValue({'BaseResponse': { 84 | 'ErrMsg': 'loading login status succeeded.', 85 | 'Ret': 0, }}) 86 | 87 | def load_last_login_status(session, cookiesDict): 88 | try: 89 | session.cookies = requests.utils.cookiejar_from_dict({ 90 | 'webwxuvid': cookiesDict['webwxuvid'], 91 | 'webwx_auth_ticket': cookiesDict['webwx_auth_ticket'], 92 | 'login_frequency': '2', 93 | 'last_wxuin': cookiesDict['wxuin'], 94 | 'wxloadtime': cookiesDict['wxloadtime'] + '_expired', 95 | 'wxpluginkey': cookiesDict['wxloadtime'], 96 | 'wxuin': cookiesDict['wxuin'], 97 | 'mm_lang': 'zh_CN', 98 | 'MM_WX_NOTIFY_STATE': '1', 99 | 'MM_WX_SOUND_STATE': '1', }) 100 | except: 101 | logger.info('Load status for push login failed, we may have experienced a cookies change.') 102 | logger.info('If you are using the newest version of itchat, you may report a bug.') 103 | -------------------------------------------------------------------------------- /lib/itchat/components/register.py: -------------------------------------------------------------------------------- 1 | import logging, traceback, sys, threading 2 | try: 3 | import Queue 4 | except ImportError: 5 | import queue as Queue 6 | 7 | from ..log import set_logging 8 | from ..utils import test_connect 9 | from ..storage import templates 10 | 11 | logger = logging.getLogger('itchat') 12 | 13 | def load_register(core): 14 | core.auto_login = auto_login 15 | core.configured_reply = configured_reply 16 | core.msg_register = msg_register 17 | core.run = run 18 | 19 | def auto_login(self, hotReload=False, statusStorageDir='itchat.pkl', 20 | enableCmdQR=False, picDir=None, qrCallback=None, 21 | loginCallback=None, exitCallback=None): 22 | if not test_connect(): 23 | logger.info("You can't get access to internet or wechat domain, so exit.") 24 | sys.exit() 25 | self.useHotReload = hotReload 26 | self.hotReloadDir = statusStorageDir 27 | if hotReload: 28 | if self.load_login_status(statusStorageDir, 29 | loginCallback=loginCallback, exitCallback=exitCallback): 30 | return 31 | self.login(enableCmdQR=enableCmdQR, picDir=picDir, qrCallback=qrCallback, 32 | loginCallback=loginCallback, exitCallback=exitCallback) 33 | self.dump_login_status(statusStorageDir) 34 | else: 35 | self.login(enableCmdQR=enableCmdQR, picDir=picDir, qrCallback=qrCallback, 36 | loginCallback=loginCallback, exitCallback=exitCallback) 37 | 38 | def configured_reply(self): 39 | ''' determine the type of message and reply if its method is defined 40 | however, I use a strange way to determine whether a msg is from massive platform 41 | I haven't found a better solution here 42 | The main problem I'm worrying about is the mismatching of new friends added on phone 43 | If you have any good idea, pleeeease report an issue. I will be more than grateful. 44 | ''' 45 | try: 46 | msg = self.msgList.get(timeout=1) 47 | except Queue.Empty: 48 | pass 49 | else: 50 | if isinstance(msg['User'], templates.User): 51 | replyFn = self.functionDict['FriendChat'].get(msg['Type']) 52 | elif isinstance(msg['User'], templates.MassivePlatform): 53 | replyFn = self.functionDict['MpChat'].get(msg['Type']) 54 | elif isinstance(msg['User'], templates.Chatroom): 55 | replyFn = self.functionDict['GroupChat'].get(msg['Type']) 56 | if replyFn is None: 57 | r = None 58 | else: 59 | try: 60 | r = replyFn(msg) 61 | if r is not None: 62 | self.send(r, msg.get('FromUserName')) 63 | except: 64 | logger.warning(traceback.format_exc()) 65 | 66 | def msg_register(self, msgType, isFriendChat=False, isGroupChat=False, isMpChat=False): 67 | ''' a decorator constructor 68 | return a specific decorator based on information given ''' 69 | if not (isinstance(msgType, list) or isinstance(msgType, tuple)): 70 | msgType = [msgType] 71 | def _msg_register(fn): 72 | for _msgType in msgType: 73 | if isFriendChat: 74 | self.functionDict['FriendChat'][_msgType] = fn 75 | if isGroupChat: 76 | self.functionDict['GroupChat'][_msgType] = fn 77 | if isMpChat: 78 | self.functionDict['MpChat'][_msgType] = fn 79 | if not any((isFriendChat, isGroupChat, isMpChat)): 80 | self.functionDict['FriendChat'][_msgType] = fn 81 | return fn 82 | return _msg_register 83 | 84 | def run(self, debug=False, blockThread=True): 85 | logger.info('Start auto replying.') 86 | if debug: 87 | set_logging(loggingLevel=logging.DEBUG) 88 | def reply_fn(): 89 | try: 90 | while self.alive: 91 | self.configured_reply() 92 | except KeyboardInterrupt: 93 | if self.useHotReload: 94 | self.dump_login_status() 95 | self.alive = False 96 | logger.debug('itchat received an ^C and exit.') 97 | logger.info('Bye~') 98 | if blockThread: 99 | reply_fn() 100 | else: 101 | replyThread = threading.Thread(target=reply_fn) 102 | replyThread.setDaemon(True) 103 | replyThread.start() 104 | -------------------------------------------------------------------------------- /lib/itchat/config.py: -------------------------------------------------------------------------------- 1 | import os, platform 2 | 3 | VERSION = '1.5.0.dev' 4 | 5 | # use this envrionment to initialize the async & sync componment 6 | ASYNC_COMPONENTS = os.environ.get('ITCHAT_UOS_ASYNC', False) 7 | 8 | BASE_URL = 'https://login.weixin.qq.com' 9 | OS = platform.system() # Windows, Linux, Darwin 10 | DIR = os.getcwd() 11 | DEFAULT_QR = 'QR.png' 12 | TIMEOUT = (10, 60) 13 | 14 | USER_AGENT = 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_6) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/54.0.2840.71 Safari/537.36' 15 | 16 | UOS_PATCH_CLIENT_VERSION = '2.0.0' 17 | UOS_PATCH_EXTSPAM = 'Go8FCIkFEokFCggwMDAwMDAwMRAGGvAESySibk50w5Wb3uTl2c2h64jVVrV7gNs06GFlWplHQbY/5FfiO++1yH4ykCyNPWKXmco+wfQzK5R98D3so7rJ5LmGFvBLjGceleySrc3SOf2Pc1gVehzJgODeS0lDL3/I/0S2SSE98YgKleq6Uqx6ndTy9yaL9qFxJL7eiA/R3SEfTaW1SBoSITIu+EEkXff+Pv8NHOk7N57rcGk1w0ZzRrQDkXTOXFN2iHYIzAAZPIOY45Lsh+A4slpgnDiaOvRtlQYCt97nmPLuTipOJ8Qc5pM7ZsOsAPPrCQL7nK0I7aPrFDF0q4ziUUKettzW8MrAaiVfmbD1/VkmLNVqqZVvBCtRblXb5FHmtS8FxnqCzYP4WFvz3T0TcrOqwLX1M/DQvcHaGGw0B0y4bZMs7lVScGBFxMj3vbFi2SRKbKhaitxHfYHAOAa0X7/MSS0RNAjdwoyGHeOepXOKY+h3iHeqCvgOH6LOifdHf/1aaZNwSkGotYnYScW8Yx63LnSwba7+hESrtPa/huRmB9KWvMCKbDThL/nne14hnL277EDCSocPu3rOSYjuB9gKSOdVmWsj9Dxb/iZIe+S6AiG29Esm+/eUacSba0k8wn5HhHg9d4tIcixrxveflc8vi2/wNQGVFNsGO6tB5WF0xf/plngOvQ1/ivGV/C1Qpdhzznh0ExAVJ6dwzNg7qIEBaw+BzTJTUuRcPk92Sn6QDn2Pu3mpONaEumacjW4w6ipPnPw+g2TfywJjeEcpSZaP4Q3YV5HG8D6UjWA4GSkBKculWpdCMadx0usMomsSS/74QgpYqcPkmamB4nVv1JxczYITIqItIKjD35IGKAUwAA==' 18 | -------------------------------------------------------------------------------- /lib/itchat/content.py: -------------------------------------------------------------------------------- 1 | TEXT = 'Text' 2 | MAP = 'Map' 3 | CARD = 'Card' 4 | NOTE = 'Note' 5 | SHARING = 'Sharing' 6 | PICTURE = 'Picture' 7 | RECORDING = VOICE = 'Recording' 8 | ATTACHMENT = 'Attachment' 9 | VIDEO = 'Video' 10 | FRIENDS = 'Friends' 11 | SYSTEM = 'System' 12 | 13 | INCOME_MSG = [TEXT, MAP, CARD, NOTE, SHARING, PICTURE, 14 | RECORDING, VOICE, ATTACHMENT, VIDEO, FRIENDS, SYSTEM] 15 | -------------------------------------------------------------------------------- /lib/itchat/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | class LogSystem(object): 4 | handlerList = [] 5 | showOnCmd = True 6 | loggingLevel = logging.INFO 7 | loggingFile = None 8 | def __init__(self): 9 | self.logger = logging.getLogger('itchat') 10 | self.logger.addHandler(logging.NullHandler()) 11 | self.logger.setLevel(self.loggingLevel) 12 | self.cmdHandler = logging.StreamHandler() 13 | self.fileHandler = None 14 | self.logger.addHandler(self.cmdHandler) 15 | def set_logging(self, showOnCmd=True, loggingFile=None, 16 | loggingLevel=logging.INFO): 17 | if showOnCmd != self.showOnCmd: 18 | if showOnCmd: 19 | self.logger.addHandler(self.cmdHandler) 20 | else: 21 | self.logger.removeHandler(self.cmdHandler) 22 | self.showOnCmd = showOnCmd 23 | if loggingFile != self.loggingFile: 24 | if self.loggingFile is not None: # clear old fileHandler 25 | self.logger.removeHandler(self.fileHandler) 26 | self.fileHandler.close() 27 | if loggingFile is not None: # add new fileHandler 28 | self.fileHandler = logging.FileHandler(loggingFile) 29 | self.logger.addHandler(self.fileHandler) 30 | self.loggingFile = loggingFile 31 | if loggingLevel != self.loggingLevel: 32 | self.logger.setLevel(loggingLevel) 33 | self.loggingLevel = loggingLevel 34 | 35 | ls = LogSystem() 36 | set_logging = ls.set_logging 37 | -------------------------------------------------------------------------------- /lib/itchat/returnvalues.py: -------------------------------------------------------------------------------- 1 | #coding=utf8 2 | TRANSLATE = 'Chinese' 3 | 4 | class ReturnValue(dict): 5 | ''' turn return value of itchat into a boolean value 6 | for requests: 7 | ..code::python 8 | 9 | import requests 10 | r = requests.get('http://httpbin.org/get') 11 | print(ReturnValue(rawResponse=r) 12 | 13 | for normal dict: 14 | ..code::python 15 | 16 | returnDict = { 17 | 'BaseResponse': { 18 | 'Ret': 0, 19 | 'ErrMsg': 'My error msg', }, } 20 | print(ReturnValue(returnDict)) 21 | ''' 22 | def __init__(self, returnValueDict={}, rawResponse=None): 23 | if rawResponse: 24 | try: 25 | returnValueDict = rawResponse.json() 26 | except ValueError: 27 | returnValueDict = { 28 | 'BaseResponse': { 29 | 'Ret': -1004, 30 | 'ErrMsg': 'Unexpected return value', }, 31 | 'Data': rawResponse.content, } 32 | for k, v in returnValueDict.items(): 33 | self[k] = v 34 | if not 'BaseResponse' in self: 35 | self['BaseResponse'] = { 36 | 'ErrMsg': 'no BaseResponse in raw response', 37 | 'Ret': -1000, } 38 | if TRANSLATE: 39 | self['BaseResponse']['RawMsg'] = self['BaseResponse'].get('ErrMsg', '') 40 | self['BaseResponse']['ErrMsg'] = \ 41 | TRANSLATION[TRANSLATE].get( 42 | self['BaseResponse'].get('Ret', '')) \ 43 | or self['BaseResponse'].get('ErrMsg', u'No ErrMsg') 44 | self['BaseResponse']['RawMsg'] = \ 45 | self['BaseResponse']['RawMsg'] or self['BaseResponse']['ErrMsg'] 46 | def __nonzero__(self): 47 | return self['BaseResponse'].get('Ret') == 0 48 | def __bool__(self): 49 | return self.__nonzero__() 50 | def __str__(self): 51 | return '{%s}' % ', '.join( 52 | ['%s: %s' % (repr(k),repr(v)) for k,v in self.items()]) 53 | def __repr__(self): 54 | return '' % self.__str__() 55 | 56 | TRANSLATION = { 57 | 'Chinese': { 58 | -1000: u'返回值不带BaseResponse', 59 | -1001: u'无法找到对应的成员', 60 | -1002: u'文件位置错误', 61 | -1003: u'服务器拒绝连接', 62 | -1004: u'服务器返回异常值', 63 | -1005: u'参数错误', 64 | -1006: u'无效操作', 65 | 0: u'请求成功', 66 | }, 67 | } 68 | -------------------------------------------------------------------------------- /lib/itchat/storage/__init__.py: -------------------------------------------------------------------------------- 1 | import os, time, copy 2 | from threading import Lock 3 | 4 | from .messagequeue import Queue 5 | from .templates import ( 6 | ContactList, AbstractUserDict, User, 7 | MassivePlatform, Chatroom, ChatroomMember) 8 | 9 | def contact_change(fn): 10 | def _contact_change(core, *args, **kwargs): 11 | with core.storageClass.updateLock: 12 | return fn(core, *args, **kwargs) 13 | return _contact_change 14 | 15 | class Storage(object): 16 | def __init__(self, core): 17 | self.userName = None 18 | self.nickName = None 19 | self.updateLock = Lock() 20 | self.memberList = ContactList() 21 | self.mpList = ContactList() 22 | self.chatroomList = ContactList() 23 | self.msgList = Queue(-1) 24 | self.lastInputUserName = None 25 | self.memberList.set_default_value(contactClass=User) 26 | self.memberList.core = core 27 | self.mpList.set_default_value(contactClass=MassivePlatform) 28 | self.mpList.core = core 29 | self.chatroomList.set_default_value(contactClass=Chatroom) 30 | self.chatroomList.core = core 31 | def dumps(self): 32 | return { 33 | 'userName' : self.userName, 34 | 'nickName' : self.nickName, 35 | 'memberList' : self.memberList, 36 | 'mpList' : self.mpList, 37 | 'chatroomList' : self.chatroomList, 38 | 'lastInputUserName' : self.lastInputUserName, } 39 | def loads(self, j): 40 | self.userName = j.get('userName', None) 41 | self.nickName = j.get('nickName', None) 42 | del self.memberList[:] 43 | for i in j.get('memberList', []): 44 | self.memberList.append(i) 45 | del self.mpList[:] 46 | for i in j.get('mpList', []): 47 | self.mpList.append(i) 48 | del self.chatroomList[:] 49 | for i in j.get('chatroomList', []): 50 | self.chatroomList.append(i) 51 | # I tried to solve everything in pickle 52 | # but this way is easier and more storage-saving 53 | for chatroom in self.chatroomList: 54 | if 'MemberList' in chatroom: 55 | for member in chatroom['MemberList']: 56 | member.core = chatroom.core 57 | member.chatroom = chatroom 58 | if 'Self' in chatroom: 59 | chatroom['Self'].core = chatroom.core 60 | chatroom['Self'].chatroom = chatroom 61 | self.lastInputUserName = j.get('lastInputUserName', None) 62 | def search_friends(self, name=None, userName=None, remarkName=None, nickName=None, 63 | wechatAccount=None): 64 | with self.updateLock: 65 | if (name or userName or remarkName or nickName or wechatAccount) is None: 66 | return copy.deepcopy(self.memberList[0]) # my own account 67 | elif userName: # return the only userName match 68 | for m in self.memberList: 69 | if m['UserName'] == userName: 70 | return copy.deepcopy(m) 71 | else: 72 | matchDict = { 73 | 'RemarkName' : remarkName, 74 | 'NickName' : nickName, 75 | 'Alias' : wechatAccount, } 76 | for k in ('RemarkName', 'NickName', 'Alias'): 77 | if matchDict[k] is None: 78 | del matchDict[k] 79 | if name: # select based on name 80 | contact = [] 81 | for m in self.memberList: 82 | if any([m.get(k) == name for k in ('RemarkName', 'NickName', 'Alias')]): 83 | contact.append(m) 84 | else: 85 | contact = self.memberList[:] 86 | if matchDict: # select again based on matchDict 87 | friendList = [] 88 | for m in contact: 89 | if all([m.get(k) == v for k, v in matchDict.items()]): 90 | friendList.append(m) 91 | return copy.deepcopy(friendList) 92 | else: 93 | return copy.deepcopy(contact) 94 | def search_chatrooms(self, name=None, userName=None): 95 | with self.updateLock: 96 | if userName is not None: 97 | for m in self.chatroomList: 98 | if m['UserName'] == userName: 99 | return copy.deepcopy(m) 100 | elif name is not None: 101 | matchList = [] 102 | for m in self.chatroomList: 103 | if name in m['NickName']: 104 | matchList.append(copy.deepcopy(m)) 105 | return matchList 106 | def search_mps(self, name=None, userName=None): 107 | with self.updateLock: 108 | if userName is not None: 109 | for m in self.mpList: 110 | if m['UserName'] == userName: 111 | return copy.deepcopy(m) 112 | elif name is not None: 113 | matchList = [] 114 | for m in self.mpList: 115 | if name in m['NickName']: 116 | matchList.append(copy.deepcopy(m)) 117 | return matchList 118 | -------------------------------------------------------------------------------- /lib/itchat/storage/messagequeue.py: -------------------------------------------------------------------------------- 1 | import logging 2 | try: 3 | import Queue as queue 4 | except ImportError: 5 | import queue 6 | 7 | from .templates import AttributeDict 8 | 9 | logger = logging.getLogger('itchat') 10 | 11 | class Queue(queue.Queue): 12 | def put(self, message): 13 | queue.Queue.put(self, Message(message)) 14 | 15 | class Message(AttributeDict): 16 | def download(self, fileName): 17 | if hasattr(self.text, '__call__'): 18 | return self.text(fileName) 19 | else: 20 | return b'' 21 | def __getitem__(self, value): 22 | if value in ('isAdmin', 'isAt'): 23 | v = value[0].upper() + value[1:] # ''[1:] == '' 24 | logger.debug('%s is expired in 1.3.0, use %s instead.' % (value, v)) 25 | value = v 26 | return super(Message, self).__getitem__(value) 27 | def __str__(self): 28 | return '{%s}' % ', '.join( 29 | ['%s: %s' % (repr(k),repr(v)) for k,v in self.items()]) 30 | def __repr__(self): 31 | return '<%s: %s>' % (self.__class__.__name__.split('.')[-1], 32 | self.__str__()) 33 | -------------------------------------------------------------------------------- /lib/itchat/utils.py: -------------------------------------------------------------------------------- 1 | import re, os, sys, subprocess, copy, traceback, logging 2 | 3 | try: 4 | from HTMLParser import HTMLParser 5 | except ImportError: 6 | from html.parser import HTMLParser 7 | try: 8 | from urllib import quote as _quote 9 | quote = lambda n: _quote(n.encode('utf8', 'replace')) 10 | except ImportError: 11 | from urllib.parse import quote 12 | 13 | import requests 14 | 15 | from . import config 16 | 17 | logger = logging.getLogger('itchat') 18 | 19 | emojiRegex = re.compile(r'') 20 | htmlParser = HTMLParser() 21 | if not hasattr(htmlParser, 'unescape'): 22 | import html 23 | htmlParser.unescape = html.unescape 24 | # FIX Python 3.9 HTMLParser.unescape is removed. See https://docs.python.org/3.9/whatsnew/3.9.html 25 | try: 26 | b = u'\u2588' 27 | sys.stdout.write(b + '\r') 28 | sys.stdout.flush() 29 | except UnicodeEncodeError: 30 | BLOCK = 'MM' 31 | else: 32 | BLOCK = b 33 | friendInfoTemplate = {} 34 | for k in ('UserName', 'City', 'DisplayName', 'PYQuanPin', 'RemarkPYInitial', 'Province', 35 | 'KeyWord', 'RemarkName', 'PYInitial', 'EncryChatRoomId', 'Alias', 'Signature', 36 | 'NickName', 'RemarkPYQuanPin', 'HeadImgUrl'): 37 | friendInfoTemplate[k] = '' 38 | for k in ('UniFriend', 'Sex', 'AppAccountFlag', 'VerifyFlag', 'ChatRoomId', 'HideInputBarFlag', 39 | 'AttrStatus', 'SnsFlag', 'MemberCount', 'OwnerUin', 'ContactFlag', 'Uin', 40 | 'StarFriend', 'Statues'): 41 | friendInfoTemplate[k] = 0 42 | friendInfoTemplate['MemberList'] = [] 43 | 44 | def clear_screen(): 45 | os.system('cls' if config.OS == 'Windows' else 'clear') 46 | 47 | def emoji_formatter(d, k): 48 | ''' _emoji_deebugger is for bugs about emoji match caused by wechat backstage 49 | like :face with tears of joy: will be replaced with :cat face with tears of joy: 50 | ''' 51 | def _emoji_debugger(d, k): 52 | s = d[k].replace('') # fix missing bug 54 | def __fix_miss_match(m): 55 | return '' % ({ 56 | '1f63c': '1f601', '1f639': '1f602', '1f63a': '1f603', 57 | '1f4ab': '1f616', '1f64d': '1f614', '1f63b': '1f60d', 58 | '1f63d': '1f618', '1f64e': '1f621', '1f63f': '1f622', 59 | }.get(m.group(1), m.group(1))) 60 | return emojiRegex.sub(__fix_miss_match, s) 61 | def _emoji_formatter(m): 62 | s = m.group(1) 63 | if len(s) == 6: 64 | return ('\\U%s\\U%s'%(s[:2].rjust(8, '0'), s[2:].rjust(8, '0')) 65 | ).encode('utf8').decode('unicode-escape', 'replace') 66 | elif len(s) == 10: 67 | return ('\\U%s\\U%s'%(s[:5].rjust(8, '0'), s[5:].rjust(8, '0')) 68 | ).encode('utf8').decode('unicode-escape', 'replace') 69 | else: 70 | return ('\\U%s'%m.group(1).rjust(8, '0') 71 | ).encode('utf8').decode('unicode-escape', 'replace') 72 | d[k] = _emoji_debugger(d, k) 73 | d[k] = emojiRegex.sub(_emoji_formatter, d[k]) 74 | 75 | def msg_formatter(d, k): 76 | emoji_formatter(d, k) 77 | d[k] = d[k].replace('
', '\n') 78 | d[k] = htmlParser.unescape(d[k]) 79 | 80 | def check_file(fileDir): 81 | try: 82 | with open(fileDir): 83 | pass 84 | return True 85 | except: 86 | return False 87 | 88 | def print_qr(fileDir): 89 | if config.OS == 'Darwin': 90 | subprocess.call(['open', fileDir]) 91 | elif config.OS == 'Linux': 92 | subprocess.call(['xdg-open', fileDir]) 93 | else: 94 | os.startfile(fileDir) 95 | 96 | def print_cmd_qr(qrText, white=BLOCK, black=' ', enableCmdQR=True): 97 | blockCount = int(enableCmdQR) 98 | if abs(blockCount) == 0: 99 | blockCount = 1 100 | white *= abs(blockCount) 101 | if blockCount < 0: 102 | white, black = black, white 103 | sys.stdout.write(' '*50 + '\r') 104 | sys.stdout.flush() 105 | qr = qrText.replace('0', white).replace('1', black) 106 | sys.stdout.write(qr) 107 | sys.stdout.flush() 108 | 109 | def struct_friend_info(knownInfo): 110 | member = copy.deepcopy(friendInfoTemplate) 111 | for k, v in copy.deepcopy(knownInfo).items(): member[k] = v 112 | return member 113 | 114 | def search_dict_list(l, key, value): 115 | ''' Search a list of dict 116 | * return dict with specific value & key ''' 117 | for i in l: 118 | if i.get(key) == value: 119 | return i 120 | 121 | def print_line(msg, oneLine = False): 122 | if oneLine: 123 | sys.stdout.write(' '*40 + '\r') 124 | sys.stdout.flush() 125 | else: 126 | sys.stdout.write('\n') 127 | sys.stdout.write(msg.encode(sys.stdin.encoding or 'utf8', 'replace' 128 | ).decode(sys.stdin.encoding or 'utf8', 'replace')) 129 | sys.stdout.flush() 130 | 131 | def test_connect(retryTime=5): 132 | for i in range(retryTime): 133 | try: 134 | r = requests.get(config.BASE_URL) 135 | return True 136 | except: 137 | if i == retryTime - 1: 138 | logger.error(traceback.format_exc()) 139 | return False 140 | 141 | def contact_deep_copy(core, contact): 142 | with core.storageClass.updateLock: 143 | return copy.deepcopy(contact) 144 | 145 | def get_image_postfix(data): 146 | data = data[:20] 147 | if b'GIF' in data: 148 | return 'gif' 149 | elif b'PNG' in data: 150 | return 'png' 151 | elif b'JFIF' in data: 152 | return 'jpg' 153 | return '' 154 | 155 | def update_info_dict(oldInfoDict, newInfoDict): 156 | ''' only normal values will be updated here 157 | because newInfoDict is normal dict, so it's not necessary to consider templates 158 | ''' 159 | for k, v in newInfoDict.items(): 160 | if any((isinstance(v, t) for t in (tuple, list, dict))): 161 | pass # these values will be updated somewhere else 162 | elif oldInfoDict.get(k) is None or v not in (None, '', '0', 0): 163 | oldInfoDict[k] = v -------------------------------------------------------------------------------- /nixpacks.toml: -------------------------------------------------------------------------------- 1 | providers = ['python'] 2 | 3 | [phases.setup] 4 | nixPkgs = ['python310'] 5 | cmds = ['apt-get update','apt-get install -y --no-install-recommends ffmpeg espeak','python -m venv /opt/venv && . /opt/venv/bin/activate && pip install -r requirements-optional.txt'] 6 | [start] 7 | cmd = "python ./app.py" -------------------------------------------------------------------------------- /plugins/__init__.py: -------------------------------------------------------------------------------- 1 | from .event import * 2 | from .plugin import * 3 | from .plugin_manager import PluginManager 4 | 5 | instance = PluginManager() 6 | 7 | register = instance.register 8 | # load_plugins = instance.load_plugins 9 | # emit_event = instance.emit_event 10 | -------------------------------------------------------------------------------- /plugins/banwords/.gitignore: -------------------------------------------------------------------------------- 1 | banwords.txt -------------------------------------------------------------------------------- /plugins/banwords/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## 插件描述 3 | 4 | 简易的敏感词插件,暂不支持分词,请自行导入词库到插件文件夹中的`banwords.txt`,每行一个词,一个参考词库是[1](https://github.com/cjh0613/tencent-sensitive-words/blob/main/sensitive_words_lines.txt)。 5 | 6 | 使用前将`config.json.template`复制为`config.json`,并自行配置。 7 | 8 | 目前插件对消息的默认处理行为有如下两种: 9 | 10 | - `ignore` : 无视这条消息。 11 | - `replace` : 将消息中的敏感词替换成"*",并回复违规。 12 | 13 | ```json 14 | "action": "replace", 15 | "reply_filter": true, 16 | "reply_action": "ignore" 17 | ``` 18 | 19 | 在以上配置项中: 20 | 21 | - `action`: 对用户消息的默认处理行为 22 | - `reply_filter`: 是否对ChatGPT的回复也进行敏感词过滤 23 | - `reply_action`: 如果开启了回复过滤,对回复的默认处理行为 24 | 25 | ## 致谢 26 | 27 | 搜索功能实现来自https://github.com/toolgood/ToolGood.Words -------------------------------------------------------------------------------- /plugins/banwords/__init__.py: -------------------------------------------------------------------------------- 1 | from .banwords import * 2 | -------------------------------------------------------------------------------- /plugins/banwords/banwords.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | 3 | import json 4 | import os 5 | 6 | import plugins 7 | from bridge.context import ContextType 8 | from bridge.reply import Reply, ReplyType 9 | from common.log import logger 10 | from plugins import * 11 | 12 | from .lib.WordsSearch import WordsSearch 13 | 14 | 15 | @plugins.register( 16 | name="Banwords", 17 | desire_priority=100, 18 | hidden=True, 19 | desc="判断消息中是否有敏感词、决定是否回复。", 20 | version="1.0", 21 | author="lanvent", 22 | ) 23 | class Banwords(Plugin): 24 | def __init__(self): 25 | super().__init__() 26 | try: 27 | curdir = os.path.dirname(__file__) 28 | config_path = os.path.join(curdir, "config.json") 29 | conf = None 30 | if not os.path.exists(config_path): 31 | conf = {"action": "ignore"} 32 | with open(config_path, "w") as f: 33 | json.dump(conf, f, indent=4) 34 | else: 35 | with open(config_path, "r") as f: 36 | conf = json.load(f) 37 | self.searchr = WordsSearch() 38 | self.action = conf["action"] 39 | banwords_path = os.path.join(curdir, "banwords.txt") 40 | with open(banwords_path, "r", encoding="utf-8") as f: 41 | words = [] 42 | for line in f: 43 | word = line.strip() 44 | if word: 45 | words.append(word) 46 | self.searchr.SetKeywords(words) 47 | self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context 48 | if conf.get("reply_filter", True): 49 | self.handlers[Event.ON_DECORATE_REPLY] = self.on_decorate_reply 50 | self.reply_action = conf.get("reply_action", "ignore") 51 | logger.info("[Banwords] inited") 52 | except Exception as e: 53 | logger.warn( 54 | "[Banwords] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/banwords ." 55 | ) 56 | raise e 57 | 58 | def on_handle_context(self, e_context: EventContext): 59 | if e_context["context"].type not in [ 60 | ContextType.TEXT, 61 | ContextType.IMAGE_CREATE, 62 | ]: 63 | return 64 | 65 | content = e_context["context"].content 66 | logger.debug("[Banwords] on_handle_context. content: %s" % content) 67 | if self.action == "ignore": 68 | f = self.searchr.FindFirst(content) 69 | if f: 70 | logger.info("[Banwords] %s in message" % f["Keyword"]) 71 | e_context.action = EventAction.BREAK_PASS 72 | return 73 | elif self.action == "replace": 74 | if self.searchr.ContainsAny(content): 75 | reply = Reply( 76 | ReplyType.INFO, "发言中包含敏感词,请重试: \n" + self.searchr.Replace(content) 77 | ) 78 | e_context["reply"] = reply 79 | e_context.action = EventAction.BREAK_PASS 80 | return 81 | 82 | def on_decorate_reply(self, e_context: EventContext): 83 | if e_context["reply"].type not in [ReplyType.TEXT]: 84 | return 85 | 86 | reply = e_context["reply"] 87 | content = reply.content 88 | if self.reply_action == "ignore": 89 | f = self.searchr.FindFirst(content) 90 | if f: 91 | logger.info("[Banwords] %s in reply" % f["Keyword"]) 92 | e_context["reply"] = None 93 | e_context.action = EventAction.BREAK_PASS 94 | return 95 | elif self.reply_action == "replace": 96 | if self.searchr.ContainsAny(content): 97 | reply = Reply( 98 | ReplyType.INFO, "已替换回复中的敏感词: \n" + self.searchr.Replace(content) 99 | ) 100 | e_context["reply"] = reply 101 | e_context.action = EventAction.CONTINUE 102 | return 103 | 104 | def get_help_text(self, **kwargs): 105 | return "过滤消息中的敏感词。" 106 | -------------------------------------------------------------------------------- /plugins/banwords/banwords.txt.template: -------------------------------------------------------------------------------- 1 | nipples 2 | pennis 3 | 法轮功 -------------------------------------------------------------------------------- /plugins/banwords/config.json.template: -------------------------------------------------------------------------------- 1 | { 2 | "action": "replace", 3 | "reply_filter": true, 4 | "reply_action": "ignore" 5 | } 6 | -------------------------------------------------------------------------------- /plugins/bdunit/README.md: -------------------------------------------------------------------------------- 1 | ## 插件说明 2 | 3 | 利用百度UNIT实现智能对话 4 | 5 | - 1.解决问题:chatgpt无法处理的指令,交给百度UNIT处理如:天气,日期时间,数学运算等 6 | - 2.如问时间:现在几点钟,今天几号 7 | - 3.如问天气:明天广州天气怎么样,这个周末深圳会不会下雨 8 | - 4.如问数学运算:23+45=多少,100-23=多少,35转化为二进制是多少? 9 | 10 | ## 使用说明 11 | 12 | ### 获取apikey 13 | 14 | 在百度UNIT官网上自己创建应用,申请百度机器人,可以把预先训练好的模型导入到自己的应用中, 15 | 16 | see https://ai.baidu.com/unit/home#/home?track=61fe1b0d3407ce3face1d92cb5c291087095fc10c8377aaf https://console.bce.baidu.com/ai平台申请 17 | 18 | ### 配置文件 19 | 20 | 将文件夹中`config.json.template`复制为`config.json`。 21 | 22 | 在其中填写百度UNIT官网上获取应用的API Key和Secret Key 23 | 24 | ``` json 25 | { 26 | "service_id": "s...", #"机器人ID" 27 | "api_key": "", 28 | "secret_key": "" 29 | } 30 | ``` -------------------------------------------------------------------------------- /plugins/bdunit/__init__.py: -------------------------------------------------------------------------------- 1 | from .bdunit import * 2 | -------------------------------------------------------------------------------- /plugins/bdunit/config.json.template: -------------------------------------------------------------------------------- 1 | { 2 | "service_id": "s...", 3 | "api_key": "", 4 | "secret_key": "" 5 | } 6 | -------------------------------------------------------------------------------- /plugins/dungeon/README.md: -------------------------------------------------------------------------------- 1 | 玩地牢游戏的聊天插件,触发方法如下: 2 | 3 | - `$开始冒险 <背景故事>` - 以<背景故事>开始一个地牢游戏,不填写会使用默认背景故事。之后聊天中你的所有消息会帮助ai完善这个故事。 4 | - `$停止冒险` - 停止一个地牢游戏,回归正常的ai。 5 | -------------------------------------------------------------------------------- /plugins/dungeon/__init__.py: -------------------------------------------------------------------------------- 1 | from .dungeon import * 2 | -------------------------------------------------------------------------------- /plugins/dungeon/dungeon.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | 3 | import plugins 4 | from bridge.bridge import Bridge 5 | from bridge.context import ContextType 6 | from bridge.reply import Reply, ReplyType 7 | from common import const 8 | from common.expired_dict import ExpiredDict 9 | from common.log import logger 10 | from config import conf 11 | from plugins import * 12 | 13 | 14 | # https://github.com/bupticybee/ChineseAiDungeonChatGPT 15 | class StoryTeller: 16 | def __init__(self, bot, sessionid, story): 17 | self.bot = bot 18 | self.sessionid = sessionid 19 | bot.sessions.clear_session(sessionid) 20 | self.first_interact = True 21 | self.story = story 22 | 23 | def reset(self): 24 | self.bot.sessions.clear_session(self.sessionid) 25 | self.first_interact = True 26 | 27 | def action(self, user_action): 28 | if user_action[-1] != "。": 29 | user_action = user_action + "。" 30 | if self.first_interact: 31 | prompt = ( 32 | """现在来充当一个文字冒险游戏,描述时候注意节奏,不要太快,仔细描述各个人物的心情和周边环境。一次只需写四到六句话。 33 | 开头是,""" 34 | + self.story 35 | + " " 36 | + user_action 37 | ) 38 | self.first_interact = False 39 | else: 40 | prompt = """继续,一次只需要续写四到六句话,总共就只讲5分钟内发生的事情。""" + user_action 41 | return prompt 42 | 43 | 44 | @plugins.register( 45 | name="Dungeon", 46 | desire_priority=0, 47 | namecn="文字冒险", 48 | desc="A plugin to play dungeon game", 49 | version="1.0", 50 | author="lanvent", 51 | ) 52 | class Dungeon(Plugin): 53 | def __init__(self): 54 | super().__init__() 55 | self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context 56 | logger.info("[Dungeon] inited") 57 | # 目前没有设计session过期事件,这里先暂时使用过期字典 58 | if conf().get("expires_in_seconds"): 59 | self.games = ExpiredDict(conf().get("expires_in_seconds")) 60 | else: 61 | self.games = dict() 62 | 63 | def on_handle_context(self, e_context: EventContext): 64 | if e_context["context"].type != ContextType.TEXT: 65 | return 66 | bottype = Bridge().get_bot_type("chat") 67 | if bottype not in (const.CHATGPT, const.OPEN_AI): 68 | return 69 | bot = Bridge().get_bot("chat") 70 | content = e_context["context"].content[:] 71 | clist = e_context["context"].content.split(maxsplit=1) 72 | sessionid = e_context["context"]["session_id"] 73 | logger.debug("[Dungeon] on_handle_context. content: %s" % clist) 74 | trigger_prefix = conf().get("plugin_trigger_prefix", "$") 75 | if clist[0] == f"{trigger_prefix}停止冒险": 76 | if sessionid in self.games: 77 | self.games[sessionid].reset() 78 | del self.games[sessionid] 79 | reply = Reply(ReplyType.INFO, "冒险结束!") 80 | e_context["reply"] = reply 81 | e_context.action = EventAction.BREAK_PASS 82 | elif clist[0] == f"{trigger_prefix}开始冒险" or sessionid in self.games: 83 | if sessionid not in self.games or clist[0] == f"{trigger_prefix}开始冒险": 84 | if len(clist) > 1: 85 | story = clist[1] 86 | else: 87 | story = ( 88 | "你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。" 89 | ) 90 | self.games[sessionid] = StoryTeller(bot, sessionid, story) 91 | reply = Reply(ReplyType.INFO, "冒险开始,你可以输入任意内容,让故事继续下去。故事背景是:" + story) 92 | e_context["reply"] = reply 93 | e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 94 | else: 95 | prompt = self.games[sessionid].action(content) 96 | e_context["context"].type = ContextType.TEXT 97 | e_context["context"].content = prompt 98 | e_context.action = EventAction.BREAK # 事件结束,不跳过处理context的默认逻辑 99 | 100 | def get_help_text(self, **kwargs): 101 | help_text = "可以和机器人一起玩文字冒险游戏。\n" 102 | if kwargs.get("verbose") != True: 103 | return help_text 104 | trigger_prefix = conf().get("plugin_trigger_prefix", "$") 105 | help_text = ( 106 | f"{trigger_prefix}开始冒险 " 107 | + "背景故事: 开始一个基于{背景故事}的文字冒险,之后你的所有消息会协助完善这个故事。\n" 108 | + f"{trigger_prefix}停止冒险: 结束游戏。\n" 109 | ) 110 | if kwargs.get("verbose") == True: 111 | help_text += f"\n命令例子: '{trigger_prefix}开始冒险 你在树林里冒险,指不定会从哪里蹦出来一些奇怪的东西,你握紧手上的手枪,希望这次冒险能够找到一些值钱的东西,你往树林深处走去。'" 112 | return help_text 113 | -------------------------------------------------------------------------------- /plugins/event.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | 3 | from enum import Enum 4 | 5 | 6 | class Event(Enum): 7 | ON_RECEIVE_MESSAGE = 1 # 收到消息 8 | """ 9 | e_context = { "channel": 消息channel, "context" : 本次消息的context} 10 | """ 11 | 12 | ON_HANDLE_CONTEXT = 2 # 处理消息前 13 | """ 14 | e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复,初始为空 } 15 | """ 16 | 17 | ON_DECORATE_REPLY = 3 # 得到回复后准备装饰 18 | """ 19 | e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 } 20 | """ 21 | 22 | ON_SEND_REPLY = 4 # 发送回复前 23 | """ 24 | e_context = { "channel": 消息channel, "context" : 本次消息的context, "reply" : 目前的回复 } 25 | """ 26 | 27 | # AFTER_SEND_REPLY = 5 # 发送回复后 28 | 29 | 30 | class EventAction(Enum): 31 | CONTINUE = 1 # 事件未结束,继续交给下个插件处理,如果没有下个插件,则交付给默认的事件处理逻辑 32 | BREAK = 2 # 事件结束,不再给下个插件处理,交付给默认的事件处理逻辑 33 | BREAK_PASS = 3 # 事件结束,不再给下个插件处理,不交付给默认的事件处理逻辑 34 | 35 | 36 | class EventContext: 37 | def __init__(self, event, econtext=dict()): 38 | self.event = event 39 | self.econtext = econtext 40 | self.action = EventAction.CONTINUE 41 | 42 | def __getitem__(self, key): 43 | return self.econtext[key] 44 | 45 | def __setitem__(self, key, value): 46 | self.econtext[key] = value 47 | 48 | def __delitem__(self, key): 49 | del self.econtext[key] 50 | 51 | def is_pass(self): 52 | return self.action == EventAction.BREAK_PASS 53 | -------------------------------------------------------------------------------- /plugins/finish/__init__.py: -------------------------------------------------------------------------------- 1 | from .finish import * 2 | -------------------------------------------------------------------------------- /plugins/finish/finish.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | 3 | import plugins 4 | from bridge.context import ContextType 5 | from bridge.reply import Reply, ReplyType 6 | from common.log import logger 7 | from config import conf 8 | from plugins import * 9 | 10 | 11 | @plugins.register( 12 | name="Finish", 13 | desire_priority=-999, 14 | hidden=True, 15 | desc="A plugin that check unknown command", 16 | version="1.0", 17 | author="js00000", 18 | ) 19 | class Finish(Plugin): 20 | def __init__(self): 21 | super().__init__() 22 | self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context 23 | logger.info("[Finish] inited") 24 | 25 | def on_handle_context(self, e_context: EventContext): 26 | if e_context["context"].type != ContextType.TEXT: 27 | return 28 | 29 | content = e_context["context"].content 30 | logger.debug("[Finish] on_handle_context. content: %s" % content) 31 | trigger_prefix = conf().get("plugin_trigger_prefix", "$") 32 | if content.startswith(trigger_prefix): 33 | reply = Reply() 34 | reply.type = ReplyType.ERROR 35 | reply.content = "未知插件命令\n查看插件命令列表请输入#help 插件名\n" 36 | e_context["reply"] = reply 37 | e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 38 | 39 | def get_help_text(self, **kwargs): 40 | return "" 41 | -------------------------------------------------------------------------------- /plugins/godcmd/README.md: -------------------------------------------------------------------------------- 1 | ## 插件说明 2 | 3 | 指令插件 4 | 5 | ## 插件使用 6 | 7 | 将`config.json.template`复制为`config.json`,并修改其中`password`的值为口令。 8 | 9 | 如果没有设置命令,在命令行日志中会打印出本次的临时口令,请注意观察,打印格式如下。 10 | 11 | ``` 12 | [INFO][2023-04-06 23:53:47][godcmd.py:165] - [Godcmd] 因未设置口令,本次的临时口令为0971。 13 | ``` 14 | 15 | 在私聊中可使用`#auth`指令,输入口令进行管理员认证。更多详细指令请输入`#help`查看帮助文档: 16 | 17 | `#auth <口令>` - 管理员认证,仅可在私聊时认证。 18 | `#help` - 输出帮助文档,**是否是管理员**和是否是在群聊中会影响帮助文档的输出内容。 19 | -------------------------------------------------------------------------------- /plugins/godcmd/__init__.py: -------------------------------------------------------------------------------- 1 | from .godcmd import * 2 | -------------------------------------------------------------------------------- /plugins/godcmd/config.json.template: -------------------------------------------------------------------------------- 1 | { 2 | "password": "", 3 | "admin_users": [] 4 | } 5 | -------------------------------------------------------------------------------- /plugins/hello/__init__.py: -------------------------------------------------------------------------------- 1 | from .hello import * 2 | -------------------------------------------------------------------------------- /plugins/hello/hello.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | 3 | import plugins 4 | from bridge.context import ContextType 5 | from bridge.reply import Reply, ReplyType 6 | from channel.chat_message import ChatMessage 7 | from common.log import logger 8 | from plugins import * 9 | 10 | 11 | @plugins.register( 12 | name="Hello", 13 | desire_priority=-1, 14 | hidden=True, 15 | desc="A simple plugin that says hello", 16 | version="0.1", 17 | author="lanvent", 18 | ) 19 | class Hello(Plugin): 20 | def __init__(self): 21 | super().__init__() 22 | self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context 23 | logger.info("[Hello] inited") 24 | 25 | def on_handle_context(self, e_context: EventContext): 26 | if e_context["context"].type not in [ 27 | ContextType.TEXT, 28 | ContextType.JOIN_GROUP, 29 | ContextType.PATPAT, 30 | ]: 31 | return 32 | 33 | if e_context["context"].type == ContextType.JOIN_GROUP: 34 | e_context["context"].type = ContextType.TEXT 35 | msg: ChatMessage = e_context["context"]["msg"] 36 | e_context[ 37 | "context" 38 | ].content = f'请你随机使用一种风格说一句问候语来欢迎新用户"{msg.actual_user_nickname}"加入群聊。' 39 | e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 40 | return 41 | 42 | if e_context["context"].type == ContextType.PATPAT: 43 | e_context["context"].type = ContextType.TEXT 44 | msg: ChatMessage = e_context["context"]["msg"] 45 | e_context["context"].content = f"请你随机使用一种风格介绍你自己,并告诉用户输入#help可以查看帮助信息。" 46 | e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 47 | return 48 | 49 | content = e_context["context"].content 50 | logger.debug("[Hello] on_handle_context. content: %s" % content) 51 | if content == "Hello": 52 | reply = Reply() 53 | reply.type = ReplyType.TEXT 54 | msg: ChatMessage = e_context["context"]["msg"] 55 | if e_context["context"]["isgroup"]: 56 | reply.content = ( 57 | f"Hello, {msg.actual_user_nickname} from {msg.from_user_nickname}" 58 | ) 59 | else: 60 | reply.content = f"Hello, {msg.from_user_nickname}" 61 | e_context["reply"] = reply 62 | e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 63 | 64 | if content == "Hi": 65 | reply = Reply() 66 | reply.type = ReplyType.TEXT 67 | reply.content = "Hi" 68 | e_context["reply"] = reply 69 | e_context.action = EventAction.BREAK # 事件结束,进入默认处理逻辑,一般会覆写reply 70 | 71 | if content == "End": 72 | # 如果是文本消息"End",将请求转换成"IMAGE_CREATE",并将content设置为"The World" 73 | e_context["context"].type = ContextType.IMAGE_CREATE 74 | content = "The World" 75 | e_context.action = EventAction.CONTINUE # 事件继续,交付给下个插件或默认逻辑 76 | 77 | def get_help_text(self, **kwargs): 78 | help_text = "输入Hello,我会回复你的名字\n输入End,我会回复你世界的图片\n" 79 | return help_text 80 | -------------------------------------------------------------------------------- /plugins/keyword/README.md: -------------------------------------------------------------------------------- 1 | # 目的 2 | 关键字匹配并回复 3 | 4 | # 试用场景 5 | 目前是在微信公众号下面使用过。 6 | 7 | # 使用步骤 8 | 1. 复制 `config.json.template` 为 `config.json` 9 | 2. 在关键字 `keyword` 新增需要关键字匹配的内容 10 | 3. 重启程序做验证 11 | 12 | # 验证结果 13 | ![结果](test-keyword.png) -------------------------------------------------------------------------------- /plugins/keyword/__init__.py: -------------------------------------------------------------------------------- 1 | from .keyword import * 2 | -------------------------------------------------------------------------------- /plugins/keyword/config.json.template: -------------------------------------------------------------------------------- 1 | { 2 | "keyword": { 3 | "关键字匹配": "测试成功" 4 | } 5 | } -------------------------------------------------------------------------------- /plugins/keyword/keyword.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | 3 | import json 4 | import os 5 | 6 | import plugins 7 | from bridge.context import ContextType 8 | from bridge.reply import Reply, ReplyType 9 | from common.log import logger 10 | from plugins import * 11 | 12 | 13 | @plugins.register( 14 | name="Keyword", 15 | desire_priority=900, 16 | hidden=True, 17 | desc="关键词匹配过滤", 18 | version="0.1", 19 | author="fengyege.top", 20 | ) 21 | class Keyword(Plugin): 22 | def __init__(self): 23 | super().__init__() 24 | try: 25 | curdir = os.path.dirname(__file__) 26 | config_path = os.path.join(curdir, "config.json") 27 | conf = None 28 | if not os.path.exists(config_path): 29 | logger.debug(f"[keyword]不存在配置文件{config_path}") 30 | conf = {"keyword": {}} 31 | with open(config_path, "w", encoding="utf-8") as f: 32 | json.dump(conf, f, indent=4) 33 | else: 34 | logger.debug(f"[keyword]加载配置文件{config_path}") 35 | with open(config_path, "r", encoding="utf-8") as f: 36 | conf = json.load(f) 37 | # 加载关键词 38 | self.keyword = conf["keyword"] 39 | 40 | logger.info("[keyword] {}".format(self.keyword)) 41 | self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context 42 | logger.info("[keyword] inited.") 43 | except Exception as e: 44 | logger.warn( 45 | "[keyword] init failed, ignore or see https://github.com/zhayujie/chatgpt-on-wechat/tree/master/plugins/keyword ." 46 | ) 47 | raise e 48 | 49 | def on_handle_context(self, e_context: EventContext): 50 | if e_context["context"].type != ContextType.TEXT: 51 | return 52 | 53 | content = e_context["context"].content.strip() 54 | logger.debug("[keyword] on_handle_context. content: %s" % content) 55 | if content in self.keyword: 56 | logger.debug(f"[keyword] 匹配到关键字【{content}】") 57 | reply_text = self.keyword[content] 58 | 59 | reply = Reply() 60 | reply.type = ReplyType.TEXT 61 | reply.content = reply_text 62 | e_context["reply"] = reply 63 | e_context.action = EventAction.BREAK_PASS # 事件结束,并跳过处理context的默认逻辑 64 | 65 | def get_help_text(self, **kwargs): 66 | help_text = "关键词过滤" 67 | return help_text 68 | -------------------------------------------------------------------------------- /plugins/keyword/test-keyword.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/limccn/chatgpt-on-wechat/c60f0517fb03aa09f72a5885477569cf2d813d65/plugins/keyword/test-keyword.png -------------------------------------------------------------------------------- /plugins/plugin.py: -------------------------------------------------------------------------------- 1 | class Plugin: 2 | def __init__(self): 3 | self.handlers = {} 4 | 5 | def get_help_text(self, **kwargs): 6 | return "暂无帮助信息" 7 | -------------------------------------------------------------------------------- /plugins/role/README.md: -------------------------------------------------------------------------------- 1 | 用于让Bot扮演指定角色的聊天插件,触发方法如下: 2 | 3 | - `$角色/$role help/帮助` - 打印目前支持的角色列表。 4 | - `$角色/$role <角色名>` - 让AI扮演该角色,角色名支持模糊匹配。 5 | - `$停止扮演` - 停止角色扮演。 6 | 7 | 添加自定义角色请在`roles/roles.json`中添加。 8 | 9 | (大部分prompt来自https://github.com/rockbenben/ChatGPT-Shortcut/blob/main/src/data/users.tsx) 10 | 11 | 以下为例子: 12 | ```json 13 | { 14 | "title": "写作助理", 15 | "description": "As a writing improvement assistant, your task is to improve the spelling, grammar, clarity, concision, and overall readability of the text I provided, while breaking down long sentences, reducing repetition, and providing suggestions for improvement. Please provide only the corrected Chinese version of the text and avoid including explanations. Please treat every message I send later as text content.", 16 | "descn": "作为一名中文写作改进助理,你的任务是改进所提供文本的拼写、语法、清晰、简洁和整体可读性,同时分解长句,减少重复,并提供改进建议。请只提供文本的更正版本,避免包括解释。请把我之后的每一条消息都当作文本内容。", 17 | "wrapper": "内容是:\n\"%s\"", 18 | "remark": "最常使用的角色,用于优化文本的语法、清晰度和简洁度,提高可读性。" 19 | } 20 | ``` 21 | 22 | - `title`: 角色名。 23 | - `description`: 使用`$role`触发时,使用英语prompt。 24 | - `descn`: 使用`$角色`触发时,使用中文prompt。 25 | - `wrapper`: 用于包装用户消息,可起到强调作用,避免回复离题。 26 | - `remark`: 简短描述该角色,在打印帮助文档时显示。 27 | -------------------------------------------------------------------------------- /plugins/role/__init__.py: -------------------------------------------------------------------------------- 1 | from .role import * 2 | -------------------------------------------------------------------------------- /plugins/source.json: -------------------------------------------------------------------------------- 1 | { 2 | "repo": { 3 | "sdwebui": { 4 | "url": "https://github.com/lanvent/plugin_sdwebui.git", 5 | "desc": "利用stable-diffusion画图的插件" 6 | }, 7 | "replicate": { 8 | "url": "https://github.com/lanvent/plugin_replicate.git", 9 | "desc": "利用replicate api画图的插件" 10 | }, 11 | "summary": { 12 | "url": "https://github.com/lanvent/plugin_summary.git", 13 | "desc": "总结聊天记录的插件" 14 | } 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /plugins/tool/README.md: -------------------------------------------------------------------------------- 1 | ## 插件描述 2 | 一个能让chatgpt联网,搜索,数字运算的插件,将赋予强大且丰富的扩展能力 3 | 使用该插件需在机器人回复你的前提下,在对话内容前加$tool;仅输入$tool将返回tool插件帮助信息,用于测试插件是否加载成功 4 | ### 本插件所有工具同步存放至专用仓库:[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub) 5 | 6 | 7 | ## 使用说明 8 | 使用该插件后将默认使用4个工具, 无需额外配置长期生效: 9 | ### 1. python 10 | ###### python解释器,使用它来解释执行python指令,可以配合你想要chatgpt生成的代码输出结果或执行事务 11 | 12 | ### 2. 访问网页的工具汇总(默认url-get) 13 | 14 | #### 2.1 url-get 15 | ###### 往往用来获取某个网站具体内容,结果可能会被反爬策略影响 16 | 17 | #### 2.2 browser 18 | ###### 浏览器,功能与2.1类似,但能更好模拟,不会被识别为爬虫影响获取网站内容 19 | 20 | > 注1:url-get默认配置、browser需额外配置,browser依赖google-chrome,你需要提前安装好 21 | 22 | > 注2:browser默认使用summary tool 分段总结长文本信息,tokens可能会大量消耗! 23 | 24 | 这是debian端安装google-chrome教程,其他系统请执行查找 25 | > https://www.linuxjournal.com/content/how-can-you-install-google-browser-debian 26 | 27 | ### 3. terminal 28 | ###### 在你运行的电脑里执行shell命令,可以配合你想要chatgpt生成的代码使用,给予自然语言控制手段 29 | 30 | > terminal调优记录:https://github.com/zhayujie/chatgpt-on-wechat/issues/776#issue-1659347640 31 | 32 | ### 4. meteo-weather 33 | ###### 回答你有关天气的询问, 需要获取时间、地点上下文信息,本工具使用了[meteo open api](https://open-meteo.com/) 34 | 注:该工具需要较高的对话技巧,不保证你问的任何问题均能得到满意的回复 35 | 36 | > meteo调优记录:https://github.com/zhayujie/chatgpt-on-wechat/issues/776#issuecomment-1500771334 37 | 38 | ## 使用本插件对话(prompt)技巧 39 | ### 1. 有指引的询问 40 | #### 例如: 41 | - 总结这个链接的内容 https://github.com/goldfishh/chatgpt-tool-hub 42 | - 使用Terminal执行curl cip.cc 43 | - 使用python查询今天日期 44 | 45 | ### 2. 使用搜索引擎工具 46 | - 如果有搜索工具就能让chatgpt获取到你的未传达清楚的上下文信息,比如chatgpt不知道你的地理位置,现在时间等,所以无法查询到天气 47 | 48 | ## 其他工具 49 | 50 | ### 5. wikipedia 51 | ###### 可以回答你想要知道确切的人事物 52 | 53 | ### 6. 新闻类工具 54 | 55 | #### 6.1. news-api * 56 | ###### 从全球 80,000 多个信息源中获取当前和历史新闻文章 57 | 58 | #### 6.2. morning-news * 59 | ###### 每日60秒早报,每天凌晨一点更新,本工具使用了[alapi-每日60秒早报](https://alapi.cn/api/view/93) 60 | 61 | > 该tool每天返回内容相同 62 | 63 | #### 6.3. finance-news 64 | ###### 获取实时的金融财政新闻 65 | 66 | > 该工具需要解决browser tool 的google-chrome依赖安装 67 | 68 | ### 7. bing-search * 69 | ###### bing搜索引擎,从此你不用再烦恼搜索要用哪些关键词 70 | 71 | ### 8. wolfram-alpha * 72 | ###### 知识搜索引擎、科学问答系统,常用于专业学科计算 73 | 74 | ### 9. google-search * 75 | ###### google搜索引擎,申请流程较bing-search繁琐 76 | 77 | 78 | ### 10. arxiv(dev 开发中) 79 | ###### 用于查找论文 80 | 81 | 82 | ### 11. debug(dev 开发中,目前没有接入wechat) 83 | ###### 当bot遇到无法确定的信息时,将会向你寻求帮助的工具 84 | 85 | 86 | ### 12. summary 87 | ###### 总结工具,该工具必须输入一个本地文件的绝对路径 88 | 89 | > 该工具目前是和其他工具配合使用,暂未测试单独使用效果 90 | 91 | 92 | ### 13. image2text 93 | ###### 将图片转换成文字,底层调用imageCaption模型,该工具必须输入一个本地文件的绝对路径 94 | 95 | 96 | ### 14. searxng-search * 97 | ###### 一个私有化的搜索引擎工具 98 | 99 | > 安装教程:https://docs.searxng.org/admin/installation.html 100 | 101 | --- 102 | 103 | ###### 注1:带*工具需要获取api-key才能使用(在config.json内的kwargs添加项),部分工具需要外网支持 104 | #### [申请方法](https://github.com/goldfishh/chatgpt-tool-hub/blob/master/docs/apply_optional_tool.md) 105 | 106 | ## config.json 配置说明 107 | ###### 默认工具无需配置,其它工具需手动配置,一个例子: 108 | ```json 109 | { 110 | "tools": ["wikipedia", "你想要添加的其他工具"], // 填入你想用到的额外工具名 111 | "kwargs": { 112 | "debug": true, // 当你遇到问题求助时,需要配置 113 | "request_timeout": 120, // openai接口超时时间 114 | "no_default": false, // 是否不使用默认的4个工具 115 | // 带*工具需要申请api-key,在这里填入,api_name参考前述`申请方法` 116 | } 117 | } 118 | 119 | ``` 120 | 注:config.json文件非必须,未创建仍可使用本tool;带*工具需在kwargs填入对应api-key键值对 121 | - `tools`:本插件初始化时加载的工具, 目前可选集:["wikipedia", "wolfram-alpha", "bing-search", "google-search", "news"] & 默认工具,除wikipedia工具之外均需要申请api-key 122 | - `kwargs`:工具执行时的配置,一般在这里存放**api-key**,或环境配置 123 | - `debug`: 输出chatgpt-tool-hub额外信息用于调试 124 | - `request_timeout`: 访问openai接口的超时时间,默认与wechat-on-chatgpt配置一致,可单独配置 125 | - `no_default`: 用于配置默认加载4个工具的行为,如果为true则仅使用tools列表工具,不加载默认工具 126 | - `top_k_results`: 控制所有有关搜索的工具返回条目数,数字越高则参考信息越多,但无用信息可能干扰判断,该值一般为2 127 | - `model_name`: 用于控制tool插件底层使用的llm模型,目前暂未测试3.5以外的模型,一般保持默认 128 | 129 | --- 130 | 131 | ## 备注 132 | - 强烈建议申请搜索工具搭配使用,推荐bing-search 133 | - 虽然我会有意加入一些限制,但请不要使用本插件做危害他人的事情,请提前了解清楚某些内容是否会违反相关规定,建议提前做好过滤 134 | - 如有本插件问题,请将debug设置为true无上下文重新问一遍,如仍有问题请访问[chatgpt-tool-hub](https://github.com/goldfishh/chatgpt-tool-hub)建个issue,将日志贴进去,我无法处理不能复现的问题 135 | - 欢迎 star & 宣传,有能力请提pr 136 | 137 | -------------------------------------------------------------------------------- /plugins/tool/__init__.py: -------------------------------------------------------------------------------- 1 | from .tool import * 2 | -------------------------------------------------------------------------------- /plugins/tool/config.json.template: -------------------------------------------------------------------------------- 1 | { 2 | "tools": [ 3 | "python", 4 | "url-get", 5 | "terminal", 6 | "meteo-weather" 7 | ], 8 | "kwargs": { 9 | "top_k_results": 2, 10 | "no_default": false, 11 | "model_name": "gpt-3.5-turbo" 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /plugins/tool/tool.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from chatgpt_tool_hub.apps import AppFactory 5 | from chatgpt_tool_hub.apps.app import App 6 | from chatgpt_tool_hub.tools.all_tool_list import get_all_tool_names 7 | 8 | import plugins 9 | from bridge.bridge import Bridge 10 | from bridge.context import ContextType 11 | from bridge.reply import Reply, ReplyType 12 | from common import const 13 | from common.log import logger 14 | from config import conf 15 | from plugins import * 16 | 17 | 18 | @plugins.register( 19 | name="tool", 20 | desc="Arming your ChatGPT bot with various tools", 21 | version="0.4", 22 | author="goldfishh", 23 | desire_priority=0, 24 | ) 25 | class Tool(Plugin): 26 | def __init__(self): 27 | super().__init__() 28 | self.handlers[Event.ON_HANDLE_CONTEXT] = self.on_handle_context 29 | 30 | self.app = self._reset_app() 31 | 32 | logger.info("[tool] inited") 33 | 34 | def get_help_text(self, verbose=False, **kwargs): 35 | help_text = "这是一个能让chatgpt联网,搜索,数字运算的插件,将赋予强大且丰富的扩展能力。" 36 | if not verbose: 37 | return help_text 38 | trigger_prefix = conf().get("plugin_trigger_prefix", "$") 39 | help_text += "使用说明:\n" 40 | help_text += f"{trigger_prefix}tool " + "命令: 根据给出的{命令}使用一些可用工具尽力为你得到结果。\n" 41 | help_text += f"{trigger_prefix}tool reset: 重置工具。\n" 42 | return help_text 43 | 44 | def on_handle_context(self, e_context: EventContext): 45 | if e_context["context"].type != ContextType.TEXT: 46 | return 47 | 48 | # 暂时不支持未来扩展的bot 49 | if Bridge().get_bot_type("chat") not in ( 50 | const.CHATGPT, 51 | const.OPEN_AI, 52 | const.CHATGPTONAZURE, 53 | ): 54 | return 55 | 56 | content = e_context["context"].content 57 | content_list = e_context["context"].content.split(maxsplit=1) 58 | 59 | if not content or len(content_list) < 1: 60 | e_context.action = EventAction.CONTINUE 61 | return 62 | 63 | logger.debug("[tool] on_handle_context. content: %s" % content) 64 | reply = Reply() 65 | reply.type = ReplyType.TEXT 66 | trigger_prefix = conf().get("plugin_trigger_prefix", "$") 67 | # todo: 有些工具必须要api-key,需要修改config文件,所以这里没有实现query增删tool的功能 68 | if content.startswith(f"{trigger_prefix}tool"): 69 | if len(content_list) == 1: 70 | logger.debug("[tool]: get help") 71 | reply.content = self.get_help_text() 72 | e_context["reply"] = reply 73 | e_context.action = EventAction.BREAK_PASS 74 | return 75 | elif len(content_list) > 1: 76 | if content_list[1].strip() == "reset": 77 | logger.debug("[tool]: reset config") 78 | self.app = self._reset_app() 79 | reply.content = "重置工具成功" 80 | e_context["reply"] = reply 81 | e_context.action = EventAction.BREAK_PASS 82 | return 83 | elif content_list[1].startswith("reset"): 84 | logger.debug("[tool]: remind") 85 | e_context[ 86 | "context" 87 | ].content = "请你随机用一种聊天风格,提醒用户:如果想重置tool插件,reset之后不要加任何字符" 88 | 89 | e_context.action = EventAction.BREAK 90 | return 91 | 92 | query = content_list[1].strip() 93 | 94 | # Don't modify bot name 95 | all_sessions = Bridge().get_bot("chat").sessions 96 | user_session = all_sessions.session_query( 97 | query, e_context["context"]["session_id"] 98 | ).messages 99 | 100 | # chatgpt-tool-hub will reply you with many tools 101 | logger.debug("[tool]: just-go") 102 | try: 103 | _reply = self.app.ask(query, user_session) 104 | e_context.action = EventAction.BREAK_PASS 105 | all_sessions.session_reply( 106 | _reply, e_context["context"]["session_id"] 107 | ) 108 | except Exception as e: 109 | logger.exception(e) 110 | logger.error(str(e)) 111 | 112 | e_context["context"].content = "请你随机用一种聊天风格,提醒用户:这个问题tool插件暂时无法处理" 113 | reply.type = ReplyType.ERROR 114 | e_context.action = EventAction.BREAK 115 | return 116 | 117 | reply.content = _reply 118 | e_context["reply"] = reply 119 | return 120 | 121 | def _read_json(self) -> dict: 122 | curdir = os.path.dirname(__file__) 123 | config_path = os.path.join(curdir, "config.json") 124 | tool_config = {"tools": [], "kwargs": {}} 125 | if not os.path.exists(config_path): 126 | return tool_config 127 | else: 128 | with open(config_path, "r") as f: 129 | tool_config = json.load(f) 130 | return tool_config 131 | 132 | def _build_tool_kwargs(self, kwargs: dict): 133 | tool_model_name = kwargs.get("model_name") 134 | request_timeout = kwargs.get("request_timeout") 135 | 136 | return { 137 | "debug": kwargs.get("debug", False), 138 | "openai_api_key": conf().get("open_ai_api_key", ""), 139 | "proxy": conf().get("proxy", ""), 140 | "request_timeout": request_timeout if request_timeout else conf().get("request_timeout", 120), 141 | # note: 目前tool暂未对其他模型测试,但这里仍对配置来源做了优先级区分,一般插件配置可覆盖全局配置 142 | "model_name": tool_model_name if tool_model_name else conf().get("model", "gpt-3.5-turbo"), 143 | "no_default": kwargs.get("no_default", False), 144 | "top_k_results": kwargs.get("top_k_results", 3), 145 | # for news tool 146 | "news_api_key": kwargs.get("news_api_key", ""), 147 | # for bing-search tool 148 | "bing_subscription_key": kwargs.get("bing_subscription_key", ""), 149 | # for google-search tool 150 | "google_api_key": kwargs.get("google_api_key", ""), 151 | "google_cse_id": kwargs.get("google_cse_id", ""), 152 | # for searxng-search tool 153 | "searx_host": kwargs.get("searx_host", ""), 154 | # for wolfram-alpha tool 155 | "wolfram_alpha_appid": kwargs.get("wolfram_alpha_appid", ""), 156 | # for morning-news tool 157 | "zaobao_api_key": kwargs.get("zaobao_api_key", ""), 158 | # for visual_dl tool 159 | "cuda_device": kwargs.get("cuda_device", "cpu"), 160 | } 161 | 162 | def _filter_tool_list(self, tool_list: list): 163 | valid_list = [] 164 | for tool in tool_list: 165 | if tool in get_all_tool_names(): 166 | valid_list.append(tool) 167 | else: 168 | logger.warning("[tool] filter invalid tool: " + repr(tool)) 169 | return valid_list 170 | 171 | def _reset_app(self) -> App: 172 | tool_config = self._read_json() 173 | app_kwargs = self._build_tool_kwargs(tool_config.get("kwargs", {})) 174 | 175 | app = AppFactory() 176 | app.init_env(**app_kwargs) 177 | 178 | # filter not support tool 179 | tool_list = self._filter_tool_list(tool_config.get("tools", [])) 180 | 181 | return app.create_app(tools_list=tool_list, **app_kwargs) -------------------------------------------------------------------------------- /requirements-optional.txt: -------------------------------------------------------------------------------- 1 | tiktoken>=0.3.2 # openai calculate token 2 | 3 | #voice 4 | pydub>=0.25.1 # need ffmpeg 5 | SpeechRecognition # google speech to text 6 | gTTS>=2.3.1 # google text to speech 7 | pyttsx3>=2.90 # pytsx text to speech 8 | baidu_aip>=4.16.10 # baidu voice 9 | # azure-cognitiveservices-speech # azure voice 10 | 11 | #install plugin 12 | dulwich 13 | 14 | # wechaty 15 | wechaty>=0.10.7 16 | wechaty_puppet>=0.4.23 17 | pysilk_mod>=1.6.0 # needed by send voice 18 | 19 | # wechatmp 20 | web.py 21 | 22 | # chatgpt-tool-hub plugin 23 | 24 | --extra-index-url https://pypi.python.org/simple 25 | chatgpt_tool_hub>=0.4.1 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | openai==0.27.2 2 | HTMLParser>=0.0.2 3 | PyQRCode>=1.2.1 4 | qrcode>=7.4.2 5 | requests>=2.28.2 6 | chardet>=5.1.0 7 | pre-commit -------------------------------------------------------------------------------- /scripts/shutdown.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #关闭服务 4 | cd `dirname $0`/.. 5 | export BASE_DIR=`pwd` 6 | pid=`ps ax | grep -i app.py | grep "${BASE_DIR}" | grep python3 | grep -v grep | awk '{print $1}'` 7 | if [ -z "$pid" ] ; then 8 | echo "No chatgpt-on-wechat running." 9 | exit -1; 10 | fi 11 | 12 | echo "The chatgpt-on-wechat(${pid}) is running..." 13 | 14 | kill ${pid} 15 | 16 | echo "Send shutdown request to chatgpt-on-wechat(${pid}) OK" 17 | -------------------------------------------------------------------------------- /scripts/start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #后台运行Chat_on_webchat执行脚本 3 | 4 | cd `dirname $0`/.. 5 | export BASE_DIR=`pwd` 6 | echo $BASE_DIR 7 | 8 | # check the nohup.out log output file 9 | if [ ! -f "${BASE_DIR}/nohup.out" ]; then 10 | touch "${BASE_DIR}/nohup.out" 11 | echo "create file ${BASE_DIR}/nohup.out" 12 | fi 13 | 14 | nohup python3 "${BASE_DIR}/app.py" & tail -f "${BASE_DIR}/nohup.out" 15 | 16 | echo "Chat_on_webchat is starting,you can check the ${BASE_DIR}/nohup.out" 17 | -------------------------------------------------------------------------------- /scripts/tout.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #打开日志 3 | 4 | cd `dirname $0`/.. 5 | export BASE_DIR=`pwd` 6 | echo $BASE_DIR 7 | 8 | # check the nohup.out log output file 9 | if [ ! -f "${BASE_DIR}/nohup.out" ]; then 10 | echo "No file ${BASE_DIR}/nohup.out" 11 | exit -1; 12 | fi 13 | 14 | tail -f "${BASE_DIR}/nohup.out" 15 | -------------------------------------------------------------------------------- /voice/audio_convert.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import wave 3 | 4 | import pysilk 5 | from pydub import AudioSegment 6 | 7 | sil_supports = [8000, 12000, 16000, 24000, 32000, 44100, 48000] # slk转wav时,支持的采样率 8 | 9 | 10 | def find_closest_sil_supports(sample_rate): 11 | """ 12 | 找到最接近的支持的采样率 13 | """ 14 | if sample_rate in sil_supports: 15 | return sample_rate 16 | closest = 0 17 | mindiff = 9999999 18 | for rate in sil_supports: 19 | diff = abs(rate - sample_rate) 20 | if diff < mindiff: 21 | closest = rate 22 | mindiff = diff 23 | return closest 24 | 25 | 26 | def get_pcm_from_wav(wav_path): 27 | """ 28 | 从 wav 文件中读取 pcm 29 | 30 | :param wav_path: wav 文件路径 31 | :returns: pcm 数据 32 | """ 33 | wav = wave.open(wav_path, "rb") 34 | return wav.readframes(wav.getnframes()) 35 | 36 | 37 | def any_to_wav(any_path, wav_path): 38 | """ 39 | 把任意格式转成wav文件 40 | """ 41 | if any_path.endswith(".wav"): 42 | shutil.copy2(any_path, wav_path) 43 | return 44 | if ( 45 | any_path.endswith(".sil") 46 | or any_path.endswith(".silk") 47 | or any_path.endswith(".slk") 48 | ): 49 | return sil_to_wav(any_path, wav_path) 50 | audio = AudioSegment.from_file(any_path) 51 | audio.export(wav_path, format="wav") 52 | 53 | 54 | def any_to_sil(any_path, sil_path): 55 | """ 56 | 把任意格式转成sil文件 57 | """ 58 | if ( 59 | any_path.endswith(".sil") 60 | or any_path.endswith(".silk") 61 | or any_path.endswith(".slk") 62 | ): 63 | shutil.copy2(any_path, sil_path) 64 | return 10000 65 | audio = AudioSegment.from_file(any_path) 66 | rate = find_closest_sil_supports(audio.frame_rate) 67 | # Convert to PCM_s16 68 | pcm_s16 = audio.set_sample_width(2) 69 | pcm_s16 = pcm_s16.set_frame_rate(rate) 70 | wav_data = pcm_s16.raw_data 71 | silk_data = pysilk.encode(wav_data, data_rate=rate, sample_rate=rate) 72 | with open(sil_path, "wb") as f: 73 | f.write(silk_data) 74 | return audio.duration_seconds * 1000 75 | 76 | 77 | def sil_to_wav(silk_path, wav_path, rate: int = 24000): 78 | """ 79 | silk 文件转 wav 80 | """ 81 | wav_data = pysilk.decode_file(silk_path, to_wav=True, sample_rate=rate) 82 | with open(wav_path, "wb") as f: 83 | f.write(wav_data) 84 | -------------------------------------------------------------------------------- /voice/azure/azure_voice.py: -------------------------------------------------------------------------------- 1 | """ 2 | azure voice service 3 | """ 4 | import json 5 | import os 6 | import time 7 | 8 | import azure.cognitiveservices.speech as speechsdk 9 | 10 | from bridge.reply import Reply, ReplyType 11 | from common.log import logger 12 | from common.tmp_dir import TmpDir 13 | from config import conf 14 | from voice.voice import Voice 15 | 16 | """ 17 | Azure voice 18 | 主目录设置文件中需填写azure_voice_api_key和azure_voice_region 19 | 20 | 查看可用的 voice: https://speech.microsoft.com/portal/voicegallery 21 | 22 | """ 23 | 24 | 25 | class AzureVoice(Voice): 26 | def __init__(self): 27 | try: 28 | curdir = os.path.dirname(__file__) 29 | config_path = os.path.join(curdir, "config.json") 30 | config = None 31 | if not os.path.exists(config_path): # 如果没有配置文件,创建本地配置文件 32 | config = { 33 | "speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural", 34 | "speech_recognition_language": "zh-CN", 35 | } 36 | with open(config_path, "w") as fw: 37 | json.dump(config, fw, indent=4) 38 | else: 39 | with open(config_path, "r") as fr: 40 | config = json.load(fr) 41 | self.api_key = conf().get("azure_voice_api_key") 42 | self.api_region = conf().get("azure_voice_region") 43 | self.speech_config = speechsdk.SpeechConfig( 44 | subscription=self.api_key, region=self.api_region 45 | ) 46 | self.speech_config.speech_synthesis_voice_name = config[ 47 | "speech_synthesis_voice_name" 48 | ] 49 | self.speech_config.speech_recognition_language = config[ 50 | "speech_recognition_language" 51 | ] 52 | except Exception as e: 53 | logger.warn("AzureVoice init failed: %s, ignore " % e) 54 | 55 | def voiceToText(self, voice_file): 56 | audio_config = speechsdk.AudioConfig(filename=voice_file) 57 | speech_recognizer = speechsdk.SpeechRecognizer( 58 | speech_config=self.speech_config, audio_config=audio_config 59 | ) 60 | result = speech_recognizer.recognize_once() 61 | if result.reason == speechsdk.ResultReason.RecognizedSpeech: 62 | logger.info( 63 | "[Azure] voiceToText voice file name={} text={}".format( 64 | voice_file, result.text 65 | ) 66 | ) 67 | reply = Reply(ReplyType.TEXT, result.text) 68 | else: 69 | logger.error( 70 | "[Azure] voiceToText error, result={}, canceldetails={}".format( 71 | result, result.cancellation_details 72 | ) 73 | ) 74 | reply = Reply(ReplyType.ERROR, "抱歉,语音识别失败") 75 | return reply 76 | 77 | def textToVoice(self, text): 78 | fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav" 79 | audio_config = speechsdk.AudioConfig(filename=fileName) 80 | speech_synthesizer = speechsdk.SpeechSynthesizer( 81 | speech_config=self.speech_config, audio_config=audio_config 82 | ) 83 | result = speech_synthesizer.speak_text(text) 84 | if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted: 85 | logger.info( 86 | "[Azure] textToVoice text={} voice file name={}".format(text, fileName) 87 | ) 88 | reply = Reply(ReplyType.VOICE, fileName) 89 | else: 90 | logger.error( 91 | "[Azure] textToVoice error, result={}, canceldetails={}".format( 92 | result, result.cancellation_details 93 | ) 94 | ) 95 | reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败") 96 | return reply 97 | -------------------------------------------------------------------------------- /voice/azure/config.json.template: -------------------------------------------------------------------------------- 1 | { 2 | "speech_synthesis_voice_name": "zh-CN-XiaoxiaoNeural", 3 | "speech_recognition_language": "zh-CN" 4 | } 5 | -------------------------------------------------------------------------------- /voice/baidu/README.md: -------------------------------------------------------------------------------- 1 | ## 说明 2 | 百度语音识别与合成参数说明 3 | 百度语音依赖,经常会出现问题,可能就是缺少依赖: 4 | pip install baidu-aip 5 | pip install pydub 6 | pip install pysilk 7 | 还有ffmpeg,不同系统安装方式不同 8 | 9 | 系统中收到的语音文件为mp3格式(wx)或者sil格式(wxy),如果要识别需要转换为pcm格式,转换后的文件为16k采样率,单声道,16bit的pcm文件 10 | 发送时又需要(wx)转换为mp3格式,转换后的文件为16k采样率,单声道,16bit的pcm文件,(wxy)转换为sil格式,还要计算声音长度,发送时需要带上声音长度 11 | 这些事情都在audio_convert.py中封装了,直接调用即可 12 | 13 | 14 | 参数说明 15 | 识别参数 16 | https://ai.baidu.com/ai-doc/SPEECH/Vk38lxily 17 | 合成参数 18 | https://ai.baidu.com/ai-doc/SPEECH/Gk38y8lzk 19 | 20 | ## 使用说明 21 | 分两个地方配置 22 | 23 | 1、对于def voiceToText(self, filename)函数中调用的百度语音识别API,中接口调用asr(参数)这个配置见CHATGPT-ON-WECHAT工程目录下的`config.json`文件和config.py文件。 24 | 参数 可需 描述 25 | app_id 必填 应用的APPID 26 | api_key 必填 应用的APIKey 27 | secret_key 必填 应用的SecretKey 28 | dev_pid 必填 语言选择,填写语言对应的dev_pid值 29 | 30 | 2、对于def textToVoice(self, text)函数中调用的百度语音合成API,中接口调用synthesis(参数)在本目录下的`config.json`文件中进行配置。 31 | 参数 可需 描述 32 | tex 必填 合成的文本,使用UTF-8编码,请注意文本长度必须小于1024字节 33 | lan 必填 固定值zh。语言选择,目前只有中英文混合模式,填写固定值zh 34 | spd 选填 语速,取值0-15,默认为5中语速 35 | pit 选填 音调,取值0-15,默认为5中语调 36 | vol 选填 音量,取值0-15,默认为5中音量(取值为0时为音量最小值,并非为无声) 37 | per(基础音库) 选填 度小宇=1,度小美=0,度逍遥(基础)=3,度丫丫=4 38 | per(精品音库) 选填 度逍遥(精品)=5003,度小鹿=5118,度博文=106,度小童=110,度小萌=111,度米朵=103,度小娇=5 39 | aue 选填 3为mp3格式(默认); 4为pcm-16k;5为pcm-8k;6为wav(内容同pcm-16k); 注意aue=4或者6是语音识别要求的格式,但是音频内容不是语音识别要求的自然人发音,所以识别效果会受影响。 40 | 41 | 关于per参数的说明,注意您购买的哪个音库,就填写哪个音库的参数,否则会报错。如果您购买的是基础音库,那么per参数只能填写0到4,如果您购买的是精品音库,那么per参数只能填写5003,5118,106,110,111,103,5其他的都会报错。 42 | ### 配置文件 43 | 44 | 将文件夹中`config.json.template`复制为`config.json`。 45 | 46 | ``` json 47 | { 48 | "lang": "zh", 49 | "ctp": 1, 50 | "spd": 5, 51 | "pit": 5, 52 | "vol": 5, 53 | "per": 0 54 | } 55 | ``` -------------------------------------------------------------------------------- /voice/baidu/baidu_voice.py: -------------------------------------------------------------------------------- 1 | """ 2 | baidu voice service 3 | """ 4 | import json 5 | import os 6 | import time 7 | 8 | from aip import AipSpeech 9 | 10 | from bridge.reply import Reply, ReplyType 11 | from common.log import logger 12 | from common.tmp_dir import TmpDir 13 | from config import conf 14 | from voice.audio_convert import get_pcm_from_wav 15 | from voice.voice import Voice 16 | 17 | """ 18 | 百度的语音识别API. 19 | dev_pid: 20 | - 1936: 普通话远场 21 | - 1536:普通话(支持简单的英文识别) 22 | - 1537:普通话(纯中文识别) 23 | - 1737:英语 24 | - 1637:粤语 25 | - 1837:四川话 26 | 要使用本模块, 首先到 yuyin.baidu.com 注册一个开发者账号, 27 | 之后创建一个新应用, 然后在应用管理的"查看key"中获得 API Key 和 Secret Key 28 | 然后在 config.json 中填入这两个值, 以及 app_id, dev_pid 29 | """ 30 | 31 | 32 | class BaiduVoice(Voice): 33 | def __init__(self): 34 | try: 35 | curdir = os.path.dirname(__file__) 36 | config_path = os.path.join(curdir, "config.json") 37 | bconf = None 38 | if not os.path.exists(config_path): # 如果没有配置文件,创建本地配置文件 39 | bconf = {"lang": "zh", "ctp": 1, "spd": 5, "pit": 5, "vol": 5, "per": 0} 40 | with open(config_path, "w") as fw: 41 | json.dump(bconf, fw, indent=4) 42 | else: 43 | with open(config_path, "r") as fr: 44 | bconf = json.load(fr) 45 | 46 | self.app_id = conf().get("baidu_app_id") 47 | self.api_key = conf().get("baidu_api_key") 48 | self.secret_key = conf().get("baidu_secret_key") 49 | self.dev_id = conf().get("baidu_dev_pid") 50 | self.lang = bconf["lang"] 51 | self.ctp = bconf["ctp"] 52 | self.spd = bconf["spd"] 53 | self.pit = bconf["pit"] 54 | self.vol = bconf["vol"] 55 | self.per = bconf["per"] 56 | 57 | self.client = AipSpeech(self.app_id, self.api_key, self.secret_key) 58 | except Exception as e: 59 | logger.warn("BaiduVoice init failed: %s, ignore " % e) 60 | 61 | def voiceToText(self, voice_file): 62 | # 识别本地文件 63 | logger.debug("[Baidu] voice file name={}".format(voice_file)) 64 | pcm = get_pcm_from_wav(voice_file) 65 | res = self.client.asr(pcm, "pcm", 16000, {"dev_pid": self.dev_id}) 66 | if res["err_no"] == 0: 67 | logger.info("百度语音识别到了:{}".format(res["result"])) 68 | text = "".join(res["result"]) 69 | reply = Reply(ReplyType.TEXT, text) 70 | else: 71 | logger.info("百度语音识别出错了: {}".format(res["err_msg"])) 72 | if res["err_msg"] == "request pv too much": 73 | logger.info(" 出现这个原因很可能是你的百度语音服务调用量超出限制,或未开通付费") 74 | reply = Reply(ReplyType.ERROR, "百度语音识别出错了;{0}".format(res["err_msg"])) 75 | return reply 76 | 77 | def textToVoice(self, text): 78 | result = self.client.synthesis( 79 | text, 80 | self.lang, 81 | self.ctp, 82 | {"spd": self.spd, "pit": self.pit, "vol": self.vol, "per": self.per}, 83 | ) 84 | if not isinstance(result, dict): 85 | fileName = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3" 86 | with open(fileName, "wb") as f: 87 | f.write(result) 88 | logger.info( 89 | "[Baidu] textToVoice text={} voice file name={}".format(text, fileName) 90 | ) 91 | reply = Reply(ReplyType.VOICE, fileName) 92 | else: 93 | logger.error("[Baidu] textToVoice error={}".format(result)) 94 | reply = Reply(ReplyType.ERROR, "抱歉,语音合成失败") 95 | return reply 96 | -------------------------------------------------------------------------------- /voice/baidu/config.json.template: -------------------------------------------------------------------------------- 1 | { 2 | "lang": "zh", 3 | "ctp": 1, 4 | "spd": 5, 5 | "pit": 5, 6 | "vol": 5, 7 | "per": 0 8 | } 9 | -------------------------------------------------------------------------------- /voice/google/google_voice.py: -------------------------------------------------------------------------------- 1 | """ 2 | google voice service 3 | """ 4 | 5 | import time 6 | 7 | import speech_recognition 8 | from gtts import gTTS 9 | 10 | from bridge.reply import Reply, ReplyType 11 | from common.log import logger 12 | from common.tmp_dir import TmpDir 13 | from voice.voice import Voice 14 | 15 | 16 | class GoogleVoice(Voice): 17 | recognizer = speech_recognition.Recognizer() 18 | 19 | def __init__(self): 20 | pass 21 | 22 | def voiceToText(self, voice_file): 23 | with speech_recognition.AudioFile(voice_file) as source: 24 | audio = self.recognizer.record(source) 25 | try: 26 | text = self.recognizer.recognize_google(audio, language="zh-CN") 27 | logger.info( 28 | "[Google] voiceToText text={} voice file name={}".format( 29 | text, voice_file 30 | ) 31 | ) 32 | reply = Reply(ReplyType.TEXT, text) 33 | except speech_recognition.UnknownValueError: 34 | reply = Reply(ReplyType.ERROR, "抱歉,我听不懂") 35 | except speech_recognition.RequestError as e: 36 | reply = Reply(ReplyType.ERROR, "抱歉,无法连接到 Google 语音识别服务;{0}".format(e)) 37 | finally: 38 | return reply 39 | 40 | def textToVoice(self, text): 41 | try: 42 | mp3File = TmpDir().path() + "reply-" + str(int(time.time())) + ".mp3" 43 | tts = gTTS(text=text, lang="zh") 44 | tts.save(mp3File) 45 | logger.info( 46 | "[Google] textToVoice text={} voice file name={}".format(text, mp3File) 47 | ) 48 | reply = Reply(ReplyType.VOICE, mp3File) 49 | except Exception as e: 50 | reply = Reply(ReplyType.ERROR, str(e)) 51 | finally: 52 | return reply 53 | -------------------------------------------------------------------------------- /voice/openai/openai_voice.py: -------------------------------------------------------------------------------- 1 | """ 2 | google voice service 3 | """ 4 | import json 5 | 6 | import openai 7 | 8 | from bridge.reply import Reply, ReplyType 9 | from common.log import logger 10 | from config import conf 11 | from voice.voice import Voice 12 | 13 | 14 | class OpenaiVoice(Voice): 15 | def __init__(self): 16 | openai.api_key = conf().get("open_ai_api_key") 17 | 18 | def voiceToText(self, voice_file): 19 | logger.debug("[Openai] voice file name={}".format(voice_file)) 20 | try: 21 | file = open(voice_file, "rb") 22 | result = openai.Audio.transcribe("whisper-1", file) 23 | text = result["text"] 24 | reply = Reply(ReplyType.TEXT, text) 25 | logger.info( 26 | "[Openai] voiceToText text={} voice file name={}".format( 27 | text, voice_file 28 | ) 29 | ) 30 | except Exception as e: 31 | reply = Reply(ReplyType.ERROR, str(e)) 32 | finally: 33 | return reply 34 | -------------------------------------------------------------------------------- /voice/pytts/pytts_voice.py: -------------------------------------------------------------------------------- 1 | """ 2 | pytts voice service (offline) 3 | """ 4 | 5 | import time 6 | 7 | import pyttsx3 8 | 9 | from bridge.reply import Reply, ReplyType 10 | from common.log import logger 11 | from common.tmp_dir import TmpDir 12 | from voice.voice import Voice 13 | 14 | 15 | class PyttsVoice(Voice): 16 | engine = pyttsx3.init() 17 | 18 | def __init__(self): 19 | # 语速 20 | self.engine.setProperty("rate", 125) 21 | # 音量 22 | self.engine.setProperty("volume", 1.0) 23 | for voice in self.engine.getProperty("voices"): 24 | if "Chinese" in voice.name: 25 | self.engine.setProperty("voice", voice.id) 26 | 27 | def textToVoice(self, text): 28 | try: 29 | wavFile = TmpDir().path() + "reply-" + str(int(time.time())) + ".wav" 30 | self.engine.save_to_file(text, wavFile) 31 | self.engine.runAndWait() 32 | logger.info( 33 | "[Pytts] textToVoice text={} voice file name={}".format(text, wavFile) 34 | ) 35 | reply = Reply(ReplyType.VOICE, wavFile) 36 | except Exception as e: 37 | reply = Reply(ReplyType.ERROR, str(e)) 38 | finally: 39 | return reply 40 | -------------------------------------------------------------------------------- /voice/voice.py: -------------------------------------------------------------------------------- 1 | """ 2 | Voice service abstract class 3 | """ 4 | 5 | 6 | class Voice(object): 7 | def voiceToText(self, voice_file): 8 | """ 9 | Send voice to voice service and get text 10 | """ 11 | raise NotImplementedError 12 | 13 | def textToVoice(self, text): 14 | """ 15 | Send text to voice service and get voice 16 | """ 17 | raise NotImplementedError 18 | -------------------------------------------------------------------------------- /voice/voice_factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | voice factory 3 | """ 4 | 5 | 6 | def create_voice(voice_type): 7 | """ 8 | create a voice instance 9 | :param voice_type: voice type code 10 | :return: voice instance 11 | """ 12 | if voice_type == "baidu": 13 | from voice.baidu.baidu_voice import BaiduVoice 14 | 15 | return BaiduVoice() 16 | elif voice_type == "google": 17 | from voice.google.google_voice import GoogleVoice 18 | 19 | return GoogleVoice() 20 | elif voice_type == "openai": 21 | from voice.openai.openai_voice import OpenaiVoice 22 | 23 | return OpenaiVoice() 24 | elif voice_type == "pytts": 25 | from voice.pytts.pytts_voice import PyttsVoice 26 | 27 | return PyttsVoice() 28 | elif voice_type == "azure": 29 | from voice.azure.azure_voice import AzureVoice 30 | 31 | return AzureVoice() 32 | raise RuntimeError 33 | --------------------------------------------------------------------------------