├── requirements.txt ├── ChatRWKV.png ├── RWKV-eval.png ├── src ├── utils.py └── model_run.py ├── benchmark.py ├── chat_onnx.py ├── README.md ├── LICENSE ├── test.py ├── chat.py └── chat_interpretability.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.1 2 | tokenizers>=0.13.2 -------------------------------------------------------------------------------- /ChatRWKV.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UnstoppableCurry/RWKV-LM-Interpretability-Research/HEAD/ChatRWKV.png -------------------------------------------------------------------------------- /RWKV-eval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UnstoppableCurry/RWKV-LM-Interpretability-Research/HEAD/RWKV-eval.png -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################## 2 | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM 3 | ######################################################################################################## 4 | 5 | import json, time, random, os 6 | import numpy as np 7 | import torch 8 | from torch.nn import functional as F 9 | from tokenizers import Tokenizer 10 | 11 | time_slot = {} 12 | time_ref = time.time() 13 | 14 | def record_time(name): 15 | if name not in time_slot: 16 | time_slot[name] = 1e20 17 | tt = (time.time() - time_ref) / 1e9 18 | if tt < time_slot[name]: 19 | time_slot[name] = tt 20 | 21 | class TOKENIZER(): 22 | def __init__(self, WORD_NAME): 23 | self.tokenizer = Tokenizer.from_file(WORD_NAME) 24 | 25 | def refine_context(self, context): 26 | context = context.strip().split('\n') 27 | for c in range(len(context)): 28 | context[c] = context[c].strip().strip('\u3000').strip('\r') 29 | context = list(filter(lambda c: c != '', context)) 30 | context = '\n' + ('\n'.join(context)).strip() 31 | if context == '': 32 | context = '\n' 33 | return context 34 | 35 | def encode(self, x): 36 | return self.tokenizer.encode(x).ids 37 | 38 | def decode(self, x): 39 | return self.tokenizer.decode(x) 40 | 41 | def sample_logits(self, logits, x, ctx_len, temperature=1.0, top_p=1.0): 42 | probs = F.softmax(logits.float(), dim=-1) 43 | 44 | if os.environ["RWKV_RUN_DEVICE"] == "cpu": 45 | probs = probs.numpy() 46 | sorted_probs = np.sort(probs)[::-1] 47 | cumulative_probs = np.cumsum(sorted_probs) 48 | cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) 49 | probs[probs < cutoff] = 0 50 | if temperature != 1.0: 51 | probs = probs.pow(1.0 / temperature) 52 | probs = probs / np.sum(probs) 53 | out = np.random.choice(a=len(probs), p=probs) 54 | return int(out) 55 | else: 56 | sorted_probs = torch.sort(probs, descending=True)[0] 57 | cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy() 58 | cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) 59 | probs[probs < cutoff] = 0 60 | if temperature != 1.0: 61 | probs = probs.pow(1.0 / temperature) 62 | out = torch.multinomial(probs, num_samples=1)[0] 63 | return int(out) 64 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################## 2 | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM 3 | ######################################################################################################## 4 | 5 | import os, sys, types, json, math 6 | try: 7 | os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1] 8 | except: 9 | pass 10 | import numpy as np 11 | np.set_printoptions(precision=4, suppress=True, linewidth=200) 12 | with open(f"misc/lambada_test.jsonl", "r", encoding="utf-8") as f: 13 | todo = [json.loads(line) for line in f] 14 | todo = [[doc['text'].rsplit(' ', 1)[0], " " + doc['text'].rsplit(' ', 1)[1]] for doc in todo] 15 | args = types.SimpleNamespace() 16 | 17 | ######################################################################################################## 18 | 19 | args.RUN_DEVICE = "cuda" # cuda / cpu 20 | args.FLOAT_MODE = "fp16" # fp16 / fp32 / bf16 21 | os.environ["RWKV_JIT_ON"] = '1' # 0 / 1 22 | args.ctx_len = 1024 23 | 24 | # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-14b/RWKV-4-Pile-14B-20230128-6782' 25 | # args.n_layer = 40 26 | # args.n_embd = 5120 27 | 28 | args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040' 29 | args.n_layer = 24 30 | args.n_embd = 2048 31 | 32 | # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-169m/RWKV-4-Pile-169M-20220807-8023' 33 | # args.n_layer = 12 34 | # args.n_embd = 768 35 | 36 | PAD_SEQ = [187] 37 | 38 | ######################################################################################################## 39 | 40 | print(f'\nLoading ChatRWKV - {args.RUN_DEVICE} - {args.FLOAT_MODE}') 41 | import torch 42 | torch.backends.cudnn.benchmark = True 43 | torch.backends.cudnn.allow_tf32 = True 44 | torch.backends.cuda.matmul.allow_tf32 = True 45 | from torch.nn import functional as F 46 | from src.model_run import RWKV_RNN 47 | from src.utils import TOKENIZER 48 | tokenizer = TOKENIZER("20B_tokenizer.json") 49 | 50 | args.vocab_size = 50277 51 | args.head_qk = 0 52 | args.pre_ffn = 0 53 | args.grad_cp = 0 54 | args.my_pos_emb = 0 55 | os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE 56 | 57 | MODEL_NAME = args.MODEL_NAME 58 | print(f'Loading model - {MODEL_NAME}') 59 | model = RWKV_RNN(args) 60 | 61 | print('Running...') 62 | xsum = 0 63 | xcnt = 0 64 | xacc = 0 65 | for d in todo: 66 | src = PAD_SEQ + tokenizer.encode(d[0]) 67 | dst = tokenizer.encode(d[1]) 68 | 69 | logits = 0 70 | correct = True 71 | for i in range(len(dst)): 72 | if i == 0: 73 | out, model_state = model.forward(src, None) 74 | else: 75 | out, model_state = model.forward([dst[i-1]], model_state) 76 | probs = F.softmax(out.float(), dim=-1) 77 | logits += math.log(probs[dst[i]]) 78 | _, s_index = torch.sort(probs, descending=True) 79 | pred = s_index[0].item() 80 | if pred != dst[i]: 81 | correct = False 82 | 83 | xcnt += 1 84 | xsum += logits 85 | xacc += 1 if correct else 0 86 | if xcnt % 100 == 0 or xcnt == len(todo): 87 | print(xcnt, 'ppl', round(math.exp(-xsum / xcnt), 2), 'acc', round(xacc/xcnt*100, 2))#, 'pred', pred, 'dst', dst) 88 | -------------------------------------------------------------------------------- /chat_onnx.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################## 2 | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM 3 | ######################################################################################################## 4 | 5 | import os, copy, types, gc, sys 6 | import numpy as np 7 | 8 | try: 9 | os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1] 10 | except: 11 | pass 12 | np.set_printoptions(precision=4, suppress=True, linewidth=200) 13 | args = types.SimpleNamespace() 14 | 15 | ######################################################################################################## 16 | 17 | args.RUN_DEVICE = "cuda" # cuda // cpu 18 | # fp16 (good for GPU, does NOT support CPU) // fp32 (good for CPU) // bf16 (worse accuracy, supports CPU) 19 | args.FLOAT_MODE = "fp16" 20 | 21 | os.environ[ 22 | "RWKV_JIT_ON"] = '0' # '1' or '0'. very useful for fp32, but might be harmful for GPU fp16. please benchmark !!! 23 | 24 | CHAT_LANG = 'Chinese' # English // Chinese // more to come 25 | 26 | QA_PROMPT = False # True: Q & A prompt // False: User & Bot prompt 27 | # 中文问答设置QA_PROMPT=True(只能问答,问答效果更好,但不能闲聊) 中文聊天设置QA_PROMPT=False(可以闲聊,但需要大模型才适合闲聊) 28 | 29 | # Download RWKV-4 models from https://huggingface.co/BlinkDL 30 | 31 | if CHAT_LANG == 'English': 32 | # args.MODEL_NAME = '/www/model/rwkv/RWKV-4-Pile-14B-20230204-7324' 33 | args.MODEL_NAME = '/www/model/rwkv/RWKV-4-Pile-7B-20220911-79' 34 | # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-14b/RWKV-4-Pile-14B-20230204-7324' 35 | # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047' 36 | # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221110-ctx4096' 37 | # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-Instruct-test1-20230124' 38 | # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040' 39 | # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-430m/RWKV-4-Pile-430M-20220808-8066' 40 | # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-169m/RWKV-4-Pile-169M-20220807-8023' 41 | # args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/7-run1z/rwkv-340' 42 | # args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/14b-run1/rwkv-6210' 43 | 44 | elif CHAT_LANG == 'Chinese': 45 | args.MODEL_NAME = '/www/model/rwkv/RWKV-4-Pile-1B5-EngChn-test4-20230115' 46 | # args.MODEL_NAME = '/www/model/rwkv/RWKV-4-Pile-7B-EngChn-test4-20230116' 47 | # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-EngChn-test4-20230115' 48 | # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-EngChn-test4-20230115' 49 | # args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/7-run1z/rwkv-490' 50 | # args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/1.5-run1z/rwkv-415' 51 | 52 | if '-169M-' in args.MODEL_NAME: 53 | args.n_layer = 12 54 | args.n_embd = 768 55 | if '-430M-' in args.MODEL_NAME: 56 | args.n_layer = 24 57 | args.n_embd = 1024 58 | if '-1B5-' in args.MODEL_NAME or '/1.5-' in args.MODEL_NAME: 59 | args.n_layer = 24 60 | args.n_embd = 2048 61 | elif '-3B-' in args.MODEL_NAME or '/3-' in args.MODEL_NAME: 62 | args.n_layer = 32 63 | args.n_embd = 2560 64 | elif '-7B-' in args.MODEL_NAME or '/7-' in args.MODEL_NAME: 65 | args.n_layer = 32 66 | args.n_embd = 4096 67 | elif '-14B-' in args.MODEL_NAME or '/14-' in args.MODEL_NAME or '/14b-' in args.MODEL_NAME: 68 | args.n_layer = 40 69 | args.n_embd = 5120 70 | 71 | args.ctx_len = 1024 72 | 73 | CHAT_LEN_SHORT = 40 74 | CHAT_LEN_LONG = 150 75 | FREE_GEN_LEN = 200 76 | 77 | GEN_TEMP = 1.0 78 | GEN_TOP_P = 0.85 79 | 80 | AVOID_REPEAT = ',。:?!' 81 | 82 | ######################################################################################################## 83 | 84 | print(f'\nLoading ChatRWKV - {CHAT_LANG} - {args.RUN_DEVICE} - {args.FLOAT_MODE} - QA_PROMPT {QA_PROMPT}') 85 | import torch 86 | 87 | torch.backends.cudnn.benchmark = True 88 | torch.backends.cudnn.allow_tf32 = True 89 | torch.backends.cuda.matmul.allow_tf32 = True 90 | from src.model_run import RWKV_RNN 91 | from src.utils import TOKENIZER 92 | 93 | tokenizer = TOKENIZER("20B_tokenizer.json") 94 | 95 | args.vocab_size = 50277 96 | args.head_qk = 0 97 | args.pre_ffn = 0 98 | args.grad_cp = 0 99 | args.my_pos_emb = 0 100 | os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE 101 | MODEL_NAME = args.MODEL_NAME 102 | 103 | # Load Model 104 | 105 | print(f'Loading model - {MODEL_NAME}') 106 | model = RWKV_RNN(args) 107 | 108 | out_onnx = 'RWKV-4-Pile-7B-20220911-79.onnx' 109 | x = torch.tensor([float(c) for c in range(1000)]).type(torch.long).to('cuda') 110 | # x1 = torch.randn((160, 1096)).type(torch.long).to('cuda') 111 | x1=None 112 | # x3 = [0] 113 | 114 | # input_names = [ 'tokens', 'state', 'preprocess_only' ] 115 | input_names = ['tokens', 'state'] 116 | output_names = ["output", 'model_state'] 117 | # torch_out = torch.onnx.export(model, (x, x1), out_onnx, input_names=input_names, 118 | # output_names=output_names) 119 | 120 | # traced_script_module = torch.jit.trace(model, x) 121 | traced_script_module = torch.jit.script(model, x) 122 | 123 | # traced_script_module.save('./') 124 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Project code from ChatRWKV introduce 2 | 3 | 4 | ChatRWKV is like ChatGPT but powered by my RWKV (100% RNN) language model, which is the only RNN (as of now) that can match transformers in quality and scaling, while being faster and saves VRAM. Training sponsored by Stability EleutherAI :) 5 | 6 | **Download RWKV-4 weights:** https://huggingface.co/BlinkDL (**Use RWKV-4 models**. DO NOT use RWKV-4a and RWKV-4b models.) 7 | 8 | **RWKV LM:** https://github.com/BlinkDL/RWKV-LM (explanation, fine-tuning, training, etc.) 9 | 10 | **RWKV Discord:** https://discord.gg/bDSBUMeFpc (let's build together) 11 | 12 | ![RWKV-eval](RWKV-eval.png) 13 | 14 | 15 | 16 | # 可视化使用方法 17 | 18 | # 运行chat.py 进行对话,对话一次保存一次 19 | # chat.py 的48行 进行配置保存视频路径以及视频尺度和帧率 videoWrite = cv2.VideoWriter(r'test.mp4', fourcc, 8, (1000, 1000)) 20 | # python chat.py 21 | 22 | 23 | # 7B 模型ffn层可视化 24 | 25 | ![请添加图片描述](https://img-blog.csdnimg.cn/aa4eb63d3ce640079f882f6f7640d8c1.gif) 26 | 27 | [详细介绍 RWKV 语言模型:具有 Transformer 优点的 CNN](https://blog.csdn.net/weixin_49139876/article/details/129869814?csdn_share_tail=%7B%22type%22:%22blog%22,%22rType%22:%22article%22,%22rId%22:%22129869814%22,%22source%22:%22weixin_49139876%22%7D) 28 | # 前言 29 | 首先 为了对新手友好 简单讲一下什么是 RWKV LM, 这是一个开源LLM语言模型项目 .其性能与生态支持很好 例如:cuda重写了python代码 效率更高,有web 项目匹配,有chatrwkv,也可以运行在移动设备上 良好的跨平台能力,以及节省资源的特点 . 30 | 是目前唯一一个使用CNN也能达到同等transformer 类模型性能的 模型 31 | 32 | ## 最近研发群里面看到大家在讨论 脑电图 33 | 大伙中有人想看RWKV fnn层响应可视化, 然后作者就发了一个效果图 34 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/b600bc21be514c65811e0319522aa3e2.gif) 35 | 显而易见的 看到 明显的异常值 大很多, 有人管这个叫做脑电图, 其实差不多 只不过不是特殊些的连续信号. 相信大老们肯定对此熟悉,但是对我来说第一次见到,非常好奇. 这种异常值统称outlier离群点. 而后另外一个佬做了一下430M模型各层的可视化 36 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/f070531343cf443b84995ce00961e0b1.gif) 37 | 这个就更加离谱了, 离群点倍越往后就越大. 那么就好奇 如果这样的outline对语言模型很重要,而模型还拥有涌现能力与顿悟, 那么他们之间有什么联系? 38 | 39 | 1.首先我们阅读一下源码 40 | 41 | ```bash 42 | def forward(self, tokens, state=None, preprocess_only = False): 43 | # tokens=tokens.to('cpu').numpy().tolist() 44 | # if tokens is not None: 45 | # print('input_tokens', len(tokens)) 46 | # if state is not None: 47 | # print(' input state',state.shape) 48 | with torch.no_grad(): 49 | w = self.w 50 | args = self.args 51 | 52 | seq_mode = len(tokens) > 1 53 | 54 | x = w.emb.weight[tokens] if seq_mode else w.emb.weight[tokens[-1]] 55 | if self.RUN_DEVICE == 'cuda': 56 | x = x.cuda() 57 | 58 | if state == None: 59 | state = torch.zeros(args.n_layer * 5, args.n_embd, device=self.RUN_DEVICE) 60 | for i in range(args.n_layer): 61 | state[5*i+4] -= 1e30 62 | 63 | SA = self.SA_seq if seq_mode else self.SA_one 64 | FF = self.FF_seq if seq_mode else self.FF_one 65 | all_ffn_out=[] 66 | for i in range(args.n_layer): 67 | ww = w.blocks[i].att 68 | x = x + SA(self.LN(x, w.blocks[i].ln1), state, i, 69 | ww.time_mix_k, ww.time_mix_v, ww.time_mix_r, ww.time_first, ww.time_decay, 70 | ww.key.weight, ww.value.weight, ww.receptance.weight, ww.output.weight) 71 | 72 | ww = w.blocks[i].ffn 73 | ffn_out=FF(self.LN(x, w.blocks[i].ln2), state, i, 74 | ww.time_mix_k, ww.time_mix_r, 75 | ww.key.weight, ww.value.weight, ww.receptance.weight) 76 | x = x + ffn_out 77 | all_ffn_out.append(ffn_out) 78 | # print('ffn->',x) 79 | if (i+1) % RWKV_RESCALE_LAYER == 0: 80 | x = x / 2 81 | 82 | if preprocess_only: 83 | return state 84 | 85 | x = self.LN(x[-1,:], w.ln_out) if seq_mode else self.LN(x, w.ln_out) 86 | x = w.head.weight @ x 87 | # print('output',x.shape,' out put state',state.shape) 88 | return x.float(),all_ffn_out, state 89 | ``` 90 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/3c29bcf3c97a48c5a5f1b5973659b4bb.png) 91 | ![在这里插入图片描述](https://img-blog.csdnimg.cn/4304a1a763274dfd99f30db062823593.png) 92 | 这里注意到ffn层 输出后 会被layernorm 泛化,但即使这样 也会有outlier 这就太奇怪了. 93 | 2.不同大小的模型可视化 94 | 为了探讨不同模型大小对outlier的影响我做了后续的可视化 95 | ![1B5](https://user-images.githubusercontent.com/65523997/233818846-a5d7cd39-6ba9-421a-873e-58d341adfd0c.gif) 96 | 1b5 fp32 97 | 98 | ![1b5fp16](https://user-images.githubusercontent.com/65523997/233818854-7f62eb14-71e4-47ff-a683-25c709b6c6d4.gif) 99 | 1b5 fp16 100 | 101 | ![7b fp16](https://user-images.githubusercontent.com/65523997/233818859-a0a04c46-528f-4453-8981-a6c4f79c4fad.gif) 102 | 7b fp16 103 | ## 单从可视化结果中可以得到3个结论 104 | - 结论一: 模型越大 outlier 情况越少 105 | - 结论二 量化 或者说精度的改变会影响 outlier 106 | - 结论三 outlier 的情况会直接影响模型的性能(并不严谨) 107 | 如何得出 108 | ## 如何得出结论 以下为可视化时控制变量的情况 109 | - 全部输入都一致 为:` What is python? Please tell me in detail what it is, how to use it, how to learn it, and what advantages and disadvantages it has.` 110 | - 模型输出不可控 111 | - 结论一; 对比7b fp16 与 1b5 fp16 的模型可视化可直观看出 7b的outlier更小 112 | - 结论二,对比1b5 fp16 与 1b5 fp32 相同输入的情况下,多次试验结果都表明fp32 的输出比fp16 的输出更长 ,效果更好 猜测做量化的对齐 会保证性能的一致 113 | - 结论三 对于1b5模型而言(7b fp32 3090单卡跑不了 后续会补上) 不同的精度会导致outlier 情况有区别, 但是是否对模型性能有影响 要在后面的试验中来演示 114 | 115 | ## 进行可量化的评估试验 116 | #### 思路 117 | - 1.通过使用KL散度 来评估相同情况下 模型大小,精度,不同输入 对outlier现象的影响 118 | - 2.通过对outlier的操作(剔除,规范化等)来验证 不同层,大小 的outlier 的作用是什么 119 | - 3.使用均方差来 评估不同模型大小,精度, ffn不同层 响应的离散程度 (outlier 的严重情况) 120 | - 4.对目前开源的LLM模型进行相同的操作 来验证outlier 到底在干什么 121 | - 5.目前没有太好的思路来验证 涌现与顿悟 跟outlier之间的关系,这应该要在训练的过程中来统计outlier的情况 并根据其他多任务的评估指标来评价 outler 到达何种程度后 顿悟/涌现 就出现了. 122 | 123 | 124 | -------------------------------------------------------------------------------- /src/model_run.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################## 2 | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM 3 | ######################################################################################################## 4 | 5 | import types, math, os, gc 6 | import torch 7 | from torch.nn import functional as F 8 | torch.backends.cudnn.benchmark = True 9 | torch.backends.cudnn.allow_tf32 = True 10 | torch.backends.cuda.matmul.allow_tf32 = True 11 | 12 | MyModule = torch.nn.Module 13 | def __nop(ob): 14 | return ob 15 | MyFunction = __nop 16 | 17 | if os.environ["RWKV_JIT_ON"] == "0": 18 | MyModule = torch.jit.ScriptModule 19 | MyFunction = torch.jit.script_method 20 | 21 | print(f'\nRWKV_JIT_ON {os.environ["RWKV_JIT_ON"]}\n') 22 | 23 | RWKV_RESCALE_LAYER = 6 # set x = x/2 every X layer (to avoid FP16 overflow) 24 | 25 | ############################################################################################################ 26 | 27 | class RWKV_RNN(MyModule): 28 | def __init__(self, args): 29 | super().__init__() 30 | 31 | self.args = args 32 | if args.FLOAT_MODE == 'fp32': 33 | self.FLOAT_MODE = torch.float 34 | elif args.FLOAT_MODE == 'fp16': 35 | self.FLOAT_MODE = torch.half 36 | elif args.FLOAT_MODE == 'bf16': 37 | self.FLOAT_MODE = torch.bfloat16 38 | self.RUN_DEVICE = args.RUN_DEVICE 39 | 40 | with torch.no_grad(): 41 | w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu') 42 | keys = list(w.keys()) # refine weights and send to correct device 43 | print_need_newline = False 44 | for x in keys: 45 | w[x].requires_grad = False 46 | if x == 'emb.weight' or 'ln0' in x: 47 | continue 48 | 49 | block_id = int(x.split('.')[1]) if ('blocks.' in x) else 0 50 | 51 | if '.time_' in x: 52 | w[x] = w[x].squeeze() 53 | if 'key.weight' in x or 'value.weight' in x or 'receptance.weight' in x or 'output.weight' in x: 54 | w[x] = w[x].t() 55 | 56 | if '.time_decay' in x: 57 | w[x] = w[x].float() 58 | w[x] = -torch.exp(w[x]) 59 | elif '.time_first' in x: 60 | w[x] = w[x].float() 61 | else: 62 | w[x] = w[x].to(dtype=self.FLOAT_MODE) 63 | 64 | if 'att.output.weight' in x: 65 | w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER)) 66 | if 'ffn.value.weight' in x: 67 | w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER)) 68 | 69 | if args.RUN_DEVICE == 'cuda': 70 | w[x] = w[x].cuda() 71 | 72 | if block_id == 0: 73 | if print_need_newline: 74 | print('\n', end = '') 75 | print_need_newline = False 76 | print(x.ljust(35), str(w[x].dtype).replace('torch.', '').ljust(10), w[x].device) 77 | else: 78 | print_need_newline = True 79 | print('.', end = '', flush = True) 80 | 81 | keys = list(w.keys()) # store weights in self.w 82 | self.w = types.SimpleNamespace() 83 | for x in keys: 84 | xx = x.split('.') 85 | here = self.w 86 | for i in range(len(xx)): 87 | if xx[i].isdigit(): 88 | ii = int(xx[i]) 89 | if ii not in here: 90 | here[ii] = types.SimpleNamespace() 91 | here = here[ii] 92 | else: 93 | if i == len(xx) - 1: 94 | setattr(here, xx[i], w[x]) 95 | elif not hasattr(here, xx[i]): 96 | if xx[i+1].isdigit(): 97 | setattr(here, xx[i], {}) 98 | else: 99 | setattr(here, xx[i], types.SimpleNamespace()) 100 | here = getattr(here, xx[i]) 101 | 102 | with torch.no_grad(): # precompute embedding 103 | try: 104 | x = self.LN(self.w.emb.weight, self.w.blocks[0].ln0) 105 | except: 106 | x = F.layer_norm(self.w.emb.weight.float(), (self.args.n_embd,), weight=self.w.blocks[0].ln0.weight.float(), bias=self.w.blocks[0].ln0.bias.float()) 107 | self.w.emb.weight = x.to(dtype=self.FLOAT_MODE) 108 | 109 | self.eval() 110 | gc.collect() 111 | torch.cuda.empty_cache() 112 | 113 | def LN(self, x, w): 114 | return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias) 115 | 116 | # state[] 0=ffn_xx 1=att_xx 2=att_aa 3=att_bb 4=att_pp 117 | 118 | @MyFunction 119 | def FF_one(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw): 120 | xx = state[5*i+0].to(dtype=self.FLOAT_MODE) 121 | xk = x * time_mix_k + xx * (1 - time_mix_k) 122 | xr = x * time_mix_r + xx * (1 - time_mix_r) 123 | state[5*i+0] = x.float() 124 | 125 | r = torch.sigmoid(xr @ rw) 126 | k = torch.square(torch.relu(xk @ kw)) 127 | kv = k @ vw 128 | return r * kv 129 | 130 | @MyFunction 131 | def FF_seq(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw): 132 | xx = torch.cat((state[5*i+0].to(dtype=self.FLOAT_MODE).unsqueeze(0), x[:-1,:])) 133 | xk = x * time_mix_k + xx * (1 - time_mix_k) 134 | xr = x * time_mix_r + xx * (1 - time_mix_r) 135 | state[5*i+0] = x[-1,:].float() 136 | 137 | r = torch.sigmoid(xr @ rw) 138 | k = torch.square(torch.relu(xk @ kw)) 139 | kv = k @ vw 140 | return r * kv 141 | 142 | @MyFunction 143 | def SA_one(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow): 144 | xx = state[5*i+1].to(dtype=self.FLOAT_MODE) 145 | xk = x * time_mix_k + xx * (1 - time_mix_k) 146 | xv = x * time_mix_v + xx * (1 - time_mix_v) 147 | xr = x * time_mix_r + xx * (1 - time_mix_r) 148 | state[5*i+1] = x.float() 149 | 150 | r = torch.sigmoid(xr @ rw) 151 | k = (xk @ kw).float() 152 | v = (xv @ vw).float() 153 | 154 | aa = state[5*i+2] 155 | bb = state[5*i+3] 156 | pp = state[5*i+4] 157 | ww = time_first + k 158 | p = torch.maximum(pp, ww) 159 | e1 = torch.exp(pp - p) 160 | e2 = torch.exp(ww - p) 161 | a = e1 * aa + e2 * v 162 | b = e1 * bb + e2 163 | ww = pp + time_decay 164 | p = torch.maximum(ww, k) 165 | e1 = torch.exp(ww - p) 166 | e2 = torch.exp(k - p) 167 | state[5*i+2] = e1 * aa + e2 * v 168 | state[5*i+3] = e1 * bb + e2 169 | state[5*i+4] = p 170 | wkv = (a / b).to(dtype=self.FLOAT_MODE) 171 | return (r * wkv) @ ow 172 | 173 | @MyFunction 174 | def SA_seq(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow): 175 | xx = torch.cat((state[5*i+1].to(dtype=self.FLOAT_MODE).unsqueeze(0), x[:-1,:])) 176 | xk = x * time_mix_k + xx * (1 - time_mix_k) 177 | xv = x * time_mix_v + xx * (1 - time_mix_v) 178 | xr = x * time_mix_r + xx * (1 - time_mix_r) 179 | state[5*i+1] = x[-1,:].float() 180 | 181 | r = torch.sigmoid(xr @ rw) 182 | k = (xk @ kw).float() 183 | v = (xv @ vw).float() 184 | 185 | aa = state[5*i+2] 186 | bb = state[5*i+3] 187 | pp = state[5*i+4] 188 | T = x.shape[0] 189 | for t in range(T): 190 | ww = time_first + k[t] 191 | p = torch.maximum(pp, ww) 192 | e1 = torch.exp(pp - p) 193 | e2 = torch.exp(ww - p) 194 | a = e1 * aa + e2 * v[t] 195 | b = e1 * bb + e2 196 | ww = pp + time_decay 197 | p = torch.maximum(ww, k[t]) 198 | e1 = torch.exp(ww - p) 199 | e2 = torch.exp(k[t] - p) 200 | if t != T - 1: 201 | aa = e1 * aa + e2 * v[t] 202 | bb = e1 * bb + e2 203 | pp = p 204 | else: 205 | state[5*i+2] = e1 * aa + e2 * v[t] 206 | state[5*i+3] = e1 * bb + e2 207 | state[5*i+4] = p 208 | xx[t] = (a / b).to(dtype=self.FLOAT_MODE) 209 | return (r * xx) @ ow 210 | 211 | def forward(self, tokens, state=None, preprocess_only = False): 212 | # tokens=tokens.to('cpu').numpy().tolist() 213 | # if tokens is not None: 214 | # print('input_tokens', len(tokens)) 215 | # if state is not None: 216 | # print(' input state',state.shape) 217 | with torch.no_grad(): 218 | w = self.w 219 | args = self.args 220 | 221 | seq_mode = len(tokens) > 1 222 | 223 | x = w.emb.weight[tokens] if seq_mode else w.emb.weight[tokens[-1]] 224 | if self.RUN_DEVICE == 'cuda': 225 | x = x.cuda() 226 | 227 | if state == None: 228 | state = torch.zeros(args.n_layer * 5, args.n_embd, device=self.RUN_DEVICE) 229 | for i in range(args.n_layer): 230 | state[5*i+4] -= 1e30 231 | 232 | SA = self.SA_seq if seq_mode else self.SA_one 233 | FF = self.FF_seq if seq_mode else self.FF_one 234 | all_ffn_out=[] 235 | for i in range(args.n_layer): 236 | ww = w.blocks[i].att 237 | x = x + SA(self.LN(x, w.blocks[i].ln1), state, i, 238 | ww.time_mix_k, ww.time_mix_v, ww.time_mix_r, ww.time_first, ww.time_decay, 239 | ww.key.weight, ww.value.weight, ww.receptance.weight, ww.output.weight) 240 | 241 | ww = w.blocks[i].ffn 242 | ffn_out=FF(self.LN(x, w.blocks[i].ln2), state, i, 243 | ww.time_mix_k, ww.time_mix_r, 244 | ww.key.weight, ww.value.weight, ww.receptance.weight) 245 | x = x + ffn_out 246 | all_ffn_out.append(ffn_out) 247 | # print('ffn->',x) 248 | if (i+1) % RWKV_RESCALE_LAYER == 0: 249 | x = x / 2 250 | 251 | if preprocess_only: 252 | return state 253 | 254 | x = self.LN(x[-1,:], w.ln_out) if seq_mode else self.LN(x, w.ln_out) 255 | x = w.head.weight @ x 256 | # print('output',x.shape,' out put state',state.shape) 257 | return x.float(),all_ffn_out, state 258 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """Convert RWKV PyTorch savepoint to TorchScript model. 2 | """ 3 | from typing import NamedTuple, List, Optional, Final 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import click 10 | 11 | 12 | class LayerNorm(NamedTuple): 13 | weight: nn.Parameter 14 | bias: nn.Parameter 15 | 16 | 17 | class ATT(NamedTuple): 18 | time_mix_k: nn.Parameter 19 | time_mix_v: nn.Parameter 20 | time_mix_r: nn.Parameter 21 | time_first: nn.Parameter 22 | time_decay: nn.Parameter 23 | key: nn.Parameter 24 | value: nn.Parameter 25 | receptance: nn.Parameter 26 | output: nn.Parameter 27 | 28 | 29 | class FFN(NamedTuple): 30 | time_mix_k: nn.Parameter 31 | time_mix_r: nn.Parameter 32 | key: nn.Parameter 33 | value: nn.Parameter 34 | receptance: nn.Parameter 35 | 36 | 37 | class Block(NamedTuple): 38 | att: ATT 39 | ffn: FFN 40 | ln1: LayerNorm 41 | ln2: LayerNorm 42 | 43 | 44 | class Weight(NamedTuple): 45 | emb: nn.Parameter 46 | blocks: List[Block] 47 | ln0: LayerNorm 48 | ln_out: LayerNorm 49 | head: nn.Parameter 50 | 51 | 52 | class RWKV_RNN_JIT(nn.Module): 53 | float_mode: Final [torch.dtype] 54 | n_layer: Final[int] 55 | n_embd: Final[int] 56 | device: Final[torch.device] 57 | 58 | RWKV_RESCALE_LAYER: final[int] = 6 59 | 60 | weight: Weight 61 | 62 | def __init__( 63 | self, 64 | *, 65 | model_path: str, 66 | float_mode: torch.dtype, 67 | device: torch.device, 68 | ): 69 | super().__init__() 70 | 71 | self.float_mode = float_mode 72 | self.device = device 73 | 74 | with torch.no_grad(): 75 | w = torch.load(model_path, map_location="cpu") 76 | n_embd = w["emb.weight"].shape[1] 77 | n_layer = 0 78 | 79 | keys = list(w.keys()) 80 | print_need_newline = False 81 | print(keys) 82 | 83 | for x in keys: 84 | w[x].requires_grad = False 85 | if x == "emb.weight" or "ln0" in x: 86 | continue 87 | 88 | block_id = int(x.split(".")[1]) if ("blocks." in x) else 0 89 | n_layer = max(n_layer, block_id + 1) 90 | 91 | if ".time_" in x: 92 | w[x] = w[x].squeeze() 93 | if ( 94 | "key.weight" in x 95 | or "value.weight" in x 96 | or "receptance.weight" in x 97 | or "output.weight" in x 98 | ): 99 | w[x] = w[x].t() 100 | 101 | if ".time_decay" in x: 102 | w[x] = w[x].float() 103 | w[x] = -torch.exp(w[x]) 104 | elif ".time_first" in x: 105 | w[x] = w[x].float() 106 | else: 107 | w[x] = w[x].to(dtype=self.float_mode) 108 | 109 | if float_mode == torch.float16: 110 | if "att.output.weight" in x: 111 | w[x] = w[x] / (2 ** int(block_id // self.RWKV_RESCALE_LAYER)) 112 | if "ffn.value.weight" in x: 113 | w[x] = w[x] / (2 ** int(block_id // self.RWKV_RESCALE_LAYER)) 114 | 115 | w[x] = w[x].to(device) 116 | 117 | shape = w[x].shape 118 | shape = [i for i in shape if i != 1] 119 | if len(shape) > 1: 120 | shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)}" 121 | else: 122 | shape = f" {str(shape[0]).rjust(5)} " 123 | if block_id == 0: 124 | if print_need_newline: 125 | print("\n", end="") 126 | print_need_newline = False 127 | print( 128 | x.ljust(32), 129 | str(w[x].dtype).replace("torch.", "").ljust(10), 130 | w[x].device, 131 | shape, 132 | ) 133 | else: 134 | print_need_newline = True 135 | print(".", end="", flush=True) 136 | print() 137 | print('n_layer ',n_layer,'n_embd',n_embd) 138 | self.n_layer = n_layer 139 | self.n_embd = n_embd 140 | 141 | emb = w["emb.weight"] 142 | ln_out = LayerNorm(w["ln_out.weight"], w["ln_out.bias"]) 143 | ln0 = LayerNorm(w["blocks.0.ln0.weight"], w["blocks.0.ln0.bias"]) 144 | head = w["head.weight"] 145 | blocks = [ 146 | Block( 147 | att=ATT( 148 | time_mix_k=w[f"blocks.{i}.att.time_mix_k"], 149 | time_mix_v=w[f"blocks.{i}.att.time_mix_v"], 150 | time_mix_r=w[f"blocks.{i}.att.time_mix_r"], 151 | time_first=w[f"blocks.{i}.att.time_first"], 152 | time_decay=w[f"blocks.{i}.att.time_decay"], 153 | key=w[f"blocks.{i}.att.key.weight"], 154 | value=w[f"blocks.{i}.att.value.weight"], 155 | receptance=w[f"blocks.{i}.att.receptance.weight"], 156 | output=w[f"blocks.{i}.att.output.weight"], 157 | ), 158 | ffn=FFN( 159 | time_mix_k=w[f"blocks.{i}.ffn.time_mix_k"], 160 | time_mix_r=w[f"blocks.{i}.ffn.time_mix_r"], 161 | key=w[f"blocks.{i}.ffn.key.weight"], 162 | value=w[f"blocks.{i}.ffn.value.weight"], 163 | receptance=w[f"blocks.{i}.ffn.receptance.weight"], 164 | ), 165 | ln1=LayerNorm(w[f"blocks.{i}.ln1.weight"], w[f"blocks.{i}.ln1.bias"]), 166 | ln2=LayerNorm(w[f"blocks.{i}.ln2.weight"], w[f"blocks.{i}.ln2.bias"]), 167 | ) 168 | for i in range(self.n_layer) 169 | ] 170 | 171 | with torch.no_grad(): # precompute embedding 172 | x = self.LN(emb, ln0) 173 | emb = x.to(dtype=self.float_mode) 174 | 175 | self.weight = Weight(emb, blocks, ln0, ln_out, head) 176 | 177 | def LN(self, x, w: LayerNorm): 178 | return F.layer_norm(x, (self.n_embd,), weight=w.weight, bias=w.bias) 179 | 180 | def FF_one(self, x, state, i: int, time_mix_k, time_mix_r, kw, vw, rw): 181 | xx = state[5 * i + 0].to(dtype=self.float_mode) 182 | xk = x * time_mix_k + xx * (1 - time_mix_k) 183 | xr = x * time_mix_r + xx * (1 - time_mix_r) 184 | state[5 * i + 0] = x.float() 185 | 186 | r = torch.sigmoid(xr @ rw) 187 | k = torch.square(torch.relu(xk @ kw)) 188 | kv = k @ vw 189 | return r * kv 190 | 191 | def FF_seq(self, x, state, i: int, time_mix_k, time_mix_r, kw, vw, rw): 192 | xx = torch.cat( 193 | (state[5 * i + 0].to(dtype=self.float_mode).unsqueeze(0), x[:-1, :]) 194 | ) 195 | xk = x * time_mix_k + xx * (1 - time_mix_k) 196 | xr = x * time_mix_r + xx * (1 - time_mix_r) 197 | state[5 * i + 0] = x[-1, :].float() 198 | 199 | r = torch.sigmoid(xr @ rw) 200 | k = torch.square(torch.relu(xk @ kw)) 201 | kv = k @ vw 202 | return r * kv 203 | 204 | def SA_one( 205 | self, 206 | x, 207 | state, 208 | i: int, 209 | time_mix_k, 210 | time_mix_v, 211 | time_mix_r, 212 | time_first, 213 | time_decay, 214 | kw, 215 | vw, 216 | rw, 217 | ow, 218 | ): 219 | xx = state[5 * i + 1].to(dtype=self.float_mode) 220 | xk = x * time_mix_k + xx * (1 - time_mix_k) 221 | xv = x * time_mix_v + xx * (1 - time_mix_v) 222 | xr = x * time_mix_r + xx * (1 - time_mix_r) 223 | state[5 * i + 1] = x.float() 224 | 225 | r = torch.sigmoid(xr @ rw) 226 | k = (xk @ kw).float() 227 | v = (xv @ vw).float() 228 | 229 | aa = state[5 * i + 2] 230 | bb = state[5 * i + 3] 231 | pp = state[5 * i + 4] 232 | ww = time_first + k 233 | p = torch.maximum(pp, ww) 234 | e1 = torch.exp(pp - p) 235 | e2 = torch.exp(ww - p) 236 | a = e1 * aa + e2 * v 237 | b = e1 * bb + e2 238 | ww = pp + time_decay 239 | p = torch.maximum(ww, k) 240 | e1 = torch.exp(ww - p) 241 | e2 = torch.exp(k - p) 242 | state[5 * i + 2] = e1 * aa + e2 * v 243 | state[5 * i + 3] = e1 * bb + e2 244 | state[5 * i + 4] = p 245 | wkv = (a / b).to(dtype=self.float_mode) 246 | return (r * wkv) @ ow 247 | 248 | def SA_seq( 249 | self, 250 | x, 251 | state, 252 | i: int, 253 | time_mix_k, 254 | time_mix_v, 255 | time_mix_r, 256 | time_first, 257 | time_decay, 258 | kw, 259 | vw, 260 | rw, 261 | ow, 262 | ): 263 | xx = torch.cat( 264 | (state[5 * i + 1].to(dtype=self.float_mode).unsqueeze(0), x[:-1, :]) 265 | ) 266 | xk = x * time_mix_k + xx * (1 - time_mix_k) 267 | xv = x * time_mix_v + xx * (1 - time_mix_v) 268 | xr = x * time_mix_r + xx * (1 - time_mix_r) 269 | state[5 * i + 1] = x[-1, :].float() 270 | 271 | r = torch.sigmoid(xr @ rw) 272 | k = (xk @ kw).float() 273 | v = (xv @ vw).float() 274 | 275 | aa = state[5 * i + 2] 276 | bb = state[5 * i + 3] 277 | pp = state[5 * i + 4] 278 | T = x.shape[0] 279 | for t in range(T): 280 | ww = time_first + k[t] 281 | p = torch.maximum(pp, ww) 282 | e1 = torch.exp(pp - p) 283 | e2 = torch.exp(ww - p) 284 | a = e1 * aa + e2 * v[t] 285 | b = e1 * bb + e2 286 | ww = pp + time_decay 287 | p = torch.maximum(ww, k[t]) 288 | e1 = torch.exp(ww - p) 289 | e2 = torch.exp(k[t] - p) 290 | if t != T - 1: 291 | aa = e1 * aa + e2 * v[t] 292 | bb = e1 * bb + e2 293 | pp = p 294 | else: 295 | state[5 * i + 2] = e1 * aa + e2 * v[t] 296 | state[5 * i + 3] = e1 * bb + e2 297 | state[5 * i + 4] = p 298 | xx[t] = (a / b).to(dtype=self.float_mode) 299 | return (r * xx) @ ow 300 | 301 | def FF( 302 | self, 303 | x, 304 | state, 305 | i: int, 306 | time_mix_k, 307 | time_mix_r, 308 | kw, 309 | vw, 310 | rw, 311 | *, 312 | seq_mode: bool, 313 | ): 314 | if seq_mode: 315 | return self.FF_seq(x, state, i, time_mix_k, time_mix_r, kw, vw, rw) 316 | else: 317 | return self.FF_one(x, state, i, time_mix_k, time_mix_r, kw, vw, rw) 318 | 319 | def SA( 320 | self, 321 | x, 322 | state, 323 | i: int, 324 | time_mix_k, 325 | time_mix_v, 326 | time_mix_r, 327 | time_first, 328 | time_decay, 329 | kw, 330 | vw, 331 | rw, 332 | ow, 333 | *, 334 | seq_mode: bool, 335 | ): 336 | if seq_mode: 337 | return self.SA_seq( 338 | x, 339 | state, 340 | i, 341 | time_mix_k, 342 | time_mix_v, 343 | time_mix_r, 344 | time_first, 345 | time_decay, 346 | kw, 347 | vw, 348 | rw, 349 | ow, 350 | ) 351 | else: 352 | return self.SA_one( 353 | x, 354 | state, 355 | i, 356 | time_mix_k, 357 | time_mix_v, 358 | time_mix_r, 359 | time_first, 360 | time_decay, 361 | kw, 362 | vw, 363 | rw, 364 | ow, 365 | ) 366 | 367 | def forward( 368 | self, 369 | tokens: List[int], 370 | state: Optional[torch.Tensor], 371 | # state_is_none: bool, 372 | preprocess_only: bool = False, 373 | ): 374 | with torch.no_grad(): 375 | w = self.weight 376 | 377 | seq_mode = len(tokens) > 1 378 | 379 | x = w.emb[tokens] if seq_mode else w.emb[tokens[-1]] 380 | x = x.to(self.device) 381 | 382 | if state is None: 383 | state = torch.zeros(self.n_layer * 5, self.n_embd, device=self.device) 384 | for i in range(self.n_layer): 385 | state[5 * i + 4] -= 1e30 386 | 387 | for i in range(self.n_layer): 388 | ww = w.blocks[i].att 389 | x = x + self.SA( 390 | self.LN(x, w.blocks[i].ln1), 391 | state, 392 | i, 393 | ww.time_mix_k, 394 | ww.time_mix_v, 395 | ww.time_mix_r, 396 | ww.time_first, 397 | ww.time_decay, 398 | ww.key, 399 | ww.value, 400 | ww.receptance, 401 | ww.output, 402 | seq_mode=seq_mode, 403 | ) 404 | 405 | ww = w.blocks[i].ffn 406 | x = x + self.FF( 407 | self.LN(x, w.blocks[i].ln2), 408 | state, 409 | i, 410 | ww.time_mix_k, 411 | ww.time_mix_r, 412 | ww.key, 413 | ww.value, 414 | ww.receptance, 415 | seq_mode=seq_mode, 416 | ) 417 | 418 | if ( 419 | self.float_mode == torch.float16 420 | and (i + 1) % self.RWKV_RESCALE_LAYER == 0 421 | ): 422 | x = x / 2 423 | 424 | if preprocess_only: 425 | return torch.empty(1), state 426 | 427 | x = self.LN(x[-1, :], w.ln_out) if seq_mode else self.LN(x, w.ln_out) 428 | x = w.head @ x 429 | 430 | return x.float(), state 431 | 432 | 433 | # @click.command() 434 | # @click.option( 435 | # "--float-mode", 436 | # type=click.Choice(["fp32", "fp16", "bf16"]), 437 | # ) 438 | # @click.option("--device", type=click.Choice(["cpu", "cuda"])) 439 | # @click.argument("model_path", type=click.Path(exists=True)) 440 | # @click.argument("output_path", type=click.Path()) 441 | def convert(float_mode, device, model_path, output_path): 442 | float_modes = { 443 | "fp32": torch.float32, 444 | "fp16": torch.float16, 445 | "bf32": torch.bfloat16, 446 | } 447 | 448 | model = RWKV_RNN_JIT( 449 | model_path=model_path, 450 | float_mode=float_modes[float_mode], 451 | device=torch.device(device), 452 | ) 453 | model = torch.jit.script(model) 454 | 455 | if float_mode == "bf16": 456 | model = model.bfloat16() 457 | elif float_mode == "fp16": 458 | model = model.half() 459 | else: 460 | model = model.float() 461 | 462 | if device == "cuda": 463 | model = model.cuda() 464 | 465 | model.save(output_path) 466 | 467 | print("脚本编写时的 ChatRWKV 版本为:git+3286707,其他版本可能不兼容。") 468 | 469 | 470 | if __name__ == "__main__": 471 | MODEL_NAME = '/www/model/rwkv/RWKV-4-Pile-1B5-EngChn-test4-20230115.pth' 472 | 473 | convert('fp16','cuda:0',MODEL_NAME,'./') -------------------------------------------------------------------------------- /chat.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################## 2 | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM 3 | ######################################################################################################## 4 | import matplotlib.pyplot as plt 5 | import os, copy, types, gc, sys 6 | import numpy as np 7 | import cv2 8 | from mpl_toolkits.mplot3d import Axes3D 9 | 10 | 11 | def plt2ndarr(plt): 12 | from matplotlib.backends.backend_agg import FigureCanvasAgg 13 | import PIL.Image as Image 14 | 15 | # 将plt转化为numpy数据 16 | canvas = FigureCanvasAgg(plt.gcf()) 17 | # print(type(canvas)) 18 | # 绘制图像 19 | canvas.draw() 20 | # 获取图像尺寸 21 | w, h = canvas.get_width_height() 22 | # 解码string 得到argb图像 23 | buf = np.fromstring(canvas.tostring_argb(), dtype=np.uint8) 24 | 25 | # 重构成w h 4(argb)图像 26 | buf.shape = (w, h, 4) 27 | # 转换为 RGBA 28 | buf = np.roll(buf, 3, axis=2) 29 | # 得到 Image RGBA图像对象 (需要Image对象的同学到此为止就可以了) 30 | image = Image.frombytes("RGBA", (w, h), buf.tostring()) 31 | # 转换为numpy array rgba四通道数组 32 | image = np.asarray(image) 33 | # 转换为rgb图像 34 | rgb_image = image[:, :, :3] 35 | # print(rgb_image.shape) 36 | videoWrite.write(rgb_image) 37 | return rgb_image 38 | 39 | 40 | try: 41 | os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1] 42 | except: 43 | pass 44 | np.set_printoptions(precision=4, suppress=True, linewidth=200) 45 | args = types.SimpleNamespace() 46 | 47 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 48 | videoWrite = cv2.VideoWriter(r'test.mp4', fourcc, 8, (1000, 1000)) 49 | ######################################################################################################## 50 | 51 | args.RUN_DEVICE = "cuda" # cuda // cpu 52 | # fp16 (good for GPU, does NOT support CPU) // fp32 (good for CPU) // bf16 (worse accuracy, supports CPU) 53 | args.FLOAT_MODE = "fp16" 54 | 55 | os.environ[ 56 | "RWKV_JIT_ON"] = '0' # '1' or '0'. very useful for fp32, but might be harmful for GPU fp16. plea se benchmark !!! 57 | 58 | CHAT_LANG = 'Chinese' # English // Chinese // more to come 59 | 60 | QA_PROMPT = False # True: Q & A prompt // False: User & Bot prompt 61 | # 中文问答设置QA_PROMPT=True(只能问答,问答效果更好,但不能闲聊) 中文聊天设置QA_PROMPT=False(可以闲聊,但需要大模型才适合闲聊) 62 | 63 | # Download RWKV-4 models from https://huggingface.co/BlinkDL 64 | 65 | if CHAT_LANG == 'English': 66 | # args.MODEL_NAME = '/www/model/rwkv/RWKV-4-Pile-14B-20230213-8019' 67 | args.MODEL_NAME = '/www/model/rwkv/RWKV-4-Pile-7B-20220911-79' 68 | # args.MODEL_NAME = '/www/model/rwkv/RWKV-4-Pile-1B5-20220903-8040' 69 | # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-14b/RWKV-4-Pile-14B-20230204-7324' 70 | # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047' 71 | # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221110-ctx4096' 72 | # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-Instruct-test1-20230124' 73 | # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040' 74 | # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-430m/RWKV-4-Pile-430M-20220808-8066' 75 | # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-169m/RWKV-4-Pile-169M-20220807-8023' 76 | # args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/7-run1z/rwkv-340' 77 | # args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/14b-run1/rwkv-6210' 78 | 79 | elif CHAT_LANG == 'Chinese': 80 | args.MODEL_NAME = '/www/model/rwkv/RWKV-4-Pile-7B-EngChn-test4-20230116' 81 | # args.MODEL_NAME = '/www/model/rwkv/RWKV-4-Pile-1B5-EngChn-test4-20230115' 82 | # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-EngChn-test4-20230115' 83 | # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-EngChn-test4-20230115' 84 | # args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/7-run1z/rwkv-490' 85 | # args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/1.5-run1z/rwkv-415' 86 | 87 | if '-169M-' in args.MODEL_NAME: 88 | args.n_layer = 12 89 | args.n_embd = 768 90 | if '-430M-' in args.MODEL_NAME: 91 | args.n_layer = 24 92 | args.n_embd = 1024 93 | if '-1B5-' in args.MODEL_NAME or '/1.5-' in args.MODEL_NAME: 94 | args.n_layer = 24 95 | args.n_embd = 2048 96 | elif '-3B-' in args.MODEL_NAME or '/3-' in args.MODEL_NAME: 97 | args.n_layer = 32 98 | args.n_embd = 2560 99 | elif '-7B-' in args.MODEL_NAME or '/7-' in args.MODEL_NAME: 100 | args.n_layer = 32 101 | args.n_embd = 4096 102 | elif '-14B-' in args.MODEL_NAME or '/14-' in args.MODEL_NAME or '/14b-' in args.MODEL_NAME: 103 | args.n_layer = 40 104 | args.n_embd = 5120 105 | 106 | args.ctx_len = 1024 107 | 108 | CHAT_LEN_SHORT = 40 109 | CHAT_LEN_LONG = 150 110 | FREE_GEN_LEN = 200 111 | 112 | GEN_TEMP = 1.0 113 | GEN_TOP_P = 0.85 114 | 115 | AVOID_REPEAT = ',。:?!' 116 | 117 | ######################################################################################################## 118 | 119 | print(f'\nLoading ChatRWKV - {CHAT_LANG} - {args.RUN_DEVICE} - {args.FLOAT_MODE} - QA_PROMPT {QA_PROMPT}') 120 | import torch 121 | 122 | torch.backends.cudnn.benchmark = True 123 | torch.backends.cudnn.allow_tf32 = True 124 | torch.backends.cuda.matmul.allow_tf32 = True 125 | from src.model_run import RWKV_RNN 126 | from src.utils import TOKENIZER 127 | 128 | tokenizer = TOKENIZER("20B_tokenizer.json") 129 | 130 | args.vocab_size = 50277 131 | args.head_qk = 0 132 | args.pre_ffn = 0 133 | args.grad_cp = 0 134 | args.my_pos_emb = 0 135 | os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE 136 | MODEL_NAME = args.MODEL_NAME 137 | 138 | if CHAT_LANG == 'English': 139 | interface = ":" 140 | 141 | if QA_PROMPT: 142 | user = "Q" 143 | bot = "A" 144 | intro = f'The following is a verbose and detailed Q & A conversation of factual information.' 145 | else: 146 | user = "User" 147 | bot = "Bot" 148 | intro = f'The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite.' 149 | 150 | init_prompt = f''' 151 | {intro} 152 | 153 | {user}{interface} french revolution what year 154 | 155 | {bot}{interface} The French Revolution started in 1789, and lasted 10 years until 1799. 156 | 157 | {user}{interface} 3+5=? 158 | 159 | {bot}{interface} The answer is 8. 160 | 161 | {user}{interface} guess i marry who ? 162 | 163 | {bot}{interface} Only if you tell me more about yourself - what are your interests? 164 | 165 | {user}{interface} solve for a: 9-a=2 166 | 167 | {bot}{interface} The answer is a = 7, because 9 - 7 = 2. 168 | 169 | {user}{interface} wat is lhc 170 | 171 | {bot}{interface} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012. 172 | 173 | ''' 174 | HELP_MSG = '''Commands: 175 | say something --> chat with bot. use \\n for new line. 176 | + --> alternate chat reply 177 | +reset --> reset chat 178 | 179 | +gen YOUR PROMPT --> free generation with any prompt. use \\n for new line. 180 | +qa YOUR QUESTION --> free generation - ask any question (just ask the question). use \\n for new line. 181 | +++ --> continue last free generation (only for +gen / +qa) 182 | ++ --> retry last free generation (only for +gen / +qa) 183 | 184 | Now talk with the bot and enjoy. Remember to +reset periodically to clean up the bot's memory. Use RWKV-4 14B for best results. 185 | This is not instruct-tuned for conversation yet, so don't expect good quality. Better use +gen for free generation. 186 | ''' 187 | elif CHAT_LANG == 'Chinese': 188 | interface = ":" 189 | if QA_PROMPT: 190 | user = "Q" 191 | bot = "A" 192 | init_prompt = f''' 193 | Expert Questions & Helpful Answers 194 | 195 | Ask Research Experts 196 | 197 | ''' 198 | else: 199 | user = "User" 200 | bot = "Bot" 201 | init_prompt = f''' 202 | The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite. 203 | 204 | {user}{interface} wat is lhc 205 | 206 | {bot}{interface} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012. 207 | 208 | {user}{interface} 企鹅会飞吗 209 | 210 | {bot}{interface} 企鹅是不会飞的。它们的翅膀主要用于游泳和平衡,而不是飞行。 211 | 212 | ''' 213 | HELP_MSG = '''指令: 214 | 215 | 直接输入内容 --> 和机器人聊天(建议问机器人问题),用\\n代表换行 216 | + --> 让机器人换个回答 217 | +reset --> 重置对话 218 | 219 | +gen 某某内容 --> 续写任何中英文内容,用\\n代表换行 220 | +qa 某某问题 --> 问独立的问题(忽略上下文),用\\n代表换行 221 | +qq 某某问题 --> 问独立的问题(忽略上下文),且敞开想象力,用\\n代表换行 222 | +++ --> 继续 +gen / +qa / +qq 的回答 223 | ++ --> 换个 +gen / +qa / +qq 的回答 224 | 225 | 现在可以输入内容和机器人聊天(注意它不大懂中文,它可能更懂英文)。请经常使用 +reset 重置机器人记忆。 226 | 目前没有“重复惩罚”,所以机器人有时会重复,此时必须使用 + 换成正常回答,以免污染电脑记忆。 227 | 注意:和上下文无关的独立问题,必须用 +qa 或 +qq 问,以免污染电脑记忆。 228 | ''' 229 | 230 | # Load Model 231 | 232 | print(f'Loading model - {MODEL_NAME}') 233 | model = RWKV_RNN(args) 234 | 235 | model_tokens = [] 236 | model_state = None 237 | 238 | AVOID_REPEAT_TOKENS = [] 239 | for i in AVOID_REPEAT: 240 | dd = tokenizer.encode(i) 241 | assert len(dd) == 1 242 | AVOID_REPEAT_TOKENS += dd 243 | 244 | 245 | ######################################################################################################## 246 | 247 | def run_rnn(tokens, newline_adj=0): 248 | global model_tokens, model_state 249 | 250 | tokens = [int(x) for x in tokens] 251 | model_tokens += tokens 252 | # print(tokens,type(tokens )) 253 | out, all_ffn_out, model_state = model.forward(tokens, model_state) 254 | # assss = all_ffn_out[0].to('cpu').numpy() 255 | # data=np.array([x.to('cpu').numpy() for x in all_ffn_out]) 256 | 257 | # plt.plot(assss) 258 | # ax.plot_surface(ax, rstride=1, cstride=1, cmap='rainbow') 259 | # plt2ndarr(plt) 260 | # plt.show() 261 | # print(f'### model ###\n{tokens}\n[{tokenizer.decode(model_tokens)}]') 262 | 263 | out[0] = -999999999 # disable <|endoftext|> 264 | out[187] += newline_adj # adjust \n probability 265 | # if newline_adj > 0: 266 | # out[15] += newline_adj / 2 # '.' 267 | if model_tokens[-1] in AVOID_REPEAT_TOKENS: 268 | out[model_tokens[-1]] = -999999999 269 | return out, all_ffn_out 270 | 271 | 272 | all_state = {} 273 | 274 | 275 | def save_all_stat(srv, name, last_out): 276 | n = f'{name}_{srv}' 277 | all_state[n] = {} 278 | all_state[n]['out'] = last_out 279 | all_state[n]['rnn'] = copy.deepcopy(model_state) 280 | all_state[n]['token'] = copy.deepcopy(model_tokens) 281 | 282 | 283 | def load_all_stat(srv, name): 284 | global model_tokens, model_state 285 | n = f'{name}_{srv}' 286 | model_state = copy.deepcopy(all_state[n]['rnn']) 287 | model_tokens = copy.deepcopy(all_state[n]['token']) 288 | return all_state[n]['out'] 289 | 290 | 291 | ######################################################################################################## 292 | 293 | # Run inference 294 | print(f'\nRun prompt...') 295 | 296 | out, all_ffn_out = run_rnn(tokenizer.encode(init_prompt)) 297 | save_all_stat('', 'chat_init', out) 298 | gc.collect() 299 | torch.cuda.empty_cache() 300 | 301 | srv_list = ['dummy_server'] 302 | for s in srv_list: 303 | save_all_stat(s, 'chat', out) 304 | 305 | print(f'### prompt ###\n[{tokenizer.decode(model_tokens)}]\n') 306 | 307 | 308 | def reply_msg(msg): 309 | print(f'{bot}{interface} {msg}\n') 310 | 311 | 312 | def draw_ffn(send_msg): 313 | fig = plt.figure(figsize=(10, 10)) 314 | ax = plt.axes(projection='3d') 315 | # ax = Axes3D(fig) 316 | ax.set_zlim(-0, 500) 317 | ax.set_zlabel('value') 318 | ax.set_xlabel('channel') 319 | ax.set_ylabel('layer') 320 | 321 | # ax.set_title(send_msg, fontsize=10) 322 | print('send_msg->', send_msg) 323 | if len(all_ffn_out[1].shape) == 1: 324 | for index in range(len(all_ffn_out)): 325 | ffn_out = all_ffn_out[index] 326 | x = [x for x in range(ffn_out.shape[0])] 327 | y = ffn_out.to('cpu').numpy() 328 | # ax.bar(ffn_out.to('cpu').numpy(),[x for x in range(2048)]) 329 | # ax.bar(ffn_out.to('cpu').numpy(),[x for x in range(2048)]) 330 | # ax.plot([x for x in range(2048)],ffn_out.to('cpu').numpy(),zs=index) 331 | ax.plot(x, y, zs=index, zdir='y') 332 | # ax.text(4, 6, s=send_msg, fontsize=5, color='green') 333 | # 将Matplotlib图像转换为OpenCV图像格式 334 | fig.canvas.draw() 335 | img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 336 | img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 337 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 338 | cv2.putText(img, send_msg, (10, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2) 339 | 340 | # 将图像写入视频 341 | videoWrite.write(img) 342 | 343 | # 清除轴对象并准备下一帧 344 | ax.clear() 345 | plt.show() 346 | plt.close() 347 | 348 | 349 | def on_message(message): 350 | global model_tokens, model_state 351 | 352 | srv = 'dummy_server' 353 | 354 | msg = message.replace('\\n', '\n').strip() 355 | # if len(msg) > 1000: 356 | # reply_msg('your message is too long (max 1000 tokens)') 357 | # return 358 | 359 | x_temp = GEN_TEMP 360 | x_top_p = GEN_TOP_P 361 | if ("-temp=" in msg): 362 | x_temp = float(msg.split("-temp=")[1].split(" ")[0]) 363 | msg = msg.replace("-temp=" + f'{x_temp:g}', "") 364 | # print(f"temp: {x_temp}") 365 | if ("-top_p=" in msg): 366 | x_top_p = float(msg.split("-top_p=")[1].split(" ")[0]) 367 | msg = msg.replace("-top_p=" + f'{x_top_p:g}', "") 368 | # print(f"top_p: {x_top_p}") 369 | if x_temp <= 0.2: 370 | x_temp = 0.2 371 | if x_temp >= 5: 372 | x_temp = 5 373 | if x_top_p <= 0: 374 | x_top_p = 0 375 | 376 | if msg == '+reset': 377 | out = load_all_stat('', 'chat_init') 378 | save_all_stat(srv, 'chat', out) 379 | reply_msg("Chat reset.") 380 | return 381 | 382 | elif msg[:5].lower() == '+gen ' or msg[:4].lower() == '+qa ' or msg[ 383 | :4].lower() == '+qq ' or msg.lower() == '+++' or msg.lower() == '++': 384 | 385 | if msg[:5].lower() == '+gen ': 386 | new = '\n' + msg[5:].strip() 387 | # print(f'### prompt ###\n[{new}]') 388 | model_state = None 389 | model_tokens = [] 390 | out, all_ffn_out = run_rnn(tokenizer.encode(new)) 391 | save_all_stat(srv, 'gen_0', out) 392 | 393 | elif msg[:4].lower() == '+qq ': 394 | new = '\nQ: ' + msg[4:].strip() + '\nA:' 395 | # print(f'### prompt ###\n[{new}]') 396 | model_state = None 397 | model_tokens = [] 398 | out, all_ffn_out = run_rnn(tokenizer.encode(new)) 399 | save_all_stat(srv, 'gen_0', out) 400 | 401 | elif msg[:4].lower() == '+qa ': 402 | out = load_all_stat('', 'chat_init') 403 | 404 | real_msg = msg[4:].strip() 405 | new = f"{user}{interface} {real_msg}\n\n{bot}{interface}" 406 | # print(f'### qa ###\n[{new}]') 407 | 408 | out, all_ffn_out = run_rnn(tokenizer.encode(new)) 409 | save_all_stat(srv, 'gen_0', out) 410 | 411 | elif msg.lower() == '+++': 412 | try: 413 | out = load_all_stat(srv, 'gen_1') 414 | save_all_stat(srv, 'gen_0', out) 415 | except: 416 | return 417 | 418 | elif msg.lower() == '++': 419 | try: 420 | out = load_all_stat(srv, 'gen_0') 421 | except: 422 | return 423 | 424 | begin = len(model_tokens) 425 | out_last = begin 426 | for i in range(FREE_GEN_LEN + 100): 427 | token = tokenizer.sample_logits( 428 | out, 429 | model_tokens, 430 | args.ctx_len, 431 | temperature=x_temp, 432 | top_p=x_top_p, 433 | ) 434 | if msg[:4].lower() == '+qa ': # or msg[:4].lower() == '+qq ': 435 | out, all_ffn_out = run_rnn([token], newline_adj=-2) 436 | else: 437 | out, all_ffn_out = run_rnn([token]) 438 | 439 | xxx = tokenizer.decode(model_tokens[out_last:]) 440 | # draw_ffn(xxx) 441 | if '\ufffd' not in xxx: # avoid utf-8 display issues 442 | print(xxx, end='', flush=True) 443 | out_last = begin + i + 1 444 | if i >= FREE_GEN_LEN: 445 | break 446 | print('\n') 447 | # send_msg = tokenizer.decode(model_tokens[begin:]).strip() 448 | # print(f'### send ###\n[{send_msg}]') 449 | # reply_msg(send_msg) 450 | 451 | save_all_stat(srv, 'gen_1', out) 452 | 453 | else: 454 | if msg.lower() == '+': 455 | try: 456 | out = load_all_stat(srv, 'chat_pre') 457 | except: 458 | return 459 | else: 460 | out = load_all_stat(srv, 'chat') 461 | new = f"{user}{interface} {msg}\n\n{bot}{interface}" 462 | # print(f'### add ###\n[{new}]') 463 | out, all_ffn_out = run_rnn(tokenizer.encode(new), newline_adj=-999999999) 464 | save_all_stat(srv, 'chat_pre', out) 465 | 466 | begin = len(model_tokens) 467 | out_last = begin 468 | print(f'{bot}{interface}', end='', flush=True) 469 | for i in range(999): 470 | if i <= 0: 471 | newline_adj = -999999999 472 | elif i <= CHAT_LEN_SHORT: 473 | newline_adj = (i - CHAT_LEN_SHORT) / 10 474 | elif i <= CHAT_LEN_LONG: 475 | newline_adj = 0 476 | else: 477 | newline_adj = (i - CHAT_LEN_LONG) * 0.25 # MUST END THE GENERATION 478 | token = tokenizer.sample_logits( 479 | out, 480 | model_tokens, 481 | args.ctx_len, 482 | temperature=x_temp, 483 | top_p=x_top_p, 484 | ) 485 | out, all_ffn_out = run_rnn([token], newline_adj=newline_adj) 486 | 487 | xxx = tokenizer.decode(model_tokens[out_last:]) 488 | if '\ufffd' not in xxx: # avoid utf-8 display issues 489 | print(xxx, end='', flush=True) 490 | out_last = begin + i + 1 491 | 492 | send_msg = tokenizer.decode(model_tokens[begin:]) 493 | 494 | # draw_ffn(send_msg) 495 | # plt2ndarr(plt) 496 | fig = plt.figure(figsize=(10, 10)) 497 | ax = plt.axes(projection='3d') 498 | # ax = Axes3D(fig) 499 | ax.set_zlim(-500, 500) 500 | ax.set_zlabel('value') 501 | ax.set_xlabel('channel') 502 | ax.set_ylabel('layer') 503 | 504 | # ax.set_title(send_msg, fontsize=10) 505 | # print('send_msg->', send_msg) 506 | 507 | if len(all_ffn_out[1].shape) == 1: 508 | for index in range(len(all_ffn_out)): 509 | ffn_out = all_ffn_out[index] 510 | x = [x for x in range(ffn_out.shape[0])] 511 | y = ffn_out.to('cpu').numpy() 512 | # ax.bar(ffn_out.to('cpu').numpy(),[x for x in range(2048)]) 513 | # ax.bar(ffn_out.to('cpu').numpy(),[x for x in range(2048)]) 514 | # ax.plot([x for x in range(2048)],ffn_out.to('cpu').numpy(),zs=index) 515 | ax.plot(x, y, zs=index, zdir='y') 516 | # ax.text(4, 6, s=send_msg, fontsize=5, color='green') 517 | # 将Matplotlib图像转换为OpenCV图像格式 518 | fig.canvas.draw() 519 | img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 520 | img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 521 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 522 | send_msg_y=100 523 | if len(send_msg)//20>1: 524 | for send_msg_index in range(len(send_msg)//20): 525 | if send_msg_index==0: 526 | cv2.putText(img, send_msg[0:20], (10, send_msg_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2) 527 | else: 528 | if send_msg_index==len(send_msg)//20: 529 | cv2.putText(img, send_msg[send_msg_index * 20 + 20:], 530 | (10, send_msg_y+35), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2) 531 | cv2.putText(img, send_msg[send_msg_index*20:send_msg_index*20+20], (10, send_msg_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2) 532 | send_msg_y+=35 533 | 534 | else: 535 | cv2.putText(img, send_msg, (10, send_msg_y), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2) 536 | 537 | # 将图像写入视频 538 | videoWrite.write(img) 539 | 540 | # 清除轴对象并准备下一帧 541 | ax.clear() 542 | # plt.show() 543 | plt.close() 544 | if '\n\n' in send_msg: 545 | videoWrite.release() 546 | cv2.destroyAllWindows() 547 | send_msg = send_msg.strip() 548 | break 549 | 550 | # send_msg = tokenizer.decode(model_tokens[begin:]).strip() 551 | # if send_msg.endswith(f'{user}{interface}'): # warning: needs to fix state too !!! 552 | # send_msg = send_msg[:-len(f'{user}{interface}')].strip() 553 | # break 554 | # if send_msg.endswith(f'{bot}{interface}'): 555 | # send_msg = send_msg[:-len(f'{bot}{interface}')].strip() 556 | # break 557 | 558 | # print(f'{model_tokens}') 559 | # print(f'[{tokenizer.decode(model_tokens)}]') 560 | 561 | # print(f'### send ###\n[{send_msg}]') 562 | # reply_msg(send_msg) 563 | save_all_stat(srv, 'chat', out) 564 | 565 | 566 | print(HELP_MSG) 567 | #What is python? Please tell me in detail what it is, how to use it, how to learn it, and what advantages and disadvantages it has 568 | while True: 569 | msg = input(f'{user}{interface} ') 570 | # msg='hello' 571 | # msg='+gen Tell me what Python is and what its characteristics are. Please demonstrate your ability to write code and explain what it can do.' 572 | if len(msg.strip()) > 0: 573 | on_message(msg) 574 | else: 575 | print('Error: please say something') 576 | -------------------------------------------------------------------------------- /chat_interpretability.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################## 2 | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM 3 | ######################################################################################################## 4 | import matplotlib.pyplot as plt 5 | import os, copy, types, gc, sys 6 | import numpy as np 7 | import cv2 8 | from mpl_toolkits.mplot3d import Axes3D 9 | from PIL import Image, ImageDraw, ImageFont 10 | import options 11 | import argparse 12 | 13 | opt = options.Options().init(argparse.ArgumentParser()).parse_args() 14 | 15 | 16 | def plt2ndarr(plt): 17 | from matplotlib.backends.backend_agg import FigureCanvasAgg 18 | import PIL.Image as Image 19 | 20 | # 将plt转化为numpy数据 21 | canvas = FigureCanvasAgg(plt.gcf()) 22 | # print(type(canvas)) 23 | # 绘制图像 24 | canvas.draw() 25 | # 获取图像尺寸 26 | w, h = canvas.get_width_height() 27 | # 解码string 得到argb图像 28 | buf = np.fromstring(canvas.tostring_argb(), dtype=np.uint8) 29 | 30 | # 重构成w h 4(argb)图像 31 | buf.shape = (w, h, 4) 32 | # 转换为 RGBA 33 | buf = np.roll(buf, 3, axis=2) 34 | # 得到 Image RGBA图像对象 (需要Image对象的同学到此为止就可以了) 35 | image = Image.frombytes("RGBA", (w, h), buf.tostring()) 36 | # 转换为numpy array rgba四通道数组 37 | image = np.asarray(image) 38 | # 转换为rgb图像 39 | rgb_image = image[:, :, :3] 40 | # print(rgb_image.shape) 41 | videoWrite.write(rgb_image) 42 | return rgb_image 43 | 44 | 45 | def draw_box_string(img, x, y, string): 46 | """ 47 | img: imread读取的图片; 48 | x,y:字符起始绘制的位置; 49 | string: 显示的文字; 50 | return: img 51 | """ 52 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 53 | img = Image.fromarray(img) 54 | draw = ImageDraw.Draw(img) 55 | # simhei.ttf 是字体,你如果没有字体,需要下载 56 | font = ImageFont.truetype("/usr/share/fonts/zh/simsun.ttc", 10, encoding="utf-8") 57 | # font = ImageFont.truetype("SourceHanSansCN-Regular.ttf", 50, encoding="utf-8") 58 | draw.text((x, y - 50), string, (0, 0, 0), font=font) 59 | img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) 60 | return img 61 | 62 | def print_interpretability(result): 63 | lower, upper = result['99%_range'] # 99%范围 64 | 65 | print(' 分布检查结果 -->',result['distribution']+' 标准差 -->',result['std']+' 方差 -->',result['var']+' KL散度 -->',result['kl_div']+' 99%数值范围 -->',lower,upper) 66 | 67 | def interpretability(ffn): 68 | ''' 69 | 选择一个合适的THRESHOLD值是比较主观的,取值范围较广。一般来说,一个好的THRESHOLD值应满足: 70 | 71 | • 大于0。THRESHOLD永远应大于0,表示两者差值实际上有一定的差距,分布不完全相同。 72 | 73 | • 小于1。THRESHOLD取值过大,例如大于1,就失去了判断的意义。任何非0差值都会被看作接近正态分布的了,没有实际作用。 74 | 75 | • 根据具体问题和数据选择。不同问题和数据集的THRESHOLD最佳值会有所不同。 genral 一个范围而言,0.1到0.5是较常见而合理的取值,但需要通过试验获得最佳THRESHOLD。 76 | 77 | • 兼顾严谨性和准确性。THRESHOLD越小,判断越严谨谨慎,仅看作最接近正态分布的分布为正态分布。THRESHOLD越大,判断越宽容,较为偏离的分布也会被看作为正态分布。需要平衡这两个方面,既要准确又要谨慎。 78 | 79 | • 考虑后续操作。THRESHOLD的选择应考虑后续任务的要求。高严谨性要求的下游任务通常需要较小的THRESHOLD;简单分类等操作THRESHOLD可以稍大。 80 | 81 | 一些示例THRESHOLD取值及其含义: 82 | 83 | • 0.1:较严谨的判断,只视最接近正态分布的分布为正态分布。 84 | • 0.3:中等严谨度,较多分布会被看作正态分布,宽容度适中。这是较常见取值的上限。 85 | • 0.5:较宽松的判断,许多偏离正态分布的分布也会被认定为正态分布。严谨度相对较低,准确性相对较高。 86 | • 大于1:判断失去严谨性和实际意义,不建议采用。 87 | 88 | 所以在实际实现中,选取THRESHOLD=0.1 to 0.5是较为合理的取值范围。但最终选择还是需要根据具体问题、数据特征和后续操作要求来确定的。同时,也可以完全舍弃THRESHOLD,转而使用KS检验或观察視覺化等更为严谨的方法判断分布。 89 | 90 | 总之,THRESHOLD是一个用于判断分布是否为正态分布的人工阈值参数。其选取应兼顾严谨性和准确性,多考虑具体环境因素,通过评估实践获得最佳值,或用更严谨方法替换THRESHOLD判断。 91 | 92 | 如有任何不明白或问题的地方,尽管继续提问。我很乐意继续为你提供帮助! 93 | ''' 94 | THRESHOLD = 0.5 95 | # 1. 检查分布 96 | dist = torch.normal(0, 1, size=ffn.size()).to(args.RUN_DEVICE) # 正态分布Tensor 97 | dist = (ffn - dist).pow(2).sum().to('cpu')/ffn.size()[0] # 计算差的L2范数 98 | dist=dist.to('cpu') 99 | # print(f'是否正态分布: {dist < THRESHOLD}') 100 | 101 | # 2. 离散程度相关 102 | std = torch.std(ffn) # 标准差 103 | variance = torch.var(ffn) # 方差 104 | kl_div = torch.kl_div(ffn.float(), torch.normal(0, 1, size=ffn.size()).to(args.RUN_DEVICE) ) # KL散度 105 | 106 | # 3. 找到99%分位数 107 | upper_bound = torch.quantile(ffn.float(), 0.99) 108 | lower_bound = -torch.quantile(ffn.float(), 0.99) 109 | 110 | # 4. 根据范围去除异常值 111 | final_tensor = torch.clamp(ffn, min=lower_bound, max=upper_bound) 112 | 113 | return { 114 | 'distribution': dist < THRESHOLD, 115 | 'std': std, 116 | 'var': variance, 117 | 'kl_div': kl_div, 118 | '99%_range': (lower_bound, upper_bound), 119 | 'processed_tensor': final_tensor 120 | } 121 | 122 | 123 | try: 124 | os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1] 125 | except: 126 | pass 127 | np.set_printoptions(precision=4, suppress=True, linewidth=200) 128 | args = types.SimpleNamespace() 129 | 130 | ######################################################################################################## 131 | 132 | args.RUN_DEVICE =opt.RUN_DEVICE #"cuda" # cuda // cpu 133 | # fp16 (good for GPU, does NOT support CPU) // fp32 (good for CPU) // bf16 (worse accuracy, supports CPU) 134 | args.FLOAT_MODE =opt.FLOAT_MODE #"fp16" 135 | 136 | os.environ[ 137 | "RWKV_JIT_ON"] = '0' # '1' or '0'. very useful for fp32, but might be harmful for GPU fp16. plea se benchmark !!! 138 | 139 | CHAT_LANG = opt.CHAT_LANG#'Chinese' # English // Chinese // more to come 140 | 141 | QA_PROMPT = False # True: Q & A prompt // False: User & Bot prompt 142 | # 中文问答设置QA_PROMPT=True(只能问答,问答效果更好,但不能闲聊) 中文聊天设置QA_PROMPT=False(可以闲聊,但需要大模型才适合闲聊) 143 | 144 | # Download RWKV-4 models from https://huggingface.co/BlinkDL 145 | 146 | if CHAT_LANG == 'English': 147 | # args.MODEL_NAME = '/www/model/rwkv/RWKV-4-Pile-14B-20230213-8019' 148 | args.MODEL_NAME = '/www/model/rwkv/RWKV-4-Pile-7B-20220911-79' 149 | # args.MODEL_NAME = '/www/model/rwkv/RWKV-4-Pile-1B5-20220903-8040' 150 | # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-14b/RWKV-4-Pile-14B-20230204-7324' 151 | # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047' 152 | # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221110-ctx4096' 153 | # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-Instruct-test1-20230124' 154 | # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040' 155 | # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-430m/RWKV-4-Pile-430M-20220808-8066' 156 | # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-169m/RWKV-4-Pile-169M-20220807-8023' 157 | # args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/7-run1z/rwkv-340' 158 | # args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/14b-run1/rwkv-6210' 159 | 160 | elif CHAT_LANG == 'Chinese': 161 | args.MODEL_NAME = '/www/model/rwkv/RWKV-4-Pile-7B-EngChn-test4-20230116' 162 | # args.MODEL_NAME = '/www/model/rwkv/RWKV-4-Pile-1B5-EngChn-test4-20230115' 163 | # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-EngChn-test4-20230115' 164 | # args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-EngChn-test4-20230115' 165 | # args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/7-run1z/rwkv-490' 166 | # args.MODEL_NAME = '/fsx/BlinkDL/CODE/_PUBLIC_/RWKV-LM/RWKV-v4neo/1.5-run1z/rwkv-415' 167 | 168 | if '-169M-' in args.MODEL_NAME: 169 | args.n_layer = 12 170 | args.n_embd = 768 171 | if '-430M-' in args.MODEL_NAME: 172 | args.n_layer = 24 173 | args.n_embd = 1024 174 | if '-1B5-' in args.MODEL_NAME or '/1.5-' in args.MODEL_NAME: 175 | args.n_layer = 24 176 | args.n_embd = 2048 177 | elif '-3B-' in args.MODEL_NAME or '/3-' in args.MODEL_NAME: 178 | args.n_layer = 32 179 | args.n_embd = 2560 180 | elif '-7B-' in args.MODEL_NAME or '/7-' in args.MODEL_NAME: 181 | args.n_layer = 32 182 | args.n_embd = 4096 183 | elif '-14B-' in args.MODEL_NAME or '/14-' in args.MODEL_NAME or '/14b-' in args.MODEL_NAME: 184 | args.n_layer = 40 185 | args.n_embd = 5120 186 | 187 | args.ctx_len = 1024 188 | 189 | CHAT_LEN_SHORT = 40 190 | CHAT_LEN_LONG = 150 191 | FREE_GEN_LEN = 200 192 | 193 | GEN_TEMP = 1.0 194 | GEN_TOP_P = 0.85 195 | 196 | AVOID_REPEAT = ',。:?!' 197 | 198 | ######################################################################################################## 199 | 200 | print(f'\nLoading ChatRWKV - {CHAT_LANG} - {args.RUN_DEVICE} - {args.FLOAT_MODE} - QA_PROMPT {QA_PROMPT}') 201 | import torch 202 | 203 | torch.backends.cudnn.benchmark = True 204 | torch.backends.cudnn.allow_tf32 = True 205 | torch.backends.cuda.matmul.allow_tf32 = True 206 | from src.model_run import RWKV_RNN 207 | from src.utils import TOKENIZER 208 | 209 | tokenizer = TOKENIZER("20B_tokenizer.json") 210 | 211 | args.vocab_size = 50277 212 | args.head_qk = 0 213 | args.pre_ffn = 0 214 | args.grad_cp = 0 215 | args.my_pos_emb = 0 216 | os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE 217 | MODEL_NAME = args.MODEL_NAME 218 | 219 | if CHAT_LANG == 'English': 220 | interface = ":" 221 | 222 | if QA_PROMPT: 223 | user = "Q" 224 | bot = "A" 225 | intro = f'The following is a verbose and detailed Q & A conversation of factual information.' 226 | else: 227 | user = "User" 228 | bot = "Bot" 229 | intro = f'The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite.' 230 | 231 | init_prompt = f''' 232 | {intro} 233 | 234 | {user}{interface} french revolution what year 235 | 236 | {bot}{interface} The French Revolution started in 1789, and lasted 10 years until 1799. 237 | 238 | {user}{interface} 3+5=? 239 | 240 | {bot}{interface} The answer is 8. 241 | 242 | {user}{interface} guess i marry who ? 243 | 244 | {bot}{interface} Only if you tell me more about yourself - what are your interests? 245 | 246 | {user}{interface} solve for a: 9-a=2 247 | 248 | {bot}{interface} The answer is a = 7, because 9 - 7 = 2. 249 | 250 | {user}{interface} wat is lhc 251 | 252 | {bot}{interface} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012. 253 | 254 | ''' 255 | HELP_MSG = '''Commands: 256 | say something --> chat with bot. use \\n for new line. 257 | + --> alternate chat reply 258 | +reset --> reset chat 259 | 260 | +gen YOUR PROMPT --> free generation with any prompt. use \\n for new line. 261 | +qa YOUR QUESTION --> free generation - ask any question (just ask the question). use \\n for new line. 262 | +++ --> continue last free generation (only for +gen / +qa) 263 | ++ --> retry last free generation (only for +gen / +qa) 264 | 265 | Now talk with the bot and enjoy. Remember to +reset periodically to clean up the bot's memory. Use RWKV-4 14B for best results. 266 | This is not instruct-tuned for conversation yet, so don't expect good quality. Better use +gen for free generation. 267 | ''' 268 | elif CHAT_LANG == 'Chinese': 269 | interface = ":" 270 | if QA_PROMPT: 271 | user = "Q" 272 | bot = "A" 273 | init_prompt = f''' 274 | Expert Questions & Helpful Answers 275 | 276 | Ask Research Experts 277 | 278 | ''' 279 | else: 280 | user = "User" 281 | bot = "Bot" 282 | init_prompt = f''' 283 | The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite. 284 | 285 | {user}{interface} wat is lhc 286 | 287 | {bot}{interface} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012. 288 | 289 | {user}{interface} 企鹅会飞吗 290 | 291 | {bot}{interface} 企鹅是不会飞的。它们的翅膀主要用于游泳和平衡,而不是飞行。 292 | 293 | ''' 294 | HELP_MSG = '''指令: 295 | 296 | 直接输入内容 --> 和机器人聊天(建议问机器人问题),用\\n代表换行 297 | + --> 让机器人换个回答 298 | +reset --> 重置对话 299 | 300 | +gen 某某内容 --> 续写任何中英文内容,用\\n代表换行 301 | +qa 某某问题 --> 问独立的问题(忽略上下文),用\\n代表换行 302 | +qq 某某问题 --> 问独立的问题(忽略上下文),且敞开想象力,用\\n代表换行 303 | +++ --> 继续 +gen / +qa / +qq 的回答 304 | ++ --> 换个 +gen / +qa / +qq 的回答 305 | 306 | 现在可以输入内容和机器人聊天(注意它不大懂中文,它可能更懂英文)。请经常使用 +reset 重置机器人记忆。 307 | 目前没有“重复惩罚”,所以机器人有时会重复,此时必须使用 + 换成正常回答,以免污染电脑记忆。 308 | 注意:和上下文无关的独立问题,必须用 +qa 或 +qq 问,以免污染电脑记忆。 309 | ''' 310 | 311 | # Load Model 312 | 313 | print(f'Loading model - {MODEL_NAME}') 314 | model = RWKV_RNN(args) 315 | 316 | model_tokens = [] 317 | model_state = None 318 | 319 | AVOID_REPEAT_TOKENS = [] 320 | for i in AVOID_REPEAT: 321 | dd = tokenizer.encode(i) 322 | assert len(dd) == 1 323 | AVOID_REPEAT_TOKENS += dd 324 | 325 | 326 | ######################################################################################################## 327 | 328 | def run_rnn(tokens, newline_adj=0,outlier=[-2]): 329 | global model_tokens, model_state 330 | 331 | tokens = [int(x) for x in tokens] 332 | model_tokens += tokens 333 | # print(tokens,type(tokens )) 334 | out, all_ffn_out, model_state = model.forward(tokens, model_state,outlier=outlier) 335 | # assss = all_ffn_out[0].to('cpu').numpy() 336 | # data=np.array([x.to('cpu').numpy() for x in all_ffn_out]) 337 | 338 | # plt.plot(assss) 339 | # ax.plot_surface(ax, rstride=1, cstride=1, cmap='rainbow') 340 | # plt2ndarr(plt) 341 | # plt.show() 342 | # print(f'### model ###\n{tokens}\n[{tokenizer.decode(model_tokens)}]') 343 | 344 | out[0] = -999999999 # disable <|endoftext|> 345 | out[187] += newline_adj # adjust \n probability 346 | # if newline_adj > 0: 347 | # out[15] += newline_adj / 2 # '.' 348 | if model_tokens[-1] in AVOID_REPEAT_TOKENS: 349 | out[model_tokens[-1]] = -999999999 350 | return out, all_ffn_out 351 | 352 | 353 | all_state = {} 354 | 355 | 356 | def save_all_stat(srv, name, last_out): 357 | n = f'{name}_{srv}' 358 | all_state[n] = {} 359 | all_state[n]['out'] = last_out 360 | all_state[n]['rnn'] = copy.deepcopy(model_state) 361 | all_state[n]['token'] = copy.deepcopy(model_tokens) 362 | 363 | 364 | def load_all_stat(srv, name): 365 | global model_tokens, model_state 366 | n = f'{name}_{srv}' 367 | model_state = copy.deepcopy(all_state[n]['rnn']) 368 | model_tokens = copy.deepcopy(all_state[n]['token']) 369 | return all_state[n]['out'] 370 | 371 | 372 | ######################################################################################################## 373 | 374 | # Run inference 375 | print(f'\nRun prompt...') 376 | 377 | out, all_ffn_out = run_rnn(tokenizer.encode(init_prompt)) 378 | save_all_stat('', 'chat_init', out) 379 | gc.collect() 380 | torch.cuda.empty_cache() 381 | 382 | srv_list = ['dummy_server'] 383 | for s in srv_list: 384 | save_all_stat(s, 'chat', out) 385 | 386 | print(f'### prompt ###\n[{tokenizer.decode(model_tokens)}]\n') 387 | 388 | 389 | def reply_msg(msg): 390 | print(f'{bot}{interface} {msg}\n') 391 | 392 | 393 | def draw_ffn(send_msg): 394 | fig = plt.figure(figsize=(10, 10)) 395 | ax = plt.axes(projection='3d') 396 | # ax = Axes3D(fig) 397 | ax.set_zlim(-0, 500) 398 | ax.set_zlabel('value') 399 | ax.set_xlabel('channel') 400 | ax.set_ylabel('layer') 401 | 402 | # ax.set_title(send_msg, fontsize=10) 403 | print('send_msg->', send_msg) 404 | if len(all_ffn_out[1].shape) == 1: 405 | for index in range(len(all_ffn_out)): 406 | ffn_out = all_ffn_out[index] 407 | x = [x for x in range(ffn_out.shape[0])] 408 | y = ffn_out.to('cpu').numpy() 409 | # ax.bar(ffn_out.to('cpu').numpy(),[x for x in range(2048)]) 410 | # ax.bar(ffn_out.to('cpu').numpy(),[x for x in range(2048)]) 411 | # ax.plot([x for x in range(2048)],ffn_out.to('cpu').numpy(),zs=index) 412 | ax.plot(x, y, zs=index, zdir='y') 413 | # ax.text(4, 6, s=send_msg, fontsize=5, color='green') 414 | # 将Matplotlib图像转换为OpenCV图像格式 415 | fig.canvas.draw() 416 | img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 417 | img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 418 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 419 | cv2.putText(img, send_msg, (10, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2) 420 | 421 | # 将图像写入视频 422 | videoWrite.write(img) 423 | 424 | # 清除轴对象并准备下一帧 425 | ax.clear() 426 | plt.show() 427 | plt.close() 428 | 429 | 430 | def on_message(message): 431 | global model_tokens, model_state 432 | 433 | srv = 'dummy_server' 434 | 435 | msg = message.replace('\\n', '\n').strip() 436 | # if len(msg) > 1000: 437 | # reply_msg('your message is too long (max 1000 tokens)') 438 | # return 439 | 440 | x_temp = GEN_TEMP 441 | x_top_p = GEN_TOP_P 442 | if ("-temp=" in msg): 443 | x_temp = float(msg.split("-temp=")[1].split(" ")[0]) 444 | msg = msg.replace("-temp=" + f'{x_temp:g}', "") 445 | # print(f"temp: {x_temp}") 446 | if ("-top_p=" in msg): 447 | x_top_p = float(msg.split("-top_p=")[1].split(" ")[0]) 448 | msg = msg.replace("-top_p=" + f'{x_top_p:g}', "") 449 | # print(f"top_p: {x_top_p}") 450 | if x_temp <= 0.2: 451 | x_temp = 0.2 452 | if x_temp >= 5: 453 | x_temp = 5 454 | if x_top_p <= 0: 455 | x_top_p = 0 456 | 457 | if msg == '+reset': 458 | out = load_all_stat('', 'chat_init') 459 | save_all_stat(srv, 'chat', out) 460 | reply_msg("Chat reset.") 461 | return 462 | 463 | else: 464 | if msg.lower() == '+': 465 | try: 466 | out = load_all_stat(srv, 'chat_pre') 467 | except: 468 | return 469 | else: 470 | out = load_all_stat(srv, 'chat') 471 | new = f"{user}{interface} {msg}\n\n{bot}{interface}" 472 | # print(f'### add ###\n[{new}]') 473 | out, all_ffn_out = run_rnn(tokenizer.encode(new), newline_adj=-999999999) 474 | save_all_stat(srv, 'chat_pre', out) 475 | 476 | begin = len(model_tokens) 477 | out_last = begin 478 | print(f'{bot}{interface}', end='', flush=True) 479 | for i in range(999): 480 | if i <= 0: 481 | newline_adj = -999999999 482 | elif i <= CHAT_LEN_SHORT: 483 | newline_adj = (i - CHAT_LEN_SHORT) / 10 484 | elif i <= CHAT_LEN_LONG: 485 | newline_adj = 0 486 | else: 487 | newline_adj = (i - CHAT_LEN_LONG) * 0.25 # MUST END THE GENERATION 488 | token = tokenizer.sample_logits( 489 | out, 490 | model_tokens, 491 | args.ctx_len, 492 | temperature=x_temp, 493 | top_p=x_top_p, 494 | ) 495 | out, all_ffn_out = run_rnn([token], newline_adj=newline_adj,outlier=opt.outlier) 496 | 497 | xxx = tokenizer.decode(model_tokens[out_last:]) 498 | if '\ufffd' not in xxx: # avoid utf-8 display issues 499 | print(xxx, end='', flush=True) 500 | out_last = begin + i + 1 501 | 502 | send_msg = tokenizer.decode(model_tokens[begin:]) 503 | 504 | # draw_ffn(send_msg) 505 | # plt2ndarr(plt) 506 | fig = plt.figure(figsize=(10, 10)) 507 | ax = plt.axes(projection='3d') 508 | # ax = Axes3D(fig) 509 | ax.set_zlim(-500, 500) 510 | ax.set_zlabel('value') 511 | ax.set_xlabel('channel') 512 | ax.set_ylabel('layer') 513 | 514 | # ax.set_title(send_msg, fontsize=10) 515 | # print('send_msg->', send_msg) 516 | 517 | if len(all_ffn_out[1].shape) == 1: 518 | for index in range(len(all_ffn_out)): 519 | ffn_out = all_ffn_out[index] 520 | ffn_result=interpretability(ffn_out) 521 | if opt.print_interpretability: 522 | print_interpretability(ffn_result) 523 | x = [x for x in range(ffn_out.shape[0])] 524 | y = ffn_out.to('cpu').numpy() 525 | # ax.bar(ffn_out.to('cpu').numpy(),[x for x in range(2048)]) 526 | # ax.bar(ffn_out.to('cpu').numpy(),[x for x in range(2048)]) 527 | # ax.plot([x for x in range(2048)],ffn_out.to('cpu').numpy(),zs=index) 528 | ax.plot(x, y, zs=index, zdir='y') 529 | # ax.text(4, 6, s=send_msg, fontsize=5, color='green') 530 | # 将Matplotlib图像转换为OpenCV图像格式 531 | fig.canvas.draw() 532 | img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 533 | img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 534 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 535 | send_msg_y = 100 536 | send_first_lens = 80 537 | if len(send_msg) // send_first_lens > 1: 538 | for send_msg_index in range(len(send_msg) // send_first_lens): 539 | if send_msg_index == 0: 540 | # cv2.putText(img, send_msg[0:20], (10, send_msg_y), cv2.FONT_HERSHEY_DUPLEX , 1, (0, 0, 0), 2) 541 | img=draw_box_string(img,10,send_msg_y, send_msg[0:send_first_lens]) 542 | # img = draw_box_string(img, 10, send_msg_y, u'你好啊') 543 | else: 544 | if send_msg_index == len(send_msg) // 20: 545 | # cv2.putText(img, send_msg[send_msg_index * 20 + 20:], 546 | # (10, send_msg_y+35), cv2.FONT_HERSHEY_DUPLEX , 1, (0, 0, 0), 2) 547 | img = draw_box_string(img, 10, send_msg_y + 35, 548 | send_msg[send_msg_index * send_first_lens + send_first_lens:]) 549 | 550 | # cv2.putText(img, send_msg[send_msg_index*20:send_msg_index*20+20], (10, send_msg_y), cv2.FONT_HERSHEY_DUPLEX , 1, (0, 0, 0), 2) 551 | img = draw_box_string(img, 10, send_msg_y, send_msg[ 552 | send_msg_index * send_first_lens:send_msg_index * send_first_lens + send_first_lens]) 553 | 554 | send_msg_y += 35 555 | 556 | else: 557 | # cv2.putText(img, send_msg, (10, send_msg_y), cv2.FONT_HERSHEY_DUPLEX , 1, (0, 0, 0), 2) 558 | img = draw_box_string(img, 10, send_msg_y, send_msg) 559 | if opt.show: 560 | cv2.imshow('outleir_runtime',img) 561 | # 将图像写入视频 562 | videoWrite.write(img) 563 | 564 | # 清除轴对象并准备下一帧 565 | ax.clear() 566 | # plt.show() 567 | plt.close() 568 | if '\n\n' in send_msg: 569 | videoWrite.release() 570 | cv2.destroyAllWindows() 571 | send_msg = send_msg.strip() 572 | break 573 | 574 | # send_msg = tokenizer.decode(model_tokens[begin:]).strip() 575 | # if send_msg.endswith(f'{user}{interface}'): # warning: needs to fix state too !!! 576 | # send_msg = send_msg[:-len(f'{user}{interface}')].strip() 577 | # break 578 | # if send_msg.endswith(f'{bot}{interface}'): 579 | # send_msg = send_msg[:-len(f'{bot}{interface}')].strip() 580 | # break 581 | 582 | # print(f'{model_tokens}') 583 | # print(f'[{tokenizer.decode(model_tokens)}]') 584 | 585 | # print(f'### send ###\n[{send_msg}]') 586 | # reply_msg(send_msg) 587 | save_all_stat(srv, 'chat', out) 588 | 589 | 590 | index = 0 591 | while True: 592 | index += 1 593 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 594 | videoWrite = cv2.VideoWriter(opt.video_name + str(index) + '.mp4', fourcc, opt.video_FPS, opt.video_size) 595 | msg = input(f'{user}{interface} ') 596 | # msg='hello' 597 | # msg='+gen Tell me what Python is and what its characteristics are. Please demonstrate your ability to write code and explain what it can do.' 598 | if len(msg.strip()) > 0: 599 | on_message(msg) 600 | else: 601 | print('Error: please say something') 602 | --------------------------------------------------------------------------------