├── .github
└── FUNDING.yml
├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── docs
├── SEND_POINT_README.md
├── load_hf_dataset.md
├── triton_deploy_trt-llm.md
└── trt_llm_deploy_langchain.md
├── examples
├── qwen-vl
│ ├── .gitignore
│ ├── README.md
│ ├── api.py
│ ├── build.py
│ ├── client
│ │ ├── openai_normal_client.py
│ │ └── openai_stream_client.py
│ ├── default_config.py
│ ├── gptq_convert.py
│ ├── model.py
│ ├── requirements.txt
│ ├── run.py
│ ├── run_chat.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── quantization.py
│ │ └── utils.py
│ ├── vit_onnx_trt.py
│ ├── web_demo.py
│ └── weight.py
├── qwen
│ ├── .gitignore
│ ├── README.md
│ ├── api.py
│ ├── benchmark.py
│ ├── build.py
│ ├── cli_chat.py
│ ├── client
│ │ ├── async_client.py
│ │ ├── normal_client.py
│ │ ├── openai_function_call.py
│ │ ├── openai_normal_client.py
│ │ └── openai_stream_client.py
│ ├── default_config.py
│ ├── gptq_convert.py
│ ├── hf_qwen_convert.py
│ ├── model.py
│ ├── quantize.py
│ ├── requirements.txt
│ ├── run.py
│ ├── smoothquant.py
│ ├── summarize.py
│ ├── test
│ │ ├── test_dynamic_ntk.py
│ │ ├── test_logn.py
│ │ ├── test_rms_norm.py
│ │ └── test_smooth_quant_rms_norm.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── convert.py
│ │ ├── quantization.py
│ │ └── utils.py
│ ├── web_demo.py
│ └── weight.py
└── qwen2
│ ├── .gitignore
│ ├── README.md
│ ├── api.py
│ ├── benchmark.py
│ ├── build.py
│ ├── cli_chat.py
│ ├── default_config.py
│ ├── gptq_convert.py
│ ├── hf_qwen_convert.py
│ ├── model.py
│ ├── pytorch_test.py
│ ├── quantize.py
│ ├── requirements.txt
│ ├── run.py
│ ├── run_old.py
│ ├── smoothquant.py
│ ├── summarize.py
│ ├── utils
│ ├── __init__.py
│ ├── convert.py
│ ├── quantization.py
│ └── utils.py
│ ├── web_demo.py
│ └── weight.py
├── images
├── course.png
├── function_call_001.jpg
├── function_call_002.jpg
├── langchain-chatchat.jpg
├── rmsnormplugin.jpeg
├── rope_inside.jpeg
├── rope_outside.jpeg
├── tensorrt_rmsnorm_op.jpeg
└── triton_trt_llm.png
├── triton_client
└── inflight_batcher_llm_client.py
└── triton_model_repo
├── ensemble
├── 1
│ └── .tmp
└── config.pbtxt
├── postprocessing
├── 1
│ └── model.py
└── config.pbtxt
├── preprocessing
├── 1
│ └── model.py
└── config.pbtxt
├── tensorrt_llm
├── 1
│ ├── .gitkeep
│ └── .tmp
└── config.pbtxt
└── tensorrt_llm_bls
├── 1
└── model.py
└── config.pbtxt
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
4 | patreon: # Replace with a single Patreon username
5 | open_collective: # Replace with a single Open Collective username
6 | ko_fi: tlntin # Replace with a single Ko-fi username
7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
9 | liberapay: # Replace with a single Liberapay username
10 | issuehunt: # Replace with a single IssueHunt username
11 | otechie: # Replace with a single Otechie username
12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
13 | polar: # Replace with a single Polar username
14 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
15 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Prerequisites
2 | *.d
3 | *.whl
4 | # Compiled Object files
5 | *.slo
6 | *.lo
7 | *.o
8 | *.obj
9 |
10 | # Precompiled Headers
11 | *.gch
12 | *.pch
13 |
14 | # Compiled Dynamic libraries
15 | *.so
16 | *.dylib
17 | *.dll
18 |
19 | # Fortran module files
20 | *.mod
21 | *.smod
22 |
23 | # Compiled Static libraries
24 | *.lai
25 | *.la
26 | *.lib
27 |
28 | # Executables
29 | *.exe
30 | *.out
31 | *.app
32 |
33 | # Byte-compiled / optimized / DLL files
34 | __pycache__/
35 | *.py[cod]
36 | *$py.class
37 |
38 | # C extensions
39 | *.so
40 |
41 | # Distribution / packaging
42 | .Python
43 | build/
44 | develop-eggs/
45 | dist/
46 | downloads/
47 | eggs/
48 | .eggs/
49 | lib/
50 | lib64/
51 | parts/
52 | sdist/
53 | var/
54 | wheels/
55 | share/python-wheels/
56 | *.egg-info/
57 | .installed.cfg
58 | *.egg
59 | MANIFEST
60 |
61 | # PyInstaller
62 | # Usually these files are written by a python script from a template
63 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
64 | *.manifest
65 | *.spec
66 |
67 | # Installer logs
68 | pip-log.txt
69 | pip-delete-this-directory.txt
70 |
71 | # Unit test / coverage reports
72 | htmlcov/
73 | .tox/
74 | .nox/
75 | .coverage
76 | .coverage.*
77 | .cache
78 | nosetests.xml
79 | coverage.xml
80 | *.cover
81 | *.py,cover
82 | .hypothesis/
83 | .pytest_cache/
84 | cover/
85 |
86 | # Translations
87 | *.mo
88 | *.pot
89 |
90 | # Django stuff:
91 | *.log
92 | local_settings.py
93 | db.sqlite3
94 | db.sqlite3-journal
95 |
96 | # Flask stuff:
97 | instance/
98 | .webassets-cache
99 |
100 | # Scrapy stuff:
101 | .scrapy
102 |
103 | # Sphinx documentation
104 | docs/_build/
105 |
106 | # PyBuilder
107 | .pybuilder/
108 | target/
109 |
110 | # Jupyter Notebook
111 | .ipynb_checkpoints
112 |
113 | # IPython
114 | profile_default/
115 | ipython_config.py
116 |
117 | # pyenv
118 | # For a library or package, you might want to ignore these files since the code is
119 | # intended to run in multiple environments; otherwise, check them in:
120 | # .python-version
121 |
122 | # pipenv
123 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
124 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
125 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
126 | # install all needed dependencies.
127 | #Pipfile.lock
128 |
129 | # poetry
130 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
131 | # This is especially recommended for binary packages to ensure reproducibility, and is more
132 | # commonly ignored for libraries.
133 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
134 | #poetry.lock
135 |
136 | # pdm
137 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
138 | #pdm.lock
139 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
140 | # in version control.
141 | # https://pdm.fming.dev/#use-with-ide
142 | .pdm.toml
143 |
144 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
145 | __pypackages__/
146 |
147 | # Celery stuff
148 | celerybeat-schedule
149 | celerybeat.pid
150 |
151 | # SageMath parsed files
152 | *.sage.py
153 |
154 | # Environments
155 | .env
156 | .venv
157 | env/
158 | venv/
159 | ENV/
160 | env.bak/
161 | venv.bak/
162 |
163 | # Spyder project settings
164 | .spyderproject
165 | .spyproject
166 |
167 | # Rope project settings
168 | .ropeproject
169 |
170 | # mkdocs documentation
171 | /site
172 |
173 | # mypy
174 | .mypy_cache/
175 | .dmypy.json
176 | dmypy.json
177 |
178 | # Pyre type checker
179 | .pyre/
180 |
181 | # pytype static type analyzer
182 | .pytype/
183 |
184 | # Cython debug symbols
185 | cython_debug/
186 |
187 | # PyCharm
188 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
189 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
190 | # and can be added to the global gitignore or merged into this file. For a more nuclear
191 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
192 | #.idea/
193 | kineto/
194 | .vscode/
195 | *.tar.gz
196 | tmp/
197 | .idea/
198 | *.jpeg
199 | examples/qwen2/CodeQwen1.5*/
200 | examples/qwen2/Qwen1.5*/
201 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 |
2 | [submodule "TensorRT-LLM"]
3 | path = TensorRT-LLM
4 | url = https://github.com/NVIDIA/TensorRT-LLM.git
5 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Tlntin
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/docs/SEND_POINT_README.md:
--------------------------------------------------------------------------------
1 | ### 送分题操作步骤
2 | ##### 准备工作
3 | 1. 进入examples/gpt目录
4 | ```bash
5 | cd /app/tensorrt_llm/examples/gpt
6 | ```
7 |
8 | 2. 安装3个基本py模块,否则会报错。
9 | ```bash
10 | pip install datasets nltk rouge_score
11 | ```
12 | 3. 从huggingface下载模型到服务器,然后将其移动到examples/gpt目录下,并且重命名为gpt2
13 | ```bash
14 | git lfs install
15 | git clone https://huggingface.co/gpt2-medium
16 | mv gpt2-medium /app/tensorrt_llm/examples/gpt/gpt2
17 | ```
18 |
19 | 4. 针对`网络不好`的用户,可以通过百度网盘下载对应数据集,然后根据里面的使用说明将其解压到huggingface的cache路径。
20 | - 百度网盘链接:https://pan.baidu.com/s/1aJrE3c6aMi7Qsc5zXk_amw?pwd=apfd 提取码:apfd
21 |
22 |
23 | ##### 送分题1执行步骤
24 | 1. 转HuggingFace模型到FT格式
25 | ```bash
26 | python3 hf_gpt_convert.py -i gpt2 -o ./c-model/gpt2 --tensor-parallelism 1 --storage-type float16
27 | ```
28 |
29 | 2. 将FT格式的模型数据编译成TensorRT Engine
30 | ```bash
31 | python3 build.py --model_dir=./c-model/gpt2/1-gpu --use_gpt_attention_plugin
32 | ```
33 |
34 | 3. 跑一下推理,看看输出结果
35 | ```bash
36 | python3 run.py --max_output_len=8
37 | ```
38 |
39 |
40 | ##### 送分题2执行步骤
41 | 1. 转HuggingFace模型到FT格式
42 | ```bash
43 | python3 hf_gpt_convert.py -i gpt2 -o ./c-model/gpt2/fp16 --tensor-parallelism 1 --storage-type float16
44 | ```
45 |
46 | 2. 将FT格式的模型数据编译成TensorRT Engine
47 | ```bash
48 | python3 build.py --model_dir=./c-model/gpt2/fp16/1-gpu \
49 | --use_gpt_attention_plugin \
50 | --use_gemm_plugin \
51 | --use_layernorm_plugin \
52 | --max_batch_size 8 \
53 | --max_input_len 924 \
54 | --max_output_len 100 \
55 | --output_dir trt_engine/gpt2/fp16/1-gpu/ \
56 | --hidden_act gelu
57 | ```
58 | 3. 执行最后一个命令, 计算pytorch版和TRT版的`rouge_score`
59 | ```bash
60 | python3 summarize.py --engine_dir trt_engine/gpt2/fp16/1-gpu \
61 | --test_hf \
62 | --batch_size 1 \
63 | --test_trt_llm \
64 | --hf_model_location=gpt2 \
65 | --check_accuracy
66 | ```
67 |
--------------------------------------------------------------------------------
/docs/load_hf_dataset.md:
--------------------------------------------------------------------------------
1 | # datasets离线加载huggingface数据集方法
2 |
3 | ### 使用场景
4 | - 服务器能上国内网不能连外网(指外面的国际网),例如国内的阿里云服务。
5 | - 或者没有联网功能(但是可以通过文件上传),比如具有保密功能的局域网服务器。
6 |
7 | ### 方法1
8 | - 前提:本机能连外网(如果本机也连不上外网,那就可以试试看第三方镜像站有没有对应数据集了)
9 | - 思路:本地在线加载数据集,然后导出数据集到磁盘,最后在服务器加载进去。
10 | - 推荐指数:5星
11 | 1. 在线加载数据集,并导出至本地指定路径
12 | ```python
13 | import os.path
14 | from datasets import load_dataset
15 |
16 | now_dir = os.path.dirname(os.path.abspath(__file__))
17 | target_dir_path = os.path.join(now_dir, "my_cnn_dailymail")
18 | dataset = load_dataset("ccdv/cnn_dailymail", name="3.0.0")
19 | dataset.save_to_disk(target_dir_path)
20 | ```
21 | 2. 观察文件夹布局
22 | ```bash
23 | $ tree my_cnn_dailymail
24 |
25 | my_cnn_dailymail
26 | ├── dataset_dict.json
27 | ├── test
28 | │ ├── data-00000-of-00001.arrow
29 | │ ├── dataset_info.json
30 | │ └── state.json
31 | ├── train
32 | │ ├── data-00000-of-00003.arrow
33 | │ ├── data-00001-of-00003.arrow
34 | │ ├── data-00002-of-00003.arrow
35 | │ ├── dataset_info.json
36 | │ └── state.json
37 | └── validation
38 | ├── data-00000-of-00001.arrow
39 | ├── dataset_info.json
40 | └── state.json
41 |
42 | ```
43 |
44 | 3. 加载数据集
45 | ```bash
46 | import os.path
47 | from datasets import load_from_disk
48 |
49 | now_dir = os.path.dirname(os.path.abspath(__file__))
50 | target_dir_path = os.path.join(now_dir, "my_cnn_dailymail")
51 | dataset = load_from_disk(target_dir_path)
52 | ```
53 |
54 | ### 方法2
55 | - 前提:本机能连外网(如果本机也连不上外网,那就可以试试看第三方镜像战有没有对应数据集了)
56 | - 思路:本地在线加载数据集,然后数据集会存在cache路径,像linux会存在`~/.cache/huggingface`目录,只需要将这个目录先清空,然后在线加载数据集后,将这个目录压缩,再去目标服务器解压至相同路径,就可以正常加载了。
57 | - 限制:需要相同python版本和datasets版本,并且datasets加载时候还是会尝试在线加载数据集,很容易造成数据集损坏,需要添加环境变量`HF_DATASETS_OFFLINE=1` 和`TRANSFORMERS_OFFLINE=1`阻止其在线加载。
58 | - 推荐指数:2星
59 |
60 | ### 方法3
61 | - 前提:本机能上网就行。有外网的就去huggingface下载,没有的就去第三方镜像站,例如hf-mirror.com或者ai.gitee.com或者直接搜索引擎找也行。
62 | - 思路:下载数据集到本地然后直接读取,不同类型的数据集有不同的读取方式,一般来说可以通过直接读取本地数据集绝对路径的方式读取,和离线读取模型文件差不多。
63 | - 限制:可能需要修改文件,有一定门槛,不过个人更喜欢这种,因为可以了解其内部原理。
64 | - 推荐指数:4星
65 | - [可参考huggingface官方教程](https://huggingface.co/docs/datasets/main/en/dataset_script)
66 | 1. 先通过git下载好数据集,下面是演示[ccdv/cnn_dailymail](https://huggingface.co/datasets/ccdv/cnn_dailymail)这个数据集,如果没有外网,也可以在国内的这个[地址](https://www.atyun.com/datasets/files/ccdv/cnn_dailymail.html)下载
67 | 2. 下载后数据集长下面这样
68 | ```bash
69 | $ tree cnn_dailymail
70 |
71 | cnn_dailymail
72 | ├── cnn_dailymail.py
73 | ├── cnn_stories.tgz
74 | ├── dailymail_stories.tgz
75 | └── README.md
76 | ```
77 | 3. 我们先按通用的方式加载一下数据集,也可用相对路径,因为代码默认是先查询本地路径再查询在线路径(不过推荐使用本地绝对路径),因为是本地加载,加上里面有py文件,需要加上`trust_remote_code=True`来信任脚本。
78 | ```python
79 | import os.path
80 |
81 | from datasets import load_dataset
82 |
83 |
84 | now_dir = os.path.dirname(os.path.abspath(__file__))
85 | dataset_dir = os.path.join(now_dir, "cnn_dailymail")
86 | dataset = load_dataset(dataset_dir, trust_remote_code=True)
87 | ```
88 | - 加载报错,提示如下:
89 | ```bash
90 | ValueError: Config name is missing.
91 | Please pick one among the available configs: ['3.0.0', '1.0.0', '2.0.0']
92 | Example of usage:
93 | `load_dataset('cnn_dailymail', '3.0.0')`
94 | ```
95 | - 大概意思是它有三个配置(版本),需要指定版本号。
96 | - 我们补齐版本号再试一次
97 | ```bash
98 | import os.path
99 | from datasets import load_dataset
100 |
101 |
102 | now_dir = os.path.dirname(os.path.abspath(__file__))
103 | dataset_dir = os.path.join(now_dir, "cnn_dailymail")
104 | dataset = load_dataset(dataset_dir, name="3.0.0", trust_remote_code=True)
105 | ```
106 | - 可以加载,不过看日志有做下载操作,共下载3次。
107 | ```bash
108 | Downloading data: 2.11MB [00:00, 3.27MB/s]
109 | Downloading data: 46.4MB [00:02, 15.9MB/s]
110 | Downloading data: 2.43MB [00:00, 2.69MB/s]
111 | Generating train split: 287113 examples [00:29, 9655.52 examples/s]
112 | Generating validation split: 13368 examples [00:01, 9698.20 examples/s]
113 | Generating test split: 11490 examples [00:01, 9748.14 examples/s]
114 | ```
115 | - 通过Debug发现,它会去加载数据集同名的py文件。也就是`cnn_dailymail.py`
116 | 4. 打开`cnn_dailymail.py`这个文件,最底下有定义一个具体的数据集类。`class CnnDailymail(datasets.GeneratorBasedBuilder):`
117 | - `_info`函数,是这个数据集的一些描述介绍,以及包含的字段信息
118 | - `_vocab_text_gen`函数,看着会调用`_generate_examples`来生成一个样本迭代器。
119 | - `_split_generators`函数,看代码应该是解压/加载当前数据集里面的压缩文件,并且返回`train`/`valid`/`test`数据集。
120 | ```python
121 | def _split_generators(self, dl_manager):
122 | dl_paths = dl_manager.download_and_extract(_DL_URLS)
123 | train_files = _subset_filenames(dl_paths, datasets.Split.TRAIN)
124 | # Generate shared vocabulary
125 |
126 | return [
127 | datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"files": train_files}),
128 | datasets.SplitGenerator(
129 | name=datasets.Split.VALIDATION,
130 | gen_kwargs={"files": _subset_filenames(dl_paths, datasets.Split.VALIDATION)},
131 | ),
132 | datasets.SplitGenerator(
133 | name=datasets.Split.TEST, gen_kwargs={"files": _subset_filenames(dl_paths, datasets.Split.TEST)}
134 | ),
135 | ]
136 | ```
137 | - 注意`dl_paths = dl_manager.download_and_extract(_DL_URLS)`这一行代码,看意思下载并解压`_DL_URLS`这个变量。定位到`_DL_URLS`看看。
138 | ```python
139 | _DL_URLS = {
140 | # pylint: disable=line-too-long
141 | "cnn_stories": "cnn_stories.tgz",
142 | "dm_stories": "dailymail_stories.tgz",
143 | "test_urls": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_test.txt",
144 | "train_urls": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt",
145 | "val_urls": "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt",
146 | # pylint: enable=line-too-long
147 | }
148 | ```
149 | - 可以看出,里面包含两个数据集内置的压缩文件,以及三个在线文件,这也就是我们刚刚日志提示有下载三个文件的原因。如果我们需要离线加载,就需要将对应的在线文件下载下来放入这个数据集,然后将链接换成对应文件名就行了。对于github文件,如果下载不了,可以通过加第三方链接前缀来加速下载,例如对于`https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_test.txt`这个文件,可以在最前面加上`https://ghproxy.net/`,变成`https://ghproxy.net/https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_test.txt`,然后再去浏览器打开下载即可。
150 | 5. 补齐文件。将上面三个链接的文件都下载好,然后丢入刚刚的数据集的文件夹中,然后修改`_DL_URLS`的数值,将链接换成文件名。修改后的`_DL_URLS`变量长这样:
151 | ```python
152 | _DL_URLS = {
153 | # pylint: disable=line-too-long
154 | "cnn_stories": "cnn_stories.tgz",
155 | "dm_stories": "dailymail_stories.tgz",
156 | "test_urls": "all_test.txt",
157 | "train_urls": "all_train.txt",
158 | "val_urls": "all_val.txt",
159 | # pylint: enable=line-too-long
160 | }
161 | ```
162 | - 对应的数据集目录长这样:
163 | ```bash
164 | $ tree cnn_dailymail
165 |
166 | cnn_dailymail
167 | ├── all_test.txt
168 | ├── all_train.txt
169 | ├── all_val.txt
170 | ├── cnn_dailymail.py
171 | ├── cnn_stories.tgz
172 | ├── dailymail_stories.tgz
173 | └── README.md
174 | ```
175 | 5. 测试一下效果。找一个新电脑或者清空`~/.cache/huggingface`防止旧数据干扰。
176 | ```bash
177 | rm -rf ~/.cache/huggingface
178 | ```
179 | - 再用刚刚的脚本来加载一下试试。
180 | ```python
181 | import os.path
182 | from datasets import load_dataset
183 |
184 |
185 | now_dir = os.path.dirname(os.path.abspath(__file__))
186 | dataset_dir = os.path.join(now_dir, "cnn_dailymail")
187 | dataset = load_dataset(dataset_dir, name="3.0.0", trust_remote_code=True)
188 | print(dataset)
189 | ```
190 | - 看日志没有发生下载操作,并且数据集导入也正常,说明问题解决。
191 | ```bash
192 | Generating train split: 287113 examples [00:29, 9608.45 examples/s]
193 | Generating validation split: 13368 examples [00:01, 9722.08 examples/s]
194 | Generating test split: 11490 examples [00:01, 9927.94 examples/s]
195 | DatasetDict({
196 | train: Dataset({
197 | features: ['article', 'highlights', 'id'],
198 | num_rows: 287113
199 | })
200 | validation: Dataset({
201 | features: ['article', 'highlights', 'id'],
202 | num_rows: 13368
203 | })
204 | test: Dataset({
205 | features: ['article', 'highlights', 'id'],
206 | num_rows: 11490
207 | })
208 | })
209 | ```
210 |
211 | ### 总结
212 | 1. 有外网的,优先用方法1更加方便。
213 | 2. 没外网的,并且第三方镜像站也找不到`例如hf-mirror.com`找不到数据集,但是能找到git克隆后的数据的,用第三种方法。
214 | 3. 想了解具体数据集加载过程的,也推荐用第三种方法。
215 | 4. 不想用ftp/sftp,想直接在服务器加载数据,但是服务器上不了外网的,也推荐第三种方法。
216 | 5. 第二种方法,只是说发出来看看而已,不是很推荐。
217 |
--------------------------------------------------------------------------------
/docs/trt_llm_deploy_langchain.md:
--------------------------------------------------------------------------------
1 | ### TensorRT-LLM + Langchain部署
2 |
3 | 1. 部署Qwen-7B-Chat-TensorRT-LLM , 参考该项目:https://github.com/Tlntin/Qwen-7B-Chat-TensorRT-LLM ,需要部署api。
4 |
5 |
6 | 2. 下载Langchain-Chatchat,当前最新版0.2.7,建议直接用下面的这个魔改后的,用起来比较方便
7 | ```bash
8 | git clone https://github.com/Tlntin/Langchain-Chatchat
9 | ```
10 | - 环境配置安装readme操作即可。
11 | - 模型下载可以忽略,如果网络好的话,可以在线下载。
12 | - 初始化配置,参考readme操作即可。
13 | ```bash
14 | python copy_config_example.py
15 | ```
16 |
17 | 3. 修改模型配置文件`configs/model_config.py`,修改`LLM_MODEL`为`[qwen-trt-llm]`,如果你想添加更多api,可以直接在列表里面加就行。
18 | ```bash
19 | # LLM 名称
20 | LLM_MODELS = ["qwen-trt-llm"]
21 | ```
22 |
23 | 4. 修改模型配置文件`configs/model_config.py`,修改url地址为你部署TensorRT-LLM api的地址,默认应该是127.0.0.1:8000,我的8000被占用了,所以换成了5540,你可以改成你的ip和端口。
24 | ```bash
25 | "qwen-trt-llm": {
26 | "api_base_url": "http://127.0.0.1:5540/v1",
27 | "api_key": "no key",
28 | "version": "qwen-trt-llm",
29 | "provider": "QwenTRTLLMWorker",
30 | },
31 | ```
32 |
33 | 5. 初始化启动数据
34 | ```bash
35 | python init_database.py --recreate-vs
36 | ```
37 |
38 | 6. 启动Langchain-Chatchat,会自动打开浏览器
39 | ```bash
40 | python startup.py -a
41 | ```
42 |
43 | 7. 再选择LLM模型部分,选择`OpenAI (Running)`即可,然后就可以愉快的聊天了。
44 |
45 | 8. 如果要知识库问答。
46 | - 先选择`知识库管理`,新建知识库,然后上传任意一个文档上去,推荐点击一下`根据源文件重建向量库`。
47 | - 回到对话,对话模式选择`知识库问答`,最下面的知识库,选择你刚刚新建的那个,然后即可在右边愉快的问答了。
48 |
49 | 9. 最终效果图
50 |
--------------------------------------------------------------------------------
/examples/qwen-vl/.gitignore:
--------------------------------------------------------------------------------
1 | qwen*
2 | Qwen*
3 | *.log
4 | c-model
5 | ccdv
6 | trt_engines
7 | hg_test.py
8 | rouge.tar.xz
9 | rouge
10 | ccdv___cnn_dailymail.tar.xz
11 | ccdv___cnn_dailymail
12 | lambada.tar.xz
13 | *.json
14 | .idea
15 | *.ttf
16 | plan
17 | onnx
18 | input_pt
19 |
20 |
--------------------------------------------------------------------------------
/examples/qwen-vl/README.md:
--------------------------------------------------------------------------------
1 | # Guide to QWen-VL pipeline
2 | 1. Download Qwen-VL-Chat
3 | ```bash
4 | git lfs install
5 | git clone https://huggingface.co/Qwen/Qwen-VL-Chat
6 | ```
7 | 2. ViT
8 | - Generate ONNX model and TRT engine for ViT
9 | ```bash
10 | python vit_onnx_trt.py --pretrained_model_path ./Qwen-VL-Chat
11 | ```
12 | The exported ONNX files lies in `./onnx/visual_encoder` and the built engine lie in `./plan/visual_encoder`. And you have onnx files already and convert TRT engine only, use:
13 | ```bash
14 | python vit_onnx_trt.py --pretrained_model_path ./Qwen-VL-Chat --only_trt
15 | ```
16 | Moreover, it will save test image tensor to `image.pt` and visual query tokens to `query_tokens.pt` for later pipeline inference.
17 |
18 | 3. QwenVL(fp16)
19 |
20 | - Build TRT-LLM engines (only need to add --max_prompt_embedding_table_size)
21 |
22 | **NOTE:** `max_prompt_embedding_table_size = query_token_num * max_batch_size`, so if you changes the max_batch_size, prompt table size must be reset accordingly.
23 | ```bash
24 | python3 build.py \
25 | --hf_model_dir=./Qwen-VL-Chat \
26 | --dtype float16 --max_batch_size 4 \
27 | --remove_input_padding \
28 | --use_gpt_attention_plugin float16 \
29 | --use_gemm_plugin float16 --enable_context_fmha \
30 | --use_rmsnorm_plugin --log_level error \
31 | --use_lookup_plugin float16 \
32 | --max_prompt_embedding_table_size 2048 \
33 | --output_dir=trt_engines/Qwen-VL-7B-fp16
34 | ```
35 | The built Qwen engines lie in `./trt_engines/Qwen-VL-7B-fp16`.
36 |
37 | 4. Qwen-VL(int8 weight only)
38 | **NOTE:** `max_prompt_embedding_table_size = query_token_num * max_batch_size`, so if you changes the max_batch_size, prompt table size must be reset accordingly.
39 | ```bash
40 | python3 build.py \
41 | --hf_model_dir=./Qwen-VL-Chat \
42 | --dtype float16 --max_batch_size 4 \
43 | --remove_input_padding \
44 | --use_gpt_attention_plugin float16 \
45 | --use_gemm_plugin float16 --enable_context_fmha \
46 | --use_rmsnorm_plugin --log_level error \
47 | --use_lookup_plugin float16 \
48 | --max_prompt_embedding_table_size 2048 \
49 | --use_weight_only --weight_only_precision int8 \
50 | --output_dir=trt_engines/Qwen-VL-7B-int8
51 | ```
52 | - The built Qwen engines lie in `./trt_engines/Qwen-VL-7B-int8`.
53 |
54 | 5. Qwen-VL(int4 weight only)
55 | **NOTE:** `max_prompt_embedding_table_size = query_token_num * max_batch_size`, so if you changes the max_batch_size, prompt table size must be reset accordingly.
56 | ```bash
57 | python3 build.py \
58 | --hf_model_dir=./Qwen-VL-Chat \
59 | --dtype float16 --max_batch_size 4 \
60 | --remove_input_padding \
61 | --use_gpt_attention_plugin float16 \
62 | --use_gemm_plugin float16 --enable_context_fmha \
63 | --use_rmsnorm_plugin --log_level error \
64 | --use_lookup_plugin float16 \
65 | --max_prompt_embedding_table_size 2048 \
66 | --use_weight_only --weight_only_precision int4 \
67 | --output_dir=trt_engines/Qwen-VL-7B-int4
68 | ```
69 | - The built Qwen engines lie in `./trt_engines/Qwen-VL-7B-int4`.
70 |
71 | 6. Qwen-VL(gptq-int4)
72 | **NOTE:** `max_prompt_embedding_table_size = query_token_num * max_batch_size`, so if you changes the max_batch_size, prompt table size must be reset accordingly.
73 | - install some python package
74 | ```bash
75 | pip install auto-gptq optimum
76 | pip install transformers -U
77 | ```
78 |
79 | - convert int4-gptq weight
80 | ```bash
81 | python3 gptq_convert.py --hf_model_dir ./Qwen-VL-Chat --tokenizer_dir ./Qwen-VL-Chat --quant_ckpt_path ./Qwen-VL-Chat-My-Int4
82 | ```
83 |
84 | - build engine
85 | ```bash
86 | python3 build.py \
87 | --hf_model_dir=./Qwen-VL-Chat \
88 | --dtype float16 --max_batch_size 4 \
89 | --remove_input_padding \
90 | --use_gpt_attention_plugin float16 \
91 | --use_gemm_plugin float16 --enable_context_fmha \
92 | --use_rmsnorm_plugin --log_level error \
93 | --use_lookup_plugin float16 \
94 | --max_prompt_embedding_table_size 2048 \
95 | --use_weight_only \
96 | --weight_only_precision int4_gptq \
97 | --per_group \
98 | --quant_ckpt_path ./Qwen-VL-Chat-My-Int4/gptq_model-4bit-128g.safetensors \
99 | --output_dir=trt_engines/Qwen-VL-7B-int4-gptq
100 | ```
101 |
102 | 7. Qwen-VL-Int4(raw official gptq-int4)
103 | **NOTE:** `max_prompt_embedding_table_size = query_token_num * max_batch_size`, so if you changes the max_batch_size, prompt table size must be reset accordingly.
104 | - install some python package
105 | ```bash
106 | pip install auto-gptq optimum
107 | pip install transformers -U
108 | ```
109 |
110 | - build engine
111 | ```bash
112 | python3 build.py \
113 | --hf_model_dir=./Qwen-VL-Chat-Int4 \
114 | --quant_ckpt_path=./Qwen-VL-Chat-Int4 \
115 | --dtype float16 --max_batch_size 4 \
116 | --remove_input_padding \
117 | --use_gpt_attention_plugin float16 \
118 | --use_gemm_plugin float16 --enable_context_fmha \
119 | --use_rmsnorm_plugin --log_level error \
120 | --use_lookup_plugin float16 \
121 | --max_prompt_embedding_table_size 2048 \
122 | --use_weight_only \
123 | --weight_only_precision int4_gptq \
124 | --per_group \
125 | --output_dir=trt_engines/Qwen-VL-7B-int4-gptq
126 | ```
127 |
128 | 8. Run Qwen-VL pipeline
129 | - fp16 run
130 | ```bash
131 | python run.py \
132 | --tokenizer_dir=./Qwen-VL-Chat \
133 | --qwen_engine_dir=./trt_engines/Qwen-VL-7B-fp16/ \
134 | --vit_engine_dir=./plan/
135 | ```
136 |
137 | - int8 weight only run
138 | ```bash
139 | python run.py \
140 | --tokenizer_dir=./Qwen-VL-Chat \
141 | --qwen_engine_dir=trt_engines/Qwen-VL-7B-int8 \
142 | --vit_engine_dir=./plan/
143 | ```
144 |
145 | - int4 weight only run
146 | ```bash
147 | python run.py \
148 | --tokenizer_dir=./Qwen-VL-Chat \
149 | --qwen_engine_dir=trt_engines/Qwen-VL-7B-int4 \
150 | --vit_engine_dir=./plan/
151 | ```
152 |
153 | - int4 gptq run
154 | ```bash
155 | python run.py \
156 | --tokenizer_dir=./Qwen-VL-Chat \
157 | --qwen_engine_dir=trt_engines/Qwen-VL-7B-int4-gptq \
158 | --vit_engine_dir=./plan/
159 | ```
160 |
161 | - raw official int4 gptq run
162 | ```bash
163 | python run.py \
164 | --tokenizer_dir=./Qwen-VL-Chat-Int4 \
165 | --qwen_engine_dir=trt_engines/Qwen-VL-7B-int4-gptq \
166 | --vit_engine_dir=./plan/
167 | ```
168 |
--------------------------------------------------------------------------------
/examples/qwen-vl/client/openai_normal_client.py:
--------------------------------------------------------------------------------
1 | from openai import OpenAI
2 |
3 | client = OpenAI(
4 | base_url="http://localhost:8000/v1",
5 | api_key="no api"
6 | )
7 |
8 | messages = [{"role": "system", "content": "You are a helpful assistant."}]
9 | print("欢迎使用Qwen聊天机器人,输入exit退出,输入clear清空历史记录")
10 | while True:
11 | prompt = input('Human:')
12 | if prompt == 'exit':
13 | break
14 | if prompt == 'clear':
15 | messages = messages[:1]
16 | continue
17 | messages.append({"role": "user", "content": prompt})
18 | completion = client.chat.completions.create(
19 | model="gpt-3.5-turbo",
20 | messages=messages,
21 | top_p=0.5,
22 | temperature=0,
23 | n=1,
24 | max_tokens=4096,
25 | stream=False,
26 | )
27 | message = completion.choices[0].message
28 | response_text = message.content
29 | print('ChatBot: {}'.format(response_text))
30 | messages.append({"role": "assistant", "content": response_text})
--------------------------------------------------------------------------------
/examples/qwen-vl/client/openai_stream_client.py:
--------------------------------------------------------------------------------
1 | from openai import OpenAI
2 |
3 | client = OpenAI(
4 | base_url="http://localhost:8000/v1",
5 | api_key="no api"
6 | )
7 |
8 |
9 | messages = [{"role": "system", "content": "You are a helpful assistant."}]
10 | print("欢迎使用Qwen聊天机器人,输入exit退出,输入clear清空历史记录")
11 | while True:
12 | prompt = input('Human:')
13 | if prompt == 'exit':
14 | break
15 | if prompt == 'clear':
16 | messages = messages[:1]
17 | continue
18 | messages.append({"role": "user", "content": prompt})
19 | response = client.chat.completions.create(
20 | model="gpt-3.5-turbo",
21 | messages=messages,
22 | top_p=0.5,
23 | temperature=0,
24 | n=1,
25 | max_tokens=4096,
26 | stream=True,
27 | )
28 | print("ChatBot:", end='', flush=True)
29 | response_text = ""
30 | for event in response:
31 | event_text = event.choices[0].delta.content # extract the text
32 | if event_text is None:
33 | event_text = ""
34 | response_text += event_text
35 | print(event_text, end='', flush=True)
36 | messages.append({"role": "assistant", "content": response_text})
37 | print("")
38 |
39 |
--------------------------------------------------------------------------------
/examples/qwen-vl/default_config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | class DefaultConfig:
5 | now_dir = os.path.dirname(os.path.abspath(__file__))
6 | hf_model_dir = os.path.join(now_dir, "Qwen-VL-Chat")
7 | tokenizer_dir = os.path.join(now_dir, "Qwen-VL-Chat")
8 | int4_gptq_model_dir = os.path.join(now_dir, "qwen_7b_vl_chat_int4")
9 | ft_dir_path = os.path.join(now_dir, "c-model", "Qwen-VL-Chat")
10 | qwen_engine_dir = os.path.join(now_dir, "trt_engines", "Qwen-VL-7B-int8")
11 | vit_engine_dir = os.path.join(now_dir, "plan")
12 |
13 | # Maximum batch size for HF backend.
14 | hf_max_batch_size = 1
15 |
16 | # Maximum batch size for TRT-LLM backend.
17 | trt_max_batch_size = 4
18 |
19 | # choice the model format, base or chat
20 | # choices=["chatml", "raw"],
21 | chat_format = "chatml"
22 |
23 | # Maximum input length.
24 | max_input_len = 1024 * 6
25 |
26 | # Maximum number of generate new tokens.
27 | max_new_tokens = 1024 * 2
28 |
29 | # Top p for sampling.
30 | top_p = 0.8
31 |
32 | # Top k for sampling.
33 | top_k = 0
34 |
35 | # Temperature for sampling.
36 | temperature = 1.0
37 |
38 |
39 | default_config = DefaultConfig()
40 |
--------------------------------------------------------------------------------
/examples/qwen-vl/gptq_convert.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer
2 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
3 | from argparse import ArgumentParser
4 | import os
5 | from datasets import load_dataset
6 | from tqdm import tqdm
7 | import sys
8 | import logging
9 |
10 | sys.path.append(os.path.dirname(os.path.abspath(__file__)))
11 | from utils.utils import make_context
12 |
13 |
14 | logging.basicConfig(
15 | level=logging.INFO,
16 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
17 | )
18 |
19 |
20 | parser = ArgumentParser()
21 | parser.add_argument(
22 | "--hf_model_dir",
23 | type=str,
24 | default=None,
25 | )
26 | parser.add_argument(
27 | '--tokenizer_dir',
28 | type=str,
29 | default=None,
30 | help="Directory containing the tokenizer.model."
31 | )
32 | parser.add_argument(
33 | "--quant_ckpt_path",
34 | type=str,
35 | default=None,
36 | )
37 | parser.add_argument(
38 | "--device",
39 | type=str,
40 | default="cuda",
41 | choices=["cuda", "cpu"],
42 | )
43 | parser.add_argument(
44 | "--num_samples",
45 | type=int,
46 | default=512,
47 | )
48 |
49 |
50 | args = parser.parse_args()
51 | # model_id_or_path = default_config.hf_model_dir
52 | # quantized_model_dir = default_config.int4_gptq_model_dir
53 | tokenizer = AutoTokenizer.from_pretrained(
54 | args.tokenizer_dir, use_fast=True, trust_remote_code=True
55 | )
56 |
57 |
58 | dataset_cnn = load_dataset(
59 | "ccdv/cnn_dailymail",
60 | "3.0.0"
61 | )
62 | dataset = dataset_cnn["test"]
63 |
64 | num_samples = min(args.num_samples, len(dataset))
65 | examples = []
66 | for i in tqdm(range(num_samples), desc="tokenizing datasets"):
67 | line = dataset[i]["article"]
68 | line = line + ' TL;DR: '
69 | line = line.strip()
70 | line = line.replace(" n't", "n't")
71 | # use make_content to generate prompt
72 | raw_text, _ = make_context(
73 | tokenizer=tokenizer,
74 | query=line,
75 | history=[],
76 | )
77 | example = tokenizer(raw_text)
78 | examples.append(example)
79 |
80 | quantize_config = BaseQuantizeConfig(
81 | bits=4, # quantize model to 4-bit
82 | group_size=128, # it is recommended to set the value to 128
83 | desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad
84 | true_sequential=True,
85 | )
86 |
87 | print("model_path", args.hf_model_dir)
88 | model = (
89 | AutoGPTQForCausalLM.from_pretrained(
90 | args.hf_model_dir,
91 | quantize_config,
92 | trust_remote_code=True,
93 | use_flash_attn=False
94 | )
95 | .eval()
96 | # .cuda()
97 | )
98 | if args.device == "cuda":
99 | model.cuda()
100 | else:
101 | print("using cpu only support on Qwen 7b v1.0, not support on Qwen 7b v1.1 / Qwen 14b")
102 | print("loading model to run gptq, may need few minute...")
103 | # quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
104 | model.quantize(examples, cache_examples_on_gpu=False)
105 | print("quantized ok!")
106 |
107 | # save quantized model
108 | model.save_quantized(args.quant_ckpt_path, use_safetensors=True)
--------------------------------------------------------------------------------
/examples/qwen-vl/requirements.txt:
--------------------------------------------------------------------------------
1 | datasets~=2.3.2
2 | rouge_score~=0.1.2
3 | transformers~=4.31.0
4 | transformers-stream-generator
5 | sentencepiece~=0.1.99
6 | tiktoken
7 | einops
8 |
9 | # optional dependencies
10 | gradio==3.40.1
11 | mdtex2html
12 | sse_starlette
13 | aiohttp_sse_client
14 | openai
15 |
--------------------------------------------------------------------------------
/examples/qwen-vl/run_chat.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from run import QWenInfer, Vit, parse_arguments
3 | from vit_onnx_trt import Preprocss
4 |
5 |
6 | if __name__ == '__main__':
7 | args = parse_arguments()
8 | # load vit with custom image
9 | """
10 | image_preprocess = Preprocss(image_size=448)
11 | image_paths = ["demo.jpeg"]
12 | images = image_preprocess.encode(image_paths)
13 | image_paths = [{"image": image} for image in image_paths]
14 | vit = Vit(args.vit_engine_dir, args.log_level)
15 | input_vit = vit.run(images=images)
16 | """
17 | # otherwise
18 | input_vit = None
19 | image_paths = []
20 | qinfer = QWenInfer(args.tokenizer_dir,args.qwen_engine_dir, args.log_level)
21 | qinfer.qwen_model_init()
22 |
23 | history = []
24 | while True:
25 | input_text = None
26 | try:
27 | input_text = input("Text (or 'q' to quit): ")
28 | except:
29 | continue
30 |
31 | if input_text == "clear history":
32 | history = []
33 | continue
34 |
35 | if input_text.lower() == 'q':
36 | break
37 |
38 | # content_list = args.images_path
39 | if len(history) == 0:
40 | content_list = image_paths + [{'text': input_text}]
41 | query = qinfer.tokenizer.from_list_format(content_list)
42 | else:
43 | query = input_text
44 |
45 | response = ""
46 | for new_text in qinfer.qwen_infer_stream(
47 | input_vit=input_vit,
48 | input_text=query,
49 | max_new_tokens=args.max_new_tokens,
50 | history=history
51 | ):
52 | print(new_text, end='', flush=True)
53 | response += new_text
54 | print("")
55 | history.append((input_text, response))
56 |
57 |
58 |
59 |
60 |
61 |
--------------------------------------------------------------------------------
/examples/qwen-vl/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tlntin/Qwen-TensorRT-LLM/7da636fe7d55f42cebf3f2a43931dd0f1619efee/examples/qwen-vl/utils/__init__.py
--------------------------------------------------------------------------------
/examples/qwen-vl/utils/utils.py:
--------------------------------------------------------------------------------
1 | from transformers import PreTrainedTokenizer
2 | from typing import List, Tuple
3 |
4 |
5 | def make_context(
6 | tokenizer: PreTrainedTokenizer,
7 | query: str,
8 | history: List[Tuple[str, str]] = None,
9 | system: str = "You are a helpful assistant.",
10 | max_input_length: int = 2048, # if you want to change this, you need to change the max_input_len in tensorrt_llm_july-release-v1/examples/qwen/build.py
11 | max_window_size: int = 6144,
12 | chat_format: str = "chatml",
13 | ):
14 | if history is None:
15 | history = []
16 |
17 | if chat_format == "chatml":
18 | im_start, im_end = "<|im_start|>", "<|im_end|>"
19 | im_start_tokens = [tokenizer.im_start_id]
20 | im_end_tokens = [tokenizer.im_end_id]
21 | nl_tokens = tokenizer.encode("\n")
22 |
23 | def _tokenize_str(role, content):
24 | return (
25 | f"{role}\n{content}",
26 | tokenizer.encode(
27 | role,
28 | allowed_special=set(),
29 | ) + nl_tokens + tokenizer.encode(
30 | content,
31 | allowed_special=set(),
32 | )
33 | )
34 |
35 | system_text, system_tokens_part = _tokenize_str("system", system)
36 | system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
37 | raw_text = ""
38 | context_tokens = []
39 |
40 | for turn_query, turn_response in reversed(history):
41 | query_text, query_tokens_part = _tokenize_str("user", turn_query)
42 | query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
43 |
44 | response_text, response_tokens_part = _tokenize_str(
45 | "assistant", turn_response
46 | )
47 | response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
48 | next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
49 | prev_chat = (
50 | f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
51 | )
52 |
53 | current_context_size = (
54 | len(system_tokens) + len(next_context_tokens) + len(context_tokens)
55 | )
56 | if current_context_size < max_window_size:
57 | context_tokens = next_context_tokens + context_tokens
58 | raw_text = prev_chat + raw_text
59 | else:
60 | break
61 |
62 | context_tokens = system_tokens + context_tokens
63 | raw_text = f"{im_start}{system_text}{im_end}" + raw_text
64 | context_tokens += (
65 | nl_tokens
66 | + im_start_tokens
67 | + _tokenize_str("user", query)[1]
68 | + im_end_tokens
69 | + nl_tokens
70 | + im_start_tokens
71 | + tokenizer.encode("assistant")
72 | + nl_tokens
73 | )
74 | raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
75 |
76 | elif chat_format == "raw":
77 | raw_text = query
78 | context_tokens = tokenizer.encode(raw_text)
79 | else:
80 | raise NotImplementedError(f"Unknown chat format {chat_format!r}")
81 | # truncate to max_input_length, truncate from the front
82 | return raw_text, context_tokens[-max_input_length: ]
83 |
84 |
85 | def _decode_chatml(
86 | tokens: List[int],
87 | stop_words: List[str],
88 | eod_token_ids: List[int],
89 | tokenizer: PreTrainedTokenizer,
90 | raw_text_len: int,
91 | context_length: int,
92 | verbose: bool = False,
93 | return_end_reason: bool = False,
94 | errors: str='replace'
95 | ):
96 | end_reason = f"Gen length {len(tokens)}"
97 | eod_token_idx = context_length
98 | for eod_token_idx in range(context_length, len(tokens)):
99 | if tokens[eod_token_idx] in eod_token_ids:
100 | end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
101 | break
102 |
103 | trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)[raw_text_len:]
104 | if verbose:
105 | print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens, errors=errors)[raw_text_len:])
106 | print("\nRaw Generate:", trim_decode_tokens)
107 | print("\nEnd Reason:", end_reason)
108 | for stop_word in stop_words:
109 | trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
110 | trim_decode_tokens = trim_decode_tokens.strip()
111 | if verbose:
112 | print("\nGenerate:", trim_decode_tokens)
113 |
114 | if return_end_reason:
115 | return trim_decode_tokens, end_reason
116 | else:
117 | return trim_decode_tokens
118 |
119 |
120 | def get_stop_words_ids(chat_format, tokenizer):
121 | if chat_format == "raw":
122 | stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]]
123 | elif chat_format == "chatml":
124 | stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
125 | else:
126 | raise NotImplementedError(f"Unknown chat format {chat_format!r}")
127 | return stop_words_ids
--------------------------------------------------------------------------------
/examples/qwen-vl/vit_onnx_trt.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import AutoModelForCausalLM,AutoTokenizer
3 | from torchvision import transforms
4 | from transformers import AutoConfig
5 | from typing import List
6 | from torchvision.transforms import InterpolationMode
7 | from PIL import Image
8 | import requests
9 | import os
10 | import tensorrt as trt
11 | import argparse
12 |
13 | from tensorrt_llm._utils import str_dtype_to_torch
14 |
15 | import tensorrt as trt
16 | from itertools import tee
17 |
18 | from polygraphy.backend.trt import (
19 | network_from_onnx_path,
20 | engine_from_network,
21 | save_engine,
22 | Profile,
23 | )
24 |
25 | from polygraphy.backend.trt import CreateConfig
26 | from tensorrt import MemoryPoolType
27 |
28 | class Preprocss:
29 | def __init__(self,
30 | image_size:int,
31 | ):
32 | mean = (0.48145466, 0.4578275, 0.40821073)
33 | std = (0.26862954, 0.26130258, 0.27577711)
34 | self.image_transform = transforms.Compose([
35 | transforms.Resize(
36 | (image_size,image_size),
37 | interpolation = InterpolationMode.BICUBIC
38 | ),
39 | transforms.ToTensor(),
40 | transforms.Normalize(mean=mean,std=std),
41 |
42 | ])
43 |
44 | def encode(self,image_paths: List[str]):
45 | images = []
46 | for image_path in image_paths:
47 | if image_path.startswith("http://") or image_path.startswith("https://"):
48 | image = Image.open(requests.get(image_path,stream=True).raw)
49 | else:
50 | image = Image.open(image_path)
51 | image = image.convert("RGB")
52 | images.append(self.image_transform(image))
53 | images = torch.stack(images, dim=0)
54 | return images
55 |
56 | class ONNX_TRT:
57 | def __init__(self,image_size):
58 | self.image_size = image_size
59 | def export_onnx(self,onnx_file_path,pretrained_model_path):
60 |
61 | image_pre_obj = Preprocss(self.image_size)
62 | torch_dtype = str_dtype_to_torch("float32")
63 | model = AutoModelForCausalLM.from_pretrained(
64 | pretrained_model_path,
65 | device_map="cpu",
66 | torch_dtype=torch_dtype,
67 | fp32=True,
68 | trust_remote_code=True
69 | ).eval()
70 | image_url = ['https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg']
71 | device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
72 | image = image_pre_obj.encode(image_url).to(device)
73 | if not os.path.exists('./input_pt'):
74 | os.mkdir('./input_pt')
75 | torch.save(image, './input_pt/image.pt')
76 | #model_visual = model.transformer.visual.to(device).to(torch_dtype)
77 | model_visual = model.transformer.visual
78 | model_visual.eval()
79 |
80 | torch.onnx.export(model_visual,
81 | image.to('cuda'),
82 | onnx_file_path,
83 | opset_version=17,
84 | input_names=['input'],
85 | output_names = ['output'],
86 | dynamic_axes = {
87 | 'input':{0:'batch'}
88 | }
89 | )
90 | def generate_trt_engine(self,onnxFile,planFile,use_polygraph,minBS=1,optBS=2,maxBS=4):
91 | import tensorrt as trt
92 | from time import time
93 |
94 | ## There are two ways to convert an engine
95 | ## 1. the first is to use the polygraph tool, which can use fp16;
96 | ## 2. the second is to use the native trt api, which must use fp32, if use fp16 the accuracy loss is great
97 | ##
98 | ## todo: the difference between the two ways!!
99 | if use_polygraph:
100 | print("we are using polygraph tools get engine file !!!")
101 | #preview_features = [trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805]
102 | preview_features = []
103 |
104 | profiles = [Profile().add(
105 | "input",
106 | min=(minBS, 3, self.image_size, self.image_size ),
107 | opt=(optBS, 3, self.image_size, self.image_size ), # Optimized based on the inputs.
108 | max=(maxBS, 3, self.image_size, self.image_size ),
109 | )]
110 | trt_inference_config = CreateConfig(
111 | fp16=True,
112 | memory_pool_limits = {MemoryPoolType.WORKSPACE: 2048 * 1024 * 1024},
113 | profiles=profiles,
114 | precision_constraints=("obey"),
115 | builder_optimization_level=3,
116 | preview_features=preview_features
117 | )
118 |
119 | onnx_network = network_from_onnx_path(onnxFile)
120 |
121 | trt_engine = engine_from_network(onnx_network, trt_inference_config)
122 |
123 | save_engine(trt_engine, planFile)
124 |
125 | else:
126 | print("we are using tensorrt api get engine file !!!")
127 | logger = trt.Logger(trt.Logger.INFO)
128 | builder = trt.Builder(logger)
129 | network = builder.create_network(
130 | 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
131 | profile = builder.create_optimization_profile()
132 | config = builder.create_builder_config()
133 | # breakpoint()
134 | #config.set_flag(trt.BuilderFlag.FP16)
135 | #config.set_flag(trt.BuilderFlag.OBEY_PRECISION_CONSTRAINTS)
136 |
137 | parser = trt.OnnxParser(network, logger)
138 | print("======onnxFile",onnxFile)
139 |
140 | with open(onnxFile, 'rb') as model:
141 | if not parser.parse(model.read(), "/".join(onnxFile.split("/"))):
142 | print("Failed parsing %s" % onnxFile)
143 | for error in range(parser.num_errors):
144 | print(parser.get_error(error))
145 | print("Succeeded parsing %s" % onnxFile)
146 | print("Begin convert onnx to TensorRT engine, need wait a few minutes")
147 |
148 | nBS = -1
149 | nMinBS = minBS
150 | nOptBS = optBS
151 | nMaxBS = maxBS
152 | inputT = network.get_input(0)
153 | inputT.shape = [nBS, 3, self.image_size, self.image_size]
154 | profile.set_shape(inputT.name, [nMinBS, 3, self.image_size, self.image_size],
155 | [nOptBS, 3, self.image_size, self.image_size], [nMaxBS, 3, self.image_size, self.image_size])
156 |
157 | config.add_optimization_profile(profile)
158 |
159 | t0 = time()
160 | engineString = builder.build_serialized_network(network, config)
161 | t1 = time()
162 | if engineString == None:
163 | print("Failed building %s" % planFile)
164 | else:
165 | print("Succeeded building %s in %d s" % (planFile, t1 - t0))
166 | print("plan file is",planFile)
167 | with open(planFile, 'wb') as f:
168 | f.write(engineString)
169 |
170 | def parse_arguments():
171 | parser = argparse.ArgumentParser()
172 | parser.add_argument('--onnxFile',type=str, default='./onnx/visual_encoder/visual_encoder.onnx',help='')#onnx/visual_encoder
173 | parser.add_argument('--pretrained_model_path',type=str, default='./Qwen-VL-Chat',help='')
174 | parser.add_argument('--planFile',type=str, default='./plan/visual_encoder/visual_encoder_fp16.plan',help='')
175 | parser.add_argument('--only_trt', action='store_true', help='Run only convert the onnx to TRT engine.')
176 | parser.add_argument('--minBS',type=int, default=1)
177 | parser.add_argument('--optBS',type=int, default=1)
178 | parser.add_argument('--maxBS',type=int, default=4)
179 | parser.add_argument('--use_polygraph', action='store_true', help='if use polygraph tools get engine.')
180 | args = parser.parse_args()
181 | return args
182 |
183 |
184 | if __name__ == '__main__':
185 |
186 | args = parse_arguments()
187 | onnx_file_dir = os.path.dirname(args.onnxFile)
188 | if not os.path.exists(onnx_file_dir):
189 | os.makedirs(onnx_file_dir)
190 | plan_file_dir = os.path.dirname(args.planFile)
191 | if not os.path.exists(plan_file_dir):
192 | os.makedirs(plan_file_dir)
193 | if True:
194 | onnx_trt_obj = ONNX_TRT(448)
195 | else:
196 | onnx_trt_obj = ONNX_TRT(config.visual['image_size'])
197 |
198 | if args.only_trt:
199 | onnx_trt_obj.generate_trt_engine(args.onnxFile,args.planFile,args.minBS,args.optBS,args.maxBS,args.use_polygraph)
200 | else:
201 | onnx_trt_obj.export_onnx(args.onnxFile,args.pretrained_model_path)
202 | onnx_trt_obj.generate_trt_engine(args.onnxFile,args.planFile,args.use_polygraph,args.minBS,args.optBS,args.maxBS)
203 |
204 |
205 |
206 |
207 |
208 |
209 |
--------------------------------------------------------------------------------
/examples/qwen-vl/web_demo.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Alibaba Cloud.
2 | #
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | """A simple web interactive chat demo based on gradio."""
7 |
8 | from argparse import ArgumentParser
9 | from pathlib import Path
10 | import copy
11 | import gradio as gr
12 | import os
13 | import re
14 | import secrets
15 | import tempfile
16 | from default_config import default_config
17 | from transformers import AutoTokenizer
18 | from openai import OpenAI
19 |
20 | BOX_TAG_PATTERN = r"([\s\S]*?)"
21 | PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
22 |
23 |
24 | def parse_args():
25 | parser = ArgumentParser()
26 | parser.add_argument(
27 | '--tokenizer_dir',
28 | type=str,
29 | default=default_config.tokenizer_dir,
30 | help="Directory containing the tokenizer.model."
31 |
32 | )
33 | parser.add_argument("--share", action="store_true", default=False,
34 | help="Create a publicly shareable link for the interface.")
35 | parser.add_argument("--inbrowser", action="store_true", default=False,
36 | help="Automatically launch the interface in a new tab on the default browser.")
37 | parser.add_argument("--server-port", type=int, default=7860,
38 | help="Demo server port.")
39 | parser.add_argument("--server-name", type=str, default="127.0.0.1",
40 | help="Demo server name.")
41 | args = parser.parse_args()
42 | return args
43 |
44 |
45 | args = parse_args()
46 | client = OpenAI(
47 | base_url="http://localhost:8000/v1",
48 | api_key="no api"
49 | )
50 | tokenizer = AutoTokenizer.from_pretrained(
51 | args.tokenizer_dir,
52 | legacy=False,
53 | trust_remote_code=True,
54 | )
55 |
56 |
57 | def _parse_text(text):
58 | lines = text.split("\n")
59 | lines = [line for line in lines if line != ""]
60 | count = 0
61 | for i, line in enumerate(lines):
62 | if "```" in line:
63 | count += 1
64 | items = line.split("`")
65 | if count % 2 == 1:
66 | lines[i] = f'
'
67 | else:
68 | lines[i] = f"
"
69 | else:
70 | if i > 0:
71 | if count % 2 == 1:
72 | line = line.replace("`", r"\`")
73 | line = line.replace("<", "<")
74 | line = line.replace(">", ">")
75 | line = line.replace(" ", " ")
76 | line = line.replace("*", "*")
77 | line = line.replace("_", "_")
78 | line = line.replace("-", "-")
79 | line = line.replace(".", ".")
80 | line = line.replace("!", "!")
81 | line = line.replace("(", "(")
82 | line = line.replace(")", ")")
83 | line = line.replace("$", "$")
84 | lines[i] = "
" + line
85 | text = "".join(lines)
86 | return text
87 |
88 |
89 | def _remove_image_special(text):
90 | text = text.replace('[', '').replace(']', '')
91 | return re.sub(r'.*?(|$)', '', text)
92 |
93 |
94 | def _launch_demo(args):
95 | uploaded_file_dir = os.environ.get("GRADIO_TEMP_DIR") or str(
96 | Path(tempfile.gettempdir()) / "gradio"
97 | )
98 |
99 | def predict(_chatbot, task_history):
100 | chat_query = _chatbot[-1][0]
101 | query = task_history[-1][0]
102 | # print("User: " + _parse_text(query))
103 | history_cp = copy.deepcopy(task_history)
104 | full_response = ""
105 |
106 | history_filter = []
107 | pic_idx = 1
108 | pre = ""
109 | image_list = []
110 | for i, (q, a) in enumerate(history_cp):
111 | if isinstance(q, (tuple, list)):
112 | image_list.append(q[0])
113 | q = f'Picture {pic_idx}:
{q[0]}'
114 | pre += q + '\n'
115 | pic_idx += 1
116 | else:
117 | pre += q
118 | history_filter.append((pre, a))
119 | pre = ""
120 | history, message = history_filter[:-1], history_filter[-1][0]
121 | messages = [
122 | {"role": "system", "content": "You are a helpful assistant."},
123 | ]
124 | for (query1, response1) in history:
125 | messages.append({"role": "user", "content": query1})
126 | messages.append({"role": "assistant", "content": response1})
127 |
128 | message_dict = {"role": "user", "content": message}
129 | if len(image_list) > 0:
130 | message_dict["images"] = image_list
131 | messages.append(message_dict)
132 | # print("Image list: ", image_list)
133 |
134 | response = client.chat.completions.create(
135 | model="gpt-3.5-turbo",
136 | messages=messages,
137 | # top_p=top_p,
138 | # temperature=temperature,
139 | n=1,
140 | # max_tokens=max_generate_length,
141 | stream=True,
142 | )
143 | response_text = ""
144 | for event in response:
145 | event_text = event.choices[0].delta.content # extract the text
146 | if event_text is None:
147 | event_text = ""
148 | # print(event_text)
149 | response_text += event_text
150 | _chatbot[-1] = (_parse_text(chat_query),
151 | _remove_image_special(_parse_text(response_text)))
152 |
153 | yield _chatbot
154 | full_response = _parse_text(response_text)
155 |
156 | response = full_response
157 | # print("response", response)
158 | history.append((message, response_text))
159 | image = tokenizer.draw_bbox_on_latest_picture(response, history)
160 | if image is not None:
161 | temp_dir = secrets.token_hex(20)
162 | temp_dir = Path(uploaded_file_dir) / temp_dir
163 | temp_dir.mkdir(exist_ok=True, parents=True)
164 | name = f"tmp{secrets.token_hex(5)}.jpg"
165 | filename = temp_dir / name
166 | image.save(str(filename))
167 | _chatbot.append((None, (str(filename),)))
168 | else:
169 | _chatbot[-1] = (_parse_text(chat_query), response)
170 | # full_response = _parse_text(response)
171 |
172 | task_history[-1] = (query, full_response)
173 | # print("Qwen-VL-Chat: " + _parse_text(full_response))
174 | yield _chatbot
175 |
176 | def regenerate(_chatbot, task_history):
177 | if not task_history:
178 | return _chatbot
179 | item = task_history[-1]
180 | if item[1] is None:
181 | return _chatbot
182 | task_history[-1] = (item[0], None)
183 | chatbot_item = _chatbot.pop(-1)
184 | if chatbot_item[0] is None:
185 | _chatbot[-1] = (_chatbot[-1][0], None)
186 | else:
187 | _chatbot.append((chatbot_item[0], None))
188 | return predict(_chatbot, task_history)
189 |
190 | def add_text(history, task_history, text):
191 | task_text = text
192 | if len(text) >= 2 and text[-1] in PUNCTUATION and text[-2] not in PUNCTUATION:
193 | task_text = text[:-1]
194 | history = history + [(_parse_text(text), None)]
195 | task_history = task_history + [(task_text, None)]
196 | return history, task_history, ""
197 |
198 | def add_file(history, task_history, file):
199 | history = history + [((file.name,), None)]
200 | task_history = task_history + [((file.name,), None)]
201 | return history, task_history
202 |
203 | def reset_user_input():
204 | return gr.update(value="")
205 |
206 | def reset_state(task_history):
207 | task_history.clear()
208 | return []
209 |
210 | with gr.Blocks() as demo:
211 | gr.Markdown("""Qwen-VL-Chat Bot""")
212 | chatbot = gr.Chatbot(label='Qwen-VL-Chat', elem_classes="control-height", height=550)
213 | query = gr.Textbox(lines=2, label='Input')
214 | task_history = gr.State([])
215 |
216 | with gr.Row():
217 | empty_bin = gr.Button("🧹 Clear History (清除历史)")
218 | submit_btn = gr.Button("🚀 Submit (发送)")
219 | regen_btn = gr.Button("🤔️ Regenerate (重试)")
220 | addfile_btn = gr.UploadButton("📁 Upload (上传文件)", file_types=["image"])
221 |
222 | submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history]).then(
223 | predict, [chatbot, task_history], [chatbot], show_progress=True
224 | )
225 | submit_btn.click(reset_user_input, [], [query])
226 | empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
227 | regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)
228 | addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True)
229 |
230 | gr.Markdown("""\
231 | Note: This demo is governed by the original license of Qwen-VL. \
232 | We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, \
233 | including hate speech, violence, pornography, deception, etc. \
234 | (注:本演示受Qwen-VL的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,\
235 | 包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)""")
236 |
237 | demo.queue().launch(
238 | share=args.share,
239 | inbrowser=args.inbrowser,
240 | server_port=args.server_port,
241 | server_name=args.server_name,
242 | )
243 |
244 |
245 | if __name__ == '__main__':
246 | _launch_demo(args)
247 |
--------------------------------------------------------------------------------
/examples/qwen/.gitignore:
--------------------------------------------------------------------------------
1 | qwen*
2 | Qwen*
3 | *.log
4 | c-model
5 | ccdv
6 | trt_engines
7 | hg_test.py
8 | rouge.tar.xz
9 | rouge
10 | ccdv___cnn_dailymail.tar.xz
11 | ccdv___cnn_dailymail
12 | lambada.tar.xz
13 | *.json
14 | .idea
15 |
--------------------------------------------------------------------------------
/examples/qwen/cli_chat.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from run import get_model
4 | from run import QWenForCausalLMGenerationSession
5 | from default_config import default_config
6 |
7 | now_dir = os.path.dirname(os.path.abspath(__file__))
8 |
9 |
10 | def parse_arguments():
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--max_new_tokens', type=int, default=default_config.max_new_tokens)
13 | parser.add_argument('--log_level', type=str, default='error')
14 | parser.add_argument(
15 | '--engine_dir',
16 | type=str,
17 | default=default_config.engine_dir,
18 | )
19 | parser.add_argument(
20 | '--tokenizer_dir',
21 | type=str,
22 | default=default_config.tokenizer_dir,
23 | help="Directory containing the tokenizer.model."
24 | )
25 | parser.add_argument(
26 | '--stream',
27 | type=bool,
28 | default=True,
29 | help="return text with stream")
30 | return parser.parse_args()
31 |
32 |
33 | if __name__ == "__main__":
34 | # get model info
35 | args = parse_arguments()
36 | (
37 | model_config, sampling_config, runtime_mapping, runtime_rank,
38 | serialize_path, remove_input_padding,
39 | tokenizer, eos_token_id, pad_token_id
40 | ) = get_model(args.tokenizer_dir, args.engine_dir, args.log_level)
41 | with open(serialize_path, 'rb') as f:
42 | engine_buffer = f.read()
43 | decoder = QWenForCausalLMGenerationSession(
44 | model_config,
45 | engine_buffer,
46 | runtime_mapping,
47 | )
48 | history = []
49 | response = ''
50 | print("欢迎使用Qwen聊天机器人,输入exit退出,输入clear清空历史记录")
51 | while True:
52 | input_text = input("Input: ")
53 | if input_text in ["exit", "quit", "exit()", "quit()"]:
54 | break
55 | if input_text == 'clear':
56 | history = []
57 | continue
58 | if not args.stream:
59 | response = decoder.chat(
60 | tokenizer=tokenizer,
61 | sampling_config=sampling_config,
62 | input_text=input_text,
63 | history=history,
64 | max_new_tokens=args.max_new_tokens,
65 | )
66 | print(f'Output: {response[0]}')
67 | else:
68 | print("Output: ", end='')
69 |
70 | response = ""
71 | for new_text in decoder.chat_stream(
72 | tokenizer=tokenizer,
73 | sampling_config=sampling_config,
74 | input_text=input_text,
75 | history=history,
76 | max_new_tokens=args.max_new_tokens,
77 | ):
78 | print(new_text[0], end='', flush=True)
79 | response += new_text[0]
80 | print("")
81 | history.append((input_text, response))
--------------------------------------------------------------------------------
/examples/qwen/client/async_client.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import json
3 | import aiohttp_sse_client.client
4 | from aiohttp import ClientSession
5 | from aiohttp_sse_client import client as sseclient
6 |
7 |
8 | async def handle_event(event: aiohttp_sse_client.client.MessageEvent, event_source):
9 | # 处理 SSE 事件的回调函数
10 | data = json.loads(event.data)
11 | # print("data", data)
12 | if event.type == "finish":
13 | try:
14 | await event_source.close()
15 | except Exception as err:
16 | print("close with error", err)
17 | return data["response"], event.type
18 |
19 |
20 | async def listen_sse(query, history=None, max_new_tokens=4096, top_p=0.5, temperature=0):
21 | if history is None:
22 | history = []
23 | async with ClientSession() as session:
24 | url = 'http://127.0.0.1:8000/stream_chat/'
25 | data = {
26 | "query": query,
27 | "history": history,
28 | "max_new_tokens": max_new_tokens,
29 | "top_p": top_p,
30 | "temperature": temperature,
31 | }
32 | headers = {'Content-Type': 'application/json'}
33 | response = ""
34 | if history is None:
35 | history = []
36 | print("Chatbox: ", end='', flush=True)
37 | async with sseclient.EventSource(url, json=data, headers=headers, session=session) as event_source:
38 | try:
39 | async for event in event_source:
40 | # 将事件传递给回调函数进行处理
41 | new_text, e_type = await handle_event(event, event_source)
42 | print(new_text, end='', flush=True)
43 | response += new_text
44 | if e_type == "finish":
45 | break
46 | except Exception as err:
47 | print("event close", err)
48 | print("")
49 | history.append((query, response))
50 | return response, history
51 |
52 |
53 | if __name__ == "__main__":
54 | history1 = []
55 | print("欢迎使用Qwen聊天机器人,输入exit退出,输入clear清空历史记录")
56 | while True:
57 | query = input("Human: ")
58 | if query == 'exit':
59 | break
60 | if query == 'clear':
61 | history1 = []
62 | continue
63 | _, history1 = asyncio.run(listen_sse(query, history1))
64 |
--------------------------------------------------------------------------------
/examples/qwen/client/normal_client.py:
--------------------------------------------------------------------------------
1 | import json
2 | import requests
3 |
4 |
5 | def chat(query, history=None,max_new_tokens=4096, top_p=0.5, temperature=0):
6 | if history is None:
7 | history = []
8 | url = 'http://127.0.0.1:8000/chat/'
9 | data = {
10 | "query": query,
11 | "history": history,
12 | "max_new_tokens": max_new_tokens,
13 | "top_p": top_p,
14 | "temperature": temperature,
15 | }
16 | headers = {'Content-Type': 'application/json'}
17 | res = requests.post(url=url, data=json.dumps(data), headers=headers)
18 | if res.status_code == 200:
19 | data = res.json()
20 | if data["status"] == 200:
21 | return data["response"], data["history"]
22 | else:
23 | print("Error: ", data)
24 | return "", history
25 | else:
26 | print("Error: ", res.status_code)
27 | return "", history
28 |
29 |
30 |
31 | if __name__ == "__main__":
32 | history1 = []
33 | print("欢迎使用Qwen聊天机器人,输入exit退出,输入clear清空历史记录")
34 | while True:
35 | query = input("Human: ")
36 | if query == 'exit':
37 | break
38 | if query == 'clear':
39 | history1 = []
40 | continue
41 | response, history1 = chat(query, history1)
42 | print("ChatBot: {}".format(response))
43 |
--------------------------------------------------------------------------------
/examples/qwen/client/openai_function_call.py:
--------------------------------------------------------------------------------
1 | from openai import OpenAI
2 | import requests
3 | import urllib3
4 | import time
5 | import random
6 | import json
7 |
8 |
9 | urllib3.disable_warnings()
10 |
11 | client = OpenAI(
12 | base_url="http://localhost:8000/v1",
13 | api_key="no api"
14 | )
15 |
16 | # get api from here https://dev.qweather.com/
17 | weather_key = ""
18 | assert len(weather_key) > 0, print("please get weather query api in https://dev.qweather.com/")
19 |
20 |
21 | class Weather:
22 | def __init__(self, api_key):
23 | self.api_key = api_key
24 |
25 | def get_location_from_api(self, location, adm=None,
26 | location_range="world", lang="zh"):
27 | """
28 | Get api based on https:dev.qweather.com
29 | params location: the location to be queried
30 | params adm: superior region, for example, the superior region of Yuexiu is Guangzhou
31 | params location_range: query range, default global, supports cn: China, us: United States, fr: France,
32 | uk: United Kingdom, please check the iso-3166 standard for more information
33 | params lang: language, default zh, support en
34 | """
35 | url = "https://geoapi.qweather.com/v2/city/lookup?"
36 | params = {
37 | "key": self.api_key,
38 | "location": location,
39 | "range": location_range,
40 | "lang": lang,
41 | }
42 | if adm is not None:
43 | if len(adm) > 0:
44 | params["adm"] = adm
45 | session = requests.session()
46 | try:
47 | res2 = session.get(url, params=params, verify=False, timeout=15)
48 | if res2.status_code == 200:
49 | data = res2.json()
50 | if data.get("code", None) == '200':
51 | return data.get("location", [])
52 | else:
53 | print(data)
54 | else:
55 | print(res2)
56 | time.sleep(1 + random.random())
57 | session.close()
58 | except Exception as err:
59 | print("request error", err)
60 | time.sleep(3 + random.random())
61 | session.close()
62 | return []
63 |
64 | def get_weather_from_api(self, location: str):
65 | """
66 | Get weather information from Zefeng weather api
67 | :param location: location information, which can be location_id or a latitude and longitude (format: "longitude, latitude")
68 | """
69 | url = "https://devapi.qweather.com/v7/weather/3d?"
70 | params = {
71 | "location": location,
72 | "key": self.api_key
73 | }
74 | session = requests.session()
75 | try:
76 | res1 = session.get(url, params=params, verify=False, timeout=15)
77 | if res1.status_code == 200:
78 | data = res1.json()
79 | if data.get("code", "") == "200":
80 | return data.get("daily", [])
81 | else:
82 | print(data)
83 | else:
84 | print(res1)
85 | time.sleep(1 + random.random())
86 | session.close()
87 | except Exception as err:
88 | print("get api error,", err)
89 | time.sleep(3 + random.random())
90 | session.close()
91 | return []
92 |
93 |
94 | def get_current_weather(location: str):
95 | weather = Weather(weather_key)
96 | location_data = weather.get_location_from_api(location)
97 | if len(location_data) > 0:
98 | location_dict = location_data[0]
99 | city_id = location_dict["id"]
100 | weather_res = weather.get_weather_from_api(city_id)
101 | n_day = len(weather_res)
102 | return f"查询到最近{n_day}天的天气。" + json.dumps(weather_res, ensure_ascii=False)
103 | else:
104 | return ""
105 |
106 | def call_qwen(messages, functions=None):
107 | # print(messages)
108 | if functions:
109 | response = client.chat.completions.create(
110 | model="Qwen", messages=messages, functions=functions
111 | )
112 | else:
113 | response = client.chat.completions.create(
114 | model="Qwen", messages=messages
115 | )
116 | # print(response)
117 | # print(response.choices[0].message.content)
118 | return response
119 |
120 |
121 | def chat(query: str):
122 | functions = [
123 | {
124 | "name": "get_current_weather",
125 | "description": "Get the current weather in a given location.",
126 | "parameters": {
127 | "type": "object",
128 | "properties": {
129 | "location": {
130 | "type": "string",
131 | "description": "The city and state, e.g. San Francisco, CA",
132 | },
133 | "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
134 | },
135 | "required": ["location"],
136 | },
137 | }
138 | ]
139 |
140 | messages = [
141 | {
142 | "role": "user",
143 | # Note: The current version of Qwen-7B-Chat (as of 2023.08) performs okay with Chinese tool-use prompts,
144 | # but performs terribly when it comes to English tool-use prompts, due to a mistake in data collecting.
145 | "content": query,
146 | }
147 | ]
148 | response = call_qwen(messages, functions)
149 | res = response.choices[0].message
150 | message_dict = {
151 | "role": res.role,
152 | "content": res.content,
153 | "function_call": res.function_call,
154 | }
155 | messages.append(message_dict)
156 | # --- call function --- #
157 | if res.function_call is not None:
158 | function_call = res.function_call
159 | function_name = function_call.name
160 | try:
161 | function_params = json.loads(function_call.arguments)
162 | except:
163 | print(f"{function_name}解析对应参数失败,请检查, 参数信息:", function_call)
164 | return
165 | for temp_dict in functions:
166 | if temp_dict["name"] == function_name:
167 | require_params = temp_dict["parameters"]["required"]
168 | # require_params.sort()
169 | had_params = list(function_params.keys())
170 | # had_params.sort()
171 | for param in had_params:
172 | if param not in require_params:
173 | del function_params[param]
174 | # recompute
175 | had_params = list(function_params.keys())
176 | if len(had_params) != len(require_params):
177 | raise Exception("ERROR, need to do other fill params")
178 |
179 |
180 | response = eval(function_name)(**function_params)
181 | message = {
182 | "role": "function",
183 | "name": function_name,
184 | }
185 | if len(response) > 0:
186 | message["content"] = response
187 | else:
188 | message["content"] = "未找到任何信息"
189 | messages.append(message)
190 | response = call_qwen(messages, functions)
191 | return response
192 |
193 |
194 | messages = [{"role": "system", "content": "You are a helpful assistant."}]
195 | print("=" * 20)
196 | print("欢迎使用Qwen聊天机器人,输入exit退出,输入clear清空历史记录")
197 | print("目前已支持天气查询插件")
198 | print("=" * 20)
199 | query = "北京天气如何?穿短袖会不会冷?"
200 | print("用户输入:", query)
201 | res = chat(query)
202 | print("回答结果:", res.choices[0].message.content)
203 |
--------------------------------------------------------------------------------
/examples/qwen/client/openai_normal_client.py:
--------------------------------------------------------------------------------
1 | from openai import OpenAI
2 |
3 | client = OpenAI(
4 | base_url="http://localhost:8000/v1",
5 | api_key="no api"
6 | )
7 |
8 | messages = [{"role": "system", "content": "You are a helpful assistant."}]
9 | print("欢迎使用Qwen聊天机器人,输入exit退出,输入clear清空历史记录")
10 | while True:
11 | prompt = input('Human:')
12 | if prompt == 'exit':
13 | break
14 | if prompt == 'clear':
15 | messages = messages[:1]
16 | continue
17 | messages.append({"role": "user", "content": prompt})
18 | completion = client.chat.completions.create(
19 | model="gpt-3.5-turbo",
20 | messages=messages,
21 | top_p=0.5,
22 | temperature=0,
23 | n=1,
24 | max_tokens=4096,
25 | stream=False,
26 | )
27 | message = completion.choices[0].message
28 | response_text = message.content
29 | print('ChatBot: {}'.format(response_text))
30 | messages.append({"role": "assistant", "content": response_text})
--------------------------------------------------------------------------------
/examples/qwen/client/openai_stream_client.py:
--------------------------------------------------------------------------------
1 | from openai import OpenAI
2 |
3 | client = OpenAI(
4 | base_url="http://localhost:8000/v1",
5 | api_key="no api"
6 | )
7 |
8 |
9 | messages = [{"role": "system", "content": "You are a helpful assistant."}]
10 | print("欢迎使用Qwen聊天机器人,输入exit退出,输入clear清空历史记录")
11 | while True:
12 | prompt = input('Human:')
13 | if prompt == 'exit':
14 | break
15 | if prompt == 'clear':
16 | messages = messages[:1]
17 | continue
18 | messages.append({"role": "user", "content": prompt})
19 | response = client.chat.completions.create(
20 | model="gpt-3.5-turbo",
21 | messages=messages,
22 | top_p=0.5,
23 | temperature=0,
24 | n=1,
25 | max_tokens=4096,
26 | stream=True,
27 | )
28 | print("ChatBot:", end='', flush=True)
29 | response_text = ""
30 | for event in response:
31 | # print(event)
32 | event_text = event.choices[0].delta.content # extract the text
33 | if event_text is None:
34 | event_text = ""
35 | response_text += event_text
36 | print(event_text, end='', flush=True)
37 | messages.append({"role": "assistant", "content": response_text})
38 | print("")
39 |
40 |
--------------------------------------------------------------------------------
/examples/qwen/default_config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | class DefaultConfig:
5 | now_dir = os.path.dirname(os.path.abspath(__file__))
6 | hf_model_dir = os.path.join(now_dir, "qwen_7b_chat")
7 | tokenizer_dir = os.path.join(now_dir, "qwen_7b_chat")
8 | int4_gptq_model_dir = os.path.join(now_dir, "qwen_7b_chat_int4")
9 | ft_dir_path = os.path.join(now_dir, "c-model", "qwen_7b_chat")
10 | engine_dir=os.path.join(now_dir, "trt_engines", "fp16", "1-gpu")
11 |
12 | # Maximum batch size for HF backend.
13 | hf_max_batch_size = 1
14 |
15 | # Maximum batch size for TRT-LLM backend.
16 | trt_max_batch_size = 1
17 |
18 | # choice the model format, base or chat
19 | # choices=["chatml", "raw"],
20 | chat_format = "chatml"
21 |
22 | # Maximum input length.
23 | max_input_len = 1024 * 6
24 |
25 | # Maximum number of generate new tokens.
26 | max_new_tokens = 2048
27 |
28 | # Top p for sampling.
29 | top_p = 0.8
30 |
31 |
32 | # Top k for sampling.
33 | top_k = 0
34 |
35 | # Temperature for sampling.
36 | temperature = 1.0
37 |
38 |
39 | default_config = DefaultConfig()
40 |
--------------------------------------------------------------------------------
/examples/qwen/gptq_convert.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer
2 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
3 | from default_config import default_config
4 | from argparse import ArgumentParser
5 | import os
6 | from datasets import load_dataset
7 | from tqdm import tqdm
8 | import sys
9 | import logging
10 |
11 | sys.path.append(os.path.dirname(os.path.abspath(__file__)))
12 | from utils.utils import make_context
13 |
14 |
15 | logging.basicConfig(
16 | level=logging.INFO,
17 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
18 | )
19 |
20 |
21 | parser = ArgumentParser()
22 | parser.add_argument(
23 | "--hf_model_dir",
24 | type=str,
25 | default=default_config.hf_model_dir,
26 | )
27 | parser.add_argument(
28 | '--tokenizer_dir',
29 | type=str,
30 | default=default_config.tokenizer_dir,
31 | help="Directory containing the tokenizer.model."
32 | )
33 | parser.add_argument(
34 | "--quant_ckpt_path",
35 | type=str,
36 | default=os.path.join(
37 | default_config.int4_gptq_model_dir,
38 | ),
39 | )
40 | parser.add_argument(
41 | "--device",
42 | type=str,
43 | default="cuda",
44 | choices=["cuda", "cpu"],
45 | )
46 | parser.add_argument(
47 | "--num_samples",
48 | type=int,
49 | default=512,
50 | )
51 |
52 |
53 | args = parser.parse_args()
54 | # model_id_or_path = default_config.hf_model_dir
55 | # quantized_model_dir = default_config.int4_gptq_model_dir
56 | tokenizer = AutoTokenizer.from_pretrained(
57 | args.tokenizer_dir, use_fast=True, trust_remote_code=True
58 | )
59 |
60 |
61 | dataset_cnn = load_dataset(
62 | "ccdv/cnn_dailymail",
63 | "3.0.0"
64 | )
65 | dataset = dataset_cnn["test"]
66 |
67 | num_samples = min(args.num_samples, len(dataset))
68 | examples = []
69 | for i in tqdm(range(num_samples), desc="tokenizing datasets"):
70 | line = dataset[i]["article"]
71 | line = line + ' TL;DR: '
72 | line = line.strip()
73 | line = line.replace(" n't", "n't")
74 | # use make_content to generate prompt
75 | raw_text, _ = make_context(
76 | tokenizer=tokenizer,
77 | query=line,
78 | history=[],
79 | )
80 | example = tokenizer(raw_text)
81 | examples.append(example)
82 |
83 | quantize_config = BaseQuantizeConfig(
84 | bits=4, # quantize model to 4-bit
85 | group_size=128, # it is recommended to set the value to 128
86 | desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad
87 | true_sequential=True,
88 | )
89 |
90 | print("model_path", args.hf_model_dir)
91 | model = (
92 | AutoGPTQForCausalLM.from_pretrained(
93 | args.hf_model_dir,
94 | quantize_config,
95 | trust_remote_code=True,
96 | use_flash_attn=False
97 | )
98 | .eval()
99 | # .cuda()
100 | )
101 | if args.device == "cuda":
102 | model.cuda()
103 | else:
104 | print("using cpu only support on Qwen 7b v1.0, not support on Qwen 7b v1.1 / Qwen 14b")
105 | print("loading model to run gptq, may need few minute...")
106 | # quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
107 | model.quantize(examples, cache_examples_on_gpu=False)
108 | print("quantized ok!")
109 |
110 | # save quantized model
111 | model.save_quantized(args.quant_ckpt_path, use_safetensors=True)
--------------------------------------------------------------------------------
/examples/qwen/quantize.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """
16 | Adapted from examples/quantization/hf_ptq.py
17 | """
18 |
19 | import argparse
20 | import random
21 | import numpy as np
22 | import torch
23 | from datasets import load_dataset
24 | from torch.utils.data import DataLoader
25 | from transformers import AutoModelForCausalLM, AutoTokenizer
26 |
27 | from tensorrt_llm._utils import str_dtype_to_torch
28 | from tensorrt_llm.logger import logger
29 | from tensorrt_llm.models.quantized.ammo import quantize_and_export
30 | import os
31 | import sys
32 |
33 | now_dir = os.path.dirname(os.path.abspath(__file__))
34 | sys.path.append(now_dir)
35 | from default_config import default_config
36 | from utils.utils import make_context
37 |
38 |
39 |
40 |
41 | def get_calib_dataloader(data="ccdv/cnn_dailymail",
42 | tokenizer=None,
43 | batch_size=1,
44 | calib_size=512,
45 | block_size=512):
46 | print("Loading calibration dataset")
47 | if data == "pileval":
48 | dataset = load_dataset(
49 | "json",
50 | data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst",
51 | split="train")
52 | dataset = dataset["text"][:calib_size]
53 | elif data == "ccdv/cnn_dailymail":
54 | dataset = load_dataset("ccdv/cnn_dailymail", name="3.0.0", split="train")
55 | dataset = dataset["article"][:calib_size]
56 | else:
57 | raise NotImplementedError
58 |
59 | tokenizer.pad_token_id = tokenizer.im_end_id
60 | # use this prompt to make chat model do summarize
61 | system_prompt = "You are a useful assistant, please directly output the corresponding summary according to the article entered by the user."
62 |
63 | # line_encoded = []
64 | new_dataset = []
65 | for i in range(len(dataset)):
66 | dataset[i] = dataset[i] + ' TL;DR: '
67 | dataset[i] = dataset[i].strip()
68 | dataset[i] = dataset[i].replace(" n't", "n't")
69 | # use make_content to generate prompt
70 | raw_text, input_id_list = make_context(
71 | tokenizer=tokenizer,
72 | query=dataset[i],
73 | history=[],
74 | system=system_prompt,
75 | )
76 | # input_id = torch.from_numpy(
77 | # np.array(input_id_list, dtype=np.int32)
78 | # ).type(torch.int32).unsqueeze(0)
79 | # input_id = input_id[:, -max_input_len:]
80 | # line_encoded.append(input_id)
81 | new_dataset.append(raw_text)
82 | batch_encoded = tokenizer.batch_encode_plus(
83 | dataset,
84 | return_tensors="pt",
85 | padding=True,
86 | max_length=block_size
87 | )
88 | batch_encoded = batch_encoded["input_ids"]
89 | batch_encoded = batch_encoded.cuda()
90 | calib_dataloader = DataLoader(batch_encoded,
91 | batch_size=batch_size,
92 | shuffle=False)
93 |
94 | return calib_dataloader
95 |
96 |
97 | def get_tokenizer(ckpt_path, **kwargs):
98 | logger.info(f"Loading tokenizer from {ckpt_path}")
99 | tokenizer = AutoTokenizer.from_pretrained(
100 | ckpt_path,
101 | padding_side="left",
102 | trust_remote_code=True,
103 | **kwargs
104 | )
105 | if tokenizer.pad_token is None:
106 | tokenizer.pad_token = tokenizer.eos_token
107 | return tokenizer
108 |
109 |
110 | def get_model(ckpt_path, dtype="float16"):
111 | logger.info(f"Loading model from {ckpt_path}")
112 | torch_dtype = str_dtype_to_torch(dtype)
113 | model = AutoModelForCausalLM.from_pretrained(
114 | ckpt_path,
115 | device_map="auto",
116 | trust_remote_code=True,
117 | torch_dtype=torch_dtype,
118 | )
119 | model.eval()
120 | model = model.to(memory_format=torch.channels_last)
121 | return model
122 |
123 |
124 | def get_args():
125 | parser = argparse.ArgumentParser(description=__doc__)
126 | parser.add_argument("--model_dir",
127 | type=str,
128 | required=False,
129 | default=default_config.hf_model_dir,
130 | help="Directory of a HF model checkpoint")
131 | parser.add_argument("--dtype", help="Model data type.", default="float16")
132 | parser.add_argument(
133 | "--qformat",
134 | type=str,
135 | choices=['fp8', 'int4_awq'],
136 | default='int4_awq',
137 | help='Quantization format. Currently only fp8 is supported. '
138 | 'For int8 smoothquant, use smoothquant.py instead. ')
139 | parser.add_argument("--calib_size",
140 | type=int,
141 | default=32,
142 | help="Number of samples for calibration.")
143 | parser.add_argument("--export_path", default=os.path.join(now_dir, "qwen_7b_4bit_gs128_awq.pt"))
144 | parser.add_argument('--seed', type=int, default=None, help='Random seed')
145 | args = parser.parse_args()
146 | return args
147 |
148 |
149 | def main():
150 | if not torch.cuda.is_available():
151 | raise EnvironmentError("GPU is required for inference.")
152 |
153 | args = get_args()
154 |
155 | if args.seed is not None:
156 | random.seed(args.seed)
157 | np.random.seed(args.seed)
158 |
159 | tokenizer = get_tokenizer(args.model_dir)
160 | model = get_model(args.model_dir, args.dtype)
161 |
162 | calib_dataloader = get_calib_dataloader(tokenizer=tokenizer,
163 | calib_size=args.calib_size)
164 | model = quantize_and_export(model,
165 | qformat=args.qformat,
166 | calib_dataloader=calib_dataloader,
167 | export_path=args.export_path)
168 |
169 |
170 | if __name__ == "__main__":
171 | main()
172 |
--------------------------------------------------------------------------------
/examples/qwen/requirements.txt:
--------------------------------------------------------------------------------
1 | datasets~=2.3.2
2 | rouge_score~=0.1.2
3 | # transformers~=4.31.0 # tensorrt-llm has installed
4 | transformers-stream-generator
5 | sentencepiece~=0.1.99
6 | tiktoken
7 | einops
8 |
9 | # optional dependencies
10 | uvicorn
11 | gradio==3.40.1
12 | mdtex2html
13 | sse_starlette
14 | aiohttp_sse_client
15 | openai==1.1.1
16 |
--------------------------------------------------------------------------------
/examples/qwen/smoothquant.py:
--------------------------------------------------------------------------------
1 | '''
2 | Utilities for SmoothQuant models
3 | '''
4 |
5 | import functools
6 | from collections import defaultdict
7 |
8 | import torch
9 | import torch.nn as nn
10 | from tqdm import tqdm
11 | from transformers.pytorch_utils import Conv1D
12 | import numpy as np
13 | import os
14 | import sys
15 | project_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
16 | sys.path.append(project_dir)
17 | from utils.utils import make_context
18 |
19 |
20 | @torch.no_grad()
21 | def apply_smoothing(
22 | scales,
23 | gemm_weights,
24 | rmsnorm_weights=None,
25 | dtype=torch.float32,
26 | rmsnorm_1p=False
27 | ):
28 | if not isinstance(gemm_weights, list):
29 | gemm_weights = [gemm_weights]
30 |
31 | if rmsnorm_weights is not None:
32 | assert rmsnorm_weights.numel() == scales.numel()
33 | rmsnorm_weights.div_(scales).to(dtype)
34 | if rmsnorm_1p:
35 | rmsnorm_weights += (1 / scales) - 1
36 |
37 | for gemm in gemm_weights:
38 | gemm.mul_(scales.view(1, -1)).to(dtype)
39 |
40 |
41 | @torch.no_grad()
42 | def smooth_gemm(gemm_weights,
43 | act_scales,
44 | rmsnorm_weights=None,
45 | alpha=0.5,
46 | weight_scales=None):
47 | if not isinstance(gemm_weights, list):
48 | gemm_weights = [gemm_weights]
49 | orig_dtype = gemm_weights[0].dtype
50 |
51 | for gemm in gemm_weights:
52 | # gemm_weights are expected to be transposed
53 | assert gemm.shape[1] == act_scales.numel()
54 |
55 | if weight_scales is None:
56 | weight_scales = torch.cat(
57 | [gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights],
58 | dim=0)
59 | weight_scales = weight_scales.max(dim=0)[0]
60 | weight_scales.to(float).clamp(min=1e-5)
61 | scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) /
62 | weight_scales.pow(1 - alpha)).clamp(min=1e-5)
63 |
64 | apply_smoothing(scales, gemm_weights, rmsnorm_weights, orig_dtype)
65 |
66 | return scales
67 |
68 |
69 | @torch.no_grad()
70 | def smooth_gemm_mlp(
71 | w1_weights,
72 | w2_weights,
73 | act_scales,
74 | rmsnorm_weights=None,
75 | alpha=0.5,
76 | weight_scales=None
77 | ):
78 | gemm_weights = []
79 | if not isinstance(w1_weights, list):
80 | w1_weights = [w1_weights]
81 | if not isinstance(w2_weights, list):
82 | w2_weights = [w2_weights]
83 |
84 | for i in range(len(w1_weights)):
85 | gemm_weight = torch.cat([w1_weights[i], w2_weights[i]], dim=0)
86 | gemm_weights.append(gemm_weight)
87 |
88 | orig_dtype = gemm_weights[0].dtype
89 |
90 | for gemm in gemm_weights:
91 | # gemm_weights are expected to be transposed
92 | assert gemm.shape[1] == act_scales.numel()
93 |
94 | if weight_scales is None:
95 | weight_scales = torch.cat(
96 | [gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights],
97 | dim=0)
98 | weight_scales = weight_scales.max(dim=0)[0]
99 | weight_scales.to(float).clamp(min=1e-5)
100 | scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) /
101 | weight_scales.pow(1 - alpha)).clamp(min=1e-5)
102 |
103 | apply_smoothing(scales, w1_weights + w2_weights, rmsnorm_weights, orig_dtype)
104 |
105 | return scales
106 |
107 |
108 | @torch.no_grad()
109 | def smooth_ln_fcs(ln, fcs, act_scales, alpha=0.5):
110 | if not isinstance(fcs, list):
111 | fcs = [fcs]
112 | for fc in fcs:
113 | assert isinstance(fc, nn.Linear)
114 | assert ln.weight.numel() == fc.in_features == act_scales.numel()
115 |
116 | device, dtype = fcs[0].weight.device, fcs[0].weight.dtype
117 | act_scales = act_scales.to(device=device, dtype=dtype)
118 | weight_scales = torch.cat(
119 | [fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0)
120 | weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)
121 |
122 | scales = (act_scales.pow(alpha) /
123 | weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype)
124 |
125 | if ln is not None:
126 | ln.weight.div_(scales)
127 | ln.bias.div_(scales)
128 |
129 | for fc in fcs:
130 | fc.weight.mul_(scales.view(1, -1))
131 | return scales
132 |
133 |
134 | @torch.no_grad()
135 | def capture_activation_range(
136 | model,
137 | tokenizer,
138 | dataset,
139 | system_prompt,
140 | chat_format,
141 | max_input_len,
142 | num_samples=512,
143 | ):
144 | model.eval()
145 | device = next(model.parameters()).device
146 | act_scales = defaultdict(lambda: {"x": None, "y": None, "w": None})
147 |
148 | def stat_tensor(name, tensor, act_scales, key):
149 | hidden_dim = tensor.shape[-1]
150 | tensor = tensor.view(-1, hidden_dim).abs().detach()
151 | comming_max = torch.max(tensor, dim=0)[0].float()
152 |
153 | if act_scales[name][key] is None:
154 | act_scales[name][key] = comming_max
155 | else:
156 | act_scales[name][key] = torch.max(act_scales[name][key],
157 | comming_max)
158 |
159 | def stat_input_hook(m, x, y, name):
160 | if isinstance(x, tuple):
161 | x = x[0]
162 | stat_tensor(name, x, act_scales, "x")
163 | stat_tensor(name, y, act_scales, "y")
164 |
165 | if act_scales[name]["w"] is None:
166 | act_scales[name]["w"] = m.weight.abs().clip(1e-8,
167 | None).max(dim=1)[0]
168 |
169 | hooks = []
170 | for name, m in model.named_modules():
171 | if isinstance(m, nn.Linear) or isinstance(m, Conv1D):
172 | hooks.append(
173 | m.register_forward_hook(
174 | functools.partial(stat_input_hook, name=name)))
175 | num_samples = min(num_samples, len(dataset))
176 | for i in tqdm(range(num_samples), desc="calibrating model"):
177 | line = dataset[i]["article"]
178 | line = line + ' TL;DR: '
179 | line = line.strip()
180 | line = line.replace(" n't", "n't")
181 | # use make_content to generate prompt
182 | _, input_id_list = make_context(
183 | tokenizer=tokenizer,
184 | query=line,
185 | history=[],
186 | system=system_prompt,
187 | chat_format=chat_format,
188 | max_input_length=max_input_len
189 | )
190 | line_encoded = torch.from_numpy(
191 | np.array(input_id_list, dtype=np.int32)
192 | ).type(torch.int32).unsqueeze(0)
193 | line_encoded = line_encoded.to(device)
194 | # input_ids = tokenizer(dataset[i]["text"],
195 | # return_tensors="pt",
196 | # max_length=seq_len,
197 | # truncation=True).input_ids.to(device)
198 | # model(input_ids)
199 | model(line_encoded)
200 |
201 | for h in hooks:
202 | h.remove()
203 |
204 | return act_scales
205 |
--------------------------------------------------------------------------------
/examples/qwen/test/test_dynamic_ntk.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from collections import OrderedDict
3 | import numpy as np
4 | import torch
5 | from polygraphy.backend.trt import EngineFromNetwork, TrtRunner, CreateConfig, Profile
6 | import tensorrt_llm
7 | from tensorrt_llm import Tensor
8 | import math
9 | import tensorrt as trt
10 | import numpy as np
11 | from tensorrt_llm.layers import Embedding
12 | from tensorrt_llm import str_dtype_to_trt
13 | from parameterized import parameterized
14 | from tensorrt_llm.functional import (
15 | Tensor, shape, concat, constant, arange, outer, unary,
16 | partial, expand, elementwise_binary, shape, pow, cos, sin, slice, maximum
17 | )
18 | log = partial(unary, op=trt.UnaryOperation.LOG)
19 | ceil = partial(unary, op=trt.UnaryOperation.CEIL)
20 | div = partial(elementwise_binary, op=trt.ElementWiseOperation.DIV)
21 | gt = partial(elementwise_binary, op=trt.ElementWiseOperation.GREATER)
22 |
23 |
24 |
25 | class RotaryEmbedding(tensorrt_llm.Module):
26 | def __init__(self, per_head_dim=128, seq_length=8192, base=10000.0) -> None:
27 | self.per_head_dim = per_head_dim
28 | self.seq_length = seq_length
29 | self.base = base
30 | super().__init__()
31 | # self.position_embedding_cos = Embedding(
32 | # seq_length,
33 | # per_head_dim,
34 | # dtype=trt.float32
35 | # )
36 | # self.position_embedding_sin = Embedding(
37 | # seq_length,
38 | # per_head_dim,
39 | # dtype=trt.float32
40 | # )
41 |
42 | def forward(self, input_ids):
43 | # implement for old
44 | batch_size = shape(input_ids, 0)
45 | input_len = shape(input_ids, 1)
46 | # pytorch impl
47 | # context_value = math.log(true_seq_len / self.seq_length, 2) + 1
48 | # ntk_alpha = 2 ** math.ceil(context_value) - 1
49 | # ntk_alpha = max(ntk_alpha, 1)
50 |
51 | # trt impl
52 | # with tensorrt_llm.precision("float32"):
53 | context_value = log(input_len.cast(trt.float32) / float(self.seq_length)) / math.log(2) + 1.0
54 | ntk_alpha = pow(constant(np.array(2, dtype=np.float32)), ceil(context_value)) - 1.0
55 |
56 | ntk_alpha = maximum(ntk_alpha, constant(np.array(1.0, dtype=np.float32)))
57 | base = constant(np.array(self.base, dtype=np.float32))
58 | base = base * pow(ntk_alpha, (self.per_head_dim / (self.per_head_dim - 2)))
59 | temp1 = constant(np.arange(0, self.per_head_dim, 2, dtype=np.float32) / self.per_head_dim)
60 | temp2 = pow(base, temp1)
61 | inv_freq = div(
62 | constant(np.array(1, dtype=np.float32)),
63 | temp2
64 | )
65 | # temp_length = f_max(2 * input_len, 16)
66 | seq = arange(constant(np.array(0, dtype=np.int32)), input_len * 2, dtype="int32")
67 | # with tensorrt_llm.precision("float32"):
68 | freqs = outer(seq.cast(trt.float32), inv_freq)
69 | emb = concat([freqs, freqs], dim=1)
70 | # emb = rearrange(emb, "n d -> 1 n 1 d")
71 | emb = emb.view(concat([1, input_len * 2, 1, self.per_head_dim]))
72 | emb = expand(emb, concat([batch_size, input_len * 2, 1, self.per_head_dim]))
73 |
74 | # with tensorrt_llm.precision("float32"):
75 | # cos, sin = emb.cos(), emb.sin()
76 | cos_res = cos(emb)
77 | sin_res = sin(emb)
78 | # position_embedding_cos = cos[:, :input_len]
79 | # position_embedding_sin = sin[:, :input_len]
80 | position_embedding_cos = slice(
81 | input=cos_res,
82 | starts=concat([0, 0, 0, 0]),
83 | sizes=concat([batch_size, input_len, 1, self.per_head_dim]),
84 | )
85 | position_embedding_sin = slice(
86 | input=sin_res,
87 | starts=concat([0, 0, 0, 0]),
88 | sizes=concat([batch_size, input_len, 1, self.per_head_dim]),
89 | )
90 |
91 | # self.register_network_output("my_cos", identity_op(position_embedding_cos))
92 | # self.register_network_output("my_sin", identity_op(position_embedding_sin))
93 | # expand_dims(position_embedding_cos, [batch_size, 1, 1, 1])
94 | rotary_pos_emb = [
95 | (position_embedding_cos, position_embedding_sin),
96 | (position_embedding_cos, position_embedding_sin),
97 | ]
98 | return rotary_pos_emb
99 |
100 |
101 |
102 | class TestFunctional(unittest.TestCase):
103 |
104 | per_head_dim = 128
105 | seq_length = 8192
106 | base = 10000.0
107 | vocab_size = 151936
108 |
109 | def setUp(self):
110 | tensorrt_llm.logger.set_level('error')
111 |
112 | @parameterized.expand([('float32', 9886), ('float32', 1886), ('float16', 1886), ('float16', 9886)])
113 | def test_case(self, dtype, input_length):
114 |
115 |
116 | def test_trt(feed_dict: dict):
117 | # construct trt network
118 | builder = tensorrt_llm.Builder()
119 | net = builder.create_network()
120 | with tensorrt_llm.net_guard(net):
121 | input_ids = Tensor(
122 | name='input_ids',
123 | shape=[-1, -1],
124 | dtype=trt.int32,
125 | dim_range=OrderedDict([
126 | ("batch_size", [[1, 1, 1]]),
127 | ("seq_length", [[1, 10 * 1024, 32 * 1024]])
128 | ])
129 | )
130 | # position_ids = Tensor(
131 | # name='position_ids',
132 | # shape=[-1, -1],
133 | # dtype=trt.int32,
134 | # dim_range=OrderedDict([
135 | # ("batch_size", [[1, 1, 1]]),
136 | # ("seq_length", [[1, 10 * 1024, 32 * 1024]])
137 | # ])
138 | # )
139 | model = RotaryEmbedding(per_head_dim=self.per_head_dim, seq_length=self.seq_length)
140 | outputs = model.forward(input_ids=input_ids)
141 | # net._mark_output(outputs[0][0], 'cos', tensorrt_llm.str_dtype_to_trt(dtype))
142 | # net._mark_output(outputs[0][1], 'sin', tensorrt_llm.str_dtype_to_trt(dtype))
143 | net._mark_output(outputs[0][0], 'cos', trt.float32)
144 | net._mark_output(outputs[0][1], 'sin', trt.float32)
145 |
146 | for k, v in model.named_network_outputs():
147 | # net._mark_output(v, k, tensorrt_llm.str_dtype_to_trt(dtype))
148 | net._mark_output(v, k, trt.float32)
149 | # for build and run
150 | profile = Profile().add(
151 | "input_ids", min=(1, 1), opt=(1, 1), max=(2, 16 * 1024)
152 | )
153 | build_engine = EngineFromNetwork(
154 | (builder.trt_builder, net.trt_network),
155 | config=CreateConfig(
156 | fp16=(dtype == 'float16'),
157 | precision_constraints="obey",
158 | profiles=[profile]
159 | )
160 | )
161 | with TrtRunner(build_engine) as runner:
162 | outputs = runner.infer(feed_dict=feed_dict)
163 | return outputs
164 |
165 | def test_pytorch(input_tensor: torch.tensor):
166 | pt_input_len = input_tensor.shape[1]
167 | # upper for old
168 | # lower for pure pytorch for fp32 consistency(code in above used fp64 by python)
169 | pt_context_value = math.log(pt_input_len / self.seq_length, 2) + 1
170 | # pt_context_value = torch.log(torch.Tensor([input_seq_len * 1. / self.seq_length]).cuda()) / torch.log(torch.Tensor([2.]).cuda()) + 1
171 |
172 | pt_ntk_alpha = 2 ** math.ceil(pt_context_value) - 1
173 | # pt_ntk_alpha = torch.Tensor([2]).cuda() ** torch.ceil(pt_context_value) - 1
174 |
175 | pt_ntk_alpha = max(pt_ntk_alpha, 1.0)
176 |
177 | pt_ntk_alpha = pt_ntk_alpha ** (self.per_head_dim / (self.per_head_dim - 2))
178 |
179 | pt_base = torch.Tensor([self.base]).cuda()
180 | pt_base = pt_base * pt_ntk_alpha
181 | pt_temp1 = (torch.arange(0, self.per_head_dim, 2).float() / self.per_head_dim).cuda()
182 | pt_temp2 = torch.pow(pt_base, pt_temp1) # base ** temp1
183 | pt_inv_freq = 1.0 / pt_temp2
184 | pt_seq = torch.arange(0, pt_input_len * 2).int().cuda()
185 | pt_freqs = torch.outer(pt_seq.type_as(pt_inv_freq), pt_inv_freq)
186 | pt_emb = torch.cat((pt_freqs, pt_freqs), dim=-1)
187 | # emb = rearrange(emb, "n d -> 1 n 1 d")
188 | pt_emb = pt_emb.unsqueeze(0).unsqueeze(2)
189 | pt_cos, pt_sin = pt_emb.cos(), pt_emb.sin()
190 | pt_cos = pt_cos[:, :pt_input_len]
191 | pt_sin = pt_sin[:, :pt_input_len]
192 | print("pt_cos shpae/mean/sum/dtype", pt_cos.shape, pt_cos.mean(), pt_cos.sum(), pt_cos.dtype)
193 | print("pt_sin shpae/mean/sum/dtype", pt_sin.shape, pt_sin.mean(), pt_sin.sum(), pt_sin.dtype)
194 | return pt_cos, pt_sin
195 |
196 |
197 |
198 | pt_batch_size = 1
199 | # pt_input_len = 9886
200 | pt_input_len = input_length
201 | print("\ndtype", dtype, "input_length", input_length)
202 | input_tensor = torch.randint(1, self.vocab_size, [pt_batch_size, pt_input_len], dtype=torch.int32)
203 | # position_tensor = torch.arange(0, pt_input_len, dtype=torch.int32).unsqueeze(0).expand([pt_batch_size, pt_input_len])
204 | # print("position_tensor shape", position_tensor.shape)
205 | pt_cos, pt_sin = test_pytorch(input_tensor)
206 | outputs = test_trt(
207 | feed_dict={
208 | "input_ids": input_tensor.numpy(),
209 | }
210 | )
211 |
212 | # import pdb; pdb.set_trace()
213 |
214 | # np.testing.assert_allclose(ntk_alpha.cpu().numpy(), outputs['ntk_alpha'], rtol=0, atol=0)
215 | # np.testing.assert_allclose(base.cpu().numpy(), outputs['base'], rtol=0, atol=0)
216 | # np.testing.assert_allclose(temp1.cpu().numpy(), outputs['temp1'], rtol=0, atol=0)
217 | # np.testing.assert_allclose(temp2.cpu().numpy(), outputs['temp2'], rtol=0, atol=0)
218 | # np.testing.assert_allclose(seq.cpu().numpy(), outputs['seq'], rtol=1e-9, atol=1e-9)
219 | # np.testing.assert_allclose(inv_freq.cpu().numpy(), outputs['inv_freq'], rtol=1e-9, atol=1e-9)
220 | # np.testing.assert_allclose(pt_freqs.cpu().numpy(), outputs['freqs'], rtol=1e-9, atol=1e-9)
221 | print("cos shpae/mean/sum/dtype", outputs["cos"].shape, outputs["cos"].mean(), outputs["cos"].sum(), outputs["cos"].dtype)
222 | print("sin shpae/mean/sum/dtype", outputs["sin"].shape, outputs["sin"].mean(), outputs["sin"].sum(), outputs["sin"].dtype)
223 | np.testing.assert_allclose(pt_cos.cpu().numpy(), outputs['cos'], rtol=1e-5, atol=1e-5)
224 | np.testing.assert_allclose(pt_sin.cpu().numpy(), outputs['sin'], rtol=1e-5, atol=1e-5)
225 |
226 | if __name__ == "__main__":
227 | unittest.main()
--------------------------------------------------------------------------------
/examples/qwen/test/test_logn.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import numpy as np
4 | import torch
5 | from polygraphy.backend.trt import EngineFromNetwork, TrtRunner, CreateConfig
6 | import tensorrt_llm
7 | from tensorrt_llm import Tensor
8 | import math
9 | import tensorrt as trt
10 | import numpy as np
11 | from parameterized import parameterized
12 | from tensorrt_llm.parameter import Parameter
13 | from tensorrt_llm.functional import (
14 | Tensor, shape, concat, constant, arange, outer, unary,
15 | partial, expand, elementwise_binary, shape, pow, cos, sin, slice, expand_dims_like, repeat_interleave, str_dtype_to_trt
16 | )
17 | log = partial(unary, op=trt.UnaryOperation.LOG)
18 | ceil = partial(unary, op=trt.UnaryOperation.CEIL)
19 | div = partial(elementwise_binary, op=trt.ElementWiseOperation.DIV)
20 |
21 |
22 | class MyLogn(tensorrt_llm.Module):
23 | def __init__(self, dtype, seq_length, head_size, per_head_dim) -> None:
24 | super().__init__()
25 | self.dtype = dtype
26 | self.seq_length = seq_length
27 | self.head_size = head_size
28 | self.per_head_dim = per_head_dim
29 | logn_array = np.array([
30 | np.log(i) / np.log(self.seq_length) if i > self.seq_length else 1
31 | for i in range(1, 32768)
32 | ],
33 | dtype=np.float32
34 | ).reshape(1, -1, 1, 1)
35 | self.logn_tensor = Parameter(
36 | value=logn_array,
37 | dtype=trt.float32,
38 | shape=[1, 32767, 1, 1],
39 | )
40 |
41 | def forward(self, key, query):
42 | seq_start = slice(shape(key), [1], [1]) - slice(shape(query), [1], [1])
43 | seq_end = slice(shape(key), [1], [1])
44 |
45 | logn_shape = self.logn_tensor.value.shape
46 | logn_tensor = slice(
47 | input=self.logn_tensor.value,
48 | starts=concat([0, seq_start, 0, 0]),
49 | sizes=concat([logn_shape[0], seq_end - seq_start, logn_shape[2], logn_shape[3]]),
50 | )
51 | # logn_tensor2 = repeat_interleave(logn_tensor, self.head_size, 2)
52 | # logn_tensor2 = repeat_interleave(logn_tensor2, self.per_head_dim, 3)
53 | logn_tensor2 = expand(
54 | logn_tensor,
55 | concat([logn_shape[0], seq_end - seq_start, self.head_size, self.per_head_dim])
56 | )
57 | query2 = query.cast(trt.float32) * logn_tensor2
58 | query2 = query2.cast(self.dtype)
59 | return [logn_tensor2, query2]
60 |
61 |
62 |
63 |
64 | class TestFunctional(unittest.TestCase):
65 |
66 | head_size = 16
67 | per_head_dim = 128
68 | seq_length = 8192
69 | base = 10000.0
70 | dtype = 'float16'
71 |
72 |
73 | def setUp(self):
74 | tensorrt_llm.logger.set_level('error')
75 |
76 | @parameterized.expand([('float32', 9886), ('float32', 1886), ("float16", 9886), ("float16", 1886)])
77 | def test_case(self, dtype, input_length):
78 | self.dtype = dtype
79 | batch_size = 1
80 | # input_seq_len = 13727
81 | input_seq_len = input_length
82 | print("\ndtype", dtype, "input_length", input_length)
83 | if dtype == "float32":
84 | pt_key = torch.rand(
85 | [batch_size, input_seq_len, self.head_size, self.per_head_dim],
86 | dtype=torch.float32
87 | )
88 | pt_query = torch.rand(
89 | [batch_size, input_seq_len, self.head_size, self.per_head_dim],
90 | dtype=torch.float32
91 | )
92 | else:
93 | pt_key = torch.rand(
94 | [batch_size, input_seq_len, self.head_size, self.per_head_dim],
95 | dtype=torch.float16
96 | )
97 | pt_query = torch.rand(
98 | [batch_size, input_seq_len, self.head_size, self.per_head_dim],
99 | dtype=torch.float16
100 | )
101 |
102 |
103 | def test_trt(feed_dict: dict):
104 | builder = tensorrt_llm.Builder()
105 | net = builder.create_network()
106 | with tensorrt_llm.net_guard(net):
107 | key = Tensor(name='key',
108 | shape=pt_key.shape,
109 | dtype=tensorrt_llm.str_dtype_to_trt(self.dtype))
110 |
111 | query = Tensor(name='query',
112 | shape=pt_query.shape,
113 | dtype=tensorrt_llm.str_dtype_to_trt(self.dtype))
114 | model = MyLogn(
115 | dtype=dtype,
116 | seq_length=self.seq_length,
117 | head_size=self.head_size,
118 | per_head_dim=self.per_head_dim,
119 | )
120 | outputs = model.forward(query=query, key=key)
121 | net._mark_output(outputs[0], 'logn', str_dtype_to_trt(dtype))
122 | net._mark_output(outputs[1], 'query_output', str_dtype_to_trt(dtype))
123 | # net._mark_output(outputs[0], 'logn', trt.float32)
124 | # net._mark_output(outputs[1], 'query_output', trt.float32)
125 |
126 | for k, v in model.named_network_outputs():
127 | net._mark_output(v, k, tensorrt_llm.str_dtype_to_trt(dtype))
128 | # net._mark_output(v, k, trt.float32)
129 | # for new
130 | build_engine = EngineFromNetwork(
131 | (builder.trt_builder, net.trt_network),
132 | config=CreateConfig(
133 | fp16=(dtype == 'float16'),
134 | precision_constraints="obey",
135 | )
136 | )
137 | with TrtRunner(build_engine) as runner:
138 | outputs = runner.infer(feed_dict=feed_dict)
139 | # {"key": pt_key.numpy(), "query": pt_query.numpy()}
140 | return outputs
141 |
142 | def test_pytorch(pt_query, pt_key):
143 | # torch impl
144 | pt_logn_list = [
145 | math.log(i, self.seq_length) if i > self.seq_length else 1
146 | for i in range(1, 32768)
147 | ]
148 | pt_logn_tensor = torch.tensor(pt_logn_list, dtype=torch.float32)[None, :, None, None]
149 | pt_seq_start = pt_key.size(1) - pt_query.size(1)
150 | pt_seq_end = pt_key.size(1)
151 | pt_logn_tensor = pt_logn_tensor[:, pt_seq_start: pt_seq_end, :, :].type_as(pt_query)
152 | pt_logn_tensor2 = pt_logn_tensor.expand_as(pt_query)
153 | pt_logn_tensor2 = pt_logn_tensor2.to(torch.float32)
154 | raw_type = pt_query.dtype
155 | pt_query2 = pt_query.to(torch.float32) * pt_logn_tensor2
156 | pt_logn_tensor2 = pt_logn_tensor2.to(raw_type)
157 | pt_query2 = pt_query2.to(raw_type)
158 | print(
159 | "pt_logn2 shpae/mean/sum/dtype",
160 | pt_logn_tensor2.shape,
161 | pt_logn_tensor2.to(torch.float32).mean().item(),
162 | pt_logn_tensor2.to(torch.float32).sum().item(),
163 | pt_logn_tensor2.dtype
164 | )
165 | print(
166 | "pt_query2 shpae/mean/sum/dtype",
167 | pt_query2.shape,
168 | pt_query2.to(torch.float32).mean(),
169 | pt_query2.to(torch.float32).sum(),
170 | pt_query2.dtype
171 | )
172 | return [pt_logn_tensor2, pt_query2]
173 |
174 |
175 | (pt_logn2, pt_query2) = test_pytorch(pt_query=pt_query, pt_key=pt_key)
176 | outputs = test_trt(feed_dict={"key": pt_key.numpy(), "query": pt_query.numpy()})
177 | rtol = atol = 1e-9
178 | print(
179 | "logn shpae/mean/sum/dtype",
180 | outputs['logn'].shape,
181 | outputs['logn'].astype(np.float32).mean(),
182 | outputs['logn'].astype(np.float32).sum(),
183 | outputs['logn'].dtype
184 | )
185 | print(
186 | "query_output shpae/mean/sum/dtype",
187 | outputs['query_output'].shape,
188 | outputs['query_output'].astype(np.float32).mean(),
189 | outputs['query_output'].astype(np.float32).sum(),
190 | outputs['query_output'].dtype
191 | )
192 | np.testing.assert_allclose(pt_logn2.cpu().numpy(), outputs['logn'], rtol=rtol, atol=atol)
193 | np.testing.assert_allclose(pt_query2.cpu().numpy(), outputs['query_output'], rtol=rtol, atol=atol)
194 |
195 | if __name__ == "__main__":
196 | unittest.main()
--------------------------------------------------------------------------------
/examples/qwen/test/test_rms_norm.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import numpy as np
4 | import torch
5 | from parameterized import parameterized
6 | from polygraphy.backend.trt import CreateConfig, EngineFromNetwork, TrtRunner
7 | from transformers.models.llama.modeling_llama import LlamaRMSNorm
8 |
9 | import tensorrt_llm
10 | from tensorrt_llm import Tensor
11 | # from tensorrt_llm.quantization.functional import smooth_quant_rms_norm
12 | from model import rms_norm_op
13 |
14 |
15 | class TestFunctional(unittest.TestCase):
16 |
17 | def setUp(self):
18 | tensorrt_llm.logger.set_level('error')
19 |
20 | @parameterized.expand([('float16',), ('float32',)])
21 | def test_rms_norm_plugin(self, dtype):
22 | print("test smooth quant rms norm plugin")
23 | test_shape = [2, 5, 10, 10]
24 |
25 | x_data = torch.randn(
26 | *test_shape, dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype))
27 |
28 | m = LlamaRMSNorm(test_shape[-1]) # LlamaRMSNorm only supports last dim
29 |
30 | with torch.no_grad():
31 | # pytorch run
32 | with torch.no_grad():
33 | ref = m(x_data).to(dtype=torch.float32)
34 |
35 | # construct trt network
36 | builder = tensorrt_llm.Builder()
37 | net = builder.create_network()
38 | # net.plugin_config.set_rmsnorm_quantization_plugin(dtype)
39 | with tensorrt_llm.net_guard(net):
40 | network = tensorrt_llm.default_trtnet()
41 | x = Tensor(name='x',
42 | shape=x_data.shape,
43 | dtype=tensorrt_llm.str_dtype_to_trt(dtype))
44 |
45 | output = rms_norm_op(
46 | x,
47 | dtype,
48 | test_shape[-1],
49 | weight=tensorrt_llm.constant(m.weight.detach().cpu().numpy()),
50 | eps=m.variance_epsilon,
51 | )
52 | output = output.trt_tensor
53 | output.name = 'output'
54 | network.mark_output(output)
55 | # output.dtype = tensorrt_llm.str_dtype_to_trt('int8')
56 |
57 | # trt run
58 | build_engine = EngineFromNetwork(
59 | (builder.trt_builder, net.trt_network),
60 | config=CreateConfig(fp16=(dtype == 'float16'),
61 | precision_constraints="obey"))
62 | assert build_engine is not None, "Build engine failed"
63 | with TrtRunner(build_engine) as runner:
64 | outputs = runner.infer(feed_dict={'x': x_data.cpu().numpy()})
65 |
66 | # compare diff of quantized output
67 | # Set absolute tolerance to 1 to mitigate some rounding error
68 | np.testing.assert_allclose(ref.cpu().numpy(),
69 | outputs['output'],
70 | atol=1,
71 | rtol=0)
72 |
73 | # compare diff of dynamic activation scales
74 | print("max diff", np.max(np.abs(ref.cpu().numpy() - outputs["output"])))
75 |
76 |
77 | if __name__ == '__main__':
78 | unittest.main()
79 |
--------------------------------------------------------------------------------
/examples/qwen/test/test_smooth_quant_rms_norm.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import numpy as np
4 | import torch
5 | from parameterized import parameterized
6 | from polygraphy.backend.trt import CreateConfig, EngineFromNetwork, TrtRunner
7 | from transformers.models.llama.modeling_llama import LlamaRMSNorm
8 |
9 | import tensorrt_llm
10 | from tensorrt_llm import Parameter, Tensor
11 | # from tensorrt_llm.quantization.functional import smooth_quant_rms_norm
12 | from utils.quantization import smooth_quant_rms_norm_op
13 |
14 |
15 | class TestFunctional(unittest.TestCase):
16 |
17 | def setUp(self):
18 | tensorrt_llm.logger.set_level('error')
19 |
20 | @parameterized.expand([('float16', False), ('float16', True),
21 | ('float32', False), ('float32', True)])
22 | def test_smooth_quant_rms_norm_plugin(self, dtype, dynamic_act_scaling):
23 | print("test smooth quant rms norm plugin")
24 | test_shape = [2, 5, 10, 10]
25 |
26 | x_data = torch.randn(
27 | *test_shape, dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype))
28 |
29 | m = LlamaRMSNorm(test_shape[-1]) # LlamaRMSNorm only supports last dim
30 |
31 | scale_data = torch.randint(2, 32, (1, ), dtype=torch.float32)
32 |
33 | with torch.no_grad():
34 |
35 | def cast_to_int8_with_sat(tensor):
36 | return tensor.round().clip(-128, 127).to(dtype=torch.int8)
37 |
38 | # pytorch run
39 | with torch.no_grad():
40 | ref = m(x_data).to(dtype=torch.float32)
41 | if dynamic_act_scaling:
42 | abs_max_f, _ = ref.abs().max(dim=-1, keepdim=True)
43 | dynamic_scale = abs_max_f / 127.0
44 | ref_quantized = cast_to_int8_with_sat(ref *
45 | (127.0 / abs_max_f))
46 | else:
47 | ref_quantized = cast_to_int8_with_sat(ref * scale_data)
48 |
49 | # construct trt network
50 | builder = tensorrt_llm.Builder()
51 | net = builder.create_network()
52 | # net.plugin_config.set_rmsnorm_quantization_plugin(dtype)
53 | with tensorrt_llm.net_guard(net):
54 | network = tensorrt_llm.default_trtnet()
55 | x = Tensor(name='x',
56 | shape=x_data.shape,
57 | dtype=tensorrt_llm.str_dtype_to_trt(dtype))
58 |
59 | output = smooth_quant_rms_norm_op(
60 | x,
61 | dtype,
62 | test_shape[-1],
63 | weight=tensorrt_llm.constant(m.weight.detach().cpu().numpy()),
64 | scale=Parameter(scale_data.cpu().numpy()).value,
65 | eps=m.variance_epsilon,
66 | dynamic_act_scaling=dynamic_act_scaling)
67 |
68 | if dynamic_act_scaling:
69 | output, dynamic_scales = output
70 | dynamic_scales = dynamic_scales.trt_tensor
71 | dynamic_scales.name = 'dynamic_scales'
72 | network.mark_output(dynamic_scales)
73 | dynamic_scales.dtype = tensorrt_llm.str_dtype_to_trt('float32')
74 |
75 | output = output.trt_tensor
76 | output.name = 'output'
77 | network.mark_output(output)
78 | output.dtype = tensorrt_llm.str_dtype_to_trt('int8')
79 |
80 | # trt run
81 | build_engine = EngineFromNetwork(
82 | (builder.trt_builder, net.trt_network),
83 | config=CreateConfig(int8=True,
84 | fp16=(dtype == 'float16'),
85 | precision_constraints="obey"))
86 | assert build_engine is not None, "Build engine failed"
87 | with TrtRunner(build_engine) as runner:
88 | outputs = runner.infer(feed_dict={'x': x_data.cpu().numpy()})
89 |
90 | # compare diff of quantized output
91 | # Set absolute tolerance to 1 to mitigate some rounding error
92 | np.testing.assert_allclose(ref_quantized.cpu().numpy(),
93 | outputs['output'],
94 | atol=1,
95 | rtol=0)
96 |
97 | # compare diff of dynamic activation scales
98 | if dynamic_act_scaling:
99 | np.testing.assert_allclose(dynamic_scale.cpu().numpy(),
100 | outputs['dynamic_scales'],
101 | atol=1e-2)
102 | print("max diff", np.max(np.abs(ref_quantized.cpu().numpy() - outputs["output"])))
103 |
104 | def test_sq_rms_norm_no_plugin(self):
105 | print("test seq rms norm no plugin")
106 | # Create builder
107 | builder = tensorrt_llm.Builder()
108 | # Create empty network
109 | net = builder.create_network()
110 | with tensorrt_llm.net_guard(net):
111 | tensorrt_llm.default_trtnet()
112 | # Get output tensor for SQ gemm
113 | with self.assertRaisesRegex(AssertionError, 'Unsupported dtype: 0'):
114 | smooth_quant_rms_norm_op(None, 0, None, None, None, 0)
115 |
116 |
117 | if __name__ == '__main__':
118 | unittest.main()
119 |
--------------------------------------------------------------------------------
/examples/qwen/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tlntin/Qwen-TensorRT-LLM/7da636fe7d55f42cebf3f2a43931dd0f1619efee/examples/qwen/utils/__init__.py
--------------------------------------------------------------------------------
/examples/qwen/utils/utils.py:
--------------------------------------------------------------------------------
1 | from transformers import PreTrainedTokenizer
2 | from typing import List, Tuple
3 |
4 |
5 | def make_context(
6 | tokenizer: PreTrainedTokenizer,
7 | query: str,
8 | history: List[Tuple[str, str]] = None,
9 | system: str = "You are a helpful assistant.",
10 | max_input_length: int = 2048, # if you want to change this, you need to change the max_input_len in tensorrt_llm_july-release-v1/examples/qwen/build.py
11 | max_window_size: int = 6144,
12 | chat_format: str = "chatml",
13 | ):
14 | if history is None:
15 | history = []
16 |
17 | if chat_format == "chatml":
18 | im_start, im_end = "<|im_start|>", "<|im_end|>"
19 | im_start_tokens = [tokenizer.im_start_id]
20 | im_end_tokens = [tokenizer.im_end_id]
21 | nl_tokens = tokenizer.encode("\n")
22 |
23 | def _tokenize_str(role, content):
24 | return (
25 | f"{role}\n{content}",
26 | tokenizer.encode(
27 | role,
28 | allowed_special=set(),
29 | ) + nl_tokens + tokenizer.encode(
30 | content,
31 | allowed_special=set(),
32 | )
33 | )
34 |
35 | system_text, system_tokens_part = _tokenize_str("system", system)
36 | system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
37 | raw_text = ""
38 | context_tokens = []
39 |
40 | for turn_query, turn_response in reversed(history):
41 | query_text, query_tokens_part = _tokenize_str("user", turn_query)
42 | query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
43 |
44 | response_text, response_tokens_part = _tokenize_str(
45 | "assistant", turn_response
46 | )
47 | response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
48 | next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
49 | prev_chat = (
50 | f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
51 | )
52 |
53 | current_context_size = (
54 | len(system_tokens) + len(next_context_tokens) + len(context_tokens)
55 | )
56 | if current_context_size < max_window_size:
57 | context_tokens = next_context_tokens + context_tokens
58 | raw_text = prev_chat + raw_text
59 | else:
60 | break
61 |
62 | context_tokens = system_tokens + context_tokens
63 | raw_text = f"{im_start}{system_text}{im_end}" + raw_text
64 | context_tokens += (
65 | nl_tokens
66 | + im_start_tokens
67 | + _tokenize_str("user", query)[1]
68 | + im_end_tokens
69 | + nl_tokens
70 | + im_start_tokens
71 | + tokenizer.encode("assistant")
72 | + nl_tokens
73 | )
74 | raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
75 |
76 | elif chat_format == "raw":
77 | raw_text = query
78 | context_tokens = tokenizer.encode(raw_text)
79 | else:
80 | raise NotImplementedError(f"Unknown chat format {chat_format!r}")
81 | # truncate to max_input_length, truncate from the front
82 | return raw_text, context_tokens[-max_input_length: ]
83 |
84 |
85 | def _decode_chatml(
86 | tokens: List[int],
87 | stop_words: List[str],
88 | eod_token_ids: List[int],
89 | tokenizer: PreTrainedTokenizer,
90 | raw_text_len: int,
91 | context_length: int,
92 | verbose: bool = False,
93 | return_end_reason: bool = False,
94 | errors: str='replace'
95 | ):
96 | end_reason = f"Gen length {len(tokens)}"
97 | eod_token_idx = context_length
98 | for eod_token_idx in range(context_length, len(tokens)):
99 | if tokens[eod_token_idx] in eod_token_ids:
100 | end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
101 | break
102 |
103 | trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)[raw_text_len:]
104 | if verbose:
105 | print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens, errors=errors)[raw_text_len:])
106 | print("\nRaw Generate:", trim_decode_tokens)
107 | print("\nEnd Reason:", end_reason)
108 | for stop_word in stop_words:
109 | trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
110 | trim_decode_tokens = trim_decode_tokens.strip()
111 | if verbose:
112 | print("\nGenerate:", trim_decode_tokens)
113 |
114 | if return_end_reason:
115 | return trim_decode_tokens, end_reason
116 | else:
117 | return trim_decode_tokens
118 |
119 |
120 | def get_stop_words_ids(chat_format, tokenizer):
121 | if chat_format == "raw":
122 | stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]]
123 | elif chat_format == "chatml":
124 | stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
125 | else:
126 | raise NotImplementedError(f"Unknown chat format {chat_format!r}")
127 | return stop_words_ids
--------------------------------------------------------------------------------
/examples/qwen/web_demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gradio as gr
3 | import mdtex2html
4 | from default_config import default_config
5 | from openai import OpenAI
6 |
7 |
8 | client = OpenAI(
9 | base_url="http://localhost:8000/v1",
10 | api_key="no api"
11 | )
12 |
13 | now_dir = os.path.dirname(os.path.abspath(__file__))
14 |
15 |
16 | """Override Chatbot.postprocess"""
17 | def postprocess(self, y):
18 | if y is None:
19 | return []
20 | for i, (message, response) in enumerate(y):
21 | y[i] = [
22 | None if message is None else mdtex2html.convert((message)),
23 | None if response is None else mdtex2html.convert(response),
24 | ]
25 | return y
26 |
27 |
28 | gr.Chatbot.postprocess = postprocess
29 |
30 |
31 | def parse_text(text):
32 | """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
33 | lines = text.split("\n")
34 | lines = [line for line in lines if line != ""]
35 | count = 0
36 | for i, line in enumerate(lines):
37 | if "```" in line:
38 | count += 1
39 | items = line.split('`')
40 | if count % 2 == 1:
41 | lines[i] = f''
42 | else:
43 | lines[i] = f'
'
44 | else:
45 | if i > 0:
46 | if count % 2 == 1:
47 | line = line.replace("`", "\`")
48 | line = line.replace("<", "<")
49 | line = line.replace(">", ">")
50 | line = line.replace(" ", " ")
51 | line = line.replace("*", "*")
52 | line = line.replace("_", "_")
53 | line = line.replace("-", "-")
54 | line = line.replace(".", ".")
55 | line = line.replace("!", "!")
56 | line = line.replace("(", "(")
57 | line = line.replace(")", ")")
58 | line = line.replace("$", "$")
59 | lines[i] = "
"+line
60 | text = "".join(lines)
61 | return text
62 |
63 |
64 | def predict(input_text, chatbot, top_p, temperature, max_generate_length, history):
65 | messages = [
66 | {"role": "system", "content": "You are a helpful assistant."},
67 | ]
68 | for (message, response) in history:
69 | messages.append({"role": "user", "content": message})
70 | messages.append({"role": "assistant", "content": response})
71 | messages.append({"role": "user", "content": input_text})
72 | chatbot.append((parse_text(input_text), ""))
73 | history.append((input_text, ""))
74 |
75 | response = client.chat.completions.create(
76 | model="gpt-3.5-turbo",
77 | messages=messages,
78 | top_p=top_p,
79 | temperature=temperature,
80 | n=1,
81 | max_tokens=max_generate_length,
82 | stream=True,
83 | )
84 | response_text = ""
85 | for event in response:
86 | event_text = event.choices[0].delta.content # extract the text
87 | if event_text is None:
88 | event_text = ""
89 | response_text += event_text
90 | chatbot[-1] = (parse_text(input_text), parse_text(response_text))
91 | history[-1] = (input_text, response_text)
92 | yield chatbot, history
93 | messages.append({"role": "assistant", "content": response_text})
94 |
95 |
96 | def reset_user_input():
97 | return gr.update(value='')
98 |
99 |
100 | def reset_state():
101 | return [], []
102 |
103 |
104 | with gr.Blocks() as demo:
105 | gr.HTML("""Qwen-7B-Chat (Power By TensorRT-LLM)
""")
106 |
107 | chatbot = gr.Chatbot()
108 | with gr.Row():
109 | with gr.Column(scale=4):
110 | with gr.Column(scale=12):
111 | user_input = gr.Textbox(
112 | show_label=False,
113 | placeholder="Input...",
114 | lines=10,
115 | container=False
116 | )
117 | with gr.Column(min_width=32, scale=1):
118 | submitBtn = gr.Button("Submit", variant="primary")
119 | with gr.Column(scale=1):
120 | emptyBtn = gr.Button("Clear History")
121 | top_p = gr.Slider(
122 | minimum=0,
123 | maximum=1,
124 | value=0.8,
125 | step=0.1,
126 | label="top-p",
127 | interactive=True
128 | )
129 | temperature = gr.Slider(
130 | minimum=0,
131 | maximum=1,
132 | value=1,
133 | step=0.1,
134 | label="temperature",
135 | interactive=True
136 | )
137 | max_generate_length = gr.Slider(
138 | 0,
139 | default_config.max_new_tokens,
140 | value=default_config.max_new_tokens // 2,
141 | step=1.0,
142 | label="Maximum generate length", interactive=True
143 | )
144 |
145 | history = gr.State([])
146 |
147 | submitBtn.click(
148 | predict, # call function
149 | [user_input, chatbot, top_p, temperature, max_generate_length, history], # inputs
150 | [chatbot, history], # outputs
151 | show_progress=True,
152 | )
153 | # reset input
154 | submitBtn.click(reset_user_input, [], [user_input])
155 |
156 | emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
157 |
158 | # demo.queue().launch(server_name="0.0.0.0", share=False, inbrowser=False)
159 | demo.queue().launch(server_name="localhost", share=False, inbrowser=False)
160 |
--------------------------------------------------------------------------------
/examples/qwen2/.gitignore:
--------------------------------------------------------------------------------
1 | qwen*
2 | Qwen*
3 | *.log
4 | c-model
5 | ccdv
6 | trt_engines
7 | hg_test.py
8 | rouge.tar.xz
9 | rouge
10 | ccdv___cnn_dailymail.tar.xz
11 | ccdv___cnn_dailymail
12 | lambada.tar.xz
13 | *.json
14 | .idea
15 |
--------------------------------------------------------------------------------
/examples/qwen2/cli_chat.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from run_old import get_model
4 | from run_old import Qwen2ForCausalLMGenerationSession
5 | from default_config import default_config
6 | import tensorrt_llm
7 |
8 | now_dir = os.path.dirname(os.path.abspath(__file__))
9 |
10 |
11 | def parse_arguments():
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--max_output_len', type=int, default=default_config.max_new_tokens)
14 | parser.add_argument('--log_level', type=str, default='error')
15 | parser.add_argument(
16 | '--engine_dir',
17 | type=str,
18 | default=default_config.engine_dir,
19 | )
20 | parser.add_argument(
21 | '--tokenizer_dir',
22 | type=str,
23 | default=default_config.tokenizer_dir,
24 | help="Directory containing the tokenizer.model."
25 | )
26 | parser.add_argument(
27 | '--stream',
28 | type=bool,
29 | default=True,
30 | help="return text with stream")
31 | return parser.parse_args()
32 |
33 |
34 | if __name__ == "__main__":
35 | # get model info
36 | args = parse_arguments()
37 | runtime_rank = tensorrt_llm.mpi_rank()
38 | (
39 |
40 | engine, model_config, sampling_config, runtime_mapping,
41 | tokenizer, eos_token_id, pad_token_id, stop_token_ids
42 | ) = get_model(args.tokenizer_dir, args.engine_dir, args.log_level, rank=runtime_rank)
43 | engine_buffer = engine.engine
44 | decoder = Qwen2ForCausalLMGenerationSession(
45 | model_config,
46 | engine_buffer,
47 | runtime_mapping,
48 | )
49 | history = []
50 | response = ''
51 | print("\n欢迎使用Qwen聊天机器人,输入exit退出,输入clear清空历史记录")
52 | while True:
53 | input_text = input("Input: ")
54 | if input_text in ["exit", "quit", "exit()", "quit()"]:
55 | break
56 | if input_text == 'clear':
57 | history = []
58 | continue
59 | if not args.stream:
60 | response = decoder.chat(
61 | pad_token_id=pad_token_id,
62 | tokenizer=tokenizer,
63 | sampling_config=sampling_config,
64 | input_text=input_text,
65 | history=history,
66 | max_new_tokens=args.max_output_len,
67 | )[0]
68 | print(f'Output: {response}')
69 | else:
70 | print("Output: ", end='')
71 |
72 | response = ""
73 | for new_text in decoder.chat_stream(
74 | stop_token_ids=stop_token_ids,
75 | pad_token_id=pad_token_id,
76 | tokenizer=tokenizer,
77 | sampling_config=sampling_config,
78 | input_text=input_text,
79 | history=history,
80 | max_new_tokens=args.max_output_len,
81 | ):
82 | print(new_text[0], end='', flush=True)
83 | response += new_text[0]
84 | print("")
85 | history.append((input_text, response))
--------------------------------------------------------------------------------
/examples/qwen2/default_config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | class DefaultConfig:
5 | now_dir = os.path.dirname(os.path.abspath(__file__))
6 | hf_model_dir = os.path.join(now_dir, "qwen1.5_7b_chat")
7 | tokenizer_dir = os.path.join(now_dir, "qwen1.5_7b_chat")
8 | int4_gptq_model_dir = os.path.join(now_dir, "qwen1.5_7b_chat_int4")
9 | ft_dir_path = os.path.join(now_dir, "c-model", "qwen1.5_7b_chat")
10 | engine_dir = os.path.join(now_dir, "trt_engines", "fp16", "1-gpu")
11 |
12 | # Maximum batch size for HF backend.
13 | hf_max_batch_size = 1
14 |
15 | # Maximum batch size for TRT-LLM backend.
16 | trt_max_batch_size = 1
17 |
18 | # choice the model format, base or chat
19 | # choices=["chatml", "raw"],
20 | chat_format = "chatml"
21 |
22 | # Maximum input length.
23 | max_input_len = 1024 * 6
24 |
25 | # Maximum number of generate new tokens.
26 | max_new_tokens = 2048
27 |
28 | max_output_len = max_new_tokens
29 |
30 | # Top p for sampling.
31 | top_p = 0.8
32 |
33 | # Top k for sampling.
34 | top_k = 50
35 |
36 | # Temperature for sampling.
37 | temperature = 1.0
38 |
39 |
40 | default_config = DefaultConfig()
41 |
--------------------------------------------------------------------------------
/examples/qwen2/gptq_convert.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer
2 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
3 | from default_config import default_config
4 | from argparse import ArgumentParser
5 | import os
6 | from datasets import load_dataset
7 | from tqdm import tqdm
8 | import sys
9 | import logging
10 |
11 | sys.path.append(os.path.dirname(os.path.abspath(__file__)))
12 |
13 |
14 | logging.basicConfig(
15 | level=logging.INFO,
16 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
17 | )
18 |
19 |
20 | parser = ArgumentParser()
21 | parser.add_argument(
22 | "--hf_model_dir",
23 | type=str,
24 | default=default_config.hf_model_dir,
25 | )
26 | parser.add_argument(
27 | '--tokenizer_dir',
28 | type=str,
29 | default=default_config.tokenizer_dir,
30 | help="Directory containing the tokenizer.model."
31 | )
32 | parser.add_argument(
33 | "--quant_ckpt_path",
34 | type=str,
35 | default=os.path.join(
36 | default_config.int4_gptq_model_dir,
37 | ),
38 | )
39 | parser.add_argument(
40 | "--device",
41 | type=str,
42 | default="cuda",
43 | choices=["cuda", "cpu"],
44 | )
45 | parser.add_argument(
46 | "--num_samples",
47 | type=int,
48 | default=512,
49 | )
50 |
51 |
52 | args = parser.parse_args()
53 | # model_id_or_path = default_config.hf_model_dir
54 | # quantized_model_dir = default_config.int4_gptq_model_dir
55 | tokenizer = AutoTokenizer.from_pretrained(
56 | args.tokenizer_dir, use_fast=True, trust_remote_code=True
57 | )
58 |
59 | dataset_cnn = load_dataset(
60 | "ccdv/cnn_dailymail",
61 | "3.0.0"
62 | )
63 | dataset = dataset_cnn["test"]
64 |
65 | num_samples = min(args.num_samples, len(dataset))
66 | examples = []
67 | system_prompt = "You are a useful assistant, please directly output the corresponding summary according to the article entered by the user."
68 | for i in tqdm(range(num_samples), desc="tokenizing datasets"):
69 | line = dataset[i]["article"]
70 | line = line + ' TL;DR: '
71 | line = line.strip()
72 | line = line.replace(" n't", "n't")
73 | # use make_content to generate prompt
74 | messages = [
75 | {"role": "system", "content": system_prompt},
76 | {"role": "user", "content": line}
77 | ]
78 | raw_text = tokenizer.apply_chat_template(
79 | messages,
80 | tokenize=False,
81 | add_generation_prompt=True
82 | )
83 | example = tokenizer(raw_text)
84 | examples.append(example)
85 |
86 | quantize_config = BaseQuantizeConfig(
87 | bits=4, # quantize model to 4-bit
88 | group_size=128, # it is recommended to set the value to 128
89 | desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad
90 | true_sequential=True,
91 | )
92 |
93 | print("model_path", args.hf_model_dir)
94 | model = (
95 | AutoGPTQForCausalLM.from_pretrained(
96 | args.hf_model_dir,
97 | quantize_config,
98 | )
99 | .eval()
100 | # .cuda()
101 | )
102 | if args.device == "cuda":
103 | model.cuda()
104 | else:
105 | print("using cpu only support on Qwen 7b v1.0, not support on Qwen 7b v1.1 / Qwen 14b")
106 | print("loading model to run gptq, may need few minute...")
107 | # quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
108 | model.quantize(examples, cache_examples_on_gpu=False)
109 | print("quantized ok!")
110 |
111 | # save quantized model
112 | model.save_quantized(args.quant_ckpt_path, use_safetensors=True)
--------------------------------------------------------------------------------
/examples/qwen2/pytorch_test.py:
--------------------------------------------------------------------------------
1 | # from transformers import AutoModelForCausalLM, AutoTokenizer
2 | from transformers.models.qwen2 import Qwen2ForCausalLM, Qwen2Tokenizer
3 | from default_config import default_config
4 | device = "cuda" # the device to load the model onto
5 |
6 |
7 | model = Qwen2ForCausalLM.from_pretrained(
8 | # "Qwen/Qwen1.5-72B-Chat",
9 | default_config.hf_model_dir,
10 | device_map="auto"
11 | ).half()
12 | tokenizer = Qwen2Tokenizer.from_pretrained(default_config.hf_model_dir)
13 |
14 | messages = [
15 | {"role": "system", "content": "You are a helpful assistant."},
16 | {"role": "user", "content": "你好,请问你叫什么?"}
17 | ]
18 | text = tokenizer.apply_chat_template(
19 | messages,
20 | tokenize=False,
21 | add_generation_prompt=True
22 | )
23 |
24 | print("Input Text: ", text)
25 | input_ids = tokenizer([text], return_tensors="pt").to(device).input_ids
26 | print("Input Shape: ", input_ids.shape)
27 |
28 | generated_ids = model.generate(
29 | input_ids,
30 | max_new_tokens=512
31 | )
32 | generated_ids = [
33 | output_ids[len(input_ids):]
34 | for input_ids, output_ids in zip(input_ids, generated_ids)
35 | ]
36 |
37 | response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
38 | print("Response: ", response)
--------------------------------------------------------------------------------
/examples/qwen2/quantize.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """
16 | Adapted from examples/quantization/hf_ptq.py
17 | """
18 |
19 | import argparse
20 | import random
21 | import numpy as np
22 | import torch
23 | from datasets import load_dataset
24 | from torch.utils.data import DataLoader
25 | from transformers import AutoModelForCausalLM, AutoTokenizer
26 |
27 | from tensorrt_llm._utils import str_dtype_to_torch
28 | from tensorrt_llm.logger import logger
29 | from tensorrt_llm.models.quantized.ammo import quantize_and_export
30 | import os
31 | import sys
32 |
33 | now_dir = os.path.dirname(os.path.abspath(__file__))
34 | sys.path.append(now_dir)
35 | from default_config import default_config
36 |
37 |
38 | def get_calib_dataloader(data="ccdv/cnn_dailymail",
39 | tokenizer=None,
40 | batch_size=1,
41 | calib_size=512,
42 | block_size=512):
43 | print("Loading calibration dataset")
44 | if data == "pileval":
45 | dataset = load_dataset(
46 | "json",
47 | data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst",
48 | split="train")
49 | dataset = dataset["text"][:calib_size]
50 | elif data == "ccdv/cnn_dailymail":
51 | dataset = load_dataset(
52 | "ccdv/cnn_dailymail", name="3.0.0", split="train", trust_remote_code=True
53 | )
54 | dataset = dataset["article"][:calib_size]
55 | else:
56 | raise NotImplementedError
57 |
58 | # use this prompt to make chat model do summarize
59 | system_prompt = "You are a useful assistant, please directly output the corresponding summary according to the article entered by the user."
60 |
61 | # line_encoded = []
62 | new_dataset = []
63 | for i in range(len(dataset)):
64 | dataset[i] = dataset[i] + ' TL;DR: '
65 | dataset[i] = dataset[i].strip()
66 | dataset[i] = dataset[i].replace(" n't", "n't")
67 | # use make_content to generate prompt
68 | messages = [
69 | {"role": "system", "content": system_prompt},
70 | {"role": "user", "content": dataset[i]}
71 | ]
72 | raw_text = tokenizer.apply_chat_template(
73 | messages,
74 | tokenize=False,
75 | add_generation_prompt=True
76 | )
77 | new_dataset.append(raw_text)
78 | batch_encoded = tokenizer.batch_encode_plus(
79 | dataset,
80 | return_tensors="pt",
81 | padding=True,
82 | max_length=block_size
83 | )
84 | batch_encoded = batch_encoded["input_ids"]
85 | batch_encoded = batch_encoded.cuda()
86 | calib_dataloader = DataLoader(batch_encoded,
87 | batch_size=batch_size,
88 | shuffle=False)
89 |
90 | return calib_dataloader
91 |
92 |
93 | def get_tokenizer(ckpt_path, **kwargs):
94 | logger.info(f"Loading tokenizer from {ckpt_path}")
95 | tokenizer = AutoTokenizer.from_pretrained(
96 | ckpt_path,
97 | padding_side="left",
98 | trust_remote_code=True,
99 | **kwargs
100 | )
101 | if tokenizer.pad_token is None:
102 | tokenizer.pad_token = tokenizer.eos_token
103 | return tokenizer
104 |
105 |
106 | def get_model(ckpt_path, dtype="float16"):
107 | logger.info(f"Loading model from {ckpt_path}")
108 | torch_dtype = str_dtype_to_torch(dtype)
109 | model = AutoModelForCausalLM.from_pretrained(
110 | ckpt_path,
111 | # device_map="auto",
112 | # torch_dtype=torch_dtype,
113 | ).to(torch_dtype).cuda()
114 | model.eval()
115 | model = model.to(memory_format=torch.channels_last)
116 | return model
117 |
118 |
119 | def get_args():
120 | parser = argparse.ArgumentParser(description=__doc__)
121 | parser.add_argument("--model_dir",
122 | type=str,
123 | required=False,
124 | default=default_config.hf_model_dir,
125 | help="Directory of a HF model checkpoint")
126 | parser.add_argument("--dtype", help="Model data type.", default="float16")
127 | parser.add_argument(
128 | "--qformat",
129 | type=str,
130 | choices=['fp8', 'int4_awq'],
131 | default='int4_awq',
132 | help='Quantization format. Currently only fp8 is supported. '
133 | 'For int8 smoothquant, use smoothquant.py instead. ')
134 | parser.add_argument("--calib_size",
135 | type=int,
136 | default=32,
137 | help="Number of samples for calibration.")
138 | parser.add_argument(
139 | "--export_path",
140 | default=os.path.join(now_dir, "qwen2_7b_4bit_gs128_awq.pt")
141 | )
142 | parser.add_argument('--seed', type=int, default=None, help='Random seed')
143 | args = parser.parse_args()
144 | return args
145 |
146 |
147 | def main():
148 | if not torch.cuda.is_available():
149 | raise EnvironmentError("GPU is required for inference.")
150 |
151 | args = get_args()
152 |
153 | if args.seed is not None:
154 | random.seed(args.seed)
155 | np.random.seed(args.seed)
156 |
157 | tokenizer = get_tokenizer(args.model_dir)
158 | model = get_model(args.model_dir, args.dtype)
159 |
160 | calib_dataloader = get_calib_dataloader(tokenizer=tokenizer,
161 | calib_size=args.calib_size)
162 | model = quantize_and_export(model,
163 | qformat=args.qformat,
164 | calib_dataloader=calib_dataloader,
165 | export_path=args.export_path)
166 |
167 |
168 | if __name__ == "__main__":
169 | main()
170 |
--------------------------------------------------------------------------------
/examples/qwen2/requirements.txt:
--------------------------------------------------------------------------------
1 | datasets~=2.3.2
2 | rouge_score~=0.1.2
3 | # transformers~=4.37.0 # tensorrt-llm has installed
4 | transformers-stream-generator
5 | sentencepiece~=0.1.99
6 | tiktoken
7 | einops
8 | #tensorrt_llm==0.8.0
9 | # optional dependencies
10 | uvicorn
11 | gradio==3.40.1
12 | mdtex2html
13 | sse_starlette==1.6.5
14 | aiohttp_sse_client
15 | openai==1.1.1
16 |
--------------------------------------------------------------------------------
/examples/qwen2/smoothquant.py:
--------------------------------------------------------------------------------
1 | '''
2 | Utilities for SmoothQuant models
3 | '''
4 |
5 | import functools
6 | from collections import defaultdict
7 |
8 | import torch
9 | import torch.nn as nn
10 | from tqdm import tqdm
11 | from transformers.pytorch_utils import Conv1D
12 |
13 |
14 | @torch.no_grad()
15 | def apply_smoothing(
16 | scales,
17 | gemm_weights,
18 | rmsnorm_weights=None,
19 | dtype=torch.float32,
20 | rmsnorm_1p=False
21 | ):
22 | if not isinstance(gemm_weights, list):
23 | gemm_weights = [gemm_weights]
24 |
25 | if rmsnorm_weights is not None:
26 | assert rmsnorm_weights.numel() == scales.numel()
27 | rmsnorm_weights.div_(scales).to(dtype)
28 | if rmsnorm_1p:
29 | rmsnorm_weights += (1 / scales) - 1
30 |
31 | for gemm in gemm_weights:
32 | gemm.mul_(scales.view(1, -1)).to(dtype)
33 |
34 |
35 | @torch.no_grad()
36 | def smooth_gemm(gemm_weights,
37 | act_scales,
38 | rmsnorm_weights=None,
39 | alpha=0.5,
40 | weight_scales=None):
41 | if not isinstance(gemm_weights, list):
42 | gemm_weights = [gemm_weights]
43 | orig_dtype = gemm_weights[0].dtype
44 |
45 | for gemm in gemm_weights:
46 | # gemm_weights are expected to be transposed
47 | assert gemm.shape[1] == act_scales.numel()
48 |
49 | if weight_scales is None:
50 | weight_scales = torch.cat(
51 | [gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights],
52 | dim=0)
53 | weight_scales = weight_scales.max(dim=0)[0]
54 | weight_scales.to(float).clamp(min=1e-5)
55 | scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) /
56 | weight_scales.pow(1 - alpha)).clamp(min=1e-5)
57 |
58 | apply_smoothing(scales, gemm_weights, rmsnorm_weights, orig_dtype)
59 |
60 | return scales
61 |
62 |
63 | @torch.no_grad()
64 | def smooth_gemm_mlp(
65 | w1_weights,
66 | w2_weights,
67 | act_scales,
68 | rmsnorm_weights=None,
69 | alpha=0.5,
70 | weight_scales=None
71 | ):
72 | gemm_weights = []
73 | if not isinstance(w1_weights, list):
74 | w1_weights = [w1_weights]
75 | if not isinstance(w2_weights, list):
76 | w2_weights = [w2_weights]
77 |
78 | for i in range(len(w1_weights)):
79 | gemm_weight = torch.cat([w1_weights[i], w2_weights[i]], dim=0)
80 | gemm_weights.append(gemm_weight)
81 |
82 | orig_dtype = gemm_weights[0].dtype
83 |
84 | for gemm in gemm_weights:
85 | # gemm_weights are expected to be transposed
86 | assert gemm.shape[1] == act_scales.numel()
87 |
88 | if weight_scales is None:
89 | weight_scales = torch.cat(
90 | [gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights],
91 | dim=0)
92 | weight_scales = weight_scales.max(dim=0)[0]
93 | weight_scales.to(float).clamp(min=1e-5)
94 | scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) /
95 | weight_scales.pow(1 - alpha)).clamp(min=1e-5)
96 |
97 | apply_smoothing(scales, w1_weights + w2_weights, rmsnorm_weights, orig_dtype)
98 |
99 | return scales
100 |
101 |
102 | @torch.no_grad()
103 | def smooth_ln_fcs(ln, fcs, act_scales, alpha=0.5):
104 | if not isinstance(fcs, list):
105 | fcs = [fcs]
106 | for fc in fcs:
107 | assert isinstance(fc, nn.Linear)
108 | assert ln.weight.numel() == fc.in_features == act_scales.numel()
109 |
110 | device, dtype = fcs[0].weight.device, fcs[0].weight.dtype
111 | act_scales = act_scales.to(device=device, dtype=dtype)
112 | weight_scales = torch.cat(
113 | [fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0)
114 | weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)
115 |
116 | scales = (act_scales.pow(alpha) /
117 | weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype)
118 |
119 | if ln is not None:
120 | ln.weight.div_(scales)
121 | ln.bias.div_(scales)
122 |
123 | for fc in fcs:
124 | fc.weight.mul_(scales.view(1, -1))
125 | return scales
126 |
127 |
128 | @torch.no_grad()
129 | def capture_activation_range(
130 | model,
131 | tokenizer,
132 | dataset,
133 | system_prompt,
134 | max_input_len,
135 | num_samples=512,
136 | ):
137 | model.eval()
138 | device = next(model.parameters()).device
139 | act_scales = defaultdict(lambda: {"x": None, "y": None, "w": None})
140 |
141 | def stat_tensor(name, tensor, act_scales, key):
142 | hidden_dim = tensor.shape[-1]
143 | tensor = tensor.view(-1, hidden_dim).abs().detach()
144 | comming_max = torch.max(tensor, dim=0)[0].float()
145 |
146 | if act_scales[name][key] is None:
147 | act_scales[name][key] = comming_max
148 | else:
149 | act_scales[name][key] = torch.max(act_scales[name][key],
150 | comming_max)
151 |
152 | def stat_input_hook(m, x, y, name):
153 | if isinstance(x, tuple):
154 | x = x[0]
155 | stat_tensor(name, x, act_scales, "x")
156 | stat_tensor(name, y, act_scales, "y")
157 |
158 | if act_scales[name]["w"] is None:
159 | act_scales[name]["w"] = m.weight.abs().clip(1e-8,
160 | None).max(dim=1)[0]
161 |
162 | hooks = []
163 | for name, m in model.named_modules():
164 | if isinstance(m, nn.Linear) or isinstance(m, Conv1D):
165 | hooks.append(
166 | m.register_forward_hook(
167 | functools.partial(stat_input_hook, name=name)))
168 | num_samples = min(num_samples, len(dataset))
169 | for i in tqdm(range(num_samples), desc="calibrating model"):
170 | line = dataset[i]["article"]
171 | line = line + ' TL;DR: '
172 | line = line.strip()
173 | line = line.replace(" n't", "n't")
174 | # use make_content to generate prompt
175 | # use make_content to generate prompt
176 | messages = [
177 | {"role": "system", "content": system_prompt},
178 | {"role": "user", "content": line}
179 | ]
180 | text = tokenizer.apply_chat_template(
181 | messages,
182 | tokenize=False,
183 | add_generation_prompt=True,
184 | truncation=True,
185 | max_length=max_input_len,
186 | )
187 | input_ids = tokenizer([text], return_tensors="pt").input_ids
188 | input_ids = input_ids.to(device)
189 | # input_ids = tokenizer(dataset[i]["text"],
190 | # return_tensors="pt",
191 | # max_length=seq_len,
192 | # truncation=True).input_ids.to(device)
193 | # model(input_ids)
194 | model(input_ids)
195 |
196 | for h in hooks:
197 | h.remove()
198 |
199 | return act_scales
200 |
--------------------------------------------------------------------------------
/examples/qwen2/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tlntin/Qwen-TensorRT-LLM/7da636fe7d55f42cebf3f2a43931dd0f1619efee/examples/qwen2/utils/__init__.py
--------------------------------------------------------------------------------
/examples/qwen2/utils/utils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import os
18 | import json
19 | from pathlib import Path
20 | from typing import Optional
21 |
22 | from transformers import AutoTokenizer, T5Tokenizer
23 |
24 | import tensorrt_llm
25 |
26 | # TODO(enweiz): Update for refactored models
27 | DEFAULT_HF_MODEL_DIRS = {
28 | 'BaichuanForCausalLM': 'baichuan-inc/Baichuan-13B-Chat',
29 | 'BloomForCausalLM': 'bigscience/bloom-560m',
30 | 'ChatGLMForCausalLM': 'THUDM/chatglm3-6b',
31 | 'FalconForCausalLM': 'tiiuae/falcon-rw-1b',
32 | 'gpt': 'gpt2-medium',
33 | 'GPTJForCausalLM': 'EleutherAI/gpt-j-6b',
34 | 'GPTNeoXForCausalLM': 'EleutherAI/gpt-neox-20b',
35 | 'InternLMForCausalLM': 'internlm/internlm-chat-7b',
36 | 'LlamaForCausalLM': 'meta-llama/Llama-2-7b-hf',
37 | 'MPTForCausalLM': 'mosaicml/mpt-7b',
38 | 'PhiForCausalLM': 'microsoft/phi-2',
39 | 'OPTForCausalLM': 'facebook/opt-350m',
40 | 'qwen': 'Qwen/Qwen-7B',
41 | }
42 |
43 | DEFAULT_PROMPT_TEMPLATES = {
44 | 'InternLMForCausalLM':
45 | "<|User|>:{input_text}\n<|Bot|>:",
46 | 'qwen':
47 | "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{input_text}<|im_end|>\n<|im_start|>assistant\n",
48 | 'Qwen2ForCausalLM':
49 | "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{input_text}<|im_end|>\n<|im_start|>assistant\n",
50 | }
51 |
52 |
53 | def read_model_name(engine_dir: str):
54 | engine_version = tensorrt_llm.runtime.engine.get_engine_version(engine_dir)
55 |
56 | with open(Path(engine_dir) / "config.json", 'r') as f:
57 | config = json.load(f)
58 |
59 | if engine_version is None:
60 | return config['builder_config']['name'], None
61 |
62 | model_arch = config['pretrained_config']['architecture']
63 | model_version = None
64 | if model_arch == 'ChatGLMForCausalLM':
65 | model_version = config['pretrained_config']['chatglm_version']
66 | return model_arch, model_version
67 |
68 |
69 | def throttle_generator(generator, stream_interval):
70 | for i, out in enumerate(generator):
71 | if not i % stream_interval:
72 | yield out
73 |
74 | if i % stream_interval:
75 | yield out
76 |
77 |
78 | def load_tokenizer(tokenizer_dir: Optional[str] = None,
79 | vocab_file: Optional[str] = None,
80 | model_name: str = 'gpt',
81 | model_version: Optional[str] = None,
82 | tokenizer_type: Optional[str] = None):
83 | if vocab_file is None:
84 | use_fast = True
85 | if tokenizer_type is not None and tokenizer_type == "llama":
86 | use_fast = False
87 | # Should set both padding_side and truncation_side to be 'left'
88 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir,
89 | legacy=False,
90 | padding_side='left',
91 | truncation_side='left',
92 | trust_remote_code=True,
93 | tokenizer_type=tokenizer_type,
94 | use_fast=use_fast)
95 | else:
96 | # For gpt-next, directly load from tokenizer.model
97 | assert model_name == 'gpt'
98 | tokenizer = T5Tokenizer(vocab_file=vocab_file,
99 | padding_side='left',
100 | truncation_side='left')
101 |
102 | if model_name == 'qwen':
103 | with open(Path(tokenizer_dir) / "generation_config.json") as f:
104 | gen_config = json.load(f)
105 | chat_format = gen_config['chat_format']
106 | if chat_format == 'raw':
107 | pad_id = gen_config['pad_token_id']
108 | end_id = gen_config['eos_token_id']
109 | elif chat_format == 'chatml':
110 | pad_id = tokenizer.im_end_id
111 | end_id = tokenizer.im_end_id
112 | else:
113 | raise Exception(f"unknown chat format: {chat_format}")
114 | elif model_name == "Qwen2ForCausalLM":
115 | gen_config_path = os.path.join(tokenizer_dir, 'generation_config.json')
116 | with open(gen_config_path, 'r') as f:
117 | gen_config = json.load(f)
118 |
119 | ### if model type is chat pad_id = end_id = gen_config["eos_token_id"][0]
120 | if isinstance (gen_config["eos_token_id"], list):
121 | pad_id = end_id = gen_config["eos_token_id"][0]
122 | ### if model type is base, run this branch
123 | else:
124 | pad_id = gen_config["bos_token_id"]
125 | end_id = gen_config["eos_token_id"]
126 | elif model_name == 'ChatGLMForCausalLM' and model_version == 'glm':
127 | pad_id = tokenizer.pad_token_id
128 | end_id = tokenizer.eop_token_id
129 | else:
130 | if tokenizer.pad_token_id is None:
131 | tokenizer.pad_token_id = tokenizer.eos_token_id
132 | pad_id = tokenizer.pad_token_id
133 | end_id = tokenizer.eos_token_id
134 |
135 | return tokenizer, pad_id, end_id
136 |
--------------------------------------------------------------------------------
/examples/qwen2/web_demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gradio as gr
3 | import mdtex2html
4 | from default_config import default_config
5 | from openai import OpenAI
6 |
7 |
8 | client = OpenAI(
9 | base_url="http://localhost:8000/v1",
10 | api_key="no api"
11 | )
12 |
13 | now_dir = os.path.dirname(os.path.abspath(__file__))
14 |
15 |
16 | """Override Chatbot.postprocess"""
17 | def postprocess(self, y):
18 | if y is None:
19 | return []
20 | for i, (message, response) in enumerate(y):
21 | y[i] = [
22 | None if message is None else mdtex2html.convert((message)),
23 | None if response is None else mdtex2html.convert(response),
24 | ]
25 | return y
26 |
27 |
28 | gr.Chatbot.postprocess = postprocess
29 |
30 |
31 | def parse_text(text):
32 | """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
33 | lines = text.split("\n")
34 | lines = [line for line in lines if line != ""]
35 | count = 0
36 | for i, line in enumerate(lines):
37 | if "```" in line:
38 | count += 1
39 | items = line.split('`')
40 | if count % 2 == 1:
41 | lines[i] = f''
42 | else:
43 | lines[i] = f'
'
44 | else:
45 | if i > 0:
46 | if count % 2 == 1:
47 | line = line.replace("`", "\`")
48 | line = line.replace("<", "<")
49 | line = line.replace(">", ">")
50 | line = line.replace(" ", " ")
51 | line = line.replace("*", "*")
52 | line = line.replace("_", "_")
53 | line = line.replace("-", "-")
54 | line = line.replace(".", ".")
55 | line = line.replace("!", "!")
56 | line = line.replace("(", "(")
57 | line = line.replace(")", ")")
58 | line = line.replace("$", "$")
59 | lines[i] = "
"+line
60 | text = "".join(lines)
61 | return text
62 |
63 |
64 | def predict(input_text, chatbot, top_p, temperature, max_generate_length, history):
65 | messages = [
66 | {"role": "system", "content": "You are a helpful assistant."},
67 | ]
68 | for (message, response) in history:
69 | messages.append({"role": "user", "content": message})
70 | messages.append({"role": "assistant", "content": response})
71 | messages.append({"role": "user", "content": input_text})
72 | chatbot.append((parse_text(input_text), ""))
73 | history.append((input_text, ""))
74 |
75 | response = client.chat.completions.create(
76 | model="gpt-3.5-turbo",
77 | messages=messages,
78 | top_p=top_p,
79 | temperature=temperature,
80 | n=1,
81 | max_tokens=max_generate_length,
82 | stream=True,
83 | )
84 | response_text = ""
85 | for event in response:
86 | event_text = event.choices[0].delta.content # extract the text
87 | if event_text is None:
88 | event_text = ""
89 | response_text += event_text
90 | chatbot[-1] = (parse_text(input_text), parse_text(response_text))
91 | history[-1] = (input_text, response_text)
92 | yield chatbot, history
93 | messages.append({"role": "assistant", "content": response_text})
94 |
95 |
96 | def reset_user_input():
97 | return gr.update(value='')
98 |
99 |
100 | def reset_state():
101 | return [], []
102 |
103 |
104 | with gr.Blocks() as demo:
105 | gr.HTML("""Qwen1.5-Chat (Power By TensorRT-LLM)
""")
106 |
107 | chatbot = gr.Chatbot()
108 | with gr.Row():
109 | with gr.Column(scale=4):
110 | with gr.Column(scale=12):
111 | user_input = gr.Textbox(
112 | show_label=False,
113 | placeholder="Input...",
114 | lines=10,
115 | container=False
116 | )
117 | with gr.Column(min_width=32, scale=1):
118 | submitBtn = gr.Button("Submit", variant="primary")
119 | with gr.Column(scale=1):
120 | emptyBtn = gr.Button("Clear History")
121 | top_p = gr.Slider(
122 | minimum=0,
123 | maximum=1,
124 | value=0.8,
125 | step=0.1,
126 | label="top-p",
127 | interactive=True
128 | )
129 | temperature = gr.Slider(
130 | minimum=0,
131 | maximum=1,
132 | value=1,
133 | step=0.1,
134 | label="temperature",
135 | interactive=True
136 | )
137 | max_generate_length = gr.Slider(
138 | 0,
139 | default_config.max_new_tokens,
140 | value=default_config.max_new_tokens // 2,
141 | step=1.0,
142 | label="Maximum generate length", interactive=True
143 | )
144 |
145 | history = gr.State([])
146 |
147 | submitBtn.click(
148 | predict, # call function
149 | [user_input, chatbot, top_p, temperature, max_generate_length, history], # inputs
150 | [chatbot, history], # outputs
151 | show_progress=True,
152 | )
153 | # reset input
154 | submitBtn.click(reset_user_input, [], [user_input])
155 |
156 | emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
157 |
158 | demo.queue().launch(server_name="0.0.0.0", share=True, inbrowser=False)
159 | # demo.queue().launch(server_name="localhost", share=False, inbrowser=False)
160 |
--------------------------------------------------------------------------------
/images/course.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tlntin/Qwen-TensorRT-LLM/7da636fe7d55f42cebf3f2a43931dd0f1619efee/images/course.png
--------------------------------------------------------------------------------
/images/function_call_001.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tlntin/Qwen-TensorRT-LLM/7da636fe7d55f42cebf3f2a43931dd0f1619efee/images/function_call_001.jpg
--------------------------------------------------------------------------------
/images/function_call_002.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tlntin/Qwen-TensorRT-LLM/7da636fe7d55f42cebf3f2a43931dd0f1619efee/images/function_call_002.jpg
--------------------------------------------------------------------------------
/images/langchain-chatchat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tlntin/Qwen-TensorRT-LLM/7da636fe7d55f42cebf3f2a43931dd0f1619efee/images/langchain-chatchat.jpg
--------------------------------------------------------------------------------
/images/rmsnormplugin.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tlntin/Qwen-TensorRT-LLM/7da636fe7d55f42cebf3f2a43931dd0f1619efee/images/rmsnormplugin.jpeg
--------------------------------------------------------------------------------
/images/rope_inside.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tlntin/Qwen-TensorRT-LLM/7da636fe7d55f42cebf3f2a43931dd0f1619efee/images/rope_inside.jpeg
--------------------------------------------------------------------------------
/images/rope_outside.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tlntin/Qwen-TensorRT-LLM/7da636fe7d55f42cebf3f2a43931dd0f1619efee/images/rope_outside.jpeg
--------------------------------------------------------------------------------
/images/tensorrt_rmsnorm_op.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tlntin/Qwen-TensorRT-LLM/7da636fe7d55f42cebf3f2a43931dd0f1619efee/images/tensorrt_rmsnorm_op.jpeg
--------------------------------------------------------------------------------
/images/triton_trt_llm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tlntin/Qwen-TensorRT-LLM/7da636fe7d55f42cebf3f2a43931dd0f1619efee/images/triton_trt_llm.png
--------------------------------------------------------------------------------
/triton_model_repo/ensemble/1/.tmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tlntin/Qwen-TensorRT-LLM/7da636fe7d55f42cebf3f2a43931dd0f1619efee/triton_model_repo/ensemble/1/.tmp
--------------------------------------------------------------------------------
/triton_model_repo/ensemble/config.pbtxt:
--------------------------------------------------------------------------------
1 | # Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Redistribution and use in source and binary forms, with or without
4 | # modification, are permitted provided that the following conditions
5 | # are met:
6 | # * Redistributions of source code must retain the above copyright
7 | # notice, this list of conditions and the following disclaimer.
8 | # * Redistributions in binary form must reproduce the above copyright
9 | # notice, this list of conditions and the following disclaimer in the
10 | # documentation and/or other materials provided with the distribution.
11 | # * Neither the name of NVIDIA CORPORATION nor the names of its
12 | # contributors may be used to endorse or promote products derived
13 | # from this software without specific prior written permission.
14 | #
15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 |
27 | name: "ensemble"
28 | platform: "ensemble"
29 | max_batch_size: 1
30 | input [
31 | {
32 | name: "text_input"
33 | data_type: TYPE_STRING
34 | dims: [ -1 ]
35 | },
36 | {
37 | name: "max_tokens"
38 | data_type: TYPE_INT32
39 | dims: [ -1 ]
40 | },
41 | {
42 | name: "bad_words"
43 | data_type: TYPE_STRING
44 | dims: [ -1 ]
45 | optional: true
46 | },
47 | {
48 | name: "stop_words"
49 | data_type: TYPE_STRING
50 | dims: [ -1 ]
51 | optional: true
52 | },
53 | {
54 | name: "end_id"
55 | data_type: TYPE_INT32
56 | dims: [ 1 ]
57 | optional: true
58 | },
59 | {
60 | name: "pad_id"
61 | data_type: TYPE_INT32
62 | dims: [ 1 ]
63 | optional: true
64 | },
65 | {
66 | name: "top_k"
67 | data_type: TYPE_INT32
68 | dims: [ 1 ]
69 | optional: true
70 | },
71 | {
72 | name: "top_p"
73 | data_type: TYPE_FP32
74 | dims: [ 1 ]
75 | optional: true
76 | },
77 | {
78 | name: "temperature"
79 | data_type: TYPE_FP32
80 | dims: [ 1 ]
81 | optional: true
82 | },
83 | {
84 | name: "length_penalty"
85 | data_type: TYPE_FP32
86 | dims: [ 1 ]
87 | optional: true
88 | },
89 | {
90 | name: "repetition_penalty"
91 | data_type: TYPE_FP32
92 | dims: [ 1 ]
93 | optional: true
94 | },
95 | {
96 | name: "min_length"
97 | data_type: TYPE_INT32
98 | dims: [ 1 ]
99 | optional: true
100 | },
101 | {
102 | name: "presence_penalty"
103 | data_type: TYPE_FP32
104 | dims: [ 1 ]
105 | optional: true
106 | },
107 | {
108 | name: "frequency_penalty"
109 | data_type: TYPE_FP32
110 | dims: [ 1 ]
111 | optional: true
112 | },
113 | {
114 | name: "random_seed"
115 | data_type: TYPE_UINT64
116 | dims: [ 1 ]
117 | optional: true
118 | },
119 | {
120 | name: "return_log_probs"
121 | data_type: TYPE_BOOL
122 | dims: [ 1 ]
123 | optional: true
124 | },
125 | {
126 | name: "return_context_logits"
127 | data_type: TYPE_BOOL
128 | dims: [ 1 ]
129 | optional: true
130 | },
131 | {
132 | name: "return_generation_logits"
133 | data_type: TYPE_BOOL
134 | dims: [ 1 ]
135 | optional: true
136 | },
137 | {
138 | name: "beam_width"
139 | data_type: TYPE_INT32
140 | dims: [ 1 ]
141 | optional: true
142 | },
143 | {
144 | name: "stream"
145 | data_type: TYPE_BOOL
146 | dims: [ 1 ]
147 | optional: true
148 | },
149 | {
150 | name: "prompt_embedding_table"
151 | data_type: TYPE_FP16
152 | dims: [ -1, -1 ]
153 | optional: true
154 | },
155 | {
156 | name: "prompt_vocab_size"
157 | data_type: TYPE_INT32
158 | dims: [ 1 ]
159 | optional: true
160 | },
161 | {
162 | name: "embedding_bias_words"
163 | data_type: TYPE_STRING
164 | dims: [ -1 ]
165 | optional: true
166 | },
167 | {
168 | name: "embedding_bias_weights"
169 | data_type: TYPE_FP32
170 | dims: [ -1 ]
171 | optional: true
172 | }
173 | ]
174 | output [
175 | {
176 | name: "text_output"
177 | data_type: TYPE_STRING
178 | dims: [ -1 ]
179 | },
180 | {
181 | name: "cum_log_probs"
182 | data_type: TYPE_FP32
183 | dims: [ -1 ]
184 | },
185 | {
186 | name: "output_log_probs"
187 | data_type: TYPE_FP32
188 | dims: [ -1, -1 ]
189 | },
190 | {
191 | name: "context_logits"
192 | data_type: TYPE_FP32
193 | dims: [ -1, -1 ]
194 | },
195 | {
196 | name: "generation_logits"
197 | data_type: TYPE_FP32
198 | dims: [ -1, -1, -1 ]
199 | }
200 | ]
201 | ensemble_scheduling {
202 | step [
203 | {
204 | model_name: "preprocessing"
205 | model_version: -1
206 | input_map {
207 | key: "QUERY"
208 | value: "text_input"
209 | }
210 | input_map {
211 | key: "REQUEST_OUTPUT_LEN"
212 | value: "max_tokens"
213 | }
214 | input_map {
215 | key: "BAD_WORDS_DICT"
216 | value: "bad_words"
217 | }
218 | input_map {
219 | key: "STOP_WORDS_DICT"
220 | value: "stop_words"
221 | }
222 | input_map {
223 | key: "EMBEDDING_BIAS_WORDS"
224 | value: "embedding_bias_words"
225 | }
226 | input_map {
227 | key: "EMBEDDING_BIAS_WEIGHTS"
228 | value: "embedding_bias_weights"
229 | }
230 | input_map {
231 | key: "END_ID"
232 | value: "end_id"
233 | }
234 | input_map {
235 | key: "PAD_ID"
236 | value: "pad_id"
237 | }
238 | output_map {
239 | key: "REQUEST_INPUT_LEN"
240 | value: "_REQUEST_INPUT_LEN"
241 | }
242 | output_map {
243 | key: "INPUT_ID"
244 | value: "_INPUT_ID"
245 | }
246 | output_map {
247 | key: "REQUEST_OUTPUT_LEN"
248 | value: "_REQUEST_OUTPUT_LEN"
249 | }
250 | output_map {
251 | key: "STOP_WORDS_IDS"
252 | value: "_STOP_WORDS_IDS"
253 | }
254 | output_map {
255 | key: "BAD_WORDS_IDS"
256 | value: "_BAD_WORDS_IDS"
257 | }
258 | output_map {
259 | key: "EMBEDDING_BIAS"
260 | value: "_EMBEDDING_BIAS"
261 | }
262 | output_map {
263 | key: "OUT_END_ID"
264 | value: "_PREPROCESSOR_END_ID"
265 | }
266 | output_map {
267 | key: "OUT_PAD_ID"
268 | value: "_PREPROCESSOR_PAD_ID"
269 | }
270 | },
271 | {
272 | model_name: "tensorrt_llm"
273 | model_version: -1
274 | input_map {
275 | key: "input_ids"
276 | value: "_INPUT_ID"
277 | }
278 | input_map {
279 | key: "input_lengths"
280 | value: "_REQUEST_INPUT_LEN"
281 | }
282 | input_map {
283 | key: "request_output_len"
284 | value: "_REQUEST_OUTPUT_LEN"
285 | }
286 | input_map {
287 | key: "end_id"
288 | value: "_PREPROCESSOR_END_ID"
289 | }
290 | input_map {
291 | key: "pad_id"
292 | value: "_PREPROCESSOR_PAD_ID"
293 | }
294 | input_map {
295 | key: "embedding_bias"
296 | value: "_EMBEDDING_BIAS"
297 | }
298 | input_map {
299 | key: "runtime_top_k"
300 | value: "top_k"
301 | }
302 | input_map {
303 | key: "runtime_top_p"
304 | value: "top_p"
305 | }
306 | input_map {
307 | key: "temperature"
308 | value: "temperature"
309 | }
310 | input_map {
311 | key: "len_penalty"
312 | value: "length_penalty"
313 | }
314 | input_map {
315 | key: "repetition_penalty"
316 | value: "repetition_penalty"
317 | }
318 | input_map {
319 | key: "min_length"
320 | value: "min_length"
321 | }
322 | input_map {
323 | key: "presence_penalty"
324 | value: "presence_penalty"
325 | }
326 | input_map {
327 | key: "frequency_penalty"
328 | value: "frequency_penalty"
329 | }
330 | input_map {
331 | key: "random_seed"
332 | value: "random_seed"
333 | }
334 | input_map {
335 | key: "return_log_probs"
336 | value: "return_log_probs"
337 | }
338 | input_map {
339 | key: "return_context_logits"
340 | value: "return_context_logits"
341 | }
342 | input_map {
343 | key: "return_generation_logits"
344 | value: "return_generation_logits"
345 | }
346 | input_map {
347 | key: "beam_width"
348 | value: "beam_width"
349 | }
350 | input_map {
351 | key: "streaming"
352 | value: "stream"
353 | }
354 | input_map {
355 | key: "prompt_embedding_table"
356 | value: "prompt_embedding_table"
357 | }
358 | input_map {
359 | key: "prompt_vocab_size"
360 | value: "prompt_vocab_size"
361 | }
362 | input_map {
363 | key: "stop_words_list"
364 | value: "_STOP_WORDS_IDS"
365 | }
366 | input_map {
367 | key: "bad_words_list"
368 | value: "_BAD_WORDS_IDS"
369 | }
370 | output_map {
371 | key: "output_ids"
372 | value: "_TOKENS_BATCH"
373 | }
374 | output_map {
375 | key: "sequence_length"
376 | value: "_SEQUENCE_LENGTH"
377 | },
378 | output_map {
379 | key: "cum_log_probs"
380 | value: "_CUM_LOG_PROBS"
381 | }
382 | output_map {
383 | key: "output_log_probs"
384 | value: "_OUTPUT_LOG_PROBS"
385 | },
386 | output_map {
387 | key: "context_logits"
388 | value: "_CONTEXT_LOGITS"
389 | },
390 | output_map {
391 | key: "generation_logits"
392 | value: "_GENERATION_LOGITS"
393 | }
394 | },
395 | {
396 | model_name: "postprocessing"
397 | model_version: -1
398 | input_map {
399 | key: "TOKENS_BATCH"
400 | value: "_TOKENS_BATCH"
401 | }
402 | input_map {
403 | key: "CUM_LOG_PROBS"
404 | value: "_CUM_LOG_PROBS"
405 | }
406 | input_map {
407 | key: "OUTPUT_LOG_PROBS"
408 | value: "_OUTPUT_LOG_PROBS"
409 | }
410 | input_map {
411 | key: "CONTEXT_LOGITS"
412 | value: "_CONTEXT_LOGITS"
413 | }
414 | input_map {
415 | key: "GENERATION_LOGITS"
416 | value: "_GENERATION_LOGITS"
417 | }
418 | input_map {
419 | key: "SEQUENCE_LENGTH"
420 | value: "_SEQUENCE_LENGTH"
421 | }
422 | output_map {
423 | key: "OUTPUT"
424 | value: "text_output"
425 | }
426 | output_map {
427 | key: "OUT_OUTPUT_LOG_PROBS"
428 | value: "output_log_probs"
429 | }
430 | output_map {
431 | key: "OUT_CUM_LOG_PROBS"
432 | value: "cum_log_probs"
433 | }
434 | output_map {
435 | key: "OUT_CONTEXT_LOGITS"
436 | value: "context_logits"
437 | }
438 | output_map {
439 | key: "OUT_GENERATION_LOGITS"
440 | value: "generation_logits"
441 | }
442 | }
443 | ]
444 | }
445 |
--------------------------------------------------------------------------------
/triton_model_repo/postprocessing/1/model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Redistribution and use in source and binary forms, with or without
4 | # modification, are permitted provided that the following conditions
5 | # are met:
6 | # * Redistributions of source code must retain the above copyright
7 | # notice, this list of conditions and the following disclaimer.
8 | # * Redistributions in binary form must reproduce the above copyright
9 | # notice, this list of conditions and the following disclaimer in the
10 | # documentation and/or other materials provided with the distribution.
11 | # * Neither the name of NVIDIA CORPORATION nor the names of its
12 | # contributors may be used to endorse or promote products derived
13 | # from this software without specific prior written permission.
14 | #
15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 |
27 | import json
28 | import os
29 | import numpy as np
30 | import triton_python_backend_utils as pb_utils
31 | from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer
32 |
33 |
34 | class TritonPythonModel:
35 | """Your Python model must use the same class name. Every Python model
36 | that is created must have "TritonPythonModel" as the class name.
37 | """
38 |
39 | def initialize(self, args):
40 | """`initialize` is called only once when the model is being loaded.
41 | Implementing `initialize` function is optional. This function allows
42 | the model to initialize any state associated with this model.
43 | Parameters
44 | ----------
45 | args : dict
46 | Both keys and values are strings. The dictionary keys and values are:
47 | * model_config: A JSON string containing the model configuration
48 | * model_instance_kind: A string containing model instance kind
49 | * model_instance_device_id: A string containing model instance device ID
50 | * model_repository: Model repository path
51 | * model_version: Model version
52 | * model_name: Model name
53 | """
54 | # Parse model configs
55 | model_config = json.loads(args['model_config'])
56 | tokenizer_dir = model_config['parameters']['tokenizer_dir'][
57 | 'string_value']
58 | tokenizer_type = model_config['parameters']['tokenizer_type'][
59 | 'string_value']
60 | self.skip_special_tokens = model_config['parameters'].get(
61 | 'skip_special_tokens',
62 | {'string_value': "true"})['string_value'].lower() in [
63 | 'true', '1', 't', 'y', 'yes'
64 | ]
65 |
66 | if tokenizer_type == 't5':
67 | self.tokenizer = T5Tokenizer(vocab_file=tokenizer_dir,
68 | padding_side='left')
69 | elif tokenizer_type == 'auto':
70 | self.tokenizer = AutoTokenizer.from_pretrained(
71 | tokenizer_dir, padding_side='left', trust_remote_code=True)
72 | elif tokenizer_type == 'llama':
73 | self.tokenizer = LlamaTokenizer.from_pretrained(
74 | tokenizer_dir, legacy=False, padding_side='left')
75 | else:
76 | raise AttributeError(
77 | f'Unexpected tokenizer type: {tokenizer_type}')
78 | gen_config_path = os.path.join(tokenizer_dir, 'generation_config.json')
79 | with open(gen_config_path, 'r') as f:
80 | gen_config = json.load(f)
81 | if isinstance (gen_config["eos_token_id"], list):
82 | pad_id = end_id = gen_config["eos_token_id"][0]
83 | ### if model type is base, run this branch
84 | else:
85 | pad_id = gen_config["bos_token_id"]
86 | end_id = gen_config["eos_token_id"]
87 | self.tokenizer_pad_id = pad_id
88 | self.tokenizer_end_id = end_id
89 | eos_token = self.tokenizer.decode(end_id)
90 | self.tokenizer.eos_token = self.tokenizer.pad_token = eos_token
91 |
92 | # Parse model output configs
93 | output_config = pb_utils.get_output_config_by_name(
94 | model_config, "OUTPUT")
95 |
96 | # Convert Triton types to numpy types
97 | self.output_dtype = pb_utils.triton_string_to_numpy(
98 | output_config['data_type'])
99 |
100 | def execute(self, requests):
101 | """`execute` must be implemented in every Python model. `execute`
102 | function receives a list of pb_utils.InferenceRequest as the only
103 | argument. This function is called when an inference is requested
104 | for this model. Depending on the batching configuration (e.g. Dynamic
105 | Batching) used, `requests` may contain multiple requests. Every
106 | Python model, must create one pb_utils.InferenceResponse for every
107 | pb_utils.InferenceRequest in `requests`. If there is an error, you can
108 | set the error argument when creating a pb_utils.InferenceResponse.
109 | Parameters
110 | ----------
111 | requests : list
112 | A list of pb_utils.InferenceRequest
113 | Returns
114 | -------
115 | list
116 | A list of pb_utils.InferenceResponse. The length of this list must
117 | be the same as `requests`
118 | """
119 |
120 | responses = []
121 |
122 | # Every Python backend must iterate over everyone of the requests
123 | # and create a pb_utils.InferenceResponse for each of them.
124 | for idx, request in enumerate(requests):
125 | # Get input tensors
126 | tokens_batch = pb_utils.get_input_tensor_by_name(
127 | request, 'TOKENS_BATCH').as_numpy()
128 |
129 | # Get sequence length
130 | sequence_lengths = pb_utils.get_input_tensor_by_name(
131 | request, 'SEQUENCE_LENGTH').as_numpy()
132 |
133 | # Get cum log probs
134 | cum_log_probs = pb_utils.get_input_tensor_by_name(
135 | request, 'CUM_LOG_PROBS').as_numpy()
136 |
137 | # Get sequence length
138 | output_log_probs = pb_utils.get_input_tensor_by_name(
139 | request, 'OUTPUT_LOG_PROBS').as_numpy()
140 |
141 | # Get context logits
142 | context_logits = pb_utils.get_input_tensor_by_name(
143 | request, 'CONTEXT_LOGITS').as_numpy()
144 |
145 | # Get generation logits
146 | generation_logits = pb_utils.get_input_tensor_by_name(
147 | request, 'GENERATION_LOGITS').as_numpy()
148 |
149 | # Reshape Input
150 | # tokens_batch = tokens_batch.reshape([-1, tokens_batch.shape[0]])
151 | # tokens_batch = tokens_batch.T
152 |
153 | # Postprocessing output data.
154 | outputs = self._postprocessing(tokens_batch, sequence_lengths)
155 |
156 | # Create output tensors. You need pb_utils.Tensor
157 | # objects to create pb_utils.InferenceResponse.
158 | output_tensor = pb_utils.Tensor(
159 | 'OUTPUT',
160 | np.array(outputs).astype(self.output_dtype))
161 |
162 | out_cum_log_probs = pb_utils.Tensor('OUT_CUM_LOG_PROBS',
163 | cum_log_probs)
164 |
165 | out_output_log_probs = pb_utils.Tensor('OUT_OUTPUT_LOG_PROBS',
166 | output_log_probs)
167 |
168 | out_context_logits = pb_utils.Tensor('OUT_CONTEXT_LOGITS',
169 | context_logits)
170 |
171 | out_generation_logits = pb_utils.Tensor('OUT_GENERATION_LOGITS',
172 | generation_logits)
173 |
174 | # Create InferenceResponse. You can set an error here in case
175 | # there was a problem with handling this inference request.
176 | # Below is an example of how you can set errors in inference
177 | # response:
178 | #
179 | # pb_utils.InferenceResponse(
180 | # output_tensors=..., TritonError("An error occurred"))
181 | inference_response = pb_utils.InferenceResponse(output_tensors=[
182 | output_tensor, out_cum_log_probs, out_output_log_probs,
183 | out_context_logits, out_generation_logits
184 | ])
185 | responses.append(inference_response)
186 |
187 | # You should return a list of pb_utils.InferenceResponse. Length
188 | # of this list must match the length of `requests` list.
189 | return responses
190 |
191 | def finalize(self):
192 | """`finalize` is called only once when the model is being unloaded.
193 | Implementing `finalize` function is optional. This function allows
194 | the model to perform any necessary clean ups before exit.
195 | """
196 | print('Cleaning up...')
197 |
198 | def _postprocessing(self, tokens_batch, sequence_lengths):
199 | outputs = []
200 | for batch_idx, beam_tokens in enumerate(tokens_batch):
201 | for beam_idx, tokens in enumerate(beam_tokens):
202 | seq_len = sequence_lengths[batch_idx][beam_idx]
203 | output = self.tokenizer.decode(
204 | tokens[:seq_len],
205 | skip_special_tokens=self.skip_special_tokens)
206 | outputs.append(output.encode('utf8'))
207 | return outputs
208 |
--------------------------------------------------------------------------------
/triton_model_repo/postprocessing/config.pbtxt:
--------------------------------------------------------------------------------
1 | # Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Redistribution and use in source and binary forms, with or without
4 | # modification, are permitted provided that the following conditions
5 | # are met:
6 | # * Redistributions of source code must retain the above copyright
7 | # notice, this list of conditions and the following disclaimer.
8 | # * Redistributions in binary form must reproduce the above copyright
9 | # notice, this list of conditions and the following disclaimer in the
10 | # documentation and/or other materials provided with the distribution.
11 | # * Neither the name of NVIDIA CORPORATION nor the names of its
12 | # contributors may be used to endorse or promote products derived
13 | # from this software without specific prior written permission.
14 | #
15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 |
27 | name: "postprocessing"
28 | backend: "python"
29 | max_batch_size: 2
30 | input [
31 | {
32 | name: "TOKENS_BATCH"
33 | data_type: TYPE_INT32
34 | dims: [ -1, -1 ]
35 | },
36 | {
37 | name: "SEQUENCE_LENGTH"
38 | data_type: TYPE_INT32
39 | dims: [ -1 ]
40 | },
41 | {
42 | name: "CUM_LOG_PROBS"
43 | data_type: TYPE_FP32
44 | dims: [ -1 ]
45 | },
46 | {
47 | name: "OUTPUT_LOG_PROBS"
48 | data_type: TYPE_FP32
49 | dims: [ -1, -1 ]
50 | },
51 | {
52 | name: "CONTEXT_LOGITS"
53 | data_type: TYPE_FP32
54 | dims: [ -1, -1 ]
55 | optional: true
56 | },
57 | {
58 | name: "GENERATION_LOGITS"
59 | data_type: TYPE_FP32
60 | dims: [ -1, -1, -1 ]
61 | optional: true
62 | }
63 | ]
64 | output [
65 | {
66 | name: "OUTPUT"
67 | data_type: TYPE_STRING
68 | dims: [ -1 ]
69 | },
70 | {
71 | name: "OUT_CUM_LOG_PROBS"
72 | data_type: TYPE_FP32
73 | dims: [ -1 ]
74 | },
75 | {
76 | name: "OUT_OUTPUT_LOG_PROBS"
77 | data_type: TYPE_FP32
78 | dims: [ -1, -1 ]
79 | },
80 | {
81 | name: "OUT_CONTEXT_LOGITS"
82 | data_type: TYPE_FP32
83 | dims: [ -1, -1 ]
84 | },
85 | {
86 | name: "OUT_GENERATION_LOGITS"
87 | data_type: TYPE_FP32
88 | dims: [ -1, -1, -1 ]
89 | }
90 | ]
91 |
92 | parameters {
93 | key: "tokenizer_dir"
94 | value: {
95 | string_value: "/tensorrtllm_backend/triton_model_repo/tensorrt_llm/qwen1.5_7b_chat"
96 | }
97 | }
98 |
99 | parameters {
100 | key: "tokenizer_type"
101 | value: {
102 | string_value: "auto"
103 | }
104 | }
105 |
106 | parameters {
107 | key: "skip_special_tokens"
108 | value: {
109 | string_value: "True"
110 | }
111 | }
112 |
113 | instance_group [
114 | {
115 | count: 4
116 | kind: KIND_CPU
117 | }
118 | ]
119 |
--------------------------------------------------------------------------------
/triton_model_repo/preprocessing/config.pbtxt:
--------------------------------------------------------------------------------
1 | # Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Redistribution and use in source and binary forms, with or without
4 | # modification, are permitted provided that the following conditions
5 | # are met:
6 | # * Redistributions of source code must retain the above copyright
7 | # notice, this list of conditions and the following disclaimer.
8 | # * Redistributions in binary form must reproduce the above copyright
9 | # notice, this list of conditions and the following disclaimer in the
10 | # documentation and/or other materials provided with the distribution.
11 | # * Neither the name of NVIDIA CORPORATION nor the names of its
12 | # contributors may be used to endorse or promote products derived
13 | # from this software without specific prior written permission.
14 | #
15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 |
27 | name: "preprocessing"
28 | backend: "python"
29 | max_batch_size: 2
30 | input [
31 | {
32 | name: "QUERY"
33 | data_type: TYPE_STRING
34 | dims: [ -1 ]
35 | },
36 | {
37 | name: "REQUEST_OUTPUT_LEN"
38 | data_type: TYPE_INT32
39 | dims: [ -1 ]
40 | },
41 | {
42 | name: "BAD_WORDS_DICT"
43 | data_type: TYPE_STRING
44 | dims: [ -1 ]
45 | optional: true
46 | },
47 | {
48 | name: "STOP_WORDS_DICT"
49 | data_type: TYPE_STRING
50 | dims: [ -1 ]
51 | optional: true
52 | },
53 | {
54 | name: "EMBEDDING_BIAS_WORDS"
55 | data_type: TYPE_STRING
56 | dims: [ -1 ]
57 | optional: true
58 | },
59 | {
60 | name: "EMBEDDING_BIAS_WEIGHTS"
61 | data_type: TYPE_FP32
62 | dims: [ -1 ]
63 | optional: true
64 | },
65 | {
66 | name: "END_ID"
67 | data_type: TYPE_INT32
68 | dims: [ -1 ]
69 | optional: true
70 | },
71 | {
72 | name: "PAD_ID"
73 | data_type: TYPE_INT32
74 | dims: [ -1 ]
75 | optional: true
76 | }
77 | ]
78 | output [
79 | {
80 | name: "INPUT_ID"
81 | data_type: TYPE_INT32
82 | dims: [ -1 ]
83 | },
84 | {
85 | name: "REQUEST_INPUT_LEN"
86 | data_type: TYPE_INT32
87 | dims: [ 1 ]
88 | },
89 | {
90 | name: "BAD_WORDS_IDS"
91 | data_type: TYPE_INT32
92 | dims: [ 2, -1 ]
93 | },
94 | {
95 | name: "STOP_WORDS_IDS"
96 | data_type: TYPE_INT32
97 | dims: [ 2, -1 ]
98 | },
99 | {
100 | name: "EMBEDDING_BIAS"
101 | data_type: TYPE_FP32
102 | dims: [ -1 ]
103 | },
104 | {
105 | name: "REQUEST_OUTPUT_LEN"
106 | data_type: TYPE_INT32
107 | dims: [ -1 ]
108 | },
109 | {
110 | name: "OUT_END_ID"
111 | data_type: TYPE_INT32
112 | dims: [ -1 ]
113 | },
114 | {
115 | name: "OUT_PAD_ID"
116 | data_type: TYPE_INT32
117 | dims: [ -1 ]
118 | }
119 | ]
120 |
121 | parameters {
122 | key: "tokenizer_dir"
123 | value: {
124 | string_value: "/tensorrtllm_backend/triton_model_repo/tensorrt_llm/qwen1.5_7b_chat"
125 | }
126 | }
127 |
128 | parameters {
129 | key: "tokenizer_type"
130 | value: {
131 | string_value: "auto"
132 | }
133 | }
134 |
135 | parameters {
136 | key: "add_special_tokens"
137 | value: {
138 | string_value: "False"
139 | }
140 | }
141 |
142 | instance_group [
143 | {
144 | count: 4
145 | kind: KIND_CPU
146 | }
147 | ]
148 |
--------------------------------------------------------------------------------
/triton_model_repo/tensorrt_llm/1/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tlntin/Qwen-TensorRT-LLM/7da636fe7d55f42cebf3f2a43931dd0f1619efee/triton_model_repo/tensorrt_llm/1/.gitkeep
--------------------------------------------------------------------------------
/triton_model_repo/tensorrt_llm/1/.tmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Tlntin/Qwen-TensorRT-LLM/7da636fe7d55f42cebf3f2a43931dd0f1619efee/triton_model_repo/tensorrt_llm/1/.tmp
--------------------------------------------------------------------------------
/triton_model_repo/tensorrt_llm/config.pbtxt:
--------------------------------------------------------------------------------
1 | # Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Redistribution and use in source and binary forms, with or without
4 | # modification, are permitted provided that the following conditions
5 | # are met:
6 | # * Redistributions of source code must retain the above copyright
7 | # notice, this list of conditions and the following disclaimer.
8 | # * Redistributions in binary form must reproduce the above copyright
9 | # notice, this list of conditions and the following disclaimer in the
10 | # documentation and/or other materials provided with the distribution.
11 | # * Neither the name of NVIDIA CORPORATION nor the names of its
12 | # contributors may be used to endorse or promote products derived
13 | # from this software without specific prior written permission.
14 | #
15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 |
27 | name: "tensorrt_llm"
28 | backend: "tensorrtllm"
29 | max_batch_size: 1
30 |
31 | model_transaction_policy {
32 | decoupled: True
33 | }
34 |
35 | dynamic_batching {
36 | preferred_batch_size: [ 1 ]
37 | max_queue_delay_microseconds: 600
38 | }
39 |
40 | input [
41 | {
42 | name: "input_ids"
43 | data_type: TYPE_INT32
44 | dims: [ -1 ]
45 | allow_ragged_batch: true
46 | },
47 | {
48 | name: "input_lengths"
49 | data_type: TYPE_INT32
50 | dims: [ 1 ]
51 | reshape: { shape: [ ] }
52 | },
53 | {
54 | name: "request_output_len"
55 | data_type: TYPE_INT32
56 | dims: [ 1 ]
57 | },
58 | {
59 | name: "draft_input_ids"
60 | data_type: TYPE_INT32
61 | dims: [ -1 ]
62 | optional: true
63 | allow_ragged_batch: true
64 | },
65 | {
66 | name: "end_id"
67 | data_type: TYPE_INT32
68 | dims: [ 1 ]
69 | reshape: { shape: [ ] }
70 | optional: true
71 | },
72 | {
73 | name: "pad_id"
74 | data_type: TYPE_INT32
75 | dims: [ 1 ]
76 | reshape: { shape: [ ] }
77 | optional: true
78 | },
79 | {
80 | name: "stop_words_list"
81 | data_type: TYPE_INT32
82 | dims: [ 2, -1 ]
83 | optional: true
84 | allow_ragged_batch: true
85 | },
86 | {
87 | name: "bad_words_list"
88 | data_type: TYPE_INT32
89 | dims: [ 2, -1 ]
90 | optional: true
91 | allow_ragged_batch: true
92 | },
93 | {
94 | name: "embedding_bias"
95 | data_type: TYPE_FP32
96 | dims: [ -1 ]
97 | optional: true
98 | allow_ragged_batch: true
99 | },
100 | {
101 | name: "beam_width"
102 | data_type: TYPE_INT32
103 | dims: [ 1 ]
104 | reshape: { shape: [ ] }
105 | optional: true
106 | },
107 | {
108 | name: "temperature"
109 | data_type: TYPE_FP32
110 | dims: [ 1 ]
111 | reshape: { shape: [ ] }
112 | optional: true
113 | },
114 | {
115 | name: "runtime_top_k"
116 | data_type: TYPE_INT32
117 | dims: [ 1 ]
118 | reshape: { shape: [ ] }
119 | optional: true
120 | },
121 | {
122 | name: "runtime_top_p"
123 | data_type: TYPE_FP32
124 | dims: [ 1 ]
125 | reshape: { shape: [ ] }
126 | optional: true
127 | },
128 | {
129 | name: "len_penalty"
130 | data_type: TYPE_FP32
131 | dims: [ 1 ]
132 | reshape: { shape: [ ] }
133 | optional: true
134 | },
135 | {
136 | name: "repetition_penalty"
137 | data_type: TYPE_FP32
138 | dims: [ 1 ]
139 | reshape: { shape: [ ] }
140 | optional: true
141 | },
142 | {
143 | name: "min_length"
144 | data_type: TYPE_INT32
145 | dims: [ 1 ]
146 | reshape: { shape: [ ] }
147 | optional: true
148 | },
149 | {
150 | name: "presence_penalty"
151 | data_type: TYPE_FP32
152 | dims: [ 1 ]
153 | reshape: { shape: [ ] }
154 | optional: true
155 | },
156 | {
157 | name: "frequency_penalty"
158 | data_type: TYPE_FP32
159 | dims: [ 1 ]
160 | reshape: { shape: [ ] }
161 | optional: true
162 | },
163 | {
164 | name: "random_seed"
165 | data_type: TYPE_UINT64
166 | dims: [ 1 ]
167 | reshape: { shape: [ ] }
168 | optional: true
169 | },
170 | {
171 | name: "return_log_probs"
172 | data_type: TYPE_BOOL
173 | dims: [ 1 ]
174 | reshape: { shape: [ ] }
175 | optional: true
176 | },
177 | {
178 | name: "return_context_logits"
179 | data_type: TYPE_BOOL
180 | dims: [ 1 ]
181 | reshape: { shape: [ ] }
182 | optional: true
183 | },
184 | {
185 | name: "return_generation_logits"
186 | data_type: TYPE_BOOL
187 | dims: [ 1 ]
188 | reshape: { shape: [ ] }
189 | optional: true
190 | },
191 | {
192 | name: "stop"
193 | data_type: TYPE_BOOL
194 | dims: [ 1 ]
195 | optional: true
196 | },
197 | {
198 | name: "streaming"
199 | data_type: TYPE_BOOL
200 | dims: [ 1 ]
201 | optional: true
202 | },
203 | {
204 | name: "prompt_embedding_table"
205 | data_type: TYPE_FP16
206 | dims: [ -1, -1 ]
207 | optional: true
208 | allow_ragged_batch: true
209 | },
210 | {
211 | name: "prompt_vocab_size"
212 | data_type: TYPE_INT32
213 | dims: [ 1 ]
214 | reshape: { shape: [ ] }
215 | optional: true
216 | },
217 | # weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ]
218 | # where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer
219 | # each of the in / out tensors are first flattened and then concatenated together in the format above.
220 | # D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out.
221 | {
222 | name: "lora_weights"
223 | data_type: TYPE_FP16
224 | dims: [ -1, -1 ]
225 | optional: true
226 | allow_ragged_batch: true
227 | },
228 | # module identifier (same size a first dimension of lora_weights)
229 | # See LoraModule::ModuleType for model id mapping
230 | #
231 | # "attn_qkv": 0 # compbined qkv adapter
232 | # "attn_q": 1 # q adapter
233 | # "attn_k": 2 # k adapter
234 | # "attn_v": 3 # v adapter
235 | # "attn_dense": 4 # adapter for the dense layer in attention
236 | # "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection
237 | # "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection
238 | # "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate
239 | #
240 | # last dim holds [ module_id, layer_idx, adapter_size (D aka R value) ]
241 | {
242 | name: "lora_config"
243 | data_type: TYPE_INT32
244 | dims: [ -1, 3 ]
245 | optional: true
246 | allow_ragged_batch: true
247 | }
248 | ]
249 | output [
250 | {
251 | name: "output_ids"
252 | data_type: TYPE_INT32
253 | dims: [ -1, -1 ]
254 | },
255 | {
256 | name: "sequence_length"
257 | data_type: TYPE_INT32
258 | dims: [ -1 ]
259 | },
260 | {
261 | name: "cum_log_probs"
262 | data_type: TYPE_FP32
263 | dims: [ -1 ]
264 | },
265 | {
266 | name: "output_log_probs"
267 | data_type: TYPE_FP32
268 | dims: [ -1, -1 ]
269 | },
270 | {
271 | name: "context_logits"
272 | data_type: TYPE_FP32
273 | dims: [ -1, -1 ]
274 | },
275 | {
276 | name: "generation_logits"
277 | data_type: TYPE_FP32
278 | dims: [ -1, -1, -1 ]
279 | }
280 | ]
281 | instance_group [
282 | {
283 | count: 1
284 | kind : KIND_CPU
285 | }
286 | ]
287 | parameters: {
288 | key: "max_beam_width"
289 | value: {
290 | string_value: "1"
291 | }
292 | }
293 | parameters: {
294 | key: "FORCE_CPU_ONLY_INPUT_TENSORS"
295 | value: {
296 | string_value: "no"
297 | }
298 | }
299 | parameters: {
300 | key: "gpt_model_type"
301 | value: {
302 | string_value: "inflight_batching"
303 | }
304 | }
305 | parameters: {
306 | key: "gpt_model_path"
307 | value: {
308 | string_value: "/tensorrtllm_backend/triton_model_repo/tensorrt_llm/1"
309 | }
310 | }
311 | parameters: {
312 | key: "max_tokens_in_paged_kv_cache"
313 | value: {
314 | string_value: "${max_tokens_in_paged_kv_cache}"
315 | }
316 | }
317 | parameters: {
318 | key: "max_attention_window_size"
319 | value: {
320 | string_value: "6144"
321 | }
322 | }
323 | parameters: {
324 | key: "batch_scheduler_policy"
325 | value: {
326 | string_value: "${batch_scheduler_policy}"
327 | }
328 | }
329 | parameters: {
330 | key: "kv_cache_free_gpu_mem_fraction"
331 | value: {
332 | string_value: "0.9"
333 | }
334 | }
335 | parameters: {
336 | key: "enable_trt_overlap"
337 | value: {
338 | string_value: "${enable_trt_overlap}"
339 | }
340 | }
341 | parameters: {
342 | key: "exclude_input_in_output"
343 | value: {
344 | string_value: "True"
345 | }
346 | }
347 | parameters: {
348 | key: "enable_kv_cache_reuse"
349 | value: {
350 | string_value: "False"
351 | }
352 | }
353 | parameters: {
354 | key: "normalize_log_probs"
355 | value: {
356 | string_value: "${normalize_log_probs}"
357 | }
358 | }
359 | parameters: {
360 | key: "enable_chunked_context"
361 | value: {
362 | string_value: "${enable_chunked_context}"
363 | }
364 | }
365 | parameters: {
366 | key: "gpu_device_ids"
367 | value: {
368 | string_value: "0"
369 | }
370 | }
371 |
--------------------------------------------------------------------------------
/triton_model_repo/tensorrt_llm_bls/config.pbtxt:
--------------------------------------------------------------------------------
1 | # Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | #
3 | # Redistribution and use in source and binary forms, with or without
4 | # modification, are permitted provided that the following conditions
5 | # are met:
6 | # * Redistributions of source code must retain the above copyright
7 | # notice, this list of conditions and the following disclaimer.
8 | # * Redistributions in binary form must reproduce the above copyright
9 | # notice, this list of conditions and the following disclaimer in the
10 | # documentation and/or other materials provided with the distribution.
11 | # * Neither the name of NVIDIA CORPORATION nor the names of its
12 | # contributors may be used to endorse or promote products derived
13 | # from this software without specific prior written permission.
14 | #
15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 |
27 | name: "tensorrt_llm_bls"
28 | backend: "python"
29 | max_batch_size: 2
30 |
31 | model_transaction_policy {
32 | decoupled: True
33 | }
34 |
35 | input [
36 | {
37 | name: "text_input"
38 | data_type: TYPE_STRING
39 | dims: [ -1 ]
40 | },
41 | {
42 | name: "max_tokens"
43 | data_type: TYPE_INT32
44 | dims: [ -1 ]
45 | },
46 | {
47 | name: "bad_words"
48 | data_type: TYPE_STRING
49 | dims: [ -1 ]
50 | optional: true
51 | },
52 | {
53 | name: "stop_words"
54 | data_type: TYPE_STRING
55 | dims: [ -1 ]
56 | optional: true
57 | },
58 | {
59 | name: "end_id"
60 | data_type: TYPE_INT32
61 | dims: [ 1 ]
62 | optional: true
63 | },
64 | {
65 | name: "pad_id"
66 | data_type: TYPE_INT32
67 | dims: [ 1 ]
68 | optional: true
69 | },
70 | {
71 | name: "top_k"
72 | data_type: TYPE_INT32
73 | dims: [ 1 ]
74 | optional: true
75 | },
76 | {
77 | name: "top_p"
78 | data_type: TYPE_FP32
79 | dims: [ 1 ]
80 | optional: true
81 | },
82 | {
83 | name: "temperature"
84 | data_type: TYPE_FP32
85 | dims: [ 1 ]
86 | optional: true
87 | },
88 | {
89 | name: "length_penalty"
90 | data_type: TYPE_FP32
91 | dims: [ 1 ]
92 | optional: true
93 | },
94 | {
95 | name: "repetition_penalty"
96 | data_type: TYPE_FP32
97 | dims: [ 1 ]
98 | optional: true
99 | },
100 | {
101 | name: "min_length"
102 | data_type: TYPE_INT32
103 | dims: [ 1 ]
104 | optional: true
105 | },
106 | {
107 | name: "presence_penalty"
108 | data_type: TYPE_FP32
109 | dims: [ 1 ]
110 | optional: true
111 | },
112 | {
113 | name: "frequency_penalty"
114 | data_type: TYPE_FP32
115 | dims: [ 1 ]
116 | optional: true
117 | },
118 | {
119 | name: "random_seed"
120 | data_type: TYPE_UINT64
121 | dims: [ 1 ]
122 | optional: true
123 | },
124 | {
125 | name: "return_log_probs"
126 | data_type: TYPE_BOOL
127 | dims: [ 1 ]
128 | optional: true
129 | },
130 | {
131 | name: "return_context_logits"
132 | data_type: TYPE_BOOL
133 | dims: [ 1 ]
134 | reshape: { shape: [ ] }
135 | optional: true
136 | },
137 | {
138 | name: "return_generation_logits"
139 | data_type: TYPE_BOOL
140 | dims: [ 1 ]
141 | reshape: { shape: [ ] }
142 | optional: true
143 | },
144 | {
145 | name: "beam_width"
146 | data_type: TYPE_INT32
147 | dims: [ 1 ]
148 | optional: true
149 | },
150 | {
151 | name: "stream"
152 | data_type: TYPE_BOOL
153 | dims: [ 1 ]
154 | optional: true
155 | },
156 | {
157 | name: "prompt_embedding_table"
158 | data_type: TYPE_FP16
159 | dims: [ -1, -1 ]
160 | optional: true
161 | },
162 | {
163 | name: "prompt_vocab_size"
164 | data_type: TYPE_INT32
165 | dims: [ 1 ]
166 | optional: true
167 | },
168 | {
169 | name: "embedding_bias_words"
170 | data_type: TYPE_STRING
171 | dims: [ -1 ]
172 | optional: true
173 | },
174 | {
175 | name: "embedding_bias_weights"
176 | data_type: TYPE_FP32
177 | dims: [ -1 ]
178 | optional: true
179 | }
180 | ]
181 | output [
182 | {
183 | name: "text_output"
184 | data_type: TYPE_STRING
185 | dims: [ -1 ]
186 | },
187 | {
188 | name: "cum_log_probs"
189 | data_type: TYPE_FP32
190 | dims: [ -1 ]
191 | },
192 | {
193 | name: "output_log_probs"
194 | data_type: TYPE_FP32
195 | dims: [ -1, -1 ]
196 | },
197 | {
198 | name: "context_logits"
199 | data_type: TYPE_FP32
200 | dims: [ -1, -1 ]
201 | },
202 | {
203 | name: "generation_logits"
204 | data_type: TYPE_FP32
205 | dims: [ -1, -1, -1 ]
206 | }
207 | ]
208 |
209 | parameters: {
210 | key: "accumulate_tokens"
211 | value: {
212 | string_value: "True"
213 | }
214 | }
215 |
216 | instance_group [
217 | {
218 | count: 4
219 | kind : KIND_CPU
220 | }
221 | ]
222 |
--------------------------------------------------------------------------------