├── LICENSE
├── README.md
├── config.json
├── config.json.example
├── extensions
├── __init__.py
├── mai_setupButler
│ ├── __init__.py
│ ├── requirements.txt
│ ├── setupButler.py
│ ├── static
│ │ ├── info.json
│ │ └── logo.png
│ └── templates
│ │ └── mai_setupButler
│ │ └── index.html
└── mai_wechat
│ ├── __init__.py
│ ├── requirements.txt
│ ├── server.py
│ ├── static
│ ├── info.json
│ └── logo.png
│ ├── templates
│ └── mai_wechat
│ │ └── index.html
│ └── wxauto
│ ├── __init__.py
│ ├── a.dll
│ ├── color.py
│ ├── elements.py
│ ├── errors.py
│ ├── languages.py
│ ├── uiautomation.py
│ ├── utils.py
│ └── wxauto.py
├── main.py
├── meowServer.py
├── models
└── EEADME.md
├── requirement.txt
├── rwkv
├── 20B_tokenizer.json
├── __pycache__
│ ├── model.cpython-39.pyc
│ ├── rwkv_tokenizer.cpython-39.pyc
│ └── utils.cpython-39.pyc
├── cpp
│ ├── librwkv.dylib
│ ├── librwkv.so
│ ├── model.py
│ ├── rwkv.dll
│ ├── rwkv_cpp_model.py
│ └── rwkv_cpp_shared_library.py
├── cuda
│ ├── gemm_fp16_cublas.cpp
│ ├── operators.cu
│ ├── rwkv5.cu
│ ├── rwkv5_op.cpp
│ ├── rwkv6.cu
│ ├── rwkv6_op.cpp
│ ├── rwkv7.cu
│ ├── rwkv7_op.cpp
│ └── wrapper.cpp
├── model.py
├── rwkv5.pyd
├── rwkv6.pyd
├── rwkv_tokenizer.py
├── rwkv_vocab_v20230424.txt
├── rwkv_vocab_v20230424_special_token.txt
├── tokenizer-midi.json
├── tokenizer-midipiano.json
├── utils.py
├── webgpu
│ ├── model.py
│ └── web_rwkv_py.cp310-win_amd64.pyd
├── wkv7s.pyd
└── wkv_cuda.pyd
├── static
├── config
│ ├── default.json
│ └── noke.json
├── css
│ ├── all.min.css
│ ├── font-awesome.css
│ ├── font-awesome.min.css
│ ├── noticejs.css
│ └── styles.css
├── fonts
│ ├── FontAwesome.otf
│ ├── fontawesome-webfont.eot
│ ├── fontawesome-webfont.svg
│ ├── fontawesome-webfont.ttf
│ ├── fontawesome-webfont.woff
│ └── fontawesome-webfont.woff2
├── img
│ ├── PC_Coordinate.png
│ ├── PC_Coordinate.wb
│ ├── ai.png
│ ├── ai2.png
│ ├── ai3.png
│ ├── ai4.png
│ ├── ai42.png
│ ├── bilibini.png
│ ├── coordinate.png
│ └── nekomusume.png
├── js
│ ├── axios.min.js
│ ├── main.js
│ ├── notice.js
│ ├── socket.io.min.js
│ └── vue.global.js
└── webfonts
│ ├── fa-brands-400.eot
│ ├── fa-brands-400.svg
│ ├── fa-brands-400.ttf
│ ├── fa-brands-400.woff
│ ├── fa-brands-400.woff2
│ ├── fa-regular-400.eot
│ ├── fa-regular-400.svg
│ ├── fa-regular-400.ttf
│ ├── fa-regular-400.woff
│ ├── fa-regular-400.woff2
│ ├── fa-solid-900.eot
│ ├── fa-solid-900.svg
│ ├── fa-solid-900.ttf
│ ├── fa-solid-900.woff
│ └── fa-solid-900.woff2
├── templates
├── extension.html
└── index.html
└── test.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 bilibini
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Meow-AI
2 | “Meow-AI”是基于RWKV的本地轻量级聊天AI
3 |
4 | ## 环境配置
5 | 推荐Python版本 3.9.18
6 | 需要模块numpy、tokenizers、prompt_toolkit、flask、flask-socketio、torch、subprocess
7 | `pip3 install numpy tokenizers prompt_toolkit flask flask-socketio torch subprocess`
8 |
9 | ## 运行要求
10 | 最低4G运行内存(CPU+4G内存,可以运行430m的小模型)
11 | 最高不限(最高可运行14b的大模型,聊天质量和效果更加好)
12 |
13 | ## 运行设置
14 | 根据自己电脑配置在模型网站([https://huggingface.co/BlinkDL](https://huggingface.co/BlinkDL))下载合适的模型 ([镜像网址](https://hf-mirror.com/BlinkDL))
15 | 下载的所有文件放在models文件夹内,并修改config.json中的“modelFile”,将'RWKV-x070-World-0.1B-v2.8-20241210-ctx4096'修改为自己下载的模型名
16 | 运行main.py,在浏览器中打开http://172.0.0.1:5000 即可开始对话
17 |
18 | ## 功能更新
19 | - 2024-01-05:支持手动停止对话
20 | - 2024-01-19:支持编辑对话实现简单的手动微调
21 | - 2024-01-23:支持配置导入/导出,支持自定义AI性格人设
22 | - 2024-03-01:完全使用RWKV架构,实现更小的模型体积,降低运行内存和CPU占用
23 | - 2024-04-26:优化整体架构
24 | - 2024-07-08:支持扩展功能,支持微信自动聊天
25 | - 2024-07-21:完成一键运行包
26 | - 2025-02-20:支持RWKV7模型
27 |
28 | ## 演示效果
29 | ### 1. 持续对话
30 | 该演示环境为GPU+8G,使用'RWKV-4-World-CHNtuned-0.1B-v1-20230617-ctx4096'模型
31 | 
32 | ### 2. 手动调整对话
33 | ~注:手动添加修改更多的自己预设的对话,对后续的实现自己想要的聊天效果有很大帮助~
34 | 
35 | ### 3. 配置导入导出
36 | 可以导入导出对话以及配置信息
37 | 
38 | ### 4. 调教AI人设性格
39 | 可以自定义设置MeowAI人设性格,让MeowAI更加符合自己的喜好
40 | 
41 | ### 5. 扩展功能
42 | 支持添加自定义扩展功能,目前已完成“微信自动聊天”
43 | 点击启动后,将自动打开微信,自动回复聊天
44 | 
45 | ### 6. 一键运行包
46 | 无需手动配置环境,运行包里含了所有所需环境,只需双击运行即可
47 | (由于运行包大小过大,请在[releases](https://github.com/bilibini/Meow-AI/releases)下载)
48 | 
49 |
50 |
51 | ## 未来展望
52 | 1. 完成更多功能扩展
53 | 2. 待续……
54 |
55 |
56 | ## Star History
57 | [](https://star-history.com/#bilibini/Meow-AI&Date)
58 |
59 |
--------------------------------------------------------------------------------
/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "host": "127.0.0.1",
3 | "port": 5000,
4 | "modelsFolder": "models",
5 | "dataFolder": "data",
6 | "outputFolder": "output",
7 | "modelFile": "RWKV-x070-World-0.1B-v2.8-20241210-ctx4096",
8 | "strategy": "cuda fp32",
9 | "configFile": "default.json",
10 | "autoOpen": true
11 | }
--------------------------------------------------------------------------------
/config.json.example:
--------------------------------------------------------------------------------
1 | {
2 | "host": "127.0.0.1",
3 | "port": 5000,
4 | "modelsFolder": "models",
5 | "dataFolder": "data",
6 | "configFolder": "config",
7 | "outputFolder": "output",
8 | "modelFile": "RWKV-x060-World-1B6-v2-20240208-ctx4096",
9 | "configFile": "default.json",
10 | "autoOpen": true
11 | }
--------------------------------------------------------------------------------
/extensions/__init__.py:
--------------------------------------------------------------------------------
1 | from flask import Flask,json
2 | from flask_socketio import SocketIO
3 | from meowServer import MeowAI as MA
4 | from typing import List,Dict,Mapping,Union,Callable,Any
5 | from pathlib import Path
6 | from importlib.metadata import version
7 | import packaging.version as pv
8 | import importlib
9 | import subprocess
10 | import sys
11 |
12 | def get_installed_version(package: str) -> pv.Version:
13 | try:
14 | return pv.parse(version(package))
15 | except Exception:
16 | return pv.parse("0")
17 |
18 | def install_requirements(requirements:Path):
19 | with open(requirements) as f:
20 | for line in f:
21 | line = line.strip()
22 | if not line or line.startswith('#'):
23 | continue
24 | if '==' in line:
25 | package, version = line.split('==')
26 | if get_installed_version(package) != pv.parse(version):
27 | print(f'Installing {package}=={version}')
28 | subprocess.check_call([sys.executable, '-m', 'pip', 'install', line])
29 | elif '>=' in line:
30 | package, version = line.split('>=')
31 | if get_installed_version(package) < pv.parse(version):
32 | print(f'Installing {package}=={version}')
33 | subprocess.check_call([sys.executable, '-m', 'pip', 'install', line])
34 | elif get_installed_version(line)==pv.parse("0"):
35 | print(f'Installing {line}')
36 | subprocess.check_call([sys.executable, '-m', 'pip', 'install', line])
37 | else:
38 | print(f'{line} is already installed')
39 |
40 | def load_extensions(meowAPP:Flask,meowSIO:SocketIO,meowAI:MA)->List[Path]:
41 | extensions_path = Path(__file__).parent
42 | extension_dirs = [extension_dir for extension_dir in extensions_path.iterdir() if (extension_dir / '__init__.py').exists()]
43 | print(extension_dirs)
44 | extension_infoList = []
45 | for extension_dir in extension_dirs:
46 | requirementsPath=extension_dir.joinpath('requirements.txt')
47 | try:
48 | install_requirements(requirementsPath)
49 | except Exception as e:
50 | print(f'Error installing requirements for {extension_dir.name}: {e}')
51 | continue
52 |
53 | module_name = f'extensions.{extension_dir.name}'
54 | module = importlib.import_module(module_name)
55 | if hasattr(module, 'init_app'):
56 | module.init_app(meowAPP,meowSIO,meowAI)
57 | extension_info=extension_dir.joinpath('static/info.json')
58 | if extension_info.exists():
59 | with open(extension_info,'r',encoding='UTF-8') as f:extension_infoList.append(json.load(f))
60 |
61 |
62 | @meowAPP.route('/extension/infoList.json')
63 | def get_extension_infoList():
64 | print(extension_infoList)
65 | return json.dumps(extension_infoList)
66 |
67 | return extension_dirs
68 |
--------------------------------------------------------------------------------
/extensions/mai_setupButler/__init__.py:
--------------------------------------------------------------------------------
1 | from flask import Flask
2 | from flask_socketio import SocketIO
3 | from meowServer import MeowAI as MA
4 | from .setupButler import SetupButler
5 |
6 | def init_app(meowAPP:Flask,meowSIO:SocketIO,meowAI:MA):
7 | SetupButler(meowAPP,meowSIO,meowAI).run()
--------------------------------------------------------------------------------
/extensions/mai_setupButler/requirements.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/extensions/mai_setupButler/requirements.txt
--------------------------------------------------------------------------------
/extensions/mai_setupButler/setupButler.py:
--------------------------------------------------------------------------------
1 | from flask import Flask,Blueprint,render_template,request
2 | from typing import List,Dict,Mapping,Union,Callable,Any
3 | from flask_socketio import SocketIO
4 | from meowServer import MeowAI as MA
5 | from rwkv.model import RWKV
6 | from rwkv.utils import PIPELINE
7 | from pathlib import Path
8 | import json,gc,torch,time
9 |
10 |
11 | rootPath = Path(__file__).parent.parent.parent
12 | configPath=rootPath.joinpath('config.json')
13 |
14 | class SetupButler():
15 | def __init__(self,meowAPP:Flask,meowSIO:SocketIO,meowAI:MA):
16 | self.meowAI=meowAI
17 | self.meowAPP=meowAPP
18 | self.meowSIO=meowSIO
19 | self.app = Blueprint('setupButler', __name__,static_folder='static',template_folder='templates',url_prefix='/extension/setupButler')
20 | self.config={}
21 | with open(configPath,'r') as f:
22 | self.config = json.load(f)
23 |
24 | def run(self):
25 | modelsFolder=rootPath.joinpath(self.config['modelsFolder'])
26 | print(modelsFolder)
27 | @self.app.route('/config.json')
28 | def get_config():
29 | print(modelsFolder)
30 | config={
31 | 'host':self.config['host'],
32 | 'port':self.config['port'],
33 | 'model':str(modelsFolder.joinpath(self.config['modelFile'])),
34 | 'modelList':self.getModelsList(modelsFolder),
35 | 'strategy':self.config['strategy'],
36 | 'autoOpen':self.config['autoOpen']
37 | }
38 | return json.dumps(config)
39 |
40 | @self.app.route('/setup',methods=['POST'])
41 | def setup():
42 | data = request.get_json()
43 | self.config['host']=data['host']
44 | self.config['port']=data['port']
45 | self.config['autoOpen']=data['autoOpen']
46 | model=Path(data['model']).name
47 | self.config['modelFile']=model
48 | self.config['strategy']=data['strategy']
49 | with open(configPath,'w') as f:f.write(json.dumps(self.config,indent=4))
50 | return json.dumps({'code': 0, 'msg': '设置成功,请重新启动服务'})
51 | if model!=self.config['modelFile'] or data['strategy']!=self.config['strategy']:
52 | del self.meowAI.model.model
53 | del self.meowAI.model
54 | gc.collect()
55 | if 'cuda' in self.config['strategy']:
56 | torch.cuda.empty_cache()
57 | self.config['modelFile']=model
58 | self.config['strategy']=data['strategy']
59 |
60 | model = RWKV(model=str(data['model']), strategy=data['strategy'])
61 | pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
62 | self.meowAI.model=pipeline
63 | with open(configPath,'w') as f:f.write(json.dumps(self.config,indent=4))
64 | return json.dumps({'code': 0, 'msg': '模型重新载入成功'})
65 |
66 | @self.app.route('/')
67 | def index():
68 | return render_template('mai_setupButler/index.html')
69 |
70 | self.meowAPP.register_blueprint(self.app)
71 |
72 | def getModelsList(self,modelsFolder:Path)->List[Dict[str,str]]:
73 | modelInfoList=[]
74 | for file in modelsFolder.rglob('*.pth'):
75 | modelsinfo={
76 | "name":file.name.replace('.pth',''),
77 | "path":str(file.absolute()).replace('.pth','')
78 | }
79 | modelInfoList.append(modelsinfo)
80 | return modelInfoList
81 |
82 |
--------------------------------------------------------------------------------
/extensions/mai_setupButler/static/info.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "setup",
3 | "version": "1.0.0",
4 | "description": "模型选择设置",
5 | "url": "/extension/setupButler",
6 | "logo": "logo.png"
7 | }
--------------------------------------------------------------------------------
/extensions/mai_setupButler/static/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/extensions/mai_setupButler/static/logo.png
--------------------------------------------------------------------------------
/extensions/mai_setupButler/templates/mai_setupButler/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | 设置_MeowAI
6 |
7 |
8 |
69 |
70 |
71 |
101 |
102 |
103 |
104 |
105 |
179 |
180 |
--------------------------------------------------------------------------------
/extensions/mai_wechat/__init__.py:
--------------------------------------------------------------------------------
1 | from flask import Flask
2 | from flask_socketio import SocketIO
3 | from meowServer import MeowAI as MA
4 | from .server import main
5 |
6 | def init_app(meowAPP:Flask,meowSIO:SocketIO,meowAI:MA):
7 | main(meowAPP,meowSIO,meowAI)
--------------------------------------------------------------------------------
/extensions/mai_wechat/requirements.txt:
--------------------------------------------------------------------------------
1 | uiautomation
2 | Pillow
3 | pywin32
4 | psutil
5 | pyperclip
--------------------------------------------------------------------------------
/extensions/mai_wechat/server.py:
--------------------------------------------------------------------------------
1 | from .wxauto import WeChat,elements
2 | from flask import Flask,Blueprint,render_template,request
3 | from flask_socketio import SocketIO
4 | from threading import Thread, Event
5 | import json,time
6 | from meowServer import MeowAI as MA
7 |
8 | class WeChatServer():
9 | def __init__(self,meowAPP:Flask,meowSIO:SocketIO, meowAI: MA):
10 | self.meowAI=meowAI
11 | self.meowAPP=meowAPP
12 | self.meowSIO=meowSIO
13 | self.stop_event = Event()
14 | self.app = Blueprint('wechat', __name__,static_folder='static',template_folder='templates',url_prefix='/extension/wechat')
15 | self.wx=None
16 | self.prev_state=None
17 | self.chats=[]
18 | @self.app.route('/')
19 | def index():
20 | return render_template('mai_wechat/index.html')
21 | @self.app.route('/start',methods=['POST'])
22 | def start():
23 | try:
24 | self.wx = WeChat()
25 | self.stop_event.clear()
26 | data = request.get_json()
27 | thread = Thread(target=self.run, args=(data['name'],data['chatHistory']))
28 | thread.start()
29 | return json.dumps({'code': 0, 'msg': f'启动成功,获取到已登录窗口:{self.wx.nickname}'})
30 | except:
31 | return json.dumps({'code': 1, 'msg': '启动失败,请检查是否已经打开并登录微信'})
32 | @self.app.route('/stop',methods=['POST','GET'])
33 | def stop():
34 | self.stop_event.set()
35 | self.prev_state=None
36 | self.wx=None
37 | return json.dumps({'code': 0, 'msg': '微信自动对话停止成功'})
38 | @self.app.route('/status',methods=['POST','GET'])
39 | def status():
40 | if self.wx:
41 | return json.dumps({'code': 1, 'msg': '已开始'})
42 | else:
43 | return json.dumps({'code': 0, 'msg': '未开始'})
44 |
45 | self.meowAPP.register_blueprint(self.app)
46 | self.index=index
47 |
48 |
49 | def reply(self,msg:str,chat:elements.ChatWnd)->str:
50 | '''
51 | self.chats=[
52 | {
53 | "role": "User",
54 | "content": "What is the meaning of life?"
55 | },
56 | {
57 | "role": "Assistant",
58 | "content": "The meaning of life is to live a happy and fulfilling life."
59 | },
60 | {
61 | "role": "User",
62 | "content": "How do cats call?"
63 | },
64 | ]
65 | '''
66 | self.chats.append({"role": "User", "content": self.meowAI.purr(msg)})
67 | if self.prev_state:
68 | messages=self.chats[len(self.chats)-2:]
69 | else:
70 | messages=self.chats
71 | print(messages)
72 | reply,self.prev_state=self.meowAI.chat(messages,self.prev_state,lambda x:print(x[0],end=''))
73 | reply=reply.replace('\n\n','\n').replace('。','')
74 | self.chats.append({"role": "Assistant", "content": reply})
75 | return reply
76 |
77 | def run(self,name:str,chat_history:bool=False):
78 | try:
79 | wx=self.wx
80 | wx.ChatWith(who=name)
81 | if chat_history:
82 | wx.LoadMoreMessage()
83 | for chat in wx.GetAllMessage():
84 | if chat[0]=='SYS':pass
85 | elif chat[0]=='Self':
86 | self.chats.append({"role": "Assistant", "content": chat[1]})
87 | else:
88 | self.chats.append({"role": "User", "content": chat[1]})
89 | print(self.chats)
90 | wx.AddListenChat(who=name)
91 | while not self.stop_event.is_set():
92 | chats = wx.GetListenMessage()
93 | for chat in filter(lambda x:x.who==name,chats):
94 | msgs=chats.get(chat)
95 | for msg in filter(lambda x:x.type=='friend',msgs):
96 | reply=self.reply(msg.content,chat)
97 | chat.SendMsg(reply)
98 | time.sleep(0.1)
99 | except Exception as e:
100 | self.meowSIO.emit('emit',{'code':1,'msg':'微信监听失败:'+str(e)})
101 | self.stop_event.set()
102 | self.prev_state=None
103 | self.meowSIO.emit('wechat_stop','微信监听失败:'+str(e))
104 | print('错误:',e)
105 |
106 |
107 | def main(meowAPP:Flask,meowSIO:SocketIO,meowAI:MA):
108 | WeChatServer(meowAPP,meowSIO,meowAI)
--------------------------------------------------------------------------------
/extensions/mai_wechat/static/info.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "wechat",
3 | "version": "1.0.0",
4 | "description": "微信自动回复",
5 | "url": "/extension/wechat",
6 | "logo": "logo.png"
7 | }
--------------------------------------------------------------------------------
/extensions/mai_wechat/static/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/extensions/mai_wechat/static/logo.png
--------------------------------------------------------------------------------
/extensions/mai_wechat/templates/mai_wechat/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | 微信自动回复_MeowAI
6 |
7 |
8 |
69 |
70 |
71 |
72 |
微信自动回复
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
温馨提示:
82 |
在“启动”前,请先登录电脑微信
83 |
84 |
85 |
86 |
87 |
88 |
89 |
173 |
174 |
--------------------------------------------------------------------------------
/extensions/mai_wechat/wxauto/__init__.py:
--------------------------------------------------------------------------------
1 | from .wxauto import WeChat
2 |
3 | VERSION = '3.9.8.15'
4 |
--------------------------------------------------------------------------------
/extensions/mai_wechat/wxauto/a.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/extensions/mai_wechat/wxauto/a.dll
--------------------------------------------------------------------------------
/extensions/mai_wechat/wxauto/color.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | import random
3 | import os
4 |
5 | os.system('')
6 |
7 | color_dict = {
8 | 'BLACK': '\x1b[30m',
9 | 'BLUE': '\x1b[34m',
10 | 'CYAN': '\x1b[36m',
11 | 'GREEN': '\x1b[32m',
12 | 'LIGHTBLACK_EX': '\x1b[90m',
13 | 'LIGHTBLUE_EX': '\x1b[94m',
14 | 'LIGHTCYAN_EX': '\x1b[96m',
15 | 'LIGHTGREEN_EX': '\x1b[92m',
16 | 'LIGHTMAGENTA_EX': '\x1b[95m',
17 | 'LIGHTRED_EX': '\x1b[91m',
18 | 'LIGHTWHITE_EX': '\x1b[97m',
19 | 'LIGHTYELLOW_EX': '\x1b[93m',
20 | 'MAGENTA': '\x1b[35m',
21 | 'RED': '\x1b[31m',
22 | 'WHITE': '\x1b[37m',
23 | 'YELLOW': '\x1b[33m'
24 | }
25 |
26 | color_reset = '\x1b[0m'
27 |
28 | class Print:
29 | @staticmethod
30 | def black(text, *args, **kwargs):
31 | print(color_dict['BLACK'] + text + color_reset, *args, **kwargs)
32 |
33 | @staticmethod
34 | def blue(text, *args, **kwargs):
35 | print(color_dict['BLUE'] + text + color_reset, *args, **kwargs)
36 |
37 | @staticmethod
38 | def cyan(text, *args, **kwargs):
39 | print(color_dict['CYAN'] + text + color_reset, *args, **kwargs)
40 |
41 | @staticmethod
42 | def green(text, *args, **kwargs):
43 | print(color_dict['GREEN'] + text + color_reset, *args, **kwargs)
44 |
45 | @staticmethod
46 | def lightblack(text, *args, **kwargs):
47 | print(color_dict['LIGHTBLACK_EX'] + text + color_reset, *args, **kwargs)
48 |
49 | @staticmethod
50 | def lightblue(text, *args, **kwargs):
51 | print(color_dict['LIGHTBLUE_EX'] + text + color_reset, *args, **kwargs)
52 |
53 | @staticmethod
54 | def lightcyan(text, *args, **kwargs):
55 | print(color_dict['LIGHTCYAN_EX'] + text + color_reset, *args, **kwargs)
56 |
57 | @staticmethod
58 | def lightgreen(text, *args, **kwargs):
59 | print(color_dict['LIGHTGREEN_EX'] + text + color_reset, *args, **kwargs)
60 |
61 | @staticmethod
62 | def lightmagenta(text, *args, **kwargs):
63 | print(color_dict['LIGHTMAGENTA_EX'] + text + color_reset, *args, **kwargs)
64 |
65 | @staticmethod
66 | def lightred(text, *args, **kwargs):
67 | print(color_dict['LIGHTRED_EX'] + text + color_reset, *args, **kwargs)
68 |
69 | @staticmethod
70 | def lightwhite(text, *args, **kwargs):
71 | print(color_dict['LIGHTWHITE_EX'] + text + color_reset, *args, **kwargs)
72 |
73 | @staticmethod
74 | def lightyellow(text, *args, **kwargs):
75 | print(color_dict['LIGHTYELLOW_EX'] + text + color_reset, *args, **kwargs)
76 |
77 | @staticmethod
78 | def magenta(text, *args, **kwargs):
79 | print(color_dict['MAGENTA'] + text + color_reset, *args, **kwargs)
80 |
81 | @staticmethod
82 | def red(text, *args, **kwargs):
83 | print(color_dict['RED'] + text + color_reset, *args, **kwargs)
84 |
85 | @staticmethod
86 | def white(text, *args, **kwargs):
87 | print(color_dict['WHITE'] + text + color_reset, *args, **kwargs)
88 |
89 | @staticmethod
90 | def yellow(text, *args, **kwargs):
91 | print(color_dict['YELLOW'] + text + color_reset, *args, **kwargs)
92 |
93 | @staticmethod
94 | def random(text, *args, **kwargs):
95 | print(random.choice(list(color_dict.values())) + text + color_reset, *args, **kwargs)
96 |
97 |
98 | class Input:
99 | @staticmethod
100 | def black(text, *args, **kwargs):
101 | print(color_dict['BLACK'] + text + color_reset, end='')
102 | result = input(*args, **kwargs)
103 | return result
104 |
105 | @staticmethod
106 | def blue(text, *args, **kwargs):
107 | print(color_dict['BLUE'] + text + color_reset, end='')
108 | result = input(*args, **kwargs)
109 | return result
110 |
111 | @staticmethod
112 | def cyan(text, *args, **kwargs):
113 | print(color_dict['CYAN'] + text + color_reset, end='')
114 | result = input(*args, **kwargs)
115 | return result
116 |
117 | @staticmethod
118 | def green(text, *args, **kwargs):
119 | print(color_dict['GREEN'] + text + color_reset, end='')
120 | result = input(*args, **kwargs)
121 | return result
122 |
123 | @staticmethod
124 | def lightblack(text, *args, **kwargs):
125 | print(color_dict['LIGHTBLACK_EX'] + text + color_reset, end='')
126 | result = input(*args, **kwargs)
127 | return result
128 |
129 | @staticmethod
130 | def lightblue(text, *args, **kwargs):
131 | print(color_dict['LIGHTBLUE_EX'] + text + color_reset, end='')
132 | result = input(*args, **kwargs)
133 | return result
134 |
135 | @staticmethod
136 | def lightcyan(text, *args, **kwargs):
137 | print(color_dict['LIGHTCYAN_EX'] + text + color_reset, end='')
138 | result = input(*args, **kwargs)
139 | return result
140 |
141 | @staticmethod
142 | def lightgreen(text, *args, **kwargs):
143 | print(color_dict['LIGHTGREEN_EX'] + text + color_reset, end='')
144 | result = input(*args, **kwargs)
145 | return result
146 |
147 | @staticmethod
148 | def lightmagenta(text, *args, **kwargs):
149 | print(color_dict['LIGHTMAGENTA_EX'] + text + color_reset, end='')
150 | result = input(*args, **kwargs)
151 | return result
152 |
153 | @staticmethod
154 | def lightred(text, *args, **kwargs):
155 | print(color_dict['LIGHTRED_EX'] + text + color_reset, end='')
156 | result = input(*args, **kwargs)
157 | return result
158 |
159 | @staticmethod
160 | def lightwhite(text, *args, **kwargs):
161 | print(color_dict['LIGHTWHITE_EX'] + text + color_reset, end='')
162 | result = input(*args, **kwargs)
163 | return result
164 |
165 | @staticmethod
166 | def lightyellow(text, *args, **kwargs):
167 | print(color_dict['LIGHTYELLOW_EX'] + text + color_reset, end='')
168 | result = input(*args, **kwargs)
169 | return result
170 |
171 | @staticmethod
172 | def magenta(text, *args, **kwargs):
173 | print(color_dict['MAGENTA'] + text + color_reset, end='')
174 | result = input(*args, **kwargs)
175 | return result
176 |
177 | @staticmethod
178 | def red(text, *args, **kwargs):
179 | print(color_dict['RED'] + text + color_reset, end='')
180 | result = input(*args, **kwargs)
181 | return result
182 |
183 | @staticmethod
184 | def white(text, *args, **kwargs):
185 | print(color_dict['WHITE'] + text + color_reset, end='')
186 | result = input(*args, **kwargs)
187 | return result
188 |
189 | @staticmethod
190 | def yellow(text, *args, **kwargs):
191 | print(color_dict['YELLOW'] + text + color_reset, end='')
192 | result = input(*args, **kwargs)
193 | return result
194 |
195 | @staticmethod
196 | def random(text, *args, **kwargs):
197 | print(random.choice(list(color_dict.values())) + text + color_reset, end='')
198 | result = input(*args, **kwargs)
199 | return result
200 |
201 |
202 | class Warnings:
203 | @staticmethod
204 | def black(text, *args, **kwargs):
205 | warnings.warn('\n' + color_dict['BLACK'] + text + color_reset, *args, **kwargs)
206 |
207 | @staticmethod
208 | def blue(text, *args, **kwargs):
209 | warnings.warn('\n' + color_dict['BLUE'] + text + color_reset, *args, **kwargs)
210 |
211 | @staticmethod
212 | def cyan(text, *args, **kwargs):
213 | warnings.warn('\n' + color_dict['CYAN'] + text + color_reset, *args, **kwargs)
214 |
215 | @staticmethod
216 | def green(text, *args, **kwargs):
217 | warnings.warn('\n' + color_dict['GREEN'] + text + color_reset, *args, **kwargs)
218 |
219 | @staticmethod
220 | def lightblack(text, *args, **kwargs):
221 | warnings.warn('\n' + color_dict['LIGHTBLACK_EX'] + text + color_reset, *args, **kwargs)
222 |
223 | @staticmethod
224 | def lightblue(text, *args, **kwargs):
225 | warnings.warn('\n' + color_dict['LIGHTBLUE_EX'] + text + color_reset, *args, **kwargs)
226 |
227 | @staticmethod
228 | def lightcyan(text, *args, **kwargs):
229 | warnings.warn('\n' + color_dict['LIGHTCYAN_EX'] + text + color_reset, *args, **kwargs)
230 |
231 | @staticmethod
232 | def lightgreen(text, *args, **kwargs):
233 | warnings.warn('\n' + color_dict['LIGHTGREEN_EX'] + text + color_reset, *args, **kwargs)
234 |
235 | @staticmethod
236 | def lightmagenta(text, *args, **kwargs):
237 | warnings.warn('\n' + color_dict['LIGHTMAGENTA_EX'] + text + color_reset, *args, **kwargs)
238 |
239 | @staticmethod
240 | def lightred(text, *args, **kwargs):
241 | warnings.warn('\n' + color_dict['LIGHTRED_EX'] + text + color_reset, *args, **kwargs)
242 |
243 | @staticmethod
244 | def lightwhite(text, *args, **kwargs):
245 | warnings.warn('\n' + color_dict['LIGHTWHITE_EX'] + text + color_reset, *args, **kwargs)
246 |
247 | @staticmethod
248 | def lightyellow(text, *args, **kwargs):
249 | warnings.warn('\n' + color_dict['LIGHTYELLOW_EX'] + text + color_reset, *args, **kwargs)
250 |
251 | @staticmethod
252 | def magenta(text, *args, **kwargs):
253 | warnings.warn('\n' + color_dict['MAGENTA'] + text + color_reset, *args, **kwargs)
254 |
--------------------------------------------------------------------------------
/extensions/mai_wechat/wxauto/errors.py:
--------------------------------------------------------------------------------
1 |
2 | class TargetNotFoundError(Exception):
3 | pass
--------------------------------------------------------------------------------
/extensions/mai_wechat/wxauto/languages.py:
--------------------------------------------------------------------------------
1 | '''
2 | 多语言关键字尚未收集完整,欢迎多多pull requests帮忙补充,感谢
3 | '''
4 |
5 | MAIN_LANGUAGE = {
6 | # 导航栏
7 | '导航': {'cn': '导航', 'cn_t': '導航', 'en': 'Navigation'},
8 | '聊天': {'cn': '聊天', 'cn_t': '聊天', 'en': 'Chats'},
9 | '通讯录': {'cn': '通讯录', 'cn_t': '通訊錄', 'en': 'Contacts'},
10 | '收藏': {'cn': '收藏', 'cn_t': '收藏', 'en': 'Favorites'},
11 | '聊天文件': {'cn': '聊天文件', 'cn_t': '聊天室檔案', 'en': 'Chat Files'},
12 | '朋友圈': {'cn': '朋友圈', 'cn_t': '朋友圈', 'en': 'Moments'},
13 | '小程序面板': {'cn': '小程序面板', 'cn_t': '小程式面板', 'en': 'Mini Programs Panel'},
14 | '手机': {'cn': '手机', 'cn_t': '手機', 'en': 'Phone'},
15 | '设置及其他': {'cn': '设置及其他', 'cn_t': '設定與其他', 'en': 'Settings and Others'},
16 |
17 | # 好友列表栏
18 | '搜索': {'cn': '搜索', 'cn_t': '搜尋', 'en': 'Search'},
19 | '发起群聊': {'cn': '发起群聊', 'cn_t': '建立群組', 'en': 'Start Group Chat'},
20 | '文件传输助手': {'cn': '文件传输助手', 'cn_t': '檔案傳輸', 'en': 'File Transfer'},
21 | '订阅号': {'cn': '订阅号', 'cn_t': '官方賬號', 'en': 'Subscriptions'},
22 | '消息': {'cn': '消息', 'cn_t': '消息', 'en': ''},
23 |
24 | # 右上角工具栏
25 | '置顶': {'cn': '置顶', 'cn_t': '置頂', 'en': 'Sticky on Top'},
26 | '最小化': {'cn': '最小化', 'cn_t': '最小化', 'en': 'Minimize'},
27 | '最大化': {'cn': '最大化', 'cn_t': '最大化', 'en': ''},
28 | '关闭': {'cn': '关闭', 'cn_t': '關閉', 'en': ''},
29 |
30 | # 聊天框
31 | '聊天信息': {'cn': '聊天信息', 'cn_t': '聊天資訊', 'en': 'Chat Info'},
32 | '表情': {'cn': '表情', 'cn_t': '貼圖', 'en': 'Sticker'},
33 | '发送文件': {'cn': '发送文件', 'cn_t': '傳送檔案', 'en': 'Send File'},
34 | '截图': {'cn': '截图', 'cn_t': '截圖', 'en': 'Screenshot'},
35 | '聊天记录': {'cn': '聊天记录', 'cn_t': '聊天記錄', 'en': 'Chat History'},
36 | '语音聊天': {'cn': '语音聊天', 'cn_t': '語音通話', 'en': 'Voice Call'},
37 | '视频聊天': {'cn': '视频聊天', 'cn_t': '視頻通話', 'en': 'Video Call'},
38 | '发送': {'cn': '发送', 'cn_t': '傳送', 'en': 'Send'},
39 | '输入': {'cn': '输入', 'cn_t': '輸入', 'en': 'Enter'},
40 |
41 | # 消息类型
42 | '链接': {'cn': '链接', 'cn_t': '鏈接', 'en': 'Link'},
43 | '视频': {'cn': '视频', 'cn_t': '視頻', 'en': 'Video'},
44 | '图片': {'cn': '图片', 'cn_t': '圖片', 'en': 'Photo'},
45 | '文件': {'cn': '文件', 'cn_t': '文件', 'en': 'File'},
46 | '': {'cn': '', 'cn_t': '', 'en': ''}}
47 |
48 |
49 | IMAGE_LANGUAGE = {
50 | '上一张': {'cn': '上一张', 'cn_t': '上一張', 'en': 'Previous'},
51 | '下一张': {'cn': '下一张', 'cn_t': '下一張', 'en': 'Next'},
52 | '预览': {'cn': '预览', 'cn_t': '預覽', 'en': 'Preview'},
53 | '放大': {'cn': '放大', 'cn_t': '放大', 'en': 'Zoom'},
54 | '缩小': {'cn': '缩小', 'cn_t': '縮小', 'en': 'Shrink'},
55 | '图片原始大小': {'cn': '图片原始大小', 'cn_t': '圖片原始大小', 'en': 'Original image size'},
56 | '旋转': {'cn': '旋转', 'cn_t': '旋轉', 'en': 'Rotate'},
57 | '编辑': {'cn': '编辑', 'cn_t': '編輯', 'en': 'Edit'},
58 | '翻译': {'cn': '翻译', 'cn_t': '翻譯', 'en': 'Translate'},
59 | '提取文字': {'cn': '提取文字', 'cn_t': '提取文字', 'en': 'Extract Text'},
60 | '识别图中二维码': {'cn': '识别图中二维码', 'cn_t': '識别圖中QR Code', 'en': 'Extract QR Code'},
61 | '另存为...': {'cn': '另存为...', 'cn_t': '另存爲...', 'en': 'Save as…'},
62 | '更多': {'cn': '更多', 'cn_t': '更多', 'en': 'More'},
63 | '最小化': {'cn': '最小化', 'cn_t': '最小化', 'en': 'Minimize'},
64 | '最大化': {'cn': '最大化', 'cn_t': '最大化', 'en': 'Maximize'},
65 | '关闭': {'cn': '关闭', 'cn_t': '關閉', 'en': 'Close'},
66 | '': {'cn': '', 'cn_t': '', 'en': ''}}
67 |
68 | FILE_LANGUAGE = {
69 | '最小化': {'cn': '最小化', 'cn_t': '最小化', 'en': 'Minimize'},
70 | '最大化': {'cn': '最大化', 'cn_t': '最大化', 'en': 'Maximize'},
71 | '关闭': {'cn': '关闭', 'cn_t': '關閉', 'en': 'Close'},
72 | '全部': {'cn': '全部', 'cn_t': '全部', 'en': 'All'},
73 | '最近使用': {'cn': '最近使用', 'cn_t': '最近使用', 'en': 'Recent'},
74 | '发送者': {'cn': '发送者', 'cn_t': '發送者', 'en': 'Sender'},
75 | '聊天': {'cn': '聊天', 'cn_t': '聊天', 'en': 'Chat'},
76 | '类型': {'cn': '类型', 'cn_t': '類型', 'en': 'Type'},
77 | '按最新时间': {'cn': '按最新时间', 'cn_t': '按最新時間', 'en': 'Sort by Newest'},
78 | '按最旧时间': {'cn': '按最旧时间', 'cn_t': '按最舊時間', 'en': 'Sort by Oldest'},
79 | '按从大到小': {'cn': '按从大到小', 'cn_t': '按從大到小', 'en': 'Sort by Largest'},
80 | '按从小到大': {'cn': '按从小到大', 'cn_t': '按從小到大', 'en': 'Sort by Smallest'},
81 | '': {'cn': '', 'cn_t': '', 'en': ''}
82 | }
83 |
84 | WARNING = {
85 | '版本不一致': {
86 | 'cn': '当前微信客户端版本为{},与当前库版本{}不一致,可能会导致部分功能无法正常使用,请注意判断',
87 | 'cn_t': '當前微信客戶端版本為{},與當前庫版本{}不一致,可能會導致部分功能無法正常使用,請注意判斷',
88 | 'en': 'The current WeChat client version is {}, which is inconsistent with the current library version {}, which may cause some functions to fail to work properly. Please pay attention to judgment'
89 | }
90 | }
91 |
--------------------------------------------------------------------------------
/extensions/mai_wechat/wxauto/utils.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime, timedelta
2 | from . import uiautomation as uia
3 | from PIL import ImageGrab
4 | import win32clipboard
5 | import win32process
6 | import win32gui
7 | import win32api
8 | import win32con
9 | import pyperclip
10 | import ctypes
11 | import psutil
12 | import shutil
13 | import winreg
14 | import time
15 | import os
16 | import re
17 |
18 | VERSION = "3.9.8.15"
19 |
20 | def set_cursor_pos(x, y):
21 | win32api.SetCursorPos((x, y))
22 |
23 | def Click(rect):
24 | x = (rect.left + rect.right) // 2
25 | y = (rect.top + rect.bottom) // 2
26 | set_cursor_pos(x, y)
27 | win32api.mouse_event(win32con.MOUSEEVENTF_LEFTDOWN, x, y, 0, 0)
28 | win32api.mouse_event(win32con.MOUSEEVENTF_LEFTUP, x, y, 0, 0)
29 |
30 | def GetPathByHwnd(hwnd):
31 | try:
32 | thread_id, process_id = win32process.GetWindowThreadProcessId(hwnd)
33 | process = psutil.Process(process_id)
34 | return process.exe()
35 | except Exception as e:
36 | print(f"Error: {e}")
37 | return None
38 |
39 | def GetVersionByPath(file_path):
40 | try:
41 | info = win32api.GetFileVersionInfo(file_path, '\\')
42 | version = "{}.{}.{}.{}".format(win32api.HIWORD(info['FileVersionMS']),
43 | win32api.LOWORD(info['FileVersionMS']),
44 | win32api.HIWORD(info['FileVersionLS']),
45 | win32api.LOWORD(info['FileVersionLS']))
46 | except:
47 | version = None
48 | return version
49 |
50 |
51 | def IsRedPixel(uicontrol):
52 | rect = uicontrol.BoundingRectangle
53 | bbox = (rect.left, rect.top, rect.right, rect.bottom)
54 | img = ImageGrab.grab(bbox=bbox, all_screens=True)
55 | return any(p[0] > p[1] and p[0] > p[2] for p in img.getdata())
56 |
57 | class DROPFILES(ctypes.Structure):
58 | _fields_ = [
59 | ("pFiles", ctypes.c_uint),
60 | ("x", ctypes.c_long),
61 | ("y", ctypes.c_long),
62 | ("fNC", ctypes.c_int),
63 | ("fWide", ctypes.c_bool),
64 | ]
65 |
66 | pDropFiles = DROPFILES()
67 | pDropFiles.pFiles = ctypes.sizeof(DROPFILES)
68 | pDropFiles.fWide = True
69 | matedata = bytes(pDropFiles)
70 |
71 | def SetClipboardText(text: str):
72 | pyperclip.copy(text)
73 | # if not isinstance(text, str):
74 | # raise TypeError(f"参数类型必须为str --> {text}")
75 | # t0 = time.time()
76 | # while True:
77 | # if time.time() - t0 > 10:
78 | # raise TimeoutError(f"设置剪贴板超时! --> {text}")
79 | # try:
80 | # win32clipboard.OpenClipboard()
81 | # win32clipboard.EmptyClipboard()
82 | # win32clipboard.SetClipboardData(win32con.CF_UNICODETEXT, text)
83 | # break
84 | # except:
85 | # pass
86 | # finally:
87 | # try:
88 | # win32clipboard.CloseClipboard()
89 | # except:
90 | # pass
91 |
92 | try:
93 | from anytree import Node, RenderTree
94 |
95 | def GetAllControl(ele):
96 | def findall(ele, n=0, node=None):
97 | nn = '\n'
98 | nodename = f"[{ele.ControlTypeName}](\"{ele.ClassName}\", \"{ele.Name.replace(nn, '')}\")"
99 | if not node:
100 | node1 = Node(nodename)
101 | else:
102 | node1 = Node(nodename, parent=node)
103 | eles = ele.GetChildren()
104 | for ele1 in eles:
105 | findall(ele1, n+1, node1)
106 | return node1
107 | tree = RenderTree(findall(ele))
108 | for pre, fill, node in tree:
109 | print(f"{pre}{node.name}")
110 | except:
111 | pass
112 |
113 | def SetClipboardFiles(paths):
114 | for file in paths:
115 | if not os.path.exists(file):
116 | raise FileNotFoundError(f"file ({file}) not exists!")
117 | files = ("\0".join(paths)).replace("/", "\\")
118 | data = files.encode("U16")[2:]+b"\0\0"
119 | t0 = time.time()
120 | while True:
121 | if time.time() - t0 > 10:
122 | raise TimeoutError(f"设置剪贴板文件超时! --> {paths}")
123 | try:
124 | win32clipboard.OpenClipboard()
125 | win32clipboard.EmptyClipboard()
126 | win32clipboard.SetClipboardData(win32clipboard.CF_HDROP, matedata+data)
127 | break
128 | except:
129 | pass
130 | finally:
131 | try:
132 | win32clipboard.CloseClipboard()
133 | except:
134 | pass
135 |
136 | def PasteFile(folder):
137 | folder = os.path.realpath(folder)
138 | if not os.path.exists(folder):
139 | os.makedirs(folder)
140 |
141 | t0 = time.time()
142 | while True:
143 | if time.time() - t0 > 10:
144 | raise TimeoutError(f"读取剪贴板文件超时!")
145 | try:
146 | win32clipboard.OpenClipboard()
147 | if win32clipboard.IsClipboardFormatAvailable(win32clipboard.CF_HDROP):
148 | files = win32clipboard.GetClipboardData(win32clipboard.CF_HDROP)
149 | for file in files:
150 | filename = os.path.basename(file)
151 | dest_file = os.path.join(folder, filename)
152 | shutil.copy2(file, dest_file)
153 | return True
154 | else:
155 | print("剪贴板中没有文件")
156 | return False
157 | except:
158 | pass
159 | finally:
160 | win32clipboard.CloseClipboard()
161 |
162 | def GetText(HWND):
163 | length = win32gui.SendMessage(HWND, win32con.WM_GETTEXTLENGTH)*2
164 | buffer = win32gui.PyMakeBuffer(length)
165 | win32api.SendMessage(HWND, win32con.WM_GETTEXT, length, buffer)
166 | address, length_ = win32gui.PyGetBufferAddressAndLen(buffer[:-1])
167 | text = win32gui.PyGetString(address, length_)[:int(length/2)]
168 | buffer.release()
169 | return text
170 |
171 | def GetAllWindowExs(HWND):
172 | if not HWND:
173 | return
174 | handles = []
175 | win32gui.EnumChildWindows(
176 | HWND, lambda hwnd, param: param.append([hwnd, win32gui.GetClassName(hwnd), GetText(hwnd)]), handles)
177 | return handles
178 |
179 | def FindWindow(classname=None, name=None) -> int:
180 | return win32gui.FindWindow(classname, name)
181 |
182 | def FindWinEx(HWND, classname=None, name=None) -> list:
183 | hwnds_classname = []
184 | hwnds_name = []
185 | def find_classname(hwnd, classname):
186 | classname_ = win32gui.GetClassName(hwnd)
187 | if classname_ == classname:
188 | if hwnd not in hwnds_classname:
189 | hwnds_classname.append(hwnd)
190 | def find_name(hwnd, name):
191 | name_ = GetText(hwnd)
192 | if name in name_:
193 | if hwnd not in hwnds_name:
194 | hwnds_name.append(hwnd)
195 | if classname:
196 | win32gui.EnumChildWindows(HWND, find_classname, classname)
197 | if name:
198 | win32gui.EnumChildWindows(HWND, find_name, name)
199 | if classname and name:
200 | hwnds = [hwnd for hwnd in hwnds_classname if hwnd in hwnds_name]
201 | else:
202 | hwnds = hwnds_classname + hwnds_name
203 | return hwnds
204 |
205 | def ClipboardFormats(unit=0, *units):
206 | units = list(units)
207 | win32clipboard.OpenClipboard()
208 | u = win32clipboard.EnumClipboardFormats(unit)
209 | win32clipboard.CloseClipboard()
210 | units.append(u)
211 | if u:
212 | units = ClipboardFormats(u, *units)
213 | return units
214 |
215 | def ReadClipboardData():
216 | Dict = {}
217 | for i in ClipboardFormats():
218 | if i == 0:
219 | continue
220 | win32clipboard.OpenClipboard()
221 | try:
222 | filenames = win32clipboard.GetClipboardData(i)
223 | win32clipboard.CloseClipboard()
224 | except:
225 | win32clipboard.CloseClipboard()
226 | raise ValueError
227 | Dict[str(i)] = filenames
228 | return Dict
229 |
230 | def ParseWeChatTime(time_str):
231 | """
232 | 时间格式转换函数
233 |
234 | Args:
235 | time_str: 输入的时间字符串
236 |
237 | Returns:
238 | 转换后的时间字符串
239 | """
240 |
241 | match = re.match(r'^(\d{1,2}):(\d{1,2})$', time_str)
242 | if match:
243 | hour, minute = match.groups()
244 | return datetime.now().strftime('%Y-%m-%d') + f' {hour}:{minute}'
245 |
246 | match = re.match(r'^昨天 (\d{1,2}):(\d{1,2})$', time_str)
247 | if match:
248 | hour, minute = match.groups()
249 | yesterday = datetime.now() - timedelta(days=1)
250 | return yesterday.strftime('%Y-%m-%d') + f' {hour}:{minute}'
251 |
252 | match = re.match(r'^星期([一二三四五六日]) (\d{1,2}):(\d{1,2})$', time_str)
253 | if match:
254 | weekday, hour, minute = match.groups()
255 | weekday_num = ['一', '二', '三', '四', '五', '六', '日'].index(weekday)
256 | today_weekday = datetime.now().weekday()
257 | delta_days = (today_weekday - weekday_num) % 7
258 | target_day = datetime.now() - timedelta(days=delta_days)
259 | return target_day.strftime('%Y-%m-%d') + f' {hour}:{minute}'
260 |
261 | match = re.match(r'^(\d{4})年(\d{1,2})月(\d{1,2})日 (\d{1,2}):(\d{1,2})$', time_str)
262 | if match:
263 | year, month, day, hour, minute = match.groups()
264 | return datetime(*[int(i) for i in [year, month, day, hour, minute]]).strftime('%Y-%m-%d') + f' {hour}:{minute}'
265 |
266 |
267 | def FindPid(process_name):
268 | procs = psutil.process_iter(['pid', 'name'])
269 | for proc in procs:
270 | if process_name in proc.info['name']:
271 | return proc.info['pid']
272 |
273 |
274 | def Mver(pid):
275 | exepath = psutil.Process(pid).exe()
276 | if GetVersionByPath(exepath) != VERSION:
277 | Warning(f"该修复方法仅适用于版本号为{VERSION}的微信!")
278 | return
279 | if not uia.Control(ClassName='WeChatLoginWndForPC', searchDepth=1).Exists(maxSearchSeconds=2):
280 | Warning("请先打开微信启动页面再次尝试运行该方法!")
281 | return
282 | path = os.path.join(os.path.dirname(__file__), 'a.dll')
283 | dll = ctypes.WinDLL(path)
284 | dll.GetDllBaseAddress.argtypes = [ctypes.c_uint, ctypes.c_wchar_p]
285 | dll.GetDllBaseAddress.restype = ctypes.c_void_p
286 | dll.WriteMemory.argtypes = [ctypes.c_ulong, ctypes.c_void_p, ctypes.c_ulong]
287 | dll.WriteMemory.restype = ctypes.c_bool
288 | dll.GetMemory.argtypes = [ctypes.c_ulong, ctypes.c_void_p]
289 | dll.GetMemory.restype = ctypes.c_ulong
290 | mname = 'WeChatWin.dll'
291 | tar = 1661536787
292 | base_address = dll.GetDllBaseAddress(pid, mname)
293 | address = base_address + 64761648
294 | if dll.GetMemory(pid, address) != tar:
295 | dll.WriteMemory(pid, address, tar)
296 | handle = ctypes.c_void_p(dll._handle)
297 | ctypes.windll.kernel32.FreeLibrary(handle)
298 |
299 | def FixVersionError():
300 | """修复版本低无法登录的问题"""
301 | pid = FindPid('WeChat.exe')
302 | if pid:
303 | Mver(pid)
304 | return
305 | else:
306 | try:
307 | registry_key = winreg.OpenKey(winreg.HKEY_CURRENT_USER, r"Software\Tencent\WeChat", 0, winreg.KEY_READ)
308 | path, _ = winreg.QueryValueEx(registry_key, "InstallPath")
309 | winreg.CloseKey(registry_key)
310 | wxpath = os.path.join(path, "WeChat.exe")
311 | if os.path.exists(wxpath):
312 | os.system(f'start "" "{wxpath}"')
313 | FixVersionError()
314 | else:
315 | raise Exception('nof found')
316 | except WindowsError:
317 | Warning("未找到微信安装路径,请先打开微信启动页面再次尝试运行该方法!")
318 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from rwkv.model import RWKV
2 | from rwkv.utils import PIPELINE
3 | from meowServer import MeowAIServer,Character,MeowAI
4 | from typing import List,Dict,Mapping,Union,Callable,Any
5 | import torch,os,json
6 | from pathlib import Path
7 |
8 | os.environ['RWKV_JIT_ON'] = '1'
9 | os.environ["RWKV_CUDA_ON"] = '0'
10 |
11 | config=None
12 | modelsFolder=""
13 | modelFile=""
14 | rootPath=Path(__file__).parent
15 | configPath=rootPath.joinpath('config.json')
16 |
17 | def getModelsList(modelsFolder:Path)->List[Dict[str,str]]:
18 | modelInfoList=[]
19 | for file in modelsFolder.rglob('*.pth'):
20 | modelsinfo={
21 | "name":file.name.replace('.pth',''),
22 | "path":str(file.absolute()).replace('.pth','')
23 | }
24 | modelInfoList.append(modelsinfo)
25 | return modelInfoList
26 |
27 | if not configPath.exists():raise FileNotFoundError(f"读取config.json失败,没有找到默认配置文件!\n请在项目根目录下创建config.json文件,并参照config.json.example文件填写配置项!\n配置文件路径:{configPath}")
28 | with open(configPath,'r') as f:config = json.load(f)
29 | modelsFolder=Path(config['modelsFolder'])
30 | if not modelsFolder.exists():raise FileNotFoundError(f"读取models文件夹失败,没有找到models文件夹!\n请在项目根目录下创建models文件夹,并放入模型文件!\n配置文件路径:{modelsFolder}")
31 | modelFile=modelsFolder.joinpath(config['modelFile'])
32 | modelInfoList=getModelsList(modelsFolder)
33 | config['strategy']=config['strategy'] if torch.cuda.is_available() else 'cpu fp32'
34 |
35 | if not os.path.exists(str(modelFile)+'.pth'):
36 | if len(modelInfoList)==0:raise FileNotFoundError(f"没有找到模型文件!\n请在项目根目录下创建models文件夹,并放入模型文件!\n配置文件路径:{modelsFolder}")
37 | print(f"模型文件不存在,请先下载{config['modelFile']}模型\n现将自动使用{modelInfoList[0]['name']}模型")
38 | modelFile=modelInfoList[0]['path']
39 | config["modelFile"]=modelInfoList[0]['name']
40 | with open(configPath, 'w') as file:file.write(json.dumps(config, indent=4))
41 |
42 | model = RWKV(model=str(modelFile), strategy=config['strategy'])
43 | if model.version == 7:
44 | import sys
45 | sys.modules.pop("rwkv.model")
46 | os.environ["RWKV_V7_ON"] = "1"
47 | from rwkv.model import RWKV
48 | model = RWKV(model=str(modelFile), strategy=config['strategy'])
49 | pipeline = PIPELINE(model, "rwkv_vocab_v20230424")
50 | meowAI=MeowAI(pipeline)
51 | meowAIServer=MeowAIServer(meowAI,host=config['host'],port=int(config['port']),autoOpen=config['autoOpen'],debug=True)
52 |
53 |
54 | if __name__ == '__main__':
55 | meowAIServer.run()
56 |
57 |
--------------------------------------------------------------------------------
/meowServer.py:
--------------------------------------------------------------------------------
1 | import re,json,torch
2 | from tqdm import tqdm
3 | from rwkv.utils import PIPELINE, PIPELINE_ARGS
4 | from typing import List,Dict,Mapping,Union,Callable,Any,Tuple
5 |
6 | class Character(PIPELINE_ARGS):
7 | def __init__(self,persona:str='主人你好呀!我是你的可爱猫娘,喵~', temperature:float=1.1,top_p:float=0.7, top_k:float=0, alpha_frequency:float=0.2, alpha_presence:float=0.2, alpha_decay:float=0.996, token_ban:list=[], token_stop:list=[], chunk_len=256):
8 | super().__init__(temperature, top_p, top_k, alpha_frequency, alpha_presence, alpha_decay, token_ban, token_stop, chunk_len)
9 | self.persona=persona
10 |
11 |
12 | class MeowAI():
13 | '''
14 | 内容生成
15 | '''
16 | def __init__(self,model:PIPELINE, character:Character=Character(),max_tokens:int=2048 ):
17 | self.model=model
18 | self.character=character
19 | self.max_tokens=max_tokens
20 | self.stop=False
21 | self.chat_state=None
22 | self.talk_state=None
23 |
24 | def purr(self,txt:str)->str:
25 | return re.sub(r'\r*\n{2,}', '\n', txt.strip())
26 |
27 | def chat(self,messages:List[Mapping[str,str]],prev_state:torch.Tensor=None,callback:Callable[[str],Any]=None)->Tuple[str,torch.Tensor]:
28 | '''
29 | messages=[
30 | {
31 | "role": "User",
32 | "content": "What is the meaning of life?"
33 | },
34 | {
35 | "role": "Assistant",
36 | "content": "The meaning of life is to live a happy and fulfilling life."
37 | },
38 | {
39 | "role": "User",
40 | "content": "How do cats call?"
41 | },
42 | ]
43 | '''
44 | out_tokens = []
45 | out_len = 0
46 | out_str = ""
47 | occurrence = {}
48 | state = prev_state
49 | input_text = "\n\n".join([f"{text['role']}:{self.purr(text['content'])}" for text in messages])
50 | if state:
51 | input_text = f"\n\n{input_text}\n\nAnswerer:"
52 | else:
53 | input_text = f"Answerer:{self.purr(self.character.persona)}\n\n{input_text}\n\nAnswerer:"
54 | for i in tqdm(range(self.max_tokens),desc=f"tokens",leave=False):
55 | if self.stop:break
56 | if i == 0:
57 | out, state = self.model.model.forward(self.model.encode(input_text), state)
58 | else:
59 | out, state = self.model.model.forward([token], state)
60 | for n in occurrence:
61 | out[n] -= (
62 | self.character.alpha_frequency + occurrence[n] * self.character.alpha_presence
63 | )
64 |
65 | token = self.model.sample_logits(out, temperature=self.character.temperature, top_p=self.character.top_p)
66 | if token == 0:break
67 | out_tokens += [token]
68 |
69 | for n in occurrence:
70 | occurrence[n] *= self.character.alpha_decay
71 | occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
72 |
73 | tmp = self.model.decode(out_tokens[out_len:])
74 | out_str += tmp
75 | # self.chat_state=state
76 | if ("\ufffd" not in tmp) and ( not tmp.endswith("\n") ):
77 | out_len = i + 1
78 | elif "\n\n" in tmp:
79 | break
80 | if callback and not self.stop: callback(tmp)
81 | print(out_str.strip())
82 | return out_str.strip(),state
83 |
84 | def talk(self,input_text:str,prev_state:torch.Tensor=None,callback:Callable[[str],Any]=None)->Tuple[str,torch.Tensor]:
85 | out_tokens = []
86 | out_len = 0
87 | out_str = ""
88 | occurrence = {}
89 | state = prev_state
90 | for i in tqdm(range(self.max_tokens),desc=f"tokens",leave=False):
91 | if i == 0:
92 | out, state = self.model.model.forward(self.model.encode(input_text), state)
93 | else:
94 | out, state = self.model.model.forward([token], state)
95 | for n in occurrence:
96 | out[n] -= (
97 | self.character.alpha_frequency + occurrence[n] * self.character.persona
98 | )
99 | token = self.model.sample_logits( out, temperature=self.character.temperature, top_p=self.character.top_p)
100 |
101 | out_tokens += [token]
102 |
103 | for n in occurrence:
104 | occurrence[n] *= self.character.alpha_decay
105 | occurrence[token] = 1 + (occurrence[token] if token in occurrence else 0)
106 |
107 | tmp = self.model.decode(out_tokens[out_len:])
108 | out_str += tmp
109 | # self.talk_state=state
110 | if callback: callback(tmp)
111 | out_len = i + 1
112 | if self.stop:break
113 | print(out_str)
114 | return out_str,state
115 |
116 | from flask import Flask, render_template
117 | from flask_socketio import SocketIO
118 | import webbrowser
119 | from extensions import load_extensions
120 |
121 | class MeowAIServer():
122 | def __init__(self, meowAI:MeowAI, host:str="0.0.0.0", port:int=5000, debug:bool=False,use_reloader:bool=False,autoOpen:bool=True):
123 | if not meowAI:raise Exception("meowAI is not defined")
124 | self.meowAI = meowAI
125 | self.host=host
126 | self.port=port
127 | self.debug=debug
128 | self.use_reloader=use_reloader
129 | self.autoOpen=autoOpen
130 |
131 | self.app = Flask(__name__)
132 | self.app.jinja_env.variable_start_string = '{['
133 | self.app.jinja_env.variable_end_string = ']}'
134 | self.socketio = SocketIO(self.app)
135 | print(load_extensions(self.app,self.socketio,self.meowAI))
136 |
137 | @self.app.route('/')
138 | def index():
139 | return render_template('index.html')
140 |
141 | @self.app.route('/extension/')
142 | def extension():
143 | return render_template('extension.html')
144 |
145 | @self.socketio.on('emit')
146 | def emit(news:Dict[str,Union[int,float]]):
147 | '''
148 | news={'code':1,'msg':'XX错误'}
149 | '''
150 | self.socketio.emit('emit', json.dumps(news))
151 |
152 | @self.socketio.on('stop')
153 | def stop(status:bool=True):
154 | self.socketio.emit('stop', status)
155 | self.meowAI.stop=status
156 |
157 | @self.socketio.on('character')
158 | def handle_character(character:Dict[str,Union[int,float]]):
159 | self.meowAI.character = Character(**character)
160 | print(character)
161 |
162 | @self.socketio.on('chat')
163 | def handle_chat(message:Mapping[str,Union[str,List[Dict[str,str]]]]):
164 | try:
165 | self.meowAI.stop=False
166 | self.meowAI.chat(message,None,lambda x:self.socketio.emit('chat',x[0]))
167 | except Exception as e:
168 | ext={'code':1,'msg':f'生成错误Error:{e}'}
169 | emit(ext)
170 | print(ext)
171 | finally:
172 | stop(True)
173 |
174 | @self.socketio.on('talk')
175 | def handle_talk(prompt):
176 | try:
177 | self.meowAI.stop=False
178 | self.meowAI.talk(prompt,None,lambda x:self.socketio.emit('talk',x[0]))
179 | except Exception as e:
180 | ext={'code':1,'msg':f'生成错误Error:{e}'}
181 | emit(ext)
182 | print(ext)
183 | finally:
184 | stop(True)
185 |
186 | def run(self):
187 | if self.autoOpen:webbrowser.open(f'http://{self.host}:{self.port}')
188 | self.socketio.run(self.app,host=self.host,port=self.port,debug=self.debug,use_reloader=self.use_reloader)
--------------------------------------------------------------------------------
/models/EEADME.md:
--------------------------------------------------------------------------------
1 | 在此文件夹下存放模型文件
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
1 | torch
2 | tqdm
3 | flask
4 | flask_socketio
5 | comtypes
6 | packaging
7 |
8 |
--------------------------------------------------------------------------------
/rwkv/__pycache__/model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/rwkv/__pycache__/model.cpython-39.pyc
--------------------------------------------------------------------------------
/rwkv/__pycache__/rwkv_tokenizer.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/rwkv/__pycache__/rwkv_tokenizer.cpython-39.pyc
--------------------------------------------------------------------------------
/rwkv/__pycache__/utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/rwkv/__pycache__/utils.cpython-39.pyc
--------------------------------------------------------------------------------
/rwkv/cpp/librwkv.dylib:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/rwkv/cpp/librwkv.dylib
--------------------------------------------------------------------------------
/rwkv/cpp/librwkv.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/rwkv/cpp/librwkv.so
--------------------------------------------------------------------------------
/rwkv/cpp/model.py:
--------------------------------------------------------------------------------
1 | from typing import Any, List, Union
2 | from . import rwkv_cpp_model
3 | from . import rwkv_cpp_shared_library
4 |
5 |
6 | class RWKV:
7 | def __init__(self, model_path: str, strategy=None):
8 | self.library = rwkv_cpp_shared_library.load_rwkv_shared_library()
9 | self.model = rwkv_cpp_model.RWKVModel(self.library, model_path)
10 | self.w = {} # fake weight
11 | self.w["emb.weight"] = [0] * self.model.n_vocab
12 | self.version = (
13 | self.model.arch_version_major + self.model.arch_version_minor / 10
14 | )
15 |
16 | def forward(self, tokens: List[int], state: Union[Any, None] = None):
17 | return self.model.eval_sequence_in_chunks(tokens, state, use_numpy=True)
18 |
--------------------------------------------------------------------------------
/rwkv/cpp/rwkv.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/rwkv/cpp/rwkv.dll
--------------------------------------------------------------------------------
/rwkv/cpp/rwkv_cpp_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import multiprocessing
3 |
4 | # Pre-import PyTorch, if available.
5 | # This fixes "OSError: [WinError 127] The specified procedure could not be found".
6 | try:
7 | import torch
8 | except ModuleNotFoundError:
9 | pass
10 |
11 | # I'm sure this is not strictly correct, but let's keep this crutch for now.
12 | try:
13 | import rwkv_cpp_shared_library
14 | except ModuleNotFoundError:
15 | from . import rwkv_cpp_shared_library
16 |
17 | from typing import TypeVar, Optional, Tuple, List
18 |
19 | # A value of this type is either a numpy's ndarray or a PyTorch's Tensor.
20 | NumpyArrayOrPyTorchTensor: TypeVar = TypeVar('NumpyArrayOrPyTorchTensor')
21 |
22 | class RWKVModel:
23 | """
24 | An RWKV model managed by rwkv.cpp library.
25 | """
26 |
27 | def __init__(
28 | self,
29 | shared_library: rwkv_cpp_shared_library.RWKVSharedLibrary,
30 | model_path: str,
31 | thread_count: int = max(1, multiprocessing.cpu_count() // 2),
32 | gpu_layer_count: int = 0,
33 | **kwargs
34 | ) -> None:
35 | """
36 | Loads the model and prepares it for inference.
37 | In case of any error, this method will throw an exception.
38 |
39 | Parameters
40 | ----------
41 | shared_library : RWKVSharedLibrary
42 | rwkv.cpp shared library.
43 | model_path : str
44 | Path to RWKV model file in ggml format.
45 | thread_count : int
46 | Thread count to use. If not set, defaults to CPU count / 2.
47 | gpu_layer_count : int
48 | Count of layers to offload onto the GPU, must be >= 0.
49 | See documentation of `gpu_offload_layers` for details about layer offloading.
50 | """
51 |
52 | if 'gpu_layers_count' in kwargs:
53 | gpu_layer_count = kwargs['gpu_layers_count']
54 |
55 | if not os.path.isfile(model_path):
56 | raise ValueError(f'{model_path} is not a file')
57 |
58 | if not (thread_count > 0):
59 | raise ValueError('Thread count must be > 0')
60 |
61 | if not (gpu_layer_count >= 0):
62 | raise ValueError('GPU layer count must be >= 0')
63 |
64 | self._library: rwkv_cpp_shared_library.RWKVSharedLibrary = shared_library
65 |
66 | self._ctx: rwkv_cpp_shared_library.RWKVContext = self._library.rwkv_init_from_file(model_path, thread_count)
67 |
68 | if gpu_layer_count > 0:
69 | self.gpu_offload_layers(gpu_layer_count)
70 |
71 | self._state_buffer_element_count: int = self._library.rwkv_get_state_buffer_element_count(self._ctx)
72 | self._logits_buffer_element_count: int = self._library.rwkv_get_logits_buffer_element_count(self._ctx)
73 |
74 | self._valid: bool = True
75 |
76 | def gpu_offload_layers(self, layer_count: int) -> bool:
77 | """
78 | Offloads specified count of model layers onto the GPU. Offloaded layers are evaluated using cuBLAS or CLBlast.
79 | For the purposes of this function, model head (unembedding matrix) is treated as an additional layer:
80 | - pass `model.n_layer` to offload all layers except model head
81 | - pass `model.n_layer + 1` to offload all layers, including model head
82 |
83 | Returns true if at least one layer was offloaded.
84 | If rwkv.cpp was compiled without cuBLAS and CLBlast support, this function is a no-op and always returns false.
85 |
86 | Parameters
87 | ----------
88 | layer_count : int
89 | Count of layers to offload onto the GPU, must be >= 0.
90 | """
91 |
92 | if not (layer_count >= 0):
93 | raise ValueError('Layer count must be >= 0')
94 |
95 | return self._library.rwkv_gpu_offload_layers(self._ctx, layer_count)
96 |
97 | @property
98 | def arch_version_major(self) -> int:
99 | return self._library.rwkv_get_arch_version_major(self._ctx)
100 |
101 | @property
102 | def arch_version_minor(self) -> int:
103 | return self._library.rwkv_get_arch_version_minor(self._ctx)
104 |
105 | @property
106 | def n_vocab(self) -> int:
107 | return self._library.rwkv_get_n_vocab(self._ctx)
108 |
109 | @property
110 | def n_embed(self) -> int:
111 | return self._library.rwkv_get_n_embed(self._ctx)
112 |
113 | @property
114 | def n_layer(self) -> int:
115 | return self._library.rwkv_get_n_layer(self._ctx)
116 |
117 | def eval(
118 | self,
119 | token: int,
120 | state_in: Optional[NumpyArrayOrPyTorchTensor],
121 | state_out: Optional[NumpyArrayOrPyTorchTensor] = None,
122 | logits_out: Optional[NumpyArrayOrPyTorchTensor] = None,
123 | use_numpy: bool = False
124 | ) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]:
125 | """
126 | Evaluates the model for a single token.
127 | In case of any error, this method will throw an exception.
128 |
129 | Parameters
130 | ----------
131 | token : int
132 | Index of next token to be seen by the model. Must be in range 0 <= token < n_vocab.
133 | state_in : Optional[NumpyArrayOrTorchTensor]
134 | State from previous call of this method. If this is a first pass, set it to None.
135 | state_out : Optional[NumpyArrayOrTorchTensor]
136 | Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count).
137 | logits_out : Optional[NumpyArrayOrTorchTensor]
138 | Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count).
139 | use_numpy : bool
140 | If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors.
141 | This parameter is ignored if any tensor parameter is not None; in such case,
142 | type of returned tensors will match the type of received tensors.
143 |
144 | Returns
145 | -------
146 | logits, state
147 | Logits vector of shape (n_vocab); state for the next step.
148 | """
149 |
150 | if not self._valid:
151 | raise ValueError('Model was freed')
152 |
153 | use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
154 |
155 | if state_in is not None:
156 | self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count)
157 |
158 | state_in_ptr = self._get_data_ptr(state_in)
159 | else:
160 | state_in_ptr = 0
161 |
162 | if state_out is not None:
163 | self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count)
164 | else:
165 | state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy)
166 |
167 | if logits_out is not None:
168 | self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count)
169 | else:
170 | logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy)
171 |
172 | self._library.rwkv_eval(
173 | self._ctx,
174 | token,
175 | state_in_ptr,
176 | self._get_data_ptr(state_out),
177 | self._get_data_ptr(logits_out)
178 | )
179 |
180 | return logits_out, state_out
181 |
182 | def eval_sequence(
183 | self,
184 | tokens: List[int],
185 | state_in: Optional[NumpyArrayOrPyTorchTensor],
186 | state_out: Optional[NumpyArrayOrPyTorchTensor] = None,
187 | logits_out: Optional[NumpyArrayOrPyTorchTensor] = None,
188 | use_numpy: bool = False
189 | ) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]:
190 | """
191 | Evaluates the model for a sequence of tokens.
192 |
193 | NOTE ON GGML NODE LIMIT
194 |
195 | ggml has a hard-coded limit on max amount of nodes in a computation graph. The sequence graph is built in a way that quickly exceedes
196 | this limit when using large models and/or large sequence lengths.
197 | Fortunately, rwkv.cpp's fork of ggml has increased limit which was tested to work for sequence lengths up to 64 for 14B models.
198 |
199 | If you get `GGML_ASSERT: ...\\ggml.c:16941: cgraph->n_nodes < GGML_MAX_NODES`, this means you've exceeded the limit.
200 | To get rid of the assertion failure, reduce the model size and/or sequence length.
201 |
202 | In case of any error, this method will throw an exception.
203 |
204 | Parameters
205 | ----------
206 | tokens : List[int]
207 | Indices of the next tokens to be seen by the model. Must be in range 0 <= token < n_vocab.
208 | state_in : Optional[NumpyArrayOrTorchTensor]
209 | State from previous call of this method. If this is a first pass, set it to None.
210 | state_out : Optional[NumpyArrayOrTorchTensor]
211 | Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count).
212 | logits_out : Optional[NumpyArrayOrTorchTensor]
213 | Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count).
214 | use_numpy : bool
215 | If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors.
216 | This parameter is ignored if any tensor parameter is not None; in such case,
217 | type of returned tensors will match the type of received tensors.
218 |
219 | Returns
220 | -------
221 | logits, state
222 | Logits vector of shape (n_vocab); state for the next step.
223 | """
224 |
225 | if not self._valid:
226 | raise ValueError('Model was freed')
227 |
228 | use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
229 |
230 | if state_in is not None:
231 | self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count)
232 |
233 | state_in_ptr = self._get_data_ptr(state_in)
234 | else:
235 | state_in_ptr = 0
236 |
237 | if state_out is not None:
238 | self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count)
239 | else:
240 | state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy)
241 |
242 | if logits_out is not None:
243 | self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count)
244 | else:
245 | logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy)
246 |
247 | self._library.rwkv_eval_sequence(
248 | self._ctx,
249 | tokens,
250 | state_in_ptr,
251 | self._get_data_ptr(state_out),
252 | self._get_data_ptr(logits_out)
253 | )
254 |
255 | return logits_out, state_out
256 |
257 | def eval_sequence_in_chunks(
258 | self,
259 | tokens: List[int],
260 | state_in: Optional[NumpyArrayOrPyTorchTensor],
261 | state_out: Optional[NumpyArrayOrPyTorchTensor] = None,
262 | logits_out: Optional[NumpyArrayOrPyTorchTensor] = None,
263 | chunk_size: int = 16,
264 | use_numpy: bool = False
265 | ) -> Tuple[NumpyArrayOrPyTorchTensor, NumpyArrayOrPyTorchTensor]:
266 | """
267 | Evaluates the model for a sequence of tokens using `eval_sequence`, splitting a potentially long sequence into fixed-length chunks.
268 | This function is useful for processing complete prompts and user input in chat & role-playing use-cases.
269 | It is recommended to use this function instead of `eval_sequence` to avoid mistakes and get maximum performance.
270 |
271 | Chunking allows processing sequences of thousands of tokens, while not reaching the ggml's node limit and not consuming too much memory.
272 | A reasonable and recommended value of chunk size is 16. If you want maximum performance, try different chunk sizes in range [2..64]
273 | and choose one that works the best in your use case.
274 |
275 | In case of any error, this method will throw an exception.
276 |
277 | Parameters
278 | ----------
279 | tokens : List[int]
280 | Indices of the next tokens to be seen by the model. Must be in range 0 <= token < n_vocab.
281 | chunk_size : int
282 | Size of each chunk in tokens, must be positive.
283 | state_in : Optional[NumpyArrayOrTorchTensor]
284 | State from previous call of this method. If this is a first pass, set it to None.
285 | state_out : Optional[NumpyArrayOrTorchTensor]
286 | Optional output tensor for state. If provided, must be of type float32, contiguous and of shape (state_buffer_element_count).
287 | logits_out : Optional[NumpyArrayOrTorchTensor]
288 | Optional output tensor for logits. If provided, must be of type float32, contiguous and of shape (logits_buffer_element_count).
289 | use_numpy : bool
290 | If set to True, numpy's ndarrays will be created instead of PyTorch's Tensors.
291 | This parameter is ignored if any tensor parameter is not None; in such case,
292 | type of returned tensors will match the type of received tensors.
293 |
294 | Returns
295 | -------
296 | logits, state
297 | Logits vector of shape (n_vocab); state for the next step.
298 | """
299 |
300 | if not self._valid:
301 | raise ValueError('Model was freed')
302 |
303 | use_numpy = self._detect_numpy_usage([state_in, state_out, logits_out], use_numpy)
304 |
305 | if state_in is not None:
306 | self._validate_tensor(state_in, 'state_in', self._state_buffer_element_count)
307 |
308 | state_in_ptr = self._get_data_ptr(state_in)
309 | else:
310 | state_in_ptr = 0
311 |
312 | if state_out is not None:
313 | self._validate_tensor(state_out, 'state_out', self._state_buffer_element_count)
314 | else:
315 | state_out = self._zeros_float32(self._state_buffer_element_count, use_numpy)
316 |
317 | if logits_out is not None:
318 | self._validate_tensor(logits_out, 'logits_out', self._logits_buffer_element_count)
319 | else:
320 | logits_out = self._zeros_float32(self._logits_buffer_element_count, use_numpy)
321 |
322 | self._library.rwkv_eval_sequence_in_chunks(
323 | self._ctx,
324 | tokens,
325 | chunk_size,
326 | state_in_ptr,
327 | self._get_data_ptr(state_out),
328 | self._get_data_ptr(logits_out)
329 | )
330 |
331 | return logits_out, state_out
332 |
333 | def free(self) -> None:
334 | """
335 | Frees all allocated resources.
336 | In case of any error, this method will throw an exception.
337 | The object must not be used anymore after calling this method.
338 | """
339 |
340 | if not self._valid:
341 | raise ValueError('Already freed')
342 |
343 | self._valid = False
344 |
345 | self._library.rwkv_free(self._ctx)
346 |
347 | def __del__(self) -> None:
348 | # Free the context on GC in case user forgot to call free() explicitly.
349 | if hasattr(self, '_valid') and self._valid:
350 | self.free()
351 |
352 | def _is_pytorch_tensor(self, tensor: NumpyArrayOrPyTorchTensor) -> bool:
353 | return hasattr(tensor, '__module__') and tensor.__module__ == 'torch'
354 |
355 | def _detect_numpy_usage(self, tensors: List[Optional[NumpyArrayOrPyTorchTensor]], use_numpy_by_default: bool) -> bool:
356 | for tensor in tensors:
357 | if tensor is not None:
358 | return False if self._is_pytorch_tensor(tensor) else True
359 |
360 | return use_numpy_by_default
361 |
362 | def _validate_tensor(self, tensor: NumpyArrayOrPyTorchTensor, name: str, size: int) -> None:
363 | if self._is_pytorch_tensor(tensor):
364 | tensor: torch.Tensor = tensor
365 |
366 | if tensor.device != torch.device('cpu'):
367 | raise ValueError(f'{name} is not on CPU')
368 | if tensor.dtype != torch.float32:
369 | raise ValueError(f'{name} is not of type float32')
370 | if tensor.shape != (size,):
371 | raise ValueError(f'{name} has invalid shape {tensor.shape}, expected ({size})')
372 | if not tensor.is_contiguous():
373 | raise ValueError(f'{name} is not contiguous')
374 | else:
375 | import numpy as np
376 | tensor: np.ndarray = tensor
377 |
378 | if tensor.dtype != np.float32:
379 | raise ValueError(f'{name} is not of type float32')
380 | if tensor.shape != (size,):
381 | raise ValueError(f'{name} has invalid shape {tensor.shape}, expected ({size})')
382 | if not tensor.data.contiguous:
383 | raise ValueError(f'{name} is not contiguous')
384 |
385 | def _get_data_ptr(self, tensor: NumpyArrayOrPyTorchTensor):
386 | if self._is_pytorch_tensor(tensor):
387 | return tensor.data_ptr()
388 | else:
389 | return tensor.ctypes.data
390 |
391 | def _zeros_float32(self, element_count: int, use_numpy: bool) -> NumpyArrayOrPyTorchTensor:
392 | if use_numpy:
393 | import numpy as np
394 | return np.zeros(element_count, dtype=np.float32)
395 | else:
396 | return torch.zeros(element_count, dtype=torch.float32, device='cpu')
397 |
--------------------------------------------------------------------------------
/rwkv/cuda/gemm_fp16_cublas.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 |
9 | #define CUBLAS_CHECK(condition) \
10 | for (cublasStatus_t _cublas_check_status = (condition); \
11 | _cublas_check_status != CUBLAS_STATUS_SUCCESS;) \
12 | throw std::runtime_error("cuBLAS error " + \
13 | std::to_string(_cublas_check_status) + " at " + \
14 | std::to_string(__LINE__));
15 |
16 | #define CUDA_CHECK(condition) \
17 | for (cudaError_t _cuda_check_status = (condition); \
18 | _cuda_check_status != cudaSuccess;) \
19 | throw std::runtime_error( \
20 | "CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \
21 | " at " + std::to_string(__LINE__));
22 |
23 | /*
24 | NOTE: blas gemm is column-major by default, but we need row-major output.
25 | The data of row-major, transposed matrix is exactly the same as the
26 | column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T
27 | */
28 | void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) {
29 | const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
30 | const auto cuda_data_type = CUDA_R_16F;
31 | const auto cuda_c_data_type =
32 | c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F;
33 | const auto compute_type = CUDA_R_32F;
34 | const float sp_alpha = 1.f;
35 | // swap a and b, and use CUBLAS_OP_N. see the notes above
36 | std::swap(a, b);
37 | const cublasOperation_t cublas_trans_a = CUBLAS_OP_N;
38 | const cublasOperation_t cublas_trans_b = CUBLAS_OP_N;
39 | // m = (B^T).size(0) = B.size(1), and = A.size(1) after swap,
40 | // negative axis is used because of the existence of batch matmul.
41 | const int m = a.size(-1);
42 | const int k = a.size(-2);
43 | const int n = b.size(-2);
44 | const int cublas_lda = m;
45 | const int cublas_ldb = k;
46 | const int cublas_ldc = m;
47 | cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle();
48 |
49 | #if CUDA_VERSION >= 11000
50 | cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
51 | #else
52 | cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
53 | #endif
54 | const float sp_beta = 0.f;
55 | if (a.sizes().size() == 2 && b.sizes().size() == 2) {
56 | CUBLAS_CHECK(cublasGemmEx(
57 | cublas_handle, cublas_trans_a, cublas_trans_b, m, n, k, &sp_alpha,
58 | a.data_ptr(), cuda_data_type, cublas_lda, b.data_ptr(), cuda_data_type,
59 | cublas_ldb, &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc,
60 | compute_type, algo));
61 | } else {
62 | // batch matmul
63 | assert(a.sizes().size() == 3 && b.sizes().size() == 3);
64 |
65 | const long long int cublas_stride_a = m * k;
66 | const long long int cublas_stride_b = k * n;
67 | const long long int cublas_stride_c = m * n;
68 | CUBLAS_CHECK(cublasGemmStridedBatchedEx(
69 | cublas_handle, cublas_trans_a, cublas_trans_b, m,
70 | n, k, &sp_alpha, a.data_ptr(), cuda_data_type, cublas_lda,
71 | cublas_stride_a, b.data_ptr(), cuda_data_type, cublas_ldb, cublas_stride_b,
72 | &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, cublas_stride_c,
73 | a.size(0), compute_type, algo));
74 | }
75 | }
76 |
--------------------------------------------------------------------------------
/rwkv/cuda/operators.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include "ATen/ATen.h"
4 | #include
5 | #define MIN_VALUE (-1e38)
6 | typedef at::Half fp16;
7 | __half *cast(fp16 *ptr) {
8 | return reinterpret_cast<__half *>(ptr);
9 | }
10 |
11 | template
12 | __global__ void kernel_wkv_forward(const int B, const int T, const int C,
13 | const float *__restrict__ const _w, const float *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
14 | F *__restrict__ const _y, float *__restrict__ const _aa, float *__restrict__ const _bb, float *__restrict__ const _pp) {
15 | const int idx = blockIdx.x * blockDim.x + threadIdx.x;
16 | const int _b = idx / C;
17 | const int _c = idx % C;
18 | const int _offset = _b * T * C + _c;
19 | const int _state_offset = _b * C + _c;
20 |
21 | float u = _u[_c];
22 | float w = _w[_c];
23 | const F *__restrict__ const k = _k + _offset;
24 | const F *__restrict__ const v = _v + _offset;
25 | F *__restrict__ const y = _y + _offset;
26 |
27 | float aa = _aa[_state_offset];
28 | float bb = _bb[_state_offset];
29 | float pp = _pp[_state_offset];
30 | for (int i = 0; i < T; i++) {
31 | const int ii = i * C;
32 | const float kk = float(k[ii]);
33 | const float vv = float(v[ii]);
34 | float ww = u + kk;
35 | float p = max(pp, ww);
36 | float e1 = exp(pp - p);
37 | float e2 = exp(ww - p);
38 | y[ii] = F((e1 * aa + e2 * vv) / (e1 * bb + e2));
39 | ww = w + pp;
40 | p = max(ww, kk);
41 | e1 = exp(ww - p);
42 | e2 = exp(kk - p);
43 | aa = e1 * aa + e2 * vv;
44 | bb = e1 * bb + e2;
45 | pp = p;
46 | }
47 | _aa[_state_offset] = aa;
48 | _bb[_state_offset] = bb;
49 | _pp[_state_offset] = pp;
50 | }
51 |
52 | template
53 | void cuda_wkv_forward(int B, int T, int C, float *w, float *u, F *k, F *v, F *y, float *aa, float *bb, float *pp) {
54 | dim3 threadsPerBlock( min(C, 32) );
55 | assert(B * C % threadsPerBlock.x == 0);
56 | dim3 numBlocks(B * C / threadsPerBlock.x);
57 | kernel_wkv_forward<<>>(B, T, C, w, u, k, v, y, aa, bb, pp);
58 | }
59 |
60 | template void cuda_wkv_forward(
61 | int B, int T, int C,
62 | float *w, float *u, fp16 *k, fp16 *v, fp16 *y,
63 | float *aa, float *bb, float *pp);
64 | template void cuda_wkv_forward(
65 | int B, int T, int C,
66 | float *w, float *u, float *k, float *v, float *y,
67 | float *aa, float *bb, float *pp);
68 |
69 | __global__ void kernel_mm_seq_fp32i8(
70 | const int B, const int N, const int M,
71 | const float *__restrict__ const x, const int x_stride,
72 | const uint8_t *__restrict__ const w, const int w_stride,
73 | const float *__restrict__ const mx,
74 | const float *__restrict__ const rx,
75 | const float *__restrict__ const my,
76 | const float *__restrict__ const ry,
77 | float *__restrict__ const y, const int y_stride) {
78 |
79 | const int i = blockIdx.x * blockDim.x + threadIdx.x;
80 | const int k = blockIdx.y * blockDim.y + threadIdx.y;
81 |
82 | if (i < B && k < M) {
83 | float y_local = 0;
84 | for (int j = 0; j < N; ++j) {
85 | y_local += x[i * x_stride + j] * (
86 | (float(w[j * w_stride + k]) + 0.5f)
87 | * rx[k] * ry[j] + mx[k] + my[j]
88 | );
89 | }
90 | y[i * y_stride + k] = y_local;
91 | }
92 | }
93 |
94 | template
95 | void cuda_mm8_seq(int B, int N, int M,
96 | F *x, int x_stride,
97 | uint8_t *w, int w_stride,
98 | F *mx, F *rx,
99 | F *my, F *ry,
100 | F *y, int y_stride);
101 |
102 | template <>
103 | void cuda_mm8_seq(int B, int N, int M,
104 | float *x, int x_stride,
105 | uint8_t *w, int w_stride,
106 | float *mx, float *rx,
107 | float *my, float *ry,
108 | float *y, int y_stride) {
109 | dim3 blockSize(1, 128);
110 | dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y);
111 | kernel_mm_seq_fp32i8<<>>(
112 | B, N, M, x, x_stride, w, w_stride,
113 | mx, rx, my, ry, y, y_stride);
114 | }
115 |
116 | __global__ void kernel_mm_seq_fp16i8(
117 | const int B, const int N, const int M,
118 | const __half *__restrict__ const x, const int x_stride,
119 | const uint8_t *__restrict__ const w, const int w_stride,
120 | const __half *__restrict__ const mx,
121 | const __half *__restrict__ const rx,
122 | const __half *__restrict__ const my,
123 | const __half *__restrict__ const ry,
124 | __half *__restrict__ const y, const int y_stride) {
125 |
126 | const int i = blockIdx.x * blockDim.x + threadIdx.x;
127 | const int k = blockIdx.y * blockDim.y + threadIdx.y;
128 |
129 | if (i < B && k < M) {
130 | float y_local = 0;
131 | for (int j = 0; j < N; ++j) {
132 | y_local += __half2float(x[i * x_stride + j]) * (
133 | (float(w[j * w_stride + k]) + 0.5f)
134 | * __half2float(rx[k]) * __half2float(ry[j])
135 | + __half2float(mx[k]) + __half2float(my[j])
136 | );
137 | }
138 | y[i * y_stride + k] = __float2half(y_local);
139 | }
140 | }
141 |
142 | template <>
143 | void cuda_mm8_seq(int B, int N, int M,
144 | fp16 *x, int x_stride,
145 | uint8_t *w, int w_stride,
146 | fp16 *mx, fp16 *rx,
147 | fp16 *my, fp16 *ry,
148 | fp16 *y, int y_stride) {
149 | dim3 blockSize(1, 128);
150 | dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y);
151 | kernel_mm_seq_fp16i8<<>>(
152 | B, N, M, cast(x), x_stride, w, w_stride,
153 | cast(mx), cast(rx), cast(my), cast(ry), cast(y), y_stride);
154 | }
155 |
156 | #define MM8_ONE_JSPLIT 24
157 | #define MM8_ONE_TILE 1024
158 |
159 | __global__ void kernel_mm_one_fp32i8(
160 | const int N, const int M,
161 | const float *__restrict__ const x,
162 | const uint8_t *__restrict__ const w, const int w_stride,
163 | const float *__restrict__ const mx,
164 | const float *__restrict__ const rx,
165 | const float *__restrict__ const my,
166 | const float *__restrict__ const ry,
167 | float *__restrict__ const y) {
168 |
169 | const int k = blockIdx.y * blockDim.y + threadIdx.y;
170 | const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
171 | const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
172 |
173 | if (k < M) {
174 | float y_local = 0;
175 | for (int j = j0; j < j1; ++j) {
176 | y_local += x[j] * (
177 | (float(w[j * w_stride + k]) + 0.5f)
178 | * rx[k] * ry[j] + mx[k] + my[j]
179 | );
180 | }
181 | atomicAdd(&y[k], y_local);
182 | }
183 | }
184 |
185 | template
186 | void cuda_mm8_one(int N, int M,
187 | F *x,
188 | uint8_t *w, int w_stride,
189 | F *mx, F *rx,
190 | F *my, F *ry,
191 | float *y);
192 |
193 | template <>
194 | void cuda_mm8_one(int N, int M,
195 | float *x,
196 | uint8_t *w, int w_stride,
197 | float *mx, float *rx,
198 | float *my, float *ry,
199 | float *y) {
200 | dim3 blockSize(1, MM8_ONE_TILE);
201 | dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y);
202 | kernel_mm_one_fp32i8<<>>(
203 | N, M, x, w, w_stride,
204 | mx, rx, my, ry, y);
205 | }
206 |
207 | __global__ void kernel_mm_one_fp16i8(
208 | const int N, const int M,
209 | const __half *__restrict__ const x,
210 | const uint8_t *__restrict__ const w, const int w_stride,
211 | const __half *__restrict__ const mx,
212 | const __half *__restrict__ const rx,
213 | const __half *__restrict__ const my,
214 | const __half *__restrict__ const ry,
215 | float *__restrict__ const y) {
216 |
217 | const int k = blockIdx.y * blockDim.y + threadIdx.y;
218 | const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
219 | const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT));
220 |
221 | if (k < M) {
222 | float y_local = 0;
223 | for (int j = j0; j < j1; ++j) {
224 | y_local += __half2float(x[j]) * (
225 | (float(w[j * w_stride + k]) + 0.5f)
226 | * __half2float(rx[k]) * __half2float(ry[j])
227 | + __half2float(mx[k]) + __half2float(my[j])
228 | );
229 | }
230 | atomicAdd(&y[k], y_local);
231 | }
232 | }
233 |
234 | template <>
235 | void cuda_mm8_one(int N, int M,
236 | fp16 *x,
237 | uint8_t *w, int w_stride,
238 | fp16 *mx, fp16 *rx,
239 | fp16 *my, fp16 *ry,
240 | float *y) {
241 | dim3 blockSize(1, MM8_ONE_TILE);
242 | dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y);
243 | kernel_mm_one_fp16i8<<>>(
244 | N, M, cast(x), w, w_stride,
245 | cast(mx), cast(rx), cast(my), cast(ry), y);
246 | }
247 |
--------------------------------------------------------------------------------
/rwkv/cuda/rwkv5.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include "ATen/ATen.h"
4 | typedef at::BFloat16 bf16;
5 | typedef at::Half fp16;
6 | typedef float fp32;
7 |
8 | template
9 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state,
10 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
11 | F *__restrict__ const _y)
12 | {
13 | const int b = blockIdx.x / H;
14 | const int h = blockIdx.x % H;
15 | const int i = threadIdx.x;
16 | _w += h*_N_;
17 | _u += h*_N_;
18 | _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
19 |
20 | __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
21 |
22 | float state[_N_];
23 | #pragma unroll
24 | for (int j = 0; j < _N_; j++)
25 | state[j] = _state[j];
26 |
27 | __syncthreads();
28 | u[i] = float(_u[i]);
29 | w[i] = _w[i];
30 | __syncthreads();
31 |
32 | for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
33 | {
34 | __syncthreads();
35 | r[i] = float(_r[t]);
36 | k[i] = float(_k[t]);
37 | __syncthreads();
38 |
39 | const float v = float(_v[t]);
40 | float y = 0;
41 |
42 | #pragma unroll
43 | for (int j = 0; j < _N_; j+=4)
44 | {
45 | const float4& r_ = (float4&)(r[j]);
46 | const float4& k_ = (float4&)(k[j]);
47 | const float4& w_ = (float4&)(w[j]);
48 | const float4& u_ = (float4&)(u[j]);
49 | float4& s = (float4&)(state[j]);
50 | float4 x;
51 |
52 | x.x = k_.x * v;
53 | x.y = k_.y * v;
54 | x.z = k_.z * v;
55 | x.w = k_.w * v;
56 |
57 | y += r_.x * (u_.x * x.x + s.x);
58 | y += r_.y * (u_.y * x.y + s.y);
59 | y += r_.z * (u_.z * x.z + s.z);
60 | y += r_.w * (u_.w * x.w + s.w);
61 |
62 | s.x = s.x * w_.x + x.x;
63 | s.y = s.y * w_.y + x.y;
64 | s.z = s.z * w_.z + x.z;
65 | s.w = s.w * w_.w + x.w;
66 | }
67 | _y[t] = F(y);
68 | }
69 | #pragma unroll
70 | for (int j = 0; j < _N_; j++)
71 | _state[j] = state[j];
72 | }
73 |
74 | void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
75 | {
76 | assert(H*_N_ == C);
77 | kernel_forward<<>>(B, T, C, H, state, r, k, v, w, u, y);
78 | }
79 | void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y)
80 | {
81 | assert(H*_N_ == C);
82 | kernel_forward<<>>(B, T, C, H, state, r, k, v, w, u, y);
83 | }
84 | void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y)
85 | {
86 | assert(H*_N_ == C);
87 | kernel_forward<<>>(B, T, C, H, state, r, k, v, w, u, y);
88 | }
89 |
--------------------------------------------------------------------------------
/rwkv/cuda/rwkv5_op.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "ATen/ATen.h"
3 | #include
4 | typedef at::BFloat16 bf16;
5 | typedef at::Half fp16;
6 | typedef float fp32;
7 |
8 | void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
9 | void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
10 | void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
11 |
12 | void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
13 | const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
14 | cuda_forward_bf16(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr());
15 | }
16 | void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
17 | const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
18 | cuda_forward_fp16(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr());
19 | }
20 | void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
21 | const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
22 | cuda_forward_fp32(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr());
23 | }
24 |
25 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
26 | m.def("forward_bf16", &forward_bf16, "rwkv5 forward_bf16");
27 | m.def("forward_fp16", &forward_fp16, "rwkv5 forward_fp16");
28 | m.def("forward_fp32", &forward_fp32, "rwkv5 forward_fp32");
29 | }
30 | TORCH_LIBRARY(rwkv5, m) {
31 | m.def("forward_bf16", forward_bf16);
32 | m.def("forward_fp16", forward_fp16);
33 | m.def("forward_fp32", forward_fp32);
34 | }
35 |
--------------------------------------------------------------------------------
/rwkv/cuda/rwkv6.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include "ATen/ATen.h"
4 | typedef at::BFloat16 bf16;
5 | typedef at::Half fp16;
6 | typedef float fp32;
7 |
8 | template
9 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state,
10 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
11 | F *__restrict__ const _y)
12 | {
13 | const int b = blockIdx.x / H;
14 | const int h = blockIdx.x % H;
15 | const int i = threadIdx.x;
16 | _u += h*_N_;
17 | _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
18 |
19 | __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
20 |
21 | float state[_N_];
22 | #pragma unroll
23 | for (int j = 0; j < _N_; j++)
24 | state[j] = _state[j];
25 |
26 | __syncthreads();
27 | u[i] = float(_u[i]);
28 | __syncthreads();
29 |
30 | for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
31 | {
32 | __syncthreads();
33 | w[i] = _w[t];
34 | r[i] = float(_r[t]);
35 | k[i] = float(_k[t]);
36 | __syncthreads();
37 |
38 | const float v = float(_v[t]);
39 | float y = 0;
40 |
41 | #pragma unroll
42 | for (int j = 0; j < _N_; j+=4)
43 | {
44 | const float4& r_ = (float4&)(r[j]);
45 | const float4& k_ = (float4&)(k[j]);
46 | const float4& w_ = (float4&)(w[j]);
47 | const float4& u_ = (float4&)(u[j]);
48 | float4& s = (float4&)(state[j]);
49 | float4 x;
50 |
51 | x.x = k_.x * v;
52 | x.y = k_.y * v;
53 | x.z = k_.z * v;
54 | x.w = k_.w * v;
55 |
56 | y += r_.x * (u_.x * x.x + s.x);
57 | y += r_.y * (u_.y * x.y + s.y);
58 | y += r_.z * (u_.z * x.z + s.z);
59 | y += r_.w * (u_.w * x.w + s.w);
60 |
61 | s.x = s.x * w_.x + x.x;
62 | s.y = s.y * w_.y + x.y;
63 | s.z = s.z * w_.z + x.z;
64 | s.w = s.w * w_.w + x.w;
65 | }
66 | _y[t] = F(y);
67 | }
68 | #pragma unroll
69 | for (int j = 0; j < _N_; j++)
70 | _state[j] = state[j];
71 | }
72 |
73 | void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
74 | {
75 | assert(H*_N_ == C);
76 | kernel_forward<<>>(B, T, C, H, state, r, k, v, w, u, y);
77 | }
78 | void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y)
79 | {
80 | assert(H*_N_ == C);
81 | kernel_forward<<>>(B, T, C, H, state, r, k, v, w, u, y);
82 | }
83 | void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y)
84 | {
85 | assert(H*_N_ == C);
86 | kernel_forward<<>>(B, T, C, H, state, r, k, v, w, u, y);
87 | }
88 |
--------------------------------------------------------------------------------
/rwkv/cuda/rwkv6_op.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "ATen/ATen.h"
3 | #include
4 | typedef at::BFloat16 bf16;
5 | typedef at::Half fp16;
6 | typedef float fp32;
7 |
8 | void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
9 | void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
10 | void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
11 |
12 | void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
13 | const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
14 | cuda_forward_bf16(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr());
15 | }
16 | void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
17 | const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
18 | cuda_forward_fp16(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr());
19 | }
20 | void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
21 | const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
22 | cuda_forward_fp32(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr());
23 | }
24 |
25 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
26 | m.def("forward_bf16", &forward_bf16, "rwkv6 forward_bf16");
27 | m.def("forward_fp16", &forward_fp16, "rwkv6 forward_fp16");
28 | m.def("forward_fp32", &forward_fp32, "rwkv6 forward_fp32");
29 | }
30 | TORCH_LIBRARY(rwkv6, m) {
31 | m.def("forward_bf16", forward_bf16);
32 | m.def("forward_fp16", forward_fp16);
33 | m.def("forward_fp32", forward_fp32);
34 | }
35 |
--------------------------------------------------------------------------------
/rwkv/cuda/rwkv7.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include "ATen/ATen.h"
4 |
5 | typedef at::Half fp16;
6 | typedef at::BFloat16 bf16;
7 | typedef float fp32;
8 |
9 | template
10 | __global__ void kernel_forward(const int B, const int T, const int C, const int H,
11 | float *__restrict__ _state, const F *__restrict__ const _r, const F *__restrict__ const _w, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _a, const F *__restrict__ const _b,
12 | F *__restrict__ const _y)
13 | {
14 | const int e = blockIdx.x / H;
15 | const int h = blockIdx.x % H;
16 | const int i = threadIdx.x;
17 | _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
18 |
19 | float state[_N_];
20 | #pragma unroll
21 | for (int j = 0; j < _N_; j++)
22 | state[j] = _state[j];
23 |
24 | __shared__ float r[_N_], k[_N_], w[_N_], a[_N_], b[_N_];
25 |
26 | for (int _t = 0; _t < T; _t++)
27 | {
28 | const int t = e*T*C + h*_N_ + i + _t * C;
29 | __syncthreads();
30 | r[i] = float(_r[t]);
31 | w[i] = __expf(-__expf(float(_w[t])));
32 | k[i] = float(_k[t]);
33 | a[i] = float(_a[t]);
34 | b[i] = float(_b[t]);
35 | __syncthreads();
36 |
37 | float sa = 0;
38 | #pragma unroll
39 | for (int j = 0; j < _N_; j++)
40 | {
41 | sa += a[j] * state[j];
42 | }
43 |
44 | float vv = float(_v[t]);
45 | float y = 0;
46 | #pragma unroll
47 | for (int j = 0; j < _N_; j++)
48 | {
49 | float& s = state[j];
50 | s = s * w[j] + k[j] * vv + sa * b[j];
51 | y += s * r[j];
52 | }
53 | _y[t] = F(y);
54 | }
55 | #pragma unroll
56 | for (int j = 0; j < _N_; j++)
57 | _state[j] = state[j];
58 | }
59 |
60 | void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16* w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y)
61 | {
62 | assert(H*_N_ == C);
63 | assert(B == 1); // only for B=1
64 | kernel_forward<<>>(B, T, C, H, state, r, w, k, v, a, b, y);
65 | }
66 | void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16* w, fp16 *k, fp16 *v, fp16 *a, fp16 *b, fp16 *y)
67 | {
68 | assert(H*_N_ == C);
69 | assert(B == 1); // only for B=1
70 | kernel_forward<<>>(B, T, C, H, state, r, w, k, v, a, b, y);
71 | }
72 | void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32* w, fp32 *k, fp32 *v, fp32 *a, fp32 *b, fp32 *y)
73 | {
74 | assert(H*_N_ == C);
75 | assert(B == 1); // only for B=1
76 | kernel_forward<<>>(B, T, C, H, state, r, w, k, v, a, b, y);
77 | }
78 |
--------------------------------------------------------------------------------
/rwkv/cuda/rwkv7_op.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "ATen/ATen.h"
3 |
4 | typedef at::Half fp16;
5 | typedef at::BFloat16 bf16;
6 | typedef float fp32;
7 |
8 | void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y);
9 | void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *w, fp16 *k, fp16 *v, fp16 *a, fp16 *b, fp16 *y);
10 | void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *w, fp32 *k, fp32 *v, fp32 *a, fp32 *b, fp32 *y);
11 |
12 | void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
13 | cuda_forward_bf16(B, T, C, H, state.data_ptr(), r.data_ptr(), w.data_ptr(), k.data_ptr(), v.data_ptr(), a.data_ptr(), b.data_ptr(), y.data_ptr());
14 | }
15 | void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
16 | cuda_forward_fp16(B, T, C, H, state.data_ptr(), r.data_ptr(), w.data_ptr(), k.data_ptr(), v.data_ptr(), a.data_ptr(), b.data_ptr(), y.data_ptr());
17 | }
18 | void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
19 | cuda_forward_fp32(B, T, C, H, state.data_ptr(), r.data_ptr(), w.data_ptr(), k.data_ptr(), v.data_ptr(), a.data_ptr(), b.data_ptr(), y.data_ptr());
20 | }
21 |
22 | TORCH_LIBRARY(wkv7s, m) {
23 | m.def("forward_bf16", forward_bf16);
24 | m.def("forward_fp16", forward_fp16);
25 | m.def("forward_fp32", forward_fp32);
26 | }
27 |
--------------------------------------------------------------------------------
/rwkv/cuda/wrapper.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "ATen/ATen.h"
3 | #include
4 | #include
5 |
6 | typedef at::Half fp16;
7 |
8 | template
9 | void cuda_wkv_forward(int B, int T, int C,
10 | float *w, float *u, F *k, F *v, F *y,
11 | float *aa, float *bb, float *pp);
12 | template
13 | void cuda_mm8_seq(int B, int N, int M,
14 | F *x, int x_stride,
15 | uint8_t *w, int w_stride,
16 | F *mx, F *rx,
17 | F *my, F *ry,
18 | F *y, int y_stride);
19 | template
20 | void cuda_mm8_one(int N, int M,
21 | F *x,
22 | uint8_t *w, int w_stride,
23 | F *mx, F *rx,
24 | F *my, F *ry,
25 | float *y);
26 |
27 | void wkv_forward(int64_t B, int64_t T, int64_t C,
28 | torch::Tensor &w, torch::Tensor &u,
29 | torch::Tensor &k, torch::Tensor &v, torch::Tensor &y,
30 | torch::Tensor &aa, torch::Tensor &bb, torch::Tensor &pp) {
31 | const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
32 | switch (k.scalar_type()) {
33 | case c10::ScalarType::Half:
34 | cuda_wkv_forward(B, T, C,
35 | w.data_ptr(), u.data_ptr(),
36 | k.data_ptr(), v.data_ptr(), y.data_ptr(),
37 | aa.data_ptr(), bb.data_ptr(), pp.data_ptr());
38 | break;
39 | case c10::ScalarType::Float:
40 | cuda_wkv_forward(B, T, C,
41 | w.data_ptr(), u.data_ptr(),
42 | k.data_ptr(), v.data_ptr(), y.data_ptr(),
43 | aa.data_ptr(), bb.data_ptr(), pp.data_ptr());
44 | break;
45 | default:
46 | assert(false && "Only FP16 and FP32 are currently supported");
47 | }
48 | }
49 |
50 | void mm8_seq(int64_t B, int64_t N, int64_t M,
51 | torch::Tensor &x, torch::Tensor &w,
52 | torch::Tensor &mx, torch::Tensor &rx,
53 | torch::Tensor &my, torch::Tensor &ry,
54 | torch::Tensor &y) {
55 | assert(x.stride(1) == 1);
56 | assert(w.stride(1) == 1);
57 | assert(mx.stride(0) == 1 && rx.stride(0) == 1);
58 | assert(my.stride(0) == 1 && ry.stride(0) == 1);
59 | assert(y.stride(1) == 1);
60 | const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
61 | switch (x.scalar_type()) {
62 | case c10::ScalarType::Half:
63 | cuda_mm8_seq(
64 | B, N, M,
65 | x.data_ptr(), x.stride(0),
66 | w.data_ptr(), w.stride(0),
67 | mx.data_ptr(), rx.data_ptr(),
68 | my.data_ptr(), ry.data_ptr(),
69 | y.data_ptr(), y.stride(0));
70 | break;
71 | case c10::ScalarType::Float:
72 | cuda_mm8_seq(
73 | B, N, M,
74 | x.data_ptr(), x.stride(0),
75 | w.data_ptr(), w.stride(0),
76 | mx.data_ptr(), rx.data_ptr(),
77 | my.data_ptr(), ry.data_ptr(),
78 | y.data_ptr(), y.stride(0));
79 | break;
80 | default:
81 | assert(false && "Only FP16 and FP32 are currently supported");
82 | }
83 | }
84 | void mm8_one(int64_t N, int64_t M,
85 | torch::Tensor &x, torch::Tensor &w,
86 | torch::Tensor &mx, torch::Tensor &rx,
87 | torch::Tensor &my, torch::Tensor &ry,
88 | torch::Tensor &y) {
89 | assert(x.stride(0) == 1);
90 | assert(w.stride(1) == 1);
91 | assert(mx.stride(0) == 1 && rx.stride(0) == 1);
92 | assert(my.stride(0) == 1 && ry.stride(0) == 1);
93 | assert(y.stride(0) == 1);
94 | const at::cuda::OptionalCUDAGuard device_guard(device_of(w));
95 | switch (x.scalar_type()) {
96 | case c10::ScalarType::Half:
97 | cuda_mm8_one(
98 | N, M,
99 | x.data_ptr(),
100 | w.data_ptr(), w.stride(0),
101 | mx.data_ptr(), rx.data_ptr(),
102 | my.data_ptr(), ry.data_ptr(),
103 | y.data_ptr());
104 | break;
105 | case c10::ScalarType::Float:
106 | cuda_mm8_one(
107 | N, M,
108 | x.data_ptr(),
109 | w.data_ptr(), w.stride(0),
110 | mx.data_ptr(), rx.data_ptr(),
111 | my.data_ptr(), ry.data_ptr(),
112 | y.data_ptr());
113 | break;
114 | default:
115 | assert(false && "Only FP16 and FP32 are currently supported");
116 | }
117 | }
118 |
119 | using torch::Tensor;
120 |
121 | #ifndef DISABLE_CUBLAS_GEMM
122 | void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c);
123 | #endif
124 |
125 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
126 | m.def("wkv_forward", &wkv_forward, "wkv forward");
127 | m.def("mm8_seq", &mm8_seq, "mm8 seq");
128 | m.def("mm8_one", &mm8_one, "mm8 one");
129 | #ifndef DISABLE_CUBLAS_GEMM
130 | m.def("gemm_fp16_cublas", &gemm_fp16_cublas, "gemv fp16 cublas");
131 | #endif
132 | }
133 |
134 | TORCH_LIBRARY(rwkv, m) {
135 | m.def("wkv_forward", wkv_forward);
136 | m.def("mm8_seq", mm8_seq);
137 | m.def("mm8_one", mm8_one);
138 | #ifndef DISABLE_CUBLAS_GEMM
139 | m.def("gemm_fp16_cublas", gemm_fp16_cublas);
140 | #endif
141 | }
142 |
--------------------------------------------------------------------------------
/rwkv/rwkv5.pyd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/rwkv/rwkv5.pyd
--------------------------------------------------------------------------------
/rwkv/rwkv6.pyd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/rwkv/rwkv6.pyd
--------------------------------------------------------------------------------
/rwkv/rwkv_tokenizer.py:
--------------------------------------------------------------------------------
1 | ########################################################################################################
2 | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
3 | ########################################################################################################
4 |
5 |
6 | class TRIE:
7 | __slots__ = tuple("ch,to,values,front".split(","))
8 | to: list
9 | values: set
10 |
11 | def __init__(self, front=None, ch=None):
12 | self.ch = ch
13 | self.to = [None for ch in range(256)]
14 | self.values = set()
15 | self.front = front
16 |
17 | def __repr__(self):
18 | fr = self
19 | ret = []
20 | while fr != None:
21 | if fr.ch != None:
22 | ret.append(fr.ch)
23 | fr = fr.front
24 | return "" % (ret[::-1], self.values)
25 |
26 | def add(self, key: bytes, idx: int = 0, val=None):
27 | if idx == len(key):
28 | if val is None:
29 | val = key
30 | self.values.add(val)
31 | return self
32 | ch = key[idx]
33 | if self.to[ch] is None:
34 | self.to[ch] = TRIE(front=self, ch=ch)
35 | return self.to[ch].add(key, idx=idx + 1, val=val)
36 |
37 | def find_longest(self, key: bytes, idx: int = 0):
38 | u: TRIE = self
39 | ch: int = key[idx]
40 |
41 | while u.to[ch] is not None:
42 | u = u.to[ch]
43 | idx += 1
44 | if u.values:
45 | ret = idx, u, u.values
46 | if idx == len(key):
47 | break
48 | ch = key[idx]
49 | return ret
50 |
51 |
52 | class TRIE_TOKENIZER:
53 | def __init__(self, file_name):
54 | self.idx2token = {}
55 | sorted = [] # must be already sorted
56 | with open(file_name, "r", encoding="utf-8") as f:
57 | lines = f.readlines()
58 | for l in lines:
59 | idx = int(l[: l.index(" ")])
60 | x = eval(l[l.index(" ") : l.rindex(" ")])
61 | x = x.encode("utf-8") if isinstance(x, str) else x
62 | assert isinstance(x, bytes)
63 | assert len(x) == int(l[l.rindex(" ") :])
64 | sorted += [x]
65 | self.idx2token[idx] = x
66 |
67 | self.token2idx = {}
68 | for k, v in self.idx2token.items():
69 | self.token2idx[v] = int(k)
70 |
71 | self.root = TRIE()
72 | for t, i in self.token2idx.items():
73 | _ = self.root.add(t, val=(t, i))
74 |
75 | def encodeBytes(self, src: bytes):
76 | idx: int = 0
77 | tokens = []
78 | while idx < len(src):
79 | _idx: int = idx
80 | idx, _, values = self.root.find_longest(src, idx)
81 | assert idx != _idx
82 | _, token = next(iter(values))
83 | tokens.append(token)
84 | return tokens
85 |
86 | def decodeBytes(self, tokens):
87 | return b"".join(map(lambda i: self.idx2token[i], tokens))
88 |
89 | def encode(self, src):
90 | return self.encodeBytes(src.encode("utf-8"))
91 |
92 | def decode(self, tokens):
93 | try:
94 | return self.decodeBytes(tokens).decode("utf-8")
95 | except:
96 | return "\ufffd" # bad utf-8
97 |
98 | def printTokens(self, tokens):
99 | for i in tokens:
100 | s = self.idx2token[i]
101 | try:
102 | s = s.decode("utf-8")
103 | except:
104 | pass
105 | print(f"{repr(s)}{i}", end=" ")
106 | print()
107 |
--------------------------------------------------------------------------------
/rwkv/utils.py:
--------------------------------------------------------------------------------
1 | ########################################################################################################
2 | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
3 | ########################################################################################################
4 |
5 | import os, sys
6 | import numpy as np
7 | import torch
8 | from torch.nn import functional as F
9 |
10 |
11 | class PIPELINE_ARGS:
12 | def __init__(
13 | self,
14 | temperature=1.0,
15 | top_p=0.85,
16 | top_k=0,
17 | alpha_frequency=0.2,
18 | alpha_presence=0.2,
19 | alpha_decay=0.996,
20 | token_ban=[],
21 | token_stop=[],
22 | chunk_len=256,
23 | ):
24 | self.temperature = temperature
25 | self.top_p = top_p
26 | self.top_k = top_k
27 | self.alpha_frequency = alpha_frequency # Frequency Penalty (as in GPT-3)
28 | self.alpha_presence = alpha_presence # Presence Penalty (as in GPT-3)
29 | self.alpha_decay = alpha_decay # gradually decay the penalty
30 | self.token_ban = token_ban # ban the generation of some tokens
31 | self.token_stop = token_stop # stop generation whenever you see any token here
32 | self.chunk_len = (
33 | chunk_len # split input into chunks to save VRAM (shorter -> slower)
34 | )
35 |
36 |
37 | class ABC_TOKENIZER:
38 | def __init__(self):
39 | self.pad_token_id = 0
40 | self.bos_token_id = 2
41 | self.eos_token_id = 3
42 |
43 | def encode(self, text):
44 | ids = [ord(c) for c in text]
45 | return ids
46 |
47 | def decode(self, ids):
48 | txt = "".join(
49 | chr(idx) if idx > self.eos_token_id else ""
50 | for idx in ids
51 | if idx != self.eos_token_id
52 | )
53 | return txt
54 |
55 |
56 | class PIPELINE:
57 | def __init__(self, model, WORD_NAME: str):
58 | self.model = model
59 |
60 | if WORD_NAME == "cl100k_base":
61 | import tiktoken
62 |
63 | self.tokenizer = tiktoken.get_encoding(WORD_NAME)
64 | elif WORD_NAME == "rwkv_vocab_v20230424":
65 | sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
66 | from rwkv_tokenizer import TRIE_TOKENIZER
67 |
68 | self.tokenizer = TRIE_TOKENIZER(
69 | os.path.dirname(os.path.abspath(__file__)) + "/rwkv_vocab_v20230424.txt"
70 | )
71 | elif WORD_NAME == "abc_tokenizer":
72 | self.tokenizer = ABC_TOKENIZER()
73 | else:
74 | if WORD_NAME.endswith(".txt"):
75 | sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
76 | from rwkv_tokenizer import TRIE_TOKENIZER
77 |
78 | self.tokenizer = TRIE_TOKENIZER(WORD_NAME)
79 | else:
80 | from tokenizers import Tokenizer
81 |
82 | self.tokenizer = Tokenizer.from_file(WORD_NAME)
83 |
84 | def refine_context(self, context):
85 | context = context.strip().split("\n")
86 | for c in range(len(context)):
87 | context[c] = context[c].strip().strip("\u3000").strip("\r")
88 | context = list(filter(lambda c: c != "", context))
89 | context = "\n" + ("\n".join(context)).strip()
90 | if context == "":
91 | context = "\n"
92 | return context
93 |
94 | def encode(self, x):
95 | if "Tokenizer" in str(type(self.tokenizer)):
96 | return self.tokenizer.encode(x).ids
97 | else:
98 | return self.tokenizer.encode(x)
99 |
100 | def decode(self, x):
101 | return self.tokenizer.decode(x)
102 |
103 | def np_softmax(self, x: np.ndarray, axis: int):
104 | x -= x.max(axis=axis, keepdims=True)
105 | e: np.ndarray = np.exp(x)
106 | return e / e.sum(axis=axis, keepdims=True)
107 |
108 | def sample_logits(self, logits, temperature=1.0, top_p=0.85, top_k=0):
109 | if type(logits) == list:
110 | logits = np.array(logits)
111 | np_logits = type(logits) == np.ndarray
112 | if np_logits:
113 | probs = self.np_softmax(logits, axis=-1)
114 | else:
115 | probs = F.softmax(logits.float(), dim=-1)
116 | top_k = int(top_k)
117 | # 'privateuseone' is the type of custom devices like `torch_directml.device()`
118 | if np_logits or probs.device.type in ["cpu", "privateuseone"]:
119 | if not np_logits:
120 | probs = probs.cpu().numpy()
121 | sorted_ids = np.argsort(probs)
122 | sorted_probs = probs[sorted_ids][::-1]
123 | cumulative_probs = np.cumsum(sorted_probs)
124 | cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)])
125 | probs[probs < cutoff] = 0
126 | if top_k < len(probs) and top_k > 0:
127 | probs[sorted_ids[:-top_k]] = 0
128 | if temperature != 1.0:
129 | probs = probs ** (1.0 / temperature)
130 | probs = probs / np.sum(probs)
131 | out = np.random.choice(a=len(probs), p=probs)
132 | return int(out)
133 | else:
134 | sorted_ids = torch.argsort(probs)
135 | sorted_probs = probs[sorted_ids]
136 | sorted_probs = torch.flip(sorted_probs, dims=(0,))
137 | cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy()
138 | cutoff = float(sorted_probs[np.argmax(cumulative_probs >= top_p)])
139 | probs[probs < cutoff] = 0
140 | if top_k < len(probs) and top_k > 0:
141 | probs[sorted_ids[:-top_k]] = 0
142 | if temperature != 1.0:
143 | probs = probs ** (1.0 / temperature)
144 | out = torch.multinomial(probs, num_samples=1)[0]
145 | return int(out)
146 |
147 | def generate(
148 | self, ctx, token_count=100, args=PIPELINE_ARGS(), callback=None, state=None
149 | ):
150 | all_tokens = []
151 | out_last = 0
152 | out_str = ""
153 | occurrence = {}
154 | for i in range(token_count):
155 | # forward & adjust prob.
156 | tokens = self.encode(ctx) if i == 0 else [token]
157 | while len(tokens) > 0:
158 | out, state = self.model.forward(tokens[: args.chunk_len], state)
159 | tokens = tokens[args.chunk_len :]
160 |
161 | for n in args.token_ban:
162 | out[n] = -float("inf")
163 | for n in occurrence:
164 | out[n] -= args.alpha_presence + occurrence[n] * args.alpha_frequency
165 |
166 | # sampler
167 | token = self.sample_logits(
168 | out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k
169 | )
170 | if token in args.token_stop:
171 | break
172 | all_tokens += [token]
173 | for xxx in occurrence:
174 | occurrence[xxx] *= args.alpha_decay
175 |
176 | ttt = self.decode([token])
177 | www = 1
178 | if ttt in " \t0123456789":
179 | www = 0
180 | # elif ttt in '\r\n,.;?!"\':+-*/=#@$%^&_`~|<>\\()[]{},。;“”:?!()【】':
181 | # www = 0.5
182 | if token not in occurrence:
183 | occurrence[token] = www
184 | else:
185 | occurrence[token] += www
186 | # print(occurrence) # debug
187 |
188 | # output
189 | tmp = self.decode(all_tokens[out_last:])
190 | if "\ufffd" not in tmp: # is valid utf-8 string?
191 | if callback:
192 | callback(tmp)
193 | out_str += tmp
194 | out_last = i + 1
195 | return out_str
196 |
--------------------------------------------------------------------------------
/rwkv/webgpu/model.py:
--------------------------------------------------------------------------------
1 | from typing import Any, List, Union
2 |
3 | try:
4 | import web_rwkv_py as wrp
5 | except ModuleNotFoundError:
6 | try:
7 | from . import web_rwkv_py as wrp
8 | except ImportError:
9 | raise ModuleNotFoundError(
10 | "web_rwkv_py not found, install it from https://github.com/cryscan/web-rwkv-py"
11 | )
12 |
13 |
14 | class RWKV:
15 | def __init__(self, model_path: str, strategy: str = None):
16 | layer = (
17 | int(s.lstrip("layer"))
18 | for s in strategy.split()
19 | for s in s.split(",")
20 | if s.startswith("layer")
21 | )
22 |
23 | chunk_size = (
24 | int(s.lstrip("chunk"))
25 | for s in strategy.split()
26 | for s in s.split(",")
27 | if s.startswith("chunk")
28 | )
29 | self.token_chunk_size = next(chunk_size, 32)
30 |
31 | args = {
32 | "path": model_path,
33 | "quant": next(layer, 31) if "i8" in strategy else 0,
34 | "quant_nf4": next(layer, 26) if "i4" in strategy else 0,
35 | }
36 | self.model = wrp.Model(**args)
37 | self.info = self.model.info()
38 | self.w = {} # fake weight
39 | self.w["emb.weight"] = [0] * self.info.num_vocab
40 | self.version = str(self.info.version).lower()
41 | self.version = float(self.version.lower().replace("v", ""))
42 |
43 | def forward(self, tokens: List[int], state: Union[Any, None] = None):
44 | if state is None:
45 | self.model.clear_state()
46 | elif type(state).__name__ == "State_Cpu":
47 | self.model.load_state(state)
48 | logits = self.model.run(tokens, self.token_chunk_size)
49 | ret_state = "State_Gpu"
50 | return logits, ret_state
51 |
--------------------------------------------------------------------------------
/rwkv/webgpu/web_rwkv_py.cp310-win_amd64.pyd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/rwkv/webgpu/web_rwkv_py.cp310-win_amd64.pyd
--------------------------------------------------------------------------------
/rwkv/wkv7s.pyd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/rwkv/wkv7s.pyd
--------------------------------------------------------------------------------
/rwkv/wkv_cuda.pyd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/rwkv/wkv_cuda.pyd
--------------------------------------------------------------------------------
/static/config/default.json:
--------------------------------------------------------------------------------
1 | {
2 | "persona": "你好,我是你的智能助理,我将提供专家般的全面回应,请随时提出任何问题,我将永远回答你。",
3 | "coordinate": {
4 | "x": 25,
5 | "y": 35,
6 | "isPC": true
7 | },
8 | "characters": {
9 | "sense": 0.7,
10 | "pizzazz": 0.5
11 | },
12 | "messages": [],
13 | "users": {
14 | "you": {
15 | "name": "you",
16 | "avatar": "/static/img/bilibini.png"
17 | },
18 | "other": {
19 | "name": "MeowAI",
20 | "avatar": "/static/img/ai.png"
21 | }
22 | }
23 | }
--------------------------------------------------------------------------------
/static/config/noke.json:
--------------------------------------------------------------------------------
1 | {
2 | "persona": "喵喵~主人您好呀!我是你的小猫咪,喵~",
3 | "coordinate": {
4 | "x": 97,
5 | "y": 47,
6 | "isPC": true
7 | },
8 | "characters": {
9 | "sense": 0.95,
10 | "pizzazz": 47.35
11 | },
12 | "messages": [],
13 | "users": {
14 | "you": {
15 | "name": "you",
16 | "avatar": "/static/img/bilibini.png"
17 | },
18 | "other": {
19 | "name": "noke",
20 | "avatar": "/static/img/ai.png"
21 | }
22 | }
23 | }
--------------------------------------------------------------------------------
/static/css/noticejs.css:
--------------------------------------------------------------------------------
1 | .noticejs-top{top:0;width:100%!important}.noticejs-top .item{border-radius:0!important;margin:0!important}.noticejs-topRight{top:10px;right:10px}.noticejs-topLeft{top:10px;left:10px}.noticejs-topCenter{top:10px;left:50%;transform:translate(-50%)}.noticejs-middleLeft,.noticejs-middleRight{right:10px;top:50%;transform:translateY(-50%)}.noticejs-middleLeft{left:10px}.noticejs-middleCenter{top:50%;left:50%;transform:translate(-50%,-50%)}.noticejs-bottom{bottom:0;width:100%!important}.noticejs-bottom .item{border-radius:0!important;margin:0!important}.noticejs-bottomRight{bottom:10px;right:10px}.noticejs-bottomLeft{bottom:10px;left:10px}.noticejs-bottomCenter{bottom:10px;left:50%;transform:translate(-50%)}.noticejs{font-family:Helvetica Neue,Helvetica,Arial,sans-serif}.noticejs .item{margin:0 0 10px;border-radius:3px;overflow:hidden}.noticejs .item .close{float:right;font-size:18px;font-weight:700;line-height:1;color:#fff;text-shadow:0 1px 0 #fff;opacity:1;margin-right:7px}.noticejs .item .close:hover{opacity:.5;color:#000}.noticejs .item a{color:#fff;border-bottom:1px dashed #fff}.noticejs .item a,.noticejs .item a:hover{text-decoration:none}.noticejs .success{background-color:#64ce83}.noticejs .success .noticejs-heading{background-color:#3da95c;color:#fff;padding:10px}.noticejs .success .noticejs-body{color:#fff;padding:10px}.noticejs .success .noticejs-body:hover{visibility:visible!important}.noticejs .success .noticejs-content{visibility:visible}.noticejs .info{background-color:#3ea2ff}.noticejs .info .noticejs-heading{background-color:#067cea;color:#fff;padding:10px}.noticejs .info .noticejs-body{color:#fff;padding:10px}.noticejs .info .noticejs-body:hover{visibility:visible!important}.noticejs .info .noticejs-content{visibility:visible}.noticejs .warning{background-color:#ff7f48}.noticejs .warning .noticejs-heading{background-color:#f44e06;color:#fff;padding:10px}.noticejs .warning .noticejs-body{color:#fff;padding:10px}.noticejs .warning .noticejs-body:hover{visibility:visible!important}.noticejs .warning .noticejs-content{visibility:visible}.noticejs .error{background-color:#e74c3c}.noticejs .error .noticejs-heading{background-color:#ba2c1d;color:#fff;padding:10px}.noticejs .error .noticejs-body{color:#fff;padding:10px}.noticejs .error .noticejs-body:hover{visibility:visible!important}.noticejs .error .noticejs-content{visibility:visible}.noticejs .progressbar{width:100%}.noticejs .progressbar .bar{width:1%;height:30px;background-color:#4caf50}.noticejs .success .noticejs-progressbar{width:100%;background-color:#64ce83;margin-top:-1px}.noticejs .success .noticejs-progressbar .noticejs-bar{width:100%;height:5px;background:#3da95c}.noticejs .info .noticejs-progressbar{width:100%;background-color:#3ea2ff;margin-top:-1px}.noticejs .info .noticejs-progressbar .noticejs-bar{width:100%;height:5px;background:#067cea}.noticejs .warning .noticejs-progressbar{width:100%;background-color:#ff7f48;margin-top:-1px}.noticejs .warning .noticejs-progressbar .noticejs-bar{width:100%;height:5px;background:#f44e06}.noticejs .error .noticejs-progressbar{width:100%;background-color:#e74c3c;margin-top:-1px}.noticejs .error .noticejs-progressbar .noticejs-bar{width:100%;height:5px;background:#ba2c1d}@keyframes noticejs-fadeOut{0%{opacity:1}to{opacity:0}}.noticejs-fadeOut{animation-name:noticejs-fadeOut}@keyframes noticejs-modal-in{to{opacity:.3}}@keyframes noticejs-modal-out{to{opacity:0}}.noticejs-rtl .noticejs-heading{direction:rtl}.noticejs-rtl .close{float:left!important;margin-left:7px;margin-right:0!important}.noticejs-rtl .noticejs-content{direction:rtl}.noticejs{position:fixed;z-index:10050}.noticejs ::-webkit-scrollbar{width:8px}.noticejs ::-webkit-scrollbar-button{width:8px;height:5px}.noticejs ::-webkit-scrollbar-track{border-radius:10px}.noticejs ::-webkit-scrollbar-thumb{background:hsla(0,0%,100%,.5);border-radius:10px}.noticejs ::-webkit-scrollbar-thumb:hover{background:#fff}.noticejs-modal{position:fixed;width:100%;height:100%;background-color:#000;z-index:10000;opacity:.3;left:0;top:0}.noticejs-modal-open{opacity:0;animation:noticejs-modal-in .3s ease-out}.noticejs-modal-close{animation:noticejs-modal-out .3s ease-out;animation-fill-mode:forwards}
--------------------------------------------------------------------------------
/static/css/styles.css:
--------------------------------------------------------------------------------
1 | body,* {
2 | margin: 0;
3 | padding: 0;
4 | }
5 | #app {
6 | font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
7 | margin: 0;
8 | padding: 0;
9 | background-color: #f8f8f8;
10 | display: flex;
11 | align-items: stretch;
12 | height: 100vh;
13 | width: 100%;
14 | }
15 |
16 |
17 | #chat-container {
18 | flex: 1;
19 | background-color: #fff;
20 | box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
21 | display: flex;
22 | flex-direction: column;
23 | height: 100%;
24 | width: 100%;
25 | }
26 |
27 | #chat-header {
28 | background-color: #4caf50;
29 | color: #fff;
30 | padding: 10px;
31 | text-align: center;
32 | font-size: 18px;
33 | border-bottom: 1px solid #ccc;
34 | }
35 |
36 | #chat-messages {
37 | padding: 20px 20px 50px;
38 | flex: 1;
39 | overflow: hidden;
40 | max-height: 100%;
41 | max-width: 100%;
42 | overflow-y: scroll;
43 | }
44 |
45 | .status-button{
46 | display: flex;
47 | flex-direction: column;
48 | position: absolute;
49 | bottom: 100px;
50 | right: 20px;
51 | height: 90px;
52 | flex-wrap: nowrap;
53 | button{
54 | border-color: transparent;
55 | color: #fff;
56 | background-color: #4caf50;
57 | max-width: 65px;
58 | max-height: 65px;
59 | min-width: 40px;
60 | min-height: 40px;
61 | padding: 5px;
62 | border-radius: 10px;
63 | position: relative;
64 | margin-bottom: 5px;
65 | cursor: pointer;
66 | }
67 | }
68 |
69 | #more-button{
70 | border-color: transparent;
71 | color: #fff;
72 | background-color: #4caf50;
73 | max-width: 65px;
74 | max-height: 65px;
75 | min-width: 40px;
76 | min-height: 40px;
77 | padding: 5px;
78 | border-radius: 10px;
79 | position: absolute;
80 | right: 25px;
81 | top: 0;
82 | cursor: pointer;
83 | }
84 |
85 | .message {
86 | margin-bottom: 15px;
87 | display: flex;
88 | flex-direction: column;
89 | align-items: flex-start;
90 | position: relative;
91 | width: 100%;
92 | .user {
93 | font-weight: bold;
94 | margin-right: 5px;
95 | }
96 | .avatar {
97 | width: 40px;
98 | position: absolute;
99 | top: 23px;
100 | border-radius: 5px;
101 | }
102 | .content {
103 | padding: 10px;
104 | border-radius: 10px;
105 | position: relative;
106 | left: 45px;
107 | max-width: calc(100% - 45px);
108 | word-wrap: break-word;
109 | }
110 | &.you {
111 | .content {
112 | background-color: #ffcc80;
113 | /* background-color:transparent; */
114 | }
115 | .edit-mode textarea{
116 | background-color: #ffcc80;
117 | }
118 | }
119 | &.other {
120 | .content {
121 | background-color: #7be383;
122 | /* background-color:transparent; */
123 | }
124 | .edit-mode textarea{
125 | background-color: #7be383;
126 | }
127 | }
128 | }
129 |
130 | .context-menu {
131 | /* display: none; */
132 | position: absolute;
133 | background-color: #fff;
134 | border: 1px solid #ccc;
135 | box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
136 | z-index: 1;
137 | border-radius: 5px;
138 | overflow: hidden;
139 | }
140 |
141 | .context-menu-item {
142 | padding: 8px;
143 | cursor: pointer;
144 | }
145 | .context-menu-item:hover{
146 | background-color: #4caf50;
147 | color: #fff;
148 | }
149 | .edit-mode {
150 | /* border: 1px solid #ccc; */
151 | display: flex;
152 | flex-direction: column;
153 | align-items: flex-start;
154 | position: relative;
155 | left: 45px;
156 | width: calc(100% - 45px);
157 | height: auto;
158 | }
159 | .edit-mode .edit-mode-btn{
160 | display: flex;
161 | flex-direction: row;
162 | align-items: center;
163 | justify-content: center;
164 | position: relative;
165 | width: 100%;
166 | height: auto;
167 | }
168 | .edit-mode .edit-mode-btn button{
169 | align-items: center;
170 | border-color: transparent;
171 | border-radius: 0.5rem;
172 | border-width: 1px;
173 | display: inline-flex;
174 | font-size: .875rem;
175 | font-weight: 500;
176 | line-height: 1.25rem;
177 | padding: 0.5rem 0.75rem;
178 | pointer-events: auto;
179 | margin: 10px;
180 | }
181 | .edit-mode .edit-mode-btn .save{
182 | background-color: rgba(16,163,127,1);
183 | color: rgba(255,255,255,1);
184 | }
185 | .edit-mode .edit-mode-btn .cancel{
186 | border-color: rgba(0,0,0,.1);
187 | }
188 | .edit-mode textarea{
189 | padding: 10px;
190 | border-radius: 10px;
191 | background-color: #7be383;
192 | font-family: auto;
193 | font-size: 1rem;
194 | resize: none;
195 | outline: none;
196 | display: inline-block;
197 | overflow-y: hidden;
198 | width: calc(100% - 20px);
199 | border:none;
200 | text-decoration: none;
201 |
202 | }
203 | #message-input {
204 | width: auto;
205 | padding: 10px;
206 | border: none;
207 | border-top: 1px solid #ccc;
208 | outline: none;
209 | }
210 |
211 | #send-button {
212 | width: 100%;
213 | padding: 10px;
214 | background-color: #4caf50;
215 | color: #fff;
216 | border: none;
217 | cursor: pointer;
218 | outline: none;
219 | border-top: 1px solid #ccc;
220 | border-radius: 0 0 8px 8px;
221 | }
222 | #send-button.disabled{
223 | background-color: #ccc;
224 | }
225 |
226 | #more-container{
227 | width: 100%;
228 | height: 100%;
229 | position: absolute;
230 | display: flex;
231 | align-content: center;
232 | justify-content: center;
233 | align-items: center;
234 | .bg{
235 | background-color: #ffffff67;
236 | width: 100%;
237 | height: 100%;
238 | position: absolute;
239 | display: flex;
240 | backdrop-filter: blur(2px);
241 | }
242 | .popup{
243 | z-index: 3;
244 | background-color:#fff;
245 | width: calc(100% - 75px);
246 | max-width: 700px;
247 | max-height: 100vh;
248 | border-radius: 26px;
249 | overflow: hidden;
250 | box-shadow: 0px 1px 17px #ccc;
251 | position: relative;
252 | .pop-more{
253 | display: flex;
254 | flex-wrap: wrap;
255 | align-content: space-around;
256 | justify-content: space-between;
257 | align-items: stretch;
258 | .btn{
259 | width: 50%;
260 | display: flex;
261 | flex-direction: row;
262 | flex-wrap: nowrap;
263 | align-content: center;
264 | justify-content: center;
265 | align-items: center;
266 | font-size: x-large;
267 | min-height: 80px;
268 | user-select: none;
269 | cursor: pointer;
270 | &.user{
271 | color: rgb(94, 159, 161);
272 | /* transition: 1s all ease-in-out; */
273 | &:hover{
274 | background-color: rgb(94, 159, 161);
275 | color: #fff;
276 | }
277 | }
278 | &.persona{
279 | color: #e69175;
280 | &:hover{
281 | background-color: #e69175;
282 | color: #fff;
283 | }
284 | }
285 | &.download{
286 | color: #326e34d9;
287 | &:hover{
288 | background-color: #326e34d9;
289 | color: #fff;
290 | }
291 | }
292 | &.upload{
293 | color: rgba(16,163,127,1);
294 | &:hover{
295 | background-color: rgba(16,163,127,1);
296 | color: #fff;
297 | }
298 | }
299 | }
300 | }
301 | .pop-user{
302 | min-height: 160px;
303 | display: flex;
304 | flex-direction: row;
305 | flex-wrap: wrap;
306 | align-content: space-around;
307 | justify-content: center;
308 | .user{
309 | display: flex;
310 | flex-direction: column;
311 | justify-content: center;
312 | margin: 18px;
313 | .name{
314 | line-height: 38px;
315 | font-size: 20px;
316 | display: flex;
317 | align-content: center;
318 | flex-wrap: nowrap;
319 | align-items: center;
320 | margin-top: 15px;
321 | & label{
322 | min-width: 45px;
323 | text-align: end;
324 | margin-right: 10px;
325 | }
326 | & input{
327 | width: 160px;
328 | height: 38px;
329 | font-size: large;
330 | letter-spacing: 0.15px;
331 | border: none;
332 | outline: none;
333 | background-color: #ecf0f3;
334 | transition: 0.25s ease;
335 | border-radius: 8px;
336 | text-align: center;
337 | }
338 | }
339 | .avatar{
340 | line-height: 38px;
341 | font-size: 20px;
342 | display: flex;
343 | align-content: center;
344 | flex-wrap: nowrap;
345 | align-items: center;
346 | & label{
347 | min-width: 45px;
348 | text-align: end;
349 | margin-right: 10px;
350 | }
351 | & img{
352 | width: 40px;
353 | margin-left: calc(50% - 50px);
354 | border-radius: 5px;
355 | }
356 | }
357 | }
358 | }
359 | .pop-persona{
360 | min-height: 160px;
361 | display: flex;
362 | flex-direction: column;
363 | flex-wrap: nowrap;
364 | align-items: center;
365 | justify-content: center;
366 | position: relative;
367 | .persona{
368 | margin-bottom: 15px;
369 | display: flex;
370 | flex-direction: column;
371 | align-items: flex-start;
372 | position: relative;
373 | width: calc(100% - 40px);
374 | .user{
375 | font-weight: bold;
376 | margin-right: 5px;
377 | }
378 | .avatar{
379 | width: 40px;
380 | position: absolute;
381 | top: 23px;
382 | border-radius: 5px;
383 | }
384 | .content{
385 | padding: 10px;
386 | border-radius: 10px;
387 | background-color: rgb(123, 227, 131);
388 | font-family: auto;
389 | font-size: 1rem;
390 | outline: none;
391 | display: inline-block;
392 | overflow-y: hidden;
393 | width: calc(100% - 90px);
394 | border: none;
395 | text-decoration: none;
396 | left: 45px;
397 | position: relative;
398 | height: 103px;
399 | }
400 | }
401 | .coordinate{
402 | position: relative;
403 | width: 185px;
404 | height: 165px;
405 | padding: 0px;
406 | margin: 0px;
407 |
408 | #coordinate-system {
409 | position: absolute;
410 | width: 120px;
411 | height: 120px;
412 | border: 1px solid #ccc;
413 | margin: 30px;
414 | padding: 0px;
415 | border-radius: 10px;
416 | background-image: url(/static/img/coordinate.png);
417 | background-size: cover;
418 | top: 0;
419 | }
420 |
421 | #draggable-point {
422 | position: absolute;
423 | width: 20px;
424 | height: 20px;
425 | padding: 0px;
426 | margin: 0px;
427 | background-color: white;
428 | border-radius: 50%;
429 | box-shadow: 0 0 7px #469905;
430 | }
431 | #coordinate-bg{
432 | position: relative;
433 | width: 185px;
434 | height: 165px;
435 | margin: 0px;
436 | padding: 0px;
437 | user-select: none;
438 | .l{
439 | position: absolute;
440 | font-size: 20px;
441 | width: 50px;
442 | top: 6px;
443 | left: 16px;
444 | color: aquamarine;
445 | }
446 | .r{
447 | position: absolute;
448 | font-size: 20px;
449 | width: 40px;
450 | right: 20px;
451 | bottom: 6px;
452 | color: bisque;
453 | }
454 | .x{
455 | display: flex;
456 | flex-wrap: nowrap;
457 | justify-content: space-between;
458 | align-items: center;
459 | flex-direction: row;
460 | height: 100%;
461 | }
462 | .y{
463 | display: flex;
464 | flex-direction: column;
465 | flex-wrap: nowrap;
466 | justify-content: space-between;
467 | align-items: center;
468 | position: absolute;
469 | width: 100%;
470 | height: 100%;
471 | top: 0;
472 | }
473 | }
474 | }
475 | .explain{
476 | display: flex;
477 | flex-direction: row;
478 | align-items: center;
479 | .mark{
480 | font-size: small;
481 | color: #00000080;
482 | }
483 | }
484 | }
485 | .pop-extension{
486 | height: 70vh;
487 | iframe{
488 | width: 100%;
489 | height: 100%;
490 | }
491 | .back{
492 | position: absolute;
493 | left: 10px;
494 | top: 10px;
495 | font-size: 1.5em;
496 | color: #000;
497 | }
498 | }
499 | }
500 | }
501 |
502 |
503 |
504 | input[type="checkbox"] + label::before {
505 | content: '\a0';
506 | /* non-break space */
507 | display: inline-block;
508 | vertical-align: .2em;
509 | width: 1em;
510 | height: 1em;
511 | margin-right: .2em;
512 | border-radius: .2em;
513 | background: silver;
514 | text-indent: .15em;
515 | line-height: .65;
516 | }
517 | input[type="checkbox"]:checked + label::before {
518 | content: '\2713';
519 | background: rgb(19, 206, 102);
520 | /* background: yellowgreen; */
521 | }
522 | input[type="checkbox"] {
523 | position: absolute;
524 | clip: rect(0,0,0,0);
525 | }
526 | input[type="checkbox"]:focus + label::before {
527 | box-shadow: 0 0 .1em .1em #58a;
528 | }
529 |
530 | input[type="checkbox"]:disabled + label::before {
531 | background: gray;
532 | box-shadow: none;
533 | color: #555;
534 | }
535 |
536 | label{
537 | font-size: 16px;
538 | font-family: monospace;
539 | font-weight: bolder;
540 | display: inline-block;
541 | user-select: none;
542 | }
543 | @media only screen and (max-width: 800px) {
544 | #app {
545 | flex-direction: column;
546 | }
547 | #more-container .popup .pop-persona .explain > p{
548 | display: none;
549 | }
550 | #chat-container {
551 | max-width: 100%;
552 | }
553 | }
554 | @media only screen and (min-width: 1081px) {
555 | #chat-messages {
556 | padding: 20px 150px;
557 | }
558 | }
--------------------------------------------------------------------------------
/static/fonts/FontAwesome.otf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/fonts/FontAwesome.otf
--------------------------------------------------------------------------------
/static/fonts/fontawesome-webfont.eot:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/fonts/fontawesome-webfont.eot
--------------------------------------------------------------------------------
/static/fonts/fontawesome-webfont.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/fonts/fontawesome-webfont.ttf
--------------------------------------------------------------------------------
/static/fonts/fontawesome-webfont.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/fonts/fontawesome-webfont.woff
--------------------------------------------------------------------------------
/static/fonts/fontawesome-webfont.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/fonts/fontawesome-webfont.woff2
--------------------------------------------------------------------------------
/static/img/PC_Coordinate.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/img/PC_Coordinate.png
--------------------------------------------------------------------------------
/static/img/PC_Coordinate.wb:
--------------------------------------------------------------------------------
1 | {"type":"wb","version":2,"source":"file://","elements":[{"id":"L6fDoP5obNbgFGqjcRCuG","type":"arrow","x":480,"y":300,"width":400,"height":0,"angle":0,"strokeColor":"#000000","backgroundColor":"transparent","fillStyle":"hachure","strokeWidth":1,"strokeStyle":"solid","roughness":1,"opacity":100,"groupIds":[],"strokeSharpness":"round","seed":572271130,"version":160,"versionNonce":474265242,"isDeleted":false,"boundElements":null,"updated":1705718257591,"link":null,"points":[[0,0],[400,0]],"lastCommittedPoint":null,"startBinding":null,"endBinding":null,"startArrowhead":null,"endArrowhead":"arrow"},{"id":"5z9q-4mzmkDFwKX2nqCdJ","type":"arrow","x":681.2758381748386,"y":100.67762530921956,"width":1.0445368363056104,"height":399.3223746907804,"angle":6.279679857460607,"strokeColor":"#000000","backgroundColor":"transparent","fillStyle":"hachure","strokeWidth":1,"strokeStyle":"solid","roughness":1,"opacity":100,"groupIds":[],"strokeSharpness":"round","seed":257727814,"version":222,"versionNonce":1081808902,"isDeleted":false,"boundElements":null,"updated":1705718257591,"link":null,"points":[[0,0],[-1.0445368363056104,399.3223746907804]],"lastCommittedPoint":null,"startBinding":{"elementId":"3PKxT8UPTnbRp5Ru00F-Z","focus":-0.002430146503276014,"gap":14.677094934406512},"endBinding":null,"startArrowhead":null,"endArrowhead":"arrow"},{"id":"3PKxT8UPTnbRp5Ru00F-Z","type":"text","x":660,"y":60,"width":41,"height":26,"angle":0,"strokeColor":"#000000","backgroundColor":"transparent","fillStyle":"hachure","strokeWidth":1,"strokeStyle":"solid","roughness":1,"opacity":100,"groupIds":[],"strokeSharpness":"sharp","seed":1022181850,"version":33,"versionNonce":1923491654,"isDeleted":false,"boundElements":[{"id":"5z9q-4mzmkDFwKX2nqCdJ","type":"arrow"}],"updated":1705718257591,"link":null,"text":"理性","fontSize":20,"fontFamily":1,"textAlign":"left","verticalAlign":"top","baseline":19,"containerId":null,"originalText":"理性"},{"id":"k15c12wy0aqNWA5j7MQFF","type":"text","x":660,"y":520,"width":41,"height":26,"angle":0,"strokeColor":"#000000","backgroundColor":"transparent","fillStyle":"hachure","strokeWidth":1,"strokeStyle":"solid","roughness":1,"opacity":100,"groupIds":[],"strokeSharpness":"sharp","seed":1546230106,"version":24,"versionNonce":262625306,"isDeleted":false,"boundElements":null,"updated":1705718257591,"link":null,"text":"感性","fontSize":20,"fontFamily":1,"textAlign":"left","verticalAlign":"top","baseline":19,"containerId":null,"originalText":"感性"},{"id":"GgNROi80SMPILCbXx0kEP","type":"text","x":420,"y":280,"width":41,"height":26,"angle":0,"strokeColor":"#000000","backgroundColor":"transparent","fillStyle":"hachure","strokeWidth":1,"strokeStyle":"solid","roughness":1,"opacity":100,"groupIds":[],"strokeSharpness":"sharp","seed":289730118,"version":26,"versionNonce":590171782,"isDeleted":false,"boundElements":null,"updated":1705718257591,"link":null,"text":"严肃","fontSize":20,"fontFamily":1,"textAlign":"left","verticalAlign":"top","baseline":19,"containerId":null,"originalText":"严肃"},{"id":"BNnHNrk1B4LzQwznG8tLq","type":"text","x":900,"y":280,"width":41,"height":26,"angle":0,"strokeColor":"#000000","backgroundColor":"transparent","fillStyle":"hachure","strokeWidth":1,"strokeStyle":"solid","roughness":1,"opacity":100,"groupIds":[],"strokeSharpness":"sharp","seed":1733198490,"version":45,"versionNonce":1811498202,"isDeleted":false,"boundElements":null,"updated":1705718257591,"link":null,"text":"活泼","fontSize":20,"fontFamily":1,"textAlign":"left","verticalAlign":"top","baseline":19,"containerId":null,"originalText":"活泼"},{"id":"OJStimQe66swL1JzMczgt","type":"text","x":460,"y":120,"width":81,"height":26,"angle":0,"strokeColor":"#000000","backgroundColor":"transparent","fillStyle":"hachure","strokeWidth":1,"strokeStyle":"solid","roughness":1,"opacity":100,"groupIds":[],"strokeSharpness":"sharp","seed":2097748762,"version":42,"versionNonce":1184750682,"isDeleted":false,"boundElements":null,"updated":1705718283503,"link":null,"text":"钻牛角尖","fontSize":20,"fontFamily":1,"textAlign":"left","verticalAlign":"top","baseline":19,"containerId":null,"originalText":"钻牛角尖"},{"id":"JpyE0DsCSJ3XCfyoSOZnu","type":"text","x":840,"y":460,"width":81,"height":26,"angle":0,"strokeColor":"#000000","backgroundColor":"transparent","fillStyle":"hachure","strokeWidth":1,"strokeStyle":"solid","roughness":1,"opacity":100,"groupIds":[],"strokeSharpness":"sharp","seed":1562129606,"version":19,"versionNonce":312730886,"isDeleted":false,"boundElements":null,"updated":1705718257591,"link":null,"text":"放飞自我","fontSize":20,"fontFamily":1,"textAlign":"left","verticalAlign":"top","baseline":19,"containerId":null,"originalText":"放飞自我"}],"appState":{"gridSize":20,"viewBackgroundColor":"#ffffff"},"files":{}}
--------------------------------------------------------------------------------
/static/img/ai.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/img/ai.png
--------------------------------------------------------------------------------
/static/img/ai2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/img/ai2.png
--------------------------------------------------------------------------------
/static/img/ai3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/img/ai3.png
--------------------------------------------------------------------------------
/static/img/ai4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/img/ai4.png
--------------------------------------------------------------------------------
/static/img/ai42.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/img/ai42.png
--------------------------------------------------------------------------------
/static/img/bilibini.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/img/bilibini.png
--------------------------------------------------------------------------------
/static/img/coordinate.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/img/coordinate.png
--------------------------------------------------------------------------------
/static/img/nekomusume.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/img/nekomusume.png
--------------------------------------------------------------------------------
/static/js/main.js:
--------------------------------------------------------------------------------
1 | const app = Vue.createApp({
2 | data() {
3 | return {
4 | users: {
5 | you: {
6 | name: 'you',
7 | avatar: '/static/img/bilibini.png',
8 | },
9 | other: {
10 | name: 'MeowAI',
11 | avatar: '/static/img/ai.png',
12 | }
13 | },
14 | coordinate: { x: 25, y: 35, isPC: true },
15 | persona: '主人你好呀!我是你的可爱猫娘,喵~',
16 | messages: [],
17 | //{type: 'you',content: '聊天内容',editMode: false,editedContent:'编辑内容'}
18 | socket: null,
19 | currentMessage: '',
20 | partialMessage: '',
21 | isDragging: false,//是否正在拖拽
22 | showStopButton: false,
23 | showContextMenu: false,
24 | showMoreSetup: -1,
25 | contextMenuIndex: -1,
26 | contextMenuPosition: { x: 0, y: 0 },
27 | extensionSRC:"/extension/",
28 | };
29 | },
30 | mounted() {
31 | this.socket = io.connect('http://' + document.domain + ':' + location.port);
32 | this.socket.on('chat', (data) => {
33 | if (this.partialMessage != '') {
34 | this.messages.pop();
35 | }
36 | this.partialMessage += data;
37 | this.messages.push({ type: 'other', content: this.partialMessage, editMode: false, editedContent: '' });
38 |
39 | const endOfMessageIndex = this.partialMessage.indexOf('\n\n');
40 | if (endOfMessageIndex !== -1) {
41 | this.partialMessage = '';
42 | this.scrollToBottom();
43 | this.stopReply();
44 | }
45 | });
46 | this.socket.on('stop', (data) => {
47 | if (data == 'True' || data == true) {
48 | this.partialMessage = '';
49 | this.showStopButton = false;
50 | }
51 | });
52 | this.socket.on('emit', (data) => {
53 | data=JSON.parse(data)
54 | new NoticeJs({
55 | type: data.code==0?'success':'error',
56 | text: data.msg,
57 | timeout:45,
58 | position: 'topLeft',
59 | }).show();
60 | });
61 | let that=this;
62 | window.addEventListener('message', function(event) {
63 | console.log(that.extensionSRC);
64 | that.extensionSRC=event.data;
65 | });
66 | },
67 | computed: {
68 | messagesInfo() {
69 | messages=[];
70 | for (let message of this.messages) {
71 | messages.push({
72 | "role": message.type=='you'?"User":"Assistant",
73 | "content": message.content,
74 | })
75 | }
76 | return messages;
77 | },
78 | characters() {
79 | if (!this.coordinate) {
80 | console.error("Coordinate is not defined!");
81 | return {"sense":1.1,"pizzazz":0.5};
82 | }
83 |
84 | let pizzazz = this.coordinate.x || 0.01;//top_p
85 | let sense = this.coordinate.y || 0.01;//temperature
86 |
87 | pizzazz=pizzazz<=50?pizzazz/50:(pizzazz-50)/5;
88 | sense=sense<=50?sense/50:(sense-50)/5;
89 |
90 | if (this.coordinate.isPC) {
91 | if (sense>=2&&pizzazz<=2){
92 | pizzazz=pizzazz>1?pizzazz*0.2:pizzazz*0.5;
93 | }else if(pizzazz>=2&&sense<=2){
94 | sense=sense>1?sense*0.2:sense*0.5;
95 | }else{
96 | pizzazz=pizzazz>1?pizzazz*0.2:pizzazz*0.5;
97 | sense=sense>1?sense*0.2:sense*0.5;
98 | }
99 | sense=sense*0.5;
100 | pizzazz=pizzazz*0.5;
101 | }
102 |
103 | console.log(this.coordinate.isPC, "pizzazz(top_p):",pizzazz,"sense(temperature):",sense);
104 | return { sense, pizzazz };
105 | },
106 | },
107 | methods: {
108 | sendMessage() {
109 | //发送消息
110 | const message = this.currentMessage.trim();
111 | if (message != '') {
112 | this.showStopButton = true;
113 | this.messages.push({ type: 'you', content: message, editMode: false, editedContent: '' });
114 |
115 | this.currentMessage = '';
116 | this.partialMessage = '';
117 | this.scrollToBottom();
118 |
119 | this.socket.emit('chat', this.messagesInfo);
120 | console.log(this.messagesInfo)
121 | }
122 | },
123 | setCharacter() {
124 | //设置性格
125 | this.socket.emit('character', {
126 | "persona": this.persona,
127 | "temperature": this.characters.sense,
128 | "top_p": this.characters.pizzazz,
129 | });
130 | },
131 | scrollToBottom() {
132 | //滚动到聊天消息底部
133 | const messagesContainer = document.getElementById('chat-messages');
134 | messagesContainer.scrollTop = messagesContainer.scrollHeight;
135 | },
136 | openExtension() {
137 | //打开扩展
138 | window.open('/extension','_blank')
139 | },
140 | stopReply() {
141 | //停止对话
142 | this.socket.emit('stop', true);
143 | },
144 | resetReply() {
145 | //重新对话
146 | console.log(this.extensionSRC)
147 | this.showStopButton = true;
148 | this.messages.pop();
149 | this.socket.emit('chat', this.messagesInfo);
150 | console.log(this.messagesInfo)
151 | },
152 | cutParagraphs(text) {
153 | //切断落
154 | return text.split('\n')
155 | },
156 | showContextMenuFun(index, event) {
157 | // 显示菜单
158 | this.showContextMenu = true;
159 | this.contextMenuIndex = index;
160 | this.contextMenuPosition = { x: event.offsetX + 55, y: event.offsetY };
161 | },
162 | hideContextMenuFun() {
163 | // 隐藏菜单
164 | this.showContextMenu = false;
165 | this.contextMenuIndex = -1;
166 | },
167 | copyMessage(index) {
168 | // 复制消息
169 | this.showContextMenu = false;
170 | navigator.clipboard.writeText(this.messages[index].content);
171 | },
172 | editMessage(index, event) {
173 | // 编辑消息
174 | this.messages[index].editMode = true;
175 | this.messages[index].editedContent = this.messages[index].content;
176 | },
177 | deleteMessage(index) {
178 | // 删除消息
179 | this.messages.splice(index, 1);
180 | },
181 | confirmEdit(index) {
182 | // 确认编辑
183 | this.messages[index].content = this.messages[index].editedContent;
184 | this.messages[index].editMode = false;
185 | },
186 | cancelEdit(index) {
187 | // 取消编辑
188 | this.messages[index].editMode = false;
189 | },
190 | autoResize(event) {
191 | //更新编辑框高度
192 | event.target.style.height = (event.target.scrollHeight - 20) + "px";
193 | },
194 | clickAvatarBox(typeAvatar) {
195 | //更换头像
196 | if (typeAvatar == 'you') {
197 | this.$refs.youAvatarToUpload.click();
198 | } else {
199 | this.$refs.otherAvatarToUpload.click();
200 | }
201 | },
202 | changeAvatarFile(typeAvatar) {
203 | //上传头像
204 | let file = typeAvatar == 'you' ? this.$refs.youAvatarToUpload.files[0] : this.$refs.otherAvatarToUpload.files[0];
205 | let thisf = this;
206 | if (file && file instanceof Blob) {
207 | const reader = new FileReader();
208 | reader.onloadend = function () {
209 | const base64DataUrl = reader.result; // 获取到图片文件的Base64编码数据URL
210 | typeAvatar == 'you' ? thisf.users.you.avatar = base64DataUrl + window.btoa('abc') : thisf.users.other.avatar = base64DataUrl;
211 | };
212 | reader.readAsDataURL(file); // 开始读取图片文件并转换成Base64格式
213 | } else {
214 | console.error("无效的图片文件");
215 | new NoticeJs({
216 | type: 'error',
217 | text: "无效的图片文件",
218 | timeout:45,
219 | position: 'topLeft',
220 | }).show();
221 | }
222 | },
223 | clickDownloadConfig() {
224 | //下载配置
225 | let config = {
226 | "persona": this.persona,
227 | "coordinate": this.coordinate,
228 | "characters": this.characters,
229 | "messages": this.messages,
230 | "users": this.users,
231 | }
232 |
233 | let file = new File([JSON.stringify(config)], 'meowAI.json', { type: 'text/plain' });
234 | this.$refs.downloadConfig.download = 'meowAI.json';
235 | this.$refs.downloadConfig.href = URL.createObjectURL(file);
236 | this.$refs.downloadConfig.click();
237 | },
238 | changeUploadConfig() {
239 | //上传配置
240 | let file = this.$refs.uploadConfig.files[0];
241 | let thisf = this;
242 | if (file && file instanceof Blob) {
243 | const reader = new FileReader();
244 | reader.onloadend = function () {
245 | let configData = reader.result; // 获取到配置文件内容
246 | configData = JSON.parse(configData);
247 | thisf.persona = configData.persona;
248 | thisf.coordinate = configData.coordinate;
249 | thisf.messages = configData.messages;
250 | thisf.users = configData.users;
251 | thisf.showMoreSetup = -1;
252 | };
253 | reader.readAsText(file); // 开始读取配置文件
254 | new NoticeJs({
255 | type: 'success',
256 | text: "读取成功",
257 | timeout:10,
258 | position: 'topLeft',
259 | }).show();
260 | } else {
261 | console.error("无效的配置文件");
262 | new NoticeJs({
263 | type: 'error',
264 | text: "无效的配置文件",
265 | timeout:45,
266 | position: 'topLeft',
267 | }).show();
268 | }
269 | },
270 | clickUploadConfig() {
271 | this.$refs.uploadConfig.click();
272 | },
273 | mousemoveCoordinate(e) {
274 | //拖动坐标
275 | const coordinateSystem = this.$refs.coordinateSystem;
276 | const draggablePoint = this.$refs.draggablePoint;
277 | if (this.isDragging) {
278 | const x = e.clientX - coordinateSystem.getBoundingClientRect().left - 10;
279 | const y = e.clientY - coordinateSystem.getBoundingClientRect().top - 10;
280 | const maxX = coordinateSystem.clientWidth - draggablePoint.clientWidth;
281 | const maxY = coordinateSystem.clientHeight - draggablePoint.clientHeight;
282 | const clampedX = Math.min(Math.max(0, x), maxX);
283 | const clampedY = Math.min(Math.max(0, y), maxY);
284 |
285 | this.coordinate.y = clampedY;
286 | this.coordinate.x = clampedX;
287 | // console.log(`X: ${clampedX}, Y: ${clampedY}`);
288 | }
289 | },
290 | },
291 | });
292 |
293 | app.mount('#app');
--------------------------------------------------------------------------------
/static/js/notice.js:
--------------------------------------------------------------------------------
1 | (function webpackUniversalModuleDefinition(root, factory) {
2 | if(typeof exports === 'object' && typeof module === 'object')
3 | module.exports = factory();
4 | else if(typeof define === 'function' && define.amd)
5 | define("NoticeJs", [], factory);
6 | else if(typeof exports === 'object')
7 | exports["NoticeJs"] = factory();
8 | else
9 | root["NoticeJs"] = factory();
10 | })(typeof self !== 'undefined' ? self : this, function() {
11 | return /******/ (function(modules) { // webpackBootstrap
12 | /******/ // The module cache
13 | /******/ var installedModules = {};
14 | /******/
15 | /******/ // The require function
16 | /******/ function __webpack_require__(moduleId) {
17 | /******/
18 | /******/ // Check if module is in cache
19 | /******/ if(installedModules[moduleId]) {
20 | /******/ return installedModules[moduleId].exports;
21 | /******/ }
22 | /******/ // Create a new module (and put it into the cache)
23 | /******/ var module = installedModules[moduleId] = {
24 | /******/ i: moduleId,
25 | /******/ l: false,
26 | /******/ exports: {}
27 | /******/ };
28 | /******/
29 | /******/ // Execute the module function
30 | /******/ modules[moduleId].call(module.exports, module, module.exports, __webpack_require__);
31 | /******/
32 | /******/ // Flag the module as loaded
33 | /******/ module.l = true;
34 | /******/
35 | /******/ // Return the exports of the module
36 | /******/ return module.exports;
37 | /******/ }
38 | /******/
39 | /******/
40 | /******/ // expose the modules object (__webpack_modules__)
41 | /******/ __webpack_require__.m = modules;
42 | /******/
43 | /******/ // expose the module cache
44 | /******/ __webpack_require__.c = installedModules;
45 | /******/
46 | /******/ // define getter function for harmony exports
47 | /******/ __webpack_require__.d = function(exports, name, getter) {
48 | /******/ if(!__webpack_require__.o(exports, name)) {
49 | /******/ Object.defineProperty(exports, name, {
50 | /******/ configurable: false,
51 | /******/ enumerable: true,
52 | /******/ get: getter
53 | /******/ });
54 | /******/ }
55 | /******/ };
56 | /******/
57 | /******/ // getDefaultExport function for compatibility with non-harmony modules
58 | /******/ __webpack_require__.n = function(module) {
59 | /******/ var getter = module && module.__esModule ?
60 | /******/ function getDefault() { return module['default']; } :
61 | /******/ function getModuleExports() { return module; };
62 | /******/ __webpack_require__.d(getter, 'a', getter);
63 | /******/ return getter;
64 | /******/ };
65 | /******/
66 | /******/ // Object.prototype.hasOwnProperty.call
67 | /******/ __webpack_require__.o = function(object, property) { return Object.prototype.hasOwnProperty.call(object, property); };
68 | /******/
69 | /******/ // __webpack_public_path__
70 | /******/ __webpack_require__.p = "dist/";
71 | /******/
72 | /******/ // Load entry module and return exports
73 | /******/ return __webpack_require__(__webpack_require__.s = 2);
74 | /******/ })
75 | /************************************************************************/
76 | /******/ ([
77 | /* 0 */
78 | /***/ (function(module, exports, __webpack_require__) {
79 |
80 | "use strict";
81 |
82 |
83 | Object.defineProperty(exports, "__esModule", {
84 | value: true
85 | });
86 | var noticeJsModalClassName = exports.noticeJsModalClassName = 'noticejs-modal';
87 | var closeAnimation = exports.closeAnimation = 'noticejs-fadeOut';
88 |
89 | var Defaults = exports.Defaults = {
90 | title: '',
91 | text: '',
92 | type: 'success',
93 | position: 'topRight',
94 | newestOnTop: false,
95 | timeout: 30,
96 | progressBar: true,
97 | closeWith: ['button'],
98 | animation: null,
99 | modal: false,
100 | width: 320,
101 | scroll: {
102 | maxHeightContent: 300,
103 | showOnHover: true
104 | },
105 | rtl: false,
106 | callbacks: {
107 | beforeShow: [],
108 | onShow: [],
109 | afterShow: [],
110 | onClose: [],
111 | afterClose: [],
112 | onClick: [],
113 | onHover: [],
114 | onTemplate: []
115 | }
116 | };
117 |
118 | /***/ }),
119 | /* 1 */
120 | /***/ (function(module, exports, __webpack_require__) {
121 |
122 | "use strict";
123 |
124 |
125 | Object.defineProperty(exports, "__esModule", {
126 | value: true
127 | });
128 | exports.appendNoticeJs = exports.addListener = exports.CloseItem = exports.AddModal = undefined;
129 | exports.getCallback = getCallback;
130 |
131 | var _api = __webpack_require__(0);
132 |
133 | var API = _interopRequireWildcard(_api);
134 |
135 | function _interopRequireWildcard(obj) { if (obj && obj.__esModule) { return obj; } else { var newObj = {}; if (obj != null) { for (var key in obj) { if (Object.prototype.hasOwnProperty.call(obj, key)) newObj[key] = obj[key]; } } newObj.default = obj; return newObj; } }
136 |
137 | var options = API.Defaults;
138 |
139 | /**
140 | * @param {NoticeJs} ref
141 | * @param {string} eventName
142 | * @return {void}
143 | */
144 | function getCallback(ref, eventName) {
145 | if (ref.callbacks.hasOwnProperty(eventName)) {
146 | ref.callbacks[eventName].forEach(function (cb) {
147 | if (typeof cb === 'function') {
148 | cb.apply(ref);
149 | }
150 | });
151 | }
152 | }
153 |
154 | var AddModal = exports.AddModal = function AddModal() {
155 | if (document.getElementsByClassName(API.noticeJsModalClassName).length <= 0) {
156 | var element = document.createElement('div');
157 | element.classList.add(API.noticeJsModalClassName);
158 | element.classList.add('noticejs-modal-open');
159 | document.body.appendChild(element);
160 | // Remove class noticejs-modal-open
161 | setTimeout(function () {
162 | element.className = API.noticeJsModalClassName;
163 | }, 200);
164 | }
165 | };
166 |
167 | var CloseItem = exports.CloseItem = function CloseItem(item) {
168 | getCallback(options, 'onClose');
169 |
170 | // Set animation to close notification item
171 | if (options.animation !== null && options.animation.close !== null) {
172 | item.className += ' ' + options.animation.close;
173 | }
174 | setTimeout(function () {
175 | item.remove();
176 | }, 200);
177 |
178 | // Close modal
179 | if (options.modal === true && document.querySelectorAll("[noticejs-modal='true']").length >= 1) {
180 | document.querySelector('.noticejs-modal').className += ' noticejs-modal-close';
181 | setTimeout(function () {
182 | document.querySelector('.noticejs-modal').remove();
183 | }, 500);
184 | }
185 |
186 | // Remove container
187 | var position = '.' + item.closest('.noticejs').className.replace('noticejs', '').trim();
188 | setTimeout(function () {
189 | if (document.querySelectorAll(position + ' .item').length <= 0) {
190 | document.querySelector(position).remove();
191 | }
192 | }, 500);
193 | };
194 |
195 | var addListener = exports.addListener = function addListener(item) {
196 | // Add close button Event
197 | if (options.closeWith.includes('button')) {
198 | item.querySelector('.close').addEventListener('click', function () {
199 | CloseItem(item);
200 | });
201 | }
202 |
203 | // Add close by click Event
204 | if (options.closeWith.includes('click')) {
205 | item.style.cursor = 'pointer';
206 | item.addEventListener('click', function (e) {
207 | if (e.target.className !== 'close') {
208 | getCallback(options, 'onClick');
209 | CloseItem(item);
210 | }
211 | });
212 | } else {
213 | item.addEventListener('click', function (e) {
214 | if (e.target.className !== 'close') {
215 | getCallback(options, 'onClick');
216 | }
217 | });
218 | }
219 |
220 | item.addEventListener('mouseover', function () {
221 | getCallback(options, 'onHover');
222 | });
223 | };
224 |
225 | var appendNoticeJs = exports.appendNoticeJs = function appendNoticeJs(noticeJsHeader, noticeJsBody, noticeJsProgressBar) {
226 | var target_class = '.noticejs-' + options.position;
227 | // Create NoticeJs item
228 | var noticeJsItem = document.createElement('div');
229 | noticeJsItem.classList.add('item');
230 | noticeJsItem.classList.add(options.type);
231 | if (options.rtl === true) {
232 | noticeJsItem.classList.add('noticejs-rtl');
233 | }
234 | if (options.width !== '' && Number.isInteger(options.width)) {
235 | noticeJsItem.style.width = options.width + 'px';
236 | }
237 |
238 | // Add Header
239 | if (noticeJsHeader && noticeJsHeader !== '') {
240 | noticeJsItem.appendChild(noticeJsHeader);
241 | }
242 |
243 | // Add body
244 | noticeJsItem.appendChild(noticeJsBody);
245 |
246 | // Add progress bar
247 | if (noticeJsProgressBar && noticeJsProgressBar !== '') {
248 | noticeJsItem.appendChild(noticeJsProgressBar);
249 | }
250 |
251 | // Empty top and bottom container
252 | if (['top', 'bottom'].includes(options.position)) {
253 | document.querySelector(target_class).innerHTML = '';
254 | }
255 |
256 | // Add open animation
257 | if (options.animation !== null && options.animation.open !== null) {
258 | noticeJsItem.className += ' ' + options.animation.open;
259 | }
260 |
261 | // Add Modal
262 | if (options.modal === true) {
263 | noticeJsItem.setAttribute('noticejs-modal', 'true');
264 | AddModal();
265 | }
266 |
267 | // Add Listener
268 | addListener(noticeJsItem, options.closeWith);
269 |
270 | getCallback(options, 'beforeShow');
271 | getCallback(options, 'onShow');
272 | if (options.newestOnTop === true) {
273 | document.querySelector(target_class).insertAdjacentElement('afterbegin', noticeJsItem);
274 | } else {
275 | document.querySelector(target_class).appendChild(noticeJsItem);
276 | }
277 | getCallback(options, 'afterShow');
278 |
279 | return noticeJsItem;
280 | };
281 |
282 | /***/ }),
283 | /* 2 */
284 | /***/ (function(module, exports, __webpack_require__) {
285 |
286 | "use strict";
287 |
288 |
289 | Object.defineProperty(exports, "__esModule", {
290 | value: true
291 | });
292 |
293 | var _createClass = function () { function defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if ("value" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } } return function (Constructor, protoProps, staticProps) { if (protoProps) defineProperties(Constructor.prototype, protoProps); if (staticProps) defineProperties(Constructor, staticProps); return Constructor; }; }();
294 |
295 | var _noticejs = __webpack_require__(3);
296 |
297 | var _noticejs2 = _interopRequireDefault(_noticejs);
298 |
299 | var _api = __webpack_require__(0);
300 |
301 | var API = _interopRequireWildcard(_api);
302 |
303 | var _components = __webpack_require__(4);
304 |
305 | var _helpers = __webpack_require__(1);
306 |
307 | var helper = _interopRequireWildcard(_helpers);
308 |
309 | function _interopRequireWildcard(obj) { if (obj && obj.__esModule) { return obj; } else { var newObj = {}; if (obj != null) { for (var key in obj) { if (Object.prototype.hasOwnProperty.call(obj, key)) newObj[key] = obj[key]; } } newObj.default = obj; return newObj; } }
310 |
311 | function _interopRequireDefault(obj) { return obj && obj.__esModule ? obj : { default: obj }; }
312 |
313 | function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } }
314 |
315 | var NoticeJs = function () {
316 | /**
317 | * @param {object} options
318 | * @returns {Noty}
319 | */
320 | function NoticeJs() {
321 | var options = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : {};
322 |
323 | _classCallCheck(this, NoticeJs);
324 |
325 | this.options = Object.assign(API.Defaults, options);
326 | this.component = new _components.Components();
327 |
328 | this.on('beforeShow', this.options.callbacks.beforeShow);
329 | this.on('onShow', this.options.callbacks.onShow);
330 | this.on('afterShow', this.options.callbacks.afterShow);
331 | this.on('onClose', this.options.callbacks.onClose);
332 | this.on('afterClose', this.options.callbacks.afterClose);
333 | this.on('onClick', this.options.callbacks.onClick);
334 | this.on('onHover', this.options.callbacks.onHover);
335 |
336 | return this;
337 | }
338 |
339 | /**
340 | * @returns {NoticeJs}
341 | */
342 |
343 |
344 | _createClass(NoticeJs, [{
345 | key: 'show',
346 | value: function show() {
347 | var container = this.component.createContainer();
348 | if (document.querySelector('.noticejs-' + this.options.position) === null) {
349 | document.body.appendChild(container);
350 | }
351 |
352 | var noticeJsHeader = void 0;
353 | var noticeJsBody = void 0;
354 | var noticeJsProgressBar = void 0;
355 |
356 | // Create NoticeJs header
357 | noticeJsHeader = this.component.createHeader(this.options.title, this.options.closeWith);
358 |
359 | // Create NoticeJs body
360 | noticeJsBody = this.component.createBody(this.options.text);
361 |
362 | // Create NoticeJs progressBar
363 | if (this.options.progressBar === true) {
364 | noticeJsProgressBar = this.component.createProgressBar();
365 | }
366 |
367 | //Append NoticeJs
368 | var noticeJs = helper.appendNoticeJs(noticeJsHeader, noticeJsBody, noticeJsProgressBar);
369 |
370 | return noticeJs;
371 | }
372 |
373 | /**
374 | * @param {string} eventName
375 | * @param {function} cb
376 | * @return {NoticeJs}
377 | */
378 |
379 | }, {
380 | key: 'on',
381 | value: function on(eventName) {
382 | var cb = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : function () {};
383 |
384 | if (typeof cb === 'function' && this.options.callbacks.hasOwnProperty(eventName)) {
385 | this.options.callbacks[eventName].push(cb);
386 | }
387 |
388 | return this;
389 | }
390 |
391 | /**
392 | * @param {Object} options
393 | * @return {Notice}
394 | */
395 |
396 | }], [{
397 | key: 'overrideDefaults',
398 | value: function overrideDefaults(options) {
399 | this.options = Object.assign(API.Defaults, options);
400 | return this;
401 | }
402 | }]);
403 |
404 | return NoticeJs;
405 | }();
406 |
407 | exports.default = NoticeJs;
408 | module.exports = exports['default'];
409 |
410 | /***/ }),
411 | /* 3 */
412 | /***/ (function(module, exports) {
413 |
414 | // removed by extract-text-webpack-plugin
415 |
416 | /***/ }),
417 | /* 4 */
418 | /***/ (function(module, exports, __webpack_require__) {
419 |
420 | "use strict";
421 |
422 |
423 | Object.defineProperty(exports, "__esModule", {
424 | value: true
425 | });
426 | exports.Components = undefined;
427 |
428 | var _createClass = function () { function defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if ("value" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } } return function (Constructor, protoProps, staticProps) { if (protoProps) defineProperties(Constructor.prototype, protoProps); if (staticProps) defineProperties(Constructor, staticProps); return Constructor; }; }();
429 |
430 | var _api = __webpack_require__(0);
431 |
432 | var API = _interopRequireWildcard(_api);
433 |
434 | var _helpers = __webpack_require__(1);
435 |
436 | var helper = _interopRequireWildcard(_helpers);
437 |
438 | function _interopRequireWildcard(obj) { if (obj && obj.__esModule) { return obj; } else { var newObj = {}; if (obj != null) { for (var key in obj) { if (Object.prototype.hasOwnProperty.call(obj, key)) newObj[key] = obj[key]; } } newObj.default = obj; return newObj; } }
439 |
440 | function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } }
441 |
442 | var options = API.Defaults;
443 |
444 | var Components = exports.Components = function () {
445 | function Components() {
446 | _classCallCheck(this, Components);
447 | }
448 |
449 | _createClass(Components, [{
450 | key: 'createContainer',
451 | value: function createContainer() {
452 | var element_class = 'noticejs-' + options.position;
453 | var element = document.createElement('div');
454 | element.classList.add('noticejs');
455 | element.classList.add(element_class);
456 |
457 | return element;
458 | }
459 | }, {
460 | key: 'createHeader',
461 | value: function createHeader() {
462 | var element = void 0;
463 | if (options.title && options.title !== '') {
464 | element = document.createElement('div');
465 | element.setAttribute('class', 'noticejs-heading');
466 | element.textContent = options.title;
467 | }
468 |
469 | // Add close button
470 | if (options.closeWith.includes('button')) {
471 | var close = document.createElement('div');
472 | close.setAttribute('class', 'close');
473 | close.innerHTML = '×';
474 | if (element) {
475 | element.appendChild(close);
476 | } else {
477 | element = close;
478 | }
479 | }
480 |
481 | return element;
482 | }
483 | }, {
484 | key: 'createBody',
485 | value: function createBody() {
486 | var element = document.createElement('div');
487 | element.setAttribute('class', 'noticejs-body');
488 | var content = document.createElement('div');
489 | content.setAttribute('class', 'noticejs-content');
490 | content.innerHTML = options.text;
491 | element.appendChild(content);
492 |
493 | if (options.scroll !== null && options.scroll.maxHeight !== '') {
494 | element.style.overflowY = 'auto';
495 | element.style.maxHeight = options.scroll.maxHeight + 'px';
496 |
497 | if (options.scroll.showOnHover === true) {
498 | element.style.visibility = 'hidden';
499 | }
500 | }
501 | return element;
502 | }
503 | }, {
504 | key: 'createProgressBar',
505 | value: function createProgressBar() {
506 | var element = document.createElement('div');
507 | element.setAttribute('class', 'noticejs-progressbar');
508 | var bar = document.createElement('div');
509 | bar.setAttribute('class', 'noticejs-bar');
510 | element.appendChild(bar);
511 |
512 | // Progress bar animation
513 | if (options.progressBar === true && typeof options.timeout !== 'boolean' && options.timeout !== false) {
514 | var frame = function frame() {
515 | if (width <= 0) {
516 | clearInterval(id);
517 |
518 | var item = element.closest('div.item');
519 | // Add close animation
520 | if (options.animation !== null && options.animation.close !== null) {
521 |
522 | // Remove open animation class
523 | item.className = item.className.replace(new RegExp('(?:^|\\s)' + options.animation.open + '(?:\\s|$)'), ' ');
524 | // Add close animation class
525 | item.className += ' ' + options.animation.close;
526 |
527 | // Close notification after 0.5s + timeout
528 | var close_time = parseInt(options.timeout) + 500;
529 | setTimeout(function () {
530 | helper.CloseItem(item);
531 | }, close_time);
532 | } else {
533 | // Close notification when progress bar completed
534 | helper.CloseItem(item);
535 | }
536 | } else {
537 | width--;
538 | bar.style.width = width + '%';
539 | }
540 | };
541 |
542 | var width = 100;
543 | var id = setInterval(frame, options.timeout);
544 | }
545 |
546 | return element;
547 | }
548 | }]);
549 |
550 | return Components;
551 | }();
552 |
553 | /***/ })
554 | /******/ ]);
555 | });
--------------------------------------------------------------------------------
/static/webfonts/fa-brands-400.eot:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/webfonts/fa-brands-400.eot
--------------------------------------------------------------------------------
/static/webfonts/fa-brands-400.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/webfonts/fa-brands-400.ttf
--------------------------------------------------------------------------------
/static/webfonts/fa-brands-400.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/webfonts/fa-brands-400.woff
--------------------------------------------------------------------------------
/static/webfonts/fa-brands-400.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/webfonts/fa-brands-400.woff2
--------------------------------------------------------------------------------
/static/webfonts/fa-regular-400.eot:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/webfonts/fa-regular-400.eot
--------------------------------------------------------------------------------
/static/webfonts/fa-regular-400.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/webfonts/fa-regular-400.ttf
--------------------------------------------------------------------------------
/static/webfonts/fa-regular-400.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/webfonts/fa-regular-400.woff
--------------------------------------------------------------------------------
/static/webfonts/fa-regular-400.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/webfonts/fa-regular-400.woff2
--------------------------------------------------------------------------------
/static/webfonts/fa-solid-900.eot:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/webfonts/fa-solid-900.eot
--------------------------------------------------------------------------------
/static/webfonts/fa-solid-900.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/webfonts/fa-solid-900.ttf
--------------------------------------------------------------------------------
/static/webfonts/fa-solid-900.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/webfonts/fa-solid-900.woff
--------------------------------------------------------------------------------
/static/webfonts/fa-solid-900.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/static/webfonts/fa-solid-900.woff2
--------------------------------------------------------------------------------
/templates/extension.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | 扩展中心 MeowAI
6 |
7 |
105 |
106 |
107 |
108 |
109 |
扩展中心
110 |
111 |
112 |
>
113 |
{{info.name}}{{info.description}}
114 |
115 |
116 |
117 |
118 |
119 |
120 |
155 |
156 |
--------------------------------------------------------------------------------
/templates/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Meow-AI
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
{{ users[message.type].name }}
19 |
![]()
20 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bilibini/Meow-AI/468d8c668118277491859b082085343feb673866/test.py
--------------------------------------------------------------------------------