├── LEGAL.md ├── README.md ├── bin └── spm_train ├── config └── l3tc │ ├── l3tc_12m.py │ ├── l3tc_200k.py │ ├── l3tc_3m2.py │ └── l3tc_800k.py ├── dataset ├── dataset.py ├── enwik_ascii_dataset.py └── enwik_dataset.py ├── docs └── main_fig.png ├── main.py ├── models ├── RWKV_V4 │ ├── __init__.py │ ├── cuda │ │ ├── wkv_cuda.cu │ │ └── wkv_op.cpp │ ├── prwkv_infer.py │ ├── prwkv_train.py │ ├── ptq │ │ ├── __init__.py │ │ ├── bit_type.py │ │ ├── layers.py │ │ ├── observer │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── build.py │ │ │ ├── ema.py │ │ │ ├── minmax.py │ │ │ ├── omse.py │ │ │ ├── percentile.py │ │ │ ├── ptf.py │ │ │ └── utils.py │ │ └── quantizer │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── build.py │ │ │ ├── log2.py │ │ │ └── uniform.py │ ├── rwkv_tc_hira_infer.py │ ├── rwkv_tc_hira_multi_pred_infer.py │ ├── rwkv_tc_hira_multi_pred_train.py │ ├── rwkv_tc_hira_rbranch_infer.py │ ├── rwkv_tc_hira_rbranch_train.py │ ├── rwkv_tc_hira_train.py │ ├── rwkv_tc_infer.py │ ├── rwkv_tc_train.py │ ├── rwkv_v4_infer.py │ ├── rwkv_v4_multi_pred_infer.py │ ├── rwkv_v4_multi_pred_train.py │ ├── rwkv_v4_quant.py │ ├── rwkv_v4_train.py │ ├── rwkv_v5_infer.py │ └── rwkv_v5_train.py ├── __init__.py ├── compressor.py └── registry.py ├── requirements.txt ├── scripts ├── compressor.py └── preprocessor.py └── util ├── arithmetic_coder.py ├── arithmeticcoding.py ├── decode.py ├── encode.py ├── logger.py ├── misc.py ├── slconfig.py └── utils.py /LEGAL.md: -------------------------------------------------------------------------------- 1 | Legal Disclaimer 2 | 3 | Within this source code, the comments in Chinese shall be the original, governing version. Any comment in other languages are for reference only. In the event of any conflict between the Chinese language version comments and other language version comments, the Chinese language version shall prevail. 4 | 5 | 法律免责声明 6 | 7 | 关于代码注释部分,中文注释为官方版本,其它语言注释仅做参考。中文注释可能与其它语言注释存在不一致,当中文注释与其它语言注释存在不一致时,请以中文注释为准。 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # L3TC: Leveraging RWKV for Learned Lossless Low-Complexity Text Compression [AAAI2025] 2 | 3 | Official PyTorch implementation of the paper "L3TC: Leveraging RWKV for Learned Lossless Low-Complexity Text Compression". 4 | 5 | ## Abstract 6 | ![main figure](docs/main_fig.png) 7 | 8 | Learning-based probabilistic models can be combined with an entropy coder for data compression. However, due to the high complexity of learning-based models, their practical application as text compressors has been largely overlooked. To address this issue, our work focuses on a low-complexity design while maintaining compression performance. We introduce a novel Learned Lossless Low-complexity Text Compression method (L3TC). Specifically, we conduct extensive experiments demonstrating that RWKV models achieve the fastest decoding speed with a moderate compression ratio, making it the most suitable backbone for our method. Second, we propose an outlier-aware tokenizer that uses a limited vocabulary to cover frequent tokens while allowing outliers to bypass the prediction and encoding. Third, we propose a novel high-rank reparameterization strategy that enhances the learning capability during training without increasing complexity during inference. Experimental results validate that our method achieves 48% bit saving compared to gzip compressor. Besides, L3TC offers compression performance comparable to other learned compressors, with a 50× reduction in model parameters. More importantly, L3TC is the fastest among all learned compressors, providing real-time decoding speeds up to megabytes per second. 9 | 10 | ## Requirements 11 | 12 | ``` 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | ## Data Preprocess 17 | First, download enwik8 and enwik9 to data/public_text_dataset. Then run the following script to generate dictionary and train/val data. 18 | 19 | ``` 20 | python scripts/preprocessor.py 21 | ``` 22 | 23 | ## Train the model 24 | 25 | ``` 26 | python main.py --output_dir work_dirs -c ./config/l3tc/l3tc_200k.py --save_log --amp 27 | ``` 28 | 29 | ## Inference && Compression 30 | 31 | ``` 32 | python scripts/compressor.py \ 33 | -c "./config/l3tc/l3tc_200k.py" \ 34 | --pretrain_model_path "work_dirs/l3tc_200k_20241210_135843/checkpoint0019.pth" \ 35 | --tokenizer "dictionary/vocab_enwik8_bpe_16384_0.999/spm_enwik8_bpe_16384_0.999.model" \ 36 | --tmp_processed_dir "data/enwik9_results/l3tc_200k_bpe16k_enwik9" \ 37 | --segment_length 2048 \ 38 | --device cuda \ 39 | --input_file "data/public_text_dataset/enwik9" 40 | ``` 41 | ## L3TC Pretrained Models 42 | 43 | The pretrained models of L3TC can be downloaded from [google drive](https://drive.google.com/file/d/1LibLdeHTi3Io0H-ZiYZ6AhYUYKUkOpXz/view?usp=drive_link). 44 | 45 | ## Citation 46 | 47 | If you use our work, please consider citing: 48 | ```bibtex 49 | @misc{zhang2024l3tcleveragingrwkvlearned, 50 | title={L3TC: Leveraging RWKV for Learned Lossless Low-Complexity Text Compression}, 51 | author={Junxuan Zhang and Zhengxue Cheng and Yan Zhao and Shihao Wang and Dajiang Zhou and Guo Lu and Li Song}, 52 | year={2024}, 53 | eprint={2412.16642}, 54 | archivePrefix={arXiv}, 55 | primaryClass={cs.CL}, 56 | url={https://arxiv.org/abs/2412.16642}, 57 | } 58 | ``` 59 | 60 | ## Contact 61 | If you have any questions, please create an issue on this repository or contact at junxuan.zjx@antgroup.com or zxcheng@sjtu.edu.cn (Corresponding Author). 62 | 63 | ## Acknowledgements 64 | Our code is based on [BlinkDL](https://github.com/BlinkDL/RWKV-LM) 65 | -------------------------------------------------------------------------------- /bin/spm_train: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alipay/L3TC-leveraging-rwkv-for-learned-lossless-low-complexity-text-compression/91dedf6e49a1808345d4916058817d2ae5a16c0b/bin/spm_train -------------------------------------------------------------------------------- /config/l3tc/l3tc_12m.py: -------------------------------------------------------------------------------- 1 | ### 模型文件路径 ### 2 | source_dir = './data/raw_text_data/' 3 | dataset_name = 'enwik' 4 | vocab_path = './dictionary/vocab_enwik8_bpe_16384_0.999/spm_enwik8_bpe_16384_0.999_vocab.json' 5 | train_file = './data/train_data/train_enwik8_bpe_16384_0.999.txt' 6 | test_file = './data/test_data/test_enwik9_bpe_16384_0.999.txt' 7 | 8 | ### 训练参数 ### 9 | random_seed = 1204 10 | epoch = 20 11 | batch_size = 16 12 | epoch_length = 1000000 13 | checkpoint_path = './checkpoint/' 14 | print_freq = 10 15 | ctx_len = 2048 16 | sentence_length = ctx_len 17 | chunk_size = 1 18 | # scheduler = [None] 19 | # scheduler = ['multi_epoch', [5, 10], 0.1] 20 | scheduler = ['step_lr', 10, 0.9999] 21 | # scheduler = ['exponential_lr', 0.999] 22 | clip_max_norm = 5 23 | save_checkpoint_interval = 1 24 | 25 | # optimizer params 26 | betas = (0.9, 0.99) 27 | eps = 1e-8 28 | learning_rate = 1e-4 29 | 30 | ### 模型参数 ### 31 | model_name = "rwkv_tc_hira" 32 | num_hidden_layer = 4 33 | dropout = 0.2 34 | hidden_size = 384 35 | intermediate_size = 1024 36 | rwkv_rank = 4 -------------------------------------------------------------------------------- /config/l3tc/l3tc_200k.py: -------------------------------------------------------------------------------- 1 | ### 模型文件路径 ### 2 | source_dir = './data/raw_text_data/' 3 | dataset_name = 'enwik' 4 | vocab_path = './dictionary/vocab_enwik8_bpe_16384_0.999/spm_enwik8_bpe_16384_0.999_vocab.json' 5 | train_file = './data/train_data/train_enwik8_bpe_16384_0.999.txt' 6 | test_file = './data/test_data/test_enwik9_bpe_16384_0.999.txt' 7 | 8 | ### 训练参数 ### 9 | random_seed = 1204 10 | epoch = 20 11 | batch_size = 32 12 | epoch_length = 1000000 13 | checkpoint_path = './checkpoint/' 14 | print_freq = 10 15 | ctx_len = 2048 16 | sentence_length = ctx_len 17 | chunk_size = 1 18 | # scheduler = [None] 19 | # scheduler = ['multi_epoch', [5, 10], 0.1] 20 | scheduler = ['step_lr', 10, 0.9999] 21 | # scheduler = ['exponential_lr', 0.999] 22 | clip_max_norm = 5 23 | save_checkpoint_interval = 1 24 | 25 | # optimizer params 26 | betas = (0.9, 0.99) 27 | eps = 1e-8 28 | learning_rate = 1e-4 29 | 30 | ### 模型参数 ### 31 | model_name = "rwkv_tc_hira" 32 | num_hidden_layer = 2 33 | dropout = 0.0 34 | hidden_size = 96 35 | intermediate_size = 96 36 | rwkv_rank = 4 -------------------------------------------------------------------------------- /config/l3tc/l3tc_3m2.py: -------------------------------------------------------------------------------- 1 | ### 模型文件路径 ### 2 | source_dir = './data/raw_text_data/' 3 | dataset_name = 'enwik' 4 | vocab_path = './dictionary/vocab_enwik8_bpe_16384_0.999/spm_enwik8_bpe_16384_0.999_vocab.json' 5 | train_file = './data/train_data/train_enwik8_bpe_16384_0.999.txt' 6 | test_file = './data/test_data/test_enwik9_bpe_16384_0.999.txt' 7 | 8 | ### 训练参数 ### 9 | random_seed = 1204 10 | epoch = 20 11 | batch_size = 16 12 | epoch_length = 1000000 13 | checkpoint_path = './checkpoint/' 14 | print_freq = 10 15 | ctx_len = 2048 16 | sentence_length = ctx_len 17 | chunk_size = 1 18 | # scheduler = [None] 19 | # scheduler = ['multi_epoch', [5, 10], 0.1] 20 | scheduler = ['step_lr', 10, 0.9999] 21 | # scheduler = ['exponential_lr', 0.999] 22 | clip_max_norm = 5 23 | save_checkpoint_interval = 1 24 | 25 | # optimizer params 26 | betas = (0.9, 0.99) 27 | eps = 1e-8 28 | learning_rate = 1e-4 29 | 30 | ### 模型参数 ### 31 | model_name = "rwkv_tc_hira" 32 | num_hidden_layer = 3 33 | dropout = 0.2 34 | hidden_size = 256 35 | intermediate_size = 512 36 | rwkv_rank = 4 -------------------------------------------------------------------------------- /config/l3tc/l3tc_800k.py: -------------------------------------------------------------------------------- 1 | ### 模型文件路径 ### 2 | source_dir = './data/raw_text_data/' 3 | dataset_name = 'enwik' 4 | vocab_path = './dictionary/vocab_enwik8_bpe_16384_0.999/spm_enwik8_bpe_16384_0.999_vocab.json' 5 | train_file = './data/train_data/train_enwik8_bpe_16384_0.999.txt' 6 | test_file = './data/test_data/test_enwik9_bpe_16384_0.999.txt' 7 | 8 | ### 训练参数 ### 9 | random_seed = 1204 10 | epoch = 20 11 | batch_size = 16 12 | epoch_length = 1000000 13 | checkpoint_path = './checkpoint/' 14 | print_freq = 10 15 | ctx_len = 2048 16 | sentence_length = ctx_len 17 | chunk_size = 1 18 | # scheduler = [None] 19 | # scheduler = ['multi_epoch', [5, 10], 0.1] 20 | scheduler = ['step_lr', 10, 0.9999] 21 | # scheduler = ['exponential_lr', 0.999] 22 | clip_max_norm = 5 23 | save_checkpoint_interval = 1 24 | 25 | # optimizer params 26 | betas = (0.9, 0.99) 27 | eps = 1e-8 28 | learning_rate = 1e-4 29 | 30 | ### 模型参数 ### 31 | model_name = "rwkv_tc_hira" 32 | num_hidden_layer = 2 33 | dropout = 0.0 34 | hidden_size = 176 35 | intermediate_size = 192 36 | rwkv_rank = 4 -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import random 4 | import pkuseg 5 | import numpy as np 6 | 7 | from tqdm import tqdm 8 | from torch.utils.data import Dataset 9 | 10 | 11 | class TransformerXLTrainDataSet(Dataset): 12 | def __init__(self, args, corpus_path, word2id_dict): 13 | self.corpus_path = corpus_path 14 | self.descriptions = [] 15 | self.segment_length = args.sentence_length 16 | self.word2id_dict = word2id_dict 17 | self.chunk_size = args.chunk_size 18 | self.epoch_length = args.epoch_length 19 | 20 | if args.debug: 21 | self.epoch_length = args.batch_size * 2 22 | 23 | self.start_token = word2id_dict[''] 24 | self.unknown_token = word2id_dict[''] 25 | self.padding_token = word2id_dict[''] 26 | 27 | with open(corpus_path, 'r', encoding='utf-8') as f: 28 | self.lines = f.readlines() 29 | 30 | self.prob_table = self.get_probabilty_table() 31 | 32 | def get_probabilty_table(self): 33 | sum_length = 0 34 | for line in self.lines: 35 | sum_length += len(line) 36 | 37 | prob_table = [] 38 | for line_id, line in enumerate(self.lines): 39 | prob = len(line) / sum_length 40 | if prob < 0.001: 41 | prob_table.append(line_id) 42 | else: 43 | prob_table.extend([line_id] * int(prob * 1000)) 44 | 45 | return prob_table 46 | 47 | def __len__(self): 48 | return self.epoch_length 49 | 50 | def __getitem__(self, item): 51 | output = {} 52 | 53 | line_id = self.prob_table[np.random.randint(low=0, high=len(self.prob_table))] 54 | line = self.lines[line_id] 55 | line = line.strip() 56 | descriptions_ids = [int(x) for x in line.split(',')] 57 | 58 | input_token = [] 59 | output_token = [] 60 | input_types = [] 61 | output_types = [] 62 | 63 | if len(descriptions_ids) <= self.chunk_size * self.segment_length: 64 | start_token_id = 0 65 | else: 66 | start_token_id = np.random.randint(low=0, high=len(descriptions_ids) - self.chunk_size * self.segment_length) 67 | 68 | segments_count = 0 69 | while True: 70 | if start_token_id + segments_count * self.segment_length < len(descriptions_ids) and segments_count < self.chunk_size: 71 | cur_input_type = [] 72 | cur_desc_type = [] 73 | if segments_count == 0: 74 | cur_input_segment = [self.start_token] + descriptions_ids[start_token_id + segments_count * self.segment_length: start_token_id + (segments_count+1) * self.segment_length][:-1] 75 | cur_desc_segment = descriptions_ids[start_token_id + segments_count * self.segment_length: start_token_id + (segments_count+1) * self.segment_length] 76 | else: 77 | cur_input_segment = descriptions_ids[start_token_id + segments_count * self.segment_length - 1: start_token_id + (segments_count+1) * self.segment_length - 1] 78 | cur_desc_segment = descriptions_ids[start_token_id + segments_count * self.segment_length: start_token_id + (segments_count+1) * self.segment_length] 79 | 80 | # 补全padding 81 | if len(cur_input_segment) < self.segment_length: 82 | for i in range(0, self.segment_length - len(cur_input_segment)): 83 | cur_input_segment.append(self.padding_token) 84 | 85 | if len(cur_desc_segment) < self.segment_length: 86 | for i in range(0, self.segment_length - len(cur_desc_segment)): 87 | cur_desc_segment.append(self.padding_token) 88 | 89 | input_token.append(cur_input_segment) 90 | output_token.append(cur_desc_segment) 91 | 92 | # 生成对应的type 93 | for i in cur_input_segment: 94 | if i == self.padding_token: 95 | cur_input_type.append(0) 96 | else: 97 | cur_input_type.append(1) 98 | 99 | # 生成对应的type 100 | for i in cur_desc_segment: 101 | if i == self.padding_token: 102 | cur_desc_type.append(0) 103 | else: 104 | cur_desc_type.append(1) 105 | 106 | input_types.append(cur_input_type) 107 | output_types.append(cur_desc_type) 108 | 109 | segments_count += 1 110 | else: 111 | break 112 | 113 | if len(input_token) < self.chunk_size: 114 | for _ in range(self.chunk_size - len(input_token)): 115 | input_token.append([self.padding_token] * self.segment_length) 116 | output_token.append([self.padding_token] * self.segment_length) 117 | input_types.append([0] * self.segment_length) 118 | output_types.append([0] * self.segment_length) 119 | 120 | output['input_token'] = input_token 121 | output['output_token'] = output_token 122 | output['input_types'] = input_types 123 | output['output_types'] = output_types 124 | else: 125 | output['input_token'] = input_token 126 | output['output_token'] = output_token 127 | output['input_types'] = input_types 128 | output['output_types'] = output_types 129 | 130 | instance = {k: torch.tensor(v, dtype=torch.long) for k, v in output.items()} 131 | return instance 132 | 133 | 134 | class TransformerXLTestDataSet(Dataset): 135 | def __init__(self, args, corpus_path, word2id_dict): 136 | self.corpus_path = corpus_path 137 | self.descriptions = [] 138 | self.segment_length = args.sentence_length 139 | self.model_name = args.model_name 140 | 141 | if 'rwkv' in self.model_name: 142 | self.num_segments = 10 143 | 144 | self.start_token = word2id_dict[''] 145 | self.unknown_token = word2id_dict[''] 146 | self.padding_token = word2id_dict[''] 147 | 148 | with open(corpus_path, 'r', encoding='utf-8') as f: 149 | self.lines = f.readlines() 150 | if args.debug: 151 | self.lines = self.lines[:10] 152 | 153 | def __len__(self): 154 | return len(self.lines) 155 | 156 | def __getitem__(self, item): 157 | output = {} 158 | 159 | line = self.lines[item] 160 | line = line.strip() 161 | descriptions_ids = [int(x) for x in line.split(',')] 162 | 163 | # descriptions_ids = self.__gen_input_token(descriptions_ids) 164 | input_token = [] 165 | output_token = [] 166 | input_types = [] 167 | output_types = [] 168 | 169 | segments_count = 0 170 | while True: 171 | if segments_count * self.segment_length < len(descriptions_ids): 172 | cur_input_type = [] 173 | cur_desc_type = [] 174 | if segments_count == 0: 175 | cur_input_segment = [self.start_token] + descriptions_ids[segments_count*self.segment_length: min((segments_count+1) * self.segment_length, len(descriptions_ids))][:-1] 176 | cur_desc_segment = descriptions_ids[segments_count*self.segment_length: min((segments_count+1) * self.segment_length, len(descriptions_ids))] 177 | else: 178 | cur_input_segment = descriptions_ids[segments_count*self.segment_length-1: min((segments_count+1) * self.segment_length, len(descriptions_ids))-1] 179 | cur_desc_segment = descriptions_ids[segments_count*self.segment_length: min((segments_count+1) * self.segment_length, len(descriptions_ids))] 180 | 181 | # 补全padding 182 | if len(cur_input_segment) < self.segment_length: 183 | for i in range(0, self.segment_length - len(cur_input_segment)): 184 | cur_input_segment.append(self.padding_token) 185 | 186 | if len(cur_desc_segment) < self.segment_length: 187 | for i in range(0, self.segment_length - len(cur_desc_segment)): 188 | cur_desc_segment.append(self.padding_token) 189 | 190 | input_token.append(cur_input_segment) 191 | output_token.append(cur_desc_segment) 192 | 193 | # 生成对应的type 194 | for i in cur_input_segment: 195 | if i == self.padding_token: 196 | cur_input_type.append(0) 197 | else: 198 | cur_input_type.append(1) 199 | 200 | # 生成对应的type 201 | for i in cur_desc_segment: 202 | if i == self.padding_token: 203 | cur_desc_type.append(0) 204 | else: 205 | cur_desc_type.append(1) 206 | 207 | input_types.append(cur_input_type) 208 | output_types.append(cur_desc_type) 209 | 210 | segments_count += 1 211 | else: 212 | break 213 | 214 | if 'rwkv' in self.model_name: 215 | if segments_count < self.num_segments: 216 | padding_vals = [[0] * self.segment_length] * (self.num_segments - segments_count) 217 | input_token.extend(padding_vals) 218 | output_token.extend(padding_vals) 219 | input_types.extend(padding_vals) 220 | output_types.extend(padding_vals) 221 | else: 222 | input_token = input_token[:self.num_segments] 223 | output_token = output_token[:self.num_segments] 224 | input_types = input_types[:self.num_segments] 225 | output_types = output_types[:self.num_segments] 226 | 227 | output['input_token'] = input_token 228 | output['output_token'] = output_token 229 | output['input_types'] = input_types 230 | output['output_types'] = output_types 231 | instance = {k: torch.tensor(v, dtype=torch.long) for k, v in output.items()} 232 | 233 | return instance 234 | 235 | 236 | if __name__ == '__main__': 237 | # dataloader = TransformerXLDataSet(CorpusPath) 238 | dataloader = TransformerXLTestSet(EvalPath) 239 | for data in dataloader: 240 | x = 1 241 | break 242 | print('加载完成') 243 | -------------------------------------------------------------------------------- /dataset/enwik_ascii_dataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import random 4 | import pkuseg 5 | import numpy as np 6 | 7 | from tqdm import tqdm 8 | from torch.utils.data import Dataset 9 | 10 | 11 | class EnWikASCIITrainDataSet(Dataset): 12 | def __init__(self, args, corpus_path, word2id_dict): 13 | self.corpus_path = corpus_path 14 | self.descriptions = [] 15 | self.segment_length = args.sentence_length 16 | self.word2id_dict = word2id_dict 17 | self.chunk_size = args.chunk_size 18 | self.epoch_length = args.epoch_length 19 | 20 | if args.debug: 21 | self.epoch_length = args.batch_size * 2 22 | 23 | self.start_token = word2id_dict[''] 24 | self.unknown_token = word2id_dict[''] 25 | self.padding_token = word2id_dict[''] 26 | 27 | with open(corpus_path, 'rb') as f: 28 | train_data = f.read() 29 | 30 | self.tokens = [int(x) + 3 for x in train_data] 31 | 32 | def __len__(self): 33 | return self.epoch_length 34 | 35 | def __getitem__(self, item): 36 | output = {} 37 | 38 | input_token = [] 39 | output_token = [] 40 | input_types = [] 41 | output_types = [] 42 | 43 | start_token_id = np.random.randint(low=0, high=len(self.tokens) - self.chunk_size * self.segment_length) 44 | 45 | segments_count = 0 46 | while True: 47 | if start_token_id + segments_count * self.segment_length < len(self.tokens) and segments_count < self.chunk_size: 48 | cur_input_type = [] 49 | cur_desc_type = [] 50 | if segments_count == 0: 51 | cur_input_segment = [self.start_token] + self.tokens[start_token_id + segments_count * self.segment_length: start_token_id + (segments_count+1) * self.segment_length][:-1] 52 | cur_desc_segment = self.tokens[start_token_id + segments_count * self.segment_length: start_token_id + (segments_count+1) * self.segment_length] 53 | else: 54 | cur_input_segment = self.tokens[start_token_id + segments_count * self.segment_length - 1: start_token_id + (segments_count+1) * self.segment_length - 1] 55 | cur_desc_segment = self.tokens[start_token_id + segments_count * self.segment_length: start_token_id + (segments_count+1) * self.segment_length] 56 | 57 | # 补全padding 58 | if len(cur_input_segment) < self.segment_length: 59 | for i in range(0, self.segment_length - len(cur_input_segment)): 60 | cur_input_segment.append(self.padding_token) 61 | 62 | if len(cur_desc_segment) < self.segment_length: 63 | for i in range(0, self.segment_length - len(cur_desc_segment)): 64 | cur_desc_segment.append(self.padding_token) 65 | 66 | input_token.append(cur_input_segment) 67 | output_token.append(cur_desc_segment) 68 | 69 | # 生成对应的type 70 | for i in cur_input_segment: 71 | if i == self.padding_token: 72 | cur_input_type.append(0) 73 | else: 74 | cur_input_type.append(1) 75 | 76 | # 生成对应的type 77 | for i in cur_desc_segment: 78 | if i == self.padding_token: 79 | cur_desc_type.append(0) 80 | else: 81 | cur_desc_type.append(1) 82 | 83 | input_types.append(cur_input_type) 84 | output_types.append(cur_desc_type) 85 | 86 | segments_count += 1 87 | else: 88 | break 89 | 90 | if len(input_token) < self.chunk_size: 91 | for _ in range(self.chunk_size - len(input_token)): 92 | input_token.append([self.padding_token] * self.segment_length) 93 | output_token.append([self.padding_token] * self.segment_length) 94 | input_types.append([0] * self.segment_length) 95 | output_types.append([0] * self.segment_length) 96 | 97 | output['input_token'] = input_token 98 | output['output_token'] = output_token 99 | output['input_types'] = input_types 100 | output['output_types'] = output_types 101 | else: 102 | output['input_token'] = input_token 103 | output['output_token'] = output_token 104 | output['input_types'] = input_types 105 | output['output_types'] = output_types 106 | 107 | instance = {k: torch.tensor(v, dtype=torch.long) for k, v in output.items()} 108 | return instance 109 | 110 | 111 | class EnWikASCIITestDataSet(Dataset): 112 | def __init__(self, args, corpus_path, word2id_dict): 113 | self.corpus_path = corpus_path 114 | self.descriptions = [] 115 | self.segment_length = args.sentence_length 116 | self.model_name = args.model_name 117 | 118 | self.start_token = word2id_dict[''] 119 | self.unknown_token = word2id_dict[''] 120 | self.padding_token = word2id_dict[''] 121 | 122 | with open(corpus_path, 'r', encoding='utf-8') as f: 123 | lines = f.readlines() 124 | 125 | test_lines = "".join(lines[1128023:1128023+5000]) 126 | test_data = test_lines.encode() 127 | tokens = [int(x) + 3 for x in test_data] 128 | 129 | # 将tokens按照一定长度分割成多份,加速验证 130 | self.seg_tokens = [] 131 | seg_data_len = 10000 132 | num_segs = math.ceil(len(tokens) / seg_data_len) 133 | for i in range(num_segs): 134 | self.seg_tokens.append(tokens[i * seg_data_len: (i+1) * seg_data_len]) 135 | 136 | if 'rwkv' in self.model_name: 137 | self.num_segments = math.ceil(seg_data_len / self.segment_length) 138 | 139 | if args.debug: 140 | self.seg_tokens = self.seg_tokens[:5] 141 | 142 | def __len__(self): 143 | return len(self.seg_tokens) 144 | 145 | def __getitem__(self, item): 146 | output = {} 147 | 148 | descriptions_ids = self.seg_tokens[item] 149 | 150 | input_token = [] 151 | output_token = [] 152 | input_types = [] 153 | output_types = [] 154 | 155 | segments_count = 0 156 | while True: 157 | if segments_count * self.segment_length < len(descriptions_ids): 158 | cur_input_type = [] 159 | cur_desc_type = [] 160 | if segments_count == 0: 161 | cur_input_segment = [self.start_token] + descriptions_ids[segments_count*self.segment_length: min((segments_count+1) * self.segment_length, len(descriptions_ids))][:-1] 162 | cur_desc_segment = descriptions_ids[segments_count*self.segment_length: min((segments_count+1) * self.segment_length, len(descriptions_ids))] 163 | else: 164 | cur_input_segment = descriptions_ids[segments_count*self.segment_length-1: min((segments_count+1) * self.segment_length, len(descriptions_ids))-1] 165 | cur_desc_segment = descriptions_ids[segments_count*self.segment_length: min((segments_count+1) * self.segment_length, len(descriptions_ids))] 166 | 167 | # 补全padding 168 | if len(cur_input_segment) < self.segment_length: 169 | for i in range(0, self.segment_length - len(cur_input_segment)): 170 | cur_input_segment.append(self.padding_token) 171 | 172 | if len(cur_desc_segment) < self.segment_length: 173 | for i in range(0, self.segment_length - len(cur_desc_segment)): 174 | cur_desc_segment.append(self.padding_token) 175 | 176 | input_token.append(cur_input_segment) 177 | output_token.append(cur_desc_segment) 178 | 179 | # 生成对应的type 180 | for i in cur_input_segment: 181 | if i == self.padding_token: 182 | cur_input_type.append(0) 183 | else: 184 | cur_input_type.append(1) 185 | 186 | # 生成对应的type 187 | for i in cur_desc_segment: 188 | if i == self.padding_token: 189 | cur_desc_type.append(0) 190 | else: 191 | cur_desc_type.append(1) 192 | 193 | input_types.append(cur_input_type) 194 | output_types.append(cur_desc_type) 195 | 196 | segments_count += 1 197 | else: 198 | break 199 | 200 | if 'rwkv' in self.model_name: 201 | if segments_count < self.num_segments: 202 | padding_vals = [[0] * self.segment_length] * (self.num_segments - segments_count) 203 | input_token.extend(padding_vals) 204 | output_token.extend(padding_vals) 205 | input_types.extend(padding_vals) 206 | output_types.extend(padding_vals) 207 | else: 208 | input_token = input_token[:self.num_segments] 209 | output_token = output_token[:self.num_segments] 210 | input_types = input_types[:self.num_segments] 211 | output_types = output_types[:self.num_segments] 212 | 213 | output['input_token'] = input_token 214 | output['output_token'] = output_token 215 | output['input_types'] = input_types 216 | output['output_types'] = output_types 217 | instance = {k: torch.tensor(v, dtype=torch.long) for k, v in output.items()} 218 | 219 | return instance 220 | 221 | 222 | if __name__ == '__main__': 223 | # dataloader = TransformerXLDataSet(CorpusPath) 224 | dataloader = TransformerXLTestSet(EvalPath) 225 | for data in dataloader: 226 | x = 1 227 | break 228 | print('加载完成') 229 | -------------------------------------------------------------------------------- /dataset/enwik_dataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import random 4 | import pkuseg 5 | import numpy as np 6 | 7 | from tqdm import tqdm 8 | from torch.utils.data import Dataset 9 | from multiprocessing import Pool 10 | 11 | 12 | class EnWikTrainDataSet(Dataset): 13 | def __init__(self, args, corpus_path, word2id_dict): 14 | self.corpus_path = corpus_path 15 | self.descriptions = [] 16 | self.segment_length = args.sentence_length 17 | self.word2id_dict = word2id_dict 18 | self.chunk_size = args.chunk_size 19 | self.epoch_length = args.epoch_length 20 | 21 | if args.debug: 22 | self.epoch_length = args.batch_size * 2 23 | 24 | self.start_token = word2id_dict[''] 25 | self.unknown_token = word2id_dict[''] 26 | self.padding_token = word2id_dict[''] 27 | 28 | with open(corpus_path, 'r', encoding='utf-8') as f: 29 | lines = f.readlines() 30 | 31 | train_data = lines[0].strip() 32 | self.tokens = [] 33 | token_str = "" 34 | 35 | a = torch.ones(10000, 100000).cuda() 36 | b = torch.ones(100000, 10000).cuda() 37 | pbar = tqdm(total=len(train_data)) 38 | for i, c in enumerate(train_data): 39 | if c != ",": 40 | token_str += c 41 | else: 42 | self.tokens.append(int(token_str)) 43 | token_str = "" 44 | 45 | if i % 1000000 == 0: 46 | c = torch.matmul(a, b) 47 | 48 | pbar.update() 49 | pbar.close() 50 | 51 | del lines 52 | del train_data 53 | 54 | # self.tokens = [int(x) for x in train_data.split(",")] 55 | 56 | def __len__(self): 57 | return self.epoch_length 58 | 59 | def __getitem__(self, item): 60 | output = {} 61 | 62 | input_token = [] 63 | output_token = [] 64 | input_types = [] 65 | output_types = [] 66 | 67 | start_token_id = np.random.randint(low=0, high=len(self.tokens) - self.chunk_size * self.segment_length) 68 | 69 | segments_count = 0 70 | while True: 71 | if start_token_id + segments_count * self.segment_length < len(self.tokens) and segments_count < self.chunk_size: 72 | cur_input_type = [] 73 | cur_desc_type = [] 74 | if segments_count == 0: 75 | cur_input_segment = [self.start_token] + self.tokens[start_token_id + segments_count * self.segment_length: start_token_id + (segments_count+1) * self.segment_length][:-1] 76 | cur_desc_segment = self.tokens[start_token_id + segments_count * self.segment_length: start_token_id + (segments_count+1) * self.segment_length] 77 | else: 78 | cur_input_segment = self.tokens[start_token_id + segments_count * self.segment_length - 1: start_token_id + (segments_count+1) * self.segment_length - 1] 79 | cur_desc_segment = self.tokens[start_token_id + segments_count * self.segment_length: start_token_id + (segments_count+1) * self.segment_length] 80 | 81 | # 补全padding 82 | if len(cur_input_segment) < self.segment_length: 83 | for i in range(0, self.segment_length - len(cur_input_segment)): 84 | cur_input_segment.append(self.padding_token) 85 | 86 | if len(cur_desc_segment) < self.segment_length: 87 | for i in range(0, self.segment_length - len(cur_desc_segment)): 88 | cur_desc_segment.append(self.padding_token) 89 | 90 | input_token.append(cur_input_segment) 91 | output_token.append(cur_desc_segment) 92 | 93 | # 生成对应的type 94 | for i in cur_input_segment: 95 | if i == self.padding_token: 96 | cur_input_type.append(0) 97 | else: 98 | cur_input_type.append(1) 99 | 100 | # 生成对应的type 101 | for i in cur_desc_segment: 102 | if i == self.padding_token: 103 | cur_desc_type.append(0) 104 | else: 105 | cur_desc_type.append(1) 106 | 107 | input_types.append(cur_input_type) 108 | output_types.append(cur_desc_type) 109 | 110 | segments_count += 1 111 | else: 112 | break 113 | 114 | if len(input_token) < self.chunk_size: 115 | for _ in range(self.chunk_size - len(input_token)): 116 | input_token.append([self.padding_token] * self.segment_length) 117 | output_token.append([self.padding_token] * self.segment_length) 118 | input_types.append([0] * self.segment_length) 119 | output_types.append([0] * self.segment_length) 120 | 121 | output['input_token'] = input_token 122 | output['output_token'] = output_token 123 | output['input_types'] = input_types 124 | output['output_types'] = output_types 125 | else: 126 | output['input_token'] = input_token 127 | output['output_token'] = output_token 128 | output['input_types'] = input_types 129 | output['output_types'] = output_types 130 | 131 | instance = {k: torch.tensor(v, dtype=torch.long) for k, v in output.items()} 132 | return instance 133 | 134 | 135 | class EnWikTestDataSet(Dataset): 136 | def __init__(self, args, corpus_path, word2id_dict): 137 | self.corpus_path = corpus_path 138 | self.descriptions = [] 139 | self.segment_length = args.sentence_length 140 | self.model_name = args.model_name 141 | 142 | self.start_token = word2id_dict[''] 143 | self.unknown_token = word2id_dict[''] 144 | self.padding_token = word2id_dict[''] 145 | 146 | with open(corpus_path, 'r', encoding='utf-8') as f: 147 | lines = f.readlines() 148 | 149 | test_data = lines[0].strip() 150 | tokens = [int(x) for x in test_data.split(",")] 151 | 152 | # 将tokens按照一定长度分割成多份,加速验证 153 | self.seg_tokens = [] 154 | seg_data_len = 10000 155 | num_segs = math.ceil(len(tokens) / seg_data_len) 156 | for i in range(num_segs): 157 | self.seg_tokens.append(tokens[i * seg_data_len: (i+1) * seg_data_len]) 158 | 159 | if 'rwkv' in self.model_name: 160 | self.num_segments = math.ceil(seg_data_len / self.segment_length) 161 | 162 | if args.debug: 163 | self.seg_tokens = self.seg_tokens[:5] 164 | 165 | def __len__(self): 166 | return len(self.seg_tokens) 167 | 168 | def __getitem__(self, item): 169 | output = {} 170 | 171 | descriptions_ids = self.seg_tokens[item] 172 | 173 | input_token = [] 174 | output_token = [] 175 | input_types = [] 176 | output_types = [] 177 | 178 | segments_count = 0 179 | while True: 180 | if segments_count * self.segment_length < len(descriptions_ids): 181 | cur_input_type = [] 182 | cur_desc_type = [] 183 | if segments_count == 0: 184 | cur_input_segment = [self.start_token] + descriptions_ids[segments_count*self.segment_length: min((segments_count+1) * self.segment_length, len(descriptions_ids))][:-1] 185 | cur_desc_segment = descriptions_ids[segments_count*self.segment_length: min((segments_count+1) * self.segment_length, len(descriptions_ids))] 186 | else: 187 | cur_input_segment = descriptions_ids[segments_count*self.segment_length-1: min((segments_count+1) * self.segment_length, len(descriptions_ids))-1] 188 | cur_desc_segment = descriptions_ids[segments_count*self.segment_length: min((segments_count+1) * self.segment_length, len(descriptions_ids))] 189 | 190 | # 补全padding 191 | if len(cur_input_segment) < self.segment_length: 192 | for i in range(0, self.segment_length - len(cur_input_segment)): 193 | cur_input_segment.append(self.padding_token) 194 | 195 | if len(cur_desc_segment) < self.segment_length: 196 | for i in range(0, self.segment_length - len(cur_desc_segment)): 197 | cur_desc_segment.append(self.padding_token) 198 | 199 | input_token.append(cur_input_segment) 200 | output_token.append(cur_desc_segment) 201 | 202 | # 生成对应的type 203 | for i in cur_input_segment: 204 | if i == self.padding_token: 205 | cur_input_type.append(0) 206 | else: 207 | cur_input_type.append(1) 208 | 209 | # 生成对应的type 210 | for i in cur_desc_segment: 211 | if i == self.padding_token: 212 | cur_desc_type.append(0) 213 | else: 214 | cur_desc_type.append(1) 215 | 216 | input_types.append(cur_input_type) 217 | output_types.append(cur_desc_type) 218 | 219 | segments_count += 1 220 | else: 221 | break 222 | 223 | if 'rwkv' in self.model_name: 224 | if segments_count < self.num_segments: 225 | padding_vals = [[0] * self.segment_length] * (self.num_segments - segments_count) 226 | input_token.extend(padding_vals) 227 | output_token.extend(padding_vals) 228 | input_types.extend(padding_vals) 229 | output_types.extend(padding_vals) 230 | else: 231 | input_token = input_token[:self.num_segments] 232 | output_token = output_token[:self.num_segments] 233 | input_types = input_types[:self.num_segments] 234 | output_types = output_types[:self.num_segments] 235 | 236 | output['input_token'] = input_token 237 | output['output_token'] = output_token 238 | output['input_types'] = input_types 239 | output['output_types'] = output_types 240 | instance = {k: torch.tensor(v, dtype=torch.long) for k, v in output.items()} 241 | 242 | return instance 243 | 244 | 245 | if __name__ == '__main__': 246 | # dataloader = TransformerXLDataSet(CorpusPath) 247 | dataloader = TransformerXLTestSet(EvalPath) 248 | for data in dataloader: 249 | x = 1 250 | break 251 | print('加载完成') 252 | -------------------------------------------------------------------------------- /docs/main_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alipay/L3TC-leveraging-rwkv-for-learned-lossless-low-complexity-text-compression/91dedf6e49a1808345d4916058817d2ae5a16c0b/docs/main_fig.png -------------------------------------------------------------------------------- /models/RWKV_V4/__init__.py: -------------------------------------------------------------------------------- 1 | from .rwkv_v4_train import build_rwkv_v4 2 | from .rwkv_v5_train import build_rwkv_v5 3 | from .prwkv_train import build_point_rwkv 4 | from .rwkv_tc_train import build_rwkv_tc 5 | from .rwkv_tc_hira_train import build_rwkv_tc_hira 6 | from .rwkv_tc_hira_rbranch_train import build_rwkv_tc_hira_rbranch 7 | from .rwkv_v4_multi_pred_train import build_rwkv_v4_multi_pred 8 | from .rwkv_v4_multi_pred_train import build_rwkv_v4_multi_pred 9 | from .rwkv_tc_hira_multi_pred_train import build_rwkv_tc_hira_multi_pred 10 | 11 | from .rwkv_v4_infer import build_rwkv_v4_infer_for_coreml 12 | from .rwkv_v4_infer import build_rwkv_v4_infer_for_onnx 13 | from .rwkv_v4_infer import build_rwkv_v4_infer_for_script 14 | 15 | from .rwkv_v4_quant import build_rwkv_v4_quant_infer 16 | 17 | from .rwkv_v5_infer import build_rwkv_v5_infer_for_coreml 18 | from .rwkv_v5_infer import build_rwkv_v5_infer_for_onnx 19 | from .rwkv_v5_infer import build_rwkv_v5_infer_for_script 20 | 21 | from .prwkv_infer import build_prwkv_infer_for_coreml 22 | from .prwkv_infer import build_prwkv_infer_for_onnx 23 | from .prwkv_infer import build_prwkv_infer_for_script 24 | 25 | from .rwkv_tc_infer import build_rwkv_tc_infer_for_coreml 26 | from .rwkv_tc_infer import build_rwkv_tc_infer_for_onnx 27 | from .rwkv_tc_infer import build_rwkv_tc_infer_for_script 28 | 29 | from .rwkv_tc_hira_infer import build_rwkv_tc_hira_infer_for_coreml 30 | from .rwkv_tc_hira_infer import build_rwkv_tc_hira_infer_for_onnx 31 | from .rwkv_tc_hira_infer import build_rwkv_tc_hira_infer_for_script 32 | 33 | from .rwkv_tc_hira_rbranch_infer import build_rwkv_tc_hira_rbranch_infer_for_coreml 34 | from .rwkv_tc_hira_rbranch_infer import build_rwkv_tc_hira_rbranch_infer_for_onnx 35 | from .rwkv_tc_hira_rbranch_infer import build_rwkv_tc_hira_rbranch_infer_for_script 36 | -------------------------------------------------------------------------------- /models/RWKV_V4/cuda/wkv_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #define MIN_VALUE (-1e38) 5 | 6 | template 7 | __global__ void kernel_forward(const int B, const int T, const int C, 8 | const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, 9 | F *__restrict__ const _y) { 10 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; 11 | const int _b = idx / C; 12 | const int _c = idx % C; 13 | const int _offset = _b * T * C + _c; 14 | 15 | F u = _u[_c]; 16 | F w = _w[_c]; 17 | const F *__restrict__ const k = _k + _offset; 18 | const F *__restrict__ const v = _v + _offset; 19 | F *__restrict__ const y = _y + _offset; 20 | 21 | F p = 0, q = 0, o = MIN_VALUE; 22 | // p and q are running sums divided by exp(o) (to avoid overflows) 23 | for (int i = 0; i < T; i++) { 24 | const int ii = i * C; 25 | 26 | F no = max(o, u + k[ii]); 27 | F A = exp(o - no); 28 | F B = exp(u + k[ii] - no); 29 | y[ii] = (A * p + B * v[ii]) / (A * q + B); 30 | 31 | no = max(w + o, k[ii]); 32 | A = exp(w + o - no); 33 | B = exp(k[ii] - no); 34 | p = A * p + B * v[ii]; 35 | q = A * q + B; 36 | o = no; 37 | } 38 | } 39 | 40 | template 41 | __global__ void kernel_backward(const int B, const int T, const int C, 42 | const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy, 43 | F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) { 44 | const int idx = blockIdx.x * blockDim.x + threadIdx.x; 45 | const int _b = idx / C; 46 | const int _c = idx % C; 47 | const int _offset = _b * T * C + _c; 48 | 49 | F u = _u[_c]; 50 | F w = _w[_c]; 51 | const F *__restrict__ const k = _k + _offset; 52 | const F *__restrict__ const v = _v + _offset; 53 | const F *__restrict__ const gy = _gy + _offset; 54 | 55 | F *__restrict__ const gk = _gk + _offset; 56 | F *__restrict__ const gv = _gv + _offset; 57 | 58 | F y[Tmax], z[Tmax], zexp[Tmax]; 59 | 60 | F gw = 0, gu = 0; 61 | F p = 0, q = 0; 62 | F dpdw = 0, dqdw = 0; 63 | F o = MIN_VALUE; 64 | for (int i = 0; i < T; i++) { 65 | const int ii = i * C; 66 | F no = max(o, k[ii] + u); 67 | F A = exp(o - no); 68 | F B = exp(k[ii] + u - no); 69 | 70 | F num = A * p + B * v[ii]; 71 | F iden = 1 / (A * q + B); 72 | 73 | y[i] = num * iden; 74 | z[i] = iden; 75 | zexp[i] = k[ii] + u - no; 76 | 77 | gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A; 78 | gu += gy[ii] * (v[ii] - y[i]) * B * iden; 79 | 80 | no = max(w + o, k[ii]); 81 | A = exp(w + o - no); 82 | B = exp(k[ii] - no); 83 | dpdw = A * (p + dpdw); 84 | dqdw = A * (q + dqdw); 85 | p = A * p + B * v[ii]; 86 | q = A * q + B; 87 | o = no; 88 | } 89 | 90 | F gp = 0, gq = 0; 91 | o = MIN_VALUE; 92 | for (int i = T - 1; i >= 0; i--) { 93 | const int ii = i * C; 94 | F A = gy[ii] * z[i] * exp(zexp[i]); 95 | F B = exp(k[ii] + o); 96 | gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq); 97 | gv[ii] = A + B * gp; 98 | 99 | F no = max(w + o, zexp[i] - k[ii] - u); 100 | A = exp(w + o - no); 101 | B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no); 102 | gp = A * gp + B; 103 | gq = A * gq - B * y[i]; 104 | o = no; 105 | } 106 | 107 | // Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass 108 | const int _offsetBC = _b * C + _c; 109 | _gw[_offsetBC] += gw * _w[_c]; 110 | _gu[_offsetBC] += gu; 111 | } 112 | 113 | void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { 114 | dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance 115 | assert(B * C % threadsPerBlock.x == 0); 116 | dim3 numBlocks(B * C / threadsPerBlock.x); 117 | kernel_forward<<>>(B, T, C, w, u, k, v, y); 118 | } 119 | 120 | void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) { 121 | dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance 122 | assert(B * C % threadsPerBlock.x == 0); 123 | dim3 numBlocks(B * C / threadsPerBlock.x); 124 | kernel_backward<<>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv); 125 | } 126 | -------------------------------------------------------------------------------- /models/RWKV_V4/cuda/wkv_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); 4 | void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv); 5 | 6 | void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { 7 | cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); 8 | } 9 | void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { 10 | cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); 11 | } 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("forward", &forward, "wkv forward"); 15 | m.def("backward", &backward, "wkv backward"); 16 | } 17 | 18 | TORCH_LIBRARY(wkv, m) { 19 | m.def("forward", forward); 20 | m.def("backward", backward); 21 | } 22 | -------------------------------------------------------------------------------- /models/RWKV_V4/prwkv_infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numbers 4 | import torch.nn as nn 5 | from torch.nn import init 6 | import torch.nn.functional as F 7 | from ..registry import MODULE_BUILD_FUNCS 8 | from torch import Tensor, Size 9 | from typing import Union, List 10 | _shape_t = Union[int, List[int], Size] 11 | 12 | class LayerNorm(nn.Module): 13 | def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, device=None, dtype=None): 14 | factory_kwargs = {'device': device, 'dtype': dtype} 15 | super(LayerNorm, self).__init__() 16 | if isinstance(normalized_shape, numbers.Integral): 17 | # mypy error: incompatible types in assignment 18 | normalized_shape = (normalized_shape,) # type: ignore[assignment] 19 | self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] 20 | self.eps = eps 21 | self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) 22 | self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) 23 | self.reset_parameters() 24 | 25 | def reset_parameters(self) -> None: 26 | init.ones_(self.weight) 27 | init.zeros_(self.bias) 28 | 29 | def forward(self, input: Tensor) -> Tensor: 30 | u = torch.mean(input, dim=-1, keepdim=True) 31 | s = torch.mean(input * input, dim=-1, keepdim=True) 32 | s = torch.sqrt(s - u * u + self.eps) 33 | x_normalized = (input - u) / s 34 | output = x_normalized * self.weight + self.bias 35 | return output 36 | 37 | 38 | class RWKV_ChannelMix(nn.Module): 39 | def __init__(self, n_embed, ffn_dim): 40 | super().__init__() 41 | 42 | self.time_mix_k = nn.Parameter(torch.ones(1, n_embed)) 43 | self.time_mix_r = nn.Parameter(torch.ones(1, n_embed)) 44 | 45 | self.key = nn.Linear(n_embed, ffn_dim, bias=False) 46 | self.receptance = nn.Linear(n_embed, n_embed, bias=False) 47 | self.value = nn.Linear(ffn_dim, n_embed, bias=False) 48 | 49 | def forward(self, x, state_ffn): 50 | xk = x * self.time_mix_k + state_ffn * (1 - self.time_mix_k) 51 | xr = x * self.time_mix_r + state_ffn * (1 - self.time_mix_r) 52 | new_ffn = x 53 | 54 | r = torch.sigmoid(self.receptance(xr)) 55 | k = torch.square(torch.relu(self.key(xk))) 56 | kv = self.value(k) 57 | 58 | rkv = r * kv 59 | return rkv, new_ffn 60 | 61 | 62 | class RWKV_TimeMix(nn.Module): 63 | def __init__(self, n_embed): 64 | super().__init__() 65 | self.time_decay = nn.Parameter(torch.ones(n_embed)) 66 | self.time_first = nn.Parameter(torch.ones(n_embed) * math.log(0.3)) 67 | 68 | self.time_mix_k = nn.Parameter(torch.ones(1, n_embed)) 69 | self.time_mix_v = nn.Parameter(torch.ones(1, n_embed)) 70 | self.time_mix_r = nn.Parameter(torch.ones(1, n_embed)) 71 | 72 | self.key = nn.Linear(n_embed, n_embed, bias=False) 73 | self.value = nn.Linear(n_embed, n_embed, bias=False) 74 | self.receptance = nn.Linear(n_embed, n_embed, bias=False) 75 | 76 | self.output = nn.Linear(n_embed, n_embed, bias=False) 77 | 78 | def forward(self, x, state_A, state_B, state_p, state_x): 79 | xk = x * self.time_mix_k + state_x * (1 - self.time_mix_k) 80 | xv = x * self.time_mix_v + state_x * (1 - self.time_mix_v) 81 | xr = x * self.time_mix_r + state_x * (1 - self.time_mix_r) 82 | new_x = x 83 | 84 | k = self.key(xk) 85 | v = self.value(xv) 86 | r = torch.sigmoid(self.receptance(xr)) 87 | 88 | ww = self.time_first + k 89 | p = torch.maximum(state_p, ww) 90 | e1 = torch.exp(state_p - p) 91 | e2 = torch.exp(ww - p) 92 | a = e1 * state_A + e2 * v 93 | b = e1 * state_B + e2 94 | 95 | ww = state_p + -torch.exp(self.time_decay) 96 | p = torch.maximum(ww, k) 97 | e1 = torch.exp(ww - p) 98 | e2 = torch.exp(k - p) 99 | new_A = e1 * state_A + e2 * v 100 | new_B = e1 * state_B + e2 101 | new_p = p 102 | 103 | rwkv = r * a / b 104 | rwkv = self.output(rwkv) 105 | return rwkv, new_A, new_B, new_p, new_x 106 | 107 | 108 | class RWKV_TimeMix_ONNX(nn.Module): 109 | def __init__(self, n_embed): 110 | super().__init__() 111 | self.time_decay = nn.Parameter(torch.ones(n_embed)) 112 | self.time_first = nn.Parameter(torch.ones(n_embed) * math.log(0.3)) 113 | 114 | self.time_mix_k = nn.Parameter(torch.ones(1, n_embed)) 115 | self.time_mix_v = nn.Parameter(torch.ones(1, n_embed)) 116 | self.time_mix_r = nn.Parameter(torch.ones(1, n_embed)) 117 | 118 | self.key = nn.Linear(n_embed, n_embed, bias=False) 119 | self.value = nn.Linear(n_embed, n_embed, bias=False) 120 | self.receptance = nn.Linear(n_embed, n_embed, bias=False) 121 | 122 | self.output = nn.Linear(n_embed, n_embed, bias=False) 123 | 124 | def forward(self, x, state_A, state_B, state_p, state_x): 125 | xk = x * self.time_mix_k + state_x * (1 - self.time_mix_k) 126 | xv = x * self.time_mix_v + state_x * (1 - self.time_mix_v) 127 | xr = x * self.time_mix_r + state_x * (1 - self.time_mix_r) 128 | new_x = x 129 | 130 | k = self.key(xk) 131 | v = self.value(xv) 132 | r = torch.sigmoid(self.receptance(xr)) 133 | 134 | ww = self.time_first + k 135 | # p = torch.maximum(state_p, ww) 136 | p = torch.stack([state_p.flatten(), ww.flatten()]).max(dim=0)[0].view(state_p.shape) 137 | # p = torch.where(state_p > ww, state_p, ww) 138 | 139 | e1 = torch.exp(state_p - p) 140 | e2 = torch.exp(ww - p) 141 | a = e1 * state_A + e2 * v 142 | b = e1 * state_B + e2 143 | 144 | ww = state_p + -torch.exp(self.time_decay) 145 | # p = torch.maximum(ww, k) 146 | p = torch.stack([ww.flatten(), k.flatten()]).max(dim=0)[0].view(state_p.shape) 147 | # p = torch.where(ww > k, ww, k) 148 | 149 | e1 = torch.exp(ww - p) 150 | e2 = torch.exp(k - p) 151 | new_A = e1 * state_A + e2 * v 152 | new_B = e1 * state_B + e2 153 | new_p = p 154 | 155 | rwkv = r * a / b 156 | rwkv = self.output(rwkv) 157 | return rwkv, new_A, new_B, new_p, new_x 158 | 159 | 160 | class Block(nn.Module): 161 | def __init__(self, layer_id, n_embed, ffn_dim): 162 | super().__init__() 163 | self.layer_id = layer_id 164 | 165 | self.ln1 = nn.LayerNorm(n_embed) 166 | self.ln2 = nn.LayerNorm(n_embed) 167 | if self.layer_id == 0: 168 | self.ln0 = nn.LayerNorm(n_embed) 169 | 170 | self.att = RWKV_TimeMix(n_embed) 171 | self.ffn = RWKV_ChannelMix(n_embed, ffn_dim) 172 | self.short = nn.Linear(n_embed, n_embed, bias=False) 173 | 174 | def forward(self, x, state_A, state_B, state_p, state_x, state_ffn): 175 | if self.layer_id == 0: 176 | x = self.ln0(x) 177 | 178 | short = F.relu(self.short(x)) 179 | 180 | short_cut = x 181 | x = self.ln1(x) 182 | x, new_A, new_B, new_p, new_x = self.att(x, state_A, state_B, state_p, state_x) 183 | x = short_cut + x 184 | 185 | short_cut = x 186 | x = self.ln2(x) 187 | x, new_ffn = self.ffn(x, state_ffn) 188 | x = short_cut + x 189 | 190 | x = x + short 191 | return x, new_A, new_B, new_p, new_x, new_ffn 192 | 193 | 194 | class Block_ONNX(nn.Module): 195 | def __init__(self, layer_id, n_embed, ffn_dim): 196 | super().__init__() 197 | self.layer_id = layer_id 198 | 199 | self.ln1 = nn.LayerNorm(n_embed) 200 | self.ln2 = nn.LayerNorm(n_embed) 201 | if self.layer_id == 0: 202 | self.ln0 = nn.LayerNorm(n_embed) 203 | 204 | self.att = RWKV_TimeMix_ONNX(n_embed) 205 | self.ffn = RWKV_ChannelMix(n_embed, ffn_dim) 206 | self.short = nn.Linear(n_embed, n_embed, bias=False) 207 | 208 | def forward(self, x, state_A, state_B, state_p, state_x, state_ffn): 209 | if self.layer_id == 0: 210 | x = self.ln0(x) 211 | 212 | short = F.relu(self.short(x)) 213 | 214 | short_cut = x 215 | x, new_A, new_B, new_p, new_x = self.att(self.ln1(x), state_A, state_B, state_p, state_x) 216 | x = short_cut + x 217 | 218 | short_cut = x 219 | x, new_ffn = self.ffn(self.ln2(x), state_ffn) 220 | x = short_cut + x 221 | 222 | x = x + short 223 | return x, new_A, new_B, new_p, new_x, new_ffn 224 | 225 | 226 | class Block_Script(nn.Module): 227 | def __init__(self, n_embed, ffn_dim): 228 | super().__init__() 229 | 230 | self.ln1 = nn.LayerNorm(n_embed) 231 | self.ln2 = nn.LayerNorm(n_embed) 232 | 233 | self.att = RWKV_TimeMix(n_embed) 234 | self.ffn = RWKV_ChannelMix(n_embed, ffn_dim) 235 | self.short = nn.Linear(n_embed, n_embed, bias=False) 236 | 237 | def forward(self, x, state_A, state_B, state_p, state_x, state_ffn): 238 | short = F.relu(self.short(x)) 239 | 240 | short_cut = x 241 | x, new_A, new_B, new_p, new_x = self.att(self.ln1(x), state_A, state_B, state_p, state_x) 242 | x = short_cut + x 243 | 244 | short_cut = x 245 | x, new_ffn = self.ffn(self.ln2(x), state_ffn) 246 | x = short_cut + x 247 | 248 | x = x + short 249 | return x, new_A, new_B, new_p, new_x, new_ffn 250 | 251 | 252 | class PRWKV_Infer_For_CoreML(nn.Module): 253 | def __init__(self, 254 | vocab_size=2000, 255 | hidden_size=512, 256 | num_hidden_layers=4, 257 | intermediate_size=1024, 258 | ): 259 | super(PRWKV_Infer_For_CoreML, self).__init__() 260 | self.hidden_size = hidden_size 261 | self.num_hidden_layers = num_hidden_layers 262 | 263 | self.emb = nn.Embedding(vocab_size, hidden_size) 264 | self.blocks = nn.ModuleList([Block(i, hidden_size, intermediate_size) for i in range(num_hidden_layers)]) 265 | self.ln_out = nn.LayerNorm(hidden_size) 266 | self.head = nn.Linear(hidden_size, vocab_size, bias=False) 267 | 268 | def forward_initialzation(self, batch_size, device): 269 | state_A = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 270 | state_B = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 271 | state_p = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) - 1e30 272 | state_x = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 273 | state_ffn = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 274 | hidden_state = torch.stack([state_A, state_B, state_p, state_x, state_ffn]).to(device) 275 | return hidden_state 276 | 277 | def forward(self, x, hidden_state): 278 | # x = self.emb(input_token) 279 | # x = torch.matmul(input_onehot, self.emb.weight) 280 | 281 | batch_size = x.size(0) 282 | state_A, state_B, state_p, state_x, state_ffn = hidden_state.split(1, dim=0) 283 | new_hidden_state = [] 284 | 285 | for i, block in enumerate(self.blocks): 286 | x, new_A, new_B, new_p, new_x, new_ffn = \ 287 | block(x, state_A[0, i], state_B[0, i], state_p[0, i], state_x[0, i], state_ffn[0, i]) 288 | 289 | new_hidden_state.append(new_A) 290 | new_hidden_state.append(new_B) 291 | new_hidden_state.append(new_p) 292 | new_hidden_state.append(new_x) 293 | new_hidden_state.append(new_ffn) 294 | 295 | new_hidden_state = torch.cat(new_hidden_state) 296 | new_hidden_state = new_hidden_state.view([self.num_hidden_layers, 5, batch_size, self.hidden_size]) 297 | new_hidden_state = new_hidden_state.transpose(0, 1) 298 | x = self.ln_out(x) 299 | x = self.head(x) 300 | return x, new_hidden_state 301 | 302 | 303 | class PRWKV_Infer_For_ONNX(nn.Module): 304 | def __init__(self, 305 | vocab_size=2000, 306 | hidden_size=512, 307 | num_hidden_layers=4, 308 | intermediate_size=1024, 309 | ): 310 | super(PRWKV_Infer_For_ONNX, self).__init__() 311 | self.hidden_size = hidden_size 312 | self.num_hidden_layers = num_hidden_layers 313 | 314 | self.emb = nn.Embedding(vocab_size, hidden_size) 315 | self.blocks = nn.ModuleList([Block_ONNX(i, hidden_size, intermediate_size) for i in range(num_hidden_layers)]) 316 | self.ln_out = nn.LayerNorm(hidden_size) 317 | self.head = nn.Linear(hidden_size, vocab_size, bias=False) 318 | 319 | def forward_initialzation(self, batch_size, device): 320 | state_A = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 321 | state_B = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 322 | state_p = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) - 1e30 323 | state_x = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 324 | state_ffn = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 325 | hidden_state = torch.stack([state_A, state_B, state_p, state_x, state_ffn]).to(device) 326 | return hidden_state 327 | 328 | def forward(self, x, hidden_state): 329 | # x = self.emb(input_token) 330 | batch_size = x.size(0) 331 | # x = torch.matmul(input_onehot, self.emb.weight) 332 | state_A, state_B, state_p, state_x, state_ffn = hidden_state.split(1, dim=0) 333 | new_hidden_state = [] 334 | 335 | for i, block in enumerate(self.blocks): 336 | x, new_A, new_B, new_p, new_x, new_ffn = \ 337 | block(x, state_A[0, i], state_B[0, i], state_p[0, i], state_x[0, i], state_ffn[0, i]) 338 | 339 | new_hidden_state.append(new_A) 340 | new_hidden_state.append(new_B) 341 | new_hidden_state.append(new_p) 342 | new_hidden_state.append(new_x) 343 | new_hidden_state.append(new_ffn) 344 | 345 | new_hidden_state = torch.cat(new_hidden_state) 346 | new_hidden_state = new_hidden_state.view([self.num_hidden_layers, 5, batch_size, self.hidden_size]) 347 | new_hidden_state = new_hidden_state.transpose(0, 1) 348 | x = self.ln_out(x) 349 | x = self.head(x) 350 | return x, new_hidden_state 351 | 352 | 353 | class PRWKV_Infer_For_Script(nn.Module): 354 | def __init__(self, 355 | vocab_size=2000, 356 | hidden_size=512, 357 | num_hidden_layers=4, 358 | intermediate_size=1024, 359 | ): 360 | super(PRWKV_Infer_For_Script, self).__init__() 361 | self.hidden_size = hidden_size 362 | self.num_hidden_layers = num_hidden_layers 363 | 364 | self.emb = nn.Embedding(vocab_size, hidden_size) 365 | self.ln0 = nn.LayerNorm(hidden_size) 366 | self.blocks = nn.ModuleList([Block_Script(hidden_size, intermediate_size) for i in range(num_hidden_layers)]) 367 | self.ln_out = nn.LayerNorm(hidden_size) 368 | self.head = nn.Linear(hidden_size, vocab_size, bias=False) 369 | 370 | def forward_initialzation(self, batch_size, device): 371 | state_A = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 372 | state_B = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 373 | state_p = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) - 1e30 374 | state_x = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 375 | state_ffn = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 376 | hidden_state = torch.stack([state_A, state_B, state_p, state_x, state_ffn]).to(device) 377 | return hidden_state 378 | 379 | def forward(self, input_token, hidden_state): 380 | x = self.emb(input_token) 381 | batch_size = input_token.size(0) 382 | # x = torch.matmul(input_onehot, self.emb.weight) 383 | state_A, state_B, state_p, state_x, state_ffn = hidden_state.split(1, dim=0) 384 | new_hidden_state = [] 385 | 386 | x = self.ln0(x) 387 | for i, block in enumerate(self.blocks): 388 | x, new_A, new_B, new_p, new_x, new_ffn = \ 389 | block(x, state_A[0, i], state_B[0, i], state_p[0, i], state_x[0, i], state_ffn[0, i]) 390 | 391 | new_hidden_state.append(new_A) 392 | new_hidden_state.append(new_B) 393 | new_hidden_state.append(new_p) 394 | new_hidden_state.append(new_x) 395 | new_hidden_state.append(new_ffn) 396 | 397 | new_hidden_state = torch.cat(new_hidden_state) 398 | new_hidden_state = new_hidden_state.view([self.num_hidden_layers, 5, batch_size, self.hidden_size]) 399 | new_hidden_state = new_hidden_state.transpose(0, 1) 400 | x = self.ln_out(x) 401 | x = self.head(x) 402 | return x, new_hidden_state 403 | 404 | 405 | @MODULE_BUILD_FUNCS.registe_with_name(module_name='prwkv_infer_for_coreml') 406 | def build_prwkv_infer_for_coreml(args): 407 | model = PRWKV_Infer_For_CoreML( 408 | vocab_size = args.vocab_size, 409 | hidden_size = args.hidden_size, 410 | num_hidden_layers = args.num_hidden_layer, 411 | intermediate_size = args.intermediate_size 412 | ) 413 | criterion = nn.CrossEntropyLoss(reduction='none') 414 | return model, criterion 415 | 416 | 417 | @MODULE_BUILD_FUNCS.registe_with_name(module_name='prwkv_infer_for_onnx') 418 | def build_prwkv_infer_for_onnx(args): 419 | model = PRWKV_Infer_For_ONNX( 420 | vocab_size = args.vocab_size, 421 | hidden_size = args.hidden_size, 422 | num_hidden_layers = args.num_hidden_layer, 423 | intermediate_size = args.intermediate_size 424 | ) 425 | criterion = nn.CrossEntropyLoss(reduction='none') 426 | return model, criterion 427 | 428 | 429 | @MODULE_BUILD_FUNCS.registe_with_name(module_name='prwkv_infer_for_script') 430 | def build_prwkv_infer_for_script(args): 431 | model = PRWKV_Infer_For_Script( 432 | vocab_size = args.vocab_size, 433 | hidden_size = args.hidden_size, 434 | num_hidden_layers = args.num_hidden_layer, 435 | intermediate_size = args.intermediate_size 436 | ) 437 | criterion = nn.CrossEntropyLoss(reduction='none') 438 | return model, criterion -------------------------------------------------------------------------------- /models/RWKV_V4/ptq/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MEGVII Inc. and its affiliates. All Rights Reserved. 2 | from .bit_type import BIT_TYPE_DICT 3 | from .layers import QAct, QConv2d, QIntLayerNorm, QIntSoftmax, QLinear, QRelPosEmbedding 4 | -------------------------------------------------------------------------------- /models/RWKV_V4/ptq/bit_type.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MEGVII Inc. and its affiliates. All Rights Reserved. 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class BitType: 8 | 9 | def __init__(self, bits, signed, name=None): 10 | self.bits = bits 11 | self.signed = signed 12 | if name is not None: 13 | self.name = name 14 | else: 15 | self.update_name() 16 | 17 | @property 18 | def upper_bound(self): 19 | if not self.signed: 20 | return 2**self.bits - 1 21 | return 2**(self.bits - 1) - 1 22 | 23 | @property 24 | def lower_bound(self): 25 | if not self.signed: 26 | return 0 27 | return -(2**(self.bits - 1)) 28 | 29 | @property 30 | def range(self): 31 | return 2**self.bits 32 | 33 | def update_name(self): 34 | self.name = '' 35 | if not self.signed: 36 | self.name += 'uint' 37 | else: 38 | self.name += 'int' 39 | self.name += '{}'.format(self.bits) 40 | 41 | 42 | BIT_TYPE_LIST = [ 43 | BitType(4, False, 'uint4'), 44 | BitType(8, True, 'int8'), 45 | BitType(8, False, 'uint8') 46 | ] 47 | BIT_TYPE_DICT = {bit_type.name: bit_type for bit_type in BIT_TYPE_LIST} 48 | -------------------------------------------------------------------------------- /models/RWKV_V4/ptq/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MEGVII Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | 6 | from .bit_type import BIT_TYPE_DICT 7 | from .observer import build_observer 8 | from .quantizer import build_quantizer 9 | 10 | 11 | class QRelPosEmbedding(nn.Module): 12 | 13 | def __init__(self, 14 | emb_dim, 15 | quant=False, 16 | calibrate=False, 17 | last_calibrate=False, 18 | bit_type=BIT_TYPE_DICT['int8'], 19 | calibration_mode='layer_wise', 20 | observer_str='minmax', 21 | quantizer_str='uniform', 22 | ): 23 | super(QRelPosEmbedding, self).__init__() 24 | self.quant = quant 25 | self.calibrate = calibrate 26 | self.last_calibrate = last_calibrate 27 | self.bit_type = bit_type 28 | self.calibration_mode = calibration_mode 29 | self.observer_str = observer_str 30 | self.quantizer_str = quantizer_str 31 | 32 | self.module_type = 'activation' 33 | self.observer = build_observer(self.observer_str, self.module_type, 34 | self.bit_type, self.calibration_mode) 35 | self.quantizer = build_quantizer(self.quantizer_str, self.bit_type, 36 | self.observer, self.module_type) 37 | 38 | self.emb_dim = emb_dim 39 | self.inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) 40 | 41 | def forward(self, seq_len, mem_len): 42 | pos_seq = torch.arange(1.0, seq_len + mem_len + 1) 43 | sinusoid_inp = pos_seq.unsqueeze(dim=1) * self.inv_freq.unsqueeze(dim=0) 44 | pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) 45 | pos_emb = pos_emb[mem_len - 1:seq_len + mem_len - 1] 46 | 47 | if self.calibrate: 48 | self.quantizer.observer.update(pos_emb) 49 | if self.last_calibrate: 50 | self.quantizer.update_quantization_params(pos_emb) 51 | if not self.quant: 52 | return pos_emb 53 | pos_emb = self.quantizer(pos_emb) 54 | return pos_emb 55 | 56 | 57 | # class QEmbedding(nn.Embedding): 58 | 59 | # def __init__(self, 60 | # num_embeddings, 61 | # embeddings_dim, 62 | # padding_idx=None, 63 | # max_norm=None, 64 | # norm_type=2.0, 65 | # scale_grad_by_freq=False, 66 | # sparse=False, 67 | # _weight=None, 68 | # device=None, 69 | # dtype=None, 70 | # quant=False, 71 | # calibrate=False, 72 | # last_calibrate=False, 73 | # bit_type=BIT_TYPE_DICT['int8'], 74 | # calibration_mode='layer_wise', 75 | # observer_str='minmax', 76 | # quantizer_str='uniform'): 77 | # if torch.__version__[:5] == '1.7.1': 78 | # super(QEmbedding, self).__init__(num_embeddings, 79 | # embeddings_dim, 80 | # padding_idx, 81 | # max_norm, 82 | # norm_type, 83 | # scale_grad_by_freq, 84 | # sparse, 85 | # _weight, 86 | # ) 87 | # else: 88 | # super(QEmbedding, self).__init__(num_embeddings, 89 | # embeddings_dim, 90 | # padding_idx, 91 | # max_norm, 92 | # norm_type, 93 | # scale_grad_by_freq, 94 | # sparse, 95 | # _weight, 96 | # device, 97 | # dtype, 98 | # ) 99 | 100 | # self.quant = quant 101 | # self.calibrate = calibrate 102 | # self.last_calibrate = last_calibrate 103 | # self.bit_type = bit_type 104 | # self.calibration_mode = calibration_mode 105 | # self.observer_str = observer_str 106 | # self.quantizer_str = quantizer_str 107 | 108 | # self.module_type = 'embedding' 109 | # self.observer = build_observer(self.observer_str, self.module_type, 110 | # self.bit_type, self.calibration_mode) 111 | # self.quantizer = build_quantizer(self.quantizer_str, self.bit_type, 112 | # self.observer, self.module_type) 113 | 114 | # def forward(self, x): 115 | # if self.calibrate: 116 | # self.quantizer.observer.update(self.weight) 117 | # if self.last_calibrate: 118 | # self.quantizer.update_quantization_params(x) 119 | # if not self.quant: 120 | # return F.embedding( 121 | # x, 122 | # self.weight, 123 | # self.padding_idx, 124 | # self.max_norm, 125 | # self.norm_type, 126 | # self.scale_grad_by_freq, 127 | # self.sparse, 128 | # ) 129 | 130 | # weight = self.quantizer(self.weight) 131 | # return F.embedding( 132 | # x, 133 | # weight, 134 | # self.padding_idx, 135 | # self.max_norm, 136 | # self.norm_type, 137 | # self.scale_grad_by_freq, 138 | # self.sparse, 139 | # ) 140 | 141 | 142 | class QConv2d(nn.Conv2d): 143 | 144 | def __init__(self, 145 | in_channels, 146 | out_channels, 147 | kernel_size, 148 | stride=1, 149 | padding=0, 150 | dilation=1, 151 | groups=1, 152 | bias=True, 153 | quant=False, 154 | calibrate=False, 155 | last_calibrate=False, 156 | bit_type=BIT_TYPE_DICT['int8'], 157 | calibration_mode='layer_wise', 158 | observer_str='minmax', 159 | quantizer_str='uniform'): 160 | super(QConv2d, self).__init__( 161 | in_channels=in_channels, 162 | out_channels=out_channels, 163 | kernel_size=kernel_size, 164 | stride=stride, 165 | padding=padding, 166 | dilation=dilation, 167 | groups=groups, 168 | bias=bias, 169 | ) 170 | self.quant = quant 171 | self.calibrate = calibrate 172 | self.last_calibrate = last_calibrate 173 | self.bit_type = bit_type 174 | self.calibration_mode = calibration_mode 175 | self.observer_str = observer_str 176 | self.quantizer_str = quantizer_str 177 | 178 | self.module_type = 'conv_weight' 179 | self.observer = build_observer(self.observer_str, self.module_type, 180 | self.bit_type, self.calibration_mode) 181 | self.quantizer = build_quantizer(self.quantizer_str, self.bit_type, 182 | self.observer, self.module_type) 183 | 184 | def forward(self, x): 185 | if self.calibrate: 186 | self.quantizer.observer.update(self.weight) 187 | if self.last_calibrate: 188 | self.quantizer.update_quantization_params(x) 189 | if not self.quant: 190 | return F.conv2d( 191 | x, 192 | self.weight, 193 | self.bias, 194 | self.stride, 195 | self.padding, 196 | self.dilation, 197 | self.groups, 198 | ) 199 | weight = self.quantizer(self.weight) 200 | return F.conv2d(x, weight, self.bias, self.stride, self.padding, 201 | self.dilation, self.groups) 202 | 203 | 204 | class QLinear(nn.Linear): 205 | 206 | def __init__(self, 207 | in_features, 208 | out_features, 209 | bias=True, 210 | quant=False, 211 | calibrate=False, 212 | last_calibrate=False, 213 | bit_type=BIT_TYPE_DICT['int8'], 214 | calibration_mode='layer_wise', 215 | observer_str='minmax', 216 | quantizer_str='uniform'): 217 | super(QLinear, self).__init__(in_features, out_features, bias) 218 | 219 | self.quant = quant 220 | self.calibrate = calibrate 221 | self.last_calibrate = last_calibrate 222 | self.bit_type = bit_type 223 | self.calibration_mode = calibration_mode 224 | self.observer_str = observer_str 225 | self.quantizer_str = quantizer_str 226 | 227 | self.module_type = 'linear_weight' 228 | self.observer = build_observer(self.observer_str, self.module_type, 229 | self.bit_type, self.calibration_mode) 230 | self.quantizer = build_quantizer(self.quantizer_str, self.bit_type, 231 | self.observer, self.module_type) 232 | 233 | def forward(self, x): 234 | if self.calibrate: 235 | self.quantizer.observer.update(self.weight) 236 | if self.last_calibrate: 237 | self.quantizer.update_quantization_params(x) 238 | if not self.quant: 239 | return F.linear(x, self.weight, self.bias) 240 | weight = self.quantizer(self.weight) 241 | return F.linear(x, weight, self.bias) 242 | 243 | 244 | class QAct(nn.Module): 245 | 246 | def __init__(self, 247 | quant=False, 248 | calibrate=False, 249 | last_calibrate=False, 250 | bit_type=BIT_TYPE_DICT['int8'], 251 | calibration_mode='layer_wise', 252 | observer_str='minmax', 253 | quantizer_str='uniform'): 254 | super(QAct, self).__init__() 255 | 256 | self.quant = quant 257 | self.calibrate = calibrate 258 | self.last_calibrate = last_calibrate 259 | self.bit_type = bit_type 260 | self.calibration_mode = calibration_mode 261 | self.observer_str = observer_str 262 | self.quantizer_str = quantizer_str 263 | 264 | self.module_type = 'activation' 265 | self.observer = build_observer(self.observer_str, self.module_type, 266 | self.bit_type, self.calibration_mode) 267 | self.quantizer = build_quantizer(self.quantizer_str, self.bit_type, 268 | self.observer, self.module_type) 269 | 270 | def forward(self, x): 271 | if self.calibrate: 272 | self.quantizer.observer.update(x) 273 | if self.last_calibrate: 274 | self.quantizer.update_quantization_params(x) 275 | if not self.quant: 276 | return x 277 | x = self.quantizer(x) 278 | return x 279 | 280 | 281 | class QIntLayerNorm(nn.LayerNorm): 282 | 283 | def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): 284 | super(QIntLayerNorm, self).__init__(normalized_shape, eps, 285 | elementwise_affine) 286 | assert isinstance(normalized_shape, int) 287 | self.mode = 'ln' 288 | 289 | def get_MN(self, x): 290 | bit = 8 291 | N = torch.clamp(bit - 1 - torch.floor(torch.log2(x)), 0, 31) 292 | M = torch.clamp(torch.floor(x * torch.pow(2, N)), 0, 2 ** bit - 1) 293 | return M, N 294 | 295 | def forward(self, 296 | x, 297 | in_quantizer=None, 298 | out_quantizer=None, 299 | in_scale_expand=1): 300 | if self.mode == 'ln': 301 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, 302 | self.eps) 303 | elif self.mode == 'int': 304 | in_scale = in_quantizer.scale 305 | if in_scale_expand != 1: 306 | in_scale = in_scale.unsqueeze(-1).expand( 307 | -1, in_scale_expand).T.reshape(-1) 308 | out_scale = out_quantizer.scale 309 | assert in_scale is not None and out_scale is not None 310 | channel_nums = x.shape[-1] 311 | in_scale = in_scale.reshape(1, -1) 312 | out_scale = out_scale.reshape(1, -1) 313 | x_q = (x / in_scale).round() 314 | in_scale1 = in_scale.min() 315 | in_scale_mask = (in_scale / in_scale1).round() 316 | 317 | x_q = x_q * in_scale_mask 318 | 319 | mean_x_q = x_q.mean(dim=-1) * in_scale1 320 | std_x_q = (in_scale1 / channel_nums) * torch.sqrt( 321 | channel_nums * (x_q**2).sum(dim=-1) - x_q.sum(dim=-1)**2) 322 | 323 | # 防止除法导致nan 324 | std_x_q = torch.maximum(std_x_q, 1e-8 * torch.ones_like(std_x_q).to(std_x_q)) 325 | 326 | A = (in_scale1 / std_x_q).unsqueeze(-1) * \ 327 | self.weight.reshape(1, -1) / out_scale 328 | A_sign = A.sign() 329 | M, N = self.get_MN(A.abs()) 330 | B = ((self.bias.reshape(1, -1) - 331 | (mean_x_q / std_x_q).unsqueeze(-1) * 332 | self.weight.reshape(1, -1)) / out_scale * 333 | torch.pow(2, N)).round() 334 | 335 | x_q = ((A_sign * M * x_q + B) / torch.pow(2, N)).round() 336 | x = x_q * out_scale 337 | else: 338 | raise NotImplementedError 339 | return x 340 | 341 | 342 | class QIntSoftmax(nn.Module): 343 | 344 | def __init__(self, 345 | log_i_softmax=False, 346 | quant=False, 347 | calibrate=False, 348 | last_calibrate=False, 349 | bit_type=BIT_TYPE_DICT['int8'], 350 | calibration_mode='layer_wise', 351 | observer_str='minmax', 352 | quantizer_str='uniform'): 353 | super(QIntSoftmax, self).__init__() 354 | 355 | self.log_i_softmax = log_i_softmax 356 | self.quant = quant 357 | self.calibrate = calibrate 358 | self.last_calibrate = last_calibrate 359 | self.bit_type = bit_type 360 | self.calibration_mode = calibration_mode 361 | self.observer_str = observer_str 362 | self.quantizer_str = quantizer_str 363 | 364 | self.module_type = 'activation' 365 | self.observer = build_observer(self.observer_str, self.module_type, 366 | self.bit_type, self.calibration_mode) 367 | self.quantizer = build_quantizer(self.quantizer_str, self.bit_type, 368 | self.observer, self.module_type) 369 | 370 | @staticmethod 371 | def log_round(x): 372 | x_log_floor = x.log2().floor() 373 | big = x_log_floor 374 | extra_mask = (x - 2**big) >= 2**(big - 1) 375 | big[extra_mask] = big[extra_mask] + 1 376 | return big 377 | 378 | @staticmethod 379 | def int_softmax(x, scaling_factor): 380 | 381 | def int_polynomial(x_int, scaling_factor): 382 | coef = [0.35815147, 0.96963238, 1.] # ax**2 + bx + c 383 | coef[1] /= coef[0] 384 | coef[2] /= coef[0] 385 | b_int = torch.floor(coef[1] / scaling_factor) 386 | c_int = torch.floor(coef[2] / scaling_factor**2) 387 | z = x_int + b_int 388 | z = x_int * z 389 | z = z + c_int 390 | scaling_factor = coef[0] * scaling_factor**2 391 | return z, scaling_factor 392 | 393 | def int_exp(x_int, scaling_factor): 394 | x0 = -0.6931 # -ln2 395 | n = 30 # sufficiently large integer 396 | x0_int = torch.floor(x0 / scaling_factor) 397 | x_int = torch.max(x_int, n * x0_int) 398 | q = torch.floor(x_int / x0_int) 399 | r = x_int - x0_int * q 400 | exp_int, exp_scaling_factor = int_polynomial(r, scaling_factor) 401 | exp_int = torch.clamp(torch.floor(exp_int * 2**(n - q)), min=0) 402 | scaling_factor = exp_scaling_factor / 2**n 403 | return exp_int, scaling_factor 404 | 405 | x_int = x / scaling_factor 406 | x_int_max, _ = x_int.max(dim=-1, keepdim=True) 407 | x_int = x_int - x_int_max 408 | exp_int, exp_scaling_factor = int_exp(x_int, scaling_factor) 409 | exp_int_sum = exp_int.sum(dim=-1, keepdim=True) 410 | return exp_int, exp_int_sum 411 | 412 | def forward(self, x, scale): 413 | if self.log_i_softmax and scale is not None: 414 | exp_int, exp_int_sum = self.int_softmax(x, scale) 415 | softmax_out = torch.round(exp_int_sum / exp_int) 416 | rounds = self.log_round(softmax_out) 417 | mask = rounds >= 2**self.bit_type.bits 418 | qlog = torch.clamp(rounds, 0, 2**self.bit_type.bits - 1) 419 | deq_softmax = 2**(-qlog) 420 | deq_softmax[mask] = 0 421 | return deq_softmax 422 | else: 423 | x = x.softmax(dim=-1) 424 | if self.calibrate: 425 | self.quantizer.observer.update(x) 426 | if self.last_calibrate: 427 | self.quantizer.update_quantization_params(x) 428 | if not self.quant: 429 | return x 430 | x = self.quantizer(x) 431 | return x 432 | -------------------------------------------------------------------------------- /models/RWKV_V4/ptq/observer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MEGVII Inc. and its affiliates. All Rights Reserved. 2 | from .build import build_observer 3 | -------------------------------------------------------------------------------- /models/RWKV_V4/ptq/observer/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MEGVII Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | 4 | 5 | class BaseObserver: 6 | 7 | def __init__(self, module_type, bit_type, calibration_mode): 8 | self.module_type = module_type 9 | self.bit_type = bit_type 10 | self.calibration_mode = calibration_mode 11 | self.max_val = None 12 | self.min_val = None 13 | self.eps = torch.finfo(torch.float32).eps 14 | 15 | def reshape_tensor(self, v): 16 | if not isinstance(v, torch.Tensor): 17 | v = torch.tensor(v) 18 | v = v.detach() 19 | if self.module_type in ['conv_weight', 'linear_weight']: 20 | v = v.reshape(v.shape[0], -1) 21 | elif self.module_type == 'activation': 22 | if len(v.shape) == 4: 23 | v = v.permute(0, 2, 3, 1) 24 | v = v.reshape(-1, v.shape[-1]) 25 | v = v.transpose(0, 1) 26 | else: 27 | raise NotImplementedError 28 | return v 29 | 30 | def update(self, v): 31 | # update self.max_val and self.min_val 32 | raise NotImplementedError 33 | 34 | def get_quantization_params(self, *args, **kwargs): 35 | raise NotImplementedError 36 | -------------------------------------------------------------------------------- /models/RWKV_V4/ptq/observer/build.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MEGVII Inc. and its affiliates. All Rights Reserved. 2 | from .ema import EmaObserver 3 | from .minmax import MinmaxObserver 4 | from .omse import OmseObserver 5 | from .percentile import PercentileObserver 6 | from .ptf import PtfObserver 7 | 8 | str2observer = { 9 | 'minmax': MinmaxObserver, 10 | 'ema': EmaObserver, 11 | 'omse': OmseObserver, 12 | 'percentile': PercentileObserver, 13 | 'ptf': PtfObserver 14 | } 15 | 16 | 17 | def build_observer(observer_str, module_type, bit_type, calibration_mode): 18 | observer = str2observer[observer_str] 19 | return observer(module_type, bit_type, calibration_mode) 20 | -------------------------------------------------------------------------------- /models/RWKV_V4/ptq/observer/ema.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MEGVII Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | 4 | from .base import BaseObserver 5 | 6 | 7 | class EmaObserver(BaseObserver): 8 | 9 | def __init__(self, 10 | module_type, 11 | bit_type, 12 | calibration_mode, 13 | ema_sigma=0.01): 14 | super(EmaObserver, self).__init__(module_type, bit_type, 15 | calibration_mode) 16 | self.ema_sigma = ema_sigma 17 | self.symmetric = self.bit_type.signed 18 | 19 | def update(self, v): 20 | v = self.reshape_tensor(v) 21 | cur_max = v.max(axis=1).values 22 | if self.max_val is None: 23 | self.max_val = cur_max 24 | else: 25 | self.max_val = self.max_val + \ 26 | self.ema_sigma * (cur_max - self.max_val) 27 | cur_min = v.min(axis=1).values 28 | if self.min_val is None: 29 | self.min_val = cur_min 30 | else: 31 | self.min_val = self.min_val + \ 32 | self.ema_sigma * (cur_min - self.min_val) 33 | 34 | if self.calibration_mode == 'layer_wise': 35 | self.max_val = self.max_val.max() 36 | self.min_val = self.min_val.min() 37 | 38 | def get_quantization_params(self, *args, **kwargs): 39 | max_val = self.max_val 40 | min_val = self.min_val 41 | 42 | qmax = self.bit_type.upper_bound 43 | qmin = self.bit_type.lower_bound 44 | 45 | scale = torch.ones_like(max_val, dtype=torch.float32) 46 | zero_point = torch.zeros_like(max_val, dtype=torch.int64) 47 | 48 | if self.symmetric: 49 | max_val = torch.max(-min_val, max_val) 50 | scale = max_val / (float(qmax - qmin) / 2) 51 | scale.clamp_(self.eps) 52 | zero_point = torch.zeros_like(max_val, dtype=torch.int64) 53 | else: 54 | scale = (max_val - min_val) / float(qmax - qmin) 55 | scale.clamp_(self.eps) 56 | zero_point = qmin - torch.round(min_val / scale) 57 | zero_point.clamp_(qmin, qmax) 58 | return scale, zero_point 59 | -------------------------------------------------------------------------------- /models/RWKV_V4/ptq/observer/minmax.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MEGVII Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | 4 | from .base import BaseObserver 5 | 6 | 7 | class MinmaxObserver(BaseObserver): 8 | 9 | def __init__(self, module_type, bit_type, calibration_mode): 10 | super(MinmaxObserver, self).__init__(module_type, bit_type, 11 | calibration_mode) 12 | self.symmetric = self.bit_type.signed 13 | 14 | def update(self, v): 15 | v = self.reshape_tensor(v) 16 | cur_max = v.max(axis=1).values 17 | if self.max_val is None: 18 | self.max_val = cur_max 19 | else: 20 | self.max_val = torch.max(cur_max, self.max_val) 21 | cur_min = v.min(axis=1).values 22 | if self.min_val is None: 23 | self.min_val = cur_min 24 | else: 25 | self.min_val = torch.min(cur_min, self.min_val) 26 | 27 | if self.calibration_mode == 'layer_wise': 28 | self.max_val = self.max_val.max() 29 | self.min_val = self.min_val.min() 30 | 31 | def get_quantization_params(self, *args, **kwargs): 32 | max_val = self.max_val 33 | min_val = self.min_val 34 | 35 | qmax = self.bit_type.upper_bound 36 | qmin = self.bit_type.lower_bound 37 | 38 | scale = torch.ones_like(max_val, dtype=torch.float32) 39 | zero_point = torch.zeros_like(max_val, dtype=torch.int64) 40 | 41 | if self.symmetric: 42 | max_val = torch.max(-min_val, max_val) 43 | scale = max_val / (float(qmax - qmin) / 2) 44 | scale.clamp_(self.eps) 45 | zero_point = torch.zeros_like(max_val, dtype=torch.int64) 46 | else: 47 | scale = (max_val - min_val) / float(qmax - qmin) 48 | scale.clamp_(self.eps) 49 | zero_point = qmin - torch.round(min_val / scale) 50 | zero_point.clamp_(qmin, qmax) 51 | return scale, zero_point 52 | -------------------------------------------------------------------------------- /models/RWKV_V4/ptq/observer/omse.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MEGVII Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | 4 | from .base import BaseObserver 5 | from .utils import lp_loss 6 | 7 | 8 | class OmseObserver(BaseObserver): 9 | 10 | def __init__(self, module_type, bit_type, calibration_mode): 11 | super(OmseObserver, self).__init__(module_type, bit_type, 12 | calibration_mode) 13 | 14 | def update(self, v): 15 | v = self.reshape_tensor(v) 16 | cur_max = v.max(axis=1).values 17 | if self.max_val is None: 18 | self.max_val = cur_max 19 | else: 20 | self.max_val = torch.max(cur_max, self.max_val) 21 | cur_min = v.min(axis=1).values 22 | if self.min_val is None: 23 | self.min_val = cur_min 24 | else: 25 | self.min_val = torch.min(cur_min, self.min_val) 26 | 27 | if self.calibration_mode == 'layer_wise': 28 | self.max_val = self.max_val.max() 29 | self.min_val = self.min_val.min() 30 | 31 | def get_quantization_params(self, inputs): 32 | max_val = self.max_val 33 | min_val = self.min_val 34 | qmax = self.bit_type.upper_bound 35 | qmin = self.bit_type.lower_bound 36 | 37 | best_score = 1e+10 38 | for i in range(90): 39 | new_max = max_val * (1.0 - (i * 0.01)) 40 | new_min = min_val * (1.0 - (i * 0.01)) 41 | new_scale = (new_max - new_min) / float(qmax - qmin) 42 | new_scale.clamp_(self.eps) 43 | new_zero_point = qmin - torch.round(new_min / new_scale) 44 | new_zero_point.clamp_(qmin, qmax) 45 | inputs_q = ((inputs / new_scale + new_zero_point).round().clamp( 46 | qmin, qmax) - new_zero_point) * new_scale 47 | # L_p norm minimization as described in LAPQ 48 | # https://arxiv.org/abs/1911.07190 49 | score = lp_loss(inputs, inputs_q, p=2.0, reduction='all') 50 | if score < best_score: 51 | best_score = score 52 | self.max_val = new_max 53 | self.min_val = new_min 54 | scale = new_scale 55 | zero_point = new_zero_point 56 | return scale, zero_point 57 | -------------------------------------------------------------------------------- /models/RWKV_V4/ptq/observer/percentile.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MEGVII Inc. and its affiliates. All Rights Reserved. 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .base import BaseObserver 7 | 8 | 9 | class PercentileObserver(BaseObserver): 10 | 11 | def __init__(self, 12 | module_type, 13 | bit_type, 14 | calibration_mode, 15 | percentile_sigma=0.01, 16 | percentile_alpha=0.99999): 17 | super(PercentileObserver, self).__init__(module_type, bit_type, 18 | calibration_mode) 19 | self.percentile_sigma = 0.01 20 | self.percentile_alpha = 0.99999 21 | self.symmetric = self.bit_type.signed 22 | 23 | def update(self, v): 24 | # channel-wise needs too much time. 25 | assert self.calibration_mode == 'layer_wise' 26 | v = self.reshape_tensor(v) 27 | try: 28 | cur_max = torch.quantile(v.reshape(-1), self.percentile_alpha) 29 | cur_min = torch.quantile(v.reshape(-1), 30 | 1.0 - self.percentile_alpha) 31 | except: 32 | cur_max = torch.tensor(np.percentile( 33 | v.reshape(-1).cpu(), self.percentile_alpha * 100), 34 | device=v.device, 35 | dtype=torch.float32) 36 | cur_min = torch.tensor(np.percentile( 37 | v.reshape(-1).cpu(), (1 - self.percentile_alpha) * 100), 38 | device=v.device, 39 | dtype=torch.float32) 40 | if self.max_val is None: 41 | self.max_val = cur_max 42 | else: 43 | self.max_val = self.max_val + \ 44 | self.percentile_sigma * (cur_max - self.max_val) 45 | if self.min_val is None: 46 | self.min_val = cur_min 47 | else: 48 | self.min_val = self.min_val + \ 49 | self.percentile_sigma * (cur_min - self.min_val) 50 | 51 | def get_quantization_params(self, *args, **kwargs): 52 | max_val = self.max_val 53 | min_val = self.min_val 54 | 55 | qmax = self.bit_type.upper_bound 56 | qmin = self.bit_type.lower_bound 57 | 58 | scale = torch.ones_like(max_val, dtype=torch.float32) 59 | zero_point = torch.zeros_like(max_val, dtype=torch.int64) 60 | 61 | if self.symmetric: 62 | max_val = torch.max(-min_val, max_val) 63 | scale = max_val / (float(qmax - qmin) / 2) 64 | scale.clamp_(self.eps) 65 | zero_point = torch.zeros_like(max_val, dtype=torch.int64) 66 | else: 67 | scale = (max_val - min_val) / float(qmax - qmin) 68 | scale.clamp_(self.eps) 69 | zero_point = qmin - torch.round(min_val / scale) 70 | zero_point.clamp_(qmin, qmax) 71 | return scale, zero_point 72 | -------------------------------------------------------------------------------- /models/RWKV_V4/ptq/observer/ptf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MEGVII Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | 4 | from .base import BaseObserver 5 | from .utils import lp_loss 6 | 7 | 8 | class PtfObserver(BaseObserver): 9 | 10 | def __init__(self, module_type, bit_type, calibration_mode): 11 | super(PtfObserver, self).__init__(module_type, bit_type, 12 | calibration_mode) 13 | 14 | def update(self, v): 15 | v = self.reshape_tensor(v) 16 | cur_max = v.max(axis=1).values 17 | if self.max_val is None: 18 | self.max_val = cur_max 19 | else: 20 | self.max_val = torch.max(cur_max, self.max_val) 21 | cur_min = v.min(axis=1).values 22 | if self.min_val is None: 23 | self.min_val = cur_min 24 | else: 25 | self.min_val = torch.min(cur_min, self.min_val) 26 | 27 | if self.calibration_mode == 'layer_wise': 28 | self.max_val = self.max_val.max() 29 | self.min_val = self.min_val.min() 30 | 31 | def get_quantization_params(self, inputs, *args, **kwargs): 32 | max_val = self.max_val 33 | min_val = self.min_val 34 | 35 | qmax = self.bit_type.upper_bound 36 | qmin = self.bit_type.lower_bound 37 | 38 | best_score = 1e+10 39 | max_val_t = max_val.max() 40 | min_val_t = min_val.min() 41 | scale8 = (max_val_t - min_val_t) / float(qmax - qmin) 42 | scale8.clamp_(self.eps) 43 | scale4 = scale8 / 2 44 | scale2 = scale4 / 2 45 | scale1 = scale2 / 2 46 | zero_point = qmin - torch.round(min_val_t / scale8) 47 | zero_point.clamp_(qmin, qmax) 48 | scale_mask = torch.ones_like(max_val) 49 | for j in range(inputs.shape[2]): 50 | data = inputs[..., j].unsqueeze(-1) 51 | data_q1 = ((data / scale1 + zero_point).round().clamp(qmin, qmax) - 52 | zero_point) * scale1 53 | data_q2 = ((data / scale2 + zero_point).round().clamp(qmin, qmax) - 54 | zero_point) * scale2 55 | data_q4 = ((data / scale4 + zero_point).round().clamp(qmin, qmax) - 56 | zero_point) * scale4 57 | data_q8 = ((data / scale8 + zero_point).round().clamp(qmin, qmax) - 58 | zero_point) * scale8 59 | score1 = lp_loss(data, data_q1, p=2.0, reduction='all') 60 | score2 = lp_loss(data, data_q2, p=2.0, reduction='all') 61 | score4 = lp_loss(data, data_q4, p=2.0, reduction='all') 62 | score8 = lp_loss(data, data_q8, p=2.0, reduction='all') 63 | score = [score1, score2, score4, score8] 64 | scale_mask[j] *= 2**score.index(min(score)) 65 | scale = scale1 * scale_mask 66 | return scale, zero_point 67 | -------------------------------------------------------------------------------- /models/RWKV_V4/ptq/observer/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MEGVII Inc. and its affiliates. All Rights Reserved. 2 | def lp_loss(pred, tgt, p=2.0, reduction='none'): 3 | """ 4 | loss function measured in L_p Norm 5 | """ 6 | if reduction == 'none': 7 | return (pred - tgt).abs().pow(p).sum(1).mean() 8 | else: 9 | return (pred - tgt).abs().pow(p).mean() 10 | -------------------------------------------------------------------------------- /models/RWKV_V4/ptq/quantizer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MEGVII Inc. and its affiliates. All Rights Reserved. 2 | from .build import build_quantizer 3 | -------------------------------------------------------------------------------- /models/RWKV_V4/ptq/quantizer/base.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MEGVII Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class BaseQuantizer(nn.Module): 7 | 8 | def __init__(self, bit_type, observer, module_type): 9 | super(BaseQuantizer, self).__init__() 10 | self.bit_type = bit_type 11 | self.observer = observer 12 | self.module_type = module_type 13 | 14 | def get_reshape_range(self, inputs): 15 | range_shape = None 16 | if self.module_type == 'conv_weight': 17 | range_shape = (-1, 1, 1, 1) 18 | elif self.module_type == 'linear_weight': 19 | range_shape = (-1, 1) 20 | elif self.module_type == 'activation': 21 | if len(inputs.shape) == 1: 22 | range_shape = (-1) 23 | elif len(inputs.shape) == 2: 24 | range_shape = (1, -1) 25 | elif len(inputs.shape) == 3: 26 | range_shape = (1, 1, -1) 27 | elif len(inputs.shape) == 4: 28 | range_shape = (1, 1, 1, -1) 29 | else: 30 | import ipdb; ipdb.set_trace() 31 | raise NotImplementedError 32 | else: 33 | raise NotImplementedError 34 | return range_shape 35 | 36 | def update_quantization_params(self, *args, **kwargs): 37 | pass 38 | 39 | def quant(self, inputs, scale=None, zero_point=None): 40 | raise NotImplementedError 41 | 42 | def dequantize(self, inputs, scale=None, zero_point=None): 43 | raise NotImplementedError 44 | 45 | def forward(self, inputs): 46 | outputs = self.quant(inputs) 47 | outputs = self.dequantize(outputs) 48 | return outputs 49 | -------------------------------------------------------------------------------- /models/RWKV_V4/ptq/quantizer/build.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MEGVII Inc. and its affiliates. All Rights Reserved. 2 | from .log2 import Log2Quantizer 3 | from .uniform import UniformQuantizer 4 | 5 | str2quantizer = {'uniform': UniformQuantizer, 'log2': Log2Quantizer} 6 | 7 | 8 | def build_quantizer(quantizer_str, bit_type, observer, module_type): 9 | quantizer = str2quantizer[quantizer_str] 10 | return quantizer(bit_type, observer, module_type) 11 | -------------------------------------------------------------------------------- /models/RWKV_V4/ptq/quantizer/log2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MEGVII Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | 4 | from .base import BaseQuantizer 5 | 6 | 7 | class Log2Quantizer(BaseQuantizer): 8 | 9 | def __init__(self, bit_type, observer, module_type): 10 | super(Log2Quantizer, self).__init__( 11 | bit_type, 12 | observer, 13 | module_type, 14 | ) 15 | self.softmax_mask = None 16 | 17 | def quant(self, inputs): 18 | rounds = torch.round(-1 * inputs.log2()) 19 | self.softmax_mask = rounds >= 2**self.bit_type.bits 20 | outputs = torch.clamp(rounds, 0, 2**self.bit_type.bits - 1) 21 | return outputs 22 | 23 | def dequantize(self, inputs): 24 | outputs = 2**(-1 * inputs) 25 | outputs[self.softmax_mask] = 0 26 | return outputs 27 | -------------------------------------------------------------------------------- /models/RWKV_V4/ptq/quantizer/uniform.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) MEGVII Inc. and its affiliates. All Rights Reserved. 2 | import torch 3 | import torch.nn as nn 4 | 5 | from .base import BaseQuantizer 6 | 7 | 8 | class UniformQuantizer(BaseQuantizer): 9 | 10 | def __init__(self, bit_type, observer, module_type): 11 | super(UniformQuantizer, self).__init__(bit_type, observer, module_type) 12 | self.scale = None 13 | self.zero_point = None 14 | 15 | def update_quantization_params(self, *args, **kwargs): 16 | self.scale, self.zero_point = self.observer.get_quantization_params( 17 | *args, **kwargs) 18 | 19 | def quant(self, inputs, scale=None, zero_point=None): 20 | if scale is None: 21 | scale = self.scale 22 | if zero_point is None: 23 | zero_point = self.zero_point 24 | range_shape = self.get_reshape_range(inputs) 25 | scale = scale.reshape(range_shape) 26 | zero_point = zero_point.reshape(range_shape) 27 | outputs = inputs / scale + zero_point 28 | outputs = outputs.round().clamp(self.bit_type.lower_bound, 29 | self.bit_type.upper_bound) 30 | return outputs 31 | 32 | def dequantize(self, inputs, scale=None, zero_point=None): 33 | if scale is None: 34 | scale = self.scale 35 | if zero_point is None: 36 | zero_point = self.zero_point 37 | range_shape = self.get_reshape_range(inputs) 38 | scale = scale.reshape(range_shape) 39 | zero_point = zero_point.reshape(range_shape) 40 | outputs = (inputs - zero_point) * scale 41 | return outputs 42 | -------------------------------------------------------------------------------- /models/RWKV_V4/rwkv_v4_infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numbers 4 | import torch.nn as nn 5 | from torch.nn import init 6 | import torch.nn.functional as F 7 | from ..registry import MODULE_BUILD_FUNCS 8 | from torch import Tensor, Size 9 | from typing import Union, List 10 | _shape_t = Union[int, List[int], Size] 11 | 12 | class LayerNorm(nn.Module): 13 | def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, device=None, dtype=None): 14 | factory_kwargs = {'device': device, 'dtype': dtype} 15 | super(LayerNorm, self).__init__() 16 | if isinstance(normalized_shape, numbers.Integral): 17 | # mypy error: incompatible types in assignment 18 | normalized_shape = (normalized_shape,) # type: ignore[assignment] 19 | self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] 20 | self.eps = eps 21 | self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) 22 | self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) 23 | self.reset_parameters() 24 | 25 | def reset_parameters(self) -> None: 26 | init.ones_(self.weight) 27 | init.zeros_(self.bias) 28 | 29 | def forward(self, input: Tensor) -> Tensor: 30 | u = torch.mean(input, dim=-1, keepdim=True) 31 | s = torch.mean(input * input, dim=-1, keepdim=True) 32 | s = torch.sqrt(s - u * u + self.eps) 33 | x_normalized = (input - u) / s 34 | output = x_normalized * self.weight + self.bias 35 | return output 36 | 37 | 38 | class RWKV_ChannelMix(nn.Module): 39 | def __init__(self, n_embed, ffn_dim): 40 | super().__init__() 41 | 42 | self.time_mix_k = nn.Parameter(torch.ones(1, n_embed)) 43 | self.time_mix_r = nn.Parameter(torch.ones(1, n_embed)) 44 | 45 | self.key = nn.Linear(n_embed, ffn_dim, bias=False) 46 | self.receptance = nn.Linear(n_embed, n_embed, bias=False) 47 | self.value = nn.Linear(ffn_dim, n_embed, bias=False) 48 | 49 | def forward(self, x, state_ffn): 50 | xk = x * self.time_mix_k + state_ffn * (1 - self.time_mix_k) 51 | xr = x * self.time_mix_r + state_ffn * (1 - self.time_mix_r) 52 | new_ffn = x 53 | 54 | r = torch.sigmoid(self.receptance(xr)) 55 | k = torch.square(torch.relu(self.key(xk))) 56 | kv = self.value(k) 57 | 58 | rkv = r * kv 59 | return rkv, new_ffn 60 | 61 | 62 | class RWKV_TimeMix(nn.Module): 63 | def __init__(self, n_embed): 64 | super().__init__() 65 | self.time_decay = nn.Parameter(torch.ones(n_embed)) 66 | self.time_first = nn.Parameter(torch.ones(n_embed) * math.log(0.3)) 67 | 68 | self.time_mix_k = nn.Parameter(torch.ones(1, n_embed)) 69 | self.time_mix_v = nn.Parameter(torch.ones(1, n_embed)) 70 | self.time_mix_r = nn.Parameter(torch.ones(1, n_embed)) 71 | 72 | self.key = nn.Linear(n_embed, n_embed, bias=False) 73 | self.value = nn.Linear(n_embed, n_embed, bias=False) 74 | self.receptance = nn.Linear(n_embed, n_embed, bias=False) 75 | 76 | self.output = nn.Linear(n_embed, n_embed, bias=False) 77 | 78 | def forward(self, x, state_A, state_B, state_p, state_x): 79 | xk = x * self.time_mix_k + state_x * (1 - self.time_mix_k) 80 | xv = x * self.time_mix_v + state_x * (1 - self.time_mix_v) 81 | xr = x * self.time_mix_r + state_x * (1 - self.time_mix_r) 82 | new_x = x 83 | 84 | k = self.key(xk) 85 | v = self.value(xv) 86 | r = torch.sigmoid(self.receptance(xr)) 87 | 88 | ww = self.time_first + k 89 | p = torch.maximum(state_p, ww) 90 | e1 = torch.exp(state_p - p) 91 | e2 = torch.exp(ww - p) 92 | a = e1 * state_A + e2 * v 93 | b = e1 * state_B + e2 94 | 95 | ww = state_p + -torch.exp(self.time_decay) 96 | p = torch.maximum(ww, k) 97 | e1 = torch.exp(ww - p) 98 | e2 = torch.exp(k - p) 99 | new_A = e1 * state_A + e2 * v 100 | new_B = e1 * state_B + e2 101 | new_p = p 102 | 103 | rwkv = r * a / b 104 | rwkv = self.output(rwkv) 105 | return rwkv, new_A, new_B, new_p, new_x 106 | 107 | 108 | class RWKV_TimeMix_ONNX(nn.Module): 109 | def __init__(self, n_embed): 110 | super().__init__() 111 | self.time_decay = nn.Parameter(torch.ones(n_embed)) 112 | self.time_first = nn.Parameter(torch.ones(n_embed) * math.log(0.3)) 113 | 114 | self.time_mix_k = nn.Parameter(torch.ones(1, n_embed)) 115 | self.time_mix_v = nn.Parameter(torch.ones(1, n_embed)) 116 | self.time_mix_r = nn.Parameter(torch.ones(1, n_embed)) 117 | 118 | self.key = nn.Linear(n_embed, n_embed, bias=False) 119 | self.value = nn.Linear(n_embed, n_embed, bias=False) 120 | self.receptance = nn.Linear(n_embed, n_embed, bias=False) 121 | 122 | self.output = nn.Linear(n_embed, n_embed, bias=False) 123 | 124 | def forward(self, x, state_A, state_B, state_p, state_x): 125 | xk = x * self.time_mix_k + state_x * (1 - self.time_mix_k) 126 | xv = x * self.time_mix_v + state_x * (1 - self.time_mix_v) 127 | xr = x * self.time_mix_r + state_x * (1 - self.time_mix_r) 128 | new_x = x 129 | 130 | k = self.key(xk) 131 | v = self.value(xv) 132 | r = torch.sigmoid(self.receptance(xr)) 133 | 134 | ww = self.time_first + k 135 | # p = torch.maximum(state_p, ww) 136 | p = torch.stack([state_p.flatten(), ww.flatten()]).max(dim=0)[0].view(state_p.shape) 137 | # p = torch.where(state_p > ww, state_p, ww) 138 | 139 | e1 = torch.exp(state_p - p) 140 | e2 = torch.exp(ww - p) 141 | a = e1 * state_A + e2 * v 142 | b = e1 * state_B + e2 143 | 144 | ww = state_p + -torch.exp(self.time_decay) 145 | # p = torch.maximum(ww, k) 146 | p = torch.stack([ww.flatten(), k.flatten()]).max(dim=0)[0].view(state_p.shape) 147 | # p = torch.where(ww > k, ww, k) 148 | 149 | e1 = torch.exp(ww - p) 150 | e2 = torch.exp(k - p) 151 | new_A = e1 * state_A + e2 * v 152 | new_B = e1 * state_B + e2 153 | new_p = p 154 | 155 | rwkv = r * a / b 156 | rwkv = self.output(rwkv) 157 | return rwkv, new_A, new_B, new_p, new_x 158 | 159 | 160 | class Block(nn.Module): 161 | def __init__(self, layer_id, n_embed, ffn_dim): 162 | super().__init__() 163 | self.layer_id = layer_id 164 | 165 | self.ln1 = nn.LayerNorm(n_embed) 166 | self.ln2 = nn.LayerNorm(n_embed) 167 | if self.layer_id == 0: 168 | self.ln0 = nn.LayerNorm(n_embed) 169 | 170 | self.att = RWKV_TimeMix(n_embed) 171 | self.ffn = RWKV_ChannelMix(n_embed, ffn_dim) 172 | 173 | def forward(self, x, state_A, state_B, state_p, state_x, state_ffn): 174 | if self.layer_id == 0: 175 | x = self.ln0(x) 176 | 177 | short_cut = x 178 | x = self.ln1(x) 179 | x, new_A, new_B, new_p, new_x = self.att(x, state_A, state_B, state_p, state_x) 180 | x = short_cut + x 181 | 182 | short_cut = x 183 | x = self.ln2(x) 184 | x, new_ffn = self.ffn(x, state_ffn) 185 | x = short_cut + x 186 | return x, new_A, new_B, new_p, new_x, new_ffn 187 | 188 | 189 | class Block_ONNX(nn.Module): 190 | def __init__(self, layer_id, n_embed, ffn_dim): 191 | super().__init__() 192 | self.layer_id = layer_id 193 | 194 | self.ln1 = nn.LayerNorm(n_embed) 195 | self.ln2 = nn.LayerNorm(n_embed) 196 | if self.layer_id == 0: 197 | self.ln0 = nn.LayerNorm(n_embed) 198 | 199 | self.att = RWKV_TimeMix_ONNX(n_embed) 200 | self.ffn = RWKV_ChannelMix(n_embed, ffn_dim) 201 | 202 | def forward(self, x, state_A, state_B, state_p, state_x, state_ffn): 203 | if self.layer_id == 0: 204 | x = self.ln0(x) 205 | 206 | short_cut = x 207 | x, new_A, new_B, new_p, new_x = self.att(self.ln1(x), state_A, state_B, state_p, state_x) 208 | x = short_cut + x 209 | 210 | short_cut = x 211 | x, new_ffn = self.ffn(self.ln2(x), state_ffn) 212 | x = short_cut + x 213 | return x, new_A, new_B, new_p, new_x, new_ffn 214 | 215 | 216 | class Block_Script(nn.Module): 217 | def __init__(self, n_embed, ffn_dim): 218 | super().__init__() 219 | 220 | self.ln1 = nn.LayerNorm(n_embed) 221 | self.ln2 = nn.LayerNorm(n_embed) 222 | 223 | self.att = RWKV_TimeMix(n_embed) 224 | self.ffn = RWKV_ChannelMix(n_embed, ffn_dim) 225 | 226 | def forward(self, x, state_A, state_B, state_p, state_x, state_ffn): 227 | short_cut = x 228 | x, new_A, new_B, new_p, new_x = self.att(self.ln1(x), state_A, state_B, state_p, state_x) 229 | x = short_cut + x 230 | 231 | short_cut = x 232 | x, new_ffn = self.ffn(self.ln2(x), state_ffn) 233 | x = short_cut + x 234 | return x, new_A, new_B, new_p, new_x, new_ffn 235 | 236 | 237 | class RWKV_V4_Infer_For_CoreML(nn.Module): 238 | def __init__(self, 239 | vocab_size=2000, 240 | hidden_size=512, 241 | num_hidden_layers=4, 242 | intermediate_size=1024, 243 | ): 244 | super(RWKV_V4_Infer_For_CoreML, self).__init__() 245 | self.hidden_size = hidden_size 246 | self.num_hidden_layers = num_hidden_layers 247 | 248 | self.emb = nn.Embedding(vocab_size, hidden_size) 249 | self.blocks = nn.ModuleList([Block(i, hidden_size, intermediate_size) for i in range(num_hidden_layers)]) 250 | self.ln_out = nn.LayerNorm(hidden_size) 251 | self.head = nn.Linear(hidden_size, vocab_size, bias=False) 252 | 253 | def forward_initialzation(self, batch_size, device): 254 | state_A = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 255 | state_B = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 256 | state_p = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) - 1e30 257 | state_x = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 258 | state_ffn = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 259 | hidden_state = torch.stack([state_A, state_B, state_p, state_x, state_ffn]).to(device) 260 | return hidden_state 261 | 262 | def forward(self, x, hidden_state): 263 | # x = self.emb(input_token) 264 | # x = torch.matmul(input_onehot, self.emb.weight) 265 | 266 | batch_size = x.size(0) 267 | state_A, state_B, state_p, state_x, state_ffn = hidden_state.split(1, dim=0) 268 | new_hidden_state = [] 269 | 270 | for i, block in enumerate(self.blocks): 271 | x, new_A, new_B, new_p, new_x, new_ffn = \ 272 | block(x, state_A[0, i], state_B[0, i], state_p[0, i], state_x[0, i], state_ffn[0, i]) 273 | 274 | new_hidden_state.append(new_A) 275 | new_hidden_state.append(new_B) 276 | new_hidden_state.append(new_p) 277 | new_hidden_state.append(new_x) 278 | new_hidden_state.append(new_ffn) 279 | 280 | new_hidden_state = torch.cat(new_hidden_state) 281 | new_hidden_state = new_hidden_state.view([self.num_hidden_layers, 5, batch_size, self.hidden_size]) 282 | new_hidden_state = new_hidden_state.transpose(0, 1) 283 | x = self.ln_out(x) 284 | x = self.head(x) 285 | return x, new_hidden_state 286 | 287 | 288 | class RWKV_V4_Infer_For_ONNX(nn.Module): 289 | def __init__(self, 290 | vocab_size=2000, 291 | hidden_size=512, 292 | num_hidden_layers=4, 293 | intermediate_size=1024, 294 | ): 295 | super(RWKV_V4_Infer_For_ONNX, self).__init__() 296 | self.hidden_size = hidden_size 297 | self.num_hidden_layers = num_hidden_layers 298 | 299 | self.emb = nn.Embedding(vocab_size, hidden_size) 300 | self.blocks = nn.ModuleList([Block_ONNX(i, hidden_size, intermediate_size) for i in range(num_hidden_layers)]) 301 | self.ln_out = nn.LayerNorm(hidden_size) 302 | self.head = nn.Linear(hidden_size, vocab_size, bias=False) 303 | 304 | def forward_initialzation(self, batch_size, device): 305 | state_A = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 306 | state_B = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 307 | state_p = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) - 1e30 308 | state_x = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 309 | state_ffn = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 310 | hidden_state = torch.stack([state_A, state_B, state_p, state_x, state_ffn]).to(device) 311 | return hidden_state 312 | 313 | def forward(self, x, hidden_state): 314 | # x = self.emb(input_token) 315 | batch_size = x.size(0) 316 | # x = torch.matmul(input_onehot, self.emb.weight) 317 | state_A, state_B, state_p, state_x, state_ffn = hidden_state.split(1, dim=0) 318 | new_hidden_state = [] 319 | 320 | for i, block in enumerate(self.blocks): 321 | x, new_A, new_B, new_p, new_x, new_ffn = \ 322 | block(x, state_A[0, i], state_B[0, i], state_p[0, i], state_x[0, i], state_ffn[0, i]) 323 | 324 | new_hidden_state.append(new_A) 325 | new_hidden_state.append(new_B) 326 | new_hidden_state.append(new_p) 327 | new_hidden_state.append(new_x) 328 | new_hidden_state.append(new_ffn) 329 | 330 | new_hidden_state = torch.cat(new_hidden_state) 331 | new_hidden_state = new_hidden_state.view([self.num_hidden_layers, 5, batch_size, self.hidden_size]) 332 | new_hidden_state = new_hidden_state.transpose(0, 1) 333 | x = self.ln_out(x) 334 | x = self.head(x) 335 | return x, new_hidden_state 336 | 337 | 338 | class RWKV_V4_Infer_For_Script(nn.Module): 339 | def __init__(self, 340 | vocab_size=2000, 341 | hidden_size=512, 342 | num_hidden_layers=4, 343 | intermediate_size=1024, 344 | ): 345 | super(RWKV_V4_Infer_For_Script, self).__init__() 346 | self.hidden_size = hidden_size 347 | self.num_hidden_layers = num_hidden_layers 348 | 349 | self.emb = nn.Embedding(vocab_size, hidden_size) 350 | self.ln0 = nn.LayerNorm(hidden_size) 351 | self.blocks = nn.ModuleList([Block_Script(hidden_size, intermediate_size) for i in range(num_hidden_layers)]) 352 | self.ln_out = nn.LayerNorm(hidden_size) 353 | self.head = nn.Linear(hidden_size, vocab_size, bias=False) 354 | 355 | def forward_initialzation(self, batch_size, device): 356 | state_A = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 357 | state_B = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 358 | state_p = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) - 1e30 359 | state_x = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 360 | state_ffn = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 361 | hidden_state = torch.stack([state_A, state_B, state_p, state_x, state_ffn]).to(device) 362 | return hidden_state 363 | 364 | def forward(self, input_token, hidden_state): 365 | x = self.emb(input_token) 366 | batch_size = input_token.size(0) 367 | # x = torch.matmul(input_onehot, self.emb.weight) 368 | state_A, state_B, state_p, state_x, state_ffn = hidden_state.split(1, dim=0) 369 | new_hidden_state = [] 370 | 371 | x = self.ln0(x) 372 | for i, block in enumerate(self.blocks): 373 | x, new_A, new_B, new_p, new_x, new_ffn = \ 374 | block(x, state_A[0, i], state_B[0, i], state_p[0, i], state_x[0, i], state_ffn[0, i]) 375 | 376 | new_hidden_state.append(new_A) 377 | new_hidden_state.append(new_B) 378 | new_hidden_state.append(new_p) 379 | new_hidden_state.append(new_x) 380 | new_hidden_state.append(new_ffn) 381 | 382 | new_hidden_state = torch.cat(new_hidden_state) 383 | new_hidden_state = new_hidden_state.view([self.num_hidden_layers, 5, batch_size, self.hidden_size]) 384 | new_hidden_state = new_hidden_state.transpose(0, 1) 385 | x = self.ln_out(x) 386 | x = self.head(x) 387 | return x, new_hidden_state 388 | 389 | 390 | @MODULE_BUILD_FUNCS.registe_with_name(module_name='rwkv_v4_infer_for_coreml') 391 | def build_rwkv_v4_infer_for_coreml(args): 392 | model = RWKV_V4_Infer_For_CoreML( 393 | vocab_size = args.vocab_size, 394 | hidden_size = args.hidden_size, 395 | num_hidden_layers = args.num_hidden_layer, 396 | intermediate_size = args.intermediate_size 397 | ) 398 | criterion = nn.CrossEntropyLoss(reduction='none') 399 | return model, criterion 400 | 401 | 402 | @MODULE_BUILD_FUNCS.registe_with_name(module_name='rwkv_v4_infer_for_onnx') 403 | def build_rwkv_v4_infer_for_onnx(args): 404 | model = RWKV_V4_Infer_For_ONNX( 405 | vocab_size = args.vocab_size, 406 | hidden_size = args.hidden_size, 407 | num_hidden_layers = args.num_hidden_layer, 408 | intermediate_size = args.intermediate_size 409 | ) 410 | criterion = nn.CrossEntropyLoss(reduction='none') 411 | return model, criterion 412 | 413 | 414 | @MODULE_BUILD_FUNCS.registe_with_name(module_name='rwkv_v4_infer_for_script') 415 | def build_rwkv_v4_infer_for_script(args): 416 | model = RWKV_V4_Infer_For_Script( 417 | vocab_size = args.vocab_size, 418 | hidden_size = args.hidden_size, 419 | num_hidden_layers = args.num_hidden_layer, 420 | intermediate_size = args.intermediate_size 421 | ) 422 | criterion = nn.CrossEntropyLoss(reduction='none') 423 | return model, criterion -------------------------------------------------------------------------------- /models/RWKV_V4/rwkv_v4_multi_pred_infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numbers 4 | import torch.nn as nn 5 | from torch.nn import init 6 | import torch.nn.functional as F 7 | from ..registry import MODULE_BUILD_FUNCS 8 | from torch import Tensor, Size 9 | from typing import Union, List 10 | _shape_t = Union[int, List[int], Size] 11 | 12 | class LayerNorm(nn.Module): 13 | def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, device=None, dtype=None): 14 | factory_kwargs = {'device': device, 'dtype': dtype} 15 | super(LayerNorm, self).__init__() 16 | if isinstance(normalized_shape, numbers.Integral): 17 | # mypy error: incompatible types in assignment 18 | normalized_shape = (normalized_shape,) # type: ignore[assignment] 19 | self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] 20 | self.eps = eps 21 | self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) 22 | self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) 23 | self.reset_parameters() 24 | 25 | def reset_parameters(self) -> None: 26 | init.ones_(self.weight) 27 | init.zeros_(self.bias) 28 | 29 | def forward(self, input: Tensor) -> Tensor: 30 | u = torch.mean(input, dim=-1, keepdim=True) 31 | s = torch.mean(input * input, dim=-1, keepdim=True) 32 | s = torch.sqrt(s - u * u + self.eps) 33 | x_normalized = (input - u) / s 34 | output = x_normalized * self.weight + self.bias 35 | return output 36 | 37 | 38 | class RWKV_ChannelMix(nn.Module): 39 | def __init__(self, n_embed, ffn_dim): 40 | super().__init__() 41 | 42 | self.time_mix_k = nn.Parameter(torch.ones(1, n_embed)) 43 | self.time_mix_r = nn.Parameter(torch.ones(1, n_embed)) 44 | 45 | self.key = nn.Linear(n_embed, ffn_dim, bias=False) 46 | self.receptance = nn.Linear(n_embed, n_embed, bias=False) 47 | self.value = nn.Linear(ffn_dim, n_embed, bias=False) 48 | 49 | def forward(self, x, state_ffn): 50 | xk = x * self.time_mix_k + state_ffn * (1 - self.time_mix_k) 51 | xr = x * self.time_mix_r + state_ffn * (1 - self.time_mix_r) 52 | new_ffn = x 53 | 54 | r = torch.sigmoid(self.receptance(xr)) 55 | k = torch.square(torch.relu(self.key(xk))) 56 | kv = self.value(k) 57 | 58 | rkv = r * kv 59 | return rkv, new_ffn 60 | 61 | 62 | class RWKV_TimeMix(nn.Module): 63 | def __init__(self, n_embed): 64 | super().__init__() 65 | self.time_decay = nn.Parameter(torch.ones(n_embed)) 66 | self.time_first = nn.Parameter(torch.ones(n_embed) * math.log(0.3)) 67 | 68 | self.time_mix_k = nn.Parameter(torch.ones(1, n_embed)) 69 | self.time_mix_v = nn.Parameter(torch.ones(1, n_embed)) 70 | self.time_mix_r = nn.Parameter(torch.ones(1, n_embed)) 71 | 72 | self.key = nn.Linear(n_embed, n_embed, bias=False) 73 | self.value = nn.Linear(n_embed, n_embed, bias=False) 74 | self.receptance = nn.Linear(n_embed, n_embed, bias=False) 75 | 76 | self.output = nn.Linear(n_embed, n_embed, bias=False) 77 | 78 | def forward(self, x, state_A, state_B, state_p, state_x): 79 | xk = x * self.time_mix_k + state_x * (1 - self.time_mix_k) 80 | xv = x * self.time_mix_v + state_x * (1 - self.time_mix_v) 81 | xr = x * self.time_mix_r + state_x * (1 - self.time_mix_r) 82 | new_x = x 83 | 84 | k = self.key(xk) 85 | v = self.value(xv) 86 | r = torch.sigmoid(self.receptance(xr)) 87 | 88 | ww = self.time_first + k 89 | p = torch.maximum(state_p, ww) 90 | e1 = torch.exp(state_p - p) 91 | e2 = torch.exp(ww - p) 92 | a = e1 * state_A + e2 * v 93 | b = e1 * state_B + e2 94 | 95 | ww = state_p + -torch.exp(self.time_decay) 96 | p = torch.maximum(ww, k) 97 | e1 = torch.exp(ww - p) 98 | e2 = torch.exp(k - p) 99 | new_A = e1 * state_A + e2 * v 100 | new_B = e1 * state_B + e2 101 | new_p = p 102 | 103 | rwkv = r * a / b 104 | rwkv = self.output(rwkv) 105 | return rwkv, new_A, new_B, new_p, new_x 106 | 107 | 108 | class RWKV_TimeMix_ONNX(nn.Module): 109 | def __init__(self, n_embed): 110 | super().__init__() 111 | self.time_decay = nn.Parameter(torch.ones(n_embed)) 112 | self.time_first = nn.Parameter(torch.ones(n_embed) * math.log(0.3)) 113 | 114 | self.time_mix_k = nn.Parameter(torch.ones(1, n_embed)) 115 | self.time_mix_v = nn.Parameter(torch.ones(1, n_embed)) 116 | self.time_mix_r = nn.Parameter(torch.ones(1, n_embed)) 117 | 118 | self.key = nn.Linear(n_embed, n_embed, bias=False) 119 | self.value = nn.Linear(n_embed, n_embed, bias=False) 120 | self.receptance = nn.Linear(n_embed, n_embed, bias=False) 121 | 122 | self.output = nn.Linear(n_embed, n_embed, bias=False) 123 | 124 | def forward(self, x, state_A, state_B, state_p, state_x): 125 | xk = x * self.time_mix_k + state_x * (1 - self.time_mix_k) 126 | xv = x * self.time_mix_v + state_x * (1 - self.time_mix_v) 127 | xr = x * self.time_mix_r + state_x * (1 - self.time_mix_r) 128 | new_x = x 129 | 130 | k = self.key(xk) 131 | v = self.value(xv) 132 | r = torch.sigmoid(self.receptance(xr)) 133 | 134 | ww = self.time_first + k 135 | # p = torch.maximum(state_p, ww) 136 | p = torch.stack([state_p.flatten(), ww.flatten()]).max(dim=0)[0].view(state_p.shape) 137 | # p = torch.where(state_p > ww, state_p, ww) 138 | 139 | e1 = torch.exp(state_p - p) 140 | e2 = torch.exp(ww - p) 141 | a = e1 * state_A + e2 * v 142 | b = e1 * state_B + e2 143 | 144 | ww = state_p + -torch.exp(self.time_decay) 145 | # p = torch.maximum(ww, k) 146 | p = torch.stack([ww.flatten(), k.flatten()]).max(dim=0)[0].view(state_p.shape) 147 | # p = torch.where(ww > k, ww, k) 148 | 149 | e1 = torch.exp(ww - p) 150 | e2 = torch.exp(k - p) 151 | new_A = e1 * state_A + e2 * v 152 | new_B = e1 * state_B + e2 153 | new_p = p 154 | 155 | rwkv = r * a / b 156 | rwkv = self.output(rwkv) 157 | return rwkv, new_A, new_B, new_p, new_x 158 | 159 | 160 | class Block(nn.Module): 161 | def __init__(self, layer_id, n_embed, ffn_dim): 162 | super().__init__() 163 | self.layer_id = layer_id 164 | 165 | self.ln1 = nn.LayerNorm(n_embed) 166 | self.ln2 = nn.LayerNorm(n_embed) 167 | if self.layer_id == 0: 168 | self.ln0 = nn.LayerNorm(n_embed) 169 | 170 | self.att = RWKV_TimeMix(n_embed) 171 | self.ffn = RWKV_ChannelMix(n_embed, ffn_dim) 172 | 173 | def forward(self, x, state_A, state_B, state_p, state_x, state_ffn): 174 | if self.layer_id == 0: 175 | x = self.ln0(x) 176 | 177 | short_cut = x 178 | x = self.ln1(x) 179 | x, new_A, new_B, new_p, new_x = self.att(x, state_A, state_B, state_p, state_x) 180 | x = short_cut + x 181 | 182 | short_cut = x 183 | x = self.ln2(x) 184 | x, new_ffn = self.ffn(x, state_ffn) 185 | x = short_cut + x 186 | return x, new_A, new_B, new_p, new_x, new_ffn 187 | 188 | 189 | class Block_ONNX(nn.Module): 190 | def __init__(self, layer_id, n_embed, ffn_dim): 191 | super().__init__() 192 | self.layer_id = layer_id 193 | 194 | self.ln1 = nn.LayerNorm(n_embed) 195 | self.ln2 = nn.LayerNorm(n_embed) 196 | if self.layer_id == 0: 197 | self.ln0 = nn.LayerNorm(n_embed) 198 | 199 | self.att = RWKV_TimeMix_ONNX(n_embed) 200 | self.ffn = RWKV_ChannelMix(n_embed, ffn_dim) 201 | 202 | def forward(self, x, state_A, state_B, state_p, state_x, state_ffn): 203 | if self.layer_id == 0: 204 | x = self.ln0(x) 205 | 206 | short_cut = x 207 | x, new_A, new_B, new_p, new_x = self.att(self.ln1(x), state_A, state_B, state_p, state_x) 208 | x = short_cut + x 209 | 210 | short_cut = x 211 | x, new_ffn = self.ffn(self.ln2(x), state_ffn) 212 | x = short_cut + x 213 | return x, new_A, new_B, new_p, new_x, new_ffn 214 | 215 | 216 | class Block_Script(nn.Module): 217 | def __init__(self, n_embed, ffn_dim): 218 | super().__init__() 219 | 220 | self.ln1 = nn.LayerNorm(n_embed) 221 | self.ln2 = nn.LayerNorm(n_embed) 222 | 223 | self.att = RWKV_TimeMix(n_embed) 224 | self.ffn = RWKV_ChannelMix(n_embed, ffn_dim) 225 | 226 | def forward(self, x, state_A, state_B, state_p, state_x, state_ffn): 227 | short_cut = x 228 | x, new_A, new_B, new_p, new_x = self.att(self.ln1(x), state_A, state_B, state_p, state_x) 229 | x = short_cut + x 230 | 231 | short_cut = x 232 | x, new_ffn = self.ffn(self.ln2(x), state_ffn) 233 | x = short_cut + x 234 | return x, new_A, new_B, new_p, new_x, new_ffn 235 | 236 | 237 | class RWKV_V4_Infer_For_CoreML(nn.Module): 238 | def __init__(self, 239 | vocab_size=2000, 240 | hidden_size=512, 241 | num_hidden_layers=4, 242 | intermediate_size=1024, 243 | ): 244 | super(RWKV_V4_Infer_For_CoreML, self).__init__() 245 | self.hidden_size = hidden_size 246 | self.num_hidden_layers = num_hidden_layers 247 | 248 | self.emb = nn.Embedding(vocab_size, hidden_size) 249 | self.blocks = nn.ModuleList([Block(i, hidden_size, intermediate_size) for i in range(num_hidden_layers)]) 250 | self.ln_out = nn.LayerNorm(hidden_size) 251 | self.head = nn.Linear(hidden_size, vocab_size, bias=False) 252 | 253 | def forward_initialzation(self, batch_size, device): 254 | state_A = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 255 | state_B = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 256 | state_p = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) - 1e30 257 | state_x = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 258 | state_ffn = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 259 | hidden_state = torch.stack([state_A, state_B, state_p, state_x, state_ffn]).to(device) 260 | return hidden_state 261 | 262 | def forward(self, x, hidden_state): 263 | # x = self.emb(input_token) 264 | # x = torch.matmul(input_onehot, self.emb.weight) 265 | 266 | batch_size = x.size(0) 267 | state_A, state_B, state_p, state_x, state_ffn = hidden_state.split(1, dim=0) 268 | new_hidden_state = [] 269 | 270 | for i, block in enumerate(self.blocks): 271 | x, new_A, new_B, new_p, new_x, new_ffn = \ 272 | block(x, state_A[0, i], state_B[0, i], state_p[0, i], state_x[0, i], state_ffn[0, i]) 273 | 274 | new_hidden_state.append(new_A) 275 | new_hidden_state.append(new_B) 276 | new_hidden_state.append(new_p) 277 | new_hidden_state.append(new_x) 278 | new_hidden_state.append(new_ffn) 279 | 280 | new_hidden_state = torch.cat(new_hidden_state) 281 | new_hidden_state = new_hidden_state.view([self.num_hidden_layers, 5, batch_size, self.hidden_size]) 282 | new_hidden_state = new_hidden_state.transpose(0, 1) 283 | x = self.ln_out(x) 284 | x = self.head(x) 285 | return x, new_hidden_state 286 | 287 | 288 | class RWKV_V4_Infer_For_ONNX(nn.Module): 289 | def __init__(self, 290 | vocab_size=2000, 291 | hidden_size=512, 292 | num_hidden_layers=4, 293 | intermediate_size=1024, 294 | ): 295 | super(RWKV_V4_Infer_For_ONNX, self).__init__() 296 | self.hidden_size = hidden_size 297 | self.num_hidden_layers = num_hidden_layers 298 | 299 | self.emb = nn.Embedding(vocab_size, hidden_size) 300 | self.blocks = nn.ModuleList([Block_ONNX(i, hidden_size, intermediate_size) for i in range(num_hidden_layers)]) 301 | self.ln_out = nn.LayerNorm(hidden_size) 302 | self.head = nn.Linear(hidden_size, vocab_size, bias=False) 303 | 304 | def forward_initialzation(self, batch_size, device): 305 | state_A = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 306 | state_B = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 307 | state_p = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) - 1e30 308 | state_x = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 309 | state_ffn = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 310 | hidden_state = torch.stack([state_A, state_B, state_p, state_x, state_ffn]).to(device) 311 | return hidden_state 312 | 313 | def forward(self, x, hidden_state): 314 | # x = self.emb(input_token) 315 | batch_size = x.size(0) 316 | # x = torch.matmul(input_onehot, self.emb.weight) 317 | state_A, state_B, state_p, state_x, state_ffn = hidden_state.split(1, dim=0) 318 | new_hidden_state = [] 319 | 320 | for i, block in enumerate(self.blocks): 321 | x, new_A, new_B, new_p, new_x, new_ffn = \ 322 | block(x, state_A[0, i], state_B[0, i], state_p[0, i], state_x[0, i], state_ffn[0, i]) 323 | 324 | new_hidden_state.append(new_A) 325 | new_hidden_state.append(new_B) 326 | new_hidden_state.append(new_p) 327 | new_hidden_state.append(new_x) 328 | new_hidden_state.append(new_ffn) 329 | 330 | new_hidden_state = torch.cat(new_hidden_state) 331 | new_hidden_state = new_hidden_state.view([self.num_hidden_layers, 5, batch_size, self.hidden_size]) 332 | new_hidden_state = new_hidden_state.transpose(0, 1) 333 | x = self.ln_out(x) 334 | x = self.head(x) 335 | return x, new_hidden_state 336 | 337 | 338 | class RWKV_V4_Infer_For_Script(nn.Module): 339 | def __init__(self, 340 | vocab_size=2000, 341 | hidden_size=512, 342 | num_hidden_layers=4, 343 | intermediate_size=1024, 344 | ): 345 | super(RWKV_V4_Infer_For_Script, self).__init__() 346 | self.hidden_size = hidden_size 347 | self.num_hidden_layers = num_hidden_layers 348 | 349 | self.emb = nn.Embedding(vocab_size, hidden_size) 350 | self.ln0 = nn.LayerNorm(hidden_size) 351 | self.blocks = nn.ModuleList([Block_Script(hidden_size, intermediate_size) for i in range(num_hidden_layers)]) 352 | self.ln_out = nn.LayerNorm(hidden_size) 353 | self.head = nn.Linear(hidden_size, vocab_size, bias=False) 354 | 355 | def forward_initialzation(self, batch_size, device): 356 | state_A = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 357 | state_B = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 358 | state_p = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) - 1e30 359 | state_x = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 360 | state_ffn = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 361 | hidden_state = torch.stack([state_A, state_B, state_p, state_x, state_ffn]).to(device) 362 | return hidden_state 363 | 364 | def forward(self, input_token, hidden_state): 365 | x = self.emb(input_token) 366 | batch_size = input_token.size(0) 367 | # x = torch.matmul(input_onehot, self.emb.weight) 368 | state_A, state_B, state_p, state_x, state_ffn = hidden_state.split(1, dim=0) 369 | new_hidden_state = [] 370 | 371 | x = self.ln0(x) 372 | for i, block in enumerate(self.blocks): 373 | x, new_A, new_B, new_p, new_x, new_ffn = \ 374 | block(x, state_A[0, i], state_B[0, i], state_p[0, i], state_x[0, i], state_ffn[0, i]) 375 | 376 | new_hidden_state.append(new_A) 377 | new_hidden_state.append(new_B) 378 | new_hidden_state.append(new_p) 379 | new_hidden_state.append(new_x) 380 | new_hidden_state.append(new_ffn) 381 | 382 | new_hidden_state = torch.cat(new_hidden_state) 383 | new_hidden_state = new_hidden_state.view([self.num_hidden_layers, 5, batch_size, self.hidden_size]) 384 | new_hidden_state = new_hidden_state.transpose(0, 1) 385 | x = self.ln_out(x) 386 | x = self.head(x) 387 | return x, new_hidden_state 388 | 389 | 390 | @MODULE_BUILD_FUNCS.registe_with_name(module_name='rwkv_v4_infer_for_coreml') 391 | def build_rwkv_v4_infer_for_coreml(args): 392 | model = RWKV_V4_Infer_For_CoreML( 393 | vocab_size = args.vocab_size, 394 | hidden_size = args.hidden_size, 395 | num_hidden_layers = args.num_hidden_layer, 396 | intermediate_size = args.intermediate_size 397 | ) 398 | criterion = nn.CrossEntropyLoss(reduction='none') 399 | return model, criterion 400 | 401 | 402 | @MODULE_BUILD_FUNCS.registe_with_name(module_name='rwkv_v4_infer_for_onnx') 403 | def build_rwkv_v4_infer_for_onnx(args): 404 | model = RWKV_V4_Infer_For_ONNX( 405 | vocab_size = args.vocab_size, 406 | hidden_size = args.hidden_size, 407 | num_hidden_layers = args.num_hidden_layer, 408 | intermediate_size = args.intermediate_size 409 | ) 410 | criterion = nn.CrossEntropyLoss(reduction='none') 411 | return model, criterion 412 | 413 | 414 | @MODULE_BUILD_FUNCS.registe_with_name(module_name='rwkv_v4_infer_for_script') 415 | def build_rwkv_v4_infer_for_script(args): 416 | model = RWKV_V4_Infer_For_Script( 417 | vocab_size = args.vocab_size, 418 | hidden_size = args.hidden_size, 419 | num_hidden_layers = args.num_hidden_layer, 420 | intermediate_size = args.intermediate_size 421 | ) 422 | criterion = nn.CrossEntropyLoss(reduction='none') 423 | return model, criterion -------------------------------------------------------------------------------- /models/RWKV_V4/rwkv_v5_infer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numbers 4 | import torch.nn as nn 5 | from torch.nn import init 6 | import torch.nn.functional as F 7 | from ..registry import MODULE_BUILD_FUNCS 8 | from torch import Tensor, Size 9 | from typing import Union, List 10 | _shape_t = Union[int, List[int], Size] 11 | 12 | class LayerNorm(nn.Module): 13 | def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, device=None, dtype=None): 14 | factory_kwargs = {'device': device, 'dtype': dtype} 15 | super(LayerNorm, self).__init__() 16 | if isinstance(normalized_shape, numbers.Integral): 17 | # mypy error: incompatible types in assignment 18 | normalized_shape = (normalized_shape,) # type: ignore[assignment] 19 | self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] 20 | self.eps = eps 21 | self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) 22 | self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) 23 | self.reset_parameters() 24 | 25 | def reset_parameters(self) -> None: 26 | init.ones_(self.weight) 27 | init.zeros_(self.bias) 28 | 29 | def forward(self, input: Tensor) -> Tensor: 30 | u = torch.mean(input, dim=-1, keepdim=True) 31 | s = torch.mean(input * input, dim=-1, keepdim=True) 32 | s = torch.sqrt(s - u * u + self.eps) 33 | x_normalized = (input - u) / s 34 | output = x_normalized * self.weight + self.bias 35 | return output 36 | 37 | 38 | class RWKV_ChannelMix(nn.Module): 39 | def __init__(self, n_embed, ffn_dim): 40 | super().__init__() 41 | 42 | self.time_mix_k = nn.Parameter(torch.ones(1, n_embed)) 43 | self.time_mix_r = nn.Parameter(torch.ones(1, n_embed)) 44 | 45 | self.key = nn.Linear(n_embed, ffn_dim, bias=False) 46 | self.receptance = nn.Linear(n_embed, n_embed, bias=False) 47 | self.value = nn.Linear(ffn_dim, n_embed, bias=False) 48 | 49 | def forward(self, x, state_ffn): 50 | xk = x * self.time_mix_k + state_ffn * (1 - self.time_mix_k) 51 | xr = x * self.time_mix_r + state_ffn * (1 - self.time_mix_r) 52 | new_ffn = x 53 | 54 | r = torch.sigmoid(self.receptance(xr)) 55 | k = torch.square(torch.relu(self.key(xk))) 56 | kv = self.value(k) 57 | 58 | rkv = r * kv 59 | return rkv, new_ffn 60 | 61 | 62 | class RWKV_TimeMix(nn.Module): 63 | def __init__(self, n_embed): 64 | super().__init__() 65 | self.time_decay = nn.Parameter(torch.ones(n_embed)) 66 | self.time_first = nn.Parameter(torch.ones(n_embed) * math.log(0.3)) 67 | 68 | self.time_mix_k = nn.Parameter(torch.ones(1, n_embed)) 69 | self.time_mix_v = nn.Parameter(torch.ones(1, n_embed)) 70 | self.time_mix_r = nn.Parameter(torch.ones(1, n_embed)) 71 | self.time_mix_g = nn.Parameter(torch.ones(1, n_embed)) 72 | 73 | self.key = nn.Linear(n_embed, n_embed, bias=False) 74 | self.value = nn.Linear(n_embed, n_embed, bias=False) 75 | self.receptance = nn.Linear(n_embed, n_embed, bias=False) 76 | 77 | self.output = nn.Linear(n_embed, n_embed, bias=False) 78 | self.gate = nn.Linear(n_embed, n_embed, bias=False) 79 | 80 | def forward(self, x, state_A, state_B, state_p, state_x): 81 | xk = x * self.time_mix_k + state_x * (1 - self.time_mix_k) 82 | xv = x * self.time_mix_v + state_x * (1 - self.time_mix_v) 83 | xr = x * self.time_mix_r + state_x * (1 - self.time_mix_r) 84 | xg = x * self.time_mix_g + state_x * (1 - self.time_mix_g) 85 | new_x = x 86 | 87 | k = self.key(xk) 88 | v = self.value(xv) 89 | r = torch.sigmoid(self.receptance(xr)) 90 | g = F.relu(self.gate(xg)) 91 | 92 | ww = self.time_first + k 93 | p = torch.maximum(state_p, ww) 94 | e1 = torch.exp(state_p - p) 95 | e2 = torch.exp(ww - p) 96 | a = e1 * state_A + e2 * v 97 | b = e1 * state_B + e2 98 | 99 | ww = state_p + -torch.exp(self.time_decay) 100 | p = torch.maximum(ww, k) 101 | e1 = torch.exp(ww - p) 102 | e2 = torch.exp(k - p) 103 | new_A = e1 * state_A + e2 * v 104 | new_B = e1 * state_B + e2 105 | new_p = p 106 | 107 | rwkv = r * a / b 108 | rwkv = self.output(rwkv * g) 109 | return rwkv, new_A, new_B, new_p, new_x 110 | 111 | 112 | class RWKV_TimeMix_ONNX(nn.Module): 113 | def __init__(self, n_embed): 114 | super().__init__() 115 | self.time_decay = nn.Parameter(torch.ones(n_embed)) 116 | self.time_first = nn.Parameter(torch.ones(n_embed) * math.log(0.3)) 117 | 118 | self.time_mix_k = nn.Parameter(torch.ones(1, n_embed)) 119 | self.time_mix_v = nn.Parameter(torch.ones(1, n_embed)) 120 | self.time_mix_r = nn.Parameter(torch.ones(1, n_embed)) 121 | self.time_mix_g = nn.Parameter(torch.ones(1, n_embed)) 122 | 123 | self.key = nn.Linear(n_embed, n_embed, bias=False) 124 | self.value = nn.Linear(n_embed, n_embed, bias=False) 125 | self.receptance = nn.Linear(n_embed, n_embed, bias=False) 126 | 127 | self.output = nn.Linear(n_embed, n_embed, bias=False) 128 | 129 | def forward(self, x, state_A, state_B, state_p, state_x): 130 | xk = x * self.time_mix_k + state_x * (1 - self.time_mix_k) 131 | xv = x * self.time_mix_v + state_x * (1 - self.time_mix_v) 132 | xr = x * self.time_mix_r + state_x * (1 - self.time_mix_r) 133 | xg = x * self.time_mix_g + state_x * (1 - self.time_mix_g) 134 | new_x = x 135 | 136 | k = self.key(xk) 137 | v = self.value(xv) 138 | r = torch.sigmoid(self.receptance(xr)) 139 | g = F.relu(self.gate(xg)) 140 | 141 | ww = self.time_first + k 142 | # p = torch.maximum(state_p, ww) 143 | p = torch.stack([state_p.flatten(), ww.flatten()]).max(dim=0)[0].view(state_p.shape) 144 | # p = torch.where(state_p > ww, state_p, ww) 145 | 146 | e1 = torch.exp(state_p - p) 147 | e2 = torch.exp(ww - p) 148 | a = e1 * state_A + e2 * v 149 | b = e1 * state_B + e2 150 | 151 | ww = state_p + -torch.exp(self.time_decay) 152 | # p = torch.maximum(ww, k) 153 | p = torch.stack([ww.flatten(), k.flatten()]).max(dim=0)[0].view(state_p.shape) 154 | # p = torch.where(ww > k, ww, k) 155 | 156 | e1 = torch.exp(ww - p) 157 | e2 = torch.exp(k - p) 158 | new_A = e1 * state_A + e2 * v 159 | new_B = e1 * state_B + e2 160 | new_p = p 161 | 162 | rwkv = r * a / b 163 | rwkv = self.output(rwkv * g) 164 | return rwkv, new_A, new_B, new_p, new_x 165 | 166 | 167 | class Block(nn.Module): 168 | def __init__(self, layer_id, n_embed, ffn_dim): 169 | super().__init__() 170 | self.layer_id = layer_id 171 | 172 | self.ln1 = nn.LayerNorm(n_embed) 173 | self.ln2 = nn.LayerNorm(n_embed) 174 | if self.layer_id == 0: 175 | self.ln0 = nn.LayerNorm(n_embed) 176 | 177 | self.att = RWKV_TimeMix(n_embed) 178 | self.ffn = RWKV_ChannelMix(n_embed, ffn_dim) 179 | 180 | def forward(self, x, state_A, state_B, state_p, state_x, state_ffn): 181 | if self.layer_id == 0: 182 | x = self.ln0(x) 183 | 184 | short_cut = x 185 | x = self.ln1(x) 186 | x, new_A, new_B, new_p, new_x = self.att(x, state_A, state_B, state_p, state_x) 187 | x = short_cut + x 188 | 189 | short_cut = x 190 | x = self.ln2(x) 191 | x, new_ffn = self.ffn(x, state_ffn) 192 | x = short_cut + x 193 | return x, new_A, new_B, new_p, new_x, new_ffn 194 | 195 | 196 | class Block_ONNX(nn.Module): 197 | def __init__(self, layer_id, n_embed, ffn_dim): 198 | super().__init__() 199 | self.layer_id = layer_id 200 | 201 | self.ln1 = nn.LayerNorm(n_embed) 202 | self.ln2 = nn.LayerNorm(n_embed) 203 | if self.layer_id == 0: 204 | self.ln0 = nn.LayerNorm(n_embed) 205 | 206 | self.att = RWKV_TimeMix_ONNX(n_embed) 207 | self.ffn = RWKV_ChannelMix(n_embed, ffn_dim) 208 | 209 | def forward(self, x, state_A, state_B, state_p, state_x, state_ffn): 210 | if self.layer_id == 0: 211 | x = self.ln0(x) 212 | 213 | short_cut = x 214 | x, new_A, new_B, new_p, new_x = self.att(self.ln1(x), state_A, state_B, state_p, state_x) 215 | x = short_cut + x 216 | 217 | short_cut = x 218 | x, new_ffn = self.ffn(self.ln2(x), state_ffn) 219 | x = short_cut + x 220 | return x, new_A, new_B, new_p, new_x, new_ffn 221 | 222 | 223 | class Block_Script(nn.Module): 224 | def __init__(self, n_embed, ffn_dim): 225 | super().__init__() 226 | 227 | self.ln1 = nn.LayerNorm(n_embed) 228 | self.ln2 = nn.LayerNorm(n_embed) 229 | 230 | self.att = RWKV_TimeMix(n_embed) 231 | self.ffn = RWKV_ChannelMix(n_embed, ffn_dim) 232 | 233 | def forward(self, x, state_A, state_B, state_p, state_x, state_ffn): 234 | short_cut = x 235 | x, new_A, new_B, new_p, new_x = self.att(self.ln1(x), state_A, state_B, state_p, state_x) 236 | x = short_cut + x 237 | 238 | short_cut = x 239 | x, new_ffn = self.ffn(self.ln2(x), state_ffn) 240 | x = short_cut + x 241 | return x, new_A, new_B, new_p, new_x, new_ffn 242 | 243 | 244 | class RWKV_V5_Infer_For_CoreML(nn.Module): 245 | def __init__(self, 246 | vocab_size=2000, 247 | hidden_size=512, 248 | num_hidden_layers=4, 249 | intermediate_size=1024, 250 | ): 251 | super(RWKV_V5_Infer_For_CoreML, self).__init__() 252 | self.hidden_size = hidden_size 253 | self.num_hidden_layers = num_hidden_layers 254 | 255 | self.emb = nn.Embedding(vocab_size, hidden_size) 256 | self.blocks = nn.ModuleList([Block(i, hidden_size, intermediate_size) for i in range(num_hidden_layers)]) 257 | self.ln_out = nn.LayerNorm(hidden_size) 258 | self.head = nn.Linear(hidden_size, vocab_size, bias=False) 259 | 260 | def forward_initialzation(self, batch_size, device): 261 | state_A = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 262 | state_B = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 263 | state_p = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) - 1e30 264 | state_x = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 265 | state_ffn = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 266 | hidden_state = torch.stack([state_A, state_B, state_p, state_x, state_ffn]).to(device) 267 | return hidden_state 268 | 269 | def forward(self, x, hidden_state): 270 | # x = self.emb(input_token) 271 | # x = torch.matmul(input_onehot, self.emb.weight) 272 | 273 | batch_size = x.size(0) 274 | state_A, state_B, state_p, state_x, state_ffn = hidden_state.split(1, dim=0) 275 | new_hidden_state = [] 276 | 277 | for i, block in enumerate(self.blocks): 278 | x, new_A, new_B, new_p, new_x, new_ffn = \ 279 | block(x, state_A[0, i], state_B[0, i], state_p[0, i], state_x[0, i], state_ffn[0, i]) 280 | 281 | new_hidden_state.append(new_A) 282 | new_hidden_state.append(new_B) 283 | new_hidden_state.append(new_p) 284 | new_hidden_state.append(new_x) 285 | new_hidden_state.append(new_ffn) 286 | 287 | new_hidden_state = torch.cat(new_hidden_state) 288 | new_hidden_state = new_hidden_state.view([self.num_hidden_layers, 5, batch_size, self.hidden_size]) 289 | new_hidden_state = new_hidden_state.transpose(0, 1) 290 | x = self.ln_out(x) 291 | x = self.head(x) 292 | return x, new_hidden_state 293 | 294 | 295 | class RWKV_V5_Infer_For_ONNX(nn.Module): 296 | def __init__(self, 297 | vocab_size=2000, 298 | hidden_size=512, 299 | num_hidden_layers=4, 300 | intermediate_size=1024, 301 | ): 302 | super(RWKV_V5_Infer_For_ONNX, self).__init__() 303 | self.hidden_size = hidden_size 304 | self.num_hidden_layers = num_hidden_layers 305 | 306 | self.emb = nn.Embedding(vocab_size, hidden_size) 307 | self.blocks = nn.ModuleList([Block_ONNX(i, hidden_size, intermediate_size) for i in range(num_hidden_layers)]) 308 | self.ln_out = nn.LayerNorm(hidden_size) 309 | self.head = nn.Linear(hidden_size, vocab_size, bias=False) 310 | 311 | def forward_initialzation(self, batch_size, device): 312 | state_A = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 313 | state_B = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 314 | state_p = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) - 1e30 315 | state_x = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 316 | state_ffn = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 317 | hidden_state = torch.stack([state_A, state_B, state_p, state_x, state_ffn]).to(device) 318 | return hidden_state 319 | 320 | def forward(self, x, hidden_state): 321 | # x = self.emb(input_token) 322 | batch_size = x.size(0) 323 | # x = torch.matmul(input_onehot, self.emb.weight) 324 | state_A, state_B, state_p, state_x, state_ffn = hidden_state.split(1, dim=0) 325 | new_hidden_state = [] 326 | 327 | for i, block in enumerate(self.blocks): 328 | x, new_A, new_B, new_p, new_x, new_ffn = \ 329 | block(x, state_A[0, i], state_B[0, i], state_p[0, i], state_x[0, i], state_ffn[0, i]) 330 | 331 | new_hidden_state.append(new_A) 332 | new_hidden_state.append(new_B) 333 | new_hidden_state.append(new_p) 334 | new_hidden_state.append(new_x) 335 | new_hidden_state.append(new_ffn) 336 | 337 | new_hidden_state = torch.cat(new_hidden_state) 338 | new_hidden_state = new_hidden_state.view([self.num_hidden_layers, 5, batch_size, self.hidden_size]) 339 | new_hidden_state = new_hidden_state.transpose(0, 1) 340 | x = self.ln_out(x) 341 | x = self.head(x) 342 | return x, new_hidden_state 343 | 344 | 345 | class RWKV_V5_Infer_For_Script(nn.Module): 346 | def __init__(self, 347 | vocab_size=2000, 348 | hidden_size=512, 349 | num_hidden_layers=4, 350 | intermediate_size=1024, 351 | ): 352 | super(RWKV_V5_Infer_For_Script, self).__init__() 353 | self.hidden_size = hidden_size 354 | self.num_hidden_layers = num_hidden_layers 355 | 356 | self.emb = nn.Embedding(vocab_size, hidden_size) 357 | self.ln0 = nn.LayerNorm(hidden_size) 358 | self.blocks = nn.ModuleList([Block_Script(hidden_size, intermediate_size) for i in range(num_hidden_layers)]) 359 | self.ln_out = nn.LayerNorm(hidden_size) 360 | self.head = nn.Linear(hidden_size, vocab_size, bias=False) 361 | 362 | def forward_initialzation(self, batch_size, device): 363 | state_A = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 364 | state_B = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 365 | state_p = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) - 1e30 366 | state_x = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 367 | state_ffn = torch.zeros([self.num_hidden_layers, batch_size, self.hidden_size]) 368 | hidden_state = torch.stack([state_A, state_B, state_p, state_x, state_ffn]).to(device) 369 | return hidden_state 370 | 371 | def forward(self, input_token, hidden_state): 372 | x = self.emb(input_token) 373 | batch_size = input_token.size(0) 374 | # x = torch.matmul(input_onehot, self.emb.weight) 375 | state_A, state_B, state_p, state_x, state_ffn = hidden_state.split(1, dim=0) 376 | new_hidden_state = [] 377 | 378 | x = self.ln0(x) 379 | for i, block in enumerate(self.blocks): 380 | x, new_A, new_B, new_p, new_x, new_ffn = \ 381 | block(x, state_A[0, i], state_B[0, i], state_p[0, i], state_x[0, i], state_ffn[0, i]) 382 | 383 | new_hidden_state.append(new_A) 384 | new_hidden_state.append(new_B) 385 | new_hidden_state.append(new_p) 386 | new_hidden_state.append(new_x) 387 | new_hidden_state.append(new_ffn) 388 | 389 | new_hidden_state = torch.cat(new_hidden_state) 390 | new_hidden_state = new_hidden_state.view([self.num_hidden_layers, 5, batch_size, self.hidden_size]) 391 | new_hidden_state = new_hidden_state.transpose(0, 1) 392 | x = self.ln_out(x) 393 | x = self.head(x) 394 | return x, new_hidden_state 395 | 396 | 397 | @MODULE_BUILD_FUNCS.registe_with_name(module_name='rwkv_v5_infer_for_coreml') 398 | def build_rwkv_v5_infer_for_coreml(args): 399 | model = RWKV_V5_Infer_For_CoreML( 400 | vocab_size = args.vocab_size, 401 | hidden_size = args.hidden_size, 402 | num_hidden_layers = args.num_hidden_layer, 403 | intermediate_size = args.intermediate_size 404 | ) 405 | criterion = nn.CrossEntropyLoss(reduction='none') 406 | return model, criterion 407 | 408 | 409 | @MODULE_BUILD_FUNCS.registe_with_name(module_name='rwkv_v5_infer_for_onnx') 410 | def build_rwkv_v5_infer_for_onnx(args): 411 | model = RWKV_V5_Infer_For_ONNX( 412 | vocab_size = args.vocab_size, 413 | hidden_size = args.hidden_size, 414 | num_hidden_layers = args.num_hidden_layer, 415 | intermediate_size = args.intermediate_size 416 | ) 417 | criterion = nn.CrossEntropyLoss(reduction='none') 418 | return model, criterion 419 | 420 | 421 | @MODULE_BUILD_FUNCS.registe_with_name(module_name='rwkv_v5_infer_for_script') 422 | def build_rwkv_v5_infer_for_script(args): 423 | model = RWKV_V5_Infer_For_Script( 424 | vocab_size = args.vocab_size, 425 | hidden_size = args.hidden_size, 426 | num_hidden_layers = args.num_hidden_layer, 427 | intermediate_size = args.intermediate_size 428 | ) 429 | criterion = nn.CrossEntropyLoss(reduction='none') 430 | return model, criterion -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .RWKV_V4 import build_rwkv_v4 -------------------------------------------------------------------------------- /models/compressor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class XLCompressorForCoreML(nn.Module): 6 | def __init__(self, model, num_tokens): 7 | super(XLCompressorForCoreML, self).__init__() 8 | self.next_token_predictor = model 9 | self.num_tokens = num_tokens 10 | 11 | def forward(self, input_onehots, memories): 12 | outputs = [] 13 | for token_id in range(self.num_tokens): 14 | input_onehot = input_onehots[:, token_id:token_id+1, :] 15 | output, memories = self.next_token_predictor(input_onehot, memories) 16 | outputs.append(output) 17 | 18 | if self.num_tokens > 1: 19 | outputs = torch.cat(outputs, dim=1) 20 | else: 21 | outputs = outputs[0] 22 | return outputs, memories 23 | 24 | 25 | class XLCompressorCache(nn.Module): 26 | def __init__(self, model, num_tokens): 27 | super(XLCompressorCache, self).__init__() 28 | self.next_token_predictor = model 29 | self.num_tokens = num_tokens 30 | 31 | def forward(self, input_onehots, k_cache, v_cache): 32 | outputs = [] 33 | for token_id in range(self.num_tokens): 34 | input_onehot = input_onehots[:, token_id:token_id + 1, :] 35 | output, k_cache, v_cache = self.next_token_predictor(input_onehot, k_cache, v_cache) 36 | outputs.append(output) 37 | 38 | if self.num_tokens > 1: 39 | outputs = torch.cat(outputs, dim=1) 40 | else: 41 | outputs = outputs[0] 42 | return outputs, k_cache, v_cache 43 | 44 | 45 | class XLCompressorForXNN(nn.Module): 46 | def __init__(self, model, num_tokens): 47 | super(XLCompressorForXNN, self).__init__() 48 | self.next_token_predictor = model 49 | self.num_tokens = num_tokens 50 | 51 | def forward(self, input_onehots, memories): 52 | outputs = [] 53 | for token_id in range(self.num_tokens): 54 | input_onehot = input_onehots[:, token_id:token_id + 1, :] 55 | output, memories = self.next_token_predictor(input_onehot, memories) 56 | outputs.append(output) 57 | 58 | if self.num_tokens > 1: 59 | outputs = torch.cat(outputs, dim=1) 60 | else: 61 | outputs = outputs[0] 62 | return outputs, memories 63 | 64 | 65 | class RWKVCompressorForCoreML(nn.Module): 66 | def __init__(self, model, num_tokens): 67 | super(RWKVCompressorForCoreML, self).__init__() 68 | self.next_token_predictor = model 69 | self.num_tokens = num_tokens 70 | 71 | def forward(self, input_onehots, hidden_state): 72 | outputs = [] 73 | for token_id in range(self.num_tokens): 74 | input_onehot = input_onehots[token_id, :, :] 75 | output, hidden_state = self.next_token_predictor(input_onehot, hidden_state) 76 | outputs.append(output) 77 | 78 | if self.num_tokens > 1: 79 | outputs = torch.stack(outputs, dim=0) 80 | else: 81 | outputs = outputs[0].unsqueeze(dim=0) 82 | return outputs, hidden_state 83 | 84 | -------------------------------------------------------------------------------- /models/registry.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Yihao Chen 3 | # @Date: 2021-08-16 16:03:17 4 | # @Last Modified by: Shilong Liu 5 | # @Last Modified time: 2022-01-23 15:26 6 | # modified from mmcv 7 | 8 | import inspect 9 | from functools import partial 10 | 11 | 12 | class Registry(object): 13 | 14 | def __init__(self, name): 15 | self._name = name 16 | self._module_dict = dict() 17 | 18 | def __repr__(self): 19 | format_str = self.__class__.__name__ + '(name={}, items={})'.format( 20 | self._name, list(self._module_dict.keys())) 21 | return format_str 22 | 23 | def __len__(self): 24 | return len(self._module_dict) 25 | 26 | @property 27 | def name(self): 28 | return self._name 29 | 30 | @property 31 | def module_dict(self): 32 | return self._module_dict 33 | 34 | def get(self, key): 35 | return self._module_dict.get(key, None) 36 | 37 | def registe_with_name(self, module_name=None, force=False): 38 | return partial(self.register, module_name=module_name, force=force) 39 | 40 | def register(self, module_build_function, module_name=None, force=False): 41 | """Register a module build function. 42 | Args: 43 | module (:obj:`nn.Module`): Module to be registered. 44 | """ 45 | if not inspect.isfunction(module_build_function): 46 | raise TypeError('module_build_function must be a function, but got {}'.format( 47 | type(module_build_function))) 48 | if module_name is None: 49 | module_name = module_build_function.__name__ 50 | if not force and module_name in self._module_dict: 51 | raise KeyError('{} is already registered in {}'.format( 52 | module_name, self.name)) 53 | self._module_dict[module_name] = module_build_function 54 | 55 | return module_build_function 56 | 57 | MODULE_BUILD_FUNCS = Registry('model build functions') 58 | 59 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ipdb 2 | addict 3 | openpyxl 4 | pkuseg 5 | yapf 6 | sentencepiece 7 | thop 8 | tiktoken -------------------------------------------------------------------------------- /scripts/preprocessor.py: -------------------------------------------------------------------------------- 1 | # coding=utf8 2 | import random 3 | import json 4 | import os 5 | import glob 6 | import sentencepiece as spm 7 | from collections import Counter, OrderedDict 8 | 9 | from tqdm import tqdm 10 | 11 | 12 | class PretrainProcess(object): 13 | def __init__(self): 14 | self.TrainFile = "data/public_text_dataset/enwik8" 15 | self.ValFile = "data/public_text_dataset/enwik9" # 只取中间5000lines For 快速评测验证 16 | self.vocab_size = 16384 # 4096 for enwik8 and 16384 for enwik9 in nncp 17 | self.model_type = 'bpe' # 可选值: unigram (默认), bpe, char 或 word, 使用word类型时,必须对输入句子进行pretokenized。 18 | self.coverage = 0.999 19 | 20 | def get_file_infos(self): 21 | with open(self.TrainFile, "r", encoding="utf-8") as f: 22 | lines = f.readlines() 23 | 24 | max_length = max([len(line) for line in lines]) 25 | print(f"There are {len(lines)} lines in enwik8, the length of the longest sentence is {max_length}") 26 | 27 | def do_spm_training(self): 28 | os.makedirs(f"dictionary/vocab_enwik8_{self.model_type}_{self.vocab_size}_{self.coverage}", exist_ok=True) 29 | # train_param = f'--input=./temp_process_spm_train_data.txt --pad_id=0 --unk_id=1 \ 30 | # --bos_id=2 --eos_id=-1 \ 31 | # --model_prefix=../dictionary/vocab_enwik8_{self.model_type}_{self.vocab_size}_{self.coverage}/spm_enwik8_{self.model_type}_{self.vocab_size}_{self.coverage} \ 32 | # --vocab_size={self.vocab_size} \ 33 | # --character_coverage={self.coverage} \ 34 | # --max_sentence_length=10000 \ 35 | # --add_dummy_prefix=0 \ 36 | # --remove_extra_whitespaces=0 \ 37 | # --user_defined_symbols=\n,\t \ 38 | # --model_type={self.model_type}' 39 | # spm.SentencePieceTrainer.Train(train_param) 40 | 41 | cmd = f'''bin/spm_train \ 42 | --input={self.TrainFile} --pad_id=0 --unk_id=1 \ 43 | --bos_id=2 --eos_id=-1 \ 44 | --model_prefix=dictionary/vocab_enwik8_{self.model_type}_{self.vocab_size}_{self.coverage}/spm_enwik8_{self.model_type}_{self.vocab_size}_{self.coverage} \ 45 | --vocab_size={self.vocab_size} \ 46 | --character_coverage={self.coverage} \ 47 | --max_sentence_length=10000 \ 48 | --add_dummy_prefix=0 \ 49 | --remove_extra_whitespaces=0 \ 50 | --user_defined_symbols="\n,\t" \ 51 | --model_type={self.model_type}''' 52 | print(cmd) 53 | os.system(cmd) 54 | 55 | # rm_cmd = "rm -f ./temp_process_spm_train_data.txt" 56 | # os.system(rm_cmd) 57 | 58 | def generate_dictionary(self): 59 | spm_vocab_file = f"dictionary/vocab_enwik8_{self.model_type}_{self.vocab_size}_{self.coverage}/spm_enwik8_{self.model_type}_{self.vocab_size}_{self.coverage}.vocab" 60 | self.symb2id_dict = OrderedDict() 61 | with open(spm_vocab_file, 'r', encoding="utf-8") as f: 62 | for line in f: 63 | if line == "\t\t0\n": 64 | self.symb2id_dict["\t"] = len(self.symb2id_dict) 65 | elif line == "\n": 66 | self.symb2id_dict["\n"] = len(self.symb2id_dict) 67 | elif line == "\t0\n": 68 | continue 69 | else: 70 | symb = line.strip().split()[0] 71 | self.symb2id_dict[symb] = len(self.symb2id_dict) 72 | 73 | with open(f"dictionary/vocab_enwik8_{self.model_type}_{self.vocab_size}_{self.coverage}/spm_enwik8_{self.model_type}_{self.vocab_size}_{self.coverage}_vocab.json", "w") as obj_f: 74 | json.dump(self.symb2id_dict, obj_f, indent=4, ensure_ascii=False) 75 | 76 | def convert_rawtext_into_labeltext(self): 77 | sp_processor = spm.SentencePieceProcessor() 78 | sp_processor.Load(f"dictionary/vocab_enwik8_{self.model_type}_{self.vocab_size}_{self.coverage}/spm_enwik8_{self.model_type}_{self.vocab_size}_{self.coverage}.model") 79 | 80 | os.makedirs("data/train_data", exist_ok=True) 81 | os.makedirs("data/test_data", exist_ok=True) 82 | 83 | # 生成训练数据 84 | if True: 85 | train_data_filename = f"data/train_data/train_enwik8_{self.model_type}_{self.vocab_size}_{self.coverage}.txt" 86 | 87 | with open(train_data_filename, "w") as output_file: 88 | pass 89 | 90 | # 将enwik8看成是一行数据 91 | with open(self.TrainFile, "r", encoding="utf-8") as f: 92 | lines = f.readlines() 93 | 94 | train_tokens = [] 95 | for line in tqdm(lines): 96 | proto = sp_processor.encode(line, out_type='immutable_proto') 97 | tokens = [] 98 | for n in proto.pieces: 99 | if n.begin == n.end: 100 | continue 101 | tokens.append(str(n.id)) 102 | 103 | train_tokens.extend(tokens) 104 | 105 | train_tokens = ",".join(train_tokens) 106 | with open(train_data_filename, "a") as output_file: 107 | output_file.write(train_tokens + "\n") 108 | 109 | # 生成测试数据 110 | if True: 111 | test_data_filename = f"data/test_data/test_enwik9_{self.model_type}_{self.vocab_size}_{self.coverage}.txt" 112 | with open(test_data_filename, "w") as output_file: 113 | pass 114 | 115 | # 将enwik9看成是一行数据 116 | with open(self.ValFile, "r", encoding="utf-8") as f: 117 | lines = f.readlines() 118 | 119 | test_lines = lines[1128023:1128023+5000] 120 | test_tokens = [] 121 | for line in tqdm(test_lines): 122 | proto = sp_processor.encode(line, out_type='immutable_proto') 123 | tokens = [] 124 | for n in proto.pieces: 125 | if n.begin == n.end: 126 | continue 127 | tokens.append(str(n.id)) 128 | 129 | test_tokens.extend(tokens) 130 | 131 | test_tokens = ",".join(test_tokens) 132 | with open(test_data_filename, "a") as output_file: 133 | output_file.write(test_tokens + "\n") 134 | 135 | 136 | if __name__ == '__main__': 137 | pp = PretrainProcess() 138 | 139 | pp.get_file_infos() 140 | 141 | pp.do_spm_training() 142 | 143 | pp.generate_dictionary() 144 | 145 | pp.convert_rawtext_into_labeltext() 146 | -------------------------------------------------------------------------------- /util/arithmetic_coder.py: -------------------------------------------------------------------------------- 1 | #@title Arithmetic Coding Library 2 | 3 | # 4 | # Reference arithmetic coding 5 | # Copyright (c) Project Nayuki 6 | # 7 | # https://www.nayuki.io/page/reference-arithmetic-coding 8 | # https://github.com/nayuki/Reference-arithmetic-coding 9 | # 10 | 11 | import sys 12 | python3 = sys.version_info.major >= 3 13 | 14 | 15 | # ---- Arithmetic coding core classes ---- 16 | 17 | # Provides the state and behaviors that arithmetic coding encoders and decoders share. 18 | class ArithmeticCoderBase(object): 19 | 20 | # Constructs an arithmetic coder, which initializes the code range. 21 | def __init__(self, numbits): 22 | if numbits < 1: 23 | raise ValueError("State size out of range") 24 | 25 | # -- Configuration fields -- 26 | # Number of bits for the 'low' and 'high' state variables. Must be at least 1. 27 | # - Larger values are generally better - they allow a larger maximum frequency total (maximum_total), 28 | # and they reduce the approximation error inherent in adapting fractions to integers; 29 | # both effects reduce the data encoding loss and asymptotically approach the efficiency 30 | # of arithmetic coding using exact fractions. 31 | # - But larger state sizes increase the computation time for integer arithmetic, 32 | # and compression gains beyond ~30 bits essentially zero in real-world applications. 33 | # - Python has native bigint arithmetic, so there is no upper limit to the state size. 34 | # For Java and C++ where using native machine-sized integers makes the most sense, 35 | # they have a recommended value of num_state_bits=32 as the most versatile setting. 36 | self.num_state_bits = numbits 37 | # Maximum range (high+1-low) during coding (trivial), which is 2^num_state_bits = 1000...000. 38 | self.full_range = 1 << self.num_state_bits 39 | # The top bit at width num_state_bits, which is 0100...000. 40 | self.half_range = self.full_range >> 1 # Non-zero 41 | # The second highest bit at width num_state_bits, which is 0010...000. This is zero when num_state_bits=1. 42 | self.quarter_range = self.half_range >> 1 # Can be zero 43 | # Minimum range (high+1-low) during coding (non-trivial), which is 0010...010. 44 | self.minimum_range = self.quarter_range + 2 # At least 2 45 | # Maximum allowed total from a frequency table at all times during coding. This differs from Java 46 | # and C++ because Python's native bigint avoids constraining the size of intermediate computations. 47 | self.maximum_total = self.minimum_range 48 | # Bit mask of num_state_bits ones, which is 0111...111. 49 | self.state_mask = self.full_range - 1 50 | 51 | # -- State fields -- 52 | # Low end of this arithmetic coder's current range. Conceptually has an infinite number of trailing 0s. 53 | self.low = 0 54 | # High end of this arithmetic coder's current range. Conceptually has an infinite number of trailing 1s. 55 | self.high = self.state_mask 56 | 57 | 58 | # Updates the code range (low and high) of this arithmetic coder as a result 59 | # of processing the given symbol with the given frequency table. 60 | # Invariants that are true before and after encoding/decoding each symbol 61 | # (letting full_range = 2^num_state_bits): 62 | # - 0 <= low <= code <= high < full_range. ('code' exists only in the decoder.) 63 | # Therefore these variables are unsigned integers of num_state_bits bits. 64 | # - low < 1/2 * full_range <= high. 65 | # In other words, they are in different halves of the full range. 66 | # - (low < 1/4 * full_range) || (high >= 3/4 * full_range). 67 | # In other words, they are not both in the middle two quarters. 68 | # - Let range = high - low + 1, then full_range/4 < minimum_range 69 | # <= range <= full_range. These invariants for 'range' essentially 70 | # dictate the maximum total that the incoming frequency table can have. 71 | def update(self, freqs, symbol): 72 | # State check 73 | low = self.low 74 | high = self.high 75 | # if low >= high or (low & self.state_mask) != low or (high & self.state_mask) != high: 76 | # raise AssertionError("Low or high out of range") 77 | range = high - low + 1 78 | # if not (self.minimum_range <= range <= self.full_range): 79 | # raise AssertionError("Range out of range") 80 | 81 | # Frequency table values check 82 | total = int(freqs[-1]) 83 | symlow = int(freqs[symbol-1]) if symbol > 0 else 0 84 | symhigh = int(freqs[symbol]) 85 | #total = freqs.get_total() 86 | #symlow = freqs.get_low(symbol) 87 | #symhigh = freqs.get_high(symbol) 88 | # if symlow == symhigh: 89 | # raise ValueError("Symbol has zero frequency") 90 | # if total > self.maximum_total: 91 | # raise ValueError("Cannot code symbol because total is too large") 92 | 93 | # Update range 94 | newlow = low + symlow * range // total 95 | newhigh = low + symhigh * range // total - 1 96 | self.low = newlow 97 | self.high = newhigh 98 | 99 | # While low and high have the same top bit value, shift them out 100 | while ((self.low ^ self.high) & self.half_range) == 0: 101 | self.shift() 102 | self.low = ((self.low << 1) & self.state_mask) 103 | self.high = ((self.high << 1) & self.state_mask) | 1 104 | # Now low's top bit must be 0 and high's top bit must be 1 105 | 106 | # While low's top two bits are 01 and high's are 10, delete the second highest bit of both 107 | while (self.low & ~self.high & self.quarter_range) != 0: 108 | self.underflow() 109 | self.low = (self.low << 1) ^ self.half_range 110 | self.high = ((self.high ^ self.half_range) << 1) | self.half_range | 1 111 | 112 | 113 | # Called to handle the situation when the top bit of 'low' and 'high' are equal. 114 | def shift(self): 115 | raise NotImplementedError() 116 | 117 | 118 | # Called to handle the situation when low=01(...) and high=10(...). 119 | def underflow(self): 120 | raise NotImplementedError() 121 | 122 | 123 | # Encodes symbols and writes to an arithmetic-coded bit stream. 124 | class ArithmeticEncoder(ArithmeticCoderBase): 125 | 126 | # Constructs an arithmetic coding encoder based on the given bit output stream. 127 | def __init__(self, numbits, bitout): 128 | super(ArithmeticEncoder, self).__init__(numbits) 129 | # The underlying bit output stream. 130 | self.output = bitout 131 | # Number of saved underflow bits. This value can grow without bound. 132 | self.num_underflow = 0 133 | 134 | 135 | # Encodes the given symbol based on the given frequency table. 136 | # This updates this arithmetic coder's state and may write out some bits. 137 | def write(self, freqs, symbol): 138 | self.update(freqs, symbol) 139 | 140 | 141 | # Terminates the arithmetic coding by flushing any buffered bits, so that the output can be decoded properly. 142 | # It is important that this method must be called at the end of the each encoding process. 143 | # Note that this method merely writes data to the underlying output stream but does not close it. 144 | def finish(self): 145 | self.output.write(1) 146 | 147 | 148 | def shift(self): 149 | bit = self.low >> (self.num_state_bits - 1) 150 | self.output.write(bit) 151 | 152 | # Write out the saved underflow bits 153 | for _ in range(self.num_underflow): 154 | self.output.write(bit ^ 1) 155 | self.num_underflow = 0 156 | 157 | 158 | def underflow(self): 159 | self.num_underflow += 1 160 | 161 | 162 | # Reads from an arithmetic-coded bit stream and decodes symbols. 163 | class ArithmeticDecoder(ArithmeticCoderBase): 164 | 165 | # Constructs an arithmetic coding decoder based on the 166 | # given bit input stream, and fills the code bits. 167 | def __init__(self, numbits, bitin): 168 | super(ArithmeticDecoder, self).__init__(numbits) 169 | # The underlying bit input stream. 170 | self.input = bitin 171 | # The current raw code bits being buffered, which is always in the range [low, high]. 172 | self.code = 0 173 | for _ in range(self.num_state_bits): 174 | self.code = self.code << 1 | self.read_code_bit() 175 | 176 | 177 | # Decodes the next symbol based on the given frequency table and returns it. 178 | # Also updates this arithmetic coder's state and may read in some bits. 179 | def read(self, freqs): 180 | #if not isinstance(freqs, CheckedFrequencyTable): 181 | # freqs = CheckedFrequencyTable(freqs) 182 | 183 | # Translate from coding range scale to frequency table scale 184 | total = int(freqs[-1]) 185 | #total = freqs.get_total() 186 | #if total > self.maximum_total: 187 | # raise ValueError("Cannot decode symbol because total is too large") 188 | range = self.high - self.low + 1 189 | offset = self.code - self.low 190 | value = ((offset + 1) * total - 1) // range 191 | #assert value * range // total <= offset 192 | #assert 0 <= value < total 193 | 194 | # A kind of binary search. Find highest symbol such that freqs.get_low(symbol) <= value. 195 | start = 0 196 | end = len(freqs) 197 | #end = freqs.get_symbol_limit() 198 | while end - start > 1: 199 | middle = (start + end) >> 1 200 | low = int(freqs[middle-1]) if middle > 0 else 0 201 | #if freqs.get_low(middle) > value: 202 | if low > value: 203 | end = middle 204 | else: 205 | start = middle 206 | #assert start + 1 == end 207 | 208 | symbol = start 209 | #assert freqs.get_low(symbol) * range // total <= offset < freqs.get_high(symbol) * range // total 210 | self.update(freqs, symbol) 211 | #if not (self.low <= self.code <= self.high): 212 | # raise AssertionError("Code out of range") 213 | return symbol 214 | 215 | 216 | def shift(self): 217 | self.code = ((self.code << 1) & self.state_mask) | self.read_code_bit() 218 | 219 | 220 | def underflow(self): 221 | self.code = (self.code & self.half_range) | ((self.code << 1) & (self.state_mask >> 1)) | self.read_code_bit() 222 | 223 | 224 | # Returns the next bit (0 or 1) from the input stream. The end 225 | # of stream is treated as an infinite number of trailing zeros. 226 | def read_code_bit(self): 227 | temp = self.input.read() 228 | if temp == -1: 229 | temp = 0 230 | return temp 231 | 232 | 233 | # ---- Bit-oriented I/O streams ---- 234 | 235 | # A stream of bits that can be read. Because they come from an underlying byte stream, 236 | # the total number of bits is always a multiple of 8. The bits are read in big endian. 237 | class BitInputStream(object): 238 | 239 | # Constructs a bit input stream based on the given byte input stream. 240 | def __init__(self, inp): 241 | # The underlying byte stream to read from 242 | self.input = inp 243 | # Either in the range [0x00, 0xFF] if bits are available, or -1 if end of stream is reached 244 | self.currentbyte = 0 245 | # Number of remaining bits in the current byte, always between 0 and 7 (inclusive) 246 | self.numbitsremaining = 0 247 | 248 | 249 | # Reads a bit from this stream. Returns 0 or 1 if a bit is available, or -1 if 250 | # the end of stream is reached. The end of stream always occurs on a byte boundary. 251 | def read(self): 252 | if self.currentbyte == -1: 253 | return -1 254 | if self.numbitsremaining == 0: 255 | temp = self.input.read(1) 256 | if len(temp) == 0: 257 | self.currentbyte = -1 258 | return -1 259 | self.currentbyte = temp[0] if python3 else ord(temp) 260 | self.numbitsremaining = 8 261 | assert self.numbitsremaining > 0 262 | self.numbitsremaining -= 1 263 | return (self.currentbyte >> self.numbitsremaining) & 1 264 | 265 | 266 | # Reads a bit from this stream. Returns 0 or 1 if a bit is available, or raises an EOFError 267 | # if the end of stream is reached. The end of stream always occurs on a byte boundary. 268 | def read_no_eof(self): 269 | result = self.read() 270 | if result != -1: 271 | return result 272 | else: 273 | raise EOFError() 274 | 275 | 276 | # Closes this stream and the underlying input stream. 277 | def close(self): 278 | self.input.close() 279 | self.currentbyte = -1 280 | self.numbitsremaining = 0 281 | 282 | 283 | # A stream where bits can be written to. Because they are written to an underlying 284 | # byte stream, the end of the stream is padded with 0's up to a multiple of 8 bits. 285 | # The bits are written in big endian. 286 | class BitOutputStream(object): 287 | 288 | # Constructs a bit output stream based on the given byte output stream. 289 | def __init__(self, out): 290 | self.output = out # The underlying byte stream to write to 291 | self.currentbyte = 0 # The accumulated bits for the current byte, always in the range [0x00, 0xFF] 292 | self.numbitsfilled = 0 # Number of accumulated bits in the current byte, always between 0 and 7 (inclusive) 293 | 294 | 295 | # Writes a bit to the stream. The given bit must be 0 or 1. 296 | def write(self, b): 297 | if b not in (0, 1): 298 | raise ValueError("Argument must be 0 or 1") 299 | self.currentbyte = (self.currentbyte << 1) | b 300 | self.numbitsfilled += 1 301 | if self.numbitsfilled == 8: 302 | towrite = bytes((self.currentbyte,)) if python3 else chr(self.currentbyte) 303 | self.output.write(towrite) 304 | self.currentbyte = 0 305 | self.numbitsfilled = 0 306 | 307 | 308 | # Closes this stream and the underlying output stream. If called when this 309 | # bit stream is not at a byte boundary, then the minimum number of "0" bits 310 | # (between 0 and 7 of them) are written as padding to reach the next byte boundary. 311 | def close(self): 312 | while self.numbitsfilled != 0: 313 | self.write(0) 314 | self.output.close() 315 | 316 | -------------------------------------------------------------------------------- /util/decode.py: -------------------------------------------------------------------------------- 1 | from arithmeticcoding import * 2 | 3 | 4 | def main(): 5 | string_to_be_coding = "ABCDEdkkjkljkljkkjkkkklmkmmmmmmmmmmmmaaaasldkkjlfjklajflksjlkfjskljklfjskljfdlskjfklsjlfjlsaacc" 6 | string_dict = dict() 7 | for symbol in string_to_be_coding: 8 | if symbol in string_dict: 9 | string_dict[symbol]["count"] += 1 10 | else: 11 | string_dict[symbol] = {"count": 1} 12 | 13 | char_table = [] 14 | freq_table = [] 15 | for key_id, key in enumerate(string_dict): 16 | char_table.append(key) 17 | freq_table.append(string_dict[key]["count"]) 18 | string_dict[key]["idx"] = key_id 19 | 20 | input_file = "./compress_code.bin" 21 | 22 | # build the input stream 23 | bitin = BitInputStream(open(input_file, mode='rb')) 24 | 25 | # build arithmetic decoder 26 | dec = ArithmeticDecoder(bitin) 27 | 28 | output_string = "" 29 | for i in range(len(string_to_be_coding)): 30 | # build frequency table for coding 31 | # freq_table[-1] = freq_table[-1] + 1 32 | freq = SimpleFrequencyTable(freq_table) 33 | 34 | symbol = dec.read(freq) 35 | output_string += char_table[symbol] 36 | 37 | print(output_string) 38 | print(string_to_be_coding) 39 | print(output_string == string_to_be_coding) 40 | 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /util/encode.py: -------------------------------------------------------------------------------- 1 | from arithmeticcoding import * 2 | import base64 3 | 4 | 5 | def main(): 6 | string_to_be_coding = "ABCDEdkkjkljkljkkjkkkklmkmmmmmmmmmmmmaaaasldkkjlfjklajflksjlkfjskljklfjskljfdlskjfklsjlfjlsaacc" 7 | string_dict = dict() 8 | for symbol in string_to_be_coding: 9 | if symbol in string_dict: 10 | string_dict[symbol]["count"] += 1 11 | else: 12 | string_dict[symbol] = {"count": 1} 13 | 14 | char_table = [] 15 | freq_table = [] 16 | for key_id, key in enumerate(string_dict): 17 | char_table.append(key) 18 | freq_table.append(string_dict[key]["count"]) 19 | string_dict[key]["idx"] = key_id 20 | 21 | output_file = "./compress_code.bin" 22 | 23 | # build the output stream 24 | bitout = BitOutputStream(open(output_file, "wb+")) 25 | 26 | # build arithmetic encoder 27 | enc = ArithmeticEncoder(bitout) 28 | 29 | print("".join(char_table)) 30 | print(freq_table) 31 | freq = SimpleFrequencyTable(freq_table) 32 | 33 | # encoding string 34 | for symbol in string_to_be_coding: 35 | symbol_id = string_dict[symbol]["idx"] 36 | 37 | # build frequency table for coding 38 | # freq_table[-1] = freq_table[-1] + 1 39 | 40 | # import ipdb; ipdb.set_trace() 41 | 42 | enc.write(freq, symbol_id) 43 | 44 | enc.finish() 45 | bitout.close() 46 | 47 | 48 | if __name__ == "__main__": 49 | main() -------------------------------------------------------------------------------- /util/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import functools 3 | import logging 4 | import os 5 | import sys 6 | from termcolor import colored 7 | 8 | 9 | class _ColorfulFormatter(logging.Formatter): 10 | def __init__(self, *args, **kwargs): 11 | self._root_name = kwargs.pop("root_name") + "." 12 | self._abbrev_name = kwargs.pop("abbrev_name", "") 13 | if len(self._abbrev_name): 14 | self._abbrev_name = self._abbrev_name + "." 15 | super(_ColorfulFormatter, self).__init__(*args, **kwargs) 16 | 17 | def formatMessage(self, record): 18 | record.name = record.name.replace(self._root_name, self._abbrev_name) 19 | log = super(_ColorfulFormatter, self).formatMessage(record) 20 | if record.levelno == logging.WARNING: 21 | prefix = colored("WARNING", "red", attrs=["blink"]) 22 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: 23 | prefix = colored("ERROR", "red", attrs=["blink", "underline"]) 24 | else: 25 | return log 26 | return prefix + " " + log 27 | 28 | 29 | # so that calling setup_logger multiple times won't add many handlers 30 | @functools.lru_cache() 31 | def setup_logger( 32 | output=None, distributed_rank=0, *, color=True, name="imagenet", abbrev_name=None 33 | ): 34 | """ 35 | Initialize the detectron2 logger and set its verbosity level to "INFO". 36 | 37 | Args: 38 | output (str): a file name or a directory to save log. If None, will not save log file. 39 | If ends with ".txt" or ".log", assumed to be a file name. 40 | Otherwise, logs will be saved to `output/log.txt`. 41 | name (str): the root module name of this logger 42 | 43 | Returns: 44 | logging.Logger: a logger 45 | """ 46 | logger = logging.getLogger(name) 47 | logger.setLevel(logging.DEBUG) 48 | logger.propagate = False 49 | 50 | if abbrev_name is None: 51 | abbrev_name = name 52 | 53 | plain_formatter = logging.Formatter( 54 | '[%(asctime)s.%(msecs)03d]: %(message)s', 55 | datefmt='%m/%d %H:%M:%S' 56 | ) 57 | # stdout logging: master only 58 | if distributed_rank == 0: 59 | ch = logging.StreamHandler(stream=sys.stdout) 60 | ch.setLevel(logging.DEBUG) 61 | if color: 62 | formatter = _ColorfulFormatter( 63 | colored("[%(asctime)s.%(msecs)03d]: ", "green") + "%(message)s", 64 | datefmt="%m/%d %H:%M:%S", 65 | root_name=name, 66 | abbrev_name=str(abbrev_name), 67 | ) 68 | else: 69 | formatter = plain_formatter 70 | ch.setFormatter(formatter) 71 | logger.addHandler(ch) 72 | 73 | # file logging: all workers 74 | if output is not None: 75 | if output.endswith(".txt") or output.endswith(".log"): 76 | filename = output 77 | else: 78 | filename = os.path.join(output, "log.txt") 79 | if distributed_rank > 0: 80 | filename = filename + f".rank{distributed_rank}" 81 | 82 | os.makedirs(os.path.dirname(filename), exist_ok=True) 83 | 84 | fh = logging.StreamHandler(_cached_log_stream(filename)) 85 | fh.setLevel(logging.DEBUG) 86 | fh.setFormatter(plain_formatter) 87 | logger.addHandler(fh) 88 | 89 | return logger 90 | 91 | 92 | # cache the opened file object, so that different calls to `setup_logger` 93 | # with the same file name can safely write to the same file. 94 | @functools.lru_cache(maxsize=None) 95 | def _cached_log_stream(filename): 96 | return open(filename, "a") 97 | -------------------------------------------------------------------------------- /util/slconfig.py: -------------------------------------------------------------------------------- 1 | # ========================================================== 2 | # Modified from mmcv 3 | # ========================================================== 4 | import os, sys 5 | import os.path as osp 6 | import ast 7 | import tempfile 8 | import shutil 9 | from importlib import import_module 10 | 11 | from argparse import Action 12 | 13 | from addict import Dict 14 | from yapf.yapflib.yapf_api import FormatCode 15 | 16 | BASE_KEY = '_base_' 17 | DELETE_KEY = '_delete_' 18 | RESERVED_KEYS = ['filename', 'text', 'pretty_text', 'get', 'dump', 'merge_from_dict'] 19 | 20 | 21 | def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): 22 | if not osp.isfile(filename): 23 | raise FileNotFoundError(msg_tmpl.format(filename)) 24 | 25 | class ConfigDict(Dict): 26 | 27 | def __missing__(self, name): 28 | raise KeyError(name) 29 | 30 | def __getattr__(self, name): 31 | try: 32 | value = super(ConfigDict, self).__getattr__(name) 33 | except KeyError: 34 | ex = AttributeError(f"'{self.__class__.__name__}' object has no " 35 | f"attribute '{name}'") 36 | except Exception as e: 37 | ex = e 38 | else: 39 | return value 40 | raise ex 41 | 42 | 43 | class SLConfig(object): 44 | """ 45 | config files. 46 | only support .py file as config now. 47 | 48 | ref: mmcv.utils.config 49 | 50 | Example: 51 | >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) 52 | >>> cfg.a 53 | 1 54 | >>> cfg.b 55 | {'b1': [0, 1]} 56 | >>> cfg.b.b1 57 | [0, 1] 58 | >>> cfg = Config.fromfile('tests/data/config/a.py') 59 | >>> cfg.filename 60 | "/home/kchen/projects/mmcv/tests/data/config/a.py" 61 | >>> cfg.item4 62 | 'test' 63 | >>> cfg 64 | "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: " 65 | "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}" 66 | """ 67 | @staticmethod 68 | def _validate_py_syntax(filename): 69 | with open(filename) as f: 70 | content = f.read() 71 | try: 72 | ast.parse(content) 73 | except SyntaxError: 74 | raise SyntaxError('There are syntax errors in config ' 75 | f'file {filename}') 76 | 77 | @staticmethod 78 | def _file2dict(filename): 79 | filename = osp.abspath(osp.expanduser(filename)) 80 | check_file_exist(filename) 81 | if filename.lower().endswith('.py'): 82 | with tempfile.TemporaryDirectory() as temp_config_dir: 83 | temp_config_file = tempfile.NamedTemporaryFile( 84 | dir=temp_config_dir, suffix='.py') 85 | temp_config_name = osp.basename(temp_config_file.name) 86 | shutil.copyfile(filename, 87 | osp.join(temp_config_dir, temp_config_name)) 88 | temp_module_name = osp.splitext(temp_config_name)[0] 89 | sys.path.insert(0, temp_config_dir) 90 | SLConfig._validate_py_syntax(filename) 91 | mod = import_module(temp_module_name) 92 | sys.path.pop(0) 93 | cfg_dict = { 94 | name: value 95 | for name, value in mod.__dict__.items() 96 | if not name.startswith('__') 97 | } 98 | # delete imported module 99 | del sys.modules[temp_module_name] 100 | # close temp file 101 | temp_config_file.close() 102 | elif filename.lower().endswith(('.yml', '.yaml', '.json')): 103 | from .slio import slload 104 | cfg_dict = slload(filename) 105 | else: 106 | raise IOError('Only py/yml/yaml/json type are supported now!') 107 | 108 | cfg_text = filename + '\n' 109 | with open(filename, 'r') as f: 110 | cfg_text += f.read() 111 | 112 | # parse the base file 113 | if BASE_KEY in cfg_dict: 114 | cfg_dir = osp.dirname(filename) 115 | base_filename = cfg_dict.pop(BASE_KEY) 116 | base_filename = base_filename if isinstance( 117 | base_filename, list) else [base_filename] 118 | 119 | cfg_dict_list = list() 120 | cfg_text_list = list() 121 | for f in base_filename: 122 | _cfg_dict, _cfg_text = SLConfig._file2dict(osp.join(cfg_dir, f)) 123 | cfg_dict_list.append(_cfg_dict) 124 | cfg_text_list.append(_cfg_text) 125 | 126 | base_cfg_dict = dict() 127 | for c in cfg_dict_list: 128 | if len(base_cfg_dict.keys() & c.keys()) > 0: 129 | raise KeyError('Duplicate key is not allowed among bases') 130 | # TODO Allow the duplicate key while warnning user 131 | base_cfg_dict.update(c) 132 | 133 | base_cfg_dict = SLConfig._merge_a_into_b(cfg_dict, base_cfg_dict) 134 | cfg_dict = base_cfg_dict 135 | 136 | # merge cfg_text 137 | cfg_text_list.append(cfg_text) 138 | cfg_text = '\n'.join(cfg_text_list) 139 | 140 | return cfg_dict, cfg_text 141 | 142 | @staticmethod 143 | def _merge_a_into_b(a, b): 144 | """merge dict `a` into dict `b` (non-inplace). 145 | values in `a` will overwrite `b`. 146 | copy first to avoid inplace modification 147 | 148 | Args: 149 | a ([type]): [description] 150 | b ([type]): [description] 151 | 152 | Returns: 153 | [dict]: [description] 154 | """ 155 | 156 | if not isinstance(a, dict): 157 | return a 158 | 159 | b = b.copy() 160 | for k, v in a.items(): 161 | if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False): 162 | 163 | if not isinstance(b[k], dict) and not isinstance(b[k], list): 164 | # if : 165 | 166 | raise TypeError( 167 | f'{k}={v} in child config cannot inherit from base ' 168 | f'because {k} is a dict in the child config but is of ' 169 | f'type {type(b[k])} in base config. You may set ' 170 | f'`{DELETE_KEY}=True` to ignore the base config') 171 | b[k] = SLConfig._merge_a_into_b(v, b[k]) 172 | elif isinstance(b, list): 173 | try: 174 | _ = int(k) 175 | except: 176 | raise TypeError( 177 | f'b is a list, ' 178 | f'index {k} should be an int when input but {type(k)}' 179 | ) 180 | b[int(k)] = SLConfig._merge_a_into_b(v, b[int(k)]) 181 | else: 182 | b[k] = v 183 | 184 | return b 185 | 186 | @staticmethod 187 | def fromfile(filename): 188 | cfg_dict, cfg_text = SLConfig._file2dict(filename) 189 | return SLConfig(cfg_dict, cfg_text=cfg_text, filename=filename) 190 | 191 | 192 | def __init__(self, cfg_dict=None, cfg_text=None, filename=None): 193 | if cfg_dict is None: 194 | cfg_dict = dict() 195 | elif not isinstance(cfg_dict, dict): 196 | raise TypeError('cfg_dict must be a dict, but ' 197 | f'got {type(cfg_dict)}') 198 | for key in cfg_dict: 199 | if key in RESERVED_KEYS: 200 | raise KeyError(f'{key} is reserved for config file') 201 | 202 | super(SLConfig, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict)) 203 | super(SLConfig, self).__setattr__('_filename', filename) 204 | if cfg_text: 205 | text = cfg_text 206 | elif filename: 207 | with open(filename, 'r') as f: 208 | text = f.read() 209 | else: 210 | text = '' 211 | super(SLConfig, self).__setattr__('_text', text) 212 | 213 | 214 | @property 215 | def filename(self): 216 | return self._filename 217 | 218 | @property 219 | def text(self): 220 | return self._text 221 | 222 | @property 223 | def pretty_text(self): 224 | 225 | indent = 4 226 | 227 | def _indent(s_, num_spaces): 228 | s = s_.split('\n') 229 | if len(s) == 1: 230 | return s_ 231 | first = s.pop(0) 232 | s = [(num_spaces * ' ') + line for line in s] 233 | s = '\n'.join(s) 234 | s = first + '\n' + s 235 | return s 236 | 237 | def _format_basic_types(k, v, use_mapping=False): 238 | if isinstance(v, str): 239 | v_str = f"'{v}'" 240 | else: 241 | v_str = str(v) 242 | 243 | if use_mapping: 244 | k_str = f"'{k}'" if isinstance(k, str) else str(k) 245 | attr_str = f'{k_str}: {v_str}' 246 | else: 247 | attr_str = f'{str(k)}={v_str}' 248 | attr_str = _indent(attr_str, indent) 249 | 250 | return attr_str 251 | 252 | def _format_list(k, v, use_mapping=False): 253 | # check if all items in the list are dict 254 | if all(isinstance(_, dict) for _ in v): 255 | v_str = '[\n' 256 | v_str += '\n'.join( 257 | f'dict({_indent(_format_dict(v_), indent)}),' 258 | for v_ in v).rstrip(',') 259 | if use_mapping: 260 | k_str = f"'{k}'" if isinstance(k, str) else str(k) 261 | attr_str = f'{k_str}: {v_str}' 262 | else: 263 | attr_str = f'{str(k)}={v_str}' 264 | attr_str = _indent(attr_str, indent) + ']' 265 | else: 266 | attr_str = _format_basic_types(k, v, use_mapping) 267 | return attr_str 268 | 269 | def _contain_invalid_identifier(dict_str): 270 | contain_invalid_identifier = False 271 | for key_name in dict_str: 272 | contain_invalid_identifier |= \ 273 | (not str(key_name).isidentifier()) 274 | return contain_invalid_identifier 275 | 276 | def _format_dict(input_dict, outest_level=False): 277 | r = '' 278 | s = [] 279 | 280 | use_mapping = _contain_invalid_identifier(input_dict) 281 | if use_mapping: 282 | r += '{' 283 | for idx, (k, v) in enumerate(input_dict.items()): 284 | is_last = idx >= len(input_dict) - 1 285 | end = '' if outest_level or is_last else ',' 286 | if isinstance(v, dict): 287 | v_str = '\n' + _format_dict(v) 288 | if use_mapping: 289 | k_str = f"'{k}'" if isinstance(k, str) else str(k) 290 | attr_str = f'{k_str}: dict({v_str}' 291 | else: 292 | attr_str = f'{str(k)}=dict({v_str}' 293 | attr_str = _indent(attr_str, indent) + ')' + end 294 | elif isinstance(v, list): 295 | attr_str = _format_list(k, v, use_mapping) + end 296 | else: 297 | attr_str = _format_basic_types(k, v, use_mapping) + end 298 | 299 | s.append(attr_str) 300 | r += '\n'.join(s) 301 | if use_mapping: 302 | r += '}' 303 | return r 304 | 305 | cfg_dict = self._cfg_dict.to_dict() 306 | text = _format_dict(cfg_dict, outest_level=True) 307 | # copied from setup.cfg 308 | yapf_style = dict( 309 | based_on_style='pep8', 310 | blank_line_before_nested_class_or_def=True, 311 | split_before_expression_after_opening_paren=True) 312 | text, _ = FormatCode(text, style_config=yapf_style, verify=True) 313 | 314 | return text 315 | 316 | 317 | def __repr__(self): 318 | return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}' 319 | 320 | def __len__(self): 321 | return len(self._cfg_dict) 322 | 323 | def __getattr__(self, name): 324 | # # debug 325 | # print('+'*15) 326 | # print('name=%s' % name) 327 | # print("addr:", id(self)) 328 | # # print('type(self):', type(self)) 329 | # print(self.__dict__) 330 | # print('+'*15) 331 | # if self.__dict__ == {}: 332 | # raise ValueError 333 | 334 | return getattr(self._cfg_dict, name) 335 | 336 | def __getitem__(self, name): 337 | return self._cfg_dict.__getitem__(name) 338 | 339 | def __setattr__(self, name, value): 340 | if isinstance(value, dict): 341 | value = ConfigDict(value) 342 | self._cfg_dict.__setattr__(name, value) 343 | 344 | def __setitem__(self, name, value): 345 | if isinstance(value, dict): 346 | value = ConfigDict(value) 347 | self._cfg_dict.__setitem__(name, value) 348 | 349 | def __iter__(self): 350 | return iter(self._cfg_dict) 351 | 352 | def dump(self, file=None): 353 | 354 | if file is None: 355 | return self.pretty_text 356 | else: 357 | with open(file, 'w') as f: 358 | f.write(self.pretty_text) 359 | 360 | def merge_from_dict(self, options): 361 | """Merge list into cfg_dict 362 | 363 | Merge the dict parsed by MultipleKVAction into this cfg. 364 | 365 | Examples: 366 | >>> options = {'model.backbone.depth': 50, 367 | ... 'model.backbone.with_cp':True} 368 | >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet')))) 369 | >>> cfg.merge_from_dict(options) 370 | >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict') 371 | >>> assert cfg_dict == dict( 372 | ... model=dict(backbone=dict(depth=50, with_cp=True))) 373 | 374 | Args: 375 | options (dict): dict of configs to merge from. 376 | """ 377 | option_cfg_dict = {} 378 | for full_key, v in options.items(): 379 | d = option_cfg_dict 380 | key_list = full_key.split('.') 381 | for subkey in key_list[:-1]: 382 | d.setdefault(subkey, ConfigDict()) 383 | d = d[subkey] 384 | subkey = key_list[-1] 385 | d[subkey] = v 386 | 387 | cfg_dict = super(SLConfig, self).__getattribute__('_cfg_dict') 388 | super(SLConfig, self).__setattr__( 389 | '_cfg_dict', SLConfig._merge_a_into_b(option_cfg_dict, cfg_dict)) 390 | 391 | # for multiprocess 392 | def __setstate__(self, state): 393 | self.__init__(state) 394 | 395 | 396 | def copy(self): 397 | return SLConfig(self._cfg_dict.copy()) 398 | 399 | def deepcopy(self): 400 | return SLConfig(self._cfg_dict.deepcopy()) 401 | 402 | 403 | class DictAction(Action): 404 | """ 405 | argparse action to split an argument into KEY=VALUE form 406 | on the first = and append to a dictionary. List options should 407 | be passed as comma separated values, i.e KEY=V1,V2,V3 408 | """ 409 | 410 | @staticmethod 411 | def _parse_int_float_bool(val): 412 | try: 413 | return int(val) 414 | except ValueError: 415 | pass 416 | try: 417 | return float(val) 418 | except ValueError: 419 | pass 420 | if val.lower() in ['true', 'false']: 421 | return True if val.lower() == 'true' else False 422 | if val.lower() in ['none', 'null']: 423 | return None 424 | return val 425 | 426 | def __call__(self, parser, namespace, values, option_string=None): 427 | options = {} 428 | for kv in values: 429 | key, val = kv.split('=', maxsplit=1) 430 | val = [self._parse_int_float_bool(v) for v in val.split(',')] 431 | if len(val) == 1: 432 | val = val[0] 433 | options[key] = val 434 | setattr(namespace, self.dest, options) 435 | 436 | -------------------------------------------------------------------------------- /util/utils.py: -------------------------------------------------------------------------------- 1 | class BestMetricSingle(): 2 | def __init__(self, init_res=0.0, better='large') -> None: 3 | self.init_res = init_res 4 | self.best_res = init_res 5 | self.best_ep = -1 6 | 7 | self.better = better 8 | assert better in ['large', 'small'] 9 | 10 | def isbetter(self, new_res, old_res): 11 | if self.better == 'large': 12 | return new_res > old_res 13 | if self.better == 'small': 14 | return new_res < old_res 15 | 16 | def update(self, new_res, ep): 17 | if self.isbetter(new_res, self.best_res): 18 | self.best_res = new_res 19 | self.best_ep = ep 20 | return True 21 | return False 22 | 23 | def __str__(self) -> str: 24 | return "best_res: {}\t best_ep: {}".format(self.best_res, self.best_ep) 25 | 26 | def __repr__(self) -> str: 27 | return self.__str__() 28 | 29 | def summary(self) -> dict: 30 | return { 31 | 'best_res': self.best_res, 32 | 'best_ep': self.best_ep, 33 | } 34 | 35 | 36 | 37 | class BestMetricHolder(): 38 | def __init__(self, init_res=0.0, better='large', use_ema=False) -> None: 39 | self.best_all = BestMetricSingle(init_res, better) 40 | self.use_ema = use_ema 41 | if use_ema: 42 | self.best_ema = BestMetricSingle(init_res, better) 43 | self.best_regular = BestMetricSingle(init_res, better) 44 | 45 | 46 | def update(self, new_res, epoch, is_ema=False): 47 | """ 48 | return if the results is the best. 49 | """ 50 | if not self.use_ema: 51 | return self.best_all.update(new_res, epoch) 52 | else: 53 | if is_ema: 54 | self.best_ema.update(new_res, epoch) 55 | return self.best_all.update(new_res, epoch) 56 | else: 57 | self.best_regular.update(new_res, epoch) 58 | return self.best_all.update(new_res, epoch) 59 | 60 | def summary(self): 61 | if not self.use_ema: 62 | return self.best_all.summary() 63 | 64 | res = {} 65 | res.update({f'all_{k}':v for k,v in self.best_all.summary().items()}) 66 | res.update({f'regular_{k}':v for k,v in self.best_regular.summary().items()}) 67 | res.update({f'ema_{k}':v for k,v in self.best_ema.summary().items()}) 68 | return res 69 | 70 | def __repr__(self) -> str: 71 | return json.dumps(self.summary(), indent=2) 72 | 73 | def __str__(self) -> str: 74 | return self.__repr__() 75 | 76 | 77 | 78 | class SmoothedValue(object): 79 | """Track a series of values and provide access to smoothed values over a 80 | window or the global series average. 81 | """ 82 | 83 | def __init__(self, window_size=20, fmt=None): 84 | if fmt is None: 85 | fmt = "{value:.4f} ({global_avg:.4f})" 86 | self.total = 0.0 87 | self.count = 0 88 | self.fmt = fmt 89 | self.value_list = [] 90 | self.window_size = window_size 91 | 92 | def update(self, value, n=1): 93 | if len(self.value_list) >= self.window_size: 94 | self.value_list = self.value_list[-self.window_size+1:] 95 | self.value_list.append(value) 96 | else: 97 | self.value_list.append(value) 98 | self.count += n 99 | self.total += value * n 100 | 101 | @property 102 | def avg(self): 103 | if not len(self.value_list): 104 | return 0 105 | return np.array(self.value_list).mean() 106 | 107 | @property 108 | def global_avg(self): 109 | return self.total / self.count 110 | 111 | @property 112 | def max(self): 113 | if not len(self.value_list): 114 | return 0 115 | return np.array(self.value_list).max() 116 | 117 | @property 118 | def value(self): 119 | return self.value_list[-1] 120 | 121 | def __str__(self): 122 | return self.fmt.format( 123 | avg=self.avg, 124 | global_avg=self.global_avg, 125 | max=self.max, 126 | value=self.value) 127 | --------------------------------------------------------------------------------