├── .gitignore ├── README.md ├── README_CN.md ├── cuda ├── wkv5_cuda.cu ├── wkv5_op.cpp ├── wkv6_cuda.cu ├── wkv6_op.cpp ├── wkv6infctx_cuda.cu ├── wkv6infctx_op.cpp ├── wkv6state_cuda.cu └── wkv6state_op.cpp ├── demo ├── demo-predict.sh └── demo-state-tuning.sh ├── fla ├── __init__.py ├── layers │ ├── __init__.py │ ├── abc.py │ ├── based.py │ ├── delta_net.py │ ├── gated_abc.py │ ├── gla.py │ ├── hgrn.py │ ├── hgrn2.py │ ├── linear_attn.py │ ├── multiscale_retention.py │ ├── rebased.py │ ├── rwkv6.py │ └── simple_gla.py ├── models │ ├── __init__.py │ ├── abc │ │ ├── __init__.py │ │ ├── configuration_abc.py │ │ └── modeling_abc.py │ ├── delta_net │ │ ├── __init__.py │ │ ├── configuration_delta_net.py │ │ └── modeling_delta_net.py │ ├── gla │ │ ├── __init__.py │ │ ├── configuration_gla.py │ │ └── modeling_gla.py │ ├── hgrn │ │ ├── __init__.py │ │ ├── configuration_hgrn.py │ │ └── modeling_hgrn.py │ ├── hgrn2 │ │ ├── __init__.py │ │ ├── configuration_hgrn2.py │ │ └── modeling_hgrn2.py │ ├── linear_attn │ │ ├── __init__.py │ │ ├── configuration_linear_attn.py │ │ └── modeling_linear_attn.py │ ├── mamba │ │ ├── __init__.py │ │ ├── configuration_mamba.py │ │ └── modeling_mamba.py │ ├── retnet │ │ ├── __init__.py │ │ ├── configuration_retnet.py │ │ └── modeling_retnet.py │ ├── rwkv6 │ │ ├── __init__.py │ │ ├── configuration_rwkv6.py │ │ └── modeling_rwkv6.py │ ├── transformer │ │ ├── __init__.py │ │ ├── configuration_transformer.py │ │ └── modeling_transformer.py │ └── utils.py ├── modules │ ├── __init__.py │ ├── activations.py │ ├── convolution.py │ ├── feature_map.py │ ├── fused_cross_entropy.py │ ├── fused_norm_gate.py │ ├── l2norm.py │ ├── layernorm.py │ └── rotary.py ├── ops │ ├── __init__.py │ ├── abc │ │ ├── __init__.py │ │ ├── chunk.py │ │ ├── chunk_gate.py │ │ ├── naive.py │ │ └── recurrent_fuse.py │ ├── based │ │ ├── __init__.py │ │ ├── chunk_fuse.py │ │ ├── naive.py │ │ └── parallel.py │ ├── delta_rule │ │ ├── README.md │ │ ├── __init__.py │ │ ├── chunk.py │ │ ├── chunk_fuse.py │ │ ├── naive.py │ │ ├── recurrent_fuse.py │ │ ├── utils.py │ │ └── wy_fast.py │ ├── gla │ │ ├── __init__.py │ │ ├── chunk.py │ │ ├── chunk_fuse.py │ │ ├── chunk_util.py │ │ ├── naive.py │ │ └── recurrent_fuse.py │ ├── hgrn │ │ ├── __init__.py │ │ ├── chunk.py │ │ ├── naive.py │ │ └── recurrent_fuse.py │ ├── linear_attn │ │ ├── __init__.py │ │ ├── chunk.py │ │ ├── chunk_fuse.py │ │ ├── naive.py │ │ └── recurrent_fuse.py │ ├── rebased │ │ ├── __init__.py │ │ ├── naive.py │ │ └── parallel.py │ ├── retention │ │ ├── __init__.py │ │ ├── chunk.py │ │ ├── chunk_fuse.py │ │ ├── naive.py │ │ ├── parallel.py │ │ └── recurrent_fuse.py │ ├── rotary.py │ ├── rwkv4 │ │ ├── __init__.py │ │ └── recurrent_fuse.py │ ├── rwkv6 │ │ ├── __init__.py │ │ ├── chunk.py │ │ ├── chunk_naive.py │ │ ├── recurrent_fuse.py │ │ └── recurrent_naive.py │ ├── simple_gla │ │ ├── README.md │ │ ├── __init__.py │ │ ├── chunk.py │ │ └── naive.py │ └── utils.py └── utils.py ├── merge ├── merge.py ├── merge_lora.py ├── merge_pissa.py └── merge_state.py ├── output └── model output dir.txt ├── requirements.txt ├── src ├── __init__.py ├── asr.py ├── binidx.py ├── dataset2.py ├── infctx_module.py ├── model.py ├── rwkvLinear.py ├── speech_encoder.py ├── trainer.py └── utils.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | wandb/ 3 | src/__pycache__/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Speech missions with frozen RWKV language models 2 | 3 | - [中文说明](README_CN.md) 4 | - [English](README.md) 5 | 6 | This repo is an exploratory experiment to enable frozen pretrained RWKV language models to accept speech modality input. Generally, LLMs trained on text data are not directly applicable to speech recognition tasks, and there are many solutions (such as adapters + pretrained audio encoders, or neural audio codecs) to bridge the gap between text and speech. We followed the idea of [SLAM_ASR](https://arxiv.org/abs/2402.08846) and used the RWKV language model as the LLM, and instead of directly writing a prompt template we directly finetuned the initial state of the RWKV model. We were able to achieve 4.6% WER on Librispeech 960h Clean test set (6.9% on Other test) with a 3B RWKV model. 7 | 8 | This code inside is developed on [RWKV-PEFT](https://github.com/JL-er/RWKV-PEFT). And the current implementation of speech encoder and adapter is based on [SLAM_ASR](https://arxiv.org/abs/2402.08846#). 9 | 10 | ### Roadmap 11 | 12 | We want to explore compute-efficient and high-performance ways to extend text-based RWKV into multimodal ones. In the audio and speech modality, these are the tasks we are attempting: 13 | 14 | - [x] ASR in single language 15 | - [x] ASR in many languages 16 | - [x] Speech Translation 17 | - [x] Voice input question answering (like GPT-4o) 18 | - [ ] Other audio missions 19 | - [ ] Multiple turns answering 20 | 21 | ### Environment 22 | 23 | The following command will create a new conda environment and install the required packages: 24 | 25 | ```bash 26 | conda create -n rwkv python=3.10 27 | conda activate rwkv 28 | pip install -r requirements.txt 29 | ``` 30 | 31 | ### Training 32 | 33 | 1. Download RWKV-6-World model files from one of the following links. We used the 3B model in our experiments, i.e. RWKV-x060-World-3B-v2.1-20240417-ctx4096.pth. 34 | 35 | - [Hugging Face](https://huggingface.co/BlinkDL/rwkv-6-world/tree/main) 36 | - [Hf Mirror (CN)](https://hf-mirror.com/BlinkDL/rwkv-6-world/tree/main) 37 | - [Modelscope](https://modelscope.cn/models/Blink_DL/rwkv-6-world/files) 38 | 39 | 2. Open ```demo/demo-state-tuning.sh```. Set ```OP=train``` for training and ```load_model=path/to/your/model/```. Modify ```n_layer``` and ```n_embd``` according to the table below: 40 | 41 | | Model | n_layer | n_embd | 42 | | --------- | ---- | ---- | 43 | | 1.6B | 24 | 2048 | 44 | | 3B | 32 | 2560 | 45 | | 7B | 32 | 4096 | 46 | | 14B | 61 | 4096 | 47 | 48 | Other parameters for training: 49 | | parameter | description | 50 | | --------- | ---- | 51 | | micro_bsz | batch size for each device | 52 | | epoch_steps | num of steps in 1 epoch. please modified as (dataset size / real batch size) | 53 | | device | num of GPU for training | 54 | 55 | The default setting will train a 3B rwkv model on librispeech 960h dataset, with 4 devices and a batch size of 4 per device (real batch size = 16). 56 | 57 | 3. The script will overwrite the .pth file in ```output/```. Make sure to save the needed .pth model files under this path to other dir before the training. 58 | 4. run ```sh demo/demo-state-tuning.sh``` to start the training process. 59 | 60 | The training process looks like this: 61 | 62 | - It first loads the provided RWKV model, and a speech encoder model from huggingface. An adapter and an initial state for RWKV model will be initialized randomly. 63 | - The (symbolically) simplified formula for this model is: 64 | 65 | ``` 66 | RWKV( [InitialState], [Adapter](SpeechEncoder(audio))) -> "The weather is good. " 67 | ``` 68 | 69 | Modules and variables in `[ ]` will be trained, the rest is all frozen. 70 | 71 | There are also some codes to enable other PEFT training of the whole model. Note that not all methods are fully adapted to speech modality training as of now, and we are still actively working on this. 72 | 73 | ### Evaluation 74 | 75 | Follow the instruction in Training, but modify ```OP=eval``` in ```demo/demo-state-tuning.sh```. The trained model in ```output/``` will be used to calculate the WER of the model in ```output/``` on the clean test set and the other test set of Librispeech. 76 | 77 | ### Audio File Prediction 78 | 79 | Open ```demo/demo-predict.sh``` and modify ```file_path=path/to/your/audio/file```. Run ```sh demo/demo-predict.sh``` to load trained weights in ```output/``` and predict the content of the input audio file. 80 | 81 | ### Pretrained weights 82 | 83 | Download the pretrained weights from the following link: 84 | 85 | ASR:https://huggingface.co/JerryAGENDD/RWKV-ASR/tree/main/ASR 86 | 87 | SpeechTranslate:https://huggingface.co/JerryAGENDD/RWKV-ASR/tree/main/ST 88 | 89 | SpeechQA:https://huggingface.co/JerryAGENDD/RWKV-ASR/tree/main/SpeechQA 90 | 91 | The pretrained weights contain the necessary parameters for the adapter and the RWKV initial state. These weights are trained using WavLM Large as the speech encoder and RWKV-3B as the language model (script default configuration). Place the weights in the ```output/``` directory for the script to load them. 92 | 93 | ### Speech Chat with RWKV 94 | 95 | A script for real-time speech conversation with RWKV: 96 | 97 | https://github.com/AGENDD/RWKV-SpeechChat 98 | 99 | You can use the trained weights to interact with RWKV in real time. -------------------------------------------------------------------------------- /README_CN.md: -------------------------------------------------------------------------------- 1 | ## 使用预训练的 RWKV 语言模型进行语音识别 2 | 3 | - [中文说明](README_CN.md) 4 | - [English](README.md) 5 | 6 | 本仓库是一个探索性实验,旨在使预训练的 RWKV 语言模型能够接受语音输入。通常,在文本数据上训练的 LLM 不直接适用于语音识别任务,有很多解决方案(例如适配器 + 预训练音频编码器或神经音频编解码器)可以弥合文本和语音之间的差距。我们遵循了 [SLAM_ASR](https://arxiv.org/abs/2402.08846) 的思路,使用 RWKV 语言模型作为 LLM,而不是直接编写提示模板,我们直接微调了 RWKV 模型的初始状态。在 Librispeech 960h Clean 测试集上,我们使用 3B RWKV 模型实现了 4.6% 的 WER(Other 测试集为 6.9%)。 7 | 8 | 本仓库的代码基于 [RWKV-PEFT](https://github.com/JL-er/RWKV-PEFT) 开发。当前的语音编码器和适配器实现基于 [SLAM_ASR](https://arxiv.org/abs/2402.08846#)。 9 | 10 | ### 路线图 11 | 12 | 我们希望探索计算效率高、性能优越的方式将基于文本的 RWKV 扩展到多模态模型。在音频和语音领域,我们正在尝试以下任务: 13 | 14 | - [x] 单语言 ASR 15 | - [x] 多语言 ASR 16 | - [x] 语音翻译 17 | - [x] 语音输入问答(如 GPT-4o) 18 | - [ ] 其他音频任务 19 | - [ ] 多轮对话 20 | 21 | ### 环境 22 | 23 | 以下命令将创建一个新的 conda 环境并安装所需的包: 24 | 25 | ```bash 26 | conda create -n rwkv python=3.10 27 | conda activate rwkv 28 | pip install -r requirements.txt 29 | ``` 30 | 31 | ### 训练 32 | 33 | 1. 从以下链接之一下载 RWKV-6-World 模型文件。我们在实验中使用了 3B 模型,即 RWKV-x060-World-3B-v2.1-20240417-ctx4096.pth。 34 | 35 | - [Hugging Face](https://huggingface.co/BlinkDL/rwkv-6-world/tree/main) 36 | - [Hf Mirror (CN)](https://hf-mirror.com/BlinkDL/rwkv-6-world/tree/main) 37 | - [Modelscope](https://modelscope.cn/models/Blink_DL/rwkv-6-world/files) 38 | 39 | 2. 打开 ```demo/demo-state-tuning.sh```。将 ```OP=train``` 设置为训练,并将 ```load_model=path/to/your/model/``` 设置为您的模型路径。根据以下表修改 ```n_layer``` 和 ```n_embd```: 40 | 41 | | 模型 | n_layer | n_embd | 42 | | --------- | ---- | ---- | 43 | | 1.6B | 24 | 2048 | 44 | | 3B | 32 | 2560 | 45 | | 7B | 32 | 4096 | 46 | | 14B | 61 | 4096 | 47 | 48 | 其他训练参数: 49 | | 参数 | 描述 | 50 | | --------- | ---- | 51 | | micro_bsz | 每个设备的批量大小 | 52 | | epoch_steps | 每个 epoch 的步骤数。请根据(数据集大小 / 实际批量大小)进行修改 | 53 | | device | 用于训练的 GPU 数量 | 54 | 55 | 默认设置将在 4 个设备上训练 3B rwkv 模型,每个设备的批量大小为 4(实际批量大小 = 16)。 56 | 57 | 3. 该脚本将覆盖 ```output/``` 中的 .pth 文件。确保在训练前将所需的 .pth 模型文件保存到其他目录下! 58 | 4. 运行 ```sh demo/demo-state-tuning.sh``` 以开始训练过程。 59 | 60 | 训练过程如下: 61 | 62 | - 它首先加载RWKV模型和从huggingface下载的语音编码模型。将随机初始化适配器和 RWKV 模型的初始状态。 63 | - 模型的(符号)简化公式如下: 64 | 65 | ``` 66 | RWKV( [InitialState], [Adapter](SpeechEncoder(audio))) -> "The weather is good. 67 | ``` 68 | 69 | 用`[ ]`包围的部分会被训练,其他参数是锁定的。 70 | 71 | 还有一些代码可以启用整个模型的其他 PEFT 训练。目前,我们还没有完全适配于语音模态训练,我们仍在积极开发中。 72 | 73 | ### 评估 74 | 75 | 参考训练的步骤,但设定`demo/demo-state-tuning.sh`里的`OP=eval`。保存在`output/`中的模型将被用于评估,脚本会计算Librispeech 960h Clean和Other测试集的WER。 76 | 77 | 78 | ### 音频文件预测 79 | 80 | 打开```demo/demo-predict.sh```并修改```file_path```为输入音频的路径。运行```sh demo/demo-predict.sh```来从```output/```加载训练权重并预测音频内容。 81 | 82 | ### 预训练权重 83 | 84 | 下载预训练权重,请访问以下链接: 85 | 86 | 语音识别:https://huggingface.co/JerryAGENDD/RWKV-ASR/tree/main/ASR 87 | 88 | 语音翻译:https://huggingface.co/JerryAGENDD/RWKV-ASR/tree/main/ST 89 | 90 | 语音问答:https://huggingface.co/JerryAGENDD/RWKV-ASR/tree/main/SpeechQA 91 | 92 | 预训练权重包含适配器和RWKV初始状态的必要参数。这些权重是使用WavLM Large作为语音编码器和RWKV-3B作为语言模型(脚本默认配置)进行训练的。请将权重放置在```output/```目录中,以便脚本加载它们。 93 | 94 | ### RWKV 语音对话 95 | 96 | 这是一个与 RWKV 进行实时语音对话的脚本: 97 | 98 | https://github.com/AGENDD/RWKV-SpeechChat 99 | 100 | 您可以使用训练后的权重与 RWKV 进行实时语音交互。 -------------------------------------------------------------------------------- /cuda/wkv5_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "ATen/ATen.h" 4 | typedef at::BFloat16 bf16; 5 | 6 | template 7 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, 8 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, 9 | F *__restrict__ const _y) 10 | { 11 | const int b = blockIdx.x / H; 12 | const int h = blockIdx.x % H; 13 | const int i = threadIdx.x; 14 | _w += h*_N_; 15 | _u += h*_N_; 16 | 17 | __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; 18 | float state[_N_] = {0}; 19 | 20 | __syncthreads(); 21 | w[i] = _w[i]; 22 | u[i] = float(_u[i]); 23 | __syncthreads(); 24 | 25 | for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) 26 | { 27 | __syncthreads(); 28 | r[i] = float(_r[t]); 29 | k[i] = float(_k[t]); 30 | __syncthreads(); 31 | 32 | const float v = float(_v[t]); 33 | float y = 0; 34 | 35 | #pragma unroll 36 | for (int j = 0; j < _N_; j+=4) 37 | { 38 | const float4& r_ = (float4&)(r[j]); 39 | const float4& k_ = (float4&)(k[j]); 40 | const float4& w_ = (float4&)(w[j]); 41 | const float4& u_ = (float4&)(u[j]); 42 | float4& s = (float4&)(state[j]); 43 | float4 x; 44 | 45 | x.x = k_.x * v; 46 | x.y = k_.y * v; 47 | x.z = k_.z * v; 48 | x.w = k_.w * v; 49 | 50 | y += r_.x * (u_.x * x.x + s.x); 51 | y += r_.y * (u_.y * x.y + s.y); 52 | y += r_.z * (u_.z * x.z + s.z); 53 | y += r_.w * (u_.w * x.w + s.w); 54 | 55 | s.x = s.x * w_.x + x.x; 56 | s.y = s.y * w_.y + x.y; 57 | s.z = s.z * w_.z + x.z; 58 | s.w = s.w * w_.w + x.w; 59 | } 60 | _y[t] = F(y); 61 | } 62 | } 63 | 64 | template 65 | __global__ void kernel_backward(const int B, const int T, const int C, const int H, 66 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const float *__restrict__ __w, const F *__restrict__ _u, const F *__restrict__ const _gy, 67 | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu) 68 | { 69 | const int b = blockIdx.x / H; 70 | const int h = blockIdx.x % H; 71 | const int i = threadIdx.x; 72 | _w += h*_N_; 73 | _u += h*_N_; 74 | __w += h*_N_; 75 | 76 | __shared__ float w_[_N_], u_[_N_]; 77 | __shared__ float r[_N_], k[_N_], v[_N_], gy[_N_]; 78 | __syncthreads(); 79 | w_[i] = _w[i]; 80 | u_[i] = float(_u[i]); 81 | __syncthreads(); 82 | 83 | const float w = w_[i]; 84 | const float ww = __w[i]; 85 | const float u = u_[i]; 86 | 87 | float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0}; 88 | 89 | float gw = 0, gu = 0; 90 | const int t000 = b*T*C + h*_N_ + i; 91 | const int t111 = (b+1)*T*C + h*_N_ + i; 92 | const int t222 = t111 - 2*C; 93 | 94 | for (int t = t000; t < t111; t += C) 95 | { 96 | __syncthreads(); 97 | v[i] = float(_v[t]); 98 | gy[i] = float(_gy[t]); 99 | __syncthreads(); 100 | 101 | const float k = float(_k[t]); 102 | float gr = 0, gu_ = 0; 103 | 104 | #pragma unroll 105 | for (int j = 0; j < _N_; j++) 106 | { 107 | float& s = state[j]; 108 | float x = k * v[j]; 109 | 110 | gr += (u * x + s) * gy[j]; 111 | gu_ += x * gy[j]; 112 | s = s * w + x; 113 | } 114 | _gr[t] = F(gr); 115 | gu += float(_r[t]) * gu_; 116 | } 117 | _gu[b*C + h*_N_ + i] = F(gu); 118 | 119 | for (int t = t000; t < t222; t += C) 120 | { 121 | __syncthreads(); 122 | v[i] = float(_v[t]); 123 | gy[i] = float(_gy[t + 2*C]); 124 | __syncthreads(); 125 | 126 | const float k = float(_k[t]); 127 | float gw_ = 0; 128 | 129 | #pragma unroll 130 | for (int j = 0; j < _N_; j++) 131 | { 132 | float& s = saaaa[j]; 133 | float& s2 = sbbbb[j]; 134 | float x = k * v[j]; 135 | 136 | float tmp = w * (x + s); 137 | s = tmp; 138 | s2 = tmp + w * s2; 139 | gw_ += s2 * gy[j]; 140 | } 141 | gw += float(_r[t + 2*C]) * gw_; 142 | } 143 | _gw[b*C + h*_N_ + i] = F(ww * gw); 144 | 145 | for (int t = t111 - C; t >= t000; t -= C) 146 | { 147 | __syncthreads(); 148 | v[i] = float(_v[t]); 149 | gy[i] = float(_gy[t]); 150 | __syncthreads(); 151 | 152 | const float rr = float(_r[t]); 153 | float gk = 0; 154 | 155 | #pragma unroll 156 | for (int j = 0; j < _N_; j++) 157 | { 158 | float& s = scccc[j]; 159 | float x = rr * gy[j]; 160 | 161 | gk += (u * x + s) * v[j]; 162 | s = x + s * w; 163 | } 164 | _gk[t] = F(gk); 165 | } 166 | 167 | for (int t = t111 - C; t >= t000; t -= C) 168 | { 169 | __syncthreads(); 170 | r[i] = float(_r[t]); 171 | k[i] = float(_k[t]); 172 | __syncthreads(); 173 | 174 | const float gyy = float(_gy[t]); 175 | float gv = 0; 176 | 177 | #pragma unroll 178 | for (int j = 0; j < _N_; j++) 179 | { 180 | float& s = sdddd[j]; 181 | float x = gyy * r[j]; 182 | 183 | gv += (u_[j] * x + s) * k[j]; 184 | s = x + s * w_[j]; 185 | } 186 | _gv[t] = F(gv); 187 | } 188 | } 189 | 190 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y) 191 | { 192 | assert(H*_N_ == C); 193 | assert(_N_%4 == 0); 194 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y); 195 | } 196 | 197 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu) 198 | { 199 | assert(H*_N_ == C); 200 | assert(_N_%4 == 0); 201 | kernel_backward<<>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu); 202 | } 203 | -------------------------------------------------------------------------------- /cuda/wkv5_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "ATen/ATen.h" 3 | typedef at::BFloat16 bf16; 4 | 5 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); 6 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu); 7 | 8 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 9 | cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); 10 | } 11 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) { 12 | cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), ww.data_ptr(), u.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr()); 13 | } 14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 15 | m.def("forward", &forward, "wkv5 forward"); 16 | m.def("backward", &backward, "wkv5 backward"); 17 | } 18 | 19 | TORCH_LIBRARY(wkv5, m) { 20 | m.def("forward", forward); 21 | m.def("backward", backward); 22 | } 23 | -------------------------------------------------------------------------------- /cuda/wkv6_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "ATen/ATen.h" 4 | typedef at::BFloat16 bf16; 5 | 6 | template 7 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, 8 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, 9 | F *__restrict__ const _y) 10 | { 11 | const int b = blockIdx.x / H; 12 | const int h = blockIdx.x % H; 13 | const int i = threadIdx.x; 14 | _u += h*_N_; 15 | 16 | __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; 17 | float state[_N_] = {0}; 18 | 19 | __syncthreads(); 20 | u[i] = float(_u[i]); 21 | __syncthreads(); 22 | 23 | for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) 24 | { 25 | __syncthreads(); 26 | w[i] = exp(_w[t]); 27 | r[i] = float(_r[t]); 28 | k[i] = float(_k[t]); 29 | __syncthreads(); 30 | 31 | const float v = float(_v[t]); 32 | float y = 0; 33 | 34 | #pragma unroll 35 | for (int j = 0; j < _N_; j+=4) 36 | { 37 | const float4& r_ = (float4&)(r[j]); 38 | const float4& k_ = (float4&)(k[j]); 39 | const float4& w_ = (float4&)(w[j]); 40 | const float4& u_ = (float4&)(u[j]); 41 | float4& s = (float4&)(state[j]); 42 | float4 x; 43 | 44 | x.x = k_.x * v; 45 | x.y = k_.y * v; 46 | x.z = k_.z * v; 47 | x.w = k_.w * v; 48 | 49 | y += r_.x * (u_.x * x.x + s.x); 50 | y += r_.y * (u_.y * x.y + s.y); 51 | y += r_.z * (u_.z * x.z + s.z); 52 | y += r_.w * (u_.w * x.w + s.w); 53 | 54 | s.x = s.x * w_.x + x.x; 55 | s.y = s.y * w_.y + x.y; 56 | s.z = s.z * w_.z + x.z; 57 | s.w = s.w * w_.w + x.w; 58 | } 59 | _y[t] = F(y); 60 | } 61 | } 62 | 63 | template 64 | __global__ void kernel_backward_111(const int B, const int T, const int C, const int H, 65 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy, 66 | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gu) 67 | { 68 | const int b = blockIdx.x / H; 69 | const int h = blockIdx.x % H; 70 | const int i = threadIdx.x; 71 | _u += h*_N_; 72 | 73 | __shared__ float u_[_N_]; 74 | __shared__ float r[_N_], k[_N_], v[_N_], w_[_N_], gy[_N_]; 75 | __syncthreads(); 76 | u_[i] = float(_u[i]); 77 | __syncthreads(); 78 | 79 | const float u = u_[i]; 80 | 81 | float state[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0}; 82 | 83 | const int t_0 = b*T*C + h*_N_ + i; 84 | const int t_T_1 = t_0 + (T-1)*C; 85 | const int t_T = t_0 + T*C; 86 | 87 | float gu = 0; 88 | for (int t = t_0; t < t_T; t += C) 89 | { 90 | __syncthreads(); 91 | v[i] = float(_v[t]); 92 | gy[i] = float(_gy[t]); 93 | __syncthreads(); 94 | 95 | const float k = float(_k[t]); 96 | const float w = exp(_w[t]); 97 | float gr = 0, gu_ = 0; 98 | 99 | #pragma unroll 100 | for (int j = 0; j < _N_; j++) 101 | { 102 | float& s = state[j]; 103 | float x = k * v[j]; 104 | 105 | gr += (u * x + s) * gy[j]; 106 | gu_ += x * gy[j]; 107 | s = s * w + x; 108 | } 109 | _gr[t] = F(gr); 110 | gu += float(_r[t]) * gu_; 111 | } 112 | _gu[b*C + h*_N_ + i] = F(gu); 113 | 114 | for (int t = t_T_1; t >= t_0; t -= C) 115 | { 116 | __syncthreads(); 117 | v[i] = float(_v[t]); 118 | gy[i] = float(_gy[t]); 119 | __syncthreads(); 120 | 121 | const float rr = float(_r[t]); 122 | const float w = exp(_w[t]); 123 | float gk = 0; 124 | 125 | #pragma unroll 126 | for (int j = 0; j < _N_; j++) 127 | { 128 | float& s = scccc[j]; 129 | float x = rr * gy[j]; 130 | 131 | gk += (u * x + s) * v[j]; 132 | s = x + s * w; 133 | } 134 | _gk[t] = F(gk); 135 | } 136 | 137 | for (int t = t_T_1; t >= t_0; t -= C) 138 | { 139 | __syncthreads(); 140 | r[i] = float(_r[t]); 141 | k[i] = float(_k[t]); 142 | w_[i] = exp(_w[t]); 143 | __syncthreads(); 144 | 145 | const float gyy = float(_gy[t]); 146 | float gv = 0; 147 | 148 | #pragma unroll 149 | for (int j = 0; j < _N_; j++) 150 | { 151 | float& s = sdddd[j]; 152 | float x = gyy * r[j]; 153 | 154 | gv += (u_[j] * x + s) * k[j]; 155 | s = x + s * w_[j]; 156 | } 157 | _gv[t] = F(gv); 158 | } 159 | } 160 | 161 | template 162 | __global__ void kernel_backward_222(const int B, const int T, const int C, const int H, 163 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy, 164 | F *__restrict__ const _gw) 165 | { 166 | const int b = blockIdx.x / H; 167 | const int h = blockIdx.x % H; 168 | const int i = threadIdx.x; 169 | 170 | __shared__ float v[_N_], gy[_N_]; 171 | float saaaa[_N_] = {0}, sbbbb[_T_-2] = {0}, scccc[_N_] = {0}; 172 | 173 | const int t_0 = b*T*C + h*_N_ + i; 174 | const int t_1 = t_0 + C; 175 | const int t_2 = t_0 + 2*C; 176 | const int t_T_1 = t_0 + (T-1)*C; 177 | 178 | for (int t = t_T_1; t > t_1; t -= C) 179 | { 180 | __syncthreads(); 181 | gy[i] = float(_gy[t]); 182 | v[i] = float(_v[t-2*C]); 183 | __syncthreads(); 184 | 185 | const float r = float(_r[t]); 186 | const float w = exp(_w[t-C]); 187 | float sum = 0.0f; 188 | 189 | #pragma unroll 190 | for (int j = 0; j < _N_; j++) 191 | { 192 | float& s = saaaa[j]; 193 | float x = r * gy[j]; 194 | s = (s + x) * w; 195 | sum += s * v[j]; 196 | } 197 | sbbbb[(t-t_2)/C] = sum * float(_k[t-2*C]); 198 | } 199 | 200 | float sss = sbbbb[0]; 201 | _gw[t_0] = 0; 202 | _gw[t_1] = F(sss * _w[t_1]); 203 | 204 | for (int t = t_2; t < t_T_1; t += C) 205 | { 206 | __syncthreads(); 207 | gy[i] = float(_gy[t]); 208 | v[i] = float(_v[t-2*C]); 209 | __syncthreads(); 210 | 211 | const float w = exp(_w[t-C]); 212 | const float k = float(_k[t-2*C]); 213 | float sum = 0.0f; 214 | 215 | #pragma unroll 216 | for (int j = 0; j < _N_; j++) 217 | { 218 | float& s = scccc[j]; 219 | float x = k * v[j]; 220 | s = (s + x) * w; 221 | sum += s * gy[j]; 222 | } 223 | sss += sbbbb[(t-t_1)/C] - (sum * float(_r[t])); 224 | _gw[t] = F(sss * _w[t]); 225 | } 226 | _gw[t_T_1] = 0; 227 | } 228 | 229 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y) 230 | { 231 | assert(H*_N_ == C); 232 | assert(_N_%4 == 0); 233 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y); 234 | } 235 | 236 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu) 237 | { 238 | assert(H*_N_ == C); 239 | assert(_N_%4 == 0); 240 | kernel_backward_111<<>>(B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gu); 241 | kernel_backward_222<<>>(B, T, C, H, r, k, v, w, u, gy, gw); 242 | } 243 | -------------------------------------------------------------------------------- /cuda/wkv6_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "ATen/ATen.h" 3 | typedef at::BFloat16 bf16; 4 | 5 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); 6 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu); 7 | 8 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 9 | cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); 10 | } 11 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) { 12 | cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr()); 13 | } 14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 15 | m.def("forward", &forward, "wkv6 forward"); 16 | m.def("backward", &backward, "wkv6 backward"); 17 | } 18 | 19 | TORCH_LIBRARY(wkv6, m) { 20 | m.def("forward", forward); 21 | m.def("backward", backward); 22 | } 23 | -------------------------------------------------------------------------------- /cuda/wkv6infctx_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "ATen/ATen.h" 3 | typedef at::BFloat16 bf16; 4 | 5 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *y); 6 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu, bf16 *gs); 7 | 8 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &y) { 9 | cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), s.data_ptr(), y.data_ptr()); 10 | } 11 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gs) { 12 | cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), s.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr(), gs.data_ptr()); 13 | } 14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 15 | m.def("forward", &forward, "wkv6state forward"); 16 | m.def("backward", &backward, "wkv6state backward"); 17 | } 18 | 19 | TORCH_LIBRARY(wkv6state, m) { 20 | m.def("forward", forward); 21 | m.def("backward", backward); 22 | } 23 | -------------------------------------------------------------------------------- /cuda/wkv6state_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "ATen/ATen.h" 3 | typedef at::BFloat16 bf16; 4 | 5 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *y); 6 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu, bf16 *gs); 7 | 8 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &y) { 9 | cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), s.data_ptr(), y.data_ptr()); 10 | } 11 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gs) { 12 | cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), s.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr(), gs.data_ptr()); 13 | } 14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 15 | m.def("forward", &forward, "wkv6state forward"); 16 | m.def("backward", &backward, "wkv6state backward"); 17 | } 18 | 19 | TORCH_LIBRARY(wkv6state, m) { 20 | m.def("forward", forward); 21 | m.def("backward", backward); 22 | } 23 | -------------------------------------------------------------------------------- /demo/demo-predict.sh: -------------------------------------------------------------------------------- 1 | 2 | # 3B 3 | load_model='RWKV-x060-World-3B-v2.1-20240417-ctx4096.pth' 4 | #7B 5 | # load_model='RWKV-x060-World-7B-v2.1-20240507-ctx4096.pth' 6 | 7 | #model output dir 8 | proj_dir='output' 9 | 10 | # 3B 11 | n_layer=32 12 | n_embd=2560 13 | 14 | # 7B 15 | # n_layer=32 16 | # n_embd=4096 17 | 18 | micro_bsz=4 19 | epoch_steps=18089 20 | ctx_len=1024 21 | device=4 22 | epoch_save=1 23 | 24 | file_path="path/to/your/audio/file" 25 | OP="predict" 26 | 27 | QUANT='nf4' 28 | export HF_ENDPOINT=https://hf-mirror.com 29 | python train.py --load_model $load_model --devices $device --file_path $path_file\ 30 | --proj_dir $proj_dir \ 31 | --data_type binidx --vocab_size 65536 \ 32 | --ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 1000 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \ 33 | --n_layer $n_layer --n_embd $n_embd \ 34 | --pre_ffn 0 --head_qk 0 --lr_init 1e-4 --lr_final 1e-4 --warmup_steps 100 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \ 35 | --accelerator gpu --strategy deepspeed_stage_1 --grad_cp 1 --op $OP \ 36 | --precision bf16 \ 37 | --my_testing "x060" \ 38 | --train_type "state" --dataload pad 39 | # --quant $QUANT 40 | -------------------------------------------------------------------------------- /demo/demo-state-tuning.sh: -------------------------------------------------------------------------------- 1 | 2 | # 3B 3 | load_model='RWKV-x060-World-3B-v2.1-20240417-ctx4096.pth' 4 | #7B 5 | # load_model='RWKV-x060-World-7B-v2.1-20240507-ctx4096.pth' 6 | 7 | #model output dir 8 | proj_dir='output' 9 | 10 | # 3B 11 | n_layer=32 12 | n_embd=2560 13 | 14 | # 7B 15 | # n_layer=32 16 | # n_embd=4096 17 | 18 | micro_bsz=4 19 | epoch_steps=18089 20 | ctx_len=1024 21 | device=4 22 | epoch_save=1 23 | 24 | OP="train" 25 | 26 | QUANT='nf4' 27 | 28 | python train.py --load_model $load_model --devices $device \ 29 | --proj_dir $proj_dir \ 30 | --data_type binidx --vocab_size 65536 \ 31 | --ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 1000 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \ 32 | --n_layer $n_layer --n_embd $n_embd \ 33 | --pre_ffn 0 --head_qk 0 --lr_init 1e-4 --lr_final 1e-4 --warmup_steps 100 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \ 34 | --accelerator gpu --strategy deepspeed_stage_1 --grad_cp 1 --op $OP \ 35 | --precision bf16 \ 36 | --my_testing "x060" \ 37 | --train_type "state" --dataload pad 38 | # --quant $QUANT 39 | -------------------------------------------------------------------------------- /fla/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from fla.layers import (ABCAttention, BasedLinearAttention, DeltaNet, 4 | GatedLinearAttention, HGRN2Attention, LinearAttention, 5 | MultiScaleRetention, ReBasedLinearAttention) 6 | from fla.models import (ABCForCausalLM, ABCModel, DeltaNetForCausalLM, 7 | DeltaNetModel, GLAForCausalLM, GLAModel, 8 | HGRN2ForCausalLM, HGRN2Model, HGRNForCausalLM, 9 | HGRNModel, LinearAttentionForCausalLM, 10 | LinearAttentionModel, RetNetForCausalLM, RetNetModel, 11 | RWKV6ForCausalLM, RWKV6Model, TransformerForCausalLM, 12 | TransformerModel) 13 | from fla.ops import (chunk_gla, chunk_retention, fused_chunk_based, 14 | fused_chunk_gla, fused_chunk_retention) 15 | 16 | __all__ = [ 17 | 'ABCAttention', 18 | 'BasedLinearAttention', 19 | 'DeltaNet', 20 | 'HGRN2Attention', 21 | 'GatedLinearAttention', 22 | 'LinearAttention', 23 | 'MultiScaleRetention', 24 | 'ReBasedLinearAttention', 25 | 'ABCForCausalLM', 26 | 'ABCModel', 27 | 'DeltaNetForCausalLM', 28 | 'DeltaNetModel', 29 | 'HGRNForCausalLM', 30 | 'HGRNModel', 31 | 'HGRN2ForCausalLM', 32 | 'HGRN2Model', 33 | 'GLAForCausalLM', 34 | 'GLAModel', 35 | 'LinearAttentionForCausalLM', 36 | 'LinearAttentionModel', 37 | 'RetNetForCausalLM', 38 | 'RetNetModel', 39 | 'RWKV6ForCausalLM', 40 | 'RWKV6Model', 41 | 'TransformerForCausalLM', 42 | 'TransformerModel', 43 | 'chunk_gla', 44 | 'chunk_retention', 45 | 'fused_chunk_based', 46 | 'fused_chunk_gla', 47 | 'fused_chunk_retention' 48 | ] 49 | 50 | __version__ = '0.1' 51 | -------------------------------------------------------------------------------- /fla/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .abc import ABCAttention 4 | from .based import BasedLinearAttention 5 | from .delta_net import DeltaNet 6 | from .gla import GatedLinearAttention 7 | from .hgrn import HGRNAttention 8 | from .hgrn2 import HGRN2Attention 9 | from .linear_attn import LinearAttention 10 | from .multiscale_retention import MultiScaleRetention 11 | from .rebased import ReBasedLinearAttention 12 | from .rwkv6 import RWKV6Attention 13 | 14 | __all__ = [ 15 | 'ABCAttention', 16 | 'BasedLinearAttention', 17 | 'DeltaNet', 18 | 'GatedLinearAttention', 19 | 'HGRNAttention', 20 | 'HGRN2Attention', 21 | 'LinearAttention', 22 | 'MultiScaleRetention', 23 | 'ReBasedLinearAttention', 24 | 'RWKV6Attention' 25 | ] 26 | -------------------------------------------------------------------------------- /fla/layers/abc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import annotations 4 | 5 | import warnings 6 | from typing import Optional, Tuple 7 | 8 | import torch 9 | import torch.nn as nn 10 | from einops import rearrange 11 | from transformers.cache_utils import Cache 12 | 13 | from fla.modules import (FusedRMSNormSwishGate, RMSNorm, RotaryEmbedding, 14 | ShortConvolution) 15 | from fla.modules.activations import swiglu, swish 16 | from fla.modules.convolution import proj_then_conv1d 17 | from fla.ops.abc.chunk import chunk_abc 18 | 19 | 20 | class ABCAttention(nn.Module): 21 | 22 | def __init__( 23 | self, 24 | hidden_size: int = 1024, 25 | expand_k: float = 0.5, 26 | expand_v: float = 1.0, 27 | num_heads: int = 4, 28 | use_short_conv: bool = False, 29 | conv_size: int = 4, 30 | conv_bias: bool = False, 31 | share_conv_kernel: bool = True, 32 | num_slots: Optional[int] = None, 33 | elementwise_affine: Optional[bool] = True, 34 | norm_eps: float = 1e-5, 35 | gate_low_rank_dim: int = 16, 36 | gate_logit_normalizer: int = 16, 37 | use_input_gate: bool = False, 38 | use_output_gate: bool = True, 39 | use_norm: bool = True, 40 | clamp_min: Optional[float] = -32, 41 | clamp_max: Optional[float] = 32, 42 | layer_idx: Optional[int] = None, 43 | **kwargs 44 | ) -> ABCAttention: 45 | super().__init__() 46 | 47 | self.hidden_size = hidden_size 48 | self.expand_k = expand_k 49 | self.expand_v = expand_v 50 | self.num_heads = num_heads 51 | self.key_dim = int(self.hidden_size * self.expand_k) 52 | self.value_dim = int(self.hidden_size * self.expand_v) 53 | self.head_k_dim = self.key_dim // self.num_heads 54 | self.head_v_dim = self.value_dim // self.num_heads 55 | 56 | self.use_short_conv = use_short_conv 57 | self.conv_size = conv_size 58 | self.conv_bias = conv_bias 59 | self.share_conv_kernel = share_conv_kernel 60 | 61 | self.gate_low_rank_dim = gate_low_rank_dim 62 | self.gate_logit_normalizer = gate_logit_normalizer 63 | 64 | self.use_input_gate = use_input_gate 65 | self.use_output_gate = use_output_gate 66 | self.use_norm = use_norm 67 | 68 | if num_slots is None: 69 | num_slots = self.head_k_dim 70 | self.num_slots = num_slots 71 | 72 | self.norm_eps = norm_eps 73 | 74 | self.clamp_min = clamp_min 75 | self.clamp_max = clamp_max 76 | self.layer_idx = layer_idx 77 | 78 | if layer_idx is None: 79 | warnings.warn( 80 | f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " 81 | "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " 82 | "when creating this class." 83 | ) 84 | 85 | self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False) 86 | self.k_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False) 87 | self.v_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False) 88 | 89 | if use_output_gate: 90 | self.g_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False) 91 | self.s_proj = nn.Linear(self.hidden_size, self.num_heads * self.num_slots, bias=False) 92 | self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) 93 | 94 | if use_short_conv: 95 | self.conv_size = conv_size 96 | if share_conv_kernel: 97 | self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu') 98 | else: 99 | self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu') 100 | self.k_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu') 101 | self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu') 102 | 103 | if self.use_norm: 104 | if self.use_output_gate: 105 | self.g_norm = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps) 106 | else: 107 | self.g_norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps) 108 | 109 | if self.use_rope: 110 | self.rotary = RotaryEmbedding(self.head_k_dim) 111 | 112 | self.apply(self._initialize_weights) 113 | 114 | def _initialize_weights(self, module: nn.Module): 115 | if getattr(module, "_is_hf_initialized", False): 116 | return 117 | if isinstance(module, nn.Linear): 118 | nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) 119 | if module.bias is not None: 120 | nn.init.zeros_(module.bias) 121 | module._is_hf_initialized = True 122 | 123 | def forward( 124 | self, 125 | hidden_states: torch.Tensor, 126 | attention_mask: Optional[torch.Tensor] = None, 127 | past_key_values: Optional[Cache] = None, 128 | use_cache: Optional[bool] = False, 129 | output_attentions: Optional[bool] = False, 130 | **kwargs 131 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: 132 | 133 | if self.use_short_conv: 134 | if self.share_conv_kernel: 135 | hidden_states = self.h_conv1d(hidden_states) 136 | q = self.q_proj(hidden_states) 137 | k = self.k_proj(hidden_states) 138 | v = self.v_proj(hidden_states) 139 | else: 140 | q = proj_then_conv1d(hidden_states, self.q_proj.weight, self.q_conv1d.weight, self.q_conv1d.bias) 141 | k = proj_then_conv1d(hidden_states, self.k_proj.weight, self.k_conv1d.weight, self.k_conv1d.bias) 142 | v = proj_then_conv1d(hidden_states, self.v_proj.weight, self.v_conv1d.weight, self.v_conv1d.bias) 143 | else: 144 | q = self.q_proj(hidden_states) 145 | k = self.k_proj(hidden_states) 146 | v = self.v_proj(hidden_states) 147 | 148 | if self.use_input_gate: 149 | q, k, v = map(lambda x: swish(x), (q, k, v)) 150 | 151 | if self.use_rope: 152 | q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads) 153 | k = rearrange(k, '... (h d) -> ... h d', h=self.num_heads) 154 | seqlen_offset = 0 155 | if past_key_values is not None: 156 | seqlen_offset = past_key_values.get_seq_length(self.layer_idx) 157 | q, k = self.rotary(q, k, seqlen_offset) 158 | q = rearrange(q, 'b n h d -> b h n d', h=self.num_heads) 159 | k = rearrange(k, 'b n h d -> b h n d', h=self.num_heads) 160 | else: 161 | q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_heads) 162 | k = rearrange(k, 'b n (h d) -> b h n d', h=self.num_heads) 163 | v = rearrange(v, 'b n (h d) -> b h n d', h=self.num_heads) 164 | 165 | # [batch_size, n_heads, seq_len, num_slots] 166 | s = rearrange(self.s_proj(hidden_states), 'b t (h m) -> b h t m', h=self.num_heads) 167 | s = s.clamp_(self.clamp_min, self.clamp_max) 168 | 169 | last_state = past_key_values[self.layer_idx] if use_cache else None 170 | o, last_state = chunk_abc(q, k, v, s, initial_state=last_state, output_final_state=use_cache) 171 | if past_key_values is not None and last_state is not None: 172 | past_key_values.update(last_state, self.layer_idx, q.shape[2]) 173 | 174 | o = rearrange(o, 'b h t d -> b t h d') 175 | if self.use_norm and not self.use_output_gate: 176 | o = self.g_norm(o) 177 | elif self.use_output_gate: 178 | g = rearrange(self.g_proj(hidden_states), 'b t (h d) -> b t h d', h=self.num_heads) 179 | o = self.g_norm(o, g) if self.use_norm else swiglu(g, o) 180 | o = rearrange(o, 'b t h d -> b t (h d)') 181 | o = self.o_proj(o) 182 | 183 | return o, None, past_key_values 184 | 185 | def init_state(self, batch_size: int) -> Tuple[torch.Tensor]: 186 | param = next(self.parameters()) 187 | state = tuple() 188 | if self.use_short_conv: 189 | state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),) 190 | state += (param.new_zeros(batch_size, self.num_heads, self.head_k_dim, self.num_slots), 191 | param.new_zeros(batch_size, self.num_heads, self.num_slots, self.head_v_dim)) 192 | return state 193 | 194 | def state_size(self, sequence_length: int = 2048): 195 | return self.num_heads * self.key_dim * self.head_v_dim 196 | -------------------------------------------------------------------------------- /fla/layers/based.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Linear attention in Based. 5 | https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | from einops import rearrange 11 | 12 | from fla.modules.feature_map import TaylorFeatureMap 13 | from fla.ops.based import parallel_based 14 | from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn 15 | 16 | 17 | class BasedLinearAttention(nn.Module): 18 | def __init__( 19 | self, 20 | hidden_size: int, 21 | l_max: int = 2048, 22 | feature_dim: int = 16, 23 | num_key_value_heads: int = 12, 24 | num_heads: int = 12, 25 | feature_name: str = "taylor_exp", 26 | eps: float = 1e-12, 27 | causal: bool = True, 28 | mode: str = "parallel", 29 | ): 30 | super().__init__() 31 | self.hidden_size 32 | self.l_max = l_max 33 | self.mode = mode 34 | assert self.mode in ["fused_chunk", "parallel", 'chunk'] 35 | 36 | # linear attention 37 | self.feature_name = feature_name 38 | self.feature_dim = feature_dim 39 | self.num_key_value_heads = num_key_value_heads 40 | self.num_heads = num_heads 41 | self.head_dim = self.hidden_size // self.num_key_value_heads 42 | self.causal = causal 43 | 44 | self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False) 45 | self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False) 46 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) 47 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) 48 | self.dropout = nn.Identity() 49 | self.feature_map = TaylorFeatureMap(feature_dim) 50 | self.eps = eps 51 | 52 | self.apply(self._initialize_weights) 53 | 54 | def _initialize_weights(self, module: nn.Module): 55 | if getattr(module, "_is_hf_initialized", False): 56 | return 57 | if isinstance(module, nn.Linear): 58 | nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) 59 | if module.bias is not None: 60 | nn.init.zeros_(module.bias) 61 | module._is_hf_initialized = True 62 | 63 | def forward(self, hidden_states: torch.Tensor, **kwargs): 64 | mode = self.mode 65 | q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) 66 | q, k, v = map(lambda x: rearrange(x, "b l (h d) -> b h l d", h=self.num_heads), [q, k, v]) 67 | if mode == "fused_chunk": 68 | q, k = self.feature_map(q), self.feature_map(k) 69 | o = fused_chunk_linear_attn(q, k, v, normalize=True, scale=1) 70 | elif mode == 'chunk': 71 | q, k = self.feature_map(q), self.feature_map(k) 72 | o = chunk_linear_attn(q, k, v, normalize=True, scale=1) 73 | elif mode == 'parallel': 74 | assert q.shape[-1] <= 128 75 | o = parallel_based(q, k, v, True, True) 76 | o = rearrange(o, "b h l d -> b l (h d)") 77 | o = self.o_proj(o) 78 | o = self.dropout(o) 79 | return o 80 | 81 | # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119 82 | 83 | def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs): 84 | """ 85 | x (torch.Tensor): tensor of shape (b, d, l) 86 | y (torch.Tensor): tensor of shape (b, d, l) 87 | """ 88 | # hidden_states = hidden_states.transpose(1, 2) 89 | b, l, _ = hidden_states.size() 90 | q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) 91 | 92 | q = q.view(b, l, self.num_heads, self.feature_dim).transpose(1, 2) 93 | k = k.view(b, l, self.num_key_value_heads, self.feature_dim).transpose(1, 2) 94 | v = v.view(b, l, self.num_key_value_heads, self.head_dim).transpose(1, 2) 95 | 96 | # Linear attention 97 | q, k = self.feature_map(q), self.feature_map(k) 98 | q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1) 99 | 100 | # Compute attention 101 | if self.causal: 102 | y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps)) 103 | else: 104 | y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps)) 105 | y = rearrange(y, 'b h l d -> b l (h d)') 106 | y = self.o_proj(y.to(hidden_states.dtype)) 107 | y = self.dropout(y) 108 | return y.to(hidden_states.dtype) 109 | 110 | 111 | if __name__ == '__main__': 112 | batch = 4 113 | seq_len = 1024 114 | hidden_size = 1024 115 | dtype = torch.float32 116 | x = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda().requires_grad_(True) 117 | dy = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda() 118 | model = BasedLinearAttention(hidden_size, mode='chunk').to(dtype).cuda() 119 | y = model(x) 120 | y.backward(dy, retain_graph=True) 121 | x_grad, x.grad = x.grad, None 122 | y2 = model.forward_reference(x) 123 | y2.backward(dy) 124 | assert y.allclose(y2, 0, 1e-4), breakpoint() 125 | assert x_grad.allclose(x.grad, 0, 1e-4), breakpoint() 126 | print("Pass") 127 | -------------------------------------------------------------------------------- /fla/layers/hgrn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # "Hierarchically Gated Recurrent Neural Network for Sequence Modeling" [https://arxiv.org/abs/2311.04823] 4 | 5 | from __future__ import annotations 6 | 7 | from typing import Optional, Tuple 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from einops import rearrange 13 | from transformers.cache_utils import Cache 14 | 15 | from fla.modules import FusedRMSNormSwishGate, ShortConvolution 16 | from fla.modules.activations import swiglu 17 | from fla.ops.hgrn import chunk_hgrn, fused_recurrent_hgrn 18 | 19 | 20 | class HGRNAttention(nn.Module): 21 | 22 | def __init__( 23 | self, 24 | mode: str = 'chunk', 25 | hidden_size: int = 1024, 26 | num_heads: Optional[int] = None, 27 | expand_ratio: Optional[int] = 1, 28 | use_short_conv: bool = False, 29 | conv_size: int = 4, 30 | conv_bias: bool = False, 31 | share_conv_kernel: bool = True, 32 | elementwise_affine: Optional[bool] = True, 33 | norm_eps: float = 1e-5, 34 | layer_idx: int = None 35 | ) -> HGRNAttention: 36 | super().__init__() 37 | 38 | self.mode = mode 39 | self.hidden_size = hidden_size 40 | self.num_heads = num_heads 41 | self.expand_ratio = expand_ratio 42 | self.input_dim = int(hidden_size * expand_ratio) 43 | self.head_dim = self.input_dim // self.num_heads 44 | 45 | self.use_short_conv = use_short_conv 46 | self.conv_size = conv_size 47 | self.conv_bias = conv_bias 48 | self.share_conv_kernel = share_conv_kernel 49 | 50 | self.layer_idx = layer_idx 51 | 52 | assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." 53 | assert self.hidden_size % num_heads == 0, f"hidden size must be divisible by num_heads of {num_heads}" 54 | 55 | self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False) 56 | self.f_proj = nn.Linear(hidden_size, self.input_dim, bias=False) 57 | self.g_proj = nn.Linear(hidden_size, self.input_dim, bias=False) 58 | 59 | if use_short_conv: 60 | self.conv_size = conv_size 61 | if share_conv_kernel: 62 | self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu') 63 | else: 64 | self.q_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu') 65 | self.f_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu') 66 | self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu') 67 | 68 | self.g_norm = FusedRMSNormSwishGate(self.input_dim, elementwise_affine, norm_eps) 69 | self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False) 70 | 71 | self.apply(self._initialize_weights) 72 | 73 | def _initialize_weights(self, module: nn.Module): 74 | if getattr(module, "_is_hf_initialized", False): 75 | return 76 | if isinstance(module, nn.Linear): 77 | nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) 78 | if module.bias is not None: 79 | nn.init.zeros_(module.bias) 80 | module._is_hf_initialized = True 81 | 82 | def forward( 83 | self, 84 | hidden_states: torch.Tensor, 85 | attention_mask: Optional[torch.Tensor] = None, 86 | past_key_values: Optional[Cache] = None, 87 | use_cache: Optional[bool] = False, 88 | output_attentions: Optional[bool] = False, 89 | lower_bound: Optional[torch.Tensor] = None, 90 | **kwargs 91 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: 92 | # launching the triton kernel for just one token will actually be slower 93 | mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode 94 | 95 | last_state = past_key_values[self.layer_idx] if use_cache else None 96 | if self.use_short_conv: 97 | conv_state = last_state[0] if use_cache else None 98 | if self.share_conv_kernel: 99 | # conv state is updated inplace 100 | hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state) 101 | i = self.i_proj(hidden_states) 102 | f = self.f_proj(hidden_states) 103 | else: 104 | conv_state_i = last_state[2] if use_cache else None 105 | conv_state_f = last_state[1] if use_cache else None 106 | i = self.i_conv1d(self.i_proj(hidden_states), attention_mask, conv_state_i) 107 | f = self.f_conv1d(self.f_proj(hidden_states), attention_mask, conv_state_f) 108 | else: 109 | i = self.i_proj(hidden_states) 110 | f = self.f_proj(hidden_states) 111 | 112 | # the lower bound for the first layer is zero 113 | if lower_bound is None or self.layer_idx == 0: 114 | i, f = swiglu(i, 1 - f.sigmoid()), F.logsigmoid(f) 115 | else: 116 | g = lower_bound + (1 - lower_bound) * f.sigmoid() 117 | i, f = swiglu(i, 1 - g), g.log() 118 | 119 | # dealing with left-padding 120 | if attention_mask is not None: 121 | i = i.mul_(attention_mask.unsqueeze(-1)) 122 | i, f = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (i, f)) 123 | 124 | recurrent_state = last_state[-1] if use_cache else None 125 | if mode == 'chunk': 126 | o, recurrent_state = chunk_hgrn(i, f, initial_state=recurrent_state, output_final_state=use_cache) 127 | elif mode == 'fused_recurrent': 128 | o, recurrent_state = fused_recurrent_hgrn(i, f, initial_state=recurrent_state, output_final_state=use_cache) 129 | else: 130 | raise NotImplementedError(f"Not supported mode `{mode}`.") 131 | 132 | if past_key_values is not None: 133 | if self.use_short_conv: 134 | if self.share_conv_kernel: 135 | last_state = (conv_state, recurrent_state) 136 | else: 137 | last_state = (conv_state_i, conv_state_f, recurrent_state) 138 | else: 139 | last_state = (recurrent_state,) 140 | past_key_values.update(last_state, self.layer_idx, i.shape[2]) 141 | 142 | o = self.g_norm(self.g_proj(hidden_states), rearrange(o, 'b h l d -> b l (h d)')) 143 | o = self.o_proj(o) 144 | 145 | return o, None, past_key_values 146 | 147 | def init_state(self, batch_size: int) -> Tuple[torch.Tensor]: 148 | param = next(self.parameters()) 149 | state = tuple() 150 | if self.use_short_conv: 151 | if self.share_conv_kernel: 152 | state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),) 153 | else: 154 | state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size), 155 | param.new_zeros(batch_size, self.hidden_size, self.conv_size), 156 | param.new_zeros(batch_size, self.hidden_size, self.conv_size)) 157 | state += (param.new_zeros(batch_size, self.num_heads, self.head_dim),) 158 | return state 159 | 160 | def state_size(self, **kwargs) -> int: 161 | state_size = self.hidden_size 162 | for module in self.children(): 163 | if isinstance(module, ShortConvolution): 164 | state_size += module.state_size 165 | return state_size 166 | -------------------------------------------------------------------------------- /fla/layers/hgrn2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # "HGRN2: Gated Linear RNNs with State Expansion"[https://arxiv.org/abs/2404.07904] 4 | 5 | from __future__ import annotations 6 | 7 | from typing import Optional, Tuple 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from einops import rearrange 13 | from transformers.cache_utils import Cache 14 | 15 | from fla.modules import RMSNorm, ShortConvolution 16 | from fla.modules.activations import swish 17 | from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla 18 | 19 | 20 | class HGRN2Attention(nn.Module): 21 | 22 | def __init__( 23 | self, 24 | mode: str = 'chunk', 25 | hidden_size: int = 1024, 26 | num_heads: Optional[int] = None, 27 | expand_ratio: Optional[int] = 128, 28 | use_short_conv: bool = False, 29 | conv_size: int = 4, 30 | conv_bias: bool = False, 31 | share_conv_kernel: bool = True, 32 | elementwise_affine: Optional[bool] = True, 33 | norm_eps: float = 1e-5, 34 | layer_idx: int = None 35 | ) -> HGRN2Attention: 36 | super().__init__() 37 | 38 | self.mode = mode 39 | self.hidden_size = hidden_size 40 | 41 | if expand_ratio is None and num_heads is not None: 42 | expand_ratio = hidden_size // num_heads 43 | elif expand_ratio is not None and num_heads is None: 44 | num_heads = hidden_size // expand_ratio 45 | else: 46 | raise RuntimeError("One of `expand_ratio` or `num_heads` should be provided.") 47 | self.num_heads = num_heads 48 | self.expand_ratio = expand_ratio 49 | 50 | self.use_short_conv = use_short_conv 51 | self.conv_size = conv_size 52 | self.conv_bias = conv_bias 53 | self.share_conv_kernel = share_conv_kernel 54 | 55 | self.forget_dim = int(self.num_heads * self.expand_ratio) 56 | self.input_dim = hidden_size 57 | self.layer_idx = layer_idx 58 | 59 | assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`." 60 | assert self.forget_dim % num_heads == 0, f"forget dim must be divisible by num_heads of {num_heads}" 61 | assert self.input_dim % num_heads == 0, f"input dim must be divisible by num_heads of {num_heads}" 62 | 63 | self.head_f_dim = self.expand_ratio 64 | self.head_i_dim = self.hidden_size // num_heads 65 | 66 | self.q_proj = nn.Linear(hidden_size, self.forget_dim, bias=False) 67 | self.f_proj = nn.Linear(hidden_size, self.forget_dim, bias=False) 68 | self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False) 69 | 70 | if use_short_conv: 71 | self.conv_size = conv_size 72 | if share_conv_kernel: 73 | self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu') 74 | else: 75 | self.q_conv1d = ShortConvolution(self.forget_dim, conv_size, activation='silu') 76 | self.f_conv1d = ShortConvolution(self.forget_dim, conv_size, activation='silu') 77 | self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu') 78 | 79 | self.g_norm = RMSNorm(self.hidden_size, elementwise_affine, norm_eps) 80 | self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False) 81 | 82 | self.apply(self._initialize_weights) 83 | 84 | def _initialize_weights(self, module: nn.Module): 85 | if getattr(module, "_is_hf_initialized", False): 86 | return 87 | if isinstance(module, nn.Linear): 88 | nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) 89 | if module.bias is not None: 90 | nn.init.zeros_(module.bias) 91 | module._is_hf_initialized = True 92 | 93 | def forward( 94 | self, 95 | hidden_states: torch.Tensor, 96 | attention_mask: Optional[torch.Tensor] = None, 97 | past_key_values: Optional[Cache] = None, 98 | use_cache: Optional[bool] = False, 99 | output_attentions: Optional[bool] = False, 100 | lower_bound: Optional[torch.Tensor] = None, 101 | **kwargs 102 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: 103 | # launching the triton kernel for just one token will actually be slower 104 | mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode 105 | 106 | last_state = past_key_values[self.layer_idx] if use_cache else None 107 | if self.use_short_conv: 108 | conv_state = last_state[0] if use_cache else None 109 | if self.share_conv_kernel: 110 | # conv state is updated inplace 111 | hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state) 112 | q = self.q_proj(hidden_states) 113 | f = self.f_proj(hidden_states) 114 | i = self.i_proj(hidden_states) 115 | else: 116 | conv_state_q = last_state[0] if use_cache else None 117 | conv_state_f = last_state[1] if use_cache else None 118 | conv_state_i = last_state[2] if use_cache else None 119 | q = self.q_proj(hidden_states) 120 | f = self.f_proj(hidden_states) 121 | i = self.i_proj(hidden_states) 122 | q = self.q_conv1d(q, attention_mask, conv_state_q) 123 | f = self.f_conv1d(f, attention_mask, conv_state_f) 124 | i = self.i_conv1d(i, attention_mask, conv_state_i) 125 | else: 126 | q = self.q_proj(hidden_states) 127 | f = self.f_proj(hidden_states) 128 | i = self.i_proj(hidden_states) 129 | 130 | # dealing with left-padding 131 | if attention_mask is not None: 132 | i = i.mul_(attention_mask.unsqueeze(-1)) 133 | 134 | q = swish(q) 135 | # the lower bound for the first layer is zero 136 | if lower_bound is None or self.layer_idx == 0: 137 | k, g = 1 - f.sigmoid(), F.logsigmoid(f) 138 | else: 139 | g = lower_bound + (1 - lower_bound) * f.sigmoid() 140 | k, g = 1 - g, g.log() 141 | q, k, i, g = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (q, k, i, g)) 142 | 143 | recurrent_state = last_state[-1] if use_cache else None 144 | if mode == 'fused_recurrent': 145 | o, recurrent_state = fused_recurrent_gla(q, k, i, g, initial_state=recurrent_state, output_final_state=use_cache) 146 | elif mode == 'fused_chunk': 147 | o, recurrent_state = fused_chunk_gla(q, k, i, g, initial_state=recurrent_state, output_final_state=use_cache) 148 | elif mode == 'chunk': 149 | o, recurrent_state = chunk_gla(q, k, i, g, initial_state=recurrent_state, output_final_state=use_cache) 150 | else: 151 | raise NotImplementedError(f"Not supported mode `{mode}`.") 152 | 153 | if past_key_values is not None: 154 | if self.use_short_conv: 155 | if self.share_conv_kernel: 156 | last_state = (conv_state, recurrent_state) 157 | else: 158 | last_state = (conv_state_q, conv_state_f, conv_state_i, recurrent_state) 159 | else: 160 | last_state = (recurrent_state,) 161 | past_key_values.update(last_state, self.layer_idx, q.shape[2]) 162 | 163 | o = self.g_norm(rearrange(o, 'b h l d -> b l (h d)')) 164 | o = self.o_proj(o) 165 | 166 | return o, None, past_key_values 167 | 168 | def init_state(self, batch_size: int) -> Tuple[torch.Tensor]: 169 | param = next(self.parameters()) 170 | state = tuple() 171 | if self.use_short_conv: 172 | if self.share_conv_kernel: 173 | state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),) 174 | else: 175 | state += (param.new_zeros(batch_size, self.forget_dim, self.conv_size), 176 | param.new_zeros(batch_size, self.forget_dim, self.conv_size), 177 | param.new_zeros(batch_size, self.input_dim, self.conv_size)) 178 | state += (param.new_zeros(batch_size, self.num_heads, self.head_f_dim, self.head_i_dim),) 179 | return state 180 | 181 | def state_size(self, **kwargs) -> int: 182 | state_size = self.forget_dim * self.head_i_dim 183 | for module in self.children(): 184 | if isinstance(module, ShortConvolution): 185 | state_size += module.state_size 186 | return state_size 187 | -------------------------------------------------------------------------------- /fla/layers/linear_attn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | 7 | from fla.modules import RMSNorm 8 | from fla.modules.feature_map import (DPFPFeatureMap, HadamardFeatureMap, 9 | HedgehogFeatureMap, T2RFeatureMap) 10 | from fla.ops.linear_attn import (chunk_linear_attn, fused_chunk_linear_attn, 11 | fused_recurrent_linear_attn) 12 | 13 | 14 | class LinearAttention(nn.Module): 15 | def __init__( 16 | self, 17 | hidden_size: str = 1024, 18 | expand_k: int = 1.0, 19 | expand_v: int = 1.0, 20 | num_heads: int = 8, 21 | mode: str = 'chunk', 22 | feature_map: str = 'elementwise_product', 23 | tie_feature_map_qk: bool = False, 24 | output_norm: str = 'rmsnorm', 25 | norm_q: bool = False, 26 | norm_k: bool = False, 27 | # standard linear attention normalization 28 | do_feature_map_norm: bool = False, 29 | elementwise_affine: bool = True, 30 | norm_eps: float = 1e-5, 31 | **kwargs, 32 | ): 33 | super().__init__() 34 | assert feature_map in ['elu', 'relu', 'hedgehog', 't2r', 'dpfp', 35 | 'identity', 'elementwise_product'], f"Not supported feature map `{feature_map}`." 36 | 37 | assert output_norm in ['rmsnorm', 'identity'], f"Not supported output norm `{output_norm}`." 38 | 39 | self.hidden_size 40 | self.mode = mode 41 | self.key_dim = int(hidden_size * expand_k) 42 | self.value_dim = int(hidden_size * expand_v) 43 | self.num_heads = num_heads 44 | 45 | assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." 46 | assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" 47 | assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" 48 | 49 | self.head_qk_dim = self.key_dim // num_heads 50 | self.head_v_dim = self.value_dim // num_heads 51 | 52 | if feature_map == 'hedgehog': 53 | if tie_feature_map_qk: 54 | self.feature_map_q = self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_qk_dim) 55 | else: 56 | self.feature_map_q = HedgehogFeatureMap(head_dim=self.head_qk_dim) 57 | self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_qk_dim) 58 | 59 | elif feature_map == 't2r': 60 | if tie_feature_map_qk: 61 | self.feature_map_q = self.feature_map_k = T2RFeatureMap(head_dim=self.head_qk_dim) 62 | else: 63 | self.feature_map_q = T2RFeatureMap(head_dim=self.head_qk_dim) 64 | self.feature_map_k = T2RFeatureMap(head_dim=self.head_qk_dim) 65 | 66 | elif feature_map == 'elementwise_product': 67 | if tie_feature_map_qk: 68 | self.feature_map_q = self.feature_map_k = HadamardFeatureMap(head_dim=self.head_qk_dim) 69 | else: 70 | self.feature_map_q = HadamardFeatureMap(head_dim=self.head_qk_dim) 71 | self.feature_map_k = HadamardFeatureMap(head_dim=self.head_qk_dim) 72 | 73 | elif feature_map == 'dpfp': 74 | self.feature_map_q = DPFPFeatureMap(head_dim=self.head_qk_dim) 75 | self.feature_map_k = DPFPFeatureMap(head_dim=self.head_qk_dim) 76 | 77 | elif feature_map == 'elu': 78 | def elu(x): 79 | return F.elu(x) + 1 80 | self.feature_map_q = elu 81 | self.feature_map_k = elu 82 | 83 | elif feature_map == 'relu': 84 | self.feature_map_q = nn.ReLU() 85 | self.feature_map_k = nn.ReLU() 86 | 87 | elif feature_map == 'identity': 88 | self.feature_map_q = nn.Identity() 89 | self.feature_map_k = nn.Identity() 90 | else: 91 | raise NotImplementedError 92 | 93 | self.do_feature_map_norm = do_feature_map_norm 94 | if output_norm == 'rmsnorm': 95 | self.norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps) 96 | elif output_norm == 'identity': 97 | self.norm = nn.Identity() 98 | else: 99 | raise NotImplementedError 100 | 101 | self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) 102 | self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) 103 | self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) 104 | self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) 105 | 106 | self.norm_q = norm_q 107 | self.norm_k = norm_k 108 | 109 | self.apply(self._initialize_weights) 110 | 111 | def _initialize_weights(self, module: nn.Module): 112 | if getattr(module, "_is_hf_initialized", False): 113 | return 114 | if isinstance(module, nn.Linear): 115 | nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) 116 | if module.bias is not None: 117 | nn.init.zeros_(module.bias) 118 | module._is_hf_initialized = True 119 | 120 | def forward(self, x): 121 | mode = self.mode 122 | q = rearrange(self.q_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) 123 | k = rearrange(self.k_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) 124 | v = rearrange(self.v_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) 125 | q = self.feature_map_q(q) 126 | k = self.feature_map_k(k) 127 | if self.norm_q: 128 | q = q / (q.sum(-1, keepdim=True) + 1e-4) 129 | if self.norm_k: 130 | k = k / (k.sum(-1, keepdim=True) + 1e-4) 131 | 132 | if mode == 'chunk': 133 | o = chunk_linear_attn(q, k, v, normalize=self.do_feature_map_norm) 134 | elif mode == 'fused_chunk': 135 | o = fused_chunk_linear_attn(q, k, v, normalize=self.do_feature_map_norm) 136 | elif mode == 'fused_recurrent': 137 | o = fused_recurrent_linear_attn(q, k, v, normalize=self.do_feature_map_norm) 138 | else: 139 | raise NotImplementedError 140 | o = self.norm(o) 141 | o = rearrange(o, 'b h n d -> b n (h d)') 142 | o = self.o_proj(o) 143 | return o 144 | 145 | 146 | if __name__ == '__main__': 147 | import torch 148 | batch = 4 149 | seq_len = 1024 150 | hidden_size = 1024 151 | x = torch.randn(batch, seq_len, hidden_size).to(torch.bfloat16).cuda().requires_grad_(True) 152 | model = LinearAttention(hidden_size, feature_map='dplp').to(torch.bfloat16).cuda() 153 | y = model(x) 154 | print(y.shape) 155 | y.sum().backward() 156 | print(x.grad.shape) 157 | -------------------------------------------------------------------------------- /fla/layers/rebased.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/layers/rebased_fast.py 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | from typing import Optional 10 | 11 | import torch 12 | import torch.nn as nn 13 | from einops import rearrange 14 | 15 | from fla.modules.feature_map import RebasedFeatureMap 16 | from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn 17 | from fla.ops.rebased import parallel_rebased 18 | 19 | 20 | class ReBasedLinearAttention(nn.Module): 21 | def __init__( 22 | self, 23 | hidden_size: int, 24 | l_max: int = 2048, 25 | feature_dim: int = 16, 26 | num_key_value_heads: int = 16, 27 | num_heads: int = 16, 28 | use_gamma: Optional[bool] = True, 29 | use_beta: Optional[bool] = True, 30 | normalize: Optional[bool] = True, 31 | causal: bool = True, 32 | eps: float = 1e-5, 33 | mode: str = "parallel", 34 | layer_idx: Optional[int] = None, 35 | **kwargs 36 | ) -> ReBasedLinearAttention: 37 | super().__init__() 38 | self.hidden_size = hidden_size 39 | self.l_max = l_max 40 | self.mode = mode 41 | assert self.mode in ["fused_chunk", "parallel", 'chunk'] 42 | 43 | # linear attention 44 | self.feature_dim = feature_dim 45 | self.num_key_value_heads = num_key_value_heads 46 | self.num_heads = num_heads 47 | self.head_dim = self.hidden_size // self.num_key_value_heads 48 | self.use_gamma = use_gamma 49 | self.use_beta = use_beta 50 | self.normalize = normalize 51 | self.causal = causal 52 | 53 | self.feature_map = RebasedFeatureMap(self.feature_dim, use_gamma, use_beta, normalize) 54 | self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False) 55 | self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False) 56 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) 57 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) 58 | self.dropout = nn.Identity() 59 | self.eps = eps 60 | 61 | self.apply(self._initialize_weights) 62 | 63 | def _initialize_weights(self, module: nn.Module): 64 | if getattr(module, "_is_hf_initialized", False): 65 | return 66 | if isinstance(module, nn.Linear): 67 | nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) 68 | if module.bias is not None: 69 | nn.init.zeros_(module.bias) 70 | module._is_hf_initialized = True 71 | 72 | def forward(self, hidden_states: torch.Tensor, **kwargs): 73 | mode = self.mode 74 | q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) 75 | q, k, v = map(lambda x: rearrange(x, "b l (h d) -> b h l d", h=self.num_heads), [q, k, v]) 76 | q, k = self.feature_map(q, flatten=(mode != 'parallel')), self.feature_map(k, flatten=(mode != 'parallel')) 77 | if mode == "fused_chunk": 78 | o = fused_chunk_linear_attn(q, k, v, normalize=True, scale=1) 79 | elif mode == 'chunk': 80 | o = chunk_linear_attn(q, k, v, normalize=True, scale=1) 81 | elif mode == 'parallel': 82 | assert q.shape[-1] <= 128 83 | o = parallel_rebased(q, k, v, self.eps, True, True) 84 | o = rearrange(o, "b h l d -> b l (h d)") 85 | o = self.o_proj(o) 86 | o = self.dropout(o) 87 | return o 88 | 89 | # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119 90 | def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs): 91 | """ 92 | x (torch.Tensor): tensor of shape (b, d, l) 93 | y (torch.Tensor): tensor of shape (b, d, l) 94 | """ 95 | # hidden_states = hidden_states.transpose(1, 2) 96 | b, l, _ = hidden_states.size() 97 | q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) 98 | 99 | q = q.view(b, l, self.num_heads, self.feature_dim).transpose(1, 2) 100 | k = k.view(b, l, self.num_key_value_heads, self.feature_dim).transpose(1, 2) 101 | v = v.view(b, l, self.num_key_value_heads, self.head_dim).transpose(1, 2) 102 | 103 | # Linear attention 104 | q, k = self.feature_map(q), self.feature_map(k) 105 | q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1) 106 | 107 | # Compute attention 108 | if self.causal: 109 | y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps)) 110 | else: 111 | y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps)) 112 | y = rearrange(y, 'b h l d -> b l (h d)') 113 | y = self.o_proj(y.to(hidden_states.dtype)) 114 | y = self.dropout(y) 115 | return y.to(hidden_states.dtype) 116 | 117 | 118 | if __name__ == '__main__': 119 | batch = 4 120 | seq_len = 1024 121 | hidden_size = 1024 122 | dtype = torch.float32 123 | x = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda().requires_grad_(True) 124 | dy = torch.randn(batch, seq_len, hidden_size).to(dtype).cuda() 125 | model = ReBasedLinearAttention(hidden_size=hidden_size, mode='parallel').to(dtype).cuda() 126 | 127 | y = model(x) 128 | y.backward(dy, retain_graph=True) 129 | x_grad, x.grad = x.grad, None 130 | print(model.mode) 131 | model.mode = 'fused_chunk' 132 | y2 = model(x) 133 | print(model.mode) 134 | y2.backward(dy) 135 | # assert y.allclose(y2, 0, 1e-4), breakpoint() 136 | # assert x_grad.allclose(x.grad, 0, 1e-4), breakpoint() 137 | print("Pass") 138 | -------------------------------------------------------------------------------- /fla/layers/simple_gla.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Optional 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from einops import rearrange 11 | from transformers.activations import ACT2FN 12 | 13 | from fla.modules import FusedRMSNormSwishGate, RMSNorm 14 | from fla.ops.simple_gla import chunk_simple_gla 15 | 16 | 17 | class SimpleGatedLinearAttention(nn.Module): 18 | r""" 19 | The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635). # noqa 20 | This layer calls the simplified GLA kernel in which the gating is head-wise instead of elementwise. 21 | 22 | Args: 23 | mode (str, Optional): 24 | Which GLA kernel to use. 25 | Currently available: `chunk`. 26 | Default: `chunk`. 27 | hidden_size (int, Optional): 28 | The hidden size of the input. Default: 1024. 29 | expand_k (float, Optional): 30 | The expansion ratio for the key dim. Default: 0.5. 31 | expand_v (float, Optional): 32 | The expansion ratio for the value dim. Default: 1.0. 33 | num_heads (int, Optional): 34 | The number of heads. Default: 4. 35 | gate_fn (str, Optional): 36 | The activation function for the output gate. Default: `swish`. 37 | elementwise_affine (bool, Optional): 38 | If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`. 39 | norm_eps (float, Optional): 40 | The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5. 41 | gate_logit_normalizer (int, Optional): 42 | The normalizer for the gate logits, appied after `logsigmoid`. Default: 16. 43 | fuse_norm (bool, Optional): 44 | Whether to fuse the norm and the output gate for better memory footprint. Default: `True`. 45 | layer_idx (int, Optional): 46 | The index of the layer. Default: None. 47 | """ 48 | 49 | def __init__( 50 | self, 51 | mode: str = 'chunk', 52 | hidden_size: int = 1024, 53 | expand_k: float = 1.0, 54 | expand_v: float = 2.0, 55 | num_heads: int = 4, 56 | gate_fn: str = 'swish', 57 | elementwise_affine: Optional[bool] = True, 58 | norm_eps: float = 1e-5, 59 | gate_logit_normalizer: int = 16, 60 | fuse_norm: bool = True, 61 | **kwargs 62 | ) -> SimpleGatedLinearAttention: 63 | super().__init__() 64 | self.hidden_size = hidden_size 65 | 66 | self.mode = mode 67 | self.key_dim = int(hidden_size * expand_k) 68 | self.value_dim = int(hidden_size * expand_v) 69 | assert mode in ['chunk'], f"Not suppoerted mode `{mode}`." 70 | assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" 71 | assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" 72 | self.num_heads = num_heads 73 | self.head_qk_dim = self.key_dim // num_heads 74 | self.head_v_dim = self.value_dim // num_heads 75 | self.gate_fn = ACT2FN[gate_fn] 76 | 77 | self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) 78 | self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) 79 | self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) 80 | self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) 81 | 82 | self.gk_proj = nn.Linear(hidden_size, self.num_heads) 83 | self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) 84 | 85 | if gate_fn == 'swish' and fuse_norm: 86 | self.g_norm_swish_gate = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps) 87 | self.fuse_norm_and_gate = True 88 | else: 89 | self.fuse_norm_and_gate = False 90 | self.g_norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps) 91 | 92 | self.gate_logit_normalizer = gate_logit_normalizer 93 | 94 | self.apply(self._initialize_weights) 95 | 96 | def _initialize_weights(self, module: nn.Module): 97 | if getattr(module, "_is_hf_initialized", False): 98 | return 99 | if isinstance(module, nn.Linear): 100 | nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) 101 | if module.bias is not None: 102 | nn.init.zeros_(module.bias) 103 | module._is_hf_initialized = True 104 | 105 | def forward(self, x): 106 | mode = self.mode 107 | q = rearrange(self.q_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) 108 | k = rearrange(self.k_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) 109 | v = rearrange(self.v_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) 110 | gk = rearrange(self.gk_proj(x), 'b n h -> b h n') 111 | gk = (F.logsigmoid(gk) / self.gate_logit_normalizer) 112 | 113 | if mode == 'chunk': 114 | o = chunk_simple_gla(q, k, v, gk) 115 | else: 116 | raise NotImplementedError(f"Not supported mode `{mode}`.") 117 | 118 | o = rearrange(o, 'b h l d -> b l h d') 119 | g = self.g_proj(x) 120 | 121 | if self.fuse_norm_and_gate: 122 | g = rearrange(g, 'b l (h d) -> b l h d', h=self.num_heads) 123 | o = self.g_norm_swish_gate(o, g) 124 | o = rearrange(o, 'b l h d -> b l (h d)') 125 | else: 126 | o = self.g_norm(o) 127 | o = rearrange(o, 'b l h d -> b l (h d)') 128 | o = o * self.gate_fn(g) 129 | o = self.o_proj(o) 130 | return o 131 | 132 | 133 | if __name__ == '__main__': 134 | batch = 4 135 | seq_len = 1024 136 | 137 | hidden_size = 2048 138 | x = torch.randn(batch, seq_len, hidden_size).to(torch.bfloat16).cuda().requires_grad_(True) 139 | model = SimpleGatedLinearAttention(hidden_size=hidden_size, mode='chunk').to(torch.bfloat16).cuda() 140 | y = model(x) 141 | print(y.shape) 142 | y.sum().backward() 143 | print(x.grad.shape) 144 | -------------------------------------------------------------------------------- /fla/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from fla.models.abc import ABCConfig, ABCForCausalLM, ABCModel 4 | from fla.models.delta_net import (DeltaNetConfig, DeltaNetForCausalLM, 5 | DeltaNetModel) 6 | from fla.models.gla import GLAConfig, GLAForCausalLM, GLAModel 7 | from fla.models.hgrn import HGRNConfig, HGRNForCausalLM, HGRNModel 8 | from fla.models.hgrn2 import HGRN2Config, HGRN2ForCausalLM, HGRN2Model 9 | from fla.models.linear_attn import (LinearAttentionConfig, 10 | LinearAttentionForCausalLM, 11 | LinearAttentionModel) 12 | from fla.models.mamba import MambaConfig, MambaForCausalLM, MambaModel 13 | from fla.models.retnet import RetNetConfig, RetNetForCausalLM, RetNetModel 14 | from fla.models.rwkv6 import RWKV6Config, RWKV6ForCausalLM, RWKV6Model 15 | from fla.models.transformer import (TransformerConfig, TransformerForCausalLM, 16 | TransformerModel) 17 | 18 | __all__ = [ 19 | 'ABCConfig', 'ABCForCausalLM', 'ABCModel', 20 | 'DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel', 21 | 'GLAConfig', 'GLAForCausalLM', 'GLAModel', 22 | 'HGRNConfig', 'HGRNForCausalLM', 'HGRNModel', 23 | 'HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model', 24 | 'LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel', 25 | 'MambaConfig', 'MambaForCausalLM', 'MambaModel', 26 | 'RetNetConfig', 'RetNetForCausalLM', 'RetNetModel', 27 | 'RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model', 28 | 'TransformerConfig', 'TransformerForCausalLM', 'TransformerModel' 29 | ] 30 | -------------------------------------------------------------------------------- /fla/models/abc/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.abc.configuration_abc import ABCConfig 6 | from fla.models.abc.modeling_abc import ABCForCausalLM, ABCModel 7 | 8 | AutoConfig.register(ABCConfig.model_type, ABCConfig) 9 | AutoModel.register(ABCConfig, ABCModel) 10 | AutoModelForCausalLM.register(ABCConfig, ABCForCausalLM) 11 | 12 | 13 | __all__ = ['ABCConfig', 'ABCForCausalLM', 'ABCModel'] 14 | -------------------------------------------------------------------------------- /fla/models/abc/configuration_abc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class ABCConfig(PretrainedConfig): 9 | 10 | model_type = 'abc' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | vocab_size: int = 32000, 16 | hidden_size: int = 2048, 17 | gate_low_rank_dim: int = 16, 18 | clamp_min: float = -32, 19 | clamp_max: float = 32, 20 | hidden_ratio: Optional[int] = 4, 21 | intermediate_size: Optional[int] = None, 22 | num_hidden_layers: int = 24, 23 | num_heads: int = 4, 24 | num_slots: Optional[int] = 64, 25 | use_short_conv: bool = True, 26 | conv_size: int = 4, 27 | share_conv_kernel: bool = True, 28 | exapnd_k: float = 0.5, 29 | exapnd_v: float = 1, 30 | hidden_act: str = "swish", 31 | max_position_embeddings: int = 2048, 32 | elementwise_affine: Optional[bool] = True, 33 | norm_eps: float = 1e-6, 34 | use_cache: bool = True, 35 | pad_token_id: int = None, 36 | bos_token_id: int = 1, 37 | eos_token_id: int = 2, 38 | initializer_range: float = 0.02, 39 | tie_word_embeddings: bool = False, 40 | fuse_norm: bool = True, 41 | fuse_cross_entropy: bool = True, 42 | **kwargs 43 | ): 44 | self.vocab_size = vocab_size 45 | self.max_position_embeddings = max_position_embeddings 46 | self.hidden_size = hidden_size 47 | self.gate_low_rank_dim = gate_low_rank_dim 48 | self.clamp_min = clamp_min 49 | self.clamp_max = clamp_max 50 | self.hidden_ratio = hidden_ratio 51 | self.intermediate_size = intermediate_size 52 | self.num_hidden_layers = num_hidden_layers 53 | self.num_heads = num_heads 54 | self.num_slots = num_slots 55 | self.use_short_conv = use_short_conv 56 | self.conv_size = conv_size 57 | self.share_conv_kernel = share_conv_kernel 58 | self.expand_k = exapnd_k 59 | self.expand_v = exapnd_v 60 | self.hidden_act = hidden_act 61 | self.elementwise_affine = elementwise_affine 62 | self.norm_eps = norm_eps 63 | self.use_cache = use_cache 64 | self.initializer_range = initializer_range 65 | self.fuse_cross_entropy = fuse_cross_entropy 66 | self.fuse_norm = fuse_norm 67 | 68 | super().__init__( 69 | pad_token_id=pad_token_id, 70 | bos_token_id=bos_token_id, 71 | eos_token_id=eos_token_id, 72 | tie_word_embeddings=tie_word_embeddings, 73 | **kwargs, 74 | ) 75 | -------------------------------------------------------------------------------- /fla/models/delta_net/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.delta_net.configuration_delta_net import \ 6 | DeltaNetConfig 7 | from fla.models.delta_net.modeling_delta_net import ( 8 | DeltaNetForCausalLM, DeltaNetModel) 9 | 10 | AutoConfig.register(DeltaNetConfig.model_type, DeltaNetConfig) 11 | AutoModel.register(DeltaNetConfig, DeltaNetModel) 12 | AutoModelForCausalLM.register(DeltaNetConfig, DeltaNetForCausalLM) 13 | 14 | __all__ = ['DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel'] 15 | -------------------------------------------------------------------------------- /fla/models/delta_net/configuration_delta_net.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class DeltaNetConfig(PretrainedConfig): 9 | 10 | model_type = 'delta_net' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | vocab_size: int = 32000, 16 | hidden_size: int = 2048, 17 | expand_k: int = 1, 18 | expand_v: int = 1, 19 | use_gate: bool = False, 20 | use_short_conv: bool = True, 21 | conv_size: int = 4, 22 | share_conv_kernel: bool = False, 23 | use_rope: bool = False, 24 | use_beta: bool = True, 25 | use_output_norm: bool = True, 26 | hidden_ratio: Optional[int] = 4, 27 | intermediate_size: Optional[int] = None, 28 | num_hidden_layers: int = 24, 29 | num_heads: int = 4, 30 | attn_mode: str = "chunk", 31 | qk_norm: str = 'l2', 32 | qk_activation: str = 'silu', 33 | chunk_size: int = 64, 34 | hidden_act: str = "swish", 35 | max_position_embeddings: int = 2048, 36 | rms_norm_eps: float = 1e-6, 37 | use_cache: bool = True, 38 | pad_token_id: int = None, 39 | bos_token_id: int = 1, 40 | eos_token_id: int = 2, 41 | tie_word_embeddings: bool = False, 42 | initializer_range: float = 0.02, 43 | fuse_cross_entropy: bool = True, 44 | **kwargs 45 | ): 46 | self.vocab_size = vocab_size 47 | self.max_position_embeddings = max_position_embeddings 48 | self.hidden_size = hidden_size 49 | self.expand_k = expand_k 50 | self.expand_v = expand_v 51 | self.hidden_ratio = hidden_ratio 52 | self.intermediate_size = intermediate_size 53 | self.num_hidden_layers = num_hidden_layers 54 | self.num_heads = num_heads 55 | self.attn_mode = attn_mode 56 | self.hidden_act = hidden_act 57 | self.rms_norm_eps = rms_norm_eps 58 | self.use_cache = use_cache 59 | self.initializer_range = initializer_range 60 | self.fuse_cross_entropy = fuse_cross_entropy 61 | self.use_gate = use_gate 62 | self.use_short_conv = use_short_conv 63 | self.conv_size = conv_size 64 | self.share_conv_kernel = share_conv_kernel 65 | self.use_rope = use_rope 66 | self.use_beta = use_beta 67 | self.use_output_norm = use_output_norm 68 | self.qk_norm = qk_norm 69 | self.qk_activation = qk_activation 70 | 71 | super().__init__( 72 | pad_token_id=pad_token_id, 73 | bos_token_id=bos_token_id, 74 | eos_token_id=eos_token_id, 75 | tie_word_embeddings=tie_word_embeddings, 76 | **kwargs, 77 | ) 78 | -------------------------------------------------------------------------------- /fla/models/gla/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.gla.configuration_gla import GLAConfig 6 | from fla.models.gla.modeling_gla import GLAForCausalLM, GLAModel 7 | 8 | AutoConfig.register(GLAConfig.model_type, GLAConfig) 9 | AutoModel.register(GLAConfig, GLAModel) 10 | AutoModelForCausalLM.register(GLAConfig, GLAForCausalLM) 11 | 12 | 13 | __all__ = ['GLAConfig', 'GLAForCausalLM', 'GLAModel'] 14 | -------------------------------------------------------------------------------- /fla/models/gla/configuration_gla.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class GLAConfig(PretrainedConfig): 9 | 10 | model_type = 'gla' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | vocab_size: int = 32000, 16 | hidden_size: int = 2048, 17 | expand_k: int = 0.5, 18 | expand_v: int = 1, 19 | hidden_ratio: Optional[int] = 4, 20 | intermediate_size: Optional[int] = None, 21 | num_hidden_layers: int = 24, 22 | num_heads: int = 4, 23 | num_kv_heads: Optional[int] = None, 24 | feature_map: Optional[str] = None, 25 | attn_mode: str = "chunk", 26 | use_short_conv: bool = False, 27 | conv_size: int = 4, 28 | share_conv_kernel: bool = True, 29 | use_output_gate: bool = True, 30 | clamp_min: Optional[float] = None, 31 | hidden_act: str = "swish", 32 | max_position_embeddings: int = 2048, 33 | elementwise_affine: Optional[bool] = True, 34 | norm_eps: float = 1e-6, 35 | use_gk: bool = True, 36 | use_gv: bool = False, 37 | use_cache: bool = True, 38 | pad_token_id: int = None, 39 | bos_token_id: int = 1, 40 | eos_token_id: int = 2, 41 | tie_word_embeddings: bool = False, 42 | initializer_range: float = 0.02, 43 | fuse_norm: bool = True, 44 | fuse_cross_entropy: bool = True, 45 | **kwargs 46 | ): 47 | self.vocab_size = vocab_size 48 | self.max_position_embeddings = max_position_embeddings 49 | self.hidden_size = hidden_size 50 | self.expand_k = expand_k 51 | self.expand_v = expand_v 52 | self.hidden_ratio = hidden_ratio 53 | self.intermediate_size = intermediate_size 54 | self.num_hidden_layers = num_hidden_layers 55 | self.num_heads = num_heads 56 | self.num_kv_heads = num_kv_heads 57 | self.feature_map = feature_map 58 | self.attn_mode = attn_mode 59 | self.clamp_min = clamp_min 60 | self.hidden_act = hidden_act 61 | self.elementwise_affine = elementwise_affine 62 | self.norm_eps = norm_eps 63 | self.use_gk = use_gk 64 | self.use_gv = use_gv 65 | self.use_cache = use_cache 66 | self.initializer_range = initializer_range 67 | self.fuse_norm = fuse_norm 68 | self.fuse_cross_entropy = fuse_cross_entropy 69 | self.use_short_conv = use_short_conv 70 | self.conv_size = conv_size 71 | self.share_conv_kernel = share_conv_kernel 72 | self.use_output_gate = use_output_gate 73 | 74 | super().__init__( 75 | pad_token_id=pad_token_id, 76 | bos_token_id=bos_token_id, 77 | eos_token_id=eos_token_id, 78 | tie_word_embeddings=tie_word_embeddings, 79 | **kwargs, 80 | ) 81 | -------------------------------------------------------------------------------- /fla/models/hgrn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.hgrn.configuration_hgrn import HGRNConfig 6 | from fla.models.hgrn.modeling_hgrn import HGRNForCausalLM, HGRNModel 7 | 8 | AutoConfig.register(HGRNConfig.model_type, HGRNConfig) 9 | AutoModel.register(HGRNConfig, HGRNModel) 10 | AutoModelForCausalLM.register(HGRNConfig, HGRNForCausalLM) 11 | 12 | 13 | __all__ = ['HGRNConfig', 'HGRNForCausalLM', 'HGRNModel'] 14 | -------------------------------------------------------------------------------- /fla/models/hgrn/configuration_hgrn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class HGRNConfig(PretrainedConfig): 9 | 10 | model_type = 'hgrn' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | attn_mode: str = "chunk", 16 | vocab_size: int = 32000, 17 | hidden_size: int = 2048, 18 | num_hidden_layers: int = 24, 19 | num_heads: Optional[int] = 1, 20 | expand_ratio: Optional[int] = 1, 21 | use_short_conv: bool = False, 22 | conv_size: int = 4, 23 | share_conv_kernel: bool = True, 24 | use_lower_bound: bool = True, 25 | hidden_ratio: Optional[int] = 4, 26 | intermediate_size: Optional[int] = None, 27 | hidden_act: str = "swish", 28 | max_position_embeddings: int = 2048, 29 | elementwise_affine: Optional[bool] = True, 30 | norm_eps: float = 1e-6, 31 | use_cache: bool = True, 32 | pad_token_id: int = None, 33 | bos_token_id: int = 1, 34 | eos_token_id: int = 2, 35 | tie_word_embeddings: bool = False, 36 | initializer_range: float = 0.02, 37 | fuse_cross_entropy: bool = True, 38 | **kwargs 39 | ): 40 | self.attn_mode = attn_mode 41 | self.vocab_size = vocab_size 42 | self.max_position_embeddings = max_position_embeddings 43 | self.hidden_size = hidden_size 44 | self.num_hidden_layers = num_hidden_layers 45 | self.num_heads = num_heads 46 | self.expand_ratio = expand_ratio 47 | self.use_short_conv = use_short_conv 48 | self.conv_size = conv_size 49 | self.share_conv_kernel = share_conv_kernel 50 | self.use_lower_bound = use_lower_bound 51 | self.hidden_ratio = hidden_ratio 52 | self.intermediate_size = intermediate_size 53 | self.hidden_act = hidden_act 54 | self.elementwise_affine = elementwise_affine 55 | self.norm_eps = norm_eps 56 | self.use_cache = use_cache 57 | self.initializer_range = initializer_range 58 | self.fuse_cross_entropy = fuse_cross_entropy 59 | 60 | super().__init__( 61 | pad_token_id=pad_token_id, 62 | bos_token_id=bos_token_id, 63 | eos_token_id=eos_token_id, 64 | tie_word_embeddings=tie_word_embeddings, 65 | **kwargs, 66 | ) 67 | -------------------------------------------------------------------------------- /fla/models/hgrn2/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.hgrn2.configuration_hgrn2 import HGRN2Config 6 | from fla.models.hgrn2.modeling_hgrn2 import HGRN2ForCausalLM, HGRN2Model 7 | 8 | AutoConfig.register(HGRN2Config.model_type, HGRN2Config) 9 | AutoModel.register(HGRN2Config, HGRN2Model) 10 | AutoModelForCausalLM.register(HGRN2Config, HGRN2ForCausalLM) 11 | 12 | 13 | __all__ = ['HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model'] 14 | -------------------------------------------------------------------------------- /fla/models/hgrn2/configuration_hgrn2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class HGRN2Config(PretrainedConfig): 9 | 10 | model_type = 'hgrn2' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | vocab_size: int = 32000, 16 | hidden_size: int = 2048, 17 | num_hidden_layers: int = 24, 18 | attn_mode: str = "chunk", 19 | num_heads: Optional[int] = None, 20 | expand_ratio: Optional[int] = 128, 21 | use_short_conv: bool = False, 22 | conv_size: int = 4, 23 | share_conv_kernel: bool = True, 24 | use_lower_bound: bool = True, 25 | hidden_ratio: Optional[int] = 4, 26 | intermediate_size: Optional[int] = None, 27 | hidden_act: str = "swish", 28 | max_position_embeddings: int = 2048, 29 | elementwise_affine: Optional[bool] = True, 30 | norm_eps: float = 1e-6, 31 | use_cache: bool = True, 32 | pad_token_id: int = None, 33 | bos_token_id: int = 1, 34 | eos_token_id: int = 2, 35 | tie_word_embeddings: bool = False, 36 | initializer_range: float = 0.02, 37 | fuse_cross_entropy: bool = True, 38 | **kwargs 39 | ): 40 | self.vocab_size = vocab_size 41 | self.max_position_embeddings = max_position_embeddings 42 | self.hidden_size = hidden_size 43 | self.num_hidden_layers = num_hidden_layers 44 | self.attn_mode = attn_mode 45 | self.num_heads = num_heads 46 | self.expand_ratio = expand_ratio 47 | self.use_short_conv = use_short_conv 48 | self.conv_size = conv_size 49 | self.share_conv_kernel = share_conv_kernel 50 | self.use_lower_bound = use_lower_bound 51 | self.hidden_ratio = hidden_ratio 52 | self.intermediate_size = intermediate_size 53 | self.hidden_act = hidden_act 54 | self.elementwise_affine = elementwise_affine 55 | self.norm_eps = norm_eps 56 | self.use_cache = use_cache 57 | self.initializer_range = initializer_range 58 | self.fuse_cross_entropy = fuse_cross_entropy 59 | 60 | super().__init__( 61 | pad_token_id=pad_token_id, 62 | bos_token_id=bos_token_id, 63 | eos_token_id=eos_token_id, 64 | tie_word_embeddings=tie_word_embeddings, 65 | **kwargs, 66 | ) 67 | -------------------------------------------------------------------------------- /fla/models/linear_attn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.linear_attn.configuration_linear_attn import \ 6 | LinearAttentionConfig 7 | from fla.models.linear_attn.modeling_linear_attn import ( 8 | LinearAttentionForCausalLM, LinearAttentionModel) 9 | 10 | AutoConfig.register(LinearAttentionConfig.model_type, LinearAttentionConfig) 11 | AutoModel.register(LinearAttentionConfig, LinearAttentionModel) 12 | AutoModelForCausalLM.register(LinearAttentionConfig, LinearAttentionForCausalLM) 13 | 14 | __all__ = ['LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel'] 15 | -------------------------------------------------------------------------------- /fla/models/linear_attn/configuration_linear_attn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class LinearAttentionConfig(PretrainedConfig): 9 | 10 | model_type = 'linear_attn' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | vocab_size: int = 32000, 16 | hidden_size: int = 2048, 17 | expand_k: int = 1, 18 | expand_v: int = 1, 19 | hidden_ratio: Optional[int] = 4, 20 | intermediate_size: Optional[int] = None, 21 | num_hidden_layers: int = 24, 22 | num_heads: int = 4, 23 | attn_mode: str = "fused_chunk", 24 | feature_map: str = "elementwise_product", 25 | tie_feature_map_qk: bool = False, 26 | norm_q: bool = False, 27 | norm_k: bool = False, 28 | norm_feature_map: bool = False, 29 | hidden_act: str = "swish", 30 | max_position_embeddings: int = 2048, 31 | elementwise_affine: Optional[bool] = True, 32 | norm_eps: float = 1e-6, 33 | use_cache: bool = True, 34 | pad_token_id: int = None, 35 | bos_token_id: int = 1, 36 | eos_token_id: int = 2, 37 | tie_word_embeddings: bool = False, 38 | initializer_range: float = 0.02, 39 | fuse_cross_entropy: bool = True, 40 | **kwargs 41 | ): 42 | self.vocab_size = vocab_size 43 | self.max_position_embeddings = max_position_embeddings 44 | self.hidden_size = hidden_size 45 | self.expand_k = expand_k 46 | self.expand_v = expand_v 47 | self.hidden_ratio = hidden_ratio 48 | self.intermediate_size = intermediate_size 49 | self.num_hidden_layers = num_hidden_layers 50 | self.num_heads = num_heads 51 | self.attn_mode = attn_mode 52 | self.feature_map = feature_map 53 | self.tie_feature_map_qk = tie_feature_map_qk 54 | self.norm_q = norm_q 55 | self.norm_k = norm_k 56 | self.norm_feature_map = norm_feature_map 57 | self.hidden_act = hidden_act 58 | self.elementwise_affine = elementwise_affine 59 | self.norm_eps = norm_eps 60 | self.use_cache = use_cache 61 | self.initializer_range = initializer_range 62 | self.fuse_cross_entropy = fuse_cross_entropy 63 | 64 | super().__init__( 65 | pad_token_id=pad_token_id, 66 | bos_token_id=bos_token_id, 67 | eos_token_id=eos_token_id, 68 | tie_word_embeddings=tie_word_embeddings, 69 | **kwargs, 70 | ) 71 | -------------------------------------------------------------------------------- /fla/models/mamba/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.mamba.configuration_mamba import MambaConfig 6 | from fla.models.mamba.modeling_mamba import (MambaBlock, MambaForCausalLM, 7 | MambaModel) 8 | 9 | AutoConfig.register(MambaConfig.model_type, MambaConfig, True) 10 | AutoModel.register(MambaConfig, MambaModel, True) 11 | AutoModelForCausalLM.register(MambaConfig, MambaForCausalLM, True) 12 | 13 | 14 | __all__ = ['MambaConfig', 'MambaForCausalLM', 'MambaModel', 'MambaBlock'] 15 | -------------------------------------------------------------------------------- /fla/models/mamba/configuration_mamba.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """MAMBA configuration""" 16 | 17 | import math 18 | 19 | from transformers.configuration_utils import PretrainedConfig 20 | 21 | 22 | class MambaConfig(PretrainedConfig): 23 | """ 24 | This is the configuration class to store the configuration of a [`MambaModel`]. It is used to instantiate a MAMBA 25 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the 26 | defaults will yield a similar configuration to that of the MAMBA 27 | [state-spaces/mamba-2.8b](https://huggingface.co/state-spaces/mamba-2.8b) architecture. 28 | 29 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 30 | documentation from [`PretrainedConfig`] for more information. 31 | 32 | 33 | Args: 34 | vocab_size (`int`, *optional*, defaults to 50280): 35 | Vocabulary size of the MAMBA model. Defines the number of different tokens that can be represented by the 36 | `inputs_ids` passed when calling [`MambaModel`]. 37 | hidden_size (`int`, *optional*, defaults to 768): 38 | Dimensionality of the embeddings and hidden states. 39 | state_size (`int`, *optional*, defaults to 16): shape of the state space latents. 40 | num_hidden_layers (`int`, *optional*, defaults to 32): 41 | Number of hidden layers in the model. 42 | layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): 43 | The epsilon to use in the layer normalization layers. 44 | pad_token_id (`int`, *optional*, defaults to 0): 45 | Padding token id. 46 | bos_token_id (`int`, *optional*, defaults to 0): 47 | The id of the beginning of sentence token in the vocabulary. 48 | eos_token_id (`int`, *optional*, defaults to 0): 49 | The id of the end of sentence token in the vocabulary. 50 | expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size. 51 | conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel. 52 | use_bias (`bool`, *optional*, defaults to `False`): 53 | Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block 54 | use_conv_bias (`bool`, *optional*, defaults to `True`): 55 | Whether or not to use bias in the convolution layer of the mixer block. 56 | hidden_act (`str`, *optional*, defaults to `"silu"`): 57 | The non-linear activation function (function or string) in the decoder. 58 | initializer_range (`float`, *optional*, defaults to 0.1): 59 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 60 | residual_in_fp32 (`bool`, *optional*, defaults to `True`): 61 | Whether or not residuals should be in `float32`. 62 | If set to `False` residuals will keep the same `dtype` as the rest of the model 63 | time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`): 64 | Rank of the the discretization projection matrix. 65 | `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)` 66 | time_step_scale (`float`, *optional*, defaults to 1.0): 67 | Scale used used to scale `dt_proj.bias`. 68 | time_step_min (`float`, *optional*, defaults to 0.001): 69 | Minimum `time_step` used to bound `dt_proj.bias`. 70 | time_step_max (`float`, *optional*, defaults to 0.1): 71 | Maximum `time_step` used to bound `dt_proj.bias`. 72 | time_step_init_scheme (`float`, *optional*, defaults to `"random"`): 73 | Init scheme used for `dt_proj.weight`. Should be one of `["random","uniform"]` 74 | time_step_floor (`float`, *optional*, defaults to 0.0001): 75 | Minimum clamping value of the `dt_proj.bias` layer initialization. 76 | rescale_prenorm_residual (`bool`, *optional*, defaults to `False`): 77 | Whether or not to rescale `out_proj` weights when initializing. 78 | use_cache (`bool`, *optional*, defaults to `True`): 79 | Whether or not the cache should be used. 80 | 81 | 82 | Example: 83 | 84 | ```python 85 | >>> from transformers import MambaConfig, MambaModel 86 | 87 | >>> # Initializing a Mamba configuration 88 | >>> configuration = MambaConfig() 89 | 90 | >>> # Initializing a model (with random weights) from the configuration 91 | >>> model = MambaModel(configuration) 92 | 93 | >>> # Accessing the model configuration 94 | >>> configuration = model.config 95 | ```""" 96 | 97 | model_type = "mamba" 98 | 99 | def __init__( 100 | self, 101 | vocab_size=32000, 102 | hidden_size=2048, 103 | state_size=16, 104 | num_hidden_layers=48, 105 | layer_norm_epsilon=1e-5, 106 | pad_token_id= 0, 107 | bos_token_id= 1, 108 | eos_token_id= 2, 109 | expand=2, 110 | conv_kernel=4, 111 | use_bias=False, 112 | use_conv_bias=True, 113 | hidden_act="silu", 114 | initializer_range=0.1, 115 | residual_in_fp32=False, 116 | time_step_rank="auto", 117 | time_step_scale=1.0, 118 | time_step_min=0.001, 119 | time_step_max=0.1, 120 | time_step_init_scheme="random", 121 | time_step_floor=1e-4, 122 | rescale_prenorm_residual=False, 123 | use_cache=True, 124 | fuse_norm: bool = True, 125 | fuse_cross_entropy: bool = True, 126 | tie_word_embeddings: bool = False, 127 | **kwargs, 128 | ): 129 | self.vocab_size = vocab_size 130 | self.hidden_size = hidden_size 131 | self.state_size = state_size 132 | self.num_hidden_layers = num_hidden_layers 133 | self.layer_norm_epsilon = layer_norm_epsilon 134 | self.conv_kernel = conv_kernel 135 | self.expand = expand 136 | self.intermediate_size = int(expand * self.hidden_size) 137 | self.bos_token_id = bos_token_id 138 | self.eos_token_id = eos_token_id 139 | self.pad_token_id = pad_token_id 140 | self.use_bias = use_bias 141 | self.use_conv_bias = use_conv_bias 142 | self.hidden_act = hidden_act 143 | self.initializer_range = initializer_range 144 | self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank 145 | self.time_step_scale = time_step_scale 146 | self.time_step_min = time_step_min 147 | self.time_step_max = time_step_max 148 | self.time_step_init_scheme = time_step_init_scheme 149 | self.time_step_floor = time_step_floor 150 | self.rescale_prenorm_residual = rescale_prenorm_residual 151 | self.residual_in_fp32 = residual_in_fp32 152 | self.use_cache = use_cache 153 | self.fuse_cross_entropy = fuse_cross_entropy 154 | self.fuse_norm = fuse_norm 155 | 156 | super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs) 157 | -------------------------------------------------------------------------------- /fla/models/retnet/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.retnet.configuration_retnet import RetNetConfig 6 | from fla.models.retnet.modeling_retnet import RetNetForCausalLM, RetNetModel 7 | 8 | AutoConfig.register(RetNetConfig.model_type, RetNetConfig) 9 | AutoModel.register(RetNetConfig, RetNetModel) 10 | AutoModelForCausalLM.register(RetNetConfig, RetNetForCausalLM) 11 | 12 | 13 | __all__ = ['RetNetConfig', 'RetNetForCausalLM', 'RetNetModel'] 14 | -------------------------------------------------------------------------------- /fla/models/retnet/configuration_retnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Optional 6 | 7 | from transformers.configuration_utils import PretrainedConfig 8 | 9 | 10 | class RetNetConfig(PretrainedConfig): 11 | 12 | model_type = 'retnet' 13 | keys_to_ignore_at_inference = ['past_key_values'] 14 | 15 | def __init__( 16 | self, 17 | vocab_size: int = 32000, 18 | hidden_size: int = 2048, 19 | expand_k: int = 1, 20 | expand_v: int = 2, 21 | hidden_ratio: Optional[int] = 2, 22 | intermediate_size: Optional[int] = None, 23 | num_hidden_layers: int = 24, 24 | num_heads: int = 8, 25 | num_kv_heads: Optional[int] = None, 26 | feature_map: Optional[str] = None, 27 | attn_mode: str = "fused_chunk", 28 | hidden_act: str = "swish", 29 | use_short_conv: bool = False, 30 | conv_size: int = 4, 31 | share_conv_kernel: bool = True, 32 | use_output_gate: bool = True, 33 | max_position_embeddings: int = 2048, 34 | elementwise_affine: Optional[bool] = True, 35 | norm_eps: float = 1e-6, 36 | use_cache: bool = True, 37 | pad_token_id: int = None, 38 | bos_token_id: int = 1, 39 | eos_token_id: int = 2, 40 | tie_word_embeddings: bool = False, 41 | initializer_range: float = 0.02, 42 | fuse_norm: bool = True, 43 | fuse_cross_entropy: bool = True, 44 | **kwargs 45 | ) -> RetNetConfig: 46 | self.vocab_size = vocab_size 47 | self.max_position_embeddings = max_position_embeddings 48 | self.hidden_size = hidden_size 49 | self.expand_k = expand_k 50 | self.expand_v = expand_v 51 | self.hidden_ratio = hidden_ratio 52 | self.intermediate_size = intermediate_size 53 | self.num_hidden_layers = num_hidden_layers 54 | self.num_heads = num_heads 55 | self.num_kv_heads = num_kv_heads 56 | self.feature_map = feature_map 57 | self.attn_mode = attn_mode 58 | self.hidden_act = hidden_act 59 | self.use_short_conv = use_short_conv 60 | self.conv_size = conv_size 61 | self.share_conv_kernel = share_conv_kernel 62 | self.use_output_gate = use_output_gate 63 | self.elementwise_affine = elementwise_affine 64 | self.norm_eps = norm_eps 65 | self.use_cache = use_cache 66 | self.initializer_range = initializer_range 67 | self.fuse_norm = fuse_norm 68 | self.fuse_cross_entropy = fuse_cross_entropy 69 | 70 | super().__init__( 71 | pad_token_id=pad_token_id, 72 | bos_token_id=bos_token_id, 73 | eos_token_id=eos_token_id, 74 | tie_word_embeddings=tie_word_embeddings, 75 | **kwargs, 76 | ) 77 | -------------------------------------------------------------------------------- /fla/models/rwkv6/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.rwkv6.configuration_rwkv6 import RWKV6Config 6 | from fla.models.rwkv6.modeling_rwkv6 import RWKV6ForCausalLM, RWKV6Model 7 | 8 | AutoConfig.register(RWKV6Config.model_type, RWKV6Config) 9 | AutoModel.register(RWKV6Config, RWKV6Model) 10 | AutoModelForCausalLM.register(RWKV6Config, RWKV6ForCausalLM) 11 | 12 | 13 | __all__ = ['RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model'] 14 | -------------------------------------------------------------------------------- /fla/models/rwkv6/configuration_rwkv6.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class RWKV6Config(PretrainedConfig): 9 | 10 | model_type = 'rwkv6' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | attn_mode: str = "chunk", 16 | vocab_size: int = 32000, 17 | hidden_size: int = 2048, 18 | expand_k: int = 0.5, 19 | expand_v: int = 1, 20 | hidden_ratio: Optional[int] = 3.5, 21 | intermediate_size: Optional[int] = None, 22 | use_glu: Optional[bool] = False, 23 | num_hidden_layers: int = 24, 24 | num_heads: int = 4, 25 | proj_low_rank_dim: int = 32, 26 | gate_low_rank_dim: int = 64, 27 | hidden_act: str = "sqrelu", 28 | max_position_embeddings: int = 2048, 29 | eps: float = 1e-6, 30 | use_cache: bool = True, 31 | pad_token_id: int = None, 32 | bos_token_id: int = 1, 33 | eos_token_id: int = 2, 34 | tie_word_embeddings: bool = False, 35 | initializer_range: float = 0.02, 36 | fuse_norm: bool = True, 37 | fuse_cross_entropy: bool = True, 38 | **kwargs 39 | ): 40 | self.vocab_size = vocab_size 41 | self.max_position_embeddings = max_position_embeddings 42 | self.hidden_size = hidden_size 43 | self.expand_k = expand_k 44 | self.expand_v = expand_v 45 | self.hidden_ratio = hidden_ratio 46 | self.intermediate_size = intermediate_size 47 | self.use_glu = use_glu 48 | self.num_hidden_layers = num_hidden_layers 49 | self.num_heads = num_heads 50 | self.proj_low_rank_dim = proj_low_rank_dim 51 | self.gate_low_rank_dim = gate_low_rank_dim 52 | self.attn_mode = attn_mode 53 | self.hidden_act = hidden_act 54 | self.eps = eps 55 | self.use_cache = use_cache 56 | self.initializer_range = initializer_range 57 | self.fuse_norm = fuse_norm 58 | self.fuse_cross_entropy = fuse_cross_entropy 59 | 60 | super().__init__( 61 | pad_token_id=pad_token_id, 62 | bos_token_id=bos_token_id, 63 | eos_token_id=eos_token_id, 64 | tie_word_embeddings=tie_word_embeddings, 65 | **kwargs, 66 | ) 67 | -------------------------------------------------------------------------------- /fla/models/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.transformer.configuration_transformer import TransformerConfig 6 | from fla.models.transformer.modeling_transformer import ( 7 | TransformerForCausalLM, TransformerModel) 8 | 9 | AutoConfig.register(TransformerConfig.model_type, TransformerConfig) 10 | AutoModel.register(TransformerConfig, TransformerModel) 11 | AutoModelForCausalLM.register(TransformerConfig, TransformerForCausalLM) 12 | 13 | 14 | __all__ = ['TransformerConfig', 'TransformerForCausalLM', 'TransformerModel'] 15 | -------------------------------------------------------------------------------- /fla/models/transformer/configuration_transformer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class TransformerConfig(PretrainedConfig): 9 | 10 | model_type = 'transformer' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | vocab_size: int = 32000, 16 | hidden_size: int = 2048, 17 | hidden_ratio: Optional[int] = 4, 18 | intermediate_size: Optional[int] = None, 19 | num_hidden_layers: int = 24, 20 | num_heads: int = 32, 21 | num_kv_heads: int = None, 22 | hidden_act: str = "swish", 23 | max_position_embeddings: int = 2048, 24 | initializer_range: float = 0.02, 25 | elementwise_affine: Optional[bool] = True, 26 | norm_eps: float = 1e-6, 27 | use_cache: bool = True, 28 | pad_token_id: int = None, 29 | bos_token_id: int = 1, 30 | eos_token_id: int = 2, 31 | tie_word_embeddings: bool = False, 32 | attention_bias: bool = False, 33 | fuse_norm: bool = True, 34 | fuse_cross_entropy: bool = True, 35 | **kwargs, 36 | ): 37 | self.vocab_size = vocab_size 38 | self.max_position_embeddings = max_position_embeddings 39 | self.hidden_size = hidden_size 40 | self.hidden_ratio = hidden_ratio 41 | self.intermediate_size = intermediate_size 42 | self.num_hidden_layers = num_hidden_layers 43 | self.num_heads = num_heads 44 | self.num_kv_heads = num_kv_heads 45 | 46 | self.hidden_act = hidden_act 47 | self.initializer_range = initializer_range 48 | self.elementwise_affine = elementwise_affine 49 | self.norm_eps = norm_eps 50 | self.use_cache = use_cache 51 | self.attention_bias = attention_bias 52 | self.fuse_cross_entropy = fuse_cross_entropy 53 | self.fuse_norm = fuse_norm 54 | 55 | super().__init__( 56 | pad_token_id=pad_token_id, 57 | bos_token_id=bos_token_id, 58 | eos_token_id=eos_token_id, 59 | tie_word_embeddings=tie_word_embeddings, 60 | **kwargs, 61 | ) 62 | -------------------------------------------------------------------------------- /fla/models/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Any, Dict, List, Optional, Tuple 6 | 7 | import torch 8 | from transformers.cache_utils import Cache 9 | 10 | 11 | class RecurrentCache(Cache): 12 | """ 13 | A cache used for storing hidden states produced by flash linear attention models. 14 | 15 | It stores the states of each layer as the tensor of shape `[batch_size, key_dim, value_dim]`. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | seen_tokens: int = 0 21 | ) -> RecurrentCache: 22 | 23 | self.states: List[torch.Tensor] = [] 24 | self._seen_tokens = seen_tokens # Used in `generate` to keep tally of how many tokens the cache has seen 25 | 26 | def __getitem__(self, layer_idx: int) -> torch.Tensor: 27 | if layer_idx < len(self): 28 | return self.states[layer_idx] 29 | else: 30 | raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") 31 | 32 | def __iter__(self): 33 | for state in self.states: 34 | yield state 35 | 36 | def __len__(self): 37 | return len(self.states) 38 | 39 | def update( 40 | self, 41 | state: Tuple[torch.Tensor], 42 | layer_idx: int, 43 | offset: Optional[int] = 1, 44 | cache_kwargs: Optional[Dict[str, Any]] = None, 45 | ) -> Tuple[torch.Tensor]: 46 | """ 47 | Updates the cache with the new `state` for the layer `layer_idx`. 48 | 49 | Parameters: 50 | state (`Tuple[torch.Tensor]`): 51 | The new state to cache. 52 | layer_idx (`int`): 53 | The index of the layer to cache the states for. 54 | offset (`int`): 55 | The offset of current fed tokens. 56 | cache_kwargs (`Dict[str, Any]`, `optional`): 57 | Additional arguments for the cache subclass. 58 | 59 | Return: 60 | The updated state. 61 | """ 62 | 63 | if isinstance(state, torch.Tensor): 64 | state = (state,) 65 | if len(self.states) <= layer_idx: 66 | self.states.append(state) 67 | else: 68 | for i, s in enumerate(state): 69 | self.states[layer_idx][i].copy_(s) 70 | # update the number of seen tokens once we achieve the last layer 71 | if layer_idx == len(self) - 1: 72 | self._seen_tokens += offset 73 | 74 | return state 75 | 76 | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: 77 | """Returns the sequence length of the cached states. A layer index can be optionally passed.""" 78 | if len(self.states) <= layer_idx: 79 | return 0 80 | return self._seen_tokens 81 | 82 | def get_max_length(self) -> Optional[int]: 83 | """Returns the maximum sequence length of the cached states. RecurrentCache does not have a maximum length.""" 84 | return None 85 | 86 | def reorder_cache(self, beam_idx: torch.LongTensor): 87 | """Reorders the cache for beam search, given the selected beam indices.""" 88 | for layer_idx in range(len(self.states)): 89 | device = self.states[layer_idx].device 90 | self.states[layer_idx] = self.states[layer_idx].index_select(0, beam_idx.to(device)) 91 | 92 | def to_legacy_cache(self) -> Tuple[torch.Tensor]: 93 | return tuple(self.states) 94 | 95 | @classmethod 96 | def from_legacy_cache( 97 | cls, 98 | past_key_values: Optional[Tuple[torch.Tensor]] = None, 99 | seen_tokens: int = 0 100 | ) -> RecurrentCache: 101 | """Converts a cache in the legacy cache format into an equivalent `RecurrentCache`.""" 102 | 103 | cache = cls(seen_tokens) 104 | if past_key_values is not None: 105 | for layer_idx in range(len(past_key_values)): 106 | cache.update(past_key_values[layer_idx], layer_idx) 107 | return cache 108 | -------------------------------------------------------------------------------- /fla/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from fla.modules.convolution import (ImplicitLongConvolution, LongConvolution, 4 | ShortConvolution) 5 | from fla.modules.fused_cross_entropy import FusedCrossEntropyLoss 6 | from fla.modules.fused_norm_gate import (FusedLayerNormSwishGate, 7 | FusedLayerNormSwishGateLinear, 8 | FusedRMSNormSwishGate, 9 | FusedRMSNormSwishGateLinear) 10 | from fla.modules.layernorm import (LayerNorm, LayerNormLinear, RMSNorm, 11 | RMSNormLinear) 12 | from fla.modules.rotary import RotaryEmbedding 13 | 14 | __all__ = [ 15 | 'ImplicitLongConvolution', 'LongConvolution', 'ShortConvolution', 16 | 'FusedCrossEntropyLoss', 17 | 'LayerNorm', 'LayerNormLinear', 'RMSNorm', 'RMSNormLinear', 18 | 'FusedLayerNormSwishGate', 'FusedLayerNormSwishGateLinear', 'FusedRMSNormSwishGate', 'FusedRMSNormSwishGateLinear', 19 | 'RotaryEmbedding' 20 | ] 21 | -------------------------------------------------------------------------------- /fla/modules/feature_map.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import annotations 4 | 5 | import math 6 | from typing import Optional 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch import nn 11 | 12 | from fla.modules.layernorm import layer_norm_fn 13 | from fla.utils import checkpoint 14 | 15 | 16 | @checkpoint 17 | def flatten_diag_outer_product(x, y): 18 | z = torch.einsum("...i,...j->...ij", x, y) 19 | N = z.size(-1) 20 | indicies = torch.triu_indices(N, N) 21 | return z[..., indicies[0], indicies[1]] 22 | 23 | 24 | @checkpoint 25 | def flatten_diag_outer_product_off1(x, y): 26 | z = torch.einsum("...i,...j->...ij", x, y) 27 | N = z.size(-1) 28 | indicies = torch.triu_indices(N, N, 1) 29 | indices2 = torch.arange(0, N) 30 | return z[..., indicies[0], indicies[1]], z[..., indices2, indices2] 31 | 32 | 33 | def is_power_of_2(n): 34 | return (n & (n - 1) == 0) and n != 0 35 | 36 | 37 | class HedgehogFeatureMap(nn.Module): 38 | 39 | r""" 40 | Hedgehog feature map as introduced in 41 | `The Hedgehog & the Porcupine: Expressive Linear Attentions with Softmax Mimicry `_ 42 | """ 43 | 44 | def __init__( 45 | self, 46 | head_dim: int 47 | ) -> HedgehogFeatureMap: 48 | super().__init__() 49 | # Trainable map 50 | self.layer = nn.Linear(head_dim, head_dim) 51 | self.init_weights_() 52 | 53 | def init_weights_(self): 54 | """Initialize trainable map as identity""" 55 | with torch.no_grad(): 56 | identity = torch.eye(*self.layer.weight.shape[-2:], dtype=torch.float) 57 | self.layer.weight.copy_(identity.to(self.layer.weight)) 58 | nn.init.zeros_(self.layer.bias) 59 | 60 | def forward(self, x: torch.Tensor): 61 | x = self.layer(x) # shape b, h, l, d 62 | return torch.cat([2*x, -2*x], dim=-1).softmax(-1) 63 | 64 | 65 | class T2RFeatureMap(nn.Module): 66 | 67 | r""" 68 | Simple linear mapping feature map as in 69 | `Finetuning Pretrained Transformers into RNNs `_ 70 | """ 71 | 72 | def __init__( 73 | self, 74 | head_dim: int, 75 | dot_dim: int = None 76 | ) -> T2RFeatureMap: 77 | super().__init__() 78 | # Trainable map 79 | if dot_dim is None: 80 | dot_dim = head_dim 81 | self.layer = nn.Linear(head_dim, dot_dim) 82 | 83 | def forward(self, x: torch.Tensor): 84 | return self.layer(x).relu() 85 | 86 | 87 | class DPFPFeatureMap(nn.Module): 88 | 89 | r""" 90 | Deterministic Parameter-Free Projection (DPFP) feature map in 91 | `Linear Transformers Are Secretly Fast Weight Programmers `_ 92 | """ 93 | 94 | def __init__( 95 | self, 96 | head_dim: int, 97 | nu: int = 4 98 | ) -> DPFPFeatureMap: 99 | super().__init__() 100 | self.nu = nu 101 | 102 | def forward(self, x: torch.Tensor): 103 | x = torch.cat([x.relu(), -x.relu()], dim=-1) 104 | x_rolled = torch.cat([x.roll(shifts=j, dims=-1) for j in range(1, self.nu+1)], dim=-1) 105 | x_repeat = torch.cat([x] * self.nu, dim=-1) 106 | return x_repeat * x_rolled 107 | 108 | 109 | class HadamardFeatureMap(nn.Module): 110 | def __init__( 111 | self, 112 | head_dim: int 113 | ) -> HadamardFeatureMap: 114 | super().__init__() 115 | # Trainable map 116 | self.layer1 = nn.Linear(head_dim, head_dim) 117 | self.layer2 = nn.Linear(head_dim, head_dim) 118 | 119 | def forward(self, x: torch.Tensor): 120 | return self.layer1(x) * self.layer2(x) 121 | 122 | 123 | class LearnableOuterProductFeatureMap(nn.Module): 124 | def __init__( 125 | self, 126 | head_dim: int, 127 | feature_dim: int 128 | ) -> LearnableOuterProductFeatureMap: 129 | super().__init__() 130 | # Trainable map 131 | self.layer1 = nn.Linear(head_dim, feature_dim, bias=False) 132 | self.layer2 = nn.Linear(head_dim, feature_dim, bias=False) 133 | self.normalizer = feature_dim ** -0.5 134 | 135 | def forward(self, x: torch.Tensor): 136 | return flatten_diag_outer_product(self.layer1(x), self.layer2(x)) 137 | 138 | 139 | class LearnablePolySketchNonNegativeFeatureMap(nn.Module): 140 | 141 | def __init__( 142 | self, 143 | head_dim: int, 144 | sketch_size: Optional[int] = None, 145 | degree: Optional[int] = 2 146 | ) -> LearnablePolySketchNonNegativeFeatureMap: 147 | super().__init__() 148 | 149 | assert is_power_of_2(degree) and degree >= 2, f"The degree {degree} must be a power of 2" 150 | 151 | self.head_dim = head_dim 152 | self.sketch_size = sketch_size if sketch_size is not None else head_dim 153 | self.degree = degree 154 | 155 | self.gamma = nn.Parameter(torch.ones(head_dim)) 156 | self.beta = nn.Parameter(torch.zeros(head_dim)) 157 | # NOTE: the sketch layers defined here are quite different from the original paper 158 | # currently we simply use linear layers without any non-linear activations 159 | self.sketches1 = nn.ModuleList([ 160 | nn.Linear(head_dim, sketch_size, bias=False), 161 | *[nn.Linear(sketch_size, sketch_size, bias=False) for _ in range(int(math.log2(self.degree)) - 2)] 162 | ]) 163 | self.sketches2 = nn.ModuleList([ 164 | nn.Linear(head_dim, sketch_size, bias=False), 165 | *[nn.Linear(sketch_size, sketch_size, bias=False) for _ in range(int(math.log2(self.degree)) - 2)] 166 | ]) 167 | 168 | def forward(self, x: torch.Tensor): 169 | # Section 2.1 170 | x = layer_norm_fn(x, self.gamma, self.beta) 171 | # first map the input to sketch size with learnable parameters 172 | x = self.sketches1[0](x) * self.sketches2[0](x) * self.head_dim ** -0.5 173 | for i in range(1, int(math.log2(self.degree)) - 1): 174 | x = self.sketches1[i](x) * self.sketches2[i](x) * self.head_dim ** -0.5 175 | # do sketch mapping for log2(p) - 1 times in total 176 | # do p=2 mapping to ensure non-negativity 177 | return flatten_diag_outer_product(x, x) 178 | 179 | 180 | class TaylorFeatureMap(nn.Module): 181 | def __init__( 182 | self, 183 | head_dim: int 184 | ) -> TaylorFeatureMap: 185 | super().__init__() 186 | self.head_dim = head_dim 187 | self.r2 = math.sqrt(2) 188 | self.rd = math.sqrt(self.head_dim) 189 | self.rrd = math.sqrt(self.rd) 190 | 191 | def forward(self, x: torch.Tensor): 192 | x2_1, x2_2 = flatten_diag_outer_product_off1(x, x) 193 | return torch.cat([torch.ones_like(x[..., 0:1]), x / self.rrd, x2_2 / (self.rd * self.r2), x2_1 / self.rd], dim=-1) 194 | 195 | 196 | class RebasedFeatureMap(nn.Module): 197 | 198 | def __init__( 199 | self, 200 | head_dim: int, 201 | use_gamma: Optional[bool] = True, 202 | use_beta: Optional[bool] = True, 203 | normalize: Optional[bool] = True 204 | ) -> RebasedFeatureMap: 205 | super().__init__() 206 | 207 | self.head_dim = head_dim 208 | self.use_gamma = use_gamma 209 | self.use_beta = use_beta 210 | self.normalize = normalize 211 | 212 | self.gamma = None 213 | self.beta = None 214 | if use_gamma: 215 | self.gamma = nn.Parameter(torch.ones(head_dim)) 216 | if use_beta: 217 | self.beta = nn.Parameter(torch.zeros(head_dim)) 218 | 219 | def forward(self, x: torch.Tensor, flatten: Optional[bool] = True): 220 | if self.use_beta and self.use_gamma and self.normalize: 221 | x = layer_norm_fn(x, self.gamma, self.beta) 222 | elif self.normalize: 223 | x = F.layer_norm(x, (self.head_dim,), self.gamma, self.beta) 224 | elif self.use_gamma and self.use_beta: 225 | x = torch.addcmul(self.beta, x, self.gamma) 226 | elif self.use_gamma: 227 | x = x.mul(self.gamma) 228 | else: 229 | raise RuntimeError(f"Not supported combination of `use_gamma`, `use_beta` and `normalize`, " 230 | f"which is currentlt set as (`{self.use_gamma}`, `{self.use_beta}`, `{self.normalize}`)") 231 | if not flatten: 232 | return x 233 | x2_1, x2_2 = flatten_diag_outer_product_off1(x, x) 234 | # rebased use learnable parameters to approximate any quadratic function 235 | return torch.cat([x2_2 * self.head_dim ** -0.5, x2_1 * (2 / self.head_dim) ** 0.5], dim=-1) 236 | -------------------------------------------------------------------------------- /fla/modules/l2norm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.cuda.amp import custom_fwd, custom_bwd 6 | import triton 7 | import triton.language as tl 8 | 9 | @triton.autotune( 10 | configs=[ 11 | triton.Config({}, num_warps=1), 12 | triton.Config({}, num_warps=2), 13 | triton.Config({}, num_warps=4), 14 | triton.Config({}, num_warps=8), 15 | triton.Config({}, num_warps=16), 16 | triton.Config({}, num_warps=32), 17 | ], 18 | key=["N"], 19 | ) 20 | # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) 21 | # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) 22 | @triton.jit 23 | def _l2_norm_fwd_1pass_kernel( 24 | X, # pointer to the input 25 | Y, # pointer to the output 26 | stride_x_row, # how much to increase the pointer when moving by 1 row 27 | N, # number of columns in X 28 | eps, # epsilon to avoid division by zero 29 | BLOCK_N: tl.constexpr, 30 | ): 31 | # Map the program id to the row of X and Y it should compute. 32 | row = tl.program_id(0) 33 | X += row * stride_x_row 34 | Y += row * stride_x_row 35 | # Compute mean and variance 36 | cols = tl.arange(0, BLOCK_N) 37 | x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) 38 | xbar = tl.where(cols < N, x, 0.0) 39 | var = tl.sum(xbar * xbar, axis=0) 40 | rstd = 1 / tl.sqrt(var + eps) 41 | # tl.store(Rstd + row, rstd) 42 | # Normalize and apply linear transformation 43 | mask = cols < N 44 | y = x * rstd 45 | # Write output 46 | tl.store(Y + cols, y, mask=mask) 47 | 48 | 49 | @triton.autotune( 50 | configs=[ 51 | triton.Config({}, num_warps=1), 52 | triton.Config({}, num_warps=2), 53 | triton.Config({}, num_warps=4), 54 | triton.Config({}, num_warps=8), 55 | triton.Config({}, num_warps=16), 56 | triton.Config({}, num_warps=32), 57 | ], 58 | key=["N"], 59 | ) 60 | # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) 61 | # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) 62 | # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) 63 | # @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) 64 | @triton.jit 65 | def _l2_norm_bwd_kernel( 66 | X, # pointer to the input 67 | # Y, # pointer to the output to be recomputed 68 | DY, # pointer to the output gradient 69 | DX, # pointer to the input gradient 70 | stride_x_row, # how much to increase the pointer when moving by 1 row 71 | N, # number of columns in X 72 | eps, # epsilon to avoid division by zero 73 | BLOCK_N: tl.constexpr, 74 | ): 75 | # Map the program id to the elements of X, DX, and DY it should compute. 76 | # Map the program id to the row of X and Y it should compute. 77 | row = tl.program_id(0) 78 | X += row * stride_x_row 79 | DX += row * stride_x_row 80 | DY += row * stride_x_row 81 | 82 | # Y += row * stride_y_row 83 | cols = tl.arange(0, BLOCK_N) 84 | x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) 85 | x = tl.where(cols < N, x, 0.0) 86 | var = tl.sum(x * x) 87 | rstd = 1 / tl.sqrt(var + eps) 88 | # tl.store(Rstd + row, rstd) 89 | # Normalize and apply linear transformation 90 | mask = cols < N 91 | # y = x * rstd 92 | dy = tl.load(DY + cols, mask=cols < N, other=0.0).to(tl.float32) 93 | dy = tl.where(cols < N, dy, 0.0) 94 | # dx = dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x 95 | dx = dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x 96 | tl.store(DX + cols, dx, mask=mask) 97 | 98 | def _l2_norm_fwd( 99 | x, eps=1e-6 100 | ): 101 | x_shape_og = x.shape 102 | x = x.reshape(-1, x.shape[-1]) 103 | if x.stride(-1) != 1: 104 | x = x.contiguous() 105 | M, N = x.shape 106 | assert x.stride(-1) == 1 107 | # allocate output 108 | y = torch.empty_like(x) 109 | assert y.stride(-1) == 1 110 | N = x.shape[-1] 111 | M = x.shape[0] 112 | # rstd = torch.empty((M,), dtype=torch.float32, device="cuda") 113 | # Less than 64KB per feature: enqueue fused kernel 114 | MAX_FUSED_SIZE = 65536 // x.element_size() 115 | BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) 116 | if N > BLOCK_N: 117 | raise RuntimeError( 118 | "This layer norm doesn't support feature dim >= 64KB.") 119 | # heuristics for number of warps 120 | with torch.cuda.device(x.device.index): 121 | _l2_norm_fwd_1pass_kernel[(M,)]( 122 | x, 123 | y, 124 | x.stride(0), 125 | N, 126 | eps, 127 | # is_rms_norm, 128 | BLOCK_N, 129 | # residual is not None, 130 | # residual_out is not None, 131 | # bias is not None, 132 | ) 133 | return y.reshape(x_shape_og) 134 | 135 | def _l2_norm_bwd( 136 | x, dy, eps=1e-5, 137 | ): 138 | x_shape_og = x.shape 139 | x = x.reshape(-1, dy.shape[-1]) 140 | dy = dy.reshape(-1, dy.shape[-1]) 141 | if dy.stride(-1) != 1: 142 | dy = dy.contiguous() 143 | assert dy.shape == x.shape 144 | # allocate output 145 | dx = torch.empty_like(x) 146 | N = x.shape[-1] 147 | M = x.shape[0] 148 | assert x.stride(-1) == 1 149 | assert dy.stride(-1) == 1 150 | # rstd = torch.empty((M,), dtype=torch.float32, device="cuda") 151 | # Less than 64KB per feature: enqueue fused kernel 152 | MAX_FUSED_SIZE = 65536 // x.element_size() 153 | BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) 154 | if N > BLOCK_N: 155 | raise RuntimeError( 156 | "This layer norm doesn't support feature dim >= 64KB.") 157 | # heuristics for number of warps 158 | with torch.cuda.device(x.device.index): 159 | _l2_norm_bwd_kernel[(M,)]( 160 | x, 161 | dy, 162 | dx, 163 | x.stride(0), 164 | N, 165 | eps, 166 | BLOCK_N, 167 | ) 168 | return dx.reshape(x_shape_og) 169 | 170 | 171 | class L2NormFN(torch.autograd.Function): 172 | @staticmethod 173 | def forward( 174 | ctx, 175 | x, 176 | eps=1e-6, 177 | ): 178 | # reshape input data into 2D tensor 179 | y = _l2_norm_fwd(x, eps) 180 | ctx.x_shape_og = x_shape_og 181 | ctx.eps = eps 182 | ctx.x_dtype = x.dtype 183 | ctx.save_for_backward(x) 184 | return y 185 | 186 | @staticmethod 187 | def backward(ctx, dy, *args): 188 | x, = ctx.saved_tensors 189 | dx = _l2_norm_bwd( 190 | x, 191 | dy, 192 | ctx.eps, 193 | ) 194 | return ( 195 | dx, 196 | None 197 | ) 198 | 199 | l2_norm_fn = L2NormFN.apply 200 | 201 | if __name__ == '__main__': 202 | x = torch.rand(10, 10, 100).cuda().requires_grad_(True) 203 | y = torch.nn.functional.normalize(x, dim=-1, p=2) 204 | dy = torch.rand_like(y) 205 | y.backward(dy, retain_graph=True) 206 | x_grad, x.grad = x.grad, None 207 | y2 = l2_norm_fn(x, 1e-6) 208 | print((y-y2).abs().max()) 209 | y2.backward(dy, retain_graph=True) 210 | x_grad2, x.grad = x.grad, None 211 | print((x_grad2-x_grad).abs().max()) 212 | breakpoint() 213 | 214 | 215 | 216 | 217 | -------------------------------------------------------------------------------- /fla/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .based import fused_chunk_based, parallel_based 4 | from .gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla 5 | from .retention import (chunk_retention, fused_chunk_retention, 6 | fused_recurrent_retention, parallel_retention) 7 | 8 | __all__ = [ 9 | 'fused_chunk_based', 10 | 'parallel_based', 11 | 'chunk_gla', 12 | 'fused_chunk_gla', 13 | 'fused_recurrent_gla', 14 | 'chunk_retention', 15 | 'fused_chunk_retention', 16 | 'fused_recurrent_retention', 17 | 'parallel_retention' 18 | ] 19 | -------------------------------------------------------------------------------- /fla/ops/abc/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_abc 4 | from .chunk_gate import chunk_gated_abc 5 | from .recurrent_fuse import fused_recurrent_gated_abc 6 | 7 | __all__ = [ 8 | 'chunk_abc', 9 | 'chunk_gated_abc', 10 | 'fused_recurrent_gated_abc' 11 | ] 12 | -------------------------------------------------------------------------------- /fla/ops/abc/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | import torch 6 | 7 | 8 | def naive_recurrent_abc( 9 | q: torch.Tensor, 10 | k: torch.Tensor, 11 | v: torch.Tensor, 12 | s: torch.Tensor, 13 | g: Optional[torch.Tensor] = None, 14 | scale: Optional[int] = None, 15 | initial_state: Optional[torch.Tensor] = None, 16 | output_final_state: Optional[bool] = False 17 | ) -> torch.Tensor: 18 | dtype = q.dtype 19 | 20 | # [batch_size, n_heads, seq_len, n_slots] 21 | if g is None: 22 | z = s.float().logcumsumexp(2) 23 | g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z 24 | s = torch.exp(s - z) 25 | q, k, v, s, g = map(lambda x: x.float(), (q, k, v, s, g)) 26 | B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] 27 | 28 | hk = torch.zeros(B, H, K, M, dtype=torch.float, device=q.device) 29 | ok = torch.zeros_like(s) 30 | 31 | if scale is None: 32 | scale = q.shape[-1] ** -0.5 33 | 34 | final_state = None 35 | if initial_state is not None: 36 | hk += initial_state[0] 37 | 38 | for i in range(T): 39 | q_i = q[:, :, i] * scale 40 | k_i = k[:, :, i] 41 | v_i = s[:, :, i] 42 | g_i = g[:, :, i].exp() 43 | hk = hk * g_i[..., None, :] + k_i[..., None] * v_i[..., None, :] 44 | ok[:, :, i] = (q_i[..., None] * hk).sum(-2) 45 | 46 | qv = ok.softmax(-1) 47 | hv = torch.zeros(B, H, M, V, dtype=torch.float, device=q.device) 48 | ov = torch.zeros_like(v) 49 | if initial_state is not None: 50 | hv += initial_state[1] 51 | 52 | for i in range(T): 53 | q_i = qv[:, :, i] 54 | k_i = s[:, :, i] 55 | v_i = v[:, :, i] 56 | g_i = g[:, :, i].exp() 57 | hv = hv * g_i[..., :, None] + k_i[..., None] * v_i[..., None, :] 58 | ov[:, :, i] = (q_i[..., None] * hv).sum(-2) 59 | 60 | if output_final_state: 61 | final_state = (hk, hv) 62 | return ov.to(dtype), final_state 63 | 64 | 65 | def naive_cumsum_abc( 66 | q: torch.Tensor, 67 | k: torch.Tensor, 68 | v: torch.Tensor, 69 | s: torch.Tensor 70 | ) -> torch.Tensor: 71 | """ 72 | A simple implementation of vanilla ABC that is more aligned with the descriptions in the paper. 73 | This is just for demonstration purposes, with no numerical stabilities guaranteed. 74 | """ 75 | 76 | dtype = q.dtype 77 | q, k, v, s = map(lambda x: x.float(), (q, k, v, s)) 78 | 79 | scale = q.shape[-1] ** -0.5 80 | # [batch_size, n_heads, seq_len, n_slots] 81 | s = (s - s.max(2, True)[0]).exp() 82 | z = s.cumsum(2) 83 | # [batch_size, n_heads, seq_len, n_slots, d_head] 84 | K = (s.unsqueeze(-1) * k.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1) 85 | V = (s.unsqueeze(-1) * v.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1) 86 | # [batch_size, n_heads, seq_len, n_slots] 87 | p = torch.einsum('...d,...md->...m', q * scale, K).softmax(-1) 88 | # [batch_size, n_heads, seq_len, d_head] 89 | o = torch.einsum('...m,...md->...d', p, V) 90 | return o.to(dtype), None 91 | -------------------------------------------------------------------------------- /fla/ops/based/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk_fuse import fused_chunk_based 4 | from .parallel import parallel_based 5 | 6 | __all__ = [ 7 | 'fused_chunk_based', 8 | 'parallel_based' 9 | ] 10 | -------------------------------------------------------------------------------- /fla/ops/based/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from einops import rearrange 5 | 6 | from fla.ops.based.chunk_fuse import fused_chunk_based 7 | from fla.ops.based.parallel import parallel_based 8 | 9 | 10 | def naive_parallel_based(q, k, v, use_scale=True, use_norm=True): 11 | if use_scale: 12 | q = q * (q.shape[-1] ** -0.5) 13 | attn = q @ k.transpose(-2, -1) 14 | attn = 1 + attn + 1/2 * (attn ** 2) 15 | attn.masked_fill_(~torch.tril(torch.ones( 16 | q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0) 17 | o = attn @ v 18 | if use_norm: 19 | z = attn.sum(-1) 20 | return o / (z[..., None] + 1e-6) 21 | else: 22 | return o 23 | 24 | 25 | def naive_chunk_based(q, k, v, chunk_size=256): 26 | q = q * (q.shape[-1] ** -0.5) 27 | 28 | # compute normalizer. 29 | k_cumsum = torch.cumsum(k, dim=-2) 30 | kk_cumsum = torch.cumsum(k.unsqueeze(-1) * k.unsqueeze(-2), dim=-3) 31 | # first 32 | z = (q * k_cumsum).sum(-1) 33 | # second order 34 | z += (q.unsqueeze(-1) * q.unsqueeze(-2) * kk_cumsum).sum((-1, -2)) * 0.5 35 | # zero-th order 36 | z += (torch.arange(0, q.shape[-2]).to(z.device) * 1.0 + 1.0)[None, None, :] 37 | 38 | # compute o 39 | # constant term 40 | _o = v.cumsum(-2) 41 | 42 | q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size) 43 | 44 | k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size) 45 | v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size) 46 | 47 | intra_chunk_attn = q @ k.transpose(-2, -1) 48 | intra_chunk_attn = intra_chunk_attn + 1/2 * (intra_chunk_attn ** 2) 49 | intra_chunk_attn.masked_fill_( 50 | ~torch.tril( 51 | torch.ones(chunk_size, chunk_size, 52 | dtype=torch.bool, device=q.device), 53 | ), 0) 54 | o = intra_chunk_attn @ v 55 | 56 | # quadractic term 57 | kv = torch.einsum( 58 | 'b h n c x, b h n c y, b h n c z -> b h n x y z', k, k, v) 59 | kv = kv.cumsum(2) 60 | kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) 61 | 62 | o += 0.5 * torch.einsum('b h n x y z, b h n c x, b h n c y -> b h n c z', kv, q, q) 63 | 64 | # linear term 65 | kv = torch.einsum('b h n c x, b h n c y -> b h n x y', k, v) 66 | kv = kv.cumsum(2) 67 | kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) 68 | o += torch.einsum('b h n x y, b h n c x -> b h n c y', kv, q) 69 | 70 | o = rearrange(o, 'b h n c d -> b h (n c) d') 71 | o = o + _o 72 | return o / (z[..., None] + 1e-6) 73 | 74 | 75 | if __name__ == "__main__": 76 | B = 4 77 | H = 4 78 | L = 128 79 | # D = 15 80 | dtype = torch.float32 81 | q = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True) 82 | k = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True) 83 | v = torch.randn(B, H, L, 128).cuda().to(dtype).requires_grad_(True) 84 | 85 | do = torch.randn_like(v).cuda() 86 | ref = naive_parallel_based(q, k, v, True, True) 87 | ref.backward(do, retain_graph=True) 88 | ref_dq, q.grad = q.grad.clone(), None 89 | ref_dk, k.grad = k.grad.clone(), None 90 | ref_dv, v.grad = v.grad.clone(), None 91 | 92 | # tri = naive_chunk_based(q, k, v) 93 | # tri.backward(do, retain_graph=True) 94 | # tri_dq, q.grad = q.grad.clone(), None 95 | # tri_dk, k.grad = k.grad.clone(), None 96 | # tri_dv, v.grad = v.grad.clone(), None 97 | 98 | # assert ref.allclose(tri, 0, 1e-4), breakpoint() 99 | # assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint() 100 | # assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint() 101 | # assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint() 102 | 103 | tri = fused_chunk_based(q, k, v, True, True) 104 | tri.backward(do, retain_graph=True) 105 | tri_dq, q.grad = q.grad.clone(), None 106 | tri_dk, k.grad = k.grad.clone(), None 107 | tri_dv, v.grad = v.grad.clone(), None 108 | print((ref-tri).abs().max()) 109 | print((ref_dq-tri_dq).abs().max()) 110 | print((ref_dk-tri_dk).abs().max()) 111 | print((ref_dv-tri_dv).abs().max()) 112 | 113 | # assert ref.allclose(tri, 0, 1e-4), breakpoint() 114 | # assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint() 115 | # assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint() 116 | # assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint() 117 | 118 | tri = parallel_based(q, k, v, True, True) 119 | tri.backward(do, retain_graph=True) 120 | tri_dq, q.grad = q.grad.clone(), None 121 | tri_dk, k.grad = k.grad.clone(), None 122 | tri_dv, v.grad = v.grad.clone(), None 123 | 124 | print((ref-tri).abs().max()) 125 | print((ref_dq-tri_dq).abs().max()) 126 | print((ref_dk-tri_dk).abs().max()) 127 | print((ref_dv-tri_dv).abs().max()) 128 | 129 | # assert ref.allclose(tri, 0, 1e-4), breakpoint() 130 | # assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint() 131 | # assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint() 132 | # assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint() 133 | -------------------------------------------------------------------------------- /fla/ops/delta_rule/README.md: -------------------------------------------------------------------------------- 1 | - Delta Rule 2 | 3 | The implementation of delta rule described in https://arxiv.org/abs/2102.11174 4 | 5 | -------------------------------------------------------------------------------- /fla/ops/delta_rule/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk_fuse import fused_chunk_delta_rule 4 | from .recurrent_fuse import fused_recurrent_linear_attn_delta_rule 5 | from .chunk import chunk_delta_rule 6 | 7 | __all__ = [ 8 | 'fused_chunk_delta_rule', 9 | 'fused_recurrent_linear_attn_delta_rule', 10 | 'chunk_delta_rule' 11 | ] 12 | -------------------------------------------------------------------------------- /fla/ops/delta_rule/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from einops import rearrange 5 | 6 | 7 | def delta_rule_recurrence(q, k, v, beta): 8 | b, h, l, d_k = q.shape 9 | d_v = v.shape[-1] 10 | o = torch.zeros_like(v) 11 | S = torch.zeros(b, h, d_k, d_v).to(v) 12 | q = q * (d_k ** -0.5) 13 | for i in range(l): 14 | _k = k[:, :, i] 15 | _q = q[:, :, i] 16 | _v = v[:, :, i].clone() 17 | beta_i = beta[:, :, i] 18 | _v = _v - (S.clone() * _k[..., None]).sum(-2) 19 | _v = _v * beta_i[..., None] 20 | S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2) 21 | o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) 22 | return o 23 | 24 | 25 | def delta_rule_chunkwise(q, k, v, beta, chunk_size=32): 26 | b, h, l, d_k = q.shape 27 | d_v = v.shape[-1] 28 | q = q * (d_k ** -0.5) 29 | v = v * beta[..., None] 30 | k_beta = k * beta[..., None] 31 | 32 | assert l % chunk_size == 0 33 | 34 | # note that diagonal is masked. 35 | mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0) 36 | q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), [q, k, v, k_beta]) 37 | attn = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0) 38 | 39 | for i in range(1, chunk_size): 40 | attn[..., i, :i] = attn[..., i, :i] + (attn[..., i, :, None].clone() * attn[..., :, :i].clone()).sum(-2) 41 | 42 | attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device) 43 | # u 44 | k_cumsum = attn @ v 45 | # w 46 | k_cumdecay = attn @ k_beta 47 | 48 | v = k_cumsum 49 | S = k.new_zeros(b, h, d_k, d_v) 50 | o = torch.zeros_like(v) 51 | mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1) 52 | for i in range(0, l // chunk_size): 53 | q_i, k_i, v_i = q[:, :, i], k[:, :, i], v[:, :, i] 54 | attn = (q_i @ k_i.transpose(-1, -2)).masked_fill_(mask, 0) 55 | v_prime = k_cumdecay[:, :, i] @ S 56 | v_new = v_i - v_prime 57 | o_inter = q_i @ S 58 | o[:, :, i] = o_inter + attn @ v_new 59 | # chunk state update 60 | S = S + k_i.transpose(-1, -2) @ v_new 61 | 62 | return rearrange(o, 'b h n c d -> b h (n c) d') 63 | 64 | 65 | if __name__ == '__main__': 66 | B = 2 67 | H = 4 68 | L = 256 69 | DK = 128 70 | DV = 128 71 | q = (torch.randn(B, H, L, DK)).cuda().requires_grad_(True) 72 | k = (torch.randn(B, H, L, DK)).cuda() 73 | k = torch.nn.functional.normalize(k, dim=-1, p=2).requires_grad_(True) 74 | v = (torch.randn(B, H, L, DV)).cuda().requires_grad_(True) 75 | beta = torch.randn(B, H, L).cuda().sigmoid().requires_grad_(True) 76 | 77 | o = delta_rule_recurrence(q, k, v, beta) 78 | do = torch.randn(B, H, L, DV).cuda() 79 | o.backward(do, retain_graph=True) 80 | q_grad, q.grad = q.grad, None 81 | k_grad, k.grad = k.grad, None 82 | v_grad, v.grad = v.grad, None 83 | beta_grad, beta.grad = beta.grad, None 84 | 85 | o2 = delta_rule_chunkwise(q, k, v, beta) 86 | o2.backward(do) 87 | assert torch.allclose(o, o2, atol=1e-4), breakpoint() 88 | assert torch.allclose(q.grad, q_grad, atol=1e-4), breakpoint() 89 | assert torch.allclose(k.grad, k_grad, atol=1e-4), breakpoint() 90 | assert torch.allclose(v.grad, v_grad, atol=1e-4), breakpoint() 91 | assert torch.allclose(beta.grad, beta_grad, atol=1e-4), breakpoint() 92 | print("All passed!") 93 | -------------------------------------------------------------------------------- /fla/ops/gla/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_gla 4 | from .chunk_fuse import fused_chunk_gla 5 | from .recurrent_fuse import fused_recurrent_gla 6 | 7 | __all__ = [ 8 | 'chunk_gla', 9 | 'fused_chunk_gla', 10 | 'fused_recurrent_gla' 11 | ] 12 | -------------------------------------------------------------------------------- /fla/ops/gla/chunk_util.py: -------------------------------------------------------------------------------- 1 | import triton 2 | import triton.language as tl 3 | 4 | inv_ln2 = 1.44269504 5 | 6 | 7 | 8 | @triton.jit 9 | def fwd_decay_cumsum( 10 | g, 11 | g_o, 12 | s_qk_h, 13 | s_qk_t, 14 | s_qk_d, 15 | B, 16 | H, 17 | T, 18 | scale, 19 | BT: tl.constexpr, 20 | BK: tl.constexpr, 21 | DK: tl.constexpr 22 | ): 23 | i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) 24 | p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK) 25 | p_go = g_o + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK) 26 | cum_decay = tl.zeros([BK], dtype=tl.float32) 27 | mask = (i_k * BK + tl.arange(0, BK)) < DK 28 | 29 | for i in range(BT): 30 | _g = tl.load(p_g, mask=mask, other=0).to(tl.float32) 31 | cum_decay += _g * inv_ln2 32 | tl.store(p_go, cum_decay.to(p_go.dtype.element_ty), mask=mask) 33 | p_g += DK 34 | p_go += DK 35 | 36 | @triton.jit 37 | def prepare_qg_kg( 38 | q, 39 | k, 40 | g, 41 | qg, 42 | kg, 43 | s_qk_h, 44 | s_qk_t, 45 | s_qk_d, 46 | B, 47 | H, 48 | T, 49 | scale, 50 | BT: tl.constexpr, 51 | BK: tl.constexpr, 52 | DK: tl.constexpr 53 | ): 54 | 55 | i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) 56 | p_q = q + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK) 57 | p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK) 58 | p_k = k + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK) 59 | p_qg = qg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK) 60 | p_kg = kg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK) 61 | 62 | mask = (i_k * BK + tl.arange(0, BK)) < DK 63 | 64 | last_decay = tl.load(g + i_bh * s_qk_h + (i_c * BT + BT - 1) * DK + i_k * BK + tl.arange(0, BK)) 65 | 66 | for i in range(BT): 67 | _q = tl.load(p_q, mask=mask, other=0) 68 | _k = tl.load(p_k, mask=mask, other=0) 69 | _g = tl.load(p_g, mask=mask, other=0).to(tl.float32) 70 | _q *= tl.math.exp2(_g) * scale 71 | _k *= tl.math.exp2(last_decay - _g) 72 | tl.store(p_kg, _k.to(p_kg.dtype.element_ty), mask=mask) 73 | tl.store(p_qg, _q.to(p_qg.dtype.element_ty), mask=mask) 74 | p_q += DK 75 | p_g += DK 76 | p_k += DK 77 | p_kg += DK 78 | p_qg += DK 79 | 80 | 81 | @triton.jit 82 | def bwd_decay_global_cumsum( 83 | dq_inner, 84 | dq_inter, 85 | dk_inner, 86 | dk_inter, 87 | q, k, g, dg, 88 | s_qk_h, 89 | s_qk_t, 90 | s_qk_d, 91 | B, 92 | H, 93 | T, 94 | scale, 95 | BT: tl.constexpr, 96 | BK: tl.constexpr, 97 | DK: tl.constexpr 98 | ): 99 | i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) 100 | p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK 101 | p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK 102 | p_g = g + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK 103 | p_dg = dg + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK 104 | p_dq_inner = dq_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK 105 | p_dk_inner = dk_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK 106 | p_dq_inter = dq_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK 107 | p_dk_inter = dk_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK 108 | cum_grad_dg = tl.zeros([BK], dtype=tl.float32) 109 | mask = (i_k * BK + tl.arange(0, BK)) < DK 110 | last_g = tl.zeros([BK], dtype=tl.float32) 111 | for j in range(BT-1, -1, -1): 112 | _g = tl.load(p_g, mask=mask, other=0).to(tl.float32) 113 | if j == (BT-1): 114 | last_g = _g 115 | _dq1 = tl.load(p_dq_inner, mask=mask, other=0) 116 | _dq2 = tl.load(p_dq_inter, mask=mask, other=0) 117 | _dq2 *= tl.math.exp2(_g) 118 | _dq = _dq1 + _dq2 119 | tl.store(p_dq_inter, _dq, mask=mask) 120 | _dk1 = tl.load(p_dk_inner, mask=mask, other=0) 121 | _dk2 = tl.load(p_dk_inter, mask=mask, other=0) 122 | _dk2 *= tl.math.exp2(last_g - _g) 123 | _dk = _dk1 + _dk2 124 | tl.store(p_dk_inter, _dk, mask=mask) 125 | _q = tl.load(p_q, mask=mask, other=0) 126 | _k = tl.load(p_k, mask=mask, other=0) 127 | _dg = _dq * _q - _dk * _k 128 | cum_grad_dg += _dg 129 | tl.store(p_dg, cum_grad_dg.to(p_dg.dtype.element_ty), mask=mask) 130 | p_g -= DK 131 | p_k -= DK 132 | p_q -= DK 133 | p_dq_inner -= DK 134 | p_dk_inner -= DK 135 | p_dq_inter -= DK 136 | p_dk_inter -= DK 137 | p_dg -= DK 138 | 139 | -------------------------------------------------------------------------------- /fla/ops/gla/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from fla.ops.gla.recurrent_fuse import fused_recurrent_gla 7 | 8 | 9 | def ceildiv(a, b): 10 | return -(a // -b) 11 | 12 | 13 | def naive_recurrent_gla( 14 | q, 15 | k, 16 | v, 17 | gk, 18 | initial_state=None, 19 | output_final_state=False, 20 | causal=True 21 | ): 22 | orig_dtype = q.dtype 23 | q, k, v, gk = map(lambda x: x.float(), (q, k, v, gk)) 24 | batch_size, n_heads, seq_len, d_head_k = q.shape 25 | _, _, _, d_head_v = v.shape 26 | h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device) 27 | o = torch.zeros_like(v) 28 | scale = d_head_k ** -0.5 29 | 30 | if initial_state is not None: 31 | h += initial_state 32 | 33 | for i in range(seq_len): 34 | q_i = q[:, :, i, :] * scale 35 | k_i = k[:, :, i] 36 | v_i = v[:, :, i, :] 37 | gk_i = gk[:, :, i].exp() 38 | kv_i = k_i[..., None] * v_i[..., None, :] 39 | h = h * gk_i[..., None] + kv_i 40 | o_i = (q_i[..., None] * h).sum(-2) 41 | o[:, :, i] = o_i 42 | 43 | if causal: 44 | return o.to(orig_dtype), h 45 | else: 46 | o_reverse = torch.zeros_like(v) 47 | h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device) 48 | for i in range(seq_len-1, -1, -1): 49 | q_i = q[:, :, i, :] * scale 50 | k_i = k[:, :, i] 51 | v_i = v[:, :, i, :] 52 | gk_i = gk[:, :, i].exp() 53 | kv_i = k_i[..., None] * v_i[..., None, :] 54 | h = h * gk_i[..., None] + kv_i 55 | o_i = (q_i[..., None] * h).sum(-2) 56 | o_reverse[:, :, i] = o_i 57 | 58 | return o, o_reverse 59 | 60 | 61 | if __name__ == "__main__": 62 | B = 4 63 | H = 4 64 | L = 512 65 | D = 128 66 | dtype = torch.float32 67 | q = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True) 68 | k = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(True) 69 | v = torch.randn(B, H, L, D).cuda().to(dtype).requires_grad_(True) 70 | g = F.logsigmoid(torch.rand(B, H, L, D)).cuda( 71 | ).clamp_min(-1).to(torch.float32).requires_grad_(True) 72 | 73 | do = torch.rand_like(v).cuda() 74 | do2 = torch.rand_like(v).cuda() 75 | intial_state = torch.rand(B, H, D, D).cuda() 76 | 77 | ref, ref_rev = naive_recurrent_gla(q, k, v, g, causal=False) 78 | 79 | ref.backward(do, retain_graph=True) 80 | ref_rev.backward(do2, retain_graph=True) 81 | 82 | ref_dq, q.grad = q.grad.clone(), None 83 | ref_dk, k.grad = k.grad.clone(), None 84 | ref_dv, v.grad = v.grad.clone(), None 85 | ref_dg, g.grad = g.grad.clone(), None 86 | 87 | tri, tri_rev = fused_recurrent_gla( 88 | q, k, v, g, initial_state=None, scale=D**-0.5, output_final_state=False, causal=False) 89 | tri.backward(do, retain_graph=True) 90 | tri_rev.backward(do2, retain_graph=True) 91 | tri_dq, q.grad = q.grad.clone(), None 92 | tri_dk, k.grad = k.grad.clone(), None 93 | tri_dv, v.grad = v.grad.clone(), None 94 | tri_dg, g.grad = g.grad.clone(), None 95 | 96 | assert ref.allclose(tri, 0, 1e-5), breakpoint() 97 | assert ref_rev.allclose(tri_rev, 0, 1e-5), breakpoint() 98 | assert ref_dq.allclose(tri_dq, 0, 1e-5), breakpoint() 99 | assert ref_dk.allclose(tri_dk, 0, 1e-5), breakpoint() 100 | assert ref_dv.allclose(tri_dv, 0, 1e-5), breakpoint() 101 | assert ref_dg.allclose(tri_dg, 0, 1e-4), breakpoint() 102 | 103 | # tri = fused_chunk_gla(q, k, v, g) 104 | # tri.backward(do, retain_graph=True) 105 | # tri_dq, q.grad = q.grad.clone(), None 106 | # tri_dk, k.grad = k.grad.clone(), None 107 | # tri_dv, v.grad = v.grad.clone(), None 108 | # tri_dg, g.grad = g.grad.clone(), None 109 | 110 | # assert ref.allclose(tri, 0, 1e-5), breakpoint() 111 | # assert ref_dq.allclose(tri_dq, 0, 1e-5), breakpoint() 112 | # assert ref_dk.allclose(tri_dk, 0, 1e-5), breakpoint() 113 | # assert ref_dv.allclose(tri_dv, 0, 1e-5), breakpoint() 114 | # assert ref_dg.allclose(tri_dg, 0, 1e-4), breakpoint() 115 | # breakpoint() 116 | print("Pass") 117 | -------------------------------------------------------------------------------- /fla/ops/hgrn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_hgrn 4 | from .recurrent_fuse import fused_recurrent_hgrn 5 | 6 | __all__ = [ 7 | 'chunk_hgrn', 8 | 'fused_recurrent_hgrn' 9 | ] 10 | -------------------------------------------------------------------------------- /fla/ops/hgrn/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | import torch 6 | 7 | 8 | def naive_recurrent_hgrn( 9 | x: torch.Tensor, 10 | g: torch.Tensor, 11 | initial_state: Optional[torch.Tensor] = None, 12 | output_final_state: Optional[bool] = False 13 | ) -> torch.Tensor: 14 | dtype = x.dtype 15 | x, g = map(lambda i: i.float(), (x, g)) 16 | B, H, T, D = x.shape 17 | 18 | h = torch.zeros(B, H, D, dtype=torch.float, device=x.device) 19 | o = torch.zeros_like(x) 20 | 21 | final_state = None 22 | if initial_state is not None: 23 | h += initial_state.detach() 24 | 25 | for i in range(T): 26 | h = g[:, :, i].exp() * h + x[:, :, i] 27 | o[:, :, i] = h 28 | 29 | if output_final_state: 30 | final_state = h 31 | return o.to(dtype), final_state 32 | -------------------------------------------------------------------------------- /fla/ops/hgrn/recurrent_fuse.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2023, Songlin Yang 4 | 5 | from typing import Tuple 6 | 7 | import torch 8 | import triton 9 | import triton.language as tl 10 | 11 | from fla.utils import contiguous 12 | 13 | 14 | @triton.autotune( 15 | configs=[ 16 | triton.Config({'BD': 32}, num_warps=1), 17 | triton.Config({'BD': 32}, num_warps=2), 18 | triton.Config({'BD': 32}, num_warps=4), 19 | triton.Config({'BD': 32}, num_warps=8), 20 | triton.Config({'BD': 64}, num_warps=1), 21 | triton.Config({'BD': 64}, num_warps=2), 22 | triton.Config({'BD': 64}, num_warps=4), 23 | triton.Config({'BD': 64}, num_warps=8), 24 | triton.Config({'BD': 128}, num_warps=1), 25 | triton.Config({'BD': 128}, num_warps=2), 26 | triton.Config({'BD': 128}, num_warps=4), 27 | triton.Config({'BD': 128}, num_warps=8), 28 | ], 29 | key=['D'] 30 | ) 31 | @triton.jit 32 | def fused_recurrent_hgrn_fwd_kernel( 33 | x, 34 | g, 35 | o, 36 | h0, 37 | ht, 38 | T: tl.constexpr, 39 | D: tl.constexpr, 40 | BD: tl.constexpr, 41 | USE_INITIAL_STATE: tl.constexpr, 42 | STORE_FINAL_STATE: tl.constexpr 43 | ): 44 | i_d, i_bh = tl.program_id(0), tl.program_id(1) 45 | o_d = i_d * BD + tl.arange(0, BD) 46 | mask = o_d < D 47 | 48 | p_x = x + i_bh * T * D + o_d 49 | p_g = g + i_bh * T * D + o_d 50 | p_o = o + i_bh * T * D + o_d 51 | 52 | b_h = tl.zeros([BD], dtype=tl.float32) 53 | if USE_INITIAL_STATE: 54 | p_h0 = h0 + i_bh * D + o_d 55 | b_h += tl.load(p_h0, mask=mask, other=0).to(tl.float32) 56 | for _ in range(0, T): 57 | b_x = tl.load(p_x, mask=mask, other=0).to(tl.float32) 58 | b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32) 59 | b_h = tl.exp(b_g) * b_h + b_x 60 | tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask) 61 | 62 | p_x += D 63 | p_g += D 64 | p_o += D 65 | 66 | if STORE_FINAL_STATE: 67 | p_ht = ht + i_bh * D + o_d 68 | tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask) 69 | 70 | 71 | @triton.autotune( 72 | configs=[ 73 | triton.Config({'BD': 32}, num_warps=1), 74 | triton.Config({'BD': 32}, num_warps=2), 75 | triton.Config({'BD': 32}, num_warps=4), 76 | triton.Config({'BD': 32}, num_warps=8), 77 | triton.Config({'BD': 64}, num_warps=1), 78 | triton.Config({'BD': 64}, num_warps=2), 79 | triton.Config({'BD': 64}, num_warps=4), 80 | triton.Config({'BD': 64}, num_warps=8), 81 | triton.Config({'BD': 128}, num_warps=1), 82 | triton.Config({'BD': 128}, num_warps=2), 83 | triton.Config({'BD': 128}, num_warps=4), 84 | triton.Config({'BD': 128}, num_warps=8), 85 | ], 86 | key=['D'] 87 | ) 88 | @triton.jit 89 | def fused_recurrent_hgrn_bwd_kernel( 90 | g, 91 | o, 92 | dx, 93 | dg, 94 | do, 95 | h0, 96 | T: tl.constexpr, 97 | D: tl.constexpr, 98 | BD: tl.constexpr, 99 | USE_INITIAL_STATE: tl.constexpr 100 | ): 101 | i_d, i_bh = tl.program_id(0), tl.program_id(1) 102 | o_d = i_d * BD + tl.arange(0, BD) 103 | mask = o_d < D 104 | 105 | p_g = g + (i_bh * T + T - 1) * D + o_d 106 | p_o = o + (i_bh * T + T - 2) * D + o_d 107 | p_dx = dx + (i_bh * T + T - 1) * D + o_d 108 | p_dg = dg + (i_bh * T + T - 1) * D + o_d 109 | p_do = do + (i_bh * T + T - 1) * D + o_d 110 | 111 | b_dh = tl.zeros([BD], dtype=tl.float32) 112 | for i in range(T - 1, -1, -1): 113 | b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32) 114 | b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32) 115 | if i > 0: 116 | b_o = tl.load(p_o, mask=mask, other=0).to(tl.float32) 117 | elif USE_INITIAL_STATE: 118 | b_o = tl.load(h0 + i_bh * D + o_d, mask=mask, other=0).to(tl.float32) 119 | else: 120 | b_o = tl.zeros([BD], dtype=tl.float32) 121 | 122 | b_dh = b_dh + b_do 123 | b_dx = b_dh 124 | b_dh = b_dh * tl.exp(b_g) 125 | b_dg = b_dh * b_o 126 | tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask) 127 | tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), mask=mask) 128 | 129 | p_g -= D 130 | p_o -= D 131 | p_dx -= D 132 | p_dg -= D 133 | p_do -= D 134 | 135 | 136 | class FusedRecurrentHGRNFunction(torch.autograd.Function): 137 | 138 | @staticmethod 139 | @contiguous 140 | def forward(ctx, x, g, initial_state=None, output_final_state=False): 141 | B, H, T, D = x.shape 142 | 143 | final_state = None 144 | if output_final_state: 145 | final_state = x.new_empty(B, H, D) 146 | 147 | o = torch.empty_like(x) 148 | def grid(meta): return (triton.cdiv(D, meta['BD']), B * H) 149 | fused_recurrent_hgrn_fwd_kernel[grid]( 150 | x, g, o, initial_state, final_state, 151 | T, D, 152 | USE_INITIAL_STATE=initial_state is not None, 153 | STORE_FINAL_STATE=final_state is not None 154 | ) 155 | ctx.save_for_backward(g, o, initial_state) 156 | return o, final_state 157 | 158 | @staticmethod 159 | @contiguous 160 | def backward(ctx, do, dht=None): 161 | g, o, initial_state = ctx.saved_tensors 162 | B, H, T, D = do.shape 163 | 164 | dx = torch.empty_like(o) 165 | dg = torch.empty_like(g) 166 | def grid(meta): return (triton.cdiv(D, meta['BD']), B * H) 167 | fused_recurrent_hgrn_bwd_kernel[grid]( 168 | g, o, dx, dg, do, initial_state, 169 | T, D, 170 | USE_INITIAL_STATE=initial_state is not None, 171 | ) 172 | 173 | return dx, dg, None, None 174 | 175 | 176 | def fused_recurrent_hgrn( 177 | x: torch.Tensor, 178 | g: torch.Tensor, 179 | initial_state: torch.Tensor = None, 180 | output_final_state: bool = False 181 | ) -> Tuple[torch.Tensor, torch.Tensor]: 182 | if initial_state is not None: 183 | initial_state = initial_state.detach() 184 | o, final_state = FusedRecurrentHGRNFunction.apply(x, g, initial_state, output_final_state) 185 | return o, final_state 186 | -------------------------------------------------------------------------------- /fla/ops/linear_attn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_linear_attn 4 | from .chunk_fuse import fused_chunk_linear_attn 5 | from .recurrent_fuse import fused_recurrent_linear_attn 6 | 7 | __all__ = [ 8 | 'chunk_linear_attn', 9 | 'fused_chunk_linear_attn', 10 | 'fused_recurrent_linear_attn' 11 | ] 12 | 13 | -------------------------------------------------------------------------------- /fla/ops/linear_attn/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from einops import rearrange 5 | 6 | 7 | def torch_chunk_linear_attn(q, k, v, chunk_size=64): 8 | q = rearrange(q, 'b h (n c) d -> b h n c d', c = chunk_size) * (q.shape[-1] **-0.5) 9 | k = rearrange(k, 'b h (n c) d -> b h n c d', c = chunk_size) 10 | v = rearrange(v, 'b h (n c) d -> b h n c d', c = chunk_size) 11 | kv = k.transpose(-1, -2) @ v 12 | kv = kv.cumsum(2) 13 | kv = torch.cat([ 14 | torch.zeros_like(kv[:, :, :1]), 15 | kv[:, :, :-1] 16 | ], dim=2) 17 | inter = q @ kv 18 | intra = ((q @ k.transpose(-1, -2)).masked_fill_(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0)) @ v 19 | o = inter + intra 20 | return rearrange(o, 'b h n c d -> b h (n c) d') 21 | -------------------------------------------------------------------------------- /fla/ops/rebased/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .parallel import parallel_rebased 4 | 5 | __all__ = [ 6 | 'parallel_rebased' 7 | ] 8 | -------------------------------------------------------------------------------- /fla/ops/rebased/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from einops import rearrange 5 | 6 | from fla.ops.rebased.parallel import parallel_rebased 7 | 8 | def naive_parallel_rebased(q, k, v, use_scale=True, use_norm=True): 9 | if use_scale: 10 | q = q * (q.shape[-1] ** -0.5) 11 | attn = q @ k.transpose(-2, -1) 12 | attn = (attn ** 2) 13 | attn.masked_fill_(~torch.tril(torch.ones( 14 | q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0) 15 | o = attn @ v 16 | if use_norm: 17 | z = attn.sum(-1) 18 | return o / (z[..., None] + 1e-6) 19 | else: 20 | return o 21 | 22 | 23 | if __name__ == "__main__": 24 | B = 4 25 | H = 4 26 | L = 128 27 | # D = 15 28 | dtype = torch.float32 29 | q = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True) 30 | k = (torch.randn(B, H, L, 16).cuda().to(dtype)).requires_grad_(True) 31 | v = torch.randn(B, H, L, 128).cuda().to(dtype).requires_grad_(True) 32 | 33 | do = torch.randn_like(v).cuda() 34 | ref = naive_parallel_rebased(q, k, v, True, True) 35 | ref.backward(do, retain_graph=True) 36 | ref_dq, q.grad = q.grad.clone(), None 37 | ref_dk, k.grad = k.grad.clone(), None 38 | ref_dv, v.grad = v.grad.clone(), None 39 | 40 | # tri = naive_chunk_based(q, k, v) 41 | # tri.backward(do, retain_graph=True) 42 | # tri_dq, q.grad = q.grad.clone(), None 43 | # tri_dk, k.grad = k.grad.clone(), None 44 | # tri_dv, v.grad = v.grad.clone(), None 45 | 46 | # assert ref.allclose(tri, 0, 1e-4), breakpoint() 47 | # assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint() 48 | # assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint() 49 | # assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint() 50 | 51 | tri = parallel_rebased(q, k, v, 1e-6, True, True) 52 | tri.backward(do, retain_graph=True) 53 | tri_dq, q.grad = q.grad.clone(), None 54 | tri_dk, k.grad = k.grad.clone(), None 55 | tri_dv, v.grad = v.grad.clone(), None 56 | print((ref-tri).abs().max()) 57 | print((ref_dq-tri_dq).abs().max()) 58 | print((ref_dk-tri_dk).abs().max()) 59 | print((ref_dv-tri_dv).abs().max()) 60 | 61 | # assert ref.allclose(tri, 0, 1e-4), breakpoint() 62 | # assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint() 63 | # assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint() 64 | # assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint() 65 | 66 | # tri = parallel_based(q, k, v, True, True) 67 | # tri.backward(do, retain_graph=True) 68 | # tri_dq, q.grad = q.grad.clone(), None 69 | # tri_dk, k.grad = k.grad.clone(), None 70 | # tri_dv, v.grad = v.grad.clone(), None 71 | 72 | # print((ref-tri).abs().max()) 73 | # print((ref_dq-tri_dq).abs().max()) 74 | # print((ref_dk-tri_dk).abs().max()) 75 | # print((ref_dv-tri_dv).abs().max()) 76 | 77 | # assert ref.allclose(tri, 0, 1e-4), breakpoint() 78 | # assert ref_dq.allclose(tri_dq, 0, 1e-4), breakpoint() 79 | # assert ref_dk.allclose(tri_dk, 0, 1e-4), breakpoint() 80 | # assert ref_dv.allclose(tri_dv, 0, 1e-4), breakpoint() 81 | -------------------------------------------------------------------------------- /fla/ops/retention/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_retention 4 | from .chunk_fuse import fused_chunk_retention 5 | from .parallel import parallel_retention 6 | from .recurrent_fuse import fused_recurrent_retention 7 | 8 | __all__ = [ 9 | 'chunk_retention', 10 | 'fused_chunk_retention', 11 | 'parallel_retention', 12 | 'fused_recurrent_retention' 13 | ] 14 | -------------------------------------------------------------------------------- /fla/ops/retention/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | 5 | 6 | def naive_retention(q, k, v): 7 | orig_type = q.dtype 8 | q, k, v = q.float(), k.float(), v.float() 9 | _, n_heads, seq_len, d_head = q.shape 10 | s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(n_heads), dtype=torch.float))).log2() 11 | n = q.new_tensor(range(seq_len), dtype=torch.float) 12 | n = torch.exp2((n.unsqueeze(-1) - n) * s.view(-1, 1, 1)) * n.unsqueeze(-1).ge(n) 13 | s = torch.einsum('bhqd,bhkd,hqk->bhqk', q * d_head ** -0.5, k, n.to(q.dtype)) 14 | o = torch.einsum('bhqk,bhkd->bhqd', s, v) 15 | return o.to(orig_type) 16 | -------------------------------------------------------------------------------- /fla/ops/rwkv4/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .recurrent_fuse import fused_recurrent_rwkv4 4 | 5 | __all__ = [ 6 | 'fused_recurrent_rwkv4' 7 | ] 8 | -------------------------------------------------------------------------------- /fla/ops/rwkv6/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_rwkv6 4 | from .recurrent_fuse import fused_recurrent_rwkv6 5 | 6 | __all__ = [ 7 | 'chunk_rwkv6', 8 | 'fused_recurrent_rwkv6' 9 | ] 10 | -------------------------------------------------------------------------------- /fla/ops/rwkv6/chunk_naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from einops import rearrange 5 | 6 | from fla.ops.rwkv6.chunk import chunk_rwkv6 7 | from fla.ops.rwkv6.recurrent_fuse import fused_recurrent_rwkv6 8 | 9 | 10 | def naive_chunk_rwkv6( 11 | q, 12 | k, 13 | v, 14 | w, 15 | u, 16 | chunk_size=32, 17 | initial_state=None, 18 | output_final_state=True, 19 | ): 20 | assert q.shape[-2] % chunk_size == 0 21 | orig_dtype = q.dtype 22 | num_chunk = q.shape[-2] // chunk_size 23 | u = u.unsqueeze(0) 24 | 25 | q, k, v, w = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size).float(), (q, k, v, w)) 26 | 27 | w_cumsum = w.cumsum(-2) 28 | 29 | kw = k * (w_cumsum[..., -1, None, :] - w_cumsum).exp() 30 | wkv = kw.transpose(-1, -2) @ v 31 | 32 | wkv_new = torch.zeros_like(wkv) 33 | 34 | for i in range(num_chunk - 1): 35 | wkv_new[:, :, i+1] = (wkv_new[:, :, i] * w_cumsum[:, :, i, -1, :, None].exp()) + wkv[:, :, i] 36 | 37 | o_inter = torch.einsum('b h n d p, b h n c d -> b h n c p', wkv_new, (q * (w_cumsum - w).exp())) 38 | 39 | o_intra = torch.zeros_like(o_inter) 40 | for i in range(chunk_size): 41 | attn = (q[:, :, :, i, None] * k * (w_cumsum[:, :, :, i, None] - w[:, :, :, i, None] - w_cumsum).exp()).sum(-1) 42 | mask = (torch.arange(0, chunk_size) < i).to(attn.device) 43 | attn.masked_fill_(~mask, 0) 44 | intra_inter_o = (attn.unsqueeze(-1) * v).sum(-2) 45 | intra_intra_o = (q[:, :, :, i] * u.unsqueeze(2) * k[:, :, :, i]).sum(-1).unsqueeze(-1) * v[:, :, :, i] 46 | o_intra[:, :, :, i] = intra_inter_o + intra_intra_o 47 | o = o_inter + o_intra 48 | return rearrange(o, 'b h n c d -> b h (n c) d').to(orig_dtype) 49 | 50 | 51 | if __name__ == "__main__": 52 | B = 4 53 | H = 4 54 | L = 1024 55 | D = 100 56 | dtype = torch.bfloat16 57 | require_grad = True 58 | q = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(require_grad) 59 | k = (torch.randn(B, H, L, D).cuda().to(dtype)).requires_grad_(require_grad) 60 | v = torch.randn(B, H, L, 2*D).cuda().to(dtype).requires_grad_(require_grad) 61 | w = torch.nn.functional.logsigmoid(torch.randn(B, H, L, D)).cuda().to(dtype).requires_grad_(require_grad) 62 | u = (torch.randn(H, D).cuda().to(dtype)).requires_grad_(require_grad) 63 | do = torch.rand_like(v).cuda() 64 | o2, _ = chunk_rwkv6(q, k, v, w.clone(), u) 65 | o, _ = fused_recurrent_rwkv6(q, k, v, w, u, scale=1.0) 66 | o.backward(do) 67 | dq, q.grad = q.grad.clone(), None 68 | dk, k.grad = k.grad.clone(), None 69 | dv, v.grad = v.grad.clone(), None 70 | dw, w.grad = w.grad.clone(), None 71 | du, u.grad = u.grad.clone(), None 72 | print((o - o2).abs().max()) 73 | o2.backward(do) 74 | print((o-o2).abs().max()) 75 | print((q.grad - dq).abs().max()) 76 | print((k.grad - dk).abs().max()) 77 | print((v.grad - dv).abs().max()) 78 | print((w.grad - dw).abs().max()) 79 | print((u.grad - du).abs().max()) 80 | -------------------------------------------------------------------------------- /fla/ops/rwkv6/recurrent_naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | import torch 6 | 7 | 8 | def naive_recurrent_rwkv6( 9 | q: torch.Tensor, 10 | k: torch.Tensor, 11 | v: torch.Tensor, 12 | w: torch.Tensor, 13 | u: torch.Tensor, 14 | scale: Optional[float] = None, 15 | initial_state: Optional[torch.Tensor] = None, 16 | output_final_state: Optional[bool] = False 17 | ): 18 | orig_dtype = q.dtype 19 | B, H, T, K, V = *q.shape, v.shape[-1] 20 | q, k, v, w, u = map(lambda x: x.float(), (q, k, v, w, u)) 21 | h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device) 22 | o = torch.zeros_like(v) 23 | 24 | if scale is None: 25 | scale = K ** -0.5 26 | 27 | if initial_state is not None: 28 | h += initial_state 29 | 30 | for i in range(T): 31 | q_i = q[:, :, i, :] * scale 32 | k_i = k[:, :, i] 33 | v_i = v[:, :, i, :] 34 | w_i = w[:, :, i].exp() 35 | kv_i = k_i[..., None] * v_i[..., None, :] 36 | o_i = (h + u[None, ..., None] * kv_i) * q_i[..., None] 37 | o[:, :, i] = o_i.sum(-2) 38 | h = h * w_i[..., None] + kv_i 39 | ht = h if output_final_state else None 40 | return o.to(orig_dtype), ht 41 | 42 | 43 | def naive_recurrent_rwkv6_bwd( 44 | q, 45 | k, 46 | v, 47 | w, 48 | u, 49 | o, 50 | do, 51 | initial_state=None, 52 | output_final_state=False 53 | ): 54 | q, k, v, w, u, o, do = map(lambda x: x.float(), (q, k, v, w, u, o, do)) 55 | B, H, T, K, V = *q.shape, v.shape[-1] 56 | h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device) 57 | dq = torch.zeros_like(q) 58 | dq_aux = torch.zeros_like(q) 59 | 60 | if initial_state is not None: 61 | h += initial_state 62 | 63 | for i in range(T): 64 | k_i = k[:, :, i] 65 | v_i = v[:, :, i] 66 | w_i = w[:, :, i].exp() 67 | kv_i = k_i[..., None] * v_i[..., None, :] 68 | h_i = (h + u[None, ..., None] * kv_i) 69 | dq_i = (do[:, :, i, None, :] * h_i).sum(-1) 70 | dq_aux_i = (do[:, :, i, None, :] * h).sum(-1) 71 | dq[:, :, i] = dq_i 72 | dq_aux[:, :, i] = dq_aux_i 73 | h = h * w_i[..., None] + kv_i 74 | 75 | du = torch.zeros_like(u) 76 | dh = torch.zeros_like(h) 77 | dk = torch.zeros_like(k) 78 | dk_aux = torch.zeros_like(k) 79 | dv = torch.zeros_like(v) 80 | 81 | for i in range(T - 1, -1, -1): 82 | d_kv_i = do[:, :, i, None, :] * q[:, :, i, :, None] 83 | k_i = k[:, :, i] 84 | v_i = v[:, :, i] 85 | du_i = (d_kv_i * k_i[..., None] * v_i[..., None, :]).sum(-1) 86 | du += du_i 87 | dk_i = (dh * v_i[..., None, :]).sum(-1) 88 | dk_aux[:, :, i] = dk_i 89 | dk_i += (d_kv_i * u[None, ..., None] * v_i[..., None, :]).sum(-1) 90 | dv_i = (d_kv_i * u[None, ..., None] * k_i[..., None]).sum(-2) 91 | dv_i += (dh * k_i[..., None]).sum(-2) 92 | 93 | dk[:, :, i] = dk_i 94 | dv[:, :, i] = dv_i 95 | dh = dh * w[:, :, i, :, None].exp() + d_kv_i 96 | 97 | # dw = q * dq_aux - k * dk_aux 98 | dw = torch.zeros_like(w) 99 | for i in range(T - 2, -1, -1): 100 | dw[:, :, i] = dw[:, :, i+1] + dq_aux[:, :, i+1] * q[:, :, i+1] - dk_aux[:, :, i] * k[:, :, i] 101 | 102 | return dq, dk, dv, dw, du 103 | -------------------------------------------------------------------------------- /fla/ops/simple_gla/README.md: -------------------------------------------------------------------------------- 1 | - Simple GLA 2 | 3 | Gating mechanism in https://arxiv.org/abs/2103.02143. Compared to GLA, the gating is head-wise instead of elementwise. As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability. It is faster than GLA but has less expressive power. I will use it as a baseline for the GLA. 4 | 5 | $S_{t+1} = g_{t+1} \odot S_{t} + K_{t+1} V_{t+1}^{\top}$ where $g$ is a scalar. -------------------------------------------------------------------------------- /fla/ops/simple_gla/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_simple_gla 4 | 5 | __all__ = [ 6 | 'chunk_simple_gla' 7 | ] 8 | 9 | -------------------------------------------------------------------------------- /fla/ops/simple_gla/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from einops import rearrange 5 | 6 | 7 | def torch_simple_gla(q, k, v, g, chunk_size=64): 8 | q = rearrange(q, 'b h (n c) d -> b h n c d', c = chunk_size) * (q.shape[-1] ** -0.5) 9 | k = rearrange(k, 'b h (n c) d -> b h n c d', c = chunk_size) 10 | v = rearrange(v, 'b h (n c) d -> b h n c d', c = chunk_size) 11 | g = rearrange(g, 'b h (n c) -> b h n c', c = chunk_size) 12 | g = g.cumsum(-1) 13 | kv = k.transpose(-1, -2) @ (v * (-g + g[:, :, :, -1, None]).exp()[..., None]) 14 | S = torch.zeros_like(kv) 15 | 16 | for i in range(1, g.shape[-2]): 17 | S[:, :, i] = S[:, :, i-1].clone() * g[:, :, i-1, -1, None, None].exp() + kv[:, :, i-1] 18 | 19 | inter = (q * g[..., None].exp()) @ S 20 | attn = q @ k.transpose(-1, -2) 21 | attn = attn * (g[..., None] - g[..., None, :]).exp() 22 | attn = attn.masked_fill(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0) 23 | intra = attn @ v 24 | o = inter + intra 25 | return rearrange(o, 'b h n c d -> b h (n c) d') 26 | 27 | 28 | def torch_simple_gla_recurrent(q, k, v, g, chunk_size=64): 29 | # q = rearrange(q, 'b h (n c) d -> b h n c d', c = chunk_size) * (q.shape[-1] ** -0.5) 30 | # k = rearrange(k, 'b h (n c) d -> b h n c d', c = chunk_size) 31 | # v = rearrange(v, 'b h (n c) d -> b h n c d', c = chunk_size) 32 | # g = rearrange(g, 'b h (n c) -> b h n c', c = chunk_size) 33 | # g = g.cumsum(-1) 34 | # kv = k.transpose(-1, -2) @ v 35 | 36 | B, H, T, DK = q.shape 37 | q = q * (DK ** -0.5) 38 | _, _, _, DV = v.shape 39 | S = torch.zeros(B, H, DK, DV).to(q) 40 | o = torch.zeros(B, H, T, DV).to(q) 41 | for i in range(T): 42 | gate = g[:, :, i].exp() 43 | key = k[:, :, i] 44 | value = v[:, :, i] 45 | kv = key.unsqueeze(-1) * value.unsqueeze(-2) 46 | S = S.clone() * gate.unsqueeze(-1).unsqueeze(-1) + kv 47 | q_i = q[:, :, i, :] 48 | o_i = (q_i.unsqueeze(-1) * S).sum(-2) 49 | o[:, :, i] = o_i 50 | 51 | return o 52 | 53 | -------------------------------------------------------------------------------- /fla/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import functools 4 | 5 | import torch 6 | 7 | 8 | def contiguous(fn): 9 | @functools.wraps(fn) 10 | def wrapper(ctx, *args, **kwargs): 11 | return fn(ctx, 12 | *(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args), 13 | **{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}) 14 | return wrapper 15 | 16 | 17 | def require_version(version, hint): 18 | def decorator(fn): 19 | @functools.wraps(fn) 20 | def wrapper(ctx, *args, **kwargs): 21 | from transformers.utils.versions import require_version 22 | require_version(version, hint) 23 | return fn(ctx, 24 | *(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args), 25 | **{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}) 26 | return wrapper 27 | return decorator 28 | 29 | 30 | def checkpoint(func): 31 | def wrapper(*args, **kwargs): 32 | return torch.utils.checkpoint.checkpoint(func, *args, **kwargs) 33 | return wrapper 34 | -------------------------------------------------------------------------------- /merge/merge.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import os 3 | import sys 4 | from typing import Dict 5 | import typing 6 | import torch 7 | import bitsandbytes as bnb 8 | from argparse import ArgumentParser 9 | 10 | parser = ArgumentParser() 11 | parser.add_argument("--type", default="pissa", type=str) 12 | parser.add_argument("--base_model", default="", type=str) 13 | parser.add_argument("--lora_init", default="none", type=str) 14 | parser.add_argument("--lora_checkpoint", default="", type=str) 15 | parser.add_argument("--output", default="", type=str) 16 | parser.add_argument("--quant", default="none", type=str) 17 | parser.add_argument("--device", default="cuda", type=str) 18 | parser.add_argument("--lora_alpha", default=16, type=int) 19 | args = parser.parse_args() 20 | device= args.device 21 | base_model = args.base_model 22 | init_lora= args.lora_init 23 | lora= args.lora_checkpoint 24 | output= args.output 25 | quant= args.quant 26 | lora_alpha = args.lora_alpha 27 | 28 | with torch.no_grad(): 29 | w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu') 30 | # merge LoRA-only slim checkpoint into the main weights 31 | w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu') 32 | 33 | if args.type=='pissa': 34 | w_init_lora: Dict[str, torch.Tensor] = torch.load(init_lora, map_location='cpu') 35 | for k in w_lora.keys(): 36 | w[k] = w_lora[k] 37 | output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict() 38 | # merge LoRA weights 39 | keys = list(w.keys()) 40 | for k in keys: 41 | if k.endswith('.weight'): 42 | prefix = k[:-len('.weight')] 43 | lora_A = prefix + '.lora_A' 44 | lora_B = prefix + '.lora_B' 45 | init_lora_A = prefix + '.init_lora_A' 46 | init_lora_B = prefix + '.init_lora_B' 47 | if lora_A in keys: 48 | assert lora_B in keys 49 | print(f'merging {lora_A} and {lora_B} into {k}') 50 | assert w[lora_B].shape[1] == w[lora_A].shape[0] 51 | lora_r = w[lora_B].shape[1] 52 | w[k] = w[k].to(device=device) 53 | w[lora_A] = w[lora_A].to(device=device) 54 | w[lora_B] = w[lora_B].to(device=device) 55 | 56 | if args.type=='pissa': 57 | w_init_lora[init_lora_A] = w_init_lora[init_lora_A].to(device=device) 58 | w_init_lora[init_lora_B] = w_init_lora[init_lora_B].to(device=device) 59 | if quant=='4bit': 60 | qw,qs = bnb.functional.quantize_4bit(w[k]- w_init_lora[init_lora_B] @ w_init_lora[init_lora_A]) 61 | w[k] = (bnb.functional.dequantize_4bit(qw,quant_state=qs)).to(dtype=torch.bfloat16) 62 | elif quant == 'nf4': 63 | qw,qs = bnb.functional.quantize_nf4(w[k]- w_init_lora[init_lora_B] @ w_init_lora[init_lora_A]) 64 | w[k] = (bnb.functional.dequantize_nf4(qw,quant_state=qs)).to(dtype=torch.bfloat16) 65 | elif quant == 'fp4': 66 | qw,qs = bnb.functional.quantize_fp4(w[k]- w_init_lora[init_lora_B] @ w_init_lora[init_lora_A]) 67 | w[k] = (bnb.functional.dequantize_fp4(qw,quant_state=qs)).to(dtype=torch.bfloat16) 68 | elif quant == 'int8': 69 | qw,qs = bnb.functional.quantize(w[k]- w_init_lora[init_lora_B] @ w_init_lora[init_lora_A]) 70 | w[k] = (bnb.functional.dequantize(qw,state=qs)).to(dtype=torch.bfloat16) 71 | else: 72 | w[k] = (w[k]- w_init_lora[init_lora_B] @ w_init_lora[init_lora_A]).to(dtype=torch.bfloat16) 73 | w[k] += w[lora_B] @ w[lora_A] 74 | else: 75 | if quant=='4bit': 76 | qw,qs = bnb.functional.quantize_4bit(w[k]) 77 | w[k] = (bnb.functional.dequantize_4bit(qw,quant_state=qs)).to(dtype=torch.bfloat16) 78 | elif quant=='nf4': 79 | qw,qs = bnb.functional.quantize_nf4(w[k]) 80 | w[k] = (bnb.functional.dequantize_nf4(qw,quant_state=qs)).to(dtype=torch.bfloat16) 81 | elif quant=='fp4': 82 | qw,qs = bnb.functional.quantize_fp4(w[k]) 83 | w[k] = (bnb.functional.dequantize_fp4(qw,quant_state=qs)).to(dtype=torch.bfloat16) 84 | elif quant=='int8': 85 | qw,qs = bnb.functional.quantize(w[k]) 86 | w[k] = (bnb.functional.dequantize(qw,state=qs)).to(dtype=torch.bfloat16) 87 | w[k] += w[lora_B] @ w[lora_A] * (lora_alpha / lora_r) 88 | output_w[k] = w[k].to(device='cpu', copy=True) 89 | del w[k] 90 | del w[lora_A] 91 | del w[lora_B] 92 | continue 93 | 94 | if 'lora' not in k: 95 | print(f'retaining {k}') 96 | output_w[k] = w[k].clone() 97 | del w[k] 98 | torch.save(output_w, output) -------------------------------------------------------------------------------- /merge/merge_lora.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import os 3 | import sys 4 | from typing import Dict 5 | import typing 6 | import torch 7 | 8 | if '-h' in sys.argv or '--help' in sys.argv: 9 | print(f'Usage: python3 {sys.argv[0]} [--use-gpu] ') 10 | 11 | if sys.argv[1] == '--use-gpu': 12 | device = 'cuda' 13 | lora_alpha, base_model, lora, output = float(sys.argv[2]), sys.argv[3], sys.argv[4], sys.argv[5] 14 | else: 15 | device = 'cpu' 16 | lora_alpha, base_model, lora, output = float(sys.argv[1]), sys.argv[2], sys.argv[3], sys.argv[4] 17 | 18 | 19 | with torch.no_grad(): 20 | w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu') 21 | # merge LoRA-only slim checkpoint into the main weights 22 | w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu') 23 | for k in w_lora.keys(): 24 | w[k] = w_lora[k] 25 | output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict() 26 | # merge LoRA weights 27 | keys = list(w.keys()) 28 | for k in keys: 29 | if k.endswith('.weight'): 30 | prefix = k[:-len('.weight')] 31 | lora_A = prefix + '.lora_A' 32 | lora_B = prefix + '.lora_B' 33 | if lora_A in keys: 34 | assert lora_B in keys 35 | print(f'merging {lora_A} and {lora_B} into {k}') 36 | assert w[lora_B].shape[1] == w[lora_A].shape[0] 37 | lora_r = w[lora_B].shape[1] 38 | w[k] = w[k].to(device=device) 39 | w[lora_A] = w[lora_A].to(device=device) 40 | w[lora_B] = w[lora_B].to(device=device) 41 | w[k] += w[lora_B] @ w[lora_A] * (lora_alpha / lora_r) 42 | output_w[k] = w[k].to(device='cpu', copy=True) 43 | del w[k] 44 | del w[lora_A] 45 | del w[lora_B] 46 | continue 47 | 48 | if 'lora' not in k: 49 | print(f'retaining {k}') 50 | output_w[k] = w[k].clone() 51 | del w[k] 52 | torch.save(output_w, output) 53 | -------------------------------------------------------------------------------- /merge/merge_pissa.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import os 3 | import sys 4 | from typing import Dict 5 | import typing 6 | import torch 7 | 8 | if '-h' in sys.argv or '--help' in sys.argv: 9 | print(f'Usage: python3 {sys.argv[0]} [--use-gpu] ') 10 | 11 | if sys.argv[1] == '--use-gpu': 12 | device = 'cuda' 13 | base_model, init_lora, lora, output = sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5] 14 | else: 15 | device = 'cpu' 16 | base_model, init_lora, lora, output = sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4] 17 | 18 | 19 | with torch.no_grad(): 20 | w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu') 21 | # merge LoRA-only slim checkpoint into the main weights 22 | w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu') 23 | w_init_lora: Dict[str, torch.Tensor] = torch.load(init_lora, map_location='cpu') 24 | for k in w_lora.keys(): 25 | w[k] = w_lora[k] 26 | output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict() 27 | # merge LoRA weights 28 | keys = list(w.keys()) 29 | for k in keys: 30 | if k.endswith('.weight'): 31 | prefix = k[:-len('.weight')] 32 | lora_A = prefix + '.lora_A' 33 | lora_B = prefix + '.lora_B' 34 | init_lora_A = prefix + '.init_lora_A' 35 | init_lora_B = prefix + '.init_lora_B' 36 | if lora_A in keys: 37 | assert lora_B in keys 38 | print(f'merging {lora_A} and {lora_B} into {k}') 39 | assert w[lora_B].shape[1] == w[lora_A].shape[0] 40 | lora_r = w[lora_B].shape[1] 41 | w[k] = w[k].to(device=device) 42 | w[lora_A] = w[lora_A].to(device=device) 43 | w[lora_B] = w[lora_B].to(device=device) 44 | w_init_lora[init_lora_A] = w_init_lora[init_lora_A].to(device=device) 45 | w_init_lora[init_lora_B] = w_init_lora[init_lora_B].to(device=device) 46 | w[k] = (w[k]- w_init_lora[init_lora_B] @ w_init_lora[init_lora_A]).to(dtype=torch.bfloat16) 47 | w[k] += w[lora_B] @ w[lora_A] 48 | output_w[k] = w[k].to(device='cpu', copy=True) 49 | del w[k] 50 | del w[lora_A] 51 | del w[lora_B] 52 | continue 53 | 54 | if 'lora' not in k: 55 | print(f'retaining {k}') 56 | output_w[k] = w[k].clone() 57 | del w[k] 58 | torch.save(output_w, output) -------------------------------------------------------------------------------- /merge/merge_state.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import os 3 | import sys 4 | from typing import Dict 5 | import typing 6 | import torch 7 | import bitsandbytes as bnb 8 | from argparse import ArgumentParser 9 | 10 | parser = ArgumentParser() 11 | parser.add_argument("--base_model", default="", type=str) 12 | parser.add_argument("--state_checkpoint", default="", type=str) 13 | parser.add_argument("--output", default="", type=str) 14 | # parser.add_argument("--quant", default="none", type=str) 15 | parser.add_argument("--device", default="cuda", type=str) 16 | # parser.add_argument("--lora_alpha", default=16, type=int) 17 | args = parser.parse_args() 18 | device= args.device 19 | base_model = args.base_model 20 | state= args.state_checkpoint 21 | output= args.output 22 | 23 | 24 | with torch.no_grad(): 25 | w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu') 26 | # merge LoRA-only slim checkpoint into the main weights 27 | w_state: Dict[str, torch.Tensor] = torch.load(state, map_location='cpu') 28 | 29 | for k in w_state.keys(): 30 | print(k) 31 | w[k] = w_state[k] 32 | # merge LoRA weights 33 | for k in w.keys(): 34 | print(k) 35 | 36 | torch.save(w, output) -------------------------------------------------------------------------------- /output/model output dir.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGENDD/RWKV-ASR/1f3c0d90db76c426820112476f125e988fe16130/output/model output dir.txt -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning==1.9.5 2 | bitsandbytes 3 | deepspeed 4 | einops 5 | triton==2.2.0 6 | transformers[torch] 7 | datasets 8 | evaluate 9 | jiwer 10 | tqdm -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AGENDD/RWKV-ASR/1f3c0d90db76c426820112476f125e988fe16130/src/__init__.py -------------------------------------------------------------------------------- /src/dataset2.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################## 2 | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM 3 | ######################################################################################################## 4 | 5 | import json, math, random, os, sys 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import Dataset 9 | from pytorch_lightning.utilities import rank_zero_info 10 | from .binidx import MMapIndexedDataset 11 | from .utils import MaybeIsPrime 12 | from rwkv.utils import PIPELINE 13 | import librosa 14 | pipeline = PIPELINE('rwkv6', "rwkv_vocab_v20230424") 15 | 16 | class MyDataset(Dataset): 17 | def __init__(self, args, hf_dataset): 18 | self.args = args 19 | self.hf_dataset = hf_dataset 20 | 21 | def __len__(self): 22 | return len(self.hf_dataset) 23 | 24 | def __getitem__(self, idx): 25 | 26 | while(True): 27 | try: 28 | sample = self.hf_dataset[idx] 29 | break 30 | except: 31 | idx = idx+1 32 | 33 | 34 | if('translation'in sample.keys()): 35 | #covost2 36 | answer = sample['translation'] 37 | audio = sample['audio']['array'] 38 | audio = librosa.resample(audio,orig_sr= 48000,target_sr= 16000) 39 | elif('sentence' in sample.keys()): 40 | #common voice 41 | answer = sample['sentence'] 42 | audio = sample['audio']['array'] 43 | audio = librosa.resample(audio,orig_sr= 48000,target_sr= 16000) 44 | elif('audio' in sample.keys()): 45 | #librispeech 46 | audio = sample['audio']['array'] 47 | answer = sample['text'] 48 | else: 49 | #en-final 50 | audio = sample['speech'] 51 | answer = sample['text'] 52 | 53 | # print(f"speech input{idx}:{len(audio)}") 54 | return audio, answer.lower() 55 | -------------------------------------------------------------------------------- /src/infctx_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | ######state 3 | class TimeMixState: 4 | def __init__(self, shift_state: torch.Tensor, wkv_state: torch.Tensor): 5 | self.shift_state = shift_state 6 | self.wkv_state = wkv_state 7 | 8 | 9 | class ChannelMixState: 10 | def __init__(self, shift_state: torch.Tensor): 11 | self.shift_state = shift_state 12 | 13 | 14 | class BlockState: 15 | def __init__(self, time_mix_state: TimeMixState, 16 | channel_mix_state: ChannelMixState): 17 | self.time_mix_state = time_mix_state 18 | self.channel_mix_state = channel_mix_state 19 | 20 | class BlockStateList: 21 | 22 | def __init__(self, shift_states, wkv_states): 23 | self.wkv_states = wkv_states 24 | self.shift_states = shift_states 25 | 26 | @staticmethod 27 | def create(N, B, C, H, device, dtype): 28 | result = BlockStateList.empty(N, B, C, H, device, dtype) 29 | result.wkv_states[:] = 0 30 | result.wkv_states[:] = 0 31 | result.shift_states[:] = 0 32 | return result 33 | 34 | @staticmethod 35 | def empty(N, B, C, H, device, dtype): 36 | wkv_states = torch.empty((N, B, H, C//H, C//H), 37 | device=device, 38 | dtype=torch.bfloat16) 39 | shift_states = torch.empty((N, 2, B, C), device=device, dtype=dtype) 40 | return BlockStateList(shift_states, wkv_states) 41 | 42 | def __getitem__(self, layer: int): 43 | return BlockState( 44 | TimeMixState(self.shift_states[layer, 0], self.wkv_states[layer]), 45 | ChannelMixState(self.shift_states[layer, 1])) 46 | 47 | def __setitem__(self, layer: int, state: BlockState): 48 | self.shift_states[layer, 0] = state.time_mix_state.shift_state 49 | self.wkv_states[layer] = state.time_mix_state.wkv_state 50 | self.shift_states[layer, 1] = state.channel_mix_state.shift_state 51 | 52 | 53 | -------------------------------------------------------------------------------- /src/rwkvLinear.py: -------------------------------------------------------------------------------- 1 | import torch, math 2 | import torch.nn as nn 3 | import bitsandbytes as bnb 4 | from torch.nn import functional as F 5 | from torch._lowrank import svd_lowrank 6 | import functools 7 | 8 | def rwkv_quantize(quant_type, weight): 9 | if quant_type=='4bit': 10 | qweight, qstate= bnb.functional.quantize_4bit((weight.data).to('cuda')) 11 | elif quant_type=='nf4': 12 | qweight, qstate= bnb.functional.quantize_nf4((weight.data).to('cuda')) 13 | elif quant_type=='fp4': 14 | qweight, qstate= bnb.functional.quantize_fp4((weight.data).to('cuda')) 15 | elif quant_type=='int8': 16 | qweight, qstate= bnb.functional.quantize((weight.data).to('cuda')) 17 | return qweight, qstate 18 | 19 | 20 | def rwkv_dequantize(quant_type, weight, qstate): 21 | if quant_type=='4bit': 22 | deweight= bnb.functional.dequantize_4bit(weight.data,quant_state=qstate) 23 | elif quant_type=='nf4': 24 | deweight= bnb.functional.dequantize_nf4(weight.data,quant_state=qstate) 25 | elif quant_type=='fp4': 26 | deweight= bnb.functional.dequantize_fp4(weight.data,quant_state=qstate) 27 | elif quant_type=='int8': 28 | deweight= bnb.functional.dequantize(weight.data,state=qstate) 29 | return deweight 30 | 31 | 32 | 33 | LORA_CONFIG = { 34 | "r": 0, 35 | "alpha": 0, 36 | "dropout": 0, 37 | "parts": {"att", "ln", "time", "ffn"}, 38 | "quant": False, 39 | } 40 | class LoraLinear(nn.Module): 41 | 42 | def __init__(self, in_features: int, out_features: int, bias: bool): 43 | super().__init__() 44 | 45 | self.weight = nn.Parameter(torch.empty((out_features, in_features))) 46 | assert bias == False, "Biased LoraLinear not supported" 47 | 48 | r, alpha, dropout = LORA_CONFIG["r"], LORA_CONFIG[ 49 | "alpha"], LORA_CONFIG["dropout"] 50 | self.lora_A = nn.Parameter(torch.empty(r, in_features)) 51 | self.lora_B = nn.Parameter(torch.empty(out_features, r)) 52 | self.lora_dropout = nn.Dropout(dropout) 53 | self.scaling = alpha / r 54 | self.r = r 55 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 56 | nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) 57 | nn.init.zeros_(self.lora_B) 58 | self.pissa = False 59 | self.is_quant = False 60 | 61 | def pissa_load(self, init_A, init_B): 62 | self.pissa = True 63 | self.weight.data = self.weight.data - init_B @ init_A 64 | 65 | 66 | def pissa_init(self, svd_niter): 67 | 68 | self.pissa = True 69 | Ur, Sr, Vr = svd_lowrank(self.weight.data, self.r, niter=svd_niter) 70 | Vhr = Vr.t() 71 | lora_A = torch.diag(torch.sqrt(Sr)) @ Vhr 72 | lora_B = Ur @ torch.diag(torch.sqrt(Sr)) 73 | self.lora_A.data = lora_A 74 | self.lora_B.data = lora_B 75 | self.weight.data = self.weight.data - lora_B @ lora_A 76 | def quant(self, quant_type): 77 | self.is_quant = True 78 | self.quant_type = quant_type 79 | self.weight.data, self.qstate= rwkv_quantize(self.quant_type, (self.weight.data).to('cuda')) 80 | 81 | def forward(self, x): 82 | 83 | if self.is_quant: 84 | if self.pissa: 85 | return ( 86 | F.linear(x, rwkv_dequantize(self.quant_type, self.weight.data, self.qstate).to(torch.bfloat16)) + 87 | F.linear(F.linear(x, self.lora_A), self.lora_B)) 88 | return ( 89 | F.linear(x, rwkv_dequantize(self.quant_type, self.weight.data, self.qstate)) + self.scaling * 90 | F.linear(F.linear(self.lora_dropout(x), self.lora_A), self.lora_B)) 91 | 92 | if self.pissa: 93 | return ( 94 | F.linear(x, self.weight) + 95 | F.linear(F.linear(x, self.lora_A), self.lora_B)) 96 | return ( 97 | F.linear(x, self.weight) + self.scaling * 98 | F.linear(F.linear(self.lora_dropout(x), self.lora_A), self.lora_B)) 99 | 100 | 101 | class QuantLinear(nn.Module): 102 | def __init__(self, in_features: int, out_features: int, bias: bool): 103 | super().__init__() 104 | 105 | self.weight = nn.Parameter(torch.empty((out_features, in_features))) 106 | assert bias == False, "Biased QuantLinear not supported" 107 | self.is_quant = False 108 | 109 | def quant(self, quant_type): 110 | self.is_quant = True 111 | self.quant_type = quant_type 112 | #self.dummy_tensor = nn.Parameter(torch.zeros(1)) 113 | self.weight.data, self.qstate= rwkv_quantize(self.quant_type, (self.weight.data).to('cuda')) 114 | def forward(self, x): 115 | 116 | if self.is_quant: 117 | return F.linear(x, rwkv_dequantize(self.quant_type, self.weight.data, self.qstate).to(torch.bfloat16)) 118 | else: 119 | return F.linear(x, self.weight) 120 | 121 | 122 | @functools.wraps(LoraLinear) 123 | def make_linear_att(*args, **kwargs): 124 | if "att" in LORA_CONFIG["parts"] and LORA_CONFIG["r"] > 0: 125 | return LoraLinear(*args, **kwargs) 126 | elif LORA_CONFIG["quant"]: 127 | return QuantLinear(*args, **kwargs) 128 | else: 129 | return nn.Linear(*args, **kwargs) 130 | 131 | 132 | @functools.wraps(LoraLinear) 133 | def make_linear_ffn(*args, **kwargs): 134 | if "ffn" in LORA_CONFIG["parts"] and LORA_CONFIG["r"] > 0: 135 | return LoraLinear(*args, **kwargs) 136 | elif LORA_CONFIG["quant"]: 137 | return QuantLinear(*args, **kwargs) 138 | else: 139 | return nn.Linear(*args, **kwargs) -------------------------------------------------------------------------------- /src/speech_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import TransformerEncoder, TransformerEncoderLayer 5 | import numpy as np 6 | 7 | from transformers import AutoProcessor, AutoModel 8 | from transformers import Wav2Vec2FeatureExtractor 9 | from transformers import Wav2Vec2Processor 10 | from transformers import Wav2Vec2CTCTokenizer 11 | 12 | 13 | class SpeechEncoder(nn.Module): 14 | def __init__( 15 | self, 16 | model_id, 17 | project_dim, 18 | downsample_K=5, 19 | hidden_dim=2048, 20 | train_mode="adapter", 21 | device="cuda", 22 | ): 23 | assert train_mode in ["adapter", "full"] 24 | super(SpeechEncoder, self).__init__() 25 | 26 | feature_extractor = Wav2Vec2FeatureExtractor( 27 | feature_size=1, 28 | sampling_rate=16000, 29 | padding_value=0.0, 30 | do_normalize=True, 31 | return_attention_mask=False, 32 | ) 33 | self.device = device 34 | self.processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft") 35 | self.time_reduction_factor = int( 36 | self.processor.feature_extractor.sampling_rate / 50 37 | ) 38 | self.padding_length = 320 39 | self.model = AutoModel.from_pretrained(model_id).to(self.device,dtype=torch.bfloat16) 40 | self.model_output_dim = self.model.config.hidden_size 41 | self.downsample_K = downsample_K 42 | self.project_dim = project_dim 43 | if hidden_dim is None: 44 | self.hidden_dim = self.project_dim * 2 45 | else: 46 | self.hidden_dim = hidden_dim 47 | # adapter shall be a Linear(Relu(Linear)) structure 48 | self.adapter = nn.Sequential( 49 | nn.Linear(self.model_output_dim * self.downsample_K, self.hidden_dim), 50 | nn.ReLU(), 51 | nn.Linear(self.hidden_dim, self.project_dim), 52 | ).to(self.device,dtype=torch.bfloat16) 53 | self.set_gradient(train_mode) 54 | 55 | def set_gradient(self, train_mode): 56 | """ 57 | if train_mode is "adapter", only train the adapter layers, otherwise train the whole model 58 | """ 59 | if train_mode == "adapter": 60 | for param in self.model.parameters(): 61 | param.requires_grad = False 62 | for param in self.adapter.parameters(): 63 | param.requires_grad = True 64 | else: 65 | for param in self.model.parameters(): 66 | param.requires_grad = True 67 | for param in self.adapter.parameters(): 68 | param.requires_grad = True 69 | 70 | def calculate_mask(self, input_dict): 71 | """ 72 | Also need to handle the masking issue, to let the model not to attend to the padding tokens 73 | """ 74 | attention_mask = input_dict["attention_mask"] # [batch, num_samples] 75 | length_in_samples = ( 76 | attention_mask.shape[1] // self.padding_length * self.padding_length 77 | ) 78 | # calculate the mask length 79 | mask_length = length_in_samples // self.time_reduction_factor 80 | # create the mask 81 | mask = attention_mask[:, :: (self.time_reduction_factor * self.downsample_K)] 82 | return mask 83 | 84 | def forward(self, x): 85 | input_dict = self.processor( 86 | x, return_tensors="pt", padding=True, sampling_rate=16000 87 | ).to(self.device,dtype=torch.bfloat16) 88 | mask = self.calculate_mask(input_dict) 89 | x = self.model(**input_dict).last_hidden_state 90 | # reshape the output from [batch_size, num_frames, hidden_size] to [batch_size, num_frames//downsample_K, hidden_size*downsample_K] 91 | x = x.unfold(1, self.downsample_K, self.downsample_K).flatten(2) 92 | x = self.adapter(x) 93 | mask = mask[:, : x.shape[1]] 94 | return x, mask 95 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import json, time, random, os 2 | import numpy as np 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | time_slot = {} 7 | time_ref = time.time_ns() 8 | 9 | def record_time(name): 10 | if name not in time_slot: 11 | time_slot[name] = 1e20 12 | tt = (time.time_ns() - time_ref) / 1e9 13 | if tt < time_slot[name]: 14 | time_slot[name] = tt 15 | 16 | class TOKENIZER(): 17 | def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'): 18 | if 'list' in str(type(WORD_NAME)): 19 | self.charMode = False 20 | if WORD_NAME[0] == WORD_NAME[1]: 21 | from transformers import PreTrainedTokenizerFast 22 | self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0]) 23 | else: 24 | from transformers import GPT2TokenizerFast 25 | self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1]) 26 | self.vocab_size = len(self.tokenizer) 27 | else: 28 | self.charMode = True 29 | with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file: 30 | self.word_table = json.load(result_file) 31 | 32 | self.vocab_size = len(self.word_table) 33 | 34 | self.stoi = {v: int(k) for k, v in self.word_table.items()} 35 | self.itos = {int(k): v for k, v in self.word_table.items()} 36 | 37 | self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR] 38 | 39 | def refine_context(self, context): 40 | context = context.strip().split('\n') 41 | for c in range(len(context)): 42 | context[c] = context[c].strip().strip('\u3000').strip('\r') 43 | context = list(filter(lambda c: c != '', context)) 44 | context = '\n' + ('\n'.join(context)).strip() 45 | if context == '': 46 | context = '\n' 47 | return context 48 | 49 | def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None): 50 | # out[self.UNKNOWN_CHAR] = -float('Inf') 51 | lastChar = int(x[-1]) 52 | 53 | probs = F.softmax(out, dim=-1) 54 | 55 | if self.charMode: 56 | if self.itos[lastChar] == '\n': 57 | top_p = top_p_newline 58 | else: 59 | top_p = top_p_usual 60 | else: 61 | top_p = top_p_usual 62 | 63 | if os.environ["RWKV_RUN_DEVICE"] == "cpu": 64 | probs = probs.numpy() 65 | sorted_probs = np.sort(probs)[::-1] 66 | cumulative_probs = np.cumsum(sorted_probs) 67 | cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) 68 | probs[probs < cutoff] = 0 69 | if temperature != 1.0: 70 | probs = probs.pow(1.0 / temperature) 71 | probs = probs / np.sum(probs) 72 | out = np.random.choice(a=len(probs), p=probs) 73 | return out 74 | else: 75 | sorted_probs = torch.sort(probs, descending=True)[0] 76 | cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy() 77 | cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) 78 | probs[probs < cutoff] = 0 79 | if temperature != 1.0: 80 | probs = probs.pow(1.0 / temperature) 81 | out = torch.multinomial(probs, num_samples=1)[0] 82 | return out 83 | 84 | def MaybeIsPrime(number): 85 | if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number): 86 | return True 87 | else: 88 | return False 89 | 90 | 91 | def FermatPrimalityTest(number): 92 | if number > 1: 93 | for time in range(3): 94 | randomNumber = random.randint(2, number) - 1 95 | if pow(randomNumber, number - 1, number) != 1: 96 | return False 97 | return True 98 | else: 99 | return False 100 | 101 | 102 | def MillerRabinPrimalityTest(number): 103 | if number == 2: 104 | return True 105 | elif number == 1 or number % 2 == 0: 106 | return False 107 | oddPartOfNumber = number - 1 108 | timesTwoDividNumber = 0 109 | while oddPartOfNumber % 2 == 0: 110 | oddPartOfNumber = oddPartOfNumber // 2 111 | timesTwoDividNumber = timesTwoDividNumber + 1 112 | 113 | for time in range(3): 114 | while True: 115 | randomNumber = random.randint(2, number) - 1 116 | if randomNumber != 0 and randomNumber != 1: 117 | break 118 | 119 | randomNumberWithPower = pow(randomNumber, oddPartOfNumber, number) 120 | 121 | if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1): 122 | iterationNumber = 1 123 | 124 | while (iterationNumber <= timesTwoDividNumber - 1) and (randomNumberWithPower != number - 1): 125 | randomNumberWithPower = pow(randomNumberWithPower, 2, number) 126 | iterationNumber = iterationNumber + 1 127 | if randomNumberWithPower != (number - 1): 128 | return False 129 | 130 | return True 131 | --------------------------------------------------------------------------------