├── .gitignore ├── README.md ├── Readme_Eng.md ├── client.py ├── configuration.py ├── cuda ├── __init__.py ├── gemm_fp16_cublas.cpp ├── operators.cu ├── rwkv6.cu ├── rwkv6_op.cpp ├── wkv5_cuda.cu ├── wkv5_op.cpp ├── wkv6_bi_cuda.cu ├── wkv6_bi_op.cpp ├── wkv6_cuda.cu ├── wkv6_op.cpp ├── wkv6infctx_cuda.cu ├── wkv6infctx_op.cpp ├── wkv6state_cuda.cu ├── wkv6state_op.cpp └── wrapper.cpp ├── docker ├── DockerfileClient ├── DockerfileIndexService ├── DockerfileLLMService └── DockerfileProxyService ├── docs ├── CONTRIBUTING.md ├── User_guide.md └── img │ ├── RWKV-RAG-Base-Model-Manage.png │ ├── RWKV-RAG-CHAT-1-Query.png │ ├── RWKV-RAG-CHAT-2-Get-Text.png │ ├── RWKV-RAG-CHAT-3-Rerank.png │ ├── RWKV-RAG-CHAT-4-Chat.png │ ├── RWKV-RAG-Manage-Database.gif │ ├── RWKV-RAG-Search-From-Internet.png │ ├── RWKV-RAG-Tuning-Data.png │ ├── RWKV-RAG-Tuning-Service-Mange.png │ ├── RWKV-RAG-WebUI-client.png │ └── models_example.png ├── etc ├── index_service_config.yml ├── llm_service_config.yml ├── proxy_service_config.yml └── ragq.yml ├── fla ├── __init__.py ├── layers │ ├── __init__.py │ ├── abc.py │ ├── based.py │ ├── delta_net.py │ ├── gated_abc.py │ ├── gla.py │ ├── hgrn.py │ ├── hgrn2.py │ ├── linear_attn.py │ ├── multiscale_retention.py │ ├── rebased.py │ ├── rwkv6.py │ └── simple_gla.py ├── models │ ├── __init__.py │ ├── abc │ │ ├── __init__.py │ │ ├── configuration_abc.py │ │ └── modeling_abc.py │ ├── delta_net │ │ ├── __init__.py │ │ ├── configuration_delta_net.py │ │ └── modeling_delta_net.py │ ├── gla │ │ ├── __init__.py │ │ ├── configuration_gla.py │ │ └── modeling_gla.py │ ├── hgrn │ │ ├── __init__.py │ │ ├── configuration_hgrn.py │ │ └── modeling_hgrn.py │ ├── hgrn2 │ │ ├── __init__.py │ │ ├── configuration_hgrn2.py │ │ └── modeling_hgrn2.py │ ├── linear_attn │ │ ├── __init__.py │ │ ├── configuration_linear_attn.py │ │ └── modeling_linear_attn.py │ ├── retnet │ │ ├── __init__.py │ │ ├── configuration_retnet.py │ │ └── modeling_retnet.py │ ├── rwkv6 │ │ ├── __init__.py │ │ ├── configuration_rwkv6.py │ │ └── modeling_rwkv6.py │ ├── transformer │ │ ├── __init__.py │ │ ├── configuration_transformer.py │ │ └── modeling_transformer.py │ └── utils.py ├── modules │ ├── __init__.py │ ├── activations.py │ ├── convolution.py │ ├── feature_map.py │ ├── fused_cross_entropy.py │ ├── fused_norm_gate.py │ ├── l2norm.py │ ├── layernorm.py │ └── rotary.py ├── ops │ ├── __init__.py │ ├── abc │ │ ├── __init__.py │ │ ├── chunk.py │ │ ├── chunk_gate.py │ │ ├── naive.py │ │ └── recurrent_fuse.py │ ├── based │ │ ├── __init__.py │ │ ├── chunk_fuse.py │ │ ├── naive.py │ │ └── parallel.py │ ├── delta_rule │ │ ├── README.md │ │ ├── __init__.py │ │ ├── chunk.py │ │ ├── chunk_fuse.py │ │ ├── naive.py │ │ ├── recurrent_fuse.py │ │ ├── utils.py │ │ └── wy_fast.py │ ├── gla │ │ ├── __init__.py │ │ ├── chunk.py │ │ ├── chunk_fuse.py │ │ ├── chunk_util.py │ │ ├── naive.py │ │ └── recurrent_fuse.py │ ├── hgrn │ │ ├── __init__.py │ │ ├── chunk.py │ │ ├── naive.py │ │ └── recurrent_fuse.py │ ├── linear_attn │ │ ├── __init__.py │ │ ├── chunk.py │ │ ├── chunk_fuse.py │ │ ├── naive.py │ │ └── recurrent_fuse.py │ ├── rebased │ │ ├── __init__.py │ │ ├── naive.py │ │ └── parallel.py │ ├── retention │ │ ├── __init__.py │ │ ├── chunk.py │ │ ├── chunk_fuse.py │ │ ├── naive.py │ │ ├── parallel.py │ │ └── recurrent_fuse.py │ ├── rotary.py │ ├── rwkv4 │ │ ├── __init__.py │ │ └── recurrent_fuse.py │ ├── rwkv6 │ │ ├── __init__.py │ │ ├── chunk.py │ │ ├── chunk_naive.py │ │ ├── recurrent_fuse.py │ │ └── recurrent_naive.py │ ├── simple_gla │ │ ├── README.md │ │ ├── __init__.py │ │ ├── chunk.py │ │ └── naive.py │ └── utils.py └── utils.py ├── proxy.py ├── required ├── index_service_requirements.txt ├── llm_service_requirements.txt └── requirements.txt ├── service.py ├── src ├── __init__.py ├── clients │ ├── __init__.py │ ├── files_service.py │ ├── index_client.py │ └── llm_client.py ├── core │ ├── __init__.py │ └── singleton.py ├── services │ ├── __init__.py │ ├── abc │ │ └── __init__.py │ ├── index_service.py │ └── llm_service.py ├── utils │ ├── __init__.py │ ├── internet.py │ ├── loader.py │ └── tools.py └── vectordb │ ├── __init__.py │ ├── abc │ └── __init__.py │ ├── chroma.py │ └── errors.py └── tokenizer ├── __init__.py ├── rwkv_tokenizer.py └── rwkv_vocab_v20230424.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | .idea/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | test/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | wandb/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | midi_resource/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | -------------------------------------------------------------------------------- /Readme_Eng.md: -------------------------------------------------------------------------------- 1 | RAG in one click 2 | This is a one click RAG system for RWKV named "AIIRWKV". AIIRWKV employed asynchronus processing, which allows maintainence and update of services to be done independently. This system design enables minimal encapsulation with high extensibility. Moreover, AIIRWKV has integrated one-click tools for StateTune, an extremely efficient fine-tuning method exclusive to RWKV. Additionally, it supports Lora and Pissa, providing convenient PEFT (Parameter-Efficient Fine-Tuning) solutions for users to tackle various downstream tasks. Models used in this framework are tuned on Chinese datasets; thus, AIIRWKV currently has better performances on Chinese tasks. However, English-tuned models are comming soon. 3 | 4 | System design 5 | Even the minimal RAG system involves several sub-systems and these systems may interact with each other. In order to increase the development flexibility and flat the development curve, a queue based RAG system is designed. 6 | 7 | Every component must be pluggable and easy to scale. Which means RPC shouldn't be hard-wired means like TCP/InProc/InterProcess, etc. 8 | 9 | The best design pattern is a pub-sub model that every component connects to a broker(or proxy) to send requests and receive responses. Generally heavy-weight message queue like RabbitMQ, RocketMQ is used to ensure efficiency and reliability. However, a Message Queue service is still another monster to administrate and maintain. 10 | 11 | Here the new design is to use a broker free queue library ZeroMQ as a queue service. 12 | 13 | Thanks to ZeroMQ's reliable and high performence implementation, this framework can scale from single resource restricted node to multinodes huge system. 14 | 15 | RWKV_RAG system looks like: 16 | 17 | 18 | Download Models 19 | Please download baseline models from https://huggingface.co/BlinkDL 20 | Please download state for chatbot from: https://huggingface.co/SupYumm/rwkv6_7b_qabot/tree/main 21 | There are several options for embedding models and rerank models: 22 | Please download RWKV embedding model from :https://huggingface.co/yueyulin/rwkv6_emb_4k_base 23 | Please download BGEM3 embedding models from: https://huggingface.co/BAAI/bge-m3 24 | Please download BGEM3 reranker from: https://huggingface.co/BAAI/bge-reranker-v2-m3 25 | Please feel free to chang your own embedding and reranker on config,yaml. Currently, BGEM3 is an ideal option; however, RWKV embedding models and reranker with better performance are coming soon. 26 | 27 | The following part will describe the implementation which will update in the future since more features will be added. However the basic design will keep the same. 28 | 29 | Quick Start 30 | Requirement 31 | Please Install the dependecies in the requirement.txt: 32 | 33 | pip install -r requirement.txt 34 | The following is the recommendation of VRAM for RWKVs in diffrent parameters: 35 | 36 | SIZE VRAM 37 | 1.6b 4G 38 | 3b 7.5G 39 | 7b 18G 40 | 12b 24G 41 | 14b 30G 42 | Modifying Configuration 43 | You can control the activation or deactivation of all services through the configuration file ragq.yml. By default, all services are enabled. Before use, you need to modify the following configuration items for some services. 44 | 45 | LLM Service 46 | Embedding, reranking, and generating text. 47 | 48 | base_model_file: RWKV baseline models path, Refer to RWKV基模下载 or Download Models Above 49 | bgem3_path: Embedding model path,Recommend:bge-m31 50 | rerank_path: Rerank model path,Recommend:BAAIbge-reranker-v2-m3 51 | state_path: state path;state is generated by state-tuning 52 | Index Service 53 | chroma_path: ChromaDB 54 | chroma_port: ChromaDB port 55 | chroma_host: ChromaDB host IP 56 | sqlite_db_path: sqlight db path 57 | Tuning Service 58 | The default value can be used for rwkv6_1.6b. 59 | 60 | Start Services 61 | python3 service.py 62 | Start Client 63 | streamlit run client.py 64 | Open the url that provided by streamlit in the browser 65 | 66 | Notes 67 | It is recommended to use Python 3.10 or Python 3.9. 68 | PyTorch Lightning must use version 1.9.5. 69 | The current version, when using the fine-tuning feature, will load the baseline model again, so it is necessary to allocate GPU memory reasonably to avoid errors due to insufficient VRAM. 70 | Handbooks 71 | Manage Vector Database 72 | This UI supports VDB Collection search,collection creation and deletion, content management of collection. 73 | 74 | knowledge manager 75 | Building knowledgebase 76 | This UI supports three different methods to Index contents into knowledgebase: Hand-typing, uploading from local computer, upload from local server. AIIRWKV also supports internet search to index real-time data from internet into knowledgebase. User can choose chunk size and chunk overlap on their own according to vairous situation. 77 | 78 | knowledge manager 79 | Fine-Tune RWKV models in one click 80 | WanDB 81 | Please register WanDB to monitor the status, especailly loss curve, of fine-tuning process. A task bar that tracks fine-tuning process is displayed at backend terminal. 82 | 83 | Setting-up fine-tune Parameters 84 | VRAM requirement for fine-tuning RWKV models with 1024 ctx. 85 | 86 | Size fp16 int8 nf4 87 | RWKV6-1.6B 5.8GB GPU 4.5GB GPU 3.9GB GPU 88 | RWKV6-3B 8.7GB GPU 6.2GB GPU 4.9GB GPU 89 | RWKV6-7B 17.8GB GPU 11.9GB GPU 8.5GB GPU 90 | For detail explaintions of other parameters and hyperparameters, please refer to the official tutorial at : https://rwkv.cn/RWKV-Fine-Tuning/State-Tuning 91 | 92 | knowledge manager 93 | RAG CHATBOT 94 | Please retrieve the most relevant information from the knowledgebase, then ask questions regarding those information. User can modify the basemodel and state dynamically on the UI. AIIRWKV is a chatbot that can deliever precise answers based on all the information from last 6 round of conversation. User can always change states for different downstream tasks. 95 | 96 | knowledge manager 97 | Future Direction 98 | The multi-modal framework, primarily focused on ASR and Vision, will be available online soon. Additionally, GraphRAG and prompt optimization are also forthcoming. 99 | 100 | Acknowledgement 101 | All RWKV tuning service is adapted from J.L 102 | All RWKV models is from BlinkDL 103 | Authors: YYnil ; Ojiyum ; LonghuaLiu 104 | -------------------------------------------------------------------------------- /configuration.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | 4 | from src.core import SingletonMeta 5 | 6 | class LLMServiceConfig(metaclass=SingletonMeta): 7 | def __init__(self, config_file): 8 | if not os.path.exists(config_file): 9 | raise FileNotFoundError(f"Config file {config_file} not found") 10 | with open(config_file) as f: 11 | try: 12 | self.config = yaml.safe_load(f) 13 | except yaml.YAMLError as exc: 14 | raise ValueError(f"Invalid config file {config_file}") 15 | self.config_file_path = config_file 16 | self.default_base_model_path = '' # 默认基座模型路径 17 | self.default_bgem3_path = '' # 默认bgem3路径 18 | self.default_rerank_path = '' # 默认rerank路径 19 | self.default_state_path = '' # 默认state文件 20 | self.validate_llm_service_config(self.config) 21 | 22 | 23 | def validate_llm_service_config(self, settings): 24 | base_model_file = settings.get("base_model_path", '') 25 | if not base_model_file: 26 | raise ValueError(f"base_model_path is required for llm service") 27 | if not os.path.exists(base_model_file): 28 | raise FileNotFoundError(f"base_model_path {base_model_file} not found for {self.config_file_path}") 29 | 30 | bgem3_path = settings.get("embedding_path", '') 31 | if not bgem3_path: 32 | raise ValueError(f"embedding_path is required for llm service") 33 | if not os.path.exists(bgem3_path): 34 | raise FileNotFoundError(f"embedding_path {bgem3_path} not found for {self.config_file_path}") 35 | rerank_path = settings.get("reranker_path", '') 36 | if not rerank_path: 37 | raise ValueError(f"reranker_path is required for llm service") 38 | if not os.path.exists(rerank_path): 39 | raise FileNotFoundError(f"reranker_path {rerank_path} not found for {self.config_file_path}") 40 | state_path = settings.get("state_path", '') or '' 41 | if state_path: 42 | if not os.path.exists(state_path): 43 | raise FileNotFoundError(f"state_path {state_path} not found for {self.config_file_path}") 44 | self.default_base_model_path = base_model_file.strip() 45 | self.default_bgem3_path = bgem3_path.strip() 46 | self.default_rerank_path = rerank_path.strip() 47 | self.default_state_path = state_path 48 | 49 | # def set_llm_service_config(self, base_model_path=None, embedding_path=None, reranker_path=None, state_path=None): 50 | # is_save = False 51 | # if base_model_path and base_model_path != self.default_base_model_path: 52 | # self.default_base_model_path = base_model_path.strip() 53 | # self.config['base_model_path'] = base_model_path 54 | # is_save = True 55 | # if embedding_path and embedding_path != self.default_bgem3_path: 56 | # self.default_bgem3_path = embedding_path 57 | # self.config['embedding_path'] = embedding_path 58 | # is_save = True 59 | # if reranker_path and reranker_path != self.default_rerank_path: 60 | # self.default_rerank_path = reranker_path 61 | # self.config['reranker_path'] = reranker_path 62 | # if state_path and state_path != self.default_state_path: 63 | # self.default_state_path = state_path 64 | # self.config['state_path'] = state_path 65 | # is_save = True 66 | # if is_save: 67 | # with open(self.config_file_path, "w") as f: 68 | # yaml.dump(self.config, f) 69 | 70 | 71 | class IndexServiceConfig(metaclass=SingletonMeta): 72 | def __init__(self, config_file): 73 | if not os.path.exists(config_file): 74 | raise FileNotFoundError(f"Config file {config_file} not found") 75 | with open(config_file) as f: 76 | try: 77 | self.config = yaml.safe_load(f) 78 | except yaml.YAMLError as exc: 79 | raise ValueError(f"Invalid config file {config_file}") 80 | self.config_file_path = config_file 81 | self.validate_index_service_config(self.config) 82 | 83 | def validate_index_service_config(self, settings): 84 | vectordb_name = settings.get("vectordb_name", '') 85 | vectordb_host = settings.get("vectordb_host", '') 86 | if not vectordb_name: 87 | raise ValueError(f"vectordb_name is required for index service") 88 | if not vectordb_host: 89 | raise ValueError(f"vectordb_host is required for index service") 90 | 91 | vectordb_port = settings.get("vectordb_port", '') 92 | if not (isinstance(vectordb_port, int) or (isinstance(vectordb_port, str) and vectordb_port.isdigit())): 93 | raise ValueError(f"vectordb_port is required for index service") 94 | 95 | # class TuningServiceConfig(metaclass=SingletonMeta): 96 | # def __init__(self, config_file): 97 | # if not os.path.exists(config_file): 98 | # raise FileNotFoundError(f"Config file {config_file} not found") 99 | # with open(config_file) as f: 100 | # try: 101 | # self.config = yaml.safe_load(f) 102 | # except yaml.YAMLError as exc: 103 | # raise ValueError(f"Invalid config file {config_file}") 104 | # self.config_file_path = config_file 105 | 106 | 107 | class ClientConfig(metaclass=SingletonMeta): 108 | def __init__(self, config_file, validate=True): 109 | if not os.path.exists(config_file): 110 | raise FileNotFoundError(f"Config file {config_file} not found") 111 | with open(config_file) as f: 112 | try: 113 | self.config = yaml.safe_load(f) 114 | except yaml.YAMLError as exc: 115 | raise ValueError(f"Invalid config file {config_file}") 116 | if validate: 117 | self.validate() 118 | 119 | def validate(self): 120 | """ 121 | Validate Configuration File 122 | """ 123 | base_setting = self.config.get('base', {}) 124 | sqlite_db_path = base_setting.get("sqlite_db_path", '') 125 | if not sqlite_db_path: 126 | raise ValueError(f"sqlite_db_path is required") 127 | sqlite_db_path_dir = os.path.dirname(sqlite_db_path) 128 | if not os.path.exists(sqlite_db_path_dir): 129 | os.makedirs(sqlite_db_path_dir) 130 | knowledge_base_path = base_setting.get("knowledge_base_path", '') 131 | if knowledge_base_path: 132 | if not os.path.exists(knowledge_base_path): 133 | os.makedirs(knowledge_base_path) 134 | 135 | 136 | 137 | -------------------------------------------------------------------------------- /cuda/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIIRWKV/RWKV-RAG/0ccd84591af280f0d20efd932041bb00a4430fb2/cuda/__init__.py -------------------------------------------------------------------------------- /cuda/gemm_fp16_cublas.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #define CUBLAS_CHECK(condition) \ 10 | for (cublasStatus_t _cublas_check_status = (condition); \ 11 | _cublas_check_status != CUBLAS_STATUS_SUCCESS;) \ 12 | throw std::runtime_error("cuBLAS error " + \ 13 | std::to_string(_cublas_check_status) + " at " + \ 14 | std::to_string(__LINE__)); 15 | 16 | #define CUDA_CHECK(condition) \ 17 | for (cudaError_t _cuda_check_status = (condition); \ 18 | _cuda_check_status != cudaSuccess;) \ 19 | throw std::runtime_error( \ 20 | "CUDA error " + std::string(cudaGetErrorString(_cuda_check_status)) + \ 21 | " at " + std::to_string(__LINE__)); 22 | 23 | /* 24 | NOTE: blas gemm is column-major by default, but we need row-major output. 25 | The data of row-major, transposed matrix is exactly the same as the 26 | column-major, non-transposed matrix, and C = A * B ---> C^T = B^T * A^T 27 | */ 28 | void gemm_fp16_cublas(torch::Tensor a, torch::Tensor b, torch::Tensor c) { 29 | const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); 30 | const auto cuda_data_type = CUDA_R_16F; 31 | const auto cuda_c_data_type = 32 | c.dtype() == torch::kFloat32 ? CUDA_R_32F : CUDA_R_16F; 33 | const auto compute_type = CUDA_R_32F; 34 | const float sp_alpha = 1.f; 35 | // swap a and b, and use CUBLAS_OP_N. see the notes above 36 | std::swap(a, b); 37 | const cublasOperation_t cublas_trans_a = CUBLAS_OP_N; 38 | const cublasOperation_t cublas_trans_b = CUBLAS_OP_N; 39 | // m = (B^T).size(0) = B.size(1), and = A.size(1) after swap, 40 | // negative axis is used because of the existence of batch matmul. 41 | const int m = a.size(-1); 42 | const int k = a.size(-2); 43 | const int n = b.size(-2); 44 | const int cublas_lda = m; 45 | const int cublas_ldb = k; 46 | const int cublas_ldc = m; 47 | cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle(); 48 | 49 | #if CUDA_VERSION >= 11000 50 | cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; 51 | #else 52 | cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT_TENSOR_OP; 53 | #endif 54 | const float sp_beta = 0.f; 55 | if (a.sizes().size() == 2 && b.sizes().size() == 2) { 56 | CUBLAS_CHECK(cublasGemmEx( 57 | cublas_handle, cublas_trans_a, cublas_trans_b, m, n, k, &sp_alpha, 58 | a.data_ptr(), cuda_data_type, cublas_lda, b.data_ptr(), cuda_data_type, 59 | cublas_ldb, &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, 60 | compute_type, algo)); 61 | } else { 62 | // batch matmul 63 | assert(a.sizes().size() == 3 && b.sizes().size() == 3); 64 | 65 | const long long int cublas_stride_a = m * k; 66 | const long long int cublas_stride_b = k * n; 67 | const long long int cublas_stride_c = m * n; 68 | CUBLAS_CHECK(cublasGemmStridedBatchedEx( 69 | cublas_handle, cublas_trans_a, cublas_trans_b, m, 70 | n, k, &sp_alpha, a.data_ptr(), cuda_data_type, cublas_lda, 71 | cublas_stride_a, b.data_ptr(), cuda_data_type, cublas_ldb, cublas_stride_b, 72 | &sp_beta, c.data_ptr(), cuda_c_data_type, cublas_ldc, cublas_stride_c, 73 | a.size(0), compute_type, algo)); 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /cuda/rwkv6.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "ATen/ATen.h" 4 | typedef at::BFloat16 bf16; 5 | typedef at::Half fp16; 6 | typedef float fp32; 7 | 8 | template 9 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, float *__restrict__ _state, 10 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, 11 | F *__restrict__ const _y) 12 | { 13 | const int b = blockIdx.x / H; 14 | const int h = blockIdx.x % H; 15 | const int i = threadIdx.x; 16 | _u += h*_N_; 17 | _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!! 18 | 19 | __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; 20 | 21 | float state[_N_]; 22 | #pragma unroll 23 | for (int j = 0; j < _N_; j++) 24 | state[j] = _state[j]; 25 | 26 | __syncthreads(); 27 | u[i] = float(_u[i]); 28 | __syncthreads(); 29 | 30 | for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) 31 | { 32 | __syncthreads(); 33 | w[i] = _w[t]; 34 | r[i] = float(_r[t]); 35 | k[i] = float(_k[t]); 36 | __syncthreads(); 37 | 38 | const float v = float(_v[t]); 39 | float y = 0; 40 | 41 | #pragma unroll 42 | for (int j = 0; j < _N_; j+=4) 43 | { 44 | const float4& r_ = (float4&)(r[j]); 45 | const float4& k_ = (float4&)(k[j]); 46 | const float4& w_ = (float4&)(w[j]); 47 | const float4& u_ = (float4&)(u[j]); 48 | float4& s = (float4&)(state[j]); 49 | float4 x; 50 | 51 | x.x = k_.x * v; 52 | x.y = k_.y * v; 53 | x.z = k_.z * v; 54 | x.w = k_.w * v; 55 | 56 | y += r_.x * (u_.x * x.x + s.x); 57 | y += r_.y * (u_.y * x.y + s.y); 58 | y += r_.z * (u_.z * x.z + s.z); 59 | y += r_.w * (u_.w * x.w + s.w); 60 | 61 | s.x = s.x * w_.x + x.x; 62 | s.y = s.y * w_.y + x.y; 63 | s.z = s.z * w_.z + x.z; 64 | s.w = s.w * w_.w + x.w; 65 | } 66 | _y[t] = F(y); 67 | } 68 | #pragma unroll 69 | for (int j = 0; j < _N_; j++) 70 | _state[j] = state[j]; 71 | } 72 | 73 | void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y) 74 | { 75 | assert(H*_N_ == C); 76 | kernel_forward<<>>(B, T, C, H, state, r, k, v, w, u, y); 77 | } 78 | void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y) 79 | { 80 | assert(H*_N_ == C); 81 | kernel_forward<<>>(B, T, C, H, state, r, k, v, w, u, y); 82 | } 83 | void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y) 84 | { 85 | assert(H*_N_ == C); 86 | kernel_forward<<>>(B, T, C, H, state, r, k, v, w, u, y); 87 | } 88 | -------------------------------------------------------------------------------- /cuda/rwkv6_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "ATen/ATen.h" 3 | #include 4 | typedef at::BFloat16 bf16; 5 | typedef at::Half fp16; 6 | typedef float fp32; 7 | 8 | void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); 9 | void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y); 10 | void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y); 11 | 12 | void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 13 | const at::cuda::OptionalCUDAGuard device_guard(device_of(state)); 14 | cuda_forward_bf16(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); 15 | } 16 | void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 17 | const at::cuda::OptionalCUDAGuard device_guard(device_of(state)); 18 | cuda_forward_fp16(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); 19 | } 20 | void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 21 | const at::cuda::OptionalCUDAGuard device_guard(device_of(state)); 22 | cuda_forward_fp32(B, T, C, H, state.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); 23 | } 24 | 25 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 26 | m.def("forward_bf16", &forward_bf16, "rwkv6 forward_bf16"); 27 | m.def("forward_fp16", &forward_fp16, "rwkv6 forward_fp16"); 28 | m.def("forward_fp32", &forward_fp32, "rwkv6 forward_fp32"); 29 | } 30 | TORCH_LIBRARY(rwkv6, m) { 31 | m.def("forward_bf16", forward_bf16); 32 | m.def("forward_fp16", forward_fp16); 33 | m.def("forward_fp32", forward_fp32); 34 | } 35 | -------------------------------------------------------------------------------- /cuda/wkv5_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "ATen/ATen.h" 4 | typedef at::BFloat16 bf16; 5 | 6 | template 7 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, 8 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, 9 | F *__restrict__ const _y) 10 | { 11 | const int b = blockIdx.x / H; 12 | const int h = blockIdx.x % H; 13 | const int i = threadIdx.x; 14 | _w += h*_N_; 15 | _u += h*_N_; 16 | 17 | __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; 18 | float state[_N_] = {0}; 19 | 20 | __syncthreads(); 21 | w[i] = _w[i]; 22 | u[i] = float(_u[i]); 23 | __syncthreads(); 24 | 25 | for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) 26 | { 27 | __syncthreads(); 28 | r[i] = float(_r[t]); 29 | k[i] = float(_k[t]); 30 | __syncthreads(); 31 | 32 | const float v = float(_v[t]); 33 | float y = 0; 34 | 35 | #pragma unroll 36 | for (int j = 0; j < _N_; j+=4) 37 | { 38 | const float4& r_ = (float4&)(r[j]); 39 | const float4& k_ = (float4&)(k[j]); 40 | const float4& w_ = (float4&)(w[j]); 41 | const float4& u_ = (float4&)(u[j]); 42 | float4& s = (float4&)(state[j]); 43 | float4 x; 44 | 45 | x.x = k_.x * v; 46 | x.y = k_.y * v; 47 | x.z = k_.z * v; 48 | x.w = k_.w * v; 49 | 50 | y += r_.x * (u_.x * x.x + s.x); 51 | y += r_.y * (u_.y * x.y + s.y); 52 | y += r_.z * (u_.z * x.z + s.z); 53 | y += r_.w * (u_.w * x.w + s.w); 54 | 55 | s.x = s.x * w_.x + x.x; 56 | s.y = s.y * w_.y + x.y; 57 | s.z = s.z * w_.z + x.z; 58 | s.w = s.w * w_.w + x.w; 59 | } 60 | _y[t] = F(y); 61 | } 62 | } 63 | 64 | template 65 | __global__ void kernel_backward(const int B, const int T, const int C, const int H, 66 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const float *__restrict__ __w, const F *__restrict__ _u, const F *__restrict__ const _gy, 67 | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu) 68 | { 69 | const int b = blockIdx.x / H; 70 | const int h = blockIdx.x % H; 71 | const int i = threadIdx.x; 72 | _w += h*_N_; 73 | _u += h*_N_; 74 | __w += h*_N_; 75 | 76 | __shared__ float w_[_N_], u_[_N_]; 77 | __shared__ float r[_N_], k[_N_], v[_N_], gy[_N_]; 78 | __syncthreads(); 79 | w_[i] = _w[i]; 80 | u_[i] = float(_u[i]); 81 | __syncthreads(); 82 | 83 | const float w = w_[i]; 84 | const float ww = __w[i]; 85 | const float u = u_[i]; 86 | 87 | float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0}; 88 | 89 | float gw = 0, gu = 0; 90 | const int t000 = b*T*C + h*_N_ + i; 91 | const int t111 = (b+1)*T*C + h*_N_ + i; 92 | const int t222 = t111 - 2*C; 93 | 94 | for (int t = t000; t < t111; t += C) 95 | { 96 | __syncthreads(); 97 | v[i] = float(_v[t]); 98 | gy[i] = float(_gy[t]); 99 | __syncthreads(); 100 | 101 | const float k = float(_k[t]); 102 | float gr = 0, gu_ = 0; 103 | 104 | #pragma unroll 105 | for (int j = 0; j < _N_; j++) 106 | { 107 | float& s = state[j]; 108 | float x = k * v[j]; 109 | 110 | gr += (u * x + s) * gy[j]; 111 | gu_ += x * gy[j]; 112 | s = s * w + x; 113 | } 114 | _gr[t] = F(gr); 115 | gu += float(_r[t]) * gu_; 116 | } 117 | _gu[b*C + h*_N_ + i] = F(gu); 118 | 119 | for (int t = t000; t < t222; t += C) 120 | { 121 | __syncthreads(); 122 | v[i] = float(_v[t]); 123 | gy[i] = float(_gy[t + 2*C]); 124 | __syncthreads(); 125 | 126 | const float k = float(_k[t]); 127 | float gw_ = 0; 128 | 129 | #pragma unroll 130 | for (int j = 0; j < _N_; j++) 131 | { 132 | float& s = saaaa[j]; 133 | float& s2 = sbbbb[j]; 134 | float x = k * v[j]; 135 | 136 | float tmp = w * (x + s); 137 | s = tmp; 138 | s2 = tmp + w * s2; 139 | gw_ += s2 * gy[j]; 140 | } 141 | gw += float(_r[t + 2*C]) * gw_; 142 | } 143 | _gw[b*C + h*_N_ + i] = F(ww * gw); 144 | 145 | for (int t = t111 - C; t >= t000; t -= C) 146 | { 147 | __syncthreads(); 148 | v[i] = float(_v[t]); 149 | gy[i] = float(_gy[t]); 150 | __syncthreads(); 151 | 152 | const float rr = float(_r[t]); 153 | float gk = 0; 154 | 155 | #pragma unroll 156 | for (int j = 0; j < _N_; j++) 157 | { 158 | float& s = scccc[j]; 159 | float x = rr * gy[j]; 160 | 161 | gk += (u * x + s) * v[j]; 162 | s = x + s * w; 163 | } 164 | _gk[t] = F(gk); 165 | } 166 | 167 | for (int t = t111 - C; t >= t000; t -= C) 168 | { 169 | __syncthreads(); 170 | r[i] = float(_r[t]); 171 | k[i] = float(_k[t]); 172 | __syncthreads(); 173 | 174 | const float gyy = float(_gy[t]); 175 | float gv = 0; 176 | 177 | #pragma unroll 178 | for (int j = 0; j < _N_; j++) 179 | { 180 | float& s = sdddd[j]; 181 | float x = gyy * r[j]; 182 | 183 | gv += (u_[j] * x + s) * k[j]; 184 | s = x + s * w_[j]; 185 | } 186 | _gv[t] = F(gv); 187 | } 188 | } 189 | 190 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y) 191 | { 192 | assert(H*_N_ == C); 193 | assert(_N_%4 == 0); 194 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y); 195 | } 196 | 197 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu) 198 | { 199 | assert(H*_N_ == C); 200 | assert(_N_%4 == 0); 201 | kernel_backward<<>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu); 202 | } 203 | -------------------------------------------------------------------------------- /cuda/wkv5_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "ATen/ATen.h" 3 | typedef at::BFloat16 bf16; 4 | 5 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); 6 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu); 7 | 8 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 9 | cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); 10 | } 11 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) { 12 | cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), ww.data_ptr(), u.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr()); 13 | } 14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 15 | m.def("forward", &forward, "wkv5 forward"); 16 | m.def("backward", &backward, "wkv5 backward"); 17 | } 18 | 19 | TORCH_LIBRARY(wkv5, m) { 20 | m.def("forward", forward); 21 | m.def("backward", backward); 22 | } 23 | -------------------------------------------------------------------------------- /cuda/wkv6_bi_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "ATen/ATen.h" 3 | typedef at::BFloat16 bf16; 4 | 5 | void cuda_forward(int B, int T, int C, int H, const int * const mask, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); 6 | void cuda_backward(int B, int T, int C, int H, const int * const mask, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu); 7 | 8 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, const torch::Tensor &mask, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 9 | cuda_forward(B, T, C, H, mask.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); 10 | } 11 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, const torch::Tensor &mask, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) { 12 | cuda_backward(B, T, C, H, mask.data_ptr(), r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr()); 13 | } 14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 15 | m.def("forward", &forward, "wkv6 bi forward"); 16 | m.def("backward", &backward, "wkv6 bi backward"); 17 | } 18 | 19 | TORCH_LIBRARY(wkv6bi, m) { 20 | m.def("forward", forward); 21 | m.def("backward", backward); 22 | } -------------------------------------------------------------------------------- /cuda/wkv6_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "ATen/ATen.h" 4 | typedef at::BFloat16 bf16; 5 | 6 | template 7 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, 8 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, 9 | F *__restrict__ const _y) 10 | { 11 | const int b = blockIdx.x / H; 12 | const int h = blockIdx.x % H; 13 | const int i = threadIdx.x; 14 | _u += h*_N_; 15 | 16 | __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; 17 | float state[_N_] = {0}; 18 | 19 | __syncthreads(); 20 | u[i] = float(_u[i]); 21 | __syncthreads(); 22 | 23 | for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C) 24 | { 25 | __syncthreads(); 26 | w[i] = exp(_w[t]); 27 | r[i] = float(_r[t]); 28 | k[i] = float(_k[t]); 29 | __syncthreads(); 30 | 31 | const float v = float(_v[t]); 32 | float y = 0; 33 | 34 | #pragma unroll 35 | for (int j = 0; j < _N_; j+=4) 36 | { 37 | const float4& r_ = (float4&)(r[j]); 38 | const float4& k_ = (float4&)(k[j]); 39 | const float4& w_ = (float4&)(w[j]); 40 | const float4& u_ = (float4&)(u[j]); 41 | float4& s = (float4&)(state[j]); 42 | float4 x; 43 | 44 | x.x = k_.x * v; 45 | x.y = k_.y * v; 46 | x.z = k_.z * v; 47 | x.w = k_.w * v; 48 | 49 | y += r_.x * (u_.x * x.x + s.x); 50 | y += r_.y * (u_.y * x.y + s.y); 51 | y += r_.z * (u_.z * x.z + s.z); 52 | y += r_.w * (u_.w * x.w + s.w); 53 | 54 | s.x = s.x * w_.x + x.x; 55 | s.y = s.y * w_.y + x.y; 56 | s.z = s.z * w_.z + x.z; 57 | s.w = s.w * w_.w + x.w; 58 | } 59 | _y[t] = F(y); 60 | } 61 | } 62 | 63 | template 64 | __global__ void kernel_backward_111(const int B, const int T, const int C, const int H, 65 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy, 66 | F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gu) 67 | { 68 | const int b = blockIdx.x / H; 69 | const int h = blockIdx.x % H; 70 | const int i = threadIdx.x; 71 | _u += h*_N_; 72 | 73 | __shared__ float u_[_N_]; 74 | __shared__ float r[_N_], k[_N_], v[_N_], w_[_N_], gy[_N_]; 75 | __syncthreads(); 76 | u_[i] = float(_u[i]); 77 | __syncthreads(); 78 | 79 | const float u = u_[i]; 80 | 81 | float state[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0}; 82 | 83 | const int t_0 = b*T*C + h*_N_ + i; 84 | const int t_T_1 = t_0 + (T-1)*C; 85 | const int t_T = t_0 + T*C; 86 | 87 | float gu = 0; 88 | for (int t = t_0; t < t_T; t += C) 89 | { 90 | __syncthreads(); 91 | v[i] = float(_v[t]); 92 | gy[i] = float(_gy[t]); 93 | __syncthreads(); 94 | 95 | const float k = float(_k[t]); 96 | const float w = exp(_w[t]); 97 | float gr = 0, gu_ = 0; 98 | 99 | #pragma unroll 100 | for (int j = 0; j < _N_; j++) 101 | { 102 | float& s = state[j]; 103 | float x = k * v[j]; 104 | 105 | gr += (u * x + s) * gy[j]; 106 | gu_ += x * gy[j]; 107 | s = s * w + x; 108 | } 109 | _gr[t] = F(gr); 110 | gu += float(_r[t]) * gu_; 111 | } 112 | _gu[b*C + h*_N_ + i] = F(gu); 113 | 114 | for (int t = t_T_1; t >= t_0; t -= C) 115 | { 116 | __syncthreads(); 117 | v[i] = float(_v[t]); 118 | gy[i] = float(_gy[t]); 119 | __syncthreads(); 120 | 121 | const float rr = float(_r[t]); 122 | const float w = exp(_w[t]); 123 | float gk = 0; 124 | 125 | #pragma unroll 126 | for (int j = 0; j < _N_; j++) 127 | { 128 | float& s = scccc[j]; 129 | float x = rr * gy[j]; 130 | 131 | gk += (u * x + s) * v[j]; 132 | s = x + s * w; 133 | } 134 | _gk[t] = F(gk); 135 | } 136 | 137 | for (int t = t_T_1; t >= t_0; t -= C) 138 | { 139 | __syncthreads(); 140 | r[i] = float(_r[t]); 141 | k[i] = float(_k[t]); 142 | w_[i] = exp(_w[t]); 143 | __syncthreads(); 144 | 145 | const float gyy = float(_gy[t]); 146 | float gv = 0; 147 | 148 | #pragma unroll 149 | for (int j = 0; j < _N_; j++) 150 | { 151 | float& s = sdddd[j]; 152 | float x = gyy * r[j]; 153 | 154 | gv += (u_[j] * x + s) * k[j]; 155 | s = x + s * w_[j]; 156 | } 157 | _gv[t] = F(gv); 158 | } 159 | } 160 | 161 | template 162 | __global__ void kernel_backward_222(const int B, const int T, const int C, const int H, 163 | const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy, 164 | F *__restrict__ const _gw) 165 | { 166 | const int b = blockIdx.x / H; 167 | const int h = blockIdx.x % H; 168 | const int i = threadIdx.x; 169 | 170 | __shared__ float v[_N_], gy[_N_]; 171 | float saaaa[_N_] = {0}, sbbbb[_T_-2] = {0}, scccc[_N_] = {0}; 172 | 173 | const int t_0 = b*T*C + h*_N_ + i; 174 | const int t_1 = t_0 + C; 175 | const int t_2 = t_0 + 2*C; 176 | const int t_T_1 = t_0 + (T-1)*C; 177 | 178 | for (int t = t_T_1; t > t_1; t -= C) 179 | { 180 | __syncthreads(); 181 | gy[i] = float(_gy[t]); 182 | v[i] = float(_v[t-2*C]); 183 | __syncthreads(); 184 | 185 | const float r = float(_r[t]); 186 | const float w = exp(_w[t-C]); 187 | float sum = 0.0f; 188 | 189 | #pragma unroll 190 | for (int j = 0; j < _N_; j++) 191 | { 192 | float& s = saaaa[j]; 193 | float x = r * gy[j]; 194 | s = (s + x) * w; 195 | sum += s * v[j]; 196 | } 197 | sbbbb[(t-t_2)/C] = sum * float(_k[t-2*C]); 198 | } 199 | 200 | float sss = sbbbb[0]; 201 | _gw[t_0] = 0; 202 | _gw[t_1] = F(sss * _w[t_1]); 203 | 204 | for (int t = t_2; t < t_T_1; t += C) 205 | { 206 | __syncthreads(); 207 | gy[i] = float(_gy[t]); 208 | v[i] = float(_v[t-2*C]); 209 | __syncthreads(); 210 | 211 | const float w = exp(_w[t-C]); 212 | const float k = float(_k[t-2*C]); 213 | float sum = 0.0f; 214 | 215 | #pragma unroll 216 | for (int j = 0; j < _N_; j++) 217 | { 218 | float& s = scccc[j]; 219 | float x = k * v[j]; 220 | s = (s + x) * w; 221 | sum += s * gy[j]; 222 | } 223 | sss += sbbbb[(t-t_1)/C] - (sum * float(_r[t])); 224 | _gw[t] = F(sss * _w[t]); 225 | } 226 | _gw[t_T_1] = 0; 227 | } 228 | 229 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y) 230 | { 231 | assert(H*_N_ == C); 232 | assert(_N_%4 == 0); 233 | kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y); 234 | } 235 | 236 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu) 237 | { 238 | assert(H*_N_ == C); 239 | assert(_N_%4 == 0); 240 | kernel_backward_111<<>>(B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gu); 241 | kernel_backward_222<<>>(B, T, C, H, r, k, v, w, u, gy, gw); 242 | } 243 | -------------------------------------------------------------------------------- /cuda/wkv6_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "ATen/ATen.h" 3 | typedef at::BFloat16 bf16; 4 | 5 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); 6 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu); 7 | 8 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 9 | cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), y.data_ptr()); 10 | } 11 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) { 12 | cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr()); 13 | } 14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 15 | m.def("forward", &forward, "wkv6 forward"); 16 | m.def("backward", &backward, "wkv6 backward"); 17 | } 18 | 19 | TORCH_LIBRARY(wkv6, m) { 20 | m.def("forward", forward); 21 | m.def("backward", backward); 22 | } 23 | -------------------------------------------------------------------------------- /cuda/wkv6infctx_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "ATen/ATen.h" 3 | typedef at::BFloat16 bf16; 4 | 5 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *y); 6 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu, bf16 *gs); 7 | 8 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &y) { 9 | cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), s.data_ptr(), y.data_ptr()); 10 | } 11 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gs) { 12 | cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), s.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr(), gs.data_ptr()); 13 | } 14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 15 | m.def("forward", &forward, "wkv6state forward"); 16 | m.def("backward", &backward, "wkv6state backward"); 17 | } 18 | 19 | TORCH_LIBRARY(wkv6state, m) { 20 | m.def("forward", forward); 21 | m.def("backward", backward); 22 | } 23 | -------------------------------------------------------------------------------- /cuda/wkv6state_op.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "ATen/ATen.h" 3 | typedef at::BFloat16 bf16; 4 | 5 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *y); 6 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu, bf16 *gs); 7 | 8 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &y) { 9 | cuda_forward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), s.data_ptr(), y.data_ptr()); 10 | } 11 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gs) { 12 | cuda_backward(B, T, C, H, r.data_ptr(), k.data_ptr(), v.data_ptr(), w.data_ptr(), u.data_ptr(), s.data_ptr(), gy.data_ptr(), gr.data_ptr(), gk.data_ptr(), gv.data_ptr(), gw.data_ptr(), gu.data_ptr(), gs.data_ptr()); 13 | } 14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 15 | m.def("forward", &forward, "wkv6state forward"); 16 | m.def("backward", &backward, "wkv6state backward"); 17 | } 18 | 19 | TORCH_LIBRARY(wkv6state, m) { 20 | m.def("forward", forward); 21 | m.def("backward", backward); 22 | } 23 | -------------------------------------------------------------------------------- /cuda/wrapper.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "ATen/ATen.h" 3 | #include 4 | #include 5 | 6 | typedef at::Half fp16; 7 | 8 | template 9 | void cuda_wkv_forward(int B, int T, int C, 10 | float *w, float *u, F *k, F *v, F *y, 11 | float *aa, float *bb, float *pp); 12 | template 13 | void cuda_mm8_seq(int B, int N, int M, 14 | F *x, int x_stride, 15 | uint8_t *w, int w_stride, 16 | F *mx, F *rx, 17 | F *my, F *ry, 18 | F *y, int y_stride); 19 | template 20 | void cuda_mm8_one(int N, int M, 21 | F *x, 22 | uint8_t *w, int w_stride, 23 | F *mx, F *rx, 24 | F *my, F *ry, 25 | float *y); 26 | 27 | void wkv_forward(int64_t B, int64_t T, int64_t C, 28 | torch::Tensor &w, torch::Tensor &u, 29 | torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, 30 | torch::Tensor &aa, torch::Tensor &bb, torch::Tensor &pp) { 31 | const at::cuda::OptionalCUDAGuard device_guard(device_of(w)); 32 | switch (k.scalar_type()) { 33 | case c10::ScalarType::Half: 34 | cuda_wkv_forward(B, T, C, 35 | w.data_ptr(), u.data_ptr(), 36 | k.data_ptr(), v.data_ptr(), y.data_ptr(), 37 | aa.data_ptr(), bb.data_ptr(), pp.data_ptr()); 38 | break; 39 | case c10::ScalarType::Float: 40 | cuda_wkv_forward(B, T, C, 41 | w.data_ptr(), u.data_ptr(), 42 | k.data_ptr(), v.data_ptr(), y.data_ptr(), 43 | aa.data_ptr(), bb.data_ptr(), pp.data_ptr()); 44 | break; 45 | default: 46 | assert(false && "Only FP16 and FP32 are currently supported"); 47 | } 48 | } 49 | 50 | void mm8_seq(int64_t B, int64_t N, int64_t M, 51 | torch::Tensor &x, torch::Tensor &w, 52 | torch::Tensor &mx, torch::Tensor &rx, 53 | torch::Tensor &my, torch::Tensor &ry, 54 | torch::Tensor &y) { 55 | assert(x.stride(1) == 1); 56 | assert(w.stride(1) == 1); 57 | assert(mx.stride(0) == 1 && rx.stride(0) == 1); 58 | assert(my.stride(0) == 1 && ry.stride(0) == 1); 59 | assert(y.stride(1) == 1); 60 | const at::cuda::OptionalCUDAGuard device_guard(device_of(w)); 61 | switch (x.scalar_type()) { 62 | case c10::ScalarType::Half: 63 | cuda_mm8_seq( 64 | B, N, M, 65 | x.data_ptr(), x.stride(0), 66 | w.data_ptr(), w.stride(0), 67 | mx.data_ptr(), rx.data_ptr(), 68 | my.data_ptr(), ry.data_ptr(), 69 | y.data_ptr(), y.stride(0)); 70 | break; 71 | case c10::ScalarType::Float: 72 | cuda_mm8_seq( 73 | B, N, M, 74 | x.data_ptr(), x.stride(0), 75 | w.data_ptr(), w.stride(0), 76 | mx.data_ptr(), rx.data_ptr(), 77 | my.data_ptr(), ry.data_ptr(), 78 | y.data_ptr(), y.stride(0)); 79 | break; 80 | default: 81 | assert(false && "Only FP16 and FP32 are currently supported"); 82 | } 83 | } 84 | void mm8_one(int64_t N, int64_t M, 85 | torch::Tensor &x, torch::Tensor &w, 86 | torch::Tensor &mx, torch::Tensor &rx, 87 | torch::Tensor &my, torch::Tensor &ry, 88 | torch::Tensor &y) { 89 | assert(x.stride(0) == 1); 90 | assert(w.stride(1) == 1); 91 | assert(mx.stride(0) == 1 && rx.stride(0) == 1); 92 | assert(my.stride(0) == 1 && ry.stride(0) == 1); 93 | assert(y.stride(0) == 1); 94 | const at::cuda::OptionalCUDAGuard device_guard(device_of(w)); 95 | switch (x.scalar_type()) { 96 | case c10::ScalarType::Half: 97 | cuda_mm8_one( 98 | N, M, 99 | x.data_ptr(), 100 | w.data_ptr(), w.stride(0), 101 | mx.data_ptr(), rx.data_ptr(), 102 | my.data_ptr(), ry.data_ptr(), 103 | y.data_ptr()); 104 | break; 105 | case c10::ScalarType::Float: 106 | cuda_mm8_one( 107 | N, M, 108 | x.data_ptr(), 109 | w.data_ptr(), w.stride(0), 110 | mx.data_ptr(), rx.data_ptr(), 111 | my.data_ptr(), ry.data_ptr(), 112 | y.data_ptr()); 113 | break; 114 | default: 115 | assert(false && "Only FP16 and FP32 are currently supported"); 116 | } 117 | } 118 | 119 | using torch::Tensor; 120 | 121 | #ifndef DISABLE_CUBLAS_GEMM 122 | void gemm_fp16_cublas(Tensor a, Tensor b, Tensor c); 123 | #endif 124 | 125 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 126 | m.def("wkv_forward", &wkv_forward, "wkv forward"); 127 | m.def("mm8_seq", &mm8_seq, "mm8 seq"); 128 | m.def("mm8_one", &mm8_one, "mm8 one"); 129 | #ifndef DISABLE_CUBLAS_GEMM 130 | m.def("gemm_fp16_cublas", &gemm_fp16_cublas, "gemv fp16 cublas"); 131 | #endif 132 | } 133 | 134 | TORCH_LIBRARY(rwkv, m) { 135 | m.def("wkv_forward", wkv_forward); 136 | m.def("mm8_seq", mm8_seq); 137 | m.def("mm8_one", mm8_one); 138 | #ifndef DISABLE_CUBLAS_GEMM 139 | m.def("gemm_fp16_cublas", gemm_fp16_cublas); 140 | #endif 141 | } 142 | -------------------------------------------------------------------------------- /docker/DockerfileClient: -------------------------------------------------------------------------------- 1 | FROM python:3.10 2 | LABEL maintainer=RWKV-RAG-Client 3 | 4 | WORKDIR /root 5 | 6 | RUN mkdir -p /root/data # 数据存放位置 7 | 8 | # 初始化环境 9 | RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \ 10 | echo "Asia/Shanghai" > /etc/timezone 11 | 12 | RUN apt-get update -y && \ 13 | apt-get install -y git && \ 14 | apt-get install -y --no-install-recommends libglib2.0-0 && \ 15 | apt-get clean && \ 16 | rm -rf /var/lib/apt/lists/* 17 | RUN pip install --upgrade pip setuptools 18 | RUN pip install --index-url https://pypi.python.org/simple/ pipx && \ 19 | pipx install poetry --force 20 | 21 | # 下载代码 22 | 23 | RUN git clone https://github.com/AIIRWKV/RWKV-RAG.git 24 | 25 | # 安装 Python 依赖 26 | WORKDIR /root/RWKV-RAG 27 | RUN pip install -r required/requirements.txt 28 | 29 | # 安装浏览器驱动 30 | RUN playwright install 31 | RUN playwright install-deps 32 | RUN apt-get update && apt-get install libnss3 \ 33 | libnspr4 \ 34 | libdbus-1-3 \ 35 | libatk1.0-0 \ 36 | libatk-bridge2.0-0 \ 37 | libcups2 \ 38 | libdrm2 \ 39 | libatspi2.0-0 \ 40 | libxcomposite1 \ 41 | libxdamage1 \ 42 | libxfixes3 \ 43 | libxrandr2 \ 44 | libgbm1 \ 45 | libxkbcommon0 \ 46 | libasound2 47 | 48 | # 安装 tesseract-ocr 49 | RUN apt-get update && apt-get install -y tesseract-ocr \ 50 | tesseract-ocr-chi-sim 51 | 52 | 53 | # 启动应用程序 54 | ENTRYPOINT ["streamlit", "run", "client.py"] 55 | -------------------------------------------------------------------------------- /docker/DockerfileIndexService: -------------------------------------------------------------------------------- 1 | FROM python:3.10 2 | LABEL maintainer=RWKV-RAG-Index-Service 3 | 4 | WORKDIR /root 5 | 6 | # 初始化环境 7 | RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \ 8 | echo "Asia/Shanghai" > /etc/timezone 9 | 10 | RUN apt-get update -y && \ 11 | apt-get install -y git && \ 12 | apt-get install -y --no-install-recommends libglib2.0-0 && \ 13 | apt-get clean && \ 14 | rm -rf /var/lib/apt/lists/* 15 | RUN pip install --upgrade pip setuptools 16 | RUN pip install --index-url https://pypi.python.org/simple/ pipx && \ 17 | pipx install poetry --force 18 | 19 | # 下载代码 20 | RUN git clone https://github.com/AIIRWKV/RWKV-RAG.git 21 | 22 | # 安装 Python 依赖 23 | WORKDIR /root/RWKV-RAG 24 | RUN pip install -r required/index_service_requirements.txt 25 | 26 | # 启动应用程序 27 | ENTRYPOINT ["python3.10", "service.py", "--service_name", "index_service"] -------------------------------------------------------------------------------- /docker/DockerfileLLMService: -------------------------------------------------------------------------------- 1 | # 基础镜像,使用支持 CUDA 的镜像 2 | FROM nvidia/cuda:12.1.0-cudnn8-runtime-ubuntu22.04 3 | LABEL maintainer=RWKV-RAG-LLM-Service 4 | 5 | # 设置工作目录 6 | WORKDIR /root 7 | RUN mkdir -p /root/models # 模型文件挂载位置 8 | 9 | # 初始化环境 10 | RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \ 11 | echo "Asia/Shanghai" > /etc/timezone 12 | 13 | # 安装 Python 3.10,并添加 deadsnakes PPA 14 | RUN apt-get update && apt-get install -y \ 15 | software-properties-common && \ 16 | add-apt-repository ppa:deadsnakes/ppa && \ 17 | apt-get update && apt-get install -y \ 18 | python3.10 \ 19 | python3.10-dev \ 20 | python3.10-venv \ 21 | python3-pip \ 22 | curl \ 23 | gnupg \ 24 | ca-certificates \ 25 | && rm -rf /var/lib/apt/lists/* 26 | 27 | # 添加 CUDA 存储库的 GPG 密钥并设置 keyring 28 | RUN curl -fsSL https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-archive-keyring.gpg | tee /usr/share/keyrings/cuda-archive-keyring.gpg > /dev/null 29 | 30 | # 添加 CUDA 存储库,并使用 keyring 签署 31 | RUN echo "deb [signed-by=/usr/share/keyrings/cuda-archive-keyring.gpg] https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/ /" > /etc/apt/sources.list.d/cuda.list 32 | 33 | # 删除重复的存储库配置 34 | RUN rm -f /etc/apt/sources.list.d/cuda-ubuntu2204-x86_64.list 35 | 36 | # 更新并安装 CUDA 工具包,包括 nvcc 37 | RUN apt-get update && apt-get install -y \ 38 | libcublas-12-1 \ 39 | cuda-toolkit-12-1 \ 40 | --allow-change-held-packages \ 41 | && rm -rf /var/lib/apt/lists/* 42 | 43 | # 设置 CUDA_HOME 环境变量 44 | ENV CUDA_HOME /usr/local/cuda 45 | 46 | RUN apt-get update -y && \ 47 | apt-get install -y git && \ 48 | apt-get install -y --no-install-recommends libglib2.0-0 && \ 49 | apt-get clean && \ 50 | rm -rf /var/lib/apt/lists/* 51 | RUN pip install --upgrade pip setuptools 52 | RUN pip install --index-url https://pypi.python.org/simple/ pipx && \ 53 | pipx install poetry --force 54 | 55 | RUN git clone https://github.com/AIIRWKV/RWKV-RAG.git 56 | 57 | # 安装 Python 依赖 58 | WORKDIR /root/RWKV-RAG 59 | RUN pip install -r required/llm_service_requirements.txt 60 | 61 | 62 | # 启动应用程序 63 | ENTRYPOINT ["python3.10", "service.py", "--service_name", "llm_service"] 64 | 65 | -------------------------------------------------------------------------------- /docker/DockerfileProxyService: -------------------------------------------------------------------------------- 1 | FROM python:3.10 2 | LABEL maintainer=RWKV-RAG-Proxy-Service 3 | 4 | WORKDIR /root 5 | 6 | # 初始化环境 7 | RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \ 8 | echo "Asia/Shanghai" > /etc/timezone 9 | 10 | RUN apt-get update -y && \ 11 | apt-get install -y git && \ 12 | apt-get install -y --no-install-recommends libglib2.0-0 && \ 13 | apt-get clean && \ 14 | rm -rf /var/lib/apt/lists/* 15 | RUN pip install --upgrade pip setuptools 16 | RUN pip install --index-url https://pypi.python.org/simple/ pipx && \ 17 | pipx install poetry --force 18 | RUN pip install pyzmq==26.0.3 19 | RUN pip install PyYAML==6.0.1 20 | 21 | # 下载代码 22 | RUN git clone https://github.com/AIIRWKV/RWKV-RAG.git 23 | 24 | # 安装 Python 依赖 25 | WORKDIR /root/RWKV-RAG 26 | 27 | # 启动应用程序 28 | ENTRYPOINT ["python3.10", "proxy.py"] -------------------------------------------------------------------------------- /docs/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # 贡献指南 2 | 3 | 本文档提供了向 RWKV-RAG 提交贡献指南和主要注意事项。 4 | 5 | - 要报告错误,请向我们提交[GitHub issue](https://github.com/AIIRWKV/RWKV-RAG/issues)。 6 | 7 | 8 | ## 您可以做出什么贡献 9 | 10 | 11 | 下面的列表提到了您可以做出的一些贡献,但这并不是完整的列表。 12 | 13 | - 提出或实现新功能 14 | - 修复错误 15 | - 添加测试用例或演示 16 | - 发布博客或教程 17 | - 对现有文档、代码或注释的更新。 18 | - 建议更用户友好的错误代码 19 | 20 | 21 | ## 提交拉取请求 (PR) 22 | 23 | ### 常规流程 24 | 25 | 1. fork我们的 GitHub 项目仓库。 26 | 2. 将你的 fork 克隆到本地机器: 27 | 3. 创建本地分支: git checkout -b my-branch 28 | 4. 在提交信息中提供足够的信息 `git commit -m 'Provide sufficient info in your commit message'` 29 | 5. 将更改提交到本地分支,并推送到 GitHub:(包括必要的提交信息) `git push origin my-branch` 30 | 6. 提交拉取请求以供审核。 31 | 32 | 33 | ### 提交 PR 之前 34 | 35 | - 考虑将大型 PR 拆分为多个较小的独立 PR,以保留可追溯的开发历史。 36 | - 确保您的 PR 只解决一个问题,或者将任何不相关的更改保持在较小范围内。 37 | - 贡献新功能时添加测试用例。它们可证明您的代码功能正常,并可防止未来更改带来的潜在问题。 38 | 39 | 40 | ### 描述你的 PR 41 | - 确保您的 PR 标题简洁明了,提供所有必需的信息。 42 | - 如果适用,请在 PR 描述中引用相应的 GitHub issue。 43 | - 在您的描述中包含足够的设计细节,以说明重大变化。 44 | 45 | 46 | ### 审查并合并 PR 47 | 48 | 确保您的 PR 在合并之前通过所有持续集成 (CI) 测试。 -------------------------------------------------------------------------------- /docs/User_guide.md: -------------------------------------------------------------------------------- 1 | ## 用户指南 2 | 3 | 4 | ### 知识库管理 5 | 6 | 知识库管理界面用于管理存储在向量数据库中的知识库,一个collection就是一个知识库,默认都会创建一个名为initial的知识库。支持对知识库进行新增、删除和查询知识库内容等操作。 7 | 8 | > [!TIP] 9 | > 10 | > 由于Streamlit架构的限制,新增、删除知识库后,建议刷新 Web 页面同步最新改动。 11 | 12 | ![RWKV-RAG-WebUI-knowledge-manager](./img/RWKV-RAG-Manage-Database.gif) 13 | 14 | --- 15 | 16 | ### 知识入库 17 | 18 | 知识入库界面用于将文本内容**分块索引**到现有的知识库中,已入库的知识可以被检索,用于问答机器人或其他下游服务。 19 | 20 | RWKV-RAG 支持三种不同的知识入库方法,这些方法支持解析 TXT、PDF和Excel 三种文件格式: 21 | 22 | - **手动输入:** 在输入框中手动输入或粘贴文本内容,系统会按行对文本进行Chunking(**分块**) 23 | - **从本地计算机上传到服务器端:** 从你的本地客户端往服务器端上传一个文件,系统会按照固定长度和块重叠字符数对文件进行Chunking(**分块**) 24 | - **从服务器端本地上传:** 如果你需要将服务器中**某个文件**或者**某个目录**下所有文件的内容加入知识库,填写文件或者目录的路径,系统会按照固定长度和块重叠字符数对文件进行Chunking(**分块**) 25 | 26 | 27 | > [!WARNING] 28 | > 29 | > 支持文本格式或图片格式的PDF文件入库,但是需要提前安装**tesseract**,并需要安装中文语言包(**chi_sim**) 30 | 31 | > [!TIP] 32 | > 33 | > RWKV-RAG 也支持从互联网上搜索知识,并将搜索到的知识文本以 TXT 格式保存到**服务器端的指定目录**。 34 | > 35 | > **联网搜索得到的 txt 文本文件仍然需要进行知识入库,才能加入现有知识库中。** 36 | 37 | ![联网搜索知识](./img/RWKV-RAG-Search-From-Internet.png) 38 | 39 | --- 40 | 41 | ### 知识问答机器人 42 | 43 | RWKV-RAG 系统提供基于知识库的问答机器人(RWKV-RAG-CHAT)。用户可以从现有的知识库中检索特定主题的知识,然后利用提取到的知识与模型进行聊天,以增强模型的回答效果。 44 | 45 | RWKV-RAG-CHAT 的工作流程如下: 46 | 47 | 1. **输入查询内容,点击 “召回” 按钮** 48 | 49 | ![RWKV-RAG-CHAT-1-Query](./img/RWKV-RAG-CHAT-1-Query.png) 50 | 51 | 2. **RWKV-RAG 从知识库中提取最相关的知识(文本块)** 52 | 53 | ![RWKV-RAG-CHAT-2-Get-Text](./img/RWKV-RAG-CHAT-2-Get-Text.png) 54 | 55 | 3. **rerank 模型对提取出来的文本块进行匹配度打分,选出最佳匹配知识** 56 | 57 | ![RWKV-RAG-CHAT-3-Rerank](img/RWKV-RAG-CHAT-3-Rerank.png) 58 | 59 | 4. **在底部输入框中输入问题并点击 “发送” 按钮,与模型聊天** 60 | 61 | ![RWKV-RAG-CHAT-4-Chat](./img/RWKV-RAG-CHAT-4-Chat.png) 62 | -------------------------------------------------------------------------------- /docs/img/RWKV-RAG-Base-Model-Manage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIIRWKV/RWKV-RAG/0ccd84591af280f0d20efd932041bb00a4430fb2/docs/img/RWKV-RAG-Base-Model-Manage.png -------------------------------------------------------------------------------- /docs/img/RWKV-RAG-CHAT-1-Query.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIIRWKV/RWKV-RAG/0ccd84591af280f0d20efd932041bb00a4430fb2/docs/img/RWKV-RAG-CHAT-1-Query.png -------------------------------------------------------------------------------- /docs/img/RWKV-RAG-CHAT-2-Get-Text.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIIRWKV/RWKV-RAG/0ccd84591af280f0d20efd932041bb00a4430fb2/docs/img/RWKV-RAG-CHAT-2-Get-Text.png -------------------------------------------------------------------------------- /docs/img/RWKV-RAG-CHAT-3-Rerank.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIIRWKV/RWKV-RAG/0ccd84591af280f0d20efd932041bb00a4430fb2/docs/img/RWKV-RAG-CHAT-3-Rerank.png -------------------------------------------------------------------------------- /docs/img/RWKV-RAG-CHAT-4-Chat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIIRWKV/RWKV-RAG/0ccd84591af280f0d20efd932041bb00a4430fb2/docs/img/RWKV-RAG-CHAT-4-Chat.png -------------------------------------------------------------------------------- /docs/img/RWKV-RAG-Manage-Database.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIIRWKV/RWKV-RAG/0ccd84591af280f0d20efd932041bb00a4430fb2/docs/img/RWKV-RAG-Manage-Database.gif -------------------------------------------------------------------------------- /docs/img/RWKV-RAG-Search-From-Internet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIIRWKV/RWKV-RAG/0ccd84591af280f0d20efd932041bb00a4430fb2/docs/img/RWKV-RAG-Search-From-Internet.png -------------------------------------------------------------------------------- /docs/img/RWKV-RAG-Tuning-Data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIIRWKV/RWKV-RAG/0ccd84591af280f0d20efd932041bb00a4430fb2/docs/img/RWKV-RAG-Tuning-Data.png -------------------------------------------------------------------------------- /docs/img/RWKV-RAG-Tuning-Service-Mange.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIIRWKV/RWKV-RAG/0ccd84591af280f0d20efd932041bb00a4430fb2/docs/img/RWKV-RAG-Tuning-Service-Mange.png -------------------------------------------------------------------------------- /docs/img/RWKV-RAG-WebUI-client.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIIRWKV/RWKV-RAG/0ccd84591af280f0d20efd932041bb00a4430fb2/docs/img/RWKV-RAG-WebUI-client.png -------------------------------------------------------------------------------- /docs/img/models_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIIRWKV/RWKV-RAG/0ccd84591af280f0d20efd932041bb00a4430fb2/docs/img/models_example.png -------------------------------------------------------------------------------- /etc/index_service_config.yml: -------------------------------------------------------------------------------- 1 | back_end: 2 | host: 3 | protocol: tcp 4 | port: 7784 5 | vectordb_host: 6 | vectordb_name: chromadb 7 | vectordb_port: 9998 -------------------------------------------------------------------------------- /etc/llm_service_config.yml: -------------------------------------------------------------------------------- 1 | back_end: 2 | host: 3 | port: 7782 4 | protocol: tcp 5 | base_model_path: /root/models/RWKV-x060-World-7B-v2.1-20240507-ctx4096.pth 6 | embedding_path: /root/models/bge-m31 7 | reranker_path: /root/models/BAAIbge-reranker-v2-m3 8 | -------------------------------------------------------------------------------- /etc/proxy_service_config.yml: -------------------------------------------------------------------------------- 1 | llm: 2 | front_end: 3 | host: 0.0.0.0 4 | protocol: tcp 5 | port: 7781 6 | back_end: 7 | host: 0.0.0.0 8 | protocol: tcp 9 | port: 7782 10 | index: 11 | front_end: 12 | host: 0.0.0.0 13 | protocol: tcp 14 | port: 7783 15 | back_end: 16 | host: 0.0.0.0 17 | protocol: tcp 18 | port: 7784 -------------------------------------------------------------------------------- /etc/ragq.yml: -------------------------------------------------------------------------------- 1 | llm: 2 | front_end: 3 | host: 4 | protocol: tcp 5 | port: 7781 6 | index: 7 | front_end: 8 | host: 9 | protocol: tcp 10 | port: 7783 11 | base: 12 | knowledge_base_path: /root/data 13 | sqlite_db_path: /root/data/files_services.db -------------------------------------------------------------------------------- /fla/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from fla.layers import (ABCAttention, BasedLinearAttention, DeltaNet, 4 | GatedLinearAttention, HGRN2Attention, LinearAttention, 5 | MultiScaleRetention, ReBasedLinearAttention) 6 | from fla.models import (ABCForCausalLM, ABCModel, DeltaNetForCausalLM, 7 | DeltaNetModel, GLAForCausalLM, GLAModel, 8 | HGRN2ForCausalLM, HGRN2Model, HGRNForCausalLM, 9 | HGRNModel, LinearAttentionForCausalLM, 10 | LinearAttentionModel, RetNetForCausalLM, RetNetModel, 11 | RWKV6ForCausalLM, RWKV6Model, TransformerForCausalLM, 12 | TransformerModel) 13 | from fla.ops import (chunk_gla, chunk_retention, fused_chunk_based, 14 | fused_chunk_gla, fused_chunk_retention) 15 | 16 | __all__ = [ 17 | 'ABCAttention', 18 | 'BasedLinearAttention', 19 | 'DeltaNet', 20 | 'HGRN2Attention', 21 | 'GatedLinearAttention', 22 | 'LinearAttention', 23 | 'MultiScaleRetention', 24 | 'ReBasedLinearAttention', 25 | 'ABCForCausalLM', 26 | 'ABCModel', 27 | 'DeltaNetForCausalLM', 28 | 'DeltaNetModel', 29 | 'HGRNForCausalLM', 30 | 'HGRNModel', 31 | 'HGRN2ForCausalLM', 32 | 'HGRN2Model', 33 | 'GLAForCausalLM', 34 | 'GLAModel', 35 | 'LinearAttentionForCausalLM', 36 | 'LinearAttentionModel', 37 | 'RetNetForCausalLM', 38 | 'RetNetModel', 39 | 'RWKV6ForCausalLM', 40 | 'RWKV6Model', 41 | 'TransformerForCausalLM', 42 | 'TransformerModel', 43 | 'chunk_gla', 44 | 'chunk_retention', 45 | 'fused_chunk_based', 46 | 'fused_chunk_gla', 47 | 'fused_chunk_retention' 48 | ] 49 | 50 | __version__ = '0.1' 51 | -------------------------------------------------------------------------------- /fla/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .abc import ABCAttention 4 | from .based import BasedLinearAttention 5 | from .delta_net import DeltaNet 6 | from .gla import GatedLinearAttention 7 | from .hgrn import HGRNAttention 8 | from .hgrn2 import HGRN2Attention 9 | from .linear_attn import LinearAttention 10 | from .multiscale_retention import MultiScaleRetention 11 | from .rebased import ReBasedLinearAttention 12 | from .rwkv6 import RWKV6Attention 13 | 14 | __all__ = [ 15 | 'ABCAttention', 16 | 'BasedLinearAttention', 17 | 'DeltaNet', 18 | 'GatedLinearAttention', 19 | 'HGRNAttention', 20 | 'HGRN2Attention', 21 | 'LinearAttention', 22 | 'MultiScaleRetention', 23 | 'ReBasedLinearAttention', 24 | 'RWKV6Attention' 25 | ] 26 | -------------------------------------------------------------------------------- /fla/layers/based.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Linear attention in Based. 5 | https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | from einops import rearrange 11 | 12 | from fla.modules.feature_map import TaylorFeatureMap 13 | from fla.ops.based import parallel_based 14 | from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn 15 | 16 | 17 | class BasedLinearAttention(nn.Module): 18 | def __init__( 19 | self, 20 | hidden_size: int, 21 | l_max: int = 2048, 22 | feature_dim: int = 16, 23 | num_key_value_heads: int = 12, 24 | num_heads: int = 12, 25 | feature_name: str = "taylor_exp", 26 | eps: float = 1e-12, 27 | causal: bool = True, 28 | mode: str = "parallel", 29 | ): 30 | super().__init__() 31 | self.hidden_size= hidden_size 32 | self.l_max = l_max 33 | self.mode = mode 34 | assert self.mode in ["fused_chunk", "parallel", 'chunk'] 35 | 36 | # linear attention 37 | self.feature_name = feature_name 38 | self.feature_dim = feature_dim 39 | self.num_key_value_heads = num_key_value_heads 40 | self.num_heads = num_heads 41 | self.head_dim = self.hidden_size // self.num_key_value_heads 42 | self.causal = causal 43 | 44 | self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False) 45 | self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False) 46 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) 47 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) 48 | self.dropout = nn.Identity() 49 | self.feature_map = TaylorFeatureMap(feature_dim) 50 | self.eps = eps 51 | 52 | self.apply(self._initialize_weights) 53 | 54 | def _initialize_weights(self, module: nn.Module): 55 | if getattr(module, "_is_hf_initialized", False): 56 | return 57 | if isinstance(module, nn.Linear): 58 | nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) 59 | if module.bias is not None: 60 | nn.init.zeros_(module.bias) 61 | module._is_hf_initialized = True 62 | 63 | def forward(self, hidden_states: torch.Tensor, **kwargs): 64 | mode = self.mode 65 | q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) 66 | q, k, v = map(lambda x: rearrange(x, "b l (h d) -> b h l d", h=self.num_heads), [q, k, v]) 67 | if mode == "fused_chunk": 68 | q, k = self.feature_map(q), self.feature_map(k) 69 | o = fused_chunk_linear_attn(q, k, v, normalize=True, scale=1) 70 | elif mode == 'chunk': 71 | q, k = self.feature_map(q), self.feature_map(k) 72 | o = chunk_linear_attn(q, k, v, normalize=True, scale=1) 73 | elif mode == 'parallel': 74 | assert q.shape[-1] <= 128 75 | o = parallel_based(q, k, v, True, True) 76 | o = rearrange(o, "b h l d -> b l (h d)") 77 | o = self.o_proj(o) 78 | o = self.dropout(o) 79 | return o 80 | 81 | # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119 82 | 83 | def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs): 84 | """ 85 | x (torch.Tensor): tensor of shape (b, d, l) 86 | y (torch.Tensor): tensor of shape (b, d, l) 87 | """ 88 | # hidden_states = hidden_states.transpose(1, 2) 89 | b, l, _ = hidden_states.size() 90 | q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) 91 | 92 | q = q.view(b, l, self.num_heads, self.feature_dim).transpose(1, 2) 93 | k = k.view(b, l, self.num_key_value_heads, self.feature_dim).transpose(1, 2) 94 | v = v.view(b, l, self.num_key_value_heads, self.head_dim).transpose(1, 2) 95 | 96 | # Linear attention 97 | q, k = self.feature_map(q), self.feature_map(k) 98 | q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1) 99 | 100 | # Compute attention 101 | if self.causal: 102 | y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps)) 103 | else: 104 | y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps)) 105 | y = rearrange(y, 'b h l d -> b l (h d)') 106 | y = self.o_proj(y.to(hidden_states.dtype)) 107 | y = self.dropout(y) 108 | return y.to(hidden_states.dtype) 109 | -------------------------------------------------------------------------------- /fla/layers/hgrn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # "Hierarchically Gated Recurrent Neural Network for Sequence Modeling" [https://arxiv.org/abs/2311.04823] 4 | 5 | from __future__ import annotations 6 | 7 | from typing import Optional, Tuple 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from einops import rearrange 13 | from transformers.cache_utils import Cache 14 | 15 | from fla.modules import FusedRMSNormSwishGate, ShortConvolution 16 | from fla.modules.activations import swiglu 17 | from fla.ops.hgrn import chunk_hgrn, fused_recurrent_hgrn 18 | 19 | 20 | class HGRNAttention(nn.Module): 21 | 22 | def __init__( 23 | self, 24 | mode: str = 'chunk', 25 | hidden_size: int = 1024, 26 | num_heads: Optional[int] = None, 27 | expand_ratio: Optional[int] = 1, 28 | use_short_conv: bool = False, 29 | conv_size: int = 4, 30 | conv_bias: bool = False, 31 | share_conv_kernel: bool = True, 32 | elementwise_affine: Optional[bool] = True, 33 | norm_eps: float = 1e-5, 34 | layer_idx: int = None 35 | ) -> HGRNAttention: 36 | super().__init__() 37 | 38 | self.mode = mode 39 | self.hidden_size = hidden_size 40 | self.num_heads = num_heads 41 | self.expand_ratio = expand_ratio 42 | self.input_dim = int(hidden_size * expand_ratio) 43 | self.head_dim = self.input_dim // self.num_heads 44 | 45 | self.use_short_conv = use_short_conv 46 | self.conv_size = conv_size 47 | self.conv_bias = conv_bias 48 | self.share_conv_kernel = share_conv_kernel 49 | 50 | self.layer_idx = layer_idx 51 | 52 | assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." 53 | assert self.hidden_size % num_heads == 0, f"hidden size must be divisible by num_heads of {num_heads}" 54 | 55 | self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False) 56 | self.f_proj = nn.Linear(hidden_size, self.input_dim, bias=False) 57 | self.g_proj = nn.Linear(hidden_size, self.input_dim, bias=False) 58 | 59 | if use_short_conv: 60 | self.conv_size = conv_size 61 | if share_conv_kernel: 62 | self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu') 63 | else: 64 | self.q_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu') 65 | self.f_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu') 66 | self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu') 67 | 68 | self.g_norm = FusedRMSNormSwishGate(self.input_dim, elementwise_affine, norm_eps) 69 | self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False) 70 | 71 | self.apply(self._initialize_weights) 72 | 73 | def _initialize_weights(self, module: nn.Module): 74 | if getattr(module, "_is_hf_initialized", False): 75 | return 76 | if isinstance(module, nn.Linear): 77 | nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) 78 | if module.bias is not None: 79 | nn.init.zeros_(module.bias) 80 | module._is_hf_initialized = True 81 | 82 | def forward( 83 | self, 84 | hidden_states: torch.Tensor, 85 | attention_mask: Optional[torch.Tensor] = None, 86 | past_key_values: Optional[Cache] = None, 87 | use_cache: Optional[bool] = False, 88 | output_attentions: Optional[bool] = False, 89 | lower_bound: Optional[torch.Tensor] = None, 90 | **kwargs 91 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: 92 | # launching the triton kernel for just one token will actually be slower 93 | mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode 94 | 95 | last_state = past_key_values[self.layer_idx] if use_cache else None 96 | if self.use_short_conv: 97 | conv_state = last_state[0] if use_cache else None 98 | if self.share_conv_kernel: 99 | # conv state is updated inplace 100 | hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state) 101 | i = self.i_proj(hidden_states) 102 | f = self.f_proj(hidden_states) 103 | else: 104 | conv_state_i = last_state[2] if use_cache else None 105 | conv_state_f = last_state[1] if use_cache else None 106 | i = self.i_conv1d(self.i_proj(hidden_states), attention_mask, conv_state_i) 107 | f = self.f_conv1d(self.f_proj(hidden_states), attention_mask, conv_state_f) 108 | else: 109 | i = self.i_proj(hidden_states) 110 | f = self.f_proj(hidden_states) 111 | 112 | # the lower bound for the first layer is zero 113 | if lower_bound is None or self.layer_idx == 0: 114 | i, f = swiglu(i, 1 - f.sigmoid()), F.logsigmoid(f) 115 | else: 116 | g = lower_bound + (1 - lower_bound) * f.sigmoid() 117 | i, f = swiglu(i, 1 - g), g.log() 118 | 119 | # dealing with left-padding 120 | if attention_mask is not None: 121 | i = i.mul_(attention_mask.unsqueeze(-1)) 122 | i, f = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (i, f)) 123 | 124 | recurrent_state = last_state[-1] if use_cache else None 125 | if mode == 'chunk': 126 | o, recurrent_state = chunk_hgrn(i, f, initial_state=recurrent_state, output_final_state=use_cache) 127 | elif mode == 'fused_recurrent': 128 | o, recurrent_state = fused_recurrent_hgrn(i, f, initial_state=recurrent_state, output_final_state=use_cache) 129 | else: 130 | raise NotImplementedError(f"Not supported mode `{mode}`.") 131 | 132 | if past_key_values is not None: 133 | if self.use_short_conv: 134 | if self.share_conv_kernel: 135 | last_state = (conv_state, recurrent_state) 136 | else: 137 | last_state = (conv_state_i, conv_state_f, recurrent_state) 138 | else: 139 | last_state = (recurrent_state,) 140 | past_key_values.update(last_state, self.layer_idx, i.shape[2]) 141 | 142 | o = self.g_norm(self.g_proj(hidden_states), rearrange(o, 'b h l d -> b l (h d)')) 143 | o = self.o_proj(o) 144 | 145 | return o, None, past_key_values 146 | 147 | def init_state(self, batch_size: int) -> Tuple[torch.Tensor]: 148 | param = next(self.parameters()) 149 | state = tuple() 150 | if self.use_short_conv: 151 | if self.share_conv_kernel: 152 | state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),) 153 | else: 154 | state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size), 155 | param.new_zeros(batch_size, self.hidden_size, self.conv_size), 156 | param.new_zeros(batch_size, self.hidden_size, self.conv_size)) 157 | state += (param.new_zeros(batch_size, self.num_heads, self.head_dim),) 158 | return state 159 | 160 | def state_size(self, **kwargs) -> int: 161 | state_size = self.hidden_size 162 | for module in self.children(): 163 | if isinstance(module, ShortConvolution): 164 | state_size += module.state_size 165 | return state_size 166 | -------------------------------------------------------------------------------- /fla/layers/hgrn2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # "HGRN2: Gated Linear RNNs with State Expansion"[https://arxiv.org/abs/2404.07904] 4 | 5 | from __future__ import annotations 6 | 7 | from typing import Optional, Tuple 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from einops import rearrange 13 | from transformers.cache_utils import Cache 14 | 15 | from fla.modules import RMSNorm, ShortConvolution 16 | from fla.modules.activations import swish 17 | from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla 18 | 19 | 20 | class HGRN2Attention(nn.Module): 21 | 22 | def __init__( 23 | self, 24 | mode: str = 'chunk', 25 | hidden_size: int = 1024, 26 | num_heads: Optional[int] = None, 27 | expand_ratio: Optional[int] = 128, 28 | use_short_conv: bool = False, 29 | conv_size: int = 4, 30 | conv_bias: bool = False, 31 | share_conv_kernel: bool = True, 32 | elementwise_affine: Optional[bool] = True, 33 | norm_eps: float = 1e-5, 34 | layer_idx: int = None 35 | ) -> HGRN2Attention: 36 | super().__init__() 37 | 38 | self.mode = mode 39 | self.hidden_size = hidden_size 40 | 41 | if expand_ratio is None and num_heads is not None: 42 | expand_ratio = hidden_size // num_heads 43 | elif expand_ratio is not None and num_heads is None: 44 | num_heads = hidden_size // expand_ratio 45 | else: 46 | raise RuntimeError("One of `expand_ratio` or `num_heads` should be provided.") 47 | self.num_heads = num_heads 48 | self.expand_ratio = expand_ratio 49 | 50 | self.use_short_conv = use_short_conv 51 | self.conv_size = conv_size 52 | self.conv_bias = conv_bias 53 | self.share_conv_kernel = share_conv_kernel 54 | 55 | self.forget_dim = int(self.num_heads * self.expand_ratio) 56 | self.input_dim = hidden_size 57 | self.layer_idx = layer_idx 58 | 59 | assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`." 60 | assert self.forget_dim % num_heads == 0, f"forget dim must be divisible by num_heads of {num_heads}" 61 | assert self.input_dim % num_heads == 0, f"input dim must be divisible by num_heads of {num_heads}" 62 | 63 | self.head_f_dim = self.expand_ratio 64 | self.head_i_dim = self.hidden_size // num_heads 65 | 66 | self.q_proj = nn.Linear(hidden_size, self.forget_dim, bias=False) 67 | self.f_proj = nn.Linear(hidden_size, self.forget_dim, bias=False) 68 | self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False) 69 | 70 | if use_short_conv: 71 | self.conv_size = conv_size 72 | if share_conv_kernel: 73 | self.h_conv1d = ShortConvolution(hidden_size, conv_size, activation='silu') 74 | else: 75 | self.q_conv1d = ShortConvolution(self.forget_dim, conv_size, activation='silu') 76 | self.f_conv1d = ShortConvolution(self.forget_dim, conv_size, activation='silu') 77 | self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation='silu') 78 | 79 | self.g_norm = RMSNorm(self.hidden_size, elementwise_affine, norm_eps) 80 | self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False) 81 | 82 | self.apply(self._initialize_weights) 83 | 84 | def _initialize_weights(self, module: nn.Module): 85 | if getattr(module, "_is_hf_initialized", False): 86 | return 87 | if isinstance(module, nn.Linear): 88 | nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) 89 | if module.bias is not None: 90 | nn.init.zeros_(module.bias) 91 | module._is_hf_initialized = True 92 | 93 | def forward( 94 | self, 95 | hidden_states: torch.Tensor, 96 | attention_mask: Optional[torch.Tensor] = None, 97 | past_key_values: Optional[Cache] = None, 98 | use_cache: Optional[bool] = False, 99 | output_attentions: Optional[bool] = False, 100 | lower_bound: Optional[torch.Tensor] = None, 101 | **kwargs 102 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: 103 | # launching the triton kernel for just one token will actually be slower 104 | mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode 105 | 106 | last_state = past_key_values[self.layer_idx] if use_cache else None 107 | if self.use_short_conv: 108 | conv_state = last_state[0] if use_cache else None 109 | if self.share_conv_kernel: 110 | # conv state is updated inplace 111 | hidden_states = self.h_conv1d(hidden_states, attention_mask, conv_state) 112 | q = self.q_proj(hidden_states) 113 | f = self.f_proj(hidden_states) 114 | i = self.i_proj(hidden_states) 115 | else: 116 | conv_state_q = last_state[0] if use_cache else None 117 | conv_state_f = last_state[1] if use_cache else None 118 | conv_state_i = last_state[2] if use_cache else None 119 | q = self.q_proj(hidden_states) 120 | f = self.f_proj(hidden_states) 121 | i = self.i_proj(hidden_states) 122 | q = self.q_conv1d(q, attention_mask, conv_state_q) 123 | f = self.f_conv1d(f, attention_mask, conv_state_f) 124 | i = self.i_conv1d(i, attention_mask, conv_state_i) 125 | else: 126 | q = self.q_proj(hidden_states) 127 | f = self.f_proj(hidden_states) 128 | i = self.i_proj(hidden_states) 129 | 130 | # dealing with left-padding 131 | if attention_mask is not None: 132 | i = i.mul_(attention_mask.unsqueeze(-1)) 133 | 134 | q = swish(q) 135 | # the lower bound for the first layer is zero 136 | if lower_bound is None or self.layer_idx == 0: 137 | k, g = 1 - f.sigmoid(), F.logsigmoid(f) 138 | else: 139 | g = lower_bound + (1 - lower_bound) * f.sigmoid() 140 | k, g = 1 - g, g.log() 141 | q, k, i, g = map(lambda x: rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads), (q, k, i, g)) 142 | 143 | recurrent_state = last_state[-1] if use_cache else None 144 | if mode == 'fused_recurrent': 145 | o, recurrent_state = fused_recurrent_gla(q, k, i, g, initial_state=recurrent_state, output_final_state=use_cache) 146 | elif mode == 'fused_chunk': 147 | o, recurrent_state = fused_chunk_gla(q, k, i, g, initial_state=recurrent_state, output_final_state=use_cache) 148 | elif mode == 'chunk': 149 | o, recurrent_state = chunk_gla(q, k, i, g, initial_state=recurrent_state, output_final_state=use_cache) 150 | else: 151 | raise NotImplementedError(f"Not supported mode `{mode}`.") 152 | 153 | if past_key_values is not None: 154 | if self.use_short_conv: 155 | if self.share_conv_kernel: 156 | last_state = (conv_state, recurrent_state) 157 | else: 158 | last_state = (conv_state_q, conv_state_f, conv_state_i, recurrent_state) 159 | else: 160 | last_state = (recurrent_state,) 161 | past_key_values.update(last_state, self.layer_idx, q.shape[2]) 162 | 163 | o = self.g_norm(rearrange(o, 'b h l d -> b l (h d)')) 164 | o = self.o_proj(o) 165 | 166 | return o, None, past_key_values 167 | 168 | def init_state(self, batch_size: int) -> Tuple[torch.Tensor]: 169 | param = next(self.parameters()) 170 | state = tuple() 171 | if self.use_short_conv: 172 | if self.share_conv_kernel: 173 | state += (param.new_zeros(batch_size, self.hidden_size, self.conv_size),) 174 | else: 175 | state += (param.new_zeros(batch_size, self.forget_dim, self.conv_size), 176 | param.new_zeros(batch_size, self.forget_dim, self.conv_size), 177 | param.new_zeros(batch_size, self.input_dim, self.conv_size)) 178 | state += (param.new_zeros(batch_size, self.num_heads, self.head_f_dim, self.head_i_dim),) 179 | return state 180 | 181 | def state_size(self, **kwargs) -> int: 182 | state_size = self.forget_dim * self.head_i_dim 183 | for module in self.children(): 184 | if isinstance(module, ShortConvolution): 185 | state_size += module.state_size 186 | return state_size 187 | -------------------------------------------------------------------------------- /fla/layers/linear_attn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | 7 | from fla.modules import RMSNorm 8 | from fla.modules.feature_map import (DPFPFeatureMap, HadamardFeatureMap, 9 | HedgehogFeatureMap, T2RFeatureMap) 10 | from fla.ops.linear_attn import (chunk_linear_attn, fused_chunk_linear_attn, 11 | fused_recurrent_linear_attn) 12 | 13 | 14 | class LinearAttention(nn.Module): 15 | def __init__( 16 | self, 17 | hidden_size: str = 1024, 18 | expand_k: int = 1.0, 19 | expand_v: int = 1.0, 20 | num_heads: int = 8, 21 | mode: str = 'chunk', 22 | feature_map: str = 'elementwise_product', 23 | tie_feature_map_qk: bool = False, 24 | output_norm: str = 'rmsnorm', 25 | norm_q: bool = False, 26 | norm_k: bool = False, 27 | # standard linear attention normalization 28 | do_feature_map_norm: bool = False, 29 | elementwise_affine: bool = True, 30 | norm_eps: float = 1e-5, 31 | **kwargs, 32 | ): 33 | super().__init__() 34 | assert feature_map in ['elu', 'relu', 'hedgehog', 't2r', 'dpfp', 35 | 'identity', 'elementwise_product'], f"Not supported feature map `{feature_map}`." 36 | 37 | assert output_norm in ['rmsnorm', 'identity'], f"Not supported output norm `{output_norm}`." 38 | 39 | self.hidden_size 40 | self.mode = mode 41 | self.key_dim = int(hidden_size * expand_k) 42 | self.value_dim = int(hidden_size * expand_v) 43 | self.num_heads = num_heads 44 | 45 | assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." 46 | assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" 47 | assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" 48 | 49 | self.head_qk_dim = self.key_dim // num_heads 50 | self.head_v_dim = self.value_dim // num_heads 51 | 52 | if feature_map == 'hedgehog': 53 | if tie_feature_map_qk: 54 | self.feature_map_q = self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_qk_dim) 55 | else: 56 | self.feature_map_q = HedgehogFeatureMap(head_dim=self.head_qk_dim) 57 | self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_qk_dim) 58 | 59 | elif feature_map == 't2r': 60 | if tie_feature_map_qk: 61 | self.feature_map_q = self.feature_map_k = T2RFeatureMap(head_dim=self.head_qk_dim) 62 | else: 63 | self.feature_map_q = T2RFeatureMap(head_dim=self.head_qk_dim) 64 | self.feature_map_k = T2RFeatureMap(head_dim=self.head_qk_dim) 65 | 66 | elif feature_map == 'elementwise_product': 67 | if tie_feature_map_qk: 68 | self.feature_map_q = self.feature_map_k = HadamardFeatureMap(head_dim=self.head_qk_dim) 69 | else: 70 | self.feature_map_q = HadamardFeatureMap(head_dim=self.head_qk_dim) 71 | self.feature_map_k = HadamardFeatureMap(head_dim=self.head_qk_dim) 72 | 73 | elif feature_map == 'dpfp': 74 | self.feature_map_q = DPFPFeatureMap(head_dim=self.head_qk_dim) 75 | self.feature_map_k = DPFPFeatureMap(head_dim=self.head_qk_dim) 76 | 77 | elif feature_map == 'elu': 78 | def elu(x): 79 | return F.elu(x) + 1 80 | self.feature_map_q = elu 81 | self.feature_map_k = elu 82 | 83 | elif feature_map == 'relu': 84 | self.feature_map_q = nn.ReLU() 85 | self.feature_map_k = nn.ReLU() 86 | 87 | elif feature_map == 'identity': 88 | self.feature_map_q = nn.Identity() 89 | self.feature_map_k = nn.Identity() 90 | else: 91 | raise NotImplementedError 92 | 93 | self.do_feature_map_norm = do_feature_map_norm 94 | if output_norm == 'rmsnorm': 95 | self.norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps) 96 | elif output_norm == 'identity': 97 | self.norm = nn.Identity() 98 | else: 99 | raise NotImplementedError 100 | 101 | self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) 102 | self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) 103 | self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) 104 | self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) 105 | 106 | self.norm_q = norm_q 107 | self.norm_k = norm_k 108 | 109 | self.apply(self._initialize_weights) 110 | 111 | def _initialize_weights(self, module: nn.Module): 112 | if getattr(module, "_is_hf_initialized", False): 113 | return 114 | if isinstance(module, nn.Linear): 115 | nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) 116 | if module.bias is not None: 117 | nn.init.zeros_(module.bias) 118 | module._is_hf_initialized = True 119 | 120 | def forward(self, x): 121 | mode = self.mode 122 | q = rearrange(self.q_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) 123 | k = rearrange(self.k_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) 124 | v = rearrange(self.v_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) 125 | q = self.feature_map_q(q) 126 | k = self.feature_map_k(k) 127 | if self.norm_q: 128 | q = q / (q.sum(-1, keepdim=True) + 1e-4) 129 | if self.norm_k: 130 | k = k / (k.sum(-1, keepdim=True) + 1e-4) 131 | 132 | if mode == 'chunk': 133 | o = chunk_linear_attn(q, k, v, normalize=self.do_feature_map_norm) 134 | elif mode == 'fused_chunk': 135 | o = fused_chunk_linear_attn(q, k, v, normalize=self.do_feature_map_norm) 136 | elif mode == 'fused_recurrent': 137 | o = fused_recurrent_linear_attn(q, k, v, normalize=self.do_feature_map_norm) 138 | else: 139 | raise NotImplementedError 140 | o = self.norm(o) 141 | o = rearrange(o, 'b h n d -> b n (h d)') 142 | o = self.o_proj(o) 143 | return o -------------------------------------------------------------------------------- /fla/layers/rebased.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/layers/rebased_fast.py 5 | """ 6 | 7 | from __future__ import annotations 8 | 9 | from typing import Optional 10 | 11 | import torch 12 | import torch.nn as nn 13 | from einops import rearrange 14 | 15 | from fla.modules.feature_map import RebasedFeatureMap 16 | from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn 17 | from fla.ops.rebased import parallel_rebased 18 | 19 | 20 | class ReBasedLinearAttention(nn.Module): 21 | def __init__( 22 | self, 23 | hidden_size: int, 24 | l_max: int = 2048, 25 | feature_dim: int = 16, 26 | num_key_value_heads: int = 16, 27 | num_heads: int = 16, 28 | use_gamma: Optional[bool] = True, 29 | use_beta: Optional[bool] = True, 30 | normalize: Optional[bool] = True, 31 | causal: bool = True, 32 | eps: float = 1e-5, 33 | mode: str = "parallel", 34 | layer_idx: Optional[int] = None, 35 | **kwargs 36 | ) -> ReBasedLinearAttention: 37 | super().__init__() 38 | self.hidden_size = hidden_size 39 | self.l_max = l_max 40 | self.mode = mode 41 | assert self.mode in ["fused_chunk", "parallel", 'chunk'] 42 | 43 | # linear attention 44 | self.feature_dim = feature_dim 45 | self.num_key_value_heads = num_key_value_heads 46 | self.num_heads = num_heads 47 | self.head_dim = self.hidden_size // self.num_key_value_heads 48 | self.use_gamma = use_gamma 49 | self.use_beta = use_beta 50 | self.normalize = normalize 51 | self.causal = causal 52 | 53 | self.feature_map = RebasedFeatureMap(self.feature_dim, use_gamma, use_beta, normalize) 54 | self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False) 55 | self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False) 56 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) 57 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) 58 | self.dropout = nn.Identity() 59 | self.eps = eps 60 | 61 | self.apply(self._initialize_weights) 62 | 63 | def _initialize_weights(self, module: nn.Module): 64 | if getattr(module, "_is_hf_initialized", False): 65 | return 66 | if isinstance(module, nn.Linear): 67 | nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) 68 | if module.bias is not None: 69 | nn.init.zeros_(module.bias) 70 | module._is_hf_initialized = True 71 | 72 | def forward(self, hidden_states: torch.Tensor, **kwargs): 73 | mode = self.mode 74 | q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) 75 | q, k, v = map(lambda x: rearrange(x, "b l (h d) -> b h l d", h=self.num_heads), [q, k, v]) 76 | q, k = self.feature_map(q, flatten=(mode != 'parallel')), self.feature_map(k, flatten=(mode != 'parallel')) 77 | if mode == "fused_chunk": 78 | o = fused_chunk_linear_attn(q, k, v, normalize=True, scale=1) 79 | elif mode == 'chunk': 80 | o = chunk_linear_attn(q, k, v, normalize=True, scale=1) 81 | elif mode == 'parallel': 82 | assert q.shape[-1] <= 128 83 | o = parallel_rebased(q, k, v, self.eps, True, True) 84 | o = rearrange(o, "b h l d -> b l (h d)") 85 | o = self.o_proj(o) 86 | o = self.dropout(o) 87 | return o 88 | 89 | # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119 90 | def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs): 91 | """ 92 | x (torch.Tensor): tensor of shape (b, d, l) 93 | y (torch.Tensor): tensor of shape (b, d, l) 94 | """ 95 | # hidden_states = hidden_states.transpose(1, 2) 96 | b, l, _ = hidden_states.size() 97 | q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) 98 | 99 | q = q.view(b, l, self.num_heads, self.feature_dim).transpose(1, 2) 100 | k = k.view(b, l, self.num_key_value_heads, self.feature_dim).transpose(1, 2) 101 | v = v.view(b, l, self.num_key_value_heads, self.head_dim).transpose(1, 2) 102 | 103 | # Linear attention 104 | q, k = self.feature_map(q), self.feature_map(k) 105 | q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1) 106 | 107 | # Compute attention 108 | if self.causal: 109 | y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps)) 110 | else: 111 | y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps)) 112 | y = rearrange(y, 'b h l d -> b l (h d)') 113 | y = self.o_proj(y.to(hidden_states.dtype)) 114 | y = self.dropout(y) 115 | return y.to(hidden_states.dtype) 116 | -------------------------------------------------------------------------------- /fla/layers/simple_gla.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Optional 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from einops import rearrange 11 | from transformers.activations import ACT2FN 12 | 13 | from fla.modules import FusedRMSNormSwishGate, RMSNorm 14 | from fla.ops.simple_gla import chunk_simple_gla 15 | 16 | 17 | class SimpleGatedLinearAttention(nn.Module): 18 | r""" 19 | The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635). # noqa 20 | This layer calls the simplified GLA kernel in which the gating is head-wise instead of elementwise. 21 | 22 | Args: 23 | mode (str, Optional): 24 | Which GLA kernel to use. 25 | Currently available: `chunk`. 26 | Default: `chunk`. 27 | hidden_size (int, Optional): 28 | The hidden size of the input. Default: 1024. 29 | expand_k (float, Optional): 30 | The expansion ratio for the key dim. Default: 0.5. 31 | expand_v (float, Optional): 32 | The expansion ratio for the value dim. Default: 1.0. 33 | num_heads (int, Optional): 34 | The number of heads. Default: 4. 35 | gate_fn (str, Optional): 36 | The activation function for the output gate. Default: `swish`. 37 | elementwise_affine (bool, Optional): 38 | If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`. 39 | norm_eps (float, Optional): 40 | The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5. 41 | gate_logit_normalizer (int, Optional): 42 | The normalizer for the gate logits, appied after `logsigmoid`. Default: 16. 43 | fuse_norm (bool, Optional): 44 | Whether to fuse the norm and the output gate for better memory footprint. Default: `True`. 45 | layer_idx (int, Optional): 46 | The index of the layer. Default: None. 47 | """ 48 | 49 | def __init__( 50 | self, 51 | mode: str = 'chunk', 52 | hidden_size: int = 1024, 53 | expand_k: float = 1.0, 54 | expand_v: float = 2.0, 55 | num_heads: int = 4, 56 | gate_fn: str = 'swish', 57 | elementwise_affine: Optional[bool] = True, 58 | norm_eps: float = 1e-5, 59 | gate_logit_normalizer: int = 16, 60 | fuse_norm: bool = True, 61 | **kwargs 62 | ) -> SimpleGatedLinearAttention: 63 | super().__init__() 64 | self.hidden_size = hidden_size 65 | 66 | self.mode = mode 67 | self.key_dim = int(hidden_size * expand_k) 68 | self.value_dim = int(hidden_size * expand_v) 69 | assert mode in ['chunk'], f"Not suppoerted mode `{mode}`." 70 | assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" 71 | assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" 72 | self.num_heads = num_heads 73 | self.head_qk_dim = self.key_dim // num_heads 74 | self.head_v_dim = self.value_dim // num_heads 75 | self.gate_fn = ACT2FN[gate_fn] 76 | 77 | self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) 78 | self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) 79 | self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) 80 | self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) 81 | 82 | self.gk_proj = nn.Linear(hidden_size, self.num_heads) 83 | self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) 84 | 85 | if gate_fn == 'swish' and fuse_norm: 86 | self.g_norm_swish_gate = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps) 87 | self.fuse_norm_and_gate = True 88 | else: 89 | self.fuse_norm_and_gate = False 90 | self.g_norm = RMSNorm(self.head_v_dim, elementwise_affine, norm_eps) 91 | 92 | self.gate_logit_normalizer = gate_logit_normalizer 93 | 94 | self.apply(self._initialize_weights) 95 | 96 | def _initialize_weights(self, module: nn.Module): 97 | if getattr(module, "_is_hf_initialized", False): 98 | return 99 | if isinstance(module, nn.Linear): 100 | nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) 101 | if module.bias is not None: 102 | nn.init.zeros_(module.bias) 103 | module._is_hf_initialized = True 104 | 105 | def forward(self, x): 106 | mode = self.mode 107 | q = rearrange(self.q_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) 108 | k = rearrange(self.k_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) 109 | v = rearrange(self.v_proj(x), 'b n (h d) -> b h n d', h=self.num_heads) 110 | gk = rearrange(self.gk_proj(x), 'b n h -> b h n') 111 | gk = (F.logsigmoid(gk) / self.gate_logit_normalizer) 112 | 113 | if mode == 'chunk': 114 | o = chunk_simple_gla(q, k, v, gk) 115 | else: 116 | raise NotImplementedError(f"Not supported mode `{mode}`.") 117 | 118 | o = rearrange(o, 'b h l d -> b l h d') 119 | g = self.g_proj(x) 120 | 121 | if self.fuse_norm_and_gate: 122 | g = rearrange(g, 'b l (h d) -> b l h d', h=self.num_heads) 123 | o = self.g_norm_swish_gate(o, g) 124 | o = rearrange(o, 'b l h d -> b l (h d)') 125 | else: 126 | o = self.g_norm(o) 127 | o = rearrange(o, 'b l h d -> b l (h d)') 128 | o = o * self.gate_fn(g) 129 | o = self.o_proj(o) 130 | return o 131 | -------------------------------------------------------------------------------- /fla/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from fla.models.abc import ABCConfig, ABCForCausalLM, ABCModel 4 | from fla.models.delta_net import (DeltaNetConfig, DeltaNetForCausalLM, 5 | DeltaNetModel) 6 | from fla.models.gla import GLAConfig, GLAForCausalLM, GLAModel 7 | from fla.models.hgrn import HGRNConfig, HGRNForCausalLM, HGRNModel 8 | from fla.models.hgrn2 import HGRN2Config, HGRN2ForCausalLM, HGRN2Model 9 | from fla.models.linear_attn import (LinearAttentionConfig, 10 | LinearAttentionForCausalLM, 11 | LinearAttentionModel) 12 | from fla.models.retnet import RetNetConfig, RetNetForCausalLM, RetNetModel 13 | from fla.models.rwkv6 import RWKV6Config, RWKV6ForCausalLM, RWKV6Model 14 | from fla.models.transformer import (TransformerConfig, TransformerForCausalLM, 15 | TransformerModel) 16 | 17 | __all__ = [ 18 | 'ABCConfig', 'ABCForCausalLM', 'ABCModel', 19 | 'DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel', 20 | 'GLAConfig', 'GLAForCausalLM', 'GLAModel', 21 | 'HGRNConfig', 'HGRNForCausalLM', 'HGRNModel', 22 | 'HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model', 23 | 'LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel', 24 | 'RetNetConfig', 'RetNetForCausalLM', 'RetNetModel', 25 | 'RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model', 26 | 'TransformerConfig', 'TransformerForCausalLM', 'TransformerModel' 27 | ] 28 | -------------------------------------------------------------------------------- /fla/models/abc/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.abc.configuration_abc import ABCConfig 6 | from fla.models.abc.modeling_abc import ABCForCausalLM, ABCModel 7 | 8 | AutoConfig.register(ABCConfig.model_type, ABCConfig) 9 | AutoModel.register(ABCConfig, ABCModel) 10 | AutoModelForCausalLM.register(ABCConfig, ABCForCausalLM) 11 | 12 | 13 | __all__ = ['ABCConfig', 'ABCForCausalLM', 'ABCModel'] 14 | -------------------------------------------------------------------------------- /fla/models/abc/configuration_abc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class ABCConfig(PretrainedConfig): 9 | 10 | model_type = 'abc' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | vocab_size: int = 32000, 16 | hidden_size: int = 2048, 17 | gate_low_rank_dim: int = 16, 18 | clamp_min: float = -32, 19 | clamp_max: float = 32, 20 | hidden_ratio: Optional[int] = 4, 21 | intermediate_size: Optional[int] = None, 22 | num_hidden_layers: int = 24, 23 | num_heads: int = 4, 24 | num_slots: Optional[int] = 64, 25 | use_short_conv: bool = True, 26 | conv_size: int = 4, 27 | share_conv_kernel: bool = True, 28 | exapnd_k: float = 0.5, 29 | exapnd_v: float = 1, 30 | hidden_act: str = "swish", 31 | max_position_embeddings: int = 2048, 32 | elementwise_affine: Optional[bool] = True, 33 | norm_eps: float = 1e-6, 34 | use_cache: bool = True, 35 | pad_token_id: int = None, 36 | bos_token_id: int = 1, 37 | eos_token_id: int = 2, 38 | initializer_range: float = 0.02, 39 | tie_word_embeddings: bool = False, 40 | fuse_norm: bool = True, 41 | fuse_cross_entropy: bool = True, 42 | **kwargs 43 | ): 44 | self.vocab_size = vocab_size 45 | self.max_position_embeddings = max_position_embeddings 46 | self.hidden_size = hidden_size 47 | self.gate_low_rank_dim = gate_low_rank_dim 48 | self.clamp_min = clamp_min 49 | self.clamp_max = clamp_max 50 | self.hidden_ratio = hidden_ratio 51 | self.intermediate_size = intermediate_size 52 | self.num_hidden_layers = num_hidden_layers 53 | self.num_heads = num_heads 54 | self.num_slots = num_slots 55 | self.use_short_conv = use_short_conv 56 | self.conv_size = conv_size 57 | self.share_conv_kernel = share_conv_kernel 58 | self.expand_k = exapnd_k 59 | self.expand_v = exapnd_v 60 | self.hidden_act = hidden_act 61 | self.elementwise_affine = elementwise_affine 62 | self.norm_eps = norm_eps 63 | self.use_cache = use_cache 64 | self.initializer_range = initializer_range 65 | self.fuse_cross_entropy = fuse_cross_entropy 66 | self.fuse_norm = fuse_norm 67 | 68 | super().__init__( 69 | pad_token_id=pad_token_id, 70 | bos_token_id=bos_token_id, 71 | eos_token_id=eos_token_id, 72 | tie_word_embeddings=tie_word_embeddings, 73 | **kwargs, 74 | ) 75 | -------------------------------------------------------------------------------- /fla/models/delta_net/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.delta_net.configuration_delta_net import \ 6 | DeltaNetConfig 7 | from fla.models.delta_net.modeling_delta_net import ( 8 | DeltaNetForCausalLM, DeltaNetModel) 9 | 10 | AutoConfig.register(DeltaNetConfig.model_type, DeltaNetConfig) 11 | AutoModel.register(DeltaNetConfig, DeltaNetModel) 12 | AutoModelForCausalLM.register(DeltaNetConfig, DeltaNetForCausalLM) 13 | 14 | __all__ = ['DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel'] 15 | -------------------------------------------------------------------------------- /fla/models/delta_net/configuration_delta_net.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class DeltaNetConfig(PretrainedConfig): 9 | 10 | model_type = 'delta_net' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | vocab_size: int = 32000, 16 | hidden_size: int = 2048, 17 | expand_k: int = 1, 18 | expand_v: int = 1, 19 | use_gate: bool = False, 20 | use_short_conv: bool = True, 21 | conv_size: int = 4, 22 | share_conv_kernel: bool = False, 23 | use_rope: bool = False, 24 | use_beta: bool = True, 25 | use_output_norm: bool = True, 26 | hidden_ratio: Optional[int] = 4, 27 | intermediate_size: Optional[int] = None, 28 | num_hidden_layers: int = 24, 29 | num_heads: int = 4, 30 | attn_mode: str = "chunk", 31 | qk_norm: str = 'l2', 32 | qk_activation: str = 'silu', 33 | chunk_size: int = 64, 34 | hidden_act: str = "swish", 35 | max_position_embeddings: int = 2048, 36 | rms_norm_eps: float = 1e-6, 37 | use_cache: bool = True, 38 | pad_token_id: int = None, 39 | bos_token_id: int = 1, 40 | eos_token_id: int = 2, 41 | tie_word_embeddings: bool = False, 42 | initializer_range: float = 0.02, 43 | fuse_cross_entropy: bool = True, 44 | **kwargs 45 | ): 46 | self.vocab_size = vocab_size 47 | self.max_position_embeddings = max_position_embeddings 48 | self.hidden_size = hidden_size 49 | self.expand_k = expand_k 50 | self.expand_v = expand_v 51 | self.hidden_ratio = hidden_ratio 52 | self.intermediate_size = intermediate_size 53 | self.num_hidden_layers = num_hidden_layers 54 | self.num_heads = num_heads 55 | self.attn_mode = attn_mode 56 | self.hidden_act = hidden_act 57 | self.rms_norm_eps = rms_norm_eps 58 | self.use_cache = use_cache 59 | self.initializer_range = initializer_range 60 | self.fuse_cross_entropy = fuse_cross_entropy 61 | self.use_gate = use_gate 62 | self.use_short_conv = use_short_conv 63 | self.conv_size = conv_size 64 | self.share_conv_kernel = share_conv_kernel 65 | self.use_rope = use_rope 66 | self.use_beta = use_beta 67 | self.use_output_norm = use_output_norm 68 | self.qk_norm = qk_norm 69 | self.qk_activation = qk_activation 70 | 71 | super().__init__( 72 | pad_token_id=pad_token_id, 73 | bos_token_id=bos_token_id, 74 | eos_token_id=eos_token_id, 75 | tie_word_embeddings=tie_word_embeddings, 76 | **kwargs, 77 | ) 78 | -------------------------------------------------------------------------------- /fla/models/gla/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.gla.configuration_gla import GLAConfig 6 | from fla.models.gla.modeling_gla import GLAForCausalLM, GLAModel 7 | 8 | AutoConfig.register(GLAConfig.model_type, GLAConfig) 9 | AutoModel.register(GLAConfig, GLAModel) 10 | AutoModelForCausalLM.register(GLAConfig, GLAForCausalLM) 11 | 12 | 13 | __all__ = ['GLAConfig', 'GLAForCausalLM', 'GLAModel'] 14 | -------------------------------------------------------------------------------- /fla/models/gla/configuration_gla.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class GLAConfig(PretrainedConfig): 9 | 10 | model_type = 'gla' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | vocab_size: int = 32000, 16 | hidden_size: int = 2048, 17 | expand_k: int = 0.5, 18 | expand_v: int = 1, 19 | hidden_ratio: Optional[int] = 4, 20 | intermediate_size: Optional[int] = None, 21 | num_hidden_layers: int = 24, 22 | num_heads: int = 4, 23 | num_kv_heads: Optional[int] = None, 24 | feature_map: Optional[str] = None, 25 | attn_mode: str = "chunk", 26 | use_short_conv: bool = False, 27 | conv_size: int = 4, 28 | share_conv_kernel: bool = True, 29 | use_output_gate: bool = True, 30 | clamp_min: Optional[float] = None, 31 | hidden_act: str = "swish", 32 | max_position_embeddings: int = 2048, 33 | elementwise_affine: Optional[bool] = True, 34 | norm_eps: float = 1e-6, 35 | use_gk: bool = True, 36 | use_gv: bool = False, 37 | use_cache: bool = True, 38 | pad_token_id: int = None, 39 | bos_token_id: int = 1, 40 | eos_token_id: int = 2, 41 | tie_word_embeddings: bool = False, 42 | initializer_range: float = 0.02, 43 | fuse_norm: bool = True, 44 | fuse_cross_entropy: bool = True, 45 | **kwargs 46 | ): 47 | self.vocab_size = vocab_size 48 | self.max_position_embeddings = max_position_embeddings 49 | self.hidden_size = hidden_size 50 | self.expand_k = expand_k 51 | self.expand_v = expand_v 52 | self.hidden_ratio = hidden_ratio 53 | self.intermediate_size = intermediate_size 54 | self.num_hidden_layers = num_hidden_layers 55 | self.num_heads = num_heads 56 | self.num_kv_heads = num_kv_heads 57 | self.feature_map = feature_map 58 | self.attn_mode = attn_mode 59 | self.clamp_min = clamp_min 60 | self.hidden_act = hidden_act 61 | self.elementwise_affine = elementwise_affine 62 | self.norm_eps = norm_eps 63 | self.use_gk = use_gk 64 | self.use_gv = use_gv 65 | self.use_cache = use_cache 66 | self.initializer_range = initializer_range 67 | self.fuse_norm = fuse_norm 68 | self.fuse_cross_entropy = fuse_cross_entropy 69 | self.use_short_conv = use_short_conv 70 | self.conv_size = conv_size 71 | self.share_conv_kernel = share_conv_kernel 72 | self.use_output_gate = use_output_gate 73 | 74 | super().__init__( 75 | pad_token_id=pad_token_id, 76 | bos_token_id=bos_token_id, 77 | eos_token_id=eos_token_id, 78 | tie_word_embeddings=tie_word_embeddings, 79 | **kwargs, 80 | ) 81 | -------------------------------------------------------------------------------- /fla/models/hgrn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.hgrn.configuration_hgrn import HGRNConfig 6 | from fla.models.hgrn.modeling_hgrn import HGRNForCausalLM, HGRNModel 7 | 8 | AutoConfig.register(HGRNConfig.model_type, HGRNConfig) 9 | AutoModel.register(HGRNConfig, HGRNModel) 10 | AutoModelForCausalLM.register(HGRNConfig, HGRNForCausalLM) 11 | 12 | 13 | __all__ = ['HGRNConfig', 'HGRNForCausalLM', 'HGRNModel'] 14 | -------------------------------------------------------------------------------- /fla/models/hgrn/configuration_hgrn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class HGRNConfig(PretrainedConfig): 9 | 10 | model_type = 'hgrn' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | attn_mode: str = "chunk", 16 | vocab_size: int = 32000, 17 | hidden_size: int = 2048, 18 | num_hidden_layers: int = 24, 19 | num_heads: Optional[int] = 1, 20 | expand_ratio: Optional[int] = 1, 21 | use_short_conv: bool = False, 22 | conv_size: int = 4, 23 | share_conv_kernel: bool = True, 24 | use_lower_bound: bool = True, 25 | hidden_ratio: Optional[int] = 4, 26 | intermediate_size: Optional[int] = None, 27 | hidden_act: str = "swish", 28 | max_position_embeddings: int = 2048, 29 | elementwise_affine: Optional[bool] = True, 30 | norm_eps: float = 1e-6, 31 | use_cache: bool = True, 32 | pad_token_id: int = None, 33 | bos_token_id: int = 1, 34 | eos_token_id: int = 2, 35 | tie_word_embeddings: bool = False, 36 | initializer_range: float = 0.02, 37 | fuse_cross_entropy: bool = True, 38 | **kwargs 39 | ): 40 | self.attn_mode = attn_mode 41 | self.vocab_size = vocab_size 42 | self.max_position_embeddings = max_position_embeddings 43 | self.hidden_size = hidden_size 44 | self.num_hidden_layers = num_hidden_layers 45 | self.num_heads = num_heads 46 | self.expand_ratio = expand_ratio 47 | self.use_short_conv = use_short_conv 48 | self.conv_size = conv_size 49 | self.share_conv_kernel = share_conv_kernel 50 | self.use_lower_bound = use_lower_bound 51 | self.hidden_ratio = hidden_ratio 52 | self.intermediate_size = intermediate_size 53 | self.hidden_act = hidden_act 54 | self.elementwise_affine = elementwise_affine 55 | self.norm_eps = norm_eps 56 | self.use_cache = use_cache 57 | self.initializer_range = initializer_range 58 | self.fuse_cross_entropy = fuse_cross_entropy 59 | 60 | super().__init__( 61 | pad_token_id=pad_token_id, 62 | bos_token_id=bos_token_id, 63 | eos_token_id=eos_token_id, 64 | tie_word_embeddings=tie_word_embeddings, 65 | **kwargs, 66 | ) 67 | -------------------------------------------------------------------------------- /fla/models/hgrn2/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.hgrn2.configuration_hgrn2 import HGRN2Config 6 | from fla.models.hgrn2.modeling_hgrn2 import HGRN2ForCausalLM, HGRN2Model 7 | 8 | AutoConfig.register(HGRN2Config.model_type, HGRN2Config) 9 | AutoModel.register(HGRN2Config, HGRN2Model) 10 | AutoModelForCausalLM.register(HGRN2Config, HGRN2ForCausalLM) 11 | 12 | 13 | __all__ = ['HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model'] 14 | -------------------------------------------------------------------------------- /fla/models/hgrn2/configuration_hgrn2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class HGRN2Config(PretrainedConfig): 9 | 10 | model_type = 'hgrn2' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | vocab_size: int = 32000, 16 | hidden_size: int = 2048, 17 | num_hidden_layers: int = 24, 18 | attn_mode: str = "chunk", 19 | num_heads: Optional[int] = None, 20 | expand_ratio: Optional[int] = 128, 21 | use_short_conv: bool = False, 22 | conv_size: int = 4, 23 | share_conv_kernel: bool = True, 24 | use_lower_bound: bool = True, 25 | hidden_ratio: Optional[int] = 4, 26 | intermediate_size: Optional[int] = None, 27 | hidden_act: str = "swish", 28 | max_position_embeddings: int = 2048, 29 | elementwise_affine: Optional[bool] = True, 30 | norm_eps: float = 1e-6, 31 | use_cache: bool = True, 32 | pad_token_id: int = None, 33 | bos_token_id: int = 1, 34 | eos_token_id: int = 2, 35 | tie_word_embeddings: bool = False, 36 | initializer_range: float = 0.02, 37 | fuse_cross_entropy: bool = True, 38 | **kwargs 39 | ): 40 | self.vocab_size = vocab_size 41 | self.max_position_embeddings = max_position_embeddings 42 | self.hidden_size = hidden_size 43 | self.num_hidden_layers = num_hidden_layers 44 | self.attn_mode = attn_mode 45 | self.num_heads = num_heads 46 | self.expand_ratio = expand_ratio 47 | self.use_short_conv = use_short_conv 48 | self.conv_size = conv_size 49 | self.share_conv_kernel = share_conv_kernel 50 | self.use_lower_bound = use_lower_bound 51 | self.hidden_ratio = hidden_ratio 52 | self.intermediate_size = intermediate_size 53 | self.hidden_act = hidden_act 54 | self.elementwise_affine = elementwise_affine 55 | self.norm_eps = norm_eps 56 | self.use_cache = use_cache 57 | self.initializer_range = initializer_range 58 | self.fuse_cross_entropy = fuse_cross_entropy 59 | 60 | super().__init__( 61 | pad_token_id=pad_token_id, 62 | bos_token_id=bos_token_id, 63 | eos_token_id=eos_token_id, 64 | tie_word_embeddings=tie_word_embeddings, 65 | **kwargs, 66 | ) 67 | -------------------------------------------------------------------------------- /fla/models/linear_attn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.linear_attn.configuration_linear_attn import \ 6 | LinearAttentionConfig 7 | from fla.models.linear_attn.modeling_linear_attn import ( 8 | LinearAttentionForCausalLM, LinearAttentionModel) 9 | 10 | AutoConfig.register(LinearAttentionConfig.model_type, LinearAttentionConfig) 11 | AutoModel.register(LinearAttentionConfig, LinearAttentionModel) 12 | AutoModelForCausalLM.register(LinearAttentionConfig, LinearAttentionForCausalLM) 13 | 14 | __all__ = ['LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel'] 15 | -------------------------------------------------------------------------------- /fla/models/linear_attn/configuration_linear_attn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class LinearAttentionConfig(PretrainedConfig): 9 | 10 | model_type = 'linear_attn' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | vocab_size: int = 32000, 16 | hidden_size: int = 2048, 17 | expand_k: int = 1, 18 | expand_v: int = 1, 19 | hidden_ratio: Optional[int] = 4, 20 | intermediate_size: Optional[int] = None, 21 | num_hidden_layers: int = 24, 22 | num_heads: int = 4, 23 | attn_mode: str = "fused_chunk", 24 | feature_map: str = "elementwise_product", 25 | tie_feature_map_qk: bool = False, 26 | norm_q: bool = False, 27 | norm_k: bool = False, 28 | norm_feature_map: bool = False, 29 | hidden_act: str = "swish", 30 | max_position_embeddings: int = 2048, 31 | elementwise_affine: Optional[bool] = True, 32 | norm_eps: float = 1e-6, 33 | use_cache: bool = True, 34 | pad_token_id: int = None, 35 | bos_token_id: int = 1, 36 | eos_token_id: int = 2, 37 | tie_word_embeddings: bool = False, 38 | initializer_range: float = 0.02, 39 | fuse_cross_entropy: bool = True, 40 | **kwargs 41 | ): 42 | self.vocab_size = vocab_size 43 | self.max_position_embeddings = max_position_embeddings 44 | self.hidden_size = hidden_size 45 | self.expand_k = expand_k 46 | self.expand_v = expand_v 47 | self.hidden_ratio = hidden_ratio 48 | self.intermediate_size = intermediate_size 49 | self.num_hidden_layers = num_hidden_layers 50 | self.num_heads = num_heads 51 | self.attn_mode = attn_mode 52 | self.feature_map = feature_map 53 | self.tie_feature_map_qk = tie_feature_map_qk 54 | self.norm_q = norm_q 55 | self.norm_k = norm_k 56 | self.norm_feature_map = norm_feature_map 57 | self.hidden_act = hidden_act 58 | self.elementwise_affine = elementwise_affine 59 | self.norm_eps = norm_eps 60 | self.use_cache = use_cache 61 | self.initializer_range = initializer_range 62 | self.fuse_cross_entropy = fuse_cross_entropy 63 | 64 | super().__init__( 65 | pad_token_id=pad_token_id, 66 | bos_token_id=bos_token_id, 67 | eos_token_id=eos_token_id, 68 | tie_word_embeddings=tie_word_embeddings, 69 | **kwargs, 70 | ) 71 | -------------------------------------------------------------------------------- /fla/models/retnet/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.retnet.configuration_retnet import RetNetConfig 6 | from fla.models.retnet.modeling_retnet import RetNetForCausalLM, RetNetModel 7 | 8 | AutoConfig.register(RetNetConfig.model_type, RetNetConfig) 9 | AutoModel.register(RetNetConfig, RetNetModel) 10 | AutoModelForCausalLM.register(RetNetConfig, RetNetForCausalLM) 11 | 12 | 13 | __all__ = ['RetNetConfig', 'RetNetForCausalLM', 'RetNetModel'] 14 | -------------------------------------------------------------------------------- /fla/models/retnet/configuration_retnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Optional 6 | 7 | from transformers.configuration_utils import PretrainedConfig 8 | 9 | 10 | class RetNetConfig(PretrainedConfig): 11 | 12 | model_type = 'retnet' 13 | keys_to_ignore_at_inference = ['past_key_values'] 14 | 15 | def __init__( 16 | self, 17 | vocab_size: int = 32000, 18 | hidden_size: int = 2048, 19 | expand_k: int = 1, 20 | expand_v: int = 2, 21 | hidden_ratio: Optional[int] = 2, 22 | intermediate_size: Optional[int] = None, 23 | num_hidden_layers: int = 24, 24 | num_heads: int = 8, 25 | num_kv_heads: Optional[int] = None, 26 | feature_map: Optional[str] = None, 27 | attn_mode: str = "fused_chunk", 28 | hidden_act: str = "swish", 29 | use_short_conv: bool = False, 30 | conv_size: int = 4, 31 | share_conv_kernel: bool = True, 32 | use_output_gate: bool = True, 33 | max_position_embeddings: int = 2048, 34 | elementwise_affine: Optional[bool] = True, 35 | norm_eps: float = 1e-6, 36 | use_cache: bool = True, 37 | pad_token_id: int = None, 38 | bos_token_id: int = 1, 39 | eos_token_id: int = 2, 40 | tie_word_embeddings: bool = False, 41 | initializer_range: float = 0.02, 42 | fuse_norm: bool = True, 43 | fuse_cross_entropy: bool = True, 44 | **kwargs 45 | ) -> RetNetConfig: 46 | self.vocab_size = vocab_size 47 | self.max_position_embeddings = max_position_embeddings 48 | self.hidden_size = hidden_size 49 | self.expand_k = expand_k 50 | self.expand_v = expand_v 51 | self.hidden_ratio = hidden_ratio 52 | self.intermediate_size = intermediate_size 53 | self.num_hidden_layers = num_hidden_layers 54 | self.num_heads = num_heads 55 | self.num_kv_heads = num_kv_heads 56 | self.feature_map = feature_map 57 | self.attn_mode = attn_mode 58 | self.hidden_act = hidden_act 59 | self.use_short_conv = use_short_conv 60 | self.conv_size = conv_size 61 | self.share_conv_kernel = share_conv_kernel 62 | self.use_output_gate = use_output_gate 63 | self.elementwise_affine = elementwise_affine 64 | self.norm_eps = norm_eps 65 | self.use_cache = use_cache 66 | self.initializer_range = initializer_range 67 | self.fuse_norm = fuse_norm 68 | self.fuse_cross_entropy = fuse_cross_entropy 69 | 70 | super().__init__( 71 | pad_token_id=pad_token_id, 72 | bos_token_id=bos_token_id, 73 | eos_token_id=eos_token_id, 74 | tie_word_embeddings=tie_word_embeddings, 75 | **kwargs, 76 | ) 77 | -------------------------------------------------------------------------------- /fla/models/rwkv6/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.rwkv6.configuration_rwkv6 import RWKV6Config 6 | from fla.models.rwkv6.modeling_rwkv6 import RWKV6ForCausalLM, RWKV6Model 7 | 8 | AutoConfig.register(RWKV6Config.model_type, RWKV6Config) 9 | AutoModel.register(RWKV6Config, RWKV6Model) 10 | AutoModelForCausalLM.register(RWKV6Config, RWKV6ForCausalLM) 11 | 12 | 13 | __all__ = ['RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model'] 14 | -------------------------------------------------------------------------------- /fla/models/rwkv6/configuration_rwkv6.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class RWKV6Config(PretrainedConfig): 9 | 10 | model_type = 'rwkv6' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | attn_mode: str = "chunk", 16 | vocab_size: int = 32000, 17 | hidden_size: int = 2048, 18 | expand_k: int = 0.5, 19 | expand_v: int = 1, 20 | hidden_ratio: Optional[int] = 3.5, 21 | intermediate_size: Optional[int] = None, 22 | use_glu: Optional[bool] = False, 23 | num_hidden_layers: int = 24, 24 | num_heads: int = 4, 25 | proj_low_rank_dim: int = 32, 26 | gate_low_rank_dim: int = 64, 27 | hidden_act: str = "sqrelu", 28 | max_position_embeddings: int = 2048, 29 | eps: float = 1e-6, 30 | use_cache: bool = True, 31 | pad_token_id: int = None, 32 | bos_token_id: int = 1, 33 | eos_token_id: int = 2, 34 | tie_word_embeddings: bool = False, 35 | initializer_range: float = 0.02, 36 | fuse_norm: bool = True, 37 | fuse_cross_entropy: bool = True, 38 | **kwargs 39 | ): 40 | self.vocab_size = vocab_size 41 | self.max_position_embeddings = max_position_embeddings 42 | self.hidden_size = hidden_size 43 | self.expand_k = expand_k 44 | self.expand_v = expand_v 45 | self.hidden_ratio = hidden_ratio 46 | self.intermediate_size = intermediate_size 47 | self.use_glu = use_glu 48 | self.num_hidden_layers = num_hidden_layers 49 | self.num_heads = num_heads 50 | self.proj_low_rank_dim = proj_low_rank_dim 51 | self.gate_low_rank_dim = gate_low_rank_dim 52 | self.attn_mode = attn_mode 53 | self.hidden_act = hidden_act 54 | self.eps = eps 55 | self.use_cache = use_cache 56 | self.initializer_range = initializer_range 57 | self.fuse_norm = fuse_norm 58 | self.fuse_cross_entropy = fuse_cross_entropy 59 | 60 | super().__init__( 61 | pad_token_id=pad_token_id, 62 | bos_token_id=bos_token_id, 63 | eos_token_id=eos_token_id, 64 | tie_word_embeddings=tie_word_embeddings, 65 | **kwargs, 66 | ) 67 | -------------------------------------------------------------------------------- /fla/models/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.transformer.configuration_transformer import TransformerConfig 6 | from fla.models.transformer.modeling_transformer import ( 7 | TransformerForCausalLM, TransformerModel) 8 | 9 | AutoConfig.register(TransformerConfig.model_type, TransformerConfig) 10 | AutoModel.register(TransformerConfig, TransformerModel) 11 | AutoModelForCausalLM.register(TransformerConfig, TransformerForCausalLM) 12 | 13 | 14 | __all__ = ['TransformerConfig', 'TransformerForCausalLM', 'TransformerModel'] 15 | -------------------------------------------------------------------------------- /fla/models/transformer/configuration_transformer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class TransformerConfig(PretrainedConfig): 9 | 10 | model_type = 'transformer' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | vocab_size: int = 32000, 16 | hidden_size: int = 2048, 17 | hidden_ratio: Optional[int] = 4, 18 | intermediate_size: Optional[int] = None, 19 | num_hidden_layers: int = 24, 20 | num_heads: int = 32, 21 | num_kv_heads: int = None, 22 | hidden_act: str = "swish", 23 | max_position_embeddings: int = 2048, 24 | initializer_range: float = 0.02, 25 | elementwise_affine: Optional[bool] = True, 26 | norm_eps: float = 1e-6, 27 | use_cache: bool = True, 28 | pad_token_id: int = None, 29 | bos_token_id: int = 1, 30 | eos_token_id: int = 2, 31 | tie_word_embeddings: bool = False, 32 | attention_bias: bool = False, 33 | fuse_norm: bool = True, 34 | fuse_cross_entropy: bool = True, 35 | **kwargs, 36 | ): 37 | self.vocab_size = vocab_size 38 | self.max_position_embeddings = max_position_embeddings 39 | self.hidden_size = hidden_size 40 | self.hidden_ratio = hidden_ratio 41 | self.intermediate_size = intermediate_size 42 | self.num_hidden_layers = num_hidden_layers 43 | self.num_heads = num_heads 44 | self.num_kv_heads = num_kv_heads 45 | 46 | self.hidden_act = hidden_act 47 | self.initializer_range = initializer_range 48 | self.elementwise_affine = elementwise_affine 49 | self.norm_eps = norm_eps 50 | self.use_cache = use_cache 51 | self.attention_bias = attention_bias 52 | self.fuse_cross_entropy = fuse_cross_entropy 53 | self.fuse_norm = fuse_norm 54 | 55 | super().__init__( 56 | pad_token_id=pad_token_id, 57 | bos_token_id=bos_token_id, 58 | eos_token_id=eos_token_id, 59 | tie_word_embeddings=tie_word_embeddings, 60 | **kwargs, 61 | ) 62 | -------------------------------------------------------------------------------- /fla/models/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import annotations 4 | 5 | from typing import Any, Dict, List, Optional, Tuple 6 | 7 | import torch 8 | from transformers.cache_utils import Cache 9 | 10 | 11 | class RecurrentCache(Cache): 12 | """ 13 | A cache used for storing hidden states produced by flash linear attention models. 14 | 15 | It stores the states of each layer as the tensor of shape `[batch_size, key_dim, value_dim]`. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | seen_tokens: int = 0 21 | ) -> RecurrentCache: 22 | 23 | self.states: List[torch.Tensor] = [] 24 | self._seen_tokens = seen_tokens # Used in `generate` to keep tally of how many tokens the cache has seen 25 | 26 | def __getitem__(self, layer_idx: int) -> torch.Tensor: 27 | if layer_idx < len(self): 28 | return self.states[layer_idx] 29 | else: 30 | raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") 31 | 32 | def __iter__(self): 33 | for state in self.states: 34 | yield state 35 | 36 | def __len__(self): 37 | return len(self.states) 38 | 39 | def update( 40 | self, 41 | state: Tuple[torch.Tensor], 42 | layer_idx: int, 43 | offset: Optional[int] = 1, 44 | cache_kwargs: Optional[Dict[str, Any]] = None, 45 | ) -> Tuple[torch.Tensor]: 46 | """ 47 | Updates the cache with the new `state` for the layer `layer_idx`. 48 | 49 | Parameters: 50 | state (`Tuple[torch.Tensor]`): 51 | The new state to cache. 52 | layer_idx (`int`): 53 | The index of the layer to cache the states for. 54 | offset (`int`): 55 | The offset of current fed tokens. 56 | cache_kwargs (`Dict[str, Any]`, `optional`): 57 | Additional arguments for the cache subclass. 58 | 59 | Return: 60 | The updated state. 61 | """ 62 | 63 | if isinstance(state, torch.Tensor): 64 | state = (state,) 65 | if len(self.states) <= layer_idx: 66 | self.states.append(state) 67 | else: 68 | for i, s in enumerate(state): 69 | self.states[layer_idx][i].copy_(s) 70 | # update the number of seen tokens once we achieve the last layer 71 | if layer_idx == len(self) - 1: 72 | self._seen_tokens += offset 73 | 74 | return state 75 | 76 | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: 77 | """Returns the sequence length of the cached states. A layer index can be optionally passed.""" 78 | if len(self.states) <= layer_idx: 79 | return 0 80 | return self._seen_tokens 81 | 82 | def get_max_length(self) -> Optional[int]: 83 | """Returns the maximum sequence length of the cached states. RecurrentCache does not have a maximum length.""" 84 | return None 85 | 86 | def reorder_cache(self, beam_idx: torch.LongTensor): 87 | """Reorders the cache for beam search, given the selected beam indices.""" 88 | for layer_idx in range(len(self.states)): 89 | device = self.states[layer_idx].device 90 | self.states[layer_idx] = self.states[layer_idx].index_select(0, beam_idx.to(device)) 91 | 92 | def to_legacy_cache(self) -> Tuple[torch.Tensor]: 93 | return tuple(self.states) 94 | 95 | @classmethod 96 | def from_legacy_cache( 97 | cls, 98 | past_key_values: Optional[Tuple[torch.Tensor]] = None, 99 | seen_tokens: int = 0 100 | ) -> RecurrentCache: 101 | """Converts a cache in the legacy cache format into an equivalent `RecurrentCache`.""" 102 | 103 | cache = cls(seen_tokens) 104 | if past_key_values is not None: 105 | for layer_idx in range(len(past_key_values)): 106 | cache.update(past_key_values[layer_idx], layer_idx) 107 | return cache 108 | -------------------------------------------------------------------------------- /fla/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from fla.modules.convolution import (ImplicitLongConvolution, LongConvolution, 4 | ShortConvolution) 5 | from fla.modules.fused_cross_entropy import FusedCrossEntropyLoss 6 | from fla.modules.fused_norm_gate import (FusedLayerNormSwishGate, 7 | FusedLayerNormSwishGateLinear, 8 | FusedRMSNormSwishGate, 9 | FusedRMSNormSwishGateLinear) 10 | from fla.modules.layernorm import (LayerNorm, LayerNormLinear, RMSNorm, 11 | RMSNormLinear) 12 | from fla.modules.rotary import RotaryEmbedding 13 | 14 | __all__ = [ 15 | 'ImplicitLongConvolution', 'LongConvolution', 'ShortConvolution', 16 | 'FusedCrossEntropyLoss', 17 | 'LayerNorm', 'LayerNormLinear', 'RMSNorm', 'RMSNormLinear', 18 | 'FusedLayerNormSwishGate', 'FusedLayerNormSwishGateLinear', 'FusedRMSNormSwishGate', 'FusedRMSNormSwishGateLinear', 19 | 'RotaryEmbedding' 20 | ] 21 | -------------------------------------------------------------------------------- /fla/modules/l2norm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.cuda.amp import custom_fwd, custom_bwd 6 | import triton 7 | import triton.language as tl 8 | 9 | @triton.autotune( 10 | configs=[ 11 | triton.Config({}, num_warps=1), 12 | triton.Config({}, num_warps=2), 13 | triton.Config({}, num_warps=4), 14 | triton.Config({}, num_warps=8), 15 | triton.Config({}, num_warps=16), 16 | triton.Config({}, num_warps=32), 17 | ], 18 | key=["N"], 19 | ) 20 | # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) 21 | # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) 22 | @triton.jit 23 | def _l2_norm_fwd_1pass_kernel( 24 | X, # pointer to the input 25 | Y, # pointer to the output 26 | stride_x_row, # how much to increase the pointer when moving by 1 row 27 | N, # number of columns in X 28 | eps, # epsilon to avoid division by zero 29 | BLOCK_N: tl.constexpr, 30 | ): 31 | # Map the program id to the row of X and Y it should compute. 32 | row = tl.program_id(0) 33 | X += row * stride_x_row 34 | Y += row * stride_x_row 35 | # Compute mean and variance 36 | cols = tl.arange(0, BLOCK_N) 37 | x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) 38 | xbar = tl.where(cols < N, x, 0.0) 39 | var = tl.sum(xbar * xbar, axis=0) 40 | rstd = 1 / tl.sqrt(var + eps) 41 | # tl.store(Rstd + row, rstd) 42 | # Normalize and apply linear transformation 43 | mask = cols < N 44 | y = x * rstd 45 | # Write output 46 | tl.store(Y + cols, y, mask=mask) 47 | 48 | 49 | @triton.autotune( 50 | configs=[ 51 | triton.Config({}, num_warps=1), 52 | triton.Config({}, num_warps=2), 53 | triton.Config({}, num_warps=4), 54 | triton.Config({}, num_warps=8), 55 | triton.Config({}, num_warps=16), 56 | triton.Config({}, num_warps=32), 57 | ], 58 | key=["N"], 59 | ) 60 | # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) 61 | # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) 62 | # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) 63 | # @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) 64 | @triton.jit 65 | def _l2_norm_bwd_kernel( 66 | X, # pointer to the input 67 | # Y, # pointer to the output to be recomputed 68 | DY, # pointer to the output gradient 69 | DX, # pointer to the input gradient 70 | stride_x_row, # how much to increase the pointer when moving by 1 row 71 | N, # number of columns in X 72 | eps, # epsilon to avoid division by zero 73 | BLOCK_N: tl.constexpr, 74 | ): 75 | # Map the program id to the elements of X, DX, and DY it should compute. 76 | # Map the program id to the row of X and Y it should compute. 77 | row = tl.program_id(0) 78 | X += row * stride_x_row 79 | DX += row * stride_x_row 80 | DY += row * stride_x_row 81 | 82 | # Y += row * stride_y_row 83 | cols = tl.arange(0, BLOCK_N) 84 | x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) 85 | x = tl.where(cols < N, x, 0.0) 86 | var = tl.sum(x * x) 87 | rstd = 1 / tl.sqrt(var + eps) 88 | # tl.store(Rstd + row, rstd) 89 | # Normalize and apply linear transformation 90 | mask = cols < N 91 | # y = x * rstd 92 | dy = tl.load(DY + cols, mask=cols < N, other=0.0).to(tl.float32) 93 | dy = tl.where(cols < N, dy, 0.0) 94 | # dx = dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x 95 | dx = dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x 96 | tl.store(DX + cols, dx, mask=mask) 97 | 98 | def _l2_norm_fwd( 99 | x, eps=1e-6 100 | ): 101 | x_shape_og = x.shape 102 | x = x.reshape(-1, x.shape[-1]) 103 | if x.stride(-1) != 1: 104 | x = x.contiguous() 105 | M, N = x.shape 106 | assert x.stride(-1) == 1 107 | # allocate output 108 | y = torch.empty_like(x) 109 | assert y.stride(-1) == 1 110 | N = x.shape[-1] 111 | M = x.shape[0] 112 | # rstd = torch.empty((M,), dtype=torch.float32, device="cuda") 113 | # Less than 64KB per feature: enqueue fused kernel 114 | MAX_FUSED_SIZE = 65536 // x.element_size() 115 | BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) 116 | if N > BLOCK_N: 117 | raise RuntimeError( 118 | "This layer norm doesn't support feature dim >= 64KB.") 119 | # heuristics for number of warps 120 | with torch.cuda.device(x.device.index): 121 | _l2_norm_fwd_1pass_kernel[(M,)]( 122 | x, 123 | y, 124 | x.stride(0), 125 | N, 126 | eps, 127 | # is_rms_norm, 128 | BLOCK_N, 129 | # residual is not None, 130 | # residual_out is not None, 131 | # bias is not None, 132 | ) 133 | return y.reshape(x_shape_og) 134 | 135 | def _l2_norm_bwd( 136 | x, dy, eps=1e-5, 137 | ): 138 | x_shape_og = x.shape 139 | x = x.reshape(-1, dy.shape[-1]) 140 | dy = dy.reshape(-1, dy.shape[-1]) 141 | if dy.stride(-1) != 1: 142 | dy = dy.contiguous() 143 | assert dy.shape == x.shape 144 | # allocate output 145 | dx = torch.empty_like(x) 146 | N = x.shape[-1] 147 | M = x.shape[0] 148 | assert x.stride(-1) == 1 149 | assert dy.stride(-1) == 1 150 | # rstd = torch.empty((M,), dtype=torch.float32, device="cuda") 151 | # Less than 64KB per feature: enqueue fused kernel 152 | MAX_FUSED_SIZE = 65536 // x.element_size() 153 | BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) 154 | if N > BLOCK_N: 155 | raise RuntimeError( 156 | "This layer norm doesn't support feature dim >= 64KB.") 157 | # heuristics for number of warps 158 | with torch.cuda.device(x.device.index): 159 | _l2_norm_bwd_kernel[(M,)]( 160 | x, 161 | dy, 162 | dx, 163 | x.stride(0), 164 | N, 165 | eps, 166 | BLOCK_N, 167 | ) 168 | return dx.reshape(x_shape_og) 169 | 170 | 171 | class L2NormFN(torch.autograd.Function): 172 | @staticmethod 173 | def forward( 174 | ctx, 175 | x, 176 | eps=1e-6, 177 | ): 178 | # reshape input data into 2D tensor 179 | y = _l2_norm_fwd(x, eps) 180 | ctx.x_shape_og = x_shape_og 181 | ctx.eps = eps 182 | ctx.x_dtype = x.dtype 183 | ctx.save_for_backward(x) 184 | return y 185 | 186 | @staticmethod 187 | def backward(ctx, dy, *args): 188 | x, = ctx.saved_tensors 189 | dx = _l2_norm_bwd( 190 | x, 191 | dy, 192 | ctx.eps, 193 | ) 194 | return ( 195 | dx, 196 | None 197 | ) 198 | 199 | l2_norm_fn = L2NormFN.apply 200 | 201 | 202 | 203 | 204 | -------------------------------------------------------------------------------- /fla/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .based import fused_chunk_based, parallel_based 4 | from .gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla 5 | from .retention import (chunk_retention, fused_chunk_retention, 6 | fused_recurrent_retention, parallel_retention) 7 | 8 | __all__ = [ 9 | 'fused_chunk_based', 10 | 'parallel_based', 11 | 'chunk_gla', 12 | 'fused_chunk_gla', 13 | 'fused_recurrent_gla', 14 | 'chunk_retention', 15 | 'fused_chunk_retention', 16 | 'fused_recurrent_retention', 17 | 'parallel_retention' 18 | ] 19 | -------------------------------------------------------------------------------- /fla/ops/abc/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_abc 4 | from .chunk_gate import chunk_gated_abc 5 | from .recurrent_fuse import fused_recurrent_gated_abc 6 | 7 | __all__ = [ 8 | 'chunk_abc', 9 | 'chunk_gated_abc', 10 | 'fused_recurrent_gated_abc' 11 | ] 12 | -------------------------------------------------------------------------------- /fla/ops/abc/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | import torch 6 | 7 | 8 | def naive_recurrent_abc( 9 | q: torch.Tensor, 10 | k: torch.Tensor, 11 | v: torch.Tensor, 12 | s: torch.Tensor, 13 | g: Optional[torch.Tensor] = None, 14 | scale: Optional[int] = None, 15 | initial_state: Optional[torch.Tensor] = None, 16 | output_final_state: Optional[bool] = False 17 | ) -> torch.Tensor: 18 | dtype = q.dtype 19 | 20 | # [batch_size, n_heads, seq_len, n_slots] 21 | if g is None: 22 | z = s.float().logcumsumexp(2) 23 | g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z 24 | s = torch.exp(s - z) 25 | q, k, v, s, g = map(lambda x: x.float(), (q, k, v, s, g)) 26 | B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] 27 | 28 | hk = torch.zeros(B, H, K, M, dtype=torch.float, device=q.device) 29 | ok = torch.zeros_like(s) 30 | 31 | if scale is None: 32 | scale = q.shape[-1] ** -0.5 33 | 34 | final_state = None 35 | if initial_state is not None: 36 | hk += initial_state[0] 37 | 38 | for i in range(T): 39 | q_i = q[:, :, i] * scale 40 | k_i = k[:, :, i] 41 | v_i = s[:, :, i] 42 | g_i = g[:, :, i].exp() 43 | hk = hk * g_i[..., None, :] + k_i[..., None] * v_i[..., None, :] 44 | ok[:, :, i] = (q_i[..., None] * hk).sum(-2) 45 | 46 | qv = ok.softmax(-1) 47 | hv = torch.zeros(B, H, M, V, dtype=torch.float, device=q.device) 48 | ov = torch.zeros_like(v) 49 | if initial_state is not None: 50 | hv += initial_state[1] 51 | 52 | for i in range(T): 53 | q_i = qv[:, :, i] 54 | k_i = s[:, :, i] 55 | v_i = v[:, :, i] 56 | g_i = g[:, :, i].exp() 57 | hv = hv * g_i[..., :, None] + k_i[..., None] * v_i[..., None, :] 58 | ov[:, :, i] = (q_i[..., None] * hv).sum(-2) 59 | 60 | if output_final_state: 61 | final_state = (hk, hv) 62 | return ov.to(dtype), final_state 63 | 64 | 65 | def naive_cumsum_abc( 66 | q: torch.Tensor, 67 | k: torch.Tensor, 68 | v: torch.Tensor, 69 | s: torch.Tensor 70 | ) -> torch.Tensor: 71 | """ 72 | A simple implementation of vanilla ABC that is more aligned with the descriptions in the paper. 73 | This is just for demonstration purposes, with no numerical stabilities guaranteed. 74 | """ 75 | 76 | dtype = q.dtype 77 | q, k, v, s = map(lambda x: x.float(), (q, k, v, s)) 78 | 79 | scale = q.shape[-1] ** -0.5 80 | # [batch_size, n_heads, seq_len, n_slots] 81 | s = (s - s.max(2, True)[0]).exp() 82 | z = s.cumsum(2) 83 | # [batch_size, n_heads, seq_len, n_slots, d_head] 84 | K = (s.unsqueeze(-1) * k.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1) 85 | V = (s.unsqueeze(-1) * v.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1) 86 | # [batch_size, n_heads, seq_len, n_slots] 87 | p = torch.einsum('...d,...md->...m', q * scale, K).softmax(-1) 88 | # [batch_size, n_heads, seq_len, d_head] 89 | o = torch.einsum('...m,...md->...d', p, V) 90 | return o.to(dtype), None 91 | -------------------------------------------------------------------------------- /fla/ops/based/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk_fuse import fused_chunk_based 4 | from .parallel import parallel_based 5 | 6 | __all__ = [ 7 | 'fused_chunk_based', 8 | 'parallel_based' 9 | ] 10 | -------------------------------------------------------------------------------- /fla/ops/based/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from einops import rearrange 5 | 6 | from fla.ops.based.chunk_fuse import fused_chunk_based 7 | from fla.ops.based.parallel import parallel_based 8 | 9 | 10 | def naive_parallel_based(q, k, v, use_scale=True, use_norm=True): 11 | if use_scale: 12 | q = q * (q.shape[-1] ** -0.5) 13 | attn = q @ k.transpose(-2, -1) 14 | attn = 1 + attn + 1/2 * (attn ** 2) 15 | attn.masked_fill_(~torch.tril(torch.ones( 16 | q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0) 17 | o = attn @ v 18 | if use_norm: 19 | z = attn.sum(-1) 20 | return o / (z[..., None] + 1e-6) 21 | else: 22 | return o 23 | 24 | 25 | def naive_chunk_based(q, k, v, chunk_size=256): 26 | q = q * (q.shape[-1] ** -0.5) 27 | 28 | # compute normalizer. 29 | k_cumsum = torch.cumsum(k, dim=-2) 30 | kk_cumsum = torch.cumsum(k.unsqueeze(-1) * k.unsqueeze(-2), dim=-3) 31 | # first 32 | z = (q * k_cumsum).sum(-1) 33 | # second order 34 | z += (q.unsqueeze(-1) * q.unsqueeze(-2) * kk_cumsum).sum((-1, -2)) * 0.5 35 | # zero-th order 36 | z += (torch.arange(0, q.shape[-2]).to(z.device) * 1.0 + 1.0)[None, None, :] 37 | 38 | # compute o 39 | # constant term 40 | _o = v.cumsum(-2) 41 | 42 | q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size) 43 | 44 | k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size) 45 | v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size) 46 | 47 | intra_chunk_attn = q @ k.transpose(-2, -1) 48 | intra_chunk_attn = intra_chunk_attn + 1/2 * (intra_chunk_attn ** 2) 49 | intra_chunk_attn.masked_fill_( 50 | ~torch.tril( 51 | torch.ones(chunk_size, chunk_size, 52 | dtype=torch.bool, device=q.device), 53 | ), 0) 54 | o = intra_chunk_attn @ v 55 | 56 | # quadractic term 57 | kv = torch.einsum( 58 | 'b h n c x, b h n c y, b h n c z -> b h n x y z', k, k, v) 59 | kv = kv.cumsum(2) 60 | kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) 61 | 62 | o += 0.5 * torch.einsum('b h n x y z, b h n c x, b h n c y -> b h n c z', kv, q, q) 63 | 64 | # linear term 65 | kv = torch.einsum('b h n c x, b h n c y -> b h n x y', k, v) 66 | kv = kv.cumsum(2) 67 | kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) 68 | o += torch.einsum('b h n x y, b h n c x -> b h n c y', kv, q) 69 | 70 | o = rearrange(o, 'b h n c d -> b h (n c) d') 71 | o = o + _o 72 | return o / (z[..., None] + 1e-6) 73 | -------------------------------------------------------------------------------- /fla/ops/delta_rule/README.md: -------------------------------------------------------------------------------- 1 | - Delta Rule 2 | 3 | The implementation of delta rule described in https://arxiv.org/abs/2102.11174 4 | 5 | -------------------------------------------------------------------------------- /fla/ops/delta_rule/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk_fuse import fused_chunk_delta_rule 4 | from .recurrent_fuse import fused_recurrent_linear_attn_delta_rule 5 | from .chunk import chunk_delta_rule 6 | 7 | __all__ = [ 8 | 'fused_chunk_delta_rule', 9 | 'fused_recurrent_linear_attn_delta_rule', 10 | 'chunk_delta_rule' 11 | ] 12 | -------------------------------------------------------------------------------- /fla/ops/delta_rule/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from einops import rearrange 5 | 6 | 7 | def delta_rule_recurrence(q, k, v, beta): 8 | b, h, l, d_k = q.shape 9 | d_v = v.shape[-1] 10 | o = torch.zeros_like(v) 11 | S = torch.zeros(b, h, d_k, d_v).to(v) 12 | q = q * (d_k ** -0.5) 13 | for i in range(l): 14 | _k = k[:, :, i] 15 | _q = q[:, :, i] 16 | _v = v[:, :, i].clone() 17 | beta_i = beta[:, :, i] 18 | _v = _v - (S.clone() * _k[..., None]).sum(-2) 19 | _v = _v * beta_i[..., None] 20 | S = S.clone() + _k.unsqueeze(-1) * _v.unsqueeze(-2) 21 | o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) 22 | return o 23 | 24 | 25 | def delta_rule_chunkwise(q, k, v, beta, chunk_size=32): 26 | b, h, l, d_k = q.shape 27 | d_v = v.shape[-1] 28 | q = q * (d_k ** -0.5) 29 | v = v * beta[..., None] 30 | k_beta = k * beta[..., None] 31 | 32 | assert l % chunk_size == 0 33 | 34 | # note that diagonal is masked. 35 | mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0) 36 | q, k, v, k_beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), [q, k, v, k_beta]) 37 | attn = -(k_beta @ k.transpose(-1, -2)).masked_fill(mask, 0) 38 | 39 | for i in range(1, chunk_size): 40 | attn[..., i, :i] = attn[..., i, :i] + (attn[..., i, :, None].clone() * attn[..., :, :i].clone()).sum(-2) 41 | 42 | attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device) 43 | # u 44 | k_cumsum = attn @ v 45 | # w 46 | k_cumdecay = attn @ k_beta 47 | 48 | v = k_cumsum 49 | S = k.new_zeros(b, h, d_k, d_v) 50 | o = torch.zeros_like(v) 51 | mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1) 52 | for i in range(0, l // chunk_size): 53 | q_i, k_i, v_i = q[:, :, i], k[:, :, i], v[:, :, i] 54 | attn = (q_i @ k_i.transpose(-1, -2)).masked_fill_(mask, 0) 55 | v_prime = k_cumdecay[:, :, i] @ S 56 | v_new = v_i - v_prime 57 | o_inter = q_i @ S 58 | o[:, :, i] = o_inter + attn @ v_new 59 | # chunk state update 60 | S = S + k_i.transpose(-1, -2) @ v_new 61 | 62 | return rearrange(o, 'b h n c d -> b h (n c) d') 63 | -------------------------------------------------------------------------------- /fla/ops/gla/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_gla 4 | from .chunk_fuse import fused_chunk_gla 5 | from .recurrent_fuse import fused_recurrent_gla 6 | 7 | __all__ = [ 8 | 'chunk_gla', 9 | 'fused_chunk_gla', 10 | 'fused_recurrent_gla' 11 | ] 12 | -------------------------------------------------------------------------------- /fla/ops/gla/chunk_util.py: -------------------------------------------------------------------------------- 1 | import triton 2 | import triton.language as tl 3 | 4 | inv_ln2 = 1.44269504 5 | 6 | 7 | 8 | @triton.jit 9 | def fwd_decay_cumsum( 10 | g, 11 | g_o, 12 | s_qk_h, 13 | s_qk_t, 14 | s_qk_d, 15 | B, 16 | H, 17 | T, 18 | scale, 19 | BT: tl.constexpr, 20 | BK: tl.constexpr, 21 | DK: tl.constexpr 22 | ): 23 | i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) 24 | p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK) 25 | p_go = g_o + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK) 26 | cum_decay = tl.zeros([BK], dtype=tl.float32) 27 | mask = (i_k * BK + tl.arange(0, BK)) < DK 28 | 29 | for i in range(BT): 30 | _g = tl.load(p_g, mask=mask, other=0).to(tl.float32) 31 | cum_decay += _g * inv_ln2 32 | tl.store(p_go, cum_decay.to(p_go.dtype.element_ty), mask=mask) 33 | p_g += DK 34 | p_go += DK 35 | 36 | @triton.jit 37 | def prepare_qg_kg( 38 | q, 39 | k, 40 | g, 41 | qg, 42 | kg, 43 | s_qk_h, 44 | s_qk_t, 45 | s_qk_d, 46 | B, 47 | H, 48 | T, 49 | scale, 50 | BT: tl.constexpr, 51 | BK: tl.constexpr, 52 | DK: tl.constexpr 53 | ): 54 | 55 | i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) 56 | p_q = q + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK) 57 | p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK) 58 | p_k = k + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK) 59 | p_qg = qg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK) 60 | p_kg = kg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK) 61 | 62 | mask = (i_k * BK + tl.arange(0, BK)) < DK 63 | 64 | last_decay = tl.load(g + i_bh * s_qk_h + (i_c * BT + BT - 1) * DK + i_k * BK + tl.arange(0, BK)) 65 | 66 | for i in range(BT): 67 | _q = tl.load(p_q, mask=mask, other=0) 68 | _k = tl.load(p_k, mask=mask, other=0) 69 | _g = tl.load(p_g, mask=mask, other=0).to(tl.float32) 70 | _q *= tl.math.exp2(_g) * scale 71 | _k *= tl.math.exp2(last_decay - _g) 72 | tl.store(p_kg, _k.to(p_kg.dtype.element_ty), mask=mask) 73 | tl.store(p_qg, _q.to(p_qg.dtype.element_ty), mask=mask) 74 | p_q += DK 75 | p_g += DK 76 | p_k += DK 77 | p_kg += DK 78 | p_qg += DK 79 | 80 | 81 | @triton.jit 82 | def bwd_decay_global_cumsum( 83 | dq_inner, 84 | dq_inter, 85 | dk_inner, 86 | dk_inter, 87 | q, k, g, dg, 88 | s_qk_h, 89 | s_qk_t, 90 | s_qk_d, 91 | B, 92 | H, 93 | T, 94 | scale, 95 | BT: tl.constexpr, 96 | BK: tl.constexpr, 97 | DK: tl.constexpr 98 | ): 99 | i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) 100 | p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK 101 | p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK 102 | p_g = g + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK 103 | p_dg = dg + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK 104 | p_dq_inner = dq_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK 105 | p_dk_inner = dk_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK 106 | p_dq_inter = dq_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK 107 | p_dk_inter = dk_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK 108 | cum_grad_dg = tl.zeros([BK], dtype=tl.float32) 109 | mask = (i_k * BK + tl.arange(0, BK)) < DK 110 | last_g = tl.zeros([BK], dtype=tl.float32) 111 | for j in range(BT-1, -1, -1): 112 | _g = tl.load(p_g, mask=mask, other=0).to(tl.float32) 113 | if j == (BT-1): 114 | last_g = _g 115 | _dq1 = tl.load(p_dq_inner, mask=mask, other=0) 116 | _dq2 = tl.load(p_dq_inter, mask=mask, other=0) 117 | _dq2 *= tl.math.exp2(_g) 118 | _dq = _dq1 + _dq2 119 | tl.store(p_dq_inter, _dq, mask=mask) 120 | _dk1 = tl.load(p_dk_inner, mask=mask, other=0) 121 | _dk2 = tl.load(p_dk_inter, mask=mask, other=0) 122 | _dk2 *= tl.math.exp2(last_g - _g) 123 | _dk = _dk1 + _dk2 124 | tl.store(p_dk_inter, _dk, mask=mask) 125 | _q = tl.load(p_q, mask=mask, other=0) 126 | _k = tl.load(p_k, mask=mask, other=0) 127 | _dg = _dq * _q - _dk * _k 128 | cum_grad_dg += _dg 129 | tl.store(p_dg, cum_grad_dg.to(p_dg.dtype.element_ty), mask=mask) 130 | p_g -= DK 131 | p_k -= DK 132 | p_q -= DK 133 | p_dq_inner -= DK 134 | p_dk_inner -= DK 135 | p_dq_inter -= DK 136 | p_dk_inter -= DK 137 | p_dg -= DK 138 | 139 | -------------------------------------------------------------------------------- /fla/ops/gla/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from fla.ops.gla.recurrent_fuse import fused_recurrent_gla 7 | 8 | 9 | def ceildiv(a, b): 10 | return -(a // -b) 11 | 12 | 13 | def naive_recurrent_gla( 14 | q, 15 | k, 16 | v, 17 | gk, 18 | initial_state=None, 19 | output_final_state=False, 20 | causal=True 21 | ): 22 | orig_dtype = q.dtype 23 | q, k, v, gk = map(lambda x: x.float(), (q, k, v, gk)) 24 | batch_size, n_heads, seq_len, d_head_k = q.shape 25 | _, _, _, d_head_v = v.shape 26 | h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device) 27 | o = torch.zeros_like(v) 28 | scale = d_head_k ** -0.5 29 | 30 | if initial_state is not None: 31 | h += initial_state 32 | 33 | for i in range(seq_len): 34 | q_i = q[:, :, i, :] * scale 35 | k_i = k[:, :, i] 36 | v_i = v[:, :, i, :] 37 | gk_i = gk[:, :, i].exp() 38 | kv_i = k_i[..., None] * v_i[..., None, :] 39 | h = h * gk_i[..., None] + kv_i 40 | o_i = (q_i[..., None] * h).sum(-2) 41 | o[:, :, i] = o_i 42 | 43 | if causal: 44 | return o.to(orig_dtype), h 45 | else: 46 | o_reverse = torch.zeros_like(v) 47 | h = torch.zeros(batch_size, n_heads, d_head_k, d_head_v, dtype=torch.float32, device=q.device) 48 | for i in range(seq_len-1, -1, -1): 49 | q_i = q[:, :, i, :] * scale 50 | k_i = k[:, :, i] 51 | v_i = v[:, :, i, :] 52 | gk_i = gk[:, :, i].exp() 53 | kv_i = k_i[..., None] * v_i[..., None, :] 54 | h = h * gk_i[..., None] + kv_i 55 | o_i = (q_i[..., None] * h).sum(-2) 56 | o_reverse[:, :, i] = o_i 57 | 58 | return o, o_reverse 59 | 60 | -------------------------------------------------------------------------------- /fla/ops/hgrn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_hgrn 4 | from .recurrent_fuse import fused_recurrent_hgrn 5 | 6 | __all__ = [ 7 | 'chunk_hgrn', 8 | 'fused_recurrent_hgrn' 9 | ] 10 | -------------------------------------------------------------------------------- /fla/ops/hgrn/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | import torch 6 | 7 | 8 | def naive_recurrent_hgrn( 9 | x: torch.Tensor, 10 | g: torch.Tensor, 11 | initial_state: Optional[torch.Tensor] = None, 12 | output_final_state: Optional[bool] = False 13 | ) -> torch.Tensor: 14 | dtype = x.dtype 15 | x, g = map(lambda i: i.float(), (x, g)) 16 | B, H, T, D = x.shape 17 | 18 | h = torch.zeros(B, H, D, dtype=torch.float, device=x.device) 19 | o = torch.zeros_like(x) 20 | 21 | final_state = None 22 | if initial_state is not None: 23 | h += initial_state.detach() 24 | 25 | for i in range(T): 26 | h = g[:, :, i].exp() * h + x[:, :, i] 27 | o[:, :, i] = h 28 | 29 | if output_final_state: 30 | final_state = h 31 | return o.to(dtype), final_state 32 | -------------------------------------------------------------------------------- /fla/ops/hgrn/recurrent_fuse.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2023, Songlin Yang 4 | 5 | from typing import Tuple 6 | 7 | import torch 8 | import triton 9 | import triton.language as tl 10 | 11 | from fla.utils import contiguous 12 | 13 | 14 | @triton.autotune( 15 | configs=[ 16 | triton.Config({'BD': 32}, num_warps=1), 17 | triton.Config({'BD': 32}, num_warps=2), 18 | triton.Config({'BD': 32}, num_warps=4), 19 | triton.Config({'BD': 32}, num_warps=8), 20 | triton.Config({'BD': 64}, num_warps=1), 21 | triton.Config({'BD': 64}, num_warps=2), 22 | triton.Config({'BD': 64}, num_warps=4), 23 | triton.Config({'BD': 64}, num_warps=8), 24 | triton.Config({'BD': 128}, num_warps=1), 25 | triton.Config({'BD': 128}, num_warps=2), 26 | triton.Config({'BD': 128}, num_warps=4), 27 | triton.Config({'BD': 128}, num_warps=8), 28 | ], 29 | key=['D'] 30 | ) 31 | @triton.jit 32 | def fused_recurrent_hgrn_fwd_kernel( 33 | x, 34 | g, 35 | o, 36 | h0, 37 | ht, 38 | T: tl.constexpr, 39 | D: tl.constexpr, 40 | BD: tl.constexpr, 41 | USE_INITIAL_STATE: tl.constexpr, 42 | STORE_FINAL_STATE: tl.constexpr 43 | ): 44 | i_d, i_bh = tl.program_id(0), tl.program_id(1) 45 | o_d = i_d * BD + tl.arange(0, BD) 46 | mask = o_d < D 47 | 48 | p_x = x + i_bh * T * D + o_d 49 | p_g = g + i_bh * T * D + o_d 50 | p_o = o + i_bh * T * D + o_d 51 | 52 | b_h = tl.zeros([BD], dtype=tl.float32) 53 | if USE_INITIAL_STATE: 54 | p_h0 = h0 + i_bh * D + o_d 55 | b_h += tl.load(p_h0, mask=mask, other=0).to(tl.float32) 56 | for _ in range(0, T): 57 | b_x = tl.load(p_x, mask=mask, other=0).to(tl.float32) 58 | b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32) 59 | b_h = tl.exp(b_g) * b_h + b_x 60 | tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask) 61 | 62 | p_x += D 63 | p_g += D 64 | p_o += D 65 | 66 | if STORE_FINAL_STATE: 67 | p_ht = ht + i_bh * D + o_d 68 | tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask) 69 | 70 | 71 | @triton.autotune( 72 | configs=[ 73 | triton.Config({'BD': 32}, num_warps=1), 74 | triton.Config({'BD': 32}, num_warps=2), 75 | triton.Config({'BD': 32}, num_warps=4), 76 | triton.Config({'BD': 32}, num_warps=8), 77 | triton.Config({'BD': 64}, num_warps=1), 78 | triton.Config({'BD': 64}, num_warps=2), 79 | triton.Config({'BD': 64}, num_warps=4), 80 | triton.Config({'BD': 64}, num_warps=8), 81 | triton.Config({'BD': 128}, num_warps=1), 82 | triton.Config({'BD': 128}, num_warps=2), 83 | triton.Config({'BD': 128}, num_warps=4), 84 | triton.Config({'BD': 128}, num_warps=8), 85 | ], 86 | key=['D'] 87 | ) 88 | @triton.jit 89 | def fused_recurrent_hgrn_bwd_kernel( 90 | g, 91 | o, 92 | dx, 93 | dg, 94 | do, 95 | h0, 96 | T: tl.constexpr, 97 | D: tl.constexpr, 98 | BD: tl.constexpr, 99 | USE_INITIAL_STATE: tl.constexpr 100 | ): 101 | i_d, i_bh = tl.program_id(0), tl.program_id(1) 102 | o_d = i_d * BD + tl.arange(0, BD) 103 | mask = o_d < D 104 | 105 | p_g = g + (i_bh * T + T - 1) * D + o_d 106 | p_o = o + (i_bh * T + T - 2) * D + o_d 107 | p_dx = dx + (i_bh * T + T - 1) * D + o_d 108 | p_dg = dg + (i_bh * T + T - 1) * D + o_d 109 | p_do = do + (i_bh * T + T - 1) * D + o_d 110 | 111 | b_dh = tl.zeros([BD], dtype=tl.float32) 112 | for i in range(T - 1, -1, -1): 113 | b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32) 114 | b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32) 115 | if i > 0: 116 | b_o = tl.load(p_o, mask=mask, other=0).to(tl.float32) 117 | elif USE_INITIAL_STATE: 118 | b_o = tl.load(h0 + i_bh * D + o_d, mask=mask, other=0).to(tl.float32) 119 | else: 120 | b_o = tl.zeros([BD], dtype=tl.float32) 121 | 122 | b_dh = b_dh + b_do 123 | b_dx = b_dh 124 | b_dh = b_dh * tl.exp(b_g) 125 | b_dg = b_dh * b_o 126 | tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask) 127 | tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), mask=mask) 128 | 129 | p_g -= D 130 | p_o -= D 131 | p_dx -= D 132 | p_dg -= D 133 | p_do -= D 134 | 135 | 136 | class FusedRecurrentHGRNFunction(torch.autograd.Function): 137 | 138 | @staticmethod 139 | @contiguous 140 | def forward(ctx, x, g, initial_state=None, output_final_state=False): 141 | B, H, T, D = x.shape 142 | 143 | final_state = None 144 | if output_final_state: 145 | final_state = x.new_empty(B, H, D) 146 | 147 | o = torch.empty_like(x) 148 | def grid(meta): return (triton.cdiv(D, meta['BD']), B * H) 149 | fused_recurrent_hgrn_fwd_kernel[grid]( 150 | x, g, o, initial_state, final_state, 151 | T, D, 152 | USE_INITIAL_STATE=initial_state is not None, 153 | STORE_FINAL_STATE=final_state is not None 154 | ) 155 | ctx.save_for_backward(g, o, initial_state) 156 | return o, final_state 157 | 158 | @staticmethod 159 | @contiguous 160 | def backward(ctx, do, dht=None): 161 | g, o, initial_state = ctx.saved_tensors 162 | B, H, T, D = do.shape 163 | 164 | dx = torch.empty_like(o) 165 | dg = torch.empty_like(g) 166 | def grid(meta): return (triton.cdiv(D, meta['BD']), B * H) 167 | fused_recurrent_hgrn_bwd_kernel[grid]( 168 | g, o, dx, dg, do, initial_state, 169 | T, D, 170 | USE_INITIAL_STATE=initial_state is not None, 171 | ) 172 | 173 | return dx, dg, None, None 174 | 175 | 176 | def fused_recurrent_hgrn( 177 | x: torch.Tensor, 178 | g: torch.Tensor, 179 | initial_state: torch.Tensor = None, 180 | output_final_state: bool = False 181 | ) -> Tuple[torch.Tensor, torch.Tensor]: 182 | if initial_state is not None: 183 | initial_state = initial_state.detach() 184 | o, final_state = FusedRecurrentHGRNFunction.apply(x, g, initial_state, output_final_state) 185 | return o, final_state 186 | -------------------------------------------------------------------------------- /fla/ops/linear_attn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_linear_attn 4 | from .chunk_fuse import fused_chunk_linear_attn 5 | from .recurrent_fuse import fused_recurrent_linear_attn 6 | 7 | __all__ = [ 8 | 'chunk_linear_attn', 9 | 'fused_chunk_linear_attn', 10 | 'fused_recurrent_linear_attn' 11 | ] 12 | 13 | -------------------------------------------------------------------------------- /fla/ops/linear_attn/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from einops import rearrange 5 | 6 | 7 | def torch_chunk_linear_attn(q, k, v, chunk_size=64): 8 | q = rearrange(q, 'b h (n c) d -> b h n c d', c = chunk_size) * (q.shape[-1] **-0.5) 9 | k = rearrange(k, 'b h (n c) d -> b h n c d', c = chunk_size) 10 | v = rearrange(v, 'b h (n c) d -> b h n c d', c = chunk_size) 11 | kv = k.transpose(-1, -2) @ v 12 | kv = kv.cumsum(2) 13 | kv = torch.cat([ 14 | torch.zeros_like(kv[:, :, :1]), 15 | kv[:, :, :-1] 16 | ], dim=2) 17 | inter = q @ kv 18 | intra = ((q @ k.transpose(-1, -2)).masked_fill_(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0)) @ v 19 | o = inter + intra 20 | return rearrange(o, 'b h n c d -> b h (n c) d') 21 | -------------------------------------------------------------------------------- /fla/ops/rebased/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .parallel import parallel_rebased 4 | 5 | __all__ = [ 6 | 'parallel_rebased' 7 | ] 8 | -------------------------------------------------------------------------------- /fla/ops/rebased/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from einops import rearrange 5 | 6 | from fla.ops.rebased.parallel import parallel_rebased 7 | 8 | def naive_parallel_rebased(q, k, v, use_scale=True, use_norm=True): 9 | if use_scale: 10 | q = q * (q.shape[-1] ** -0.5) 11 | attn = q @ k.transpose(-2, -1) 12 | attn = (attn ** 2) 13 | attn.masked_fill_(~torch.tril(torch.ones( 14 | q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0) 15 | o = attn @ v 16 | if use_norm: 17 | z = attn.sum(-1) 18 | return o / (z[..., None] + 1e-6) 19 | else: 20 | return o 21 | -------------------------------------------------------------------------------- /fla/ops/retention/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_retention 4 | from .chunk_fuse import fused_chunk_retention 5 | from .parallel import parallel_retention 6 | from .recurrent_fuse import fused_recurrent_retention 7 | 8 | __all__ = [ 9 | 'chunk_retention', 10 | 'fused_chunk_retention', 11 | 'parallel_retention', 12 | 'fused_recurrent_retention' 13 | ] 14 | -------------------------------------------------------------------------------- /fla/ops/retention/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | 5 | 6 | def naive_retention(q, k, v): 7 | orig_type = q.dtype 8 | q, k, v = q.float(), k.float(), v.float() 9 | _, n_heads, seq_len, d_head = q.shape 10 | s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(n_heads), dtype=torch.float))).log2() 11 | n = q.new_tensor(range(seq_len), dtype=torch.float) 12 | n = torch.exp2((n.unsqueeze(-1) - n) * s.view(-1, 1, 1)) * n.unsqueeze(-1).ge(n) 13 | s = torch.einsum('bhqd,bhkd,hqk->bhqk', q * d_head ** -0.5, k, n.to(q.dtype)) 14 | o = torch.einsum('bhqk,bhkd->bhqd', s, v) 15 | return o.to(orig_type) 16 | -------------------------------------------------------------------------------- /fla/ops/rwkv4/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .recurrent_fuse import fused_recurrent_rwkv4 4 | 5 | __all__ = [ 6 | 'fused_recurrent_rwkv4' 7 | ] 8 | -------------------------------------------------------------------------------- /fla/ops/rwkv6/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_rwkv6 4 | from .recurrent_fuse import fused_recurrent_rwkv6 5 | 6 | __all__ = [ 7 | 'chunk_rwkv6', 8 | 'fused_recurrent_rwkv6' 9 | ] 10 | -------------------------------------------------------------------------------- /fla/ops/rwkv6/chunk_naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from einops import rearrange 5 | 6 | from fla.ops.rwkv6.chunk import chunk_rwkv6 7 | from fla.ops.rwkv6.recurrent_fuse import fused_recurrent_rwkv6 8 | 9 | 10 | def naive_chunk_rwkv6( 11 | q, 12 | k, 13 | v, 14 | w, 15 | u, 16 | chunk_size=32, 17 | initial_state=None, 18 | output_final_state=True, 19 | ): 20 | assert q.shape[-2] % chunk_size == 0 21 | orig_dtype = q.dtype 22 | num_chunk = q.shape[-2] // chunk_size 23 | u = u.unsqueeze(0) 24 | 25 | q, k, v, w = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size).float(), (q, k, v, w)) 26 | 27 | w_cumsum = w.cumsum(-2) 28 | 29 | kw = k * (w_cumsum[..., -1, None, :] - w_cumsum).exp() 30 | wkv = kw.transpose(-1, -2) @ v 31 | 32 | wkv_new = torch.zeros_like(wkv) 33 | 34 | for i in range(num_chunk - 1): 35 | wkv_new[:, :, i+1] = (wkv_new[:, :, i] * w_cumsum[:, :, i, -1, :, None].exp()) + wkv[:, :, i] 36 | 37 | o_inter = torch.einsum('b h n d p, b h n c d -> b h n c p', wkv_new, (q * (w_cumsum - w).exp())) 38 | 39 | o_intra = torch.zeros_like(o_inter) 40 | for i in range(chunk_size): 41 | attn = (q[:, :, :, i, None] * k * (w_cumsum[:, :, :, i, None] - w[:, :, :, i, None] - w_cumsum).exp()).sum(-1) 42 | mask = (torch.arange(0, chunk_size) < i).to(attn.device) 43 | attn.masked_fill_(~mask, 0) 44 | intra_inter_o = (attn.unsqueeze(-1) * v).sum(-2) 45 | intra_intra_o = (q[:, :, :, i] * u.unsqueeze(2) * k[:, :, :, i]).sum(-1).unsqueeze(-1) * v[:, :, :, i] 46 | o_intra[:, :, :, i] = intra_inter_o + intra_intra_o 47 | o = o_inter + o_intra 48 | return rearrange(o, 'b h n c d -> b h (n c) d').to(orig_dtype) 49 | -------------------------------------------------------------------------------- /fla/ops/rwkv6/recurrent_naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | import torch 6 | 7 | 8 | def naive_recurrent_rwkv6( 9 | q: torch.Tensor, 10 | k: torch.Tensor, 11 | v: torch.Tensor, 12 | w: torch.Tensor, 13 | u: torch.Tensor, 14 | scale: Optional[float] = None, 15 | initial_state: Optional[torch.Tensor] = None, 16 | output_final_state: Optional[bool] = False 17 | ): 18 | orig_dtype = q.dtype 19 | B, H, T, K, V = *q.shape, v.shape[-1] 20 | q, k, v, w, u = map(lambda x: x.float(), (q, k, v, w, u)) 21 | h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device) 22 | o = torch.zeros_like(v) 23 | 24 | if scale is None: 25 | scale = K ** -0.5 26 | 27 | if initial_state is not None: 28 | h += initial_state 29 | 30 | for i in range(T): 31 | q_i = q[:, :, i, :] * scale 32 | k_i = k[:, :, i] 33 | v_i = v[:, :, i, :] 34 | w_i = w[:, :, i].exp() 35 | kv_i = k_i[..., None] * v_i[..., None, :] 36 | o_i = (h + u[None, ..., None] * kv_i) * q_i[..., None] 37 | o[:, :, i] = o_i.sum(-2) 38 | h = h * w_i[..., None] + kv_i 39 | ht = h if output_final_state else None 40 | return o.to(orig_dtype), ht 41 | 42 | 43 | def naive_recurrent_rwkv6_bwd( 44 | q, 45 | k, 46 | v, 47 | w, 48 | u, 49 | o, 50 | do, 51 | initial_state=None, 52 | output_final_state=False 53 | ): 54 | q, k, v, w, u, o, do = map(lambda x: x.float(), (q, k, v, w, u, o, do)) 55 | B, H, T, K, V = *q.shape, v.shape[-1] 56 | h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device) 57 | dq = torch.zeros_like(q) 58 | dq_aux = torch.zeros_like(q) 59 | 60 | if initial_state is not None: 61 | h += initial_state 62 | 63 | for i in range(T): 64 | k_i = k[:, :, i] 65 | v_i = v[:, :, i] 66 | w_i = w[:, :, i].exp() 67 | kv_i = k_i[..., None] * v_i[..., None, :] 68 | h_i = (h + u[None, ..., None] * kv_i) 69 | dq_i = (do[:, :, i, None, :] * h_i).sum(-1) 70 | dq_aux_i = (do[:, :, i, None, :] * h).sum(-1) 71 | dq[:, :, i] = dq_i 72 | dq_aux[:, :, i] = dq_aux_i 73 | h = h * w_i[..., None] + kv_i 74 | 75 | du = torch.zeros_like(u) 76 | dh = torch.zeros_like(h) 77 | dk = torch.zeros_like(k) 78 | dk_aux = torch.zeros_like(k) 79 | dv = torch.zeros_like(v) 80 | 81 | for i in range(T - 1, -1, -1): 82 | d_kv_i = do[:, :, i, None, :] * q[:, :, i, :, None] 83 | k_i = k[:, :, i] 84 | v_i = v[:, :, i] 85 | du_i = (d_kv_i * k_i[..., None] * v_i[..., None, :]).sum(-1) 86 | du += du_i 87 | dk_i = (dh * v_i[..., None, :]).sum(-1) 88 | dk_aux[:, :, i] = dk_i 89 | dk_i += (d_kv_i * u[None, ..., None] * v_i[..., None, :]).sum(-1) 90 | dv_i = (d_kv_i * u[None, ..., None] * k_i[..., None]).sum(-2) 91 | dv_i += (dh * k_i[..., None]).sum(-2) 92 | 93 | dk[:, :, i] = dk_i 94 | dv[:, :, i] = dv_i 95 | dh = dh * w[:, :, i, :, None].exp() + d_kv_i 96 | 97 | # dw = q * dq_aux - k * dk_aux 98 | dw = torch.zeros_like(w) 99 | for i in range(T - 2, -1, -1): 100 | dw[:, :, i] = dw[:, :, i+1] + dq_aux[:, :, i+1] * q[:, :, i+1] - dk_aux[:, :, i] * k[:, :, i] 101 | 102 | return dq, dk, dv, dw, du 103 | -------------------------------------------------------------------------------- /fla/ops/simple_gla/README.md: -------------------------------------------------------------------------------- 1 | - Simple GLA 2 | 3 | Gating mechanism in https://arxiv.org/abs/2103.02143. Compared to GLA, the gating is head-wise instead of elementwise. As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability. It is faster than GLA but has less expressive power. I will use it as a baseline for the GLA. 4 | 5 | $S_{t+1} = g_{t+1} \odot S_{t} + K_{t+1} V_{t+1}^{\top}$ where $g$ is a scalar. -------------------------------------------------------------------------------- /fla/ops/simple_gla/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_simple_gla 4 | 5 | __all__ = [ 6 | 'chunk_simple_gla' 7 | ] 8 | 9 | -------------------------------------------------------------------------------- /fla/ops/simple_gla/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from einops import rearrange 5 | 6 | 7 | def torch_simple_gla(q, k, v, g, chunk_size=64): 8 | q = rearrange(q, 'b h (n c) d -> b h n c d', c = chunk_size) * (q.shape[-1] ** -0.5) 9 | k = rearrange(k, 'b h (n c) d -> b h n c d', c = chunk_size) 10 | v = rearrange(v, 'b h (n c) d -> b h n c d', c = chunk_size) 11 | g = rearrange(g, 'b h (n c) -> b h n c', c = chunk_size) 12 | g = g.cumsum(-1) 13 | kv = k.transpose(-1, -2) @ (v * (-g + g[:, :, :, -1, None]).exp()[..., None]) 14 | S = torch.zeros_like(kv) 15 | 16 | for i in range(1, g.shape[-2]): 17 | S[:, :, i] = S[:, :, i-1].clone() * g[:, :, i-1, -1, None, None].exp() + kv[:, :, i-1] 18 | 19 | inter = (q * g[..., None].exp()) @ S 20 | attn = q @ k.transpose(-1, -2) 21 | attn = attn * (g[..., None] - g[..., None, :]).exp() 22 | attn = attn.masked_fill(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0) 23 | intra = attn @ v 24 | o = inter + intra 25 | return rearrange(o, 'b h n c d -> b h (n c) d') 26 | 27 | 28 | def torch_simple_gla_recurrent(q, k, v, g, chunk_size=64): 29 | # q = rearrange(q, 'b h (n c) d -> b h n c d', c = chunk_size) * (q.shape[-1] ** -0.5) 30 | # k = rearrange(k, 'b h (n c) d -> b h n c d', c = chunk_size) 31 | # v = rearrange(v, 'b h (n c) d -> b h n c d', c = chunk_size) 32 | # g = rearrange(g, 'b h (n c) -> b h n c', c = chunk_size) 33 | # g = g.cumsum(-1) 34 | # kv = k.transpose(-1, -2) @ v 35 | 36 | B, H, T, DK = q.shape 37 | q = q * (DK ** -0.5) 38 | _, _, _, DV = v.shape 39 | S = torch.zeros(B, H, DK, DV).to(q) 40 | o = torch.zeros(B, H, T, DV).to(q) 41 | for i in range(T): 42 | gate = g[:, :, i].exp() 43 | key = k[:, :, i] 44 | value = v[:, :, i] 45 | kv = key.unsqueeze(-1) * value.unsqueeze(-2) 46 | S = S.clone() * gate.unsqueeze(-1).unsqueeze(-1) + kv 47 | q_i = q[:, :, i, :] 48 | o_i = (q_i.unsqueeze(-1) * S).sum(-2) 49 | o[:, :, i] = o_i 50 | 51 | return o 52 | 53 | -------------------------------------------------------------------------------- /fla/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import functools 4 | 5 | import torch 6 | 7 | 8 | def contiguous(fn): 9 | @functools.wraps(fn) 10 | def wrapper(ctx, *args, **kwargs): 11 | return fn(ctx, 12 | *(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args), 13 | **{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}) 14 | return wrapper 15 | 16 | 17 | def require_version(version, hint): 18 | def decorator(fn): 19 | @functools.wraps(fn) 20 | def wrapper(ctx, *args, **kwargs): 21 | from transformers.utils.versions import require_version 22 | require_version(version, hint) 23 | return fn(ctx, 24 | *(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args), 25 | **{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}) 26 | return wrapper 27 | return decorator 28 | 29 | 30 | def checkpoint(func): 31 | def wrapper(*args, **kwargs): 32 | return torch.utils.checkpoint.checkpoint(func, *args, **kwargs) 33 | return wrapper 34 | -------------------------------------------------------------------------------- /proxy.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | """ 3 | 启动代理 4 | 我们用消息队列让每个子系统进行通信 5 | 如果不理解,建议查看zmq代理模式原理 6 | """ 7 | import multiprocessing 8 | import zmq 9 | 10 | from configuration import ClientConfig 11 | 12 | def start_proxy(frontend_url, backend_url): 13 | print(f'\033[91mstart proxy {frontend_url} {backend_url}\033[0m') 14 | context = zmq.Context() 15 | frontend = context.socket(zmq.ROUTER) 16 | frontend.bind(frontend_url) 17 | backend = context.socket(zmq.DEALER) 18 | backend.bind(backend_url) 19 | zmq.proxy(frontend, backend) 20 | 21 | 22 | def start_service_proxy(): 23 | project_config = ClientConfig('etc/proxy_service_config.yml', validate=False) 24 | config = project_config.config 25 | for service_name, service_config in config.items(): 26 | if 'front_end' in service_config and 'back_end' in service_config: 27 | backend_url = service_config["back_end"]["protocol"] + "://" + service_config["back_end"]["host"] + ":" + str( 28 | service_config["back_end"]["port"]) 29 | frontend_url = service_config["front_end"]["protocol"] + "://" + service_config["front_end"]["host"] + ":" + str( 30 | service_config["front_end"]["port"]) 31 | print( 32 | f"\033[91mStarting service {service_name}_service with backend url {backend_url} and frontend url {frontend_url}\033[0m") 33 | multiprocessing.Process(target=start_proxy, args=(frontend_url, backend_url)).start() 34 | print(f"\033[91mService {service_name}_service started\033[0m") 35 | 36 | 37 | if __name__ == '__main__': 38 | start_service_proxy() -------------------------------------------------------------------------------- /required/index_service_requirements.txt: -------------------------------------------------------------------------------- 1 | PyYAML==6.0.1 2 | pyzmq==26.0.3 3 | msgpack_python==0.5.6 4 | chromadb==0.5.13 -------------------------------------------------------------------------------- /required/llm_service_requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.26.4 2 | pandas==2.2.2 3 | torch==2.2.2 --index-url https://download.pytorch.org/whl/cu121 4 | FlagEmbedding==1.2.10 5 | PyYAML==6.0.1 6 | rwkv==0.8.26 7 | pyzmq==26.0.3 8 | msgpack_python==0.5.6 -------------------------------------------------------------------------------- /required/requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit==1.35.0 2 | pyzmq==26.0.3 3 | msgpack_python==0.5.6 4 | playwright==1.44.0 5 | PyMuPDF==1.24.4 6 | beautifulsoup4==4.12.3 7 | PyYAML==6.0.1 -------------------------------------------------------------------------------- /service.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | import multiprocessing 4 | import sys 5 | import argparse 6 | 7 | #from src.services import * 8 | from src.services import public_service_workers 9 | from src.services import AbstractServiceWorker 10 | from configuration import LLMServiceConfig, IndexServiceConfig 11 | 12 | 13 | 14 | def start_process(service_cls: "AbstractServiceWorker", backend_url: str,config: dict): 15 | service_instance = service_cls(backend_url,config) 16 | print(f"\033[93mStarting service worker {service_cls} with backend url {backend_url} at process {os.getpid()}\033[0m") 17 | service_instance.run() 18 | 19 | 20 | def start_service(service_cls :"AbstractServiceWorker", config:dict): 21 | name = service_cls.__name__ 22 | back_end = config.get("back_end", {}) 23 | protocol = back_end.get("protocol","tcp") 24 | host = back_end.get("host","0.0.0.0") 25 | port = back_end.get("port", '') 26 | backend_url = f"{protocol}://{host}:{port}" 27 | start_process(service_cls, backend_url, config) 28 | print(f"\033[91mService {name} started\033[0m") 29 | 30 | 31 | def main(service_name:str = None): 32 | services = public_service_workers.keys() 33 | if service_name: 34 | if service_name not in services: 35 | print(f"Service {service_name} not found, service_name must be one of {services}") 36 | return 37 | services = [service_name] 38 | 39 | print(f"Starting services {services}") 40 | for service_module_name in services: 41 | if service_module_name == 'llm_service': 42 | config_service = LLMServiceConfig(f'etc/{service_module_name}_config.yml').config 43 | elif service_module_name == 'index_service': 44 | config_service = IndexServiceConfig(f'etc/{service_module_name}_config.yml').config 45 | # elif service == 'tuning_service': 46 | # config_service = TuningServiceConfig(f'etc/{service}_config.yml').config 47 | else: 48 | continue 49 | print(f"Starting service {service_module_name}") 50 | # 类字符串名称 51 | class_name = public_service_workers.get(service_module_name, None) 52 | module = importlib.import_module(f'src.services.{service_module_name}') # 判断服务“类” 有没有导入成功 53 | service_cls = getattr(module, class_name) 54 | if service_cls: 55 | start_service(service_cls, config_service) 56 | 57 | 58 | if __name__ == "__main__": 59 | parser = argparse.ArgumentParser(description='RWKV-RAG-Service') 60 | parser.add_argument('--service_name', nargs='?', help='Service name',default=None) 61 | args = parser.parse_args() 62 | service_name = args.service_name 63 | if not service_name: 64 | service_name = os.environ.get('RWKV-RAG-SERVICE-NAME', None) 65 | main(service_name) -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIIRWKV/RWKV-RAG/0ccd84591af280f0d20efd932041bb00a4430fb2/src/__init__.py -------------------------------------------------------------------------------- /src/clients/__init__.py: -------------------------------------------------------------------------------- 1 | from .index_client import IndexClient 2 | from .llm_client import LLMClient 3 | from .files_service import FileStatusManager 4 | 5 | __all__ = [ 6 | "IndexClient", 7 | "LLMClient", 8 | "FileStatusManager" 9 | ] 10 | -------------------------------------------------------------------------------- /src/clients/index_client.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import msgpack 4 | import zmq 5 | 6 | 7 | class IndexClient: 8 | def __init__(self,frontend_url) -> None: 9 | self.context = zmq.Context() 10 | self.socket = self.context.socket(zmq.REQ) 11 | self.socket.connect(frontend_url) 12 | self.socket.setsockopt(zmq.RCVTIMEO, 60000 * 5) 13 | 14 | def index_config(self,config): 15 | cmd = {"cmd": "INDEX_CONFIG"} 16 | self.socket.send(msgpack.packb(cmd, use_bin_type=True)) 17 | msg = self.socket.recv() 18 | resp = msgpack.unpackb(msg, raw=False) 19 | return resp 20 | 21 | 22 | def index_texts(self,texts, embeddings:List[List[float]], keys=None,collection_name=None): 23 | cmd = {"cmd": "INDEX_TEXTS", 24 | "texts": texts, 25 | "embeddings": embeddings, 26 | 'collection_name':collection_name} 27 | self.socket.send(msgpack.packb(cmd, use_bin_type=True)) 28 | msg = self.socket.recv() 29 | resp = msgpack.unpackb(msg, raw=False) 30 | return resp 31 | 32 | def show_collection(self): 33 | cmd = {"cmd":'SHOW_COLLECTIONS'} 34 | self.socket.send(msgpack.packb(cmd, use_bin_type=True)) 35 | msg = self.socket.recv() 36 | resp = msgpack.unpackb(msg, raw=False) 37 | return resp 38 | 39 | def create_collection(self,collection_name=None): 40 | cmd = {"cmd":'CREATE_COLLECTION','collection_name':collection_name} 41 | self.socket.send(msgpack.packb(cmd, use_bin_type=True)) 42 | msg = self.socket.recv() 43 | resp = msgpack.unpackb(msg, raw=False) 44 | return resp 45 | 46 | def delete_collection(self,collection_name=None): 47 | cmd = {"cmd":'DELETE_COLLECTION','collection_name':collection_name} 48 | self.socket.send(msgpack.packb(cmd, use_bin_type=True)) 49 | msg = self.socket.recv() 50 | resp = msgpack.unpackb(msg, raw=False) 51 | return resp 52 | 53 | def search_nearby(self,embeddings,collection_name): 54 | cmd = {"cmd": "SEARCH_NEARBY", "embeddings": embeddings, 'collection_name':collection_name} 55 | self.socket.send(msgpack.packb(cmd, use_bin_type=True)) 56 | msg = self.socket.recv() 57 | resp = msgpack.unpackb(msg, raw=False) 58 | return resp 59 | -------------------------------------------------------------------------------- /src/clients/llm_client.py: -------------------------------------------------------------------------------- 1 | import msgpack 2 | import zmq 3 | 4 | class LLMClient: 5 | def __init__(self,url) -> None: 6 | self.context = zmq.Context() 7 | self.socket = self.context.socket(zmq.REQ) 8 | self.socket.connect(url) 9 | self.socket.setsockopt(zmq.RCVTIMEO, 60000 * 5) 10 | 11 | def llm_config(self): 12 | cmd = {"cmd": "LLM_CONFIG"} 13 | self.socket.send(msgpack.packb(cmd, use_bin_type=True)) 14 | msg = self.socket.recv() 15 | resp = msgpack.unpackb(msg, raw=False) 16 | return resp 17 | 18 | def encode(self,texts): 19 | cmd = {"cmd": "GET_EMBEDDINGS", "texts": texts} 20 | self.socket.send(msgpack.packb(cmd, use_bin_type=True)) 21 | msg = self.socket.recv() 22 | resp = msgpack.unpackb(msg, raw=False) 23 | return resp 24 | 25 | def cross_encode(self,texts_0,texts_1): 26 | cmd = {"cmd": "GET_CROSS_SCORES", "texts_0": texts_0,"texts_1": texts_1} 27 | self.socket.send(msgpack.packb(cmd, use_bin_type=True)) 28 | msg = self.socket.recv() 29 | resp = msgpack.unpackb(msg, raw=False) 30 | return resp 31 | 32 | def beam_generate(self, instruction, input_text, token_count=128, num_beams=5): 33 | cmd = { 34 | "cmd": "BEAM_GENERATE", 35 | "instruction": instruction, 36 | "input_text": input_text, 37 | "token_count": token_count, 38 | "num_beams":num_beams 39 | } 40 | self.socket.send(msgpack.packb(cmd, use_bin_type=True)) 41 | msg = self.socket.recv() 42 | resp = msgpack.unpackb(msg, raw=False) 43 | return resp 44 | 45 | def sampling_generate(self, instruction, input_text, state_file, token_count=1200, temperature=0.3, 46 | top_p=0.2,template_prompt=None, base_model_path=None): 47 | cmd = { 48 | "cmd": "SAMPLING_GENERATE", 49 | "instruction": instruction, 50 | "input_text": input_text, 51 | "token_count": token_count, 52 | "top_p": top_p, 53 | "state_file": state_file, 54 | "temperature": temperature, 55 | "template_prompt": template_prompt, 56 | "base_model_path": base_model_path 57 | 58 | } 59 | self.socket.send(msgpack.packb(cmd, use_bin_type=True)) 60 | msg = self.socket.recv() 61 | resp = msgpack.unpackb(msg, raw=False) 62 | return resp 63 | 64 | 65 | def reload_base_model(self, base_model_path): 66 | cmd = { 67 | "cmd": "RELOAD_BASE_MODEL", 68 | "base_model_path": base_model_path 69 | } 70 | self.socket.send(msgpack.packb(cmd, use_bin_type=True)) 71 | self.socket.recv() 72 | return True 73 | -------------------------------------------------------------------------------- /src/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .singleton import SingletonMeta 2 | 3 | __all__ = ['SingletonMeta'] -------------------------------------------------------------------------------- /src/core/singleton.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import threading 3 | 4 | class SingletonMeta(type): 5 | _instances = {} 6 | _lock = threading.Lock() 7 | 8 | def __call__(cls, *args, **kwargs): 9 | if cls not in cls._instances: 10 | with cls._lock: 11 | if cls not in cls._instances: 12 | instance = super().__call__(*args, **kwargs) 13 | cls._instances[cls] = instance 14 | return cls._instances[cls] -------------------------------------------------------------------------------- /src/services/__init__.py: -------------------------------------------------------------------------------- 1 | from .abc import AbstractServiceWorker 2 | #from .index_service import IndexServiceWorker 3 | #from .llm_service import LLMServiceWorker, LLMService 4 | 5 | public_service_workers = { 6 | 'index_service': 'IndexServiceWorker', 7 | 'llm_service': 'LLMServiceWorker', 8 | } 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /src/services/abc/__init__.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import os 3 | from abc import ABC, abstractmethod 4 | 5 | import zmq 6 | import msgpack 7 | 8 | 9 | class AbstractServiceWorker(ABC): 10 | UNSUPPORTED_COMMAND = 'Unsupported command' 11 | 12 | def __init__(self, backend_url, config): 13 | self.init_with_config(config) 14 | self.backend_url = backend_url 15 | self.context = zmq.Context() 16 | self.socket = self.context.socket(zmq.REP) 17 | self.socket.connect(backend_url) 18 | self.service_config = config # 服务的配置 19 | print( 20 | f"\033[93m Service worker {self.__class__.__name__} connected to {backend_url} at process {os.getpid()}\033[0m") 21 | 22 | @abstractmethod 23 | def init_with_config(self, config): 24 | pass 25 | 26 | @abstractmethod 27 | def process(self, cmd): 28 | pass 29 | 30 | def run(self): 31 | while True: 32 | message = self.socket.recv() 33 | cmd = msgpack.unpackb(message, raw=False) 34 | try: 35 | resp = self.process(cmd) 36 | if resp == AbstractServiceWorker.UNSUPPORTED_COMMAND: 37 | resp = {"code": 400, "error": "Unsupported command"} 38 | else: 39 | resp = {"code": 200, "value": resp} 40 | self.socket.send(msgpack.packb(resp, use_bin_type=True)) 41 | except Exception as e: 42 | resp = {"code": 400, "error": str(e)} 43 | self.socket.send(msgpack.packb(resp, use_bin_type=True)) 44 | 45 | -------------------------------------------------------------------------------- /src/services/index_service.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Lock 2 | 3 | from src.services import AbstractServiceWorker 4 | from src.vectordb import INIT_VECTORDB_COLLECTION_NAME, VectorDBError 5 | 6 | 7 | class IndexServiceWorker(AbstractServiceWorker): 8 | lock = Lock() 9 | 10 | def init_with_config(self, config: dict): 11 | # 向量数据库相关配置 12 | self.vectordb_name = config.get("vectordb_name") 13 | if not self.vectordb_name: 14 | self.vectordb_name = "chromadb" 15 | self.vectordb_port = config.get("vectordb_port") 16 | self.vectordb_host = config.get("vectordb_host", ) 17 | self.vectordb_manager = None # 管理器 18 | self.init_once() 19 | self.init_vectordb_db() 20 | 21 | def init_vectordb_db(self): 22 | """ 23 | Init the vectordb 24 | """ 25 | if self.lock.acquire(False): # TODO 集群模式的话需加分布式锁 26 | try: 27 | manager = self.vectordb_manager 28 | if not manager.has_collection(INIT_VECTORDB_COLLECTION_NAME): 29 | manager.create_collection(INIT_VECTORDB_COLLECTION_NAME) 30 | print(f"{self.vectordb_name} collection {INIT_VECTORDB_COLLECTION_NAME} is created") 31 | print(f"{self.vectordb_name} collection {INIT_VECTORDB_COLLECTION_NAME} is ready") 32 | finally: 33 | self.lock.release() 34 | 35 | def init_once(self): 36 | if self.vectordb_name == 'chromadb': 37 | from src.vectordb import ChromaDBManager 38 | self.vectordb_manager = ChromaDBManager(self.vectordb_port, self.vectordb_host) 39 | else: 40 | raise VectorDBError(f'暂时不支持向量数据库类型:{self.vectordb_name}') 41 | 42 | def cmd_index_texts(self, cmd: dict): 43 | return self.vectordb_manager.add(cmd) 44 | 45 | def cmd_show_collections(self, cmd: dict): 46 | return self.vectordb_manager.show_collections(cmd) 47 | 48 | def cmd_create_collection(self, cmd: dict): 49 | collection_name = cmd['collection_name'] 50 | return self.vectordb_manager.create_collection(collection_name) 51 | 52 | def cmd_delete_collection(self, cmd: dict): 53 | collection_name = cmd['collection_name'] 54 | return self.vectordb_manager.delete_collection(collection_name) 55 | 56 | def cmd_search_nearby(self, cmd: dict): 57 | return self.vectordb_manager.search_nearby(cmd) 58 | 59 | def process(self, cmd: dict): 60 | cmd_name = cmd.get('cmd', '').lower() 61 | function_name = f'cmd_{cmd_name}' 62 | if hasattr(self, function_name) and callable(getattr(self, function_name)): 63 | return getattr(self, function_name)(cmd) 64 | return IndexServiceWorker.UNSUPPORTED_COMMAND -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIIRWKV/RWKV-RAG/0ccd84591af280f0d20efd932041bb00a4430fb2/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/internet.py: -------------------------------------------------------------------------------- 1 | from playwright.async_api import async_playwright 2 | from bs4 import BeautifulSoup 3 | import os 4 | 5 | async def search_on_baike(query, output_directory='.', filename=None): 6 | 7 | if filename is None: 8 | filename = f'{query}.txt' 9 | filepath = os.path.join(output_directory, filename) 10 | 11 | async with async_playwright() as p: 12 | browser = await p.chromium.launch(headless=True) 13 | page = await browser.new_page() 14 | 15 | await page.goto('https://baike.baidu.com/') 16 | 17 | await page.wait_for_selector('#root > div > div.index-module_pageHeader__jSG5w > div.lemmaSearchBarWrapper.undefined > div > div > div > div > input', timeout=5000) 18 | await page.fill('#root > div > div.index-module_pageHeader__jSG5w > div.lemmaSearchBarWrapper.undefined > div > div > div > div > input', query) 19 | await page.click('#root > div > div.index-module_pageHeader__jSG5w > div.lemmaSearchBarWrapper.undefined > div > div > div > button.lemmaBtn') 20 | 21 | await page.wait_for_timeout(5000) 22 | html_content = await page.content() 23 | soup = BeautifulSoup(html_content, 'html.parser') 24 | content_div = soup.find('div', {'class': 'J-lemma-content'}) 25 | if content_div: 26 | content_text = content_div.get_text().strip() 27 | 28 | # Save content_text to the specified file 29 | with open(filepath, 'w', encoding='utf-8') as f: 30 | f.write(content_text) 31 | return '' 32 | else: 33 | return "百度百科没有该词条简介,请重新输入关键词" 34 | -------------------------------------------------------------------------------- /src/utils/tools.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import hashlib 3 | from typing import Union 4 | 5 | 6 | def calculate_string_md5(text: Union[str, bytes]): 7 | """ 8 | 计算字符串的md5值 9 | """ 10 | if isinstance(text, str): 11 | text = text.encode('utf-8') 12 | md5_hash = hashlib.md5() 13 | md5_hash.update(text) 14 | return md5_hash.hexdigest() -------------------------------------------------------------------------------- /src/vectordb/__init__.py: -------------------------------------------------------------------------------- 1 | #coding: utf-8 2 | """ 3 | 向量数据库适配层 4 | """ 5 | from .abc import AbstractVectorDBManager 6 | from .abc import VECTOR_DB_DIMENSION 7 | from .abc import VECTORDB_USED_LIMIT 8 | from .abc import RECALL_NUMBER 9 | from .abc import INIT_VECTORDB_COLLECTION_NAME 10 | from .abc import TEXT_MAX_LENGTH 11 | from .chroma import ChromaDBManager 12 | from .errors import VectorDBError, VectorDBCollectionNotExistError, VectorDBCollectionExistError -------------------------------------------------------------------------------- /src/vectordb/abc/__init__.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | from typing import List 3 | from abc import ABC, abstractmethod 4 | 5 | VECTOR_DB_DIMENSION = 1024 # 向量维度 6 | 7 | TEXT_MAX_LENGTH = 512 # 单个文本Embedding最大长度 8 | 9 | RECALL_NUMBER = 3 # 召回数量 10 | INIT_VECTORDB_COLLECTION_NAME = 'initial' 11 | VECTORDB_USED_LIMIT = {'linux': ['chromadb', 'milvus_lite'], 12 | 'windows': ['chromadb'] 13 | } 14 | 15 | 16 | 17 | class AbstractVectorDBManager(ABC): 18 | 19 | def __init__(self, db_port: int, db_host: str): 20 | self.db_port = db_port 21 | self.db_host = db_host 22 | self._client = None 23 | 24 | @abstractmethod 25 | def client(self): 26 | """ 27 | 初始化数据库连接 28 | :return: 29 | """ 30 | pass 31 | 32 | @abstractmethod 33 | def show_collections(self, page: int=None, page_size: int=None): 34 | """ 35 | 集合列表 36 | :param page: 37 | :param page_size: 38 | :return: 39 | """ 40 | 41 | @abstractmethod 42 | def has_collection(self, collection_name: str) -> bool: 43 | """ 44 | 判断集合是否存在 45 | :param collection_name: 46 | :return: 47 | """ 48 | 49 | @abstractmethod 50 | def create_collection(self, collection_name: str): 51 | """ 52 | 创建集合 53 | :param collection_name: 54 | :return: 55 | """ 56 | 57 | @abstractmethod 58 | def delete_collection(self, collection_name: str): 59 | """ 60 | 删除集合 61 | :param collection_name: 62 | :return: 63 | """ 64 | 65 | @abstractmethod 66 | def add(self, kwargs: dict)-> List[str]: 67 | """ 68 | 添加向量 69 | :param kwargs:必须有如下键 70 | keys: List[(str)] 71 | texts: List[(str)] 72 | collection_name: str 73 | embeddings: List[numpy.ndarray[numpy.float16]] 74 | :return: uuid列表 75 | """ 76 | 77 | @abstractmethod 78 | def search_nearby(self, kwargs: dict) -> list[str]: 79 | """ 80 | 搜索向量 81 | :param kwargs:必须有如下键: 82 | collection_name: str 83 | embeddings: List[numpy.ndarray[numpy.float16]] 84 | :return: 85 | """ 86 | 87 | @staticmethod 88 | def padding_vectors(vector: list): 89 | if len(vector) < VECTOR_DB_DIMENSION: 90 | vector += [0] * (VECTOR_DB_DIMENSION - len(vector)) 91 | elif len(vector) > VECTOR_DB_DIMENSION: 92 | vector = vector[:VECTOR_DB_DIMENSION] 93 | return vector -------------------------------------------------------------------------------- /src/vectordb/chroma.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | """ 3 | Windows Linux都支持 4 | """ 5 | from datetime import datetime 6 | 7 | from src.vectordb import RECALL_NUMBER 8 | from src.vectordb import AbstractVectorDBManager 9 | from src.utils.tools import calculate_string_md5 10 | from .errors import VectorDBCollectionNotExistError, VectorDBError 11 | 12 | 13 | class ChromaDBManager(AbstractVectorDBManager): 14 | 15 | def client(self): 16 | import chromadb 17 | if self._client is None: 18 | print(self.db_host) 19 | try: 20 | self._client = chromadb.HttpClient(host=self.db_host, 21 | port=self.db_port) 22 | return self._client 23 | except Exception as e: 24 | raise VectorDBError('连接Chroma服务失败') 25 | return self._client 26 | 27 | 28 | def has_collection(self, collection_name: str) -> bool: 29 | chroma_client = self.client() 30 | try: 31 | collection = chroma_client.get_collection(collection_name) 32 | except: 33 | return False 34 | if collection: 35 | return True 36 | return False 37 | 38 | def show_collections(self, page: int=None, page_size: int=None): 39 | chroma_client = self.client() 40 | offset = (page - 1) * page_size if page is not None and page_size is not None else None 41 | collections = chroma_client.list_collections(page_size, offset) 42 | return [(i.name, i.metadata) for i in collections] if collections else [] 43 | 44 | def create_collection(self, collection_name: str): 45 | now = datetime.now().strftime("%Y-%m-%d %H:%M:%S") 46 | client = self.client() 47 | client.get_or_create_collection(collection_name, 48 | metadata={"hnsw:space": "cosine", 49 | "create_time": now}) 50 | return True 51 | 52 | def delete_collection(self, collection_name: str): 53 | 54 | client = self.client() 55 | try: 56 | client.delete_collection(collection_name) 57 | except: 58 | raise VectorDBCollectionNotExistError() 59 | return True 60 | 61 | def add(self,kwargs:dict): 62 | keys = kwargs.get("keys") 63 | values = kwargs["texts"] 64 | collection_name = kwargs.get('collection_name') 65 | embeddings = kwargs.get('embeddings') 66 | 67 | if keys is None or isinstance(keys, list) is False or len(keys) != len(values): 68 | keys = [calculate_string_md5(value) for value in values] 69 | client = self.client() 70 | new_embeddings = [eb for eb in embeddings] 71 | try: 72 | collection = client.get_collection(collection_name) 73 | except: 74 | raise VectorDBCollectionNotExistError() 75 | collection.add( 76 | ids=keys, 77 | embeddings=new_embeddings, 78 | documents=values 79 | ) 80 | # index the value 81 | return keys 82 | 83 | def search_nearby(self, kwargs: dict): 84 | collection_name = kwargs.get('collection_name') 85 | embeddings = kwargs.get('embeddings') 86 | client = self.client() 87 | try: 88 | collection = client.get_collection(collection_name) 89 | except: 90 | raise VectorDBCollectionNotExistError() 91 | search_result = collection.query( 92 | query_embeddings=embeddings, 93 | n_results=RECALL_NUMBER, 94 | include=['documents']) 95 | return search_result['documents'][0] 96 | 97 | -------------------------------------------------------------------------------- /src/vectordb/errors.py: -------------------------------------------------------------------------------- 1 | #coding: utf-8 2 | 3 | class VectorDBError(Exception): 4 | """ 5 | 向量数据库自定义异常类 6 | """ 7 | def __init__(self, message="未知错误"): 8 | self.message = message 9 | super().__init__(self.message) 10 | 11 | def __str__(self): 12 | return self.message 13 | 14 | 15 | class VectorDBCollectionNotExistError(VectorDBError): 16 | """ 17 | 集合不存在 18 | """ 19 | def __init__(self, message="集合不存在"): 20 | self.message = message 21 | super().__init__(self.message) 22 | 23 | 24 | class VectorDBCollectionExistError(VectorDBError): 25 | """ 26 | 集合已存在 27 | """ 28 | def __init__(self, message="集合已存在"): 29 | self.message = message 30 | super().__init__(self.message) -------------------------------------------------------------------------------- /tokenizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIIRWKV/RWKV-RAG/0ccd84591af280f0d20efd932041bb00a4430fb2/tokenizer/__init__.py -------------------------------------------------------------------------------- /tokenizer/rwkv_tokenizer.py: -------------------------------------------------------------------------------- 1 | ######################################################################################################## 2 | # The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM 3 | ######################################################################################################## 4 | import os 5 | class TRIE: 6 | __slots__ = tuple("ch,to,values,front".split(",")) 7 | to:list 8 | values:set 9 | def __init__(self, front=None, ch=None): 10 | self.ch = ch 11 | self.to = [None for ch in range(256)] 12 | self.values = set() 13 | self.front = front 14 | 15 | def __repr__(self): 16 | fr = self 17 | ret = [] 18 | while(fr!=None): 19 | if(fr.ch!=None): 20 | ret.append(fr.ch) 21 | fr = fr.front 22 | return ""%(ret[::-1], self.values) 23 | 24 | def add(self, key:bytes, idx:int=0, val=None): 25 | if(idx == len(key)): 26 | if(val is None): 27 | val = key 28 | self.values.add(val) 29 | return self 30 | ch = key[idx] 31 | if(self.to[ch] is None): 32 | self.to[ch] = TRIE(front=self, ch=ch) 33 | return self.to[ch].add(key, idx=idx+1, val=val) 34 | 35 | def find_longest(self, key:bytes, idx:int=0): 36 | u:TRIE = self 37 | ch:int = key[idx] 38 | 39 | while(u.to[ch] is not None): 40 | u = u.to[ch] 41 | idx += 1 42 | if(u.values): 43 | ret = idx, u, u.values 44 | if(idx==len(key)): 45 | break 46 | ch = key[idx] 47 | return ret 48 | 49 | class TRIE_TOKENIZER(): 50 | def __init__(self, file_name=''): # 默认使用该目录下的文件 51 | if not file_name: 52 | current_dir = os.path.dirname(os.path.abspath(__file__)) 53 | file_name = os.path.join(current_dir, 'rwkv_vocab_v20230424.txt') 54 | self.idx2token = {} 55 | sorted = [] # must be already sorted 56 | with open(file_name, "r", encoding="utf-8") as f: 57 | lines = f.readlines() 58 | for l in lines: 59 | idx = int(l[:l.index(' ')]) 60 | x = eval(l[l.index(' '):l.rindex(' ')]) 61 | x = x.encode("utf-8") if isinstance(x, str) else x 62 | assert isinstance(x, bytes) 63 | assert len(x) == int(l[l.rindex(' '):]) 64 | sorted += [x] 65 | self.idx2token[idx] = x 66 | 67 | self.token2idx = {} 68 | for k,v in self.idx2token.items(): 69 | self.token2idx[v] = int(k) 70 | 71 | self.root = TRIE() 72 | for t, i in self.token2idx.items(): 73 | _ = self.root.add(t, val=(t, i)) 74 | 75 | def encodeBytes(self, src:bytes): 76 | idx:int = 0 77 | tokens = [] 78 | while (idx < len(src)): 79 | _idx:int = idx 80 | idx, _, values = self.root.find_longest(src, idx) 81 | assert(idx != _idx) 82 | _, token = next(iter(values)) 83 | tokens.append(token) 84 | return tokens 85 | 86 | def decodeBytes(self, tokens): 87 | return b''.join(map(lambda i: self.idx2token[i], tokens)) 88 | 89 | def encode(self, src): 90 | return self.encodeBytes(src.encode("utf-8")) 91 | 92 | def decode(self, tokens): 93 | try: 94 | return self.decodeBytes(tokens).decode('utf-8') 95 | except: 96 | return '\ufffd' # bad utf-8 97 | 98 | def printTokens(self, tokens): 99 | for i in tokens: 100 | s = self.idx2token[i] 101 | try: 102 | s = s.decode('utf-8') 103 | except: 104 | pass 105 | print(f'{repr(s)}{i}', end=' ') 106 | print() 107 | --------------------------------------------------------------------------------