├── .gitignore
├── LICENSE
├── README.md
├── README_zh-CN.md
├── assets
├── logo.png
├── scaling_law.svg
├── sentence_vae.svg
└── sllm.svg
├── config
├── SLLM
│ ├── SLLM-1.3b
│ │ ├── sllm_1.3b_base.yaml
│ │ ├── sllm_1.3b_h1_all.yaml
│ │ ├── sllm_1.3b_h2_all.yaml
│ │ └── sllm_1.3b_h4_all.yaml
│ ├── SLLM-125m
│ │ ├── sllm_125m_base.yaml
│ │ ├── sllm_125m_h1_all.yaml
│ │ ├── sllm_125m_h2_all.yaml
│ │ └── sllm_125m_h4_all.yaml
│ ├── SLLM-350m
│ │ ├── sllm_350m_base.yaml
│ │ ├── sllm_350m_h1_all.yaml
│ │ ├── sllm_350m_h2_all.yaml
│ │ └── sllm_350m_h4_all.yaml
│ └── sllm_base.yaml
└── SVAE
│ ├── SVAE-1024
│ ├── svae_1024_base.yaml
│ ├── svae_1024_h1.yaml
│ ├── svae_1024_h2.yaml
│ └── svae_1024_h4.yaml
│ ├── SVAE-2048
│ ├── svae_2048_base.yaml
│ ├── svae_2048_h1.yaml
│ ├── svae_2048_h2.yaml
│ └── svae_2048_h4.yaml
│ ├── SVAE-768
│ ├── svae_768_base.yaml
│ ├── svae_768_h1.yaml
│ ├── svae_768_h2.yaml
│ └── svae_768_h4.yaml
│ └── svae_base.yaml
├── requirements.txt
├── sentence_vae
├── __init__.py
├── data
│ ├── __init__.py
│ ├── data_collate.py
│ ├── data_eval.py
│ └── tele_ds_dataset.py
├── models
│ ├── __init__.py
│ ├── focal_loss.py
│ ├── positional_encoding.py
│ ├── sentence_decoder.py
│ ├── sentence_encoder.py
│ ├── sentence_llm_model.py
│ └── sentence_vae_model.py
└── utils
│ ├── __init__.py
│ ├── config.py
│ ├── http.py
│ ├── llm.py
│ └── weights.py
├── setup.py
└── tools
├── analysis
├── analysis_llm.py
├── analysis_sllm.py
├── analysis_svae.py
├── statistic_dataset.py
└── statistic_vocab.py
├── benchmark
├── benchmark_llm.py
└── benchmark_sllm.py
├── demo
├── demo_llm.py
├── demo_sllm.py
└── demo_svae.py
├── eval
├── eval_llm.py
├── eval_sllm.py
└── eval_svae.py
└── train
├── train_sllm.py
└── train_svae.py
/.gitignore:
--------------------------------------------------------------------------------
1 | ### Linux ###
2 | *~
3 |
4 | # user experiments directory
5 | model_repo*
6 | datasets/
7 | exp*
8 | test*
9 | temp_config/
10 |
11 | # temporary files which can be created if a process still has a handle open of a deleted file
12 | .fuse_hidden*
13 |
14 | # KDE directory preferences
15 | .directory
16 |
17 | # Linux trash folder which might appear on any partition or disk
18 | .Trash-*
19 |
20 | # .nfs files are created when an open file is removed but is still being accessed
21 | .nfs*
22 |
23 | ### PyCharm ###
24 | # User-specific stuff
25 | .idea
26 |
27 | # CMake
28 | cmake-build-*/
29 |
30 | # Mongo Explorer plugin
31 | .idea/**/mongoSettings.xml
32 |
33 | # File-based project format
34 | *.iws
35 |
36 | # IntelliJ
37 | out/
38 |
39 | # mpeltonen/sbt-idea plugin
40 | .idea_modules/
41 |
42 | # JIRA plugin
43 | atlassian-ide-plugin.xml
44 |
45 | # Cursive Clojure plugin
46 | .idea/replstate.xml
47 |
48 | # Crashlytics plugin (for Android Studio and IntelliJ)
49 | com_crashlytics_export_strings.xml
50 | crashlytics.properties
51 | crashlytics-build.properties
52 | fabric.properties
53 |
54 | # Editor-based Rest Client
55 | .idea/httpRequests
56 |
57 | # Android studio 3.1+ serialized cache file
58 | .idea/caches/build_file_checksums.ser
59 |
60 | # JetBrains templates
61 | **___jb_tmp___
62 |
63 | ### Python ###
64 | # Byte-compiled / optimized / DLL files
65 | __pycache__/
66 | *.py[cod]
67 | *$py.class
68 |
69 | # C extensions
70 | *.so
71 |
72 | # Distribution / packaging
73 | .Python
74 | build/
75 | develop-eggs/
76 | dist/
77 | downloads/
78 | eggs/
79 | .eggs/
80 | lib/
81 | lib64/
82 | parts/
83 | sdist/
84 | var/
85 | wheels/
86 | pip-wheel-metadata/
87 | share/python-wheels/
88 | *.egg-info/
89 | .installed.cfg
90 | *.egg
91 | MANIFEST
92 |
93 | # PyInstaller
94 | # Usually these files are written by a python script from a template
95 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
96 | *.manifest
97 | *.spec
98 |
99 | # Installer logs
100 | pip-log.txt
101 | pip-delete-this-directory.txt
102 |
103 | # Unit test / coverage reports
104 | htmlcov/
105 | .tox/
106 | .nox/
107 | .coverage
108 | .coverage.*
109 | .cache
110 | nosetests.xml
111 | coverage.xml
112 | *.cover
113 | .hypothesis/
114 | .pytest_cache/
115 |
116 | # Translations
117 | *.mo
118 | *.pot
119 |
120 | # Django stuff:
121 | *.log
122 | local_settings.py
123 | db.sqlite3
124 |
125 | # Flask stuff:
126 | instance/
127 | .webassets-cache
128 |
129 | # Scrapy stuff:
130 | .scrapy
131 |
132 | # Sphinx documentation
133 | docs/_build/
134 | docs/build/
135 |
136 | # PyBuilder
137 | target/
138 |
139 | # Jupyter Notebook
140 | .ipynb_checkpoints
141 |
142 | # IPython
143 | profile_default/
144 | ipython_config.py
145 |
146 | # pyenv
147 | .python-version
148 |
149 | # pipenv
150 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
151 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
152 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not
153 | # install all needed dependencies.
154 | #Pipfile.lock
155 |
156 | # celery beat schedule file
157 | celerybeat-schedule
158 |
159 | # SageMath parsed files
160 | *.sage.py
161 |
162 | # Environments
163 | .env
164 | .venv
165 | env/
166 | venv/
167 | ENV/
168 | env.bak/
169 | venv.bak/
170 |
171 | # Spyder project settings
172 | .spyderproject
173 | .spyproject
174 |
175 | # Rope project settings
176 | .ropeproject
177 |
178 | # mkdocs documentation
179 | /site
180 |
181 | # mypy
182 | .mypy_cache/
183 | .dmypy.json
184 | dmypy.json
185 |
186 | # Pyre type checker
187 | .pyre/
188 |
189 | ### Vim ###
190 | # Swap
191 | [._]*.s[a-v][a-z]
192 | [._]*.sw[a-p]
193 | [._]s[a-rt-v][a-z]
194 | [._]ss[a-gi-z]
195 | [._]sw[a-p]
196 |
197 | # Session
198 | Session.vim
199 |
200 | # Temporary
201 | .netrwhist
202 | # Auto-generated tag files
203 | tags
204 | # Persistent undo
205 | [._]*.un~
206 |
207 | # output
208 | docs/api
209 | .code-workspace.code-workspace
210 | *.pkl
211 | *.npy
212 | *.pth
213 | *.onnx
214 | *.engine
215 | events.out.tfevents*
216 |
217 | # vscode
218 | *.code-workspace
219 | .vscode
220 |
221 | # vim
222 | .vim
223 |
224 | # OS generated files
225 | .DS_Store
226 | .DS_Store?
227 | .Trashes
228 | ehthumbs.db
229 | Thumbs.db
230 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT LICENSE
2 |
3 | Copyright (C) 2024 School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), Northwestern PolyTechnical University, and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 |
7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 |
9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 |   [](https://arxiv.org/abs/2408.00655)
5 |
6 | Hongjun An
1,2*,Yifan Chen
1,2*,Zhe Sun
1,2✉ & Xuelong Li
1,2✉
7 |
8 |
1School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), Northwestern PolyTechnical University
9 |
10 |
2Institute of Artificial Intelligence (TeleAI), China Telecom
11 |
12 |
13 | English | [简体中文](README_zh-CN.md)
14 |
15 |
16 |
17 | # 1.Introduction
18 |
19 | Current large language models (LLMs) primarily utilize next-token prediction method for inference, which significantly impedes their processing speed. In this [paper](https://arxiv.org/abs/2408.00655), we introduce a novel inference methodology termed next-sentence prediction, aiming at enhancing the inference efficiency of LLMs. We present Sentence Variational Autoencoder (SentenceVAE), which includes a Sentence Encoder to compress multiple tokens in a sentence into a single token, and a Sentence Decoder to reconstruct it.
20 |
21 |
22 |

23 |
Fig. 1. The schematic form of SentenceVAE.
24 |
25 |
26 |
27 | By integrating SentenceVAE into the input and output layers of LLMs, we develop Sentence-level LLMs (SLLMs) that employ a sentence-by-sentence inference method.
28 |
29 |
30 |

31 |
Fig. 2. (a) The schematic form of published LLMs. (b) The schematic form of SLLMs, which embedded with SentenceVAEs.
32 |
33 |
34 | The SLLMs can maintain the integrity of the original semantic content by segmenting the context into sentences, thereby improving accuracy while boosting inference speed. Moreover, compared to previous LLMs, SLLMs process fewer tokens over equivalent context length, significantly reducing memory demands for self-attention computation and facilitating the handling of longer context. Extensive experiments on [Wanjuan dataset](https://github.com/opendatalab/WanJuan1.0/) have revealed that the proposed method can accelerate inference speed by 204 ~ 365%, reduce perplexity (PPL) to 46 ~ 75% of its original metric, and decrease memory overhead by 86 ~ 91% for the equivalent context length, compared to previous token-by-token methods.
35 |
36 |
37 |
38 |
39 | Model |
40 | Total Params |
41 | Average PPL |
42 | Mean output throughput (toks/s) |
43 | Mean GPU memory (KB/token) |
44 |
45 |
46 | OPT↓ |
47 | SLLM↓ |
48 | Δ↓ |
49 | OPT↑ |
50 | SLLM↑ |
51 | Δ↑ |
52 | OPT↓ |
53 | SLLM↓ |
54 | Δ↓ |
55 |
56 |
57 | SLLM-125M-H1 |
58 | 214M |
59 | 26.75 |
60 | 31.68 |
61 | +18.4% |
62 | 214.57 |
63 | 652.78 |
64 | +204.2% |
65 | 73.15 |
66 | 12.03 |
67 | -83.6% |
68 |
69 |
70 | SLLM-125M-H2 |
71 | 226M |
72 | 44.60 |
73 | +66.7% |
74 | 539.80 |
75 | +151.6% |
76 | 7.08 |
77 | -90.3% |
78 |
79 |
80 | SLLM-125M-H4 |
81 | 250M |
82 | 14.32 |
83 | -46.5% |
84 | 332.12 |
85 | +54.8% |
86 | 10.00 |
87 | -86.3% |
88 |
89 |
90 | SLLM-350M-H1 |
91 | 429M |
92 | 25.18 |
93 | 24.84 |
94 | -1.4% |
95 | 144.33 |
96 | 481.39 |
97 | +233.5% |
98 | 197.59 |
99 | 29.98 |
100 | -84.8% |
101 |
102 |
103 | SLLM-350M-H2 |
104 | 450M |
105 | 14.81 |
106 | -41.2% |
107 | 442.23 |
108 | +206.4% |
109 | 26.78 |
110 | -86.4% |
111 |
112 |
113 | SLLM-350M-H4 |
114 | 492M |
115 | 10.17 |
116 | -59.6% |
117 | 315.61 |
118 | +118.7% |
119 | 17.73 |
120 | -91.0% |
121 |
122 |
123 | SLLM-1.3B-H1 |
124 | 1.61B |
125 | 15.95 |
126 | 8.76 |
127 | -45.1% |
128 | 119.07 |
129 | 479.71 |
130 | +302.9% |
131 | 400.01 |
132 | 57.07 |
133 | -85.7% |
134 |
135 |
136 | SLLM-1.3B-H2 |
137 | 1.69B |
138 | 3.84 |
139 | -75.9% |
140 | 553.95 |
141 | +365.2% |
142 | 55.14 |
143 | -86.2% |
144 |
145 |
146 |
147 |
148 | In addition, by corroborating the Scaling Law, we extrapolated the feasibility of our methodologies to larger-scale models.
149 |
150 |
151 |

152 |
Fig. 3. Scaling Law of (a) SLLMs and (b) SVAEs.
153 |
154 |
155 | # 2.Quick Start
156 |
157 |
158 | Installation
159 |
160 | Step1. Install SentenceVAE from source.
161 |
162 | ```sh
163 | git clone https://github.com/BestAnHongjun/SentenceVAE.git
164 | cd SentenceVAE
165 | pip3 install -e . # or python3 setup.py develop
166 | ```
167 |
168 |
169 |
170 |
171 | Prepare OPT models
172 |
173 | Step1. Create a folder named `model_repo` under `SentenceVAE` to save OPT series models.
174 |
175 | ```sh
176 | cd SentenceVAE
177 | mkdir -p model_repo
178 | ```
179 |
180 | Step2. Navigate to the `model_repo` directory with `cd` and initialize [`git-lfs`](https://git-lfs.com).
181 |
182 | ```sh
183 | cd model_repo
184 | git lfs install
185 | ```
186 |
187 | Step3. Download [OPT-125M](https://huggingface.co/facebook/opt-125m) model for SentenceVAE-768 series and SLLM-125M series.
188 |
189 | ```sh
190 | git clone https://huggingface.co/facebook/opt-125m
191 | ```
192 |
193 | Step4. Download [OPT-350M](https://huggingface.co/facebook/opt-350m) model for SentenceVAE-1024 series and SLLM-350M series.
194 |
195 | ```sh
196 | git clone https://huggingface.co/facebook/opt-350m
197 | ```
198 |
199 | Step5. Download [OPT-1.3B](https://huggingface.co/facebook/opt-1.3b) model for Sentence-2048 series and SLLM-1.3B series.
200 |
201 | ```sh
202 | git clone https://huggingface.co/facebook/opt-1.3b
203 | ```
204 |
205 |
206 |
207 |
208 | SentenceVAE Demo
209 |
210 | Step1. Download a pretrained model from table below.
211 |
212 |
213 |
214 | |Model|Hidden Size|Hidden Layers|Loss↓|PPL↓|Download Link|
215 | |:-:|:-:|:-:|:-:|:-:|:-:|
216 | |SVAE-768-H1|768|1|1.339|3.605|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceVAE/resolve/master/SVAE-768-H1.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceVAE/main?filepath=SVAE-768-H1.pth&sign=d8239cd4b6979b61ee0b969ef54f1a78&nonce=1723800234481)|
217 | |SVAE-768-H2|768|2|1.019|2.588|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceVAE/resolve/master/SVAE-768-H2.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceVAE/main?filepath=SVAE-768-H2.pth&sign=c11ca77f7934d4b441e7a6ae5359157f&nonce=1723800264673)|
218 | |SVAE-768-H4|768|4|**0.5598**|**1.649**|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceVAE/resolve/master/SVAE-768-H4.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceVAE/main?filepath=SVAE-768-H4.pth&sign=829995892eba42caf3f28a0f77a28d9e&nonce=1723800281621)|
219 | |SVAE-1024-H1|1024|1|0.9266|2.406|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceVAE/resolve/master/SVAE-1024-H1.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceVAE/main?filepath=SVAE-1024-H1.pth&sign=b3d5202c64d117389131def1b35e2f33&nonce=1723800301123)|
220 | |SVAE-1024-H2|1024|2|0.6610|1.845|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceVAE/resolve/master/SVAE-1024-H2.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceVAE/main?filepath=SVAE-1024-H2.pth&sign=f4ba19e8f474068598f8186be14a7ab4&nonce=1723800319623)|
221 | |SVAE-1024-H4|1024|4|**0.3704**|**1.384**|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceVAE/resolve/master/SVAE-1024-H4.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceVAE/main?filepath=SVAE-1024-H4.pth&sign=6ff668d782383e4bf01d1337a98910b3&nonce=1723800343431)|
222 | |SVAE-2048-H1|2048|1|0.5165|1.622|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceVAE/resolve/master/SVAE-2048-H1.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceVAE/main?filepath=SVAE-2048-H1.pth&sign=4d1ef8d0d0cf0f48e406eb73d74bb5cf&nonce=1723800363566)|
223 | |SVAE-2048-H2|2048|2|0.2845|1.292|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceVAE/resolve/master/SVAE-2048-H2.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceVAE/main?filepath=SVAE-2048-H2.pth&sign=7cc09034413bcfdb0fcbd875e6ad4be4&nonce=1723800379541)|
224 | |SVAE-2048-H4|2048|4|**0.1270**|**1.115**|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceVAE/resolve/master/SVAE-2048-H4.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceVAE/main?filepath=SVAE-2048-H4.pth&sign=8d09686dbcd6aaf0aeaf70a537de1836&nonce=1723800393625)|
225 |
226 |
227 |
228 | Step2. Run demo script under `tools/demo` folder. Here's an example:
229 |
230 | ```sh
231 | cd SentenceVAE
232 |
233 | python3 tools/demo/demo_svae.py \
234 | -c config/SVAE/SVAE-768/svae_768_h4.yaml \
235 | --checkpoint /path/to/pretrained/checkpoint \
236 | --input "What's your name?"
237 | ```
238 |
239 | **Arguments**:
240 | * `-c`,`--config`: path to the corresponding configuration file, please reference [this folder](config/SVAE/).
241 | * `--checkpoint`: path to the checkpoint file you just downloaded.
242 | * `--input`: A sentence you want to test.
243 | * It must be a separate sentence ending with punctuation marks such as commas, periods, etc. Please refer to the [paper](https://arxiv.org/abs/2408.00655) for specific reasons.
244 | * Currently, only English is supported.
245 |
246 | The model will compress this sentence into a single vector, decode and restore it for output. In an ideal state, the output and input should be consistent.
247 |
248 |
249 |
250 |
251 |
252 | SentenceLLM Demo
253 |
254 | **Notice**: Please be aware that, as SFT datasets are typically commercial secrets and difficult for us to access, all the models listed below are **pre-trained models**, not general-purpose conversation models. Therefore, the **PPL** (Perplexity) metric should be used to assess model quality, not conversational performance. If you treat them as Q&A models, you're likely to get gibberish outputs (***in fact, even our baseline OPT model will output gibberish***). We recommend fine-tuning these models on private SFT datasets to explore their potential as general-purpose conversation models.
255 |
256 | Step1. Download a pretrained model from table below.
257 |
258 |
259 |
260 | |Model|Download Link|
261 | |:-:|:-:|
262 | |SLLM-125M-H1|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceLLM/resolve/master/SLLM-125M-H1.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceLLM/main?filepath=SLLM-125M-H1.pth&sign=b170645b08b03ee1240b95267c7454ca&nonce=1723800424069)|
263 | |SLLM-125M-H2|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceLLM/resolve/master/SLLM-125M-H2.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceLLM/main?filepath=SLLM-125M-H2.pth&sign=fb3e59ef4c30dcea0d732af741d183b2&nonce=1723800472802)|
264 | |SLLM-125M-H4|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceLLM/resolve/master/SLLM-125M-H4.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceLLM/main?filepath=SLLM-125M-H4.pth&sign=d8d1c1aa1d516e26fe21cb1ca3220e62&nonce=1723800484902)|
265 | |SLLM-350M-H1|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceLLM/resolve/master/SLLM-350M-H1.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceLLM/main?filepath=SLLM-350M-H1.pth&sign=bc5e5e777a1bc41a1a564a7e52a2bf94&nonce=1723800502532)|
266 | |SLLM-350M-H2|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceLLM/resolve/master/SLLM-350M-H2.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceLLM/main?filepath=SLLM-350M-H2.pth&sign=6c8ba81b806649366df99c96fbe3e4ed&nonce=1723800517836)|
267 | |SLLM-350M-H4|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceLLM/resolve/master/SLLM-350M-H4.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceLLM/main?filepath=SLLM-350M-H4.pth&sign=e5722a48a5ff516e61cf9efdc1ee8230&nonce=1723800534148)|
268 | |SLLM-1.3B-H1|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceLLM/resolve/master/SLLM-1.3B-H1.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceLLM/main?filepath=SLLM-1.3B-H1.pth&sign=77edd326c8e46eebd98f7f545f4d4e0c&nonce=1723800549084)|
269 | |SLLM-1.3B-H2|[ModelScope](https://modelscope.cn/models/CoderAN/SentenceLLM/resolve/master/SLLM-1.3B-H2.pth)
[OpenXLab](https://download.openxlab.org.cn/repos/file/Coder-AN/SentenceLLM/main?filepath=SLLM-1.3B-H2.pth&sign=54dc841b8a067afe7a2fbd16a6c0a2e5&nonce=1723800565365)|
270 |
271 |
272 |
273 | Step2. Run demo script under `tools/demo` folder. Here's an example:
274 |
275 | ```sh
276 | cd SentenceVAE
277 |
278 | python3 tools/demo/demo_sllm.py \
279 | -c config/SLLM/SLLM-125m/sllm_125m_h4_all.yaml \
280 | --checkpoint /path/to/pretrained/checkpoint \
281 | --input "What's your name?"
282 | ```
283 |
284 | **Arguments**:
285 | * `-c`,`--config`: path to the corresponding configuration file, please reference [this folder](config/SLLM/).
286 | * `--checkpoint`: path to the checkpoint file you just downloaded.
287 | * `--input`: Your input sentence.
288 |
289 |
290 |
291 |
292 | # 3.Tutorials
293 |
294 | Under writing...
295 |
296 |
297 | Train Models
298 |
299 | * [Prepare Datasets](#)
300 | * [Train SentenceVAEs](#)
301 | * [Train SentenceLLMs](#)
302 |
303 |
304 |
305 |
306 | Eval Models
307 |
308 | * [Eval OPT models (baseline)](#)
309 | * [Eval SentenceVAEs](#)
310 | * [Eval SentenceLLMs](#)
311 |
312 |
313 |
314 |
315 | Test Benchmarks
316 |
317 | * [Test benchmarks of SentenceVAEs](#)
318 | * [Test benchmarks of SentenceLLMs](#)
319 |
320 |
321 |
322 | # 4.Cite SentenceVAE
323 |
324 | If you use SentenceVAE in your research, please cite our work by using the following BibTeX entry:
325 |
326 | ```bibtex
327 | @article{an2024sentencevae,
328 | title={SentenceVAE: Enable Next-sentence Prediction for Large Language Models with Faster Speed, Higher Accuracy and Longer Context},
329 | author={An, Hongjun and Chen, Yifan and Sun, Zhe and Li, Xuelong},
330 | journal={arXiv preprint arXiv:2408.00655},
331 | year={2024}
332 | }
333 | ```
--------------------------------------------------------------------------------
/README_zh-CN.md:
--------------------------------------------------------------------------------
1 | TODO.
--------------------------------------------------------------------------------
/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BestAnHongjun/SentenceVAE/8e323e4a5db3b9f36ccc3435d88804e60dfabf69/assets/logo.png
--------------------------------------------------------------------------------
/config/SLLM/SLLM-1.3b/sllm_1.3b_base.yaml:
--------------------------------------------------------------------------------
1 | __base__: !include ../sllm_base.yaml
2 |
3 | svae:
4 | ref_model_dir: "model_repo/opt-1.3b"
5 |
6 | llm:
7 | ref_model_dir: "model_repo/opt-1.3b"
8 |
--------------------------------------------------------------------------------
/config/SLLM/SLLM-1.3b/sllm_1.3b_h1_all.yaml:
--------------------------------------------------------------------------------
1 | __base__: !include ./sllm_1.3b_base.yaml
2 |
3 | svae:
4 | num_hidden_layers: 1
5 |
6 | llm:
7 | finetune_layers: -1
8 |
--------------------------------------------------------------------------------
/config/SLLM/SLLM-1.3b/sllm_1.3b_h2_all.yaml:
--------------------------------------------------------------------------------
1 | __base__: !include ./sllm_1.3b_base.yaml
2 |
3 | svae:
4 | num_hidden_layers: 2
5 |
6 | llm:
7 | finetune_layers: -1
8 |
--------------------------------------------------------------------------------
/config/SLLM/SLLM-1.3b/sllm_1.3b_h4_all.yaml:
--------------------------------------------------------------------------------
1 | __base__: !include ./sllm_1.3b_base.yaml
2 |
3 | svae:
4 | num_hidden_layers: 4
5 |
6 | llm:
7 | finetune_layers: -1
8 |
--------------------------------------------------------------------------------
/config/SLLM/SLLM-125m/sllm_125m_base.yaml:
--------------------------------------------------------------------------------
1 | __base__: !include ../sllm_base.yaml
2 |
3 | svae:
4 | ref_model_dir: "model_repo/opt-125m"
5 |
6 | llm:
7 | ref_model_dir: "model_repo/opt-125m"
8 |
--------------------------------------------------------------------------------
/config/SLLM/SLLM-125m/sllm_125m_h1_all.yaml:
--------------------------------------------------------------------------------
1 | __base__: !include ./sllm_125m_base.yaml
2 |
3 | svae:
4 | num_hidden_layers: 1
5 |
6 | llm:
7 | finetune_layers: -1
8 |
--------------------------------------------------------------------------------
/config/SLLM/SLLM-125m/sllm_125m_h2_all.yaml:
--------------------------------------------------------------------------------
1 | __base__: !include ./sllm_125m_base.yaml
2 |
3 | svae:
4 | num_hidden_layers: 2
5 |
6 | llm:
7 | finetune_layers: -1
8 |
--------------------------------------------------------------------------------
/config/SLLM/SLLM-125m/sllm_125m_h4_all.yaml:
--------------------------------------------------------------------------------
1 | __base__: !include ./sllm_125m_base.yaml
2 |
3 | svae:
4 | num_hidden_layers: 4
5 |
6 | llm:
7 | finetune_layers: -1
8 |
--------------------------------------------------------------------------------
/config/SLLM/SLLM-350m/sllm_350m_base.yaml:
--------------------------------------------------------------------------------
1 | __base__: !include ../sllm_base.yaml
2 |
3 | svae:
4 | ref_model_dir: "model_repo/opt-350m"
5 |
6 | llm:
7 | ref_model_dir: "model_repo/opt-350m"
8 |
--------------------------------------------------------------------------------
/config/SLLM/SLLM-350m/sllm_350m_h1_all.yaml:
--------------------------------------------------------------------------------
1 | __base__: !include ./sllm_350m_base.yaml
2 |
3 | svae:
4 | num_hidden_layers: 1
5 |
6 | llm:
7 | finetune_layers: -1
8 |
--------------------------------------------------------------------------------
/config/SLLM/SLLM-350m/sllm_350m_h2_all.yaml:
--------------------------------------------------------------------------------
1 | __base__: !include ./sllm_350m_base.yaml
2 |
3 | svae:
4 | num_hidden_layers: 2
5 |
6 | llm:
7 | finetune_layers: -1
8 |
--------------------------------------------------------------------------------
/config/SLLM/SLLM-350m/sllm_350m_h4_all.yaml:
--------------------------------------------------------------------------------
1 | __base__: !include ./sllm_350m_base.yaml
2 |
3 | svae:
4 | num_hidden_layers: 4
5 |
6 | llm:
7 | finetune_layers: -1
8 |
--------------------------------------------------------------------------------
/config/SLLM/sllm_base.yaml:
--------------------------------------------------------------------------------
1 | device: "cuda"
2 | dtype: "fp32"
3 |
4 | svae:
5 | learnable_add: false
6 | load_ref_model: false
7 | ref_model_dir: "model_repo/opt-125m"
8 | ref_model_dtype: null
9 | finetune_embedding: true
10 | model_path: null
11 |
12 | llm:
13 | ref_model_dir: "model_repo/opt-125m"
14 | ref_model_dtype: null
15 | finetune_layers: -1
16 |
17 | finetune_svae: true
18 | max_sen_len: 64
19 | max_sen_num: 64
20 | batch_size: 1
21 | base_lr: 0.000001
22 | resume_train: true
23 | dataloader_num_workers: 32
24 | dataloader_prefetch_factor: 20
25 | save_checkpoint_iters: 5000
26 | max_iters: 1600000
27 | warmup_iters: 5000
28 | val_iters: 5000
29 | cosineannealinglr_tmax: 20000
30 | max_keep_ckpts: 2
31 |
--------------------------------------------------------------------------------
/config/SVAE/SVAE-1024/svae_1024_base.yaml:
--------------------------------------------------------------------------------
1 | __base__: !include ../svae_base.yaml
2 |
3 | ref_model_dir: "model_repo/opt-350m"
4 |
--------------------------------------------------------------------------------
/config/SVAE/SVAE-1024/svae_1024_h1.yaml:
--------------------------------------------------------------------------------
1 | __base__: !include ./svae_1024_base.yaml
2 |
3 | num_hidden_layers: 1
4 |
--------------------------------------------------------------------------------
/config/SVAE/SVAE-1024/svae_1024_h2.yaml:
--------------------------------------------------------------------------------
1 | __base__: !include ./svae_1024_base.yaml
2 |
3 | num_hidden_layers: 2
4 |
--------------------------------------------------------------------------------
/config/SVAE/SVAE-1024/svae_1024_h4.yaml:
--------------------------------------------------------------------------------
1 | __base__: !include ./svae_1024_base.yaml
2 |
3 | num_hidden_layers: 4
4 |
--------------------------------------------------------------------------------
/config/SVAE/SVAE-2048/svae_2048_base.yaml:
--------------------------------------------------------------------------------
1 | __base__: !include ../svae_base.yaml
2 |
3 | ref_model_dir: "model_repo/opt-1.3b"
4 |
--------------------------------------------------------------------------------
/config/SVAE/SVAE-2048/svae_2048_h1.yaml:
--------------------------------------------------------------------------------
1 | __base__: !include ./svae_2048_base.yaml
2 |
3 | num_hidden_layers: 1
4 |
--------------------------------------------------------------------------------
/config/SVAE/SVAE-2048/svae_2048_h2.yaml:
--------------------------------------------------------------------------------
1 | __base__: !include ./svae_2048_base.yaml
2 |
3 | num_hidden_layers: 2
4 |
--------------------------------------------------------------------------------
/config/SVAE/SVAE-2048/svae_2048_h4.yaml:
--------------------------------------------------------------------------------
1 | __base__: !include ./svae_2048_base.yaml
2 |
3 | num_hidden_layers: 4
4 |
--------------------------------------------------------------------------------
/config/SVAE/SVAE-768/svae_768_base.yaml:
--------------------------------------------------------------------------------
1 | __base__: !include ../svae_base.yaml
2 |
3 | ref_model_dir: "model_repo/opt-125m"
4 |
--------------------------------------------------------------------------------
/config/SVAE/SVAE-768/svae_768_h1.yaml:
--------------------------------------------------------------------------------
1 | __base__: !include ./svae_768_base.yaml
2 |
3 | num_hidden_layers: 1
4 |
--------------------------------------------------------------------------------
/config/SVAE/SVAE-768/svae_768_h2.yaml:
--------------------------------------------------------------------------------
1 | __base__: !include ./svae_768_base.yaml
2 |
3 | num_hidden_layers: 2
4 |
--------------------------------------------------------------------------------
/config/SVAE/SVAE-768/svae_768_h4.yaml:
--------------------------------------------------------------------------------
1 | __base__: !include ./svae_768_base.yaml
2 |
3 | num_hidden_layers: 4
4 |
--------------------------------------------------------------------------------
/config/SVAE/svae_base.yaml:
--------------------------------------------------------------------------------
1 | device: "cuda"
2 | dtype: "fp32"
3 | learnable_add: false
4 | load_ref_model: false
5 | ref_model_dir: "model_repo/opt-125m"
6 | ref_model_dtype: null
7 | num_hidden_layers: 1
8 | finetune_embedding: true
9 | model_path: null
10 | max_seq_len: 64
11 | batch_size: 128
12 | base_lr: 0.0000001
13 | resume_train: true
14 | dataloader_num_workers: 32
15 | dataloader_prefetch_factor: 20
16 | save_checkpoint_iters: 5000
17 | max_iters: 300000
18 | warmup_iters: 5000
19 | val_iters: 5000
20 | cosineannealinglr_tmax: 20000
21 | max_keep_ckpts: 2
22 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | mmengine
2 | torch
3 | transformers
4 | tensorboard
5 | accelerate
6 | pyyaml
7 | pyyaml-include
8 | tqdm
9 | numpy
10 |
--------------------------------------------------------------------------------
/sentence_vae/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | __version__ = "0.0.1"
30 |
--------------------------------------------------------------------------------
/sentence_vae/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | from .tele_ds_dataset import TeleDSDataset
30 | from .data_collate import SentenceCollate, PassageCollate
31 | from .data_eval import SVAE_PPL, SLLM_PPL
32 |
--------------------------------------------------------------------------------
/sentence_vae/data/data_collate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | import re
30 | import torch
31 |
32 |
33 | class SentenceCollate:
34 | def __init__(
35 | self,
36 | tokenizer,
37 | max_len=1024,
38 | padding=True,
39 | fix_len=True
40 | ):
41 | self.tokenizer = tokenizer
42 | self.max_len = max_len
43 | self.padding = padding
44 | self.fix_len = fix_len
45 |
46 | def __call__(self, texts):
47 | encoded = self.tokenizer.batch_encode_plus(
48 | texts,
49 | padding=self.padding,
50 | truncation=True,
51 | max_length=self.max_len,
52 | return_tensors="pt"
53 | )
54 |
55 | input_ids = encoded['input_ids']
56 | attention_mask = encoded['attention_mask']
57 |
58 | if self.fix_len:
59 | batch, seq_len = input_ids.shape
60 | if seq_len < self.max_len:
61 | pad_ids = torch.zeros(
62 | (batch, self.max_len - seq_len),
63 | dtype=input_ids.dtype,
64 | device=input_ids.device
65 | ).fill_(self.tokenizer.pad_token_id)
66 | pad_mask = torch.zeros(
67 | (batch, self.max_len - seq_len),
68 | dtype=attention_mask.dtype,
69 | device=attention_mask.device
70 | )
71 |
72 | input_ids = torch.concat((input_ids, pad_ids), dim=1)
73 | attention_mask = torch.concat((attention_mask, pad_mask), dim=1)
74 |
75 | return input_ids, attention_mask
76 |
77 |
78 | class PassageCollate:
79 | def __init__(
80 | self,
81 | tokenizer,
82 | max_sentence_len=512,
83 | max_sentence_num=512,
84 | padding=True,
85 | fix_len=True
86 | ):
87 | self.tokenizer = tokenizer
88 | self.max_sentence_len = max_sentence_len
89 | self.max_sentence_num = max_sentence_num
90 | self.padding = padding
91 | self.fix_len = fix_len
92 |
93 | self._re_sentence = re.compile(
94 | '([,。!?;:\?])([^”’])|' + '([,.!?;:\?])([^"\'])|' +
95 | '(\…{2})([^”’])|' + '(\.{6})([^”’])|' +
96 | '([。!?\?][”’])([^,。!?\?])|' + '([.!?\?]["\'])([^,。!?\?])'
97 | )
98 |
99 |
100 | def cut_sentence_func(self, para):
101 | p, last_p, length = 0, 0, len(para)
102 | sentences = []
103 | sentence_mask = []
104 |
105 | while p < length and len(sentences) < self.max_sentence_num:
106 | if p + 1 < length and (
107 | para[p:p+2] == "……" or
108 | para[p:p+2] == ",”" or
109 | para[p:p+2] == "”," or
110 | para[p:p+2] == "。”" or
111 | para[p:p+2] == "”。" or
112 | para[p:p+2] == "!”" or
113 | para[p:p+2] == "”!" or
114 | para[p:p+2] == "?”" or
115 | para[p:p+2] == "”?" or
116 | para[p:p+2] == ',"' or
117 | para[p:p+2] == '",' or
118 | para[p:p+2] == '."' or
119 | para[p:p+2] == '".' or
120 | para[p:p+2] == '!"' or
121 | para[p:p+2] == '"!' or
122 | para[p:p+2] == '?"' or
123 | para[p:p+2] == '"?'
124 | ):
125 | p += 2
126 | # cut
127 | elif para[p] in [
128 | ',', '。', '”', '!', '?', ';', ':',
129 | ',', '.', '!', '?', ';', ':', '"', '\n', '\r'
130 | ]:
131 | p += 1
132 | # cut
133 | else:
134 | p += 1
135 | continue
136 |
137 | while p < length and para[p] in [' ', '\n', '\r']: p += 1
138 | sentences.append(para[last_p:p])
139 | sentence_mask.append(1)
140 | last_p = p
141 |
142 | delta = self.max_sentence_num - len(sentences)
143 | if delta:
144 | sentences.extend(["" for _ in range(delta)])
145 | sentence_mask.extend([0 for _ in range(delta)])
146 |
147 | sentence_mask = torch.tensor(sentence_mask, dtype=torch.int64)
148 |
149 | return sentences, sentence_mask
150 |
151 |
152 | def __call__(self, texts):
153 | batch_size = len(texts)
154 |
155 | batch_sentence_mask = torch.zeros(
156 | (batch_size, self.max_sentence_num),
157 | dtype=torch.int64
158 | )
159 | batch_sentence_toks = torch.zeros(
160 | (batch_size, self.max_sentence_num, self.max_sentence_len),
161 | dtype=torch.int64
162 | )
163 | batch_tok_mask = torch.zeros(
164 | (batch_size, self.max_sentence_num, self.max_sentence_len),
165 | dtype=torch.int64
166 | )
167 |
168 | for i, text in enumerate(texts):
169 | sentences, sentence_mask = self.cut_sentence_func(text)
170 | encoded = self.tokenizer.batch_encode_plus(
171 | sentences,
172 | padding=self.padding,
173 | truncation=True,
174 | max_length=self.max_sentence_len,
175 | return_tensors="pt"
176 | )
177 |
178 | sentence_toks, tok_mask = encoded['input_ids'], encoded['attention_mask']
179 | sentence_num, seq_len = sentence_toks.shape
180 | if seq_len < self.max_sentence_len:
181 | pad_ids = torch.zeros(
182 | (sentence_num, self.max_sentence_len - seq_len),
183 | dtype=sentence_toks.dtype,
184 | device=sentence_toks.device
185 | ).fill_(self.tokenizer.pad_token_id)
186 | pad_mask = torch.zeros(
187 | (sentence_num, self.max_sentence_len - seq_len),
188 | dtype=tok_mask.dtype,
189 | device=tok_mask.device
190 | )
191 | sentence_toks = torch.concat((sentence_toks, pad_ids), dim=1)
192 | tok_mask = torch.concat((tok_mask, pad_mask), dim=1)
193 |
194 | batch_sentence_mask[i] = sentence_mask
195 | batch_sentence_toks[i] = sentence_toks
196 | batch_tok_mask[i] = tok_mask
197 |
198 | return batch_sentence_mask, batch_sentence_toks, batch_tok_mask
199 |
--------------------------------------------------------------------------------
/sentence_vae/data/data_eval.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | import numpy as np
30 | import torch.nn.functional as F
31 | from mmengine.evaluator import BaseMetric
32 |
33 |
34 | class SVAE_PPL(BaseMetric):
35 | def process(self, data_batch, data_samples):
36 | output, attention_mask, tgt_ids = data_samples
37 | loss = F.cross_entropy(output[attention_mask], tgt_ids[:, 1:][attention_mask])
38 | result = {'loss': loss.item()}
39 | self.results.append(result)
40 |
41 | def compute_metrics(self, results):
42 | loss = np.mean([res['loss'] for res in results])
43 | ppl = np.exp(loss)
44 | return dict(eval_loss=loss, eval_ppl=ppl)
45 |
46 |
47 | class SLLM_PPL(BaseMetric):
48 | def process(self, data_batch, data_sample):
49 | stop_loss, ppl_loss = data_sample
50 | result = {'stop_loss': stop_loss.item(), 'ppl_loss': ppl_loss.item()}
51 | self.results.append(result)
52 |
53 | def compute_metrics(self, results):
54 | stop_loss = np.mean([res['stop_loss'] for res in results])
55 | loss = np.mean([res['ppl_loss'] for res in results])
56 | ppl = np.exp(loss)
57 | return dict(eval_stop_loss=stop_loss, eval_loss=loss, eval_ppl=ppl)
58 |
--------------------------------------------------------------------------------
/sentence_vae/data/tele_ds_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | from time import sleep
30 | from torch.utils.data import Dataset
31 |
32 | from sentence_vae.utils import fetch_text_with_retry
33 |
34 |
35 | class TeleDSDataset(Dataset):
36 | def __init__(
37 | self,
38 | server_ip="127.0.0.1",
39 | server_port=8000,
40 | eval_mode=False,
41 | eval_samples=1000
42 | ):
43 | self.server_url = f"http://{server_ip}:{server_port}"
44 | self.eval_mode = eval_mode
45 | self.eval_samples = eval_samples
46 |
47 | def __len__(self):
48 | while True:
49 | status = fetch_text_with_retry(f"{self.server_url}/status").strip()
50 | if status != "ready.":
51 | print("TeleDS Server is not ready, retrying.")
52 | else:
53 | break
54 | sleep(1)
55 | # print("Server is ready!")
56 | num = int(fetch_text_with_retry(f"{self.server_url}/count").strip())
57 | assert num > self.eval_samples
58 | if not self.eval_mode:
59 | return num - self.eval_samples
60 | else:
61 | return self.eval_samples
62 |
63 | def __getitem__(self, idx):
64 | if not self.eval_mode:
65 | idx += self.eval_samples
66 | text = fetch_text_with_retry(f"{self.server_url}/data/{idx}")
67 | return text
68 |
--------------------------------------------------------------------------------
/sentence_vae/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | from .positional_encoding import PositionalEncodding
30 | from .sentence_encoder import SentenceEncoder
31 | from .sentence_decoder import SentenceDecoder
32 | from .sentence_vae_model import SentenceVAE
33 | from .sentence_llm_model import SentenceLLM
34 | from .focal_loss import FocalLoss
35 |
--------------------------------------------------------------------------------
/sentence_vae/models/focal_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | import torch
30 | import torch.nn as nn
31 | import torch.nn.functional as F
32 |
33 |
34 | class FocalLoss(nn.Module):
35 | def __init__(self, gamma=2.0, reduction='mean', eps=1e-6):
36 | super(FocalLoss, self).__init__()
37 | self.gamma = gamma
38 | self.reduction = reduction
39 | self.eps = eps
40 |
41 | def forward(self, logits, targets):
42 | max_logits = torch.max(logits, dim=-1, keepdim=True)[0]
43 | log_probs = logits - max_logits
44 | log_probs = log_probs - torch.log(torch.sum(torch.exp(log_probs), dim=-1, keepdim=True))
45 | probs = torch.exp(log_probs)
46 |
47 | targets_one_hot = F.one_hot(targets, num_classes=logits.size(-1))
48 | targets_probs = torch.sum(probs * targets_one_hot, dim=-1)
49 |
50 | focal_weight = (1 - targets_probs) ** self.gamma
51 | log_targets_probs = torch.log(targets_probs + self.eps)
52 | loss = -focal_weight * log_targets_probs
53 |
54 | loss = loss[~torch.isnan(loss)]
55 | if loss.size(0) == 0:
56 | return torch.tensor(0, requires_grad=True).to(targets.device)
57 |
58 | if self.reduction == 'mean':
59 | loss = torch.mean(loss)
60 | elif self.reduction == 'sum':
61 | loss = torch.sum(loss)
62 |
63 | return loss
64 |
--------------------------------------------------------------------------------
/sentence_vae/models/positional_encoding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | import math
30 | import torch
31 | import torch.nn as nn
32 |
33 |
34 | class PositionalEncodding(nn.Module):
35 | def __init__(
36 | self,
37 | hidden_size: int,
38 | max_len: int = 4096,
39 | device: torch.device = None,
40 | dtype: torch.dtype = None
41 | ):
42 | super(PositionalEncodding, self).__init__()
43 | assert hidden_size % 2 == 0, \
44 | f"Cannot use sin/cos positional encoding with odd hidden_size (go size={hidden_size})."
45 |
46 | device = device if device is not None else torch.device('cuda')
47 | dtype = dtype if dtype is not None else torch.float16
48 |
49 | pe = torch.zeros(max_len, hidden_size)
50 | position = torch.arange(0, max_len).unsqueeze(1)
51 | div_term = torch.exp((torch.arange(0, hidden_size, 2) * -(math.log(10000.0) / hidden_size)))
52 |
53 | pe[:, 0::2] = torch.sin(position * div_term)
54 | pe[:, 1::2] = torch.cos(position * div_term)
55 |
56 | self.pe = pe.unsqueeze(0)
57 | self.pe.requires_grad = False
58 |
59 | self.pe = self.pe.to(dtype)
60 | self.pe = self.pe.to(device)
61 |
62 |
63 | def forward(self, seq_len):
64 | return self.pe[:, :seq_len, :]
65 |
--------------------------------------------------------------------------------
/sentence_vae/models/sentence_decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | import torch
30 | import torch.nn as nn
31 | from typing import Union
32 |
33 | from sentence_vae.models import PositionalEncodding
34 | from sentence_vae.utils.llm import get_model
35 | from sentence_vae.utils.weights import load_embedding_state_dict
36 |
37 |
38 | class SentenceDecoder(nn.Module):
39 | def __init__(
40 | self,
41 | hidden_size: int,
42 | vocab_size: int,
43 | device: torch.device = None,
44 | dtype: torch.dtype = None,
45 | load_ref_model: Union[bool, nn.Module] = False,
46 | ref_model_dir: str = None,
47 | ref_model_dtype: torch.dtype = None,
48 | finetune_embedding: bool = True,
49 | word_embed_proj_dim: int = None,
50 | num_attention_heads: int = 16,
51 | num_hidden_layers: int = 1,
52 | max_seq_len: int = 1024,
53 | dropout: float = 0.1,
54 | pad_id=1,
55 | ):
56 | super(SentenceDecoder, self).__init__()
57 | word_embed_proj_dim = word_embed_proj_dim if word_embed_proj_dim is not None else hidden_size
58 |
59 | self.device = device if device is not None else torch.device('cuda')
60 | self.dtype = dtype if dtype is not None else torch.float16
61 |
62 | self.embed_token = nn.Embedding(vocab_size, word_embed_proj_dim, padding_idx=pad_id)
63 | self.embed_positions = PositionalEncodding(hidden_size, max_seq_len, dtype=self.dtype, device=self.device)
64 |
65 | if word_embed_proj_dim != hidden_size:
66 | self.project_in = nn.Linear(word_embed_proj_dim, hidden_size, bias=False)
67 | else:
68 | self.project_in = None
69 |
70 | decoder_layer = nn.TransformerDecoderLayer(
71 | d_model=hidden_size,
72 | nhead=num_attention_heads,
73 | dim_feedforward=hidden_size * 2,
74 | dropout=dropout,
75 | batch_first=True,
76 | )
77 |
78 | self.decoder = nn.TransformerDecoder(decoder_layer, num_hidden_layers)
79 | self.linear = nn.Linear(hidden_size, vocab_size)
80 |
81 | if isinstance(load_ref_model, nn.Module):
82 | ref_model_dtype = ref_model_dtype if ref_model_dtype is not None else self.dtype
83 | ref_model = load_ref_model
84 | ref_model_state_dict = load_embedding_state_dict(ref_model)
85 | assert ref_model_state_dict is not None, f"Model {ref_model_dir} does not have an Embedding layer."
86 | self.embed_token.load_state_dict(ref_model_state_dict)
87 | elif load_ref_model and ref_model_dir is not None:
88 | ref_model_dtype = ref_model_dtype if ref_model_dtype is not None else self.dtype
89 | ref_model = get_model(ref_model_dir, ref_model_dtype, 'cpu')
90 | ref_model_state_dict = load_embedding_state_dict(ref_model)
91 | assert ref_model_state_dict is not None, f"Model {ref_model_dir} does not have an Embedding layer."
92 | self.embed_token.load_state_dict(ref_model_state_dict)
93 | if not finetune_embedding:
94 | self.embed_token.weight.requires_grad = False
95 |
96 | self.to(self.dtype)
97 | self.to(self.device)
98 |
99 |
100 | def forward(self, input_ids, sentence_embed, attention_mask=None):
101 | _, seq_len = input_ids.shape
102 | if attention_mask is not None and attention_mask.dtype is not torch.bool:
103 | attention_mask = ~attention_mask.to(torch.bool)
104 |
105 | inputs_emb = self.embed_token(input_ids)
106 | pos_emb = self.embed_positions(seq_len)
107 | if self.project_in is not None:
108 | inputs_emb = self.project_in(inputs_emb)
109 | inputs_emb = inputs_emb + pos_emb
110 |
111 | future_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=self.device) == -torch.inf
112 |
113 | hidden_state = self.decoder(
114 | inputs_emb,
115 | sentence_embed,
116 | tgt_mask=future_mask,
117 | tgt_key_padding_mask=attention_mask,
118 | tgt_is_causal=True
119 | )
120 | output = self.linear(hidden_state)
121 |
122 | return output
123 |
124 |
125 | def streaming_generate(
126 | self, sentence_embed,
127 | max_output_len:int=64,
128 | bos_token_id: int=2,
129 | eos_token_id: int=2
130 | ):
131 | output_ids = torch.tensor([[bos_token_id]], dtype=torch.long, device=self.device)
132 | while len(output_ids) < max_output_len:
133 | logits = self.forward(output_ids, sentence_embed)
134 | new_id = torch.argmax(logits[:, -1:], dim=-1)
135 | output_ids = torch.concat((output_ids, new_id), dim=1)
136 | yield new_id.item()
137 | if new_id.item() == eos_token_id:
138 | break
139 |
140 |
141 | def generate(
142 | self, sentence_embed,
143 | max_output_len:int=64,
144 | bos_token_id: int=2,
145 | eos_token_id: int=2
146 | ):
147 | output_ids = []
148 | for output_id in self.streaming_generate(sentence_embed, max_output_len, bos_token_id, eos_token_id):
149 | output_ids.append(output_id)
150 | return output_ids
151 |
--------------------------------------------------------------------------------
/sentence_vae/models/sentence_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | import torch
30 | import torch.nn as nn
31 | from typing import Union
32 |
33 | from sentence_vae.models import PositionalEncodding
34 | from sentence_vae.utils.llm import get_model
35 | from sentence_vae.utils.weights import load_embedding_state_dict
36 |
37 |
38 | class SentenceEncoder(nn.Module):
39 | def __init__(
40 | self,
41 | hidden_size: int,
42 | vocab_size: int,
43 | device: torch.device = None,
44 | dtype: torch.dtype = None,
45 | learnable_add: bool = True,
46 | load_ref_model: Union[bool, nn.Module] = False,
47 | ref_model_dir: str = None,
48 | ref_model_dtype: torch.dtype = None,
49 | finetune_embedding: bool = True,
50 | word_embed_proj_dim: int = None,
51 | num_attention_heads: int = 16,
52 | num_hidden_layers: int = 1,
53 | max_seq_len: int = 1024,
54 | dropout: float = 0.1,
55 | pad_id=1,
56 | ):
57 | super(SentenceEncoder, self).__init__()
58 | word_embed_proj_dim = word_embed_proj_dim if word_embed_proj_dim is not None else hidden_size
59 |
60 | self.device = device if device is not None else torch.device('cuda')
61 | self.dtype = dtype if dtype is not None else torch.float16
62 | self.learnable_add = learnable_add
63 |
64 | self.embed_token = nn.Embedding(vocab_size, word_embed_proj_dim, padding_idx=pad_id)
65 | self.embed_positions = PositionalEncodding(hidden_size, max_seq_len, dtype=self.dtype, device=self.device)
66 |
67 | if word_embed_proj_dim != hidden_size:
68 | self.project_in = nn.Linear(word_embed_proj_dim, hidden_size, bias=False)
69 | else:
70 | self.project_in = None
71 |
72 | encoder_layer = nn.TransformerEncoderLayer(
73 | d_model=hidden_size,
74 | nhead=num_attention_heads,
75 | dim_feedforward=hidden_size * 2,
76 | dropout=dropout,
77 | batch_first=True,
78 | )
79 | self.encoder = nn.TransformerEncoder(encoder_layer, num_hidden_layers)
80 | if self.learnable_add:
81 | self.la = nn.Linear(hidden_size, 1)
82 | # self.conv1d = nn.Conv1d(hidden_size, 1, kernel_size=3, stride=1, padding=1, bias=True)
83 | self.lnorm = nn.LayerNorm(hidden_size, device=self.device, dtype=self.dtype)
84 |
85 | if isinstance(load_ref_model, nn.Module):
86 | ref_model_dtype = ref_model_dtype if ref_model_dtype is not None else self.dtype
87 | ref_model = load_ref_model
88 | ref_model_state_dict = load_embedding_state_dict(ref_model)
89 | assert ref_model_state_dict is not None, f"Model {ref_model_dir} does not have an Embedding layer."
90 | self.embed_token.load_state_dict(ref_model_state_dict)
91 | elif load_ref_model and ref_model_dir is not None:
92 | ref_model_dtype = ref_model_dtype if ref_model_dtype is not None else self.dtype
93 | ref_model = get_model(ref_model_dir, ref_model_dtype, 'cpu')
94 | ref_model_state_dict = load_embedding_state_dict(ref_model)
95 | assert ref_model_state_dict is not None, f"Model {ref_model_dir} does not have an Embedding layer."
96 | self.embed_token.load_state_dict(ref_model_state_dict)
97 | if not finetune_embedding:
98 | self.embed_token.weight.requires_grad = False
99 |
100 | self.to(self.dtype)
101 | self.to(self.device)
102 |
103 |
104 | def forward(self, input_ids, attention_mask=None):
105 | _, seq_len = input_ids.shape
106 | if attention_mask is not None and attention_mask.dtype is not torch.bool:
107 | attention_mask = ~attention_mask.to(torch.bool)
108 |
109 | inputs_emb = self.embed_token(input_ids)
110 | pos_emb = self.embed_positions(seq_len)
111 | if self.project_in is not None:
112 | inputs_emb = self.project_in(inputs_emb)
113 | inputs_emb = inputs_emb + pos_emb
114 |
115 | hidden_state = self.encoder(inputs_emb, src_key_padding_mask=attention_mask)
116 | hidden_state[attention_mask] = 0 # (batch, seq_len, hidden)
117 |
118 | if self.learnable_add:
119 | alpha = self.la(hidden_state)
120 | # print(alpha.shape)
121 | # hidden_trans = hidden_state.transpose(2, 1).contiguous() # (batch, hidden, seq_len)
122 | # alpha = self.conv1d(hidden_trans) # (batch, 1, seq_len)
123 | # alpha = alpha.transpose(2, 1).contiguous() # (batch, seqlen, hidden)
124 | sentence_emb = torch.sum(hidden_state * alpha, dim=-2, keepdim=True)
125 | else:
126 | sentence_emb = torch.sum(hidden_state, dim=-2, keepdim=True)
127 | sentence_emb_norm = self.lnorm(sentence_emb)
128 |
129 | return sentence_emb_norm
130 |
--------------------------------------------------------------------------------
/sentence_vae/models/sentence_llm_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | import torch
30 | import torch.nn as nn
31 | import torch.nn.functional as F
32 |
33 | from mmengine.model import BaseModel
34 |
35 | from .focal_loss import FocalLoss
36 | from .sentence_vae_model import SentenceVAE
37 | from sentence_vae.utils.llm import get_model
38 |
39 |
40 | class SentenceLLM(BaseModel):
41 | def __init__(
42 | self,
43 | svae_hidden_size: int,
44 | svae_vocab_size: int,
45 | svae_learnable_add: bool = True,
46 | svae_load_ref_model: bool = False,
47 | svae_ref_model_dir: str = None,
48 | svae_ref_model_dtype: torch.dtype = None,
49 | svae_finetune_embedding: bool = True,
50 | svae_word_embed_proj_dim: int = None,
51 | svae_num_attention_heads: int = 16,
52 | svae_num_hidden_layers: int = 1,
53 | svae_model_path: str = None,
54 | llm_ref_model_dir: str = None,
55 | llm_ref_model_dtype: torch.dtype = None,
56 | llm_finetune_layers: int = -1,
57 | finetune_svae: bool = True,
58 | max_sentence_len: int = 512,
59 | max_sentence_num: int = 512,
60 | dropout: float = 0.1,
61 | bos_id=2,
62 | pad_id=1,
63 | end_id=2,
64 | device: torch.device = None,
65 | dtype: torch.dtype = None,
66 | ):
67 | super().__init__()
68 | self.bos_token_id = bos_id
69 | self.pad_token_id = pad_id
70 | self.eos_token_id = end_id
71 | self.hidden_size = svae_hidden_size
72 | self.llm_finetune_layers = llm_finetune_layers
73 | self.max_sentence_len = max_sentence_len
74 | self.max_sentence_num = max_sentence_num
75 |
76 | self.device = device if device is not None else torch.device("cuda")
77 | self.dtype = dtype if dtype is not None else torch.float16
78 |
79 | self.svae = SentenceVAE(
80 | hidden_size=svae_hidden_size, vocab_size=svae_vocab_size, device=self.device, dtype=self.dtype,
81 | learnable_add=svae_learnable_add, load_ref_model=svae_load_ref_model, ref_model_dir=svae_ref_model_dir, ref_model_dtype=svae_ref_model_dtype,
82 | finetune_embedding=svae_finetune_embedding, word_embed_proj_dim=svae_word_embed_proj_dim,
83 | num_attention_heads=svae_num_attention_heads, num_hidden_layers=svae_num_hidden_layers, max_seq_len=max_sentence_len, dropout=dropout,
84 | bos_id=bos_id, pad_id=pad_id, end_id=end_id
85 | )
86 |
87 | llm_ref_model_dtype = llm_ref_model_dtype if llm_ref_model_dtype else self.dtype
88 | llm = get_model(llm_ref_model_dir, llm_ref_model_dtype, self.device, True)
89 | self.llm_pe = llm.model.decoder.embed_positions
90 | self.llm_layers = llm.model.decoder.layers
91 |
92 | self.fc = nn.Linear(svae_hidden_size, 2)
93 | self.focal_loss = FocalLoss()
94 |
95 | if svae_model_path is not None:
96 | print(f"Loading {svae_model_path}")
97 | ckpt = torch.load(svae_model_path)
98 | self.svae.load_state_dict(ckpt['state_dict'])
99 |
100 | self.to(self.dtype)
101 | self.to(self.device)
102 |
103 | if not finetune_svae:
104 | self._freeze_model(self.svae)
105 | if llm_finetune_layers >= 0 and llm_finetune_layers * 2 < len(self.llm_layers):
106 | for i in range(llm_finetune_layers, len(self.llm_layers) - llm_finetune_layers):
107 | self._freeze_model(self.llm_layers[i])
108 |
109 |
110 | def _freeze_model(self, model: nn.Module):
111 | for param in model.parameters():
112 | param.requires_grad = False
113 |
114 |
115 | def forward(self, sentence_mask, sentence_toks, tok_mask, mode='loss'):
116 | sentence_mask = sentence_mask.to(self.device)
117 | sentence_toks = sentence_toks.to(self.device)
118 | tok_mask = tok_mask.to(self.device)
119 | batch_size, sen_num, seq_len = sentence_toks.shape
120 |
121 | sentence_embedding = torch.zeros(
122 | (batch_size, sen_num, self.hidden_size),
123 | dtype=self.dtype, device=self.device
124 | )
125 |
126 | # encoder
127 | for b in range(batch_size):
128 | sentence_embedding[b] = self.svae.encoder(sentence_toks[b], tok_mask[b]).view(sen_num, self.hidden_size)
129 | sentence_embedding[b][~sentence_mask[b].bool()] = 0
130 |
131 | # llm
132 | pos_emb = self.llm_pe(sentence_mask, 0)
133 | hidden_state = sentence_embedding + pos_emb
134 | for layer in self.llm_layers:
135 | hidden_state = layer(hidden_state)[0]
136 | stop_flag = self.fc(hidden_state)
137 |
138 | sentence_pad = torch.zeros((batch_size, 1), dtype=sentence_mask.dtype, device=sentence_mask.device)
139 | tgt_stop_flag = torch.concat((sentence_mask, sentence_pad), dim=1)[:, 1:]
140 | sen_lens = torch.sum(sentence_mask, dim=1)
141 | stop_loss = 0
142 | for b in range(batch_size):
143 | sen_len = sen_lens[b]
144 | stop_loss += self.focal_loss(stop_flag[b, :sen_len], tgt_stop_flag[b, :sen_len])
145 | stop_loss /= batch_size
146 |
147 | # decoder
148 | hidden_state = hidden_state.view(batch_size, sen_num, 1, self.hidden_size)
149 | decode_loss = 0
150 | ppl_loss = 0
151 | for b in range(batch_size):
152 | sen_len = sen_lens[b]
153 | input_ids = sentence_toks[b, 1:sen_len]
154 | attention_mask = tok_mask[b, :sen_len-1]
155 | sentence_embd = hidden_state[b, :sen_len-1]
156 | output = self.svae.decoder(input_ids, sentence_embd, attention_mask)
157 | pad_ids = torch.zeros((sen_len-1, 1), device=self.device, dtype=input_ids.dtype).fill_(self.pad_token_id)
158 | tgt_ids = torch.cat((input_ids, pad_ids), dim=1)
159 | seq_lens = torch.sum(attention_mask, dim=1, keepdim=True)
160 | tgt_ids.scatter_(1, seq_lens, self.eos_token_id)
161 | attention_mask = attention_mask.bool()
162 | decode_loss += self.focal_loss(output[attention_mask], tgt_ids[:, 1:][attention_mask])
163 | ppl_loss += F.cross_entropy(output[attention_mask], tgt_ids[:, 1:][attention_mask])
164 | del input_ids
165 | del attention_mask
166 | del sentence_embd
167 | del output
168 | del pad_ids
169 | del tgt_ids
170 | decode_loss /= batch_size
171 | ppl_loss /= batch_size
172 |
173 | if mode == 'loss':
174 | if stop_loss < 1e-2:
175 | return {"decode_loss": decode_loss}
176 | return {"stop_loss": stop_loss, "decode_loss": decode_loss}
177 | elif mode == 'predict':
178 | return stop_loss, ppl_loss
179 |
180 |
181 | def streaming_generate(self, sentence_mask, sentence_toks, tok_mask, max_sentence_num=64, token_level=False):
182 | batch_size, sen_num, seq_len = sentence_toks.shape
183 | assert batch_size == 1
184 |
185 | sentence_mask = sentence_mask[:batch_size].to(self.device)
186 | sen_num = torch.sum(sentence_mask)
187 | sentence_mask = sentence_mask[:batch_size, :sen_num]
188 | sentence_toks = sentence_toks[:batch_size, :sen_num].to(self.device)
189 | tok_mask = tok_mask[:batch_size, :sen_num].to(self.device)
190 |
191 | sentence_embedding = torch.zeros(
192 | (batch_size, sen_num, self.hidden_size),
193 | dtype=self.dtype, device=self.device
194 | )
195 |
196 | # encoder
197 | for b in range(batch_size):
198 | sentence_embedding[b] = self.svae.encoder(sentence_toks[b], tok_mask[b]).view(sen_num, self.hidden_size)
199 | sentence_embedding[b][~sentence_mask[b].bool()] = 0
200 |
201 | # llm
202 | pos_emb = self.llm_pe(sentence_mask, 0)
203 | hidden_state = sentence_embedding + pos_emb
204 | past_key_values = []
205 | past_key_values_length = sen_num
206 | for layer in self.llm_layers:
207 | hidden_state, kv_cache = layer(hidden_state, use_cache=True)
208 | past_key_values.append(kv_cache)
209 |
210 | while past_key_values_length < max_sentence_num:
211 | stop_flag = torch.argmax(self.fc(hidden_state[:batch_size, -1:, :]), dim=-1)
212 | if stop_flag.item() == 0:
213 | break
214 | new_sentence = self.svae.decoder.generate(
215 | hidden_state[:batch_size, -1:, :],
216 | self.max_sentence_len,
217 | self.bos_token_id,
218 | self.eos_token_id
219 | )
220 |
221 | yield new_sentence
222 |
223 | new_sentence = torch.tensor([new_sentence], dtype=torch.long, device=self.device)
224 | new_sentence_embedding = self.svae.encoder(new_sentence)
225 | pos_emb = self.llm_pe(torch.ones((1, past_key_values_length + 1), dtype=torch.long, device=self.device), past_key_values_length)
226 | past_key_values_length += 1
227 | hidden_state = new_sentence_embedding + pos_emb
228 | for i, layer in enumerate(self.llm_layers):
229 | hidden_state, kv_cache = layer(hidden_state, past_key_value=past_key_values[i], use_cache=True)
230 | past_key_values[i] = kv_cache
231 |
--------------------------------------------------------------------------------
/sentence_vae/models/sentence_vae_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | import torch
30 |
31 | from mmengine.model import BaseModel
32 |
33 | from .focal_loss import FocalLoss
34 | from .sentence_encoder import SentenceEncoder
35 | from .sentence_decoder import SentenceDecoder
36 | from sentence_vae.utils import get_model
37 |
38 |
39 | class SentenceVAE(BaseModel):
40 | def __init__(
41 | self,
42 | hidden_size: int,
43 | vocab_size: int,
44 | device: torch.dtype = None,
45 | dtype: torch.dtype = None,
46 | learnable_add: bool = True,
47 | load_ref_model: bool = False,
48 | ref_model_dir: str = None,
49 | ref_model_dtype: torch.dtype = None,
50 | finetune_embedding: bool = True,
51 | word_embed_proj_dim: int = None,
52 | num_attention_heads: int = 16,
53 | num_hidden_layers: int = 1,
54 | max_seq_len: int = 512,
55 | dropout: float = 0.1,
56 | bos_id=2,
57 | pad_id=1,
58 | end_id=2,
59 | ):
60 | super().__init__()
61 | self.bos_token_id = bos_id
62 | self.pad_token_id = pad_id
63 | self.eos_token_id = end_id
64 | self.max_seq_len = max_seq_len
65 |
66 | self.device = device if device is not None else torch.device("cuda")
67 | self.dtype = dtype if dtype is not None else torch.float16
68 |
69 | if load_ref_model and ref_model_dir is not None:
70 | ref_model_dtype = ref_model_dtype if ref_model_dtype is not None else self.dtype
71 | load_ref_model = get_model(ref_model_dir, ref_model_dtype, 'cpu')
72 | else:
73 | load_ref_model = False
74 |
75 | self.encoder = SentenceEncoder(
76 | hidden_size, vocab_size, device, dtype,
77 | learnable_add, load_ref_model, ref_model_dir, ref_model_dtype,
78 | finetune_embedding, word_embed_proj_dim,
79 | num_attention_heads, num_hidden_layers, max_seq_len,
80 | dropout, pad_id
81 | )
82 | self.decoder = SentenceDecoder(
83 | hidden_size, vocab_size, device, dtype,
84 | load_ref_model, ref_model_dir, ref_model_dtype,
85 | finetune_embedding, word_embed_proj_dim,
86 | num_attention_heads, num_hidden_layers, max_seq_len,
87 | dropout, pad_id
88 | )
89 |
90 | self.focal_loss = FocalLoss()
91 |
92 | def forward(self, input_ids, attention_mask=None, mode='loss'):
93 | input_ids = input_ids.to(self.device)
94 | if attention_mask is not None:
95 | attention_mask = attention_mask.to(self.device)
96 | else:
97 | attention_mask = torch.ones(input_ids.shape, dtype=torch.int64, device=self.device)
98 |
99 | sentence_embd = self.encoder(input_ids, attention_mask)
100 |
101 | output = self.decoder(input_ids, sentence_embd, attention_mask)
102 | batch, _ = input_ids.shape
103 | pad_ids = torch.zeros((batch, 1), device=self.device, dtype=input_ids.dtype).fill_(self.pad_token_id)
104 | tgt_ids = torch.cat((input_ids, pad_ids), dim=1)
105 | seq_lens = torch.sum(attention_mask, dim=1, keepdim=True)
106 | tgt_ids.scatter_(1, seq_lens, self.eos_token_id)
107 | attention_mask = attention_mask.bool()
108 | if mode == 'loss':
109 | loss = self.focal_loss(output[attention_mask], tgt_ids[:, 1:][attention_mask])
110 | return {'total_loss': loss}
111 | elif mode == 'predict':
112 | return output, attention_mask, tgt_ids
113 | else:
114 | return output
115 |
116 | def streaming_generate(self, input_ids, attention_mask=None, max_output_len=64):
117 | batch_size = input_ids.size(0)
118 | assert batch_size == 1
119 |
120 | input_ids = input_ids.to(self.device)
121 | if attention_mask is not None:
122 | attention_mask = attention_mask.to(self.device)
123 | else:
124 | attention_mask = torch.ones(input_ids.shape, dtype=torch.int64, device=self.device)
125 |
126 | sentence_embd = self.encoder(input_ids, attention_mask)
127 |
128 | for output_id in self.decoder.streaming_generate(
129 | sentence_embd,
130 | max_output_len,
131 | self.bos_token_id,
132 | self.eos_token_id
133 | ):
134 | yield output_id
135 |
136 | # output_ids = torch.tensor([[self.bos_token_id]], dtype=torch.long, device=self.device)
137 | # while len(output_ids) < max_output_len:
138 | # logits = self.decoder(output_ids, sentence_embd)
139 | # new_id = torch.argmax(logits[:, -1:], dim=-1)
140 | # output_ids = torch.concat((output_ids, new_id), dim=1)
141 | # yield new_id
142 | # if new_id.item() == self.eos_token_id:
143 | # break
144 |
145 | def generate(self, input_ids, attention_mask=None, max_output_len=64):
146 | output_ids = []
147 | for output_id in self.streaming_generate(input_ids, attention_mask, max_output_len):
148 | output_ids.append(output_id.item())
149 | return output_ids
--------------------------------------------------------------------------------
/sentence_vae/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | from .http import fetch_text_with_retry
30 | from .llm import get_model, get_tokenizer, get_config, get_dtype
31 | from .weights import init_model_weights, load_embedding_state_dict
32 | from .config import load_yaml
33 |
--------------------------------------------------------------------------------
/sentence_vae/utils/config.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | import os
30 |
31 |
32 | def recursive_update(d, u):
33 | for k, v in u.items():
34 | if isinstance(v, dict):
35 | d[k] = recursive_update(d.get(k, {}), v)
36 | else:
37 | d[k] = v
38 | return d
39 |
40 |
41 | def load_yaml(yaml_path):
42 | import yaml
43 | import yaml_include
44 | yaml.add_constructor("!include", yaml_include.Constructor(base_dir=os.path.dirname(os.path.abspath(yaml_path))))
45 | with open(yaml_path, "r") as file:
46 | config = yaml.full_load(file)
47 | while "__base__" in config.keys():
48 | base = config["__base__"]
49 | del config["__base__"]
50 | config = recursive_update(base, config)
51 | return config
52 |
--------------------------------------------------------------------------------
/sentence_vae/utils/http.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | import requests
30 | from time import sleep
31 |
32 |
33 | def fetch_text_with_retry(url, retries=3, timeout=10, backoff_factor=0.3):
34 | """
35 | Fetch text from the given URL with retries on failure.
36 |
37 | Parameters:
38 | url (str): The URL to fetch the text from.
39 | retries (int): Number of retries before giving up.
40 | timeout (int): Timeout in seconds for the HTTP request.
41 | backoff_factor (float): Factor for exponential backoff between retries.
42 |
43 | Returns:
44 | str: The fetched text or None if all retries fail.
45 | """
46 | attempt = 0
47 |
48 | while attempt < retries:
49 | try:
50 | response = requests.get(url, timeout=timeout)
51 | response.raise_for_status()
52 | response.encoding = 'utf-8'
53 | return response.text
54 | except requests.exceptions.RequestException as e:
55 | attempt += 1
56 | wait = backoff_factor * (2 ** attempt)
57 | # print(f"Attempt {attempt} failed: {e}. Retrying in {wait:.1f} seconds...")
58 | sleep(wait)
59 |
60 | print(f"All attempts to fetch the URL have failed when load {url}.")
61 | return "Hello, welcome to TeleDS! Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN), Northwestern PolyTechnical University, and Institute of Artificial Intelligence (TeleAI), China Telecom."
62 |
--------------------------------------------------------------------------------
/sentence_vae/utils/llm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | import torch
30 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
31 |
32 |
33 | def get_config(ckpt_path):
34 | print(f"Loading config from {ckpt_path}")
35 | cfg = AutoConfig.from_pretrained(ckpt_path)
36 | return cfg
37 |
38 |
39 | def get_tokenizer(ckpt_path, max_seq_len):
40 | print(f"Initializaing tokenizer from {ckpt_path}")
41 | tokenizer = AutoTokenizer.from_pretrained(
42 | ckpt_path,
43 | model_max_length=max_seq_len,
44 | padding_side="right",
45 | trust_remote_code=True
46 | )
47 | return tokenizer
48 |
49 |
50 | def get_dtype(dtype):
51 | if isinstance(dtype, torch.dtype):
52 | return dtype
53 |
54 | if dtype == "bf16" or dtype == "bfloat16":
55 | dtype = torch.bfloat16
56 | elif dtype == "fp16" or dtype == "float16":
57 | dtype = torch.float16
58 | elif dtype == "fp32" or dtype == "float32":
59 | dtype = torch.float32
60 | else:
61 | raise NotImplementedError(f"Unknown dtype {dtype}")
62 |
63 | return dtype
64 |
65 |
66 | def is_model_on_gpu(model) -> bool:
67 | """Returns if the model is fully loaded on GPUs."""
68 | return all("cuda" in str(param.device) for param in model.parameters())
69 |
70 |
71 | def get_model(ckpt_path, dtype="fp16", device="cuda", dist=False):
72 | print(f"Initializaing model from {ckpt_path}")
73 | dtype = get_dtype(dtype)
74 | model_kwargs = {"torch_dtype": dtype}
75 |
76 | device_map = "auto"
77 | if device == "cpu":
78 | device_map = "cpu"
79 | if dist:
80 | device_map = {"": device}
81 |
82 | model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map=device_map, trust_remote_code=True, **model_kwargs)
83 | model.eval()
84 |
85 | if device == "cuda":
86 | if not is_model_on_gpu(model):
87 | print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM")
88 |
89 | return model
--------------------------------------------------------------------------------
/sentence_vae/utils/weights.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | import torch.nn as nn
30 |
31 |
32 | def init_model_weights(model: nn.Module, std: float = 1e-5):
33 | for _, module in model.named_modules():
34 | if isinstance(module, nn.Linear):
35 | module.weight.data.normal_(mean=0, std=std)
36 | if module.bias is not None:
37 | module.bias.data.zero_()
38 | elif isinstance(module, nn.Embedding):
39 | module.weight.data.normal_(mean=0, std=std)
40 | if module.padding_idx is not None:
41 | module.weight.data[module.padding_idx].zero_()
42 |
43 |
44 | def load_embedding_state_dict(model: nn.Module):
45 | for _, module in model.named_modules():
46 | if isinstance(module, nn.Embedding):
47 | return module.state_dict()
48 | return None
49 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | import re
30 | import setuptools
31 |
32 |
33 | def get_install_requirements():
34 | with open("requirements.txt", "r", encoding="utf-8") as f:
35 | reqs = [x.strip() for x in f.read().splitlines()]
36 | reqs = [x for x in reqs if not x.startswith("#")]
37 | return reqs
38 |
39 |
40 | def get_sentence_vae_version():
41 | with open("sentence_vae/__init__.py", "r") as f:
42 | version = re.search(
43 | r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]',
44 | f.read(), re.MULTILINE
45 | ).group(1)
46 | return version
47 |
48 |
49 | def get_long_description():
50 | with open("README.md", "r", encoding="utf-8") as f:
51 | long_description = f.read()
52 | return long_description
53 |
54 |
55 | setuptools.setup(
56 | name="sentence_vae",
57 | version=get_sentence_vae_version(),
58 | author="Coder.AN",
59 | url="https://github.com/BestAnHongjun/SentenceVAE",
60 | packages=setuptools.find_packages(exclude=("tools")),
61 | python_requires=">=3.8",
62 | install_requires=get_install_requirements(),
63 | setup_requires=["wheel"], # avoid building error when pip is not updated
64 | long_description=get_long_description(),
65 | long_description_content_type="text/markdown",
66 | include_package_data=True, # include files in MANIFEST.in
67 | classifiers=[
68 | "Programming Language :: Python :: 3", "Operating System :: OS Independent",
69 | "License :: OSI Approved :: MIT Software License",
70 | ],
71 | project_urls={
72 | "paper": "https://arxiv.org/abs/2408.00655",
73 | "Source": "https://github.com/BestAnHongjun/SentenceVAE",
74 | },
75 | )
--------------------------------------------------------------------------------
/tools/analysis/analysis_llm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | import torch
30 | import argparse
31 | from mmengine.analysis import get_model_complexity_info
32 | from sentence_vae.utils import get_model
33 |
34 |
35 | def make_parser():
36 | parser = argparse.ArgumentParser("SentenceVAE analysis parser.")
37 | parser.add_argument("--model_dir", type=str, required=True)
38 | parser.add_argument("--model_dtype", type=str, default="fp16")
39 | parser.add_argument("--device", type=str, default='cuda')
40 | parser.add_argument("--max_seq_len", type=int, default=1024)
41 | parser.add_argument("--batch_size", type=int, default=1)
42 | return parser
43 |
44 |
45 | def main(args):
46 | model = get_model(args.model_dir, args.model_dtype, args.device).eval()
47 | input_shape = (args.batch_size, args.max_seq_len)
48 | input_tensor = torch.ones(input_shape, dtype=torch.long, device=torch.device(args.device))
49 | print(get_model_complexity_info(model, inputs=input_tensor)['out_table'])
50 |
51 |
52 | if __name__ == "__main__":
53 | args = make_parser().parse_args()
54 | main(args)
55 |
--------------------------------------------------------------------------------
/tools/analysis/analysis_sllm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | import torch
30 | import argparse
31 |
32 | from mmengine.analysis import get_model_complexity_info
33 |
34 | from sentence_vae.utils import get_config, get_dtype, load_yaml
35 | from sentence_vae.models import SentenceLLM
36 |
37 |
38 | def make_parser():
39 | parser = argparse.ArgumentParser("SentenceLLM train parser.")
40 | parser.add_argument("-c", "--config", type=str, required=True)
41 | parser.add_argument("--device", type=str, default='cuda')
42 | parser.add_argument("--cards", type=int, default=4)
43 | return parser
44 |
45 |
46 | def main(args):
47 | cfg = load_yaml(args.config)
48 | svae_ref_model_cfg = get_config(cfg["svae"]["ref_model_dir"])
49 |
50 | model = SentenceLLM(
51 | svae_hidden_size=svae_ref_model_cfg.hidden_size,
52 | svae_vocab_size=svae_ref_model_cfg.vocab_size,
53 | svae_learnable_add=cfg["svae"]["learnable_add"],
54 | svae_load_ref_model=cfg["svae"]["load_ref_model"],
55 | svae_ref_model_dir=cfg["svae"]["ref_model_dir"],
56 | svae_ref_model_dtype=svae_ref_model_cfg.torch_dtype,
57 | svae_finetune_embedding=cfg["svae"]["finetune_embedding"],
58 | svae_word_embed_proj_dim=svae_ref_model_cfg.word_embed_proj_dim,
59 | svae_num_attention_heads=svae_ref_model_cfg.num_attention_heads,
60 | svae_num_hidden_layers=cfg["svae"]["num_hidden_layers"],
61 | svae_model_path=cfg["svae"]["model_path"],
62 | llm_ref_model_dir=cfg["llm"]["ref_model_dir"],
63 | llm_ref_model_dtype=svae_ref_model_cfg.torch_dtype,
64 | llm_finetune_layers=cfg["llm"]["finetune_layers"],
65 | finetune_svae=cfg["finetune_svae"],
66 | max_sentence_len=cfg["max_sen_len"],
67 | max_sentence_num=cfg["max_sen_num"],
68 | dropout=svae_ref_model_cfg.dropout,
69 | bos_id=svae_ref_model_cfg.bos_token_id,
70 | pad_id=svae_ref_model_cfg.pad_token_id,
71 | end_id=svae_ref_model_cfg.eos_token_id,
72 | device=torch.device(cfg["device"]),
73 | dtype=get_dtype(cfg["dtype"])
74 | ).eval()
75 |
76 | batch_size = cfg["batch_size"] * args.cards
77 | device = torch.device(args.device)
78 | batch_sentence_mask = torch.ones((batch_size, cfg["max_sen_num"]), dtype=torch.long, device=device)
79 | batch_sentence_toks = torch.ones((batch_size, cfg["max_sen_num"], cfg["max_sen_len"]), dtype=torch.long, device=device)
80 | batch_tok_mask = torch.ones((batch_size, cfg["max_sen_num"], cfg["max_sen_len"]), dtype=torch.long, device=device)
81 |
82 | print(get_model_complexity_info(model, inputs=(batch_sentence_mask, batch_sentence_toks, batch_tok_mask))['out_table'])
83 |
84 |
85 | if __name__ == "__main__":
86 | args = make_parser().parse_args()
87 | main(args)
88 |
--------------------------------------------------------------------------------
/tools/analysis/analysis_svae.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | import torch
30 | import argparse
31 |
32 | from mmengine.analysis import get_model_complexity_info
33 |
34 | from sentence_vae.utils import get_config, get_dtype, load_yaml
35 | from sentence_vae.models import SentenceVAE
36 |
37 |
38 | def make_parser():
39 | parser = argparse.ArgumentParser("SentenceLLM train parser.")
40 | parser.add_argument("-c", "--config", type=str, required=True)
41 | parser.add_argument("--device", type=str, default='cuda')
42 | parser.add_argument("--cards", type=int, default=1)
43 | return parser
44 |
45 |
46 | def main(args):
47 | cfg = load_yaml(args.config)
48 | ref_model_cfg = get_config(cfg["ref_model_dir"])
49 |
50 | model = SentenceVAE(
51 | hidden_size=ref_model_cfg.hidden_size,
52 | vocab_size=ref_model_cfg.vocab_size,
53 | device=torch.device(cfg["device"]),
54 | dtype=get_dtype(cfg["dtype"]),
55 | learnable_add=cfg["learnable_add"],
56 | load_ref_model=cfg["load_ref_model"],
57 | ref_model_dir=cfg["ref_model_dir"],
58 | ref_model_dtype=get_dtype(cfg["ref_model_dtype"]) if cfg["ref_model_dtype"] is not None else None,
59 | finetune_embedding=cfg["finetune_embedding"],
60 | num_attention_heads=ref_model_cfg.num_attention_heads,
61 | num_hidden_layers=cfg["num_hidden_layers"],
62 | max_seq_len=cfg["max_seq_len"],
63 | dropout=ref_model_cfg.dropout,
64 | bos_id=ref_model_cfg.bos_token_id,
65 | pad_id=ref_model_cfg.pad_token_id,
66 | end_id=ref_model_cfg.eos_token_id
67 | ).eval()
68 |
69 | batch_size = cfg["batch_size"] * args.cards
70 | device = torch.device(args.device)
71 | sentences = torch.ones((batch_size, cfg["max_seq_len"]), dtype=torch.long, device=device)
72 | sentence_mask = torch.ones((batch_size, cfg["max_seq_len"]), dtype=torch.long, device=device)
73 |
74 | print(get_model_complexity_info(model, inputs=(sentences, sentence_mask))['out_table'])
75 |
76 |
77 | if __name__ == "__main__":
78 | args = make_parser().parse_args()
79 | main(args)
80 |
--------------------------------------------------------------------------------
/tools/analysis/statistic_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | import argparse
30 | from tqdm import tqdm
31 |
32 | import torch
33 | from torch.utils.data import DataLoader
34 |
35 | from mmengine.dataset import DefaultSampler
36 |
37 | from sentence_vae.utils import get_tokenizer
38 | from sentence_vae.data import TeleDSDataset, SentenceCollate, PassageCollate
39 |
40 |
41 | def make_parser():
42 | parser = argparse.ArgumentParser("SentenceVAE statistic parser.")
43 | parser.add_argument("--tokenizer_dir", type=str)
44 | parser.add_argument("--max_sen_len", type=int, default=64)
45 | parser.add_argument("--max_sen_num", type=int, default=64)
46 | parser.add_argument("--mode", choices=["sentence", "passage"], default='sentence')
47 | parser.add_argument("--card_size", type=int, default=4)
48 | parser.add_argument("--batch_size", type=int, default=1)
49 | parser.add_argument("--teleds_ip", type=str, default="127.0.0.1")
50 | parser.add_argument("--teleds_port", type=int, default=8000)
51 | parser.add_argument("--num_workers", type=int, default=16)
52 | parser.add_argument("--prefetch_factor", type=int, default=5)
53 | parser.add_argument("--max_iters", type=int, default=300000)
54 | return parser
55 |
56 |
57 | def main(args):
58 | tokenizer = get_tokenizer(ckpt_path=args.tokenizer_dir, max_seq_len=args.max_sen_len)
59 |
60 | dataset = TeleDSDataset(server_ip=args.teleds_ip, server_port=args.teleds_port)
61 | sampler = DefaultSampler(dataset, shuffle=False)
62 | if args.mode == "sentence":
63 | collate_fn = SentenceCollate(tokenizer=tokenizer, max_len=args.max_sen_len, padding=True)
64 | else:
65 | collate_fn = PassageCollate(tokenizer=tokenizer, max_sentence_len=args.max_sen_len, max_sentence_num=args.max_sen_num, padding=True)
66 |
67 | dataloader = DataLoader(
68 | dataset=dataset,
69 | batch_size=args.batch_size * args.card_size,
70 | sampler=sampler,
71 | collate_fn=collate_fn,
72 | num_workers=args.num_workers,
73 | prefetch_factor=args.prefetch_factor
74 | )
75 |
76 | all_tokens = 0
77 | total_len = len(dataloader)
78 | if total_len > args.max_iters:
79 | total_len = args.max_iters
80 |
81 | pbar = tqdm(range(total_len))
82 | for i, data in enumerate(dataloader):
83 | if i >= total_len:
84 | break
85 | if args.mode == 'sentence':
86 | input_ids, attention_mask = data
87 | tokens = torch.sum(attention_mask).item()
88 | all_tokens += tokens
89 | pbar.set_postfix(tokens=all_tokens)
90 | pbar.update(1)
91 | else:
92 | batch_sentence_mask, batch_sentence_toks, batch_tok_mask = data
93 | batch_size = batch_sentence_mask.size(0)
94 | for b in range(batch_size):
95 | attention_mask = batch_tok_mask[b][batch_sentence_mask[b]]
96 | tokens = torch.sum(attention_mask).item()
97 | all_tokens += tokens
98 | pbar.set_postfix(tokens=all_tokens)
99 | pbar.update(1)
100 | pbar.close()
101 | print("Total tokens:", all_tokens)
102 |
103 |
104 | if __name__ == "__main__":
105 | args = make_parser().parse_args()
106 | main(args)
107 |
--------------------------------------------------------------------------------
/tools/analysis/statistic_vocab.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | import torch
30 | import argparse
31 |
32 | from sentence_vae.utils import get_model, get_tokenizer
33 | from sentence_vae.data import SentenceCollate
34 |
35 |
36 | def make_parser():
37 | parser = argparse.ArgumentParser("SentenceVAE analysis parser.")
38 | parser.add_argument("--model_dir", type=str, required=True)
39 | parser.add_argument("--max_seq_len", type=int, default=1024)
40 | return parser
41 |
42 |
43 | def main(args):
44 | tokenizer = get_tokenizer(ckpt_path=args.model_dir, max_seq_len=args.max_seq_len)
45 | vocab = tokenizer.get_vocab()
46 | total_bytes = 0
47 |
48 | for token in vocab:
49 | token_bytes = len(token.encode('utf-8'))
50 | total_bytes += token_bytes
51 |
52 | average_bytes = total_bytes / len(vocab)
53 | print(f"Average bytes per token: {average_bytes:.2f}")
54 |
55 |
56 | if __name__ == "__main__":
57 | args = make_parser().parse_args()
58 | main(args)
59 |
--------------------------------------------------------------------------------
/tools/benchmark/benchmark_llm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | import time
30 | import argparse
31 | import numpy as np
32 | from tqdm import tqdm
33 |
34 | import torch
35 | from torch.utils.data import DataLoader
36 |
37 | from mmengine.dataset import DefaultSampler
38 |
39 | from sentence_vae.utils import get_model, get_tokenizer
40 | from sentence_vae.data import TeleDSDataset, SentenceCollate
41 |
42 |
43 | def make_parser():
44 | parser = argparse.ArgumentParser("SentenceVAE benchmark parser.")
45 | parser.add_argument("--server", type=str, default="127.0.0.1")
46 | parser.add_argument("--port", type=int, default=8001)
47 | parser.add_argument("--model_dir", type=str, required=True)
48 | parser.add_argument("--model_dtype", type=str, default="fp16")
49 | parser.add_argument("--device", type=str, default='cuda')
50 | parser.add_argument("--max_seq_len", type=int, default=1024)
51 | parser.add_argument("--max_eval_samples", type=int, default=10)
52 | return parser
53 |
54 |
55 | def main(args):
56 | model = get_model(args.model_dir, args.model_dtype, args.device).eval()
57 | tokenizer = get_tokenizer(ckpt_path=args.model_dir, max_seq_len=args.max_seq_len)
58 |
59 | eval_dataset = TeleDSDataset(server_ip=args.server, server_port=args.port, eval_mode=True, eval_samples=args.max_eval_samples)
60 | eval_sampler = DefaultSampler(eval_dataset, shuffle=False)
61 | eval_collate_fn = SentenceCollate(tokenizer=tokenizer, max_len=args.max_seq_len, padding=True, fix_len=False)
62 |
63 | eval_dataloader = DataLoader(
64 | dataset=eval_dataset,
65 | batch_size=1,
66 | sampler=eval_sampler,
67 | collate_fn=eval_collate_fn,
68 | num_workers=8,
69 | prefetch_factor=20
70 | )
71 |
72 | input_toks, output_toks = 0, 0
73 | input_times, output_times = 0, 0
74 | mem_usage = []
75 | device = torch.device(args.device)
76 |
77 | with torch.no_grad():
78 | for data in tqdm(eval_dataloader):
79 | input_ids, attention_mask = data
80 | seq_len = torch.sum(attention_mask) // 2
81 | if seq_len < 1:
82 | continue
83 | input_ids = input_ids[:1, :seq_len].to(device)
84 |
85 | # 预填充阶段
86 | start_time = time.perf_counter()
87 | output = model(input_ids)
88 | input_time = time.perf_counter()
89 |
90 | input_toks += seq_len
91 | input_times += input_time - start_time
92 |
93 | # 推理阶段
94 | while True:
95 | logits = output.logits
96 | past_key_values = output.past_key_values
97 | new_id = torch.argmax(logits[:1, -1:], dim=-1)
98 | input_ids = torch.concat((input_ids, new_id), dim=1)
99 | if new_id.item() == tokenizer.eos_token_id:
100 | break
101 | if input_ids.size(1) >= args.max_seq_len:
102 | break
103 | output = model(new_id, past_key_values=past_key_values)
104 | mem_usage.append([input_ids.size(1), torch.cuda.memory_allocated() / 1024])
105 |
106 | output_time = time.perf_counter()
107 | output_toks += input_ids.size(1) - seq_len
108 | output_times += output_time - input_time
109 |
110 | mem_usage = np.array(mem_usage)
111 | k, b = np.polyfit(mem_usage[:, 0], mem_usage[:, 1], 1)
112 | print(f"显存占用: {b / 1024:.2f}MB + {k:.2f}KB/token")
113 | print(f"Input: {input_toks / input_times:.2f} tokens/s")
114 | print(f"Output: {output_toks / output_times:.2f} tokens/s")
115 |
116 |
117 | if __name__ == "__main__":
118 | args = make_parser().parse_args()
119 | main(args)
120 |
--------------------------------------------------------------------------------
/tools/benchmark/benchmark_sllm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | import os
30 | import time
31 | import argparse
32 | import numpy as np
33 | from tqdm import tqdm
34 |
35 | import torch
36 | from torch.utils.data import DataLoader
37 |
38 | from mmengine.dataset import DefaultSampler
39 |
40 | from sentence_vae.utils import get_config, get_tokenizer, get_dtype, load_yaml
41 | from sentence_vae.models import SentenceLLM
42 | from sentence_vae.data import TeleDSDataset, PassageCollate
43 |
44 |
45 | def make_parser():
46 | parser = argparse.ArgumentParser("SentenceLLM benchmark parser.")
47 | parser.add_argument("--server", type=str, default="127.0.0.1")
48 | parser.add_argument("--port", type=int, default=8001)
49 | parser.add_argument("--device", type=str, default='cuda')
50 | parser.add_argument("-c", "--config", type=str, required=True)
51 | parser.add_argument("--checkpoint", type=str, default=None)
52 | parser.add_argument("--max_eval_samples", type=int, default=10)
53 | return parser
54 |
55 |
56 | def main(args):
57 | cfg = load_yaml(args.config)
58 | expn = os.path.splitext(os.path.basename(args.config))[0]
59 | svae_ref_model_cfg = get_config(cfg["svae"]["ref_model_dir"])
60 |
61 | model = SentenceLLM(
62 | svae_hidden_size=svae_ref_model_cfg.hidden_size,
63 | svae_vocab_size=svae_ref_model_cfg.vocab_size,
64 | svae_learnable_add=cfg["svae"]["learnable_add"],
65 | svae_load_ref_model=cfg["svae"]["load_ref_model"],
66 | svae_ref_model_dir=cfg["svae"]["ref_model_dir"],
67 | svae_ref_model_dtype=svae_ref_model_cfg.torch_dtype,
68 | svae_finetune_embedding=cfg["svae"]["finetune_embedding"],
69 | svae_word_embed_proj_dim=svae_ref_model_cfg.word_embed_proj_dim,
70 | svae_num_attention_heads=svae_ref_model_cfg.num_attention_heads,
71 | svae_num_hidden_layers=cfg["svae"]["num_hidden_layers"],
72 | svae_model_path=cfg["svae"]["model_path"],
73 | llm_ref_model_dir=cfg["llm"]["ref_model_dir"],
74 | llm_ref_model_dtype=svae_ref_model_cfg.torch_dtype,
75 | llm_finetune_layers=cfg["llm"]["finetune_layers"],
76 | finetune_svae=cfg["finetune_svae"],
77 | max_sentence_len=cfg["max_sen_len"],
78 | max_sentence_num=cfg["max_sen_num"],
79 | dropout=svae_ref_model_cfg.dropout,
80 | bos_id=svae_ref_model_cfg.bos_token_id,
81 | pad_id=svae_ref_model_cfg.pad_token_id,
82 | end_id=svae_ref_model_cfg.eos_token_id,
83 | device=torch.device(cfg["device"]),
84 | dtype=get_dtype(cfg["dtype"])
85 | )
86 |
87 | exp_dir = f"exp/SentenceVAE-{expn}"
88 | ckpt_list = os.listdir(exp_dir)
89 | ckpt = args.checkpoint
90 | if ckpt is None:
91 | for ckpt_path in ckpt_list:
92 | if "best" in ckpt_path:
93 | ckpt_path = os.path.join(exp_dir, ckpt_path)
94 | ckpt = torch.load(ckpt_path)['state_dict']
95 | assert ckpt is not None, f"Not found the best checkpoint under {exp_dir}."
96 | else:
97 | assert os.path.exists(ckpt), f"Checkpoint {ckpt} not found."
98 | ckpt = torch.load(ckpt)["state_dict"]
99 | model.load_state_dict(ckpt)
100 | model.eval()
101 |
102 | tokenizer = get_tokenizer(ckpt_path=cfg["svae"]["ref_model_dir"], max_seq_len=cfg["max_sen_len"])
103 |
104 | eval_dataset = TeleDSDataset(server_ip=args.server, server_port=args.port, eval_mode=True, eval_samples=args.max_eval_samples)
105 | eval_sampler = DefaultSampler(eval_dataset, shuffle=False)
106 | eval_collate_fn = PassageCollate(tokenizer=tokenizer, max_sentence_len=cfg["max_sen_len"], max_sentence_num=cfg["max_sen_num"], padding=True)
107 |
108 | eval_dataloader = DataLoader(
109 | dataset=eval_dataset,
110 | batch_size=1,
111 | sampler=eval_sampler,
112 | collate_fn=eval_collate_fn,
113 | num_workers=cfg["dataloader_num_workers"],
114 | prefetch_factor=cfg["dataloader_prefetch_factor"]
115 | )
116 |
117 | input_toks, output_toks = 0, 0
118 | input_times, output_times = 0, 0
119 | mem_usage = []
120 | device = torch.device(args.device)
121 |
122 | with torch.no_grad():
123 | for data in tqdm(eval_dataloader):
124 | batch_sentence_mask, batch_sentence_toks, batch_tok_mask = data
125 | sentence_num = torch.sum(batch_sentence_mask).item() // 2
126 | if sentence_num < 1:
127 | continue
128 | seq_len = torch.sum(batch_tok_mask[0, :sentence_num]).item()
129 | input_toks += seq_len
130 |
131 | batch_sentence_mask = batch_sentence_mask[:1, :seq_len].to(device)
132 | batch_sentence_toks = batch_sentence_toks[:1, :seq_len, :].to(device)
133 | batch_tok_mask = batch_tok_mask[:1, :seq_len, :].to(device)
134 |
135 | # 预填充阶段
136 | tokens_i = seq_len
137 | input_time = -1
138 | start_time = time.perf_counter()
139 |
140 | for i, new_tokens in enumerate(model.streaming_generate(
141 | batch_sentence_mask,
142 | batch_sentence_toks,
143 | batch_tok_mask
144 | )):
145 | if i == 0:
146 | input_time = time.perf_counter()
147 | tokens_i += len(new_tokens)
148 | mem_usage.append([tokens_i, torch.cuda.memory_allocated() / 1024])
149 |
150 | if input_time < 0:
151 | continue
152 |
153 | output_time = time.perf_counter()
154 | output_toks += tokens_i - seq_len
155 | output_times += output_time - input_time
156 | input_times += input_time - start_time
157 |
158 | mem_usage = np.array(mem_usage)
159 | k, b = np.polyfit(mem_usage[:, 0], mem_usage[:, 1], 1)
160 | print(f"显存占用: {b / 1024:.2f}MB + {k:.2f}KB/token")
161 | print(f"Input: {input_toks / input_times:.2f} tokens/s")
162 | print(f"Output: {output_toks / output_times:.2f} tokens/s")
163 |
164 |
165 | if __name__ == "__main__":
166 | args = make_parser().parse_args()
167 | main(args)
168 |
--------------------------------------------------------------------------------
/tools/demo/demo_llm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | import torch
30 | import argparse
31 |
32 | from sentence_vae.utils import get_model, get_tokenizer
33 | from sentence_vae.data import SentenceCollate
34 |
35 |
36 | def make_parser():
37 | parser = argparse.ArgumentParser("SentenceVAE benchmark parser.")
38 | parser.add_argument("--model_dir", type=str, required=True)
39 | parser.add_argument("--model_dtype", type=str, default="fp16")
40 | parser.add_argument("--device", type=str, default='cuda')
41 | parser.add_argument("--max_seq_len", type=int, default=1024)
42 | parser.add_argument("--input", type=str, default="Hello,")
43 | return parser
44 |
45 |
46 | def main(args):
47 | model = get_model(args.model_dir, args.model_dtype, args.device).eval()
48 | tokenizer = get_tokenizer(ckpt_path=args.model_dir, max_seq_len=args.max_seq_len)
49 |
50 | collate_fn = SentenceCollate(tokenizer=tokenizer, max_len=args.max_seq_len, padding=True, fix_len=False)
51 | input_ids = tokenizer.batch_encode_plus([args.input], return_tensors="pt")['input_ids']
52 |
53 | device = torch.device(args.device)
54 | input_ids = input_ids.to(device)
55 | output = model(input_ids)
56 |
57 | print("Input:", args.input)
58 | print("Output:")
59 | while True:
60 | logits = output.logits
61 | past_key_values = output.past_key_values
62 | new_id = torch.argmax(logits[:1, -1:], dim=-1)
63 | input_ids = torch.concat((input_ids, new_id), dim=1)
64 | if new_id.item() == tokenizer.eos_token_id:
65 | break
66 | if input_ids.size(1) >= args.max_seq_len:
67 | break
68 | output_word = tokenizer.decode(new_id.item(), skip_special_tokens=True)
69 | print(output_word, end="")
70 | output = model(new_id, past_key_values=past_key_values)
71 | print()
72 |
73 |
74 | if __name__ == "__main__":
75 | args = make_parser().parse_args()
76 | main(args)
77 |
--------------------------------------------------------------------------------
/tools/demo/demo_sllm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | import os
30 | import torch
31 | import argparse
32 |
33 | from sentence_vae.models import SentenceLLM
34 | from sentence_vae.data import PassageCollate
35 | from sentence_vae.utils import get_config, get_tokenizer, get_dtype, load_yaml
36 |
37 |
38 | def make_parser():
39 | parser = argparse.ArgumentParser("SentenceVAE demo parser.")
40 | parser.add_argument("-c", "--config", type=str, required=True)
41 | parser.add_argument("--checkpoint", type=str, default=None)
42 | parser.add_argument("--input", type=str, default="Hello,")
43 | return parser
44 |
45 |
46 | def main(args):
47 | cfg = load_yaml(args.config)
48 | expn = os.path.splitext(os.path.basename(args.config))[0]
49 | svae_ref_model_cfg = get_config(cfg["svae"]["ref_model_dir"])
50 |
51 | model = SentenceLLM(
52 | svae_hidden_size=svae_ref_model_cfg.hidden_size,
53 | svae_vocab_size=svae_ref_model_cfg.vocab_size,
54 | svae_learnable_add=cfg["svae"]["learnable_add"],
55 | svae_load_ref_model=cfg["svae"]["load_ref_model"],
56 | svae_ref_model_dir=cfg["svae"]["ref_model_dir"],
57 | svae_ref_model_dtype=svae_ref_model_cfg.torch_dtype,
58 | svae_finetune_embedding=cfg["svae"]["finetune_embedding"],
59 | svae_word_embed_proj_dim=svae_ref_model_cfg.word_embed_proj_dim,
60 | svae_num_attention_heads=svae_ref_model_cfg.num_attention_heads,
61 | svae_num_hidden_layers=cfg["svae"]["num_hidden_layers"],
62 | svae_model_path=cfg["svae"]["model_path"],
63 | llm_ref_model_dir=cfg["llm"]["ref_model_dir"],
64 | llm_ref_model_dtype=svae_ref_model_cfg.torch_dtype,
65 | llm_finetune_layers=cfg["llm"]["finetune_layers"],
66 | finetune_svae=cfg["finetune_svae"],
67 | max_sentence_len=cfg["max_sen_len"],
68 | max_sentence_num=cfg["max_sen_num"],
69 | dropout=svae_ref_model_cfg.dropout,
70 | bos_id=svae_ref_model_cfg.bos_token_id,
71 | pad_id=svae_ref_model_cfg.pad_token_id,
72 | end_id=svae_ref_model_cfg.eos_token_id,
73 | device=torch.device(cfg["device"]),
74 | dtype=get_dtype(cfg["dtype"])
75 | )
76 |
77 | exp_dir = f"exp/SentenceVAE-{expn}"
78 | ckpt_list = os.listdir(exp_dir)
79 | ckpt = args.checkpoint
80 | if ckpt is None:
81 | for ckpt_path in ckpt_list:
82 | if "best" in ckpt_path:
83 | ckpt_path = os.path.join(exp_dir, ckpt_path)
84 | ckpt = torch.load(ckpt_path)['state_dict']
85 | assert ckpt is not None, f"Not found the best checkpoint under {exp_dir}."
86 | else:
87 | assert os.path.exists(ckpt), f"Checkpoint {ckpt} not found."
88 | ckpt = torch.load(ckpt)["state_dict"]
89 | model.load_state_dict(ckpt)
90 | model.eval()
91 |
92 | tokenizer = get_tokenizer(ckpt_path=cfg["svae"]["ref_model_dir"], max_seq_len=cfg["max_sen_len"])
93 | collate_fn = PassageCollate(tokenizer=tokenizer, max_sentence_len=cfg["max_sen_len"], max_sentence_num=cfg["max_sen_num"], padding=True)
94 | batch_sentence_mask, batch_sentence_toks, batch_tok_mask = collate_fn([args.input])
95 |
96 | print(f"Input: {args.input}")
97 | print("Output:")
98 |
99 | for output_ids in model.streaming_generate(
100 | batch_sentence_mask,
101 | batch_sentence_toks,
102 | batch_tok_mask
103 | ):
104 | new_sentence = tokenizer.decode(output_ids, skip_special_tokens=True)
105 | print(new_sentence)
106 |
107 |
108 | if __name__ == "__main__":
109 | args = make_parser().parse_args()
110 | main(args)
111 |
--------------------------------------------------------------------------------
/tools/demo/demo_svae.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | import os
30 | import torch
31 | import argparse
32 |
33 | from sentence_vae.models import SentenceVAE
34 | from sentence_vae.utils import get_config, get_tokenizer, get_dtype, load_yaml
35 |
36 |
37 | def make_parser():
38 | parser = argparse.ArgumentParser("SentenceVAE demo parser.")
39 | parser.add_argument("-c", "--config", type=str, required=True)
40 | parser.add_argument("--checkpoint", type=str, default=None)
41 | parser.add_argument("--input", type=str, default=None)
42 | return parser
43 |
44 |
45 | def main(args):
46 | cfg = load_yaml(args.config)
47 | expn = os.path.splitext(os.path.basename(args.config))[0]
48 | ref_model_cfg = get_config(cfg["ref_model_dir"])
49 |
50 | model = SentenceVAE(
51 | hidden_size=ref_model_cfg.hidden_size,
52 | vocab_size=ref_model_cfg.vocab_size,
53 | device=torch.device(cfg["device"]),
54 | dtype=get_dtype(cfg["dtype"]),
55 | learnable_add=cfg["learnable_add"],
56 | load_ref_model=cfg["load_ref_model"],
57 | ref_model_dir=cfg["ref_model_dir"],
58 | ref_model_dtype=get_dtype(cfg["ref_model_dtype"]) if cfg["ref_model_dtype"] is not None else None,
59 | finetune_embedding=cfg["finetune_embedding"],
60 | num_attention_heads=ref_model_cfg.num_attention_heads,
61 | num_hidden_layers=cfg["num_hidden_layers"],
62 | max_seq_len=cfg["max_seq_len"],
63 | dropout=ref_model_cfg.dropout,
64 | bos_id=ref_model_cfg.bos_token_id,
65 | pad_id=ref_model_cfg.pad_token_id,
66 | end_id=ref_model_cfg.eos_token_id
67 | )
68 | exp_dir = f"exp/{expn}"
69 | ckpt_list = os.listdir(exp_dir)
70 | ckpt = args.checkpoint
71 | if ckpt is None:
72 | for ckpt_path in ckpt_list:
73 | if "best" in ckpt_path:
74 | ckpt_path = os.path.join(exp_dir, ckpt_path)
75 | ckpt = torch.load(ckpt_path)['state_dict']
76 | assert ckpt is not None, f"Not found the best checkpoint under {exp_dir}."
77 | else:
78 | assert os.path.exists(ckpt), f"Checkpoint {ckpt} not found."
79 | ckpt = torch.load(ckpt)["state_dict"]
80 | model.load_state_dict(ckpt)
81 | model.eval()
82 |
83 | tokenizer = get_tokenizer(ckpt_path=cfg["ref_model_dir"], max_seq_len=cfg["max_seq_len"])
84 |
85 | if args.input is None:
86 | input_texts = [
87 | 'I love China.',
88 | 'We come from Northwestern Polytechnical University.',
89 | "Hello,",
90 | "Welcome to TeleAI!",
91 | "What's your name?",
92 | "What's your problem?",
93 | "Hello, my dear friend",
94 | "Today is Friday.",
95 | "There is Institute of Artificial Intelligence (TeleAI), China Telecom.",
96 | "One two three four five six seven eight nine ten~",
97 | "Hahaha... and you?",
98 | "Yao yao ling xian!"
99 | ]
100 | else:
101 | input_texts = [args.input]
102 |
103 | input_ids = tokenizer.batch_encode_plus(
104 | input_texts,
105 | padding=True,
106 | truncation=True,
107 | max_length=cfg["max_seq_len"],
108 | return_tensors="pt"
109 | )
110 | attention_mask = input_ids['attention_mask']
111 | input_ids = input_ids['input_ids']
112 |
113 | for idx, input_text in enumerate(input_texts):
114 | print("--------------------")
115 | print(f"[{idx}]\tTest Input: ", end="")
116 | print(input_text)
117 | input_ids = tokenizer.encode(input_text, return_tensors='pt')
118 | print(f"\tTokens:{input_ids.size(1)}")
119 | print("\tVAE Output: ", end="")
120 | for output_id in model.streaming_generate(input_ids):
121 | output_word = tokenizer.decode(output_id, skip_special_tokens=True)
122 | print(output_word, end='')
123 | print()
124 |
125 | # exit(0)
126 |
127 | # output_ids, output_mask = model(input_ids, attention_mask, mode='predict')
128 |
129 | # output_texts = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
130 |
131 | # print("VAE input:")
132 | # print(input_texts)
133 | # print("VAE output:")
134 | # print(output_texts)
135 |
136 |
137 | if __name__ == "__main__":
138 | args = make_parser().parse_args()
139 | main(args)
140 |
--------------------------------------------------------------------------------
/tools/eval/eval_llm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | import argparse
30 | import numpy as np
31 | from tqdm import tqdm
32 |
33 | import torch
34 | import torch.nn.functional as F
35 | from torch.utils.data import DataLoader
36 |
37 | from mmengine.dataset import DefaultSampler
38 |
39 | from sentence_vae.utils import get_model, get_tokenizer
40 | from sentence_vae.data import TeleDSDataset, SentenceCollate
41 |
42 |
43 | def make_parser():
44 | parser = argparse.ArgumentParser("SentenceVAE eval parser.")
45 | parser.add_argument("--server", type=str, default="127.0.0.1")
46 | parser.add_argument("--port", type=int, default=8001)
47 | parser.add_argument("--model_dir", type=str, required=True)
48 | parser.add_argument("--model_dtype", type=str, default="fp16")
49 | parser.add_argument("--device", type=str, default='cuda')
50 | parser.add_argument("--max_seq_len", type=int, default=2048)
51 | parser.add_argument("--batch_size", type=int, default=8)
52 | parser.add_argument("--max_eval_samples", type=int, default=1000)
53 | return parser
54 |
55 |
56 | def main(args):
57 | model = get_model(args.model_dir, args.model_dtype, args.device).eval()
58 | tokenizer = get_tokenizer(ckpt_path=args.model_dir, max_seq_len=args.max_seq_len)
59 |
60 | eval_dataset = TeleDSDataset(server_ip=args.server, server_port=args.port, eval_mode=True, eval_samples=args.max_eval_samples)
61 | eval_sampler = DefaultSampler(eval_dataset, shuffle=False)
62 | eval_collate_fn = SentenceCollate(tokenizer=tokenizer, max_len=args.max_seq_len, padding=True)
63 |
64 | eval_dataloader = DataLoader(
65 | dataset=eval_dataset,
66 | batch_size=args.batch_size,
67 | sampler=eval_sampler,
68 | collate_fn=eval_collate_fn,
69 | num_workers=8,
70 | prefetch_factor=20
71 | )
72 |
73 | loss_list = []
74 | device = torch.device(args.device)
75 | for data in tqdm(eval_dataloader):
76 | input_ids, attention_mask = data
77 | input_ids = input_ids.to(device)
78 | attention_mask = attention_mask.to(device)
79 | with torch.no_grad():
80 | output = model(input_ids, attention_mask).logits
81 | batch, _ = input_ids.shape
82 | pad_ids = torch.zeros((batch, 1), device=device, dtype=input_ids.dtype).fill_(tokenizer.pad_token_id)
83 | tgt_ids = torch.cat((input_ids, pad_ids), dim=1)
84 | seq_lens = torch.sum(attention_mask, dim=1, keepdim=True)
85 | tgt_ids.scatter_(1, seq_lens, tokenizer.eos_token_id)
86 | attention_mask = attention_mask.bool()
87 | loss = F.cross_entropy(output[attention_mask], tgt_ids[:, 1:][attention_mask]).item()
88 | loss_list.append(loss)
89 | mean_loss = np.mean(np.array(loss_list))
90 | ppl = np.exp(mean_loss)
91 |
92 | print("Exp:", args.model_dir)
93 | print("Best PPL:", ppl)
94 |
95 |
96 | if __name__ == "__main__":
97 | args = make_parser().parse_args()
98 | main(args)
99 |
--------------------------------------------------------------------------------
/tools/eval/eval_sllm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | import os
30 | import argparse
31 | from tqdm import tqdm
32 |
33 | import torch
34 | from torch.utils.data import DataLoader
35 |
36 | from mmengine.dataset import DefaultSampler
37 |
38 | from sentence_vae.models import SentenceLLM
39 | from sentence_vae.utils import get_tokenizer, get_config, get_dtype, load_yaml
40 | from sentence_vae.data import TeleDSDataset, PassageCollate, SLLM_PPL
41 |
42 |
43 | def make_parser():
44 | parser = argparse.ArgumentParser("SentenceVAE eval parser.")
45 | parser.add_argument("--server", type=str, default="127.0.0.1")
46 | parser.add_argument("--port", type=int, default=8001)
47 | parser.add_argument("-c", "--config", type=str, required=True)
48 | parser.add_argument("--checkpoint", type=str, default=None)
49 | parser.add_argument("--device", type=str, default='cuda')
50 | parser.add_argument("--max_eval_samples", type=int, default=1000)
51 | return parser
52 |
53 |
54 | def main(args):
55 | cfg = load_yaml(args.config)
56 | expn = os.path.splitext(os.path.basename(args.config))[0]
57 | svae_ref_model_cfg = get_config(cfg["svae"]["ref_model_dir"])
58 |
59 | model = SentenceLLM(
60 | svae_hidden_size=svae_ref_model_cfg.hidden_size,
61 | svae_vocab_size=svae_ref_model_cfg.vocab_size,
62 | svae_learnable_add=cfg["svae"]["learnable_add"],
63 | svae_load_ref_model=cfg["svae"]["load_ref_model"],
64 | svae_ref_model_dir=cfg["svae"]["ref_model_dir"],
65 | svae_ref_model_dtype=svae_ref_model_cfg.torch_dtype,
66 | svae_finetune_embedding=cfg["svae"]["finetune_embedding"],
67 | svae_word_embed_proj_dim=svae_ref_model_cfg.word_embed_proj_dim,
68 | svae_num_attention_heads=svae_ref_model_cfg.num_attention_heads,
69 | svae_num_hidden_layers=cfg["svae"]["num_hidden_layers"],
70 | svae_model_path=cfg["svae"]["model_path"],
71 | llm_ref_model_dir=cfg["llm"]["ref_model_dir"],
72 | llm_ref_model_dtype=svae_ref_model_cfg.torch_dtype,
73 | llm_finetune_layers=cfg["llm"]["finetune_layers"],
74 | finetune_svae=cfg["finetune_svae"],
75 | max_sentence_len=cfg["max_sen_len"],
76 | max_sentence_num=cfg["max_sen_num"],
77 | dropout=svae_ref_model_cfg.dropout,
78 | bos_id=svae_ref_model_cfg.bos_token_id,
79 | pad_id=svae_ref_model_cfg.pad_token_id,
80 | end_id=svae_ref_model_cfg.eos_token_id,
81 | device=torch.device(cfg["device"]),
82 | dtype=get_dtype(cfg["dtype"])
83 | )
84 |
85 | exp_dir = f"exp/SentenceVAE-{expn}"
86 | ckpt_list = os.listdir(exp_dir)
87 | ckpt = args.checkpoint
88 | if ckpt is None:
89 | for ckpt_path in ckpt_list:
90 | if "best" in ckpt_path:
91 | ckpt_path = os.path.join(exp_dir, ckpt_path)
92 | ckpt = torch.load(ckpt_path)['state_dict']
93 | assert ckpt is not None, f"Not found the best checkpoint under {exp_dir}."
94 | else:
95 | assert os.path.exists(ckpt), f"Checkpoint {ckpt} not found."
96 | ckpt = torch.load(ckpt)["state_dict"]
97 | model.load_state_dict(ckpt)
98 | model.eval()
99 |
100 | tokenizer = get_tokenizer(ckpt_path=cfg["svae"]["ref_model_dir"], max_seq_len=cfg["max_sen_len"])
101 |
102 | eval_dataset = TeleDSDataset(server_ip=args.server, server_port=args.port, eval_mode=True)
103 | eval_sampler = DefaultSampler(eval_dataset, shuffle=False)
104 | eval_collate_fn = PassageCollate(tokenizer=tokenizer, max_sentence_len=cfg["max_sen_len"], max_sentence_num=cfg["max_sen_num"], padding=True)
105 |
106 |
107 | eval_dataloader = DataLoader(
108 | dataset=eval_dataset,
109 | batch_size=cfg["batch_size"],
110 | sampler=eval_sampler,
111 | collate_fn=eval_collate_fn,
112 | num_workers=cfg["dataloader_num_workers"],
113 | prefetch_factor=cfg["dataloader_prefetch_factor"]
114 | )
115 |
116 | metric = SLLM_PPL()
117 | device = torch.device(args.device)
118 | for data in tqdm(eval_dataloader):
119 | batch_sentence_mask, batch_sentence_toks, batch_tok_mask = data
120 | batch_sentence_mask = batch_sentence_mask.to(device)
121 | batch_sentence_toks = batch_sentence_toks.to(device)
122 | batch_tok_mask = batch_tok_mask.to(device)
123 | output = model(batch_sentence_mask, batch_sentence_toks, batch_tok_mask, mode='predict')
124 | metric.process(batch_sentence_mask, output)
125 | results = metric.compute_metrics(metric.results)
126 |
127 | print("Exp:", expn)
128 | print("Best PPL:", results["eval_ppl"])
129 |
130 |
131 | if __name__ == "__main__":
132 | args = make_parser().parse_args()
133 | main(args)
134 |
--------------------------------------------------------------------------------
/tools/eval/eval_svae.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | import os
30 | import argparse
31 | from tqdm import tqdm
32 |
33 | import torch
34 | from torch.utils.data import DataLoader
35 |
36 | from mmengine.dataset import DefaultSampler
37 |
38 | from sentence_vae.models import SentenceVAE
39 | from sentence_vae.utils import get_tokenizer, get_config, get_dtype, load_yaml
40 | from sentence_vae.data import TeleDSDataset, SentenceCollate, SVAE_PPL
41 |
42 |
43 | def make_parser():
44 | parser = argparse.ArgumentParser("SentenceVAE eval parser.")
45 | parser.add_argument("--server", type=str, default="127.0.0.1")
46 | parser.add_argument("--port", type=int, default=8000)
47 | parser.add_argument("-c", "--config", type=str, required=True)
48 | parser.add_argument("--checkpoint", type=str, default=None)
49 | parser.add_argument("--device", type=str, default='cuda')
50 | parser.add_argument("--max_eval_samples", type=int, default=1000)
51 | return parser
52 |
53 |
54 | def main(args):
55 | cfg = load_yaml(args.config)
56 | expn = os.path.splitext(os.path.basename(args.config))[0]
57 | ref_model_cfg = get_config(cfg["ref_model_dir"])
58 |
59 | model = SentenceVAE(
60 | hidden_size=ref_model_cfg.hidden_size,
61 | vocab_size=ref_model_cfg.vocab_size,
62 | device=torch.device(cfg["device"]),
63 | dtype=get_dtype(cfg["dtype"]),
64 | learnable_add=cfg["learnable_add"],
65 | load_ref_model=cfg["load_ref_model"],
66 | ref_model_dir=cfg["ref_model_dir"],
67 | ref_model_dtype=get_dtype(cfg["ref_model_dtype"]) if cfg["ref_model_dtype"] is not None else None,
68 | finetune_embedding=cfg["finetune_embedding"],
69 | num_attention_heads=ref_model_cfg.num_attention_heads,
70 | num_hidden_layers=cfg["num_hidden_layers"],
71 | max_seq_len=cfg["max_seq_len"],
72 | dropout=ref_model_cfg.dropout,
73 | bos_id=ref_model_cfg.bos_token_id,
74 | pad_id=ref_model_cfg.pad_token_id,
75 | end_id=ref_model_cfg.eos_token_id
76 | )
77 | exp_dir = f"exp/{expn}"
78 | ckpt_list = os.listdir(exp_dir)
79 | ckpt = args.checkpoint
80 | if ckpt is None:
81 | for ckpt_path in ckpt_list:
82 | if "best" in ckpt_path:
83 | ckpt_path = os.path.join(exp_dir, ckpt_path)
84 | ckpt = torch.load(ckpt_path)['state_dict']
85 | assert ckpt is not None, f"Not found the best checkpoint under {exp_dir}."
86 | else:
87 | assert os.path.exists(ckpt), f"Checkpoint {ckpt} not found."
88 | ckpt = torch.load(ckpt)["state_dict"]
89 | model.load_state_dict(ckpt)
90 | model.eval()
91 |
92 | tokenizer = get_tokenizer(ckpt_path=cfg["ref_model_dir"], max_seq_len=cfg["max_seq_len"])
93 |
94 | eval_dataset = TeleDSDataset(server_ip=args.server, server_port=args.port, eval_mode=True)
95 | eval_sampler = DefaultSampler(eval_dataset, shuffle=False)
96 | eval_collate_fn = SentenceCollate(tokenizer=tokenizer, max_len=cfg["max_seq_len"], padding=True)
97 |
98 | eval_dataloader = DataLoader(
99 | dataset=eval_dataset,
100 | batch_size=cfg["batch_size"],
101 | sampler=eval_sampler,
102 | collate_fn=eval_collate_fn,
103 | num_workers=cfg["dataloader_num_workers"],
104 | prefetch_factor=cfg["dataloader_prefetch_factor"]
105 | )
106 |
107 | metric = SVAE_PPL()
108 | device = torch.device(args.device)
109 | for data in tqdm(eval_dataloader):
110 | input_ids, attention_mask = data
111 | input_ids = input_ids.to(device)
112 | attention_mask = attention_mask.to(device)
113 | output = model(input_ids, attention_mask, mode='predict')
114 | metric.process(input_ids, output)
115 | results = metric.compute_metrics(metric.results)
116 |
117 | print("Exp:", expn)
118 | print("Best PPL:", results["eval_ppl"])
119 |
120 |
121 | if __name__ == "__main__":
122 | args = make_parser().parse_args()
123 | main(args)
124 |
--------------------------------------------------------------------------------
/tools/train/train_sllm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | import os
30 | import argparse
31 |
32 | import torch
33 | from torch.utils.data import DataLoader
34 |
35 | from mmengine.runner import Runner
36 | from mmengine.dataset import DefaultSampler
37 | from mmengine.dist.utils import init_dist
38 |
39 | from sentence_vae.utils import get_config, get_tokenizer, get_dtype, load_yaml
40 | from sentence_vae.models import SentenceLLM
41 | from sentence_vae.data import TeleDSDataset, PassageCollate, SLLM_PPL
42 |
43 | torch.multiprocessing.set_sharing_strategy('file_system')
44 |
45 |
46 | def make_parser():
47 | parser = argparse.ArgumentParser("SentenceLLM train parser.")
48 | parser.add_argument("-c", "--config", type=str, required=True)
49 | parser.add_argument("--local-rank", "--local_rank", type=int, default=0)
50 | parser.add_argument("--launcher", choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='job launcher')
51 | parser.add_argument("--teleds_ip", type=str, default="127.0.0.1")
52 | parser.add_argument("--teleds_port", type=int, default=8001)
53 | return parser
54 |
55 |
56 | def main(args):
57 | cfg = load_yaml(args.config)
58 | expn = os.path.splitext(os.path.basename(args.config))[0]
59 | svae_ref_model_cfg = get_config(cfg["svae"]["ref_model_dir"])
60 |
61 | if args.launcher != 'none':
62 | init_dist(args.launcher)
63 |
64 | model = SentenceLLM(
65 | svae_hidden_size=svae_ref_model_cfg.hidden_size,
66 | svae_vocab_size=svae_ref_model_cfg.vocab_size,
67 | svae_learnable_add=cfg["svae"]["learnable_add"],
68 | svae_load_ref_model=cfg["svae"]["load_ref_model"],
69 | svae_ref_model_dir=cfg["svae"]["ref_model_dir"],
70 | svae_ref_model_dtype=svae_ref_model_cfg.torch_dtype,
71 | svae_finetune_embedding=cfg["svae"]["finetune_embedding"],
72 | svae_word_embed_proj_dim=svae_ref_model_cfg.word_embed_proj_dim,
73 | svae_num_attention_heads=svae_ref_model_cfg.num_attention_heads,
74 | svae_num_hidden_layers=cfg["svae"]["num_hidden_layers"],
75 | svae_model_path=cfg["svae"]["model_path"],
76 | llm_ref_model_dir=cfg["llm"]["ref_model_dir"],
77 | llm_ref_model_dtype=svae_ref_model_cfg.torch_dtype,
78 | llm_finetune_layers=cfg["llm"]["finetune_layers"],
79 | finetune_svae=cfg["finetune_svae"],
80 | max_sentence_len=cfg["max_sen_len"],
81 | max_sentence_num=cfg["max_sen_num"],
82 | dropout=svae_ref_model_cfg.dropout,
83 | bos_id=svae_ref_model_cfg.bos_token_id,
84 | pad_id=svae_ref_model_cfg.pad_token_id,
85 | end_id=svae_ref_model_cfg.eos_token_id,
86 | device=torch.device(cfg["device"]),
87 | dtype=get_dtype(cfg["dtype"])
88 | )
89 |
90 | tokenizer = get_tokenizer(ckpt_path=cfg["svae"]["ref_model_dir"], max_seq_len=cfg["max_sen_len"])
91 |
92 | train_dataset = TeleDSDataset(server_ip=args.teleds_ip, server_port=args.teleds_port)
93 | train_sampler = DefaultSampler(train_dataset, shuffle=False)
94 | train_collate_fn = PassageCollate(tokenizer=tokenizer, max_sentence_len=cfg["max_sen_len"], max_sentence_num=cfg["max_sen_num"], padding=True)
95 |
96 | eval_dataset = TeleDSDataset(server_ip=args.teleds_ip, server_port=args.teleds_port, eval_mode=True)
97 | eval_sampler = DefaultSampler(eval_dataset, shuffle=False)
98 | eval_collate_fn = PassageCollate(tokenizer=tokenizer, max_sentence_len=cfg["max_sen_len"], max_sentence_num=cfg["max_sen_num"], padding=True)
99 |
100 | train_dataloader = DataLoader(
101 | dataset=train_dataset,
102 | batch_size=cfg["batch_size"],
103 | sampler=train_sampler,
104 | collate_fn=train_collate_fn,
105 | num_workers=cfg["dataloader_num_workers"],
106 | prefetch_factor=cfg["dataloader_prefetch_factor"]
107 | )
108 | eval_dataloader = DataLoader(
109 | dataset=eval_dataset,
110 | batch_size=cfg["batch_size"],
111 | sampler=eval_sampler,
112 | collate_fn=eval_collate_fn,
113 | num_workers=cfg["dataloader_num_workers"],
114 | prefetch_factor=cfg["dataloader_prefetch_factor"]
115 | )
116 |
117 | learning_rate = cfg["batch_size"] * cfg["base_lr"]
118 |
119 | default_hooks=dict(checkpoint=dict(
120 | type='CheckpointHook',
121 | by_epoch=False,
122 | interval=cfg["save_checkpoint_iters"],
123 | max_keep_ckpts=cfg["max_keep_ckpts"],
124 | save_best='eval_ppl', rule='less', published_keys=['meta', 'state_dict']
125 | ))
126 | runner = Runner(
127 | model=model,
128 | work_dir=f"exp/SentenceVAE-{expn}",
129 | train_dataloader=train_dataloader,
130 | val_dataloader=eval_dataloader,
131 | val_cfg=dict(),
132 | val_evaluator=dict(type=SLLM_PPL),
133 | train_cfg=dict(by_epoch=False, max_iters=cfg["max_iters"], val_interval=cfg["val_iters"]),
134 | optim_wrapper=dict(type="AmpOptimWrapper", optimizer=dict(type='AdamW', lr=learning_rate, weight_decay=0.01), clip_grad=dict(max_norm=1)),
135 | param_scheduler=[
136 | dict(type='LinearLR', start_factor=1e-3, by_epoch=False, begin=0, end=cfg["warmup_iters"]),
137 | dict(type='CosineAnnealingLR', by_epoch=False, T_max=cfg["cosineannealinglr_tmax"])
138 | ],
139 | visualizer=dict(type='Visualizer', vis_backends=[dict(type='TensorboardVisBackend')]),
140 | default_hooks=default_hooks,
141 | custom_hooks=[dict(type='EMAHook')],
142 | resume=cfg["resume_train"]
143 | )
144 | runner.train()
145 |
146 |
147 | if __name__ == "__main__":
148 | args = make_parser().parse_args()
149 | main(args)
150 |
--------------------------------------------------------------------------------
/tools/train/train_svae.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024, School of Artificial Intelligence, OPtics and ElectroNics(iOPEN),
2 | # Northwestern PolyTechnical University,
3 | # and Institute of Artificial Intelligence (TeleAI), China Telecom.
4 | #
5 | # Author: Coder.AN (an.hongjun@foxmail.com)
6 | # Huasen Chen (chenyifan1@mail.nwpu.edu.cn)
7 | #
8 | #
9 | # This software is licensed under the MIT License.
10 | #
11 | # Permission is hereby granted, free of charge, to any person obtaining a copy
12 | # of this software and associated documentation files (the "Software"), to deal
13 | # in the Software without restriction, including without limitation the rights
14 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15 | # copies of the Software, and to permit persons to whom the Software is
16 | # furnished to do so, subject to the following conditions:
17 | #
18 | # The above copyright notice and this permission notice shall be included in
19 | # all copies or substantial portions of the Software.
20 | #
21 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
27 | # THE SOFTWARE.
28 |
29 | import os
30 | import argparse
31 |
32 | import torch
33 | from torch.utils.data import DataLoader
34 |
35 | from mmengine.runner import Runner
36 | from mmengine.dataset import DefaultSampler
37 | from mmengine.dist.utils import init_dist
38 |
39 | from sentence_vae.utils import get_config, get_tokenizer, get_dtype, load_yaml
40 | from sentence_vae.models import SentenceVAE
41 | from sentence_vae.data import TeleDSDataset, SentenceCollate, SVAE_PPL
42 |
43 |
44 | def make_parser():
45 | parser = argparse.ArgumentParser("SentenceVAE train parser.")
46 | parser.add_argument("-c", "--config", type=str, required=True)
47 | parser.add_argument("--local-rank", "--local_rank", type=int, default=0)
48 | parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='job launcher')
49 | parser.add_argument("--teleds_ip", type=str, default="127.0.0.1")
50 | parser.add_argument("--teleds_port", type=int, default=8000)
51 | return parser
52 |
53 |
54 | def main(args):
55 | cfg = load_yaml(args.config)
56 | expn = os.path.splitext(os.path.basename(args.config))[0]
57 | ref_model_cfg = get_config(cfg["ref_model_dir"])
58 |
59 | if args.launcher != 'none':
60 | init_dist(args.launcher)
61 |
62 | model = SentenceVAE(
63 | hidden_size=ref_model_cfg.hidden_size,
64 | vocab_size=ref_model_cfg.vocab_size,
65 | device=torch.device(cfg["device"]),
66 | dtype=get_dtype(cfg["dtype"]),
67 | learnable_add=cfg["learnable_add"],
68 | load_ref_model=cfg["load_ref_model"],
69 | ref_model_dir=cfg["ref_model_dir"],
70 | ref_model_dtype=get_dtype(cfg["ref_model_dtype"]) if cfg["ref_model_dtype"] is not None else None,
71 | finetune_embedding=cfg["finetune_embedding"],
72 | num_attention_heads=ref_model_cfg.num_attention_heads,
73 | num_hidden_layers=cfg["num_hidden_layers"],
74 | max_seq_len=cfg["max_seq_len"],
75 | dropout=ref_model_cfg.dropout,
76 | bos_id=ref_model_cfg.bos_token_id,
77 | pad_id=ref_model_cfg.pad_token_id,
78 | end_id=ref_model_cfg.eos_token_id
79 | )
80 |
81 | tokenizer = get_tokenizer(ckpt_path=cfg["ref_model_dir"], max_seq_len=cfg["max_seq_len"])
82 |
83 | train_dataset = TeleDSDataset(server_ip=args.teleds_ip, server_port=args.teleds_port)
84 | train_sampler = DefaultSampler(train_dataset, shuffle=False)
85 | train_collate_fn = SentenceCollate(tokenizer=tokenizer, max_len=cfg["max_seq_len"], padding=True)
86 |
87 | eval_dataset = TeleDSDataset(server_ip=args.teleds_ip, server_port=args.teleds_port, eval_mode=True)
88 | eval_sampleer = DefaultSampler(eval_dataset, shuffle=False)
89 | eval_collate_fn = SentenceCollate(tokenizer=tokenizer, max_len=cfg["max_seq_len"], padding=True)
90 |
91 | train_dataloader = DataLoader(
92 | dataset=train_dataset,
93 | batch_size=cfg["batch_size"],
94 | sampler=train_sampler,
95 | collate_fn=train_collate_fn,
96 | num_workers=cfg["dataloader_num_workers"],
97 | prefetch_factor=cfg["dataloader_prefetch_factor"]
98 | )
99 | eval_dataloader = DataLoader(
100 | dataset=eval_dataset,
101 | batch_size=cfg["batch_size"],
102 | sampler=eval_sampleer,
103 | collate_fn=eval_collate_fn,
104 | num_workers=cfg["dataloader_num_workers"],
105 | prefetch_factor=cfg["dataloader_prefetch_factor"]
106 | )
107 |
108 | learning_rate = cfg["batch_size"] * cfg["base_lr"]
109 |
110 | default_hooks=dict(checkpoint=dict(
111 | type='CheckpointHook',
112 | by_epoch=False,
113 | interval=cfg["save_checkpoint_iters"],
114 | max_keep_ckpts=cfg["max_keep_ckpts"],
115 | save_best='eval_ppl', rule='less', published_keys=['meta', 'state_dict']
116 | ))
117 | runner = Runner(
118 | model=model,
119 | work_dir=f"exp/{expn}",
120 | train_dataloader=train_dataloader,
121 | val_dataloader=eval_dataloader,
122 | val_cfg=dict(),
123 | val_evaluator=dict(type=SVAE_PPL),
124 | train_cfg=dict(by_epoch=False, max_iters=cfg["max_iters"], val_interval=cfg["val_iters"]),
125 | optim_wrapper=dict(type="AmpOptimWrapper", optimizer=dict(type='AdamW', lr=learning_rate, weight_decay=0.01), clip_grad=dict(max_norm=1)),
126 | param_scheduler=[
127 | dict(type='LinearLR', start_factor=1e-3, by_epoch=False, begin=0, end=cfg["warmup_iters"]),
128 | dict(type='CosineAnnealingLR', by_epoch=False, T_max=cfg["cosineannealinglr_tmax"])
129 | ],
130 | visualizer=dict(type='Visualizer', vis_backends=[dict(type='TensorboardVisBackend')]),
131 | default_hooks=default_hooks,
132 | custom_hooks=[dict(type='EMAHook')],
133 | resume=cfg["resume_train"]
134 | )
135 | runner.train()
136 |
137 |
138 | if __name__ == "__main__":
139 | args = make_parser().parse_args()
140 | main(args)
141 |
--------------------------------------------------------------------------------