├── .gitignore ├── LICENSE ├── README.md ├── README_zh-CN.md ├── assets ├── logo.png ├── scaling_law.svg ├── sentence_vae.svg └── sllm.svg ├── config ├── SLLM │ ├── SLLM-1.3b │ │ ├── sllm_1.3b_base.yaml │ │ ├── sllm_1.3b_h1_all.yaml │ │ ├── sllm_1.3b_h2_all.yaml │ │ └── sllm_1.3b_h4_all.yaml │ ├── SLLM-125m │ │ ├── sllm_125m_base.yaml │ │ ├── sllm_125m_h1_all.yaml │ │ ├── sllm_125m_h2_all.yaml │ │ └── sllm_125m_h4_all.yaml │ ├── SLLM-350m │ │ ├── sllm_350m_base.yaml │ │ ├── sllm_350m_h1_all.yaml │ │ ├── sllm_350m_h2_all.yaml │ │ └── sllm_350m_h4_all.yaml │ └── sllm_base.yaml └── SVAE │ ├── SVAE-1024 │ ├── svae_1024_base.yaml │ ├── svae_1024_h1.yaml │ ├── svae_1024_h2.yaml │ └── svae_1024_h4.yaml │ ├── SVAE-2048 │ ├── svae_2048_base.yaml │ ├── svae_2048_h1.yaml │ ├── svae_2048_h2.yaml │ └── svae_2048_h4.yaml │ ├── SVAE-768 │ ├── svae_768_base.yaml │ ├── svae_768_h1.yaml │ ├── svae_768_h2.yaml │ └── svae_768_h4.yaml │ └── svae_base.yaml ├── requirements.txt ├── sentence_vae ├── __init__.py ├── data │ ├── __init__.py │ ├── data_collate.py │ ├── data_eval.py │ └── tele_ds_dataset.py ├── models │ ├── __init__.py │ ├── focal_loss.py │ ├── positional_encoding.py │ ├── sentence_decoder.py │ ├── sentence_encoder.py │ ├── sentence_llm_model.py │ └── sentence_vae_model.py └── utils │ ├── __init__.py │ ├── config.py │ ├── http.py │ ├── llm.py │ └── weights.py ├── setup.py └── tools ├── analysis ├── analysis_llm.py ├── analysis_sllm.py ├── analysis_svae.py ├── statistic_dataset.py └── statistic_vocab.py ├── benchmark ├── benchmark_llm.py └── benchmark_sllm.py ├── demo ├── demo_llm.py ├── demo_sllm.py └── demo_svae.py ├── eval ├── eval_llm.py ├── eval_sllm.py └── eval_svae.py └── train ├── train_sllm.py └── train_svae.py /.gitignore: -------------------------------------------------------------------------------- 1 | ### Linux ### 2 | *~ 3 | 4 | # user experiments directory 5 | model_repo* 6 | datasets/ 7 | exp* 8 | test* 9 | temp_config/ 10 | 11 | # temporary files which can be created if a process still has a handle open of a deleted file 12 | .fuse_hidden* 13 | 14 | # KDE directory preferences 15 | .directory 16 | 17 | # Linux trash folder which might appear on any partition or disk 18 | .Trash-* 19 | 20 | # .nfs files are created when an open file is removed but is still being accessed 21 | .nfs* 22 | 23 | ### PyCharm ### 24 | # User-specific stuff 25 | .idea 26 | 27 | # CMake 28 | cmake-build-*/ 29 | 30 | # Mongo Explorer plugin 31 | .idea/**/mongoSettings.xml 32 | 33 | # File-based project format 34 | *.iws 35 | 36 | # IntelliJ 37 | out/ 38 | 39 | # mpeltonen/sbt-idea plugin 40 | .idea_modules/ 41 | 42 | # JIRA plugin 43 | atlassian-ide-plugin.xml 44 | 45 | # Cursive Clojure plugin 46 | .idea/replstate.xml 47 | 48 | # Crashlytics plugin (for Android Studio and IntelliJ) 49 | com_crashlytics_export_strings.xml 50 | crashlytics.properties 51 | crashlytics-build.properties 52 | fabric.properties 53 | 54 | # Editor-based Rest Client 55 | .idea/httpRequests 56 | 57 | # Android studio 3.1+ serialized cache file 58 | .idea/caches/build_file_checksums.ser 59 | 60 | # JetBrains templates 61 | **___jb_tmp___ 62 | 63 | ### Python ### 64 | # Byte-compiled / optimized / DLL files 65 | __pycache__/ 66 | *.py[cod] 67 | *$py.class 68 | 69 | # C extensions 70 | *.so 71 | 72 | # Distribution / packaging 73 | .Python 74 | build/ 75 | develop-eggs/ 76 | dist/ 77 | downloads/ 78 | eggs/ 79 | .eggs/ 80 | lib/ 81 | lib64/ 82 | parts/ 83 | sdist/ 84 | var/ 85 | wheels/ 86 | pip-wheel-metadata/ 87 | share/python-wheels/ 88 | *.egg-info/ 89 | .installed.cfg 90 | *.egg 91 | MANIFEST 92 | 93 | # PyInstaller 94 | # Usually these files are written by a python script from a template 95 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 96 | *.manifest 97 | *.spec 98 | 99 | # Installer logs 100 | pip-log.txt 101 | pip-delete-this-directory.txt 102 | 103 | # Unit test / coverage reports 104 | htmlcov/ 105 | .tox/ 106 | .nox/ 107 | .coverage 108 | .coverage.* 109 | .cache 110 | nosetests.xml 111 | coverage.xml 112 | *.cover 113 | .hypothesis/ 114 | .pytest_cache/ 115 | 116 | # Translations 117 | *.mo 118 | *.pot 119 | 120 | # Django stuff: 121 | *.log 122 | local_settings.py 123 | db.sqlite3 124 | 125 | # Flask stuff: 126 | instance/ 127 | .webassets-cache 128 | 129 | # Scrapy stuff: 130 | .scrapy 131 | 132 | # Sphinx documentation 133 | docs/_build/ 134 | docs/build/ 135 | 136 | # PyBuilder 137 | target/ 138 | 139 | # Jupyter Notebook 140 | .ipynb_checkpoints 141 | 142 | # IPython 143 | profile_default/ 144 | ipython_config.py 145 | 146 | # pyenv 147 | .python-version 148 | 149 | # pipenv 150 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 151 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 152 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 153 | # install all needed dependencies. 154 | #Pipfile.lock 155 | 156 | # celery beat schedule file 157 | celerybeat-schedule 158 | 159 | # SageMath parsed files 160 | *.sage.py 161 | 162 | # Environments 163 | .env 164 | .venv 165 | env/ 166 | venv/ 167 | ENV/ 168 | env.bak/ 169 | venv.bak/ 170 | 171 | # Spyder project settings 172 | .spyderproject 173 | .spyproject 174 | 175 | # Rope project settings 176 | .ropeproject 177 | 178 | # mkdocs documentation 179 | /site 180 | 181 | # mypy 182 | .mypy_cache/ 183 | .dmypy.json 184 | dmypy.json 185 | 186 | # Pyre type checker 187 | .pyre/ 188 | 189 | ### Vim ### 190 | # Swap 191 | [._]*.s[a-v][a-z] 192 | [._]*.sw[a-p] 193 | [._]s[a-rt-v][a-z] 194 | [._]ss[a-gi-z] 195 | [._]sw[a-p] 196 | 197 | # Session 198 | Session.vim 199 | 200 | # Temporary 201 | .netrwhist 202 | # Auto-generated tag files 203 | tags 204 | # Persistent undo 205 | [._]*.un~ 206 | 207 | # output 208 | docs/api 209 | .code-workspace.code-workspace 210 | *.pkl 211 | *.npy 212 | *.pth 213 | *.onnx 214 | *.engine 215 | events.out.tfevents* 216 | 217 | # vscode 218 | *.code-workspace 219 | .vscode 220 | 221 | # vim 222 | .vim 223 | 224 | # OS generated files 225 | .DS_Store 226 | .DS_Store? 227 | .Trashes 228 | ehthumbs.db 229 | Thumbs.db 230 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT LICENSE 2 | 3 | Copyright (C) 2024 School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), Northwestern PolyTechnical University, and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 4 | ![Static Badge](https://img.shields.io/badge/license-MIT-green) ![Static Badge](https://img.shields.io/badge/version-0.0.1-blue) [![Static Badge](https://img.shields.io/badge/paper-arXiv-red)](https://arxiv.org/abs/2408.00655) 5 | 6 | Hongjun An1,2*,Yifan Chen1,2*,Zhe Sun1,2✉ & Xuelong Li1,2✉ 7 | 8 | 1School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), Northwestern PolyTechnical University 9 | 10 | 2Institute of Artificial Intelligence (TeleAI), China Telecom 11 | 12 | 13 | English | [简体中文](README_zh-CN.md) 14 | 15 |
16 | 17 | # 1.Introduction 18 | 19 | Current large language models (LLMs) primarily utilize next-token prediction method for inference, which significantly impedes their processing speed. In this [paper](https://arxiv.org/abs/2408.00655), we introduce a novel inference methodology termed next-sentence prediction, aiming at enhancing the inference efficiency of LLMs. We present Sentence Variational Autoencoder (SentenceVAE), which includes a Sentence Encoder to compress multiple tokens in a sentence into a single token, and a Sentence Decoder to reconstruct it. 20 | 21 |
22 |
23 | Fig. 1. The schematic form of SentenceVAE. 24 |
25 | 26 | 27 | By integrating SentenceVAE into the input and output layers of LLMs, we develop Sentence-level LLMs (SLLMs) that employ a sentence-by-sentence inference method. 28 | 29 |
30 |
31 | Fig. 2. (a) The schematic form of published LLMs. (b) The schematic form of SLLMs, which embedded with SentenceVAEs. 32 |
33 | 34 | The SLLMs can maintain the integrity of the original semantic content by segmenting the context into sentences, thereby improving accuracy while boosting inference speed. Moreover, compared to previous LLMs, SLLMs process fewer tokens over equivalent context length, significantly reducing memory demands for self-attention computation and facilitating the handling of longer context. Extensive experiments on [Wanjuan dataset](https://github.com/opendatalab/WanJuan1.0/) have revealed that the proposed method can accelerate inference speed by 204 ~ 365%, reduce perplexity (PPL) to 46 ~ 75% of its original metric, and decrease memory overhead by 86 ~ 91% for the equivalent context length, compared to previous token-by-token methods. 35 | 36 |
37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 |
ModelTotal ParamsAverage PPLMean output throughput (toks/s)Mean GPU memory (KB/token)
OPT↓SLLM↓Δ↓OPT↑SLLM↑Δ↑OPT↓SLLM↓Δ↓
SLLM-125M-H1214M26.7531.68+18.4%214.57652.78+204.2%73.1512.03-83.6%
SLLM-125M-H2226M44.60+66.7%539.80+151.6%7.08-90.3%
SLLM-125M-H4250M14.32-46.5%332.12+54.8%10.00-86.3%
SLLM-350M-H1429M25.1824.84-1.4%144.33481.39+233.5%197.5929.98-84.8%
SLLM-350M-H2450M14.81-41.2%442.23+206.4%26.78-86.4%
SLLM-350M-H4492M10.17-59.6%315.61+118.7%17.73-91.0%
SLLM-1.3B-H11.61B15.958.76-45.1%119.07479.71+302.9%400.0157.07-85.7%
SLLM-1.3B-H21.69B3.84-75.9%553.95+365.2%55.14-86.2%
146 |
147 | 148 | In addition, by corroborating the Scaling Law, we extrapolated the feasibility of our methodologies to larger-scale models. 149 | 150 |
151 |
152 | Fig. 3. Scaling Law of (a) SLLMs and (b) SVAEs. 153 |
154 | 155 | # 2.Quick Start 156 | 157 |
158 | Installation 159 | 160 | Step1. Install SentenceVAE from source. 161 | 162 | ```sh 163 | git clone https://github.com/BestAnHongjun/SentenceVAE.git 164 | cd SentenceVAE 165 | pip3 install -e . # or python3 setup.py develop 166 | ``` 167 | 168 |
169 | 170 |
171 | Prepare OPT models 172 | 173 | Step1. Create a folder named `model_repo` under `SentenceVAE` to save OPT series models. 174 | 175 | ```sh 176 | cd SentenceVAE 177 | mkdir -p model_repo 178 | ``` 179 | 180 | Step2. Navigate to the `model_repo` directory with `cd` and initialize [`git-lfs`](https://git-lfs.com). 181 | 182 | ```sh 183 | cd model_repo 184 | git lfs install 185 | ``` 186 | 187 | Step3. Download [OPT-125M](https://huggingface.co/facebook/opt-125m) model for SentenceVAE-768 series and SLLM-125M series. 188 | 189 | ```sh 190 | git clone https://huggingface.co/facebook/opt-125m 191 | ``` 192 | 193 | Step4. Download [OPT-350M](https://huggingface.co/facebook/opt-350m) model for SentenceVAE-1024 series and SLLM-350M series. 194 | 195 | ```sh 196 | git clone https://huggingface.co/facebook/opt-350m 197 | ``` 198 | 199 | Step5. Download [OPT-1.3B](https://huggingface.co/facebook/opt-1.3b) model for Sentence-2048 series and SLLM-1.3B series. 200 | 201 | ```sh 202 | git clone https://huggingface.co/facebook/opt-1.3b 203 | ``` 204 | 205 |
206 | 207 |
208 | SentenceVAE Demo 209 | 210 | Step1. Download a pretrained model from table below. 211 | 212 |
213 | 214 | |Model|Hidden Size|Hidden Layers|Loss↓|PPL↓|Download Link| 215 | |:-:|:-:|:-:|:-:|:-:|:-:| 216 | |SVAE-768-H1|768|1|1.339|3.605|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceVAE/resolve/master/SVAE-768-H1.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceVAE/main?filepath=SVAE-768-H1.pth&sign=d8239cd4b6979b61ee0b969ef54f1a78&nonce=1723800234481)| 217 | |SVAE-768-H2|768|2|1.019|2.588|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceVAE/resolve/master/SVAE-768-H2.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceVAE/main?filepath=SVAE-768-H2.pth&sign=c11ca77f7934d4b441e7a6ae5359157f&nonce=1723800264673)| 218 | |SVAE-768-H4|768|4|**0.5598**|**1.649**|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceVAE/resolve/master/SVAE-768-H4.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceVAE/main?filepath=SVAE-768-H4.pth&sign=829995892eba42caf3f28a0f77a28d9e&nonce=1723800281621)| 219 | |SVAE-1024-H1|1024|1|0.9266|2.406|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceVAE/resolve/master/SVAE-1024-H1.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceVAE/main?filepath=SVAE-1024-H1.pth&sign=b3d5202c64d117389131def1b35e2f33&nonce=1723800301123)| 220 | |SVAE-1024-H2|1024|2|0.6610|1.845|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceVAE/resolve/master/SVAE-1024-H2.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceVAE/main?filepath=SVAE-1024-H2.pth&sign=f4ba19e8f474068598f8186be14a7ab4&nonce=1723800319623)| 221 | |SVAE-1024-H4|1024|4|**0.3704**|**1.384**|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceVAE/resolve/master/SVAE-1024-H4.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceVAE/main?filepath=SVAE-1024-H4.pth&sign=6ff668d782383e4bf01d1337a98910b3&nonce=1723800343431)| 222 | |SVAE-2048-H1|2048|1|0.5165|1.622|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceVAE/resolve/master/SVAE-2048-H1.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceVAE/main?filepath=SVAE-2048-H1.pth&sign=4d1ef8d0d0cf0f48e406eb73d74bb5cf&nonce=1723800363566)| 223 | |SVAE-2048-H2|2048|2|0.2845|1.292|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceVAE/resolve/master/SVAE-2048-H2.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceVAE/main?filepath=SVAE-2048-H2.pth&sign=7cc09034413bcfdb0fcbd875e6ad4be4&nonce=1723800379541)| 224 | |SVAE-2048-H4|2048|4|**0.1270**|**1.115**|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceVAE/resolve/master/SVAE-2048-H4.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceVAE/main?filepath=SVAE-2048-H4.pth&sign=8d09686dbcd6aaf0aeaf70a537de1836&nonce=1723800393625)| 225 | 226 |
227 | 228 | Step2. Run demo script under `tools/demo` folder. Here's an example: 229 | 230 | ```sh 231 | cd SentenceVAE 232 | 233 | python3 tools/demo/demo_svae.py \ 234 | -c config/SVAE/SVAE-768/svae_768_h4.yaml \ 235 | --checkpoint /path/to/pretrained/checkpoint \ 236 | --input "What's your name?" 237 | ``` 238 | 239 | **Arguments**: 240 | * `-c`,`--config`: path to the corresponding configuration file, please reference [this folder](config/SVAE/). 241 | * `--checkpoint`: path to the checkpoint file you just downloaded. 242 | * `--input`: A sentence you want to test. 243 | * It must be a separate sentence ending with punctuation marks such as commas, periods, etc. Please refer to the [paper](https://arxiv.org/abs/2408.00655) for specific reasons. 244 | * Currently, only English is supported. 245 | 246 | The model will compress this sentence into a single vector, decode and restore it for output. In an ideal state, the output and input should be consistent. 247 | 248 |
249 | 250 |
251 | 252 | SentenceLLM Demo 253 | 254 | **Notice**: Please be aware that, as SFT datasets are typically commercial secrets and difficult for us to access, all the models listed below are **pre-trained models**, not general-purpose conversation models. Therefore, the **PPL** (Perplexity) metric should be used to assess model quality, not conversational performance. If you treat them as Q&A models, you're likely to get gibberish outputs (***in fact, even our baseline OPT model will output gibberish***). We recommend fine-tuning these models on private SFT datasets to explore their potential as general-purpose conversation models. 255 | 256 | Step1. Download a pretrained model from table below. 257 | 258 |
259 | 260 | |Model|Download Link| 261 | |:-:|:-:| 262 | |SLLM-125M-H1|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceLLM/resolve/master/SLLM-125M-H1.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceLLM/main?filepath=SLLM-125M-H1.pth&sign=b170645b08b03ee1240b95267c7454ca&nonce=1723800424069)| 263 | |SLLM-125M-H2|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceLLM/resolve/master/SLLM-125M-H2.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceLLM/main?filepath=SLLM-125M-H2.pth&sign=fb3e59ef4c30dcea0d732af741d183b2&nonce=1723800472802)| 264 | |SLLM-125M-H4|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceLLM/resolve/master/SLLM-125M-H4.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceLLM/main?filepath=SLLM-125M-H4.pth&sign=d8d1c1aa1d516e26fe21cb1ca3220e62&nonce=1723800484902)| 265 | |SLLM-350M-H1|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceLLM/resolve/master/SLLM-350M-H1.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceLLM/main?filepath=SLLM-350M-H1.pth&sign=bc5e5e777a1bc41a1a564a7e52a2bf94&nonce=1723800502532)| 266 | |SLLM-350M-H2|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceLLM/resolve/master/SLLM-350M-H2.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceLLM/main?filepath=SLLM-350M-H2.pth&sign=6c8ba81b806649366df99c96fbe3e4ed&nonce=1723800517836)| 267 | |SLLM-350M-H4|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceLLM/resolve/master/SLLM-350M-H4.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceLLM/main?filepath=SLLM-350M-H4.pth&sign=e5722a48a5ff516e61cf9efdc1ee8230&nonce=1723800534148)| 268 | |SLLM-1.3B-H1|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceLLM/resolve/master/SLLM-1.3B-H1.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceLLM/main?filepath=SLLM-1.3B-H1.pth&sign=77edd326c8e46eebd98f7f545f4d4e0c&nonce=1723800549084)| 269 | |SLLM-1.3B-H2|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceLLM/resolve/master/SLLM-1.3B-H2.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceLLM/main?filepath=SLLM-1.3B-H2.pth&sign=54dc841b8a067afe7a2fbd16a6c0a2e5&nonce=1723800565365)| 270 | 271 |
272 | 273 | Step2. Run demo script under `tools/demo` folder. Here's an example: 274 | 275 | ```sh 276 | cd SentenceVAE 277 | 278 | python3 tools/demo/demo_sllm.py \ 279 | -c config/SLLM/SLLM-125m/sllm_125m_h4_all.yaml \ 280 | --checkpoint /path/to/pretrained/checkpoint \ 281 | --input "What's your name?" 282 | ``` 283 | 284 | **Arguments**: 285 | * `-c`,`--config`: path to the corresponding configuration file, please reference [this folder](config/SLLM/). 286 | * `--checkpoint`: path to the checkpoint file you just downloaded. 287 | * `--input`: Your input sentence. 288 | 289 | 290 |
291 | 292 | # 3.Tutorials 293 | 294 | Under writing... 295 | 296 |
297 | Train Models 298 | 299 | * [Prepare Datasets](#) 300 | * [Train SentenceVAEs](#) 301 | * [Train SentenceLLMs](#) 302 | 303 |
304 | 305 |
306 | Eval Models 307 | 308 | * [Eval OPT models (baseline)](#) 309 | * [Eval SentenceVAEs](#) 310 | * [Eval SentenceLLMs](#) 311 | 312 |
313 | 314 |
315 | Test Benchmarks 316 | 317 | * [Test benchmarks of SentenceVAEs](#) 318 | * [Test benchmarks of SentenceLLMs](#) 319 | 320 |
321 | 322 | # 4.Cite SentenceVAE 323 | 324 | If you use SentenceVAE in your research, please cite our work by using the following BibTeX entry: 325 | 326 | ```bibtex 327 | @article{an2024sentencevae, 328 | title={SentenceVAE: Enable Next-sentence Prediction for Large Language Models with Faster Speed, Higher Accuracy and Longer Context}, 329 | author={An, Hongjun and Chen, Yifan and Sun, Zhe and Li, Xuelong}, 330 | journal={arXiv preprint arXiv:2408.00655}, 331 | year={2024} 332 | } 333 | ``` -------------------------------------------------------------------------------- /README_zh-CN.md: -------------------------------------------------------------------------------- 1 | TODO. -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BestAnHongjun/SentenceVAE/8e323e4a5db3b9f36ccc3435d88804e60dfabf69/assets/logo.png -------------------------------------------------------------------------------- /config/SLLM/SLLM-1.3b/sllm_1.3b_base.yaml: -------------------------------------------------------------------------------- 1 | __base__: !include ../sllm_base.yaml 2 | 3 | svae: 4 | ref_model_dir: "model_repo/opt-1.3b" 5 | 6 | llm: 7 | ref_model_dir: "model_repo/opt-1.3b" 8 | -------------------------------------------------------------------------------- /config/SLLM/SLLM-1.3b/sllm_1.3b_h1_all.yaml: -------------------------------------------------------------------------------- 1 | __base__: !include ./sllm_1.3b_base.yaml 2 | 3 | svae: 4 | num_hidden_layers: 1 5 | 6 | llm: 7 | finetune_layers: -1 8 | -------------------------------------------------------------------------------- /config/SLLM/SLLM-1.3b/sllm_1.3b_h2_all.yaml: -------------------------------------------------------------------------------- 1 | __base__: !include ./sllm_1.3b_base.yaml 2 | 3 | svae: 4 | num_hidden_layers: 2 5 | 6 | llm: 7 | finetune_layers: -1 8 | -------------------------------------------------------------------------------- /config/SLLM/SLLM-1.3b/sllm_1.3b_h4_all.yaml: -------------------------------------------------------------------------------- 1 | __base__: !include ./sllm_1.3b_base.yaml 2 | 3 | svae: 4 | num_hidden_layers: 4 5 | 6 | llm: 7 | finetune_layers: -1 8 | -------------------------------------------------------------------------------- /config/SLLM/SLLM-125m/sllm_125m_base.yaml: -------------------------------------------------------------------------------- 1 | __base__: !include ../sllm_base.yaml 2 | 3 | svae: 4 | ref_model_dir: "model_repo/opt-125m" 5 | 6 | llm: 7 | ref_model_dir: "model_repo/opt-125m" 8 | -------------------------------------------------------------------------------- /config/SLLM/SLLM-125m/sllm_125m_h1_all.yaml: -------------------------------------------------------------------------------- 1 | __base__: !include ./sllm_125m_base.yaml 2 | 3 | svae: 4 | num_hidden_layers: 1 5 | 6 | llm: 7 | finetune_layers: -1 8 | -------------------------------------------------------------------------------- /config/SLLM/SLLM-125m/sllm_125m_h2_all.yaml: -------------------------------------------------------------------------------- 1 | __base__: !include ./sllm_125m_base.yaml 2 | 3 | svae: 4 | num_hidden_layers: 2 5 | 6 | llm: 7 | finetune_layers: -1 8 | -------------------------------------------------------------------------------- /config/SLLM/SLLM-125m/sllm_125m_h4_all.yaml: -------------------------------------------------------------------------------- 1 | __base__: !include ./sllm_125m_base.yaml 2 | 3 | svae: 4 | num_hidden_layers: 4 5 | 6 | llm: 7 | finetune_layers: -1 8 | -------------------------------------------------------------------------------- /config/SLLM/SLLM-350m/sllm_350m_base.yaml: -------------------------------------------------------------------------------- 1 | __base__: !include ../sllm_base.yaml 2 | 3 | svae: 4 | ref_model_dir: "model_repo/opt-350m" 5 | 6 | llm: 7 | ref_model_dir: "model_repo/opt-350m" 8 | -------------------------------------------------------------------------------- /config/SLLM/SLLM-350m/sllm_350m_h1_all.yaml: -------------------------------------------------------------------------------- 1 | __base__: !include ./sllm_350m_base.yaml 2 | 3 | svae: 4 | num_hidden_layers: 1 5 | 6 | llm: 7 | finetune_layers: -1 8 | -------------------------------------------------------------------------------- /config/SLLM/SLLM-350m/sllm_350m_h2_all.yaml: -------------------------------------------------------------------------------- 1 | __base__: !include ./sllm_350m_base.yaml 2 | 3 | svae: 4 | num_hidden_layers: 2 5 | 6 | llm: 7 | finetune_layers: -1 8 | -------------------------------------------------------------------------------- /config/SLLM/SLLM-350m/sllm_350m_h4_all.yaml: -------------------------------------------------------------------------------- 1 | __base__: !include ./sllm_350m_base.yaml 2 | 3 | svae: 4 | num_hidden_layers: 4 5 | 6 | llm: 7 | finetune_layers: -1 8 | -------------------------------------------------------------------------------- /config/SLLM/sllm_base.yaml: -------------------------------------------------------------------------------- 1 | device: "cuda" 2 | dtype: "fp32" 3 | 4 | svae: 5 | learnable_add: false 6 | load_ref_model: false 7 | ref_model_dir: "model_repo/opt-125m" 8 | ref_model_dtype: null 9 | finetune_embedding: true 10 | model_path: null 11 | 12 | llm: 13 | ref_model_dir: "model_repo/opt-125m" 14 | ref_model_dtype: null 15 | finetune_layers: -1 16 | 17 | finetune_svae: true 18 | max_sen_len: 64 19 | max_sen_num: 64 20 | batch_size: 1 21 | base_lr: 0.000001 22 | resume_train: true 23 | dataloader_num_workers: 32 24 | dataloader_prefetch_factor: 20 25 | save_checkpoint_iters: 5000 26 | max_iters: 1600000 27 | warmup_iters: 5000 28 | val_iters: 5000 29 | cosineannealinglr_tmax: 20000 30 | max_keep_ckpts: 2 31 | -------------------------------------------------------------------------------- /config/SVAE/SVAE-1024/svae_1024_base.yaml: -------------------------------------------------------------------------------- 1 | __base__: !include ../svae_base.yaml 2 | 3 | ref_model_dir: "model_repo/opt-350m" 4 | -------------------------------------------------------------------------------- /config/SVAE/SVAE-1024/svae_1024_h1.yaml: -------------------------------------------------------------------------------- 1 | __base__: !include ./svae_1024_base.yaml 2 | 3 | num_hidden_layers: 1 4 | -------------------------------------------------------------------------------- /config/SVAE/SVAE-1024/svae_1024_h2.yaml: -------------------------------------------------------------------------------- 1 | __base__: !include ./svae_1024_base.yaml 2 | 3 | num_hidden_layers: 2 4 | -------------------------------------------------------------------------------- /config/SVAE/SVAE-1024/svae_1024_h4.yaml: -------------------------------------------------------------------------------- 1 | __base__: !include ./svae_1024_base.yaml 2 | 3 | num_hidden_layers: 4 4 | -------------------------------------------------------------------------------- /config/SVAE/SVAE-2048/svae_2048_base.yaml: -------------------------------------------------------------------------------- 1 | __base__: !include ../svae_base.yaml 2 | 3 | ref_model_dir: "model_repo/opt-1.3b" 4 | -------------------------------------------------------------------------------- /config/SVAE/SVAE-2048/svae_2048_h1.yaml: -------------------------------------------------------------------------------- 1 | __base__: !include ./svae_2048_base.yaml 2 | 3 | num_hidden_layers: 1 4 | -------------------------------------------------------------------------------- /config/SVAE/SVAE-2048/svae_2048_h2.yaml: -------------------------------------------------------------------------------- 1 | __base__: !include ./svae_2048_base.yaml 2 | 3 | num_hidden_layers: 2 4 | -------------------------------------------------------------------------------- /config/SVAE/SVAE-2048/svae_2048_h4.yaml: -------------------------------------------------------------------------------- 1 | __base__: !include ./svae_2048_base.yaml 2 | 3 | num_hidden_layers: 4 4 | -------------------------------------------------------------------------------- /config/SVAE/SVAE-768/svae_768_base.yaml: -------------------------------------------------------------------------------- 1 | __base__: !include ../svae_base.yaml 2 | 3 | ref_model_dir: "model_repo/opt-125m" 4 | -------------------------------------------------------------------------------- /config/SVAE/SVAE-768/svae_768_h1.yaml: -------------------------------------------------------------------------------- 1 | __base__: !include ./svae_768_base.yaml 2 | 3 | num_hidden_layers: 1 4 | -------------------------------------------------------------------------------- /config/SVAE/SVAE-768/svae_768_h2.yaml: -------------------------------------------------------------------------------- 1 | __base__: !include ./svae_768_base.yaml 2 | 3 | num_hidden_layers: 2 4 | -------------------------------------------------------------------------------- /config/SVAE/SVAE-768/svae_768_h4.yaml: -------------------------------------------------------------------------------- 1 | __base__: !include ./svae_768_base.yaml 2 | 3 | num_hidden_layers: 4 4 | -------------------------------------------------------------------------------- /config/SVAE/svae_base.yaml: -------------------------------------------------------------------------------- 1 | device: "cuda" 2 | dtype: "fp32" 3 | learnable_add: false 4 | load_ref_model: false 5 | ref_model_dir: "model_repo/opt-125m" 6 | ref_model_dtype: null 7 | num_hidden_layers: 1 8 | finetune_embedding: true 9 | model_path: null 10 | max_seq_len: 64 11 | batch_size: 128 12 | base_lr: 0.0000001 13 | resume_train: true 14 | dataloader_num_workers: 32 15 | dataloader_prefetch_factor: 20 16 | save_checkpoint_iters: 5000 17 | max_iters: 300000 18 | warmup_iters: 5000 19 | val_iters: 5000 20 | cosineannealinglr_tmax: 20000 21 | max_keep_ckpts: 2 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mmengine 2 | torch 3 | transformers 4 | tensorboard 5 | accelerate 6 | pyyaml 7 | pyyaml-include 8 | tqdm 9 | numpy 10 | -------------------------------------------------------------------------------- /sentence_vae/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | __version__ = "0.0.1" 30 | -------------------------------------------------------------------------------- /sentence_vae/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | from .tele_ds_dataset import TeleDSDataset 30 | from .data_collate import SentenceCollate, PassageCollate 31 | from .data_eval import SVAE_PPL, SLLM_PPL 32 | -------------------------------------------------------------------------------- /sentence_vae/data/data_collate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | import re 30 | import torch 31 | 32 | 33 | class SentenceCollate: 34 | def __init__( 35 | self, 36 | tokenizer, 37 | max_len=1024, 38 | padding=True, 39 | fix_len=True 40 | ): 41 | self.tokenizer = tokenizer 42 | self.max_len = max_len 43 | self.padding = padding 44 | self.fix_len = fix_len 45 | 46 | def __call__(self, texts): 47 | encoded = self.tokenizer.batch_encode_plus( 48 | texts, 49 | padding=self.padding, 50 | truncation=True, 51 | max_length=self.max_len, 52 | return_tensors="pt" 53 | ) 54 | 55 | input_ids = encoded['input_ids'] 56 | attention_mask = encoded['attention_mask'] 57 | 58 | if self.fix_len: 59 | batch, seq_len = input_ids.shape 60 | if seq_len < self.max_len: 61 | pad_ids = torch.zeros( 62 | (batch, self.max_len - seq_len), 63 | dtype=input_ids.dtype, 64 | device=input_ids.device 65 | ).fill_(self.tokenizer.pad_token_id) 66 | pad_mask = torch.zeros( 67 | (batch, self.max_len - seq_len), 68 | dtype=attention_mask.dtype, 69 | device=attention_mask.device 70 | ) 71 | 72 | input_ids = torch.concat((input_ids, pad_ids), dim=1) 73 | attention_mask = torch.concat((attention_mask, pad_mask), dim=1) 74 | 75 | return input_ids, attention_mask 76 | 77 | 78 | class PassageCollate: 79 | def __init__( 80 | self, 81 | tokenizer, 82 | max_sentence_len=512, 83 | max_sentence_num=512, 84 | padding=True, 85 | fix_len=True 86 | ): 87 | self.tokenizer = tokenizer 88 | self.max_sentence_len = max_sentence_len 89 | self.max_sentence_num = max_sentence_num 90 | self.padding = padding 91 | self.fix_len = fix_len 92 | 93 | self._re_sentence = re.compile( 94 | '([,。!?;:\?])([^”’])|' + '([,.!?;:\?])([^"\'])|' + 95 | '(\…{2})([^”’])|' + '(\.{6})([^”’])|' + 96 | '([。!?\?][”’])([^,。!?\?])|' + '([.!?\?]["\'])([^,。!?\?])' 97 | ) 98 | 99 | 100 | def cut_sentence_func(self, para): 101 | p, last_p, length = 0, 0, len(para) 102 | sentences = [] 103 | sentence_mask = [] 104 | 105 | while p < length and len(sentences) < self.max_sentence_num: 106 | if p + 1 < length and ( 107 | para[p:p+2] == "……" or 108 | para[p:p+2] == ",”" or 109 | para[p:p+2] == "”," or 110 | para[p:p+2] == "。”" or 111 | para[p:p+2] == "”。" or 112 | para[p:p+2] == "!”" or 113 | para[p:p+2] == "”!" or 114 | para[p:p+2] == "?”" or 115 | para[p:p+2] == "”?" or 116 | para[p:p+2] == ',"' or 117 | para[p:p+2] == '",' or 118 | para[p:p+2] == '."' or 119 | para[p:p+2] == '".' or 120 | para[p:p+2] == '!"' or 121 | para[p:p+2] == '"!' or 122 | para[p:p+2] == '?"' or 123 | para[p:p+2] == '"?' 124 | ): 125 | p += 2 126 | # cut 127 | elif para[p] in [ 128 | ',', '。', '”', '!', '?', ';', ':', 129 | ',', '.', '!', '?', ';', ':', '"', '\n', '\r' 130 | ]: 131 | p += 1 132 | # cut 133 | else: 134 | p += 1 135 | continue 136 | 137 | while p < length and para[p] in [' ', '\n', '\r']: p += 1 138 | sentences.append(para[last_p:p]) 139 | sentence_mask.append(1) 140 | last_p = p 141 | 142 | delta = self.max_sentence_num - len(sentences) 143 | if delta: 144 | sentences.extend(["" for _ in range(delta)]) 145 | sentence_mask.extend([0 for _ in range(delta)]) 146 | 147 | sentence_mask = torch.tensor(sentence_mask, dtype=torch.int64) 148 | 149 | return sentences, sentence_mask 150 | 151 | 152 | def __call__(self, texts): 153 | batch_size = len(texts) 154 | 155 | batch_sentence_mask = torch.zeros( 156 | (batch_size, self.max_sentence_num), 157 | dtype=torch.int64 158 | ) 159 | batch_sentence_toks = torch.zeros( 160 | (batch_size, self.max_sentence_num, self.max_sentence_len), 161 | dtype=torch.int64 162 | ) 163 | batch_tok_mask = torch.zeros( 164 | (batch_size, self.max_sentence_num, self.max_sentence_len), 165 | dtype=torch.int64 166 | ) 167 | 168 | for i, text in enumerate(texts): 169 | sentences, sentence_mask = self.cut_sentence_func(text) 170 | encoded = self.tokenizer.batch_encode_plus( 171 | sentences, 172 | padding=self.padding, 173 | truncation=True, 174 | max_length=self.max_sentence_len, 175 | return_tensors="pt" 176 | ) 177 | 178 | sentence_toks, tok_mask = encoded['input_ids'], encoded['attention_mask'] 179 | sentence_num, seq_len = sentence_toks.shape 180 | if seq_len < self.max_sentence_len: 181 | pad_ids = torch.zeros( 182 | (sentence_num, self.max_sentence_len - seq_len), 183 | dtype=sentence_toks.dtype, 184 | device=sentence_toks.device 185 | ).fill_(self.tokenizer.pad_token_id) 186 | pad_mask = torch.zeros( 187 | (sentence_num, self.max_sentence_len - seq_len), 188 | dtype=tok_mask.dtype, 189 | device=tok_mask.device 190 | ) 191 | sentence_toks = torch.concat((sentence_toks, pad_ids), dim=1) 192 | tok_mask = torch.concat((tok_mask, pad_mask), dim=1) 193 | 194 | batch_sentence_mask[i] = sentence_mask 195 | batch_sentence_toks[i] = sentence_toks 196 | batch_tok_mask[i] = tok_mask 197 | 198 | return batch_sentence_mask, batch_sentence_toks, batch_tok_mask 199 | -------------------------------------------------------------------------------- /sentence_vae/data/data_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | import numpy as np 30 | import torch.nn.functional as F 31 | from mmengine.evaluator import BaseMetric 32 | 33 | 34 | class SVAE_PPL(BaseMetric): 35 | def process(self, data_batch, data_samples): 36 | output, attention_mask, tgt_ids = data_samples 37 | loss = F.cross_entropy(output[attention_mask], tgt_ids[:, 1:][attention_mask]) 38 | result = {'loss': loss.item()} 39 | self.results.append(result) 40 | 41 | def compute_metrics(self, results): 42 | loss = np.mean([res['loss'] for res in results]) 43 | ppl = np.exp(loss) 44 | return dict(eval_loss=loss, eval_ppl=ppl) 45 | 46 | 47 | class SLLM_PPL(BaseMetric): 48 | def process(self, data_batch, data_sample): 49 | stop_loss, ppl_loss = data_sample 50 | result = {'stop_loss': stop_loss.item(), 'ppl_loss': ppl_loss.item()} 51 | self.results.append(result) 52 | 53 | def compute_metrics(self, results): 54 | stop_loss = np.mean([res['stop_loss'] for res in results]) 55 | loss = np.mean([res['ppl_loss'] for res in results]) 56 | ppl = np.exp(loss) 57 | return dict(eval_stop_loss=stop_loss, eval_loss=loss, eval_ppl=ppl) 58 | -------------------------------------------------------------------------------- /sentence_vae/data/tele_ds_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | from time import sleep 30 | from torch.utils.data import Dataset 31 | 32 | from sentence_vae.utils import fetch_text_with_retry 33 | 34 | 35 | class TeleDSDataset(Dataset): 36 | def __init__( 37 | self, 38 | server_ip="127.0.0.1", 39 | server_port=8000, 40 | eval_mode=False, 41 | eval_samples=1000 42 | ): 43 | self.server_url = f"http://{server_ip}:{server_port}" 44 | self.eval_mode = eval_mode 45 | self.eval_samples = eval_samples 46 | 47 | def __len__(self): 48 | while True: 49 | status = fetch_text_with_retry(f"{self.server_url}/status").strip() 50 | if status != "ready.": 51 | print("TeleDS Server is not ready, retrying.") 52 | else: 53 | break 54 | sleep(1) 55 | # print("Server is ready!") 56 | num = int(fetch_text_with_retry(f"{self.server_url}/count").strip()) 57 | assert num > self.eval_samples 58 | if not self.eval_mode: 59 | return num - self.eval_samples 60 | else: 61 | return self.eval_samples 62 | 63 | def __getitem__(self, idx): 64 | if not self.eval_mode: 65 | idx += self.eval_samples 66 | text = fetch_text_with_retry(f"{self.server_url}/data/{idx}") 67 | return text 68 | -------------------------------------------------------------------------------- /sentence_vae/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | from .positional_encoding import PositionalEncodding 30 | from .sentence_encoder import SentenceEncoder 31 | from .sentence_decoder import SentenceDecoder 32 | from .sentence_vae_model import SentenceVAE 33 | from .sentence_llm_model import SentenceLLM 34 | from .focal_loss import FocalLoss 35 | -------------------------------------------------------------------------------- /sentence_vae/models/focal_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | import torch 30 | import torch.nn as nn 31 | import torch.nn.functional as F 32 | 33 | 34 | class FocalLoss(nn.Module): 35 | def __init__(self, gamma=2.0, reduction='mean', eps=1e-6): 36 | super(FocalLoss, self).__init__() 37 | self.gamma = gamma 38 | self.reduction = reduction 39 | self.eps = eps 40 | 41 | def forward(self, logits, targets): 42 | max_logits = torch.max(logits, dim=-1, keepdim=True)[0] 43 | log_probs = logits - max_logits 44 | log_probs = log_probs - torch.log(torch.sum(torch.exp(log_probs), dim=-1, keepdim=True)) 45 | probs = torch.exp(log_probs) 46 | 47 | targets_one_hot = F.one_hot(targets, num_classes=logits.size(-1)) 48 | targets_probs = torch.sum(probs * targets_one_hot, dim=-1) 49 | 50 | focal_weight = (1 - targets_probs) ** self.gamma 51 | log_targets_probs = torch.log(targets_probs + self.eps) 52 | loss = -focal_weight * log_targets_probs 53 | 54 | loss = loss[~torch.isnan(loss)] 55 | if loss.size(0) == 0: 56 | return torch.tensor(0, requires_grad=True).to(targets.device) 57 | 58 | if self.reduction == 'mean': 59 | loss = torch.mean(loss) 60 | elif self.reduction == 'sum': 61 | loss = torch.sum(loss) 62 | 63 | return loss 64 | -------------------------------------------------------------------------------- /sentence_vae/models/positional_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | import math 30 | import torch 31 | import torch.nn as nn 32 | 33 | 34 | class PositionalEncodding(nn.Module): 35 | def __init__( 36 | self, 37 | hidden_size: int, 38 | max_len: int = 4096, 39 | device: torch.device = None, 40 | dtype: torch.dtype = None 41 | ): 42 | super(PositionalEncodding, self).__init__() 43 | assert hidden_size % 2 == 0, \ 44 | f"Cannot use sin/cos positional encoding with odd hidden_size (go size={hidden_size})." 45 | 46 | device = device if device is not None else torch.device('cuda') 47 | dtype = dtype if dtype is not None else torch.float16 48 | 49 | pe = torch.zeros(max_len, hidden_size) 50 | position = torch.arange(0, max_len).unsqueeze(1) 51 | div_term = torch.exp((torch.arange(0, hidden_size, 2) * -(math.log(10000.0) / hidden_size))) 52 | 53 | pe[:, 0::2] = torch.sin(position * div_term) 54 | pe[:, 1::2] = torch.cos(position * div_term) 55 | 56 | self.pe = pe.unsqueeze(0) 57 | self.pe.requires_grad = False 58 | 59 | self.pe = self.pe.to(dtype) 60 | self.pe = self.pe.to(device) 61 | 62 | 63 | def forward(self, seq_len): 64 | return self.pe[:, :seq_len, :] 65 | -------------------------------------------------------------------------------- /sentence_vae/models/sentence_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | import torch 30 | import torch.nn as nn 31 | from typing import Union 32 | 33 | from sentence_vae.models import PositionalEncodding 34 | from sentence_vae.utils.llm import get_model 35 | from sentence_vae.utils.weights import load_embedding_state_dict 36 | 37 | 38 | class SentenceDecoder(nn.Module): 39 | def __init__( 40 | self, 41 | hidden_size: int, 42 | vocab_size: int, 43 | device: torch.device = None, 44 | dtype: torch.dtype = None, 45 | load_ref_model: Union[bool, nn.Module] = False, 46 | ref_model_dir: str = None, 47 | ref_model_dtype: torch.dtype = None, 48 | finetune_embedding: bool = True, 49 | word_embed_proj_dim: int = None, 50 | num_attention_heads: int = 16, 51 | num_hidden_layers: int = 1, 52 | max_seq_len: int = 1024, 53 | dropout: float = 0.1, 54 | pad_id=1, 55 | ): 56 | super(SentenceDecoder, self).__init__() 57 | word_embed_proj_dim = word_embed_proj_dim if word_embed_proj_dim is not None else hidden_size 58 | 59 | self.device = device if device is not None else torch.device('cuda') 60 | self.dtype = dtype if dtype is not None else torch.float16 61 | 62 | self.embed_token = nn.Embedding(vocab_size, word_embed_proj_dim, padding_idx=pad_id) 63 | self.embed_positions = PositionalEncodding(hidden_size, max_seq_len, dtype=self.dtype, device=self.device) 64 | 65 | if word_embed_proj_dim != hidden_size: 66 | self.project_in = nn.Linear(word_embed_proj_dim, hidden_size, bias=False) 67 | else: 68 | self.project_in = None 69 | 70 | decoder_layer = nn.TransformerDecoderLayer( 71 | d_model=hidden_size, 72 | nhead=num_attention_heads, 73 | dim_feedforward=hidden_size * 2, 74 | dropout=dropout, 75 | batch_first=True, 76 | ) 77 | 78 | self.decoder = nn.TransformerDecoder(decoder_layer, num_hidden_layers) 79 | self.linear = nn.Linear(hidden_size, vocab_size) 80 | 81 | if isinstance(load_ref_model, nn.Module): 82 | ref_model_dtype = ref_model_dtype if ref_model_dtype is not None else self.dtype 83 | ref_model = load_ref_model 84 | ref_model_state_dict = load_embedding_state_dict(ref_model) 85 | assert ref_model_state_dict is not None, f"Model {ref_model_dir} does not have an Embedding layer." 86 | self.embed_token.load_state_dict(ref_model_state_dict) 87 | elif load_ref_model and ref_model_dir is not None: 88 | ref_model_dtype = ref_model_dtype if ref_model_dtype is not None else self.dtype 89 | ref_model = get_model(ref_model_dir, ref_model_dtype, 'cpu') 90 | ref_model_state_dict = load_embedding_state_dict(ref_model) 91 | assert ref_model_state_dict is not None, f"Model {ref_model_dir} does not have an Embedding layer." 92 | self.embed_token.load_state_dict(ref_model_state_dict) 93 | if not finetune_embedding: 94 | self.embed_token.weight.requires_grad = False 95 | 96 | self.to(self.dtype) 97 | self.to(self.device) 98 | 99 | 100 | def forward(self, input_ids, sentence_embed, attention_mask=None): 101 | _, seq_len = input_ids.shape 102 | if attention_mask is not None and attention_mask.dtype is not torch.bool: 103 | attention_mask = ~attention_mask.to(torch.bool) 104 | 105 | inputs_emb = self.embed_token(input_ids) 106 | pos_emb = self.embed_positions(seq_len) 107 | if self.project_in is not None: 108 | inputs_emb = self.project_in(inputs_emb) 109 | inputs_emb = inputs_emb + pos_emb 110 | 111 | future_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=self.device) == -torch.inf 112 | 113 | hidden_state = self.decoder( 114 | inputs_emb, 115 | sentence_embed, 116 | tgt_mask=future_mask, 117 | tgt_key_padding_mask=attention_mask, 118 | tgt_is_causal=True 119 | ) 120 | output = self.linear(hidden_state) 121 | 122 | return output 123 | 124 | 125 | def streaming_generate( 126 | self, sentence_embed, 127 | max_output_len:int=64, 128 | bos_token_id: int=2, 129 | eos_token_id: int=2 130 | ): 131 | output_ids = torch.tensor([[bos_token_id]], dtype=torch.long, device=self.device) 132 | while len(output_ids) < max_output_len: 133 | logits = self.forward(output_ids, sentence_embed) 134 | new_id = torch.argmax(logits[:, -1:], dim=-1) 135 | output_ids = torch.concat((output_ids, new_id), dim=1) 136 | yield new_id.item() 137 | if new_id.item() == eos_token_id: 138 | break 139 | 140 | 141 | def generate( 142 | self, sentence_embed, 143 | max_output_len:int=64, 144 | bos_token_id: int=2, 145 | eos_token_id: int=2 146 | ): 147 | output_ids = [] 148 | for output_id in self.streaming_generate(sentence_embed, max_output_len, bos_token_id, eos_token_id): 149 | output_ids.append(output_id) 150 | return output_ids 151 | -------------------------------------------------------------------------------- /sentence_vae/models/sentence_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | import torch 30 | import torch.nn as nn 31 | from typing import Union 32 | 33 | from sentence_vae.models import PositionalEncodding 34 | from sentence_vae.utils.llm import get_model 35 | from sentence_vae.utils.weights import load_embedding_state_dict 36 | 37 | 38 | class SentenceEncoder(nn.Module): 39 | def __init__( 40 | self, 41 | hidden_size: int, 42 | vocab_size: int, 43 | device: torch.device = None, 44 | dtype: torch.dtype = None, 45 | learnable_add: bool = True, 46 | load_ref_model: Union[bool, nn.Module] = False, 47 | ref_model_dir: str = None, 48 | ref_model_dtype: torch.dtype = None, 49 | finetune_embedding: bool = True, 50 | word_embed_proj_dim: int = None, 51 | num_attention_heads: int = 16, 52 | num_hidden_layers: int = 1, 53 | max_seq_len: int = 1024, 54 | dropout: float = 0.1, 55 | pad_id=1, 56 | ): 57 | super(SentenceEncoder, self).__init__() 58 | word_embed_proj_dim = word_embed_proj_dim if word_embed_proj_dim is not None else hidden_size 59 | 60 | self.device = device if device is not None else torch.device('cuda') 61 | self.dtype = dtype if dtype is not None else torch.float16 62 | self.learnable_add = learnable_add 63 | 64 | self.embed_token = nn.Embedding(vocab_size, word_embed_proj_dim, padding_idx=pad_id) 65 | self.embed_positions = PositionalEncodding(hidden_size, max_seq_len, dtype=self.dtype, device=self.device) 66 | 67 | if word_embed_proj_dim != hidden_size: 68 | self.project_in = nn.Linear(word_embed_proj_dim, hidden_size, bias=False) 69 | else: 70 | self.project_in = None 71 | 72 | encoder_layer = nn.TransformerEncoderLayer( 73 | d_model=hidden_size, 74 | nhead=num_attention_heads, 75 | dim_feedforward=hidden_size * 2, 76 | dropout=dropout, 77 | batch_first=True, 78 | ) 79 | self.encoder = nn.TransformerEncoder(encoder_layer, num_hidden_layers) 80 | if self.learnable_add: 81 | self.la = nn.Linear(hidden_size, 1) 82 | # self.conv1d = nn.Conv1d(hidden_size, 1, kernel_size=3, stride=1, padding=1, bias=True) 83 | self.lnorm = nn.LayerNorm(hidden_size, device=self.device, dtype=self.dtype) 84 | 85 | if isinstance(load_ref_model, nn.Module): 86 | ref_model_dtype = ref_model_dtype if ref_model_dtype is not None else self.dtype 87 | ref_model = load_ref_model 88 | ref_model_state_dict = load_embedding_state_dict(ref_model) 89 | assert ref_model_state_dict is not None, f"Model {ref_model_dir} does not have an Embedding layer." 90 | self.embed_token.load_state_dict(ref_model_state_dict) 91 | elif load_ref_model and ref_model_dir is not None: 92 | ref_model_dtype = ref_model_dtype if ref_model_dtype is not None else self.dtype 93 | ref_model = get_model(ref_model_dir, ref_model_dtype, 'cpu') 94 | ref_model_state_dict = load_embedding_state_dict(ref_model) 95 | assert ref_model_state_dict is not None, f"Model {ref_model_dir} does not have an Embedding layer." 96 | self.embed_token.load_state_dict(ref_model_state_dict) 97 | if not finetune_embedding: 98 | self.embed_token.weight.requires_grad = False 99 | 100 | self.to(self.dtype) 101 | self.to(self.device) 102 | 103 | 104 | def forward(self, input_ids, attention_mask=None): 105 | _, seq_len = input_ids.shape 106 | if attention_mask is not None and attention_mask.dtype is not torch.bool: 107 | attention_mask = ~attention_mask.to(torch.bool) 108 | 109 | inputs_emb = self.embed_token(input_ids) 110 | pos_emb = self.embed_positions(seq_len) 111 | if self.project_in is not None: 112 | inputs_emb = self.project_in(inputs_emb) 113 | inputs_emb = inputs_emb + pos_emb 114 | 115 | hidden_state = self.encoder(inputs_emb, src_key_padding_mask=attention_mask) 116 | hidden_state[attention_mask] = 0 # (batch, seq_len, hidden) 117 | 118 | if self.learnable_add: 119 | alpha = self.la(hidden_state) 120 | # print(alpha.shape) 121 | # hidden_trans = hidden_state.transpose(2, 1).contiguous() # (batch, hidden, seq_len) 122 | # alpha = self.conv1d(hidden_trans) # (batch, 1, seq_len) 123 | # alpha = alpha.transpose(2, 1).contiguous() # (batch, seqlen, hidden) 124 | sentence_emb = torch.sum(hidden_state * alpha, dim=-2, keepdim=True) 125 | else: 126 | sentence_emb = torch.sum(hidden_state, dim=-2, keepdim=True) 127 | sentence_emb_norm = self.lnorm(sentence_emb) 128 | 129 | return sentence_emb_norm 130 | -------------------------------------------------------------------------------- /sentence_vae/models/sentence_llm_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | import torch 30 | import torch.nn as nn 31 | import torch.nn.functional as F 32 | 33 | from mmengine.model import BaseModel 34 | 35 | from .focal_loss import FocalLoss 36 | from .sentence_vae_model import SentenceVAE 37 | from sentence_vae.utils.llm import get_model 38 | 39 | 40 | class SentenceLLM(BaseModel): 41 | def __init__( 42 | self, 43 | svae_hidden_size: int, 44 | svae_vocab_size: int, 45 | svae_learnable_add: bool = True, 46 | svae_load_ref_model: bool = False, 47 | svae_ref_model_dir: str = None, 48 | svae_ref_model_dtype: torch.dtype = None, 49 | svae_finetune_embedding: bool = True, 50 | svae_word_embed_proj_dim: int = None, 51 | svae_num_attention_heads: int = 16, 52 | svae_num_hidden_layers: int = 1, 53 | svae_model_path: str = None, 54 | llm_ref_model_dir: str = None, 55 | llm_ref_model_dtype: torch.dtype = None, 56 | llm_finetune_layers: int = -1, 57 | finetune_svae: bool = True, 58 | max_sentence_len: int = 512, 59 | max_sentence_num: int = 512, 60 | dropout: float = 0.1, 61 | bos_id=2, 62 | pad_id=1, 63 | end_id=2, 64 | device: torch.device = None, 65 | dtype: torch.dtype = None, 66 | ): 67 | super().__init__() 68 | self.bos_token_id = bos_id 69 | self.pad_token_id = pad_id 70 | self.eos_token_id = end_id 71 | self.hidden_size = svae_hidden_size 72 | self.llm_finetune_layers = llm_finetune_layers 73 | self.max_sentence_len = max_sentence_len 74 | self.max_sentence_num = max_sentence_num 75 | 76 | self.device = device if device is not None else torch.device("cuda") 77 | self.dtype = dtype if dtype is not None else torch.float16 78 | 79 | self.svae = SentenceVAE( 80 | hidden_size=svae_hidden_size, vocab_size=svae_vocab_size, device=self.device, dtype=self.dtype, 81 | learnable_add=svae_learnable_add, load_ref_model=svae_load_ref_model, ref_model_dir=svae_ref_model_dir, ref_model_dtype=svae_ref_model_dtype, 82 | finetune_embedding=svae_finetune_embedding, word_embed_proj_dim=svae_word_embed_proj_dim, 83 | num_attention_heads=svae_num_attention_heads, num_hidden_layers=svae_num_hidden_layers, max_seq_len=max_sentence_len, dropout=dropout, 84 | bos_id=bos_id, pad_id=pad_id, end_id=end_id 85 | ) 86 | 87 | llm_ref_model_dtype = llm_ref_model_dtype if llm_ref_model_dtype else self.dtype 88 | llm = get_model(llm_ref_model_dir, llm_ref_model_dtype, self.device, True) 89 | self.llm_pe = llm.model.decoder.embed_positions 90 | self.llm_layers = llm.model.decoder.layers 91 | 92 | self.fc = nn.Linear(svae_hidden_size, 2) 93 | self.focal_loss = FocalLoss() 94 | 95 | if svae_model_path is not None: 96 | print(f"Loading {svae_model_path}") 97 | ckpt = torch.load(svae_model_path) 98 | self.svae.load_state_dict(ckpt['state_dict']) 99 | 100 | self.to(self.dtype) 101 | self.to(self.device) 102 | 103 | if not finetune_svae: 104 | self._freeze_model(self.svae) 105 | if llm_finetune_layers >= 0 and llm_finetune_layers * 2 < len(self.llm_layers): 106 | for i in range(llm_finetune_layers, len(self.llm_layers) - llm_finetune_layers): 107 | self._freeze_model(self.llm_layers[i]) 108 | 109 | 110 | def _freeze_model(self, model: nn.Module): 111 | for param in model.parameters(): 112 | param.requires_grad = False 113 | 114 | 115 | def forward(self, sentence_mask, sentence_toks, tok_mask, mode='loss'): 116 | sentence_mask = sentence_mask.to(self.device) 117 | sentence_toks = sentence_toks.to(self.device) 118 | tok_mask = tok_mask.to(self.device) 119 | batch_size, sen_num, seq_len = sentence_toks.shape 120 | 121 | sentence_embedding = torch.zeros( 122 | (batch_size, sen_num, self.hidden_size), 123 | dtype=self.dtype, device=self.device 124 | ) 125 | 126 | # encoder 127 | for b in range(batch_size): 128 | sentence_embedding[b] = self.svae.encoder(sentence_toks[b], tok_mask[b]).view(sen_num, self.hidden_size) 129 | sentence_embedding[b][~sentence_mask[b].bool()] = 0 130 | 131 | # llm 132 | pos_emb = self.llm_pe(sentence_mask, 0) 133 | hidden_state = sentence_embedding + pos_emb 134 | for layer in self.llm_layers: 135 | hidden_state = layer(hidden_state)[0] 136 | stop_flag = self.fc(hidden_state) 137 | 138 | sentence_pad = torch.zeros((batch_size, 1), dtype=sentence_mask.dtype, device=sentence_mask.device) 139 | tgt_stop_flag = torch.concat((sentence_mask, sentence_pad), dim=1)[:, 1:] 140 | sen_lens = torch.sum(sentence_mask, dim=1) 141 | stop_loss = 0 142 | for b in range(batch_size): 143 | sen_len = sen_lens[b] 144 | stop_loss += self.focal_loss(stop_flag[b, :sen_len], tgt_stop_flag[b, :sen_len]) 145 | stop_loss /= batch_size 146 | 147 | # decoder 148 | hidden_state = hidden_state.view(batch_size, sen_num, 1, self.hidden_size) 149 | decode_loss = 0 150 | ppl_loss = 0 151 | for b in range(batch_size): 152 | sen_len = sen_lens[b] 153 | input_ids = sentence_toks[b, 1:sen_len] 154 | attention_mask = tok_mask[b, :sen_len-1] 155 | sentence_embd = hidden_state[b, :sen_len-1] 156 | output = self.svae.decoder(input_ids, sentence_embd, attention_mask) 157 | pad_ids = torch.zeros((sen_len-1, 1), device=self.device, dtype=input_ids.dtype).fill_(self.pad_token_id) 158 | tgt_ids = torch.cat((input_ids, pad_ids), dim=1) 159 | seq_lens = torch.sum(attention_mask, dim=1, keepdim=True) 160 | tgt_ids.scatter_(1, seq_lens, self.eos_token_id) 161 | attention_mask = attention_mask.bool() 162 | decode_loss += self.focal_loss(output[attention_mask], tgt_ids[:, 1:][attention_mask]) 163 | ppl_loss += F.cross_entropy(output[attention_mask], tgt_ids[:, 1:][attention_mask]) 164 | del input_ids 165 | del attention_mask 166 | del sentence_embd 167 | del output 168 | del pad_ids 169 | del tgt_ids 170 | decode_loss /= batch_size 171 | ppl_loss /= batch_size 172 | 173 | if mode == 'loss': 174 | if stop_loss < 1e-2: 175 | return {"decode_loss": decode_loss} 176 | return {"stop_loss": stop_loss, "decode_loss": decode_loss} 177 | elif mode == 'predict': 178 | return stop_loss, ppl_loss 179 | 180 | 181 | def streaming_generate(self, sentence_mask, sentence_toks, tok_mask, max_sentence_num=64, token_level=False): 182 | batch_size, sen_num, seq_len = sentence_toks.shape 183 | assert batch_size == 1 184 | 185 | sentence_mask = sentence_mask[:batch_size].to(self.device) 186 | sen_num = torch.sum(sentence_mask) 187 | sentence_mask = sentence_mask[:batch_size, :sen_num] 188 | sentence_toks = sentence_toks[:batch_size, :sen_num].to(self.device) 189 | tok_mask = tok_mask[:batch_size, :sen_num].to(self.device) 190 | 191 | sentence_embedding = torch.zeros( 192 | (batch_size, sen_num, self.hidden_size), 193 | dtype=self.dtype, device=self.device 194 | ) 195 | 196 | # encoder 197 | for b in range(batch_size): 198 | sentence_embedding[b] = self.svae.encoder(sentence_toks[b], tok_mask[b]).view(sen_num, self.hidden_size) 199 | sentence_embedding[b][~sentence_mask[b].bool()] = 0 200 | 201 | # llm 202 | pos_emb = self.llm_pe(sentence_mask, 0) 203 | hidden_state = sentence_embedding + pos_emb 204 | past_key_values = [] 205 | past_key_values_length = sen_num 206 | for layer in self.llm_layers: 207 | hidden_state, kv_cache = layer(hidden_state, use_cache=True) 208 | past_key_values.append(kv_cache) 209 | 210 | while past_key_values_length < max_sentence_num: 211 | stop_flag = torch.argmax(self.fc(hidden_state[:batch_size, -1:, :]), dim=-1) 212 | if stop_flag.item() == 0: 213 | break 214 | new_sentence = self.svae.decoder.generate( 215 | hidden_state[:batch_size, -1:, :], 216 | self.max_sentence_len, 217 | self.bos_token_id, 218 | self.eos_token_id 219 | ) 220 | 221 | yield new_sentence 222 | 223 | new_sentence = torch.tensor([new_sentence], dtype=torch.long, device=self.device) 224 | new_sentence_embedding = self.svae.encoder(new_sentence) 225 | pos_emb = self.llm_pe(torch.ones((1, past_key_values_length + 1), dtype=torch.long, device=self.device), past_key_values_length) 226 | past_key_values_length += 1 227 | hidden_state = new_sentence_embedding + pos_emb 228 | for i, layer in enumerate(self.llm_layers): 229 | hidden_state, kv_cache = layer(hidden_state, past_key_value=past_key_values[i], use_cache=True) 230 | past_key_values[i] = kv_cache 231 | -------------------------------------------------------------------------------- /sentence_vae/models/sentence_vae_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | import torch 30 | 31 | from mmengine.model import BaseModel 32 | 33 | from .focal_loss import FocalLoss 34 | from .sentence_encoder import SentenceEncoder 35 | from .sentence_decoder import SentenceDecoder 36 | from sentence_vae.utils import get_model 37 | 38 | 39 | class SentenceVAE(BaseModel): 40 | def __init__( 41 | self, 42 | hidden_size: int, 43 | vocab_size: int, 44 | device: torch.dtype = None, 45 | dtype: torch.dtype = None, 46 | learnable_add: bool = True, 47 | load_ref_model: bool = False, 48 | ref_model_dir: str = None, 49 | ref_model_dtype: torch.dtype = None, 50 | finetune_embedding: bool = True, 51 | word_embed_proj_dim: int = None, 52 | num_attention_heads: int = 16, 53 | num_hidden_layers: int = 1, 54 | max_seq_len: int = 512, 55 | dropout: float = 0.1, 56 | bos_id=2, 57 | pad_id=1, 58 | end_id=2, 59 | ): 60 | super().__init__() 61 | self.bos_token_id = bos_id 62 | self.pad_token_id = pad_id 63 | self.eos_token_id = end_id 64 | self.max_seq_len = max_seq_len 65 | 66 | self.device = device if device is not None else torch.device("cuda") 67 | self.dtype = dtype if dtype is not None else torch.float16 68 | 69 | if load_ref_model and ref_model_dir is not None: 70 | ref_model_dtype = ref_model_dtype if ref_model_dtype is not None else self.dtype 71 | load_ref_model = get_model(ref_model_dir, ref_model_dtype, 'cpu') 72 | else: 73 | load_ref_model = False 74 | 75 | self.encoder = SentenceEncoder( 76 | hidden_size, vocab_size, device, dtype, 77 | learnable_add, load_ref_model, ref_model_dir, ref_model_dtype, 78 | finetune_embedding, word_embed_proj_dim, 79 | num_attention_heads, num_hidden_layers, max_seq_len, 80 | dropout, pad_id 81 | ) 82 | self.decoder = SentenceDecoder( 83 | hidden_size, vocab_size, device, dtype, 84 | load_ref_model, ref_model_dir, ref_model_dtype, 85 | finetune_embedding, word_embed_proj_dim, 86 | num_attention_heads, num_hidden_layers, max_seq_len, 87 | dropout, pad_id 88 | ) 89 | 90 | self.focal_loss = FocalLoss() 91 | 92 | def forward(self, input_ids, attention_mask=None, mode='loss'): 93 | input_ids = input_ids.to(self.device) 94 | if attention_mask is not None: 95 | attention_mask = attention_mask.to(self.device) 96 | else: 97 | attention_mask = torch.ones(input_ids.shape, dtype=torch.int64, device=self.device) 98 | 99 | sentence_embd = self.encoder(input_ids, attention_mask) 100 | 101 | output = self.decoder(input_ids, sentence_embd, attention_mask) 102 | batch, _ = input_ids.shape 103 | pad_ids = torch.zeros((batch, 1), device=self.device, dtype=input_ids.dtype).fill_(self.pad_token_id) 104 | tgt_ids = torch.cat((input_ids, pad_ids), dim=1) 105 | seq_lens = torch.sum(attention_mask, dim=1, keepdim=True) 106 | tgt_ids.scatter_(1, seq_lens, self.eos_token_id) 107 | attention_mask = attention_mask.bool() 108 | if mode == 'loss': 109 | loss = self.focal_loss(output[attention_mask], tgt_ids[:, 1:][attention_mask]) 110 | return {'total_loss': loss} 111 | elif mode == 'predict': 112 | return output, attention_mask, tgt_ids 113 | else: 114 | return output 115 | 116 | def streaming_generate(self, input_ids, attention_mask=None, max_output_len=64): 117 | batch_size = input_ids.size(0) 118 | assert batch_size == 1 119 | 120 | input_ids = input_ids.to(self.device) 121 | if attention_mask is not None: 122 | attention_mask = attention_mask.to(self.device) 123 | else: 124 | attention_mask = torch.ones(input_ids.shape, dtype=torch.int64, device=self.device) 125 | 126 | sentence_embd = self.encoder(input_ids, attention_mask) 127 | 128 | for output_id in self.decoder.streaming_generate( 129 | sentence_embd, 130 | max_output_len, 131 | self.bos_token_id, 132 | self.eos_token_id 133 | ): 134 | yield output_id 135 | 136 | # output_ids = torch.tensor([[self.bos_token_id]], dtype=torch.long, device=self.device) 137 | # while len(output_ids) < max_output_len: 138 | # logits = self.decoder(output_ids, sentence_embd) 139 | # new_id = torch.argmax(logits[:, -1:], dim=-1) 140 | # output_ids = torch.concat((output_ids, new_id), dim=1) 141 | # yield new_id 142 | # if new_id.item() == self.eos_token_id: 143 | # break 144 | 145 | def generate(self, input_ids, attention_mask=None, max_output_len=64): 146 | output_ids = [] 147 | for output_id in self.streaming_generate(input_ids, attention_mask, max_output_len): 148 | output_ids.append(output_id.item()) 149 | return output_ids -------------------------------------------------------------------------------- /sentence_vae/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | from .http import fetch_text_with_retry 30 | from .llm import get_model, get_tokenizer, get_config, get_dtype 31 | from .weights import init_model_weights, load_embedding_state_dict 32 | from .config import load_yaml 33 | -------------------------------------------------------------------------------- /sentence_vae/utils/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | import os 30 | 31 | 32 | def recursive_update(d, u): 33 | for k, v in u.items(): 34 | if isinstance(v, dict): 35 | d[k] = recursive_update(d.get(k, {}), v) 36 | else: 37 | d[k] = v 38 | return d 39 | 40 | 41 | def load_yaml(yaml_path): 42 | import yaml 43 | import yaml_include 44 | yaml.add_constructor("!include", yaml_include.Constructor(base_dir=os.path.dirname(os.path.abspath(yaml_path)))) 45 | with open(yaml_path, "r") as file: 46 | config = yaml.full_load(file) 47 | while "__base__" in config.keys(): 48 | base = config["__base__"] 49 | del config["__base__"] 50 | config = recursive_update(base, config) 51 | return config 52 | -------------------------------------------------------------------------------- /sentence_vae/utils/http.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | import requests 30 | from time import sleep 31 | 32 | 33 | def fetch_text_with_retry(url, retries=3, timeout=10, backoff_factor=0.3): 34 | """ 35 | Fetch text from the given URL with retries on failure. 36 | 37 | Parameters: 38 | url (str): The URL to fetch the text from. 39 | retries (int): Number of retries before giving up. 40 | timeout (int): Timeout in seconds for the HTTP request. 41 | backoff_factor (float): Factor for exponential backoff between retries. 42 | 43 | Returns: 44 | str: The fetched text or None if all retries fail. 45 | """ 46 | attempt = 0 47 | 48 | while attempt < retries: 49 | try: 50 | response = requests.get(url, timeout=timeout) 51 | response.raise_for_status() 52 | response.encoding = 'utf-8' 53 | return response.text 54 | except requests.exceptions.RequestException as e: 55 | attempt += 1 56 | wait = backoff_factor * (2 ** attempt) 57 | # print(f"Attempt {attempt} failed: {e}. Retrying in {wait:.1f} seconds...") 58 | sleep(wait) 59 | 60 | print(f"All attempts to fetch the URL have failed when load {url}.") 61 | return "Hello, welcome to TeleDS! Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), Northwestern PolyTechnical University, and Institute of Artificial Intelligence (TeleAI), China Telecom." 62 | -------------------------------------------------------------------------------- /sentence_vae/utils/llm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | import torch 30 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer 31 | 32 | 33 | def get_config(ckpt_path): 34 | print(f"Loading config from {ckpt_path}") 35 | cfg = AutoConfig.from_pretrained(ckpt_path) 36 | return cfg 37 | 38 | 39 | def get_tokenizer(ckpt_path, max_seq_len): 40 | print(f"Initializaing tokenizer from {ckpt_path}") 41 | tokenizer = AutoTokenizer.from_pretrained( 42 | ckpt_path, 43 | model_max_length=max_seq_len, 44 | padding_side="right", 45 | trust_remote_code=True 46 | ) 47 | return tokenizer 48 | 49 | 50 | def get_dtype(dtype): 51 | if isinstance(dtype, torch.dtype): 52 | return dtype 53 | 54 | if dtype == "bf16" or dtype == "bfloat16": 55 | dtype = torch.bfloat16 56 | elif dtype == "fp16" or dtype == "float16": 57 | dtype = torch.float16 58 | elif dtype == "fp32" or dtype == "float32": 59 | dtype = torch.float32 60 | else: 61 | raise NotImplementedError(f"Unknown dtype {dtype}") 62 | 63 | return dtype 64 | 65 | 66 | def is_model_on_gpu(model) -> bool: 67 | """Returns if the model is fully loaded on GPUs.""" 68 | return all("cuda" in str(param.device) for param in model.parameters()) 69 | 70 | 71 | def get_model(ckpt_path, dtype="fp16", device="cuda", dist=False): 72 | print(f"Initializaing model from {ckpt_path}") 73 | dtype = get_dtype(dtype) 74 | model_kwargs = {"torch_dtype": dtype} 75 | 76 | device_map = "auto" 77 | if device == "cpu": 78 | device_map = "cpu" 79 | if dist: 80 | device_map = {"": device} 81 | 82 | model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map=device_map, trust_remote_code=True, **model_kwargs) 83 | model.eval() 84 | 85 | if device == "cuda": 86 | if not is_model_on_gpu(model): 87 | print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM") 88 | 89 | return model -------------------------------------------------------------------------------- /sentence_vae/utils/weights.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | import torch.nn as nn 30 | 31 | 32 | def init_model_weights(model: nn.Module, std: float = 1e-5): 33 | for _, module in model.named_modules(): 34 | if isinstance(module, nn.Linear): 35 | module.weight.data.normal_(mean=0, std=std) 36 | if module.bias is not None: 37 | module.bias.data.zero_() 38 | elif isinstance(module, nn.Embedding): 39 | module.weight.data.normal_(mean=0, std=std) 40 | if module.padding_idx is not None: 41 | module.weight.data[module.padding_idx].zero_() 42 | 43 | 44 | def load_embedding_state_dict(model: nn.Module): 45 | for _, module in model.named_modules(): 46 | if isinstance(module, nn.Embedding): 47 | return module.state_dict() 48 | return None 49 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | import re 30 | import setuptools 31 | 32 | 33 | def get_install_requirements(): 34 | with open("requirements.txt", "r", encoding="utf-8") as f: 35 | reqs = [x.strip() for x in f.read().splitlines()] 36 | reqs = [x for x in reqs if not x.startswith("#")] 37 | return reqs 38 | 39 | 40 | def get_sentence_vae_version(): 41 | with open("sentence_vae/__init__.py", "r") as f: 42 | version = re.search( 43 | r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', 44 | f.read(), re.MULTILINE 45 | ).group(1) 46 | return version 47 | 48 | 49 | def get_long_description(): 50 | with open("README.md", "r", encoding="utf-8") as f: 51 | long_description = f.read() 52 | return long_description 53 | 54 | 55 | setuptools.setup( 56 | name="sentence_vae", 57 | version=get_sentence_vae_version(), 58 | author="Coder.AN", 59 | url="https://github.com/BestAnHongjun/SentenceVAE", 60 | packages=setuptools.find_packages(exclude=("tools")), 61 | python_requires=">=3.8", 62 | install_requires=get_install_requirements(), 63 | setup_requires=["wheel"], # avoid building error when pip is not updated 64 | long_description=get_long_description(), 65 | long_description_content_type="text/markdown", 66 | include_package_data=True, # include files in MANIFEST.in 67 | classifiers=[ 68 | "Programming Language :: Python :: 3", "Operating System :: OS Independent", 69 | "License :: OSI Approved :: MIT Software License", 70 | ], 71 | project_urls={ 72 | "paper": "https://arxiv.org/abs/2408.00655", 73 | "Source": "https://github.com/BestAnHongjun/SentenceVAE", 74 | }, 75 | ) -------------------------------------------------------------------------------- /tools/analysis/analysis_llm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | import torch 30 | import argparse 31 | from mmengine.analysis import get_model_complexity_info 32 | from sentence_vae.utils import get_model 33 | 34 | 35 | def make_parser(): 36 | parser = argparse.ArgumentParser("SentenceVAE analysis parser.") 37 | parser.add_argument("--model_dir", type=str, required=True) 38 | parser.add_argument("--model_dtype", type=str, default="fp16") 39 | parser.add_argument("--device", type=str, default='cuda') 40 | parser.add_argument("--max_seq_len", type=int, default=1024) 41 | parser.add_argument("--batch_size", type=int, default=1) 42 | return parser 43 | 44 | 45 | def main(args): 46 | model = get_model(args.model_dir, args.model_dtype, args.device).eval() 47 | input_shape = (args.batch_size, args.max_seq_len) 48 | input_tensor = torch.ones(input_shape, dtype=torch.long, device=torch.device(args.device)) 49 | print(get_model_complexity_info(model, inputs=input_tensor)['out_table']) 50 | 51 | 52 | if __name__ == "__main__": 53 | args = make_parser().parse_args() 54 | main(args) 55 | -------------------------------------------------------------------------------- /tools/analysis/analysis_sllm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | import torch 30 | import argparse 31 | 32 | from mmengine.analysis import get_model_complexity_info 33 | 34 | from sentence_vae.utils import get_config, get_dtype, load_yaml 35 | from sentence_vae.models import SentenceLLM 36 | 37 | 38 | def make_parser(): 39 | parser = argparse.ArgumentParser("SentenceLLM train parser.") 40 | parser.add_argument("-c", "--config", type=str, required=True) 41 | parser.add_argument("--device", type=str, default='cuda') 42 | parser.add_argument("--cards", type=int, default=4) 43 | return parser 44 | 45 | 46 | def main(args): 47 | cfg = load_yaml(args.config) 48 | svae_ref_model_cfg = get_config(cfg["svae"]["ref_model_dir"]) 49 | 50 | model = SentenceLLM( 51 | svae_hidden_size=svae_ref_model_cfg.hidden_size, 52 | svae_vocab_size=svae_ref_model_cfg.vocab_size, 53 | svae_learnable_add=cfg["svae"]["learnable_add"], 54 | svae_load_ref_model=cfg["svae"]["load_ref_model"], 55 | svae_ref_model_dir=cfg["svae"]["ref_model_dir"], 56 | svae_ref_model_dtype=svae_ref_model_cfg.torch_dtype, 57 | svae_finetune_embedding=cfg["svae"]["finetune_embedding"], 58 | svae_word_embed_proj_dim=svae_ref_model_cfg.word_embed_proj_dim, 59 | svae_num_attention_heads=svae_ref_model_cfg.num_attention_heads, 60 | svae_num_hidden_layers=cfg["svae"]["num_hidden_layers"], 61 | svae_model_path=cfg["svae"]["model_path"], 62 | llm_ref_model_dir=cfg["llm"]["ref_model_dir"], 63 | llm_ref_model_dtype=svae_ref_model_cfg.torch_dtype, 64 | llm_finetune_layers=cfg["llm"]["finetune_layers"], 65 | finetune_svae=cfg["finetune_svae"], 66 | max_sentence_len=cfg["max_sen_len"], 67 | max_sentence_num=cfg["max_sen_num"], 68 | dropout=svae_ref_model_cfg.dropout, 69 | bos_id=svae_ref_model_cfg.bos_token_id, 70 | pad_id=svae_ref_model_cfg.pad_token_id, 71 | end_id=svae_ref_model_cfg.eos_token_id, 72 | device=torch.device(cfg["device"]), 73 | dtype=get_dtype(cfg["dtype"]) 74 | ).eval() 75 | 76 | batch_size = cfg["batch_size"] * args.cards 77 | device = torch.device(args.device) 78 | batch_sentence_mask = torch.ones((batch_size, cfg["max_sen_num"]), dtype=torch.long, device=device) 79 | batch_sentence_toks = torch.ones((batch_size, cfg["max_sen_num"], cfg["max_sen_len"]), dtype=torch.long, device=device) 80 | batch_tok_mask = torch.ones((batch_size, cfg["max_sen_num"], cfg["max_sen_len"]), dtype=torch.long, device=device) 81 | 82 | print(get_model_complexity_info(model, inputs=(batch_sentence_mask, batch_sentence_toks, batch_tok_mask))['out_table']) 83 | 84 | 85 | if __name__ == "__main__": 86 | args = make_parser().parse_args() 87 | main(args) 88 | -------------------------------------------------------------------------------- /tools/analysis/analysis_svae.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | import torch 30 | import argparse 31 | 32 | from mmengine.analysis import get_model_complexity_info 33 | 34 | from sentence_vae.utils import get_config, get_dtype, load_yaml 35 | from sentence_vae.models import SentenceVAE 36 | 37 | 38 | def make_parser(): 39 | parser = argparse.ArgumentParser("SentenceLLM train parser.") 40 | parser.add_argument("-c", "--config", type=str, required=True) 41 | parser.add_argument("--device", type=str, default='cuda') 42 | parser.add_argument("--cards", type=int, default=1) 43 | return parser 44 | 45 | 46 | def main(args): 47 | cfg = load_yaml(args.config) 48 | ref_model_cfg = get_config(cfg["ref_model_dir"]) 49 | 50 | model = SentenceVAE( 51 | hidden_size=ref_model_cfg.hidden_size, 52 | vocab_size=ref_model_cfg.vocab_size, 53 | device=torch.device(cfg["device"]), 54 | dtype=get_dtype(cfg["dtype"]), 55 | learnable_add=cfg["learnable_add"], 56 | load_ref_model=cfg["load_ref_model"], 57 | ref_model_dir=cfg["ref_model_dir"], 58 | ref_model_dtype=get_dtype(cfg["ref_model_dtype"]) if cfg["ref_model_dtype"] is not None else None, 59 | finetune_embedding=cfg["finetune_embedding"], 60 | num_attention_heads=ref_model_cfg.num_attention_heads, 61 | num_hidden_layers=cfg["num_hidden_layers"], 62 | max_seq_len=cfg["max_seq_len"], 63 | dropout=ref_model_cfg.dropout, 64 | bos_id=ref_model_cfg.bos_token_id, 65 | pad_id=ref_model_cfg.pad_token_id, 66 | end_id=ref_model_cfg.eos_token_id 67 | ).eval() 68 | 69 | batch_size = cfg["batch_size"] * args.cards 70 | device = torch.device(args.device) 71 | sentences = torch.ones((batch_size, cfg["max_seq_len"]), dtype=torch.long, device=device) 72 | sentence_mask = torch.ones((batch_size, cfg["max_seq_len"]), dtype=torch.long, device=device) 73 | 74 | print(get_model_complexity_info(model, inputs=(sentences, sentence_mask))['out_table']) 75 | 76 | 77 | if __name__ == "__main__": 78 | args = make_parser().parse_args() 79 | main(args) 80 | -------------------------------------------------------------------------------- /tools/analysis/statistic_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | import argparse 30 | from tqdm import tqdm 31 | 32 | import torch 33 | from torch.utils.data import DataLoader 34 | 35 | from mmengine.dataset import DefaultSampler 36 | 37 | from sentence_vae.utils import get_tokenizer 38 | from sentence_vae.data import TeleDSDataset, SentenceCollate, PassageCollate 39 | 40 | 41 | def make_parser(): 42 | parser = argparse.ArgumentParser("SentenceVAE statistic parser.") 43 | parser.add_argument("--tokenizer_dir", type=str) 44 | parser.add_argument("--max_sen_len", type=int, default=64) 45 | parser.add_argument("--max_sen_num", type=int, default=64) 46 | parser.add_argument("--mode", choices=["sentence", "passage"], default='sentence') 47 | parser.add_argument("--card_size", type=int, default=4) 48 | parser.add_argument("--batch_size", type=int, default=1) 49 | parser.add_argument("--teleds_ip", type=str, default="127.0.0.1") 50 | parser.add_argument("--teleds_port", type=int, default=8000) 51 | parser.add_argument("--num_workers", type=int, default=16) 52 | parser.add_argument("--prefetch_factor", type=int, default=5) 53 | parser.add_argument("--max_iters", type=int, default=300000) 54 | return parser 55 | 56 | 57 | def main(args): 58 | tokenizer = get_tokenizer(ckpt_path=args.tokenizer_dir, max_seq_len=args.max_sen_len) 59 | 60 | dataset = TeleDSDataset(server_ip=args.teleds_ip, server_port=args.teleds_port) 61 | sampler = DefaultSampler(dataset, shuffle=False) 62 | if args.mode == "sentence": 63 | collate_fn = SentenceCollate(tokenizer=tokenizer, max_len=args.max_sen_len, padding=True) 64 | else: 65 | collate_fn = PassageCollate(tokenizer=tokenizer, max_sentence_len=args.max_sen_len, max_sentence_num=args.max_sen_num, padding=True) 66 | 67 | dataloader = DataLoader( 68 | dataset=dataset, 69 | batch_size=args.batch_size * args.card_size, 70 | sampler=sampler, 71 | collate_fn=collate_fn, 72 | num_workers=args.num_workers, 73 | prefetch_factor=args.prefetch_factor 74 | ) 75 | 76 | all_tokens = 0 77 | total_len = len(dataloader) 78 | if total_len > args.max_iters: 79 | total_len = args.max_iters 80 | 81 | pbar = tqdm(range(total_len)) 82 | for i, data in enumerate(dataloader): 83 | if i >= total_len: 84 | break 85 | if args.mode == 'sentence': 86 | input_ids, attention_mask = data 87 | tokens = torch.sum(attention_mask).item() 88 | all_tokens += tokens 89 | pbar.set_postfix(tokens=all_tokens) 90 | pbar.update(1) 91 | else: 92 | batch_sentence_mask, batch_sentence_toks, batch_tok_mask = data 93 | batch_size = batch_sentence_mask.size(0) 94 | for b in range(batch_size): 95 | attention_mask = batch_tok_mask[b][batch_sentence_mask[b]] 96 | tokens = torch.sum(attention_mask).item() 97 | all_tokens += tokens 98 | pbar.set_postfix(tokens=all_tokens) 99 | pbar.update(1) 100 | pbar.close() 101 | print("Total tokens:", all_tokens) 102 | 103 | 104 | if __name__ == "__main__": 105 | args = make_parser().parse_args() 106 | main(args) 107 | -------------------------------------------------------------------------------- /tools/analysis/statistic_vocab.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | import torch 30 | import argparse 31 | 32 | from sentence_vae.utils import get_model, get_tokenizer 33 | from sentence_vae.data import SentenceCollate 34 | 35 | 36 | def make_parser(): 37 | parser = argparse.ArgumentParser("SentenceVAE analysis parser.") 38 | parser.add_argument("--model_dir", type=str, required=True) 39 | parser.add_argument("--max_seq_len", type=int, default=1024) 40 | return parser 41 | 42 | 43 | def main(args): 44 | tokenizer = get_tokenizer(ckpt_path=args.model_dir, max_seq_len=args.max_seq_len) 45 | vocab = tokenizer.get_vocab() 46 | total_bytes = 0 47 | 48 | for token in vocab: 49 | token_bytes = len(token.encode('utf-8')) 50 | total_bytes += token_bytes 51 | 52 | average_bytes = total_bytes / len(vocab) 53 | print(f"Average bytes per token: {average_bytes:.2f}") 54 | 55 | 56 | if __name__ == "__main__": 57 | args = make_parser().parse_args() 58 | main(args) 59 | -------------------------------------------------------------------------------- /tools/benchmark/benchmark_llm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | import time 30 | import argparse 31 | import numpy as np 32 | from tqdm import tqdm 33 | 34 | import torch 35 | from torch.utils.data import DataLoader 36 | 37 | from mmengine.dataset import DefaultSampler 38 | 39 | from sentence_vae.utils import get_model, get_tokenizer 40 | from sentence_vae.data import TeleDSDataset, SentenceCollate 41 | 42 | 43 | def make_parser(): 44 | parser = argparse.ArgumentParser("SentenceVAE benchmark parser.") 45 | parser.add_argument("--server", type=str, default="127.0.0.1") 46 | parser.add_argument("--port", type=int, default=8001) 47 | parser.add_argument("--model_dir", type=str, required=True) 48 | parser.add_argument("--model_dtype", type=str, default="fp16") 49 | parser.add_argument("--device", type=str, default='cuda') 50 | parser.add_argument("--max_seq_len", type=int, default=1024) 51 | parser.add_argument("--max_eval_samples", type=int, default=10) 52 | return parser 53 | 54 | 55 | def main(args): 56 | model = get_model(args.model_dir, args.model_dtype, args.device).eval() 57 | tokenizer = get_tokenizer(ckpt_path=args.model_dir, max_seq_len=args.max_seq_len) 58 | 59 | eval_dataset = TeleDSDataset(server_ip=args.server, server_port=args.port, eval_mode=True, eval_samples=args.max_eval_samples) 60 | eval_sampler = DefaultSampler(eval_dataset, shuffle=False) 61 | eval_collate_fn = SentenceCollate(tokenizer=tokenizer, max_len=args.max_seq_len, padding=True, fix_len=False) 62 | 63 | eval_dataloader = DataLoader( 64 | dataset=eval_dataset, 65 | batch_size=1, 66 | sampler=eval_sampler, 67 | collate_fn=eval_collate_fn, 68 | num_workers=8, 69 | prefetch_factor=20 70 | ) 71 | 72 | input_toks, output_toks = 0, 0 73 | input_times, output_times = 0, 0 74 | mem_usage = [] 75 | device = torch.device(args.device) 76 | 77 | with torch.no_grad(): 78 | for data in tqdm(eval_dataloader): 79 | input_ids, attention_mask = data 80 | seq_len = torch.sum(attention_mask) // 2 81 | if seq_len < 1: 82 | continue 83 | input_ids = input_ids[:1, :seq_len].to(device) 84 | 85 | # 预填充阶段 86 | start_time = time.perf_counter() 87 | output = model(input_ids) 88 | input_time = time.perf_counter() 89 | 90 | input_toks += seq_len 91 | input_times += input_time - start_time 92 | 93 | # 推理阶段 94 | while True: 95 | logits = output.logits 96 | past_key_values = output.past_key_values 97 | new_id = torch.argmax(logits[:1, -1:], dim=-1) 98 | input_ids = torch.concat((input_ids, new_id), dim=1) 99 | if new_id.item() == tokenizer.eos_token_id: 100 | break 101 | if input_ids.size(1) >= args.max_seq_len: 102 | break 103 | output = model(new_id, past_key_values=past_key_values) 104 | mem_usage.append([input_ids.size(1), torch.cuda.memory_allocated() / 1024]) 105 | 106 | output_time = time.perf_counter() 107 | output_toks += input_ids.size(1) - seq_len 108 | output_times += output_time - input_time 109 | 110 | mem_usage = np.array(mem_usage) 111 | k, b = np.polyfit(mem_usage[:, 0], mem_usage[:, 1], 1) 112 | print(f"显存占用: {b / 1024:.2f}MB + {k:.2f}KB/token") 113 | print(f"Input: {input_toks / input_times:.2f} tokens/s") 114 | print(f"Output: {output_toks / output_times:.2f} tokens/s") 115 | 116 | 117 | if __name__ == "__main__": 118 | args = make_parser().parse_args() 119 | main(args) 120 | -------------------------------------------------------------------------------- /tools/benchmark/benchmark_sllm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | import os 30 | import time 31 | import argparse 32 | import numpy as np 33 | from tqdm import tqdm 34 | 35 | import torch 36 | from torch.utils.data import DataLoader 37 | 38 | from mmengine.dataset import DefaultSampler 39 | 40 | from sentence_vae.utils import get_config, get_tokenizer, get_dtype, load_yaml 41 | from sentence_vae.models import SentenceLLM 42 | from sentence_vae.data import TeleDSDataset, PassageCollate 43 | 44 | 45 | def make_parser(): 46 | parser = argparse.ArgumentParser("SentenceLLM benchmark parser.") 47 | parser.add_argument("--server", type=str, default="127.0.0.1") 48 | parser.add_argument("--port", type=int, default=8001) 49 | parser.add_argument("--device", type=str, default='cuda') 50 | parser.add_argument("-c", "--config", type=str, required=True) 51 | parser.add_argument("--checkpoint", type=str, default=None) 52 | parser.add_argument("--max_eval_samples", type=int, default=10) 53 | return parser 54 | 55 | 56 | def main(args): 57 | cfg = load_yaml(args.config) 58 | expn = os.path.splitext(os.path.basename(args.config))[0] 59 | svae_ref_model_cfg = get_config(cfg["svae"]["ref_model_dir"]) 60 | 61 | model = SentenceLLM( 62 | svae_hidden_size=svae_ref_model_cfg.hidden_size, 63 | svae_vocab_size=svae_ref_model_cfg.vocab_size, 64 | svae_learnable_add=cfg["svae"]["learnable_add"], 65 | svae_load_ref_model=cfg["svae"]["load_ref_model"], 66 | svae_ref_model_dir=cfg["svae"]["ref_model_dir"], 67 | svae_ref_model_dtype=svae_ref_model_cfg.torch_dtype, 68 | svae_finetune_embedding=cfg["svae"]["finetune_embedding"], 69 | svae_word_embed_proj_dim=svae_ref_model_cfg.word_embed_proj_dim, 70 | svae_num_attention_heads=svae_ref_model_cfg.num_attention_heads, 71 | svae_num_hidden_layers=cfg["svae"]["num_hidden_layers"], 72 | svae_model_path=cfg["svae"]["model_path"], 73 | llm_ref_model_dir=cfg["llm"]["ref_model_dir"], 74 | llm_ref_model_dtype=svae_ref_model_cfg.torch_dtype, 75 | llm_finetune_layers=cfg["llm"]["finetune_layers"], 76 | finetune_svae=cfg["finetune_svae"], 77 | max_sentence_len=cfg["max_sen_len"], 78 | max_sentence_num=cfg["max_sen_num"], 79 | dropout=svae_ref_model_cfg.dropout, 80 | bos_id=svae_ref_model_cfg.bos_token_id, 81 | pad_id=svae_ref_model_cfg.pad_token_id, 82 | end_id=svae_ref_model_cfg.eos_token_id, 83 | device=torch.device(cfg["device"]), 84 | dtype=get_dtype(cfg["dtype"]) 85 | ) 86 | 87 | exp_dir = f"exp/SentenceVAE-{expn}" 88 | ckpt_list = os.listdir(exp_dir) 89 | ckpt = args.checkpoint 90 | if ckpt is None: 91 | for ckpt_path in ckpt_list: 92 | if "best" in ckpt_path: 93 | ckpt_path = os.path.join(exp_dir, ckpt_path) 94 | ckpt = torch.load(ckpt_path)['state_dict'] 95 | assert ckpt is not None, f"Not found the best checkpoint under {exp_dir}." 96 | else: 97 | assert os.path.exists(ckpt), f"Checkpoint {ckpt} not found." 98 | ckpt = torch.load(ckpt)["state_dict"] 99 | model.load_state_dict(ckpt) 100 | model.eval() 101 | 102 | tokenizer = get_tokenizer(ckpt_path=cfg["svae"]["ref_model_dir"], max_seq_len=cfg["max_sen_len"]) 103 | 104 | eval_dataset = TeleDSDataset(server_ip=args.server, server_port=args.port, eval_mode=True, eval_samples=args.max_eval_samples) 105 | eval_sampler = DefaultSampler(eval_dataset, shuffle=False) 106 | eval_collate_fn = PassageCollate(tokenizer=tokenizer, max_sentence_len=cfg["max_sen_len"], max_sentence_num=cfg["max_sen_num"], padding=True) 107 | 108 | eval_dataloader = DataLoader( 109 | dataset=eval_dataset, 110 | batch_size=1, 111 | sampler=eval_sampler, 112 | collate_fn=eval_collate_fn, 113 | num_workers=cfg["dataloader_num_workers"], 114 | prefetch_factor=cfg["dataloader_prefetch_factor"] 115 | ) 116 | 117 | input_toks, output_toks = 0, 0 118 | input_times, output_times = 0, 0 119 | mem_usage = [] 120 | device = torch.device(args.device) 121 | 122 | with torch.no_grad(): 123 | for data in tqdm(eval_dataloader): 124 | batch_sentence_mask, batch_sentence_toks, batch_tok_mask = data 125 | sentence_num = torch.sum(batch_sentence_mask).item() // 2 126 | if sentence_num < 1: 127 | continue 128 | seq_len = torch.sum(batch_tok_mask[0, :sentence_num]).item() 129 | input_toks += seq_len 130 | 131 | batch_sentence_mask = batch_sentence_mask[:1, :seq_len].to(device) 132 | batch_sentence_toks = batch_sentence_toks[:1, :seq_len, :].to(device) 133 | batch_tok_mask = batch_tok_mask[:1, :seq_len, :].to(device) 134 | 135 | # 预填充阶段 136 | tokens_i = seq_len 137 | input_time = -1 138 | start_time = time.perf_counter() 139 | 140 | for i, new_tokens in enumerate(model.streaming_generate( 141 | batch_sentence_mask, 142 | batch_sentence_toks, 143 | batch_tok_mask 144 | )): 145 | if i == 0: 146 | input_time = time.perf_counter() 147 | tokens_i += len(new_tokens) 148 | mem_usage.append([tokens_i, torch.cuda.memory_allocated() / 1024]) 149 | 150 | if input_time < 0: 151 | continue 152 | 153 | output_time = time.perf_counter() 154 | output_toks += tokens_i - seq_len 155 | output_times += output_time - input_time 156 | input_times += input_time - start_time 157 | 158 | mem_usage = np.array(mem_usage) 159 | k, b = np.polyfit(mem_usage[:, 0], mem_usage[:, 1], 1) 160 | print(f"显存占用: {b / 1024:.2f}MB + {k:.2f}KB/token") 161 | print(f"Input: {input_toks / input_times:.2f} tokens/s") 162 | print(f"Output: {output_toks / output_times:.2f} tokens/s") 163 | 164 | 165 | if __name__ == "__main__": 166 | args = make_parser().parse_args() 167 | main(args) 168 | -------------------------------------------------------------------------------- /tools/demo/demo_llm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | import torch 30 | import argparse 31 | 32 | from sentence_vae.utils import get_model, get_tokenizer 33 | from sentence_vae.data import SentenceCollate 34 | 35 | 36 | def make_parser(): 37 | parser = argparse.ArgumentParser("SentenceVAE benchmark parser.") 38 | parser.add_argument("--model_dir", type=str, required=True) 39 | parser.add_argument("--model_dtype", type=str, default="fp16") 40 | parser.add_argument("--device", type=str, default='cuda') 41 | parser.add_argument("--max_seq_len", type=int, default=1024) 42 | parser.add_argument("--input", type=str, default="Hello,") 43 | return parser 44 | 45 | 46 | def main(args): 47 | model = get_model(args.model_dir, args.model_dtype, args.device).eval() 48 | tokenizer = get_tokenizer(ckpt_path=args.model_dir, max_seq_len=args.max_seq_len) 49 | 50 | collate_fn = SentenceCollate(tokenizer=tokenizer, max_len=args.max_seq_len, padding=True, fix_len=False) 51 | input_ids = tokenizer.batch_encode_plus([args.input], return_tensors="pt")['input_ids'] 52 | 53 | device = torch.device(args.device) 54 | input_ids = input_ids.to(device) 55 | output = model(input_ids) 56 | 57 | print("Input:", args.input) 58 | print("Output:") 59 | while True: 60 | logits = output.logits 61 | past_key_values = output.past_key_values 62 | new_id = torch.argmax(logits[:1, -1:], dim=-1) 63 | input_ids = torch.concat((input_ids, new_id), dim=1) 64 | if new_id.item() == tokenizer.eos_token_id: 65 | break 66 | if input_ids.size(1) >= args.max_seq_len: 67 | break 68 | output_word = tokenizer.decode(new_id.item(), skip_special_tokens=True) 69 | print(output_word, end="") 70 | output = model(new_id, past_key_values=past_key_values) 71 | print() 72 | 73 | 74 | if __name__ == "__main__": 75 | args = make_parser().parse_args() 76 | main(args) 77 | -------------------------------------------------------------------------------- /tools/demo/demo_sllm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | import os 30 | import torch 31 | import argparse 32 | 33 | from sentence_vae.models import SentenceLLM 34 | from sentence_vae.data import PassageCollate 35 | from sentence_vae.utils import get_config, get_tokenizer, get_dtype, load_yaml 36 | 37 | 38 | def make_parser(): 39 | parser = argparse.ArgumentParser("SentenceVAE demo parser.") 40 | parser.add_argument("-c", "--config", type=str, required=True) 41 | parser.add_argument("--checkpoint", type=str, default=None) 42 | parser.add_argument("--input", type=str, default="Hello,") 43 | return parser 44 | 45 | 46 | def main(args): 47 | cfg = load_yaml(args.config) 48 | expn = os.path.splitext(os.path.basename(args.config))[0] 49 | svae_ref_model_cfg = get_config(cfg["svae"]["ref_model_dir"]) 50 | 51 | model = SentenceLLM( 52 | svae_hidden_size=svae_ref_model_cfg.hidden_size, 53 | svae_vocab_size=svae_ref_model_cfg.vocab_size, 54 | svae_learnable_add=cfg["svae"]["learnable_add"], 55 | svae_load_ref_model=cfg["svae"]["load_ref_model"], 56 | svae_ref_model_dir=cfg["svae"]["ref_model_dir"], 57 | svae_ref_model_dtype=svae_ref_model_cfg.torch_dtype, 58 | svae_finetune_embedding=cfg["svae"]["finetune_embedding"], 59 | svae_word_embed_proj_dim=svae_ref_model_cfg.word_embed_proj_dim, 60 | svae_num_attention_heads=svae_ref_model_cfg.num_attention_heads, 61 | svae_num_hidden_layers=cfg["svae"]["num_hidden_layers"], 62 | svae_model_path=cfg["svae"]["model_path"], 63 | llm_ref_model_dir=cfg["llm"]["ref_model_dir"], 64 | llm_ref_model_dtype=svae_ref_model_cfg.torch_dtype, 65 | llm_finetune_layers=cfg["llm"]["finetune_layers"], 66 | finetune_svae=cfg["finetune_svae"], 67 | max_sentence_len=cfg["max_sen_len"], 68 | max_sentence_num=cfg["max_sen_num"], 69 | dropout=svae_ref_model_cfg.dropout, 70 | bos_id=svae_ref_model_cfg.bos_token_id, 71 | pad_id=svae_ref_model_cfg.pad_token_id, 72 | end_id=svae_ref_model_cfg.eos_token_id, 73 | device=torch.device(cfg["device"]), 74 | dtype=get_dtype(cfg["dtype"]) 75 | ) 76 | 77 | exp_dir = f"exp/SentenceVAE-{expn}" 78 | ckpt_list = os.listdir(exp_dir) 79 | ckpt = args.checkpoint 80 | if ckpt is None: 81 | for ckpt_path in ckpt_list: 82 | if "best" in ckpt_path: 83 | ckpt_path = os.path.join(exp_dir, ckpt_path) 84 | ckpt = torch.load(ckpt_path)['state_dict'] 85 | assert ckpt is not None, f"Not found the best checkpoint under {exp_dir}." 86 | else: 87 | assert os.path.exists(ckpt), f"Checkpoint {ckpt} not found." 88 | ckpt = torch.load(ckpt)["state_dict"] 89 | model.load_state_dict(ckpt) 90 | model.eval() 91 | 92 | tokenizer = get_tokenizer(ckpt_path=cfg["svae"]["ref_model_dir"], max_seq_len=cfg["max_sen_len"]) 93 | collate_fn = PassageCollate(tokenizer=tokenizer, max_sentence_len=cfg["max_sen_len"], max_sentence_num=cfg["max_sen_num"], padding=True) 94 | batch_sentence_mask, batch_sentence_toks, batch_tok_mask = collate_fn([args.input]) 95 | 96 | print(f"Input: {args.input}") 97 | print("Output:") 98 | 99 | for output_ids in model.streaming_generate( 100 | batch_sentence_mask, 101 | batch_sentence_toks, 102 | batch_tok_mask 103 | ): 104 | new_sentence = tokenizer.decode(output_ids, skip_special_tokens=True) 105 | print(new_sentence) 106 | 107 | 108 | if __name__ == "__main__": 109 | args = make_parser().parse_args() 110 | main(args) 111 | -------------------------------------------------------------------------------- /tools/demo/demo_svae.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | import os 30 | import torch 31 | import argparse 32 | 33 | from sentence_vae.models import SentenceVAE 34 | from sentence_vae.utils import get_config, get_tokenizer, get_dtype, load_yaml 35 | 36 | 37 | def make_parser(): 38 | parser = argparse.ArgumentParser("SentenceVAE demo parser.") 39 | parser.add_argument("-c", "--config", type=str, required=True) 40 | parser.add_argument("--checkpoint", type=str, default=None) 41 | parser.add_argument("--input", type=str, default=None) 42 | return parser 43 | 44 | 45 | def main(args): 46 | cfg = load_yaml(args.config) 47 | expn = os.path.splitext(os.path.basename(args.config))[0] 48 | ref_model_cfg = get_config(cfg["ref_model_dir"]) 49 | 50 | model = SentenceVAE( 51 | hidden_size=ref_model_cfg.hidden_size, 52 | vocab_size=ref_model_cfg.vocab_size, 53 | device=torch.device(cfg["device"]), 54 | dtype=get_dtype(cfg["dtype"]), 55 | learnable_add=cfg["learnable_add"], 56 | load_ref_model=cfg["load_ref_model"], 57 | ref_model_dir=cfg["ref_model_dir"], 58 | ref_model_dtype=get_dtype(cfg["ref_model_dtype"]) if cfg["ref_model_dtype"] is not None else None, 59 | finetune_embedding=cfg["finetune_embedding"], 60 | num_attention_heads=ref_model_cfg.num_attention_heads, 61 | num_hidden_layers=cfg["num_hidden_layers"], 62 | max_seq_len=cfg["max_seq_len"], 63 | dropout=ref_model_cfg.dropout, 64 | bos_id=ref_model_cfg.bos_token_id, 65 | pad_id=ref_model_cfg.pad_token_id, 66 | end_id=ref_model_cfg.eos_token_id 67 | ) 68 | exp_dir = f"exp/{expn}" 69 | ckpt_list = os.listdir(exp_dir) 70 | ckpt = args.checkpoint 71 | if ckpt is None: 72 | for ckpt_path in ckpt_list: 73 | if "best" in ckpt_path: 74 | ckpt_path = os.path.join(exp_dir, ckpt_path) 75 | ckpt = torch.load(ckpt_path)['state_dict'] 76 | assert ckpt is not None, f"Not found the best checkpoint under {exp_dir}." 77 | else: 78 | assert os.path.exists(ckpt), f"Checkpoint {ckpt} not found." 79 | ckpt = torch.load(ckpt)["state_dict"] 80 | model.load_state_dict(ckpt) 81 | model.eval() 82 | 83 | tokenizer = get_tokenizer(ckpt_path=cfg["ref_model_dir"], max_seq_len=cfg["max_seq_len"]) 84 | 85 | if args.input is None: 86 | input_texts = [ 87 | 'I love China.', 88 | 'We come from Northwestern Polytechnical University.', 89 | "Hello,", 90 | "Welcome to TeleAI!", 91 | "What's your name?", 92 | "What's your problem?", 93 | "Hello, my dear friend", 94 | "Today is Friday.", 95 | "There is Institute of Artificial Intelligence (TeleAI), China Telecom.", 96 | "One two three four five six seven eight nine ten~", 97 | "Hahaha... and you?", 98 | "Yao yao ling xian!" 99 | ] 100 | else: 101 | input_texts = [args.input] 102 | 103 | input_ids = tokenizer.batch_encode_plus( 104 | input_texts, 105 | padding=True, 106 | truncation=True, 107 | max_length=cfg["max_seq_len"], 108 | return_tensors="pt" 109 | ) 110 | attention_mask = input_ids['attention_mask'] 111 | input_ids = input_ids['input_ids'] 112 | 113 | for idx, input_text in enumerate(input_texts): 114 | print("--------------------") 115 | print(f"[{idx}]\tTest Input: ", end="") 116 | print(input_text) 117 | input_ids = tokenizer.encode(input_text, return_tensors='pt') 118 | print(f"\tTokens:{input_ids.size(1)}") 119 | print("\tVAE Output: ", end="") 120 | for output_id in model.streaming_generate(input_ids): 121 | output_word = tokenizer.decode(output_id, skip_special_tokens=True) 122 | print(output_word, end='') 123 | print() 124 | 125 | # exit(0) 126 | 127 | # output_ids, output_mask = model(input_ids, attention_mask, mode='predict') 128 | 129 | # output_texts = tokenizer.batch_decode(output_ids, skip_special_tokens=True) 130 | 131 | # print("VAE input:") 132 | # print(input_texts) 133 | # print("VAE output:") 134 | # print(output_texts) 135 | 136 | 137 | if __name__ == "__main__": 138 | args = make_parser().parse_args() 139 | main(args) 140 | -------------------------------------------------------------------------------- /tools/eval/eval_llm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | import argparse 30 | import numpy as np 31 | from tqdm import tqdm 32 | 33 | import torch 34 | import torch.nn.functional as F 35 | from torch.utils.data import DataLoader 36 | 37 | from mmengine.dataset import DefaultSampler 38 | 39 | from sentence_vae.utils import get_model, get_tokenizer 40 | from sentence_vae.data import TeleDSDataset, SentenceCollate 41 | 42 | 43 | def make_parser(): 44 | parser = argparse.ArgumentParser("SentenceVAE eval parser.") 45 | parser.add_argument("--server", type=str, default="127.0.0.1") 46 | parser.add_argument("--port", type=int, default=8001) 47 | parser.add_argument("--model_dir", type=str, required=True) 48 | parser.add_argument("--model_dtype", type=str, default="fp16") 49 | parser.add_argument("--device", type=str, default='cuda') 50 | parser.add_argument("--max_seq_len", type=int, default=2048) 51 | parser.add_argument("--batch_size", type=int, default=8) 52 | parser.add_argument("--max_eval_samples", type=int, default=1000) 53 | return parser 54 | 55 | 56 | def main(args): 57 | model = get_model(args.model_dir, args.model_dtype, args.device).eval() 58 | tokenizer = get_tokenizer(ckpt_path=args.model_dir, max_seq_len=args.max_seq_len) 59 | 60 | eval_dataset = TeleDSDataset(server_ip=args.server, server_port=args.port, eval_mode=True, eval_samples=args.max_eval_samples) 61 | eval_sampler = DefaultSampler(eval_dataset, shuffle=False) 62 | eval_collate_fn = SentenceCollate(tokenizer=tokenizer, max_len=args.max_seq_len, padding=True) 63 | 64 | eval_dataloader = DataLoader( 65 | dataset=eval_dataset, 66 | batch_size=args.batch_size, 67 | sampler=eval_sampler, 68 | collate_fn=eval_collate_fn, 69 | num_workers=8, 70 | prefetch_factor=20 71 | ) 72 | 73 | loss_list = [] 74 | device = torch.device(args.device) 75 | for data in tqdm(eval_dataloader): 76 | input_ids, attention_mask = data 77 | input_ids = input_ids.to(device) 78 | attention_mask = attention_mask.to(device) 79 | with torch.no_grad(): 80 | output = model(input_ids, attention_mask).logits 81 | batch, _ = input_ids.shape 82 | pad_ids = torch.zeros((batch, 1), device=device, dtype=input_ids.dtype).fill_(tokenizer.pad_token_id) 83 | tgt_ids = torch.cat((input_ids, pad_ids), dim=1) 84 | seq_lens = torch.sum(attention_mask, dim=1, keepdim=True) 85 | tgt_ids.scatter_(1, seq_lens, tokenizer.eos_token_id) 86 | attention_mask = attention_mask.bool() 87 | loss = F.cross_entropy(output[attention_mask], tgt_ids[:, 1:][attention_mask]).item() 88 | loss_list.append(loss) 89 | mean_loss = np.mean(np.array(loss_list)) 90 | ppl = np.exp(mean_loss) 91 | 92 | print("Exp:", args.model_dir) 93 | print("Best PPL:", ppl) 94 | 95 | 96 | if __name__ == "__main__": 97 | args = make_parser().parse_args() 98 | main(args) 99 | -------------------------------------------------------------------------------- /tools/eval/eval_sllm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | import os 30 | import argparse 31 | from tqdm import tqdm 32 | 33 | import torch 34 | from torch.utils.data import DataLoader 35 | 36 | from mmengine.dataset import DefaultSampler 37 | 38 | from sentence_vae.models import SentenceLLM 39 | from sentence_vae.utils import get_tokenizer, get_config, get_dtype, load_yaml 40 | from sentence_vae.data import TeleDSDataset, PassageCollate, SLLM_PPL 41 | 42 | 43 | def make_parser(): 44 | parser = argparse.ArgumentParser("SentenceVAE eval parser.") 45 | parser.add_argument("--server", type=str, default="127.0.0.1") 46 | parser.add_argument("--port", type=int, default=8001) 47 | parser.add_argument("-c", "--config", type=str, required=True) 48 | parser.add_argument("--checkpoint", type=str, default=None) 49 | parser.add_argument("--device", type=str, default='cuda') 50 | parser.add_argument("--max_eval_samples", type=int, default=1000) 51 | return parser 52 | 53 | 54 | def main(args): 55 | cfg = load_yaml(args.config) 56 | expn = os.path.splitext(os.path.basename(args.config))[0] 57 | svae_ref_model_cfg = get_config(cfg["svae"]["ref_model_dir"]) 58 | 59 | model = SentenceLLM( 60 | svae_hidden_size=svae_ref_model_cfg.hidden_size, 61 | svae_vocab_size=svae_ref_model_cfg.vocab_size, 62 | svae_learnable_add=cfg["svae"]["learnable_add"], 63 | svae_load_ref_model=cfg["svae"]["load_ref_model"], 64 | svae_ref_model_dir=cfg["svae"]["ref_model_dir"], 65 | svae_ref_model_dtype=svae_ref_model_cfg.torch_dtype, 66 | svae_finetune_embedding=cfg["svae"]["finetune_embedding"], 67 | svae_word_embed_proj_dim=svae_ref_model_cfg.word_embed_proj_dim, 68 | svae_num_attention_heads=svae_ref_model_cfg.num_attention_heads, 69 | svae_num_hidden_layers=cfg["svae"]["num_hidden_layers"], 70 | svae_model_path=cfg["svae"]["model_path"], 71 | llm_ref_model_dir=cfg["llm"]["ref_model_dir"], 72 | llm_ref_model_dtype=svae_ref_model_cfg.torch_dtype, 73 | llm_finetune_layers=cfg["llm"]["finetune_layers"], 74 | finetune_svae=cfg["finetune_svae"], 75 | max_sentence_len=cfg["max_sen_len"], 76 | max_sentence_num=cfg["max_sen_num"], 77 | dropout=svae_ref_model_cfg.dropout, 78 | bos_id=svae_ref_model_cfg.bos_token_id, 79 | pad_id=svae_ref_model_cfg.pad_token_id, 80 | end_id=svae_ref_model_cfg.eos_token_id, 81 | device=torch.device(cfg["device"]), 82 | dtype=get_dtype(cfg["dtype"]) 83 | ) 84 | 85 | exp_dir = f"exp/SentenceVAE-{expn}" 86 | ckpt_list = os.listdir(exp_dir) 87 | ckpt = args.checkpoint 88 | if ckpt is None: 89 | for ckpt_path in ckpt_list: 90 | if "best" in ckpt_path: 91 | ckpt_path = os.path.join(exp_dir, ckpt_path) 92 | ckpt = torch.load(ckpt_path)['state_dict'] 93 | assert ckpt is not None, f"Not found the best checkpoint under {exp_dir}." 94 | else: 95 | assert os.path.exists(ckpt), f"Checkpoint {ckpt} not found." 96 | ckpt = torch.load(ckpt)["state_dict"] 97 | model.load_state_dict(ckpt) 98 | model.eval() 99 | 100 | tokenizer = get_tokenizer(ckpt_path=cfg["svae"]["ref_model_dir"], max_seq_len=cfg["max_sen_len"]) 101 | 102 | eval_dataset = TeleDSDataset(server_ip=args.server, server_port=args.port, eval_mode=True) 103 | eval_sampler = DefaultSampler(eval_dataset, shuffle=False) 104 | eval_collate_fn = PassageCollate(tokenizer=tokenizer, max_sentence_len=cfg["max_sen_len"], max_sentence_num=cfg["max_sen_num"], padding=True) 105 | 106 | 107 | eval_dataloader = DataLoader( 108 | dataset=eval_dataset, 109 | batch_size=cfg["batch_size"], 110 | sampler=eval_sampler, 111 | collate_fn=eval_collate_fn, 112 | num_workers=cfg["dataloader_num_workers"], 113 | prefetch_factor=cfg["dataloader_prefetch_factor"] 114 | ) 115 | 116 | metric = SLLM_PPL() 117 | device = torch.device(args.device) 118 | for data in tqdm(eval_dataloader): 119 | batch_sentence_mask, batch_sentence_toks, batch_tok_mask = data 120 | batch_sentence_mask = batch_sentence_mask.to(device) 121 | batch_sentence_toks = batch_sentence_toks.to(device) 122 | batch_tok_mask = batch_tok_mask.to(device) 123 | output = model(batch_sentence_mask, batch_sentence_toks, batch_tok_mask, mode='predict') 124 | metric.process(batch_sentence_mask, output) 125 | results = metric.compute_metrics(metric.results) 126 | 127 | print("Exp:", expn) 128 | print("Best PPL:", results["eval_ppl"]) 129 | 130 | 131 | if __name__ == "__main__": 132 | args = make_parser().parse_args() 133 | main(args) 134 | -------------------------------------------------------------------------------- /tools/eval/eval_svae.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | import os 30 | import argparse 31 | from tqdm import tqdm 32 | 33 | import torch 34 | from torch.utils.data import DataLoader 35 | 36 | from mmengine.dataset import DefaultSampler 37 | 38 | from sentence_vae.models import SentenceVAE 39 | from sentence_vae.utils import get_tokenizer, get_config, get_dtype, load_yaml 40 | from sentence_vae.data import TeleDSDataset, SentenceCollate, SVAE_PPL 41 | 42 | 43 | def make_parser(): 44 | parser = argparse.ArgumentParser("SentenceVAE eval parser.") 45 | parser.add_argument("--server", type=str, default="127.0.0.1") 46 | parser.add_argument("--port", type=int, default=8000) 47 | parser.add_argument("-c", "--config", type=str, required=True) 48 | parser.add_argument("--checkpoint", type=str, default=None) 49 | parser.add_argument("--device", type=str, default='cuda') 50 | parser.add_argument("--max_eval_samples", type=int, default=1000) 51 | return parser 52 | 53 | 54 | def main(args): 55 | cfg = load_yaml(args.config) 56 | expn = os.path.splitext(os.path.basename(args.config))[0] 57 | ref_model_cfg = get_config(cfg["ref_model_dir"]) 58 | 59 | model = SentenceVAE( 60 | hidden_size=ref_model_cfg.hidden_size, 61 | vocab_size=ref_model_cfg.vocab_size, 62 | device=torch.device(cfg["device"]), 63 | dtype=get_dtype(cfg["dtype"]), 64 | learnable_add=cfg["learnable_add"], 65 | load_ref_model=cfg["load_ref_model"], 66 | ref_model_dir=cfg["ref_model_dir"], 67 | ref_model_dtype=get_dtype(cfg["ref_model_dtype"]) if cfg["ref_model_dtype"] is not None else None, 68 | finetune_embedding=cfg["finetune_embedding"], 69 | num_attention_heads=ref_model_cfg.num_attention_heads, 70 | num_hidden_layers=cfg["num_hidden_layers"], 71 | max_seq_len=cfg["max_seq_len"], 72 | dropout=ref_model_cfg.dropout, 73 | bos_id=ref_model_cfg.bos_token_id, 74 | pad_id=ref_model_cfg.pad_token_id, 75 | end_id=ref_model_cfg.eos_token_id 76 | ) 77 | exp_dir = f"exp/{expn}" 78 | ckpt_list = os.listdir(exp_dir) 79 | ckpt = args.checkpoint 80 | if ckpt is None: 81 | for ckpt_path in ckpt_list: 82 | if "best" in ckpt_path: 83 | ckpt_path = os.path.join(exp_dir, ckpt_path) 84 | ckpt = torch.load(ckpt_path)['state_dict'] 85 | assert ckpt is not None, f"Not found the best checkpoint under {exp_dir}." 86 | else: 87 | assert os.path.exists(ckpt), f"Checkpoint {ckpt} not found." 88 | ckpt = torch.load(ckpt)["state_dict"] 89 | model.load_state_dict(ckpt) 90 | model.eval() 91 | 92 | tokenizer = get_tokenizer(ckpt_path=cfg["ref_model_dir"], max_seq_len=cfg["max_seq_len"]) 93 | 94 | eval_dataset = TeleDSDataset(server_ip=args.server, server_port=args.port, eval_mode=True) 95 | eval_sampler = DefaultSampler(eval_dataset, shuffle=False) 96 | eval_collate_fn = SentenceCollate(tokenizer=tokenizer, max_len=cfg["max_seq_len"], padding=True) 97 | 98 | eval_dataloader = DataLoader( 99 | dataset=eval_dataset, 100 | batch_size=cfg["batch_size"], 101 | sampler=eval_sampler, 102 | collate_fn=eval_collate_fn, 103 | num_workers=cfg["dataloader_num_workers"], 104 | prefetch_factor=cfg["dataloader_prefetch_factor"] 105 | ) 106 | 107 | metric = SVAE_PPL() 108 | device = torch.device(args.device) 109 | for data in tqdm(eval_dataloader): 110 | input_ids, attention_mask = data 111 | input_ids = input_ids.to(device) 112 | attention_mask = attention_mask.to(device) 113 | output = model(input_ids, attention_mask, mode='predict') 114 | metric.process(input_ids, output) 115 | results = metric.compute_metrics(metric.results) 116 | 117 | print("Exp:", expn) 118 | print("Best PPL:", results["eval_ppl"]) 119 | 120 | 121 | if __name__ == "__main__": 122 | args = make_parser().parse_args() 123 | main(args) 124 | -------------------------------------------------------------------------------- /tools/train/train_sllm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | import os 30 | import argparse 31 | 32 | import torch 33 | from torch.utils.data import DataLoader 34 | 35 | from mmengine.runner import Runner 36 | from mmengine.dataset import DefaultSampler 37 | from mmengine.dist.utils import init_dist 38 | 39 | from sentence_vae.utils import get_config, get_tokenizer, get_dtype, load_yaml 40 | from sentence_vae.models import SentenceLLM 41 | from sentence_vae.data import TeleDSDataset, PassageCollate, SLLM_PPL 42 | 43 | torch.multiprocessing.set_sharing_strategy('file_system') 44 | 45 | 46 | def make_parser(): 47 | parser = argparse.ArgumentParser("SentenceLLM train parser.") 48 | parser.add_argument("-c", "--config", type=str, required=True) 49 | parser.add_argument("--local-rank", "--local_rank", type=int, default=0) 50 | parser.add_argument("--launcher", choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='job launcher') 51 | parser.add_argument("--teleds_ip", type=str, default="127.0.0.1") 52 | parser.add_argument("--teleds_port", type=int, default=8001) 53 | return parser 54 | 55 | 56 | def main(args): 57 | cfg = load_yaml(args.config) 58 | expn = os.path.splitext(os.path.basename(args.config))[0] 59 | svae_ref_model_cfg = get_config(cfg["svae"]["ref_model_dir"]) 60 | 61 | if args.launcher != 'none': 62 | init_dist(args.launcher) 63 | 64 | model = SentenceLLM( 65 | svae_hidden_size=svae_ref_model_cfg.hidden_size, 66 | svae_vocab_size=svae_ref_model_cfg.vocab_size, 67 | svae_learnable_add=cfg["svae"]["learnable_add"], 68 | svae_load_ref_model=cfg["svae"]["load_ref_model"], 69 | svae_ref_model_dir=cfg["svae"]["ref_model_dir"], 70 | svae_ref_model_dtype=svae_ref_model_cfg.torch_dtype, 71 | svae_finetune_embedding=cfg["svae"]["finetune_embedding"], 72 | svae_word_embed_proj_dim=svae_ref_model_cfg.word_embed_proj_dim, 73 | svae_num_attention_heads=svae_ref_model_cfg.num_attention_heads, 74 | svae_num_hidden_layers=cfg["svae"]["num_hidden_layers"], 75 | svae_model_path=cfg["svae"]["model_path"], 76 | llm_ref_model_dir=cfg["llm"]["ref_model_dir"], 77 | llm_ref_model_dtype=svae_ref_model_cfg.torch_dtype, 78 | llm_finetune_layers=cfg["llm"]["finetune_layers"], 79 | finetune_svae=cfg["finetune_svae"], 80 | max_sentence_len=cfg["max_sen_len"], 81 | max_sentence_num=cfg["max_sen_num"], 82 | dropout=svae_ref_model_cfg.dropout, 83 | bos_id=svae_ref_model_cfg.bos_token_id, 84 | pad_id=svae_ref_model_cfg.pad_token_id, 85 | end_id=svae_ref_model_cfg.eos_token_id, 86 | device=torch.device(cfg["device"]), 87 | dtype=get_dtype(cfg["dtype"]) 88 | ) 89 | 90 | tokenizer = get_tokenizer(ckpt_path=cfg["svae"]["ref_model_dir"], max_seq_len=cfg["max_sen_len"]) 91 | 92 | train_dataset = TeleDSDataset(server_ip=args.teleds_ip, server_port=args.teleds_port) 93 | train_sampler = DefaultSampler(train_dataset, shuffle=False) 94 | train_collate_fn = PassageCollate(tokenizer=tokenizer, max_sentence_len=cfg["max_sen_len"], max_sentence_num=cfg["max_sen_num"], padding=True) 95 | 96 | eval_dataset = TeleDSDataset(server_ip=args.teleds_ip, server_port=args.teleds_port, eval_mode=True) 97 | eval_sampler = DefaultSampler(eval_dataset, shuffle=False) 98 | eval_collate_fn = PassageCollate(tokenizer=tokenizer, max_sentence_len=cfg["max_sen_len"], max_sentence_num=cfg["max_sen_num"], padding=True) 99 | 100 | train_dataloader = DataLoader( 101 | dataset=train_dataset, 102 | batch_size=cfg["batch_size"], 103 | sampler=train_sampler, 104 | collate_fn=train_collate_fn, 105 | num_workers=cfg["dataloader_num_workers"], 106 | prefetch_factor=cfg["dataloader_prefetch_factor"] 107 | ) 108 | eval_dataloader = DataLoader( 109 | dataset=eval_dataset, 110 | batch_size=cfg["batch_size"], 111 | sampler=eval_sampler, 112 | collate_fn=eval_collate_fn, 113 | num_workers=cfg["dataloader_num_workers"], 114 | prefetch_factor=cfg["dataloader_prefetch_factor"] 115 | ) 116 | 117 | learning_rate = cfg["batch_size"] * cfg["base_lr"] 118 | 119 | default_hooks=dict(checkpoint=dict( 120 | type='CheckpointHook', 121 | by_epoch=False, 122 | interval=cfg["save_checkpoint_iters"], 123 | max_keep_ckpts=cfg["max_keep_ckpts"], 124 | save_best='eval_ppl', rule='less', published_keys=['meta', 'state_dict'] 125 | )) 126 | runner = Runner( 127 | model=model, 128 | work_dir=f"exp/SentenceVAE-{expn}", 129 | train_dataloader=train_dataloader, 130 | val_dataloader=eval_dataloader, 131 | val_cfg=dict(), 132 | val_evaluator=dict(type=SLLM_PPL), 133 | train_cfg=dict(by_epoch=False, max_iters=cfg["max_iters"], val_interval=cfg["val_iters"]), 134 | optim_wrapper=dict(type="AmpOptimWrapper", optimizer=dict(type='AdamW', lr=learning_rate, weight_decay=0.01), clip_grad=dict(max_norm=1)), 135 | param_scheduler=[ 136 | dict(type='LinearLR', start_factor=1e-3, by_epoch=False, begin=0, end=cfg["warmup_iters"]), 137 | dict(type='CosineAnnealingLR', by_epoch=False, T_max=cfg["cosineannealinglr_tmax"]) 138 | ], 139 | visualizer=dict(type='Visualizer', vis_backends=[dict(type='TensorboardVisBackend')]), 140 | default_hooks=default_hooks, 141 | custom_hooks=[dict(type='EMAHook')], 142 | resume=cfg["resume_train"] 143 | ) 144 | runner.train() 145 | 146 | 147 | if __name__ == "__main__": 148 | args = make_parser().parse_args() 149 | main(args) 150 | -------------------------------------------------------------------------------- /tools/train/train_svae.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), 2 | # Northwestern PolyTechnical University, 3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom. 4 | # 5 | # Author: Coder.AN (an.hongjun@foxmail.com) 6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn) 7 | # 8 | # 9 | # This software is licensed under the MIT License. 10 | # 11 | # Permission is hereby granted, free of charge, to any person obtaining a copy 12 | # of this software and associated documentation files (the "Software"), to deal 13 | # in the Software without restriction, including without limitation the rights 14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | # copies of the Software, and to permit persons to whom the Software is 16 | # furnished to do so, subject to the following conditions: 17 | # 18 | # The above copyright notice and this permission notice shall be included in 19 | # all copies or substantial portions of the Software. 20 | # 21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | # THE SOFTWARE. 28 | 29 | import os 30 | import argparse 31 | 32 | import torch 33 | from torch.utils.data import DataLoader 34 | 35 | from mmengine.runner import Runner 36 | from mmengine.dataset import DefaultSampler 37 | from mmengine.dist.utils import init_dist 38 | 39 | from sentence_vae.utils import get_config, get_tokenizer, get_dtype, load_yaml 40 | from sentence_vae.models import SentenceVAE 41 | from sentence_vae.data import TeleDSDataset, SentenceCollate, SVAE_PPL 42 | 43 | 44 | def make_parser(): 45 | parser = argparse.ArgumentParser("SentenceVAE train parser.") 46 | parser.add_argument("-c", "--config", type=str, required=True) 47 | parser.add_argument("--local-rank", "--local_rank", type=int, default=0) 48 | parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='job launcher') 49 | parser.add_argument("--teleds_ip", type=str, default="127.0.0.1") 50 | parser.add_argument("--teleds_port", type=int, default=8000) 51 | return parser 52 | 53 | 54 | def main(args): 55 | cfg = load_yaml(args.config) 56 | expn = os.path.splitext(os.path.basename(args.config))[0] 57 | ref_model_cfg = get_config(cfg["ref_model_dir"]) 58 | 59 | if args.launcher != 'none': 60 | init_dist(args.launcher) 61 | 62 | model = SentenceVAE( 63 | hidden_size=ref_model_cfg.hidden_size, 64 | vocab_size=ref_model_cfg.vocab_size, 65 | device=torch.device(cfg["device"]), 66 | dtype=get_dtype(cfg["dtype"]), 67 | learnable_add=cfg["learnable_add"], 68 | load_ref_model=cfg["load_ref_model"], 69 | ref_model_dir=cfg["ref_model_dir"], 70 | ref_model_dtype=get_dtype(cfg["ref_model_dtype"]) if cfg["ref_model_dtype"] is not None else None, 71 | finetune_embedding=cfg["finetune_embedding"], 72 | num_attention_heads=ref_model_cfg.num_attention_heads, 73 | num_hidden_layers=cfg["num_hidden_layers"], 74 | max_seq_len=cfg["max_seq_len"], 75 | dropout=ref_model_cfg.dropout, 76 | bos_id=ref_model_cfg.bos_token_id, 77 | pad_id=ref_model_cfg.pad_token_id, 78 | end_id=ref_model_cfg.eos_token_id 79 | ) 80 | 81 | tokenizer = get_tokenizer(ckpt_path=cfg["ref_model_dir"], max_seq_len=cfg["max_seq_len"]) 82 | 83 | train_dataset = TeleDSDataset(server_ip=args.teleds_ip, server_port=args.teleds_port) 84 | train_sampler = DefaultSampler(train_dataset, shuffle=False) 85 | train_collate_fn = SentenceCollate(tokenizer=tokenizer, max_len=cfg["max_seq_len"], padding=True) 86 | 87 | eval_dataset = TeleDSDataset(server_ip=args.teleds_ip, server_port=args.teleds_port, eval_mode=True) 88 | eval_sampleer = DefaultSampler(eval_dataset, shuffle=False) 89 | eval_collate_fn = SentenceCollate(tokenizer=tokenizer, max_len=cfg["max_seq_len"], padding=True) 90 | 91 | train_dataloader = DataLoader( 92 | dataset=train_dataset, 93 | batch_size=cfg["batch_size"], 94 | sampler=train_sampler, 95 | collate_fn=train_collate_fn, 96 | num_workers=cfg["dataloader_num_workers"], 97 | prefetch_factor=cfg["dataloader_prefetch_factor"] 98 | ) 99 | eval_dataloader = DataLoader( 100 | dataset=eval_dataset, 101 | batch_size=cfg["batch_size"], 102 | sampler=eval_sampleer, 103 | collate_fn=eval_collate_fn, 104 | num_workers=cfg["dataloader_num_workers"], 105 | prefetch_factor=cfg["dataloader_prefetch_factor"] 106 | ) 107 | 108 | learning_rate = cfg["batch_size"] * cfg["base_lr"] 109 | 110 | default_hooks=dict(checkpoint=dict( 111 | type='CheckpointHook', 112 | by_epoch=False, 113 | interval=cfg["save_checkpoint_iters"], 114 | max_keep_ckpts=cfg["max_keep_ckpts"], 115 | save_best='eval_ppl', rule='less', published_keys=['meta', 'state_dict'] 116 | )) 117 | runner = Runner( 118 | model=model, 119 | work_dir=f"exp/{expn}", 120 | train_dataloader=train_dataloader, 121 | val_dataloader=eval_dataloader, 122 | val_cfg=dict(), 123 | val_evaluator=dict(type=SVAE_PPL), 124 | train_cfg=dict(by_epoch=False, max_iters=cfg["max_iters"], val_interval=cfg["val_iters"]), 125 | optim_wrapper=dict(type="AmpOptimWrapper", optimizer=dict(type='AdamW', lr=learning_rate, weight_decay=0.01), clip_grad=dict(max_norm=1)), 126 | param_scheduler=[ 127 | dict(type='LinearLR', start_factor=1e-3, by_epoch=False, begin=0, end=cfg["warmup_iters"]), 128 | dict(type='CosineAnnealingLR', by_epoch=False, T_max=cfg["cosineannealinglr_tmax"]) 129 | ], 130 | visualizer=dict(type='Visualizer', vis_backends=[dict(type='TensorboardVisBackend')]), 131 | default_hooks=default_hooks, 132 | custom_hooks=[dict(type='EMAHook')], 133 | resume=cfg["resume_train"] 134 | ) 135 | runner.train() 136 | 137 | 138 | if __name__ == "__main__": 139 | args = make_parser().parse_args() 140 | main(args) 141 | --------------------------------------------------------------------------------