├── .gitignore ├── CITATION.cff ├── LICENSE ├── README.md ├── RWKV-chat.png ├── RWKV-ctxlen.png ├── RWKV-demo.png ├── RWKV-eval.png ├── RWKV-eval2.png ├── RWKV-formula.png ├── RWKV-loss.png ├── RWKV-time-w.png ├── RWKV-v1 ├── src │ ├── __init__.py │ ├── model.py │ ├── trainer.py │ └── utils.py └── train.py ├── RWKV-v2-430M-Pile-LR.png ├── RWKV-v2-430M-Pile.png ├── RWKV-v2-RNN-run.png ├── RWKV-v2-RNN.png ├── RWKV-v2-RNN ├── cuda │ ├── timex_cuda.cu │ └── timex_op.cpp ├── enwik8-vocab.json ├── run.py ├── src │ ├── model.py │ ├── model_run.py │ ├── trainer.py │ └── utils.py └── train.py ├── RWKV-v3-1.5B-Pile.png ├── RWKV-v3-plan.png ├── RWKV-v3 ├── cuda │ ├── timex_cuda.cu │ └── timex_op.cpp ├── run.py ├── src │ ├── model.py │ ├── model_run.py │ ├── trainer.py │ └── utils.py ├── train.py └── verify.py ├── RWKV-v4-1.5B-Pile.png ├── RWKV-v4 ├── 20B_tokenizer.json ├── cuda │ ├── wkv_cuda.cu │ └── wkv_op.cpp ├── run.py ├── src │ ├── binidx.py │ ├── model.py │ ├── model_run.py │ ├── trainer.py │ └── utils.py ├── train.py └── verify.py ├── RWKV-v4neo ├── 20B_tokenizer.json ├── chat.py ├── cuda │ ├── wkv_cuda.cu │ ├── wkv_cuda_bf16.cu │ ├── wkv_op.cpp │ └── wkv_op_bf16.cpp ├── img_demoAE.py ├── merge_lora.py ├── run.py ├── src │ ├── __init__.py │ ├── binidx.py │ ├── dataset.py │ ├── model.py │ ├── model_img.py │ ├── model_run.py │ ├── trainer.py │ └── utils.py ├── train.py └── verify.py ├── RWKV-vs-MHA.png └── Research └── better_lr_schedule.png /.gitignore: -------------------------------------------------------------------------------- 1 | *.txt 2 | *.csv 3 | *.pth 4 | *.xlsb 5 | *.xlsx 6 | *.xls 7 | wandb/ 8 | data/ 9 | vocab.json 10 | *.sh 11 | *log/ 12 | test/ 13 | tools/ 14 | 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | pip-wheel-metadata/ 38 | share/python-wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | MANIFEST 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .nox/ 58 | .coverage 59 | .coverage.* 60 | .cache 61 | nosetests.xml 62 | coverage.xml 63 | *.cover 64 | *.py,cover 65 | .hypothesis/ 66 | .pytest_cache/ 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | local_settings.py 75 | db.sqlite3 76 | db.sqlite3-journal 77 | 78 | # Flask stuff: 79 | instance/ 80 | .webassets-cache 81 | 82 | # Scrapy stuff: 83 | .scrapy 84 | 85 | # Sphinx documentation 86 | docs/_build/ 87 | 88 | # PyBuilder 89 | target/ 90 | 91 | # Jupyter Notebook 92 | .ipynb_checkpoints 93 | 94 | # IPython 95 | profile_default/ 96 | ipython_config.py 97 | 98 | # pyenv 99 | .python-version 100 | 101 | # pipenv 102 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 103 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 104 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 105 | # install all needed dependencies. 106 | #Pipfile.lock 107 | 108 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 109 | __pypackages__/ 110 | 111 | # Celery stuff 112 | celerybeat-schedule 113 | celerybeat.pid 114 | 115 | # SageMath parsed files 116 | *.sage.py 117 | 118 | # Environments 119 | .env 120 | .venv 121 | env/ 122 | venv/ 123 | ENV/ 124 | env.bak/ 125 | venv.bak/ 126 | 127 | # Spyder project settings 128 | .spyderproject 129 | .spyproject 130 | 131 | # Rope project settings 132 | .ropeproject 133 | 134 | # mkdocs documentation 135 | /site 136 | 137 | # mypy 138 | .mypy_cache/ 139 | .dmypy.json 140 | dmypy.json 141 | 142 | # Pyre type checker 143 | .pyre/ 144 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "PENG" 5 | given-names: "Bo" 6 | orcid: "https://orcid.org/0000-0002-0865-547X" 7 | title: "RWKV-LM" 8 | version: 1.0.0 9 | doi: 10.5281/zenodo.5196577 10 | date-released: 2021-08-13 11 | url: "https://github.com/BlinkDL/RWKV-LM" 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LoRA fork of RWKV-LM 2 | 3 | A RWKV-LM fork, added with [LoRA](https://arxiv.org/abs/2106.09685) finetuning support. 4 | Currently only RWKV-v4neo is supported. 5 | The LoRA module is self-implemented to work with the TorchScript JIT. 6 | Existing RWKV-v4neo models/checkpoints should work out of the box. 7 | Now only LoRA-finetuned weights are checkpointed during training: it provides much smaller checkpoints, but you now need to specify the base model to use it. 8 | See `args.MODEL_LOAD` and `args.MODEL_LORA` in `RWKV-v4neo/chat.py`. 9 | 10 | To finetune an existing model with LoRA, just work like full finetuning but with the LoRA options, in the directory `RWKV-v4neo`: 11 | 12 | ``` 13 | python3 train.py \ 14 | --load_model \ 15 | --proj_dir \ 16 | --data_file \ 17 | --data_type \ 18 | --vocab_size 50277 --ctx_len 1024 --epoch_steps 1000 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 --micro_bsz 2 --accumulate_grad_batches 4 \ 19 | --n_layer 24 --n_embd 1024 --pre_ffn 0 --head_qk 0 --lr_init 1e-4 --lr_final 1e-4 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0 \ # all your familiar options 20 | --lora --lora_r 8 --lora_alpha 16 --lora_dropout 0.01 \ 21 | --lora_load \ # optional 22 | --lora_parts=att,ffn,time,ln # configure which parts to finetune 23 | ``` 24 | 25 | The `r`, `alpha` and `dropout` options are up to your choice. 26 | The `att`, `ffn`, `time` and `ln` refers to the TimeMix, ChannelMix, time decay/first/mix parameters, and layernorm parameters; DON'T FORGET to add the set of parameters to be finetuned here. 27 | I'm still experimenting with different configurations; your experience is also welcomed! 28 | 29 | Use [json2binidx](https://github.com/Abel2076/json2binidx_tool) to convert your data into binidx, which is best suited for this trainer implementation. 30 | Once you have the pair of files `path/to/foo.bin` and `path/to/foo.idx`, pass `--data_file path/to/foo --data_type binidx` as arguments. 31 | Notice that the `.bin` and `.idx` suffix is not there. 32 | 33 | To use the finetuned model, use `chat.py` as usual with the checkpoints in your specified `proj_dir`, but **remember to align the LoRA-corresponded options** with what you have specified during training! 34 | 35 | ``` 36 | args.MODEL_LORA = 'your_lora_checkpoint.pth' 37 | args.lora_r = 8 38 | args.lora_alpha = 32 39 | ``` 40 | -------------------------------------------------------------------------------- /RWKV-chat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blealtan/RWKV-LM-LoRA/4987137c31dd49cbbf2b2db3977930bc6ce5b84e/RWKV-chat.png -------------------------------------------------------------------------------- /RWKV-ctxlen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blealtan/RWKV-LM-LoRA/4987137c31dd49cbbf2b2db3977930bc6ce5b84e/RWKV-ctxlen.png -------------------------------------------------------------------------------- /RWKV-demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blealtan/RWKV-LM-LoRA/4987137c31dd49cbbf2b2db3977930bc6ce5b84e/RWKV-demo.png -------------------------------------------------------------------------------- /RWKV-eval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blealtan/RWKV-LM-LoRA/4987137c31dd49cbbf2b2db3977930bc6ce5b84e/RWKV-eval.png -------------------------------------------------------------------------------- /RWKV-eval2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blealtan/RWKV-LM-LoRA/4987137c31dd49cbbf2b2db3977930bc6ce5b84e/RWKV-eval2.png -------------------------------------------------------------------------------- /RWKV-formula.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blealtan/RWKV-LM-LoRA/4987137c31dd49cbbf2b2db3977930bc6ce5b84e/RWKV-formula.png -------------------------------------------------------------------------------- /RWKV-loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blealtan/RWKV-LM-LoRA/4987137c31dd49cbbf2b2db3977930bc6ce5b84e/RWKV-loss.png -------------------------------------------------------------------------------- /RWKV-time-w.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blealtan/RWKV-LM-LoRA/4987137c31dd49cbbf2b2db3977930bc6ce5b84e/RWKV-time-w.png -------------------------------------------------------------------------------- /RWKV-v1/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blealtan/RWKV-LM-LoRA/4987137c31dd49cbbf2b2db3977930bc6ce5b84e/RWKV-v1/src/__init__.py -------------------------------------------------------------------------------- /RWKV-v1/src/trainer.py: -------------------------------------------------------------------------------- 1 | import math, sys, datetime 2 | import logging 3 | import numpy as np 4 | from tqdm.auto import tqdm 5 | import torch 6 | import torch.optim as optim 7 | from torch.optim.lr_scheduler import LambdaLR 8 | from torch.utils.data.dataloader import DataLoader 9 | logger = logging.getLogger(__name__) 10 | 11 | # print('logging to wandb... (comment it if you don\'t have wandb)') 12 | # import wandb # comment this if you don't have wandb 13 | 14 | class TrainerConfig: 15 | max_epochs = 10 16 | batch_size = 64 17 | learning_rate = 4e-4 18 | betas = (0.9, 0.99) 19 | eps = 1e-8 20 | grad_norm_clip = 1.0 21 | weight_decay = 0.01 22 | lr_decay = False # linear warmup followed by cosine decay 23 | warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper 24 | final_tokens = 260e9 # at which point do we reach lr_final 25 | epoch_save_frequency = 0 26 | epoch_save_path = 'trained-' 27 | num_workers = 0 # for DataLoader 28 | 29 | def __init__(self, **kwargs): 30 | for k,v in kwargs.items(): 31 | setattr(self, k, v) 32 | 33 | class Trainer: 34 | 35 | def __init__(self, model, train_dataset, test_dataset, config): 36 | self.model = model 37 | self.train_dataset = train_dataset 38 | self.test_dataset = test_dataset 39 | self.config = config 40 | self.avg_loss = -1 41 | self.steps = 0 42 | 43 | if 'wandb' in sys.modules: 44 | cfg = model.config 45 | for k in config.__dict__: 46 | setattr(cfg, k, config.__dict__[k]) # combine cfg 47 | wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False) 48 | 49 | self.device = 'cpu' 50 | if torch.cuda.is_available(): # take over whatever gpus are on the system 51 | self.device = torch.cuda.current_device() 52 | self.model = torch.nn.DataParallel(self.model).to(self.device) 53 | 54 | def get_run_name(self): 55 | raw_model = self.model.module if hasattr(self.model, "module") else self.model 56 | cfg = raw_model.config 57 | run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd) 58 | return run_name 59 | 60 | def train(self): 61 | model, config = self.model, self.config 62 | raw_model = model.module if hasattr(self.model, "module") else model 63 | optimizer = raw_model.configure_optimizers(config) 64 | 65 | def run_epoch(split): 66 | is_train = split == 'train' 67 | model.train(is_train) 68 | data = self.train_dataset if is_train else self.test_dataset 69 | loader = DataLoader(data, shuffle=True, pin_memory=True, 70 | batch_size=config.batch_size, 71 | num_workers=config.num_workers) 72 | 73 | pbar = tqdm(enumerate(loader), total=len(loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader) 74 | 75 | for it, (x, y) in pbar: 76 | x = x.to(self.device) # place data on the correct device 77 | y = y.to(self.device) 78 | 79 | with torch.set_grad_enabled(is_train): 80 | _, loss = model(x, y) # forward the model 81 | loss = loss.mean() # collapse all losses if they are scattered on multiple gpus 82 | 83 | if is_train: # backprop and update the parameters 84 | model.zero_grad() 85 | loss.backward() 86 | 87 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip) 88 | optimizer.step() 89 | 90 | if config.lr_decay: # decay the learning rate based on our progress 91 | self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100) 92 | lr_final_factor = config.lr_final / config.learning_rate 93 | if self.tokens < config.warmup_tokens: 94 | # linear warmup 95 | lr_mult = lr_final_factor + (1 - lr_final_factor) * float(self.tokens) / float(config.warmup_tokens) 96 | progress = 0 97 | else: 98 | # cosine learning rate decay 99 | progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens)) 100 | # progress = min(progress * 1.1, 1.0) # more fine-tuning with low LR 101 | lr_mult = (0.5 + lr_final_factor / 2) + (0.5 - lr_final_factor / 2) * math.cos(math.pi * progress) # better 1.0 ~ 0.1 102 | lr = config.learning_rate * lr_mult 103 | for param_group in optimizer.param_groups: 104 | param_group['lr'] = lr 105 | else: 106 | lr = config.learning_rate 107 | 108 | now_loss = loss.item() # report progress 109 | 110 | if 'wandb' in sys.modules: 111 | wandb.log({"loss": now_loss}, step = self.steps * self.config.batch_size) 112 | self.steps += 1 113 | 114 | if self.avg_loss < 0: 115 | self.avg_loss = now_loss 116 | else: 117 | # factor = max(1.0 / 300, 1.0 / math.sqrt(it + 1)) 118 | factor = 1 / (it + 1) 119 | self.avg_loss = self.avg_loss * (1.0 - factor) + now_loss * factor 120 | pbar.set_description(f"epoch {epoch+1} progress {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}") 121 | 122 | while True: 123 | self.tokens = 0 # counter used for learning rate decay 124 | for epoch in range(config.max_epochs): 125 | 126 | run_epoch('train') 127 | 128 | if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1): 129 | raw_model = self.model.module if hasattr(self.model, "module") else self.model # DataParallel wrappers keep raw model object in .module 130 | torch.save(raw_model, self.config.epoch_save_path + str(epoch+1) + '.pth') 131 | -------------------------------------------------------------------------------- /RWKV-v1/src/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | def top_k_logits(logits, k): 8 | v, ix = torch.topk(logits, k) 9 | out = logits.clone() 10 | out[out < v[:, [-1]]] = -float('Inf') 11 | return out 12 | 13 | def top_p_probs(probs, p): 14 | out = probs.clone() 15 | 16 | sorted_probs, sorted_indices = torch.sort(out, descending=True) 17 | cumulative_probs = torch.cumsum(sorted_probs, dim=-1) 18 | sorted_indices_to_remove = cumulative_probs > p 19 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 20 | sorted_indices_to_remove[..., 0] = 0 21 | indices_to_remove = sorted_indices[sorted_indices_to_remove] 22 | out[indices_to_remove] = 0 23 | 24 | return out 25 | 26 | # top-p + top-k + pow&ratio sampling 27 | def sample_logits(logits, pos, temperature=1.0, top_k=None, top_p=None, min_p_pow=None, min_p_ratio=None): 28 | logits = logits[:, pos, :] / temperature 29 | probs = F.softmax(logits, dim=-1) 30 | 31 | if min_p_ratio is not None: 32 | limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio 33 | logits[probs < limit] = -float('Inf') 34 | 35 | if top_k is not None: 36 | logits = top_k_logits(logits, top_k) 37 | 38 | probs = F.softmax(logits, dim=-1) 39 | 40 | if top_p is not None: 41 | probs[0] = top_p_probs(probs[0], top_p) 42 | 43 | ix = torch.multinomial(probs, num_samples=1) 44 | return ix[0][0].cpu() 45 | 46 | def set_seed(seed): 47 | random.seed(seed) 48 | np.random.seed(seed) 49 | torch.manual_seed(seed) 50 | torch.cuda.manual_seed_all(seed) 51 | -------------------------------------------------------------------------------- /RWKV-v1/train.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################## 2 | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM 3 | ######################################################################################################## 4 | 5 | import os, sys, time, math, random, json, datetime, logging 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import Dataset 9 | from src.trainer import Trainer, TrainerConfig 10 | from src.model import GPT, GPTConfig 11 | from src.utils import set_seed 12 | 13 | set_seed(42) 14 | np.set_printoptions(precision=4, suppress=True, linewidth=200) 15 | logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO,) 16 | 17 | # RWKV : our new model - fastest when ctx_len is long - VRAM friendly - good performance 18 | # MHA_rotary : usual MultiheadAttention+Rotary+GeGLU - not as good 19 | # MHA_shift : with time-shift - good performance 20 | # MHA_pro : slow (lots of tricks) - VRAM hungry - very good performance 21 | model_type = 'RWKV' 22 | 23 | # datafile = u"V:\\NLP\\text8" 24 | # datafile = u"V:\\NLP\\enwik8" 25 | datafile = u"V:\\NLP\\simplebooks\\simplebooks-92-raw\\train.txt" 26 | datafile_encoding = 'utf-8' 27 | # datafile = u"D:\\NLP-Data\\ww100M.txt" 28 | # datafile = u"D:\\NLP-Data\\__2019.txt" 29 | # datafile = u"Y:\\BlinkNLP\\_txt_\\txt\\_all.txt" 30 | # datafile = u"V:\\NLP\\enwik8-shift-300.bpe" 31 | # datafile_encoding = 'utf-16' 32 | # datafile = u"V:\\NLP\\simplebooks-shift-utf32.word" 33 | # datafile_encoding = 'utf-32' 34 | 35 | datafile_type = 0 # use 0 for char-level english. use 1 for chinese. only affects some RWKV hyperparametrs 36 | 37 | #################################### VERY IMPORTANT #################################### 38 | epoch_save_frequency = 10 # 0 = never, 1 = every 'epoch', 2 = every two 'epoch', etc. 39 | epoch_save_path = 'trained-' 40 | 41 | batch_size = 32 # if you see "CUDA out of memory", reduce this. 42 | # if you have good GPU, increase this. 43 | # use GPU-Z to find the highest value for your VRAM. 44 | 45 | n_epoch = 100 # the 'epoch' here is actually very short (and of fixed length) 46 | ######################################################################################## 47 | 48 | model_level = 'character' # 'character' (recommended) or 'word' 49 | 50 | ctx_len = 256 # context length, try 512 or 1024 if you have good GPU 51 | n_layer = 6 # try 12 for 100M, 24 for 300M 52 | n_head = 8 # try 12 for 100M, 16 for 300M 53 | 54 | n_embd = n_head * 64 55 | n_attn = n_embd 56 | n_ffn = n_embd 57 | 58 | lr_init = 6e-4 if model_type == 'RWKV' else 4e-4 # RWKV can use higher lr. 8e-4 = 0.0008 4e-4 = 0.0004 59 | lr_final = 4e-5 60 | 61 | betas = (0.9, 0.99) if model_type == 'RWKV' else (0.9, 0.99) 62 | eps = 4e-9 63 | weight_decay = 0 if model_type == 'RWKV' else 0.01 # wd is not useful when we have enough data 64 | 65 | epoch_length_fixed = 10000 # make an 'epoch' very short, so we can see the training progress 66 | 67 | ######## special hyperparameters for RWKV model ######## 68 | rwkv_emb_scale = 0.4 # scale of initial embedding. 0.4 is a good choice 69 | rwkv_tiny_attn = 0#64 if (datafile_type == 0 and ctx_len > 600) else 0 # extra tiny attention dim, useful for long ctx char-level english 70 | rwkv_tiny_head = 1 # 1 is good enough. 8 is slow 71 | # n_side_proj = 512 # extra 'side projection', quite useful for BPE models 72 | 73 | ######################################################################################################## 74 | # Load data 75 | ######################################################################################################## 76 | 77 | print('loading data... ' + datafile) 78 | 79 | class Dataset(Dataset): 80 | def __init__(self, data, model_level, ctx_len): 81 | print('building token list...', end=' ') 82 | if model_level == 'word': 83 | import re 84 | data = re.sub(r'(\n|\.|\,|\?|\!|\:|\;|\-|\—|\||\'|\"|\`|\(|\)|[0-9]|\[|\]|\{|\}|\=|\+|\*|\\|\/|\~|\&|\$|\#|\%)', r' \g<0> ', data) 85 | data = re.sub(' +',' ',data) 86 | print('splitting token...') 87 | data = data.lower().split(' ') 88 | unique = sorted(list(set(data))) 89 | # print() 90 | # for u in unique: 91 | # print(u, end=' ') 92 | # print('\n\n') 93 | 94 | xx = 0 95 | xxObj = {} 96 | for u in unique: 97 | xxObj[xx] = u 98 | xx += 1 99 | with open('vocab.json', "w", encoding="utf-16") as vocab_file: 100 | vocab_file.write(json.dumps(xxObj, ensure_ascii=False)) 101 | 102 | data_size, vocab_size = len(data), len(unique) 103 | print('data has %d %ss, %d unique.' % (data_size, model_level, vocab_size)) 104 | self.stoi = { ch:i for i,ch in enumerate(unique) } 105 | self.itos = { i:ch for i,ch in enumerate(unique) } 106 | self.ctx_len = ctx_len 107 | self.vocab_size = vocab_size 108 | self.data = data 109 | 110 | def __len__(self): 111 | return epoch_length_fixed 112 | 113 | def __getitem__(self, idx): 114 | i = np.random.randint(0, len(self.data) - (self.ctx_len + 1)) # cheat: pick a random spot in dataset 115 | chunk = self.data[i:i+self.ctx_len+1] 116 | dix = [self.stoi[s] for s in chunk] 117 | x = torch.tensor(dix[:-1], dtype=torch.long) 118 | y = torch.tensor(dix[1:], dtype=torch.long) 119 | return x, y 120 | 121 | train_dataset = Dataset(open(datafile, "r", encoding=datafile_encoding).read(), model_level, ctx_len) 122 | 123 | ######################################################################################################## 124 | # Train model 125 | ######################################################################################################## 126 | 127 | model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type, 128 | rwkv_emb_scale=rwkv_emb_scale, rwkv_tiny_attn=rwkv_tiny_attn, rwkv_tiny_head=rwkv_tiny_head, 129 | n_layer=n_layer, n_head=n_head, n_embd=n_embd, n_attn=n_attn, n_ffn=n_ffn)) 130 | 131 | # load a trained model 132 | # model.load_state_dict(torch.load('trained-xxx.pth').state_dict()) 133 | 134 | print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', betas, 'eps', eps, 'wd', weight_decay, 'ctx', ctx_len, 'layer', n_layer, 'head', n_head, 'embd', n_embd, 'attn', n_attn, 'ffn', n_ffn) 135 | tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size, weight_decay=weight_decay, 136 | learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps, 137 | warmup_tokens=0, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=0, epoch_save_frequency=epoch_save_frequency, epoch_save_path=epoch_save_path) 138 | trainer = Trainer(model, train_dataset, None, tconf) 139 | 140 | trainer.train() 141 | 142 | torch.save(model, 'trained-' + trainer.get_run_name() + '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth') 143 | -------------------------------------------------------------------------------- /RWKV-v2-430M-Pile-LR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blealtan/RWKV-LM-LoRA/4987137c31dd49cbbf2b2db3977930bc6ce5b84e/RWKV-v2-430M-Pile-LR.png -------------------------------------------------------------------------------- /RWKV-v2-430M-Pile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blealtan/RWKV-LM-LoRA/4987137c31dd49cbbf2b2db3977930bc6ce5b84e/RWKV-v2-430M-Pile.png -------------------------------------------------------------------------------- /RWKV-v2-RNN-run.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blealtan/RWKV-LM-LoRA/4987137c31dd49cbbf2b2db3977930bc6ce5b84e/RWKV-v2-RNN-run.png -------------------------------------------------------------------------------- /RWKV-v2-RNN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blealtan/RWKV-LM-LoRA/4987137c31dd49cbbf2b2db3977930bc6ce5b84e/RWKV-v2-RNN.png -------------------------------------------------------------------------------- /RWKV-v2-RNN/cuda/timex_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // require T <= Tmax, T % 4 == 0, B % BF == 0, B % BB === 0 (Tmax and BF and BB are passed by compiler) 4 | 5 | #define F4(A, B) ((float4 *)(A))[(B) >> 2] 6 | 7 | template 8 | __global__ void kernel_forward(const F *__restrict__ const __w, const F *__restrict__ const __k, F *__restrict__ const x, 9 | const F eps, const int B, const int C, const int T) { 10 | const int i = blockIdx.y; 11 | const int ij = (B * C) / BF; 12 | const int t = threadIdx.x << 2; 13 | 14 | __shared__ F ww[Tmax]; 15 | __shared__ F kk[Tmax * BF]; 16 | F4(ww, t) = F4(__w, t + T * (i % C)); 17 | 18 | #pragma unroll 19 | for (int j = 0; j < BF; j++) { 20 | F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j)); 21 | } 22 | __syncthreads(); 23 | 24 | float4 s[BF]; 25 | #pragma unroll 26 | for (int j = 0; j < BF; j++) { 27 | s[j] = {eps, eps, eps, eps}; 28 | } 29 | const F *__restrict__ const w = ww + T - t - 4; 30 | for (int u = 0; u <= t; u++) { 31 | #pragma unroll 32 | for (int j = 0; j < BF; j++) { 33 | const F x = kk[u + Tmax * j]; 34 | s[j].x += w[u + 3] * x; 35 | s[j].y += w[u + 2] * x; 36 | s[j].z += w[u + 1] * x; 37 | s[j].w += w[u + 0] * x; 38 | } 39 | } 40 | #pragma unroll 41 | for (int j = 0; j < BF; j++) { 42 | const F *__restrict__ const k = kk + Tmax * j; 43 | s[j].y += w[t + 3] * k[t + 1]; 44 | s[j].z += w[t + 2] * k[t + 1]; 45 | s[j].z += w[t + 3] * k[t + 2]; 46 | s[j].w += w[t + 1] * k[t + 1]; 47 | s[j].w += w[t + 2] * k[t + 2]; 48 | s[j].w += w[t + 3] * k[t + 3]; 49 | F4(x, t + T * (i + ij * j)) = s[j]; 50 | } 51 | } 52 | 53 | template 54 | __global__ void kernel_backward_W(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk, 55 | F *__restrict__ const gw, F *__restrict__ const gk, 56 | const int B, const int C, const int T) { 57 | const int i = blockIdx.y; 58 | const int t = threadIdx.x << 2; 59 | 60 | __shared__ F k[Tmax]; 61 | __shared__ F gg[Tmax]; 62 | F4(k, t) = F4(__k, t + T * i); 63 | F4(gg, t) = F4(__gwk, t + T * i); 64 | __syncthreads(); 65 | 66 | float4 s = {0, 0, 0, 0}; 67 | 68 | const F *__restrict__ const g = gg + T - t - 4; 69 | for (int u = 0; u <= t; u++) { 70 | F x = k[u]; 71 | s.x += g[u + 3] * x; 72 | s.y += g[u + 2] * x; 73 | s.z += g[u + 1] * x; 74 | s.w += g[u + 0] * x; 75 | } 76 | s.y += g[t + 3] * k[t + 1]; 77 | s.z += g[t + 2] * k[t + 1]; 78 | s.z += g[t + 3] * k[t + 2]; 79 | s.w += g[t + 1] * k[t + 1]; 80 | s.w += g[t + 2] * k[t + 2]; 81 | s.w += g[t + 3] * k[t + 3]; 82 | F4(gw, t + T * i) = s; 83 | } 84 | void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T) { 85 | dim3 gridDim(1, B * C / BF); 86 | dim3 blockDim(T >> 2); 87 | kernel_forward<<>>(w, k, x, eps, B, C, T); 88 | } 89 | 90 | template 91 | __global__ void kernel_backward(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk, 92 | F *__restrict__ const gw, F *__restrict__ const gk, 93 | const int B, const int C, const int T) { 94 | const int i = blockIdx.y; 95 | const int ij = (B * C) / BB; 96 | const int t = threadIdx.x << 2; 97 | 98 | __shared__ F w[Tmax]; 99 | __shared__ F kk[Tmax * BB]; 100 | __shared__ F gg[Tmax * BB]; 101 | F4(w, t) = F4(__w, t + T * (i % C)); 102 | 103 | #pragma unroll 104 | for (int j = 0; j < BB; j++) { 105 | F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j)); 106 | F4(gg, t + Tmax * j) = F4(__gwk, t + T * (i + ij * j)); 107 | } 108 | __syncthreads(); 109 | 110 | float4 s[BB]; 111 | #pragma unroll 112 | for (int j = 0; j < BB; j++) { 113 | s[j] = {0, 0, 0, 0}; 114 | } 115 | 116 | for (int u = 0; u <= t; u++) { 117 | #pragma unroll 118 | for (int j = 0; j < BB; j++) { 119 | const F *__restrict__ const g = gg + Tmax * j + T - t - 4; 120 | F x = kk[u + Tmax * j]; 121 | s[j].x += g[u + 3] * x; 122 | s[j].y += g[u + 2] * x; 123 | s[j].z += g[u + 1] * x; 124 | s[j].w += g[u + 0] * x; 125 | } 126 | } 127 | #pragma unroll 128 | for (int j = 0; j < BB; j++) { 129 | const F *__restrict__ const k = kk + Tmax * j; 130 | const F *__restrict__ const g = gg + Tmax * j + T - t - 4; 131 | s[j].y += g[t + 3] * k[t + 1]; 132 | s[j].z += g[t + 2] * k[t + 1]; 133 | s[j].z += g[t + 3] * k[t + 2]; 134 | s[j].w += g[t + 1] * k[t + 1]; 135 | s[j].w += g[t + 2] * k[t + 2]; 136 | s[j].w += g[t + 3] * k[t + 3]; 137 | F4(gw, t + T * (i + ij * j)) = s[j]; 138 | } 139 | 140 | #pragma unroll 141 | for (int j = 0; j < BB; j++) { 142 | s[j] = {0, 0, 0, 0}; 143 | } 144 | 145 | for (int u = t + 3; u < T; u++) { 146 | F x = w[u]; 147 | #pragma unroll 148 | for (int j = 0; j < BB; j++) { 149 | const F *__restrict__ const g = gg + Tmax * j + T + t - 3; 150 | s[j].x += g[2 - u] * x; 151 | s[j].y += g[3 - u] * x; 152 | s[j].z += g[4 - u] * x; 153 | s[j].w += g[5 - u] * x; 154 | } 155 | } 156 | #pragma unroll 157 | for (int j = 0; j < BB; j++) { 158 | const F *__restrict__ const g = gg + Tmax * j + T + t - 3; 159 | s[j].x += g[2 - t] * w[t + 0]; 160 | s[j].x += g[1 - t] * w[t + 1]; 161 | s[j].x += g[0 - t] * w[t + 2]; 162 | s[j].y += g[2 - t] * w[t + 1]; 163 | s[j].y += g[1 - t] * w[t + 2]; 164 | s[j].z += g[2 - t] * w[t + 2]; 165 | F4(gk, t + T * (i + ij * j)) = s[j]; 166 | } 167 | } 168 | void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T) { 169 | dim3 gridDim(1, B * C / BB); 170 | dim3 blockDim(T >> 2); 171 | kernel_backward<<>>(w, k, gwk, gw, gk, B, C, T); 172 | } 173 | -------------------------------------------------------------------------------- /RWKV-v2-RNN/cuda/timex_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T); 4 | void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T); 5 | 6 | void forward(torch::Tensor &w, const torch::Tensor &k, torch::Tensor &x, double eps, int64_t B, int64_t C, int64_t T) { 7 | cuda_forward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (float *)x.data_ptr(), eps, B, C, T); 8 | } 9 | void backward(torch::Tensor &w, const torch::Tensor &k, const torch::Tensor &gwk, torch::Tensor &gw, torch::Tensor &gk, int64_t B, int64_t C, int64_t T) { 10 | cuda_backward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (const float *)gwk.data_ptr(), (float *)gw.data_ptr(), (float *)gk.data_ptr(), B, C, T); 11 | } 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("forward", &forward, "timex forward"); 15 | m.def("backward", &backward, "timex backward"); 16 | } 17 | 18 | TORCH_LIBRARY(timex, m) { 19 | m.def("forward", forward); 20 | m.def("backward", backward); 21 | } 22 | -------------------------------------------------------------------------------- /RWKV-v2-RNN/enwik8-vocab.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blealtan/RWKV-LM-LoRA/4987137c31dd49cbbf2b2db3977930bc6ce5b84e/RWKV-v2-RNN/enwik8-vocab.json -------------------------------------------------------------------------------- /RWKV-v2-RNN/run.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | ######################################################################################################## 3 | # The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM 4 | ######################################################################################################## 5 | 6 | import numpy as np 7 | import math 8 | import time 9 | import types 10 | import copy 11 | import torch 12 | from torch.nn import functional as F 13 | from src.utils import TOKENIZER, Dataset 14 | from src.model_run import RWKV_RNN 15 | torch.backends.cudnn.benchmark = True 16 | torch.backends.cudnn.allow_tf32 = True 17 | torch.backends.cuda.matmul.allow_tf32 = True 18 | np.set_printoptions(precision=4, suppress=True, linewidth=200) 19 | 20 | ### Step 1: set model ################################################################################## 21 | 22 | ctx_len = 1024 23 | n_layer = 6 24 | n_embd = 512 25 | model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre' 26 | 27 | # your trained model 28 | MODEL_NAME = 'trained-31' 29 | WORD_NAME = 'vocab' # the .json vocab (generated by train.py 30 | 31 | # ########## Uncomment these to test my 27M params enwik8 model ########## 32 | # MODEL_NAME = 'enwik8-ppl1.65-6064-1024-RWKV-6-512-2022-03-25-21-05-13' 33 | # WORD_NAME = 'enwik8-vocab' 34 | # EVAL_DATA = 'enwik8' # uncomment this for EVAL MODE (no text generation) 35 | # ######################################################################## 36 | 37 | # --> set UNKNOWN_CHAR to the rarest token in your vocab.json <-- 38 | # --> all unknown tokens in your context will be denoted by it <-- 39 | UNKNOWN_CHAR = ' ' # here we just set it to [space] for simplicity 40 | 41 | RUN_DEVICE = 'cpu' # 'cpu' (already very fast) or 'cuda' 42 | DEBUG_DEBUG = False # True False - show softmax output 43 | 44 | ### Step 2: set context ################################################################################ 45 | 46 | context = "\nIn the" # ==> this is your prompt 47 | 48 | NUM_TRIALS = 999 49 | LENGTH_PER_TRIAL = 500 50 | 51 | TEMPERATURE = 1.0 52 | top_p = 0.7 53 | top_p_newline = 0.9 54 | 55 | ######################################################################################################## 56 | 57 | print(f'Loading {MODEL_NAME}...') 58 | model = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len) 59 | tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR) 60 | 61 | ######################################################################################################## 62 | 63 | if 'EVAL_DATA' in vars() or 'EVAL_DATA' in globals(): 64 | print('Evaluating on ' + EVAL_DATA + ' ...') 65 | 66 | data = open(EVAL_DATA, "r", encoding='utf-8').read() 67 | 68 | loss_table = np.zeros(ctx_len) 69 | 70 | N_SAMPLE = 1000 71 | 72 | for iii in range(N_SAMPLE): 73 | pos = np.random.randint(0, len(data) - ctx_len-1) 74 | context = data[pos:pos+ctx_len+1] 75 | ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context] 76 | 77 | model.clear() 78 | for i in range(1, ctx_len+1): 79 | x = ctx[:i] 80 | out = model.run(x) 81 | prob = F.softmax(torch.tensor(out), dim=-1) 82 | loss_table[i-1] += -math.log(prob[ctx[i]]) 83 | 84 | print(f'Tested {iii+1} samples: avg_loss over ctx_len =', 85 | np.mean(loss_table) / (iii+1)) 86 | 87 | exit(0) 88 | 89 | ######################################################################################################## 90 | 91 | context = tokenizer.refine_context(context) 92 | print('\nYour prompt has ' + str(len(context)) + ' tokens.') 93 | print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. This will be much faster in future versions. <--\n') 94 | 95 | for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): 96 | t_begin = time.time_ns() 97 | 98 | src_len = len(context) 99 | ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context] 100 | print(('-' * 30) + context, end='') 101 | 102 | model.clear() 103 | if TRIAL == 0: 104 | init_state = types.SimpleNamespace() 105 | for i in range(src_len): 106 | x = ctx[:i+1] 107 | if i == src_len - 1: 108 | init_state.out = model.run(x) 109 | else: 110 | model.run(x) 111 | model.save(init_state) 112 | else: 113 | model.load(init_state) 114 | 115 | for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)): 116 | x = ctx[:i+1] 117 | x = x[-ctx_len:] 118 | 119 | if i == src_len: 120 | out = copy.deepcopy(init_state.out) 121 | else: 122 | out = model.run(x) 123 | if DEBUG_DEBUG: 124 | print('model', np.array(x), '==>', np.array( 125 | out), np.max(out), np.min(out)) 126 | 127 | char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE, 128 | top_p_usual=top_p, top_p_newline=top_p_newline) 129 | char = char.item() 130 | print(tokenizer.itos[int(char)], end='', flush=True) 131 | ctx += [char] 132 | t_end = time.time_ns() 133 | print("\n----------", round((t_end - t_begin) / (10 ** 9), 2), end='s ') 134 | -------------------------------------------------------------------------------- /RWKV-v2-RNN/src/model_run.py: -------------------------------------------------------------------------------- 1 | import types 2 | import copy 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | RWKV_K_CLAMP = 60 7 | RWKV_K_EPS = 1e-16 8 | RWKV_HEAD_QK_DIM = 256 9 | 10 | DEBUG_TIME = False # True False - show trained time-coeffs 11 | 12 | 13 | class RWKV_RNN(): 14 | def __init__(self, MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len): 15 | self.RUN_DEVICE = RUN_DEVICE 16 | self.model_type = model_type 17 | self.n_layer = n_layer 18 | self.n_embd = n_embd 19 | self.ctx_len = ctx_len 20 | 21 | self.w = types.SimpleNamespace() 22 | 23 | w = torch.load(MODEL_NAME + '.pth', 24 | map_location=torch.device(RUN_DEVICE)) 25 | for x in w.keys(): 26 | if '.time_' in x: 27 | w[x] = w[x].squeeze() 28 | if '.time_decay' in x: 29 | w[x] = torch.exp(-torch.exp(w[x])) 30 | if '.time_first' in x: 31 | w[x] = torch.exp(w[x]) 32 | if DEBUG_TIME and '.time_' in x: 33 | print(x, w[x].squeeze().cpu().numpy()) 34 | 35 | xx = x.split('.') 36 | here = self.w 37 | for i in range(len(xx)): 38 | if xx[i].isdigit(): 39 | ii = int(xx[i]) 40 | if ii not in here: 41 | here[ii] = types.SimpleNamespace() 42 | here = here[ii] 43 | else: 44 | if i == len(xx) - 1: 45 | setattr(here, xx[i], w[x]) 46 | elif not hasattr(here, xx[i]): 47 | if xx[i+1].isdigit(): 48 | setattr(here, xx[i], {}) 49 | else: 50 | setattr(here, xx[i], types.SimpleNamespace()) 51 | here = getattr(here, xx[i]) 52 | 53 | self.clear() 54 | 55 | def clear(self): 56 | self.xx = {} 57 | self.aa = {} 58 | self.bb = {} 59 | self.hk = None 60 | 61 | def save(self, target): 62 | target.xx = copy.deepcopy(self.xx) 63 | target.aa = copy.deepcopy(self.aa) 64 | target.bb = copy.deepcopy(self.bb) 65 | target.hk = copy.deepcopy(self.hk) 66 | 67 | def load(self, target): 68 | self.xx = copy.deepcopy(target.xx) 69 | self.aa = copy.deepcopy(target.aa) 70 | self.bb = copy.deepcopy(target.bb) 71 | self.hk = copy.deepcopy(target.hk) 72 | 73 | def LN(self, xx, w): 74 | return F.layer_norm(xx, (self.n_embd,), weight=w.weight, bias=w.bias) 75 | 76 | def FF(self, xx, w, name): 77 | if name not in self.xx: 78 | self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) 79 | x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix) 80 | self.xx[name] = xx 81 | 82 | r = torch.sigmoid(w.receptance.weight @ x) 83 | k = torch.square(torch.relu(w.key.weight @ x)) 84 | kv = w.value.weight @ k 85 | 86 | return r * kv 87 | 88 | def SA(self, xx, w, name): 89 | if name not in self.xx: 90 | self.xx[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) 91 | self.aa[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) 92 | self.bb[name] = torch.zeros(self.n_embd, device=self.RUN_DEVICE) 93 | x = xx * w.time_mix + self.xx[name] * (1 - w.time_mix) 94 | self.xx[name] = xx 95 | 96 | r = torch.sigmoid(w.receptance.weight @ x) 97 | 98 | k = torch.exp(torch.clamp(w.key.weight @ x, max=RWKV_K_CLAMP)) 99 | v = w.value.weight @ x 100 | kv = k * v 101 | 102 | a = self.aa[name] + w.time_first * kv 103 | b = self.bb[name] + w.time_first * k 104 | self.aa[name] = w.time_decay * self.aa[name] + kv 105 | self.bb[name] = w.time_decay * self.bb[name] + k 106 | 107 | rwkv = r * a / (b + RWKV_K_EPS) 108 | 109 | return w.output.weight @ rwkv 110 | 111 | def run(self, ctx): 112 | w = self.w 113 | x = w.emb.weight[ctx[-1]] 114 | 115 | for i in range(self.n_layer): 116 | x = self.LN(x, w.blocks[i].ln1) 117 | if i == 0 and self.model_type == 'RWKV-ffnPre': 118 | x = x + self.FF(x, w.blocks[i].ffnPre, f'ffnPre.{i}') 119 | else: 120 | x = x + self.SA(x, w.blocks[i].att, f'att.{i}') 121 | x = self.LN(x, w.blocks[i].ln2) 122 | x = x + self.FF(x, w.blocks[i].ffn, f'ffn.{i}') 123 | 124 | x = self.LN(x, w.ln_out) 125 | 126 | if self.hk == None: 127 | self.hk = (w.head_k.weight @ x).unsqueeze(0) 128 | else: 129 | self.hk = torch.cat( 130 | [self.hk, (w.head_k.weight @ x).unsqueeze(0)], dim=0) 131 | if self.hk.shape[0] > self.ctx_len: 132 | self.hk = self.hk[-self.ctx_len:, :] 133 | 134 | q = w.head_q.weight @ x 135 | 136 | x = w.head.weight @ x 137 | x = x.cpu().numpy().tolist() 138 | 139 | c = (self.hk @ q) / RWKV_HEAD_QK_DIM 140 | for i in range(len(c)): 141 | x[ctx[i]] += c[i] 142 | 143 | return x 144 | -------------------------------------------------------------------------------- /RWKV-v2-RNN/src/trainer.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################## 2 | # The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM 3 | ######################################################################################################## 4 | 5 | from torch.utils.data.dataloader import DataLoader 6 | from torch.optim.lr_scheduler import LambdaLR 7 | from torch.nn import functional as F 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import torch 11 | from tqdm.auto import tqdm 12 | import numpy as np 13 | import logging 14 | import os 15 | import datetime 16 | import sys 17 | import math 18 | 19 | # import wandb # comment this if you don't have wandb 20 | # print('logging to wandb... (comment it if you don\'t have wandb)') 21 | 22 | logger = logging.getLogger(__name__) 23 | torch.backends.cudnn.benchmark = True 24 | torch.backends.cudnn.allow_tf32 = True 25 | torch.backends.cuda.matmul.allow_tf32 = True 26 | 27 | log_file = open("mylog.txt", "a") 28 | 29 | 30 | class TrainerConfig: 31 | max_epochs = 10 32 | batch_size = 64 33 | learning_rate = 4e-4 34 | betas = (0.9, 0.99) 35 | eps = 1e-8 36 | grad_norm_clip = 1.0 37 | lr_decay = True # linear warmup followed by cosine decay 38 | warmup_tokens = 0 39 | final_tokens = 0 40 | epoch_save_frequency = 0 41 | epoch_save_path = 'trained-' 42 | num_workers = 0 # for DataLoader 43 | 44 | def __init__(self, **kwargs): 45 | for k, v in kwargs.items(): 46 | setattr(self, k, v) 47 | 48 | 49 | class Trainer: 50 | 51 | def __init__(self, model, train_dataset, test_dataset, config): 52 | self.model = model 53 | self.train_dataset = train_dataset 54 | self.test_dataset = test_dataset 55 | self.config = config 56 | self.avg_loss = -1 57 | self.steps = 0 58 | 59 | if 'wandb' in sys.modules: 60 | cfg = model.config 61 | for k in config.__dict__: 62 | setattr(cfg, k, config.__dict__[k]) # combine cfg 63 | wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' + 64 | datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False) 65 | 66 | self.device = 'cpu' 67 | if torch.cuda.is_available(): # take over whatever gpus are on the system 68 | self.device = torch.cuda.current_device() 69 | 70 | def get_run_name(self): 71 | raw_model = self.model.module if hasattr( 72 | self.model, "module") else self.model 73 | cfg = raw_model.config 74 | run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + \ 75 | cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd) 76 | return run_name 77 | 78 | def train(self): 79 | model, config = self.model, self.config 80 | raw_model = model.module if hasattr(self.model, "module") else model 81 | optimizer = raw_model.configure_optimizers(config) 82 | 83 | def run_epoch(split): 84 | is_train = split == 'train' 85 | model.train(is_train) 86 | data = self.train_dataset if is_train else self.test_dataset 87 | 88 | if config.num_workers > 0: 89 | loader = DataLoader(data, shuffle=False, pin_memory=True, 90 | batch_size=config.batch_size, 91 | num_workers=config.num_workers) 92 | else: 93 | loader = DataLoader(data, shuffle=False, 94 | batch_size=config.batch_size, 95 | num_workers=config.num_workers) 96 | 97 | pbar = tqdm(enumerate(loader), total=len( 98 | loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader) 99 | 100 | for it, (x, y) in pbar: 101 | x = x.to(self.device) # place data on the correct device 102 | y = y.to(self.device) 103 | 104 | with torch.set_grad_enabled(is_train): 105 | _, loss = model(x, y) # forward the model 106 | 107 | if is_train: # backprop and update the parameters 108 | model.zero_grad() 109 | loss.backward() 110 | 111 | if config.grad_norm_clip > 0: 112 | torch.nn.utils.clip_grad_norm_( 113 | model.parameters(), config.grad_norm_clip) 114 | 115 | optimizer.step() 116 | 117 | if config.lr_decay: # decay the learning rate based on our progress 118 | # number of tokens processed this step (i.e. label is not -100) 119 | self.tokens += (y >= 0).sum() 120 | lr_final_factor = config.lr_final / config.learning_rate 121 | if self.tokens < config.warmup_tokens: 122 | # linear warmup 123 | lr_mult = lr_final_factor + \ 124 | (1 - lr_final_factor) * float(self.tokens) / \ 125 | float(config.warmup_tokens) 126 | progress = 0 127 | else: 128 | # cosine learning rate decay 129 | progress = float(self.tokens - config.warmup_tokens) / float( 130 | max(1, config.final_tokens - config.warmup_tokens)) 131 | lr_mult = (0.5 + lr_final_factor / 2) + (0.5 - lr_final_factor / 132 | 2) * math.cos(math.pi * progress) # better 1.0 ~ 0.1 133 | lr = config.learning_rate * lr_mult 134 | for param_group in optimizer.param_groups: 135 | param_group['lr'] = lr 136 | else: 137 | lr = config.learning_rate 138 | 139 | now_loss = loss.item() # report progress 140 | self.lr = lr 141 | 142 | if 'wandb' in sys.modules: 143 | wandb.log({"loss": now_loss}, 144 | step=self.steps * self.config.batch_size) 145 | self.steps += 1 146 | 147 | if self.avg_loss < 0: 148 | self.avg_loss = now_loss 149 | else: 150 | factor = 1 / (it + 1) 151 | self.avg_loss = self.avg_loss * \ 152 | (1.0 - factor) + now_loss * factor 153 | pbar.set_description( 154 | f"mini-epoch {epoch+1} prog {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}") 155 | 156 | self.tokens = 0 # counter used for learning rate decay 157 | for epoch in range(config.max_epochs): 158 | 159 | run_epoch('train') 160 | 161 | log_file.write( 162 | f'{epoch+1} {self.avg_loss:.6f} {math.exp(self.avg_loss):.4f} {self.lr:.8f} {datetime.datetime.now()} \n') 163 | log_file.flush() 164 | 165 | if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1): 166 | # DataParallel wrappers keep raw model object in .module 167 | raw_model = self.model.module if hasattr( 168 | self.model, "module") else self.model 169 | torch.save(raw_model.state_dict(), 170 | self.config.epoch_save_path + str(epoch+1) + '.pth') 171 | -------------------------------------------------------------------------------- /RWKV-v2-RNN/src/utils.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################## 2 | # The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM 3 | ######################################################################################################## 4 | 5 | import json 6 | import random 7 | import time 8 | import math 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | from torch.nn import functional as F 13 | from torch.utils.data import Dataset 14 | 15 | 16 | class Dataset(Dataset): 17 | def __init__(self, data, ctx_len, epoch_length_fixed): 18 | print('building token list...', end=' ') 19 | unique = sorted(list(set(data))) 20 | # print() 21 | # for u in unique: 22 | # print(u, end=' ') 23 | # print('\n\n') 24 | 25 | xx = 0 26 | xxObj = {} 27 | for u in unique: 28 | xxObj[xx] = u 29 | xx += 1 30 | with open('vocab.json', "w", encoding="utf-16") as vocab_file: 31 | vocab_file.write(json.dumps(xxObj, ensure_ascii=False)) 32 | 33 | data_size, vocab_size = len(data), len(unique) 34 | print('data has %d tokens, %d unique.' % (data_size, vocab_size)) 35 | self.stoi = {ch: i for i, ch in enumerate(unique)} 36 | self.itos = {i: ch for i, ch in enumerate(unique)} 37 | self.ctx_len = ctx_len 38 | self.epoch_length_fixed = epoch_length_fixed 39 | self.vocab_size = vocab_size 40 | self.data = data 41 | 42 | def __len__(self): 43 | return self.epoch_length_fixed 44 | 45 | def __getitem__(self, idx): 46 | # cheat: pick a random spot in dataset 47 | i = np.random.randint(0, len(self.data) - (self.ctx_len + 1)) 48 | chunk = self.data[i:i+self.ctx_len+1] 49 | dix = [self.stoi[s] for s in chunk] 50 | x = torch.tensor(dix[:-1], dtype=torch.long, 51 | device=torch.device('cuda')) 52 | y = torch.tensor(dix[1:], dtype=torch.long, 53 | device=torch.device('cuda')) 54 | return x, y 55 | 56 | 57 | class TOKENIZER(): 58 | def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'): 59 | with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file: 60 | self.word_table = json.load(result_file) 61 | 62 | self.vocab_size = len(self.word_table) 63 | 64 | self.stoi = {v: int(k) for k, v in self.word_table.items()} 65 | self.itos = {int(k): v for k, v in self.word_table.items()} 66 | 67 | self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR] 68 | 69 | def refine_context(self, context): 70 | context = context.strip().split('\n') 71 | for c in range(len(context)): 72 | context[c] = context[c].strip().strip('\u3000').strip('\r') 73 | context = list(filter(lambda c: c != '', context)) 74 | context = '\n' + ('\n'.join(context)).strip() 75 | if context == '': 76 | context = '\n' 77 | 78 | return context 79 | 80 | def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None): 81 | # out[self.UNKNOWN_CHAR] = -float('Inf') 82 | 83 | lastChar = int(x[-1]) 84 | 85 | probs = F.softmax(torch.tensor(out), dim=-1) 86 | 87 | if self.itos[lastChar] == '\n': 88 | top_p = top_p_newline 89 | else: 90 | top_p = top_p_usual 91 | 92 | sorted_probs, s_index = torch.sort(probs, descending=True) 93 | 94 | # for j in range(30): 95 | # pp = sorted_probs[j].item() 96 | # if pp < 0.005: 97 | # break 98 | # ss = self.itos[int(s_index[j])].replace('\n','_') 99 | # print(f'{math.floor(pp*100):>3.0f}{ss}', end='') 100 | # print('') 101 | 102 | cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy() 103 | cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) 104 | 105 | probs[probs < cutoff] = 0 106 | # print("[" + str(round(cutoff,4)) + ' ' + str(round(to_float(sum(probs)),3)) + "]", end = "") 107 | 108 | if temperature != 1.0: 109 | probs = probs.pow(1.0 / temperature) 110 | 111 | return torch.multinomial(probs, num_samples=1)[0] 112 | 113 | 114 | def to_float(x): 115 | return x.cpu().detach().numpy().flatten()[0].astype(float) 116 | 117 | 118 | def set_seed(seed): 119 | random.seed(seed) 120 | np.random.seed(seed) 121 | torch.manual_seed(seed) 122 | torch.cuda.manual_seed_all(seed) 123 | -------------------------------------------------------------------------------- /RWKV-v2-RNN/train.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################## 2 | # The RWKV v2-RNN Language Model - https://github.com/BlinkDL/RWKV-LM 3 | ######################################################################################################## 4 | 5 | import logging 6 | import datetime 7 | import json 8 | from src.model import GPT, GPTConfig 9 | from src.trainer import Trainer, TrainerConfig 10 | from src.utils import Dataset 11 | import torch 12 | import numpy as np 13 | torch.backends.cudnn.benchmark = True 14 | torch.backends.cudnn.allow_tf32 = True 15 | torch.backends.cuda.matmul.allow_tf32 = True 16 | 17 | ### Step 1: set training data ########################################################################## 18 | 19 | datafile = "enwik8" 20 | datafile_encoding = 'utf-8' 21 | # datafile_encoding = 'utf-16le' 22 | 23 | ### Step 2: set model size ############################################################################# 24 | 25 | ctx_len = 1024 # ===> increase T_MAX in model.py if your ctx_len > 1024 26 | n_layer = 6 27 | n_embd = 512 28 | 29 | # 'RWKV' (better for char-level English) or 'RWKV-ffnPre' (better in some cases) 30 | model_type = 'RWKV' 31 | 32 | ### Step 3: set batch size ############################################################################# 33 | 34 | # ===> batch_size must be divisible by B_GROUP_FORWARD and B_GROUP_BACKWARD in model.py 35 | # For example, if your batch_size = 20, you can set B_GROUP_FORWARD = 4, B_GROUP_BACKWARD = 2 36 | # If you see "CUDA out of memory", reduce it. Use GPU-Z to find the highest value for your VRAM. 37 | batch_size = 12 38 | 39 | ### Step 4: set learning rate, training mini-epochs ####################################################### 40 | 41 | lr_init = 6e-4 42 | lr_final = 1e-5 43 | # the mini-epoch is very short and of fixed length (ctx_len * epoch_length_fixed tokens) 44 | n_epoch = 500 45 | # 0 = never, 1 = every mini-epoch, 2 = every two mini-epochs, etc. 46 | epoch_save_frequency = 30 47 | epoch_save_path = 'trained-' 48 | 49 | epoch_length_fixed = 10000 50 | 51 | ######################################################################################################## 52 | 53 | # import src.utils 54 | # src.utils.set_seed(42) # remember to change seed if you load a model 55 | 56 | np.set_printoptions(precision=4, suppress=True, linewidth=200) 57 | logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 58 | datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO,) 59 | 60 | grad_norm_clip = 1.0 61 | warmup_tokens = 0 62 | 63 | betas = (0.9, 0.99) 64 | eps = 4e-9 65 | 66 | num_workers = 0 67 | 68 | ######################################################################################################## 69 | # Load data 70 | ######################################################################################################## 71 | 72 | print('loading data... ' + datafile) 73 | train_dataset = Dataset(open( 74 | datafile, "r", encoding=datafile_encoding).read(), ctx_len, epoch_length_fixed) 75 | 76 | ######################################################################################################## 77 | # Train model 78 | ######################################################################################################## 79 | if __name__ == '__main__': 80 | 81 | model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type, 82 | n_layer=n_layer, n_embd=n_embd)).cuda() 83 | 84 | # # # load a trained model. remember to change random seed 85 | # m2 = torch.load('trained-61.pth') 86 | # model.load_state_dict(m2) 87 | 88 | print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', 89 | betas, 'eps', eps, 'ctx', ctx_len, 'layer', n_layer, 'embd', n_embd, ) 90 | tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size, 91 | learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps, grad_norm_clip=grad_norm_clip, 92 | warmup_tokens=warmup_tokens, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=num_workers, epoch_save_frequency=epoch_save_frequency, epoch_save_path=epoch_save_path) 93 | trainer = Trainer(model, train_dataset, None, tconf) 94 | 95 | trainer.train() 96 | 97 | torch.save(model.state_dict(), 'trained-' + str(n_epoch) + '-' + trainer.get_run_name() + 98 | '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth') 99 | -------------------------------------------------------------------------------- /RWKV-v3-1.5B-Pile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blealtan/RWKV-LM-LoRA/4987137c31dd49cbbf2b2db3977930bc6ce5b84e/RWKV-v3-1.5B-Pile.png -------------------------------------------------------------------------------- /RWKV-v3-plan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blealtan/RWKV-LM-LoRA/4987137c31dd49cbbf2b2db3977930bc6ce5b84e/RWKV-v3-plan.png -------------------------------------------------------------------------------- /RWKV-v3/cuda/timex_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // require T <= Tmax, T % 4 == 0, B % BF == 0, B % BB === 0 (Tmax and BF and BB are passed by compiler) 4 | 5 | #define F4(A, B) ((float4 *)(A))[(B) >> 2] 6 | 7 | template 8 | __global__ void kernel_forward(const F *__restrict__ const __w, const F *__restrict__ const __k, F *__restrict__ const x, 9 | const F eps, const int B, const int C, const int T) { 10 | const int i = blockIdx.y; 11 | const int ij = (B * C) / BF; 12 | const int t = threadIdx.x << 2; 13 | 14 | __shared__ F ww[Tmax]; 15 | __shared__ F kk[Tmax * BF]; 16 | F4(ww, t) = F4(__w, t + T * (i % C)); 17 | 18 | #pragma unroll 19 | for (int j = 0; j < BF; j++) { 20 | F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j)); 21 | } 22 | __syncthreads(); 23 | 24 | float4 s[BF]; 25 | #pragma unroll 26 | for (int j = 0; j < BF; j++) { 27 | s[j] = {eps, eps, eps, eps}; 28 | } 29 | const F *__restrict__ const w = ww + T - t - 4; 30 | for (int u = 0; u <= t; u++) { 31 | #pragma unroll 32 | for (int j = 0; j < BF; j++) { 33 | const F x = kk[u + Tmax * j]; 34 | s[j].x += w[u + 3] * x; 35 | s[j].y += w[u + 2] * x; 36 | s[j].z += w[u + 1] * x; 37 | s[j].w += w[u + 0] * x; 38 | } 39 | } 40 | #pragma unroll 41 | for (int j = 0; j < BF; j++) { 42 | const F *__restrict__ const k = kk + Tmax * j; 43 | s[j].y += w[t + 3] * k[t + 1]; 44 | s[j].z += w[t + 2] * k[t + 1]; 45 | s[j].z += w[t + 3] * k[t + 2]; 46 | s[j].w += w[t + 1] * k[t + 1]; 47 | s[j].w += w[t + 2] * k[t + 2]; 48 | s[j].w += w[t + 3] * k[t + 3]; 49 | F4(x, t + T * (i + ij * j)) = s[j]; 50 | } 51 | } 52 | 53 | template 54 | __global__ void kernel_backward_W(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk, 55 | F *__restrict__ const gw, F *__restrict__ const gk, 56 | const int B, const int C, const int T) { 57 | const int i = blockIdx.y; 58 | const int t = threadIdx.x << 2; 59 | 60 | __shared__ F k[Tmax]; 61 | __shared__ F gg[Tmax]; 62 | F4(k, t) = F4(__k, t + T * i); 63 | F4(gg, t) = F4(__gwk, t + T * i); 64 | __syncthreads(); 65 | 66 | float4 s = {0, 0, 0, 0}; 67 | 68 | const F *__restrict__ const g = gg + T - t - 4; 69 | for (int u = 0; u <= t; u++) { 70 | F x = k[u]; 71 | s.x += g[u + 3] * x; 72 | s.y += g[u + 2] * x; 73 | s.z += g[u + 1] * x; 74 | s.w += g[u + 0] * x; 75 | } 76 | s.y += g[t + 3] * k[t + 1]; 77 | s.z += g[t + 2] * k[t + 1]; 78 | s.z += g[t + 3] * k[t + 2]; 79 | s.w += g[t + 1] * k[t + 1]; 80 | s.w += g[t + 2] * k[t + 2]; 81 | s.w += g[t + 3] * k[t + 3]; 82 | F4(gw, t + T * i) = s; 83 | } 84 | void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T) { 85 | dim3 gridDim(1, B * C / BF); 86 | dim3 blockDim(T >> 2); 87 | kernel_forward<<>>(w, k, x, eps, B, C, T); 88 | } 89 | 90 | template 91 | __global__ void kernel_backward(const F *__restrict__ const __w, const F *__restrict__ const __k, const F *__restrict__ const __gwk, 92 | F *__restrict__ const gw, F *__restrict__ const gk, 93 | const int B, const int C, const int T) { 94 | const int i = blockIdx.y; 95 | const int ij = (B * C) / BB; 96 | const int t = threadIdx.x << 2; 97 | 98 | __shared__ F w[Tmax]; 99 | __shared__ F kk[Tmax * BB]; 100 | __shared__ F gg[Tmax * BB]; 101 | F4(w, t) = F4(__w, t + T * (i % C)); 102 | 103 | #pragma unroll 104 | for (int j = 0; j < BB; j++) { 105 | F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j)); 106 | F4(gg, t + Tmax * j) = F4(__gwk, t + T * (i + ij * j)); 107 | } 108 | __syncthreads(); 109 | 110 | float4 s[BB]; 111 | #pragma unroll 112 | for (int j = 0; j < BB; j++) { 113 | s[j] = {0, 0, 0, 0}; 114 | } 115 | 116 | for (int u = 0; u <= t; u++) { 117 | #pragma unroll 118 | for (int j = 0; j < BB; j++) { 119 | const F *__restrict__ const g = gg + Tmax * j + T - t - 4; 120 | F x = kk[u + Tmax * j]; 121 | s[j].x += g[u + 3] * x; 122 | s[j].y += g[u + 2] * x; 123 | s[j].z += g[u + 1] * x; 124 | s[j].w += g[u + 0] * x; 125 | } 126 | } 127 | #pragma unroll 128 | for (int j = 0; j < BB; j++) { 129 | const F *__restrict__ const k = kk + Tmax * j; 130 | const F *__restrict__ const g = gg + Tmax * j + T - t - 4; 131 | s[j].y += g[t + 3] * k[t + 1]; 132 | s[j].z += g[t + 2] * k[t + 1]; 133 | s[j].z += g[t + 3] * k[t + 2]; 134 | s[j].w += g[t + 1] * k[t + 1]; 135 | s[j].w += g[t + 2] * k[t + 2]; 136 | s[j].w += g[t + 3] * k[t + 3]; 137 | F4(gw, t + T * (i + ij * j)) = s[j]; 138 | } 139 | 140 | #pragma unroll 141 | for (int j = 0; j < BB; j++) { 142 | s[j] = {0, 0, 0, 0}; 143 | } 144 | 145 | for (int u = t + 3; u < T; u++) { 146 | F x = w[u]; 147 | #pragma unroll 148 | for (int j = 0; j < BB; j++) { 149 | const F *__restrict__ const g = gg + Tmax * j + T + t - 3; 150 | s[j].x += g[2 - u] * x; 151 | s[j].y += g[3 - u] * x; 152 | s[j].z += g[4 - u] * x; 153 | s[j].w += g[5 - u] * x; 154 | } 155 | } 156 | #pragma unroll 157 | for (int j = 0; j < BB; j++) { 158 | const F *__restrict__ const g = gg + Tmax * j + T + t - 3; 159 | s[j].x += g[2 - t] * w[t + 0]; 160 | s[j].x += g[1 - t] * w[t + 1]; 161 | s[j].x += g[0 - t] * w[t + 2]; 162 | s[j].y += g[2 - t] * w[t + 1]; 163 | s[j].y += g[1 - t] * w[t + 2]; 164 | s[j].z += g[2 - t] * w[t + 2]; 165 | F4(gk, t + T * (i + ij * j)) = s[j]; 166 | } 167 | } 168 | void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T) { 169 | dim3 gridDim(1, B * C / BB); 170 | dim3 blockDim(T >> 2); 171 | kernel_backward<<>>(w, k, gwk, gw, gk, B, C, T); 172 | } 173 | -------------------------------------------------------------------------------- /RWKV-v3/cuda/timex_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | void cuda_forward(const float *w, const float *k, float *x, float eps, int B, int C, int T); 4 | void cuda_backward(const float *w, const float *k, const float *gwk, float *gw, float *gk, int B, int C, int T); 5 | 6 | void forward(torch::Tensor &w, const torch::Tensor &k, torch::Tensor &x, double eps, int64_t B, int64_t C, int64_t T) { 7 | cuda_forward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (float *)x.data_ptr(), eps, B, C, T); 8 | } 9 | void backward(torch::Tensor &w, const torch::Tensor &k, const torch::Tensor &gwk, torch::Tensor &gw, torch::Tensor &gk, int64_t B, int64_t C, int64_t T) { 10 | cuda_backward((const float *)w.data_ptr(), (const float *)k.data_ptr(), (const float *)gwk.data_ptr(), (float *)gw.data_ptr(), (float *)gk.data_ptr(), B, C, T); 11 | } 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("forward", &forward, "timex forward"); 15 | m.def("backward", &backward, "timex backward"); 16 | } 17 | 18 | TORCH_LIBRARY(timex, m) { 19 | m.def("forward", forward); 20 | m.def("backward", backward); 21 | } 22 | -------------------------------------------------------------------------------- /RWKV-v3/run.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################## 2 | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM 3 | ######################################################################################################## 4 | 5 | import numpy as np 6 | import math 7 | import time 8 | import types 9 | import copy 10 | import torch 11 | from torch.nn import functional as F 12 | from src.utils import TOKENIZER, Dataset 13 | from src.model_run import RWKV_RNN 14 | torch.backends.cudnn.benchmark = True 15 | torch.backends.cudnn.allow_tf32 = True 16 | torch.backends.cuda.matmul.allow_tf32 = True 17 | np.set_printoptions(precision=4, suppress=True, linewidth=200) 18 | 19 | ### Step 1: set model ################################################################################## 20 | 21 | ctx_len = 1024 22 | n_layer = 6 23 | n_embd = 512 24 | model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre' 25 | 26 | # your trained model 27 | MODEL_NAME = 'trained-1' 28 | WORD_NAME = 'vocab' # the .json vocab (generated by train.py 29 | 30 | # --> set UNKNOWN_CHAR to the rarest token in your vocab.json <-- 31 | # --> all unknown tokens in your context will be denoted by it <-- 32 | UNKNOWN_CHAR = ' ' # here we just set it to [space] for simplicity 33 | 34 | RUN_DEVICE = 'cpu' # 'cpu' (already very fast) or 'cuda' 35 | DEBUG_DEBUG = False # True False - show softmax output 36 | 37 | ### Step 2: set context ################################################################################ 38 | 39 | context = "\nIn the" # ==> this is your prompt 40 | 41 | NUM_TRIALS = 999 42 | LENGTH_PER_TRIAL = 500 43 | 44 | TEMPERATURE = 1.0 45 | top_p = 0.7 46 | top_p_newline = 0.9 47 | 48 | ######################################################################################################## 49 | 50 | print(f'Loading {MODEL_NAME}...') 51 | model = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len) 52 | tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR) 53 | 54 | ######################################################################################################## 55 | 56 | context = tokenizer.refine_context(context) 57 | print('\nYour prompt has ' + str(len(context)) + ' tokens.') 58 | print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. Use GPT to build the hidden state for better speed. <--\n') 59 | 60 | for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): 61 | t_begin = time.time_ns() 62 | 63 | src_len = len(context) 64 | ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context] 65 | print(('-' * 30) + context, end='') 66 | 67 | model.clear() 68 | if TRIAL == 0: 69 | init_state = types.SimpleNamespace() 70 | for i in range(src_len): 71 | x = ctx[:i+1] 72 | if i == src_len - 1: 73 | init_state.out = model.run(x) 74 | else: 75 | model.run(x) 76 | model.save(init_state) 77 | else: 78 | model.load(init_state) 79 | 80 | for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)): 81 | x = ctx[:i+1] 82 | x = x[-ctx_len:] 83 | 84 | if i == src_len: 85 | out = copy.deepcopy(init_state.out) 86 | else: 87 | out = model.run(x) 88 | if DEBUG_DEBUG: 89 | print('model', np.array(x), '==>', np.array( 90 | out), np.max(out), np.min(out)) 91 | 92 | char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE, 93 | top_p_usual=top_p, top_p_newline=top_p_newline) 94 | char = char.item() 95 | print(tokenizer.itos[int(char)], end='', flush=True) 96 | ctx += [char] 97 | t_end = time.time_ns() 98 | print("\n----------", round((t_end - t_begin) / (10 ** 9), 2), end='s ') 99 | -------------------------------------------------------------------------------- /RWKV-v3/src/trainer.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################## 2 | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM 3 | ######################################################################################################## 4 | 5 | from torch.utils.data.dataloader import DataLoader 6 | from torch.optim.lr_scheduler import LambdaLR 7 | from torch.nn import functional as F 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import torch 11 | from tqdm.auto import tqdm 12 | import numpy as np 13 | import logging 14 | import os 15 | import datetime 16 | import sys 17 | import math 18 | 19 | # import wandb # comment this if you don't have wandb 20 | # print('logging to wandb... (comment it if you don\'t have wandb)') 21 | 22 | logger = logging.getLogger(__name__) 23 | torch.backends.cudnn.benchmark = True 24 | torch.backends.cudnn.allow_tf32 = True 25 | torch.backends.cuda.matmul.allow_tf32 = True 26 | 27 | log_file = open("mylog.txt", "a") 28 | 29 | 30 | class TrainerConfig: 31 | max_epochs = 10 32 | batch_size = 64 33 | learning_rate = 4e-4 34 | betas = (0.9, 0.99) 35 | eps = 1e-8 36 | grad_norm_clip = 1.0 37 | lr_decay = True # linear warmup followed by cosine decay 38 | warmup_tokens = 0 39 | final_tokens = 0 40 | epoch_save_frequency = 0 41 | epoch_save_path = 'trained-' 42 | num_workers = 0 # for DataLoader 43 | 44 | def __init__(self, **kwargs): 45 | for k, v in kwargs.items(): 46 | setattr(self, k, v) 47 | 48 | 49 | class Trainer: 50 | 51 | def __init__(self, model, train_dataset, test_dataset, config): 52 | self.model = model 53 | self.train_dataset = train_dataset 54 | self.test_dataset = test_dataset 55 | self.config = config 56 | self.avg_loss = -1 57 | self.steps = 0 58 | 59 | if 'wandb' in sys.modules: 60 | cfg = model.config 61 | for k in config.__dict__: 62 | setattr(cfg, k, config.__dict__[k]) # combine cfg 63 | wandb.init(project="RWKV-LM", name=self.get_run_name() + '-' + 64 | datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'), config=cfg, save_code=False) 65 | 66 | self.device = 'cpu' 67 | if torch.cuda.is_available(): # take over whatever gpus are on the system 68 | self.device = torch.cuda.current_device() 69 | 70 | def get_run_name(self): 71 | raw_model = self.model.module if hasattr( 72 | self.model, "module") else self.model 73 | cfg = raw_model.config 74 | run_name = str(cfg.vocab_size) + '-' + str(cfg.ctx_len) + '-' + \ 75 | cfg.model_type + '-' + str(cfg.n_layer) + '-' + str(cfg.n_embd) 76 | return run_name 77 | 78 | def train(self): 79 | model, config = self.model, self.config 80 | raw_model = model.module if hasattr(self.model, "module") else model 81 | optimizer = raw_model.configure_optimizers(config) 82 | 83 | def run_epoch(split): 84 | is_train = split == 'train' 85 | model.train(is_train) 86 | data = self.train_dataset if is_train else self.test_dataset 87 | 88 | if config.num_workers > 0: 89 | loader = DataLoader(data, shuffle=False, pin_memory=True, 90 | batch_size=config.batch_size, 91 | num_workers=config.num_workers) 92 | else: 93 | loader = DataLoader(data, shuffle=False, 94 | batch_size=config.batch_size, 95 | num_workers=config.num_workers) 96 | 97 | pbar = tqdm(enumerate(loader), total=len( 98 | loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader) 99 | 100 | for it, (x, y) in pbar: 101 | x = x.to(self.device) # place data on the correct device 102 | y = y.to(self.device) 103 | 104 | with torch.set_grad_enabled(is_train): 105 | _, loss = model(x, y) # forward the model 106 | 107 | if is_train: # backprop and update the parameters 108 | model.zero_grad() 109 | loss.backward() 110 | 111 | if config.grad_norm_clip > 0: 112 | torch.nn.utils.clip_grad_norm_( 113 | model.parameters(), config.grad_norm_clip) 114 | 115 | optimizer.step() 116 | 117 | if config.lr_decay: # decay the learning rate based on our progress 118 | # number of tokens processed this step (i.e. label is not -100) 119 | self.tokens += (y >= 0).sum() 120 | lr_final_factor = config.lr_final / config.learning_rate 121 | if self.tokens < config.warmup_tokens: 122 | # linear warmup 123 | lr_mult = lr_final_factor + \ 124 | (1 - lr_final_factor) * float(self.tokens) / \ 125 | float(config.warmup_tokens) 126 | progress = 0 127 | else: 128 | # exponential learning rate decay 129 | progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens)) 130 | if progress >= 1: 131 | lr_mult = lr_final_factor 132 | else: 133 | lr_mult = math.exp(math.log(lr_final_factor) * pow(progress, 1)) 134 | lr = config.learning_rate * lr_mult 135 | for param_group in optimizer.param_groups: 136 | param_group['lr'] = lr 137 | else: 138 | lr = config.learning_rate 139 | 140 | now_loss = loss.item() # report progress 141 | self.lr = lr 142 | 143 | if 'wandb' in sys.modules: 144 | wandb.log({"loss": now_loss}, 145 | step=self.steps * self.config.batch_size) 146 | self.steps += 1 147 | 148 | if self.avg_loss < 0: 149 | self.avg_loss = now_loss 150 | else: 151 | factor = 1 / (it + 1) 152 | self.avg_loss = self.avg_loss * \ 153 | (1.0 - factor) + now_loss * factor 154 | pbar.set_description( 155 | f"mini-epoch {epoch+1} prog {progress*100.0:.2f}% iter {it}: ppl {math.exp(self.avg_loss):.2f} loss {self.avg_loss:.4f} lr {lr:e}") 156 | 157 | self.tokens = 0 # counter used for learning rate decay 158 | for epoch in range(config.max_epochs): 159 | 160 | run_epoch('train') 161 | 162 | log_file.write( 163 | f'{epoch+1} {self.avg_loss:.6f} {math.exp(self.avg_loss):.4f} {self.lr:.8f} {datetime.datetime.now()} \n') 164 | log_file.flush() 165 | 166 | if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1): 167 | # DataParallel wrappers keep raw model object in .module 168 | raw_model = self.model.module if hasattr( 169 | self.model, "module") else self.model 170 | torch.save(raw_model.state_dict(), 171 | self.config.epoch_save_path + str(epoch+1) + '.pth') 172 | -------------------------------------------------------------------------------- /RWKV-v3/src/utils.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################## 2 | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM 3 | ######################################################################################################## 4 | 5 | import json 6 | import random 7 | import time 8 | import math 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | from torch.nn import functional as F 13 | from torch.utils.data import Dataset 14 | 15 | 16 | class Dataset(Dataset): 17 | def __init__(self, data, ctx_len, epoch_length_fixed): 18 | print('building token list...', end=' ') 19 | unique = sorted(list(set(data))) 20 | # print() 21 | # for u in unique: 22 | # print(u, end=' ') 23 | # print('\n\n') 24 | 25 | xx = 0 26 | xxObj = {} 27 | for u in unique: 28 | xxObj[xx] = u 29 | xx += 1 30 | with open('vocab.json', "w", encoding="utf-16") as vocab_file: 31 | vocab_file.write(json.dumps(xxObj, ensure_ascii=False)) 32 | 33 | data_size, vocab_size = len(data), len(unique) 34 | print('data has %d tokens, %d unique.' % (data_size, vocab_size)) 35 | self.stoi = {ch: i for i, ch in enumerate(unique)} 36 | self.itos = {i: ch for i, ch in enumerate(unique)} 37 | self.ctx_len = ctx_len 38 | self.epoch_length_fixed = epoch_length_fixed 39 | self.vocab_size = vocab_size 40 | self.data = data 41 | 42 | def __len__(self): 43 | return self.epoch_length_fixed 44 | 45 | def __getitem__(self, idx): 46 | # cheat: pick a random spot in dataset 47 | i = np.random.randint(0, len(self.data) - (self.ctx_len + 1)) 48 | chunk = self.data[i:i+self.ctx_len+1] 49 | dix = [self.stoi[s] for s in chunk] 50 | x = torch.tensor(dix[:-1], dtype=torch.long, 51 | device=torch.device('cuda')) 52 | y = torch.tensor(dix[1:], dtype=torch.long, 53 | device=torch.device('cuda')) 54 | return x, y 55 | 56 | 57 | class TOKENIZER(): 58 | def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'): 59 | with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file: 60 | self.word_table = json.load(result_file) 61 | 62 | self.vocab_size = len(self.word_table) 63 | 64 | self.stoi = {v: int(k) for k, v in self.word_table.items()} 65 | self.itos = {int(k): v for k, v in self.word_table.items()} 66 | 67 | self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR] 68 | 69 | def refine_context(self, context): 70 | context = context.strip().split('\n') 71 | for c in range(len(context)): 72 | context[c] = context[c].strip().strip('\u3000').strip('\r') 73 | context = list(filter(lambda c: c != '', context)) 74 | context = '\n' + ('\n'.join(context)).strip() 75 | if context == '': 76 | context = '\n' 77 | 78 | return context 79 | 80 | def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None): 81 | # out[self.UNKNOWN_CHAR] = -float('Inf') 82 | 83 | lastChar = int(x[-1]) 84 | 85 | probs = F.softmax(torch.tensor(out), dim=-1) 86 | 87 | if self.itos[lastChar] == '\n': 88 | top_p = top_p_newline 89 | else: 90 | top_p = top_p_usual 91 | 92 | sorted_probs, s_index = torch.sort(probs, descending=True) 93 | 94 | # for j in range(30): 95 | # pp = sorted_probs[j].item() 96 | # if pp < 0.005: 97 | # break 98 | # ss = self.itos[int(s_index[j])].replace('\n','_') 99 | # print(f'{math.floor(pp*100):>3.0f}{ss}', end='') 100 | # print('') 101 | 102 | cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy() 103 | cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) 104 | 105 | probs[probs < cutoff] = 0 106 | # print("[" + str(round(cutoff,4)) + ' ' + str(round(to_float(sum(probs)),3)) + "]", end = "") 107 | 108 | if temperature != 1.0: 109 | probs = probs.pow(1.0 / temperature) 110 | 111 | return torch.multinomial(probs, num_samples=1)[0] 112 | 113 | 114 | def to_float(x): 115 | return x.cpu().detach().numpy().flatten()[0].astype(float) 116 | 117 | 118 | def set_seed(seed): 119 | random.seed(seed) 120 | np.random.seed(seed) 121 | torch.manual_seed(seed) 122 | torch.cuda.manual_seed_all(seed) 123 | -------------------------------------------------------------------------------- /RWKV-v3/train.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################## 2 | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM 3 | ######################################################################################################## 4 | 5 | import os 6 | 7 | # if False: # True False ---> Set to False if you don't understand it 8 | # print("\n\n[[[ SPECIAL DEBUG MODE FOR MYSELF. DON'T ENABLE THIS IF YOU DON'T UNDERSTAND IT ]]]\n\n") 9 | # os.environ["CUDA_VISIBLE_DEVICES"] = "0" 10 | # import src.utils 11 | # src.utils.set_seed(42) # make training deterministic (including dataloader). if you are doing this, remember to change seed when you load a model (otherwise the dataloader loads old samples) 12 | 13 | import logging 14 | import datetime 15 | from src.model import GPT, GPTConfig 16 | from src.trainer import Trainer, TrainerConfig 17 | from src.utils import Dataset 18 | import torch 19 | import numpy as np 20 | 21 | np.set_printoptions(precision=4, suppress=True, linewidth=200) 22 | logging.basicConfig(format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 23 | datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO,) 24 | torch.backends.cudnn.benchmark = True 25 | torch.backends.cudnn.allow_tf32 = True 26 | torch.backends.cuda.matmul.allow_tf32 = True 27 | 28 | ### Step 1: set training data ########################################################################## 29 | 30 | datafile = "../data/enwik8" # your data 31 | datafile_encoding = 'utf-8' 32 | # datafile_encoding = 'utf-16le' 33 | 34 | ### Step 2: set model size ############################################################################# 35 | # ----> test deeper models (n_layer at least 12) to see the advantage of RWKV-3 over RWKV-2 36 | 37 | ctx_len = 1024 # increase T_MAX in model.py if your ctx_len > 1024 38 | n_layer = 6 39 | n_embd = 512 40 | 41 | # 'RWKV' (better for English) or 'RWKV-ffnPre' (better in some cases) 42 | model_type = 'RWKV' 43 | 44 | # ---> there is a RWKV_HEAD_QK_DIM in model.py and model_run.py 45 | # set it to 256, then it's using my headQK trick (similar to a tiny attention) to improve loss 46 | # set it to 0, then it's a pure RNN (attention-free) 47 | 48 | ### Step 3: set batch size ############################################################################# 49 | 50 | # ---> batch_size must be divisible by B_GROUP_FORWARD and B_GROUP_BACKWARD in model.py 51 | # for example, if your batch_size = 20, you can set B_GROUP_FORWARD = 4, B_GROUP_BACKWARD = 2 52 | # if you see "CUDA out of memory", reduce batch_size. Use nvidia-smi to find the highest value for your GPU. 53 | batch_size = 12 54 | 55 | ### Step 4: set learning rate, number of mini-epochs ####################################################### 56 | # By default we are using exponential LR decay. 57 | # 58 | # Here are my suggestions for training a good model. 59 | # Let's say you will train a L6-D512 model. 60 | # 1) Set lr_init = lr_final = 8e-4. Let it run for some mini-epochs, until the improvement of loss become slow. 61 | # 2) Check epoch_save_frequency and make sure the partially-trained model is saved. Ctrl+C to stop the run. 62 | # 3) Set lr_init = 8e-4, lr_final = 1e-5, warmup_tokens = ctx_len * batch_size * 50, betas = (0.9, 0.999). 63 | # 4) Search for "torch.load" here and modify it to load the partially-trained model. Continue the training. 64 | # 65 | # For L12-D768, set lr_init = 6e-4. For L24-D1024, set lr_init = 4e-4. For L24-D2048, set lr_init = 3e-4. 66 | 67 | lr_init = 8e-4 # we can use larger lr because of preLN 68 | lr_final = 1e-5 69 | 70 | # the mini-epoch is very short and of fixed length (length = ctx_len * epoch_length_fixed tokens) 71 | n_epoch = 500 72 | epoch_length_fixed = 10000 73 | 74 | # 0 = never, 1 = every mini-epoch, 2 = every two mini-epochs, ... 75 | epoch_save_frequency = 10 76 | epoch_save_path = 'trained-' 77 | 78 | ######################################################################################################## 79 | 80 | grad_norm_clip = 1.0 81 | warmup_tokens = ctx_len * batch_size * 0 82 | 83 | betas = (0.9, 0.99) 84 | eps = 4e-9 85 | 86 | num_workers = 0 87 | 88 | ######################################################################################################## 89 | # Load data 90 | ######################################################################################################## 91 | 92 | print('loading data... ' + datafile) 93 | train_dataset = Dataset(open( 94 | datafile, "r", encoding=datafile_encoding).read(), ctx_len, epoch_length_fixed) 95 | 96 | ######################################################################################################## 97 | # Train model 98 | ######################################################################################################## 99 | if __name__ == '__main__': 100 | 101 | model = GPT(GPTConfig(train_dataset.vocab_size, train_dataset.ctx_len, model_type=model_type, 102 | n_layer=n_layer, n_embd=n_embd)).cuda() 103 | 104 | ### ---> load a trained model <--- 105 | # m2 = torch.load('trained-61.pth') 106 | # model.load_state_dict(m2) 107 | 108 | print('model', model_type, 'epoch', n_epoch, 'batchsz', batch_size, 'betas', 109 | betas, 'eps', eps, 'ctx', ctx_len, 'layer', n_layer, 'embd', n_embd, ) 110 | tconf = TrainerConfig(model_type=model_type, max_epochs=n_epoch, batch_size=batch_size, 111 | learning_rate=lr_init, lr_decay=True, lr_final=lr_final, betas=betas, eps=eps, grad_norm_clip=grad_norm_clip, 112 | warmup_tokens=warmup_tokens, final_tokens=n_epoch*len(train_dataset)*ctx_len, num_workers=num_workers, epoch_save_frequency=epoch_save_frequency, epoch_save_path=epoch_save_path) 113 | trainer = Trainer(model, train_dataset, None, tconf) 114 | 115 | trainer.train() 116 | 117 | torch.save(model.state_dict(), 'trained-' + str(n_epoch) + '-' + trainer.get_run_name() + 118 | '-' + datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S') + '.pth') 119 | -------------------------------------------------------------------------------- /RWKV-v3/verify.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################## 2 | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM 3 | ######################################################################################################## 4 | 5 | # this is for verifying the results of different models and make sure they agree with each other 6 | 7 | import numpy as np 8 | np.set_printoptions(precision=4, suppress=True, linewidth=200) 9 | 10 | import os 11 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 12 | RUN_DEVICE = 'cuda' 13 | 14 | import torch 15 | from src.model_run import RWKV_RNN, RWKV_GPT 16 | from src.model import GPT, GPTConfig 17 | 18 | ctx_len = 1024 19 | n_layer = 6 20 | n_embd = 512 21 | model_type = 'RWKV' 22 | 23 | model_name = 'trained-1' 24 | 25 | from src.utils import TOKENIZER 26 | tokenizer = TOKENIZER('vocab', UNKNOWN_CHAR=' ') 27 | 28 | ######################################################################################################## 29 | 30 | model_train = GPT(GPTConfig(tokenizer.vocab_size, ctx_len, model_type=model_type, n_layer=n_layer, n_embd=n_embd)).cuda() 31 | print('loading ' + model_name) 32 | m2 = torch.load(model_name + '.pth', map_location=RUN_DEVICE) 33 | model_train.load_state_dict(m2) 34 | 35 | model_rnn = RWKV_RNN(model_name, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len) 36 | model_gpt = RWKV_GPT(model_name, RUN_DEVICE, model_type, tokenizer.vocab_size, n_layer, n_embd, ctx_len).cuda() 37 | 38 | ######################################################################################################## 39 | 40 | context = '\nIn a' 41 | ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context] 42 | print(f'input len {len(ctx)} data {ctx}') 43 | 44 | ######################################################################################################## 45 | 46 | print('\nRWKV-GPT output') 47 | out = model_gpt.forward(torch.tensor(ctx).unsqueeze(0).cuda())[0].detach().cpu().numpy() 48 | print(out) 49 | 50 | print('\nRWKV-RNN output') 51 | model_rnn.clear() 52 | src_len = len(ctx) 53 | for i in range(src_len): 54 | x = ctx[:i+1] 55 | out = model_rnn.run(x) 56 | if i < 3 or i >= src_len - 3: 57 | print(torch.tensor(out).detach().cpu().numpy()) 58 | if i == 2: 59 | print('...') 60 | 61 | print('\nRWKV-train output') 62 | ctx += [0] * (ctx_len - src_len) # pad to ctx_len 63 | ctx = [ctx] * 4 # increase batch size (to make it work with B_GROUP_FORWARD & B_GROUP_BACKWARD) 64 | out = model_train.forward(torch.tensor(ctx).cuda())[0][0][:src_len].detach().cpu().numpy() 65 | print(out, '\n') 66 | -------------------------------------------------------------------------------- /RWKV-v4-1.5B-Pile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blealtan/RWKV-LM-LoRA/4987137c31dd49cbbf2b2db3977930bc6ce5b84e/RWKV-v4-1.5B-Pile.png -------------------------------------------------------------------------------- /RWKV-v4/cuda/wkv_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #define MIN_VALUE (-1e38) 5 | 6 | template 7 | __global__ void kernel_forward(const int B, const int T, const int C, 8 | const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, 9 | F *__restrict__ const _y) { 10 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; 11 | const int _b = idx / C; 12 | const int _c = idx % C; 13 | const int _offset = _b * T * C + _c; 14 | 15 | F u = _u[_c]; 16 | F w = _w[_c]; 17 | const F *__restrict__ const k = _k + _offset; 18 | const F *__restrict__ const v = _v + _offset; 19 | F *__restrict__ const y = _y + _offset; 20 | 21 | F p = 0, q = 0, o = MIN_VALUE; 22 | // p and q are running sums divided by exp(o) (to avoid overflows) 23 | for (int i = 0; i < T; i++) { 24 | const int ii = i * C; 25 | 26 | F no = max(o, u + k[ii]); 27 | F A = exp(o - no); 28 | F B = exp(u + k[ii] - no); 29 | y[ii] = (A * p + B * v[ii]) / (A * q + B); 30 | 31 | no = max(w + o, k[ii]); 32 | A = exp(w + o - no); 33 | B = exp(k[ii] - no); 34 | p = A * p + B * v[ii]; 35 | q = A * q + B; 36 | o = no; 37 | } 38 | } 39 | 40 | template 41 | __global__ void kernel_backward(const int B, const int T, const int C, 42 | const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy, 43 | F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) { 44 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; 45 | const int _b = idx / C; 46 | const int _c = idx % C; 47 | const int _offset = _b * T * C + _c; 48 | 49 | F u = _u[_c]; 50 | F w = _w[_c]; 51 | const F *__restrict__ const k = _k + _offset; 52 | const F *__restrict__ const v = _v + _offset; 53 | const F *__restrict__ const gy = _gy + _offset; 54 | 55 | F *__restrict__ const gk = _gk + _offset; 56 | F *__restrict__ const gv = _gv + _offset; 57 | 58 | F y[Tmax], z[Tmax], zexp[Tmax]; 59 | 60 | F gw = 0, gu = 0; 61 | F p = 0, q = 0; 62 | F dpdw = 0, dqdw = 0; 63 | F o = MIN_VALUE; 64 | for (int i = 0; i < T; i++) { 65 | const int ii = i * C; 66 | F no = max(o, k[ii] + u); 67 | F A = exp(o - no); 68 | F B = exp(k[ii] + u - no); 69 | 70 | F num = A * p + B * v[ii]; 71 | F iden = 1 / (A * q + B); 72 | 73 | y[i] = num * iden; 74 | z[i] = iden; 75 | zexp[i] = k[ii] + u - no; 76 | 77 | gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A; 78 | gu += gy[ii] * (v[ii] - y[i]) * B * iden; 79 | 80 | no = max(w + o, k[ii]); 81 | A = exp(w + o - no); 82 | B = exp(k[ii] - no); 83 | dpdw = A * (p + dpdw); 84 | dqdw = A * (q + dqdw); 85 | p = A * p + B * v[ii]; 86 | q = A * q + B; 87 | o = no; 88 | } 89 | 90 | F gp = 0, gq = 0; 91 | o = MIN_VALUE; 92 | for (int i = T - 1; i >= 0; i--) { 93 | const int ii = i * C; 94 | F A = gy[ii] * z[i] * exp(zexp[i]); 95 | F B = exp(k[ii] + o); 96 | gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq); 97 | gv[ii] = A + B * gp; 98 | 99 | F no = max(w + o, zexp[i] - k[ii] - u); 100 | A = exp(w + o - no); 101 | B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no); 102 | gp = A * gp + B; 103 | gq = A * gq - B * y[i]; 104 | o = no; 105 | } 106 | 107 | // Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass 108 | const int _offsetBC = _b * C + _c; 109 | _gw[_offsetBC] += gw * _w[_c]; 110 | _gu[_offsetBC] += gu; 111 | } 112 | 113 | void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { 114 | dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance 115 | assert(B * C % threadsPerBlock.x == 0); 116 | dim3 numBlocks(B * C / threadsPerBlock.x); 117 | kernel_forward<<>>(B, T, C, w, u, k, v, y); 118 | } 119 | 120 | void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) { 121 | dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance 122 | assert(B * C % threadsPerBlock.x == 0); 123 | dim3 numBlocks(B * C / threadsPerBlock.x); 124 | kernel_backward<<>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv); 125 | } 126 | -------------------------------------------------------------------------------- /RWKV-v4/cuda/wkv_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); 4 | void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv); 5 | 6 | void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { 7 | cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); 8 | } 9 | void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { 10 | cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); 11 | } 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("forward", &forward, "wkv forward"); 15 | m.def("backward", &backward, "wkv backward"); 16 | } 17 | 18 | TORCH_LIBRARY(wkv, m) { 19 | m.def("forward", forward); 20 | m.def("backward", backward); 21 | } 22 | -------------------------------------------------------------------------------- /RWKV-v4/run.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################## 2 | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM 3 | ######################################################################################################## 4 | 5 | import numpy as np 6 | import math, os 7 | import time 8 | import types 9 | import copy 10 | import torch 11 | from torch.nn import functional as F 12 | from src.utils import TOKENIZER, Dataset 13 | torch.backends.cudnn.benchmark = True 14 | torch.backends.cudnn.allow_tf32 = True 15 | torch.backends.cuda.matmul.allow_tf32 = True 16 | np.set_printoptions(precision=4, suppress=True, linewidth=200) 17 | 18 | ######################################################################################################## 19 | # Step 1: set model 20 | # 21 | # Set TOKEN_MODE to 'char' or 'bpe' if the model is trained by 'train.py' from scratch. 22 | # 23 | # Set TOKEN_MODE to 'pile' if you want to test pre-trained pile models. 24 | ######################################################################################################## 25 | 26 | TOKEN_MODE = 'char' # char / bpe / pile 27 | 28 | n_layer = 6 29 | n_embd = 512 30 | ctx_len = 1024 31 | 32 | if TOKEN_MODE == 'char': 33 | MODEL_NAME = 'trained-500' # your trained model 34 | WORD_NAME = 'vocab' # the .json vocab (generated by train.py) 35 | # set UNKNOWN_CHAR to the rarest token in your vocab.json, and all unknown tokens in your prompt will be denoted by it 36 | UNKNOWN_CHAR = ' ' # here we just set it to ' ' for simplicity 37 | 38 | elif TOKEN_MODE == 'bpe': 39 | MODEL_NAME = 'trained-500' # your trained model 40 | WORD_NAME = ['model-vocab.json', 'model-merges.txt'] # [vocab, merge] for your BPE model 41 | UNKNOWN_CHAR = None 42 | 43 | elif TOKEN_MODE == 'pile': 44 | WORD_NAME = ['20B_tokenizer.json', '20B_tokenizer.json'] 45 | UNKNOWN_CHAR = None 46 | 47 | #---> you can set MODEL_NAME to your fine-tuned model <--- 48 | 49 | MODEL_NAME = 'RWKV-4-Pile-169M-20220807-8023' 50 | # MODEL_NAME = 'trained-11' 51 | n_layer = 12 52 | n_embd = 768 53 | ctx_len = 1024 54 | 55 | # MODEL_NAME = 'RWKV-4-Pile-430M-20220808-8066' 56 | # n_layer = 24 57 | # n_embd = 1024 58 | # ctx_len = 1024 59 | 60 | # MODEL_NAME = 'RWKV-4-Pile-1B5-20220903-8040' 61 | # n_layer = 24 62 | # n_embd = 2048 63 | # ctx_len = 1024 64 | 65 | os.environ['RWKV_FLOAT_MODE'] = 'fp32' # 'bf16' / 'fp16' / 'fp32' (note: only using fp32 at this moment) 66 | os.environ['RWKV_RUN_DEVICE'] = 'cpu' # 'cpu' (already very fast) or 'cuda' 67 | model_type = 'RWKV' # 'RWKV' or 'RWKV-ffnPre' 68 | 69 | ######################################################################################################## 70 | # Step 2: set prompt & sampling stuffs 71 | ######################################################################################################## 72 | 73 | # context = 'A' 74 | # context = "\nIn the" 75 | # context = '\nSugar:' 76 | context = '\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.' 77 | 78 | NUM_TRIALS = 999 79 | LENGTH_PER_TRIAL = 333 80 | 81 | TEMPERATURE = 1.0 82 | top_p = 0.7 83 | top_p_newline = 0.9 # only used in TOKEN_MODE = char 84 | 85 | DEBUG_DEBUG = False # True False --> show softmax output 86 | 87 | ######################################################################################################## 88 | 89 | print(f'Loading {MODEL_NAME}...') 90 | from src.model_run import RWKV_RNN 91 | model = RWKV_RNN(MODEL_NAME, os.environ['RWKV_RUN_DEVICE'], model_type, n_layer, n_embd, ctx_len) 92 | tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR) 93 | 94 | ######################################################################################################## 95 | 96 | if tokenizer.charMode: 97 | context = tokenizer.refine_context(context) 98 | ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context] 99 | else: 100 | ctx = tokenizer.tokenizer.encode(context) 101 | src_len = len(ctx) 102 | src_ctx = ctx.copy() 103 | 104 | print('\nYour prompt has ' + str(src_len) + ' tokens.') 105 | print('\n--> Currently the first run takes a while if your prompt is long, as we are using RNN to process the prompt. Use GPT to build the hidden state for better speed. <--\n') 106 | 107 | for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): 108 | t_begin = time.time_ns() 109 | print(('-' * 30) + context, end='') 110 | ctx = src_ctx.copy() 111 | model.clear() 112 | if TRIAL == 0: 113 | init_state = types.SimpleNamespace() 114 | for i in range(src_len): 115 | x = ctx[:i+1] 116 | if i == src_len - 1: 117 | init_state.out = model.run(x) 118 | else: 119 | model.run(x) 120 | model.save(init_state) 121 | else: 122 | model.load(init_state) 123 | 124 | for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)): 125 | x = ctx[:i+1] 126 | x = x[-ctx_len:] 127 | 128 | if i == src_len: 129 | out = copy.deepcopy(init_state.out) 130 | else: 131 | out = model.run(x) 132 | if DEBUG_DEBUG: 133 | print('model', np.array(x), '==>', np.array( 134 | out), np.max(out), np.min(out)) 135 | 136 | if TOKEN_MODE == 'pile': 137 | out[0] = -999999999 # disable <|endoftext|> 138 | 139 | char = tokenizer.sample_logits(out, x, ctx_len, temperature=TEMPERATURE, 140 | top_p_usual=top_p, top_p_newline=top_p_newline) 141 | char = char.item() 142 | if tokenizer.charMode: 143 | print(tokenizer.itos[int(char)], end='', flush=True) 144 | else: 145 | print(tokenizer.tokenizer.decode(int(char)), end='', flush=True) 146 | ctx += [char] 147 | 148 | t_end = time.time_ns() 149 | print("\n----------", round((t_end - t_begin) / (10 ** 9), 2), end='s ') 150 | -------------------------------------------------------------------------------- /RWKV-v4/src/binidx.py: -------------------------------------------------------------------------------- 1 | from lib2to3.pgen2 import token 2 | import os 3 | import torch 4 | import numpy as np 5 | import shutil 6 | import struct 7 | from functools import lru_cache 8 | from itertools import accumulate 9 | 10 | def print_rank_0(*message): 11 | """If distributed is initialized print only on rank 0.""" 12 | if torch.distributed.is_initialized(): 13 | if torch.distributed.get_rank() == 0: 14 | print(*message, flush=True) 15 | else: 16 | print(*message, flush=True) 17 | 18 | def _warmup_mmap_file(path): 19 | pass 20 | # with open(path, "rb") as stream: 21 | # while stream.read(100 * 1024 * 1024): 22 | # pass 23 | 24 | dtypes = { 25 | 1: np.uint8, 26 | 2: np.int8, 27 | 3: np.int16, 28 | 4: np.int32, 29 | 5: np.int64, 30 | 6: float, 31 | 7: np.double, 32 | 8: np.uint16, 33 | } 34 | 35 | def code(dtype): 36 | for k in dtypes.keys(): 37 | if dtypes[k] == dtype: 38 | return k 39 | raise ValueError(dtype) 40 | 41 | def index_file_path(prefix_path): 42 | return prefix_path + ".idx" 43 | 44 | def data_file_path(prefix_path): 45 | return prefix_path + ".bin" 46 | 47 | class MMapIndexedDataset(torch.utils.data.Dataset): 48 | class Index(object): 49 | _HDR_MAGIC = b"MMIDIDX\x00\x00" 50 | 51 | def __init__(self, path, skip_warmup=False): 52 | with open(path, "rb") as stream: 53 | magic_test = stream.read(9) 54 | assert self._HDR_MAGIC == magic_test, ( 55 | "Index file doesn't match expected format. " 56 | "Make sure that --dataset-impl is configured properly." 57 | ) 58 | # Little endian unsigned 64 Bit integer 59 | version = struct.unpack(" 0: 103 | loader = DataLoader(data, shuffle=False, pin_memory=True, 104 | batch_size=config.batch_size // NUM_GPUS, 105 | num_workers=config.num_workers) 106 | else: 107 | loader = DataLoader(data, shuffle=False, 108 | batch_size=config.batch_size // NUM_GPUS, 109 | num_workers=config.num_workers) 110 | 111 | pbar = tqdm(enumerate(loader), total=len( 112 | loader), bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') if is_train else enumerate(loader) 113 | loader = self.setup_dataloaders(loader) 114 | gc.collect() 115 | torch.cuda.empty_cache() 116 | 117 | for it, (x, y) in pbar: 118 | with torch.set_grad_enabled(is_train): 119 | loss = model(x, y) # forward the model 120 | 121 | if os.environ['RWKV_DEEPSPEED'] == '0': 122 | all_loss = [loss.clone()] 123 | else: 124 | all_loss = [loss.clone() for _ in range(NUM_GPUS)] 125 | torch.distributed.all_gather(all_loss, loss) 126 | 127 | if is_train: # backprop and update the parameters 128 | model.zero_grad() 129 | self.backward(loss) 130 | 131 | # deepspeed will handle gradient_clipping 132 | 133 | optimizer.step() 134 | 135 | # decay the learning rate based on our progress 136 | self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100) 137 | lr_final_factor = config.lr_final / config.learning_rate 138 | if self.tokens < config.warmup_tokens: 139 | # linear warmup 140 | lr_mult = lr_final_factor + \ 141 | (1 - lr_final_factor) * float(self.tokens) / \ 142 | float(config.warmup_tokens) 143 | progress = 0 144 | else: 145 | # exponential learning rate decay 146 | progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens)) 147 | if progress >= 1: 148 | lr_mult = lr_final_factor 149 | else: 150 | lr_mult = math.exp(math.log(lr_final_factor) * pow(progress, 1)) 151 | lr = config.learning_rate * lr_mult 152 | 153 | for param_group in optimizer.param_groups: 154 | param_group['lr'] = lr 155 | 156 | self.lr = lr 157 | self.steps += 1 158 | 159 | now_loss = 0 160 | for gg in range(NUM_GPUS): 161 | now_loss += all_loss[gg].item() 162 | now_loss = now_loss / NUM_GPUS # report progress 163 | if USE_WANDB and self.cuda_id == 0: 164 | wandb.log({"loss": now_loss}, step = self.steps) 165 | 166 | if self.avg_loss < 0: 167 | self.avg_loss = now_loss 168 | else: 169 | factor = 1 / (it + 1) 170 | self.avg_loss = self.avg_loss * (1.0 - factor) + now_loss * factor 171 | 172 | pbar.set_description(f"miniE {epoch+1+self.EPOCH_BEGIN} s {self.steps} prog {progress*100.0:.2f}% : ppl {math.exp(self.avg_loss):.6f} loss {self.avg_loss:.6f} lr {lr:e}") 173 | 174 | self.tokens = 0 # counter used for learning rate decay 175 | for epoch in range(99999999): 176 | 177 | run_epoch('train') 178 | if math.isnan(self.avg_loss): 179 | exit(0) 180 | 181 | if self.cuda_id == 0: 182 | log_file.write(f'{epoch+1+self.EPOCH_BEGIN} {self.avg_loss:.6f} {math.exp(self.avg_loss):.4f} {self.lr:.8f} {datetime.datetime.now()} {epoch+1} \n') 183 | log_file.flush() 184 | 185 | if (self.config.epoch_save_frequency > 0 and epoch % self.config.epoch_save_frequency == 0) or (epoch == config.max_epochs - 1): 186 | raw_model = self.model.module if hasattr(self.model, "module") else self.model 187 | torch.save(raw_model.state_dict(), self.config.epoch_save_path + str(epoch+1+self.EPOCH_BEGIN) + '.pth') 188 | -------------------------------------------------------------------------------- /RWKV-v4/src/utils.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################## 2 | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM 3 | ######################################################################################################## 4 | 5 | import os 6 | try: 7 | NUM_GPUS = int(os.environ['RWKV_NUM_GPUS']) 8 | except: 9 | NUM_GPUS = 1 10 | 11 | import json 12 | import random 13 | import numpy as np 14 | import torch 15 | from torch.nn import functional as F 16 | from torch.utils.data import Dataset 17 | 18 | class Dataset(Dataset): 19 | def __init__(self, data, ctx_len, epoch_length_fixed): 20 | self.ctx_len = ctx_len 21 | self.epoch_length_fixed = epoch_length_fixed 22 | self.data = data 23 | 24 | if 'MMapIndexedDataset' in str(type(self.data)): 25 | self.vocab_size = int(os.environ['VOCAB_SIZE']) 26 | print('current vocab size =', self.vocab_size, "(make sure it's correct)") 27 | self.data_size = len(self.data._bin_buffer) // 2 28 | print(f'data has {self.data_size} tokens.') 29 | elif 'numpy' in str(type(self.data)): 30 | self.vocab_size = int(os.environ['VOCAB_SIZE']) 31 | print('current vocab size =', self.vocab_size, "(make sure it's correct)") 32 | self.data_size = len(self.data) 33 | print(f'data has {self.data_size} tokens.') 34 | else: 35 | print('building token list...', end=' ') 36 | unique = sorted(list(set(data))) 37 | self.vocab_size = len(unique) 38 | # print() 39 | # for u in unique: 40 | # print(u, end=' ') 41 | # print('\n\n') 42 | 43 | xx = 0 44 | xxObj = {} 45 | for u in unique: 46 | xxObj[xx] = u 47 | xx += 1 48 | with open('vocab.json', "w", encoding="utf-16") as vocab_file: 49 | vocab_file.write(json.dumps(xxObj, ensure_ascii=False)) 50 | self.data_size = len(self.data) 51 | print('data has %d tokens, %d unique.' % (self.data_size, self.vocab_size)) 52 | self.stoi = {ch: i for i, ch in enumerate(unique)} 53 | self.itos = {i: ch for i, ch in enumerate(unique)} 54 | 55 | def __len__(self): 56 | return self.epoch_length_fixed // NUM_GPUS 57 | 58 | def __getitem__(self, idx): 59 | # 60 | # we are cheating: pick a random spot in dataset 61 | # 62 | i = np.random.randint(0, self.data_size - (self.ctx_len + 1)) 63 | if 'MMapIndexedDataset' in str(type(self.data)): 64 | dix = self.data.get(idx=0, offset=i, length=self.ctx_len + 1).astype(int) 65 | elif 'numpy' in str(type(self.data)): 66 | dix = self.data[i:i+self.ctx_len+1] 67 | else: 68 | dix = [self.stoi[s] for s in self.data[i:i+self.ctx_len+1]] 69 | 70 | x = torch.tensor(dix[:-1], dtype=torch.long) 71 | y = torch.tensor(dix[1:], dtype=torch.long) 72 | return x, y 73 | 74 | 75 | class TOKENIZER(): 76 | def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'): 77 | if 'list' in str(type(WORD_NAME)): 78 | self.charMode = False 79 | if WORD_NAME[0] == WORD_NAME[1]: 80 | from transformers import PreTrainedTokenizerFast 81 | self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0]) 82 | else: 83 | from transformers import GPT2TokenizerFast 84 | self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1]) 85 | self.vocab_size = len(self.tokenizer) 86 | else: 87 | self.charMode = True 88 | with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file: 89 | self.word_table = json.load(result_file) 90 | 91 | self.vocab_size = len(self.word_table) 92 | 93 | self.stoi = {v: int(k) for k, v in self.word_table.items()} 94 | self.itos = {int(k): v for k, v in self.word_table.items()} 95 | 96 | self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR] 97 | 98 | def refine_context(self, context): 99 | context = context.strip().split('\n') 100 | for c in range(len(context)): 101 | context[c] = context[c].strip().strip('\u3000').strip('\r') 102 | context = list(filter(lambda c: c != '', context)) 103 | context = '\n' + ('\n'.join(context)).strip() 104 | if context == '': 105 | context = '\n' 106 | return context 107 | 108 | def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None): 109 | # out[self.UNKNOWN_CHAR] = -float('Inf') 110 | 111 | lastChar = int(x[-1]) 112 | 113 | probs = F.softmax(torch.tensor(out), dim=-1) 114 | 115 | if self.charMode: 116 | if self.itos[lastChar] == '\n': 117 | top_p = top_p_newline 118 | else: 119 | top_p = top_p_usual 120 | else: 121 | top_p = top_p_usual 122 | 123 | sorted_probs, s_index = torch.sort(probs, descending=True) 124 | 125 | # for j in range(30): 126 | # pp = sorted_probs[j].item() 127 | # if pp < 0.005: 128 | # break 129 | # ss = self.itos[int(s_index[j])].replace('\n','_') 130 | # print(f'{math.floor(pp*100):>3.0f}{ss}', end='') 131 | # print('') 132 | 133 | cumulative_probs = torch.cumsum(sorted_probs, dim=-1).numpy() 134 | cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) 135 | 136 | probs[probs < cutoff] = 0 137 | # print("[" + str(round(cutoff,4)) + ' ' + str(round(to_float(sum(probs)),3)) + "]", end = "") 138 | 139 | if temperature != 1.0: 140 | probs = probs.pow(1.0 / temperature) 141 | 142 | return torch.multinomial(probs, num_samples=1)[0] 143 | 144 | 145 | def to_float(x): 146 | return x.cpu().detach().numpy().flatten()[0].astype(float) 147 | 148 | 149 | def set_seed(seed): 150 | random.seed(seed) 151 | np.random.seed(seed) 152 | torch.manual_seed(seed) 153 | torch.cuda.manual_seed_all(seed) 154 | -------------------------------------------------------------------------------- /RWKV-v4/verify.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################## 2 | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM 3 | ######################################################################################################## 4 | 5 | # this is for verifying the results of different models and make sure they agree with each other 6 | 7 | import numpy as np 8 | np.set_printoptions(precision=4, suppress=True, linewidth=200) 9 | 10 | import os 11 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 12 | os.environ['RWKV_FLOAT_MODE'] = 'bf16' # 'bf16' (stable) or 'fp16' (will overflow after training a large model for very long. can be solved in the future) 13 | os.environ['RWKV_RUN_DEVICE'] = 'cuda' 14 | RUN_DEVICE = os.environ['RWKV_RUN_DEVICE'] 15 | 16 | import torch 17 | from src.model_run import RWKV_RNN, RWKV_GPT 18 | from src.model import GPT, GPTConfig 19 | 20 | TOKEN_MODE = 'pile' # char / pile 21 | 22 | if TOKEN_MODE == 'char': 23 | MODEL_NAME = 'trained-1' 24 | WORD_NAME = 'vocab' # the .json vocab (generated by train.py) 25 | ctx_len = 1024 26 | n_layer = 6 27 | n_embd = 512 28 | UNKNOWN_CHAR = ' ' # here we just set it to [space] for simplicity 29 | elif TOKEN_MODE == 'pile': 30 | WORD_NAME = ['20B_tokenizer.json', '20B_tokenizer.json'] 31 | MODEL_NAME = 'RWKV-4-Pile-169M-20220807-8023' 32 | ctx_len = 1024 33 | n_layer = 12 34 | n_embd = 768 35 | UNKNOWN_CHAR = None 36 | 37 | model_type = 'RWKV' 38 | 39 | from src.utils import TOKENIZER 40 | tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR) 41 | if TOKEN_MODE == 'pile': 42 | tokenizer.vocab_size = 50277 43 | 44 | ######################################################################################################## 45 | 46 | model_train = GPT(GPTConfig(tokenizer.vocab_size, ctx_len, model_type=model_type, n_layer=n_layer, n_embd=n_embd)).cuda() 47 | 48 | if os.environ['RWKV_FLOAT_MODE'] == 'fp16': 49 | model_train = model_train.half() 50 | elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': 51 | model_train = model_train.bfloat16() 52 | 53 | print('loading ' + MODEL_NAME) 54 | m2 = torch.load(MODEL_NAME + '.pth', map_location=RUN_DEVICE) 55 | model_train.load_state_dict(m2) 56 | 57 | model_rnn = RWKV_RNN(MODEL_NAME, RUN_DEVICE, model_type, n_layer, n_embd, ctx_len) 58 | model_gpt = RWKV_GPT(MODEL_NAME, RUN_DEVICE, model_type, tokenizer.vocab_size, n_layer, n_embd, ctx_len).cuda() 59 | 60 | ######################################################################################################## 61 | 62 | # context = '\nIn a' 63 | context = '\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.' 64 | 65 | if TOKEN_MODE == 'char': 66 | ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context] 67 | elif TOKEN_MODE == 'pile': 68 | ctx = tokenizer.tokenizer.encode(context) 69 | print(f'input len {len(ctx)} data {ctx}') 70 | 71 | ######################################################################################################## 72 | 73 | print('\nRWKV-GPT output') 74 | out = model_gpt.forward(torch.tensor(ctx).unsqueeze(0).cuda())[0].detach().cpu().numpy() 75 | print(out) 76 | 77 | print('\nRWKV-RNN output') 78 | model_rnn.clear() 79 | src_len = len(ctx) 80 | for i in range(src_len): 81 | x = ctx[:i+1] 82 | out = model_rnn.run(x) 83 | if i < 3 or i >= src_len - 3: 84 | print(torch.tensor(out).detach().cpu().numpy()) 85 | if i == 2: 86 | print('...') 87 | 88 | print('\nRWKV-train output') 89 | out = model_train.forward(torch.tensor([ctx]).cuda())[0][0].detach().cpu().float().numpy() 90 | print(out, '\n') 91 | -------------------------------------------------------------------------------- /RWKV-v4neo/cuda/wkv_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #define MIN_VALUE (-1e38) 5 | 6 | template 7 | __global__ void kernel_forward(const int B, const int T, const int C, 8 | const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, 9 | F *__restrict__ const _y) { 10 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; 11 | const int _b = idx / C; 12 | const int _c = idx % C; 13 | const int _offset = _b * T * C + _c; 14 | 15 | F u = _u[_c]; 16 | F w = _w[_c]; 17 | const F *__restrict__ const k = _k + _offset; 18 | const F *__restrict__ const v = _v + _offset; 19 | F *__restrict__ const y = _y + _offset; 20 | 21 | // aa and bb are running sums divided by exp(pp) (to avoid overflow) 22 | F aa = 0, bb = 0, pp = MIN_VALUE; 23 | for (int i = 0; i < T; i++) { 24 | const int ii = i * C; 25 | const F kk = k[ii]; 26 | const F vv = v[ii]; 27 | 28 | F ww = u + kk; 29 | F p = max(pp, ww); 30 | F e1 = exp(pp - p); 31 | F e2 = exp(ww - p); 32 | y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2); 33 | 34 | ww = w + pp; 35 | p = max(ww, kk); 36 | e1 = exp(ww - p); 37 | e2 = exp(kk - p); 38 | aa = e1 * aa + e2 * vv; 39 | bb = e1 * bb + e2; 40 | pp = p; 41 | } 42 | } 43 | 44 | template 45 | __global__ void kernel_backward(const int B, const int T, const int C, 46 | const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, 47 | const F *__restrict__ const _y, const F *__restrict__ const _gy, 48 | F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) { 49 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; 50 | const int _b = idx / C; 51 | const int _c = idx % C; 52 | const int _offset = _b * T * C + _c; 53 | 54 | F u = _u[_c]; 55 | F w = _w[_c]; 56 | const F *__restrict__ const k = _k + _offset; 57 | const F *__restrict__ const v = _v + _offset; 58 | const F *__restrict__ const y = _y + _offset; 59 | const F *__restrict__ const gy = _gy + _offset; 60 | F *__restrict__ const gk = _gk + _offset; 61 | F *__restrict__ const gv = _gv + _offset; 62 | 63 | F q[Tmax], r[Tmax]; 64 | 65 | F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE; 66 | for (int i = 0; i < T; i++) { 67 | const int ii = i * C; 68 | const F kk = k[ii]; 69 | const F vv = v[ii]; 70 | const F yy = y[ii]; 71 | 72 | F ww = u + kk; 73 | F p = max(pp, ww); 74 | F e1 = exp(pp - p); 75 | F e2 = exp(ww - p); 76 | const F qq = gy[ii] / (e1 * bb + e2); 77 | gw += (ga - gb * yy) * e1 * qq; 78 | gu += (vv - yy) * e2 * qq; 79 | q[i] = qq; 80 | r[i] = ww - p; 81 | 82 | ww = w + pp; 83 | p = max(ww, kk); 84 | e1 = exp(ww - p); 85 | e2 = exp(kk - p); 86 | ga = e1 * (aa + ga); 87 | gb = e1 * (bb + gb); 88 | aa = e1 * aa + e2 * vv; 89 | bb = e1 * bb + e2; 90 | pp = p; 91 | } 92 | const int _offsetBC = _b * C + _c; 93 | _gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward() 94 | _gu[_offsetBC] = gu; 95 | 96 | aa = 0, bb = 0, pp = MIN_VALUE; 97 | for (int i = T - 1; i >= 0; i--) { 98 | const int ii = i * C; 99 | const F kk = k[ii]; 100 | const F vv = v[ii]; 101 | const F yy = y[ii]; 102 | const F qq = q[i]; 103 | const F rr = r[i]; 104 | 105 | F e1 = qq * exp(rr); 106 | F e2 = exp(kk + pp); 107 | gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb); 108 | gv[ii] = e1 + e2 * aa; 109 | 110 | const F ww = w + pp; 111 | const F www = rr - u - kk; 112 | const F p = max(ww, www); 113 | e1 = exp(ww - p); 114 | e2 = qq * exp(www - p); 115 | aa = e1 * aa + e2; 116 | bb = e1 * bb - e2 * yy; 117 | pp = p; 118 | } 119 | } 120 | 121 | void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { 122 | dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance 123 | assert(B * C % threadsPerBlock.x == 0); 124 | dim3 numBlocks(B * C / threadsPerBlock.x); 125 | kernel_forward<<>>(B, T, C, w, u, k, v, y); 126 | } 127 | 128 | void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) { 129 | dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance 130 | assert(B * C % threadsPerBlock.x == 0); 131 | dim3 numBlocks(B * C / threadsPerBlock.x); 132 | kernel_backward<<>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv); 133 | } 134 | -------------------------------------------------------------------------------- /RWKV-v4neo/cuda/wkv_cuda_bf16.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "ATen/ATen.h" 4 | #define MIN_VALUE (-1e38) 5 | typedef at::BFloat16 bf16; 6 | 7 | __global__ void kernel_forward(const int B, const int T, const int C, 8 | const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, 9 | bf16 *__restrict__ const _y) { 10 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; 11 | const int _b = idx / C; 12 | const int _c = idx % C; 13 | const int _offset = _b * T * C + _c; 14 | 15 | float u = float(_u[_c]); 16 | float w = _w[_c]; 17 | const bf16 *__restrict__ const k = _k + _offset; 18 | const bf16 *__restrict__ const v = _v + _offset; 19 | bf16 *__restrict__ const y = _y + _offset; 20 | 21 | // aa and bb are running sums divided by exp(pp) (to avoid overflow) 22 | float aa = 0, bb = 0, pp = MIN_VALUE; 23 | for (int i = 0; i < T; i++) { 24 | const int ii = i * C; 25 | const float kk = float(k[ii]); 26 | const float vv = float(v[ii]); 27 | 28 | float ww = u + kk; 29 | float p = max(pp, ww); 30 | float e1 = exp(pp - p); 31 | float e2 = exp(ww - p); 32 | y[ii] = bf16((e1 * aa + e2 * vv) / (e1 * bb + e2)); 33 | 34 | ww = w + pp; 35 | p = max(ww, kk); 36 | e1 = exp(ww - p); 37 | e2 = exp(kk - p); 38 | aa = e1 * aa + e2 * vv; 39 | bb = e1 * bb + e2; 40 | pp = p; 41 | } 42 | } 43 | 44 | __global__ void kernel_backward(const int B, const int T, const int C, 45 | const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v, 46 | const bf16 *__restrict__ const _y, const bf16 *__restrict__ const _gy, 47 | bf16 *__restrict__ const _gw, bf16 *__restrict__ const _gu, bf16 *__restrict__ const _gk, bf16 *__restrict__ const _gv) { 48 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; 49 | const int _b = idx / C; 50 | const int _c = idx % C; 51 | const int _offset = _b * T * C + _c; 52 | 53 | float u = float(_u[_c]); 54 | float w = _w[_c]; 55 | const bf16 *__restrict__ const k = _k + _offset; 56 | const bf16 *__restrict__ const v = _v + _offset; 57 | const bf16 *__restrict__ const y = _y + _offset; 58 | const bf16 *__restrict__ const gy = _gy + _offset; 59 | bf16 *__restrict__ const gk = _gk + _offset; 60 | bf16 *__restrict__ const gv = _gv + _offset; 61 | 62 | float q[Tmax], r[Tmax]; 63 | 64 | float gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE; 65 | for (int i = 0; i < T; i++) { 66 | const int ii = i * C; 67 | const float kk = float(k[ii]); 68 | const float vv = float(v[ii]); 69 | const float yy = float(y[ii]); 70 | 71 | float ww = u + kk; 72 | float p = max(pp, ww); 73 | float e1 = exp(pp - p); 74 | float e2 = exp(ww - p); 75 | const float qq = float(gy[ii]) / (e1 * bb + e2); 76 | gw += (ga - gb * yy) * e1 * qq; 77 | gu += (vv - yy) * e2 * qq; 78 | q[i] = qq; 79 | r[i] = ww - p; 80 | 81 | ww = w + pp; 82 | p = max(ww, kk); 83 | e1 = exp(ww - p); 84 | e2 = exp(kk - p); 85 | ga = e1 * (aa + ga); 86 | gb = e1 * (bb + gb); 87 | aa = e1 * aa + e2 * vv; 88 | bb = e1 * bb + e2; 89 | pp = p; 90 | } 91 | const int _offsetBC = _b * C + _c; 92 | _gw[_offsetBC] = bf16(gw * _w[_c]); // multiply by w because of w -> -exp(w) in python forward() 93 | _gu[_offsetBC] = bf16(gu); 94 | 95 | aa = 0, bb = 0, pp = MIN_VALUE; 96 | for (int i = T - 1; i >= 0; i--) { 97 | const int ii = i * C; 98 | const float kk = float(k[ii]); 99 | const float vv = float(v[ii]); 100 | const float yy = float(y[ii]); 101 | const float qq = q[i]; 102 | const float rr = r[i]; 103 | 104 | float e1 = qq * exp(rr); 105 | float e2 = exp(kk + pp); 106 | gk[ii] = bf16(e1 * (vv - yy) + e2 * (aa * vv + bb)); 107 | gv[ii] = bf16(e1 + e2 * aa); 108 | 109 | const float ww = w + pp; 110 | const float www = rr - u - kk; 111 | const float p = max(ww, www); 112 | e1 = exp(ww - p); 113 | e2 = qq * exp(www - p); 114 | aa = e1 * aa + e2; 115 | bb = e1 * bb - e2 * yy; 116 | pp = p; 117 | } 118 | } 119 | 120 | void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y) { 121 | dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance 122 | assert(B * C % threadsPerBlock.x == 0); 123 | dim3 numBlocks(B * C / threadsPerBlock.x); 124 | kernel_forward<<>>(B, T, C, w, u, k, v, y); 125 | } 126 | 127 | void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv) { 128 | dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance 129 | assert(B * C % threadsPerBlock.x == 0); 130 | dim3 numBlocks(B * C / threadsPerBlock.x); 131 | kernel_backward<<>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv); 132 | } 133 | -------------------------------------------------------------------------------- /RWKV-v4neo/cuda/wkv_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); 4 | void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv); 5 | 6 | void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { 7 | cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); 8 | } 9 | void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { 10 | cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); 11 | } 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("forward", &forward, "wkv forward"); 15 | m.def("backward", &backward, "wkv backward"); 16 | } 17 | 18 | TORCH_LIBRARY(wkv, m) { 19 | m.def("forward", forward); 20 | m.def("backward", backward); 21 | } 22 | -------------------------------------------------------------------------------- /RWKV-v4neo/cuda/wkv_op_bf16.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "ATen/ATen.h" 3 | typedef at::BFloat16 bf16; 4 | 5 | void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y); 6 | void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv); 7 | 8 | void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { 9 | cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); 10 | } 11 | void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, 12 | torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { 13 | cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), 14 | gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); 15 | } 16 | 17 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 18 | m.def("forward", &forward, "wkv forward"); 19 | m.def("backward", &backward, "wkv backward"); 20 | } 21 | 22 | TORCH_LIBRARY(wkv, m) { 23 | m.def("forward", forward); 24 | m.def("backward", backward); 25 | } 26 | -------------------------------------------------------------------------------- /RWKV-v4neo/img_demoAE.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################## 2 | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM 3 | ######################################################################################################## 4 | 5 | import torch, types, os 6 | import numpy as np 7 | from PIL import Image 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | import torchvision as vision 11 | import torchvision.transforms as transforms 12 | np.set_printoptions(precision=4, suppress=True, linewidth=200) 13 | print(f'loading...') 14 | 15 | ######################################################################################################## 16 | 17 | model_prefix = 'test/image_trained/out-v7c_d8_256-224-13bit-OB32x0.5-201' 18 | input_img = 'test/img_ae_test/test0.png' 19 | 20 | ######################################################################################################## 21 | 22 | class ToBinary(torch.autograd.Function): 23 | @staticmethod 24 | def forward(ctx, x): 25 | return torch.floor(x + 0.5) # no need for noise when we have plenty of data 26 | 27 | @staticmethod 28 | def backward(ctx, grad_output): 29 | return grad_output.clone() # pass-through 30 | 31 | class R_ENCODER(nn.Module): 32 | def __init__(self, args): 33 | super().__init__() 34 | self.args = args 35 | dd = 8 36 | self.Bxx = nn.BatchNorm2d(dd*64) 37 | 38 | self.CIN = nn.Conv2d(3, dd, kernel_size=3, padding=1) 39 | self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1) 40 | self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1) 41 | 42 | self.B00 = nn.BatchNorm2d(dd*4) 43 | self.C00 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) 44 | self.C01 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) 45 | self.C02 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) 46 | self.C03 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) 47 | 48 | self.B10 = nn.BatchNorm2d(dd*16) 49 | self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) 50 | self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) 51 | self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) 52 | self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) 53 | 54 | self.B20 = nn.BatchNorm2d(dd*64) 55 | self.C20 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) 56 | self.C21 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) 57 | self.C22 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) 58 | self.C23 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) 59 | 60 | self.COUT = nn.Conv2d(dd*64, args.my_img_bit, kernel_size=3, padding=1) 61 | 62 | def forward(self, img): 63 | ACT = F.mish 64 | 65 | x = self.CIN(img) 66 | xx = self.Bxx(F.pixel_unshuffle(x, 8)) 67 | x = x + self.Cx1(ACT(self.Cx0(x))) 68 | 69 | x = F.pixel_unshuffle(x, 2) 70 | x = x + self.C01(ACT(self.C00(ACT(self.B00(x))))) 71 | x = x + self.C03(ACT(self.C02(x))) 72 | 73 | x = F.pixel_unshuffle(x, 2) 74 | x = x + self.C11(ACT(self.C10(ACT(self.B10(x))))) 75 | x = x + self.C13(ACT(self.C12(x))) 76 | 77 | x = F.pixel_unshuffle(x, 2) 78 | x = x + self.C21(ACT(self.C20(ACT(self.B20(x))))) 79 | x = x + self.C23(ACT(self.C22(x))) 80 | 81 | x = self.COUT(x + xx) 82 | return torch.sigmoid(x) 83 | 84 | class R_DECODER(nn.Module): 85 | def __init__(self, args): 86 | super().__init__() 87 | self.args = args 88 | dd = 8 89 | self.CIN = nn.Conv2d(args.my_img_bit, dd*64, kernel_size=3, padding=1) 90 | 91 | self.B00 = nn.BatchNorm2d(dd*64) 92 | self.C00 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) 93 | self.C01 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) 94 | self.C02 = nn.Conv2d(dd*64, 256, kernel_size=3, padding=1) 95 | self.C03 = nn.Conv2d(256, dd*64, kernel_size=3, padding=1) 96 | 97 | self.B10 = nn.BatchNorm2d(dd*16) 98 | self.C10 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) 99 | self.C11 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) 100 | self.C12 = nn.Conv2d(dd*16, 256, kernel_size=3, padding=1) 101 | self.C13 = nn.Conv2d(256, dd*16, kernel_size=3, padding=1) 102 | 103 | self.B20 = nn.BatchNorm2d(dd*4) 104 | self.C20 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) 105 | self.C21 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) 106 | self.C22 = nn.Conv2d(dd*4, 256, kernel_size=3, padding=1) 107 | self.C23 = nn.Conv2d(256, dd*4, kernel_size=3, padding=1) 108 | 109 | self.Cx0 = nn.Conv2d(dd, 32, kernel_size=3, padding=1) 110 | self.Cx1 = nn.Conv2d(32, dd, kernel_size=3, padding=1) 111 | self.COUT = nn.Conv2d(dd, 3, kernel_size=3, padding=1) 112 | 113 | def forward(self, code): 114 | ACT = F.mish 115 | x = self.CIN(code) 116 | 117 | x = x + self.C01(ACT(self.C00(ACT(self.B00(x))))) 118 | x = x + self.C03(ACT(self.C02(x))) 119 | x = F.pixel_shuffle(x, 2) 120 | 121 | x = x + self.C11(ACT(self.C10(ACT(self.B10(x))))) 122 | x = x + self.C13(ACT(self.C12(x))) 123 | x = F.pixel_shuffle(x, 2) 124 | 125 | x = x + self.C21(ACT(self.C20(ACT(self.B20(x))))) 126 | x = x + self.C23(ACT(self.C22(x))) 127 | x = F.pixel_shuffle(x, 2) 128 | 129 | x = x + self.Cx1(ACT(self.Cx0(x))) 130 | x = self.COUT(x) 131 | 132 | return torch.sigmoid(x) 133 | 134 | ######################################################################################################## 135 | 136 | print(f'building model...') 137 | args = types.SimpleNamespace() 138 | args.my_img_bit = 13 139 | encoder = R_ENCODER(args).eval().cuda() 140 | decoder = R_DECODER(args).eval().cuda() 141 | 142 | zpow = torch.tensor([2**i for i in range(0,13)]).reshape(13,1,1).cuda().long() 143 | 144 | encoder.load_state_dict(torch.load(f'{model_prefix}-E.pth')) 145 | decoder.load_state_dict(torch.load(f'{model_prefix}-D.pth')) 146 | 147 | ######################################################################################################## 148 | 149 | print(f'test image...') 150 | img_transform = transforms.Compose([ 151 | transforms.PILToTensor(), 152 | transforms.ConvertImageDtype(torch.float), 153 | transforms.Resize((224, 224)) 154 | ]) 155 | 156 | with torch.no_grad(): 157 | img = img_transform(Image.open(input_img)).unsqueeze(0).cuda() 158 | z = encoder(img) 159 | z = ToBinary.apply(z) 160 | 161 | zz = torch.sum(z.squeeze().long() * zpow, dim=0) 162 | print(f'Code shape = {zz.shape}\n{zz.cpu().numpy()}\n') 163 | 164 | out = decoder(z) 165 | vision.utils.save_image(out, f"{input_img.split('.')[0]}-out-13bit.jpg") 166 | -------------------------------------------------------------------------------- /RWKV-v4neo/merge_lora.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import os 3 | import sys 4 | from typing import Dict 5 | import typing 6 | import torch 7 | 8 | if '-h' in sys.argv or '--help' in sys.argv: 9 | print(f'Usage: python3 {sys.argv[0]} [--use-gpu] ') 10 | 11 | if sys.argv[1] == '--use-gpu': 12 | device = 'cuda' 13 | lora_alpha, base_model, lora, output = float(sys.argv[2]), sys.argv[3], sys.argv[4], sys.argv[5] 14 | else: 15 | device = 'cpu' 16 | lora_alpha, base_model, lora, output = float(sys.argv[1]), sys.argv[2], sys.argv[3], sys.argv[4] 17 | 18 | 19 | with torch.no_grad(): 20 | w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu') 21 | # merge LoRA-only slim checkpoint into the main weights 22 | w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu') 23 | for k in w_lora.keys(): 24 | w[k] = w_lora[k] 25 | output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict() 26 | # merge LoRA weights 27 | keys = list(w.keys()) 28 | for k in keys: 29 | if k.endswith('.weight'): 30 | prefix = k[:-len('.weight')] 31 | lora_A = prefix + '.lora_A' 32 | lora_B = prefix + '.lora_B' 33 | if lora_A in keys: 34 | assert lora_B in keys 35 | print(f'merging {lora_A} and {lora_B} into {k}') 36 | assert w[lora_B].shape[1] == w[lora_A].shape[0] 37 | lora_r = w[lora_B].shape[1] 38 | w[k] = w[k].to(device=device) 39 | w[lora_A] = w[lora_A].to(device=device) 40 | w[lora_B] = w[lora_B].to(device=device) 41 | w[k] += w[lora_B] @ w[lora_A] * (lora_alpha / lora_r) 42 | output_w[k] = w[k].to(device='cpu', copy=True) 43 | del w[k] 44 | del w[lora_A] 45 | del w[lora_B] 46 | continue 47 | 48 | if 'lora' not in k: 49 | print(f'retaining {k}') 50 | output_w[k] = w[k].clone() 51 | del w[k] 52 | 53 | torch.save(output_w, output) 54 | -------------------------------------------------------------------------------- /RWKV-v4neo/run.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################## 2 | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM 3 | ######################################################################################################## 4 | 5 | import numpy as np 6 | import math, os, sys, types, time, gc 7 | import torch 8 | from src.utils import TOKENIZER 9 | try: 10 | os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1] 11 | except: 12 | pass 13 | torch.backends.cudnn.benchmark = True 14 | torch.backends.cudnn.allow_tf32 = True 15 | torch.backends.cuda.matmul.allow_tf32 = True 16 | np.set_printoptions(precision=4, suppress=True, linewidth=200) 17 | args = types.SimpleNamespace() 18 | 19 | ######################################################################################################## 20 | # Step 1: set model & config (use v4 to run your trained-from-scratch models. v4 and v4neo are compatible) 21 | ######################################################################################################## 22 | 23 | args.RUN_DEVICE = "cuda" # 'cuda' // 'cpu' (already fast) 24 | args.FLOAT_MODE = "fp16" # fp16 (good for GPU, does not work for CPU) // fp32 (good for CPU) // bf16 (less accurate, but works for CPU) 25 | 26 | # if args.RUN_DEVICE == "cuda": 27 | # os.environ["RWKV_RUN_BACKEND"] = 'nvfuser' # !!!BUGGY!!! wrong output 28 | os.environ["RWKV_JIT_ON"] = '1' # '1' or '0'. very useful for GPU/CPU fp32, but might be harmful for GPU fp16. please benchmark !!! 29 | 30 | TOKEN_MODE = "pile" 31 | WORD_NAME = [ 32 | "20B_tokenizer.json", 33 | "20B_tokenizer.json", 34 | ] # [vocab, vocab] for Pile model 35 | UNKNOWN_CHAR = None 36 | vocab_size = 50277 37 | 38 | # Download Pile models: https://huggingface.co/BlinkDL 39 | # or, set MODEL_NAME to your fine-tuned model 40 | 41 | # MODEL_NAME = "/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-169M-20220807-8023" 42 | # n_layer = 12 43 | # n_embd = 768 44 | # ctx_len = 1024 45 | 46 | # MODEL_NAME = '/fsx/BlinkDL/rwkv-release/RWKV-4-Pile-430M-20220808-8066' 47 | # n_layer = 24 48 | # n_embd = 1024 49 | # ctx_len = 1024 50 | 51 | # MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040' 52 | # n_layer = 24 53 | # n_embd = 2048 54 | # ctx_len = 1024 55 | 56 | # MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221008-8023' 57 | # n_layer = 32 58 | # n_embd = 2560 59 | # ctx_len = 1024 60 | 61 | MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047' 62 | n_layer = 32 63 | n_embd = 4096 64 | ctx_len = 1024 65 | 66 | args.MODEL_NAME = MODEL_NAME 67 | args.n_layer = n_layer 68 | args.n_embd = n_embd 69 | args.ctx_len = ctx_len 70 | args.vocab_size = vocab_size 71 | args.head_qk = 0 72 | args.pre_ffn = 0 73 | args.grad_cp = 0 74 | args.my_pos_emb = 0 75 | os.environ["RWKV_RUN_DEVICE"] = args.RUN_DEVICE 76 | 77 | ######################################################################################################## 78 | # Step 2: set prompt & sampling stuffs 79 | ######################################################################################################## 80 | 81 | # context = 'A' 82 | # context = "\nIn the" 83 | # context = '\nSugar:' 84 | context = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese." 85 | 86 | # context = "\n深圳是" # test Chinese 87 | # context = "\n東京は" # test Japanese 88 | 89 | # ###### A good prompt for Q&A ###### 90 | # context = ''' 91 | # Questions & Helpful Answers 92 | # Ask Research Experts 93 | # Question: 94 | # Can penguins fly? 95 | 96 | # Full Answer: 97 | # ''' 98 | 99 | # ###### A good prompt for chatbot ###### 100 | # context = ''' 101 | # The following is a conversation between a highly knowledgeable and intelligent AI assistant called Bot, and a human user called User. In the following interactions, User and Bot converse in natural language, and Bot always answer User's questions. Bot is very smart, polite and humorous. Bot knows a lot, and always tells the truth. The conversation begins. 102 | 103 | # User: who is president of usa? 104 | 105 | # Bot: It’s Joe Biden; he was sworn in earlier this year. 106 | 107 | # User: french revolution what year 108 | 109 | # Bot: It started in 1789, but it lasted 10 years until 1799. 110 | 111 | # User: guess i marry who ? 112 | 113 | # Bot: Only if you tell me more about yourself - what are your interests? 114 | 115 | # User: wat is lhc 116 | 117 | # Bot: It’s a large and very expensive piece of science equipment. If I understand correctly, it’s a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012. 118 | 119 | # User:''' # type your question here 120 | 121 | NUM_TRIALS = 999 122 | LENGTH_PER_TRIAL = 333 123 | 124 | TEMPERATURE = 1.0 125 | top_p = 0.8 126 | top_p_newline = 0.9 # only used in TOKEN_MODE = char 127 | 128 | DEBUG_DEBUG = False # True False --> show softmax output 129 | 130 | ######################################################################################################## 131 | 132 | print(f'\nUsing {args.RUN_DEVICE.upper()}. Loading {MODEL_NAME}...') 133 | from src.model_run import RWKV_RNN 134 | 135 | model = RWKV_RNN(args) 136 | 137 | print(f'\nOptimizing speed...') 138 | out, _ = model.forward([187], None) 139 | # print(out) 140 | gc.collect() 141 | torch.cuda.empty_cache() 142 | 143 | # input(0) 144 | 145 | print(f'\nLoading tokenizer {WORD_NAME}...') 146 | tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR) 147 | if TOKEN_MODE == "pile": 148 | assert tokenizer.tokenizer.decode([187]) == '\n' 149 | 150 | ######################################################################################################## 151 | 152 | if tokenizer.charMode: 153 | context = tokenizer.refine_context(context) 154 | ctx = [tokenizer.stoi.get(s, tokenizer.UNKNOWN_CHAR) for s in context] 155 | else: 156 | ctx = tokenizer.tokenizer.encode(context) 157 | src_len = len(ctx) 158 | src_ctx = ctx.copy() 159 | 160 | print("\nYour prompt has " + str(src_len) + " tokens.") 161 | print( 162 | "Note: currently the first run takes a while if your prompt is long, as we are using RNN to preprocess the prompt. Use GPT to build the hidden state for better speed.\n" 163 | ) 164 | 165 | time_slot = {} 166 | time_ref = time.time_ns() 167 | 168 | def record_time(name): 169 | if name not in time_slot: 170 | time_slot[name] = 1e20 171 | tt = (time.time_ns() - time_ref) / 1e9 172 | if tt < time_slot[name]: 173 | time_slot[name] = tt 174 | 175 | init_state = None 176 | init_out = None 177 | state = None 178 | out = None 179 | 180 | for TRIAL in range(1 if DEBUG_DEBUG else NUM_TRIALS): 181 | print(("-" * 50) + '\n' + context, end="") 182 | 183 | time_ref = time.time_ns() 184 | ctx = src_ctx.copy() 185 | 186 | if TRIAL == 0: 187 | for i in range(src_len): 188 | x = ctx[: i + 1] 189 | if i == src_len - 1: 190 | init_out, init_state = model.forward(x, init_state) 191 | else: 192 | init_state = model.forward(x, init_state, preprocess_only=True) 193 | gc.collect() 194 | torch.cuda.empty_cache() 195 | 196 | record_time('preprocess') 197 | out_last = src_len 198 | for i in range(src_len, src_len + (1 if DEBUG_DEBUG else LENGTH_PER_TRIAL)): 199 | x = ctx[: i + 1] 200 | x = x[-ctx_len:] 201 | 202 | if i == src_len: 203 | out = init_out.clone() 204 | state = init_state.clone() 205 | else: 206 | out, state = model.forward(x, state) 207 | if DEBUG_DEBUG: 208 | print("model", np.array(x), "==>", np.array(out), np.max(out.cpu().numpy()), np.min(out.cpu().numpy())) 209 | if TOKEN_MODE == "pile": 210 | out[0] = -999999999 # disable <|endoftext|> 211 | 212 | ttt = tokenizer.sample_logits( 213 | out, 214 | x, 215 | ctx_len, 216 | temperature=TEMPERATURE, 217 | top_p_usual=top_p, 218 | top_p_newline=top_p_newline, 219 | ) 220 | ctx += [ttt] 221 | 222 | if tokenizer.charMode: 223 | char = tokenizer.itos[ttt] 224 | print(char, end="", flush=True) 225 | else: 226 | char = tokenizer.tokenizer.decode(ctx[out_last:]) 227 | if '\ufffd' not in char: # is valid utf8 string? 228 | print(char, end="", flush=True) 229 | out_last = i+1 230 | 231 | record_time('total') 232 | # print(f'\n\n{time_slot}\n\n') 233 | print( 234 | f"\n\n--- preprocess {round(time_slot['preprocess'], 2)}s, generation {round(time_slot['total']-time_slot['preprocess'], 2)}s ", end = '' 235 | ) 236 | 237 | print(("-" * 50) + '\n') 238 | -------------------------------------------------------------------------------- /RWKV-v4neo/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blealtan/RWKV-LM-LoRA/4987137c31dd49cbbf2b2db3977930bc6ce5b84e/RWKV-v4neo/src/__init__.py -------------------------------------------------------------------------------- /RWKV-v4neo/src/binidx.py: -------------------------------------------------------------------------------- 1 | from lib2to3.pgen2 import token 2 | import os 3 | import torch 4 | import numpy as np 5 | import shutil 6 | import struct 7 | from functools import lru_cache 8 | from itertools import accumulate 9 | 10 | def print_rank_0(*message): 11 | pass 12 | # """If distributed is initialized print only on rank 0.""" 13 | # if torch.distributed.is_initialized(): 14 | # if torch.distributed.get_rank() == 0: 15 | # print(*message, flush=True) 16 | # else: 17 | # print(*message, flush=True) 18 | 19 | def _warmup_mmap_file(path): 20 | pass 21 | # with open(path, "rb") as stream: 22 | # while stream.read(100 * 1024 * 1024): 23 | # pass 24 | 25 | dtypes = { 26 | 1: np.uint8, 27 | 2: np.int8, 28 | 3: np.int16, 29 | 4: np.int32, 30 | 5: np.int64, 31 | 6: float, 32 | 7: np.double, 33 | 8: np.uint16, 34 | } 35 | 36 | def code(dtype): 37 | for k in dtypes.keys(): 38 | if dtypes[k] == dtype: 39 | return k 40 | raise ValueError(dtype) 41 | 42 | def index_file_path(prefix_path): 43 | return prefix_path + ".idx" 44 | 45 | def data_file_path(prefix_path): 46 | return prefix_path + ".bin" 47 | 48 | class MMapIndexedDataset(torch.utils.data.Dataset): 49 | class Index(object): 50 | _HDR_MAGIC = b"MMIDIDX\x00\x00" 51 | 52 | @classmethod 53 | def writer(cls, path, dtype): 54 | class _Writer(object): 55 | def __enter__(self): 56 | self._file = open(path, "wb") 57 | 58 | # Write Magic string so we can check the file format then opening it again. 59 | self._file.write(cls._HDR_MAGIC) 60 | # Write version number 61 | # Little endian unsigned 64 Bit integer 62 | self._file.write(struct.pack(" 0: 36 | self.data_pile = MMapIndexedDataset('/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document') 37 | self.data_pile_size = len(self.data_pile._bin_buffer) // self.data._index._dtype_size 38 | 39 | if args.my_pile_stage > 0: 40 | # assert self.data_size == 332115325534 and self.vocab_size == 50277 41 | self.samples_per_epoch = args.epoch_steps * args.real_bsz 42 | assert self.samples_per_epoch == 40320 43 | rank_zero_info(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########") 44 | dataset_slot = self.data_size // args.ctx_len 45 | if args.my_pile_stage != 4: 46 | assert MaybeIsPrime(args.magic_prime) 47 | assert args.magic_prime % 3 == 2 48 | assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1 49 | elif args.data_type == "numpy": 50 | self.data = np.load(args.data_file).astype("int") 51 | self.vocab_size = args.vocab_size 52 | rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)") 53 | self.data_size = len(self.data) 54 | rank_zero_info(f"Data has {self.data_size} tokens.") 55 | elif args.data_type == "uint16": 56 | self.data = np.fromfile(args.data_file, dtype=np.uint16).astype("int32").reshape(-1, args.my_sample_len) 57 | self.vocab_size = args.vocab_size 58 | rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)") 59 | self.data_size = self.data.shape[0] 60 | rank_zero_info(f"Data has {self.data_size} samples.") 61 | elif args.data_type == "wds_img": 62 | self.vocab_size = -1 63 | self.data_size = -1 64 | self.data = None 65 | self.error_count = 0 66 | else: 67 | if args.data_type == "dummy": 68 | rank_zero_info("Building dummy data...") 69 | self.data = "" 70 | for i in range(100000): 71 | aa = (i) % 10000 72 | bb = (i * i) % 10000 73 | cc = aa + bb 74 | self.data += f".{aa}+{bb}={cc}." 75 | else: 76 | self.data = open(args.data_file, "r", encoding=args.data_type).read() 77 | rank_zero_info("Building token list...") 78 | unique = sorted(list(set(self.data))) 79 | self.vocab_size = len(unique) 80 | # rank_zero_info() 81 | # for u in unique: 82 | # print(u, end=' ') 83 | # rank_zero_info('\n\n') 84 | xx = 0 85 | xxObj = {} 86 | for u in unique: 87 | xxObj[xx] = u 88 | xx += 1 89 | with open(f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le") as vocab_file: 90 | vocab_file.write(json.dumps(xxObj, ensure_ascii=False)) 91 | self.data_size = len(self.data) 92 | rank_zero_info(f"Data has {self.data_size} tokens, {self.vocab_size} vocab size.") 93 | self.stoi = {ch: i for i, ch in enumerate(unique)} 94 | self.itos = {i: ch for i, ch in enumerate(unique)} 95 | 96 | def __len__(self): 97 | return self.args.epoch_steps * self.args.micro_bsz 98 | 99 | def __getitem__(self, idx): 100 | args = self.args 101 | rank = self.global_rank 102 | epoch = self.real_epoch 103 | world_size = self.world_size 104 | # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size}") 105 | 106 | if args.data_type == "wds_img": 107 | def init_wds(self, bias=0): 108 | def identity(x): 109 | return x 110 | import webdataset as wds 111 | import torchvision.transforms as transforms 112 | # img_transform = transforms.Compose( 113 | # [transforms.CenterCrop(256)] 114 | # ) 115 | img_transform = transforms.Compose([ 116 | transforms.CenterCrop(512), 117 | transforms.Resize((args.my_img_size)) 118 | ]) 119 | self.data_raw = wds.WebDataset(args.data_file, resampled=True).shuffle(10000, initial=1000, rng=random.Random(epoch*100000+rank+bias*1e9)).decode("torchrgb").to_tuple("jpg", "json", "txt").map_tuple(img_transform, identity, identity) 120 | for pp in self.data_raw.pipeline: 121 | if 'Resampled' in str(pp): 122 | pp.deterministic = True 123 | def worker_seed(): 124 | return rank*100000+epoch+bias*1e9 125 | pp.worker_seed = worker_seed 126 | self.data = iter(self.data_raw) 127 | # print(f"WebDataset loaded for rank {rank} epoch {epoch}") 128 | if self.data == None: 129 | init_wds(self) 130 | trial = 0 131 | while trial < 10: 132 | try: 133 | dd = next(self.data) # jpg, json, txt 134 | break 135 | except: 136 | print(f'[dataloader error - epoch {epoch} rank {rank} - trying a new shuffle]') 137 | self.error_count += 1 138 | init_wds(self, self.error_count) 139 | trial += 1 140 | pass 141 | # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {dd[2]}") 142 | # with open(f"sample_{rank}.txt", "a", encoding="utf-8") as tmp: 143 | # tmp.write(f"epoch {epoch} idx {idx} rank {rank}/{world_size} {int(dd[1]['key'])}\n") 144 | return dd[0], dd[2] 145 | else: 146 | if args.data_type == "uint16": 147 | i = np.random.randint(0, self.data_size-1) 148 | dix = self.data[i] 149 | x = torch.tensor(dix[:-1], dtype=torch.long) 150 | y = torch.tensor(dix[1:], dtype=torch.long) 151 | else: 152 | ctx_len = args.ctx_len 153 | req_len = ctx_len + 1 154 | magic_prime = args.magic_prime 155 | data = self.data 156 | 157 | if args.my_pile_stage > 0 and args.my_pile_stage != 4: 158 | ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank 159 | 160 | if args.my_qa_mask > 0: 161 | ii_orig = ii 162 | if ii % 2 == 0: 163 | ii = (ii // 2) * args.magic_prime 164 | if args.ctx_len == 1024: 165 | magic_prime = 324331313 166 | elif args.ctx_len == 2048: 167 | magic_prime = 162165671 168 | elif args.ctx_len == 4096: 169 | magic_prime = 81082817 170 | data = self.data_pile 171 | else: 172 | ii = ii // 2 173 | 174 | factor = (math.sqrt(5) - 1) / 2 175 | factor = int(magic_prime * factor) 176 | i = ((factor * ii * ii * ii) % magic_prime) * ctx_len 177 | if (args.my_qa_mask == 0) or (data == self.data_pile): 178 | i = i + args.my_pile_shift 179 | # print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}") 180 | else: 181 | # cheat: pick a random spot in dataset 182 | i = np.random.randint(0, self.data_size - req_len) 183 | 184 | if args.data_type == "binidx": 185 | dix = data.get(idx=0, offset=i, length=req_len).astype(int) 186 | elif args.data_type == "numpy": 187 | dix = data[i : i + req_len] 188 | else: 189 | dix = [self.stoi[s] for s in data[i : i + req_len]] 190 | 191 | if args.my_qa_mask == 1: 192 | if data == self.data_pile: 193 | z = [1] * ctx_len 194 | else: 195 | z = [0] * ctx_len 196 | z_sum = 0 197 | isGood = False 198 | for i in range(3, ctx_len): 199 | if dix[i] == 27 and dix[i-1] == 34 and dix[i-2] == 187 and dix[i-3] == 187: 200 | isGood = True 201 | if dix[i] == 0: 202 | isGood = False 203 | if isGood: 204 | z[i] = 1 205 | z_sum += 1 206 | if z_sum == 0: 207 | z = [1] * ctx_len 208 | i = np.random.randint(0, self.data_pile_size - req_len) 209 | dix = self.data_pile.get(idx=0, offset=i, length=req_len).astype(int) 210 | z = torch.tensor(z, dtype=torch.bfloat16) 211 | 212 | x = torch.tensor(dix[:-1], dtype=torch.long) 213 | y = torch.tensor(dix[1:], dtype=torch.long) 214 | 215 | # if ii_orig < 50: 216 | # # if rank == 1: 217 | # print('rank', rank, 'i', ii_orig, ii, i, 'x', x[:5], '...', x[-5:]) 218 | # else: 219 | # exit(0) 220 | 221 | if args.my_qa_mask == 1: 222 | return x, y, z 223 | 224 | return x, y 225 | -------------------------------------------------------------------------------- /RWKV-v4neo/src/model_run.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################## 2 | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM 3 | ######################################################################################################## 4 | 5 | import types 6 | import torch 7 | import math, os, gc 8 | from torch.nn import functional as F 9 | import torch.nn as nn 10 | from typing import List, Dict 11 | 12 | MyModule = nn.Module 13 | def __nop(ob): 14 | return ob 15 | MyFunction = __nop 16 | 17 | # # try torchdynamo 18 | # import torchdynamo 19 | # MyFunction = torchdynamo.optimize(os.environ["RWKV_RUN_BACKEND"]) # !!!BUGGY!!! wrong output 20 | 21 | # try torch jit --> faster for fp32, slower for fp16 (why?) 22 | if os.environ["RWKV_JIT_ON"] == "1": 23 | MyModule = torch.jit.ScriptModule 24 | MyFunction = torch.jit.script_method 25 | 26 | RWKV_HEAD_QK_DIM = 0 27 | print(f'\nRWKV_HEAD_QK_DIM {RWKV_HEAD_QK_DIM} RWKV_JIT_ON {os.environ["RWKV_JIT_ON"]}\n') 28 | 29 | DEBUG_TIME = False # True False - show trained time-coeffs 30 | 31 | RWKV_RESCALE_LAYER = 6 # set x=x/2 every X layer 32 | 33 | ############################################################################################################ 34 | 35 | class RWKV_RNN(MyModule): 36 | def __init__(self, args): 37 | super().__init__() 38 | 39 | self.args = args 40 | self.FLOAT_MODE = args.FLOAT_MODE 41 | self.RUN_DEVICE = args.RUN_DEVICE 42 | 43 | with torch.no_grad(): 44 | w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu') 45 | if args.lora_r > 0: 46 | # merge LoRA-only slim checkpoint into the main weights 47 | w_lora = torch.load(args.MODEL_LORA + '.pth', map_location='cpu') 48 | for k in w_lora.keys(): 49 | w[k] = w_lora[k] 50 | # merge LoRA weights 51 | keys = set(w.keys()) 52 | for k in keys: 53 | k: str 54 | if k.endswith('.weight'): 55 | prefix = k[:-len('.weight')] 56 | lora_A = prefix + '.lora_A' 57 | lora_B = prefix + '.lora_B' 58 | if lora_A in keys: 59 | assert lora_B in keys 60 | print(f'merging {lora_A} and {lora_B} into {k}') 61 | assert w[lora_B].shape[1] == w[lora_A].shape[0] == args.lora_r 62 | # merging needs matmul, which is slow on cpu; work on gpu if possible 63 | if args.RUN_DEVICE == 'cuda': 64 | w[k] = w[k].cuda() 65 | w[lora_A] = w[lora_A].cuda() 66 | w[lora_B] = w[lora_B].cuda() 67 | w[k] += w[lora_B] @ w[lora_A] * (args.lora_alpha / args.lora_r) 68 | del w[lora_A] 69 | del w[lora_B] 70 | # refine weights and send to correct device 71 | keys = list(w.keys()) 72 | if 'pos_emb_x' in keys: 73 | w['pos_emb'] = (w['pos_emb_x'] + w['pos_emb_y']).reshape(args.ctx_len+1, -1)[:-1,:] 74 | keys = list(w.keys()) 75 | print_need_newline = False 76 | for x in keys: 77 | block_id = 0 78 | if 'blocks.' in x: 79 | block_id = int(x.split('.')[1]) 80 | if 'att.output.weight' in x: 81 | w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER)) 82 | if 'ffn.value.weight' in x: 83 | w[x] = w[x] / (2 ** int(block_id // RWKV_RESCALE_LAYER)) 84 | 85 | if '.time_' in x: 86 | w[x] = w[x].squeeze() 87 | if DEBUG_TIME: 88 | print(x, w[x].numpy()) 89 | if '.time_decay' in x: 90 | w[x] = w[x].float() 91 | w[x] = -torch.exp(w[x]) 92 | elif '.time_first' in x: 93 | w[x] = w[x].float() 94 | else: 95 | if self.FLOAT_MODE == "fp32": 96 | w[x] = w[x].float() 97 | elif self.FLOAT_MODE == "bf16": 98 | w[x] = w[x].bfloat16() 99 | elif self.FLOAT_MODE == "fp16": 100 | w[x] = w[x].half() 101 | 102 | w[x].requires_grad = False 103 | if args.RUN_DEVICE == 'cuda' and x != 'emb.weight': 104 | w[x] = w[x].cuda() 105 | 106 | if ('blocks.' not in x) or ('blocks.0.' in x): 107 | if print_need_newline: 108 | print('\n', end = '') 109 | print_need_newline = False 110 | print(x.ljust(40), str(w[x].dtype).replace('torch.', '').ljust(10), w[x].device) 111 | else: 112 | print_need_newline = True 113 | print('.', end = '', flush = True) 114 | 115 | # store weights in self.w 116 | keys = list(w.keys()) 117 | self.w = types.SimpleNamespace() 118 | for x in keys: 119 | xx = x.split('.') 120 | here = self.w 121 | for i in range(len(xx)): 122 | if xx[i].isdigit(): 123 | ii = int(xx[i]) 124 | if ii not in here: 125 | here[ii] = types.SimpleNamespace() 126 | here = here[ii] 127 | else: 128 | if i == len(xx) - 1: 129 | setattr(here, xx[i], w[x]) 130 | elif not hasattr(here, xx[i]): 131 | if xx[i+1].isdigit(): 132 | setattr(here, xx[i], {}) 133 | else: 134 | setattr(here, xx[i], types.SimpleNamespace()) 135 | here = getattr(here, xx[i]) 136 | 137 | self.eval() 138 | gc.collect() 139 | torch.cuda.empty_cache() 140 | 141 | def LN(self, x, w): 142 | return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias) 143 | 144 | # state[] 0=ffn_xx 1=att_xx 2=att_aa 3=att_bb 4=att_pp 145 | 146 | @MyFunction 147 | def FF(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw): 148 | if self.FLOAT_MODE == "bf16": 149 | xk = x * time_mix_k + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_k) 150 | xr = x * time_mix_r + state[5*i+0].type(torch.bfloat16) * (1 - time_mix_r) 151 | state[5*i+0] = x.float() 152 | elif self.FLOAT_MODE == "fp16": 153 | xk = x * time_mix_k + state[5*i+0].half() * (1 - time_mix_k) 154 | xr = x * time_mix_r + state[5*i+0].half() * (1 - time_mix_r) 155 | state[5*i+0] = x.float() 156 | else: 157 | xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k) 158 | xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r) 159 | state[5*i+0] = x 160 | 161 | r = torch.sigmoid(rw @ xr) 162 | k = torch.square(torch.relu(kw @ xk)) 163 | kv = vw @ k 164 | 165 | return r * kv 166 | 167 | @MyFunction 168 | def SA(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow): 169 | if self.FLOAT_MODE == "bf16": 170 | xk = x * time_mix_k + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_k) 171 | xv = x * time_mix_v + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_v) 172 | xr = x * time_mix_r + state[5*i+1].type(torch.bfloat16) * (1 - time_mix_r) 173 | state[5*i+1] = x.float() 174 | elif self.FLOAT_MODE == "fp16": 175 | xk = x * time_mix_k + state[5*i+1].half() * (1 - time_mix_k) 176 | xv = x * time_mix_v + state[5*i+1].half() * (1 - time_mix_v) 177 | xr = x * time_mix_r + state[5*i+1].half() * (1 - time_mix_r) 178 | state[5*i+1] = x.float() 179 | else: 180 | xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k) 181 | xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v) 182 | xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r) 183 | state[5*i+1] = x 184 | 185 | r = torch.sigmoid(rw @ xr) 186 | k = kw @ xk 187 | v = vw @ xv 188 | 189 | if '16' in self.FLOAT_MODE: 190 | kk = k.float() 191 | vv = v.float() 192 | else: 193 | kk = k 194 | vv = v 195 | aa = state[5*i+2] 196 | bb = state[5*i+3] 197 | pp = state[5*i+4] 198 | ww = time_first + kk 199 | p = torch.maximum(pp, ww) 200 | e1 = torch.exp(pp - p) 201 | e2 = torch.exp(ww - p) 202 | a = e1 * aa + e2 * vv 203 | b = e1 * bb + e2 204 | ww = pp + time_decay 205 | p = torch.maximum(ww, kk) 206 | e1 = torch.exp(ww - p) 207 | e2 = torch.exp(kk - p) 208 | state[5*i+2] = e1 * aa + e2 * vv 209 | state[5*i+3] = e1 * bb + e2 210 | state[5*i+4] = p 211 | if self.FLOAT_MODE == "bf16": 212 | wkv = (a / b).type(torch.bfloat16) 213 | elif self.FLOAT_MODE == "fp16": 214 | wkv = (a / b).half() 215 | else: 216 | wkv = a / b 217 | 218 | return ow @ (r * wkv) 219 | 220 | def forward(self, ctx, state, preprocess_only = False): 221 | with torch.no_grad(): 222 | w = self.w 223 | args = self.args 224 | 225 | x = w.emb.weight[ctx[-1]] 226 | if self.RUN_DEVICE == 'cuda': 227 | x = x.cuda() 228 | try: 229 | pos_emb = w.pos_emb[len(ctx)-1] 230 | x = x + pos_emb 231 | except: 232 | pass 233 | 234 | if state == None: 235 | state = torch.zeros(args.n_layer * 5, args.n_embd, device=self.RUN_DEVICE) 236 | for i in range(args.n_layer): 237 | state[5*i+4] -= 1e30 238 | 239 | for i in range(args.n_layer): 240 | if i == 0: 241 | x = self.LN(x, w.blocks[i].ln0) 242 | 243 | ww = w.blocks[i].att 244 | x = x + self.SA(self.LN(x, w.blocks[i].ln1), state, i, 245 | ww.time_mix_k, ww.time_mix_v, ww.time_mix_r, ww.time_first, ww.time_decay, 246 | ww.key.weight, ww.value.weight, ww.receptance.weight, ww.output.weight) 247 | 248 | ww = w.blocks[i].ffn 249 | x = x + self.FF(self.LN(x, w.blocks[i].ln2), state, i, 250 | ww.time_mix_k, ww.time_mix_r, 251 | ww.key.weight, ww.value.weight, ww.receptance.weight) 252 | 253 | if (i+1) % RWKV_RESCALE_LAYER == 0: 254 | x = x / 2 255 | 256 | if preprocess_only: 257 | return state 258 | 259 | x = self.LN(x, w.ln_out) 260 | x = w.head.weight @ x 261 | 262 | return x.float(), state 263 | -------------------------------------------------------------------------------- /RWKV-v4neo/src/trainer.py: -------------------------------------------------------------------------------- 1 | import os, math, time, datetime, subprocess 2 | import torch 3 | from torch.utils.data import DataLoader 4 | import pytorch_lightning as pl 5 | from pytorch_lightning.utilities import rank_zero_info, rank_zero_only 6 | from .model import LORA_CONFIG 7 | 8 | def my_save(dd, ff): 9 | if '14b-run1' not in ff: 10 | torch.save(dd, ff) 11 | else: 12 | fn = ff.split('/')[-1] 13 | fff = '/dev/shm/' + fn 14 | torch.save(dd, fff) 15 | subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b-4k/{fn} --quiet", shell=True) 16 | 17 | class train_callback(pl.Callback): 18 | def __init__(self, args): 19 | super().__init__() 20 | self.args = args 21 | 22 | def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): 23 | args = self.args 24 | # if args.cuda_cleanup > 0: 25 | # torch.cuda.empty_cache() 26 | real_step = trainer.global_step + args.epoch_begin * args.epoch_steps 27 | 28 | # LR schedule 29 | w_step = args.warmup_steps 30 | if args.lr_final == args.lr_init or args.epoch_count == 0: 31 | lr = args.lr_init 32 | else: 33 | decay_step = real_step - args.my_pile_edecay * args.epoch_steps 34 | decay_total = (args.epoch_count - args.my_pile_edecay) * args.epoch_steps 35 | progress = (decay_step - w_step + 1) / (decay_total - w_step) 36 | progress = min(1, max(0, progress)) 37 | 38 | if args.lr_final == 0 or args.lr_init == 0: # linear decay 39 | lr = args.lr_init + (args.lr_final - args.lr_init) * progress 40 | else: # exp decay 41 | lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1)) 42 | 43 | if trainer.global_step < w_step: 44 | lr = lr * (0.2 + 0.8 * trainer.global_step / w_step) 45 | # if trainer.is_global_zero: 46 | # print(trainer.global_step, decay_step, decay_total, w_step, progress, lr) 47 | 48 | for param_group in trainer.optimizers[0].param_groups: 49 | if args.layerwise_lr > 0: 50 | param_group["lr"] = lr * param_group["my_lr_scale"] 51 | # print(param_group["lr"], param_group["my_lr_scale"]) 52 | else: 53 | param_group["lr"] = lr 54 | 55 | trainer.my_lr = lr 56 | # rank_zero_info(f"{real_step} {lr}") 57 | 58 | if trainer.global_step == 0: 59 | if trainer.is_global_zero: # logging 60 | trainer.my_loss_sum = 0 61 | trainer.my_loss_count = 0 62 | trainer.my_log = open(args.proj_dir + "/train_log.txt", "a") 63 | trainer.my_log.write(f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n") 64 | try: 65 | print(f"\n{trainer.strategy.config}\n") 66 | trainer.my_log.write(f"{trainer.strategy.config}\n") 67 | except: 68 | pass 69 | trainer.my_log.flush() 70 | if len(args.wandb) > 0: 71 | print("Login to wandb...") 72 | import wandb 73 | wandb.init( 74 | project=args.wandb, 75 | name=args.run_name + " " + args.my_timestamp, 76 | config=args, 77 | save_code=False, 78 | ) 79 | trainer.my_wandb = wandb 80 | 81 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): 82 | args = self.args 83 | if trainer.is_global_zero: # logging 84 | t_now = time.time_ns() 85 | token_per_step = args.ctx_len * args.real_bsz 86 | real_step = trainer.global_step + args.epoch_begin * args.epoch_steps 87 | kt_s = 0 88 | try: 89 | t_cost = (t_now - trainer.my_time_ns) / 1e9 90 | kt_s = token_per_step / t_cost / 1000 91 | self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True) 92 | self.log("Kt/s", kt_s, prog_bar=True, on_step=True) 93 | except: 94 | pass 95 | trainer.my_time_ns = t_now 96 | trainer.my_loss = trainer.my_loss_all.float().mean().item() 97 | trainer.my_loss_sum += trainer.my_loss 98 | trainer.my_loss_count += 1 99 | trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count 100 | self.log("lr", trainer.my_lr, prog_bar=True, on_step=True) 101 | self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True) 102 | # self.log("s", real_step, prog_bar=True, on_step=True) 103 | 104 | if len(args.wandb) > 0: 105 | lll = {"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9} 106 | if kt_s > 0: 107 | lll["kt/s"] = kt_s 108 | trainer.my_wandb.log(lll, step=int(real_step)) 109 | if args.magic_prime > 0: 110 | expand_factor = 2 if args.my_qa_mask > 0 else 1 111 | if int(real_step) == int(args.magic_prime * expand_factor // args.real_bsz) - 1: 112 | to_save_dict = pl_module.state_dict() 113 | my_save( 114 | to_save_dict, 115 | f"{args.proj_dir}/rwkv-final.pth", 116 | ) 117 | 118 | 119 | def on_train_epoch_start(self, trainer, pl_module): 120 | args = self.args 121 | dataset = trainer.train_dataloader.dataset.datasets 122 | assert "MyDataset" in str(dataset) 123 | dataset.global_rank = trainer.global_rank 124 | dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch) 125 | dataset.world_size = trainer.world_size 126 | # print(f'########## world_size {dataset.world_size} global_rank {dataset.global_rank} real_epoch {dataset.real_epoch} ##########') 127 | 128 | def on_train_epoch_end(self, trainer, pl_module): 129 | args = self.args 130 | if trainer.is_global_zero: # logging & save state_dict 131 | if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1: 132 | if args.data_type == 'wds_img': 133 | raw_dict = pl_module.state_dict() 134 | to_save_dict = {} 135 | for k in raw_dict: 136 | if k.startswith('encoder.') or k.startswith('decoder.'): 137 | to_save_dict[k] = raw_dict[k] 138 | else: 139 | to_save_dict = pl_module.state_dict() 140 | 141 | if args.lora: 142 | enable_time_finetune = 'time' in LORA_CONFIG["parts"] 143 | enable_ln_finetune = 'ln' in LORA_CONFIG["parts"] 144 | lora_dict = {} 145 | for name, state in to_save_dict.items(): 146 | if ('.lora_' in name 147 | or (enable_time_finetune and '.time_' in name) 148 | or (enable_ln_finetune and '.ln' in name)): 149 | lora_dict[name] = state 150 | to_save_dict = lora_dict 151 | 152 | try: 153 | my_save( 154 | to_save_dict, 155 | f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth", 156 | ) 157 | except Exception as e: 158 | print('Error\n\n', e, '\n\n') 159 | trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n") 160 | trainer.my_log.flush() 161 | 162 | trainer.my_loss_sum = 0 163 | trainer.my_loss_count = 0 164 | 165 | 166 | @rank_zero_only 167 | def generate_init_weight(model, init_weight_name): 168 | mm = model.generate_init_weight() 169 | 170 | if model.args.my_pile_stage == 1: 171 | if len(model.args.load_model) > 0: 172 | print(f"Combine weights from {model.args.load_model}...") 173 | load_dict = torch.load(model.args.load_model, map_location="cpu") 174 | for k in load_dict: 175 | assert k in mm 176 | src = load_dict[k] 177 | try: 178 | mm[k] = src.reshape(mm[k].shape) 179 | except: 180 | tmp = mm[k].squeeze().clone() 181 | print(k, src.shape, '-->', mm[k].shape) 182 | ss = src.shape[0] 183 | dd = tmp.shape[0] 184 | for i in range(dd): 185 | pos = i / dd * ss 186 | if pos >= ss - 1: 187 | tmp[i] = src[ss-1] 188 | else: 189 | p0 = int(math.floor(pos)) 190 | ii = pos - p0 191 | tmp[i] = src[p0] * (1-ii) + src[p0+1] * (ii) 192 | mm[k] = tmp.reshape(mm[k].shape) 193 | sss = src.squeeze().float().cpu().numpy() 194 | print(sss[:10], '...', sss[-10:]) 195 | mmm = mm[k].squeeze().float().cpu().numpy() 196 | print(mmm[:10], '...', mmm[-10:]) 197 | 198 | print(f"Save to {init_weight_name}...") 199 | torch.save(mm, init_weight_name) 200 | 201 | if model.args.my_pile_stage == 1: 202 | print("Done. Now go for stage 2.") 203 | exit(0) 204 | -------------------------------------------------------------------------------- /RWKV-v4neo/src/utils.py: -------------------------------------------------------------------------------- 1 | import json, time, random, os 2 | import numpy as np 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | time_slot = {} 7 | time_ref = time.time_ns() 8 | 9 | def record_time(name): 10 | if name not in time_slot: 11 | time_slot[name] = 1e20 12 | tt = (time.time_ns() - time_ref) / 1e9 13 | if tt < time_slot[name]: 14 | time_slot[name] = tt 15 | 16 | class TOKENIZER(): 17 | def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'): 18 | if 'list' in str(type(WORD_NAME)): 19 | self.charMode = False 20 | if WORD_NAME[0] == WORD_NAME[1]: 21 | from transformers import PreTrainedTokenizerFast 22 | self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0]) 23 | else: 24 | from transformers import GPT2TokenizerFast 25 | self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1]) 26 | self.vocab_size = len(self.tokenizer) 27 | else: 28 | self.charMode = True 29 | with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file: 30 | self.word_table = json.load(result_file) 31 | 32 | self.vocab_size = len(self.word_table) 33 | 34 | self.stoi = {v: int(k) for k, v in self.word_table.items()} 35 | self.itos = {int(k): v for k, v in self.word_table.items()} 36 | 37 | self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR] 38 | 39 | def refine_context(self, context): 40 | context = context.strip().split('\n') 41 | for c in range(len(context)): 42 | context[c] = context[c].strip().strip('\u3000').strip('\r') 43 | context = list(filter(lambda c: c != '', context)) 44 | context = '\n' + ('\n'.join(context)).strip() 45 | if context == '': 46 | context = '\n' 47 | return context 48 | 49 | def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None): 50 | # out[self.UNKNOWN_CHAR] = -float('Inf') 51 | lastChar = int(x[-1]) 52 | 53 | probs = F.softmax(out, dim=-1) 54 | 55 | if self.charMode: 56 | if self.itos[lastChar] == '\n': 57 | top_p = top_p_newline 58 | else: 59 | top_p = top_p_usual 60 | else: 61 | top_p = top_p_usual 62 | 63 | if os.environ["RWKV_RUN_DEVICE"] == "cpu": 64 | probs = probs.numpy() 65 | sorted_probs = np.sort(probs)[::-1] 66 | cumulative_probs = np.cumsum(sorted_probs) 67 | cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) 68 | probs[probs < cutoff] = 0 69 | if temperature != 1.0: 70 | probs = probs.pow(1.0 / temperature) 71 | probs = probs / np.sum(probs) 72 | out = np.random.choice(a=len(probs), p=probs) 73 | return out 74 | else: 75 | sorted_probs = torch.sort(probs, descending=True)[0] 76 | cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy() 77 | cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) 78 | probs[probs < cutoff] = 0 79 | if temperature != 1.0: 80 | probs = probs.pow(1.0 / temperature) 81 | out = torch.multinomial(probs, num_samples=1)[0] 82 | return out 83 | 84 | def MaybeIsPrime(number): 85 | if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number): 86 | return True 87 | else: 88 | return False 89 | 90 | 91 | def FermatPrimalityTest(number): 92 | if number > 1: 93 | for time in range(3): 94 | randomNumber = random.randint(2, number) - 1 95 | if pow(randomNumber, number - 1, number) != 1: 96 | return False 97 | return True 98 | else: 99 | return False 100 | 101 | 102 | def MillerRabinPrimalityTest(number): 103 | if number == 2: 104 | return True 105 | elif number == 1 or number % 2 == 0: 106 | return False 107 | oddPartOfNumber = number - 1 108 | timesTwoDividNumber = 0 109 | while oddPartOfNumber % 2 == 0: 110 | oddPartOfNumber = oddPartOfNumber // 2 111 | timesTwoDividNumber = timesTwoDividNumber + 1 112 | 113 | for time in range(3): 114 | while True: 115 | randomNumber = random.randint(2, number) - 1 116 | if randomNumber != 0 and randomNumber != 1: 117 | break 118 | 119 | randomNumberWithPower = pow(randomNumber, oddPartOfNumber, number) 120 | 121 | if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1): 122 | iterationNumber = 1 123 | 124 | while (iterationNumber <= timesTwoDividNumber - 1) and (randomNumberWithPower != number - 1): 125 | randomNumberWithPower = pow(randomNumberWithPower, 2, number) 126 | iterationNumber = iterationNumber + 1 127 | if randomNumberWithPower != (number - 1): 128 | return False 129 | 130 | return True 131 | -------------------------------------------------------------------------------- /RWKV-v4neo/verify.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################## 2 | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM 3 | ######################################################################################################## 4 | 5 | # this is for verifying the results of different models and make sure they agree with each other 6 | 7 | import os, sys, types 8 | import numpy as np 9 | import torch 10 | np.set_printoptions(precision=4, suppress=True, linewidth=200) 11 | try: 12 | os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1] 13 | except: 14 | pass 15 | torch.backends.cudnn.benchmark = True 16 | torch.backends.cudnn.allow_tf32 = False 17 | torch.backends.cuda.matmul.allow_tf32 = False 18 | 19 | os.environ['RWKV_FLOAT_MODE'] = 'bf16' # bf16 or fp32 20 | os.environ['RWKV_RUN_DEVICE'] = 'cuda' # currently model_train requires CUDA 21 | RUN_DEVICE = os.environ['RWKV_RUN_DEVICE'] 22 | 23 | TOKEN_MODE = 'pile' 24 | 25 | if TOKEN_MODE == 'pile': 26 | WORD_NAME = ['20B_tokenizer.json', '20B_tokenizer.json'] 27 | MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221003-6783' 28 | n_layer = 32 29 | n_embd = 2560 30 | ctx_len = 1024 31 | UNKNOWN_CHAR = None 32 | 33 | from src.utils import TOKENIZER 34 | tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR) 35 | if TOKEN_MODE == 'pile': 36 | tokenizer.vocab_size = 50277 37 | 38 | ######################################################################################################## 39 | 40 | os.environ["RWKV_JIT_ON"] = "1" 41 | os.environ["RWKV_T_MAX"] = str(ctx_len) 42 | 43 | from src.model_run import RWKV_RNN 44 | from src.model import RWKV 45 | 46 | args = types.SimpleNamespace() 47 | args.vocab_size = tokenizer.vocab_size 48 | args.ctx_len = ctx_len 49 | args.n_embd = n_embd 50 | args.n_layer = n_layer 51 | args.head_qk = 0 52 | args.pre_ffn = 0 53 | args.grad_cp = 0 54 | args.my_pos_emb = 0 55 | model_train = RWKV(args).to(RUN_DEVICE) 56 | 57 | if os.environ['RWKV_FLOAT_MODE'] == 'fp16': 58 | model_train = model_train.half() 59 | elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': 60 | model_train = model_train.bfloat16() 61 | 62 | print('loading ' + MODEL_NAME) 63 | m2 = torch.load(MODEL_NAME + '.pth', map_location='cpu') 64 | model_train.load_state_dict(m2) 65 | 66 | if os.environ['RWKV_FLOAT_MODE'] == 'fp16': 67 | model_train = model_train.half() 68 | elif os.environ['RWKV_FLOAT_MODE'] == 'bf16': 69 | model_train = model_train.bfloat16() 70 | 71 | args.MODEL_NAME = MODEL_NAME 72 | args.RUN_DEVICE = RUN_DEVICE 73 | args.FLOAT_MODE = os.environ['RWKV_FLOAT_MODE'] 74 | model_rnn = RWKV_RNN(args) 75 | 76 | ######################################################################################################## 77 | 78 | print(f"\nVerifying {os.environ['RWKV_RUN_DEVICE']} {os.environ['RWKV_FLOAT_MODE']}") 79 | 80 | # context = '\nIn a' 81 | context = '\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.' 82 | 83 | if TOKEN_MODE == 'pile': 84 | ctx = tokenizer.tokenizer.encode(context) 85 | print(f'input len {len(ctx)} data {ctx}') 86 | 87 | ######################################################################################################## 88 | 89 | with torch.no_grad(): 90 | print('\nRWKV-train output') 91 | out = model_train.forward(torch.tensor([ctx]).to(RUN_DEVICE))[0].detach().cpu().float().numpy() 92 | print(out, '\n') 93 | 94 | print('\nRWKV-RNN output') 95 | state = None 96 | out = None 97 | src_len = len(ctx) 98 | for i in range(src_len): 99 | x = ctx[:i+1] 100 | out, state = model_rnn.forward(x, state) 101 | if i < 3 or i >= src_len - 3: 102 | print(out.detach().cpu().numpy()) 103 | if i == 2: 104 | print('...') 105 | -------------------------------------------------------------------------------- /RWKV-vs-MHA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blealtan/RWKV-LM-LoRA/4987137c31dd49cbbf2b2db3977930bc6ce5b84e/RWKV-vs-MHA.png -------------------------------------------------------------------------------- /Research/better_lr_schedule.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Blealtan/RWKV-LM-LoRA/4987137c31dd49cbbf2b2db3977930bc6ce5b84e/Research/better_lr_schedule.png --------------------------------------------------------------------------------