├── README.md ├── app.py ├── history.py ├── history └── history will be here ├── new_api.py ├── orinigal_api.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # ChatRWKV-flask-api 2 | 一个由 ChatGPT 完成的 ChatRWKV 服务端,包括这个 readme(的以下部分 3 | 4 | # 程序功能 5 | 本程序是一个简单的基于 Flask 框架的聊天机器人服务端,使用 RWKV 模型进行聊天,支持 GET 和 POST 请求,用于与客户端交互并记录聊天记录。 6 | 7 | # 推荐安装要求 8 | 本程序推荐使用 Python 3.10 或以上版本以及 Flask 和 torch 库,并需要安装 rwkvstic 库才能正常运行。 9 | 10 | # 安装依赖 11 | 可以直接通过以下命令安装依赖: 12 | ``` 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | # API调用方法 17 | 1. 启动程序后,通过 GET 或 POST 请求访问 /chatrwkv 路由。 18 | 2. 在请求参数中传递消息内容、用户 ID 和消息来源。 19 | 3. 服务器将会返回聊天机器人的回复,并将聊天记录保存到文件中。 20 | 21 | # 调用参数 22 | /chatrwkv 路由支持以下请求参数: 23 | 24 | - `msg`:要发送给聊天机器人的消息内容。 25 | - `usrid`:用户 ID。 26 | - `source`:消息来源。 27 | 28 | 注意:以上三个参数都是必须的,否则服务器将会返回错误响应。 -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | print("finding CUDA") 4 | import torch 5 | from torch.utils.cpp_extension import CUDA_HOME 6 | from flask import Flask, jsonify, request 7 | from rwkvstic.load import RWKV 8 | 9 | app = Flask(__name__) 10 | 11 | # 检测 CUDA 是否可用并输出 CUDA 设备名称和 CUDA 安装路径 12 | if torch.cuda.is_available(): 13 | device_name = torch.cuda.get_device_name(0) 14 | print("CUDA device found:", device_name) 15 | 16 | cuda_home = CUDA_HOME 17 | if cuda_home is None or cuda_home.strip() == '': 18 | print("CUDA_HOME is empty, please check your CUDA driver") 19 | else: 20 | os.environ['CUDA_HOME'] = cuda_home 21 | print("CUDA home:", cuda_home) 22 | else: 23 | print("CUDA device not found") 24 | 25 | # 输出 "loading model",加载模型并输出 "model loaded" 26 | print("loading model") 27 | model = RWKV( 28 | "https://huggingface.co/BlinkDL/rwkv-4-pile-3b/resolve/main/RWKV-4-Pile-3B-Instruct-test1-20230124.pth" 29 | ) 30 | print("model loaded") 31 | 32 | # 修改 /chatrwkv 路由,同时支持 GET 和 POST 请求 33 | @app.route('/chatrwkv', methods=['GET', 'POST']) 34 | def chat_with_rwkv(): 35 | 36 | # 如果是 GET 请求 37 | if request.method == 'GET': 38 | # 从请求参数中获取 msg、usrid 和 source 39 | msg = request.args.get('msg') 40 | usrid = request.args.get('usrid') 41 | source = request.args.get('source') 42 | # 如果是 POST 请求 43 | elif request.method == 'POST': 44 | # 从请求参数中获取 msg、usrid 和 source 45 | msg = request.form.get('msg') 46 | usrid = request.form.get('usrid') 47 | source = request.form.get('source') 48 | else: 49 | # 如果不是 GET 或 POST 请求,则返回错误响应 50 | return jsonify({'status': 'error', 'error': 'method not allowed'}), 405 51 | 52 | # 如果 usrid 参数不存在或为空,则返回错误响应 53 | if not usrid or usrid.strip() == '': 54 | return jsonify({'status': 'error', 'error': 'usrid parameter is missing or empty'}), 400 55 | 56 | # 如果 source 参数不存在或为空,则返回错误响应 57 | if not source or source.strip() == '': 58 | return jsonify({'status': 'error', 'error': 'source parameter is missing or empty'}), 400 59 | 60 | # 如果 msg 参数不存在或为空,则返回错误响应 61 | if not msg or msg.strip() == '': 62 | return jsonify({'status': 'error', 'error': 'msg parameter is missing or empty'}), 400 63 | 64 | # 构建聊天历史记录文件名和路径 65 | filename = f"{usrid}.txt" 66 | filepath = os.path.join(os.path.dirname(__file__), 'history', filename) 67 | 68 | # 如果聊天历史记录文件不存在,则创建文件 69 | if not os.path.exists(filepath): 70 | open(filepath, 'w').close() 71 | 72 | # 将消息内容写入聊天历史记录文件 73 | with open(filepath, 'a', encoding='utf-8') as f: 74 | f.write(f"{msg}\n") 75 | 76 | # 从聊天历史记录文件中读取上下文内容并加载上下文 77 | with open(filepath, 'r', encoding='utf-8') as f: 78 | context = f.read() 79 | model.loadContext(newctx=context) 80 | 81 | # 调用 RWKV 模型进行聊天 82 | output = model.forward(number="100")["output"] 83 | res = output 84 | 85 | # 将聊天结果写入聊天历史记录文件 86 | with open(filepath, 'a', encoding='utf-8') as f: 87 | f.write(f"{res}\n") 88 | 89 | # 将聊天结果写入响应中并返回 90 | return jsonify({'status': 'ok', 'reply': res}) 91 | 92 | # 启动 Flask 应用程序 93 | app.run(host='0.0.0.0', port=7860) -------------------------------------------------------------------------------- /history.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, jsonify, request 2 | import os 3 | 4 | app = Flask(__name__) 5 | 6 | @app.route('/chat/history', methods=['GET', 'POST']) 7 | def get_chat_history(): 8 | usrid = request.args.get('usrid') or request.form.get('usrid') 9 | content = request.form.get('content') 10 | 11 | if not usrid: 12 | return jsonify({'error': 'usrid parameter is missing'}), 400 13 | 14 | if request.method == 'POST' and not content: 15 | return jsonify({'error': 'content parameter is missing'}), 400 16 | 17 | filename = f"{usrid}.txt" 18 | filepath = os.path.join(os.getcwd(), 'history', filename) 19 | 20 | if not os.path.exists(filepath): 21 | with open(filepath, 'w') as f: 22 | f.write('') 23 | 24 | if request.method == 'POST': 25 | with open(filepath, 'a') as f: 26 | f.write(content + '\n') 27 | 28 | with open(filepath, 'r') as f: 29 | content = f.read() 30 | 31 | return jsonify({'usrid': usrid, 'content': content}), 200 32 | 33 | if __name__ == '__main__': 34 | app.run() 35 | -------------------------------------------------------------------------------- /history/history will be here: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /new_api.py: -------------------------------------------------------------------------------- 1 | print('初始化中') 2 | from flask import Flask, request, jsonify 3 | from rwkv.model import RWKV 4 | from rwkv.utils import PIPELINE, PIPELINE_ARGS 5 | import os, sys, torch 6 | import numpy as np 7 | np.set_printoptions(precision=4, suppress=True, linewidth=200) 8 | 9 | # 加载模型 10 | print("正在加载模型,请稍等...") 11 | filename = 'RWKV-4-Raven-3B-v10-Eng49%-Chn50%-Other1%-20230419-ctx4096.pth' 12 | def checkmodel(filename): 13 | if os.path.isfile(filename): 14 | print('模型以存在,开始加载') 15 | else: 16 | print('模型不存在,开始下载') 17 | os.system('wget https://huggingface.co/BlinkDL/rwkv-4-raven/resolve/main/RWKV-4-Raven-3B-v10-Eng49%25-Chn50%25-Other1%25-20230419-ctx4096.pth') 18 | return 'done' 19 | 20 | checkmodel(filename) 21 | 22 | os.environ['RWKV_JIT_ON'] = '1' 23 | os.environ["RWKV_CUDA_ON"] = '0' # '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries 24 | model = RWKV(model='RWKV-4-Raven-3B-v10-Eng49%-Chn50%-Other1%-20230419-ctx4096', strategy='cuda fp16') 25 | pipeline = PIPELINE(model, "20B_tokenizer.json") 26 | 27 | #out, state = model.forward([187, 510, 1563, 310, 247], None) 28 | #print(out.detach().cpu().numpy()) # get logits 29 | #out, state = model.forward([187, 510], None) 30 | #out, state = model.forward([1563], state) # RNN has state (use deepcopy to clone states) 31 | #out, state = model.forward([310, 247], state) 32 | #print(out.detach().cpu().numpy()) 33 | 34 | print("模型加载完成!") 35 | 36 | # 创建Flask应用 37 | app = Flask(__name__) 38 | 39 | # 创建空字典,用于存储对话记录 40 | chat_dict = {} 41 | 42 | # 定义路由函数 43 | @app.route('/chatrwkv', methods=['GET']) 44 | def chat_rwkv(): 45 | # 获取请求参数 46 | source = request.args.get('source') 47 | msg = request.args.get('msg') 48 | usrid = request.args.get('usrid') 49 | # 检查参数是否齐全 50 | if not all([source, msg, usrid]): 51 | return jsonify({'code': 400, 'msg': '参数缺失'}) 52 | # 输出请求参数 53 | print(f"请求参数:source={source}, msg={msg}, usrid={usrid}") 54 | # 如果该usrid还没有对话记录,就创建一个空列表,并将其作为chat_dict的一个键值对,键为usrid,值为该列表 55 | if usrid not in chat_dict: 56 | chat_dict[usrid] = [] 57 | # 将msg参数写入该usrid下的记录列表,并在末尾添加一个换行符 58 | chat_dict[usrid].append(msg + "\n") 59 | # 将该usrid下的所有记录拼接起来,作为输入给模型,并调用rwkv模型生成回答 60 | prompt = ''.join(chat_dict[usrid]) 61 | ctx = prompt 62 | out = "out" 63 | args = PIPELINE_ARGS(temperature=max(0.2, float(0.99)), top_p=float(0.99), 64 | token_ban=[], # ban the generation of some tokens 65 | token_stop=[0]) # stop generation whenever you see any token here 66 | out = pipeline.generate(ctx, ) 67 | # 将模型的输出写入该usrid下的记录列表,并在末尾添加一个换行符 68 | chat_dict[usrid].append(out + "\n") 69 | # 将该usrid下的所有记录拼接起来,作为响应返回 70 | response = out 71 | # 输出响应内容 72 | print(f"响应内容:{response}") 73 | return response 74 | 75 | if __name__ == '__main__': 76 | app.run(host='0.0.0.0', port=7860) -------------------------------------------------------------------------------- /orinigal_api.py: -------------------------------------------------------------------------------- 1 | 2 | ######################################################################################################## 3 | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM 4 | # RWKV语言模型,GitHub地址:https://github.com/BlinkDL/RWKV-LM 5 | ######################################################################################################## 6 | 7 | print('\nChatRWKV https://github.com/BlinkDL/ChatRWKV\n') 8 | 9 | import os, sys, torch 10 | import numpy as np 11 | np.set_printoptions(precision=4, suppress=True, linewidth=200) 12 | 13 | # current_path = os.path.dirname(os.path.abspath(__file__)) 14 | # sys.path.append(f'{current_path}/rwkv_pip_package/src') 15 | 16 | # Tune these below (test True/False for all of them) to find the fastest setting: 17 | # 下面这些参数可以调整,以找到最快的设置: 18 | # torch._C._jit_set_profiling_executor(True) 19 | # torch._C._jit_set_profiling_mode(True) 20 | # torch._C._jit_override_can_fuse_on_cpu(True) 21 | # torch._C._jit_override_can_fuse_on_gpu(True) 22 | # torch._C._jit_set_texpr_fuser_enabled(False) 23 | # torch._C._jit_set_nvfuser_enabled(False) 24 | 25 | ######################################################################################################## 26 | # 27 | # Use '/' in model path, instead of '\'. Use ctx4096 models if you need long ctx. 28 | # 模型路径中使用“/”而不是“\”,如果需要较长的上下文,请使用ctx4096模型。 29 | # 30 | # fp16 = good for GPU (!!! DOES NOT support CPU !!!) 31 | # fp32 = good for CPU 32 | # bf16 = worse accuracy, supports CPU 33 | # xxxi8 (example: fp16i8) = xxx with int8 quantization to save 50% VRAM/RAM, slower, slightly less accuracy 34 | # 35 | # Read https://pypi.org/project/rwkv/ for Strategy Guide 36 | # 请阅读https://pypi.org/project/rwkv/以获取策略指南。 37 | ######################################################################################################## 38 | # set these before import RWKV 39 | # 在导入RWKV之前设置这些参数。 40 | os.environ['RWKV_JIT_ON'] = '1' 41 | os.environ["RWKV_CUDA_ON"] = '0' # '1' to compile CUDA kernel (10x faster), requires c++ compiler & cuda libraries 42 | 43 | from rwkv.model import RWKV # pip install rwkv 44 | model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-169m/RWKV-4-Pile-169M-20220807-8023', strategy='cuda fp16') 45 | # model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-169m/RWKV-4-Pile-169M-20220807-8023', strategy='cuda fp16i8') 46 | # model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-169m/RWKV-4-Pile-169M-20220807-8023', strategy='cuda fp16i8 *6 -> cuda fp16 *0+ -> cpu fp32 *1') 47 | # model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-169m/RWKV-4-Pile-169M-20220807-8023', strategy='cpu fp32') 48 | # model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-169m/RWKV-4-Pile-169M-20220807-8023', strategy='cpu fp32 *3 -> cuda fp16 *6+') 49 | # model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040', strategy='cpu fp32') 50 | # model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040', strategy='cuda fp16') 51 | # model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040', strategy='cuda fp16 *8 -> cpu fp32') 52 | # model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040', strategy='cuda:0 fp16 -> cuda:1 fp16 -> cpu fp32 *1') 53 | # model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040', strategy='cuda fp16 *6+') 54 | # model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-14b/RWKV-4-Pile-14B-20230213-8019', strategy='cuda fp16 *0+ -> cpu fp32 *1') 55 | # model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221110-ctx4096', strategy='cuda:0 fp16 *25 -> cuda:1 fp16') 56 | 57 | # out, state = model.forward([187], None) 58 | # print(out.detach().cpu().numpy()) 59 | 60 | out, state = model.forward([187, 510, 1563, 310, 247], None) 61 | print(out.detach().cpu().numpy()) # get logits 62 | out, state = model.forward([187, 510], None) 63 | out, state = model.forward([1563], state) # RNN has state (use deepcopy to clone states) 64 | out, state = model.forward([310, 247], state) 65 | print(out.detach().cpu().numpy()) # same result as above 66 | 67 | # print('\n') 68 | # exit(0) 69 | from rwkv.utils import PIPELINE, PIPELINE_ARGS 70 | pipeline = PIPELINE(model, "20B_tokenizer.json") 71 | 72 | ctx = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese." 73 | print(ctx, end='') 74 | 75 | def my_print(s): 76 | print(s, end='', flush=True) 77 | 78 | # For alpha_frequency and alpha_presence, see "Frequency and presence penalties": 79 | # https://platform.openai.com/docs/api-reference/parameter-details 80 | # alpha_frequency和alpha_presence参数的详细信息请参见“Frequency and presence penalties”页面。 81 | args = PIPELINE_ARGS(temperature = 1.0, max_tokens = 50, top_p = 0.95, frequency_penalty = 0.0, presence_penalty = 0.0, stop=["\n"]) 82 | # args = PIPELINE_ARGS(temperature = 1.0, max_tokens = 50, top_p = 0.95, frequency_penalty = 0.5, presence_penalty = 0.5, stop=["\n"]) 83 | 84 | while True: 85 | try: 86 | text = input() 87 | if text == 'quit': 88 | break 89 | ctx += text + '\n' 90 | prompt = ctx.replace('\n', ' ') 91 | out = pipeline(prompt, **args) 92 | my_print(out) 93 | ctx += out + '\n' 94 | except: 95 | import traceback 96 | traceback.print_exc() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | inquirer 2 | Flask 3 | torch 4 | transformers 5 | rwkvstic 6 | scipy 7 | tokenizers>=0.13.2 8 | prompt_toolkit 9 | rwkv --------------------------------------------------------------------------------