├── LICENSE ├── README.md ├── config.json ├── config.json.example ├── extensions ├── __init__.py ├── mai_setupButler │ ├── __init__.py │ ├── requirements.txt │ ├── setupButler.py │ ├── static │ │ ├── info.json │ │ └── logo.png │ └── templates │ │ └── mai_setupButler │ │ └── index.html └── mai_wechat │ ├── __init__.py │ ├── requirements.txt │ ├── server.py │ ├── static │ ├── info.json │ └── logo.png │ ├── templates │ └── mai_wechat │ │ └── index.html │ └── wxauto │ ├── __init__.py │ ├── a.dll │ ├── color.py │ ├── elements.py │ ├── errors.py │ ├── languages.py │ ├── uiautomation.py │ ├── utils.py │ └── wxauto.py ├── main.py ├── meowServer.py ├── models └── EEADME.md ├── requirement.txt ├── rwkv ├── 20B_tokenizer.json ├── __pycache__ │ ├── model.cpython-39.pyc │ ├── rwkv_tokenizer.cpython-39.pyc │ └── utils.cpython-39.pyc ├── cpp │ ├── librwkv.dylib │ ├── librwkv.so │ ├── model.py │ ├── rwkv.dll │ ├── rwkv_cpp_model.py │ └── rwkv_cpp_shared_library.py ├── cuda │ ├── gemm_fp16_cublas.cpp │ ├── operators.cu │ ├── rwkv5.cu │ ├── rwkv5_op.cpp │ ├── rwkv6.cu │ ├── rwkv6_op.cpp │ ├── rwkv7.cu │ ├── rwkv7_op.cpp │ └── wrapper.cpp ├── model.py ├── rwkv5.pyd ├── rwkv6.pyd ├── rwkv_tokenizer.py ├── rwkv_vocab_v20230424.txt ├── rwkv_vocab_v20230424_special_token.txt ├── tokenizer-midi.json ├── tokenizer-midipiano.json ├── utils.py ├── webgpu │ ├── model.py │ └── web_rwkv_py.cp310-win_amd64.pyd ├── wkv7s.pyd └── wkv_cuda.pyd ├── static ├── config │ ├── default.json │ └── noke.json ├── css │ ├── all.min.css │ ├── font-awesome.css │ ├── font-awesome.min.css │ ├── noticejs.css │ └── styles.css ├── fonts │ ├── FontAwesome.otf │ ├── fontawesome-webfont.eot │ ├── fontawesome-webfont.svg │ ├── fontawesome-webfont.ttf │ ├── fontawesome-webfont.woff │ └── fontawesome-webfont.woff2 ├── img │ ├── PC_Coordinate.png │ ├── PC_Coordinate.wb │ ├── ai.png │ ├── ai2.png │ ├── ai3.png │ ├── ai4.png │ ├── ai42.png │ ├── bilibini.png │ ├── coordinate.png │ └── nekomusume.png ├── js │ ├── axios.min.js │ ├── main.js │ ├── notice.js │ ├── socket.io.min.js │ └── vue.global.js └── webfonts │ ├── fa-brands-400.eot │ ├── fa-brands-400.svg │ ├── fa-brands-400.ttf │ ├── fa-brands-400.woff │ ├── fa-brands-400.woff2 │ ├── fa-regular-400.eot │ ├── fa-regular-400.svg │ ├── fa-regular-400.ttf │ ├── fa-regular-400.woff │ ├── fa-regular-400.woff2 │ ├── fa-solid-900.eot │ ├── fa-solid-900.svg │ ├── fa-solid-900.ttf │ ├── fa-solid-900.woff │ └── fa-solid-900.woff2 ├── templates ├── extension.html └── index.html └── test.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 bilibini 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Meow-AI 2 | “Meow-AI”是基于RWKV的本地轻量级聊天AI 3 | 4 | ## 环境配置 5 | 推荐Python版本 3.9.18 6 | 需要模块numpy、tokenizers、prompt_toolkit、flask、flask-socketio、torch、subprocess 7 | `pip3 install numpy tokenizers prompt_toolkit flask flask-socketio torch subprocess` 8 | 9 | ## 运行要求 10 | 最低4G运行内存(CPU+4G内存,可以运行430m的小模型) 11 | 最高不限(最高可运行14b的大模型,聊天质量和效果更加好) 12 | 13 | ## 运行设置 14 | 根据自己电脑配置在模型网站([https://huggingface.co/BlinkDL](https://huggingface.co/BlinkDL))下载合适的模型 ([镜像网址](https://hf-mirror.com/BlinkDL)) 15 | 下载的所有文件放在models文件夹内,并修改config.json中的“modelFile”,将'RWKV-x070-World-0.1B-v2.8-20241210-ctx4096'修改为自己下载的模型名 16 | 运行main.py,在浏览器中打开http://172.0.0.1:5000 即可开始对话 17 | 18 | ## 功能更新 19 | - 2024-01-05:支持手动停止对话 20 | - 2024-01-19:支持编辑对话实现简单的手动微调 21 | - 2024-01-23:支持配置导入/导出,支持自定义AI性格人设 22 | - 2024-03-01:完全使用RWKV架构,实现更小的模型体积,降低运行内存和CPU占用 23 | - 2024-04-26:优化整体架构 24 | - 2024-07-08:支持扩展功能,支持微信自动聊天 25 | - 2024-07-21:完成一键运行包 26 | - 2025-02-20:支持RWKV7模型 27 | 28 | ## 演示效果 29 | ### 1. 持续对话 30 | 该演示环境为GPU+8G,使用'RWKV-4-World-CHNtuned-0.1B-v1-20230617-ctx4096'模型 31 | ![image](https://img.z4a.net/images/2024/06/12/9067bcfab67bb37a3c012ccca84b39be.png) 32 | ### 2. 手动调整对话 33 | ~注:手动添加修改更多的自己预设的对话,对后续的实现自己想要的聊天效果有很大帮助~ 34 | ![image](https://img.z4a.net/images/2024/06/12/c21bbf309e8836a2f420a2e9aa430805.png) 35 | ### 3. 配置导入导出 36 | 可以导入导出对话以及配置信息 37 | ![image](https://img.z4a.net/images/2024/06/12/827ad26b2bfd9c31897b75081133d307.png) 38 | ### 4. 调教AI人设性格 39 | 可以自定义设置MeowAI人设性格,让MeowAI更加符合自己的喜好 40 | ![image](https://img.z4a.net/images/2024/06/12/Meow.gif) 41 | ### 5. 扩展功能 42 | 支持添加自定义扩展功能,目前已完成“微信自动聊天” 43 | 点击启动后,将自动打开微信,自动回复聊天 44 | ![image](https://t3.picb.cc/2024/07/08/iQfcbG.gif) 45 | ### 6. 一键运行包 46 | 无需手动配置环境,运行包里含了所有所需环境,只需双击运行即可 47 | (由于运行包大小过大,请在[releases](https://github.com/bilibini/Meow-AI/releases)下载) 48 | ![image](https://img.z4a.net/images/2024/07/21/_.png) 49 | 50 | 51 | ## 未来展望 52 | 1. 完成更多功能扩展 53 | 2. 待续…… 54 | 55 | 56 | ## Star History 57 | [![Star History Chart](https://api.star-history.com/svg?repos=bilibini/Meow-AI&type=Date)](https://star-history.com/#bilibini/Meow-AI&Date) 58 | 59 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "host": "127.0.0.1", 3 | "port": 5000, 4 | "modelsFolder": "models", 5 | "dataFolder": "data", 6 | "outputFolder": "output", 7 | "modelFile": "RWKV-x070-World-0.1B-v2.8-20241210-ctx4096", 8 | "strategy": "cuda fp32", 9 | "configFile": "default.json", 10 | "autoOpen": true 11 | } -------------------------------------------------------------------------------- /config.json.example: -------------------------------------------------------------------------------- 1 | { 2 | "host": "127.0.0.1", 3 | "port": 5000, 4 | "modelsFolder": "models", 5 | "dataFolder": "data", 6 | "configFolder": "config", 7 | "outputFolder": "output", 8 | "modelFile": "RWKV-x060-World-1B6-v2-20240208-ctx4096", 9 | "configFile": "default.json", 10 | "autoOpen": true 11 | } -------------------------------------------------------------------------------- /extensions/__init__.py: -------------------------------------------------------------------------------- 1 | from flask import Flask,json 2 | from flask_socketio import SocketIO 3 | from meowServer import MeowAI as MA 4 | from typing import List,Dict,Mapping,Union,Callable,Any 5 | from pathlib import Path 6 | from importlib.metadata import version 7 | import packaging.version as pv 8 | import importlib 9 | import subprocess 10 | import sys 11 | 12 | def get_installed_version(package: str) -> pv.Version: 13 | try: 14 | return pv.parse(version(package)) 15 | except Exception: 16 | return pv.parse("0") 17 | 18 | def install_requirements(requirements:Path): 19 | with open(requirements) as f: 20 | for line in f: 21 | line = line.strip() 22 | if not line or line.startswith('#'): 23 | continue 24 | if '==' in line: 25 | package, version = line.split('==') 26 | if get_installed_version(package) != pv.parse(version): 27 | print(f'Installing {package}=={version}') 28 | subprocess.check_call([sys.executable, '-m', 'pip', 'install', line]) 29 | elif '>=' in line: 30 | package, version = line.split('>=') 31 | if get_installed_version(package) < pv.parse(version): 32 | print(f'Installing {package}=={version}') 33 | subprocess.check_call([sys.executable, '-m', 'pip', 'install', line]) 34 | elif get_installed_version(line)==pv.parse("0"): 35 | print(f'Installing {line}') 36 | subprocess.check_call([sys.executable, '-m', 'pip', 'install', line]) 37 | else: 38 | print(f'{line} is already installed') 39 | 40 | def load_extensions(meowAPP:Flask,meowSIO:SocketIO,meowAI:MA)->List[Path]: 41 | extensions_path = Path(__file__).parent 42 | extension_dirs = [extension_dir for extension_dir in extensions_path.iterdir() if (extension_dir / '__init__.py').exists()] 43 | print(extension_dirs) 44 | extension_infoList = [] 45 | for extension_dir in extension_dirs: 46 | requirementsPath=extension_dir.joinpath('requirements.txt') 47 | try: 48 | install_requirements(requirementsPath) 49 | except Exception as e: 50 | print(f'Error installing requirements for {extension_dir.name}: {e}') 51 | continue 52 | 53 | module_name = f'extensions.{extension_dir.name}' 54 | module = importlib.import_module(module_name) 55 | if hasattr(module, 'init_app'): 56 | module.init_app(meowAPP,meowSIO,meowAI) 57 | extension_info=extension_dir.joinpath('static/info.json') 58 | if extension_info.exists(): 59 | with open(extension_info,'r',encoding='UTF-8') as f:extension_infoList.append(json.load(f)) 60 | 61 | 62 | @meowAPP.route('/extension/infoList.json') 63 | def get_extension_infoList(): 64 | print(extension_infoList) 65 | return json.dumps(extension_infoList) 66 | 67 | return extension_dirs 68 | -------------------------------------------------------------------------------- /extensions/mai_setupButler/__init__.py: -------------------------------------------------------------------------------- 1 | from flask import Flask 2 | from flask_socketio import SocketIO 3 | from meowServer import MeowAI as MA 4 | from .setupButler import SetupButler 5 | 6 | def init_app(meowAPP:Flask,meowSIO:SocketIO,meowAI:MA): 7 | SetupButler(meowAPP,meowSIO,meowAI).run() -------------------------------------------------------------------------------- /extensions/mai_setupButler/requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/extensions/mai_setupButler/requirements.txt -------------------------------------------------------------------------------- /extensions/mai_setupButler/setupButler.py: -------------------------------------------------------------------------------- 1 | from flask import Flask,Blueprint,render_template,request 2 | from typing import List,Dict,Mapping,Union,Callable,Any 3 | from flask_socketio import SocketIO 4 | from meowServer import MeowAI as MA 5 | from rwkv.model import RWKV 6 | from rwkv.utils import PIPELINE 7 | from pathlib import Path 8 | import json,gc,torch,time 9 | 10 | 11 | rootPath = Path(__file__).parent.parent.parent 12 | configPath=rootPath.joinpath('config.json') 13 | 14 | class SetupButler(): 15 | def __init__(self,meowAPP:Flask,meowSIO:SocketIO,meowAI:MA): 16 | self.meowAI=meowAI 17 | self.meowAPP=meowAPP 18 | self.meowSIO=meowSIO 19 | self.app = Blueprint('setupButler', __name__,static_folder='static',template_folder='templates',url_prefix='/extension/setupButler') 20 | self.config={} 21 | with open(configPath,'r') as f: 22 | self.config = json.load(f) 23 | 24 | def run(self): 25 | modelsFolder=rootPath.joinpath(self.config['modelsFolder']) 26 | print(modelsFolder) 27 | @self.app.route('/config.json') 28 | def get_config(): 29 | print(modelsFolder) 30 | config={ 31 | 'host':self.config['host'], 32 | 'port':self.config['port'], 33 | 'model':str(modelsFolder.joinpath(self.config['modelFile'])), 34 | 'modelList':self.getModelsList(modelsFolder), 35 | 'strategy':self.config['strategy'], 36 | 'autoOpen':self.config['autoOpen'] 37 | } 38 | return json.dumps(config) 39 | 40 | @self.app.route('/setup',methods=['POST']) 41 | def setup(): 42 | data = request.get_json() 43 | self.config['host']=data['host'] 44 | self.config['port']=data['port'] 45 | self.config['autoOpen']=data['autoOpen'] 46 | model=Path(data['model']).name 47 | self.config['modelFile']=model 48 | self.config['strategy']=data['strategy'] 49 | with open(configPath,'w') as f:f.write(json.dumps(self.config,indent=4)) 50 | return json.dumps({'code': 0, 'msg': '设置成功,请重新启动服务'}) 51 | if model!=self.config['modelFile'] or data['strategy']!=self.config['strategy']: 52 | del self.meowAI.model.model 53 | del self.meowAI.model 54 | gc.collect() 55 | if 'cuda' in self.config['strategy']: 56 | torch.cuda.empty_cache() 57 | self.config['modelFile']=model 58 | self.config['strategy']=data['strategy'] 59 | 60 | model = RWKV(model=str(data['model']), strategy=data['strategy']) 61 | pipeline = PIPELINE(model, "rwkv_vocab_v20230424") 62 | self.meowAI.model=pipeline 63 | with open(configPath,'w') as f:f.write(json.dumps(self.config,indent=4)) 64 | return json.dumps({'code': 0, 'msg': '模型重新载入成功'}) 65 | 66 | @self.app.route('/') 67 | def index(): 68 | return render_template('mai_setupButler/index.html') 69 | 70 | self.meowAPP.register_blueprint(self.app) 71 | 72 | def getModelsList(self,modelsFolder:Path)->List[Dict[str,str]]: 73 | modelInfoList=[] 74 | for file in modelsFolder.rglob('*.pth'): 75 | modelsinfo={ 76 | "name":file.name.replace('.pth',''), 77 | "path":str(file.absolute()).replace('.pth','') 78 | } 79 | modelInfoList.append(modelsinfo) 80 | return modelInfoList 81 | 82 | -------------------------------------------------------------------------------- /extensions/mai_setupButler/static/info.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "setup", 3 | "version": "1.0.0", 4 | "description": "模型选择设置", 5 | "url": "/extension/setupButler", 6 | "logo": "logo.png" 7 | } -------------------------------------------------------------------------------- /extensions/mai_setupButler/static/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/extensions/mai_setupButler/static/logo.png -------------------------------------------------------------------------------- /extensions/mai_setupButler/templates/mai_setupButler/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 设置_MeowAI 6 | 7 | 8 | 69 | 70 | 71 |
72 |

设置

73 |
74 | 80 | 92 |
93 |
94 | 95 | 96 | 97 | 98 |
99 | 100 |
101 | 102 | 103 | 104 | 105 | 179 | 180 | -------------------------------------------------------------------------------- /extensions/mai_wechat/__init__.py: -------------------------------------------------------------------------------- 1 | from flask import Flask 2 | from flask_socketio import SocketIO 3 | from meowServer import MeowAI as MA 4 | from .server import main 5 | 6 | def init_app(meowAPP:Flask,meowSIO:SocketIO,meowAI:MA): 7 | main(meowAPP,meowSIO,meowAI) -------------------------------------------------------------------------------- /extensions/mai_wechat/requirements.txt: -------------------------------------------------------------------------------- 1 | uiautomation 2 | Pillow 3 | pywin32 4 | psutil 5 | pyperclip -------------------------------------------------------------------------------- /extensions/mai_wechat/server.py: -------------------------------------------------------------------------------- 1 | from .wxauto import WeChat,elements 2 | from flask import Flask,Blueprint,render_template,request 3 | from flask_socketio import SocketIO 4 | from threading import Thread, Event 5 | import json,time 6 | from meowServer import MeowAI as MA 7 | 8 | class WeChatServer(): 9 | def __init__(self,meowAPP:Flask,meowSIO:SocketIO, meowAI: MA): 10 | self.meowAI=meowAI 11 | self.meowAPP=meowAPP 12 | self.meowSIO=meowSIO 13 | self.stop_event = Event() 14 | self.app = Blueprint('wechat', __name__,static_folder='static',template_folder='templates',url_prefix='/extension/wechat') 15 | self.wx=None 16 | self.prev_state=None 17 | self.chats=[] 18 | @self.app.route('/') 19 | def index(): 20 | return render_template('mai_wechat/index.html') 21 | @self.app.route('/start',methods=['POST']) 22 | def start(): 23 | try: 24 | self.wx = WeChat() 25 | self.stop_event.clear() 26 | data = request.get_json() 27 | thread = Thread(target=self.run, args=(data['name'],data['chatHistory'])) 28 | thread.start() 29 | return json.dumps({'code': 0, 'msg': f'启动成功,获取到已登录窗口:{self.wx.nickname}'}) 30 | except: 31 | return json.dumps({'code': 1, 'msg': '启动失败,请检查是否已经打开并登录微信'}) 32 | @self.app.route('/stop',methods=['POST','GET']) 33 | def stop(): 34 | self.stop_event.set() 35 | self.prev_state=None 36 | self.wx=None 37 | return json.dumps({'code': 0, 'msg': '微信自动对话停止成功'}) 38 | @self.app.route('/status',methods=['POST','GET']) 39 | def status(): 40 | if self.wx: 41 | return json.dumps({'code': 1, 'msg': '已开始'}) 42 | else: 43 | return json.dumps({'code': 0, 'msg': '未开始'}) 44 | 45 | self.meowAPP.register_blueprint(self.app) 46 | self.index=index 47 | 48 | 49 | def reply(self,msg:str,chat:elements.ChatWnd)->str: 50 | ''' 51 | self.chats=[ 52 | { 53 | "role": "User", 54 | "content": "What is the meaning of life?" 55 | }, 56 | { 57 | "role": "Assistant", 58 | "content": "The meaning of life is to live a happy and fulfilling life." 59 | }, 60 | { 61 | "role": "User", 62 | "content": "How do cats call?" 63 | }, 64 | ] 65 | ''' 66 | self.chats.append({"role": "User", "content": self.meowAI.purr(msg)}) 67 | if self.prev_state: 68 | messages=self.chats[len(self.chats)-2:] 69 | else: 70 | messages=self.chats 71 | print(messages) 72 | reply,self.prev_state=self.meowAI.chat(messages,self.prev_state,lambda x:print(x[0],end='')) 73 | reply=reply.replace('\n\n','\n').replace('。','') 74 | self.chats.append({"role": "Assistant", "content": reply}) 75 | return reply 76 | 77 | def run(self,name:str,chat_history:bool=False): 78 | try: 79 | wx=self.wx 80 | wx.ChatWith(who=name) 81 | if chat_history: 82 | wx.LoadMoreMessage() 83 | for chat in wx.GetAllMessage(): 84 | if chat[0]=='SYS':pass 85 | elif chat[0]=='Self': 86 | self.chats.append({"role": "Assistant", "content": chat[1]}) 87 | else: 88 | self.chats.append({"role": "User", "content": chat[1]}) 89 | print(self.chats) 90 | wx.AddListenChat(who=name) 91 | while not self.stop_event.is_set(): 92 | chats = wx.GetListenMessage() 93 | for chat in filter(lambda x:x.who==name,chats): 94 | msgs=chats.get(chat) 95 | for msg in filter(lambda x:x.type=='friend',msgs): 96 | reply=self.reply(msg.content,chat) 97 | chat.SendMsg(reply) 98 | time.sleep(0.1) 99 | except Exception as e: 100 | self.meowSIO.emit('emit',{'code':1,'msg':'微信监听失败:'+str(e)}) 101 | self.stop_event.set() 102 | self.prev_state=None 103 | self.meowSIO.emit('wechat_stop','微信监听失败:'+str(e)) 104 | print('错误:',e) 105 | 106 | 107 | def main(meowAPP:Flask,meowSIO:SocketIO,meowAI:MA): 108 | WeChatServer(meowAPP,meowSIO,meowAI) -------------------------------------------------------------------------------- /extensions/mai_wechat/static/info.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "wechat", 3 | "version": "1.0.0", 4 | "description": "微信自动回复", 5 | "url": "/extension/wechat", 6 | "logo": "logo.png" 7 | } -------------------------------------------------------------------------------- /extensions/mai_wechat/static/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/extensions/mai_wechat/static/logo.png -------------------------------------------------------------------------------- /extensions/mai_wechat/templates/mai_wechat/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 微信自动回复_MeowAI 6 | 7 | 8 | 69 | 70 | 71 |
72 |

微信自动回复

73 |
74 | 75 | 76 | 77 | 78 | 79 |
80 |
81 |

温馨提示:

82 |

在“启动”前,请先登录电脑微信

83 |
84 |
85 | 86 | 87 | 88 | 89 | 173 | 174 | -------------------------------------------------------------------------------- /extensions/mai_wechat/wxauto/__init__.py: -------------------------------------------------------------------------------- 1 | from .wxauto import WeChat 2 | 3 | VERSION = '3.9.8.15' 4 | -------------------------------------------------------------------------------- /extensions/mai_wechat/wxauto/a.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/extensions/mai_wechat/wxauto/a.dll -------------------------------------------------------------------------------- /extensions/mai_wechat/wxauto/color.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import random 3 | import os 4 | 5 | os.system('') 6 | 7 | color_dict = { 8 | 'BLACK': '\x1b[30m', 9 | 'BLUE': '\x1b[34m', 10 | 'CYAN': '\x1b[36m', 11 | 'GREEN': '\x1b[32m', 12 | 'LIGHTBLACK_EX': '\x1b[90m', 13 | 'LIGHTBLUE_EX': '\x1b[94m', 14 | 'LIGHTCYAN_EX': '\x1b[96m', 15 | 'LIGHTGREEN_EX': '\x1b[92m', 16 | 'LIGHTMAGENTA_EX': '\x1b[95m', 17 | 'LIGHTRED_EX': '\x1b[91m', 18 | 'LIGHTWHITE_EX': '\x1b[97m', 19 | 'LIGHTYELLOW_EX': '\x1b[93m', 20 | 'MAGENTA': '\x1b[35m', 21 | 'RED': '\x1b[31m', 22 | 'WHITE': '\x1b[37m', 23 | 'YELLOW': '\x1b[33m' 24 | } 25 | 26 | color_reset = '\x1b[0m' 27 | 28 | class Print: 29 | @staticmethod 30 | def black(text, *args, **kwargs): 31 | print(color_dict['BLACK'] + text + color_reset, *args, **kwargs) 32 | 33 | @staticmethod 34 | def blue(text, *args, **kwargs): 35 | print(color_dict['BLUE'] + text + color_reset, *args, **kwargs) 36 | 37 | @staticmethod 38 | def cyan(text, *args, **kwargs): 39 | print(color_dict['CYAN'] + text + color_reset, *args, **kwargs) 40 | 41 | @staticmethod 42 | def green(text, *args, **kwargs): 43 | print(color_dict['GREEN'] + text + color_reset, *args, **kwargs) 44 | 45 | @staticmethod 46 | def lightblack(text, *args, **kwargs): 47 | print(color_dict['LIGHTBLACK_EX'] + text + color_reset, *args, **kwargs) 48 | 49 | @staticmethod 50 | def lightblue(text, *args, **kwargs): 51 | print(color_dict['LIGHTBLUE_EX'] + text + color_reset, *args, **kwargs) 52 | 53 | @staticmethod 54 | def lightcyan(text, *args, **kwargs): 55 | print(color_dict['LIGHTCYAN_EX'] + text + color_reset, *args, **kwargs) 56 | 57 | @staticmethod 58 | def lightgreen(text, *args, **kwargs): 59 | print(color_dict['LIGHTGREEN_EX'] + text + color_reset, *args, **kwargs) 60 | 61 | @staticmethod 62 | def lightmagenta(text, *args, **kwargs): 63 | print(color_dict['LIGHTMAGENTA_EX'] + text + color_reset, *args, **kwargs) 64 | 65 | @staticmethod 66 | def lightred(text, *args, **kwargs): 67 | print(color_dict['LIGHTRED_EX'] + text + color_reset, *args, **kwargs) 68 | 69 | @staticmethod 70 | def lightwhite(text, *args, **kwargs): 71 | print(color_dict['LIGHTWHITE_EX'] + text + color_reset, *args, **kwargs) 72 | 73 | @staticmethod 74 | def lightyellow(text, *args, **kwargs): 75 | print(color_dict['LIGHTYELLOW_EX'] + text + color_reset, *args, **kwargs) 76 | 77 | @staticmethod 78 | def magenta(text, *args, **kwargs): 79 | print(color_dict['MAGENTA'] + text + color_reset, *args, **kwargs) 80 | 81 | @staticmethod 82 | def red(text, *args, **kwargs): 83 | print(color_dict['RED'] + text + color_reset, *args, **kwargs) 84 | 85 | @staticmethod 86 | def white(text, *args, **kwargs): 87 | print(color_dict['WHITE'] + text + color_reset, *args, **kwargs) 88 | 89 | @staticmethod 90 | def yellow(text, *args, **kwargs): 91 | print(color_dict['YELLOW'] + text + color_reset, *args, **kwargs) 92 | 93 | @staticmethod 94 | def random(text, *args, **kwargs): 95 | print(random.choice(list(color_dict.values())) + text + color_reset, *args, **kwargs) 96 | 97 | 98 | class Input: 99 | @staticmethod 100 | def black(text, *args, **kwargs): 101 | print(color_dict['BLACK'] + text + color_reset, end='') 102 | result = input(*args, **kwargs) 103 | return result 104 | 105 | @staticmethod 106 | def blue(text, *args, **kwargs): 107 | print(color_dict['BLUE'] + text + color_reset, end='') 108 | result = input(*args, **kwargs) 109 | return result 110 | 111 | @staticmethod 112 | def cyan(text, *args, **kwargs): 113 | print(color_dict['CYAN'] + text + color_reset, end='') 114 | result = input(*args, **kwargs) 115 | return result 116 | 117 | @staticmethod 118 | def green(text, *args, **kwargs): 119 | print(color_dict['GREEN'] + text + color_reset, end='') 120 | result = input(*args, **kwargs) 121 | return result 122 | 123 | @staticmethod 124 | def lightblack(text, *args, **kwargs): 125 | print(color_dict['LIGHTBLACK_EX'] + text + color_reset, end='') 126 | result = input(*args, **kwargs) 127 | return result 128 | 129 | @staticmethod 130 | def lightblue(text, *args, **kwargs): 131 | print(color_dict['LIGHTBLUE_EX'] + text + color_reset, end='') 132 | result = input(*args, **kwargs) 133 | return result 134 | 135 | @staticmethod 136 | def lightcyan(text, *args, **kwargs): 137 | print(color_dict['LIGHTCYAN_EX'] + text + color_reset, end='') 138 | result = input(*args, **kwargs) 139 | return result 140 | 141 | @staticmethod 142 | def lightgreen(text, *args, **kwargs): 143 | print(color_dict['LIGHTGREEN_EX'] + text + color_reset, end='') 144 | result = input(*args, **kwargs) 145 | return result 146 | 147 | @staticmethod 148 | def lightmagenta(text, *args, **kwargs): 149 | print(color_dict['LIGHTMAGENTA_EX'] + text + color_reset, end='') 150 | result = input(*args, **kwargs) 151 | return result 152 | 153 | @staticmethod 154 | def lightred(text, *args, **kwargs): 155 | print(color_dict['LIGHTRED_EX'] + text + color_reset, end='') 156 | result = input(*args, **kwargs) 157 | return result 158 | 159 | @staticmethod 160 | def lightwhite(text, *args, **kwargs): 161 | print(color_dict['LIGHTWHITE_EX'] + text + color_reset, end='') 162 | result = input(*args, **kwargs) 163 | return result 164 | 165 | @staticmethod 166 | def lightyellow(text, *args, **kwargs): 167 | print(color_dict['LIGHTYELLOW_EX'] + text + color_reset, end='') 168 | result = input(*args, **kwargs) 169 | return result 170 | 171 | @staticmethod 172 | def magenta(text, *args, **kwargs): 173 | print(color_dict['MAGENTA'] + text + color_reset, end='') 174 | result = input(*args, **kwargs) 175 | return result 176 | 177 | @staticmethod 178 | def red(text, *args, **kwargs): 179 | print(color_dict['RED'] + text + color_reset, end='') 180 | result = input(*args, **kwargs) 181 | return result 182 | 183 | @staticmethod 184 | def white(text, *args, **kwargs): 185 | print(color_dict['WHITE'] + text + color_reset, end='') 186 | result = input(*args, **kwargs) 187 | return result 188 | 189 | @staticmethod 190 | def yellow(text, *args, **kwargs): 191 | print(color_dict['YELLOW'] + text + color_reset, end='') 192 | result = input(*args, **kwargs) 193 | return result 194 | 195 | @staticmethod 196 | def random(text, *args, **kwargs): 197 | print(random.choice(list(color_dict.values())) + text + color_reset, end='') 198 | result = input(*args, **kwargs) 199 | return result 200 | 201 | 202 | class Warnings: 203 | @staticmethod 204 | def black(text, *args, **kwargs): 205 | warnings.warn('\n' + color_dict['BLACK'] + text + color_reset, *args, **kwargs) 206 | 207 | @staticmethod 208 | def blue(text, *args, **kwargs): 209 | warnings.warn('\n' + color_dict['BLUE'] + text + color_reset, *args, **kwargs) 210 | 211 | @staticmethod 212 | def cyan(text, *args, **kwargs): 213 | warnings.warn('\n' + color_dict['CYAN'] + text + color_reset, *args, **kwargs) 214 | 215 | @staticmethod 216 | def green(text, *args, **kwargs): 217 | warnings.warn('\n' + color_dict['GREEN'] + text + color_reset, *args, **kwargs) 218 | 219 | @staticmethod 220 | def lightblack(text, *args, **kwargs): 221 | warnings.warn('\n' + color_dict['LIGHTBLACK_EX'] + text + color_reset, *args, **kwargs) 222 | 223 | @staticmethod 224 | def lightblue(text, *args, **kwargs): 225 | warnings.warn('\n' + color_dict['LIGHTBLUE_EX'] + text + color_reset, *args, **kwargs) 226 | 227 | @staticmethod 228 | def lightcyan(text, *args, **kwargs): 229 | warnings.warn('\n' + color_dict['LIGHTCYAN_EX'] + text + color_reset, *args, **kwargs) 230 | 231 | @staticmethod 232 | def lightgreen(text, *args, **kwargs): 233 | warnings.warn('\n' + color_dict['LIGHTGREEN_EX'] + text + color_reset, *args, **kwargs) 234 | 235 | @staticmethod 236 | def lightmagenta(text, *args, **kwargs): 237 | warnings.warn('\n' + color_dict['LIGHTMAGENTA_EX'] + text + color_reset, *args, **kwargs) 238 | 239 | @staticmethod 240 | def lightred(text, *args, **kwargs): 241 | warnings.warn('\n' + color_dict['LIGHTRED_EX'] + text + color_reset, *args, **kwargs) 242 | 243 | @staticmethod 244 | def lightwhite(text, *args, **kwargs): 245 | warnings.warn('\n' + color_dict['LIGHTWHITE_EX'] + text + color_reset, *args, **kwargs) 246 | 247 | @staticmethod 248 | def lightyellow(text, *args, **kwargs): 249 | warnings.warn('\n' + color_dict['LIGHTYELLOW_EX'] + text + color_reset, *args, **kwargs) 250 | 251 | @staticmethod 252 | def magenta(text, *args, **kwargs): 253 | warnings.warn('\n' + color_dict['MAGENTA'] + text + color_reset, *args, **kwargs) 254 | -------------------------------------------------------------------------------- /extensions/mai_wechat/wxauto/errors.py: -------------------------------------------------------------------------------- 1 | 2 | class TargetNotFoundError(Exception): 3 | pass -------------------------------------------------------------------------------- /extensions/mai_wechat/wxauto/languages.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 多语言关键字尚未收集完整,欢迎多多pull requests帮忙补充,感谢 3 | ''' 4 | 5 | MAIN_LANGUAGE = { 6 | # 导航栏 7 | '导航': {'cn': '导航', 'cn_t': '導航', 'en': 'Navigation'}, 8 | '聊天': {'cn': '聊天', 'cn_t': '聊天', 'en': 'Chats'}, 9 | '通讯录': {'cn': '通讯录', 'cn_t': '通訊錄', 'en': 'Contacts'}, 10 | '收藏': {'cn': '收藏', 'cn_t': '收藏', 'en': 'Favorites'}, 11 | '聊天文件': {'cn': '聊天文件', 'cn_t': '聊天室檔案', 'en': 'Chat Files'}, 12 | '朋友圈': {'cn': '朋友圈', 'cn_t': '朋友圈', 'en': 'Moments'}, 13 | '小程序面板': {'cn': '小程序面板', 'cn_t': '小程式面板', 'en': 'Mini Programs Panel'}, 14 | '手机': {'cn': '手机', 'cn_t': '手機', 'en': 'Phone'}, 15 | '设置及其他': {'cn': '设置及其他', 'cn_t': '設定與其他', 'en': 'Settings and Others'}, 16 | 17 | # 好友列表栏 18 | '搜索': {'cn': '搜索', 'cn_t': '搜尋', 'en': 'Search'}, 19 | '发起群聊': {'cn': '发起群聊', 'cn_t': '建立群組', 'en': 'Start Group Chat'}, 20 | '文件传输助手': {'cn': '文件传输助手', 'cn_t': '檔案傳輸', 'en': 'File Transfer'}, 21 | '订阅号': {'cn': '订阅号', 'cn_t': '官方賬號', 'en': 'Subscriptions'}, 22 | '消息': {'cn': '消息', 'cn_t': '消息', 'en': ''}, 23 | 24 | # 右上角工具栏 25 | '置顶': {'cn': '置顶', 'cn_t': '置頂', 'en': 'Sticky on Top'}, 26 | '最小化': {'cn': '最小化', 'cn_t': '最小化', 'en': 'Minimize'}, 27 | '最大化': {'cn': '最大化', 'cn_t': '最大化', 'en': ''}, 28 | '关闭': {'cn': '关闭', 'cn_t': '關閉', 'en': ''}, 29 | 30 | # 聊天框 31 | '聊天信息': {'cn': '聊天信息', 'cn_t': '聊天資訊', 'en': 'Chat Info'}, 32 | '表情': {'cn': '表情', 'cn_t': '貼圖', 'en': 'Sticker'}, 33 | '发送文件': {'cn': '发送文件', 'cn_t': '傳送檔案', 'en': 'Send File'}, 34 | '截图': {'cn': '截图', 'cn_t': '截圖', 'en': 'Screenshot'}, 35 | '聊天记录': {'cn': '聊天记录', 'cn_t': '聊天記錄', 'en': 'Chat History'}, 36 | '语音聊天': {'cn': '语音聊天', 'cn_t': '語音通話', 'en': 'Voice Call'}, 37 | '视频聊天': {'cn': '视频聊天', 'cn_t': '視頻通話', 'en': 'Video Call'}, 38 | '发送': {'cn': '发送', 'cn_t': '傳送', 'en': 'Send'}, 39 | '输入': {'cn': '输入', 'cn_t': '輸入', 'en': 'Enter'}, 40 | 41 | # 消息类型 42 | '链接': {'cn': '链接', 'cn_t': '鏈接', 'en': 'Link'}, 43 | '视频': {'cn': '视频', 'cn_t': '視頻', 'en': 'Video'}, 44 | '图片': {'cn': '图片', 'cn_t': '圖片', 'en': 'Photo'}, 45 | '文件': {'cn': '文件', 'cn_t': '文件', 'en': 'File'}, 46 | '': {'cn': '', 'cn_t': '', 'en': ''}} 47 | 48 | 49 | IMAGE_LANGUAGE = { 50 | '上一张': {'cn': '上一张', 'cn_t': '上一張', 'en': 'Previous'}, 51 | '下一张': {'cn': '下一张', 'cn_t': '下一張', 'en': 'Next'}, 52 | '预览': {'cn': '预览', 'cn_t': '預覽', 'en': 'Preview'}, 53 | '放大': {'cn': '放大', 'cn_t': '放大', 'en': 'Zoom'}, 54 | '缩小': {'cn': '缩小', 'cn_t': '縮小', 'en': 'Shrink'}, 55 | '图片原始大小': {'cn': '图片原始大小', 'cn_t': '圖片原始大小', 'en': 'Original image size'}, 56 | '旋转': {'cn': '旋转', 'cn_t': '旋轉', 'en': 'Rotate'}, 57 | '编辑': {'cn': '编辑', 'cn_t': '編輯', 'en': 'Edit'}, 58 | '翻译': {'cn': '翻译', 'cn_t': '翻譯', 'en': 'Translate'}, 59 | '提取文字': {'cn': '提取文字', 'cn_t': '提取文字', 'en': 'Extract Text'}, 60 | '识别图中二维码': {'cn': '识别图中二维码', 'cn_t': '識别圖中QR Code', 'en': 'Extract QR Code'}, 61 | '另存为...': {'cn': '另存为...', 'cn_t': '另存爲...', 'en': 'Save as…'}, 62 | '更多': {'cn': '更多', 'cn_t': '更多', 'en': 'More'}, 63 | '最小化': {'cn': '最小化', 'cn_t': '最小化', 'en': 'Minimize'}, 64 | '最大化': {'cn': '最大化', 'cn_t': '最大化', 'en': 'Maximize'}, 65 | '关闭': {'cn': '关闭', 'cn_t': '關閉', 'en': 'Close'}, 66 | '': {'cn': '', 'cn_t': '', 'en': ''}} 67 | 68 | FILE_LANGUAGE = { 69 | '最小化': {'cn': '最小化', 'cn_t': '最小化', 'en': 'Minimize'}, 70 | '最大化': {'cn': '最大化', 'cn_t': '最大化', 'en': 'Maximize'}, 71 | '关闭': {'cn': '关闭', 'cn_t': '關閉', 'en': 'Close'}, 72 | '全部': {'cn': '全部', 'cn_t': '全部', 'en': 'All'}, 73 | '最近使用': {'cn': '最近使用', 'cn_t': '最近使用', 'en': 'Recent'}, 74 | '发送者': {'cn': '发送者', 'cn_t': '發送者', 'en': 'Sender'}, 75 | '聊天': {'cn': '聊天', 'cn_t': '聊天', 'en': 'Chat'}, 76 | '类型': {'cn': '类型', 'cn_t': '類型', 'en': 'Type'}, 77 | '按最新时间': {'cn': '按最新时间', 'cn_t': '按最新時間', 'en': 'Sort by Newest'}, 78 | '按最旧时间': {'cn': '按最旧时间', 'cn_t': '按最舊時間', 'en': 'Sort by Oldest'}, 79 | '按从大到小': {'cn': '按从大到小', 'cn_t': '按從大到小', 'en': 'Sort by Largest'}, 80 | '按从小到大': {'cn': '按从小到大', 'cn_t': '按從小到大', 'en': 'Sort by Smallest'}, 81 | '': {'cn': '', 'cn_t': '', 'en': ''} 82 | } 83 | 84 | WARNING = { 85 | '版本不一致': { 86 | 'cn': '当前微信客户端版本为{},与当前库版本{}不一致,可能会导致部分功能无法正常使用,请注意判断', 87 | 'cn_t': '當前微信客戶端版本為{},與當前庫版本{}不一致,可能會導致部分功能無法正常使用,請注意判斷', 88 | 'en': 'The current WeChat client version is {}, which is inconsistent with the current library version {}, which may cause some functions to fail to work properly. Please pay attention to judgment' 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /extensions/mai_wechat/wxauto/utils.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | from . import uiautomation as uia 3 | from PIL import ImageGrab 4 | import win32clipboard 5 | import win32process 6 | import win32gui 7 | import win32api 8 | import win32con 9 | import pyperclip 10 | import ctypes 11 | import psutil 12 | import shutil 13 | import winreg 14 | import time 15 | import os 16 | import re 17 | 18 | VERSION = "3.9.8.15" 19 | 20 | def set_cursor_pos(x, y): 21 | win32api.SetCursorPos((x, y)) 22 | 23 | def Click(rect): 24 | x = (rect.left + rect.right) // 2 25 | y = (rect.top + rect.bottom) // 2 26 | set_cursor_pos(x, y) 27 | win32api.mouse_event(win32con.MOUSEEVENTF_LEFTDOWN, x, y, 0, 0) 28 | win32api.mouse_event(win32con.MOUSEEVENTF_LEFTUP, x, y, 0, 0) 29 | 30 | def GetPathByHwnd(hwnd): 31 | try: 32 | thread_id, process_id = win32process.GetWindowThreadProcessId(hwnd) 33 | process = psutil.Process(process_id) 34 | return process.exe() 35 | except Exception as e: 36 | print(f"Error: {e}") 37 | return None 38 | 39 | def GetVersionByPath(file_path): 40 | try: 41 | info = win32api.GetFileVersionInfo(file_path, '\\') 42 | version = "{}.{}.{}.{}".format(win32api.HIWORD(info['FileVersionMS']), 43 | win32api.LOWORD(info['FileVersionMS']), 44 | win32api.HIWORD(info['FileVersionLS']), 45 | win32api.LOWORD(info['FileVersionLS'])) 46 | except: 47 | version = None 48 | return version 49 | 50 | 51 | def IsRedPixel(uicontrol): 52 | rect = uicontrol.BoundingRectangle 53 | bbox = (rect.left, rect.top, rect.right, rect.bottom) 54 | img = ImageGrab.grab(bbox=bbox, all_screens=True) 55 | return any(p[0] > p[1] and p[0] > p[2] for p in img.getdata()) 56 | 57 | class DROPFILES(ctypes.Structure): 58 | _fields_ = [ 59 | ("pFiles", ctypes.c_uint), 60 | ("x", ctypes.c_long), 61 | ("y", ctypes.c_long), 62 | ("fNC", ctypes.c_int), 63 | ("fWide", ctypes.c_bool), 64 | ] 65 | 66 | pDropFiles = DROPFILES() 67 | pDropFiles.pFiles = ctypes.sizeof(DROPFILES) 68 | pDropFiles.fWide = True 69 | matedata = bytes(pDropFiles) 70 | 71 | def SetClipboardText(text: str): 72 | pyperclip.copy(text) 73 | # if not isinstance(text, str): 74 | # raise TypeError(f"参数类型必须为str --> {text}") 75 | # t0 = time.time() 76 | # while True: 77 | # if time.time() - t0 > 10: 78 | # raise TimeoutError(f"设置剪贴板超时! --> {text}") 79 | # try: 80 | # win32clipboard.OpenClipboard() 81 | # win32clipboard.EmptyClipboard() 82 | # win32clipboard.SetClipboardData(win32con.CF_UNICODETEXT, text) 83 | # break 84 | # except: 85 | # pass 86 | # finally: 87 | # try: 88 | # win32clipboard.CloseClipboard() 89 | # except: 90 | # pass 91 | 92 | try: 93 | from anytree import Node, RenderTree 94 | 95 | def GetAllControl(ele): 96 | def findall(ele, n=0, node=None): 97 | nn = '\n' 98 | nodename = f"[{ele.ControlTypeName}](\"{ele.ClassName}\", \"{ele.Name.replace(nn, '')}\")" 99 | if not node: 100 | node1 = Node(nodename) 101 | else: 102 | node1 = Node(nodename, parent=node) 103 | eles = ele.GetChildren() 104 | for ele1 in eles: 105 | findall(ele1, n+1, node1) 106 | return node1 107 | tree = RenderTree(findall(ele)) 108 | for pre, fill, node in tree: 109 | print(f"{pre}{node.name}") 110 | except: 111 | pass 112 | 113 | def SetClipboardFiles(paths): 114 | for file in paths: 115 | if not os.path.exists(file): 116 | raise FileNotFoundError(f"file ({file}) not exists!") 117 | files = ("\0".join(paths)).replace("/", "\\") 118 | data = files.encode("U16")[2:]+b"\0\0" 119 | t0 = time.time() 120 | while True: 121 | if time.time() - t0 > 10: 122 | raise TimeoutError(f"设置剪贴板文件超时! --> {paths}") 123 | try: 124 | win32clipboard.OpenClipboard() 125 | win32clipboard.EmptyClipboard() 126 | win32clipboard.SetClipboardData(win32clipboard.CF_HDROP, matedata+data) 127 | break 128 | except: 129 | pass 130 | finally: 131 | try: 132 | win32clipboard.CloseClipboard() 133 | except: 134 | pass 135 | 136 | def PasteFile(folder): 137 | folder = os.path.realpath(folder) 138 | if not os.path.exists(folder): 139 | os.makedirs(folder) 140 | 141 | t0 = time.time() 142 | while True: 143 | if time.time() - t0 > 10: 144 | raise TimeoutError(f"读取剪贴板文件超时!") 145 | try: 146 | win32clipboard.OpenClipboard() 147 | if win32clipboard.IsClipboardFormatAvailable(win32clipboard.CF_HDROP): 148 | files = win32clipboard.GetClipboardData(win32clipboard.CF_HDROP) 149 | for file in files: 150 | filename = os.path.basename(file) 151 | dest_file = os.path.join(folder, filename) 152 | shutil.copy2(file, dest_file) 153 | return True 154 | else: 155 | print("剪贴板中没有文件") 156 | return False 157 | except: 158 | pass 159 | finally: 160 | win32clipboard.CloseClipboard() 161 | 162 | def GetText(HWND): 163 | length = win32gui.SendMessage(HWND, win32con.WM_GETTEXTLENGTH)*2 164 | buffer = win32gui.PyMakeBuffer(length) 165 | win32api.SendMessage(HWND, win32con.WM_GETTEXT, length, buffer) 166 | address, length_ = win32gui.PyGetBufferAddressAndLen(buffer[:-1]) 167 | text = win32gui.PyGetString(address, length_)[:int(length/2)] 168 | buffer.release() 169 | return text 170 | 171 | def GetAllWindowExs(HWND): 172 | if not HWND: 173 | return 174 | handles = [] 175 | win32gui.EnumChildWindows( 176 | HWND, lambda hwnd, param: param.append([hwnd, win32gui.GetClassName(hwnd), GetText(hwnd)]), handles) 177 | return handles 178 | 179 | def FindWindow(classname=None, name=None) -> int: 180 | return win32gui.FindWindow(classname, name) 181 | 182 | def FindWinEx(HWND, classname=None, name=None) -> list: 183 | hwnds_classname = [] 184 | hwnds_name = [] 185 | def find_classname(hwnd, classname): 186 | classname_ = win32gui.GetClassName(hwnd) 187 | if classname_ == classname: 188 | if hwnd not in hwnds_classname: 189 | hwnds_classname.append(hwnd) 190 | def find_name(hwnd, name): 191 | name_ = GetText(hwnd) 192 | if name in name_: 193 | if hwnd not in hwnds_name: 194 | hwnds_name.append(hwnd) 195 | if classname: 196 | win32gui.EnumChildWindows(HWND, find_classname, classname) 197 | if name: 198 | win32gui.EnumChildWindows(HWND, find_name, name) 199 | if classname and name: 200 | hwnds = [hwnd for hwnd in hwnds_classname if hwnd in hwnds_name] 201 | else: 202 | hwnds = hwnds_classname + hwnds_name 203 | return hwnds 204 | 205 | def ClipboardFormats(unit=0, *units): 206 | units = list(units) 207 | win32clipboard.OpenClipboard() 208 | u = win32clipboard.EnumClipboardFormats(unit) 209 | win32clipboard.CloseClipboard() 210 | units.append(u) 211 | if u: 212 | units = ClipboardFormats(u, *units) 213 | return units 214 | 215 | def ReadClipboardData(): 216 | Dict = {} 217 | for i in ClipboardFormats(): 218 | if i == 0: 219 | continue 220 | win32clipboard.OpenClipboard() 221 | try: 222 | filenames = win32clipboard.GetClipboardData(i) 223 | win32clipboard.CloseClipboard() 224 | except: 225 | win32clipboard.CloseClipboard() 226 | raise ValueError 227 | Dict[str(i)] = filenames 228 | return Dict 229 | 230 | def ParseWeChatTime(time_str): 231 | """ 232 | 时间格式转换函数 233 | 234 | Args: 235 | time_str: 输入的时间字符串 236 | 237 | Returns: 238 | 转换后的时间字符串 239 | """ 240 | 241 | match = re.match(r'^(\d{1,2}):(\d{1,2})$', time_str) 242 | if match: 243 | hour, minute = match.groups() 244 | return datetime.now().strftime('%Y-%m-%d') + f' {hour}:{minute}' 245 | 246 | match = re.match(r'^昨天 (\d{1,2}):(\d{1,2})$', time_str) 247 | if match: 248 | hour, minute = match.groups() 249 | yesterday = datetime.now() - timedelta(days=1) 250 | return yesterday.strftime('%Y-%m-%d') + f' {hour}:{minute}' 251 | 252 | match = re.match(r'^星期([一二三四五六日]) (\d{1,2}):(\d{1,2})$', time_str) 253 | if match: 254 | weekday, hour, minute = match.groups() 255 | weekday_num = ['一', '二', '三', '四', '五', '六', '日'].index(weekday) 256 | today_weekday = datetime.now().weekday() 257 | delta_days = (today_weekday - weekday_num) % 7 258 | target_day = datetime.now() - timedelta(days=delta_days) 259 | return target_day.strftime('%Y-%m-%d') + f' {hour}:{minute}' 260 | 261 | match = re.match(r'^(\d{4})年(\d{1,2})月(\d{1,2})日 (\d{1,2}):(\d{1,2})$', time_str) 262 | if match: 263 | year, month, day, hour, minute = match.groups() 264 | return datetime(*[int(i) for i in [year, month, day, hour, minute]]).strftime('%Y-%m-%d') + f' {hour}:{minute}' 265 | 266 | 267 | def FindPid(process_name): 268 | procs = psutil.process_iter(['pid', 'name']) 269 | for proc in procs: 270 | if process_name in proc.info['name']: 271 | return proc.info['pid'] 272 | 273 | 274 | def Mver(pid): 275 | exepath = psutil.Process(pid).exe() 276 | if GetVersionByPath(exepath) != VERSION: 277 | Warning(f"该修复方法仅适用于版本号为{VERSION}的微信!") 278 | return 279 | if not uia.Control(ClassName='WeChatLoginWndForPC', searchDepth=1).Exists(maxSearchSeconds=2): 280 | Warning("请先打开微信启动页面再次尝试运行该方法!") 281 | return 282 | path = os.path.join(os.path.dirname(__file__), 'a.dll') 283 | dll = ctypes.WinDLL(path) 284 | dll.GetDllBaseAddress.argtypes = [ctypes.c_uint, ctypes.c_wchar_p] 285 | dll.GetDllBaseAddress.restype = ctypes.c_void_p 286 | dll.WriteMemory.argtypes = [ctypes.c_ulong, ctypes.c_void_p, ctypes.c_ulong] 287 | dll.WriteMemory.restype = ctypes.c_bool 288 | dll.GetMemory.argtypes = [ctypes.c_ulong, ctypes.c_void_p] 289 | dll.GetMemory.restype = ctypes.c_ulong 290 | mname = 'WeChatWin.dll' 291 | tar = 1661536787 292 | base_address = dll.GetDllBaseAddress(pid, mname) 293 | address = base_address + 64761648 294 | if dll.GetMemory(pid, address) != tar: 295 | dll.WriteMemory(pid, address, tar) 296 | handle = ctypes.c_void_p(dll._handle) 297 | ctypes.windll.kernel32.FreeLibrary(handle) 298 | 299 | def FixVersionError(): 300 | """修复版本低无法登录的问题""" 301 | pid = FindPid('WeChat.exe') 302 | if pid: 303 | Mver(pid) 304 | return 305 | else: 306 | try: 307 | registry_key = winreg.OpenKey(winreg.HKEY_CURRENT_USER, r"Software\Tencent\WeChat", 0, winreg.KEY_READ) 308 | path, _ = winreg.QueryValueEx(registry_key, "InstallPath") 309 | winreg.CloseKey(registry_key) 310 | wxpath = os.path.join(path, "WeChat.exe") 311 | if os.path.exists(wxpath): 312 | os.system(f'start "" "{wxpath}"') 313 | FixVersionError() 314 | else: 315 | raise Exception('nof found') 316 | except WindowsError: 317 | Warning("未找到微信安装路径,请先打开微信启动页面再次尝试运行该方法!") 318 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from rwkv.model import RWKV 2 | from rwkv.utils import PIPELINE 3 | from meowServer import MeowAIServer,Character,MeowAI 4 | from typing import List,Dict,Mapping,Union,Callable,Any 5 | import torch,os,json 6 | from pathlib import Path 7 | 8 | os.environ['RWKV_JIT_ON'] = '1' 9 | os.environ["RWKV_CUDA_ON"] = '0' 10 | 11 | config=None 12 | modelsFolder="" 13 | modelFile="" 14 | rootPath=Path(__file__).parent 15 | configPath=rootPath.joinpath('config.json') 16 | 17 | def getModelsList(modelsFolder:Path)->List[Dict[str,str]]: 18 | modelInfoList=[] 19 | for file in modelsFolder.rglob('*.pth'): 20 | modelsinfo={ 21 | "name":file.name.replace('.pth',''), 22 | "path":str(file.absolute()).replace('.pth','') 23 | } 24 | modelInfoList.append(modelsinfo) 25 | return modelInfoList 26 | 27 | if not configPath.exists():raise FileNotFoundError(f"读取config.json失败,没有找到默认配置文件!\n请在项目根目录下创建config.json文件,并参照config.json.example文件填写配置项!\n配置文件路径:{configPath}") 28 | with open(configPath,'r') as f:config = json.load(f) 29 | modelsFolder=Path(config['modelsFolder']) 30 | if not modelsFolder.exists():raise FileNotFoundError(f"读取models文件夹失败,没有找到models文件夹!\n请在项目根目录下创建models文件夹,并放入模型文件!\n配置文件路径:{modelsFolder}") 31 | modelFile=modelsFolder.joinpath(config['modelFile']) 32 | modelInfoList=getModelsList(modelsFolder) 33 | config['strategy']=config['strategy'] if torch.cuda.is_available() else 'cpu fp32' 34 | 35 | if not os.path.exists(str(modelFile)+'.pth'): 36 | if len(modelInfoList)==0:raise FileNotFoundError(f"没有找到模型文件!\n请在项目根目录下创建models文件夹,并放入模型文件!\n配置文件路径:{modelsFolder}") 37 | print(f"模型文件不存在,请先下载{config['modelFile']}模型\n现将自动使用{modelInfoList[0]['name']}模型") 38 | modelFile=modelInfoList[0]['path'] 39 | config["modelFile"]=modelInfoList[0]['name'] 40 | with open(configPath, 'w') as file:file.write(json.dumps(config, indent=4)) 41 | 42 | model = RWKV(model=str(modelFile), strategy=config['strategy']) 43 | if model.version == 7: 44 | import sys 45 | sys.modules.pop("rwkv.model") 46 | os.environ["RWKV_V7_ON"] = "1" 47 | from rwkv.model import RWKV 48 | model = RWKV(model=str(modelFile), strategy=config['strategy']) 49 | pipeline = PIPELINE(model, "rwkv_vocab_v20230424") 50 | meowAI=MeowAI(pipeline) 51 | meowAIServer=MeowAIServer(meowAI,host=config['host'],port=int(config['port']),autoOpen=config['autoOpen'],debug=True) 52 | 53 | 54 | if __name__ == '__main__': 55 | meowAIServer.run() 56 | 57 | -------------------------------------------------------------------------------- /meowServer.py: -------------------------------------------------------------------------------- 1 | import re,json,torch 2 | from tqdm import tqdm 3 | from rwkv.utils import PIPELINE, PIPELINE_ARGS 4 | from typing import List,Dict,Mapping,Union,Callable,Any,Tuple 5 | 6 | class Character(PIPELINE_ARGS): 7 | def __init__(self,persona:str='主人你好呀!我是你的可爱猫娘,喵~', temperature:float=1.1,top_p:float=0.7, top_k:float=0, alpha_frequency:float=0.2, alpha_presence:float=0.2, alpha_decay:float=0.996, token_ban:list=[], token_stop:list=[], chunk_len=256): 8 | super().__init__(temperature, top_p, top_k, alpha_frequency, alpha_presence, alpha_decay, token_ban, token_stop, chunk_len) 9 | self.persona=persona 10 | 11 | 12 | class MeowAI(): 13 | ''' 14 | 内容生成 15 | ''' 16 | def __init__(self,model:PIPELINE, character:Character=Character(),max_tokens:int=2048 ): 17 | self.model=model 18 | self.character=character 19 | self.max_tokens=max_tokens 20 | self.stop=False 21 | self.chat_state=None 22 | self.talk_state=None 23 | 24 | def purr(self,txt:str)->str: 25 | return re.sub(r'\r*\n{2,}', '\n', txt.strip()) 26 | 27 | def chat(self,messages:List[Mapping[str,str]],prev_state:torch.Tensor=None,callback:Callable[[str],Any]=None)->Tuple[str,torch.Tensor]: 28 | ''' 29 | messages=[ 30 | { 31 | "role": "User", 32 | "content": "What is the meaning of life?" 33 | }, 34 | { 35 | "role": "Assistant", 36 | "content": "The meaning of life is to live a happy and fulfilling life." 37 | }, 38 | { 39 | "role": "User", 40 | "content": "How do cats call?" 41 | }, 42 | ] 43 | ''' 44 | out_tokens = [] 45 | out_len = 0 46 | out_str = "" 47 | occurrence = {} 48 | state = prev_state 49 | input_text = "\n\n".join([f"{text['role']}:{self.purr(text['content'])}" for text in messages]) 50 | if state: 51 | input_text = f"\n\n{input_text}\n\nAnswerer:" 52 | else: 53 | input_text = f"Answerer:{self.purr(self.character.persona)}\n\n{input_text}\n\nAnswerer:" 54 | for i in tqdm(range(self.max_tokens),desc=f"tokens",leave=False): 55 | if self.stop:break 56 | if i == 0: 57 | out, state = self.model.model.forward(self.model.encode(input_text), state) 58 | else: 59 | out, state = self.model.model.forward([token], state) 60 | for n in occurrence: 61 | out[n] -= ( 62 | self.character.alpha_frequency + occurrence[n] * self.character.alpha_presence 63 | ) 64 | 65 | token = self.model.sample_logits(out, temperature=self.character.temperature, top_p=self.character.top_p) 66 | if token == 0:break 67 | out_tokens += [token] 68 | 69 | for n in occurrence: 70 | occurrence[n] *= self.character.alpha_decay 71 | occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0) 72 | 73 | tmp = self.model.decode(out_tokens[out_len:]) 74 | out_str += tmp 75 | # self.chat_state=state 76 | if ("\ufffd" not in tmp) and ( not tmp.endswith("\n") ): 77 | out_len = i + 1 78 | elif "\n\n" in tmp: 79 | break 80 | if callback and not self.stop: callback(tmp) 81 | print(out_str.strip()) 82 | return out_str.strip(),state 83 | 84 | def talk(self,input_text:str,prev_state:torch.Tensor=None,callback:Callable[[str],Any]=None)->Tuple[str,torch.Tensor]: 85 | out_tokens = [] 86 | out_len = 0 87 | out_str = "" 88 | occurrence = {} 89 | state = prev_state 90 | for i in tqdm(range(self.max_tokens),desc=f"tokens",leave=False): 91 | if i == 0: 92 | out, state = self.model.model.forward(self.model.encode(input_text), state) 93 | else: 94 | out, state = self.model.model.forward([token], state) 95 | for n in occurrence: 96 | out[n] -= ( 97 | self.character.alpha_frequency + occurrence[n] * self.character.persona 98 | ) 99 | token = self.model.sample_logits( out, temperature=self.character.temperature, top_p=self.character.top_p) 100 | 101 | out_tokens += [token] 102 | 103 | for n in occurrence: 104 | occurrence[n] *= self.character.alpha_decay 105 | occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0) 106 | 107 | tmp = self.model.decode(out_tokens[out_len:]) 108 | out_str += tmp 109 | # self.talk_state=state 110 | if callback: callback(tmp) 111 | out_len = i + 1 112 | if self.stop:break 113 | print(out_str) 114 | return out_str,state 115 | 116 | from flask import Flask, render_template 117 | from flask_socketio import SocketIO 118 | import webbrowser 119 | from extensions import load_extensions 120 | 121 | class MeowAIServer(): 122 | def __init__(self, meowAI:MeowAI, host:str="0.0.0.0", port:int=5000, debug:bool=False,use_reloader:bool=False,autoOpen:bool=True): 123 | if not meowAI:raise Exception("meowAI is not defined") 124 | self.meowAI = meowAI 125 | self.host=host 126 | self.port=port 127 | self.debug=debug 128 | self.use_reloader=use_reloader 129 | self.autoOpen=autoOpen 130 | 131 | self.app = Flask(__name__) 132 | self.app.jinja_env.variable_start_string = '{[' 133 | self.app.jinja_env.variable_end_string = ']}' 134 | self.socketio = SocketIO(self.app) 135 | print(load_extensions(self.app,self.socketio,self.meowAI)) 136 | 137 | @self.app.route('/') 138 | def index(): 139 | return render_template('index.html') 140 | 141 | @self.app.route('/extension/') 142 | def extension(): 143 | return render_template('extension.html') 144 | 145 | @self.socketio.on('emit') 146 | def emit(news:Dict[str,Union[int,float]]): 147 | ''' 148 | news={'code':1,'msg':'XX错误'} 149 | ''' 150 | self.socketio.emit('emit', json.dumps(news)) 151 | 152 | @self.socketio.on('stop') 153 | def stop(status:bool=True): 154 | self.socketio.emit('stop', status) 155 | self.meowAI.stop=status 156 | 157 | @self.socketio.on('character') 158 | def handle_character(character:Dict[str,Union[int,float]]): 159 | self.meowAI.character = Character(**character) 160 | print(character) 161 | 162 | @self.socketio.on('chat') 163 | def handle_chat(message:Mapping[str,Union[str,List[Dict[str,str]]]]): 164 | try: 165 | self.meowAI.stop=False 166 | self.meowAI.chat(message,None,lambda x:self.socketio.emit('chat',x[0])) 167 | except Exception as e: 168 | ext={'code':1,'msg':f'生成错误Error:{e}'} 169 | emit(ext) 170 | print(ext) 171 | finally: 172 | stop(True) 173 | 174 | @self.socketio.on('talk') 175 | def handle_talk(prompt): 176 | try: 177 | self.meowAI.stop=False 178 | self.meowAI.talk(prompt,None,lambda x:self.socketio.emit('talk',x[0])) 179 | except Exception as e: 180 | ext={'code':1,'msg':f'生成错误Error:{e}'} 181 | emit(ext) 182 | print(ext) 183 | finally: 184 | stop(True) 185 | 186 | def run(self): 187 | if self.autoOpen:webbrowser.open(f'http://{self.host}:{self.port}') 188 | self.socketio.run(self.app,host=self.host,port=self.port,debug=self.debug,use_reloader=self.use_reloader) -------------------------------------------------------------------------------- /models/EEADME.md: -------------------------------------------------------------------------------- 1 | 在此文件夹下存放模型文件 -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | torch 2 | tqdm 3 | flask 4 | flask_socketio 5 | comtypes 6 | packaging 7 | 8 | -------------------------------------------------------------------------------- /rwkv/__pycache__/model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/rwkv/__pycache__/model.cpython-39.pyc -------------------------------------------------------------------------------- /rwkv/__pycache__/rwkv_tokenizer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/rwkv/__pycache__/rwkv_tokenizer.cpython-39.pyc -------------------------------------------------------------------------------- /rwkv/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/rwkv/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /rwkv/cpp/librwkv.dylib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/rwkv/cpp/librwkv.dylib -------------------------------------------------------------------------------- /rwkv/cpp/librwkv.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/rwkv/cpp/librwkv.so -------------------------------------------------------------------------------- /rwkv/cpp/model.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Union 2 | from . import rwkv_cpp_model 3 | from . import rwkv_cpp_shared_library 4 | 5 | 6 | class RWKV: 7 | def __init__(self, model_path: str, strategy=None): 8 | self.library = rwkv_cpp_shared_library.load_rwkv_shared_library() 9 | self.model = rwkv_cpp_model.RWKVModel(self.library, model_path) 10 | self.w = {} # fake weight 11 | self.w["emb.weight"] = [0] * self.model.n_vocab 12 | self.version = ( 13 | self.model.arch_version_major + self.model.arch_version_minor / 10 14 | ) 15 | 16 | def forward(self, tokens: List[int], state: Union[Any, None] = None): 17 | return self.model.eval_sequence_in_chunks(tokens, state, use_numpy=True) 18 | -------------------------------------------------------------------------------- /rwkv/cpp/rwkv.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/rwkv/cpp/rwkv.dll -------------------------------------------------------------------------------- /rwkv/cpp/rwkv_cpp_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import multiprocessing 3 | 4 | # Pre-import PyTorch, if available. 5 | # This fixes "OSError: [WinError 127] The specified procedure could not be found". 6 | try: 7 | import torch 8 | except ModuleNotFoundError: 9 | pass 10 | 11 | # I'm sure this is not strictly correct, but let's keep this crutch for now. 12 | try: 13 | import rwkv_cpp_shared_library 14 | except ModuleNotFoundError: 15 | from . import rwkv_cpp_shared_library 16 | 17 | from typing import TypeVar, Optional, Tuple, List 18 | 19 | # A value of this type is either a numpy's ndarray or a PyTorch's Tensor. 20 | NumpyArrayOrPyTorchTensor: TypeVar = TypeVar('NumpyArrayOrPyTorchTensor') 21 | 22 | class RWKVModel: 23 | """ 24 | An RWKV model managed by rwkv.cpp library. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | shared_library: rwkv_cpp_shared_library.RWKVSharedLibrary, 30 | model_path: str, 31 | thread_count: int = max(1, multiprocessing.cpu_count() // 2), 32 | gpu_layer_count: int = 0, 33 | **kwargs 34 | ) -> None: 35 | """ 36 | Loads the model and prepares it for inference. 37 | In case of any error, this method will throw an exception. 38 | 39 | Parameters 40 | ---------- 41 | shared_library : RWKVSharedLibrary 42 | rwkv.cpp shared library. 43 | model_path : str 44 | Path to RWKV model file in ggml format. 45 | thread_count : int 46 | Thread count to use. If not set, defaults to CPU count / 2. 47 | gpu_layer_count : int 48 | Count of layers to offload onto the GPU, must be >= 0. 49 | See documentation of `gpu_offload_layers` for details about layer offloading. 50 | """ 51 | 52 | if 'gpu_layers_count' in kwargs: 53 | gpu_layer_count = kwargs['gpu_layers_count'] 54 | 55 | if not os.path.isfile(model_path): 56 | raise ValueError(f'{model_path} is not a file') 57 | 58 | if not (thread_count > 0): 59 | raise ValueError('Thread count must be > 0') 60 | 61 | if not (gpu_layer_count >= 0): 62 | raise ValueError('GPU layer count must be >= 0') 63 | 64 | self._library: rwkv_cpp_shared_library.RWKVSharedLibrary = shared_library 65 | 66 | self._ctx: rwkv_cpp_shared_library.RWKVContext = self._library.rwkv_init_from_file(model_path, thread_count) 67 | 68 | if gpu_layer_count > 0: 69 | self.gpu_offload_layers(gpu_layer_count) 70 | 71 | self._state_buffer_element_count: int = self._library.rwkv_get_state_buffer_element_count(self._ctx) 72 | self._logits_buffer_element_count: int = self._library.rwkv_get_logits_buffer_element_count(self._ctx) 73 | 74 | self._valid: bool = True 75 | 76 | def gpu_offload_layers(self, layer_count: int) -> bool: 77 | """ 78 | Offloads specified count of model layers onto the GPU. Offloaded layers are evaluated using cuBLAS or CLBlast. 79 | For the purposes of this function, model head (unembedding matrix) is treated as an additional layer: 80 | - pass `model.n_layer` to offload all layers except model head 81 | - pass `model.n_layer + 1` to offload all layers, including model head 82 | 83 | Returns true if at least one layer was offloaded. 84 | If rwkv.cpp was compiled without cuBLAS and CLBlast support, this function is a no-op and always returns false. 85 | 86 | Parameters 87 | ---------- 88 | layer_count : int 89 | Count of layers to offload onto the GPU, must be >= 0. 90 | """ 91 | 92 | if not (layer_count >= 0): 93 | raise ValueError('Layer count must be >= 0') 94 | 95 | return self._library.rwkv_gpu_offload_layers(self._ctx, layer_count) 96 | 97 | @property 98 | def arch_version_major(self) -> int: 99 | return self._library.rwkv_get_arch_version_major(self._ctx) 100 | 101 | @property 102 | def arch_version_minor(self) -> int: 103 | return self._library.rwkv_get_arch_version_minor(self._ctx) 104 | 105 | @property 106 | def n_vocab(self) -> int: 107 | return self._library.rwkv_get_n_vocab(self._ctx) 108 | 109 | @property 110 | def n_embed(self) -> int: 111 | return self._library.rwkv_get_n_embed(self._ctx) 112 | 113 | @property 114 | def n_layer(self) -> int: 115 | return self._library.rwkv_get_n_layer(self._ctx) 116 | 117 | def eval( 118 | self, 119 | token: int, 120 | state_in: Optional[NumpyArrayOrPyTorchTensor], 121 | state_out: Optional[NumpyArrayOrPyTorchTensor] = None, 122 | logits_out: Optional[NumpyArrayOrPyTorchTensor] = None, 123 | use_numpy: bool = False 124 | ) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]: 125 | """ 126 | Evaluates the model for a single token. 127 | In case of any error, this method will throw an exception. 128 | 129 | Parameters 130 | ---------- 131 | token : int 132 | Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab. 133 | state_in : Optional[NumpyArrayOrTorchTensor] 134 | State from previous call of this method. If this is a first pass, set it to None. 135 | state_out : Optional[NumpyArrayOrTorchTensor] 136 | Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count). 137 | logits_out : Optional[NumpyArrayOrTorchTensor] 138 | Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count). 139 | use_numpy : bool 140 | If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors. 141 | This parameter is ignored if any tensor parameter is not None; in such case, 142 | type of returned tensors will match the type of received tensors. 143 | 144 | Returns 145 | ------- 146 | logits, state 147 | Logits vector of shape (n_vocab); state for the next step. 148 | """ 149 | 150 | if not self._valid: 151 | raise ValueError('Model was freed') 152 | 153 | use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy) 154 | 155 | if state_in is not None: 156 | self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count) 157 | 158 | state_in_ptr = self._get_data_ptr(state_in) 159 | else: 160 | state_in_ptr = 0 161 | 162 | if state_out is not None: 163 | self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count) 164 | else: 165 | state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy) 166 | 167 | if logits_out is not None: 168 | self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count) 169 | else: 170 | logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy) 171 | 172 | self._library.rwkv_eval( 173 | self._ctx, 174 | token, 175 | state_in_ptr, 176 | self._get_data_ptr(state_out), 177 | self._get_data_ptr(logits_out) 178 | ) 179 | 180 | return logits_out, state_out 181 | 182 | def eval_sequence( 183 | self, 184 | tokens: List[int], 185 | state_in: Optional[NumpyArrayOrPyTorchTensor], 186 | state_out: Optional[NumpyArrayOrPyTorchTensor] = None, 187 | logits_out: Optional[NumpyArrayOrPyTorchTensor] = None, 188 | use_numpy: bool = False 189 | ) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]: 190 | """ 191 | Evaluates the model for a sequence of tokens. 192 | 193 | NOTE ON GGML NODE LIMIT 194 | 195 | ggml has a hard-coded limit on max amount of nodes in a computation graph. The sequence graph is built in a way that quickly exceedes 196 | this limit when using large models and/or large sequence lengths. 197 | Fortunately, rwkv.cpp's fork of ggml has increased limit which was tested to work for sequence lengths up to 64 for 14B models. 198 | 199 | If you get `GGML_ASSERT: ...\\ggml.c:16941: cgraph->n_nodes < GGML_MAX_NODES`, this means you've exceeded the limit. 200 | To get rid of the assertion failure, reduce the model size and/or sequence length. 201 | 202 | In case of any error, this method will throw an exception. 203 | 204 | Parameters 205 | ---------- 206 | tokens : List[int] 207 | Indices of the next tokens to be seen by the model. Must be in range 0 <= token < n_vocab. 208 | state_in : Optional[NumpyArrayOrTorchTensor] 209 | State from previous call of this method. If this is a first pass, set it to None. 210 | state_out : Optional[NumpyArrayOrTorchTensor] 211 | Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count). 212 | logits_out : Optional[NumpyArrayOrTorchTensor] 213 | Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count). 214 | use_numpy : bool 215 | If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors. 216 | This parameter is ignored if any tensor parameter is not None; in such case, 217 | type of returned tensors will match the type of received tensors. 218 | 219 | Returns 220 | ------- 221 | logits, state 222 | Logits vector of shape (n_vocab); state for the next step. 223 | """ 224 | 225 | if not self._valid: 226 | raise ValueError('Model was freed') 227 | 228 | use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy) 229 | 230 | if state_in is not None: 231 | self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count) 232 | 233 | state_in_ptr = self._get_data_ptr(state_in) 234 | else: 235 | state_in_ptr = 0 236 | 237 | if state_out is not None: 238 | self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count) 239 | else: 240 | state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy) 241 | 242 | if logits_out is not None: 243 | self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count) 244 | else: 245 | logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy) 246 | 247 | self._library.rwkv_eval_sequence( 248 | self._ctx, 249 | tokens, 250 | state_in_ptr, 251 | self._get_data_ptr(state_out), 252 | self._get_data_ptr(logits_out) 253 | ) 254 | 255 | return logits_out, state_out 256 | 257 | def eval_sequence_in_chunks( 258 | self, 259 | tokens: List[int], 260 | state_in: Optional[NumpyArrayOrPyTorchTensor], 261 | state_out: Optional[NumpyArrayOrPyTorchTensor] = None, 262 | logits_out: Optional[NumpyArrayOrPyTorchTensor] = None, 263 | chunk_size: int = 16, 264 | use_numpy: bool = False 265 | ) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]: 266 | """ 267 | Evaluates the model for a sequence of tokens using `eval_sequence`, splitting a potentially long sequence into fixed-length chunks. 268 | This function is useful for processing complete prompts and user input in chat & role-playing use-cases. 269 | It is recommended to use this function instead of `eval_sequence` to avoid mistakes and get maximum performance. 270 | 271 | Chunking allows processing sequences of thousands of tokens, while not reaching the ggml's node limit and not consuming too much memory. 272 | A reasonable and recommended value of chunk size is 16. If you want maximum performance, try different chunk sizes in range [2..64] 273 | and choose one that works the best in your use case. 274 | 275 | In case of any error, this method will throw an exception. 276 | 277 | Parameters 278 | ---------- 279 | tokens : List[int] 280 | Indices of the next tokens to be seen by the model. Must be in range 0 <= token < n_vocab. 281 | chunk_size : int 282 | Size of each chunk in tokens, must be positive. 283 | state_in : Optional[NumpyArrayOrTorchTensor] 284 | State from previous call of this method. If this is a first pass, set it to None. 285 | state_out : Optional[NumpyArrayOrTorchTensor] 286 | Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count). 287 | logits_out : Optional[NumpyArrayOrTorchTensor] 288 | Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count). 289 | use_numpy : bool 290 | If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors. 291 | This parameter is ignored if any tensor parameter is not None; in such case, 292 | type of returned tensors will match the type of received tensors. 293 | 294 | Returns 295 | ------- 296 | logits, state 297 | Logits vector of shape (n_vocab); state for the next step. 298 | """ 299 | 300 | if not self._valid: 301 | raise ValueError('Model was freed') 302 | 303 | use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy) 304 | 305 | if state_in is not None: 306 | self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count) 307 | 308 | state_in_ptr = self._get_data_ptr(state_in) 309 | else: 310 | state_in_ptr = 0 311 | 312 | if state_out is not None: 313 | self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count) 314 | else: 315 | state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy) 316 | 317 | if logits_out is not None: 318 | self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count) 319 | else: 320 | logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy) 321 | 322 | self._library.rwkv_eval_sequence_in_chunks( 323 | self._ctx, 324 | tokens, 325 | chunk_size, 326 | state_in_ptr, 327 | self._get_data_ptr(state_out), 328 | self._get_data_ptr(logits_out) 329 | ) 330 | 331 | return logits_out, state_out 332 | 333 | def free(self) -> None: 334 | """ 335 | Frees all allocated resources. 336 | In case of any error, this method will throw an exception. 337 | The object must not be used anymore after calling this method. 338 | """ 339 | 340 | if not self._valid: 341 | raise ValueError('Already freed') 342 | 343 | self._valid = False 344 | 345 | self._library.rwkv_free(self._ctx) 346 | 347 | def __del__(self) -> None: 348 | # Free the context on GC in case user forgot to call free() explicitly. 349 | if hasattr(self, '_valid') and self._valid: 350 | self.free() 351 | 352 | def _is_pytorch_tensor(self, tensor: NumpyArrayOrPyTorchTensor) -> bool: 353 | return hasattr(tensor, '__module__') and tensor.__module__ == 'torch' 354 | 355 | def _detect_numpy_usage(self, tensors: List[Optional[NumpyArrayOrPyTorchTensor]], use_numpy_by_default: bool) -> bool: 356 | for tensor in tensors: 357 | if tensor is not None: 358 | return False if self._is_pytorch_tensor(tensor) else True 359 | 360 | return use_numpy_by_default 361 | 362 | def _validate_tensor(self, tensor: NumpyArrayOrPyTorchTensor, name: str, size: int) -> None: 363 | if self._is_pytorch_tensor(tensor): 364 | tensor: torch.Tensor = tensor 365 | 366 | if tensor.device != torch.device('cpu'): 367 | raise ValueError(f'{name} is not on CPU') 368 | if tensor.dtype != torch.float32: 369 | raise ValueError(f'{name} is not of type float32') 370 | if tensor.shape != (size,): 371 | raise ValueError(f'{name} has invalid shape {tensor.shape}, expected ({size})') 372 | if not tensor.is_contiguous(): 373 | raise ValueError(f'{name} is not contiguous') 374 | else: 375 | import numpy as np 376 | tensor: np.ndarray = tensor 377 | 378 | if tensor.dtype != np.float32: 379 | raise ValueError(f'{name} is not of type float32') 380 | if tensor.shape != (size,): 381 | raise ValueError(f'{name} has invalid shape {tensor.shape}, expected ({size})') 382 | if not tensor.data.contiguous: 383 | raise ValueError(f'{name} is not contiguous') 384 | 385 | def _get_data_ptr(self, tensor: NumpyArrayOrPyTorchTensor): 386 | if self._is_pytorch_tensor(tensor): 387 | return tensor.data_ptr() 388 | else: 389 | return tensor.ctypes.data 390 | 391 | def _zeros_float32(self, element_count: int, use_numpy: bool) -> NumpyArrayOrPyTorchTensor: 392 | if use_numpy: 393 | import numpy as np 394 | return np.zeros(element_count, dtype=np.float32) 395 | else: 396 | return torch.zeros(element_count, dtype=torch.float32, device='cpu') 397 | -------------------------------------------------------------------------------- /rwkv/cuda/gemm_fp16_cublas.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #define CUBLAS_CHECK(condition) \ 10 | for (cublasStatus_t _cublas_check_status = (condition); \ 11 | _cublas_check_status != CUBLAS_STATUS_SUCCESS;) \ 12 | throw std::runtime_error("cuBLAS error " + \ 13 | std::to_string(_cublas_check_status) + " at " + \ 14 | std::to_string(__LINE__)); 15 | 16 | #define CUDA_CHECK(condition) \ 17 | for (cudaError_t _cuda_check_status = (condition); \ 18 | _cuda_check_status != cudaSuccess;) \ 19 | throw std::runtime_error( \ 20 | "CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \ 21 | " at " + std::to_string(__LINE__)); 22 | 23 | /* 24 | NOTE: blas gemm is column-major by default, but we need row-major output. 25 | The data of row-major, transposed matrix is exactly the same as the 26 | column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T 27 | */ 28 | void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) { 29 | const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); 30 | const auto cuda_data_type = CUDA_R_16F; 31 | const auto cuda_c_data_type = 32 | c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F; 33 | const auto compute_type = CUDA_R_32F; 34 | const float sp_alpha = 1.f; 35 | // swap a and b, and use CUBLAS_OP_N. see the notes above 36 | std::swap(a, b); 37 | const cublasOperation_t cublas_trans_a = CUBLAS_OP_N; 38 | const cublasOperation_t cublas_trans_b = CUBLAS_OP_N; 39 | // m = (B^T).size(0) = B.size(1), and = A.size(1) after swap, 40 | // negative axis is used because of the existence of batch matmul. 41 | const int m = a.size(-1); 42 | const int k = a.size(-2); 43 | const int n = b.size(-2); 44 | const int cublas_lda = m; 45 | const int cublas_ldb = k; 46 | const int cublas_ldc = m; 47 | cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle(); 48 | 49 | #if CUDA_VERSION >= 11000 50 | cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; 51 | #else 52 | cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT_TENSOR_OP; 53 | #endif 54 | const float sp_beta = 0.f; 55 | if (a.sizes().size() == 2 && b.sizes().size() == 2) { 56 | CUBLAS_CHECK(cublasGemmEx( 57 | cublas_handle, cublas_trans_a, cublas_trans_b, m, n, k, &sp_alpha, 58 | a.data_ptr(), cuda_data_type, cublas_lda, b.data_ptr(), cuda_data_type, 59 | cublas_ldb, &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, 60 | compute_type, algo)); 61 | } else { 62 | // batch matmul 63 | assert(a.sizes().size() == 3 && b.sizes().size() == 3); 64 | 65 | const long long int cublas_stride_a = m * k; 66 | const long long int cublas_stride_b = k * n; 67 | const long long int cublas_stride_c = m * n; 68 | CUBLAS_CHECK(cublasGemmStridedBatchedEx( 69 | cublas_handle, cublas_trans_a, cublas_trans_b, m, 70 | n, k, &sp_alpha, a.data_ptr(), cuda_data_type, cublas_lda, 71 | cublas_stride_a, b.data_ptr(), cuda_data_type, cublas_ldb, cublas_stride_b, 72 | &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, cublas_stride_c, 73 | a.size(0), compute_type, algo)); 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /rwkv/cuda/operators.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "ATen/ATen.h" 4 | #include 5 | #define MIN_VALUE (-1e38) 6 | typedef at::Half fp16; 7 | __half *cast(fp16 *ptr) { 8 | return reinterpret_cast<__half *>(ptr); 9 | } 10 | 11 | template 12 | __global__ void kernel_wkv_forward(const int B, const int T, const int C, 13 | const float *__restrict__ const _w, const float *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, 14 | F *__restrict__ const _y, float *__restrict__ const _aa, float *__restrict__ const _bb, float *__restrict__ const _pp) { 15 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; 16 | const int _b = idx / C; 17 | const int _c = idx % C; 18 | const int _offset = _b * T * C + _c; 19 | const int _state_offset = _b * C + _c; 20 | 21 | float u = _u[_c]; 22 | float w = _w[_c]; 23 | const F *__restrict__ const k = _k + _offset; 24 | const F *__restrict__ const v = _v + _offset; 25 | F *__restrict__ const y = _y + _offset; 26 | 27 | float aa = _aa[_state_offset]; 28 | float bb = _bb[_state_offset]; 29 | float pp = _pp[_state_offset]; 30 | for (int i = 0; i < T; i++) { 31 | const int ii = i * C; 32 | const float kk = float(k[ii]); 33 | const float vv = float(v[ii]); 34 | float ww = u + kk; 35 | float p = max(pp, ww); 36 | float e1 = exp(pp - p); 37 | float e2 = exp(ww - p); 38 | y[ii] = F((e1 * aa + e2 * vv) / (e1 * bb + e2)); 39 | ww = w + pp; 40 | p = max(ww, kk); 41 | e1 = exp(ww - p); 42 | e2 = exp(kk - p); 43 | aa = e1 * aa + e2 * vv; 44 | bb = e1 * bb + e2; 45 | pp = p; 46 | } 47 | _aa[_state_offset] = aa; 48 | _bb[_state_offset] = bb; 49 | _pp[_state_offset] = pp; 50 | } 51 | 52 | template 53 | void cuda_wkv_forward(int B, int T, int C, float *w, float *u, F *k, F *v, F *y, float *aa, float *bb, float *pp) { 54 | dim3 threadsPerBlock( min(C, 32) ); 55 | assert(B * C % threadsPerBlock.x == 0); 56 | dim3 numBlocks(B * C / threadsPerBlock.x); 57 | kernel_wkv_forward<<>>(B, T, C, w, u, k, v, y, aa, bb, pp); 58 | } 59 | 60 | template void cuda_wkv_forward( 61 | int B, int T, int C, 62 | float *w, float *u, fp16 *k, fp16 *v, fp16 *y, 63 | float *aa, float *bb, float *pp); 64 | template void cuda_wkv_forward( 65 | int B, int T, int C, 66 | float *w, float *u, float *k, float *v, float *y, 67 | float *aa, float *bb, float *pp); 68 | 69 | __global__ void kernel_mm_seq_fp32i8( 70 | const int B, const int N, const int M, 71 | const float *__restrict__ const x, const int x_stride, 72 | const uint8_t *__restrict__ const w, const int w_stride, 73 | const float *__restrict__ const mx, 74 | const float *__restrict__ const rx, 75 | const float *__restrict__ const my, 76 | const float *__restrict__ const ry, 77 | float *__restrict__ const y, const int y_stride) { 78 | 79 | const int i = blockIdx.x * blockDim.x + threadIdx.x; 80 | const int k = blockIdx.y * blockDim.y + threadIdx.y; 81 | 82 | if (i < B && k < M) { 83 | float y_local = 0; 84 | for (int j = 0; j < N; ++j) { 85 | y_local += x[i * x_stride + j] * ( 86 | (float(w[j * w_stride + k]) + 0.5f) 87 | * rx[k] * ry[j] + mx[k] + my[j] 88 | ); 89 | } 90 | y[i * y_stride + k] = y_local; 91 | } 92 | } 93 | 94 | template 95 | void cuda_mm8_seq(int B, int N, int M, 96 | F *x, int x_stride, 97 | uint8_t *w, int w_stride, 98 | F *mx, F *rx, 99 | F *my, F *ry, 100 | F *y, int y_stride); 101 | 102 | template <> 103 | void cuda_mm8_seq(int B, int N, int M, 104 | float *x, int x_stride, 105 | uint8_t *w, int w_stride, 106 | float *mx, float *rx, 107 | float *my, float *ry, 108 | float *y, int y_stride) { 109 | dim3 blockSize(1, 128); 110 | dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y); 111 | kernel_mm_seq_fp32i8<<>>( 112 | B, N, M, x, x_stride, w, w_stride, 113 | mx, rx, my, ry, y, y_stride); 114 | } 115 | 116 | __global__ void kernel_mm_seq_fp16i8( 117 | const int B, const int N, const int M, 118 | const __half *__restrict__ const x, const int x_stride, 119 | const uint8_t *__restrict__ const w, const int w_stride, 120 | const __half *__restrict__ const mx, 121 | const __half *__restrict__ const rx, 122 | const __half *__restrict__ const my, 123 | const __half *__restrict__ const ry, 124 | __half *__restrict__ const y, const int y_stride) { 125 | 126 | const int i = blockIdx.x * blockDim.x + threadIdx.x; 127 | const int k = blockIdx.y * blockDim.y + threadIdx.y; 128 | 129 | if (i < B && k < M) { 130 | float y_local = 0; 131 | for (int j = 0; j < N; ++j) { 132 | y_local += __half2float(x[i * x_stride + j]) * ( 133 | (float(w[j * w_stride + k]) + 0.5f) 134 | * __half2float(rx[k]) * __half2float(ry[j]) 135 | + __half2float(mx[k]) + __half2float(my[j]) 136 | ); 137 | } 138 | y[i * y_stride + k] = __float2half(y_local); 139 | } 140 | } 141 | 142 | template <> 143 | void cuda_mm8_seq(int B, int N, int M, 144 | fp16 *x, int x_stride, 145 | uint8_t *w, int w_stride, 146 | fp16 *mx, fp16 *rx, 147 | fp16 *my, fp16 *ry, 148 | fp16 *y, int y_stride) { 149 | dim3 blockSize(1, 128); 150 | dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y); 151 | kernel_mm_seq_fp16i8<<>>( 152 | B, N, M, cast(x), x_stride, w, w_stride, 153 | cast(mx), cast(rx), cast(my), cast(ry), cast(y), y_stride); 154 | } 155 | 156 | #define MM8_ONE_JSPLIT 24 157 | #define MM8_ONE_TILE 1024 158 | 159 | __global__ void kernel_mm_one_fp32i8( 160 | const int N, const int M, 161 | const float *__restrict__ const x, 162 | const uint8_t *__restrict__ const w, const int w_stride, 163 | const float *__restrict__ const mx, 164 | const float *__restrict__ const rx, 165 | const float *__restrict__ const my, 166 | const float *__restrict__ const ry, 167 | float *__restrict__ const y) { 168 | 169 | const int k = blockIdx.y * blockDim.y + threadIdx.y; 170 | const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT)); 171 | const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT)); 172 | 173 | if (k < M) { 174 | float y_local = 0; 175 | for (int j = j0; j < j1; ++j) { 176 | y_local += x[j] * ( 177 | (float(w[j * w_stride + k]) + 0.5f) 178 | * rx[k] * ry[j] + mx[k] + my[j] 179 | ); 180 | } 181 | atomicAdd(&y[k], y_local); 182 | } 183 | } 184 | 185 | template 186 | void cuda_mm8_one(int N, int M, 187 | F *x, 188 | uint8_t *w, int w_stride, 189 | F *mx, F *rx, 190 | F *my, F *ry, 191 | float *y); 192 | 193 | template <> 194 | void cuda_mm8_one(int N, int M, 195 | float *x, 196 | uint8_t *w, int w_stride, 197 | float *mx, float *rx, 198 | float *my, float *ry, 199 | float *y) { 200 | dim3 blockSize(1, MM8_ONE_TILE); 201 | dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y); 202 | kernel_mm_one_fp32i8<<>>( 203 | N, M, x, w, w_stride, 204 | mx, rx, my, ry, y); 205 | } 206 | 207 | __global__ void kernel_mm_one_fp16i8( 208 | const int N, const int M, 209 | const __half *__restrict__ const x, 210 | const uint8_t *__restrict__ const w, const int w_stride, 211 | const __half *__restrict__ const mx, 212 | const __half *__restrict__ const rx, 213 | const __half *__restrict__ const my, 214 | const __half *__restrict__ const ry, 215 | float *__restrict__ const y) { 216 | 217 | const int k = blockIdx.y * blockDim.y + threadIdx.y; 218 | const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT)); 219 | const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT)); 220 | 221 | if (k < M) { 222 | float y_local = 0; 223 | for (int j = j0; j < j1; ++j) { 224 | y_local += __half2float(x[j]) * ( 225 | (float(w[j * w_stride + k]) + 0.5f) 226 | * __half2float(rx[k]) * __half2float(ry[j]) 227 | + __half2float(mx[k]) + __half2float(my[j]) 228 | ); 229 | } 230 | atomicAdd(&y[k], y_local); 231 | } 232 | } 233 | 234 | template <> 235 | void cuda_mm8_one(int N, int M, 236 | fp16 *x, 237 | uint8_t *w, int w_stride, 238 | fp16 *mx, fp16 *rx, 239 | fp16 *my, fp16 *ry, 240 | float *y) { 241 | dim3 blockSize(1, MM8_ONE_TILE); 242 | dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y); 243 | kernel_mm_one_fp16i8<<>>( 244 | N, M, cast(x), w, w_stride, 245 | cast(mx), cast(rx), cast(my), cast(ry), y); 246 | } 247 | -------------------------------------------------------------------------------- /rwkv/cuda/rwkv5.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "ATen/ATen.h" 4 | typedef at::BFloat16 bf16; 5 | typedef at::Half fp16; 6 | typedef float fp32; 7 | 8 | template 9 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state, 10 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, 11 | F *__restrict__ const _y) 12 | { 13 | const int b = blockIdx.x / H; 14 | const int h = blockIdx.x % H; 15 | const int i = threadIdx.x; 16 | _w += h*_N_; 17 | _u += h*_N_; 18 | _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!! 19 | 20 | __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; 21 | 22 | float state[_N_]; 23 | #pragma unroll 24 | for (int j = 0; j < _N_; j++) 25 | state[j] = _state[j]; 26 | 27 | __syncthreads(); 28 | u[i] = float(_u[i]); 29 | w[i] = _w[i]; 30 | __syncthreads(); 31 | 32 | for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) 33 | { 34 | __syncthreads(); 35 | r[i] = float(_r[t]); 36 | k[i] = float(_k[t]); 37 | __syncthreads(); 38 | 39 | const float v = float(_v[t]); 40 | float y = 0; 41 | 42 | #pragma unroll 43 | for (int j = 0; j < _N_; j+=4) 44 | { 45 | const float4& r_ = (float4&)(r[j]); 46 | const float4& k_ = (float4&)(k[j]); 47 | const float4& w_ = (float4&)(w[j]); 48 | const float4& u_ = (float4&)(u[j]); 49 | float4& s = (float4&)(state[j]); 50 | float4 x; 51 | 52 | x.x = k_.x * v; 53 | x.y = k_.y * v; 54 | x.z = k_.z * v; 55 | x.w = k_.w * v; 56 | 57 | y += r_.x * (u_.x * x.x + s.x); 58 | y += r_.y * (u_.y * x.y + s.y); 59 | y += r_.z * (u_.z * x.z + s.z); 60 | y += r_.w * (u_.w * x.w + s.w); 61 | 62 | s.x = s.x * w_.x + x.x; 63 | s.y = s.y * w_.y + x.y; 64 | s.z = s.z * w_.z + x.z; 65 | s.w = s.w * w_.w + x.w; 66 | } 67 | _y[t] = F(y); 68 | } 69 | #pragma unroll 70 | for (int j = 0; j < _N_; j++) 71 | _state[j] = state[j]; 72 | } 73 | 74 | void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y) 75 | { 76 | assert(H*_N_ == C); 77 | kernel_forward<<>>(B, T, C, H, state, r, k, v, w, u, y); 78 | } 79 | void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y) 80 | { 81 | assert(H*_N_ == C); 82 | kernel_forward<<>>(B, T, C, H, state, r, k, v, w, u, y); 83 | } 84 | void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y) 85 | { 86 | assert(H*_N_ == C); 87 | kernel_forward<<>>(B, T, C, H, state, r, k, v, w, u, y); 88 | } 89 | -------------------------------------------------------------------------------- /rwkv/cuda/rwkv5_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "ATen/ATen.h" 3 | #include 4 | typedef at::BFloat16 bf16; 5 | typedef at::Half fp16; 6 | typedef float fp32; 7 | 8 | void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); 9 | void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y); 10 | void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y); 11 | 12 | void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 13 | const at::cuda::OptionalCUDAGuard device_guard(device_of(state)); 14 | cuda_forward_bf16(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); 15 | } 16 | void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 17 | const at::cuda::OptionalCUDAGuard device_guard(device_of(state)); 18 | cuda_forward_fp16(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); 19 | } 20 | void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 21 | const at::cuda::OptionalCUDAGuard device_guard(device_of(state)); 22 | cuda_forward_fp32(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); 23 | } 24 | 25 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 26 | m.def("forward_bf16", &forward_bf16, "rwkv5 forward_bf16"); 27 | m.def("forward_fp16", &forward_fp16, "rwkv5 forward_fp16"); 28 | m.def("forward_fp32", &forward_fp32, "rwkv5 forward_fp32"); 29 | } 30 | TORCH_LIBRARY(rwkv5, m) { 31 | m.def("forward_bf16", forward_bf16); 32 | m.def("forward_fp16", forward_fp16); 33 | m.def("forward_fp32", forward_fp32); 34 | } 35 | -------------------------------------------------------------------------------- /rwkv/cuda/rwkv6.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "ATen/ATen.h" 4 | typedef at::BFloat16 bf16; 5 | typedef at::Half fp16; 6 | typedef float fp32; 7 | 8 | template 9 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state, 10 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, 11 | F *__restrict__ const _y) 12 | { 13 | const int b = blockIdx.x / H; 14 | const int h = blockIdx.x % H; 15 | const int i = threadIdx.x; 16 | _u += h*_N_; 17 | _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!! 18 | 19 | __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; 20 | 21 | float state[_N_]; 22 | #pragma unroll 23 | for (int j = 0; j < _N_; j++) 24 | state[j] = _state[j]; 25 | 26 | __syncthreads(); 27 | u[i] = float(_u[i]); 28 | __syncthreads(); 29 | 30 | for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) 31 | { 32 | __syncthreads(); 33 | w[i] = _w[t]; 34 | r[i] = float(_r[t]); 35 | k[i] = float(_k[t]); 36 | __syncthreads(); 37 | 38 | const float v = float(_v[t]); 39 | float y = 0; 40 | 41 | #pragma unroll 42 | for (int j = 0; j < _N_; j+=4) 43 | { 44 | const float4& r_ = (float4&)(r[j]); 45 | const float4& k_ = (float4&)(k[j]); 46 | const float4& w_ = (float4&)(w[j]); 47 | const float4& u_ = (float4&)(u[j]); 48 | float4& s = (float4&)(state[j]); 49 | float4 x; 50 | 51 | x.x = k_.x * v; 52 | x.y = k_.y * v; 53 | x.z = k_.z * v; 54 | x.w = k_.w * v; 55 | 56 | y += r_.x * (u_.x * x.x + s.x); 57 | y += r_.y * (u_.y * x.y + s.y); 58 | y += r_.z * (u_.z * x.z + s.z); 59 | y += r_.w * (u_.w * x.w + s.w); 60 | 61 | s.x = s.x * w_.x + x.x; 62 | s.y = s.y * w_.y + x.y; 63 | s.z = s.z * w_.z + x.z; 64 | s.w = s.w * w_.w + x.w; 65 | } 66 | _y[t] = F(y); 67 | } 68 | #pragma unroll 69 | for (int j = 0; j < _N_; j++) 70 | _state[j] = state[j]; 71 | } 72 | 73 | void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y) 74 | { 75 | assert(H*_N_ == C); 76 | kernel_forward<<>>(B, T, C, H, state, r, k, v, w, u, y); 77 | } 78 | void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y) 79 | { 80 | assert(H*_N_ == C); 81 | kernel_forward<<>>(B, T, C, H, state, r, k, v, w, u, y); 82 | } 83 | void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y) 84 | { 85 | assert(H*_N_ == C); 86 | kernel_forward<<>>(B, T, C, H, state, r, k, v, w, u, y); 87 | } 88 | -------------------------------------------------------------------------------- /rwkv/cuda/rwkv6_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "ATen/ATen.h" 3 | #include 4 | typedef at::BFloat16 bf16; 5 | typedef at::Half fp16; 6 | typedef float fp32; 7 | 8 | void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); 9 | void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y); 10 | void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y); 11 | 12 | void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 13 | const at::cuda::OptionalCUDAGuard device_guard(device_of(state)); 14 | cuda_forward_bf16(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); 15 | } 16 | void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 17 | const at::cuda::OptionalCUDAGuard device_guard(device_of(state)); 18 | cuda_forward_fp16(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); 19 | } 20 | void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 21 | const at::cuda::OptionalCUDAGuard device_guard(device_of(state)); 22 | cuda_forward_fp32(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); 23 | } 24 | 25 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 26 | m.def("forward_bf16", &forward_bf16, "rwkv6 forward_bf16"); 27 | m.def("forward_fp16", &forward_fp16, "rwkv6 forward_fp16"); 28 | m.def("forward_fp32", &forward_fp32, "rwkv6 forward_fp32"); 29 | } 30 | TORCH_LIBRARY(rwkv6, m) { 31 | m.def("forward_bf16", forward_bf16); 32 | m.def("forward_fp16", forward_fp16); 33 | m.def("forward_fp32", forward_fp32); 34 | } 35 | -------------------------------------------------------------------------------- /rwkv/cuda/rwkv7.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "ATen/ATen.h" 4 | 5 | typedef at::Half fp16; 6 | typedef at::BFloat16 bf16; 7 | typedef float fp32; 8 | 9 | template 10 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, 11 | float *__restrict__ _state, const F *__restrict__ const _r, const F *__restrict__ const _w, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _a, const F *__restrict__ const _b, 12 | F *__restrict__ const _y) 13 | { 14 | const int e = blockIdx.x / H; 15 | const int h = blockIdx.x % H; 16 | const int i = threadIdx.x; 17 | _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!! 18 | 19 | float state[_N_]; 20 | #pragma unroll 21 | for (int j = 0; j < _N_; j++) 22 | state[j] = _state[j]; 23 | 24 | __shared__ float r[_N_], k[_N_], w[_N_], a[_N_], b[_N_]; 25 | 26 | for (int _t = 0; _t < T; _t++) 27 | { 28 | const int t = e*T*C + h*_N_ + i + _t * C; 29 | __syncthreads(); 30 | r[i] = float(_r[t]); 31 | w[i] = __expf(-__expf(float(_w[t]))); 32 | k[i] = float(_k[t]); 33 | a[i] = float(_a[t]); 34 | b[i] = float(_b[t]); 35 | __syncthreads(); 36 | 37 | float sa = 0; 38 | #pragma unroll 39 | for (int j = 0; j < _N_; j++) 40 | { 41 | sa += a[j] * state[j]; 42 | } 43 | 44 | float vv = float(_v[t]); 45 | float y = 0; 46 | #pragma unroll 47 | for (int j = 0; j < _N_; j++) 48 | { 49 | float& s = state[j]; 50 | s = s * w[j] + k[j] * vv + sa * b[j]; 51 | y += s * r[j]; 52 | } 53 | _y[t] = F(y); 54 | } 55 | #pragma unroll 56 | for (int j = 0; j < _N_; j++) 57 | _state[j] = state[j]; 58 | } 59 | 60 | void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16* w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y) 61 | { 62 | assert(H*_N_ == C); 63 | assert(B == 1); // only for B=1 64 | kernel_forward<<>>(B, T, C, H, state, r, w, k, v, a, b, y); 65 | } 66 | void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16* w, fp16 *k, fp16 *v, fp16 *a, fp16 *b, fp16 *y) 67 | { 68 | assert(H*_N_ == C); 69 | assert(B == 1); // only for B=1 70 | kernel_forward<<>>(B, T, C, H, state, r, w, k, v, a, b, y); 71 | } 72 | void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32* w, fp32 *k, fp32 *v, fp32 *a, fp32 *b, fp32 *y) 73 | { 74 | assert(H*_N_ == C); 75 | assert(B == 1); // only for B=1 76 | kernel_forward<<>>(B, T, C, H, state, r, w, k, v, a, b, y); 77 | } 78 | -------------------------------------------------------------------------------- /rwkv/cuda/rwkv7_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "ATen/ATen.h" 3 | 4 | typedef at::Half fp16; 5 | typedef at::BFloat16 bf16; 6 | typedef float fp32; 7 | 8 | void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y); 9 | void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *w, fp16 *k, fp16 *v, fp16 *a, fp16 *b, fp16 *y); 10 | void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *w, fp32 *k, fp32 *v, fp32 *a, fp32 *b, fp32 *y); 11 | 12 | void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) { 13 | cuda_forward_bf16(B, T, C, H, state.data_ptr(), r.data_ptr(), w.data_ptr(), k.data_ptr(), v.data_ptr(), a.data_ptr(), b.data_ptr(), y.data_ptr()); 14 | } 15 | void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) { 16 | cuda_forward_fp16(B, T, C, H, state.data_ptr(), r.data_ptr(), w.data_ptr(), k.data_ptr(), v.data_ptr(), a.data_ptr(), b.data_ptr(), y.data_ptr()); 17 | } 18 | void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) { 19 | cuda_forward_fp32(B, T, C, H, state.data_ptr(), r.data_ptr(), w.data_ptr(), k.data_ptr(), v.data_ptr(), a.data_ptr(), b.data_ptr(), y.data_ptr()); 20 | } 21 | 22 | TORCH_LIBRARY(wkv7s, m) { 23 | m.def("forward_bf16", forward_bf16); 24 | m.def("forward_fp16", forward_fp16); 25 | m.def("forward_fp32", forward_fp32); 26 | } 27 | -------------------------------------------------------------------------------- /rwkv/cuda/wrapper.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "ATen/ATen.h" 3 | #include 4 | #include 5 | 6 | typedef at::Half fp16; 7 | 8 | template 9 | void cuda_wkv_forward(int B, int T, int C, 10 | float *w, float *u, F *k, F *v, F *y, 11 | float *aa, float *bb, float *pp); 12 | template 13 | void cuda_mm8_seq(int B, int N, int M, 14 | F *x, int x_stride, 15 | uint8_t *w, int w_stride, 16 | F *mx, F *rx, 17 | F *my, F *ry, 18 | F *y, int y_stride); 19 | template 20 | void cuda_mm8_one(int N, int M, 21 | F *x, 22 | uint8_t *w, int w_stride, 23 | F *mx, F *rx, 24 | F *my, F *ry, 25 | float *y); 26 | 27 | void wkv_forward(int64_t B, int64_t T, int64_t C, 28 | torch::Tensor &w, torch::Tensor &u, 29 | torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, 30 | torch::Tensor &aa, torch::Tensor &bb, torch::Tensor &pp) { 31 | const at::cuda::OptionalCUDAGuard device_guard(device_of(w)); 32 | switch (k.scalar_type()) { 33 | case c10::ScalarType::Half: 34 | cuda_wkv_forward(B, T, C, 35 | w.data_ptr(), u.data_ptr(), 36 | k.data_ptr(), v.data_ptr(), y.data_ptr(), 37 | aa.data_ptr(), bb.data_ptr(), pp.data_ptr()); 38 | break; 39 | case c10::ScalarType::Float: 40 | cuda_wkv_forward(B, T, C, 41 | w.data_ptr(), u.data_ptr(), 42 | k.data_ptr(), v.data_ptr(), y.data_ptr(), 43 | aa.data_ptr(), bb.data_ptr(), pp.data_ptr()); 44 | break; 45 | default: 46 | assert(false && "Only FP16 and FP32 are currently supported"); 47 | } 48 | } 49 | 50 | void mm8_seq(int64_t B, int64_t N, int64_t M, 51 | torch::Tensor &x, torch::Tensor &w, 52 | torch::Tensor &mx, torch::Tensor &rx, 53 | torch::Tensor &my, torch::Tensor &ry, 54 | torch::Tensor &y) { 55 | assert(x.stride(1) == 1); 56 | assert(w.stride(1) == 1); 57 | assert(mx.stride(0) == 1 && rx.stride(0) == 1); 58 | assert(my.stride(0) == 1 && ry.stride(0) == 1); 59 | assert(y.stride(1) == 1); 60 | const at::cuda::OptionalCUDAGuard device_guard(device_of(w)); 61 | switch (x.scalar_type()) { 62 | case c10::ScalarType::Half: 63 | cuda_mm8_seq( 64 | B, N, M, 65 | x.data_ptr(), x.stride(0), 66 | w.data_ptr(), w.stride(0), 67 | mx.data_ptr(), rx.data_ptr(), 68 | my.data_ptr(), ry.data_ptr(), 69 | y.data_ptr(), y.stride(0)); 70 | break; 71 | case c10::ScalarType::Float: 72 | cuda_mm8_seq( 73 | B, N, M, 74 | x.data_ptr(), x.stride(0), 75 | w.data_ptr(), w.stride(0), 76 | mx.data_ptr(), rx.data_ptr(), 77 | my.data_ptr(), ry.data_ptr(), 78 | y.data_ptr(), y.stride(0)); 79 | break; 80 | default: 81 | assert(false && "Only FP16 and FP32 are currently supported"); 82 | } 83 | } 84 | void mm8_one(int64_t N, int64_t M, 85 | torch::Tensor &x, torch::Tensor &w, 86 | torch::Tensor &mx, torch::Tensor &rx, 87 | torch::Tensor &my, torch::Tensor &ry, 88 | torch::Tensor &y) { 89 | assert(x.stride(0) == 1); 90 | assert(w.stride(1) == 1); 91 | assert(mx.stride(0) == 1 && rx.stride(0) == 1); 92 | assert(my.stride(0) == 1 && ry.stride(0) == 1); 93 | assert(y.stride(0) == 1); 94 | const at::cuda::OptionalCUDAGuard device_guard(device_of(w)); 95 | switch (x.scalar_type()) { 96 | case c10::ScalarType::Half: 97 | cuda_mm8_one( 98 | N, M, 99 | x.data_ptr(), 100 | w.data_ptr(), w.stride(0), 101 | mx.data_ptr(), rx.data_ptr(), 102 | my.data_ptr(), ry.data_ptr(), 103 | y.data_ptr()); 104 | break; 105 | case c10::ScalarType::Float: 106 | cuda_mm8_one( 107 | N, M, 108 | x.data_ptr(), 109 | w.data_ptr(), w.stride(0), 110 | mx.data_ptr(), rx.data_ptr(), 111 | my.data_ptr(), ry.data_ptr(), 112 | y.data_ptr()); 113 | break; 114 | default: 115 | assert(false && "Only FP16 and FP32 are currently supported"); 116 | } 117 | } 118 | 119 | using torch::Tensor; 120 | 121 | #ifndef DISABLE_CUBLAS_GEMM 122 | void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c); 123 | #endif 124 | 125 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 126 | m.def("wkv_forward", &wkv_forward, "wkv forward"); 127 | m.def("mm8_seq", &mm8_seq, "mm8 seq"); 128 | m.def("mm8_one", &mm8_one, "mm8 one"); 129 | #ifndef DISABLE_CUBLAS_GEMM 130 | m.def("gemm_fp16_cublas", &gemm_fp16_cublas, "gemv fp16 cublas"); 131 | #endif 132 | } 133 | 134 | TORCH_LIBRARY(rwkv, m) { 135 | m.def("wkv_forward", wkv_forward); 136 | m.def("mm8_seq", mm8_seq); 137 | m.def("mm8_one", mm8_one); 138 | #ifndef DISABLE_CUBLAS_GEMM 139 | m.def("gemm_fp16_cublas", gemm_fp16_cublas); 140 | #endif 141 | } 142 | -------------------------------------------------------------------------------- /rwkv/rwkv5.pyd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/rwkv/rwkv5.pyd -------------------------------------------------------------------------------- /rwkv/rwkv6.pyd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/rwkv/rwkv6.pyd -------------------------------------------------------------------------------- /rwkv/rwkv_tokenizer.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################## 2 | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM 3 | ######################################################################################################## 4 | 5 | 6 | class TRIE: 7 | __slots__ = tuple("ch,to,values,front".split(",")) 8 | to: list 9 | values: set 10 | 11 | def __init__(self, front=None, ch=None): 12 | self.ch = ch 13 | self.to = [None for ch in range(256)] 14 | self.values = set() 15 | self.front = front 16 | 17 | def __repr__(self): 18 | fr = self 19 | ret = [] 20 | while fr != None: 21 | if fr.ch != None: 22 | ret.append(fr.ch) 23 | fr = fr.front 24 | return "" % (ret[::-1], self.values) 25 | 26 | def add(self, key: bytes, idx: int = 0, val=None): 27 | if idx == len(key): 28 | if val is None: 29 | val = key 30 | self.values.add(val) 31 | return self 32 | ch = key[idx] 33 | if self.to[ch] is None: 34 | self.to[ch] = TRIE(front=self, ch=ch) 35 | return self.to[ch].add(key, idx=idx + 1, val=val) 36 | 37 | def find_longest(self, key: bytes, idx: int = 0): 38 | u: TRIE = self 39 | ch: int = key[idx] 40 | 41 | while u.to[ch] is not None: 42 | u = u.to[ch] 43 | idx += 1 44 | if u.values: 45 | ret = idx, u, u.values 46 | if idx == len(key): 47 | break 48 | ch = key[idx] 49 | return ret 50 | 51 | 52 | class TRIE_TOKENIZER: 53 | def __init__(self, file_name): 54 | self.idx2token = {} 55 | sorted = [] # must be already sorted 56 | with open(file_name, "r", encoding="utf-8") as f: 57 | lines = f.readlines() 58 | for l in lines: 59 | idx = int(l[: l.index(" ")]) 60 | x = eval(l[l.index(" ") : l.rindex(" ")]) 61 | x = x.encode("utf-8") if isinstance(x, str) else x 62 | assert isinstance(x, bytes) 63 | assert len(x) == int(l[l.rindex(" ") :]) 64 | sorted += [x] 65 | self.idx2token[idx] = x 66 | 67 | self.token2idx = {} 68 | for k, v in self.idx2token.items(): 69 | self.token2idx[v] = int(k) 70 | 71 | self.root = TRIE() 72 | for t, i in self.token2idx.items(): 73 | _ = self.root.add(t, val=(t, i)) 74 | 75 | def encodeBytes(self, src: bytes): 76 | idx: int = 0 77 | tokens = [] 78 | while idx < len(src): 79 | _idx: int = idx 80 | idx, _, values = self.root.find_longest(src, idx) 81 | assert idx != _idx 82 | _, token = next(iter(values)) 83 | tokens.append(token) 84 | return tokens 85 | 86 | def decodeBytes(self, tokens): 87 | return b"".join(map(lambda i: self.idx2token[i], tokens)) 88 | 89 | def encode(self, src): 90 | return self.encodeBytes(src.encode("utf-8")) 91 | 92 | def decode(self, tokens): 93 | try: 94 | return self.decodeBytes(tokens).decode("utf-8") 95 | except: 96 | return "\ufffd" # bad utf-8 97 | 98 | def printTokens(self, tokens): 99 | for i in tokens: 100 | s = self.idx2token[i] 101 | try: 102 | s = s.decode("utf-8") 103 | except: 104 | pass 105 | print(f"{repr(s)}{i}", end=" ") 106 | print() 107 | -------------------------------------------------------------------------------- /rwkv/utils.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################## 2 | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM 3 | ######################################################################################################## 4 | 5 | import os, sys 6 | import numpy as np 7 | import torch 8 | from torch.nn import functional as F 9 | 10 | 11 | class PIPELINE_ARGS: 12 | def __init__( 13 | self, 14 | temperature=1.0, 15 | top_p=0.85, 16 | top_k=0, 17 | alpha_frequency=0.2, 18 | alpha_presence=0.2, 19 | alpha_decay=0.996, 20 | token_ban=[], 21 | token_stop=[], 22 | chunk_len=256, 23 | ): 24 | self.temperature = temperature 25 | self.top_p = top_p 26 | self.top_k = top_k 27 | self.alpha_frequency = alpha_frequency # Frequency Penalty (as in GPT-3) 28 | self.alpha_presence = alpha_presence # Presence Penalty (as in GPT-3) 29 | self.alpha_decay = alpha_decay # gradually decay the penalty 30 | self.token_ban = token_ban # ban the generation of some tokens 31 | self.token_stop = token_stop # stop generation whenever you see any token here 32 | self.chunk_len = ( 33 | chunk_len # split input into chunks to save VRAM (shorter -> slower) 34 | ) 35 | 36 | 37 | class ABC_TOKENIZER: 38 | def __init__(self): 39 | self.pad_token_id = 0 40 | self.bos_token_id = 2 41 | self.eos_token_id = 3 42 | 43 | def encode(self, text): 44 | ids = [ord(c) for c in text] 45 | return ids 46 | 47 | def decode(self, ids): 48 | txt = "".join( 49 | chr(idx) if idx > self.eos_token_id else "" 50 | for idx in ids 51 | if idx != self.eos_token_id 52 | ) 53 | return txt 54 | 55 | 56 | class PIPELINE: 57 | def __init__(self, model, WORD_NAME: str): 58 | self.model = model 59 | 60 | if WORD_NAME == "cl100k_base": 61 | import tiktoken 62 | 63 | self.tokenizer = tiktoken.get_encoding(WORD_NAME) 64 | elif WORD_NAME == "rwkv_vocab_v20230424": 65 | sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) 66 | from rwkv_tokenizer import TRIE_TOKENIZER 67 | 68 | self.tokenizer = TRIE_TOKENIZER( 69 | os.path.dirname(os.path.abspath(__file__)) + "/rwkv_vocab_v20230424.txt" 70 | ) 71 | elif WORD_NAME == "abc_tokenizer": 72 | self.tokenizer = ABC_TOKENIZER() 73 | else: 74 | if WORD_NAME.endswith(".txt"): 75 | sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) 76 | from rwkv_tokenizer import TRIE_TOKENIZER 77 | 78 | self.tokenizer = TRIE_TOKENIZER(WORD_NAME) 79 | else: 80 | from tokenizers import Tokenizer 81 | 82 | self.tokenizer = Tokenizer.from_file(WORD_NAME) 83 | 84 | def refine_context(self, context): 85 | context = context.strip().split("\n") 86 | for c in range(len(context)): 87 | context[c] = context[c].strip().strip("\u3000").strip("\r") 88 | context = list(filter(lambda c: c != "", context)) 89 | context = "\n" + ("\n".join(context)).strip() 90 | if context == "": 91 | context = "\n" 92 | return context 93 | 94 | def encode(self, x): 95 | if "Tokenizer" in str(type(self.tokenizer)): 96 | return self.tokenizer.encode(x).ids 97 | else: 98 | return self.tokenizer.encode(x) 99 | 100 | def decode(self, x): 101 | return self.tokenizer.decode(x) 102 | 103 | def np_softmax(self, x: np.ndarray, axis: int): 104 | x -= x.max(axis=axis, keepdims=True) 105 | e: np.ndarray = np.exp(x) 106 | return e / e.sum(axis=axis, keepdims=True) 107 | 108 | def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0): 109 | if type(logits) == list: 110 | logits = np.array(logits) 111 | np_logits = type(logits) == np.ndarray 112 | if np_logits: 113 | probs = self.np_softmax(logits, axis=-1) 114 | else: 115 | probs = F.softmax(logits.float(), dim=-1) 116 | top_k = int(top_k) 117 | # 'privateuseone' is the type of custom devices like `torch_directml.device()` 118 | if np_logits or probs.device.type in ["cpu", "privateuseone"]: 119 | if not np_logits: 120 | probs = probs.cpu().numpy() 121 | sorted_ids = np.argsort(probs) 122 | sorted_probs = probs[sorted_ids][::-1] 123 | cumulative_probs = np.cumsum(sorted_probs) 124 | cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)]) 125 | probs[probs < cutoff] = 0 126 | if top_k < len(probs) and top_k > 0: 127 | probs[sorted_ids[:-top_k]] = 0 128 | if temperature != 1.0: 129 | probs = probs ** (1.0 / temperature) 130 | probs = probs / np.sum(probs) 131 | out = np.random.choice(a=len(probs), p=probs) 132 | return int(out) 133 | else: 134 | sorted_ids = torch.argsort(probs) 135 | sorted_probs = probs[sorted_ids] 136 | sorted_probs = torch.flip(sorted_probs, dims=(0,)) 137 | cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy() 138 | cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)]) 139 | probs[probs < cutoff] = 0 140 | if top_k < len(probs) and top_k > 0: 141 | probs[sorted_ids[:-top_k]] = 0 142 | if temperature != 1.0: 143 | probs = probs ** (1.0 / temperature) 144 | out = torch.multinomial(probs, num_samples=1)[0] 145 | return int(out) 146 | 147 | def generate( 148 | self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, state=None 149 | ): 150 | all_tokens = [] 151 | out_last = 0 152 | out_str = "" 153 | occurrence = {} 154 | for i in range(token_count): 155 | # forward & adjust prob. 156 | tokens = self.encode(ctx) if i == 0 else [token] 157 | while len(tokens) > 0: 158 | out, state = self.model.forward(tokens[: args.chunk_len], state) 159 | tokens = tokens[args.chunk_len :] 160 | 161 | for n in args.token_ban: 162 | out[n] = -float("inf") 163 | for n in occurrence: 164 | out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency 165 | 166 | # sampler 167 | token = self.sample_logits( 168 | out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k 169 | ) 170 | if token in args.token_stop: 171 | break 172 | all_tokens += [token] 173 | for xxx in occurrence: 174 | occurrence[xxx] *= args.alpha_decay 175 | 176 | ttt = self.decode([token]) 177 | www = 1 178 | if ttt in " \t0123456789": 179 | www = 0 180 | # elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】': 181 | # www = 0.5 182 | if token not in occurrence: 183 | occurrence[token] = www 184 | else: 185 | occurrence[token] += www 186 | # print(occurrence) # debug 187 | 188 | # output 189 | tmp = self.decode(all_tokens[out_last:]) 190 | if "\ufffd" not in tmp: # is valid utf-8 string? 191 | if callback: 192 | callback(tmp) 193 | out_str += tmp 194 | out_last = i + 1 195 | return out_str 196 | -------------------------------------------------------------------------------- /rwkv/webgpu/model.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Union 2 | 3 | try: 4 | import web_rwkv_py as wrp 5 | except ModuleNotFoundError: 6 | try: 7 | from . import web_rwkv_py as wrp 8 | except ImportError: 9 | raise ModuleNotFoundError( 10 | "web_rwkv_py not found, install it from https://github.com/cryscan/web-rwkv-py" 11 | ) 12 | 13 | 14 | class RWKV: 15 | def __init__(self, model_path: str, strategy: str = None): 16 | layer = ( 17 | int(s.lstrip("layer")) 18 | for s in strategy.split() 19 | for s in s.split(",") 20 | if s.startswith("layer") 21 | ) 22 | 23 | chunk_size = ( 24 | int(s.lstrip("chunk")) 25 | for s in strategy.split() 26 | for s in s.split(",") 27 | if s.startswith("chunk") 28 | ) 29 | self.token_chunk_size = next(chunk_size, 32) 30 | 31 | args = { 32 | "path": model_path, 33 | "quant": next(layer, 31) if "i8" in strategy else 0, 34 | "quant_nf4": next(layer, 26) if "i4" in strategy else 0, 35 | } 36 | self.model = wrp.Model(**args) 37 | self.info = self.model.info() 38 | self.w = {} # fake weight 39 | self.w["emb.weight"] = [0] * self.info.num_vocab 40 | self.version = str(self.info.version).lower() 41 | self.version = float(self.version.lower().replace("v", "")) 42 | 43 | def forward(self, tokens: List[int], state: Union[Any, None] = None): 44 | if state is None: 45 | self.model.clear_state() 46 | elif type(state).__name__ == "State_Cpu": 47 | self.model.load_state(state) 48 | logits = self.model.run(tokens, self.token_chunk_size) 49 | ret_state = "State_Gpu" 50 | return logits, ret_state 51 | -------------------------------------------------------------------------------- /rwkv/webgpu/web_rwkv_py.cp310-win_amd64.pyd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/rwkv/webgpu/web_rwkv_py.cp310-win_amd64.pyd -------------------------------------------------------------------------------- /rwkv/wkv7s.pyd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/rwkv/wkv7s.pyd -------------------------------------------------------------------------------- /rwkv/wkv_cuda.pyd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/rwkv/wkv_cuda.pyd -------------------------------------------------------------------------------- /static/config/default.json: -------------------------------------------------------------------------------- 1 | { 2 | "persona": "你好,我是你的智能助理,我将提供专家般的全面回应,请随时提出任何问题,我将永远回答你。", 3 | "coordinate": { 4 | "x": 25, 5 | "y": 35, 6 | "isPC": true 7 | }, 8 | "characters": { 9 | "sense": 0.7, 10 | "pizzazz": 0.5 11 | }, 12 | "messages": [], 13 | "users": { 14 | "you": { 15 | "name": "you", 16 | "avatar": "/static/img/bilibini.png" 17 | }, 18 | "other": { 19 | "name": "MeowAI", 20 | "avatar": "/static/img/ai.png" 21 | } 22 | } 23 | } -------------------------------------------------------------------------------- /static/config/noke.json: -------------------------------------------------------------------------------- 1 | { 2 | "persona": "喵喵~主人您好呀!我是你的小猫咪,喵~", 3 | "coordinate": { 4 | "x": 97, 5 | "y": 47, 6 | "isPC": true 7 | }, 8 | "characters": { 9 | "sense": 0.95, 10 | "pizzazz": 47.35 11 | }, 12 | "messages": [], 13 | "users": { 14 | "you": { 15 | "name": "you", 16 | "avatar": "/static/img/bilibini.png" 17 | }, 18 | "other": { 19 | "name": "noke", 20 | "avatar": "/static/img/ai.png" 21 | } 22 | } 23 | } -------------------------------------------------------------------------------- /static/css/noticejs.css: -------------------------------------------------------------------------------- 1 | .noticejs-top{top:0;width:100%!important}.noticejs-top .item{border-radius:0!important;margin:0!important}.noticejs-topRight{top:10px;right:10px}.noticejs-topLeft{top:10px;left:10px}.noticejs-topCenter{top:10px;left:50%;transform:translate(-50%)}.noticejs-middleLeft,.noticejs-middleRight{right:10px;top:50%;transform:translateY(-50%)}.noticejs-middleLeft{left:10px}.noticejs-middleCenter{top:50%;left:50%;transform:translate(-50%,-50%)}.noticejs-bottom{bottom:0;width:100%!important}.noticejs-bottom .item{border-radius:0!important;margin:0!important}.noticejs-bottomRight{bottom:10px;right:10px}.noticejs-bottomLeft{bottom:10px;left:10px}.noticejs-bottomCenter{bottom:10px;left:50%;transform:translate(-50%)}.noticejs{font-family:Helvetica Neue,Helvetica,Arial,sans-serif}.noticejs .item{margin:0 0 10px;border-radius:3px;overflow:hidden}.noticejs .item .close{float:right;font-size:18px;font-weight:700;line-height:1;color:#fff;text-shadow:0 1px 0 #fff;opacity:1;margin-right:7px}.noticejs .item .close:hover{opacity:.5;color:#000}.noticejs .item a{color:#fff;border-bottom:1px dashed #fff}.noticejs .item a,.noticejs .item a:hover{text-decoration:none}.noticejs .success{background-color:#64ce83}.noticejs .success .noticejs-heading{background-color:#3da95c;color:#fff;padding:10px}.noticejs .success .noticejs-body{color:#fff;padding:10px}.noticejs .success .noticejs-body:hover{visibility:visible!important}.noticejs .success .noticejs-content{visibility:visible}.noticejs .info{background-color:#3ea2ff}.noticejs .info .noticejs-heading{background-color:#067cea;color:#fff;padding:10px}.noticejs .info .noticejs-body{color:#fff;padding:10px}.noticejs .info .noticejs-body:hover{visibility:visible!important}.noticejs .info .noticejs-content{visibility:visible}.noticejs .warning{background-color:#ff7f48}.noticejs .warning .noticejs-heading{background-color:#f44e06;color:#fff;padding:10px}.noticejs .warning .noticejs-body{color:#fff;padding:10px}.noticejs .warning .noticejs-body:hover{visibility:visible!important}.noticejs .warning .noticejs-content{visibility:visible}.noticejs .error{background-color:#e74c3c}.noticejs .error .noticejs-heading{background-color:#ba2c1d;color:#fff;padding:10px}.noticejs .error .noticejs-body{color:#fff;padding:10px}.noticejs .error .noticejs-body:hover{visibility:visible!important}.noticejs .error .noticejs-content{visibility:visible}.noticejs .progressbar{width:100%}.noticejs .progressbar .bar{width:1%;height:30px;background-color:#4caf50}.noticejs .success .noticejs-progressbar{width:100%;background-color:#64ce83;margin-top:-1px}.noticejs .success .noticejs-progressbar .noticejs-bar{width:100%;height:5px;background:#3da95c}.noticejs .info .noticejs-progressbar{width:100%;background-color:#3ea2ff;margin-top:-1px}.noticejs .info .noticejs-progressbar .noticejs-bar{width:100%;height:5px;background:#067cea}.noticejs .warning .noticejs-progressbar{width:100%;background-color:#ff7f48;margin-top:-1px}.noticejs .warning .noticejs-progressbar .noticejs-bar{width:100%;height:5px;background:#f44e06}.noticejs .error .noticejs-progressbar{width:100%;background-color:#e74c3c;margin-top:-1px}.noticejs .error .noticejs-progressbar .noticejs-bar{width:100%;height:5px;background:#ba2c1d}@keyframes noticejs-fadeOut{0%{opacity:1}to{opacity:0}}.noticejs-fadeOut{animation-name:noticejs-fadeOut}@keyframes noticejs-modal-in{to{opacity:.3}}@keyframes noticejs-modal-out{to{opacity:0}}.noticejs-rtl .noticejs-heading{direction:rtl}.noticejs-rtl .close{float:left!important;margin-left:7px;margin-right:0!important}.noticejs-rtl .noticejs-content{direction:rtl}.noticejs{position:fixed;z-index:10050}.noticejs ::-webkit-scrollbar{width:8px}.noticejs ::-webkit-scrollbar-button{width:8px;height:5px}.noticejs ::-webkit-scrollbar-track{border-radius:10px}.noticejs ::-webkit-scrollbar-thumb{background:hsla(0,0%,100%,.5);border-radius:10px}.noticejs ::-webkit-scrollbar-thumb:hover{background:#fff}.noticejs-modal{position:fixed;width:100%;height:100%;background-color:#000;z-index:10000;opacity:.3;left:0;top:0}.noticejs-modal-open{opacity:0;animation:noticejs-modal-in .3s ease-out}.noticejs-modal-close{animation:noticejs-modal-out .3s ease-out;animation-fill-mode:forwards} -------------------------------------------------------------------------------- /static/css/styles.css: -------------------------------------------------------------------------------- 1 | body,* { 2 | margin: 0; 3 | padding: 0; 4 | } 5 | #app { 6 | font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; 7 | margin: 0; 8 | padding: 0; 9 | background-color: #f8f8f8; 10 | display: flex; 11 | align-items: stretch; 12 | height: 100vh; 13 | width: 100%; 14 | } 15 | 16 | 17 | #chat-container { 18 | flex: 1; 19 | background-color: #fff; 20 | box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); 21 | display: flex; 22 | flex-direction: column; 23 | height: 100%; 24 | width: 100%; 25 | } 26 | 27 | #chat-header { 28 | background-color: #4caf50; 29 | color: #fff; 30 | padding: 10px; 31 | text-align: center; 32 | font-size: 18px; 33 | border-bottom: 1px solid #ccc; 34 | } 35 | 36 | #chat-messages { 37 | padding: 20px 20px 50px; 38 | flex: 1; 39 | overflow: hidden; 40 | max-height: 100%; 41 | max-width: 100%; 42 | overflow-y: scroll; 43 | } 44 | 45 | .status-button{ 46 | display: flex; 47 | flex-direction: column; 48 | position: absolute; 49 | bottom: 100px; 50 | right: 20px; 51 | height: 90px; 52 | flex-wrap: nowrap; 53 | button{ 54 | border-color: transparent; 55 | color: #fff; 56 | background-color: #4caf50; 57 | max-width: 65px; 58 | max-height: 65px; 59 | min-width: 40px; 60 | min-height: 40px; 61 | padding: 5px; 62 | border-radius: 10px; 63 | position: relative; 64 | margin-bottom: 5px; 65 | cursor: pointer; 66 | } 67 | } 68 | 69 | #more-button{ 70 | border-color: transparent; 71 | color: #fff; 72 | background-color: #4caf50; 73 | max-width: 65px; 74 | max-height: 65px; 75 | min-width: 40px; 76 | min-height: 40px; 77 | padding: 5px; 78 | border-radius: 10px; 79 | position: absolute; 80 | right: 25px; 81 | top: 0; 82 | cursor: pointer; 83 | } 84 | 85 | .message { 86 | margin-bottom: 15px; 87 | display: flex; 88 | flex-direction: column; 89 | align-items: flex-start; 90 | position: relative; 91 | width: 100%; 92 | .user { 93 | font-weight: bold; 94 | margin-right: 5px; 95 | } 96 | .avatar { 97 | width: 40px; 98 | position: absolute; 99 | top: 23px; 100 | border-radius: 5px; 101 | } 102 | .content { 103 | padding: 10px; 104 | border-radius: 10px; 105 | position: relative; 106 | left: 45px; 107 | max-width: calc(100% - 45px); 108 | word-wrap: break-word; 109 | } 110 | &.you { 111 | .content { 112 | background-color: #ffcc80; 113 | /* background-color:transparent; */ 114 | } 115 | .edit-mode textarea{ 116 | background-color: #ffcc80; 117 | } 118 | } 119 | &.other { 120 | .content { 121 | background-color: #7be383; 122 | /* background-color:transparent; */ 123 | } 124 | .edit-mode textarea{ 125 | background-color: #7be383; 126 | } 127 | } 128 | } 129 | 130 | .context-menu { 131 | /* display: none; */ 132 | position: absolute; 133 | background-color: #fff; 134 | border: 1px solid #ccc; 135 | box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); 136 | z-index: 1; 137 | border-radius: 5px; 138 | overflow: hidden; 139 | } 140 | 141 | .context-menu-item { 142 | padding: 8px; 143 | cursor: pointer; 144 | } 145 | .context-menu-item:hover{ 146 | background-color: #4caf50; 147 | color: #fff; 148 | } 149 | .edit-mode { 150 | /* border: 1px solid #ccc; */ 151 | display: flex; 152 | flex-direction: column; 153 | align-items: flex-start; 154 | position: relative; 155 | left: 45px; 156 | width: calc(100% - 45px); 157 | height: auto; 158 | } 159 | .edit-mode .edit-mode-btn{ 160 | display: flex; 161 | flex-direction: row; 162 | align-items: center; 163 | justify-content: center; 164 | position: relative; 165 | width: 100%; 166 | height: auto; 167 | } 168 | .edit-mode .edit-mode-btn button{ 169 | align-items: center; 170 | border-color: transparent; 171 | border-radius: 0.5rem; 172 | border-width: 1px; 173 | display: inline-flex; 174 | font-size: .875rem; 175 | font-weight: 500; 176 | line-height: 1.25rem; 177 | padding: 0.5rem 0.75rem; 178 | pointer-events: auto; 179 | margin: 10px; 180 | } 181 | .edit-mode .edit-mode-btn .save{ 182 | background-color: rgba(16,163,127,1); 183 | color: rgba(255,255,255,1); 184 | } 185 | .edit-mode .edit-mode-btn .cancel{ 186 | border-color: rgba(0,0,0,.1); 187 | } 188 | .edit-mode textarea{ 189 | padding: 10px; 190 | border-radius: 10px; 191 | background-color: #7be383; 192 | font-family: auto; 193 | font-size: 1rem; 194 | resize: none; 195 | outline: none; 196 | display: inline-block; 197 | overflow-y: hidden; 198 | width: calc(100% - 20px); 199 | border:none; 200 | text-decoration: none; 201 | 202 | } 203 | #message-input { 204 | width: auto; 205 | padding: 10px; 206 | border: none; 207 | border-top: 1px solid #ccc; 208 | outline: none; 209 | } 210 | 211 | #send-button { 212 | width: 100%; 213 | padding: 10px; 214 | background-color: #4caf50; 215 | color: #fff; 216 | border: none; 217 | cursor: pointer; 218 | outline: none; 219 | border-top: 1px solid #ccc; 220 | border-radius: 0 0 8px 8px; 221 | } 222 | #send-button.disabled{ 223 | background-color: #ccc; 224 | } 225 | 226 | #more-container{ 227 | width: 100%; 228 | height: 100%; 229 | position: absolute; 230 | display: flex; 231 | align-content: center; 232 | justify-content: center; 233 | align-items: center; 234 | .bg{ 235 | background-color: #ffffff67; 236 | width: 100%; 237 | height: 100%; 238 | position: absolute; 239 | display: flex; 240 | backdrop-filter: blur(2px); 241 | } 242 | .popup{ 243 | z-index: 3; 244 | background-color:#fff; 245 | width: calc(100% - 75px); 246 | max-width: 700px; 247 | max-height: 100vh; 248 | border-radius: 26px; 249 | overflow: hidden; 250 | box-shadow: 0px 1px 17px #ccc; 251 | position: relative; 252 | .pop-more{ 253 | display: flex; 254 | flex-wrap: wrap; 255 | align-content: space-around; 256 | justify-content: space-between; 257 | align-items: stretch; 258 | .btn{ 259 | width: 50%; 260 | display: flex; 261 | flex-direction: row; 262 | flex-wrap: nowrap; 263 | align-content: center; 264 | justify-content: center; 265 | align-items: center; 266 | font-size: x-large; 267 | min-height: 80px; 268 | user-select: none; 269 | cursor: pointer; 270 | &.user{ 271 | color: rgb(94, 159, 161); 272 | /* transition: 1s all ease-in-out; */ 273 | &:hover{ 274 | background-color: rgb(94, 159, 161); 275 | color: #fff; 276 | } 277 | } 278 | &.persona{ 279 | color: #e69175; 280 | &:hover{ 281 | background-color: #e69175; 282 | color: #fff; 283 | } 284 | } 285 | &.download{ 286 | color: #326e34d9; 287 | &:hover{ 288 | background-color: #326e34d9; 289 | color: #fff; 290 | } 291 | } 292 | &.upload{ 293 | color: rgba(16,163,127,1); 294 | &:hover{ 295 | background-color: rgba(16,163,127,1); 296 | color: #fff; 297 | } 298 | } 299 | } 300 | } 301 | .pop-user{ 302 | min-height: 160px; 303 | display: flex; 304 | flex-direction: row; 305 | flex-wrap: wrap; 306 | align-content: space-around; 307 | justify-content: center; 308 | .user{ 309 | display: flex; 310 | flex-direction: column; 311 | justify-content: center; 312 | margin: 18px; 313 | .name{ 314 | line-height: 38px; 315 | font-size: 20px; 316 | display: flex; 317 | align-content: center; 318 | flex-wrap: nowrap; 319 | align-items: center; 320 | margin-top: 15px; 321 | & label{ 322 | min-width: 45px; 323 | text-align: end; 324 | margin-right: 10px; 325 | } 326 | & input{ 327 | width: 160px; 328 | height: 38px; 329 | font-size: large; 330 | letter-spacing: 0.15px; 331 | border: none; 332 | outline: none; 333 | background-color: #ecf0f3; 334 | transition: 0.25s ease; 335 | border-radius: 8px; 336 | text-align: center; 337 | } 338 | } 339 | .avatar{ 340 | line-height: 38px; 341 | font-size: 20px; 342 | display: flex; 343 | align-content: center; 344 | flex-wrap: nowrap; 345 | align-items: center; 346 | & label{ 347 | min-width: 45px; 348 | text-align: end; 349 | margin-right: 10px; 350 | } 351 | & img{ 352 | width: 40px; 353 | margin-left: calc(50% - 50px); 354 | border-radius: 5px; 355 | } 356 | } 357 | } 358 | } 359 | .pop-persona{ 360 | min-height: 160px; 361 | display: flex; 362 | flex-direction: column; 363 | flex-wrap: nowrap; 364 | align-items: center; 365 | justify-content: center; 366 | position: relative; 367 | .persona{ 368 | margin-bottom: 15px; 369 | display: flex; 370 | flex-direction: column; 371 | align-items: flex-start; 372 | position: relative; 373 | width: calc(100% - 40px); 374 | .user{ 375 | font-weight: bold; 376 | margin-right: 5px; 377 | } 378 | .avatar{ 379 | width: 40px; 380 | position: absolute; 381 | top: 23px; 382 | border-radius: 5px; 383 | } 384 | .content{ 385 | padding: 10px; 386 | border-radius: 10px; 387 | background-color: rgb(123, 227, 131); 388 | font-family: auto; 389 | font-size: 1rem; 390 | outline: none; 391 | display: inline-block; 392 | overflow-y: hidden; 393 | width: calc(100% - 90px); 394 | border: none; 395 | text-decoration: none; 396 | left: 45px; 397 | position: relative; 398 | height: 103px; 399 | } 400 | } 401 | .coordinate{ 402 | position: relative; 403 | width: 185px; 404 | height: 165px; 405 | padding: 0px; 406 | margin: 0px; 407 | 408 | #coordinate-system { 409 | position: absolute; 410 | width: 120px; 411 | height: 120px; 412 | border: 1px solid #ccc; 413 | margin: 30px; 414 | padding: 0px; 415 | border-radius: 10px; 416 | background-image: url(/static/img/coordinate.png); 417 | background-size: cover; 418 | top: 0; 419 | } 420 | 421 | #draggable-point { 422 | position: absolute; 423 | width: 20px; 424 | height: 20px; 425 | padding: 0px; 426 | margin: 0px; 427 | background-color: white; 428 | border-radius: 50%; 429 | box-shadow: 0 0 7px #469905; 430 | } 431 | #coordinate-bg{ 432 | position: relative; 433 | width: 185px; 434 | height: 165px; 435 | margin: 0px; 436 | padding: 0px; 437 | user-select: none; 438 | .l{ 439 | position: absolute; 440 | font-size: 20px; 441 | width: 50px; 442 | top: 6px; 443 | left: 16px; 444 | color: aquamarine; 445 | } 446 | .r{ 447 | position: absolute; 448 | font-size: 20px; 449 | width: 40px; 450 | right: 20px; 451 | bottom: 6px; 452 | color: bisque; 453 | } 454 | .x{ 455 | display: flex; 456 | flex-wrap: nowrap; 457 | justify-content: space-between; 458 | align-items: center; 459 | flex-direction: row; 460 | height: 100%; 461 | } 462 | .y{ 463 | display: flex; 464 | flex-direction: column; 465 | flex-wrap: nowrap; 466 | justify-content: space-between; 467 | align-items: center; 468 | position: absolute; 469 | width: 100%; 470 | height: 100%; 471 | top: 0; 472 | } 473 | } 474 | } 475 | .explain{ 476 | display: flex; 477 | flex-direction: row; 478 | align-items: center; 479 | .mark{ 480 | font-size: small; 481 | color: #00000080; 482 | } 483 | } 484 | } 485 | .pop-extension{ 486 | height: 70vh; 487 | iframe{ 488 | width: 100%; 489 | height: 100%; 490 | } 491 | .back{ 492 | position: absolute; 493 | left: 10px; 494 | top: 10px; 495 | font-size: 1.5em; 496 | color: #000; 497 | } 498 | } 499 | } 500 | } 501 | 502 | 503 | 504 | input[type="checkbox"] + label::before { 505 | content: '\a0'; 506 | /* non-break space */ 507 | display: inline-block; 508 | vertical-align: .2em; 509 | width: 1em; 510 | height: 1em; 511 | margin-right: .2em; 512 | border-radius: .2em; 513 | background: silver; 514 | text-indent: .15em; 515 | line-height: .65; 516 | } 517 | input[type="checkbox"]:checked + label::before { 518 | content: '\2713'; 519 | background: rgb(19, 206, 102); 520 | /* background: yellowgreen; */ 521 | } 522 | input[type="checkbox"] { 523 | position: absolute; 524 | clip: rect(0,0,0,0); 525 | } 526 | input[type="checkbox"]:focus + label::before { 527 | box-shadow: 0 0 .1em .1em #58a; 528 | } 529 | 530 | input[type="checkbox"]:disabled + label::before { 531 | background: gray; 532 | box-shadow: none; 533 | color: #555; 534 | } 535 | 536 | label{ 537 | font-size: 16px; 538 | font-family: monospace; 539 | font-weight: bolder; 540 | display: inline-block; 541 | user-select: none; 542 | } 543 | @media only screen and (max-width: 800px) { 544 | #app { 545 | flex-direction: column; 546 | } 547 | #more-container .popup .pop-persona .explain > p{ 548 | display: none; 549 | } 550 | #chat-container { 551 | max-width: 100%; 552 | } 553 | } 554 | @media only screen and (min-width: 1081px) { 555 | #chat-messages { 556 | padding: 20px 150px; 557 | } 558 | } -------------------------------------------------------------------------------- /static/fonts/FontAwesome.otf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/fonts/FontAwesome.otf -------------------------------------------------------------------------------- /static/fonts/fontawesome-webfont.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/fonts/fontawesome-webfont.eot -------------------------------------------------------------------------------- /static/fonts/fontawesome-webfont.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/fonts/fontawesome-webfont.ttf -------------------------------------------------------------------------------- /static/fonts/fontawesome-webfont.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/fonts/fontawesome-webfont.woff -------------------------------------------------------------------------------- /static/fonts/fontawesome-webfont.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/fonts/fontawesome-webfont.woff2 -------------------------------------------------------------------------------- /static/img/PC_Coordinate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/img/PC_Coordinate.png -------------------------------------------------------------------------------- /static/img/PC_Coordinate.wb: -------------------------------------------------------------------------------- 1 | {"type":"wb","version":2,"source":"file://","elements":[{"id":"L6fDoP5obNbgFGqjcRCuG","type":"arrow","x":480,"y":300,"width":400,"height":0,"angle":0,"strokeColor":"#000000","backgroundColor":"transparent","fillStyle":"hachure","strokeWidth":1,"strokeStyle":"solid","roughness":1,"opacity":100,"groupIds":[],"strokeSharpness":"round","seed":572271130,"version":160,"versionNonce":474265242,"isDeleted":false,"boundElements":null,"updated":1705718257591,"link":null,"points":[[0,0],[400,0]],"lastCommittedPoint":null,"startBinding":null,"endBinding":null,"startArrowhead":null,"endArrowhead":"arrow"},{"id":"5z9q-4mzmkDFwKX2nqCdJ","type":"arrow","x":681.2758381748386,"y":100.67762530921956,"width":1.0445368363056104,"height":399.3223746907804,"angle":6.279679857460607,"strokeColor":"#000000","backgroundColor":"transparent","fillStyle":"hachure","strokeWidth":1,"strokeStyle":"solid","roughness":1,"opacity":100,"groupIds":[],"strokeSharpness":"round","seed":257727814,"version":222,"versionNonce":1081808902,"isDeleted":false,"boundElements":null,"updated":1705718257591,"link":null,"points":[[0,0],[-1.0445368363056104,399.3223746907804]],"lastCommittedPoint":null,"startBinding":{"elementId":"3PKxT8UPTnbRp5Ru00F-Z","focus":-0.002430146503276014,"gap":14.677094934406512},"endBinding":null,"startArrowhead":null,"endArrowhead":"arrow"},{"id":"3PKxT8UPTnbRp5Ru00F-Z","type":"text","x":660,"y":60,"width":41,"height":26,"angle":0,"strokeColor":"#000000","backgroundColor":"transparent","fillStyle":"hachure","strokeWidth":1,"strokeStyle":"solid","roughness":1,"opacity":100,"groupIds":[],"strokeSharpness":"sharp","seed":1022181850,"version":33,"versionNonce":1923491654,"isDeleted":false,"boundElements":[{"id":"5z9q-4mzmkDFwKX2nqCdJ","type":"arrow"}],"updated":1705718257591,"link":null,"text":"理性","fontSize":20,"fontFamily":1,"textAlign":"left","verticalAlign":"top","baseline":19,"containerId":null,"originalText":"理性"},{"id":"k15c12wy0aqNWA5j7MQFF","type":"text","x":660,"y":520,"width":41,"height":26,"angle":0,"strokeColor":"#000000","backgroundColor":"transparent","fillStyle":"hachure","strokeWidth":1,"strokeStyle":"solid","roughness":1,"opacity":100,"groupIds":[],"strokeSharpness":"sharp","seed":1546230106,"version":24,"versionNonce":262625306,"isDeleted":false,"boundElements":null,"updated":1705718257591,"link":null,"text":"感性","fontSize":20,"fontFamily":1,"textAlign":"left","verticalAlign":"top","baseline":19,"containerId":null,"originalText":"感性"},{"id":"GgNROi80SMPILCbXx0kEP","type":"text","x":420,"y":280,"width":41,"height":26,"angle":0,"strokeColor":"#000000","backgroundColor":"transparent","fillStyle":"hachure","strokeWidth":1,"strokeStyle":"solid","roughness":1,"opacity":100,"groupIds":[],"strokeSharpness":"sharp","seed":289730118,"version":26,"versionNonce":590171782,"isDeleted":false,"boundElements":null,"updated":1705718257591,"link":null,"text":"严肃","fontSize":20,"fontFamily":1,"textAlign":"left","verticalAlign":"top","baseline":19,"containerId":null,"originalText":"严肃"},{"id":"BNnHNrk1B4LzQwznG8tLq","type":"text","x":900,"y":280,"width":41,"height":26,"angle":0,"strokeColor":"#000000","backgroundColor":"transparent","fillStyle":"hachure","strokeWidth":1,"strokeStyle":"solid","roughness":1,"opacity":100,"groupIds":[],"strokeSharpness":"sharp","seed":1733198490,"version":45,"versionNonce":1811498202,"isDeleted":false,"boundElements":null,"updated":1705718257591,"link":null,"text":"活泼","fontSize":20,"fontFamily":1,"textAlign":"left","verticalAlign":"top","baseline":19,"containerId":null,"originalText":"活泼"},{"id":"OJStimQe66swL1JzMczgt","type":"text","x":460,"y":120,"width":81,"height":26,"angle":0,"strokeColor":"#000000","backgroundColor":"transparent","fillStyle":"hachure","strokeWidth":1,"strokeStyle":"solid","roughness":1,"opacity":100,"groupIds":[],"strokeSharpness":"sharp","seed":2097748762,"version":42,"versionNonce":1184750682,"isDeleted":false,"boundElements":null,"updated":1705718283503,"link":null,"text":"钻牛角尖","fontSize":20,"fontFamily":1,"textAlign":"left","verticalAlign":"top","baseline":19,"containerId":null,"originalText":"钻牛角尖"},{"id":"JpyE0DsCSJ3XCfyoSOZnu","type":"text","x":840,"y":460,"width":81,"height":26,"angle":0,"strokeColor":"#000000","backgroundColor":"transparent","fillStyle":"hachure","strokeWidth":1,"strokeStyle":"solid","roughness":1,"opacity":100,"groupIds":[],"strokeSharpness":"sharp","seed":1562129606,"version":19,"versionNonce":312730886,"isDeleted":false,"boundElements":null,"updated":1705718257591,"link":null,"text":"放飞自我","fontSize":20,"fontFamily":1,"textAlign":"left","verticalAlign":"top","baseline":19,"containerId":null,"originalText":"放飞自我"}],"appState":{"gridSize":20,"viewBackgroundColor":"#ffffff"},"files":{}} -------------------------------------------------------------------------------- /static/img/ai.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/img/ai.png -------------------------------------------------------------------------------- /static/img/ai2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/img/ai2.png -------------------------------------------------------------------------------- /static/img/ai3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/img/ai3.png -------------------------------------------------------------------------------- /static/img/ai4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/img/ai4.png -------------------------------------------------------------------------------- /static/img/ai42.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/img/ai42.png -------------------------------------------------------------------------------- /static/img/bilibini.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/img/bilibini.png -------------------------------------------------------------------------------- /static/img/coordinate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/img/coordinate.png -------------------------------------------------------------------------------- /static/img/nekomusume.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/img/nekomusume.png -------------------------------------------------------------------------------- /static/js/main.js: -------------------------------------------------------------------------------- 1 | const app = Vue.createApp({ 2 | data() { 3 | return { 4 | users: { 5 | you: { 6 | name: 'you', 7 | avatar: '/static/img/bilibini.png', 8 | }, 9 | other: { 10 | name: 'MeowAI', 11 | avatar: '/static/img/ai.png', 12 | } 13 | }, 14 | coordinate: { x: 25, y: 35, isPC: true }, 15 | persona: '主人你好呀!我是你的可爱猫娘,喵~', 16 | messages: [], 17 | //{type: 'you',content: '聊天内容',editMode: false,editedContent:'编辑内容'} 18 | socket: null, 19 | currentMessage: '', 20 | partialMessage: '', 21 | isDragging: false,//是否正在拖拽 22 | showStopButton: false, 23 | showContextMenu: false, 24 | showMoreSetup: -1, 25 | contextMenuIndex: -1, 26 | contextMenuPosition: { x: 0, y: 0 }, 27 | extensionSRC:"/extension/", 28 | }; 29 | }, 30 | mounted() { 31 | this.socket = io.connect('http://' + document.domain + ':' + location.port); 32 | this.socket.on('chat', (data) => { 33 | if (this.partialMessage != '') { 34 | this.messages.pop(); 35 | } 36 | this.partialMessage += data; 37 | this.messages.push({ type: 'other', content: this.partialMessage, editMode: false, editedContent: '' }); 38 | 39 | const endOfMessageIndex = this.partialMessage.indexOf('\n\n'); 40 | if (endOfMessageIndex !== -1) { 41 | this.partialMessage = ''; 42 | this.scrollToBottom(); 43 | this.stopReply(); 44 | } 45 | }); 46 | this.socket.on('stop', (data) => { 47 | if (data == 'True' || data == true) { 48 | this.partialMessage = ''; 49 | this.showStopButton = false; 50 | } 51 | }); 52 | this.socket.on('emit', (data) => { 53 | data=JSON.parse(data) 54 | new NoticeJs({ 55 | type: data.code==0?'success':'error', 56 | text: data.msg, 57 | timeout:45, 58 | position: 'topLeft', 59 | }).show(); 60 | }); 61 | let that=this; 62 | window.addEventListener('message', function(event) { 63 | console.log(that.extensionSRC); 64 | that.extensionSRC=event.data; 65 | }); 66 | }, 67 | computed: { 68 | messagesInfo() { 69 | messages=[]; 70 | for (let message of this.messages) { 71 | messages.push({ 72 | "role": message.type=='you'?"User":"Assistant", 73 | "content": message.content, 74 | }) 75 | } 76 | return messages; 77 | }, 78 | characters() { 79 | if (!this.coordinate) { 80 | console.error("Coordinate is not defined!"); 81 | return {"sense":1.1,"pizzazz":0.5}; 82 | } 83 | 84 | let pizzazz = this.coordinate.x || 0.01;//top_p 85 | let sense = this.coordinate.y || 0.01;//temperature 86 | 87 | pizzazz=pizzazz<=50?pizzazz/50:(pizzazz-50)/5; 88 | sense=sense<=50?sense/50:(sense-50)/5; 89 | 90 | if (this.coordinate.isPC) { 91 | if (sense>=2&&pizzazz<=2){ 92 | pizzazz=pizzazz>1?pizzazz*0.2:pizzazz*0.5; 93 | }else if(pizzazz>=2&&sense<=2){ 94 | sense=sense>1?sense*0.2:sense*0.5; 95 | }else{ 96 | pizzazz=pizzazz>1?pizzazz*0.2:pizzazz*0.5; 97 | sense=sense>1?sense*0.2:sense*0.5; 98 | } 99 | sense=sense*0.5; 100 | pizzazz=pizzazz*0.5; 101 | } 102 | 103 | console.log(this.coordinate.isPC, "pizzazz(top_p):",pizzazz,"sense(temperature):",sense); 104 | return { sense, pizzazz }; 105 | }, 106 | }, 107 | methods: { 108 | sendMessage() { 109 | //发送消息 110 | const message = this.currentMessage.trim(); 111 | if (message != '') { 112 | this.showStopButton = true; 113 | this.messages.push({ type: 'you', content: message, editMode: false, editedContent: '' }); 114 | 115 | this.currentMessage = ''; 116 | this.partialMessage = ''; 117 | this.scrollToBottom(); 118 | 119 | this.socket.emit('chat', this.messagesInfo); 120 | console.log(this.messagesInfo) 121 | } 122 | }, 123 | setCharacter() { 124 | //设置性格 125 | this.socket.emit('character', { 126 | "persona": this.persona, 127 | "temperature": this.characters.sense, 128 | "top_p": this.characters.pizzazz, 129 | }); 130 | }, 131 | scrollToBottom() { 132 | //滚动到聊天消息底部 133 | const messagesContainer = document.getElementById('chat-messages'); 134 | messagesContainer.scrollTop = messagesContainer.scrollHeight; 135 | }, 136 | openExtension() { 137 | //打开扩展 138 | window.open('/extension','_blank') 139 | }, 140 | stopReply() { 141 | //停止对话 142 | this.socket.emit('stop', true); 143 | }, 144 | resetReply() { 145 | //重新对话 146 | console.log(this.extensionSRC) 147 | this.showStopButton = true; 148 | this.messages.pop(); 149 | this.socket.emit('chat', this.messagesInfo); 150 | console.log(this.messagesInfo) 151 | }, 152 | cutParagraphs(text) { 153 | //切断落 154 | return text.split('\n') 155 | }, 156 | showContextMenuFun(index, event) { 157 | // 显示菜单 158 | this.showContextMenu = true; 159 | this.contextMenuIndex = index; 160 | this.contextMenuPosition = { x: event.offsetX + 55, y: event.offsetY }; 161 | }, 162 | hideContextMenuFun() { 163 | // 隐藏菜单 164 | this.showContextMenu = false; 165 | this.contextMenuIndex = -1; 166 | }, 167 | copyMessage(index) { 168 | // 复制消息 169 | this.showContextMenu = false; 170 | navigator.clipboard.writeText(this.messages[index].content); 171 | }, 172 | editMessage(index, event) { 173 | // 编辑消息 174 | this.messages[index].editMode = true; 175 | this.messages[index].editedContent = this.messages[index].content; 176 | }, 177 | deleteMessage(index) { 178 | // 删除消息 179 | this.messages.splice(index, 1); 180 | }, 181 | confirmEdit(index) { 182 | // 确认编辑 183 | this.messages[index].content = this.messages[index].editedContent; 184 | this.messages[index].editMode = false; 185 | }, 186 | cancelEdit(index) { 187 | // 取消编辑 188 | this.messages[index].editMode = false; 189 | }, 190 | autoResize(event) { 191 | //更新编辑框高度 192 | event.target.style.height = (event.target.scrollHeight - 20) + "px"; 193 | }, 194 | clickAvatarBox(typeAvatar) { 195 | //更换头像 196 | if (typeAvatar == 'you') { 197 | this.$refs.youAvatarToUpload.click(); 198 | } else { 199 | this.$refs.otherAvatarToUpload.click(); 200 | } 201 | }, 202 | changeAvatarFile(typeAvatar) { 203 | //上传头像 204 | let file = typeAvatar == 'you' ? this.$refs.youAvatarToUpload.files[0] : this.$refs.otherAvatarToUpload.files[0]; 205 | let thisf = this; 206 | if (file && file instanceof Blob) { 207 | const reader = new FileReader(); 208 | reader.onloadend = function () { 209 | const base64DataUrl = reader.result; // 获取到图片文件的Base64编码数据URL 210 | typeAvatar == 'you' ? thisf.users.you.avatar = base64DataUrl + window.btoa('abc') : thisf.users.other.avatar = base64DataUrl; 211 | }; 212 | reader.readAsDataURL(file); // 开始读取图片文件并转换成Base64格式 213 | } else { 214 | console.error("无效的图片文件"); 215 | new NoticeJs({ 216 | type: 'error', 217 | text: "无效的图片文件", 218 | timeout:45, 219 | position: 'topLeft', 220 | }).show(); 221 | } 222 | }, 223 | clickDownloadConfig() { 224 | //下载配置 225 | let config = { 226 | "persona": this.persona, 227 | "coordinate": this.coordinate, 228 | "characters": this.characters, 229 | "messages": this.messages, 230 | "users": this.users, 231 | } 232 | 233 | let file = new File([JSON.stringify(config)], 'meowAI.json', { type: 'text/plain' }); 234 | this.$refs.downloadConfig.download = 'meowAI.json'; 235 | this.$refs.downloadConfig.href = URL.createObjectURL(file); 236 | this.$refs.downloadConfig.click(); 237 | }, 238 | changeUploadConfig() { 239 | //上传配置 240 | let file = this.$refs.uploadConfig.files[0]; 241 | let thisf = this; 242 | if (file && file instanceof Blob) { 243 | const reader = new FileReader(); 244 | reader.onloadend = function () { 245 | let configData = reader.result; // 获取到配置文件内容 246 | configData = JSON.parse(configData); 247 | thisf.persona = configData.persona; 248 | thisf.coordinate = configData.coordinate; 249 | thisf.messages = configData.messages; 250 | thisf.users = configData.users; 251 | thisf.showMoreSetup = -1; 252 | }; 253 | reader.readAsText(file); // 开始读取配置文件 254 | new NoticeJs({ 255 | type: 'success', 256 | text: "读取成功", 257 | timeout:10, 258 | position: 'topLeft', 259 | }).show(); 260 | } else { 261 | console.error("无效的配置文件"); 262 | new NoticeJs({ 263 | type: 'error', 264 | text: "无效的配置文件", 265 | timeout:45, 266 | position: 'topLeft', 267 | }).show(); 268 | } 269 | }, 270 | clickUploadConfig() { 271 | this.$refs.uploadConfig.click(); 272 | }, 273 | mousemoveCoordinate(e) { 274 | //拖动坐标 275 | const coordinateSystem = this.$refs.coordinateSystem; 276 | const draggablePoint = this.$refs.draggablePoint; 277 | if (this.isDragging) { 278 | const x = e.clientX - coordinateSystem.getBoundingClientRect().left - 10; 279 | const y = e.clientY - coordinateSystem.getBoundingClientRect().top - 10; 280 | const maxX = coordinateSystem.clientWidth - draggablePoint.clientWidth; 281 | const maxY = coordinateSystem.clientHeight - draggablePoint.clientHeight; 282 | const clampedX = Math.min(Math.max(0, x), maxX); 283 | const clampedY = Math.min(Math.max(0, y), maxY); 284 | 285 | this.coordinate.y = clampedY; 286 | this.coordinate.x = clampedX; 287 | // console.log(`X: ${clampedX}, Y: ${clampedY}`); 288 | } 289 | }, 290 | }, 291 | }); 292 | 293 | app.mount('#app'); -------------------------------------------------------------------------------- /static/js/notice.js: -------------------------------------------------------------------------------- 1 | (function webpackUniversalModuleDefinition(root, factory) { 2 | if(typeof exports === 'object' && typeof module === 'object') 3 | module.exports = factory(); 4 | else if(typeof define === 'function' && define.amd) 5 | define("NoticeJs", [], factory); 6 | else if(typeof exports === 'object') 7 | exports["NoticeJs"] = factory(); 8 | else 9 | root["NoticeJs"] = factory(); 10 | })(typeof self !== 'undefined' ? self : this, function() { 11 | return /******/ (function(modules) { // webpackBootstrap 12 | /******/ // The module cache 13 | /******/ var installedModules = {}; 14 | /******/ 15 | /******/ // The require function 16 | /******/ function __webpack_require__(moduleId) { 17 | /******/ 18 | /******/ // Check if module is in cache 19 | /******/ if(installedModules[moduleId]) { 20 | /******/ return installedModules[moduleId].exports; 21 | /******/ } 22 | /******/ // Create a new module (and put it into the cache) 23 | /******/ var module = installedModules[moduleId] = { 24 | /******/ i: moduleId, 25 | /******/ l: false, 26 | /******/ exports: {} 27 | /******/ }; 28 | /******/ 29 | /******/ // Execute the module function 30 | /******/ modules[moduleId].call(module.exports, module, module.exports, __webpack_require__); 31 | /******/ 32 | /******/ // Flag the module as loaded 33 | /******/ module.l = true; 34 | /******/ 35 | /******/ // Return the exports of the module 36 | /******/ return module.exports; 37 | /******/ } 38 | /******/ 39 | /******/ 40 | /******/ // expose the modules object (__webpack_modules__) 41 | /******/ __webpack_require__.m = modules; 42 | /******/ 43 | /******/ // expose the module cache 44 | /******/ __webpack_require__.c = installedModules; 45 | /******/ 46 | /******/ // define getter function for harmony exports 47 | /******/ __webpack_require__.d = function(exports, name, getter) { 48 | /******/ if(!__webpack_require__.o(exports, name)) { 49 | /******/ Object.defineProperty(exports, name, { 50 | /******/ configurable: false, 51 | /******/ enumerable: true, 52 | /******/ get: getter 53 | /******/ }); 54 | /******/ } 55 | /******/ }; 56 | /******/ 57 | /******/ // getDefaultExport function for compatibility with non-harmony modules 58 | /******/ __webpack_require__.n = function(module) { 59 | /******/ var getter = module && module.__esModule ? 60 | /******/ function getDefault() { return module['default']; } : 61 | /******/ function getModuleExports() { return module; }; 62 | /******/ __webpack_require__.d(getter, 'a', getter); 63 | /******/ return getter; 64 | /******/ }; 65 | /******/ 66 | /******/ // Object.prototype.hasOwnProperty.call 67 | /******/ __webpack_require__.o = function(object, property) { return Object.prototype.hasOwnProperty.call(object, property); }; 68 | /******/ 69 | /******/ // __webpack_public_path__ 70 | /******/ __webpack_require__.p = "dist/"; 71 | /******/ 72 | /******/ // Load entry module and return exports 73 | /******/ return __webpack_require__(__webpack_require__.s = 2); 74 | /******/ }) 75 | /************************************************************************/ 76 | /******/ ([ 77 | /* 0 */ 78 | /***/ (function(module, exports, __webpack_require__) { 79 | 80 | "use strict"; 81 | 82 | 83 | Object.defineProperty(exports, "__esModule", { 84 | value: true 85 | }); 86 | var noticeJsModalClassName = exports.noticeJsModalClassName = 'noticejs-modal'; 87 | var closeAnimation = exports.closeAnimation = 'noticejs-fadeOut'; 88 | 89 | var Defaults = exports.Defaults = { 90 | title: '', 91 | text: '', 92 | type: 'success', 93 | position: 'topRight', 94 | newestOnTop: false, 95 | timeout: 30, 96 | progressBar: true, 97 | closeWith: ['button'], 98 | animation: null, 99 | modal: false, 100 | width: 320, 101 | scroll: { 102 | maxHeightContent: 300, 103 | showOnHover: true 104 | }, 105 | rtl: false, 106 | callbacks: { 107 | beforeShow: [], 108 | onShow: [], 109 | afterShow: [], 110 | onClose: [], 111 | afterClose: [], 112 | onClick: [], 113 | onHover: [], 114 | onTemplate: [] 115 | } 116 | }; 117 | 118 | /***/ }), 119 | /* 1 */ 120 | /***/ (function(module, exports, __webpack_require__) { 121 | 122 | "use strict"; 123 | 124 | 125 | Object.defineProperty(exports, "__esModule", { 126 | value: true 127 | }); 128 | exports.appendNoticeJs = exports.addListener = exports.CloseItem = exports.AddModal = undefined; 129 | exports.getCallback = getCallback; 130 | 131 | var _api = __webpack_require__(0); 132 | 133 | var API = _interopRequireWildcard(_api); 134 | 135 | function _interopRequireWildcard(obj) { if (obj && obj.__esModule) { return obj; } else { var newObj = {}; if (obj != null) { for (var key in obj) { if (Object.prototype.hasOwnProperty.call(obj, key)) newObj[key] = obj[key]; } } newObj.default = obj; return newObj; } } 136 | 137 | var options = API.Defaults; 138 | 139 | /** 140 | * @param {NoticeJs} ref 141 | * @param {string} eventName 142 | * @return {void} 143 | */ 144 | function getCallback(ref, eventName) { 145 | if (ref.callbacks.hasOwnProperty(eventName)) { 146 | ref.callbacks[eventName].forEach(function (cb) { 147 | if (typeof cb === 'function') { 148 | cb.apply(ref); 149 | } 150 | }); 151 | } 152 | } 153 | 154 | var AddModal = exports.AddModal = function AddModal() { 155 | if (document.getElementsByClassName(API.noticeJsModalClassName).length <= 0) { 156 | var element = document.createElement('div'); 157 | element.classList.add(API.noticeJsModalClassName); 158 | element.classList.add('noticejs-modal-open'); 159 | document.body.appendChild(element); 160 | // Remove class noticejs-modal-open 161 | setTimeout(function () { 162 | element.className = API.noticeJsModalClassName; 163 | }, 200); 164 | } 165 | }; 166 | 167 | var CloseItem = exports.CloseItem = function CloseItem(item) { 168 | getCallback(options, 'onClose'); 169 | 170 | // Set animation to close notification item 171 | if (options.animation !== null && options.animation.close !== null) { 172 | item.className += ' ' + options.animation.close; 173 | } 174 | setTimeout(function () { 175 | item.remove(); 176 | }, 200); 177 | 178 | // Close modal 179 | if (options.modal === true && document.querySelectorAll("[noticejs-modal='true']").length >= 1) { 180 | document.querySelector('.noticejs-modal').className += ' noticejs-modal-close'; 181 | setTimeout(function () { 182 | document.querySelector('.noticejs-modal').remove(); 183 | }, 500); 184 | } 185 | 186 | // Remove container 187 | var position = '.' + item.closest('.noticejs').className.replace('noticejs', '').trim(); 188 | setTimeout(function () { 189 | if (document.querySelectorAll(position + ' .item').length <= 0) { 190 | document.querySelector(position).remove(); 191 | } 192 | }, 500); 193 | }; 194 | 195 | var addListener = exports.addListener = function addListener(item) { 196 | // Add close button Event 197 | if (options.closeWith.includes('button')) { 198 | item.querySelector('.close').addEventListener('click', function () { 199 | CloseItem(item); 200 | }); 201 | } 202 | 203 | // Add close by click Event 204 | if (options.closeWith.includes('click')) { 205 | item.style.cursor = 'pointer'; 206 | item.addEventListener('click', function (e) { 207 | if (e.target.className !== 'close') { 208 | getCallback(options, 'onClick'); 209 | CloseItem(item); 210 | } 211 | }); 212 | } else { 213 | item.addEventListener('click', function (e) { 214 | if (e.target.className !== 'close') { 215 | getCallback(options, 'onClick'); 216 | } 217 | }); 218 | } 219 | 220 | item.addEventListener('mouseover', function () { 221 | getCallback(options, 'onHover'); 222 | }); 223 | }; 224 | 225 | var appendNoticeJs = exports.appendNoticeJs = function appendNoticeJs(noticeJsHeader, noticeJsBody, noticeJsProgressBar) { 226 | var target_class = '.noticejs-' + options.position; 227 | // Create NoticeJs item 228 | var noticeJsItem = document.createElement('div'); 229 | noticeJsItem.classList.add('item'); 230 | noticeJsItem.classList.add(options.type); 231 | if (options.rtl === true) { 232 | noticeJsItem.classList.add('noticejs-rtl'); 233 | } 234 | if (options.width !== '' && Number.isInteger(options.width)) { 235 | noticeJsItem.style.width = options.width + 'px'; 236 | } 237 | 238 | // Add Header 239 | if (noticeJsHeader && noticeJsHeader !== '') { 240 | noticeJsItem.appendChild(noticeJsHeader); 241 | } 242 | 243 | // Add body 244 | noticeJsItem.appendChild(noticeJsBody); 245 | 246 | // Add progress bar 247 | if (noticeJsProgressBar && noticeJsProgressBar !== '') { 248 | noticeJsItem.appendChild(noticeJsProgressBar); 249 | } 250 | 251 | // Empty top and bottom container 252 | if (['top', 'bottom'].includes(options.position)) { 253 | document.querySelector(target_class).innerHTML = ''; 254 | } 255 | 256 | // Add open animation 257 | if (options.animation !== null && options.animation.open !== null) { 258 | noticeJsItem.className += ' ' + options.animation.open; 259 | } 260 | 261 | // Add Modal 262 | if (options.modal === true) { 263 | noticeJsItem.setAttribute('noticejs-modal', 'true'); 264 | AddModal(); 265 | } 266 | 267 | // Add Listener 268 | addListener(noticeJsItem, options.closeWith); 269 | 270 | getCallback(options, 'beforeShow'); 271 | getCallback(options, 'onShow'); 272 | if (options.newestOnTop === true) { 273 | document.querySelector(target_class).insertAdjacentElement('afterbegin', noticeJsItem); 274 | } else { 275 | document.querySelector(target_class).appendChild(noticeJsItem); 276 | } 277 | getCallback(options, 'afterShow'); 278 | 279 | return noticeJsItem; 280 | }; 281 | 282 | /***/ }), 283 | /* 2 */ 284 | /***/ (function(module, exports, __webpack_require__) { 285 | 286 | "use strict"; 287 | 288 | 289 | Object.defineProperty(exports, "__esModule", { 290 | value: true 291 | }); 292 | 293 | var _createClass = function () { function defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if ("value" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } } return function (Constructor, protoProps, staticProps) { if (protoProps) defineProperties(Constructor.prototype, protoProps); if (staticProps) defineProperties(Constructor, staticProps); return Constructor; }; }(); 294 | 295 | var _noticejs = __webpack_require__(3); 296 | 297 | var _noticejs2 = _interopRequireDefault(_noticejs); 298 | 299 | var _api = __webpack_require__(0); 300 | 301 | var API = _interopRequireWildcard(_api); 302 | 303 | var _components = __webpack_require__(4); 304 | 305 | var _helpers = __webpack_require__(1); 306 | 307 | var helper = _interopRequireWildcard(_helpers); 308 | 309 | function _interopRequireWildcard(obj) { if (obj && obj.__esModule) { return obj; } else { var newObj = {}; if (obj != null) { for (var key in obj) { if (Object.prototype.hasOwnProperty.call(obj, key)) newObj[key] = obj[key]; } } newObj.default = obj; return newObj; } } 310 | 311 | function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; } 312 | 313 | function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } } 314 | 315 | var NoticeJs = function () { 316 | /** 317 | * @param {object} options 318 | * @returns {Noty} 319 | */ 320 | function NoticeJs() { 321 | var options = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : {}; 322 | 323 | _classCallCheck(this, NoticeJs); 324 | 325 | this.options = Object.assign(API.Defaults, options); 326 | this.component = new _components.Components(); 327 | 328 | this.on('beforeShow', this.options.callbacks.beforeShow); 329 | this.on('onShow', this.options.callbacks.onShow); 330 | this.on('afterShow', this.options.callbacks.afterShow); 331 | this.on('onClose', this.options.callbacks.onClose); 332 | this.on('afterClose', this.options.callbacks.afterClose); 333 | this.on('onClick', this.options.callbacks.onClick); 334 | this.on('onHover', this.options.callbacks.onHover); 335 | 336 | return this; 337 | } 338 | 339 | /** 340 | * @returns {NoticeJs} 341 | */ 342 | 343 | 344 | _createClass(NoticeJs, [{ 345 | key: 'show', 346 | value: function show() { 347 | var container = this.component.createContainer(); 348 | if (document.querySelector('.noticejs-' + this.options.position) === null) { 349 | document.body.appendChild(container); 350 | } 351 | 352 | var noticeJsHeader = void 0; 353 | var noticeJsBody = void 0; 354 | var noticeJsProgressBar = void 0; 355 | 356 | // Create NoticeJs header 357 | noticeJsHeader = this.component.createHeader(this.options.title, this.options.closeWith); 358 | 359 | // Create NoticeJs body 360 | noticeJsBody = this.component.createBody(this.options.text); 361 | 362 | // Create NoticeJs progressBar 363 | if (this.options.progressBar === true) { 364 | noticeJsProgressBar = this.component.createProgressBar(); 365 | } 366 | 367 | //Append NoticeJs 368 | var noticeJs = helper.appendNoticeJs(noticeJsHeader, noticeJsBody, noticeJsProgressBar); 369 | 370 | return noticeJs; 371 | } 372 | 373 | /** 374 | * @param {string} eventName 375 | * @param {function} cb 376 | * @return {NoticeJs} 377 | */ 378 | 379 | }, { 380 | key: 'on', 381 | value: function on(eventName) { 382 | var cb = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : function () {}; 383 | 384 | if (typeof cb === 'function' && this.options.callbacks.hasOwnProperty(eventName)) { 385 | this.options.callbacks[eventName].push(cb); 386 | } 387 | 388 | return this; 389 | } 390 | 391 | /** 392 | * @param {Object} options 393 | * @return {Notice} 394 | */ 395 | 396 | }], [{ 397 | key: 'overrideDefaults', 398 | value: function overrideDefaults(options) { 399 | this.options = Object.assign(API.Defaults, options); 400 | return this; 401 | } 402 | }]); 403 | 404 | return NoticeJs; 405 | }(); 406 | 407 | exports.default = NoticeJs; 408 | module.exports = exports['default']; 409 | 410 | /***/ }), 411 | /* 3 */ 412 | /***/ (function(module, exports) { 413 | 414 | // removed by extract-text-webpack-plugin 415 | 416 | /***/ }), 417 | /* 4 */ 418 | /***/ (function(module, exports, __webpack_require__) { 419 | 420 | "use strict"; 421 | 422 | 423 | Object.defineProperty(exports, "__esModule", { 424 | value: true 425 | }); 426 | exports.Components = undefined; 427 | 428 | var _createClass = function () { function defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if ("value" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } } return function (Constructor, protoProps, staticProps) { if (protoProps) defineProperties(Constructor.prototype, protoProps); if (staticProps) defineProperties(Constructor, staticProps); return Constructor; }; }(); 429 | 430 | var _api = __webpack_require__(0); 431 | 432 | var API = _interopRequireWildcard(_api); 433 | 434 | var _helpers = __webpack_require__(1); 435 | 436 | var helper = _interopRequireWildcard(_helpers); 437 | 438 | function _interopRequireWildcard(obj) { if (obj && obj.__esModule) { return obj; } else { var newObj = {}; if (obj != null) { for (var key in obj) { if (Object.prototype.hasOwnProperty.call(obj, key)) newObj[key] = obj[key]; } } newObj.default = obj; return newObj; } } 439 | 440 | function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } } 441 | 442 | var options = API.Defaults; 443 | 444 | var Components = exports.Components = function () { 445 | function Components() { 446 | _classCallCheck(this, Components); 447 | } 448 | 449 | _createClass(Components, [{ 450 | key: 'createContainer', 451 | value: function createContainer() { 452 | var element_class = 'noticejs-' + options.position; 453 | var element = document.createElement('div'); 454 | element.classList.add('noticejs'); 455 | element.classList.add(element_class); 456 | 457 | return element; 458 | } 459 | }, { 460 | key: 'createHeader', 461 | value: function createHeader() { 462 | var element = void 0; 463 | if (options.title && options.title !== '') { 464 | element = document.createElement('div'); 465 | element.setAttribute('class', 'noticejs-heading'); 466 | element.textContent = options.title; 467 | } 468 | 469 | // Add close button 470 | if (options.closeWith.includes('button')) { 471 | var close = document.createElement('div'); 472 | close.setAttribute('class', 'close'); 473 | close.innerHTML = '×'; 474 | if (element) { 475 | element.appendChild(close); 476 | } else { 477 | element = close; 478 | } 479 | } 480 | 481 | return element; 482 | } 483 | }, { 484 | key: 'createBody', 485 | value: function createBody() { 486 | var element = document.createElement('div'); 487 | element.setAttribute('class', 'noticejs-body'); 488 | var content = document.createElement('div'); 489 | content.setAttribute('class', 'noticejs-content'); 490 | content.innerHTML = options.text; 491 | element.appendChild(content); 492 | 493 | if (options.scroll !== null && options.scroll.maxHeight !== '') { 494 | element.style.overflowY = 'auto'; 495 | element.style.maxHeight = options.scroll.maxHeight + 'px'; 496 | 497 | if (options.scroll.showOnHover === true) { 498 | element.style.visibility = 'hidden'; 499 | } 500 | } 501 | return element; 502 | } 503 | }, { 504 | key: 'createProgressBar', 505 | value: function createProgressBar() { 506 | var element = document.createElement('div'); 507 | element.setAttribute('class', 'noticejs-progressbar'); 508 | var bar = document.createElement('div'); 509 | bar.setAttribute('class', 'noticejs-bar'); 510 | element.appendChild(bar); 511 | 512 | // Progress bar animation 513 | if (options.progressBar === true && typeof options.timeout !== 'boolean' && options.timeout !== false) { 514 | var frame = function frame() { 515 | if (width <= 0) { 516 | clearInterval(id); 517 | 518 | var item = element.closest('div.item'); 519 | // Add close animation 520 | if (options.animation !== null && options.animation.close !== null) { 521 | 522 | // Remove open animation class 523 | item.className = item.className.replace(new RegExp('(?:^|\\s)' + options.animation.open + '(?:\\s|$)'), ' '); 524 | // Add close animation class 525 | item.className += ' ' + options.animation.close; 526 | 527 | // Close notification after 0.5s + timeout 528 | var close_time = parseInt(options.timeout) + 500; 529 | setTimeout(function () { 530 | helper.CloseItem(item); 531 | }, close_time); 532 | } else { 533 | // Close notification when progress bar completed 534 | helper.CloseItem(item); 535 | } 536 | } else { 537 | width--; 538 | bar.style.width = width + '%'; 539 | } 540 | }; 541 | 542 | var width = 100; 543 | var id = setInterval(frame, options.timeout); 544 | } 545 | 546 | return element; 547 | } 548 | }]); 549 | 550 | return Components; 551 | }(); 552 | 553 | /***/ }) 554 | /******/ ]); 555 | }); -------------------------------------------------------------------------------- /static/webfonts/fa-brands-400.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/webfonts/fa-brands-400.eot -------------------------------------------------------------------------------- /static/webfonts/fa-brands-400.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/webfonts/fa-brands-400.ttf -------------------------------------------------------------------------------- /static/webfonts/fa-brands-400.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/webfonts/fa-brands-400.woff -------------------------------------------------------------------------------- /static/webfonts/fa-brands-400.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/webfonts/fa-brands-400.woff2 -------------------------------------------------------------------------------- /static/webfonts/fa-regular-400.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/webfonts/fa-regular-400.eot -------------------------------------------------------------------------------- /static/webfonts/fa-regular-400.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/webfonts/fa-regular-400.ttf -------------------------------------------------------------------------------- /static/webfonts/fa-regular-400.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/webfonts/fa-regular-400.woff -------------------------------------------------------------------------------- /static/webfonts/fa-regular-400.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/webfonts/fa-regular-400.woff2 -------------------------------------------------------------------------------- /static/webfonts/fa-solid-900.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/webfonts/fa-solid-900.eot -------------------------------------------------------------------------------- /static/webfonts/fa-solid-900.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/webfonts/fa-solid-900.ttf -------------------------------------------------------------------------------- /static/webfonts/fa-solid-900.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/webfonts/fa-solid-900.woff -------------------------------------------------------------------------------- /static/webfonts/fa-solid-900.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/webfonts/fa-solid-900.woff2 -------------------------------------------------------------------------------- /templates/extension.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 扩展中心 MeowAI 6 | 7 | 105 | 106 | 107 |
108 |
109 |

扩展中心

110 |
111 |
112 |
>
113 |
{{info.name}}{{info.description}}
114 |
115 |
116 |
117 |
118 | 119 | 120 | 155 | 156 | -------------------------------------------------------------------------------- /templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Meow-AI 7 | 8 | 9 | 10 | 11 | 12 | 13 |
14 |
15 |
Meow-AI Chat
16 |
17 |
18 | {{ users[message.type].name }} 19 | 20 |
21 |

{{text}}

22 |
23 |
24 | 25 |
26 | 27 | 28 |
29 |
30 |
31 |
复制
32 |
编辑
33 |
删除
34 |
35 |
36 |
37 | 38 | 39 |
40 |
41 | 42 | 43 | 44 |
45 |
46 |
47 | 93 |
94 |
95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/test.py --------------------------------------------------------------------------------