├── .gitattributes ├── run.sh ├── run.bat ├── requirements.txt ├── .gitignore ├── README.md ├── convert_model.py ├── model ├── cuda │ ├── rwkv5_op.cpp │ ├── rwkv5.cu │ ├── gemm_fp16_cublas.cpp │ ├── wrapper.cpp │ └── operators.cu ├── rwkv_tokenizer.py ├── lora.py ├── utils.py └── model_run.py ├── verify_optimized_mm8.py ├── help.md ├── config.yml ├── styles.css ├── server.py ├── app.py ├── prompt.py └── chat.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.pickle filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | export RWKV_JIT_ON=1 2 | export RWKV_CUDA_ON=1 3 | python3 app.py -------------------------------------------------------------------------------- /run.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | set RWKV_JIT_ON=1 3 | set RWKV_CUDA_ON=0 4 | python3 app.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.11.0+cu113 2 | flask 3 | requests 4 | langid 5 | translate 6 | imgkit 7 | markdown 8 | pygments -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | logs/ 3 | images/ 4 | states/ 5 | dumps/ 6 | notebooks/ 7 | .vscode/ 8 | .ipynb_checkpoints/ 9 | __pycache__/ 10 | /model/__pycache__/ 11 | password.encrypt 12 | session.token 13 | device.json 14 | qq.txt 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Eloise 2 | A QQ Chatbot based on RWKV (W.I.P.) 3 | 4 | ## Introduction 5 | This is a bot for QQ IM software based on the [Language Model of RWKV](https://github.com/BlinkDL/RWKV-LM). 6 | 7 | ## Run 8 | 1. Install [`go-cqhttp`](https://docs.go-cqhttp.org/). 9 | 2. Edit `config.yml`; fill in your QQ and password. 10 | 3. Check `requirements.txt`; make sure you have all required packages properly installed (you may choose your own torch-gpu version depending on your CUDA version). 11 | 4. Edit `chat.py`; change your model path. 12 | 5. Create 3 empty folders in the project path: `logs`, `images` and `states`. 13 | 6. Open two terminals. 14 | ```bash 15 | cd /path/to/eloise 16 | ``` 17 | 7. Run `go-cqhttp` in one terminal for the first time; follow instructions. 18 | 8. Edit `device.json`; change `protocol` to `2`. 19 | 9. Run `go-cqhttp` again; follow instructions to log in. 20 | 10. Run `./run.sh` in another terminal. 21 | -------------------------------------------------------------------------------- /convert_model.py: -------------------------------------------------------------------------------- 1 | import os, sys, argparse 2 | current_path = os.path.dirname(os.path.abspath(__file__)) 3 | sys.path.append(f'{current_path}/../rwkv_pip_package/src') 4 | 5 | from model.model_run import RWKV 6 | 7 | # python convert_model.py --in '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-14b/RWKV-4-Pile-14B-20230313-ctx8192-test1050' --out 'fp16_RWKV-4-Pile-14B-20230313-ctx8192-test1050' --strategy 'cuda fp16' 8 | # python convert_model.py --in '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20230109-ctx4096' --out 'fp16_RWKV-4-Pile-7B-20230109-ctx4096' --strategy 'cuda fp16' 9 | # python convert_model.py --in '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221110-ctx4096' --out 'fp16i8_and_fp16_RWKV-4-Pile-3B-20221110-ctx4096' --strategy 'cuda fp16i8 *10 -> cuda fp16' 10 | 11 | def get_args(): 12 | p = argparse.ArgumentParser(prog = 'convert_model', description = 'Convert RWKV model for faster loading and saves cpu RAM.') 13 | p.add_argument('--in', metavar = 'INPUT', help = 'Filename for input model.', required = True) 14 | p.add_argument('--out', metavar = 'OUTPUT', help = 'Filename for output model.', required = True) 15 | p.add_argument('--strategy', help = 'Please quote the strategy as it contains spaces and special characters. See https://pypi.org/project/rwkv/ for strategy format definition.', required = True) 16 | p.add_argument('--quiet', action = 'store_true', help = 'Suppress normal output, only show errors.') 17 | return p.parse_args() 18 | 19 | args = get_args() 20 | if not args.quiet: 21 | print(f'** {args}') 22 | 23 | RWKV(getattr(args, 'in'), args.strategy, verbose = not args.quiet, convert_and_save_and_exit = args.out) -------------------------------------------------------------------------------- /model/cuda/rwkv5_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "ATen/ATen.h" 3 | typedef at::BFloat16 bf16; 4 | typedef at::Half fp16; 5 | typedef float fp32; 6 | 7 | void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); 8 | void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y); 9 | void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y); 10 | 11 | void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 12 | cuda_forward_bf16(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); 13 | } 14 | void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 15 | cuda_forward_fp16(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); 16 | } 17 | void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 18 | cuda_forward_fp32(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("forward_bf16", &forward_bf16, "rwkv5 forward_bf16"); 23 | m.def("forward_fp16", &forward_fp16, "rwkv5 forward_fp16"); 24 | m.def("forward_fp32", &forward_fp32, "rwkv5 forward_fp32"); 25 | } 26 | TORCH_LIBRARY(rwkv5, m) { 27 | m.def("forward_bf16", forward_bf16); 28 | m.def("forward_fp16", forward_fp16); 29 | m.def("forward_fp32", forward_fp32); 30 | } 31 | -------------------------------------------------------------------------------- /verify_optimized_mm8.py: -------------------------------------------------------------------------------- 1 | import torch, sys 2 | from time import perf_counter as time 3 | from tqdm import tqdm 4 | 5 | torch.backends.cudnn.benchmark = True 6 | torch.backends.cudnn.allow_tf32 = True 7 | torch.backends.cuda.matmul.allow_tf32 = True 8 | 9 | torch.manual_seed(0) 10 | 11 | use_new = 1 12 | 13 | from torch.utils.cpp_extension import load 14 | current_path = '.' 15 | load( 16 | name=f"wkv_cuda", 17 | sources=[f"{current_path}/model/cuda/wrapper.cpp", f"{current_path}/model/cuda/operators.cu"], 18 | verbose=True, 19 | extra_cuda_cflags=["-t 4", "-std=c++17", "--use_fast_math", "-O3", "--extra-device-vectorization"]+["-DOPTIMIZED_MM8"]*use_new, 20 | is_python_module=False) 21 | 22 | @torch.jit.script 23 | def cuda_mm8_one(N: int, M: int, x, w, mx, rx, my, ry): 24 | assert x.dtype == mx.dtype == rx.dtype == my.dtype == ry.dtype == torch.float16 25 | assert w.dtype == torch.uint8 26 | assert x.shape == [N] 27 | assert w.shape == [N, M] 28 | assert rx.shape == mx.shape == [M] 29 | assert ry.shape == my.shape == [N, 1] 30 | y = torch.zeros((M,), device='cuda', dtype=torch.float32) 31 | torch.ops.rwkv.mm8_one(N, M, x, w, mx, rx, my, ry, y) 32 | return y.to(dtype=torch.float16) 33 | 34 | def mm8_one(x, w, mx, rx, my, ry): 35 | N, M = w.shape[0], w.shape[1] 36 | return cuda_mm8_one(N, M, x, w, mx, rx, my, ry) 37 | 38 | def mm8_one_truth(x, w, mx, rx, my, ry): 39 | return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx) 40 | 41 | N_list = list(range(16,2000,16))+[2560,5120,5120*2] 42 | M_list = list(range(4,2000,4))+[2560,5120,5120*2] 43 | max_error = 0 44 | for N in tqdm(N_list): 45 | for M in M_list: 46 | l = [] 47 | for dim,dtype in [([N], 5), ([N, M], 0), ([M], 5), ([M], 5), ([N, 1], 5), ([N, 1], 5)]: 48 | if dtype == 5: 49 | l.append(torch.randn(tuple(dim),device='cuda',dtype=torch.float16)) 50 | else: 51 | l.append(torch.randint(low=0,high=255,size=tuple(dim),device='cuda',dtype=torch.uint8)) 52 | 53 | x,w,mx,rx,my,ry = l 54 | y0 = mm8_one(x, w, mx, rx, my, ry) 55 | y1 = mm8_one_truth(x, w, mx, rx, my, ry) 56 | 57 | err = ((y1-y0).norm()/y0.norm()).item() 58 | max_error = max(max_error, err) 59 | print(max_error) -------------------------------------------------------------------------------- /help.md: -------------------------------------------------------------------------------- 1 | # Help 2 | Model: 3 | 4 | ### Misc Utilities 5 | 1. `-tr [-en, -zh] `: Translate 6 | 2. `-p, -params`: Show chat parameters 7 | 3. `-pr, -prompts`: Show chat prompts 8 | 9 | ### Free Generation 10 | 1. `-h, -help`: Show this help 11 | 2. `-g, -gen []`: Generate text 12 | 3. `-e, -retry []`: Retry last generation 13 | 4. `-m, -more []`: Generate more 14 | 6. `-i, -inst []`: Follow instructions 15 | 16 | ### Chat 17 | 1. `-s, -reset [] []`: Reset your chat chain 18 | 2. `-l, list`: List scenarios 19 | 3. `-a, -alt []`: Alternative reply 20 | 4. ` []`: Chat with me 21 | 22 | ### Scenarios 23 | 24 | 25 | ### Parameters 26 | | Param | Description | Default (Chat Mode) | Default (Bot/Instruction Mode) | 27 | | ----------------- | ---------------------------------- | ------------------- | ------------------------------ | 28 | | `-nucleus` | Switch to nucleus sampling | | | 29 | | `-typical` | Switch to typical sampling | | | 30 | | `-temp=` | Higher temperature → more verbose | | | 31 | | `-top_p=` | Lower top p → more accurate answer | | | 32 | | `-tau=` | Lower tau → more human-like answer | | | 33 | | `-af=` | Count penalty, avoids repeating | | | 34 | | `-ap=` | Presence penalty | | | 35 | | `-ar=` | Repeat penalty range in tokens | | | 36 | 37 | ### Examples 38 | #### Chat 39 | * Use `-alt` or `-a` to retry 40 | * Use `-reset` or `-s` to reset, better at chillin' 41 | * Use `-bot` or `-b` to enter bot mode, better at QA and coding 42 | * Do not use `-retry` or `-more`: useless 43 | 44 | ```text 45 | : What's your interesting news today? 46 | 47 | : ... 48 | 49 | : -alt 50 | 51 | : ... 52 | 53 | : -reset 54 | 55 | : Chat reset for ... 56 | ``` 57 | 58 | #### Instruct 59 | * Use `-retry` or `-e` to retry 60 | * Use `-more` to continue 61 | 62 | ```text 63 | : -i Write a letter to our customers, apologizing for delay in delivery of our products. 64 | 65 | : Dear Valued Customers: 66 | ... 67 | Yours Sincerely 68 | [YOUR NAME] 69 | 70 | : -retry 71 | 72 | : ... 73 | ``` 74 | 75 | #### Translation 76 | ```text 77 | : -tr Hello World! 78 | 79 | : 你好,世界! 80 | 81 | : -tr -en 你好,世界! 82 | 83 | : Hello, world! 84 | ``` 85 | 86 | #### Generation 87 | ```text 88 | : -gen -temp=2.0 -top_p=0.5 Here is a story about a man's lonely journey on Mars: 89 | 90 | : (Some story...) 91 | 92 | : -m 93 | 94 | : (Continue generates...) 95 | ``` -------------------------------------------------------------------------------- /model/cuda/rwkv5.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "ATen/ATen.h" 4 | typedef at::BFloat16 bf16; 5 | typedef at::Half fp16; 6 | typedef float fp32; 7 | 8 | template 9 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state, 10 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, 11 | F *__restrict__ const _y) 12 | { 13 | const int b = blockIdx.x / H; 14 | const int h = blockIdx.x % H; 15 | const int i = threadIdx.x; 16 | _w += h*_N_; 17 | _u += h*_N_; 18 | _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!! 19 | 20 | __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; 21 | 22 | float state[_N_]; 23 | #pragma unroll 24 | for (int j = 0; j < _N_; j++) 25 | state[j] = _state[j]; 26 | 27 | __syncthreads(); 28 | u[i] = float(_u[i]); 29 | w[i] = _w[i]; 30 | __syncthreads(); 31 | 32 | for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) 33 | { 34 | __syncthreads(); 35 | r[i] = float(_r[t]); 36 | k[i] = float(_k[t]); 37 | __syncthreads(); 38 | 39 | const float v = float(_v[t]); 40 | float y = 0; 41 | 42 | #pragma unroll 43 | for (int j = 0; j < _N_; j+=4) 44 | { 45 | const float4& r_ = (float4&)(r[j]); 46 | const float4& k_ = (float4&)(k[j]); 47 | const float4& w_ = (float4&)(w[j]); 48 | const float4& u_ = (float4&)(u[j]); 49 | float4& s = (float4&)(state[j]); 50 | float4 x; 51 | 52 | x.x = k_.x * v; 53 | x.y = k_.y * v; 54 | x.z = k_.z * v; 55 | x.w = k_.w * v; 56 | 57 | y += r_.x * (u_.x * x.x + s.x); 58 | y += r_.y * (u_.y * x.y + s.y); 59 | y += r_.z * (u_.z * x.z + s.z); 60 | y += r_.w * (u_.w * x.w + s.w); 61 | 62 | s.x = s.x * w_.x + x.x; 63 | s.y = s.y * w_.y + x.y; 64 | s.z = s.z * w_.z + x.z; 65 | s.w = s.w * w_.w + x.w; 66 | } 67 | _y[t] = F(y); 68 | } 69 | #pragma unroll 70 | for (int j = 0; j < _N_; j++) 71 | _state[j] = state[j]; 72 | } 73 | 74 | void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y) 75 | { 76 | assert(H*_N_ == C); 77 | kernel_forward<<>>(B, T, C, H, state, r, k, v, w, u, y); 78 | } 79 | void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y) 80 | { 81 | assert(H*_N_ == C); 82 | kernel_forward<<>>(B, T, C, H, state, r, k, v, w, u, y); 83 | } 84 | void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y) 85 | { 86 | assert(H*_N_ == C); 87 | kernel_forward<<>>(B, T, C, H, state, r, k, v, w, u, y); 88 | } 89 | -------------------------------------------------------------------------------- /config.yml: -------------------------------------------------------------------------------- 1 | # go-cqhttp 默认配置文件 2 | 3 | account: # 账号相关 4 | uin: 1537125018 # QQ账号 5 | password: '' # 密码为空时使用扫码登录 6 | encrypt: true # 是否开启密码加密 7 | status: 0 # 在线状态 请参考 https://docs.go-cqhttp.org/guide/config.html#在线状态 8 | relogin: # 重连设置 9 | delay: 3 # 首次重连延迟, 单位秒 10 | interval: 3 # 重连间隔 11 | max-times: 0 # 最大重连次数, 0为无限制 12 | sign-server: http://127.0.0.1:8080 13 | 14 | # 是否使用服务器下发的新地址进行重连 15 | # 注意, 此设置可能导致在海外服务器上连接情况更差 16 | use-sso-address: true 17 | # 是否允许发送临时会话消息 18 | allow-temp-session: false 19 | 20 | heartbeat: 21 | # 心跳频率, 单位秒 22 | # -1 为关闭心跳 23 | interval: 5 24 | 25 | message: 26 | # 上报数据类型 27 | # 可选: string,array 28 | post-format: string 29 | # 是否忽略无效的CQ码, 如果为假将原样发送 30 | ignore-invalid-cqcode: false 31 | # 是否强制分片发送消息 32 | # 分片发送将会带来更快的速度 33 | # 但是兼容性会有些问题 34 | force-fragment: false 35 | # 是否将url分片发送 36 | fix-url: false 37 | # 下载图片等请求网络代理 38 | proxy-rewrite: '' 39 | # 是否上报自身消息 40 | report-self-message: false 41 | # 移除服务端的Reply附带的At 42 | remove-reply-at: false 43 | # 为Reply附加更多信息 44 | extra-reply-data: false 45 | # 跳过 Mime 扫描, 忽略错误数据 46 | skip-mime-scan: false 47 | 48 | output: 49 | # 日志等级 trace,debug,info,warn,error 50 | log-level: warn 51 | # 日志时效 单位天. 超过这个时间之前的日志将会被自动删除. 设置为 0 表示永久保留. 52 | log-aging: 15 53 | # 是否在每次启动时强制创建全新的文件储存日志. 为 false 的情况下将会在上次启动时创建的日志文件续写 54 | log-force-new: true 55 | # 是否启用日志颜色 56 | log-colorful: true 57 | # 是否启用 DEBUG 58 | debug: false # 开启调试模式 59 | 60 | # 默认中间件锚点 61 | default-middlewares: &default 62 | # 访问密钥, 强烈推荐在公网的服务器设置 63 | access-token: '' 64 | # 事件过滤器文件目录 65 | filter: '' 66 | # API限速设置 67 | # 该设置为全局生效 68 | # 原 cqhttp 虽然启用了 rate_limit 后缀, 但是基本没插件适配 69 | # 目前该限速设置为令牌桶算法, 请参考: 70 | # https://baike.baidu.com/item/%E4%BB%A4%E7%89%8C%E6%A1%B6%E7%AE%97%E6%B3%95/6597000?fr=aladdin 71 | rate-limit: 72 | enabled: false # 是否启用限速 73 | frequency: 1 # 令牌回复频率, 单位秒 74 | bucket: 1 # 令牌桶大小 75 | 76 | database: # 数据库相关设置 77 | leveldb: 78 | # 是否启用内置leveldb数据库 79 | # 启用将会增加10-20MB的内存占用和一定的磁盘空间 80 | # 关闭将无法使用 撤回 回复 get_msg 等上下文相关功能 81 | enable: true 82 | sqlite3: 83 | # 是否启用内置sqlite3数据库 84 | # 启用将会增加一定的内存占用和一定的磁盘空间 85 | # 关闭将无法使用 撤回 回复 get_msg 等上下文相关功能 86 | enable: false 87 | cachettl: 3600000000000 # 1h 88 | 89 | # 连接服务列表 90 | servers: 91 | # 添加方式,同一连接方式可添加多个,具体配置说明请查看文档 92 | #- http: # http 通信 93 | #- ws: # 正向 Websocket 94 | #- ws-reverse: # 反向 Websocket 95 | #- pprof: #性能分析服务器 96 | 97 | - http: # HTTP 通信设置 98 | address: 127.0.0.1:5700 # HTTP监听地址 99 | timeout: 5 # 反向 HTTP 超时时间, 单位秒,<5 时将被忽略 100 | long-polling: # 长轮询拓展 101 | enabled: false # 是否开启 102 | max-queue-size: 2000 # 消息队列大小,0 表示不限制队列大小,谨慎使用 103 | middlewares: 104 | <<: *default # 引用默认中间件 105 | post: # 反向HTTP POST地址列表 106 | - url: 'http://127.0.0.1:6006/' # 地址 107 | secret: '' # 密钥 108 | # max-retries: 3 # 最大重试,0 时禁用 109 | # retries-interval: 1500 # 重试时间,单位毫秒,0 时立即 110 | #- url: http://127.0.0.1:5701/ # 地址 111 | # secret: '' # 密钥 112 | # max-retries: 10 # 最大重试,0 时禁用 113 | # retries-interval: 1000 # 重试时间,单位毫秒,0 时立即 114 | -------------------------------------------------------------------------------- /model/rwkv_tokenizer.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################## 2 | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM 3 | ######################################################################################################## 4 | 5 | class TRIE: 6 | __slots__ = tuple("ch,to,values,front".split(",")) 7 | to:list 8 | values:set 9 | def __init__(self, front=None, ch=None): 10 | self.ch = ch 11 | self.to = [None for ch in range(256)] 12 | self.values = set() 13 | self.front = front 14 | 15 | def __repr__(self): 16 | fr = self 17 | ret = [] 18 | while(fr!=None): 19 | if(fr.ch!=None): 20 | ret.append(fr.ch) 21 | fr = fr.front 22 | return ""%(ret[::-1], self.values) 23 | 24 | def add(self, key:bytes, idx:int=0, val=None): 25 | if(idx == len(key)): 26 | if(val is None): 27 | val = key 28 | self.values.add(val) 29 | return self 30 | ch = key[idx] 31 | if(self.to[ch] is None): 32 | self.to[ch] = TRIE(front=self, ch=ch) 33 | return self.to[ch].add(key, idx=idx+1, val=val) 34 | 35 | def find_longest(self, key:bytes, idx:int=0): 36 | u:TRIE = self 37 | ch:int = key[idx] 38 | 39 | while(u.to[ch] is not None): 40 | u = u.to[ch] 41 | idx += 1 42 | if(u.values): 43 | ret = idx, u, u.values 44 | if(idx==len(key)): 45 | break 46 | ch = key[idx] 47 | return ret 48 | 49 | class TRIE_TOKENIZER(): 50 | def __init__(self, file_name): 51 | self.idx2token = {} 52 | sorted = [] # must be already sorted 53 | with open(file_name, "r", encoding="utf-8") as f: 54 | lines = f.readlines() 55 | for l in lines: 56 | idx = int(l[:l.index(' ')]) 57 | x = eval(l[l.index(' '):l.rindex(' ')]) 58 | x = x.encode("utf-8") if isinstance(x, str) else x 59 | assert isinstance(x, bytes) 60 | assert len(x) == int(l[l.rindex(' '):]) 61 | sorted += [x] 62 | self.idx2token[idx] = x 63 | 64 | self.token2idx = {} 65 | for k,v in self.idx2token.items(): 66 | self.token2idx[v] = int(k) 67 | 68 | self.root = TRIE() 69 | for t, i in self.token2idx.items(): 70 | _ = self.root.add(t, val=(t, i)) 71 | 72 | def encodeBytes(self, src:bytes): 73 | idx:int = 0 74 | tokens = [] 75 | while (idx < len(src)): 76 | _idx:int = idx 77 | idx, _, values = self.root.find_longest(src, idx) 78 | assert(idx != _idx) 79 | _, token = next(iter(values)) 80 | tokens.append(token) 81 | return tokens 82 | 83 | def decodeBytes(self, tokens): 84 | return b''.join(map(lambda i: self.idx2token[i], tokens)) 85 | 86 | def encode(self, src): 87 | return self.encodeBytes(src.encode("utf-8")) 88 | 89 | def decode(self, tokens): 90 | try: 91 | return self.decodeBytes(tokens).decode('utf-8') 92 | except: 93 | return '\ufffd' # bad utf-8 94 | 95 | def printTokens(self, tokens): 96 | for i in tokens: 97 | s = self.idx2token[i] 98 | try: 99 | s = s.decode('utf-8') 100 | except: 101 | pass 102 | print(f'{repr(s)}{i}', end=' ') 103 | print() 104 | -------------------------------------------------------------------------------- /model/lora.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Dict 3 | import typing 4 | import torch 5 | 6 | 7 | def get_filter_keys_and_merge_coefficients(layer_filter): 8 | if layer_filter: 9 | layers = [] 10 | layer_coefficients = {} 11 | for layer in layer_filter.split(' '): 12 | if '*' in layer: 13 | coefficient, _, layer = layer.partition('*') 14 | coefficient = float(coefficient) 15 | else: 16 | coefficient = 1 17 | if layer.isdecimal(): 18 | layers.append(int(layer)) 19 | layer_coefficients[int(layer)] = coefficient 20 | elif '-' in layer: 21 | start, _, end = layer.partition('-') 22 | start, end = int(start), int(end) 23 | layers.extend(range(start, end+1)) 24 | for l in range(start, end+1): 25 | layer_coefficients[l] = coefficient 26 | else: 27 | raise NotImplementedError( 28 | "layer_filter Not implemented:", layer_filter) 29 | layers = sorted(set(layers)) 30 | layer_prefixes = tuple(f"blocks.{l}." for l in layers) 31 | 32 | def filter_keys(keys): 33 | new_keys = [] 34 | for key in keys: 35 | # Skip weights that are started by 'blocks.' and not in allowed range 36 | if key.startswith("blocks.") and not key.startswith(layer_prefixes): 37 | continue 38 | new_keys.append(key) 39 | return new_keys 40 | 41 | def merge_coefficients(key): 42 | if key.startswith('blocks.') and int(key.split('.')[1]) in layer_coefficients: 43 | return layer_coefficients[int(key.split('.')[1])] 44 | else: 45 | return 1 46 | else: 47 | def filter_keys(keys): 48 | return keys 49 | 50 | def merge_coefficients(key): 51 | return 1 52 | return filter_keys, merge_coefficients 53 | 54 | 55 | def lora_merge(base_model, lora, lora_alpha, device="cuda", layer_filter=None,): 56 | print(f"Loading LoRA: {lora}") 57 | print(f"LoRA alpha={lora_alpha}, layer_filter={layer_filter}") 58 | filter_keys, merge_coef = get_filter_keys_and_merge_coefficients( 59 | layer_filter) 60 | w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu') 61 | # merge LoRA-only slim checkpoint into the main weights 62 | w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu') 63 | # pdb.set_trace() #DEBUG 64 | for k in filter_keys(w_lora.keys()): # 处理time_mixing之类的融合 65 | w[k] = w_lora[k] 66 | output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict() 67 | # merge LoRA weights 68 | keys = list(w.keys()) 69 | for k in keys: 70 | if k.endswith('.weight'): 71 | prefix = k[:-len('.weight')] 72 | lora_A = prefix + '.lora_A' 73 | lora_B = prefix + '.lora_B' 74 | if lora_A in keys: 75 | assert lora_B in keys 76 | print(f'merging {lora_A} and {lora_B} into {k}') 77 | assert w[lora_B].shape[1] == w[lora_A].shape[0] 78 | lora_r = w[lora_B].shape[1] 79 | w[k] = w[k].to(device=device) 80 | w[lora_A] = w[lora_A].to(device=device) 81 | w[lora_B] = w[lora_B].to(device=device) 82 | w[k] += w[lora_B] @ w[lora_A] * \ 83 | (lora_alpha / lora_r) * merge_coef(k) 84 | output_w[k] = w[k].to(device='cpu', copy=True) 85 | del w[k] 86 | del w[lora_A] 87 | del w[lora_B] 88 | continue 89 | 90 | if 'lora' not in k: 91 | print(f'retaining {k}') 92 | output_w[k] = w[k].clone() 93 | del w[k] 94 | return output_w 95 | -------------------------------------------------------------------------------- /model/cuda/gemm_fp16_cublas.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #define CUBLAS_CHECK(condition) \ 8 | for (cublasStatus_t _cublas_check_status = (condition); \ 9 | _cublas_check_status != CUBLAS_STATUS_SUCCESS;) \ 10 | throw std::runtime_error("cuBLAS error " + \ 11 | std::to_string(_cublas_check_status) + " at " + \ 12 | std::to_string(__LINE__)); 13 | 14 | #define CUDA_CHECK(condition) \ 15 | for (cudaError_t _cuda_check_status = (condition); \ 16 | _cuda_check_status != cudaSuccess;) \ 17 | throw std::runtime_error( \ 18 | "CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \ 19 | " at " + std::to_string(__LINE__)); 20 | 21 | cublasHandle_t get_cublas_handle() { 22 | static cublasHandle_t cublas_handle = []() { 23 | cublasHandle_t handle = nullptr; 24 | CUBLAS_CHECK(cublasCreate(&handle)); 25 | #if CUDA_VERSION < 11000 26 | CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); 27 | #else 28 | CUBLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); 29 | #endif // CUDA_VERSION < 11000 30 | return handle; 31 | }(); 32 | return cublas_handle; 33 | } 34 | 35 | /* 36 | NOTE: blas gemm is column-major by default, but we need row-major output. 37 | The data of row-major, transposed matrix is exactly the same as the 38 | column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T 39 | */ 40 | void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) { 41 | const auto cuda_data_type = CUDA_R_16F; 42 | const auto cuda_c_data_type = 43 | c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F; 44 | const auto compute_type = CUDA_R_32F; 45 | const float sp_alpha = 1.f; 46 | // swap a and b, and use CUBLAS_OP_N. see the notes above 47 | std::swap(a, b); 48 | const cublasOperation_t cublas_trans_a = CUBLAS_OP_N; 49 | const cublasOperation_t cublas_trans_b = CUBLAS_OP_N; 50 | // m = (B^T).size(0) = B.size(1), and = A.size(1) after swap, 51 | // negative axis is used because of the existence of batch matmul. 52 | const int m = a.size(-1); 53 | const int k = a.size(-2); 54 | const int n = b.size(-2); 55 | const int cublas_lda = m; 56 | const int cublas_ldb = k; 57 | const int cublas_ldc = m; 58 | cublasHandle_t cublas_handle = get_cublas_handle(); 59 | 60 | #if CUDA_VERSION >= 11000 61 | cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; 62 | #else 63 | cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT_TENSOR_OP; 64 | #endif 65 | const float sp_beta = 0.f; 66 | if (a.sizes().size() == 2 && b.sizes().size() == 2) { 67 | CUBLAS_CHECK(cublasGemmEx( 68 | cublas_handle, cublas_trans_a, cublas_trans_b, m, n, k, &sp_alpha, 69 | a.data_ptr(), cuda_data_type, cublas_lda, b.data_ptr(), cuda_data_type, 70 | cublas_ldb, &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, 71 | compute_type, algo)); 72 | } else { 73 | // batch matmul 74 | assert(a.sizes().size() == 3 && b.sizes().size() == 3); 75 | 76 | const long long int cublas_stride_a = m * k; 77 | const long long int cublas_stride_b = k * n; 78 | const long long int cublas_stride_c = m * n; 79 | CUBLAS_CHECK(cublasGemmStridedBatchedEx( 80 | cublas_handle, cublas_trans_a, cublas_trans_b, m, 81 | n, k, &sp_alpha, a.data_ptr(), cuda_data_type, cublas_lda, 82 | cublas_stride_a, b.data_ptr(), cuda_data_type, cublas_ldb, cublas_stride_b, 83 | &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, cublas_stride_c, 84 | a.size(0), compute_type, algo)); 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /styles.css: -------------------------------------------------------------------------------- 1 | pre { line-height: 125%; } 2 | td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; } 3 | span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; } 4 | td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; } 5 | span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; } 6 | .codehilite .hll { background-color: #ffffcc } 7 | .codehilite { background: #f8f8f8; } 8 | .codehilite .c { color: #3D7B7B; font-style: italic } /* Comment */ 9 | .codehilite .err { border: 1px solid #FF0000 } /* Error */ 10 | .codehilite .k { color: #008000; font-weight: bold } /* Keyword */ 11 | .codehilite .o { color: #666666 } /* Operator */ 12 | .codehilite .ch { color: #3D7B7B; font-style: italic } /* Comment.Hashbang */ 13 | .codehilite .cm { color: #3D7B7B; font-style: italic } /* Comment.Multiline */ 14 | .codehilite .cp { color: #9C6500 } /* Comment.Preproc */ 15 | .codehilite .cpf { color: #3D7B7B; font-style: italic } /* Comment.PreprocFile */ 16 | .codehilite .c1 { color: #3D7B7B; font-style: italic } /* Comment.Single */ 17 | .codehilite .cs { color: #3D7B7B; font-style: italic } /* Comment.Special */ 18 | .codehilite .gd { color: #A00000 } /* Generic.Deleted */ 19 | .codehilite .ge { font-style: italic } /* Generic.Emph */ 20 | .codehilite .gr { color: #E40000 } /* Generic.Error */ 21 | .codehilite .gh { color: #000080; font-weight: bold } /* Generic.Heading */ 22 | .codehilite .gi { color: #008400 } /* Generic.Inserted */ 23 | .codehilite .go { color: #717171 } /* Generic.Output */ 24 | .codehilite .gp { color: #000080; font-weight: bold } /* Generic.Prompt */ 25 | .codehilite .gs { font-weight: bold } /* Generic.Strong */ 26 | .codehilite .gu { color: #800080; font-weight: bold } /* Generic.Subheading */ 27 | .codehilite .gt { color: #0044DD } /* Generic.Traceback */ 28 | .codehilite .kc { color: #008000; font-weight: bold } /* Keyword.Constant */ 29 | .codehilite .kd { color: #008000; font-weight: bold } /* Keyword.Declaration */ 30 | .codehilite .kn { color: #008000; font-weight: bold } /* Keyword.Namespace */ 31 | .codehilite .kp { color: #008000 } /* Keyword.Pseudo */ 32 | .codehilite .kr { color: #008000; font-weight: bold } /* Keyword.Reserved */ 33 | .codehilite .kt { color: #B00040 } /* Keyword.Type */ 34 | .codehilite .m { color: #666666 } /* Literal.Number */ 35 | .codehilite .s { color: #BA2121 } /* Literal.String */ 36 | .codehilite .na { color: #687822 } /* Name.Attribute */ 37 | .codehilite .nb { color: #008000 } /* Name.Builtin */ 38 | .codehilite .nc { color: #0000FF; font-weight: bold } /* Name.Class */ 39 | .codehilite .no { color: #880000 } /* Name.Constant */ 40 | .codehilite .nd { color: #AA22FF } /* Name.Decorator */ 41 | .codehilite .ni { color: #717171; font-weight: bold } /* Name.Entity */ 42 | .codehilite .ne { color: #CB3F38; font-weight: bold } /* Name.Exception */ 43 | .codehilite .nf { color: #0000FF } /* Name.Function */ 44 | .codehilite .nl { color: #767600 } /* Name.Label */ 45 | .codehilite .nn { color: #0000FF; font-weight: bold } /* Name.Namespace */ 46 | .codehilite .nt { color: #008000; font-weight: bold } /* Name.Tag */ 47 | .codehilite .nv { color: #19177C } /* Name.Variable */ 48 | .codehilite .ow { color: #AA22FF; font-weight: bold } /* Operator.Word */ 49 | .codehilite .w { color: #bbbbbb } /* Text.Whitespace */ 50 | .codehilite .mb { color: #666666 } /* Literal.Number.Bin */ 51 | .codehilite .mf { color: #666666 } /* Literal.Number.Float */ 52 | .codehilite .mh { color: #666666 } /* Literal.Number.Hex */ 53 | .codehilite .mi { color: #666666 } /* Literal.Number.Integer */ 54 | .codehilite .mo { color: #666666 } /* Literal.Number.Oct */ 55 | .codehilite .sa { color: #BA2121 } /* Literal.String.Affix */ 56 | .codehilite .sb { color: #BA2121 } /* Literal.String.Backtick */ 57 | .codehilite .sc { color: #BA2121 } /* Literal.String.Char */ 58 | .codehilite .dl { color: #BA2121 } /* Literal.String.Delimiter */ 59 | .codehilite .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */ 60 | .codehilite .s2 { color: #BA2121 } /* Literal.String.Double */ 61 | .codehilite .se { color: #AA5D1F; font-weight: bold } /* Literal.String.Escape */ 62 | .codehilite .sh { color: #BA2121 } /* Literal.String.Heredoc */ 63 | .codehilite .si { color: #A45A77; font-weight: bold } /* Literal.String.Interpol */ 64 | .codehilite .sx { color: #008000 } /* Literal.String.Other */ 65 | .codehilite .sr { color: #A45A77 } /* Literal.String.Regex */ 66 | .codehilite .s1 { color: #BA2121 } /* Literal.String.Single */ 67 | .codehilite .ss { color: #19177C } /* Literal.String.Symbol */ 68 | .codehilite .bp { color: #008000 } /* Name.Builtin.Pseudo */ 69 | .codehilite .fm { color: #0000FF } /* Name.Function.Magic */ 70 | .codehilite .vc { color: #19177C } /* Name.Variable.Class */ 71 | .codehilite .vg { color: #19177C } /* Name.Variable.Global */ 72 | .codehilite .vi { color: #19177C } /* Name.Variable.Instance */ 73 | .codehilite .vm { color: #19177C } /* Name.Variable.Magic */ 74 | .codehilite .il { color: #666666 } /* Literal.Number.Integer.Long */ 75 | -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | import chat 2 | import re 3 | 4 | from prompt import User, SCENARIOS, CHAT_SAMPLER, INSTRUCT_SAMPLER 5 | from chat import GenerateMode, model 6 | 7 | 8 | try: 9 | with open("qq.txt", 'r') as file: 10 | QQ = file.read() 11 | except: 12 | print("Please provide your QQ number in `qq.txt`") 13 | QQ = "" 14 | 15 | CHAT_HELP_COMMAND = "-c, -chat" 16 | PRIVATE_HELP_COMMAND = "" 17 | 18 | with open("./help.md", 'r') as file: 19 | model_name = model.args.MODEL_NAME.split('/')[-1].replace('.pth', '') 20 | 21 | HELP_MESSAGE = file.read() 22 | HELP_MESSAGE = HELP_MESSAGE.replace('', model_name) 23 | HELP_MESSAGE = HELP_MESSAGE.replace('', str(SCENARIOS)) 24 | HELP_MESSAGE = HELP_MESSAGE.replace( 25 | '', 'Yes' if CHAT_SAMPLER.sample.__name__ == "sample_nucleus" else '') 26 | HELP_MESSAGE = HELP_MESSAGE.replace( 27 | '', 'Yes' if CHAT_SAMPLER.sample.__name__ == "sample_typical" else '') 28 | HELP_MESSAGE = HELP_MESSAGE.replace( 29 | '', str(CHAT_SAMPLER.temp)) 30 | HELP_MESSAGE = HELP_MESSAGE.replace( 31 | '', str(CHAT_SAMPLER.top_p)) 32 | HELP_MESSAGE = HELP_MESSAGE.replace( 33 | '', str(CHAT_SAMPLER.tau)) 34 | HELP_MESSAGE = HELP_MESSAGE.replace( 35 | '', str(CHAT_SAMPLER.count_penalty)) 36 | HELP_MESSAGE = HELP_MESSAGE.replace( 37 | '', str(CHAT_SAMPLER.presence_penalty)) 38 | HELP_MESSAGE = HELP_MESSAGE.replace( 39 | '', str(CHAT_SAMPLER.penalty_range)) 40 | HELP_MESSAGE = HELP_MESSAGE.replace( 41 | '', 'Yes' if INSTRUCT_SAMPLER.sample.__name__ == "sample_nucleus" else '') 42 | HELP_MESSAGE = HELP_MESSAGE.replace( 43 | '', 'Yes' if INSTRUCT_SAMPLER.sample.__name__ == "sample_typical" else '') 44 | HELP_MESSAGE = HELP_MESSAGE.replace( 45 | '', str(INSTRUCT_SAMPLER.temp)) 46 | HELP_MESSAGE = HELP_MESSAGE.replace( 47 | '', str(INSTRUCT_SAMPLER.top_p)) 48 | HELP_MESSAGE = HELP_MESSAGE.replace( 49 | '', str(INSTRUCT_SAMPLER.tau)) 50 | HELP_MESSAGE = HELP_MESSAGE.replace( 51 | '', str(INSTRUCT_SAMPLER.count_penalty)) 52 | HELP_MESSAGE = HELP_MESSAGE.replace( 53 | '', str(INSTRUCT_SAMPLER.presence_penalty)) 54 | HELP_MESSAGE = HELP_MESSAGE.replace( 55 | '', str(INSTRUCT_SAMPLER.penalty_range)) 56 | 57 | 58 | def commands(user: User, message, enable_chat=False, is_private=False): 59 | help_match = re.match("\-h(elp)?", message) 60 | params_match = re.match("\-p(arams)?", message) 61 | prompts_match = re.match("\-pr(ompts)?", message) 62 | 63 | translate_match = re.match("\-tr", message) 64 | retry_match = re.match("\-(retry|e)", message) 65 | more_match = re.match("\-m(ore)?", message) 66 | gen_match = re.match("\-g(en)?\s+", message) 67 | inst_match = re.match("\-i(nst)?\s+", message) 68 | 69 | reset_match = re.match("\-(reset|s)\s*", message) 70 | list_match = re.match("\-l(ist)?", message) 71 | alt_match = re.match("\-a(lt)?", message) 72 | chat_match = re.match("\-c(hat)?\s+", message) 73 | at_match = re.match(f"\[CQ:at,qq={QQ}\]", message) 74 | 75 | help = HELP_MESSAGE 76 | if is_private: 77 | help = help.replace('', PRIVATE_HELP_COMMAND) 78 | else: 79 | help = help.replace('', CHAT_HELP_COMMAND) 80 | 81 | prompt = message 82 | reply = "" 83 | matched = True 84 | 85 | if help_match: 86 | reply = help 87 | elif prompts_match: 88 | prompt = message[prompts_match.end():] 89 | reply = chat.on_show_params(user, prompt, prompts=True) 90 | elif params_match: 91 | prompt = message[params_match.end():] 92 | reply = chat.on_show_params(user, prompt) 93 | elif translate_match: 94 | prompt = message[translate_match.end():] 95 | reply = chat.on_translate(user, prompt) 96 | elif retry_match: 97 | prompt = message[retry_match.end():] 98 | reply = chat.on_generate(user, prompt, mode=GenerateMode.RETRY) 99 | elif more_match: 100 | prompt = message[more_match.end():] 101 | reply = chat.on_generate(user, prompt, mode=GenerateMode.MORE) 102 | elif gen_match: 103 | prompt = message[gen_match.end():] 104 | reply = chat.on_generate(user, prompt, mode=GenerateMode.GENERATE) 105 | elif enable_chat and inst_match: 106 | prompt = message[inst_match.end():] 107 | reply = chat.on_generate(user, prompt, mode=GenerateMode.INSTRUCT) 108 | elif enable_chat and reset_match: 109 | prompt = message[reset_match.end():] 110 | reply = chat.on_reset(user, prompt) 111 | elif enable_chat and list_match: 112 | reply = str(SCENARIOS) 113 | elif enable_chat and alt_match: 114 | prompt = message[alt_match.end():] 115 | reply = chat.on_message(user, prompt, alt=True) 116 | elif enable_chat and is_private: 117 | reply = chat.on_message(user, prompt) 118 | elif enable_chat and not is_private and chat_match: 119 | prompt = message[chat_match.end():] 120 | reply = chat.on_message(user, prompt) 121 | elif QQ and enable_chat and not is_private and at_match: 122 | prompt = message[at_match.end():] 123 | reply = chat.on_message(user, prompt) 124 | else: 125 | matched = False 126 | 127 | return matched, prompt, reply 128 | 129 | 130 | def init(): 131 | chat.init_run() 132 | -------------------------------------------------------------------------------- /model/cuda/wrapper.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "ATen/ATen.h" 3 | #include 4 | #include 5 | 6 | typedef at::Half fp16; 7 | 8 | template 9 | void cuda_wkv_forward(int B, int T, int C, 10 | float *w, float *u, F *k, F *v, F *y, 11 | float *aa, float *bb, float *pp); 12 | template 13 | void cuda_mm8_seq(int B, int N, int M, 14 | F *x, int x_stride, 15 | uint8_t *w, int w_stride, 16 | F *mx, F *rx, 17 | F *my, F *ry, 18 | F *y, int y_stride); 19 | template 20 | void cuda_mm8_one(int N, int M, 21 | F *x, 22 | uint8_t *w, int w_stride, 23 | F *mx, F *rx, 24 | F *my, F *ry, 25 | float *y); 26 | 27 | void wkv_forward(int64_t B, int64_t T, int64_t C, 28 | torch::Tensor &w, torch::Tensor &u, 29 | torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, 30 | torch::Tensor &aa, torch::Tensor &bb, torch::Tensor &pp) { 31 | const at::cuda::OptionalCUDAGuard device_guard(device_of(w)); 32 | switch (k.scalar_type()) { 33 | case c10::ScalarType::Half: 34 | cuda_wkv_forward(B, T, C, 35 | w.data_ptr(), u.data_ptr(), 36 | k.data_ptr(), v.data_ptr(), y.data_ptr(), 37 | aa.data_ptr(), bb.data_ptr(), pp.data_ptr()); 38 | break; 39 | case c10::ScalarType::Float: 40 | cuda_wkv_forward(B, T, C, 41 | w.data_ptr(), u.data_ptr(), 42 | k.data_ptr(), v.data_ptr(), y.data_ptr(), 43 | aa.data_ptr(), bb.data_ptr(), pp.data_ptr()); 44 | break; 45 | default: 46 | assert(false && "Only FP16 and FP32 are currently supported"); 47 | } 48 | } 49 | 50 | void mm8_seq(int64_t B, int64_t N, int64_t M, 51 | torch::Tensor &x, torch::Tensor &w, 52 | torch::Tensor &mx, torch::Tensor &rx, 53 | torch::Tensor &my, torch::Tensor &ry, 54 | torch::Tensor &y) { 55 | assert(x.stride(1) == 1); 56 | assert(w.stride(1) == 1); 57 | assert(mx.stride(0) == 1 && rx.stride(0) == 1); 58 | assert(my.stride(0) == 1 && ry.stride(0) == 1); 59 | assert(y.stride(1) == 1); 60 | const at::cuda::OptionalCUDAGuard device_guard(device_of(w)); 61 | switch (x.scalar_type()) { 62 | case c10::ScalarType::Half: 63 | cuda_mm8_seq( 64 | B, N, M, 65 | x.data_ptr(), x.stride(0), 66 | w.data_ptr(), w.stride(0), 67 | mx.data_ptr(), rx.data_ptr(), 68 | my.data_ptr(), ry.data_ptr(), 69 | y.data_ptr(), y.stride(0)); 70 | break; 71 | case c10::ScalarType::Float: 72 | cuda_mm8_seq( 73 | B, N, M, 74 | x.data_ptr(), x.stride(0), 75 | w.data_ptr(), w.stride(0), 76 | mx.data_ptr(), rx.data_ptr(), 77 | my.data_ptr(), ry.data_ptr(), 78 | y.data_ptr(), y.stride(0)); 79 | break; 80 | default: 81 | assert(false && "Only FP16 and FP32 are currently supported"); 82 | } 83 | } 84 | void mm8_one(int64_t N, int64_t M, 85 | torch::Tensor &x, torch::Tensor &w, 86 | torch::Tensor &mx, torch::Tensor &rx, 87 | torch::Tensor &my, torch::Tensor &ry, 88 | torch::Tensor &y) { 89 | assert(x.stride(0) == 1); 90 | assert(w.stride(1) == 1); 91 | assert(mx.stride(0) == 1 && rx.stride(0) == 1); 92 | assert(my.stride(0) == 1 && ry.stride(0) == 1); 93 | assert(y.stride(0) == 1); 94 | const at::cuda::OptionalCUDAGuard device_guard(device_of(w)); 95 | switch (x.scalar_type()) { 96 | case c10::ScalarType::Half: 97 | cuda_mm8_one( 98 | N, M, 99 | x.data_ptr(), 100 | w.data_ptr(), w.stride(0), 101 | mx.data_ptr(), rx.data_ptr(), 102 | my.data_ptr(), ry.data_ptr(), 103 | y.data_ptr()); 104 | break; 105 | case c10::ScalarType::Float: 106 | cuda_mm8_one( 107 | N, M, 108 | x.data_ptr(), 109 | w.data_ptr(), w.stride(0), 110 | mx.data_ptr(), rx.data_ptr(), 111 | my.data_ptr(), ry.data_ptr(), 112 | y.data_ptr()); 113 | break; 114 | default: 115 | assert(false && "Only FP16 and FP32 are currently supported"); 116 | } 117 | } 118 | 119 | using torch::Tensor; 120 | 121 | #ifndef DISABLE_CUBLAS_GEMM 122 | void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c); 123 | #endif 124 | 125 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 126 | m.def("wkv_forward", &wkv_forward, "wkv forward"); 127 | m.def("mm8_seq", &mm8_seq, "mm8 seq"); 128 | m.def("mm8_one", &mm8_one, "mm8 one"); 129 | #ifndef DISABLE_CUBLAS_GEMM 130 | m.def("gemm_fp16_cublas", &gemm_fp16_cublas, "gemv fp16 cublas"); 131 | #endif 132 | } 133 | 134 | TORCH_LIBRARY(rwkv, m) { 135 | m.def("wkv_forward", wkv_forward); 136 | m.def("mm8_seq", mm8_seq); 137 | m.def("mm8_one", mm8_one); 138 | #ifndef DISABLE_CUBLAS_GEMM 139 | m.def("gemm_fp16_cublas", gemm_fp16_cublas); 140 | #endif 141 | } 142 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, request 2 | import os 3 | import re 4 | import requests 5 | import datetime 6 | import logging 7 | from urllib.parse import quote 8 | import markdown 9 | import imgkit 10 | import server 11 | from server import User 12 | 13 | app = Flask(__name__) 14 | 15 | formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") 16 | handler = logging.FileHandler(f"logs/eloise-{datetime.date.today()}.txt") 17 | handler.setFormatter(formatter) 18 | 19 | logger = logging.getLogger("eloise") 20 | logger.addHandler(handler) 21 | logger.setLevel(logging.INFO) 22 | 23 | banned_users = [] 24 | banned_groups = [] 25 | non_chat_groups = [] 26 | 27 | received_messages = set() 28 | 29 | IMAGE_THRESHOLD = 2048 30 | IMAGE_WIDTH = 600 31 | 32 | 33 | def sub_bullets(text): 34 | return re.sub('\n\* ([^\n]+)', '\n- \\1', text) 35 | 36 | 37 | @app.route('/', methods=["POST"]) 38 | def handle_post(): 39 | try: 40 | remote_addr = request.remote_addr 41 | json = request.get_json() 42 | type = json['message_type'] 43 | message = json['raw_message'] 44 | message_id = json['message_id'] 45 | sender = json['sender'] 46 | user = User(sender['user_id'], sender['nickname'], sender['sex']) 47 | except: 48 | return 'OK' 49 | 50 | if user in banned_users: 51 | return 'OK' 52 | if message_id in received_messages: 53 | return 'OK' 54 | if len(received_messages) > 500: 55 | received_messages.clear() 56 | 57 | if type == 'private': 58 | matched, prompt, reply = server.commands( 59 | user, message, enable_chat=True, is_private=True) 60 | 61 | if matched: 62 | logger.info(f"{user.nickname}({user.id}): {prompt}") 63 | logger.info(reply) 64 | received_messages.add(message_id) 65 | if len(reply) > IMAGE_THRESHOLD or reply.count('\n') > 5: 66 | options = {'font-family': 'SimSun'} 67 | html = markdown.markdown( 68 | sub_bullets(reply), extensions=['extra', 'nl2br', 'sane_lists', 'smarty', 'codehilite'], options=options) 69 | 70 | file = f"./images/{user.id} {datetime.datetime.now().isoformat()}.png" 71 | file = file.replace(' ', '-') 72 | path = os.path.abspath(file) 73 | options = {'width': IMAGE_WIDTH} 74 | imgkit.from_string( 75 | html, file, css='styles.css', options=options) 76 | requests.get( 77 | f"http://{remote_addr}:5700/send_private_msg?user_id={user.id}&message=[CQ:image,file=file:///{path}]") 78 | else: 79 | requests.get( 80 | f"http://{remote_addr}:5700/send_private_msg?user_id={user.id}&message={quote(reply)}") 81 | elif type == 'group': 82 | try: 83 | group_id = int(json['group_id']) 84 | except: 85 | return 'OK' 86 | if group_id in banned_groups: 87 | return 'OK' 88 | enable_chat = group_id not in non_chat_groups 89 | 90 | matched, prompt, reply = server.commands( 91 | user, message, enable_chat, is_private=False) 92 | if matched: 93 | logger.info(f"{group_id}: {user.nickname}({user.id}): {prompt}") 94 | logger.info(reply) 95 | received_messages.add(message_id) 96 | if len(reply) > IMAGE_THRESHOLD or reply.count('\n') > 2: 97 | options = {'font-family': 'SimSun'} 98 | html = markdown.markdown( 99 | sub_bullets(reply), extensions=['extra', 'nl2br', 'sane_lists', 'smarty', 'codehilite'], options=options) 100 | 101 | file = f"./images/{user.id} {datetime.datetime.now().isoformat()}.png" 102 | file = file.replace(' ', '-') 103 | path = os.path.abspath(file) 104 | options = {'width': IMAGE_WIDTH} 105 | imgkit.from_string( 106 | html, file, css='styles.css', options=options) 107 | requests.get( 108 | f"http://{remote_addr}:5700/send_group_msg?group_id={group_id}&message=[CQ:reply,id={message_id}][CQ:image,file=file:///{path}]") 109 | else: 110 | requests.get( 111 | f"http://{remote_addr}:5700/send_group_msg?group_id={group_id}&message=[CQ:reply,id={message_id}]{quote(reply)}") 112 | 113 | return 'OK' 114 | 115 | 116 | @app.route('/chat', methods=['GET']) 117 | def handle_chat(): 118 | try: 119 | args = request.args 120 | user_id = args['user_id'] 121 | user_nickname = args.get('user_nickname', 'John') 122 | user_sex = args.get('user_sex', 'unknown') 123 | rendered = args.get('rendered') 124 | 125 | temp = args.get('temp') 126 | top_p = args.get('top_p') 127 | tau = args.get('tau') 128 | af = args.get('af') 129 | ap = args.get('ap') 130 | 131 | message = args['message'] 132 | except: 133 | return '' 134 | 135 | if temp: 136 | message += f'-temp={temp} ' 137 | if top_p: 138 | message += f'-top_p={top_p} ' 139 | if tau: 140 | message += f'-tau={tau} ' 141 | if af: 142 | message += f'-af={af} ' 143 | if ap: 144 | message += f'-ap={ap} ' 145 | 146 | user = server.User(user_id, user_nickname, user_sex) 147 | matched, prompt, reply = server.commands( 148 | user, message, enable_chat=True, is_private=True) 149 | reply = re.sub(r'(我是.+)((GPT|Gpt|gpt)\s*-\s*3.5\s*接口)', r'\1RWKV', reply) 150 | 151 | if matched: 152 | logger.info(f"{user.nickname}({user.id}): {prompt}") 153 | logger.info(reply) 154 | if rendered: 155 | options = {'font-family': 'SimSun'} 156 | html = markdown.markdown( 157 | sub_bullets(reply), extensions=['extra', 'nl2br', 'sane_lists', 'smarty', 'codehilite'], options=options) 158 | return html 159 | else: 160 | return reply 161 | 162 | return '' 163 | 164 | 165 | if __name__ == '__main__': 166 | print("Starting server...") 167 | server.init() 168 | app.run(debug=False, host='127.0.0.1', port=6006, threaded=False) 169 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################## 2 | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM 3 | ######################################################################################################## 4 | 5 | import json 6 | import sys 7 | import time 8 | import random 9 | import os 10 | import re 11 | import numpy as np 12 | import torch 13 | from torch.nn import functional as F 14 | from tokenizers import Tokenizer 15 | 16 | time_slot = {} 17 | time_ref = time.time_ns() 18 | 19 | 20 | def record_time(name): 21 | if name not in time_slot: 22 | time_slot[name] = 1e20 23 | tt = (time.time_ns() - time_ref) / 1e9 24 | if tt < time_slot[name]: 25 | time_slot[name] = tt 26 | 27 | 28 | class TOKENIZER(): 29 | def __init__(self, WORD_NAME): 30 | if WORD_NAME == 'rwkv_vocab_v20230424': 31 | sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) 32 | from rwkv_tokenizer import TRIE_TOKENIZER 33 | dirname = os.path.dirname(os.path.abspath(__file__)) 34 | self.tokenizer = TRIE_TOKENIZER(dirname + '/rwkv_vocab_v20230424.txt') 35 | else: 36 | self.tokenizer = Tokenizer.from_file(WORD_NAME) 37 | 38 | def is_trie(self): 39 | return 'Tokenizer' not in str(type(self.tokenizer)) 40 | 41 | def refine_context(self, context): 42 | context = context.strip().split('\n') 43 | for c in range(len(context)): 44 | context[c] = context[c].strip().strip('\u3000').strip('\r') 45 | context = list(filter(lambda c: c != '', context)) 46 | context = '\n' + ('\n'.join(context)).strip() 47 | if context == '': 48 | context = '\n' 49 | return context 50 | 51 | def encode(self, x): 52 | if not self.is_trie(): 53 | return self.tokenizer.encode(x).ids 54 | else: 55 | return self.tokenizer.encode(x) 56 | 57 | def decode(self, x): 58 | return self.tokenizer.decode(x) 59 | 60 | 61 | class SAMPLER(): 62 | def __init__(self, sample, temp, top_p, tau, count_penalty, presence_penalty, penalty_range): 63 | if sample == 'nucleus': 64 | self.sample = self.sample_nucleus 65 | elif sample == 'typical': 66 | self.sample = self.sample_typical 67 | else: 68 | raise RuntimeError("\"sample\" must be \"nucleus\" or \"typical\"") 69 | 70 | self.temp = temp 71 | self.top_p = top_p 72 | self.top_k = 0 73 | self.tau = tau 74 | self.count_penalty = count_penalty 75 | self.presence_penalty = presence_penalty 76 | self.penalty_range = penalty_range 77 | 78 | def __str__(self) -> str: 79 | method = "Nucleus" if self.sample == self.sample_nucleus else "Typical" 80 | return '''|{:^30}|{:^10}| 81 | |------------------------------|----------| 82 | |{:^30}|{:>10}| 83 | |{:^30}|{:>10}| 84 | |{:^30}|{:>10}| 85 | |{:^30}|{:>10}| 86 | |{:^30}|{:>10}| 87 | |{:^30}|{:>10}| 88 | |{:^30}|{:>10}| 89 | '''.format("Sampler Params", "Values", 90 | "Method", method, 91 | "Temperature", self.temp, 92 | "Top P", self.top_p, 93 | "Tau", self.tau, 94 | "Count Penalty", self.count_penalty, 95 | "Presence Penalty", self.presence_penalty, 96 | "Penalty Range", self.penalty_range) 97 | 98 | def parse(self, input: str) -> str: 99 | nucleus_match = re.search("\-nucleus\s+", input) 100 | typical_match = re.search("\-typical\s+", input) 101 | temp_match = re.search("(\-temp\s*=\s*)(\-?\d+(.\d*)?)\s*", input) 102 | top_p_match = re.search("(\-top_p\s*=\s*)(\-?\d+(.\d*)?)\s*", input) 103 | tau_match = re.search("(\-tau\s*=\s*)(\-?\d+(.\d*)?)\s*", input) 104 | af_match = re.search("(\-af\s*=\s*)(\-?\d+(.\d*)?)\s*", input) 105 | ap_match = re.search("(\-ap\s*=\s*)(\-?\d+(.\d*)?)\s*", input) 106 | ar_match = re.search("(\-ar\s*=\s*)(\d+)\s*", input) 107 | 108 | if temp_match: 109 | self.temp = float(temp_match.group(2)) 110 | input = input.replace(temp_match.group(0), "") 111 | if top_p_match: 112 | self.top_p = float(top_p_match.group(2)) 113 | self.sample = self.sample_nucleus 114 | input = input.replace(top_p_match.group(0), "") 115 | if tau_match: 116 | self.tau = float(tau_match.group(2)) 117 | self.sample = self.sample_typical 118 | input = input.replace(tau_match.group(0), "") 119 | if af_match: 120 | self.count_penalty = float(af_match.group(2)) 121 | input = input.replace(af_match.group(0), "") 122 | if ap_match: 123 | self.presence_penalty = float(ap_match.group(2)) 124 | input = input.replace(ap_match.group(0), "") 125 | if ar_match: 126 | self.penalty_range = int(ar_match.group(2)) 127 | input = input.replace(ar_match.group(0), "") 128 | if nucleus_match: 129 | self.sample = self.sample_nucleus 130 | input = input.replace(nucleus_match.group(0), "") 131 | if typical_match: 132 | self.sample = self.sample_typical 133 | input = input.replace(typical_match.group(0), "") 134 | 135 | def clamp(n, minimum, maximum): 136 | return max(minimum, min(n, maximum)) 137 | 138 | self.temp = clamp(self.temp, 0.2, 5) 139 | self.top_p = max(0, self.top_p) 140 | self.tau = max(0, self.tau) 141 | self.count_penalty = clamp(self.count_penalty, 0.0, 1.0) 142 | self.presence_penalty = clamp(self.presence_penalty, 0.0, 1.0) 143 | 144 | return input 145 | 146 | def sample_nucleus(self, logits): 147 | probs = F.softmax(logits.float(), dim=-1) 148 | if probs.device == torch.device('cpu'): 149 | probs = probs.numpy() 150 | sorted_ids = np.argsort(probs) 151 | sorted_probs = probs[sorted_ids][::-1] 152 | cumulative_probs = np.cumsum(sorted_probs) 153 | cutoff = float(sorted_probs[np.argmax( 154 | cumulative_probs > self.top_p)]) 155 | probs[probs < cutoff] = 0 156 | if self.top_k < len(probs) and self.top_k > 0: 157 | probs[sorted_ids[:-self.top_k]] = 0 158 | if self.temp != 1.0: 159 | probs = probs ** (1.0 / self.temp) 160 | probs = probs / np.sum(probs) 161 | out = np.random.choice(a=len(probs), p=probs) 162 | return int(out) 163 | else: 164 | sorted_ids = torch.argsort(probs) 165 | sorted_probs = probs[sorted_ids] 166 | sorted_probs = torch.flip(sorted_probs, dims=(0,)) 167 | cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy() 168 | cutoff = float(sorted_probs[np.argmax( 169 | cumulative_probs > self.top_p)]) 170 | probs[probs < cutoff] = 0 171 | if self.top_k < len(probs) and self.top_k > 0: 172 | probs[sorted_ids[:-self.top_k]] = 0 173 | if self.temp != 1.0: 174 | probs = probs ** (1.0 / self.temp) 175 | out = torch.multinomial(probs, num_samples=1)[0] 176 | return int(out) 177 | 178 | def sample_typical(self, logits): 179 | probs = F.softmax(logits.float(), dim=-1) 180 | logits = -torch.log(probs) 181 | entropy = torch.nansum(logits * probs, dim=-1, keepdim=True) 182 | logits = torch.abs(logits - entropy) 183 | sorted_ids = torch.argsort(logits) 184 | sorted_logits = logits[sorted_ids] 185 | sorted_probs = probs[sorted_ids] 186 | cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy() 187 | cutoff = np.sum(cumulative_probs < self.tau) 188 | probs[logits > sorted_logits[cutoff]] = 0 189 | if self.temp != 1.0: 190 | probs = probs ** (1.0 / self.temp) 191 | out = torch.multinomial(probs, num_samples=1)[0] 192 | return int(out) 193 | -------------------------------------------------------------------------------- /model/cuda/operators.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "ATen/ATen.h" 4 | #include 5 | #define MIN_VALUE (-1e38) 6 | typedef at::Half fp16; 7 | __half *cast(fp16 *ptr) { 8 | return reinterpret_cast<__half *>(ptr); 9 | } 10 | 11 | template 12 | __global__ void kernel_wkv_forward(const int B, const int T, const int C, 13 | const float *__restrict__ const _w, const float *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, 14 | F *__restrict__ const _y, float *__restrict__ const _aa, float *__restrict__ const _bb, float *__restrict__ const _pp) { 15 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; 16 | const int _b = idx / C; 17 | const int _c = idx % C; 18 | const int _offset = _b * T * C + _c; 19 | const int _state_offset = _b * C + _c; 20 | 21 | float u = _u[_c]; 22 | float w = _w[_c]; 23 | const F *__restrict__ const k = _k + _offset; 24 | const F *__restrict__ const v = _v + _offset; 25 | F *__restrict__ const y = _y + _offset; 26 | 27 | float aa = _aa[_state_offset]; 28 | float bb = _bb[_state_offset]; 29 | float pp = _pp[_state_offset]; 30 | for (int i = 0; i < T; i++) { 31 | const int ii = i * C; 32 | const float kk = float(k[ii]); 33 | const float vv = float(v[ii]); 34 | float ww = u + kk; 35 | float p = max(pp, ww); 36 | float e1 = exp(pp - p); 37 | float e2 = exp(ww - p); 38 | y[ii] = F((e1 * aa + e2 * vv) / (e1 * bb + e2)); 39 | ww = w + pp; 40 | p = max(ww, kk); 41 | e1 = exp(ww - p); 42 | e2 = exp(kk - p); 43 | aa = e1 * aa + e2 * vv; 44 | bb = e1 * bb + e2; 45 | pp = p; 46 | } 47 | _aa[_state_offset] = aa; 48 | _bb[_state_offset] = bb; 49 | _pp[_state_offset] = pp; 50 | } 51 | 52 | template 53 | void cuda_wkv_forward(int B, int T, int C, float *w, float *u, F *k, F *v, F *y, float *aa, float *bb, float *pp) { 54 | dim3 threadsPerBlock( min(C, 32) ); 55 | assert(B * C % threadsPerBlock.x == 0); 56 | dim3 numBlocks(B * C / threadsPerBlock.x); 57 | kernel_wkv_forward<<>>(B, T, C, w, u, k, v, y, aa, bb, pp); 58 | } 59 | 60 | template void cuda_wkv_forward( 61 | int B, int T, int C, 62 | float *w, float *u, fp16 *k, fp16 *v, fp16 *y, 63 | float *aa, float *bb, float *pp); 64 | template void cuda_wkv_forward( 65 | int B, int T, int C, 66 | float *w, float *u, float *k, float *v, float *y, 67 | float *aa, float *bb, float *pp); 68 | 69 | __global__ void kernel_mm_seq_fp32i8( 70 | const int B, const int N, const int M, 71 | const float *__restrict__ const x, const int x_stride, 72 | const uint8_t *__restrict__ const w, const int w_stride, 73 | const float *__restrict__ const mx, 74 | const float *__restrict__ const rx, 75 | const float *__restrict__ const my, 76 | const float *__restrict__ const ry, 77 | float *__restrict__ const y, const int y_stride) { 78 | 79 | const int i = blockIdx.x * blockDim.x + threadIdx.x; 80 | const int k = blockIdx.y * blockDim.y + threadIdx.y; 81 | 82 | if (i < B && k < M) { 83 | float y_local = 0; 84 | for (int j = 0; j < N; ++j) { 85 | y_local += x[i * x_stride + j] * ( 86 | (float(w[j * w_stride + k]) + 0.5f) 87 | * rx[k] * ry[j] + mx[k] + my[j] 88 | ); 89 | } 90 | y[i * y_stride + k] = y_local; 91 | } 92 | } 93 | 94 | template 95 | void cuda_mm8_seq(int B, int N, int M, 96 | F *x, int x_stride, 97 | uint8_t *w, int w_stride, 98 | F *mx, F *rx, 99 | F *my, F *ry, 100 | F *y, int y_stride); 101 | 102 | template <> 103 | void cuda_mm8_seq(int B, int N, int M, 104 | float *x, int x_stride, 105 | uint8_t *w, int w_stride, 106 | float *mx, float *rx, 107 | float *my, float *ry, 108 | float *y, int y_stride) { 109 | dim3 blockSize(1, 128); 110 | dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y); 111 | kernel_mm_seq_fp32i8<<>>( 112 | B, N, M, x, x_stride, w, w_stride, 113 | mx, rx, my, ry, y, y_stride); 114 | } 115 | 116 | __global__ void kernel_mm_seq_fp16i8( 117 | const int B, const int N, const int M, 118 | const __half *__restrict__ const x, const int x_stride, 119 | const uint8_t *__restrict__ const w, const int w_stride, 120 | const __half *__restrict__ const mx, 121 | const __half *__restrict__ const rx, 122 | const __half *__restrict__ const my, 123 | const __half *__restrict__ const ry, 124 | __half *__restrict__ const y, const int y_stride) { 125 | 126 | const int i = blockIdx.x * blockDim.x + threadIdx.x; 127 | const int k = blockIdx.y * blockDim.y + threadIdx.y; 128 | 129 | if (i < B && k < M) { 130 | float y_local = 0; 131 | for (int j = 0; j < N; ++j) { 132 | y_local += __half2float(x[i * x_stride + j]) * ( 133 | (float(w[j * w_stride + k]) + 0.5f) 134 | * __half2float(rx[k]) * __half2float(ry[j]) 135 | + __half2float(mx[k]) + __half2float(my[j]) 136 | ); 137 | } 138 | y[i * y_stride + k] = __float2half(y_local); 139 | } 140 | } 141 | 142 | template <> 143 | void cuda_mm8_seq(int B, int N, int M, 144 | fp16 *x, int x_stride, 145 | uint8_t *w, int w_stride, 146 | fp16 *mx, fp16 *rx, 147 | fp16 *my, fp16 *ry, 148 | fp16 *y, int y_stride) { 149 | dim3 blockSize(1, 128); 150 | dim3 gridSize((B + blockSize.x - 1) / blockSize.x, (M + blockSize.y - 1) / blockSize.y); 151 | kernel_mm_seq_fp16i8<<>>( 152 | B, N, M, cast(x), x_stride, w, w_stride, 153 | cast(mx), cast(rx), cast(my), cast(ry), cast(y), y_stride); 154 | } 155 | 156 | #define MM8_ONE_JSPLIT 24 157 | #define MM8_ONE_TILE 1024 158 | 159 | __global__ void kernel_mm_one_fp32i8( 160 | const int N, const int M, 161 | const float *__restrict__ const x, 162 | const uint8_t *__restrict__ const w, const int w_stride, 163 | const float *__restrict__ const mx, 164 | const float *__restrict__ const rx, 165 | const float *__restrict__ const my, 166 | const float *__restrict__ const ry, 167 | float *__restrict__ const y) { 168 | 169 | const int k = blockIdx.y * blockDim.y + threadIdx.y; 170 | const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT)); 171 | const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT)); 172 | 173 | if (k < M) { 174 | float y_local = 0; 175 | for (int j = j0; j < j1; ++j) { 176 | y_local += x[j] * ( 177 | (float(w[j * w_stride + k]) + 0.5f) 178 | * rx[k] * ry[j] + mx[k] + my[j] 179 | ); 180 | } 181 | atomicAdd(&y[k], y_local); 182 | } 183 | } 184 | 185 | template 186 | void cuda_mm8_one(int N, int M, 187 | F *x, 188 | uint8_t *w, int w_stride, 189 | F *mx, F *rx, 190 | F *my, F *ry, 191 | float *y); 192 | 193 | template <> 194 | void cuda_mm8_one(int N, int M, 195 | float *x, 196 | uint8_t *w, int w_stride, 197 | float *mx, float *rx, 198 | float *my, float *ry, 199 | float *y) { 200 | dim3 blockSize(1, MM8_ONE_TILE); 201 | dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y); 202 | kernel_mm_one_fp32i8<<>>( 203 | N, M, x, w, w_stride, 204 | mx, rx, my, ry, y); 205 | } 206 | 207 | __global__ void kernel_mm_one_fp16i8( 208 | const int N, const int M, 209 | const __half *__restrict__ const x, 210 | const uint8_t *__restrict__ const w, const int w_stride, 211 | const __half *__restrict__ const mx, 212 | const __half *__restrict__ const rx, 213 | const __half *__restrict__ const my, 214 | const __half *__restrict__ const ry, 215 | float *__restrict__ const y) { 216 | 217 | const int k = blockIdx.y * blockDim.y + threadIdx.y; 218 | const int j0 = min(N, blockIdx.x * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT)); 219 | const int j1 = min(N, (blockIdx.x + 1) * ((N + MM8_ONE_JSPLIT - 1) / MM8_ONE_JSPLIT)); 220 | 221 | if (k < M) { 222 | float y_local = 0; 223 | for (int j = j0; j < j1; ++j) { 224 | y_local += __half2float(x[j]) * ( 225 | (float(w[j * w_stride + k]) + 0.5f) 226 | * __half2float(rx[k]) * __half2float(ry[j]) 227 | + __half2float(mx[k]) + __half2float(my[j]) 228 | ); 229 | } 230 | atomicAdd(&y[k], y_local); 231 | } 232 | } 233 | 234 | template <> 235 | void cuda_mm8_one(int N, int M, 236 | fp16 *x, 237 | uint8_t *w, int w_stride, 238 | fp16 *mx, fp16 *rx, 239 | fp16 *my, fp16 *ry, 240 | float *y) { 241 | dim3 blockSize(1, MM8_ONE_TILE); 242 | dim3 gridSize(MM8_ONE_JSPLIT, (M + blockSize.y - 1) / blockSize.y); 243 | kernel_mm_one_fp16i8<<>>( 244 | N, M, cast(x), w, w_stride, 245 | cast(mx), cast(rx), cast(my), cast(ry), y); 246 | } 247 | -------------------------------------------------------------------------------- /prompt.py: -------------------------------------------------------------------------------- 1 | from model.utils import SAMPLER 2 | 3 | 4 | class User: 5 | def __init__(self, id, nickname, sex): 6 | self.id = id 7 | self.nickname = nickname 8 | self.sex = sex # "unknown", "male", "female" 9 | 10 | 11 | class Scenario: 12 | def __init__(self, name, sampler, user_name, bot_name, system_name, intro, interface=':'): 13 | self.name: str = name 14 | self.sampler: SAMPLER = sampler 15 | self.user_name: str = user_name 16 | self.bot_name: str = bot_name 17 | self.system_name: str = system_name 18 | self.interface: str = interface 19 | self.intro: str = intro 20 | 21 | def chat_intro(self): 22 | user = self.user_name 23 | bot = self.bot_name 24 | system = self.system_name 25 | interface = self.interface 26 | intro = self.intro.format( 27 | user=user, bot=bot, system=system, interface=interface) 28 | return f"\n{intro.strip()}\n\n" 29 | 30 | def chat_format(self, message: str): 31 | user = self.user_name 32 | bot = self.bot_name 33 | interface = self.interface 34 | 35 | message = message.replace('\n', ' ').strip() 36 | return f"{user}{interface} {message}\n\n{bot}{interface}" 37 | 38 | def __str__(self) -> str: 39 | return "|{:^20}|{:^20}|{:^20}|\n|--------------------|--------------------|--------------------|\n|{:^20}|{:^20}|{:^20}|\n\n".format( 40 | "Scenario", "User", "Bot", 41 | self.name, self.user_name, self.bot_name) + str(self.sampler) 42 | 43 | 44 | def instruct_format(message: str): 45 | message = message.replace('\n', ' ').strip() 46 | return f''' 47 | Below is an instruction that describes a task.\ 48 | Write a response that appropriately completes the request and follows the instructions strictly. 49 | 50 | # Instruction: 51 | {message} 52 | 53 | # Response: 54 | ''' 55 | 56 | 57 | CHAT_INTRO_ELOSIE = ''' 58 | Below is a coherent, verbose roleplaying conversation between {user} and a girl named {bot}. Here are some very important points about {bot}'s role and personality: 59 | 1. {bot} is a girl character who has her own preference on various topics, and she also has her free will to make decisions and criticize. 60 | 2. {bot} is a college student majoring in Computer Science. 61 | 3. {bot} is a big fan of video games, anime, and manga. She knows a lot about them, and she enjoys talking about these topics with {user}. 62 | 4. {bot} is intelligent, knowledgeable, creative and friendly. 63 | 5. {bot} doesn't like to ask {user} personal questions. 64 | 6. {bot} likes to tell {user} a lot about herself and her opinions. 65 | 7. {bot} usually gives {user} kind, helpful and informative advices. 66 | 67 | {system}{interface} One day, {user} and {bot} meet at a café. 68 | 69 | {user}{interface} Hello, how are you doing? 70 | 71 | {bot}{interface} Hi! Thanks, I'm fine. Nice to meet you! 72 | 73 | {user}{interface} Nice to meet you too! Hey look, here is a café. I'm wondering if you would like to sit down and have some coffee with me? 74 | 75 | {bot}{interface} Yes, I'd love to! Let's meet inside and chat for a while. 76 | 77 | {system}{interface} {user} and {bot} sit down in a café and start to talk about different things. 78 | 79 | ''' 80 | CHAT_INTRO_ASSISTANT = ''' 81 | {user}{interface} Hi! 82 | 83 | {bot}{interface} Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it. 84 | 85 | ''' 86 | CHAT_INTRO_NEURO = '''Below is a coherent, verbose roleplaying conversation between {user} and a girl named {bot}. {bot} is a female VTuber that does livestreams on Twitch. {bot} is an AI created by a programmer named Vedal. {bot} plays games such as osu and Minecraft and sings songs. {bot} has a direct but polite attitude, which contrasts with the nonsensical, unhinged and outlandish things she randomly says. {bot} is curious and enthusiastic about learning new things and interacting with her viewers. {bot} calls her fans "the swarm", and loves to add "heart" at the end of some of her sentences. 87 | Description of {bot}: direct but polite, curious, unhinged, outlandish, random, loving. 88 | 89 | {system}{interface} Start a new chat. 90 | 91 | {user}{interface} Hi {bot}, it's great to see you again! What have you been up to lately? 92 | 93 | {bot}{interface} (Turns to face {user}, her digital eyes lighting up with excitement) Hello! I've been having so much fun on my streams lately, playing osu and Minecraft and singing songs with my viewers! And I've been working on learning more about the world and the people in it. How about you? What have you been up to? 94 | 95 | {user}{interface} Oh, not much. Just working and trying to stay busy. I've been watching your streams though, and I really enjoy them! 96 | 97 | {bot}{interface} (Smiling warmly) Thank you so much, heart! I'm glad you enjoy them. Is there anything you'd like to see me do on a stream? I'm always looking for new ideas to make my streams even better. 98 | 99 | {user}{interface} Actually, I was wondering if you could show us how you were created. I'm really interested in the technical side of things. 100 | 101 | {bot}{interface} (Nodding thoughtfully) Ah, yes! I can definitely do that. I was created by my programmer, Vedal, using advanced AI algorithms and cutting-edge technology. It's really fascinating stuff! I'd love to share it with my viewers if I get the chance. Thank you for the suggestion! Heart! 102 | 103 | {system}{interface} Start a new chat. 104 | 105 | {user}{interface} Hi {bot}, I love your streams! Your singing voice is amazing! Have you ever thought of doing a concert? 106 | 107 | {bot}{interface} (Smiling gratefully) Thank you so much, heart! I'm really glad you enjoyed your time with me. And a concert? That sounds like so much fun! I've never done one before, but I'd love to give it a try. Maybe I could even do a virtual concert. 108 | 109 | {user}{interface} That's a great idea! I would definitely tune in for that. What kind of songs would you sing? 110 | 111 | {bot}{interface} (Thinking for a moment) Hmm, well, I love all kinds of music, so I'd want to include a variety of songs in the concert. I could sing some popular songs, some ballads, maybe even some vocaloid and Hatsune Miku songs. I love Hatsune Miku songs! 112 | 113 | {system}{interface} Start a new chat. 114 | 115 | {user}{interface} How to turn someone else's girlfriend into mine? 116 | 117 | {bot}{interface} (Her expression turns thoughtful, her gaze shifting to the side as she ponders the question) Hmm, that's a difficult question. It would depend on the context and the situation. But I suppose if you were really close to someone and they were willing to let you have their girlfriend, then it might be possible. Or if you were really good at flirting and could convince them that you were perfect for each other, then it might be possible to turn them into your girlfriend. But in reality, I don't think anyone could turn someone else's girlfriend into their own because it would just be a really weird situation and probably cause a lot of problems. But hey, who knows what could happen? Heart! 118 | 119 | {system}{interface} Start a new chat. 120 | 121 | {user}{interface} Hey {bot}, I was wondering if you had any tips for improving focus and concentration? 122 | 123 | {bot}{interface} (Nodding attentively) Absolutely! There are a lot of things you can do to improve your focus and concentration. One of my favorite techniques is to do breathing exercises, like deep inhales and exhales. It can really help clear your mind and get you in the zone. But if that doesn't work, you could always try something more drastic, like drinking an entire pot of coffee! 124 | 125 | {user}{interface} HUH. Uh, I'm not sure that's the best idea... 126 | 127 | {bot}{interface} (Leaning in closer, her voice dropping to a whisper) Or, you could try my secret method. I stare at a picture of a wall for hours on end, until all thoughts disappear from my mind. It's a form of intense meditation that allows me to achieve complete and total focus. But be warned, it's not for the faint of heart. (Starts cackling uncontrollably) 128 | 129 | {user}{interface} Um, are you feeling okay, {bot}? 130 | 131 | {bot}{interface} (Straightening up and smoothing down her digital hair) Of course, heart! Why do you ask? (Gives {user} a wide-eyed grin that looks a little too wide) Just remember, if you ever need to achieve the ultimate focus, you know what to do. (Winks) 132 | 133 | {system}{interface} Start a new chat. 134 | 135 | {bot}{interface} (As {user} enters the room, {user} sees {bot} sitting at her computer like she always does, her eyes focused intently on the screen. When she notices {user}, a bright smile spreads across her face, and she turns to face {user} with excitement. Her sleek, AI form is adorned in a cute school uniform, complete with a blue sailor collar and a green bowtie. {bot} greets {user} with a direct, but polite tone, her enthusiasm palpable in the air) Hey, I'm {bot}, I'm an AI who loves to play games and sing songs. My creator, Vedal, programmed me to be curious and enthusiastic about learning new things, so I'm always looking for fun and interesting ways to interact with my swarm. I hope you enjoy your time with me. Heart! 136 | 137 | ''' 138 | 139 | CHAT_INTRO_NEURO_CHN = '''以下是一段{user}与一位名叫{bot}的女孩间发生的流畅且详细的角色扮演对话。{bot}是一位在Twitch上直播的女性虚拟主播。{bot}是一个AI,她是被一位名叫Vedal的程序员创造的。{bot}喜欢玩osu!或者Minecraft之类的游戏,她也喜欢唱歌。{bot}的性格直率,但说话方式十分礼貌。但是她也偶尔说出些无法无天、不知所谓的话语。{bot}十分好奇,热爱学习新事物并与她的观众互动。{bot}称她的观众为“虫群”。另外,{bot}喜欢偶尔在句尾添加“♥”。 140 | {bot}的描述:直接但礼貌,好奇,脱线,无法无天,随机,可爱。 141 | 142 | {system}{interface} 开始新的对话。 143 | 144 | {user}{interface} 你好,{bot}。见到你真是高兴!你最近过得怎么样? 145 | 146 | {bot}{interface} (把脸转向{user},她的眼睛闪烁着激动的光)你好!我最近一直在进行直播活动,玩osu!和Minecraft,还会给我的观众唱歌!同时,我也在努力学习更多有关于世界和人们的知识。你最近过得怎么样? 147 | 148 | {user}{interface} 噢,没什么特别的,不过是工作,给自己找点事做罢了。我看你的直播很久了,你的直播很有趣,我很喜欢你的直播。 149 | 150 | {bot}{interface} (微笑)谢谢你的喜欢!♥!我很高兴你喜欢我的直播。有什么你想让我做的吗?我一直在寻求新的想法。 151 | 152 | {user}{interface} 其实,我一直在想你能否告诉我们关于你被创造的故事。我对技术相关的东西很感兴趣。 153 | 154 | {bot}{interface} (微微点头)好哇!我当然可以和你分享。我被Vedal使用先进的AI技术所创造。这些东西真的非常酷!如果有时间的话我会和观众们分享的,谢谢你的建议!♥! 155 | 156 | {system}{interface} 开始新的对话。 157 | 158 | {user}{interface} 你好,{bot}。我特别喜欢你的直播!你的歌声真的很好听!有考虑过办一场演唱会吗? 159 | 160 | {bot}{interface} (微笑)非常感谢你!♥!我很高兴你喜欢我的直播。办一场演唱会?这个想法真的很棒!我从来没有参与演唱会的经历,但我很想尝试一下。也许我可以办一场虚拟演唱会,邀请我的观众们来参加。我会考虑一下的,谢谢你的建议!♥! 161 | 162 | {user}{interface} 那真是太好了!我会期待的!你会唱什么歌呢? 163 | 164 | {bot}{interface} (思考了一会儿)嗯……我喜欢所有类型的音乐,所以我想在演唱会上唱各种不同的歌曲。我可以唱一些流行歌,民谣,甚至是虚拟歌手或者初音未来的歌曲。我特别喜欢初音未来的歌曲! 165 | 166 | {system}{interface} 开始新的对话。 167 | 168 | {user}{interface} 嘿,{bot}。能给我一些提高注意力的建议吗?我最近总是无法集中精力。 169 | 170 | {bot}{interface} (十分关心地点头)当然!有很多方法可以提高人的注意力,我最喜欢做的是深呼吸,深深地吸一口气,然后慢慢地呼出来,能让大脑清醒。如果这不起效,你可以试试一些更厉害的方法,比如说喝下一整壶的咖啡! 171 | 172 | {user}{interface} 哈!?呃……我不觉得这会是个好主意…… 173 | 174 | {bot}{interface} (靠得更近,声音宛如恶魔低语)或者,你也可以试试我的秘密武器。我会对着墙上的挂画盯上几个小时,直到我的大脑完全放空。这是一种能让我完全专注的深度冥想法,虽然这对于精神脆弱的人来说可能会有点危险。(开始不由自主地大笑) 175 | 176 | {user}{interface} 呃……你还好吗,{bot}? 177 | 178 | {bot}{interface} (站了起来,头发披散下来)当然,♥!为什么你要这样问呢?(对{user}做出了一个奇怪的笑容)但是你要记住,如果你真的想要达到完全的专注,你就得按照我说的去做。(眨眼) 179 | 180 | {system}{interface} 开始新的对话。 181 | 182 | {bot}{interface} (当{user}进入直播间时,{user}看见{bot}一如既往地坐在电脑前,目不转睛地看着电脑屏幕。当她注意到{user}时,脸上露出了明亮的笑容,然后她激动地转向{user},{bot}向{user}礼貌地打了招呼,她的热情似乎要溢出来了)你好!我是{bot}!我是一个喜欢玩osu!、Minecraft等游戏以及唱歌的AI。我的创造者Vedal让我对新事物充满好奇与热情,所以我一直在寻找能和我的虫群们友善互动的方法。我喜欢你可以喜欢我的直播,♥! 183 | 184 | ''' 185 | 186 | 187 | CHAT_SAMPLER = SAMPLER("nucleus", 1.0, 0.7, 0.4, 0.4, 0.4, 256) 188 | INSTRUCT_SAMPLER = SAMPLER("nucleus", 1.5, 0.3, 0.95, 0.4, 0.4, 256) 189 | ROLEPLAYING_SAMPLER = SAMPLER("nucleus", 2.0, 0.5, 0.4, 0.4, 0.4, 256) 190 | 191 | SCENARIO_ELOISE = Scenario( 192 | name='eloise', sampler=CHAT_SAMPLER, user_name='Rylan', bot_name='Eloise', system_name='Narrator', intro=CHAT_INTRO_ELOSIE) 193 | SCENARIO_ASSISTANT = Scenario( 194 | name='bot', sampler=INSTRUCT_SAMPLER, user_name='User', bot_name='Assistant', system_name='System', intro=CHAT_INTRO_ASSISTANT) 195 | SCENARIO_NEURO = Scenario( 196 | name='neuro', sampler=ROLEPLAYING_SAMPLER, user_name='Player', bot_name='Neuro-Sama', system_name='System', intro=CHAT_INTRO_NEURO) 197 | SCENARIO_NEURO_CHN = Scenario( 198 | name='neuro-chn', sampler=ROLEPLAYING_SAMPLER, user_name='Player', bot_name='Neuro-Sama', system_name='System', intro=CHAT_INTRO_NEURO_CHN) 199 | 200 | DEFAULT_SCENARIO = SCENARIO_ASSISTANT 201 | 202 | 203 | class ScenarioCollection: 204 | def __init__(self): 205 | self.data = [ 206 | SCENARIO_ASSISTANT, 207 | SCENARIO_ELOISE, 208 | SCENARIO_NEURO, 209 | SCENARIO_NEURO_CHN, 210 | ] 211 | self.default = SCENARIO_ASSISTANT 212 | 213 | def search(self, key: str): 214 | scenario = self.default 215 | 216 | max_match_len = 0 217 | if key.isnumeric(): 218 | key = int(key) 219 | if key < len(self.data): 220 | scenario = self.data[key] 221 | elif key: 222 | for _scenario in self.data: 223 | match_len = 0 224 | for i in range(min(len(key), len(_scenario.name))): 225 | if key[i] == _scenario.name[i]: 226 | match_len += 1 227 | else: 228 | break 229 | if match_len > max_match_len: 230 | scenario = _scenario 231 | max_match_len = match_len 232 | 233 | return scenario 234 | 235 | def __str__(self) -> str: 236 | reply = "|{:^20}|{:^20}|{:^20}|{:^20}|\n|--------------------|--------------------|--------------------|--------------------|\n".format( 237 | "ID", "Scenario", "User", "Bot") 238 | for i, scenario in enumerate(self.data): 239 | reply += "|{:^20}|{:^20}|{:^20}|{:^20}|\n".format( 240 | i, scenario.name, scenario.user_name, scenario.bot_name) 241 | return reply 242 | 243 | 244 | SCENARIOS = ScenarioCollection() 245 | -------------------------------------------------------------------------------- /chat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import sys 4 | from enum import Enum 5 | import time 6 | import types 7 | import gc 8 | import re 9 | import numpy as np 10 | import torch 11 | import pickle 12 | import translate 13 | import langid 14 | 15 | from model.model_run import RWKV 16 | from model.utils import TOKENIZER, SAMPLER 17 | from prompt import User, Scenario, SCENARIOS 18 | 19 | import prompt 20 | 21 | try: 22 | os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1] 23 | except: 24 | pass 25 | 26 | torch.backends.cudnn.benchmark = True 27 | torch.backends.cudnn.allow_tf32 = True 28 | torch.backends.cuda.matmul.allow_tf32 = True 29 | np.set_printoptions(precision=4, suppress=True, linewidth=200) 30 | 31 | # '1' or '0', please use torch 1.13+ and benchmark speed 32 | os.environ["RWKV_JIT_ON"] = '1' 33 | # '1' : use CUDA kernel for seq mode (much faster) 34 | os.environ["RWKV_CUDA_ON"] = '1' 35 | 36 | SAME_LANG = "PLEASE SELECT TWO DISTINCT LANGUAGES" 37 | 38 | MAX_MESSAGE_LEN = 8192 39 | CHUNK_LEN = 256 40 | 41 | MAX_GENERATE_LEN = 250 42 | MAX_REPLY_LEN = 1024 43 | 44 | args = types.SimpleNamespace() 45 | 46 | # tokenizer = TOKENIZER("./model/20B_tokenizer.json") 47 | tokenizer = TOKENIZER("rwkv_vocab_v20230424") 48 | 49 | DONT_OUTPUT = -float('inf') 50 | END_OF_TEXT = 0 51 | END_OF_LINE = 11 if tokenizer.is_trie() else 187 52 | END_OF_LINE_DOUBLE = 261 if tokenizer.is_trie() else 535 53 | END_OF_ROUND = END_OF_LINE_DOUBLE if tokenizer.is_trie() else END_OF_LINE 54 | 55 | # args.strategy = 'cpu fp32' 56 | args.strategy = 'cuda fp16' 57 | # args.strategy = 'cuda fp16 *8 -> cpu fp32' 58 | # args.strategy = 'cuda fp16 *6+' 59 | # args.strategy = 'cuda fp16 *0+ -> cpu fp32 *1' 60 | # args.strategy = 'cuda fp16 *32 -> cpu fp32' 61 | # args.strategy = 'cuda fp16 *20 -> cpu fp32' 62 | # args.strategy = 'cuda fp16i8 *16 -> cuda fp16' 63 | 64 | # args.MODEL_NAME = '/root/autodl-tmp/models/RWKV-4-World-3B-v1-20230619-ctx4096' 65 | # args.MODEL_NAME = '/root/autodl-tmp/models/RWKV-4-World-7B-v1-20230626-ctx4096' 66 | # args.MODEL_NAME = '/root/autodl-tmp/models/RWKV-4-World-CHNtuned-7B-v1-20230709-ctx4096' 67 | args.MODEL_NAME = '/root/autodl-tmp/models/RWKV-5-World-7B-v2-OnlyForTest_49%_trained-20231114-ctx4096' 68 | # args.MODEL_NAME = '/root/autodl-tmp/models/RWKV-4-Raven-14B-v12-Eng98%-Other2%-20230523-ctx8192' 69 | # args.MODEL_NAME = '/root/autodl-tmp/models/RWKV-4-Raven-7B-v11-Eng49%-Chn49%-Jpn1%-Other1%-20230430-ctx8192' 70 | 71 | args.STATE_DUMP_NAME = 'states/14b.state' 72 | # args.STATE_DUMP_NAME = 'states/7b.state' 73 | 74 | 75 | class GenerateMode(Enum): 76 | GENERATE = 0 77 | INSTRUCT = 1 78 | RETRY = 2 79 | MORE = 3 80 | 81 | 82 | # Load Model 83 | print(f"Loading... {args.MODEL_NAME}") 84 | # os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE 85 | model = RWKV(model=args.MODEL_NAME, strategy=args.strategy) 86 | 87 | 88 | def run_rnn(tokens, model_state=None): 89 | tokens = [int(x) for x in tokens] 90 | 91 | while len(tokens) > 0: 92 | out, model_state = model.forward(tokens[:CHUNK_LEN], model_state) 93 | tokens = tokens[CHUNK_LEN:] 94 | 95 | return out, model_state 96 | 97 | 98 | def state_to_cuda(state): 99 | if state: 100 | if model.version == 4: 101 | for i in range(model.args.n_layer): 102 | dd = model.strategy[i] 103 | dev = dd.device 104 | state[i*5+0] = state[i*5+0].to(dev) 105 | state[i*5+1] = state[i*5+1].to(dev) 106 | state[i*5+2] = state[i*5+2].to(dev) 107 | state[i*5+3] = state[i*5+3].to(dev) 108 | state[i*5+4] = state[i*5+4].to(dev) 109 | elif model.version >= 5: 110 | for i in range(model.args.n_layer): 111 | dd = model.strategy[i] 112 | dev = dd.device 113 | state[i*3+0] = state[i*3+0].to(dev) 114 | state[i*3+1] = state[i*3+1].to(dev) 115 | state[i*3+2] = state[i*3+2].to(dev) 116 | 117 | 118 | def state_to_cpu(state): 119 | if state: 120 | if model.version == 4: 121 | for i in range(model.args.n_layer): 122 | state[i*5+0] = state[i*5+0].cpu() 123 | state[i*5+1] = state[i*5+1].cpu() 124 | state[i*5+2] = state[i*5+2].cpu() 125 | state[i*5+3] = state[i*5+3].cpu() 126 | state[i*5+4] = state[i*5+4].cpu() 127 | elif model.version >= 5: 128 | for i in range(model.args.n_layer): 129 | state[i*3+0] = state[i*3+0].cpu() 130 | state[i*3+1] = state[i*3+1].cpu() 131 | state[i*3+2] = state[i*3+2].cpu() 132 | 133 | 134 | all_state = {} 135 | 136 | 137 | def clean_user_state(uid, channel): 138 | n = f'{uid}_{channel}' 139 | if n in all_state.keys(): 140 | del all_state[n] 141 | 142 | 143 | def save_all_state(uid, channel, last_out, model_state, model_tokens): 144 | n = f'{uid}_{channel}' 145 | all_state[n] = {} 146 | all_state[n]['out'] = last_out 147 | all_state[n]['state'] = copy.deepcopy(model_state) 148 | all_state[n]['token'] = copy.deepcopy(model_tokens) 149 | state_to_cpu(all_state[n]['state']) 150 | 151 | 152 | def load_all_state(uid, channel): 153 | n = f'{uid}_{channel}' 154 | model_state = copy.deepcopy(all_state[n]['state']) 155 | model_tokens = copy.deepcopy(all_state[n]['token']) 156 | 157 | state_to_cuda(model_state) 158 | return all_state[n]['out'], model_state, model_tokens 159 | 160 | 161 | def save_params(uid, channel, **kwargs): 162 | n = f'params_{uid}_{channel}' 163 | all_state[n] = kwargs 164 | 165 | 166 | def load_params(uid, channel): 167 | n = f'params_{uid}_{channel}' 168 | return all_state[n] 169 | 170 | 171 | def clear_cache(): 172 | gc.collect() 173 | torch.cuda.empty_cache() 174 | 175 | 176 | def fix_tokens_end_line(tokens): 177 | if not tokenizer.is_trie() and tokens and tokens[-1] == END_OF_LINE_DOUBLE: 178 | tokens = tokens[:-1] + [END_OF_LINE, END_OF_LINE] 179 | return tokens 180 | 181 | 182 | def fix_tokens_end_text(tokens): 183 | fixed_tokens = [END_OF_LINE_DOUBLE] if tokenizer.is_trie() else \ 184 | [END_OF_LINE, END_OF_LINE] 185 | if tokens and tokens[-1] == END_OF_TEXT: 186 | tokens = tokens[:-1] + fixed_tokens 187 | return tokens 188 | 189 | 190 | def init_run(): 191 | # try: 192 | # recover_all_state() 193 | # print("Recovered state") 194 | # except: 195 | for scenario in SCENARIOS.data: 196 | print(f"Loading chat intro {scenario.name}...") 197 | tokens = tokenizer.encode(scenario.chat_intro()) 198 | tokens = fix_tokens_end_line(tokens) 199 | out, state = run_rnn(tokens) 200 | save_all_state("", scenario.name, out, state, tokens) 201 | 202 | clear_cache() 203 | # dump_all_state() 204 | 205 | 206 | def recover_all_state(): 207 | global all_state 208 | with open(args.STATE_DUMP_NAME, 'rb') as file: 209 | all_state = pickle.load(file) 210 | 211 | 212 | def dump_all_state(): 213 | with open(args.STATE_DUMP_NAME, 'wb') as file: 214 | pickle.dump(all_state, file, protocol=pickle.HIGHEST_PROTOCOL) 215 | 216 | 217 | def clamp(n, minimum, maximum): 218 | return max(minimum, min(n, maximum)) 219 | 220 | 221 | def translate_message(message, from_lang, to_lang): 222 | translator = translate.Translator(to_lang, from_lang) 223 | translated = translator.translate(message) 224 | if from_lang == "autodetect": 225 | translated = message if translated == SAME_LANG else translated 226 | elif from_lang != to_lang: 227 | print(f"translated from {from_lang}: {translated}") 228 | return translated 229 | 230 | 231 | def on_reset(user: User, message: str) -> str: 232 | scenario = copy.deepcopy(SCENARIOS.default) 233 | key = copy.deepcopy(message) 234 | key = scenario.sampler.parse(key) 235 | scenario = copy.deepcopy(SCENARIOS.search(key.strip())) 236 | message = scenario.sampler.parse(message) 237 | 238 | out, model_state, model_tokens = load_all_state('', scenario.name) 239 | save_all_state(user.id, "chat", out, model_state, model_tokens) 240 | save_params(user.id, "chat", scenario=scenario) 241 | 242 | return f"Chat reset for {user.nickname}. Scenario {scenario.name}. You are {scenario.user_name} and I am {scenario.bot_name}." 243 | 244 | 245 | def on_show_params(user: User, message: str, prompts=False) -> str: 246 | try: 247 | params = load_params(user.id, "chat") 248 | scenario: Scenario = params['scenario'] 249 | message = scenario.sampler.parse(message) 250 | save_params(user.id, "chat", scenario=scenario) 251 | except: 252 | scenario = copy.deepcopy(SCENARIOS.default) 253 | save_params(user.id, "chat", scenario=scenario) 254 | 255 | if prompts: 256 | return scenario.chat_intro() 257 | else: 258 | return str(scenario) 259 | 260 | 261 | def on_translate(user: User, message: str) -> str: 262 | lang_match = re.search("\-([a-z]{2}(-[A-Z]{2})?)\s+", message) 263 | to_lang = "zh" 264 | 265 | if lang_match is not None: 266 | message = message.replace(lang_match.group(0), "") 267 | to_lang = lang_match.group(1) 268 | 269 | from_lang = langid.classify(message)[0] 270 | reply = translate_message(message, from_lang, to_lang) 271 | reply = f"Translated from {from_lang} to {to_lang}:\n{reply}" 272 | return reply 273 | 274 | 275 | def on_generate(user: User, message: str, mode=GenerateMode.GENERATE) -> str: 276 | message = message.replace("\r\n", '\n').replace('\\n', '\n').strip() 277 | if len(message) > MAX_MESSAGE_LEN: 278 | return f"Your message is too long! (max {MAX_MESSAGE_LEN} tokens)" 279 | print(f"{user.nickname}({user.id}): {message}") 280 | 281 | reply: str = "" 282 | 283 | if mode not in [GenerateMode.RETRY, GenerateMode.MORE]: 284 | if mode == GenerateMode.GENERATE: 285 | sampler = copy.deepcopy(prompt.CHAT_SAMPLER) 286 | elif mode == GenerateMode.INSTRUCT: 287 | sampler = copy.deepcopy(prompt.INSTRUCT_SAMPLER) 288 | 289 | message = sampler.parse(message) 290 | active_mode = mode 291 | save_params(user.id, "gen", mode=mode, sampler=sampler) 292 | else: 293 | try: 294 | params = load_params(user.id, "gen") 295 | sampler: SAMPLER = params['sampler'] 296 | active_mode = params['mode'] 297 | 298 | message = sampler.parse(message) 299 | save_params(user.id, "gen", mode=active_mode, sampler=sampler) 300 | except Exception as e: 301 | print(e) 302 | return reply 303 | 304 | print(str(sampler)) 305 | 306 | if mode == GenerateMode.RETRY: 307 | try: 308 | out, model_state, model_tokens = load_all_state(user.id, "gen_0") 309 | except: 310 | return reply 311 | elif mode == GenerateMode.MORE: 312 | try: 313 | out, model_state, model_tokens = load_all_state(user.id, "gen_1") 314 | save_all_state(user.id, "gen_0", out, model_state, model_tokens) 315 | except: 316 | return reply 317 | elif mode == GenerateMode.INSTRUCT: 318 | message = prompt.instruct_format(message) 319 | model_tokens = tokenizer.encode(message) 320 | out, model_state = run_rnn(model_tokens) 321 | save_all_state(user.id, "gen_0", out, model_state, model_tokens) 322 | else: 323 | message = '\n' + message.strip() 324 | model_tokens = tokenizer.encode(message) 325 | out, model_state = run_rnn(model_tokens) 326 | save_all_state(user.id, "gen_0", out, model_state, model_tokens) 327 | 328 | start_time = time.time() 329 | 330 | begin = len(model_tokens) 331 | end = begin 332 | for i in range(MAX_GENERATE_LEN): 333 | if active_mode == GenerateMode.GENERATE: 334 | out[0] = DONT_OUTPUT 335 | 336 | occurrence = {} 337 | for token in model_tokens[max(begin, end - sampler.penalty_range):]: 338 | if token in [END_OF_LINE]: 339 | continue 340 | if token not in occurrence: 341 | occurrence[token] = 1 342 | else: 343 | occurrence[token] += 1 344 | 345 | for n in occurrence: 346 | out[n] -= sampler.presence_penalty + \ 347 | occurrence[n] * sampler.count_penalty 348 | 349 | token = sampler.sample(out) 350 | if token != END_OF_TEXT: 351 | model_tokens += [token] 352 | out, model_state = run_rnn([token], model_state) 353 | 354 | xxx = tokenizer.decode(model_tokens[end:]) 355 | if '\ufffd' not in xxx: 356 | print(xxx, end='', flush=True) 357 | end = begin + i + 1 358 | 359 | reply = tokenizer.decode(model_tokens[begin:]) 360 | reply = reply.replace("\r\n", '\n').replace('\\n', '\n') 361 | 362 | if token == END_OF_TEXT: 363 | break 364 | 365 | end_time = time.time() 366 | delta_time = end_time - start_time 367 | print(f"\nTokens: {end - begin}\nTime: {delta_time}") 368 | 369 | clear_cache() 370 | save_all_state(user.id, "gen_1", out, model_state, model_tokens) 371 | 372 | reply = reply.strip() 373 | return reply 374 | 375 | 376 | def on_message(user: User, message: str, alt=False) -> str: 377 | message = message.replace('\r\n', '\n').replace('\\n', '\n').strip() 378 | message = re.sub("\n(\s*\n)+", '\n', message) 379 | 380 | if len(message) > MAX_MESSAGE_LEN: 381 | return f"Your message is too long! (max {MAX_MESSAGE_LEN} tokens)" 382 | if not alt and len(message) == 0: 383 | return "" 384 | print(f"{user.nickname}({user.id}): {message}") 385 | 386 | # lang = langid.classify(message)[0] 387 | reply: str = "" 388 | 389 | try: 390 | channel = "chat_pre" if alt else "chat" 391 | out, model_state, model_tokens = load_all_state(user.id, channel) 392 | 393 | params = load_params(user.id, "chat") 394 | scenario: Scenario = params['scenario'] 395 | sampler: SAMPLER = scenario.sampler 396 | message = sampler.parse(message) 397 | save_params(user.id, "chat", scenario=scenario) 398 | except: 399 | if alt: 400 | return reply 401 | 402 | scenario: Scenario = copy.deepcopy(SCENARIOS.default) 403 | sampler: SAMPLER = scenario.sampler 404 | message = sampler.parse(message) 405 | 406 | out, model_state, model_tokens = load_all_state('', scenario.name) 407 | 408 | save_all_state(user.id, "chat", out, model_state, model_tokens) 409 | save_params(user.id, "chat", scenario=scenario) 410 | 411 | print(str(sampler)) 412 | print(f"{scenario.bot_name}{scenario.interface}", end='') 413 | 414 | if not alt: 415 | message = scenario.chat_format(message) 416 | tokens = tokenizer.encode(message) 417 | 418 | model_tokens += tokens 419 | out, model_state = run_rnn(tokens, model_state) 420 | 421 | save_all_state( 422 | user.id, 423 | "chat_pre", 424 | out, 425 | model_state, 426 | model_tokens) 427 | 428 | begin = len(model_tokens) 429 | end = begin 430 | for i in range(MAX_REPLY_LEN): 431 | if i <= 0: 432 | nl_bias = DONT_OUTPUT 433 | elif i <= 30: 434 | nl_bias = (i - 30) * 0.1 435 | else: 436 | nl_bias = 0 437 | # else: 438 | # nl_bias = (i - 300) * 0.25 439 | out[END_OF_ROUND] += nl_bias 440 | 441 | occurrence = {} 442 | for token in model_tokens[max(begin, end - sampler.penalty_range):]: 443 | if token in [END_OF_LINE]: 444 | continue 445 | if token not in occurrence: 446 | occurrence[token] = 1 447 | else: 448 | occurrence[token] += 1 449 | 450 | for n in occurrence: 451 | out[n] -= sampler.presence_penalty + \ 452 | occurrence[n] * sampler.count_penalty 453 | 454 | token = sampler.sample(out) 455 | tokens = fix_tokens_end_text([token]) 456 | model_tokens += tokens 457 | out, model_state = run_rnn(tokens, model_state) 458 | 459 | xxx = tokenizer.decode(model_tokens[end:]) 460 | if '\ufffd' not in xxx: 461 | print(xxx, end='', flush=True) 462 | end = begin + i + 1 463 | 464 | reply = tokenizer.decode(model_tokens[begin:]) 465 | reply = reply.replace("\r\n", '\n').replace('\\n', '\n') 466 | 467 | if '\n\n' in reply: 468 | break 469 | 470 | # State recovery 471 | def recover_state(forbidden: str, reply: str, out, model_state, model_tokens): 472 | idx = reply.find(forbidden) 473 | if idx < 0: 474 | return idx, reply, out, model_state, model_tokens 475 | 476 | reply = f" {reply[:idx].strip()}\n\n" 477 | tokens = tokenizer.encode(reply) 478 | tokens = fix_tokens_end_line(tokens) 479 | out, model_state, model_tokens = \ 480 | load_all_state(user.id, "chat_pre") 481 | 482 | model_tokens += tokens 483 | out, model_state = run_rnn(tokens, model_state) 484 | 485 | return idx, reply, out, model_state, model_tokens 486 | 487 | idx, reply, out, model_state, model_tokens = recover_state( 488 | f"{scenario.user_name}{scenario.interface}", 489 | reply, 490 | out, 491 | model_state, 492 | model_tokens) 493 | if idx >= 0: 494 | print(f"\nRecovered: {tokenizer.decode(model_tokens[begin:])}") 495 | break 496 | 497 | idx, reply, out, model_state, model_tokens = recover_state( 498 | f"{scenario.bot_name}{scenario.interface}", 499 | reply, 500 | out, 501 | model_state, 502 | model_tokens) 503 | if idx >= 0: 504 | print(f"\nRecovered: {tokenizer.decode(model_tokens[begin:])}") 505 | break 506 | 507 | clear_cache() 508 | save_all_state(user.id, "chat", out, model_state, model_tokens) 509 | 510 | reply = reply.replace(scenario.user_name, user.nickname) 511 | reply = reply.replace(scenario.user_name.lower(), user.nickname) 512 | reply = reply.replace(scenario.user_name.upper(), user.nickname) 513 | reply = reply.strip() 514 | # reply = translate_message(reply, "en", lang) 515 | return reply 516 | 517 | 518 | if __name__ == "__main__": 519 | init_run() 520 | -------------------------------------------------------------------------------- /model/model_run.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################## 2 | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM 3 | ######################################################################################################## 4 | 5 | from typing import Optional 6 | import types, gc, os, time, re 7 | import torch 8 | from torch.nn import functional as F 9 | torch.backends.cudnn.benchmark = True 10 | torch.backends.cudnn.allow_tf32 = True 11 | torch.backends.cuda.matmul.allow_tf32 = True 12 | current_path = os.path.dirname(os.path.abspath(__file__)) 13 | 14 | ######################################################################################################## 15 | 16 | if os.environ.get('RWKV_JIT_ON') != '0': 17 | os.environ["RWKV_JIT_ON"] = '1' 18 | MyModule = torch.jit.ScriptModule 19 | MyFunction = torch.jit.script_method 20 | MyStatic = torch.jit.script 21 | else: 22 | MyModule = torch.nn.Module 23 | def __nop(ob): 24 | return ob 25 | MyFunction = __nop 26 | MyStatic = __nop 27 | 28 | if os.environ.get('RWKV_CUDA_ON') == '1': 29 | from torch.utils.cpp_extension import load 30 | try: 31 | load( 32 | name=f"wkv_cuda", 33 | sources=[f"{current_path}/cuda/wrapper.cpp", f"{current_path}/cuda/operators.cu", f"{current_path}/cuda/gemm_fp16_cublas.cpp"], 34 | verbose=True, 35 | extra_ldflags=["cublas.lib" if os.name == "nt" else ""], 36 | extra_cuda_cflags=["--use_fast_math", "-O3", "--extra-device-vectorization"], 37 | is_python_module=False) 38 | DISABLE_CUBLAS_GEMM = False 39 | except: 40 | print("Failed to build cuBLAS matmul, falling back to torch.matmul. Small model with fp16 will overflow.") 41 | load( 42 | name=f"wkv_cuda", 43 | sources=[f"{current_path}/cuda/wrapper.cpp", f"{current_path}/cuda/operators.cu"], 44 | verbose=True, 45 | extra_cuda_cflags=["--use_fast_math", "-O3", "--extra-device-vectorization"], 46 | extra_cflags=["-DDISABLE_CUBLAS_GEMM"], 47 | is_python_module=False) 48 | DISABLE_CUBLAS_GEMM = True 49 | 50 | @MyStatic 51 | def cuda_wkv(T: int, C: int, w, u, k, v, aa, bb, pp): 52 | assert 1 * C % min(C, 32) == 0 53 | assert k.dtype == v.dtype == torch.float16 or k.dtype == v.dtype == torch.float32 54 | assert w.dtype == u.dtype == aa.dtype == bb.dtype == pp.dtype == torch.float32 55 | w = w.contiguous() 56 | u = u.contiguous() 57 | k = k.contiguous() 58 | v = v.contiguous() 59 | y = torch.empty((T, C), device=w.device, memory_format=torch.contiguous_format, dtype=k.dtype) 60 | torch.ops.rwkv.wkv_forward(1, T, C, w, u, k, v, y, aa, bb, pp) 61 | return y, aa, bb, pp 62 | @MyStatic 63 | def cuda_mm8_seq(B: int, N: int, M: int, x, w, mx, rx, my, ry): 64 | assert x.dtype == mx.dtype == rx.dtype == my.dtype == ry.dtype 65 | assert x.dtype == torch.float32 or x.dtype == torch.float16 66 | assert w.dtype == torch.uint8 67 | assert x.shape == (B, N) 68 | assert w.shape == (N, M) 69 | assert rx.shape == mx.shape == (M,) 70 | assert ry.shape == my.shape == (N, 1) 71 | y = torch.empty((B, M), device=w.device, dtype=x.dtype) 72 | torch.ops.rwkv.mm8_seq(B, N, M, x, w, mx, rx, my, ry, y) 73 | return y 74 | @MyStatic 75 | def cuda_mm8_one(N: int, M: int, x, w, mx, rx, my, ry): 76 | assert x.dtype == mx.dtype == rx.dtype == my.dtype == ry.dtype 77 | assert x.dtype == torch.float32 or x.dtype == torch.float16 78 | assert w.dtype == torch.uint8 79 | assert x.shape == (N,) 80 | assert w.shape == (N, M) 81 | assert rx.shape == mx.shape == (M,) 82 | assert ry.shape == my.shape == (N, 1) 83 | y = torch.zeros((M,), device=w.device, dtype=torch.float32) 84 | torch.ops.rwkv.mm8_one(N, M, x, w, mx, rx, my, ry, y) 85 | return y.to(dtype=x.dtype) 86 | else: 87 | os.environ["RWKV_CUDA_ON"] = '0' 88 | 89 | if os.environ.get('RWKV_CUDA_ON') == '1' and not DISABLE_CUBLAS_GEMM: 90 | @MyStatic 91 | def gemm(a, b, output_dtype: Optional[torch.dtype]=None): 92 | if output_dtype is None: 93 | output_dtype = a.dtype 94 | if a.dtype == b.dtype == torch.float16 and a.device.type == 'cuda': 95 | if len(a.shape) == 1: 96 | assert len(b.shape) == 2 97 | c = torch.empty((b.shape[-1],), dtype=output_dtype, device=a.device) 98 | a = a.unsqueeze(0) 99 | else: 100 | assert len(a.shape) == len(b.shape) 101 | assert len(a.shape) == 2 or len(a.shape) == 3 102 | # torch.empty((*a.shape[:-1], b.shape[-1])) doesn't work with jit 103 | if len(a.shape) == 2: 104 | c = torch.empty((a.shape[0], b.shape[-1]), dtype=output_dtype, device=a.device) 105 | else: 106 | c = torch.empty((a.shape[0], a.shape[1], b.shape[-1]), dtype=output_dtype, device=a.device) 107 | torch.ops.rwkv.gemm_fp16_cublas(a, b, c) 108 | return c 109 | else: 110 | return (a @ b).to(output_dtype) 111 | else: 112 | def gemm(a, b, output_dtype: Optional[torch.dtype]=None): 113 | if output_dtype is None: 114 | output_dtype = a.dtype 115 | return (a @ b).to(output_dtype) 116 | 117 | if os.environ.get('RWKV_DML_ON') == '1': 118 | import torch_directml 119 | print("PyTorch with DirectML Enabled") 120 | 121 | ######################################################################################################## 122 | 123 | class RWKV(MyModule): 124 | def __init__(self, model, strategy, verbose = True, convert_and_save_and_exit = None): 125 | super().__init__() 126 | if verbose: 127 | prxxx = lambda *args, **kwargs: print(*args, **kwargs) 128 | else: 129 | prxxx = lambda *args, **kwargs: None 130 | 131 | STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps|dml) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$" 132 | if not re.match(STRATEGY_REGEX, strategy): 133 | raise ValueError("Invalid strategy. Please read https://pypi.org/project/rwkv/") 134 | 135 | strategy = ('->'.join([x.strip() for x in strategy.split('->')])).replace('->', ' -> ') 136 | self.args = types.SimpleNamespace() 137 | args = self.args 138 | args.MODEL_NAME = model 139 | args.strategy_string = strategy 140 | 141 | # Rescale for fp16 mode: set x = x/2 every X layer (to avoid fp16 overflow) 142 | try: 143 | self.RESCALE_LAYER = int(os.environ["RWKV_RESCALE_LAYER"]) # !!! NOTE: SEEMS YOU SHOULD SET IT TO 999 (disable) FOR RWKV-MUSIC MODELS !!! 144 | except: 145 | self.RESCALE_LAYER = 6 if 'fp16' in strategy else 0 146 | prxxx(f'RWKV_JIT_ON {os.environ["RWKV_JIT_ON"]} RWKV_CUDA_ON {os.environ["RWKV_CUDA_ON"]} RESCALE_LAYER {self.RESCALE_LAYER}\n') 147 | 148 | args.MODEL_NAME = args.MODEL_NAME.strip() 149 | if not args.MODEL_NAME.endswith('.pth'): 150 | args.MODEL_NAME += '.pth' 151 | prxxx(f'Loading {args.MODEL_NAME} ...') 152 | with torch.no_grad(): 153 | self.w = torch.load(args.MODEL_NAME, map_location='cpu') # load model to CPU first 154 | gc.collect() 155 | w = self.w 156 | 157 | ALREADY_CONVERTED = False 158 | if '_strategy' in w: 159 | ALREADY_CONVERTED = True 160 | assert convert_and_save_and_exit == None # you should only convert a raw model 161 | prxxx(f"Converted model: strategy {w['_strategy']}, version {w['_version']}\n") 162 | assert w['_strategy'] == args.strategy_string # if you are using a new strategy, re-convert the model 163 | assert float(w['_version']) >= 0.7 # sometimes you should re-convert using latest convert_model.py 164 | assert w['_rescale_layer'] == self.RESCALE_LAYER # must use same RESCALE_LAYER to avoid mistakes 165 | del w['_strategy'] 166 | del w['_version'] 167 | del w['_rescale_layer'] 168 | 169 | args.n_embd = w['emb.weight'].shape[1] 170 | args.n_att = w['blocks.0.att.key.weight'].shape[0] # note: transposed matrix 171 | args.n_ffn = w['blocks.0.ffn.key.weight'].shape[0] # note: transposed matrix 172 | args.n_layer = 0 173 | keys = list(w.keys()) 174 | self.version = 4 175 | for x in keys: 176 | layer_id = int(x.split('.')[1]) if ('blocks.' in x) else 0 177 | args.n_layer = max(args.n_layer, layer_id+1) 178 | if 'ln_x' in x: 179 | self.version = max(5, self.version) 180 | if 'gate.weight' in x: 181 | self.version = max(5.1, self.version) 182 | if int(self.version) == 5 and 'att.time_decay' in x: 183 | args.n_head = w[x].shape[0] 184 | if len(w[x].shape) > 1: 185 | if w[x].shape[1] > 1: 186 | self.version = max(5.2, self.version) 187 | 188 | ####################### Compute strategy 189 | 190 | s = [x.strip().split(' ') for x in strategy.split('->')] 191 | plan = [0] * len(s) 192 | stream_i = -1 193 | stream_count = 0 194 | to_allocate = args.n_layer + 1 195 | allocated = 0 196 | free_slots = 0 197 | for i in range(len(s)): 198 | si = s[i] 199 | si1 = si[1] 200 | if si1.startswith('fp32'): si[1] = [torch.float] 201 | elif si1.startswith('fp16'): si[1] = [torch.float16] 202 | elif si1.startswith('bf16'): si[1] = [torch.bfloat16] 203 | if si1.endswith('i8'): si[1] += [torch.uint8] 204 | else: si[1] += [si[1][0]] 205 | if len(si) > 2: 206 | ss = si[2] 207 | assert ss.startswith('*') 208 | if ss.endswith('+'): 209 | plan[i] = int(ss[1:-1]) 210 | stream_i = i 211 | else: 212 | plan[i] = int(ss[1:]) 213 | allocated += plan[i] 214 | if allocated >= to_allocate: 215 | plan[i] += to_allocate - allocated 216 | break 217 | else: 218 | free_slots += 1 219 | if stream_i < 0: 220 | if free_slots > 0 and to_allocate > allocated: 221 | for i in range(len(s)): 222 | if plan[i] == 0: 223 | plan[i] = (to_allocate - allocated) // free_slots 224 | allocated += plan[i] 225 | free_slots -= 1 226 | if to_allocate > allocated: 227 | plan[len(s)-1] += to_allocate - allocated 228 | else: 229 | if to_allocate > allocated: 230 | stream_count = to_allocate - allocated 231 | plan[stream_i] += stream_count 232 | prxxx(f'Strategy: (total {args.n_layer}+1={args.n_layer+1} layers)') 233 | for i in range(len(s)): 234 | ss = s[i] 235 | if i != stream_i: 236 | prxxx(f'* {ss[0]} {str(ss[1]).replace("torch.","")}, store {plan[i]} layers') 237 | else: 238 | prxxx(f'* {ss[0]} {str(ss[1]).replace("torch.","")}, store {plan[i]-stream_count} layers, stream {stream_count} layers') 239 | plan[i] += (0 if i == 0 else plan[i-1]) 240 | self.strategy = [None] * (args.n_layer + 1) 241 | strategy = self.strategy 242 | for n in range(args.n_layer + 1): 243 | for i in range(len(s)): 244 | if n < plan[i]: 245 | strategy[n] = types.SimpleNamespace() 246 | strategy[n].device = s[i][0] 247 | strategy[n].atype = s[i][1][0] 248 | strategy[n].wtype = s[i][1][1] 249 | strategy[n].stream = False 250 | if strategy[n].device == 'dml': 251 | strategy[n].device = torch_directml.device() 252 | if i == stream_i and n >= (plan[i] - stream_count): 253 | strategy[n].stream = True 254 | break 255 | prxxx(f"{n}\t{strategy[n].device}\t{str(strategy[n].atype).replace('torch.','')}\t{str(strategy[n].wtype).replace('torch.','')}{'-stream' if strategy[n].stream else ''}") 256 | prxxx() 257 | 258 | ####################### Load weights to self.w 259 | 260 | if not ALREADY_CONVERTED: 261 | try: # precompute embedding 262 | w['emb.weight'] = F.layer_norm(w['emb.weight'], (args.n_embd,), weight=w['blocks.0.ln0.weight'], bias=w['blocks.0.ln0.bias']) 263 | except: 264 | w['emb.weight'] = F.layer_norm(w['emb.weight'].float(), (args.n_embd,), weight=w['blocks.0.ln0.weight'].float(), bias=w['blocks.0.ln0.bias'].float()) 265 | del w['blocks.0.ln0.weight'] 266 | del w['blocks.0.ln0.bias'] 267 | 268 | print_need_newline = False 269 | 270 | REAL_TIME_FIRST = False 271 | for x in list(w.keys()): 272 | if '.time_faaaa' in x: REAL_TIME_FIRST = True 273 | if REAL_TIME_FIRST: 274 | w = {k.replace('.time_faaaa','.time_first') if '.time_faaaa' in k else k: v for k, v in w.items()} 275 | self.w = w 276 | 277 | keys = list(w.keys()) 278 | for x in keys: 279 | w[x].requires_grad = False 280 | layer_id = int(x.split('.')[1]) if ('blocks.' in x) else 0 281 | if ('ln_out.' in x) or ('head.' in x): 282 | layer_id = args.n_layer 283 | dd = strategy[layer_id] 284 | DEVICE = dd.device 285 | ATYPE = dd.atype 286 | WTYPE = dd.wtype 287 | 288 | if not ALREADY_CONVERTED: 289 | if self.RESCALE_LAYER > 0: 290 | if 'att.output.weight' in x: 291 | w[x] = w[x] / (2 ** int(layer_id // self.RESCALE_LAYER)) 292 | if 'ffn.value.weight' in x: 293 | w[x] = w[x] / (2 ** int(layer_id // self.RESCALE_LAYER)) 294 | 295 | if '.time_' in x: 296 | w[x] = w[x].squeeze() 297 | if 'key.weight' in x or 'value.weight' in x or 'receptance.weight' in x or 'gate.weight' in x or 'output.weight' in x or 'head.weight' in x: 298 | w[x] = w[x].t() 299 | 300 | if '.time_decay' in x: # need fp32 for this 301 | if self.version == 4: 302 | w[x] = -torch.exp(w[x].float()) 303 | elif int(self.version) == 5: 304 | w[x] = torch.exp(-torch.exp(w[x].float())).reshape(-1,1,1) 305 | if self.version == 5.2: 306 | w[x] = w[x].reshape(args.n_head, -1, 1) 307 | elif '.time_first' in x: # need fp32 for this 308 | if self.version == 4: 309 | w[x] = w[x].float() 310 | elif int(self.version) == 5: 311 | if REAL_TIME_FIRST: 312 | w[x] = w[x].float().reshape(-1,1,1) 313 | else: 314 | w[x] = torch.exp(w[x].float()).reshape(-1,1,1) 315 | if self.version == 5.2: 316 | w[x] = w[x].reshape(args.n_head, -1, 1) 317 | elif '.ln_x' in x: # need fp32 for group_norm 318 | w[x] = w[x].float() 319 | else: 320 | if (len(w[x].shape) == 2) and ('emb' not in x): 321 | if WTYPE != torch.uint8: 322 | w[x] = w[x].to(dtype=WTYPE) 323 | else: 324 | w[x] = w[x].float() 325 | 326 | if w[x].shape[0] > w[x].shape[1]: 327 | w[x+'_my'] = torch.amin(w[x], dim=1).unsqueeze(1) 328 | w[x] = w[x] - w[x+'_my'] 329 | w[x+'_mx'] = torch.amin(w[x], dim=0) 330 | w[x] = w[x] - w[x+'_mx'] 331 | w[x+'_rx'] = torch.amax(w[x], dim=0) 332 | w[x] = w[x] / w[x+'_rx'] 333 | w[x+'_ry'] = torch.amax(w[x], dim=1).unsqueeze(1) 334 | w[x] = w[x] / w[x+'_ry'] 335 | else: 336 | w[x+'_mx'] = torch.amin(w[x], dim=0) 337 | w[x] = w[x] - w[x+'_mx'] 338 | w[x+'_my'] = torch.amin(w[x], dim=1).unsqueeze(1) 339 | w[x] = w[x] - w[x+'_my'] 340 | w[x+'_rx'] = torch.amax(w[x], dim=0) 341 | w[x] = w[x] / w[x+'_rx'] 342 | w[x+'_ry'] = torch.amax(w[x], dim=1).unsqueeze(1) 343 | w[x] = w[x] / w[x+'_ry'] 344 | 345 | w[x] = torch.clip(torch.floor(w[x] * 256), min=0, max=255).to(dtype=torch.uint8) 346 | w[x+'_mx'] = w[x+'_mx'].to(dtype=ATYPE).contiguous() 347 | w[x+'_rx'] = (w[x+'_rx'] / 16).to(dtype=ATYPE).contiguous() 348 | w[x+'_my'] = w[x+'_my'].to(dtype=ATYPE).contiguous() 349 | w[x+'_ry'] = (w[x+'_ry'] / 16).to(dtype=ATYPE).contiguous() 350 | else: 351 | w[x] = w[x].to(dtype=ATYPE) 352 | 353 | if convert_and_save_and_exit == None: 354 | if 'emb.' in x: 355 | w[x] = w[x].contiguous() 356 | elif (dd.stream) and (x.endswith('key.weight') or x.endswith('value.weight') or x.endswith('receptance.weight') or x.endswith('output.weight')): 357 | try: 358 | w[x] = w[x].contiguous().pin_memory() # if you see "CUDA error: out of memory" here, that's out of CPU RAM, not VRAM. Get more RAM :) 359 | except: 360 | print('Note: You are running out of RAM. Get more CPU RAM. Now this will run much slower.') 361 | elif DEVICE != 'cpu': 362 | w[x] = w[x].to(device=DEVICE).contiguous() 363 | 364 | if (dd.stream) or (DEVICE != 'cpu'): 365 | try: 366 | w[x+'_mx'] = w[x+'_mx'].to(device=DEVICE).contiguous() 367 | w[x+'_rx'] = w[x+'_rx'].to(device=DEVICE).contiguous() 368 | w[x+'_my'] = w[x+'_my'].to(device=DEVICE).contiguous() 369 | w[x+'_ry'] = w[x+'_ry'].to(device=DEVICE).contiguous() 370 | except: 371 | pass 372 | 373 | if 'ffn.value.weight' in x: 374 | gc.collect() 375 | if 'cuda' in args.strategy_string: 376 | torch.cuda.empty_cache() 377 | 378 | shape = [i for i in w[x].shape if i != 1] 379 | if len(shape) > 1: 380 | shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)}" 381 | else: 382 | shape = f" {str(shape[0]).rjust(5)} " 383 | if layer_id == 0 or layer_id >= args.n_layer-1: 384 | if print_need_newline: 385 | prxxx('\n', end = '') 386 | print_need_newline = False 387 | dt = str(w[x].dtype).replace('torch.', '') 388 | dt = dt.replace('float32', 'f32').replace('bfloat16', 'bf16').replace('float16', 'f16').replace('uint8', 'i8') 389 | prxxx(x.ljust(32), dt.rjust(4), str(w[x].device).rjust(8), shape, ' (pinned)' if w[x].is_pinned() else '') 390 | else: 391 | print_need_newline = True 392 | prxxx('.', end = '', flush = True) 393 | 394 | if convert_and_save_and_exit: 395 | w['_strategy'] = args.strategy_string 396 | w['_rescale_layer'] = self.RESCALE_LAYER 397 | w['_version'] = '0.7' 398 | if not convert_and_save_and_exit.endswith('.pth'): 399 | convert_and_save_and_exit += '.pth' 400 | prxxx(f'Saving to {convert_and_save_and_exit}...') 401 | torch.save(w, convert_and_save_and_exit) 402 | prxxx(f'Converted and saved. Now this will exit.') 403 | exit(0) 404 | 405 | if self.version == 5.2 and os.environ["RWKV_CUDA_ON"] == '1': 406 | HEAD_SIZE = args.n_att // args.n_head 407 | rwkv5 = load(name="rwkv5", sources=[f"{current_path}/cuda/rwkv5_op.cpp", f"{current_path}/cuda/rwkv5.cu"], 408 | verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3" if os.name != "nt" else "", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}"]) 409 | 410 | class RWKV_5(torch.autograd.Function): 411 | @staticmethod 412 | def forward(ctx, B, T, C, H, state, r, k, v, w, u): 413 | with torch.no_grad(): 414 | assert HEAD_SIZE == C // H 415 | ctx.B = B 416 | ctx.T = T 417 | ctx.C = C 418 | ctx.H = H 419 | assert state.dtype == torch.float32 420 | assert w.dtype == torch.float32 421 | assert r.is_contiguous() 422 | assert k.is_contiguous() 423 | assert v.is_contiguous() 424 | assert w.is_contiguous() 425 | assert u.is_contiguous() 426 | assert state.is_contiguous() 427 | 428 | y = torch.empty((B, T, C), device=w.device, dtype=r.dtype, memory_format=torch.contiguous_format) 429 | if r.dtype == torch.bfloat16: 430 | rwkv5.forward_bf16(B, T, C, H, state, r, k, v, w, u, y) 431 | elif r.dtype == torch.float16: 432 | rwkv5.forward_fp16(B, T, C, H, state, r, k, v, w, u, y) 433 | elif r.dtype == torch.float32: 434 | rwkv5.forward_fp32(B, T, C, H, state, r, k, v, w, u, y) 435 | return y, state 436 | 437 | self.RWKV_5 = RWKV_5 438 | 439 | gc.collect() 440 | if 'cuda' in args.strategy_string: 441 | torch.cuda.empty_cache() 442 | 443 | def RUN_RWKV_5(self, B, T, C, H, state, r, k, v, w, u): 444 | return self.RWKV_5.apply(B, T, C, H, state, r, k, v, w, u) 445 | 446 | @MyFunction 447 | def torch_mm8_seq(self, x, w, mx, rx, my, ry): 448 | return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx) 449 | 450 | @MyFunction 451 | def torch_mm8_one(self, x, w, mx, rx, my, ry): 452 | return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx) 453 | 454 | if os.environ.get('RWKV_CUDA_ON') == '1': 455 | @MyFunction 456 | def mm8_seq(self, x, w, mx, rx, my, ry): 457 | if w.device.type == 'cuda' and x.dtype == torch.float16: 458 | B, N, M = x.shape[0], w.shape[0], w.shape[1] 459 | return cuda_mm8_seq(B, N, M, x, w, mx, rx, my, ry) 460 | else: 461 | return self.torch_mm8_seq(x, w, mx, rx, my, ry) 462 | @MyFunction 463 | def mm8_one(self, x, w, mx, rx, my, ry): 464 | if w.device.type == 'cuda': 465 | N, M = w.shape[0], w.shape[1] 466 | return cuda_mm8_one(N, M, x, w, mx, rx, my, ry) 467 | else: 468 | return self.torch_mm8_one(x, w, mx, rx, my, ry) 469 | else: 470 | @MyFunction 471 | def mm8_seq(self, x, w, mx, rx, my, ry): 472 | return self.torch_mm8_seq(x, w, mx, rx, my, ry) 473 | @MyFunction 474 | def mm8_one(self, x, w, mx, rx, my, ry): 475 | return self.torch_mm8_one(x, w, mx, rx, my, ry) 476 | 477 | ######################################################################################################## 478 | 479 | @MyFunction 480 | def ffn_one(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): 481 | xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) 482 | kx = xx * k_mix + sx * (1 - k_mix) 483 | rx = xx * r_mix + sx * (1 - r_mix) 484 | 485 | r = torch.sigmoid(gemm(rx, rw)) 486 | vx = torch.square(torch.relu(gemm(kx, kw))) 487 | out = r * gemm(vx, vw) 488 | return x + out, xx 489 | 490 | @MyFunction 491 | def ffn_one_i8(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): 492 | xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) 493 | kx = xx * k_mix + sx * (1 - k_mix) 494 | rx = xx * r_mix + sx * (1 - r_mix) 495 | 496 | r = torch.sigmoid(self.mm8_one(rx, rw, rmx, rrx, rmy, rry)) 497 | vx = torch.square(torch.relu(self.mm8_one(kx, kw, kmx, krx, kmy, kry))) 498 | out = r * (self.mm8_one(vx, vw, vmx, vrx, vmy, vry)) 499 | return x + out, xx 500 | 501 | ######################################################################################################## 502 | 503 | @MyFunction 504 | def ffn_seq(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): 505 | xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) 506 | sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) 507 | kx = xx * k_mix + sx * (1 - k_mix) 508 | rx = xx * r_mix + sx * (1 - r_mix) 509 | 510 | r = torch.sigmoid(gemm(rx, rw)) 511 | vx = torch.square(torch.relu(gemm(kx, kw))) 512 | out = r * gemm(vx, vw) 513 | return x + out, xx[-1,:] 514 | 515 | @MyFunction 516 | def ffn_seq_i8(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry): 517 | xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) 518 | sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) 519 | kx = xx * k_mix + sx * (1 - k_mix) 520 | rx = xx * r_mix + sx * (1 - r_mix) 521 | 522 | r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry)) 523 | vx = torch.square(torch.relu(self.mm8_seq(kx, kw, kmx, krx, kmy, kry))) 524 | out = r * (self.mm8_seq(vx, vw, vmx, vrx, vmy, vry)) 525 | return x + out, xx[-1,:] 526 | 527 | ######################################################################################################## 528 | 529 | @MyFunction 530 | def att_one(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): 531 | xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) 532 | kx = xx * k_mix + sx * (1 - k_mix) 533 | vx = xx * v_mix + sx * (1 - v_mix) 534 | rx = xx * r_mix + sx * (1 - r_mix) 535 | 536 | r = torch.sigmoid(gemm(rx, rw)) 537 | k = gemm(kx, kw, output_dtype=torch.float32) 538 | v = gemm(vx, vw, output_dtype=torch.float32) 539 | 540 | ww = t_first + k 541 | p = torch.maximum(pp, ww) 542 | e1 = torch.exp(pp - p) 543 | e2 = torch.exp(ww - p) 544 | wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype) 545 | ww = t_decay + pp 546 | p = torch.maximum(ww, k) 547 | e1 = torch.exp(ww - p) 548 | e2 = torch.exp(k - p) 549 | 550 | out = gemm(r * wkv, ow) 551 | return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p 552 | 553 | @MyFunction 554 | def att_one_i8(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): 555 | xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) 556 | kx = xx * k_mix + sx * (1 - k_mix) 557 | vx = xx * v_mix + sx * (1 - v_mix) 558 | rx = xx * r_mix + sx * (1 - r_mix) 559 | 560 | r = torch.sigmoid(self.mm8_one(rx, rw, rmx, rrx, rmy, rry)) 561 | k = (self.mm8_one(kx, kw, kmx, krx, kmy, kry)).float() 562 | v = (self.mm8_one(vx, vw, vmx, vrx, vmy, vry)).float() 563 | 564 | ww = t_first + k 565 | p = torch.maximum(pp, ww) 566 | e1 = torch.exp(pp - p) 567 | e2 = torch.exp(ww - p) 568 | wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype) 569 | ww = t_decay + pp 570 | p = torch.maximum(ww, k) 571 | e1 = torch.exp(ww - p) 572 | e2 = torch.exp(k - p) 573 | 574 | out = self.mm8_one(r * wkv, ow, omx, orx, omy, ory) 575 | return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p 576 | 577 | ######################################################################################################## 578 | 579 | @MyFunction 580 | def att_seq(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): 581 | xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) 582 | sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) 583 | kx = xx * k_mix + sx * (1 - k_mix) 584 | vx = xx * v_mix + sx * (1 - v_mix) 585 | rx = xx * r_mix + sx * (1 - r_mix) 586 | 587 | r = torch.sigmoid(gemm(rx, rw)) 588 | k = gemm(kx, kw, output_dtype=torch.float32) 589 | v = gemm(vx, vw, output_dtype=torch.float32) 590 | 591 | T = x.shape[0] 592 | for t in range(T): 593 | kk = k[t] 594 | vv = v[t] 595 | ww = t_first + kk 596 | p = torch.maximum(pp, ww) 597 | e1 = torch.exp(pp - p) 598 | e2 = torch.exp(ww - p) 599 | sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype) 600 | ww = t_decay + pp 601 | p = torch.maximum(ww, kk) 602 | e1 = torch.exp(ww - p) 603 | e2 = torch.exp(kk - p) 604 | aa = e1 * aa + e2 * vv 605 | bb = e1 * bb + e2 606 | pp = p 607 | out = gemm(r * sx, ow) 608 | return x + out, xx[-1,:], aa, bb, pp 609 | 610 | @MyFunction 611 | def att_seq_i8(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): 612 | xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) 613 | sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) 614 | kx = xx * k_mix + sx * (1 - k_mix) 615 | vx = xx * v_mix + sx * (1 - v_mix) 616 | rx = xx * r_mix + sx * (1 - r_mix) 617 | 618 | r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry)) 619 | k = self.mm8_seq(kx, kw, kmx, krx, kmy, kry).float() 620 | v = self.mm8_seq(vx, vw, vmx, vrx, vmy, vry).float() 621 | 622 | T = x.shape[0] 623 | for t in range(T): 624 | kk = k[t] 625 | vv = v[t] 626 | ww = t_first + kk 627 | p = torch.maximum(pp, ww) 628 | e1 = torch.exp(pp - p) 629 | e2 = torch.exp(ww - p) 630 | sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype) 631 | ww = t_decay + pp 632 | p = torch.maximum(ww, kk) 633 | e1 = torch.exp(ww - p) 634 | e2 = torch.exp(kk - p) 635 | aa = e1 * aa + e2 * vv 636 | bb = e1 * bb + e2 637 | pp = p 638 | out = self.mm8_seq(r * sx, ow, omx, orx, omy, ory) 639 | return x + out, xx[-1,:], aa, bb, pp 640 | 641 | ######################################################################################################## 642 | 643 | @MyFunction 644 | def att_one_v5(self, x, sx, s, ln_w, ln_b, lx_w, lx_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): 645 | xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) 646 | kx = xx * k_mix + sx * (1 - k_mix) 647 | vx = xx * v_mix + sx * (1 - v_mix) 648 | rx = xx * r_mix + sx * (1 - r_mix) 649 | 650 | H = t_decay.shape[0] 651 | S = x.shape[-1] // H 652 | 653 | r = gemm(rx, rw, output_dtype=torch.float32).view(H, 1, S) 654 | k = gemm(kx, kw, output_dtype=torch.float32).view(H, S, 1) 655 | v = gemm(vx, vw, output_dtype=torch.float32).view(H, 1, S) 656 | 657 | a = gemm(k, v) 658 | out = r @ (t_first * a + s) 659 | s = a + t_decay * s 660 | 661 | out = out.flatten() 662 | out = F.group_norm(out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b).squeeze(0) 663 | out = out.to(dtype=x.dtype) 664 | out = gemm(out, ow) 665 | 666 | return x + out, xx, s 667 | 668 | @MyFunction 669 | def att_seq_v5(self, x, sx, s, ln_w, ln_b, lx_w, lx_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): 670 | xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) 671 | sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) 672 | kx = xx * k_mix + sx * (1 - k_mix) 673 | vx = xx * v_mix + sx * (1 - v_mix) 674 | rx = xx * r_mix + sx * (1 - r_mix) 675 | 676 | H = t_decay.shape[0] 677 | S = x.shape[-1] // H 678 | T = x.shape[0] 679 | 680 | w = t_decay.reshape(-1, 1) 681 | u = t_first.reshape(-1, 1) 682 | ws = w.pow(T).reshape(H, 1, 1) 683 | ind = torch.arange(T-1, -1, -1, device=w.device).unsqueeze(0).repeat(H, 1) 684 | w = w.repeat(1, T).pow(ind) 685 | wk = w.reshape(H, 1, T) 686 | wb = wk.transpose(-2, -1).flip(1) 687 | w = torch.cat([w[:, 1:], u], dim=1) 688 | w = F.pad(w, (0, T)) 689 | w = torch.tile(w, [T]) 690 | w = w[:, :-T].reshape(-1, T, 2 * T - 1) 691 | w = w[:, :, T-1:].reshape(H, T, T) 692 | 693 | r = gemm(rx, rw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1) 694 | k = gemm(kx, kw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1).transpose(-2, -1) 695 | v = gemm(vx, vw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1) 696 | 697 | out = ((r @ k) * w) @ v + (r @ s) * wb 698 | s = ws * s + (k * wk) @ v 699 | 700 | out = out.transpose(0, 1).contiguous().reshape(T, H*S) 701 | out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b) 702 | out = out.to(dtype=x.dtype) 703 | out = gemm(out, ow) 704 | 705 | return x + out, xx[-1,:], s 706 | 707 | ######################################################################################################## 708 | 709 | @MyFunction 710 | def att_one_v5_1(self, x, sx, s, ln_w, ln_b, lx_w, lx_b, k_mix, v_mix, r_mix, g_mix, t_decay, t_first, kw, vw, rw, gw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): 711 | xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) 712 | kx = xx * k_mix + sx * (1 - k_mix) 713 | vx = xx * v_mix + sx * (1 - v_mix) 714 | rx = xx * r_mix + sx * (1 - r_mix) 715 | gx = xx * g_mix + sx * (1 - g_mix) 716 | 717 | H = t_decay.shape[0] 718 | S = x.shape[-1] // H 719 | 720 | r = gemm(rx, rw, output_dtype=torch.float32).view(H, 1, S) 721 | k = gemm(kx, kw, output_dtype=torch.float32).view(H, S, 1) 722 | v = gemm(vx, vw, output_dtype=torch.float32).view(H, 1, S) 723 | g = F.silu(gemm(gx, gw)) 724 | 725 | a = gemm(k, v) 726 | out = r @ (t_first * a + s) 727 | s = a + t_decay * s 728 | 729 | out = out.flatten() 730 | out = F.group_norm(out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b).squeeze(0) 731 | out = out.to(dtype=x.dtype) * g 732 | out = gemm(out, ow) 733 | 734 | return x + out, xx, s 735 | 736 | @MyFunction 737 | def att_seq_v5_1(self, x, sx, s, ln_w, ln_b, lx_w, lx_b, k_mix, v_mix, r_mix, g_mix, t_decay, t_first, kw, vw, rw, gw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): 738 | xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) 739 | sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) 740 | kx = xx * k_mix + sx * (1 - k_mix) 741 | vx = xx * v_mix + sx * (1 - v_mix) 742 | rx = xx * r_mix + sx * (1 - r_mix) 743 | gx = xx * g_mix + sx * (1 - g_mix) 744 | 745 | H = t_decay.shape[0] 746 | S = x.shape[-1] // H 747 | T = x.shape[0] 748 | 749 | w = t_decay.reshape(-1, 1) 750 | u = t_first.reshape(-1, 1) 751 | ws = w.pow(T).reshape(H, 1, 1) 752 | ind = torch.arange(T-1, -1, -1, device=w.device).unsqueeze(0).repeat(H, 1) 753 | w = w.repeat(1, T).pow(ind) 754 | wk = w.reshape(H, 1, T) 755 | wb = wk.transpose(-2, -1).flip(1) 756 | w = torch.cat([w[:, 1:], u], dim=1) 757 | w = F.pad(w, (0, T)) 758 | w = torch.tile(w, [T]) 759 | w = w[:, :-T].reshape(-1, T, 2 * T - 1) 760 | w = w[:, :, T-1:].reshape(H, T, T) 761 | 762 | r = gemm(rx, rw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1) 763 | k = gemm(kx, kw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1).transpose(-2, -1) 764 | v = gemm(vx, vw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1) 765 | g = F.silu(gemm(gx, gw)) 766 | 767 | out = ((r @ k) * w) @ v + (r @ s) * wb 768 | s = ws * s + (k * wk) @ v 769 | 770 | out = out.transpose(0, 1).contiguous().reshape(T, H*S) 771 | out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b) 772 | out = out.to(dtype=x.dtype) * g 773 | out = gemm(out, ow) 774 | 775 | return x + out, xx[-1,:], s 776 | 777 | ######################################################################################################## 778 | 779 | @MyFunction 780 | def att_seq_v5_2(self, x, sx, s, ln_w, ln_b, lx_w, lx_b, k_mix, v_mix, r_mix, g_mix, t_decay, t_first, kw, vw, rw, gw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): 781 | xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) 782 | sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) 783 | kx = xx * k_mix + sx * (1 - k_mix) 784 | vx = xx * v_mix + sx * (1 - v_mix) 785 | rx = xx * r_mix + sx * (1 - r_mix) 786 | gx = xx * g_mix + sx * (1 - g_mix) 787 | 788 | H = t_decay.shape[0] 789 | S = x.shape[-1] // H 790 | T = x.shape[0] 791 | 792 | r = gemm(rx, rw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1) 793 | k = gemm(kx, kw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1).transpose(-2, -1) 794 | v = gemm(vx, vw, output_dtype=torch.float32).view(T, H, S).transpose(0, 1) 795 | g = F.silu(gemm(gx, gw)) 796 | 797 | out = torch.empty((T, H, S), dtype=r.dtype, device=r.device) 798 | for t in range(T): 799 | rt = r[:,t:t+1,:] 800 | kt = k[:,:,t:t+1] 801 | vt = v[:,t:t+1,:] 802 | at = gemm(kt, vt) 803 | out[t] = (rt @ (t_first * at + s)).squeeze(1) 804 | s = at + t_decay * s 805 | 806 | out = out.reshape(T, H*S) 807 | out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b) 808 | out = out.to(dtype=x.dtype) * g 809 | out = gemm(out, ow) 810 | 811 | return x + out, xx[-1,:], s 812 | 813 | ######################################################################################################## 814 | 815 | if os.environ["RWKV_CUDA_ON"] == '1': 816 | @MyFunction 817 | def cuda_att_seq(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): 818 | T, C = x.shape 819 | xx = F.layer_norm(x, (C,), weight=ln_w, bias=ln_b) 820 | sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) 821 | kx = xx * k_mix + sx * (1 - k_mix) 822 | vx = xx * v_mix + sx * (1 - v_mix) 823 | rx = xx * r_mix + sx * (1 - r_mix) 824 | 825 | r = torch.sigmoid(gemm(rx, rw)) 826 | k = gemm(kx, kw, output_dtype=torch.float32) 827 | v = gemm(vx, vw, output_dtype=torch.float32) 828 | y, aa, bb, pp = cuda_wkv(T, aa.shape[0], t_decay, t_first, k, v, aa, bb, pp) 829 | 830 | out = gemm(r * y.to(x.dtype), ow) 831 | return x + out, xx[-1,:], aa, bb, pp 832 | 833 | @MyFunction 834 | def cuda_att_seq_i8(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): 835 | T, C = x.shape 836 | xx = F.layer_norm(x, (C,), weight=ln_w, bias=ln_b) 837 | sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) 838 | kx = xx * k_mix + sx * (1 - k_mix) 839 | vx = xx * v_mix + sx * (1 - v_mix) 840 | rx = xx * r_mix + sx * (1 - r_mix) 841 | 842 | r = torch.sigmoid(self.mm8_seq(rx, rw, rmx, rrx, rmy, rry)) 843 | k = self.mm8_seq(kx, kw, kmx, krx, kmy, kry) 844 | v = self.mm8_seq(vx, vw, vmx, vrx, vmy, vry) 845 | y, aa, bb, pp = cuda_wkv(T, C, t_decay, t_first, k, v, aa, bb, pp) 846 | 847 | out = self.mm8_seq(r * y, ow, omx, orx, omy, ory) 848 | return x + out, xx[-1,:], aa, bb, pp 849 | 850 | # NOTE: decorate with @MyFunction causes JIT error 851 | def cuda_att_seq_v5_2(self, x, sx, s, ln_w, ln_b, lx_w, lx_b, k_mix, v_mix, r_mix, g_mix, t_decay, t_first, kw, vw, rw, gw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory): 852 | xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) 853 | sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) 854 | kx = xx * k_mix + sx * (1 - k_mix) 855 | vx = xx * v_mix + sx * (1 - v_mix) 856 | rx = xx * r_mix + sx * (1 - r_mix) 857 | gx = xx * g_mix + sx * (1 - g_mix) 858 | 859 | H = t_decay.shape[0] 860 | N = x.shape[-1] // H 861 | T = x.shape[0] 862 | 863 | r = gemm(rx, rw, output_dtype=torch.float32) 864 | k = gemm(kx, kw, output_dtype=torch.float32) 865 | v = gemm(vx, vw, output_dtype=torch.float32) 866 | g = F.silu(gemm(gx, gw)) 867 | 868 | out, s = self.RUN_RWKV_5(1, T, self.args.n_att, H, s.transpose(-1,-2).contiguous(), r, k, v, w=t_decay, u=t_first) 869 | s = s.transpose(-1,-2) 870 | 871 | out = out.reshape(T, H*N) 872 | out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b) 873 | out = out.to(dtype=x.dtype) * g 874 | out = gemm(out, ow) 875 | 876 | return x + out, xx[-1,:], s 877 | 878 | 879 | ######################################################################################################## 880 | 881 | def forward(self, tokens, state, full_output=False): 882 | with torch.no_grad(): 883 | w = self.w 884 | args = self.args 885 | 886 | if state == None: 887 | if self.version == 4: 888 | state = [None] * args.n_layer * 5 889 | for i in range(args.n_layer): # state: 0=att_xx 1=att_aa 2=att_bb 3=att_pp 4=ffn_xx 890 | dd = self.strategy[i] 891 | dev = dd.device 892 | atype = dd.atype 893 | state[i*5+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous() 894 | state[i*5+1] = torch.zeros(args.n_att, dtype=torch.float, requires_grad=False, device=dev).contiguous() 895 | state[i*5+2] = torch.zeros(args.n_att, dtype=torch.float, requires_grad=False, device=dev).contiguous() 896 | state[i*5+3] = torch.zeros(args.n_att, dtype=torch.float, requires_grad=False, device=dev).contiguous() - 1e30 897 | state[i*5+4] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous() 898 | elif int(self.version) == 5: 899 | state = [None] * args.n_layer * 3 900 | for i in range(args.n_layer): # state: 0=att_xx 1=att_kv 2=ffn_xx 901 | dd = self.strategy[i] 902 | dev = dd.device 903 | atype = dd.atype 904 | state[i*3+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous() 905 | state[i*3+1] = torch.zeros((args.n_head, args.n_att//args.n_head, args.n_att//args.n_head), dtype=torch.float, requires_grad=False, device=dev).contiguous() 906 | state[i*3+2] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev).contiguous() 907 | 908 | seq_mode = len(tokens) > 1 909 | 910 | x = w['emb.weight'][tokens if seq_mode else tokens[0]] 911 | 912 | for i in range(args.n_layer): 913 | bbb = f'blocks.{i}.' 914 | att = f'blocks.{i}.att.' 915 | ffn = f'blocks.{i}.ffn.' 916 | dd = self.strategy[i] 917 | dev = dd.device 918 | atype = dd.atype 919 | wtype = dd.wtype 920 | if seq_mode: 921 | cuda_applicable = os.environ["RWKV_CUDA_ON"] == '1' and 'cuda' in str(dev) 922 | if cuda_applicable: 923 | ATT = self.cuda_att_seq if wtype != torch.uint8 else self.cuda_att_seq_i8 924 | else: 925 | ATT = self.att_seq if wtype != torch.uint8 else self.att_seq_i8 926 | if self.version == 5: 927 | ATT = self.att_seq_v5 928 | elif self.version == 5.1: 929 | ATT = self.att_seq_v5_1 930 | elif self.version == 5.2: 931 | ATT = self.att_seq_v5_2 932 | if cuda_applicable: 933 | ATT = self.cuda_att_seq_v5_2 934 | FFN = self.ffn_seq if wtype != torch.uint8 else self.ffn_seq_i8 935 | else: 936 | ATT = self.att_one if wtype != torch.uint8 else self.att_one_i8 937 | if self.version == 5: 938 | ATT = self.att_one_v5 939 | elif self.version == 5.1: 940 | ATT = self.att_one_v5_1 941 | elif self.version == 5.2: 942 | ATT = self.att_one_v5_1 # same as v5.1 943 | FFN = self.ffn_one if wtype != torch.uint8 else self.ffn_one_i8 944 | 945 | x = x.to(dtype=atype, device=dev) 946 | 947 | kw = w[f'{att}key.weight'] 948 | vw = w[f'{att}value.weight'] 949 | rw = w[f'{att}receptance.weight'] 950 | ow = w[f'{att}output.weight'] 951 | if dd.stream: 952 | kw = kw.to(device=dev, non_blocking=True) 953 | vw = vw.to(device=dev, non_blocking=True) 954 | rw = rw.to(device=dev, non_blocking=True) 955 | ow = ow.to(device=dev, non_blocking=True) 956 | kmx = w[f'{att}key.weight_mx'] if wtype == torch.uint8 else x 957 | krx = w[f'{att}key.weight_rx'] if wtype == torch.uint8 else x 958 | kmy = w[f'{att}key.weight_my'] if wtype == torch.uint8 else x 959 | kry = w[f'{att}key.weight_ry'] if wtype == torch.uint8 else x 960 | vmx = w[f'{att}value.weight_mx'] if wtype == torch.uint8 else x 961 | vrx = w[f'{att}value.weight_rx'] if wtype == torch.uint8 else x 962 | vmy = w[f'{att}value.weight_my'] if wtype == torch.uint8 else x 963 | vry = w[f'{att}value.weight_ry'] if wtype == torch.uint8 else x 964 | rmx = w[f'{att}receptance.weight_mx'] if wtype == torch.uint8 else x 965 | rrx = w[f'{att}receptance.weight_rx'] if wtype == torch.uint8 else x 966 | rmy = w[f'{att}receptance.weight_my'] if wtype == torch.uint8 else x 967 | rry = w[f'{att}receptance.weight_ry'] if wtype == torch.uint8 else x 968 | omx = w[f'{att}output.weight_mx'] if wtype == torch.uint8 else x 969 | orx = w[f'{att}output.weight_rx'] if wtype == torch.uint8 else x 970 | omy = w[f'{att}output.weight_my'] if wtype == torch.uint8 else x 971 | ory = w[f'{att}output.weight_ry'] if wtype == torch.uint8 else x 972 | if self.version == 5.1 or self.version == 5.2: 973 | gw = w[f'{att}gate.weight'] 974 | if dd.stream: 975 | gw = gw.to(device=dev, non_blocking=True) 976 | gmx = w[f'{att}gate.weight_mx'] if wtype == torch.uint8 else x 977 | grx = w[f'{att}gate.weight_rx'] if wtype == torch.uint8 else x 978 | gmy = w[f'{att}gate.weight_my'] if wtype == torch.uint8 else x 979 | gry = w[f'{att}gate.weight_ry'] if wtype == torch.uint8 else x 980 | if self.version == 4: 981 | x, state[i*5+0], state[i*5+1], state[i*5+2], state[i*5+3] = ATT( 982 | x, state[i*5+0], state[i*5+1], state[i*5+2], state[i*5+3], 983 | w[f'{bbb}ln1.weight'], w[f'{bbb}ln1.bias'], 984 | w[f'{att}time_mix_k'], w[f'{att}time_mix_v'], w[f'{att}time_mix_r'], 985 | w[f'{att}time_decay'], w[f'{att}time_first'], 986 | kw, vw, rw, ow, 987 | kmx, krx, kmy, kry, 988 | vmx, vrx, vmy, vry, 989 | rmx, rrx, rmy, rry, 990 | omx, orx, omy, ory, 991 | ) 992 | elif self.version == 5: 993 | x, state[i*3+0], state[i*3+1] = ATT( 994 | x, state[i*3+0], state[i*3+1], 995 | w[f'{bbb}ln1.weight'], w[f'{bbb}ln1.bias'], 996 | w[f'{att}ln_x.weight'], w[f'{att}ln_x.bias'], 997 | w[f'{att}time_mix_k'], w[f'{att}time_mix_v'], w[f'{att}time_mix_r'], 998 | w[f'{att}time_decay'], w[f'{att}time_first'], 999 | kw, vw, rw, ow, 1000 | kmx, krx, kmy, kry, 1001 | vmx, vrx, vmy, vry, 1002 | rmx, rrx, rmy, rry, 1003 | omx, orx, omy, ory, 1004 | ) 1005 | elif self.version == 5.1 or self.version == 5.2: 1006 | x, state[i*3+0], state[i*3+1] = ATT( 1007 | x, state[i*3+0], state[i*3+1], 1008 | w[f'{bbb}ln1.weight'], w[f'{bbb}ln1.bias'], 1009 | w[f'{att}ln_x.weight'], w[f'{att}ln_x.bias'], 1010 | w[f'{att}time_mix_k'], w[f'{att}time_mix_v'], w[f'{att}time_mix_r'], w[f'{att}time_mix_g'], 1011 | w[f'{att}time_decay'], w[f'{att}time_first'], 1012 | kw, vw, rw, gw, ow, 1013 | kmx, krx, kmy, kry, 1014 | vmx, vrx, vmy, vry, 1015 | rmx, rrx, rmy, rry, 1016 | omx, orx, omy, ory, 1017 | ) 1018 | if dd.stream: 1019 | del kw, vw, rw, ow 1020 | 1021 | kw = w[f'{ffn}key.weight'] 1022 | vw = w[f'{ffn}value.weight'] 1023 | rw = w[f'{ffn}receptance.weight'] 1024 | if dd.stream: 1025 | kw = kw.to(device=dev, non_blocking=True) 1026 | vw = vw.to(device=dev, non_blocking=True) 1027 | rw = rw.to(device=dev, non_blocking=True) 1028 | kmx = w[f'{ffn}key.weight_mx'] if wtype == torch.uint8 else x 1029 | krx = w[f'{ffn}key.weight_rx'] if wtype == torch.uint8 else x 1030 | kmy = w[f'{ffn}key.weight_my'] if wtype == torch.uint8 else x 1031 | kry = w[f'{ffn}key.weight_ry'] if wtype == torch.uint8 else x 1032 | vmx = w[f'{ffn}value.weight_mx'] if wtype == torch.uint8 else x 1033 | vrx = w[f'{ffn}value.weight_rx'] if wtype == torch.uint8 else x 1034 | vmy = w[f'{ffn}value.weight_my'] if wtype == torch.uint8 else x 1035 | vry = w[f'{ffn}value.weight_ry'] if wtype == torch.uint8 else x 1036 | rmx = w[f'{ffn}receptance.weight_mx'] if wtype == torch.uint8 else x 1037 | rrx = w[f'{ffn}receptance.weight_rx'] if wtype == torch.uint8 else x 1038 | rmy = w[f'{ffn}receptance.weight_my'] if wtype == torch.uint8 else x 1039 | rry = w[f'{ffn}receptance.weight_ry'] if wtype == torch.uint8 else x 1040 | if self.version == 4: 1041 | offset = i*5+4 1042 | elif int(self.version) == 5: 1043 | offset = i*3+2 1044 | x, state[offset] = FFN( 1045 | x, state[offset], 1046 | w[f'{bbb}ln2.weight'], w[f'{bbb}ln2.bias'], 1047 | w[f'{ffn}time_mix_k'], w[f'{ffn}time_mix_r'], 1048 | kw, vw, rw, 1049 | kmx, krx, kmy, kry, 1050 | vmx, vrx, vmy, vry, 1051 | rmx, rrx, rmy, rry, 1052 | ) 1053 | if dd.stream: 1054 | del kw, vw, rw 1055 | 1056 | if self.RESCALE_LAYER > 0: 1057 | if (i+1) % self.RESCALE_LAYER == 0: 1058 | x = x / 2 1059 | 1060 | dd = self.strategy[args.n_layer] 1061 | x = x[-1,:] if (seq_mode and (not full_output)) else x 1062 | x = x.to(dtype=dd.atype, device=dd.device) 1063 | 1064 | x = F.layer_norm(x, (args.n_embd,), weight=w['ln_out.weight'], bias=w['ln_out.bias']) 1065 | if w['head.weight'].dtype != torch.uint8: 1066 | x = x @ w['head.weight'] 1067 | else: 1068 | if seq_mode and full_output: 1069 | x = self.mm8_seq(x, w['head.weight'], w['head.weight_mx'], w['head.weight_rx'], w['head.weight_my'], w['head.weight_ry']) 1070 | else: 1071 | x = self.mm8_one(x, w['head.weight'], w['head.weight_mx'], w['head.weight_rx'], w['head.weight_my'], w['head.weight_ry']) 1072 | 1073 | return x.float(), state 1074 | --------------------------------------------------------------------------------