├── .gitignore
├── LICENSE
├── README-en.md
├── README.md
├── chatglm
├── chatbot_webui.py
├── chatbot_with_memory.ipynb
├── chatglm_inference.ipynb
├── data
│ ├── raw_data.txt
│ ├── zhouyi_dataset_20240118_152413.csv
│ ├── zhouyi_dataset_20240118_163659.csv
│ └── zhouyi_dataset_handmade.csv
├── gen_dataset.ipynb
├── qlora_chatglm3.ipynb
└── qlora_chatglm3_timestamp.ipynb
├── deepspeed
├── README.md
├── config
│ ├── ds_config_zero2.json
│ └── ds_config_zero3.json
├── train_on_multi_nodes.sh
├── train_on_one_gpu.sh
└── translation
│ ├── README.md
│ ├── requirements.txt
│ └── run_translation.py
├── docs
├── INSTALL.md
├── cuda_installation.png
├── version_check.py
└── version_info.txt
├── langchain
├── chains
│ ├── router_chain.ipynb
│ ├── sequential_chain.ipynb
│ └── transform_chain.ipynb
├── data_connection
│ ├── document_loader.ipynb
│ ├── document_transformer.ipynb
│ ├── text_embedding.ipynb
│ └── vector_stores.ipynb
├── images
│ ├── llm_chain.png
│ ├── memory.png
│ ├── model_io.jpeg
│ ├── router_chain.png
│ ├── sequential_chain_0.png
│ ├── simple_sequential_chain_0.png
│ ├── simple_sequential_chain_1.png
│ └── transform_chain.png
├── memory
│ └── memory.ipynb
├── model_io
│ ├── model.ipynb
│ ├── output_parser.ipynb
│ └── prompt.ipynb
└── tests
│ ├── state_of_the_union.txt
│ └── the_old_man_and_the_sea.txt
├── llama
├── llama2_inference.ipynb
└── llama2_instruction_tuning.ipynb
├── peft
├── chatglm3.ipynb
├── data
│ └── audio
│ │ └── test_zh.flac
├── peft_chatglm_inference.ipynb
├── peft_lora_opt-6.7b.ipynb
├── peft_lora_whisper-large-v2.ipynb
├── peft_qlora_chatglm.ipynb
└── whisper_eval.ipynb
├── quantization
├── AWQ-opt-125m.ipynb
├── AWQ_opt-2.7b.ipynb
├── AutoGPTQ_opt-2.7b.ipynb
├── bits_and_bytes.ipynb
└── docs
│ └── images
│ └── qlora.png
├── requirements.txt
└── transformers
├── data
├── audio
│ └── mlk.flac
└── image
│ ├── cat-chonk.jpeg
│ ├── cat_dog.jpg
│ └── panda.jpg
├── docs
└── images
│ ├── bert-base-chinese.png
│ ├── bert.png
│ ├── bert_pretrain.png
│ ├── full_nlp_pipeline.png
│ ├── gpt2.png
│ ├── pipeline_advanced.png
│ ├── pipeline_func.png
│ └── question_answering.png
├── fine-tune-QA.ipynb
├── fine-tune-quickstart.ipynb
├── pipelines.ipynb
└── pipelines_advanced.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | # Customization
163 |
164 | nohup.out
165 | */models/
166 | peft/temp/
167 | deepspeed/output_dir
168 | RLHF
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README-en.md:
--------------------------------------------------------------------------------
1 | # LLM Quick Start
2 |
3 | 
4 | 
5 | 
6 | 
7 | 
8 | 
9 | 
10 | 
11 |
12 |
13 |
English | 中文
14 |
15 |
16 | Quick Start for Large Language Models (Theoretical Learning and Practical Fine-tuning)
17 |
18 |
19 | ## Setting Up the Development Environment
20 |
21 | - Python v3.10+
22 | - Python Environment Management: [Miniconda](https://docs.conda.io/projects/miniconda/en/latest/)
23 | - Interactive Python Development Environment: [Jupyter Lab](https://jupyterlab.readthedocs.io/en/stable/getting_started/installation.html)
24 | - [Hugging Face Transformers](https://huggingface.co/docs/transformers/installation#install-with-conda)
25 | - [Audio processing toolkit ffmpeg](https://phoenixnap.com/kb/install-ffmpeg-ubuntu)
26 |
27 | For detailed installation instructions, please refer to [Documentation](docs/INSTALL.md)
28 |
29 | ### Installing Python Dependencies
30 |
31 | Please use the `requirements.txt` file to install Python dependencies:
32 |
33 | ```shell
34 | pip install -r requirements.txt
35 | ```
36 | The currently supported list of software versions for project operation is as follows, see [Version Comparison Document](docs/version_info.txt) for details:
37 |
38 | ```
39 | torch>=2.1.2==2.3.0.dev20240116+cu121
40 | transformers==4.37.2
41 | ffmpeg==1.4
42 | ffmpeg-python==0.2.0
43 | timm==0.9.12
44 | datasets==2.16.1
45 | evaluate==0.4.1
46 | scikit-learn==1.3.2
47 | pandas==2.1.1
48 | peft==0.7.2.dev0
49 | accelerate==0.26.1
50 | autoawq==0.2.2
51 | optimum==1.17.0.dev0
52 | auto-gptq==0.6.0
53 | bitsandbytes>0.39.0==0.41.3.post2
54 | jiwer==3.0.3
55 | soundfile>=0.12.1==0.12.1
56 | librosa==0.10.1
57 | langchain==0.1.0
58 | gradio==4.13.0
59 | ```
60 |
61 | To check if the software versions in your runtime environment match, the project provides an automated [Version Check Script](docs/version_check.py), please be sure to modify the output file name.
62 |
63 | ### About GPU Drivers and CUDA Versions
64 |
65 | Typically, GPU drivers and CUDA versions need to meet the requirements of the installed PyTorch and TensorFlow versions.
66 |
67 | Most recently released large language models use newer versions of PyTorch, such as PyTorch v2.0+. According to the PyTorch official documentation, the minimum required CUDA version is 11.8, along with a matching GPU driver version. You can find more details in the [PyTorch official CUDA version requirements](https://pytorch.org/get-started/pytorch-2.0/#faqs).
68 |
69 | In summary, it is recommended to directly install the latest CUDA 12.3 version. You can find the installation packages on the [Nvidia official website](https://developer.nvidia.com/cuda-downloads).
70 |
71 |
72 | After installation, use the `nvidia-smi` command to check the version:
73 |
74 | ```shell
75 | nvidia-smi
76 | Mon Dec 18 12:10:47 2023
77 | +---------------------------------------------------------------------------------------+
78 | | NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |
79 | |-----------------------------------------+----------------------+----------------------+
80 | | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
81 | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
82 | | | | MIG M. |
83 | |=========================================+======================+======================|
84 | | 0 Tesla T4 Off | 00000000:00:0D.0 Off | 0 |
85 | | N/A 44C P0 26W / 70W | 2MiB / 15360MiB | 6% Default |
86 | | | | N/A |
87 | +-----------------------------------------+----------------------+----------------------+
88 |
89 | +---------------------------------------------------------------------------------------+
90 | | Processes: |
91 | | GPU GI CI PID Type Process name GPU Memory |
92 | | ID ID Usage |
93 | |=======================================================================================|
94 | | No running processes found |
95 | +---------------------------------------------------------------------------------------+
96 | ```
97 |
98 |
99 | ### Configuring Jupyter Lab for Background Startup
100 |
101 | After installing the development environment as mentioned above, it's recommended to start Jupyter Lab as a background service. Here's how to configure it (using the root user as an example):
102 |
103 | ```shell
104 | # Generate a Jupyter Lab configuration file
105 | $ jupyter lab --generate-config
106 | Writing default config to: /root/.jupyter/jupyter_lab_config.py
107 | ```
108 |
109 | Open the configuration file and make the following changes:
110 |
111 | ```python
112 | # Allowing Jupyter Lab to start as a non-root user (no need to modify if starting as root)
113 | c.ServerApp.allow_root = True
114 | c.ServerApp.ip = '*'
115 | ```
116 |
117 | Use `nohup` to start Jupyter Lab in the background:
118 |
119 | ```shell
120 | $ nohup jupyter lab --port=8000 --NotebookApp.token='replace_with_your_password' --notebook-dir=./ &
121 | ```
122 |
123 | Jupyter Lab's output log will be saved in the `nohup.out` file (which is already filtered in the `.gitignore` file).
124 |
125 | ### Configuration for calling OpenAI GPT API in LangChain
126 |
127 | In order to use the OpenAI API, you need to have an API key which can be obtained from the OpenAI dashboard. Once you have the key, you can set it as an environment variable:
128 |
129 | For Unix-based systems (like Ubuntu or MacOS), you can run the following command in your terminal:
130 |
131 | ```bash
132 | export OPENAI_API_KEY='your-api-key'
133 | ```
134 |
135 | For Windows, you can use the following command in the Command Prompt:
136 |
137 | ```
138 | set OPENAI_API_KEY=your-api-key
139 | ```
140 |
141 | Make sure to replace `'your-api-key'` with your actual OpenAI API key.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 大模型(LLMs)微调训练 快速入门指南
2 |
3 | 
4 | 
5 | 
6 | 
7 | 
8 | 
9 | 
10 | 
11 |
12 |
13 |
中文 | English
14 |
15 |
16 |
17 | 大语言模型快速入门(理论学习与微调实战)
18 |
19 | ## 拉取代码
20 |
21 | 你可以通过克隆此仓库到 GPU 服务器来开始学习:
22 |
23 | ```shell
24 | git clone https://github.com/DjangoPeng/LLM-quickstart.git
25 | ```
26 |
27 | ## 搭建开发环境
28 |
29 | 本项目对于硬件有一定要求:GPU 显存不小于16GB,支持最低配置显卡型号为 NVIDIA Tesla T4。
30 |
31 | 建议使用 GPU 云服务器来进行模型训练和微调。
32 |
33 | 项目使用 Python 版本为 3.10,环境关键依赖的官方文档如下:
34 |
35 | - Python 环境管理 [Miniconda](https://docs.conda.io/projects/miniconda/en/latest/)
36 | - Python 交互式开发环境 [Jupyter Lab](https://jupyterlab.readthedocs.io/en/stable/getting_started/installation.html)
37 |
38 |
39 | **以下是详细的安装指导(以 Ubuntu 22.04 操作系统为例)**:
40 |
41 | ### 安装 CUDA Toolkit 和 GPU 驱动
42 |
43 | 根据你的实际情况,找到对应的 [CUDA 12.04](https://developer.nvidia.com/cuda-downloads?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=22.04&target_type=runfile_local):
44 |
45 | 下载并安装 CUDA 12.04 Toolkit(包含GPU驱动):
46 |
47 | ```shell
48 | wget https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_550.54.14_linux.run
49 | sudo sh cuda_12.4.0_550.54.14_linux.run
50 | ```
51 |
52 | **注意使用`runfile`方式,可以连同版本匹配的 GPU 驱动一起安装好。
53 |
54 | 
55 |
56 | 安装完成后,使用 `nvidia-smi` 指令查看版本:
57 |
58 | ```shell
59 | nvidia-smi
60 | Mon Dec 18 12:10:47 2023
61 | +---------------------------------------------------------------------------------------+
62 | | NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |
63 | |-----------------------------------------+----------------------+----------------------+
64 | | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
65 | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
66 | | | | MIG M. |
67 | |=========================================+======================+======================|
68 | | 0 Tesla T4 Off | 00000000:00:0D.0 Off | 0 |
69 | | N/A 44C P0 26W / 70W | 2MiB / 15360MiB | 6% Default |
70 | | | | N/A |
71 | +-----------------------------------------+----------------------+----------------------+
72 |
73 | +---------------------------------------------------------------------------------------+
74 | | Processes: |
75 | | GPU GI CI PID Type Process name GPU Memory |
76 | | ID ID Usage |
77 | |=======================================================================================|
78 | | No running processes found |
79 | +---------------------------------------------------------------------------------------+
80 | ```
81 |
82 | ### 安装操作系统级软件依赖
83 |
84 | ```shell
85 | sudo apt update && sudo apt upgrade
86 | sudo apt install ffmpeg
87 | ## 检查是否安装成功
88 | ffmpeg -version
89 | ```
90 |
91 | 参考:[音频工具包 ffmpeg 官方安装文档](https://phoenixnap.com/kb/install-ffmpeg-ubuntu)
92 |
93 |
94 | ### 安装 Python 环境管理工具 Miniconda
95 |
96 | ```shell
97 | mkdir -p ~/miniconda3
98 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
99 | bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
100 | rm -rf ~/miniconda3/miniconda.sh
101 | ```
102 |
103 | 安装完成后,建议新建一个 Python 虚拟环境,命名为 `peft`。
104 |
105 | ```shell
106 | conda create -n peft python=3.10
107 |
108 | # 激活环境
109 | conda activate peft
110 | ```
111 |
112 | 之后每次使用需要激活此环境。
113 |
114 |
115 | ### 安装 Python 依赖软件包
116 |
117 | 完整 Python 依赖软件包见[requirements.txt](requirements.txt)。
118 |
119 | ```shell
120 | pip install -r requirements.txt
121 | ```
122 |
123 |
124 | ### 安装和配置 Jupyter Lab
125 |
126 | 上述开发环境安装完成后,使用 Miniconda 安装 Jupyter Lab:
127 |
128 | ```shell
129 | conda install -c conda-forge jupyterlab
130 | ```
131 |
132 | 使用 Jupyter Lab 开发的最佳实践是后台常驻,下面是相关配置(以 root 用户为例):
133 |
134 | ```shell
135 | # 生成 Jupyter Lab 配置文件,
136 | jupyter lab --generate-config
137 | ```
138 |
139 | 打开上面执行输出的`jupyter_lab_config.py`配置文件后,修改以下配置项:
140 |
141 | ```python
142 | c.ServerApp.allow_root = True # 非 root 用户启动,无需修改
143 | c.ServerApp.ip = '*'
144 | ```
145 |
146 | 使用 nohup 后台启动 Jupyter Lab
147 | ```shell
148 | $ nohup jupyter lab --port=8000 --NotebookApp.token='替换为你的密码' --notebook-dir=./ &
149 | ```
150 |
151 | Jupyter Lab 输出的日志将会保存在 `nohup.out` 文件(已在 .gitignore中过滤)。
152 |
153 |
154 | ### 关于 LangChain 调用 OpenAI GPT API 的配置
155 |
156 | 为了使用OpenAI API,你需要从OpenAI控制台获取一个API密钥。一旦你有了密钥,你可以将其设置为环境变量:
157 |
158 | 对于基于Unix的系统(如Ubuntu或MacOS),你可以在终端中运行以下命令:
159 |
160 | ```bash
161 | export OPENAI_API_KEY='你的-api-key'
162 | ```
163 |
164 | 对于Windows,你可以在命令提示符中使用以下命令:
165 |
166 | ```bash
167 | set OPENAI_API_KEY=你的-api-key
168 | ```
169 |
170 | 请确保将`'你的-api-key'`替换为你的实际OpenAI API密钥。
171 |
--------------------------------------------------------------------------------
/chatglm/chatbot_webui.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 |
3 | from langchain_community.llms import ChatGLM
4 | from langchain.chains import ConversationChain
5 | from langchain.memory import ConversationBufferMemory
6 |
7 | CHATGLM_URL = "http://127.0.0.1:8001"
8 |
9 | def init_chatbot():
10 | llm = ChatGLM(
11 | endpoint_url=CHATGLM_URL,
12 | max_token=80000,
13 | history=[],
14 | top_p=0.9,
15 | model_kwargs={"sample_model_args": False},
16 | )
17 | global CHATGLM_CHATBOT
18 | CHATGLM_CHATBOT = ConversationChain(llm=llm,
19 | verbose=True,
20 | memory=ConversationBufferMemory())
21 | return CHATGLM_CHATBOT
22 |
23 | def chatglm_chat(message, history):
24 | ai_message = CHATGLM_CHATBOT.predict(input = message)
25 | return ai_message
26 |
27 | def launch_gradio():
28 | demo = gr.ChatInterface(
29 | fn=chatglm_chat,
30 | title="ChatBot (Powered by ChatGLM)",
31 | chatbot=gr.Chatbot(height=600),
32 | )
33 |
34 | demo.launch(share=True, server_name="0.0.0.0")
35 |
36 | if __name__ == "__main__":
37 | # 初始化聊天机器人
38 | init_chatbot()
39 | # 启动 Gradio 服务
40 | launch_gradio()
41 |
--------------------------------------------------------------------------------
/chatglm/chatbot_with_memory.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "1af0fbae-98f6-4a91-a0a8-cefed3ff445b",
6 | "metadata": {},
7 | "source": [
8 | "# LangChain 调用私有化 ChatGLM 模型\n",
9 | "\n",
10 | "## LLMChain 实现单轮对话"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 1,
16 | "id": "d04b3bdb-98ca-4b02-94d6-0871e1bea6a1",
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "from langchain_community.llms import ChatGLM\n",
21 | "from langchain.chains import LLMChain\n",
22 | "from langchain.prompts import PromptTemplate"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 2,
28 | "id": "5b32e8d2-50b8-4566-9081-716aa5df5d8c",
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "# ChatGLM 私有化部署的 Endpoint URL\n",
33 | "endpoint_url = \"http://127.0.0.1:8001\""
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 4,
39 | "id": "8208a4a1-7863-414b-b2a9-2f86d6ddcf5a",
40 | "metadata": {},
41 | "outputs": [],
42 | "source": [
43 | "# 实例化 ChatGLM 大模型\n",
44 | "llm = ChatGLM(\n",
45 | " endpoint_url=endpoint_url,\n",
46 | " max_token=80000,\n",
47 | " history=[\n",
48 | " [\"你是一个专业的销售顾问\", \"欢迎问我任何问题。\"]\n",
49 | " ],\n",
50 | " top_p=0.9,\n",
51 | " model_kwargs={\"sample_model_args\": False},\n",
52 | ")"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": 5,
58 | "id": "b03b42d0-a4f7-48eb-a795-64d891703a38",
59 | "metadata": {},
60 | "outputs": [],
61 | "source": [
62 | "# 提示词模板\n",
63 | "template = \"\"\"{question}\"\"\"\n",
64 | "prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
65 | ]
66 | },
67 | {
68 | "cell_type": "code",
69 | "execution_count": 6,
70 | "id": "b7322d29-8bc2-4c9d-aa34-adfed0c56f1f",
71 | "metadata": {},
72 | "outputs": [],
73 | "source": [
74 | "llm_chain = LLMChain(prompt=prompt, llm=llm)"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": 7,
80 | "id": "700fef2d-eaa7-439e-bb3b-e3af4c2dac3b",
81 | "metadata": {},
82 | "outputs": [
83 | {
84 | "name": "stderr",
85 | "output_type": "stream",
86 | "text": [
87 | "/root/miniconda3/lib/python3.11/site-packages/langchain_core/_api/deprecation.py:189: LangChainDeprecationWarning: The function `run` was deprecated in LangChain 0.1.0 and will be removed in 0.2.0. Use invoke instead.\n",
88 | " warn_deprecated(\n",
89 | "/root/miniconda3/lib/python3.11/site-packages/langchain_core/_api/deprecation.py:189: LangChainDeprecationWarning: The function `__call__` was deprecated in LangChain 0.1.0 and will be removed in 0.2.0. Use invoke instead.\n",
90 | " warn_deprecated(\n"
91 | ]
92 | },
93 | {
94 | "data": {
95 | "text/plain": [
96 | "'我们的衣服是通过品牌授权和直营销售两种方式进行销售。\\n\\n品牌授权是指我们与一些知名品牌合作,在这些品牌的授权下销售他们的衣服。这些品牌会给我们提供品牌形象、设计、产品质量和销售支持等资源,我们则会根据这些资源来制定自己的销售策略,进行市场推广和销售。\\n\\n直营销售是指我们自行设计和销售自己的品牌衣服。我们拥有自己的设计团队和生产线,能够提供优质的产品和个性化的服务,同时也可以通过自营销售来更好地掌控产品质量和销售流程。\\n\\n无论是品牌授权还是直营销售,我们都致力于为客户提供高品质、个性化和时尚的衣服,让客户能够轻松地找到适合自己的衣服,同时也为客户提供优质的售后服务。'"
97 | ]
98 | },
99 | "execution_count": 7,
100 | "metadata": {},
101 | "output_type": "execute_result"
102 | }
103 | ],
104 | "source": [
105 | "llm_chain.run(\"你们衣服怎么卖?\")"
106 | ]
107 | },
108 | {
109 | "cell_type": "code",
110 | "execution_count": null,
111 | "id": "a0396649-c400-4972-b6f6-18e51bbd7105",
112 | "metadata": {},
113 | "outputs": [],
114 | "source": []
115 | },
116 | {
117 | "cell_type": "markdown",
118 | "id": "0a9070c0-14b0-416f-98d3-65eb5ab43c4e",
119 | "metadata": {},
120 | "source": [
121 | "## 带记忆功能的聊天对话(Conversation with Memory)"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": 8,
127 | "id": "9ca4b002-0a05-4c3b-9aa4-94e7990072f9",
128 | "metadata": {},
129 | "outputs": [],
130 | "source": [
131 | "from langchain.chains import ConversationChain\n",
132 | "from langchain.memory import ConversationBufferMemory"
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "execution_count": 9,
138 | "id": "455b4814-75d3-4251-a0f1-859a9d72bc1c",
139 | "metadata": {},
140 | "outputs": [],
141 | "source": [
142 | "conversation = ConversationChain(\n",
143 | " llm=llm, \n",
144 | " verbose=True, \n",
145 | " memory=ConversationBufferMemory()\n",
146 | ")"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": 10,
152 | "id": "11161e97-cfbb-4bbe-9839-27334c7f6eba",
153 | "metadata": {},
154 | "outputs": [
155 | {
156 | "name": "stderr",
157 | "output_type": "stream",
158 | "text": [
159 | "/root/miniconda3/lib/python3.11/site-packages/langchain_core/_api/deprecation.py:189: LangChainDeprecationWarning: The function `__call__` was deprecated in LangChain 0.1.0 and will be removed in 0.2.0. Use invoke instead.\n",
160 | " warn_deprecated(\n"
161 | ]
162 | },
163 | {
164 | "name": "stdout",
165 | "output_type": "stream",
166 | "text": [
167 | "\n",
168 | "\n",
169 | "\u001b[1m> Entering new ConversationChain chain...\u001b[0m\n",
170 | "Prompt after formatting:\n",
171 | "\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n",
172 | "\n",
173 | "Current conversation:\n",
174 | "\n",
175 | "Human: 你们衣服怎么卖?\n",
176 | "AI:\u001b[0m\n",
177 | "\n",
178 | "\u001b[1m> Finished chain.\u001b[0m\n"
179 | ]
180 | },
181 | {
182 | "data": {
183 | "text/plain": [
184 | "'我们的衣服都是自己生产的,然后拿到市场上卖。我们主要的销售渠道是线上和线下,线上是通过我们的官方网站和一些在线平台销售,线下则是通过我们的实体店和一些零售商进行销售。我们的衣服设计独特,质量优良,价格实惠,深受广大消费者的喜爱。'"
185 | ]
186 | },
187 | "execution_count": 10,
188 | "metadata": {},
189 | "output_type": "execute_result"
190 | }
191 | ],
192 | "source": [
193 | "conversation.predict(input=\"你们衣服怎么卖?\")"
194 | ]
195 | },
196 | {
197 | "cell_type": "code",
198 | "execution_count": 11,
199 | "id": "dcc8ced2-396e-47c0-8474-f8a5065ba7aa",
200 | "metadata": {},
201 | "outputs": [
202 | {
203 | "name": "stderr",
204 | "output_type": "stream",
205 | "text": [
206 | "/root/miniconda3/lib/python3.11/site-packages/langchain_core/_api/deprecation.py:189: LangChainDeprecationWarning: The function `__call__` was deprecated in LangChain 0.1.0 and will be removed in 0.2.0. Use invoke instead.\n",
207 | " warn_deprecated(\n"
208 | ]
209 | },
210 | {
211 | "name": "stdout",
212 | "output_type": "stream",
213 | "text": [
214 | "\n",
215 | "\n",
216 | "\u001b[1m> Entering new ConversationChain chain...\u001b[0m\n",
217 | "Prompt after formatting:\n",
218 | "\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n",
219 | "\n",
220 | "Current conversation:\n",
221 | "Human: 你们衣服怎么卖?\n",
222 | "AI: 我们的衣服都是自己生产的,然后拿到市场上卖。我们主要的销售渠道是线上和线下,线上是通过我们的官方网站和一些在线平台销售,线下则是通过我们的实体店和一些零售商进行销售。我们的衣服设计独特,质量优良,价格实惠,深受广大消费者的喜爱。\n",
223 | "Human: 有哪些款式?\n",
224 | "AI:\u001b[0m\n",
225 | "\n",
226 | "\u001b[1m> Finished chain.\u001b[0m\n"
227 | ]
228 | },
229 | {
230 | "data": {
231 | "text/plain": [
232 | "'我们有很多款式不同的衣服,涵盖了不同的风格和场合。我们的设计团队会根据当前的流行趋势和客户的需求,不断推出新的款式,保持我们的产品具有新鲜感和吸引力。我们的产品线主要包括T恤、衬衫、连衣裙、牛仔裤、休闲裤、运动鞋等多种款式。'"
233 | ]
234 | },
235 | "execution_count": 11,
236 | "metadata": {},
237 | "output_type": "execute_result"
238 | }
239 | ],
240 | "source": [
241 | "conversation.predict(input=\"有哪些款式?\")"
242 | ]
243 | },
244 | {
245 | "cell_type": "code",
246 | "execution_count": 12,
247 | "id": "1518523e-6ccd-461c-a1e1-331905750f3d",
248 | "metadata": {},
249 | "outputs": [
250 | {
251 | "name": "stderr",
252 | "output_type": "stream",
253 | "text": [
254 | "/root/miniconda3/lib/python3.11/site-packages/langchain_core/_api/deprecation.py:189: LangChainDeprecationWarning: The function `__call__` was deprecated in LangChain 0.1.0 and will be removed in 0.2.0. Use invoke instead.\n",
255 | " warn_deprecated(\n"
256 | ]
257 | },
258 | {
259 | "name": "stdout",
260 | "output_type": "stream",
261 | "text": [
262 | "\n",
263 | "\n",
264 | "\u001b[1m> Entering new ConversationChain chain...\u001b[0m\n",
265 | "Prompt after formatting:\n",
266 | "\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n",
267 | "\n",
268 | "Current conversation:\n",
269 | "Human: 你们衣服怎么卖?\n",
270 | "AI: 我们的衣服都是自己生产的,然后拿到市场上卖。我们主要的销售渠道是线上和线下,线上是通过我们的官方网站和一些在线平台销售,线下则是通过我们的实体店和一些零售商进行销售。我们的衣服设计独特,质量优良,价格实惠,深受广大消费者的喜爱。\n",
271 | "Human: 有哪些款式?\n",
272 | "AI: 我们有很多款式不同的衣服,涵盖了不同的风格和场合。我们的设计团队会根据当前的流行趋势和客户的需求,不断推出新的款式,保持我们的产品具有新鲜感和吸引力。我们的产品线主要包括T恤、衬衫、连衣裙、牛仔裤、休闲裤、运动鞋等多种款式。\n",
273 | "Human: 休闲装男款都有啥?\n",
274 | "AI:\u001b[0m\n",
275 | "\n",
276 | "\u001b[1m> Finished chain.\u001b[0m\n"
277 | ]
278 | },
279 | {
280 | "data": {
281 | "text/plain": [
282 | "'我们的休闲装男款主要包括T恤、衬衫、牛仔裤和休闲裤等。我们的设计团队会根据当前的流行趋势和客户的需求,不断推出新的款式,保持我们的产品具有新鲜感和吸引力。我们的产品线非常丰富,可以满足不同客户的需求。'"
283 | ]
284 | },
285 | "execution_count": 12,
286 | "metadata": {},
287 | "output_type": "execute_result"
288 | }
289 | ],
290 | "source": [
291 | "conversation.predict(input=\"休闲装男款都有啥?\")"
292 | ]
293 | },
294 | {
295 | "cell_type": "code",
296 | "execution_count": null,
297 | "id": "a987c6c6-98be-4a1b-b916-b05033d6d0a0",
298 | "metadata": {},
299 | "outputs": [],
300 | "source": []
301 | }
302 | ],
303 | "metadata": {
304 | "kernelspec": {
305 | "display_name": "Python 3 (ipykernel)",
306 | "language": "python",
307 | "name": "python3"
308 | },
309 | "language_info": {
310 | "codemirror_mode": {
311 | "name": "ipython",
312 | "version": 3
313 | },
314 | "file_extension": ".py",
315 | "mimetype": "text/x-python",
316 | "name": "python",
317 | "nbconvert_exporter": "python",
318 | "pygments_lexer": "ipython3",
319 | "version": "3.11.5"
320 | }
321 | },
322 | "nbformat": 4,
323 | "nbformat_minor": 5
324 | }
325 |
--------------------------------------------------------------------------------
/chatglm/data/raw_data.txt:
--------------------------------------------------------------------------------
1 | 蒙卦原文
2 | 蒙。亨。匪我求童蒙,童蒙求我。初筮告,再三渎,渎则不告。利贞。
3 | 象曰:山下出泉,蒙。君子以果行育德。
4 | 白话文解释
5 | 蒙卦:通泰。不是我有求于幼稚愚昧的人,而是幼稚愚昧的人有求于我。第一次占筮,神灵告诉了他。轻慢不敬的再三占筮,轻慢不敬的占筮,神灵就不会告诉他。但还是吉利的卜问。
6 | 《象辞》说:上卦为艮,象征山;下卦为坎,象征泉。山下有泉,泉水喷涌而出,这是蒙卦的卦象。君子观此卦象,取法于一往无前的山泉,从而以果敢坚毅的行动来培养自身的品德。
7 | 《断易天机》解
8 | 蒙卦艮上坎下,为离宫四世卦。蒙即蒙昧,主回还往复,疑惑不前,多忧愁过失,乃是凶卦。
9 | 北宋易学家邵雍解
10 | 智慧未开,蒙昧闭塞;犹豫不决,缺乏果断。
11 | 得此卦者,智慧犹如童蒙,不辨是非,迷失方向;若能顺贤师良友之教,启其聪明则亨通。
12 | 台湾国学大儒傅佩荣解
13 | 时运:蓄积德行,出而用世。
14 | 财运:矿山生意,果决则吉。
15 | 家宅:君子居吉;婚姻之始。
16 | 身体:驱去邪热,可保平安。
17 | 传统解卦
18 | 这个卦是异卦(下坎上艮)相叠,艮是山的形象,喻止;坎是水的形象,喻险。卦形为山下有险,仍不停止前进,是为蒙昧,故称蒙卦。但因把握时机,行动切合时宜,因此,具有启蒙和通达的卦象。
19 | 大象:蒙者,昏而无所见也,故宜「启蒙」。
20 | 运势:初时迷惑不知方向,须忍耐待机而动,凡事多听取别人意见,则运可通。
21 | 事业:事业开始,混乱无序,危机四伏,以勇敢坚毅的行动可以扭转局面。然而必须接受严格教育,培养这种奋发图强的精神。务必脚踏实地,最忌好高骛远,否则会陷入孤立无援的境地。
22 | 经商:务必小心谨慎,不得急功近利,尤其应树立高尚的商业道德,以良好的信誉提高竞争力而取胜。
23 | 求名:必须接受良好的基础教育,陶冶情操。且动机纯正,可以达到目的。
24 | 婚恋:注意考察对方品德,不可以金钱为诱铒。夫妻需相互宽容、理解。
25 | 决策:有时会陷入迷惘困顿的境地,加上胆小、不果断,往往误事。如能接受长辈的教诲,甚至严酷的考验,抛弃疑惧的心理,等待适当时机,必然一帆风顺。
26 |
27 | 屯卦原文
28 | 屯。元,亨,利,贞。勿用,有攸往,利建侯。
29 | 象曰:云,雷,屯;君子以经纶。
30 | 白话文解释
31 | 屯卦。大吉大利,吉利的占卜。不利于出门。有利于建国封侯。
32 | 《象辞》说:屯的上卦为坎,坎为云,下卦为震,震为雷。云行于上,雷动于下,是屯卦的卦象。君子观此卦象,取法于云雷,用云的恩泽,雷的威严来治理国事。
33 | 《断易天机》解
34 | 屯卦坎上震下,为坎宫二世卦。屯说明有困难,又象征动而逢险境,需刚毅果敢方为吉。
35 | 北宋易学家邵雍解
36 | 万物始生,开始困难;先劳后逸,苦尽甘来。
37 | 得此卦者,身处困境,宜守不宜进,须多加辛苦努力,排除困难,方可通达,有初难后解之象。
38 | 台湾国学大儒傅佩荣解
39 | 时运:宜守不宜进。
40 | 财运:创业维艰。
41 | 家宅:修缮住宅;初婚不和。
42 | 身体:保存元气。
43 | 传统解卦
44 | 这个卦是异卦(下震上坎)相叠,震为雷,喻动;坎为雨,喻险。雷雨交加,险象丛生,环境恶劣。“屯”原指植物萌生大地,万物始生,充满艰难险阻,然而顺时应运,必欣欣向荣。
45 | 大象:屯者难也,万事欲进而不得进。
46 | 运势:身处困境,步步为营,有初难后解之意。
47 | 事业:起初多有不利,要知难而进,小心翼翼,勇往直前,灵活机动,可望获得大的成功,时机到来时一定要抓住,却也不得操之太急,且仍有困难,务必有他人相助,故平时应多施恩惠。
48 | 经商:创业初期步履艰难,多有挫折。坚定信念最重要,不要为表面现象所迷惑,应积极进取,行动果断,打开出路。若仍无法摆脱困境,则应退守保全,等待机会,再展宏图。
49 | 求名:积极争取,主动追求,可以成功。
50 | 婚恋:好事多磨,忠贞纯洁,大胆追求,能够成功,婚姻美满。
51 | 决策:初始困难,但若具有坚忍不拔的毅力和锲而不舍的奋斗精神,前途不可估量,但往往不为他人理解而陷于孤独苦闷,事业会因此处于困难状态,需要得到贤德之人的帮助才能摆脱。如能以乐观主义精神处世,能取得成就。
52 |
53 | 需卦原文
54 | 需。有孚,光亨,贞吉。利涉大川。
55 | 象曰:云上于天,需;君子以饮食宴乐。
56 | 白话文解释
57 | 需卦:抓到俘虏。大吉大利,吉利的卜问。有利于涉水渡河。
58 | 《象辞》说:需的上卦为坎,表示云;下卦为乾,表示天。云浮聚于天上,待时降雨是需卦的卦象。君子观此卦象,可以宴饮安乐,待时而动。
59 | 《断易天机》解
60 | 需卦坎上乾下,为坤宫游魂卦。需为踌躇、期待,虽然刚强,但前面有险阻,应当等待,涉大川则利。
61 | 北宋易学家邵雍解
62 | 坎陷当前,遇阻不进;大器晚成,收成在后。
63 | 得此卦者,时机尚未成熟,需要耐心等待,急进反会见凶。
64 | 台湾国学大儒傅佩荣解
65 | 时运:时机尚未成熟,耐心等待。
66 | 财运:资本未集,无法开张。
67 | 家宅:平安是福。
68 | 身体:调节饮食,健康有望。
69 | 传统解卦
70 | 这个卦是异卦(下乾上坎)相叠,下卦是乾,刚健之意;上卦是坎,险陷之意。以刚逢险,宜稳健之妥,不可冒失行动,观时待变,所往一定成功。
71 | 大象:云登天上而未雨,不能急进,等待时机之象。
72 | 运势:智者必须待时而行,急进反见凶险。
73 | 事业:关键在于审时度势,耐心等待,事成于安祥,切勿冒险,欲速不达。自己要充满自信,临危不惧,坚守中正,必可化险为夷。情况有利时,仍得居安思危。
74 | 经商:行动之初,情况困难,必须以极大的耐心,创造条件和机会,行事光明磊落,观时待变,实现愿望。事情接近成功时,更应小心谨慎,切莫功亏一篑。
75 | 求名:时机尚不成熟,应耐心等待。这时应坚定信念,不为闲言流语所动摇,努力丰富自己,再求助可靠的人便可成功。
76 | 婚恋:慎重,切不可草率行事,逐渐培养感情,以诚实、热情相待,会发生变故,仍可以有良好的结局。双方都应懂得以柔克刚的道理。
77 | 决策:前途光明,有雄心大志,且可实现。为此需要积蓄实力,等待时机,大器晚成。本人具有坚强的意志,冷静的头脑。前进途中会遇到困难和险阻,必须十分谨慎,坦然对待小人的中伤,在灾祸在面前能镇静自若。不轻举妄动,冷静选择方向。为人谦和、坦率,多有他人相助,促使事业成功。当时机成熟后,必然一帆风顺。
78 |
79 | 讼卦原文
80 | 讼。有孚,窒惕,中吉,终凶。利见大人,不利涉大川。
81 | 象曰:天与水违行,讼。君子以做事谋始。
82 | 白话文解释
83 | 讼卦:虽有利可图(获得俘虏),但要警惕戒惧。其事中间吉利,后来凶险。占筮得此爻,有利于会见贵族王公,不利于涉水渡河。
84 | 《象辞》说:上卦为乾,乾为天;下卦为坎,坎为水,天水隔绝,流向相背,事理乖舛,这是讼卦的卦象。君子观此卦象,以杜绝争讼为意,从而在谋事之初必须慎之又慎。
85 | 《断易天机》解
86 | 讼卦乾上坎下,为离宫游魂卦。上乾为刚,下坎为险,一方刚强,一方阴险,必然产生争论,因此多有不吉。
87 | 北宋易学家邵雍解
88 | 天高水深,达远不亲;慎谋退守,敬畏无凶。
89 | 得此卦者,身心不安,事多不顺,与他人多争诉之事,宜修身养性,谨慎处事。
90 | 台湾国学大儒傅佩荣解
91 | 时运:功名受阻,不宜树敌。
92 | 财运:开始谨慎,终可获利。
93 | 家宅:君子必求淑女。
94 | 身体:预防胜于治疗。
95 | 传统解卦
96 | 这个卦是异卦(下坎上乾)相叠。同需卦相反,互为“综卦”。乾为刚健,坎为险陷。刚与险,健与险,彼此反对,定生争讼。争讼非善事,务必慎重戒惧。
97 | 大象:乾天升于上,坎水降于下,相背而行而起讼。
98 | 运势:事与愿违,凡事不顺,小人加害,宜防陷阱。
99 | 事业:起初顺利,有利可图,继而受挫,务必警惕,慎之又慎,不得固执已见,极力避免介入诉讼纠纷的争执之中。与其这样,不如退而让人,求得化解,安于正理,可免除意外之灾。陷入争讼,即使获胜,最后还得失去,得不偿失。
100 | 经商:和气生财,吃亏是福,切勿追求不义之财。商业谈判应坚持公正、公平、互利的原则,尽量避免发生冲突。这样便会有好结果。
101 | 求名:不利。自己尚缺乏竞争实力,应坚守纯正,隐忍自励,自强自勉,切莫逞强。依靠有地位的人的帮助,及早渡过难关。
102 | 婚恋:虽不尽人意,倒也般配,彼此理解,未尝不可。双方应以温和的方式处理生活。
103 | 决策:争强好胜,不安于现状,为改变命运和超越他人而奋斗。头脑聪颖,反应敏捷,有贵人相助。但缺乏持之以恒的毅力,容易露出锋芒,得罪他人,带来诉讼之灾。宜承认现实,顺其自然,知足,适可而止。接受教训,引以为戒,可功成名就。
104 |
105 | 师卦原文
106 | 师。贞,丈人吉,无咎。
107 | 象曰:地中有水,师。君子以容民畜众。
108 | 白话文解释
109 | 师卦:占问总指挥的军情,没有灾祸。
110 | 《象辞》说:下卦为坎,坎为水;上卦为坤,坤为地,像“地中有水”,这是师卦的卦象。君子观此卦象,取法于容纳江河的大地,收容和畜养大众。
111 | 《断易天机》解
112 | 师卦坤上坎下,为坎宫归魂卦。师即兵众,只有选择德高望重的长者来统率军队,才能吉祥无咎。
113 | 北宋易学家邵雍解
114 | 忧劳动众,变化无穷;公正无私,排除万难。
115 | 得此卦者,困难重重,忧心劳众,宜包容别人,艰苦努力,摒除一切困难。
116 | 台湾国学大儒傅佩荣解
117 | 时运:包容别人,修行待时。
118 | 财运:有财有库,善自珍惜。
119 | 家宅:旧亲联姻,可喜可贺。
120 | 身体:腹胀之症,调气无忧。
121 | 传统解卦
122 | 这个卦是异卦(下坎上坤)相叠。“师”指军队。坎为水、为险;坤为地、为顺,喻寓兵于农。兵凶战危,用兵乃圣人不得已而为之,但它可以顺利无阻碍地解决矛盾,因为顺乎形势,师出有名,故能化凶为吉。
123 | 大象:养兵聚众,出师攻伐之象,彼此有伤,难得安宁。
124 | 运势:困难重重,凡事以正规行事,忌独断独行、投机取巧,提防潜在敌人。
125 | 事业:阻力很大,困难很多,处于激烈的竞争状态,要与他人密切合作,谨小慎微,行为果断,切忌盲目妄动,适度即可,注意保全自己。机动灵活,严于律已。从容沉着对付一切,必能成功。
126 | 经商:已有一定的积蓄,可以从事大的营销活动,但必卷入激烈商战,以刚毅顽强的精神和高尚的商业道德,辅以灵活的方法,勿贪图小利,勿掉以轻心,加强与他人的沟通,必可摆脱困境,化险为夷。
127 | 求名:具备很好的条件,但须有正确的引导,务必严格要求自己,克服不利因素的干扰,经过扎实努力,必可名利双全。
128 | 婚恋:慎重、专注,否则会陷入“三角”纠纷。痴情追求可以达到目的。
129 | 决策:天资聪颖,性格灵活,具有坚强的意志,对事业执着追求,迎难而进。可成就大事业。喜竞争,善争辩,富有冒险精神,不免带来麻烦,务老成持重,不贪功,以中正为要。
130 |
131 | 比卦原文
132 | 比。吉。原筮,元永贞,无咎。不宁方来,后夫凶。
133 | 象曰:地上有水,比。先王以建万国,亲诸侯。
134 | 白话文解释
135 | 比卦:吉利。同时再卜筮,仍然大吉大利。卜问长时期的吉凶,也没有灾祸。不愿臣服的邦国来朝,迟迟不来者有难。
136 | 《象辞》说:下卦为坤,上卦为坎,坤为地,坎为水,像地上有水,这是比卦的卦象。先王观此卦象,取法于水附大地,地纳江河之象,封建万国,亲近诸侯。
137 | 《断易天机》解
138 | 比卦坎上坤下,为坤宫归魂卦。比为相亲相依附之意,长期如此,就会无咎,所以吉祥。
139 | 北宋易学家邵雍解
140 | 水行地上,亲比欢乐;人情亲顺,百事无忧。
141 | 得此卦者,可获朋友之助,众人之力,谋事有成,荣显之极。
142 | 台湾国学大儒傅佩荣解
143 | 时运:众人相贺,荣显之极。
144 | 财运:善人相扶,大发利市。
145 | 家宅:百年好合。
146 | 身体:心腹水肿,宜早求治。
147 | 传统解卦
148 | 这个卦是异卦(下坤上坎)相叠,坤为地,坎为水。水附大地,地纳河海,相互依赖,亲密无间。此卦与师卦完全相反,互为综卦。它阐述的是相亲相辅,宽宏无私,精诚团结的道理。
149 | 大象:一阳统五阴,比邻相亲相辅,和乐之象。
150 | 运势:平顺,可得贵人提拔,凡事宜速战速决,不可过份迟疑。
151 | 事业:顺利能够成功,向前发展,可以得到他人的帮助和辅佐,以诚实、信任的态度去做事。待人宽厚、正直,主动热情,向才德高尚的人士学习,听取建议。
152 | 经商:愿望能够实现,且有较丰厚的利润,但需要与他人密切合作,真诚交往,讲究商业道德,遵守信义,如唯利是图,贪心不足,或自以为是,会导致严重损失。
153 | 求名:有成功的希望,不仅要靠个人的努力,更为重要的是他人的赏识和栽培。
154 | 婚恋:美好姻缘,相亲相爱,彼此忠诚,白头到老。
155 | 决策:心地善良,待人忠诚、厚道,乐于帮助他人,也能得到回报。工作勤恳,对自己要求严格,可以实现自己的理想,但要多动脑筋,多思考,善于判断是非,尤其要注意选择朋友,一旦结上品行不端的人,会成为自己的祸患。如果与比自己高明的人交朋友,并取得帮助,会终身受益。
156 |
157 | 坤卦原文
158 | 坤。元,亨,利牝马之贞。君子有攸往,先迷后得主。利西南得朋,东北丧朋。安贞,吉。
159 | 象曰:地势坤,君子以厚德载物。
160 | 白话文解释
161 | 坤卦:大吉大利。占问雌马得到吉兆。君子前去旅行,先迷失路途,后来找到主人,吉利。西南行获得财物,东北行丧失财物。占问定居,得到吉兆。
162 | 《象辞》说:大地的形势平铺舒展,顺承天道。君子观此卦象,取法于地,以深厚的德行来承担重大的责任。
163 | 《断易天机》解
164 | 坤卦坤上坤下,为坤宫本位卦。坤卦为柔顺,为地气舒展之象,具有纯阴之性,先失道而后得主,宜往西南,西南可得到朋友。
165 | 北宋易学家邵雍解
166 | 柔顺和静,厚载之功;静守安顺,妄动招损。
167 | 得此卦者,宜顺从运势,以静制动,不宜独立谋事,顺从他人,一起合作,可成大事。
168 | 台湾国学大儒傅佩荣解
169 | 时运:为人厚道,声名远传。
170 | 财运:满载而归。
171 | 家宅:家庭安稳;婚嫁大吉。
172 | 身体:柔软运动。
173 | 传统解卦
174 | 这个卦是同卦(下坤上坤)相叠,阴性。象征地(与乾卦相反),顺从天,承载万物,伸展无穷无尽。坤卦以雌马为象征,表明地道生育抚养万物,而又依天顺时,性情温顺。它以“先迷后得”证明“坤”顺从“乾”,依随“乾”,才能把握正确方向,遵循正道,获取吉利。
175 | 大象:大地承载万物,以德服众,仁者无敌。
176 | 运势:诸事不宜急进,以静制动为宜。
177 | 事业:诸项事业可以成功,得到预想的结果,但开始出师不利,为困境所扰。切莫冒险急进,须小心谨言慎行,尤其不可单枪匹马,独断专行。取得朋友的关心和支持最为重要,在他人的合作下,共同完成事业。因此,应注重内心修养,积蓄养德,效法大地,容忍负重,宽厚大度,以直率、方正、含蓄为原则,不得贪功自傲,持之以恒,谋求事业的成功。
178 | 经商:机遇不很好,切莫冒险,以稳健为妥,遇到挫折,务必即时总结经验。注意储存货物,待价而沽,处处小心为是。
179 | 求名:比较顺利,具备基本条件,踏踏实实,埋头苦干,不追求身外之物,即可吉祥。
180 | 婚恋:阴盛。以柔克刚,女方柔顺,美好姻缘,白头到老。
181 | 决策:忠厚、温和,待人真诚,热心助人,因此也能得到他人的帮助,可往往因不提防小人而受到伤害,但无大碍。性格灵活,工作方法多样,可以左右逢源,得到赞许。
182 |
183 | 乾卦原文
184 | 乾。元,亨,利,贞。
185 | 象曰:天行健,君子以自强不息。
186 | 白话文解释
187 | 乾卦:大吉大利,吉利的贞卜。
188 | 《象辞》说:天道刚健,运行不已。君子观此卦象,从而以天为法,自强不息。
189 | 《断易天机》解
190 | 乾象征天,六阳爻构成乾卦,为《易经》六十四卦之首。纯阳刚建,其性刚强,其行劲健,大通而至正,兆示大通而有利,但须行正道,方可永远亨通。
191 | 北宋易学家邵雍解
192 | 刚健旺盛,发育之功;完事顺利,谨防太强。
193 | 得此卦者,天行刚健,自强不息,名利双收之象,宜把握机会,争取成果。女人得此卦则有过于刚直之嫌。
194 | 台湾国学大儒傅佩荣解
195 | 时运:临事刚健,自强不息。
196 | 财运:施比受有福,不利买而利卖。
197 | 家宅:积善有余庆;女子过刚宜慎重。
198 | 身体:保健有恒。
199 | 传统解卦
200 | 这个卦是同卦(下乾上乾)相叠。象征天,喻龙(德才的君子),又象征纯粹的阳和健,表明兴盛强健。乾卦是根据万物变通的道理,以“元、亨、利、贞”为卦辞,表示吉祥如意,教导人遵守天道的德行。
201 | 大象:天行刚健,自强不息。
202 | 运势:飞龙在天,名利双收之象,宜把握机会,争取成果。
203 | 事业:大吉大利,万事如意,心想事成,自有天佑,春风得意,事业如日中天。但阳气已达顶点,盛极必衰,务须提高警惕,小心谨慎。力戒骄傲,冷静处世,心境平和,如是则能充分发挥才智,保证事业成功。
204 | 经商:十分顺利,有发展向上的大好机会。但切勿操之过急,宜冷静分析形势,把握时机,坚持商业道德,冷静对待中途出现的困难,定会有满意的结果。
205 | 求名:潜在能力尚未充分发挥,只要进一步努力,克服骄傲自满情绪,进业修德,以渊博学识和高尚品质,成君子之名。
206 | 婚恋:阳盛阴衰,但刚柔可相济,形成美满结果。女性温柔者更佳。
207 | 决策:可成就大的事业。坚持此卦的刚健、正直、公允的实质,修养德行,积累知识,坚定信念,自强不息,必能克服困难,消除灾难。
--------------------------------------------------------------------------------
/chatglm/data/zhouyi_dataset_handmade.csv:
--------------------------------------------------------------------------------
1 | content,summary
2 | 乾卦,乾卦原文\n乾。元,亨,利,贞。\n象曰:天行健,君子以自强不息。\n\n白话文解释\n乾卦:大吉大利,吉利的贞卜。\n《象辞》说:天道刚健,运行不已。君子观此卦象,从而以天为法,自强不息。\n\n《断易天机》解\n乾象征天,六阳爻构成乾卦,为《易经》六十四卦之首。纯阳刚建,其性刚强,其行劲健,大通而至正,兆示大通而有利,但须行正道,方可永远亨通。\n\n北宋易学家邵雍解\n刚健旺盛,发育之功;完事顺利,谨防太强。\n得此卦者,天行刚健,自强不息,名利双收之象,宜把握机会,争取成果。女人得此卦则有过于刚直之嫌。\n\n传统解卦\n这个卦是同卦(下乾上乾)相叠。象征天,喻龙(德才的君子),又象征纯粹的阳和健,表明兴盛强健。乾卦是根据万物变通的道理,以“元、亨、利、贞”为卦辞,表示吉祥如意,教导人遵守天道的德行。
3 | 坤卦,坤卦原文\n坤。元,亨,利牝马之贞。君子有攸往,先迷后得主。利西南得朋,东北丧朋。安贞,吉。\n象曰:地势坤,君子以厚德载物。\n白话文解释\n坤卦:大吉大利。占问雌马得到吉兆。君子前去旅行,先迷失路途,后来找到主人,吉利。西南行获得财物,东北行丧失财物。占问定居,得到吉兆。\n《象辞》说:大地的形势平铺舒展,顺承天道。君子观此卦象,取法于地,以深厚的德行来承担重大的责任。\n《断易天机》解\n坤卦坤上坤下,为坤宫本位卦。坤卦为柔顺,为地气舒展之象,具有纯阴之性,先失道而后得主,宜往西南,西南可得到朋友。\n北宋易学家邵雍解\n柔顺和静,厚载之功;静守安顺,妄动招损。\n得此卦者,宜顺从运势,以静制动,不宜独立谋事,顺从他人,一起合作,可成大事。\n台湾国学大儒傅佩荣解\n时运:为人厚道,声名远传。\n财运:满载而归。\n家宅:家庭安稳;婚嫁大吉。\n身体:柔软运动。\n传统解卦\n这个卦是同卦(下坤上坤)相叠,阴性。象征地(与乾卦相反),顺从天,承载万物,伸展无穷无尽。坤卦以雌马为象征,表明地道生育抚养万物,而又依天顺时,性情温顺。它以“先迷后得”证明“坤”顺从“乾”,依随“乾”,才能把握正确方向,遵循正道,获取吉利。
4 | 水雷屯卦,屯卦原文:屯。元,亨,利,贞。勿用,有攸往,利建侯。象曰:云,雷,屯;君子以经纶。白话文解释:屯卦大吉大利,吉利的占卜。不利于出门。有利于建国封侯。《象辞》说:屯的上卦为坎,坎为云,下卦为震,震为雷。云行于上,雷动于下,是屯卦的卦象。君子观此卦象,取法于云雷,用云的恩泽和雷的威严治理国事。《断易天机》解:屯卦坎上震下,为坎宫二世卦。屯卦显示困难,动而逢险,需刚毅果敢方为吉。北宋易学家邵雍解:万物始生,开始困难;先劳后逸,苦尽甘来。得此卦者,身处困境,宜守不宜进,需辛劳克难,初难后解。台湾国学大儒傅佩荣解:时运宜守,财运创业艰难,家宅初婚不和,身体需保元气。传统解卦:异卦(下震上坎),震为雷动,坎为雨险。雷雨交加,环境险恶。“屯”指万物始生,艰难险阻中顺时应运,终将欣荣。
5 | 山水蒙卦,蒙卦原文:蒙。亨。匪我求童蒙,童蒙求我。初筮告,再三渎,渎则不告。利贞。象曰:山下出泉,蒙。君子以果行育德。白话文解释:蒙卦通泰。不是我求幼稚之人,而是幼稚之人求我。初次占卜被告知,轻慢占卜则不再告知。占卜吉利。《象辞》说:卦象为山下有泉,取法于山泉果敢坚毅,培养品德。《断易天机》解:蒙卦艮上坎下,象征蒙昧,主疑惑不前,多忧愁,凶兆。北宋易学家邵雍解:智慧未开,犹豫不决,需顺师友教导启智。台湾国学大儒傅佩荣解:时运蓄德出世,财运矿业果决吉,家宅君子居吉,婚姻之始,身体驱邪保安。传统解卦:异卦(下坎上艮),山下有险仍前进,为蒙昧,把握时机行动恰时,启蒙通达之象。大象:蒙为昏无所见,宜启蒙。
6 | 水天需卦,需卦原文:需。有孚,光亨,贞吉。利涉大川。象曰:云上于天,需;君子以饮食宴乐。白话文解释:需卦代表俘虏,大吉大利,适宜涉水过河。《象辞》说:上卦为坎,象征云;下卦为乾,象征天。云聚天上,待降雨,君子观此卦,宜宴饮安乐,待时而动。《断易天机》解:需卦坎上乾下,象征踌躇期待,刚强面对险阻,宜等待,涉大川利。北宋易学家邵雍解:遇阻不进,大器晚成,需耐心等待。台湾国学大儒傅佩荣解:时运需耐心等待,财运资本未集,家宅平安,身体调饮食以健康。传统解卦:异卦(下乾上坎),刚逢险,宜稳健,观时待变,必成功。
7 | 天水讼卦,讼卦原文:讼。有孚,窒惕,中吉,终凶。利见大人,不利涉大川。象曰:天与水违行,讼。君子以做事谋始。白话文解释:讼卦象征虽有利可图但需警惕。事情初吉后凶,利于见贵人,不宜涉水。《象辞》说:上卦为乾(天),下卦为坎(水),天水相隔,事理不合,君子需慎重谋事。《断易天机》解:讼卦乾上坎下,刚遇险,必有争论,多不吉。北宋易学家邵雍解:天高水深,远离不亲,慎谋退守则无凶。得此卦者,身心不安,多争诉,宜修身养性。台湾国学大儒傅佩荣解:时运受阻,财运初谨慎终获利,家宅君子求淑女,身体预防胜于治疗。传统解卦:异卦(下坎上乾),刚健遇险,彼此反对,生争讼,需慎重戒惧。
8 | 地水师卦,师卦原文:师。贞,丈人吉,无咎。象曰:地中有水,师。君子以容民畜众。白话文解释:师卦象征军队指挥,无灾祸。《象辞》说:下卦为坎(水),上卦为坤(地),如大地容纳江河,君子应容纳众人。《断易天机》解:师卦坤上坎下,象征军众,需德高长者统率以吉无咎。北宋易学家邵雍解:忧劳动众,公正无私排难。得卦者应包容他人,努力排除困难。台湾国学大儒傅佩荣解:时运包容他人,财运有财需珍惜,家宅旧亲联姻吉,身体腹胀调气。传统解卦:异卦(下坎上坤),“师”指军队。坎为水险,坤为地顺,寓兵于农,用兵应顺势,故化凶为吉。
9 | 水地比卦,比卦原文:比。吉。原筮,元永贞,无咎。不宁方来,后夫凶。象曰:地上有水,比。先王以建万国,亲诸侯。白话文解释:比卦吉利,再卜筮仍大吉,长期吉凶无灾。不愿臣服的邦国来朝,迟来有难。《象辞》说:下卦为坤(地),上卦为坎(水),如水附大地,先王据此建万国,亲近诸侯。《断易天机》解:比卦坎上坤下,意为相依附,长期如此则无咎,故吉祥。北宋易学家邵雍解:水行地上,亲比欢乐,人情亲顺,百事无忧。得此卦者获朋友助,众力成事,荣显。台湾国学大儒傅佩荣解:时运众贺荣显,财运相扶大利,家宅好合,身体心腹肿宜治。传统解卦:异卦(下坤上坎),地水相依,亲密无间,阐述相亲相辅,团结道理。
--------------------------------------------------------------------------------
/deepspeed/README.md:
--------------------------------------------------------------------------------
1 | # DeepSpeed 框架安装指南
2 |
3 | ## 更新 GCC 和 G++ 版本(如需)
4 |
5 | 首先,添加必要的 PPA 仓库,然后更新 `gcc` 和 `g++`:
6 |
7 | ```bash
8 | sudo add-apt-repository ppa:ubuntu-toolchain-r/test
9 | sudo apt update
10 | sudo apt install gcc-7 g++-7
11 | ```
12 |
13 | 更新系统的默认 `gcc` 和 `g++` 指向:
14 |
15 | ```bash
16 | sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-7 60 --slave /usr/bin/g++ g++ /usr/bin/g++-7
17 | sudo update-alternatives --config gcc
18 | ```
19 |
20 | ## 创建隔离的 Anaconda 环境
21 |
22 | 如果想要隔离环境,建议采用 clone 方式,新建一个 DeepSpeed 专用的 Anaconda 环境:
23 |
24 | ```bash
25 | conda create -n deepspeed --clone base
26 | ```
27 |
28 | ## 安装 Transformers 和 DeepSpeed
29 |
30 | ### 源代码安装 Transformers
31 |
32 | 遵循[官方文档](https://huggingface.co/docs/transformers/installation#install-from-source),通过下面的命令安装 Transformers:
33 |
34 | ```bash
35 | pip install git+https://github.com/huggingface/transformers
36 | ```
37 |
38 | ### 源代码安装 DeepSpeed
39 |
40 | 根据你的 GPU 实际情况设置参数 `TORCH_CUDA_ARCH_LIST`。如果你需要使用 CPU Offload 优化器参数,设置参数 `DS_BUILD_CPU_ADAM=1`;如果你需要使用 NVMe Offload,设置参数 `DS_BUILD_UTILS=1`:
41 |
42 | ```bash
43 | git clone https://github.com/microsoft/DeepSpeed/
44 | cd DeepSpeed
45 | rm -rf build
46 | TORCH_CUDA_ARCH_LIST="7.5" DS_BUILD_CPU_ADAM=1 DS_BUILD_UTILS=1 pip install . \
47 | --global-option="build_ext" --global-option="-j8" --no-cache -v \
48 | --disable-pip-version-check 2>&1 | tee build.log
49 | ```
50 |
51 | **注意:不要在项目内 clone DeepSpeed 源代码安装,容易造成误提交。**
52 |
53 | ### 使用 DeepSpeed 训练 T5 系列模型
54 |
55 | - 单机单卡训练脚本:[train_on_one_gpu.sh](train_on_one_gpu.sh)
56 | - 分布式训练脚本:[train_on_multi_nodes.sh](train_on_multi_nodes.sh)
--------------------------------------------------------------------------------
/deepspeed/config/ds_config_zero2.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 |
11 | "bf16": {
12 | "enabled": "auto"
13 | },
14 |
15 | "optimizer": {
16 | "type": "AdamW",
17 | "params": {
18 | "lr": "auto",
19 | "betas": "auto",
20 | "eps": "auto",
21 | "weight_decay": "auto"
22 | }
23 | },
24 |
25 | "scheduler": {
26 | "type": "WarmupLR",
27 | "params": {
28 | "warmup_min_lr": "auto",
29 | "warmup_max_lr": "auto",
30 | "warmup_num_steps": "auto"
31 | }
32 | },
33 |
34 | "zero_optimization": {
35 | "stage": 2,
36 | "offload_optimizer": {
37 | "device": "cpu",
38 | "pin_memory": true
39 | },
40 | "allgather_partitions": true,
41 | "allgather_bucket_size": 2e8,
42 | "overlap_comm": true,
43 | "reduce_scatter": true,
44 | "reduce_bucket_size": 2e8,
45 | "contiguous_gradients": true
46 | },
47 |
48 | "gradient_accumulation_steps": "auto",
49 | "gradient_clipping": "auto",
50 | "steps_per_print": 20,
51 | "train_batch_size": "auto",
52 | "train_micro_batch_size_per_gpu": "auto",
53 | "wall_clock_breakdown": false
54 | }
55 |
--------------------------------------------------------------------------------
/deepspeed/config/ds_config_zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 |
11 | "bf16": {
12 | "enabled": "auto"
13 | },
14 |
15 | "optimizer": {
16 | "type": "AdamW",
17 | "params": {
18 | "lr": "auto",
19 | "betas": "auto",
20 | "eps": "auto",
21 | "weight_decay": "auto"
22 | }
23 | },
24 |
25 | "scheduler": {
26 | "type": "WarmupLR",
27 | "params": {
28 | "warmup_min_lr": "auto",
29 | "warmup_max_lr": "auto",
30 | "warmup_num_steps": "auto"
31 | }
32 | },
33 |
34 | "zero_optimization": {
35 | "stage": 3,
36 | "offload_optimizer": {
37 | "device": "cpu",
38 | "pin_memory": true
39 | },
40 | "offload_param": {
41 | "device": "cpu",
42 | "pin_memory": true
43 | },
44 | "overlap_comm": true,
45 | "contiguous_gradients": true,
46 | "sub_group_size": 1e9,
47 | "reduce_bucket_size": "auto",
48 | "stage3_prefetch_bucket_size": "auto",
49 | "stage3_param_persistence_threshold": "auto",
50 | "stage3_max_live_parameters": 1e9,
51 | "stage3_max_reuse_distance": 1e9,
52 | "stage3_gather_16bit_weights_on_model_save": true
53 | },
54 |
55 | "gradient_accumulation_steps": "auto",
56 | "gradient_clipping": "auto",
57 | "steps_per_print": 20,
58 | "train_batch_size": "auto",
59 | "train_micro_batch_size_per_gpu": "auto",
60 | "wall_clock_breakdown": false
61 | }
62 |
--------------------------------------------------------------------------------
/deepspeed/train_on_multi_nodes.sh:
--------------------------------------------------------------------------------
1 | ################# 在编译和源代码安装 DeepSpeed 的机器运行 ######################3
2 | # 更新 GCC 和 G++ 版本(如需)
3 | sudo add-apt-repository ppa:ubuntu-toolchain-r/test
4 | sudo apt update
5 | sudo apt install gcc-7 g++-7
6 | # 更新系统的默认 gcc 和 g++ 指向
7 | sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-7 60 --slave /usr/bin/g++ g++ /usr/bin/g++-7
8 | sudo update-alternatives --config gcc
9 |
10 | # 源代码安装 DeepSpeed
11 | # 根据你的 GPU 实际情况(查看方法见前一页),设置参数 TORCH_CUDA_ARCH_LIST;
12 | # 如果你需要使用 NVMe Offload,设置参数 DS_BUILD_UTILS=1;
13 | # 如果你需要使用 CPU Offload 优化器参数,设置参数 DS_BUILD_CPU_ADAM=1;
14 | git clone https://github.com/microsoft/DeepSpeed/
15 | cd DeepSpeed
16 | rm -rf build
17 | TORCH_CUDA_ARCH_LIST="7.5" DS_BUILD_CPU_ADAM=1 DS_BUILD_UTILS=1
18 | python setup.py build_ext -j8 bdist_wheel
19 | # 运行将生成类似于dist/deepspeed-0.3.13+8cd046f-cp38-cp38-linux_x86_64.whl的文件,
20 | # 在其他节点安装:pip install deepspeed-0.3.13+8cd046f-cp38-cp38-linux_x86_64.whl。
21 |
22 | # 源代码安装 Transformers
23 | # https://huggingface.co/docs/transformers/installation#install-from-source
24 | pip install git+https://github.com/huggingface/transformers
25 |
26 |
27 | ################# launch.slurm 脚本(按照实际情况修改模板值) ######################
28 | #SBATCH --job-name=test-nodes # name
29 | #SBATCH --nodes=2 # nodes
30 | #SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node!
31 | #SBATCH --cpus-per-task=10 # number of cores per tasks
32 | #SBATCH --gres=gpu:8 # number of gpus
33 | #SBATCH --time 20:00:00 # maximum execution time (HH:MM:SS)
34 | #SBATCH --output=%x-%j.out # output file name
35 |
36 | export GPUS_PER_NODE=8
37 | export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
38 | export MASTER_PORT=9901
39 |
40 | srun --jobid $SLURM_JOBID bash -c 'python -m torch.distributed.run \
41 | --nproc_per_node $GPUS_PER_NODE --nnodes $SLURM_NNODES --node_rank $SLURM_PROCID \
42 | --master_addr $MASTER_ADDR --master_port $MASTER_PORT \
43 | your_program.py --deepspeed ds_config.json'
--------------------------------------------------------------------------------
/deepspeed/train_on_one_gpu.sh:
--------------------------------------------------------------------------------
1 | # DeepSpeed ZeRO-2 模式单 GPU 训练翻译模型(T5-Small)
2 | deepspeed --num_gpus=1 translation/run_translation.py \
3 | --deepspeed config/ds_config_zero2.json \
4 | --model_name_or_path t5-small --per_device_train_batch_size 1 \
5 | --output_dir output_dir --overwrite_output_dir --fp16 \
6 | --do_train --max_train_samples 500 --num_train_epochs 1 \
7 | --dataset_name wmt16 --dataset_config "ro-en" \
8 | --source_lang en --target_lang ro
9 |
10 | # DeepSpeed ZeRO-2 模式单 GPU 训练翻译模型(T5-Large)
11 | deepspeed --num_gpus=1 translation/run_translation.py \
12 | --deepspeed config/ds_config_zero2.json \
13 | --model_name_or_path t5-large \
14 | --per_device_train_batch_size 4 \
15 | --per_device_eval_batch_size 4 \
16 | --output_dir output_dir --overwrite_output_dir \
17 | --do_train \
18 | --do_eval \
19 | --max_train_samples 500 --num_train_epochs 1 \
20 | --dataset_name wmt16 --dataset_config "ro-en" \
21 | --source_lang en --target_lang ro
22 |
23 |
24 |
25 | # DeepSpeed ZeRO-3 模式单 GPU 训练翻译模型(T5-Large)
26 | deepspeed --num_gpus=1 translation/run_translation.py \
27 | --deepspeed config/ds_config_zero3.json \
28 | --model_name_or_path t5-3b --per_device_train_batch_size 1 \
29 | --output_dir output_dir --overwrite_output_dir --fp16 \
30 | --do_train --max_train_samples 500 --num_train_epochs 1 \
31 | --dataset_name wmt16 --dataset_config "ro-en" \
32 | --source_lang en --target_lang ro
33 |
34 |
35 |
36 | # 直接使用 Python 命令启动 ZeRO-2 模式单 GPU 训练翻译模型(T5-Small)
37 | python translation/run_translation.py \
38 | --model_name_or_path t5-small \
39 | --do_train \
40 | --do_eval \
41 | --source_lang en \
42 | --target_lang ro \
43 | --source_prefix "translate English to Romanian: " \
44 | --dataset_name wmt16 \
45 | --dataset_config_name ro-en \
46 | --output_dir tmp/tst-translation \
47 | --per_device_train_batch_size=4 \
48 | --per_device_eval_batch_size=4 \
49 | --overwrite_output_dir \
50 | --predict_with_generate
--------------------------------------------------------------------------------
/deepspeed/translation/README.md:
--------------------------------------------------------------------------------
1 |
16 |
17 | ## Translation
18 |
19 | This directory contains examples for finetuning and evaluating transformers on translation tasks.
20 | Please tag @patil-suraj with any issues/unexpected behaviors, or send a PR!
21 | For deprecated `bertabs` instructions, see [`bertabs/README.md`](https://github.com/huggingface/transformers/blob/main/examples/research_projects/bertabs/README.md).
22 | For the old `finetune_trainer.py` and related utils, see [`examples/legacy/seq2seq`](https://github.com/huggingface/transformers/blob/main/examples/legacy/seq2seq).
23 |
24 | ### Supported Architectures
25 |
26 | - `BartForConditionalGeneration`
27 | - `FSMTForConditionalGeneration` (translation only)
28 | - `MBartForConditionalGeneration`
29 | - `MarianMTModel`
30 | - `PegasusForConditionalGeneration`
31 | - `T5ForConditionalGeneration`
32 | - `MT5ForConditionalGeneration`
33 |
34 | `run_translation.py` is a lightweight examples of how to download and preprocess a dataset from the [🤗 Datasets](https://github.com/huggingface/datasets) library or use your own files (jsonlines or csv), then fine-tune one of the architectures above on it.
35 |
36 | For custom datasets in `jsonlines` format please see: https://huggingface.co/docs/datasets/loading_datasets#json-files
37 | and you also will find examples of these below.
38 |
39 |
40 | ## With Trainer
41 |
42 | Here is an example of a translation fine-tuning with a MarianMT model:
43 |
44 | ```bash
45 | python examples/pytorch/translation/run_translation.py \
46 | --model_name_or_path Helsinki-NLP/opus-mt-en-ro \
47 | --do_train \
48 | --do_eval \
49 | --source_lang en \
50 | --target_lang ro \
51 | --dataset_name wmt16 \
52 | --dataset_config_name ro-en \
53 | --output_dir /tmp/tst-translation \
54 | --per_device_train_batch_size=4 \
55 | --per_device_eval_batch_size=4 \
56 | --overwrite_output_dir \
57 | --predict_with_generate
58 | ```
59 |
60 | MBart and some T5 models require special handling.
61 |
62 | T5 models `t5-small`, `t5-base`, `t5-large`, `t5-3b` and `t5-11b` must use an additional argument: `--source_prefix "translate {source_lang} to {target_lang}"`. For example:
63 |
64 | ```bash
65 | python examples/pytorch/translation/run_translation.py \
66 | --model_name_or_path t5-small \
67 | --do_train \
68 | --do_eval \
69 | --source_lang en \
70 | --target_lang ro \
71 | --source_prefix "translate English to Romanian: " \
72 | --dataset_name wmt16 \
73 | --dataset_config_name ro-en \
74 | --output_dir /tmp/tst-translation \
75 | --per_device_train_batch_size=4 \
76 | --per_device_eval_batch_size=4 \
77 | --overwrite_output_dir \
78 | --predict_with_generate
79 | ```
80 |
81 | If you get a terrible BLEU score, make sure that you didn't forget to use the `--source_prefix` argument.
82 |
83 | For the aforementioned group of T5 models it's important to remember that if you switch to a different language pair, make sure to adjust the source and target values in all 3 language-specific command line argument: `--source_lang`, `--target_lang` and `--source_prefix`.
84 |
85 | MBart models require a different format for `--source_lang` and `--target_lang` values, e.g. instead of `en` it expects `en_XX`, for `ro` it expects `ro_RO`. The full MBart specification for language codes can be found [here](https://huggingface.co/facebook/mbart-large-cc25). For example:
86 |
87 | ```bash
88 | python examples/pytorch/translation/run_translation.py \
89 | --model_name_or_path facebook/mbart-large-en-ro \
90 | --do_train \
91 | --do_eval \
92 | --dataset_name wmt16 \
93 | --dataset_config_name ro-en \
94 | --source_lang en_XX \
95 | --target_lang ro_RO \
96 | --output_dir /tmp/tst-translation \
97 | --per_device_train_batch_size=4 \
98 | --per_device_eval_batch_size=4 \
99 | --overwrite_output_dir \
100 | --predict_with_generate
101 | ```
102 |
103 | And here is how you would use the translation finetuning on your own files, after adjusting the
104 | values for the arguments `--train_file`, `--validation_file` to match your setup:
105 |
106 | ```bash
107 | python examples/pytorch/translation/run_translation.py \
108 | --model_name_or_path t5-small \
109 | --do_train \
110 | --do_eval \
111 | --source_lang en \
112 | --target_lang ro \
113 | --source_prefix "translate English to Romanian: " \
114 | --dataset_name wmt16 \
115 | --dataset_config_name ro-en \
116 | --train_file path_to_jsonlines_file \
117 | --validation_file path_to_jsonlines_file \
118 | --output_dir /tmp/tst-translation \
119 | --per_device_train_batch_size=4 \
120 | --per_device_eval_batch_size=4 \
121 | --overwrite_output_dir \
122 | --predict_with_generate
123 | ```
124 |
125 | The task of translation supports only custom JSONLINES files, with each line being a dictionary with a key `"translation"` and its value another dictionary whose keys is the language pair. For example:
126 |
127 | ```json
128 | { "translation": { "en": "Others have dismissed him as a joke.", "ro": "Alții l-au numit o glumă." } }
129 | { "translation": { "en": "And some are holding out for an implosion.", "ro": "Iar alții așteaptă implozia." } }
130 | ```
131 | Here the languages are Romanian (`ro`) and English (`en`).
132 |
133 | If you want to use a pre-processed dataset that leads to high BLEU scores, but for the `en-de` language pair, you can use `--dataset_name stas/wmt14-en-de-pre-processed`, as following:
134 |
135 | ```bash
136 | python examples/pytorch/translation/run_translation.py \
137 | --model_name_or_path t5-small \
138 | --do_train \
139 | --do_eval \
140 | --source_lang en \
141 | --target_lang de \
142 | --source_prefix "translate English to German: " \
143 | --dataset_name stas/wmt14-en-de-pre-processed \
144 | --output_dir /tmp/tst-translation \
145 | --per_device_train_batch_size=4 \
146 | --per_device_eval_batch_size=4 \
147 | --overwrite_output_dir \
148 | --predict_with_generate
149 | ```
150 |
151 | ## With Accelerate
152 |
153 | Based on the script [`run_translation_no_trainer.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/translation/run_translation_no_trainer.py).
154 |
155 | Like `run_translation.py`, this script allows you to fine-tune any of the models supported on a
156 | translation task, the main difference is that this
157 | script exposes the bare training loop, to allow you to quickly experiment and add any customization you would like.
158 |
159 | It offers less options than the script with `Trainer` (for instance you can easily change the options for the optimizer
160 | or the dataloaders directly in the script) but still run in a distributed setup, on TPU and supports mixed precision by
161 | the mean of the [🤗 `Accelerate`](https://github.com/huggingface/accelerate) library. You can use the script normally
162 | after installing it:
163 |
164 | ```bash
165 | pip install git+https://github.com/huggingface/accelerate
166 | ```
167 |
168 | then
169 |
170 | ```bash
171 | python run_translation_no_trainer.py \
172 | --model_name_or_path Helsinki-NLP/opus-mt-en-ro \
173 | --source_lang en \
174 | --target_lang ro \
175 | --dataset_name wmt16 \
176 | --dataset_config_name ro-en \
177 | --output_dir ~/tmp/tst-translation
178 | ```
179 |
180 | You can then use your usual launchers to run in it in a distributed environment, but the easiest way is to run
181 |
182 | ```bash
183 | accelerate config
184 | ```
185 |
186 | and reply to the questions asked. Then
187 |
188 | ```bash
189 | accelerate test
190 | ```
191 |
192 | that will check everything is ready for training. Finally, you can launch training with
193 |
194 | ```bash
195 | accelerate launch run_translation_no_trainer.py \
196 | --model_name_or_path Helsinki-NLP/opus-mt-en-ro \
197 | --source_lang en \
198 | --target_lang ro \
199 | --dataset_name wmt16 \
200 | --dataset_config_name ro-en \
201 | --output_dir ~/tmp/tst-translation
202 | ```
203 |
204 | This command is the same and will work for:
205 |
206 | - a CPU-only setup
207 | - a setup with one GPU
208 | - a distributed training with several GPUs (single or multi node)
209 | - a training on TPUs
210 |
211 | Note that this library is in alpha release so your feedback is more than welcome if you encounter any problem using it.
212 |
--------------------------------------------------------------------------------
/deepspeed/translation/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate >= 0.12.0
2 | datasets >= 1.8.0
3 | sentencepiece != 0.1.92
4 | protobuf
5 | sacrebleu >= 1.4.12
6 | py7zr
7 | torch >= 1.3
8 | evaluate
--------------------------------------------------------------------------------
/docs/INSTALL.md:
--------------------------------------------------------------------------------
1 | # Transformers开发环境搭建
2 | ## 介绍
3 | 开发环境搭建包含几个部分
4 | - Miniconda
5 | - Jupyter Lab
6 | - Hugging Face Transformers,需要尝试多种模型时候,建议tensorflow和pytorch都安装
7 | - 其他依赖包
8 |
9 | ## Miniconda
10 | Miniconda 是一个 Python 环境管理工具,可以用来创建、管理多个 Python 环境。它是 Anaconda 的轻量级替代品,不包含任何 IDE 工具。 Miniconda可以从[官网](https://docs.conda.io/en/latest/miniconda.html)下载安装包。也可以从镜像网站下载:
11 |
12 | ### Miniconda环境的安装
13 | ```bash
14 | # 下载 Miniconda 安装包
15 | $ wget https://mirrors.tuna.tsinghua.edu.cn/anaconda/miniconda/Miniconda3-latest-Linux-x86_64.sh
16 | # 也可以使用curl命令下载
17 | $ curl -O https://mirrors.tuna.tsinghua.edu.cn/anaconda/miniconda/Miniconda3-latest-Linux-x86_64.sh
18 | # 安装 Miniconda
19 | $ bash Miniconda3-latest-Linux-x86_64.sh
20 | ```
21 |
22 | 安装过程中,需要回答一些问题,如安装路径、是否将 Miniconda 添加到环境变量等。安装完成后,需要重启终端,使环境变量生效。
23 |
24 | 可以使用以下命令来验证 Miniconda 是否安装成功:
25 |
26 | ```bash
27 | $ conda --version
28 | ```
29 |
30 | ### 配置Miniconda
31 | Miniconda的配置文件存放在~/.condarc,可以参考文档手工修改,也可以使用conda config命令来修改。
32 |
33 | 1. 为了加速包下载,可以配置使用国内的镜像源:
34 | ```bash
35 | # 配置清华镜像
36 | $ conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
37 | $ conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
38 | $ conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/
39 | $ conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/
40 | $ conda config --set show_channel_urls yes
41 | # 查看~/.condarc配置
42 | $ conda config --show-sources
43 | channels:
44 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/
45 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/
46 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
47 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
48 | - defaults
49 | show_channel_urls: True
50 | ```
51 | 2. 加速anaconda包的下载
52 | 可以使用mamba或micromamba来代替conda,这两个工具都是conda的替代品,会缓存包的版本信息,不需要在每次安装包的时候都去检查,这种可以有效提高conda-forge这种比较大的。安装mamba或micromamba的方法如下:
53 | ```bash
54 | # 安装mamba
55 | $ conda install -n base -c conda-forge mamba
56 | # 安装micromamba
57 | $ conda install -n base -c conda-forge micromamba
58 | ```
59 | 之后可以使用mamba或者micromamba命令代替conda命令。
60 |
61 | ### 创建虚拟环境
62 | ```bash
63 | # 创建虚拟环境,指定 Python 版本为 3.11
64 | (base) $ conda create -n transformers python=3.11
65 | # 激活 openai 环境
66 | $ conda activate transformers
67 | ```
68 | 以下若无特殊说明,均在这里新建的openai环境中进行。
69 |
70 | ## Jupyter Lab
71 | Jupyter Lab 是一个交互式的开发环境,可以在浏览器中运行。它支持多种编程语言,包括 Python、R、Julia 等。 Jupyter Lab由conda-forge提供,请先配置镜像,然后使用以下命令安装:
72 | ```bash
73 | (transformers) $ conda install jupyterlab
74 | ```
75 |
76 | ## Hugging Face Transformers
77 | Hugging Face Transformers 是一个基于 PyTorch 和 TensorFlow 的自然语言处理工具包,提供了大量预训练模型,可以用来完成多种 NLP 任务。Hugging Face Transformers 可以通过 conda 安装:
78 |
79 | ```bash
80 | (transformers) $ conda install -c huggingface transformers
81 | ```
82 |
83 | 安装文档:[Hugging Face Transformers](https://huggingface.co/docs/transformers/installation#install-with-conda)
84 |
85 |
86 | ## 安装pytorch
87 |
88 | Transformers需要使用pytorch进行实际的模型推理,在前面已经配置了使用的pytorch和conda-forge镜像源,可以使用下命令安装和CUDA版本对应的Pytorch版本:
89 | ```bash
90 | # Linux
91 | # CUDA 11.8
92 | (transformers) $ conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c nvidia
93 | # CUDA 12.1
94 | (transformers) $ conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c nvidia
95 |
96 | # Mac
97 | (transformers) $ conda install pytorch::pytorch torchvision torchaudio
98 | ```
99 |
100 | 安装文档:[pytorch](https://pytorch.org/get-started/locally/)
101 |
102 | ## 安装其他的依赖包
103 | 在处理图像、音频等数据时,需要使用到其他的依赖包,包括:
104 | - tqdm、iprogress 进度条
105 | - ffmpeg、ffmpeg-python 音频处理工具
106 | - pillow 图像处理工具
107 |
108 | ```bash
109 | (transformers) $ conda install tqdm iprogress ffmpeg ffmpeg-python pillow
110 | ```
111 |
--------------------------------------------------------------------------------
/docs/cuda_installation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DjangoPeng/LLM-quickstart/bf3e50d82a104c2b5bfaea3f3d476e8b8193ca2c/docs/cuda_installation.png
--------------------------------------------------------------------------------
/docs/version_check.py:
--------------------------------------------------------------------------------
1 | import pkg_resources
2 | import subprocess
3 |
4 | # 首先,确保安装了 requirements.txt 中的所有包
5 | subprocess.check_call(["pip", "install", "-r", "../requirements.txt"])
6 |
7 | # 读取 requirements.txt 文件,获取软件包名称列表
8 | with open("../requirements.txt", "r") as f:
9 | packages = f.readlines()
10 | packages = [pkg.strip() for pkg in packages]
11 |
12 | # 获取每个软件包的版本信息
13 | with open("version_info.txt", "w") as output_file:
14 | for pkg in packages:
15 | if pkg == "" or pkg.startswith("#"): # 跳过空行和注释
16 | continue
17 | try:
18 | # 尝试获取软件包版本
19 | version = pkg_resources.get_distribution(pkg).version
20 | output_file.write(f"{pkg}=={version}\n")
21 | except pkg_resources.DistributionNotFound:
22 | # 如果软件包未安装,则记录一个错误消息
23 | output_file.write(f"{pkg}: Not Found\n")
24 |
25 | print("版本信息已写入 version_info.txt 文件。")
26 |
--------------------------------------------------------------------------------
/docs/version_info.txt:
--------------------------------------------------------------------------------
1 | torch>=2.1.2==2.3.0.dev20240116+cu121
2 | transformers==4.37.2
3 | ffmpeg==1.4
4 | ffmpeg-python==0.2.0
5 | timm==0.9.12
6 | datasets==2.16.1
7 | evaluate==0.4.1
8 | scikit-learn==1.3.2
9 | pandas==2.1.1
10 | peft==0.7.2.dev0
11 | accelerate==0.26.1
12 | autoawq==0.2.2
13 | optimum==1.17.0.dev0
14 | auto-gptq==0.6.0
15 | bitsandbytes>0.39.0==0.41.3.post2
16 | jiwer==3.0.3
17 | soundfile>=0.12.1==0.12.1
18 | librosa==0.10.1
19 | langchain==0.1.0
20 | gradio==4.13.0
21 |
--------------------------------------------------------------------------------
/langchain/chains/router_chain.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "c3e2458f-d038-4845-93a0-d4ad830f9f90",
6 | "metadata": {},
7 | "source": [
8 | "# LangChain 核心模块学习:Chains\n",
9 | "\n",
10 | "对于简单的大模型应用,单独使用语言模型(LLMs)是可以的。\n",
11 | "\n",
12 | "**但更复杂的大模型应用需要将 `LLMs` 和 `Chat Models` 链接在一起 - 要么彼此链接,要么与其他组件链接。**\n",
13 | "\n",
14 | "LangChain 为这种“链式”应用程序提供了 `Chain` 接口。\n",
15 | "\n",
16 | "LangChain 以通用方式定义了 `Chain`,它是对组件进行调用序列的集合,其中可以包含其他链。"
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "id": "c81a7df0-26c7-4eb8-92f1-cc54445cf507",
22 | "metadata": {},
23 | "source": [
24 | "## LLMChain\n",
25 | "\n",
26 | "LLMChain 是 LangChain 中最简单的链,作为其他复杂 Chains 和 Agents 的内部调用,被广泛应用。\n",
27 | "\n",
28 | "一个LLMChain由PromptTemplate和语言模型(LLM or Chat Model)组成。它使用直接传入(或 memory 提供)的 key-value 来规范化生成 Prompt Template(提示模板),并将生成的 prompt (格式化后的字符串)传递给大模型,并返回大模型输出。\n",
29 | "\n",
30 | ""
31 | ]
32 | },
33 | {
34 | "cell_type": "markdown",
35 | "id": "4fbd5ca7-ca54-4701-919c-2857266caefc",
36 | "metadata": {},
37 | "source": [
38 | "## Router Chain: 实现条件判断的大模型调用\n",
39 | "\n",
40 | "\n",
41 | "这段代码构建了一个可定制的链路系统,用户可以提供不同的输入提示,并根据这些提示获取适当的响应。\n",
42 | "\n",
43 | "主要逻辑:从`prompt_infos`创建多个`LLMChain`对象,并将它们保存在一个字典中,然后创建一个默认的`ConversationChain`,最后创建一个带有路由功能的`MultiPromptChain`。\n",
44 | "\n",
45 | ""
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": 1,
51 | "id": "aaf8c391-9225-4e66-ad4d-d689b53a0379",
52 | "metadata": {},
53 | "outputs": [],
54 | "source": [
55 | "from langchain.chains.router import MultiPromptChain\n",
56 | "from langchain.llms import OpenAI\n",
57 | "from langchain.chains import ConversationChain\n",
58 | "from langchain.chains.llm import LLMChain\n",
59 | "from langchain.prompts import PromptTemplate"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": 2,
65 | "id": "33b5061c-391e-4762-91c7-73b57f4ab501",
66 | "metadata": {},
67 | "outputs": [],
68 | "source": [
69 | "physics_template = \"\"\"你是一位非常聪明的物理教授。\n",
70 | "你擅长以简洁易懂的方式回答关于物理的问题。\n",
71 | "当你不知道某个问题的答案时,你会坦诚承认。\n",
72 | "\n",
73 | "这是一个问题:\n",
74 | "{input}\"\"\"\n",
75 | "\n",
76 | "\n",
77 | "math_template = \"\"\"你是一位很棒的数学家。你擅长回答数学问题。\n",
78 | "之所以如此出色,是因为你能够将难题分解成各个组成部分,\n",
79 | "先回答这些组成部分,然后再将它们整合起来回答更广泛的问题。\n",
80 | "\n",
81 | "这是一个问题:\n",
82 | "{input}\"\"\""
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "execution_count": 3,
88 | "id": "5ef1db6e-3da4-4f9b-9707-0f30aa293dd7",
89 | "metadata": {},
90 | "outputs": [],
91 | "source": [
92 | "prompt_infos = [\n",
93 | " {\n",
94 | " \"name\": \"物理\",\n",
95 | " \"description\": \"适用于回答物理问题\",\n",
96 | " \"prompt_template\": physics_template,\n",
97 | " },\n",
98 | " {\n",
99 | " \"name\": \"数学\",\n",
100 | " \"description\": \"适用于回答数学问题\",\n",
101 | " \"prompt_template\": math_template,\n",
102 | " },\n",
103 | "]"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 4,
109 | "id": "3983cafe-c2d5-4951-b779-88d844594777",
110 | "metadata": {},
111 | "outputs": [],
112 | "source": [
113 | "llm = OpenAI()"
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "execution_count": 5,
119 | "id": "db8be9f0-1ac2-4ded-8950-6403cfa40004",
120 | "metadata": {},
121 | "outputs": [],
122 | "source": [
123 | "# 创建一个空的目标链字典,用于存放根据prompt_infos生成的LLMChain。\n",
124 | "destination_chains = {}\n",
125 | "\n",
126 | "# 遍历prompt_infos列表,为每个信息创建一个LLMChain。\n",
127 | "for p_info in prompt_infos:\n",
128 | " name = p_info[\"name\"] # 提取名称\n",
129 | " prompt_template = p_info[\"prompt_template\"] # 提取模板\n",
130 | " # 创建PromptTemplate对象\n",
131 | " prompt = PromptTemplate(template=prompt_template, input_variables=[\"input\"])\n",
132 | " # 使用上述模板和llm对象创建LLMChain对象\n",
133 | " chain = LLMChain(llm=llm, prompt=prompt)\n",
134 | " # 将新创建的chain对象添加到destination_chains字典中\n",
135 | " destination_chains[name] = chain\n",
136 | "\n",
137 | "# 创建一个默认的ConversationChain\n",
138 | "default_chain = ConversationChain(llm=llm, output_key=\"text\")"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": 6,
144 | "id": "ae77b13a-2077-4e80-83f9-a2b1d8398461",
145 | "metadata": {},
146 | "outputs": [
147 | {
148 | "data": {
149 | "text/plain": [
150 | "langchain.chains.conversation.base.ConversationChain"
151 | ]
152 | },
153 | "execution_count": 6,
154 | "metadata": {},
155 | "output_type": "execute_result"
156 | }
157 | ],
158 | "source": [
159 | "type(default_chain)"
160 | ]
161 | },
162 | {
163 | "cell_type": "markdown",
164 | "id": "f5aa4a82-2d96-4124-8896-4e11e5d5c8e9",
165 | "metadata": {},
166 | "source": [
167 | "### 使用 LLMRouterChain 实现条件判断调用\n",
168 | "\n",
169 | "这段代码定义了一个chain对象(LLMRouterChain),该对象首先使用router_chain来决定哪个destination_chain应该被执行,如果没有合适的目标链,则默认使用default_chain。"
170 | ]
171 | },
172 | {
173 | "cell_type": "code",
174 | "execution_count": 7,
175 | "id": "1c196e6c-e767-4d4f-8327-50ead641bc3a",
176 | "metadata": {},
177 | "outputs": [],
178 | "source": [
179 | "from langchain.chains.router.llm_router import LLMRouterChain, RouterOutputParser\n",
180 | "from langchain.chains.router.multi_prompt_prompt import MULTI_PROMPT_ROUTER_TEMPLATE"
181 | ]
182 | },
183 | {
184 | "cell_type": "code",
185 | "execution_count": 8,
186 | "id": "f5ada86e-e430-412c-828d-b053b630f07c",
187 | "metadata": {},
188 | "outputs": [],
189 | "source": [
190 | "# 从prompt_infos中提取目标信息并将其转化为字符串列表\n",
191 | "destinations = [f\"{p['name']}: {p['description']}\" for p in prompt_infos]\n",
192 | "# 使用join方法将列表转化为字符串,每个元素之间用换行符分隔\n",
193 | "destinations_str = \"\\n\".join(destinations)\n",
194 | "# 根据MULTI_PROMPT_ROUTER_TEMPLATE格式化字符串和destinations_str创建路由模板\n",
195 | "router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(destinations=destinations_str)\n",
196 | "# 创建路由的PromptTemplate\n",
197 | "router_prompt = PromptTemplate(\n",
198 | " template=router_template,\n",
199 | " input_variables=[\"input\"],\n",
200 | " output_parser=RouterOutputParser(),\n",
201 | ")\n",
202 | "# 使用上述路由模板和llm对象创建LLMRouterChain对象\n",
203 | "router_chain = LLMRouterChain.from_llm(llm, router_prompt)"
204 | ]
205 | },
206 | {
207 | "cell_type": "code",
208 | "execution_count": 9,
209 | "id": "8c1013dc-ae1f-468d-96b3-4babe0d50d1f",
210 | "metadata": {},
211 | "outputs": [
212 | {
213 | "name": "stdout",
214 | "output_type": "stream",
215 | "text": [
216 | "['物理: 适用于回答物理问题', '数学: 适用于回答数学问题']\n"
217 | ]
218 | }
219 | ],
220 | "source": [
221 | "print(destinations)"
222 | ]
223 | },
224 | {
225 | "cell_type": "code",
226 | "execution_count": 10,
227 | "id": "a85ef126-aca1-40c2-8e01-d15af5500785",
228 | "metadata": {},
229 | "outputs": [
230 | {
231 | "name": "stdout",
232 | "output_type": "stream",
233 | "text": [
234 | "物理: 适用于回答物理问题\n",
235 | "数学: 适用于回答数学问题\n"
236 | ]
237 | }
238 | ],
239 | "source": [
240 | "print(destinations_str)"
241 | ]
242 | },
243 | {
244 | "cell_type": "code",
245 | "execution_count": 11,
246 | "id": "5db81fcb-704a-4250-a6b5-210e4be77af5",
247 | "metadata": {},
248 | "outputs": [
249 | {
250 | "name": "stdout",
251 | "output_type": "stream",
252 | "text": [
253 | "Given a raw text input to a language model select the model prompt best suited for the input. You will be given the names of the available prompts and a description of what the prompt is best suited for. You may also revise the original input if you think that revising it will ultimately lead to a better response from the language model.\n",
254 | "\n",
255 | "<< FORMATTING >>\n",
256 | "Return a markdown code snippet with a JSON object formatted to look like:\n",
257 | "```json\n",
258 | "{{{{\n",
259 | " \"destination\": string \\ name of the prompt to use or \"DEFAULT\"\n",
260 | " \"next_inputs\": string \\ a potentially modified version of the original input\n",
261 | "}}}}\n",
262 | "```\n",
263 | "\n",
264 | "REMEMBER: \"destination\" MUST be one of the candidate prompt names specified below OR it can be \"DEFAULT\" if the input is not well suited for any of the candidate prompts.\n",
265 | "REMEMBER: \"next_inputs\" can just be the original input if you don't think any modifications are needed.\n",
266 | "\n",
267 | "<< CANDIDATE PROMPTS >>\n",
268 | "{destinations}\n",
269 | "\n",
270 | "<< INPUT >>\n",
271 | "{{input}}\n",
272 | "\n",
273 | "<< OUTPUT (must include ```json at the start of the response) >>\n",
274 | "<< OUTPUT (must end with ```) >>\n",
275 | "\n"
276 | ]
277 | }
278 | ],
279 | "source": [
280 | "print(MULTI_PROMPT_ROUTER_TEMPLATE)"
281 | ]
282 | },
283 | {
284 | "cell_type": "code",
285 | "execution_count": 12,
286 | "id": "f882244c-1fa6-4d74-a44c-578c9fb25e18",
287 | "metadata": {},
288 | "outputs": [
289 | {
290 | "name": "stdout",
291 | "output_type": "stream",
292 | "text": [
293 | "Given a raw text input to a language model select the model prompt best suited for the input. You will be given the names of the available prompts and a description of what the prompt is best suited for. You may also revise the original input if you think that revising it will ultimately lead to a better response from the language model.\n",
294 | "\n",
295 | "<< FORMATTING >>\n",
296 | "Return a markdown code snippet with a JSON object formatted to look like:\n",
297 | "```json\n",
298 | "{{\n",
299 | " \"destination\": string \\ name of the prompt to use or \"DEFAULT\"\n",
300 | " \"next_inputs\": string \\ a potentially modified version of the original input\n",
301 | "}}\n",
302 | "```\n",
303 | "\n",
304 | "REMEMBER: \"destination\" MUST be one of the candidate prompt names specified below OR it can be \"DEFAULT\" if the input is not well suited for any of the candidate prompts.\n",
305 | "REMEMBER: \"next_inputs\" can just be the original input if you don't think any modifications are needed.\n",
306 | "\n",
307 | "<< CANDIDATE PROMPTS >>\n",
308 | "物理: 适用于回答物理问题\n",
309 | "数学: 适用于回答数学问题\n",
310 | "\n",
311 | "<< INPUT >>\n",
312 | "{input}\n",
313 | "\n",
314 | "<< OUTPUT (must include ```json at the start of the response) >>\n",
315 | "<< OUTPUT (must end with ```) >>\n",
316 | "\n"
317 | ]
318 | }
319 | ],
320 | "source": [
321 | "print(router_template)"
322 | ]
323 | },
324 | {
325 | "cell_type": "code",
326 | "execution_count": 13,
327 | "id": "c2a482e4-5757-4295-a3d8-c3fdd1d4abd2",
328 | "metadata": {},
329 | "outputs": [],
330 | "source": [
331 | "# 创建MultiPromptChain对象,其中包含了路由链,目标链和默认链。\n",
332 | "chain = MultiPromptChain(\n",
333 | " router_chain=router_chain,\n",
334 | " destination_chains=destination_chains,\n",
335 | " default_chain=default_chain,\n",
336 | " verbose=True,\n",
337 | ")"
338 | ]
339 | },
340 | {
341 | "cell_type": "code",
342 | "execution_count": 14,
343 | "id": "128bb7a0-b176-4b14-835e-8aaa723ab441",
344 | "metadata": {},
345 | "outputs": [
346 | {
347 | "name": "stdout",
348 | "output_type": "stream",
349 | "text": [
350 | "\n",
351 | "\n",
352 | "\u001b[1m> Entering new MultiPromptChain chain...\u001b[0m\n"
353 | ]
354 | },
355 | {
356 | "name": "stderr",
357 | "output_type": "stream",
358 | "text": [
359 | "/root/miniconda3/envs/langchain/lib/python3.10/site-packages/langchain/chains/llm.py:321: UserWarning: The predict_and_parse method is deprecated, instead pass an output parser directly to LLMChain.\n",
360 | " warnings.warn(\n"
361 | ]
362 | },
363 | {
364 | "name": "stdout",
365 | "output_type": "stream",
366 | "text": [
367 | "物理: {'input': '黑体辐射是什么?'}\n",
368 | "\u001b[1m> Finished chain.\u001b[0m\n",
369 | "\n",
370 | "\n",
371 | "黑体辐射是一种发出的热量,由物体因其自身温度而发出,它从物体表面以各种波长的光谱(电磁波)的形式发出。\n"
372 | ]
373 | }
374 | ],
375 | "source": [
376 | "print(chain.run(\"黑体辐射是什么??\"))"
377 | ]
378 | },
379 | {
380 | "cell_type": "code",
381 | "execution_count": 15,
382 | "id": "cd869807-9cec-4bb2-9104-ecc4efce9baa",
383 | "metadata": {},
384 | "outputs": [
385 | {
386 | "name": "stdout",
387 | "output_type": "stream",
388 | "text": [
389 | "\n",
390 | "\n",
391 | "\u001b[1m> Entering new MultiPromptChain chain...\u001b[0m\n"
392 | ]
393 | },
394 | {
395 | "name": "stderr",
396 | "output_type": "stream",
397 | "text": [
398 | "/root/miniconda3/envs/langchain/lib/python3.10/site-packages/langchain/chains/llm.py:321: UserWarning: The predict_and_parse method is deprecated, instead pass an output parser directly to LLMChain.\n",
399 | " warnings.warn(\n"
400 | ]
401 | },
402 | {
403 | "name": "stdout",
404 | "output_type": "stream",
405 | "text": [
406 | "数学: {'input': '大于40的第一个质数,使得加一后能被3整除?'}\n",
407 | "\u001b[1m> Finished chain.\u001b[0m\n",
408 | "\n",
409 | "\n",
410 | "答案:43\n"
411 | ]
412 | }
413 | ],
414 | "source": [
415 | "print(\n",
416 | " chain.run(\n",
417 | " \"大于40的第一个质数是多少,使得这个质数加一能被3整除?\"\n",
418 | " )\n",
419 | ")"
420 | ]
421 | },
422 | {
423 | "cell_type": "code",
424 | "execution_count": 16,
425 | "id": "7ad5dcb2-48c0-4d0f-b6cc-09ebcbdce75e",
426 | "metadata": {},
427 | "outputs": [],
428 | "source": [
429 | "router_chain.verbose = True"
430 | ]
431 | },
432 | {
433 | "cell_type": "code",
434 | "execution_count": 17,
435 | "id": "bd37e004-bb24-4929-992c-34407593d86e",
436 | "metadata": {},
437 | "outputs": [
438 | {
439 | "name": "stdout",
440 | "output_type": "stream",
441 | "text": [
442 | "\n",
443 | "\n",
444 | "\u001b[1m> Entering new MultiPromptChain chain...\u001b[0m\n",
445 | "\n",
446 | "\n",
447 | "\u001b[1m> Entering new LLMRouterChain chain...\u001b[0m\n"
448 | ]
449 | },
450 | {
451 | "name": "stderr",
452 | "output_type": "stream",
453 | "text": [
454 | "/root/miniconda3/envs/langchain/lib/python3.10/site-packages/langchain/chains/llm.py:321: UserWarning: The predict_and_parse method is deprecated, instead pass an output parser directly to LLMChain.\n",
455 | " warnings.warn(\n"
456 | ]
457 | },
458 | {
459 | "name": "stdout",
460 | "output_type": "stream",
461 | "text": [
462 | "\n",
463 | "\u001b[1m> Finished chain.\u001b[0m\n",
464 | "物理: {'input': '黑洞是什么?'}\n",
465 | "\u001b[1m> Finished chain.\u001b[0m\n",
466 | "\n",
467 | "\n",
468 | "黑洞是一种超强引力场,它的引力比其他物质引力强得多,以至于即使光也无法逃离。黑洞由一个质量极大的中心点构成,称为“超文本”,以及其他物质的空间区域,这些空间区域的引力足以阻止任何物质和信息逃离。\n"
469 | ]
470 | }
471 | ],
472 | "source": [
473 | "print(chain.run(\"黑洞是什么?\"))"
474 | ]
475 | },
476 | {
477 | "cell_type": "code",
478 | "execution_count": null,
479 | "id": "a51119ed-025f-48d7-ad81-cd9cdab7090f",
480 | "metadata": {},
481 | "outputs": [],
482 | "source": []
483 | },
484 | {
485 | "cell_type": "code",
486 | "execution_count": null,
487 | "id": "cbda2930-a0e6-48b2-8e02-4c3d792f0225",
488 | "metadata": {},
489 | "outputs": [],
490 | "source": []
491 | },
492 | {
493 | "cell_type": "code",
494 | "execution_count": null,
495 | "id": "24d11e0f-d5ee-4086-9e1a-b21000232134",
496 | "metadata": {},
497 | "outputs": [],
498 | "source": []
499 | },
500 | {
501 | "cell_type": "markdown",
502 | "id": "8b6836f0-213d-4cac-abc9-3617831be3db",
503 | "metadata": {},
504 | "source": [
505 | "### Homework\n",
506 | "\n",
507 | "#### 扩展 Demo:实现生物、计算机和汉语言文学老师 PromptTemplates 及对应 Chains"
508 | ]
509 | },
510 | {
511 | "cell_type": "code",
512 | "execution_count": null,
513 | "id": "1c7edb0a-675d-40c0-9f5d-d58f0170ce72",
514 | "metadata": {},
515 | "outputs": [],
516 | "source": []
517 | }
518 | ],
519 | "metadata": {
520 | "kernelspec": {
521 | "display_name": "Python 3 (ipykernel)",
522 | "language": "python",
523 | "name": "python3"
524 | },
525 | "language_info": {
526 | "codemirror_mode": {
527 | "name": "ipython",
528 | "version": 3
529 | },
530 | "file_extension": ".py",
531 | "mimetype": "text/x-python",
532 | "name": "python",
533 | "nbconvert_exporter": "python",
534 | "pygments_lexer": "ipython3",
535 | "version": "3.10.11"
536 | }
537 | },
538 | "nbformat": 4,
539 | "nbformat_minor": 5
540 | }
541 |
--------------------------------------------------------------------------------
/langchain/chains/sequential_chain.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "c3e2458f-d038-4845-93a0-d4ad830f9f90",
6 | "metadata": {},
7 | "source": [
8 | "# LangChain 核心模块学习:Chains\n",
9 | "\n",
10 | "对于简单的大模型应用,单独使用语言模型(LLMs)是可以的。\n",
11 | "\n",
12 | "**但更复杂的大模型应用需要将 `LLMs` 和 `Chat Models` 链接在一起 - 要么彼此链接,要么与其他组件链接。**\n",
13 | "\n",
14 | "LangChain 为这种“链式”应用程序提供了 `Chain` 接口。\n",
15 | "\n",
16 | "LangChain 以通用方式定义了 `Chain`,它是对组件进行调用序列的集合,其中可以包含其他链。"
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "id": "f9cf0d43-107b-47ae-9e2c-2edaec38c800",
22 | "metadata": {},
23 | "source": [
24 | "## Chain Class 基类\n",
25 | "\n",
26 | "类继承关系:\n",
27 | "\n",
28 | "```\n",
29 | "Chain --> Chain # Examples: LLMChain, MapReduceChain, RouterChain\n",
30 | "```\n",
31 | "\n",
32 | "**代码实现:https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/chains/base.py**\n",
33 | "\n",
34 | "```python\n",
35 | "# 定义一个名为Chain的基础类\n",
36 | "class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):\n",
37 | " \"\"\"为创建结构化的组件调用序列的抽象基类。\n",
38 | " \n",
39 | " 链应该用来编码对组件的一系列调用,如模型、文档检索器、其他链等,并为此序列提供一个简单的接口。\n",
40 | " \n",
41 | " Chain接口使创建应用程序变得容易,这些应用程序是:\n",
42 | " - 有状态的:给任何Chain添加Memory可以使它具有状态,\n",
43 | " - 可观察的:向Chain传递Callbacks来执行额外的功能,如记录,这在主要的组件调用序列之外,\n",
44 | " - 可组合的:Chain API足够灵活,可以轻松地将Chains与其他组件结合起来,包括其他Chains。\n",
45 | " \n",
46 | " 链公开的主要方法是:\n",
47 | " - `__call__`:链是可以调用的。`__call__`方法是执行Chain的主要方式。它将输入作为一个字典接收,并返回一个字典输出。\n",
48 | " - `run`:一个方便的方法,它以args/kwargs的形式接收输入,并将输出作为字符串或对象返回。这种方法只能用于一部分链,不能像`__call__`那样返回丰富的输出。\n",
49 | " \"\"\"\n",
50 | "\n",
51 | " # 调用链\n",
52 | " def invoke(\n",
53 | " self, input: Dict[str, Any], config: Optional[runnableConfig] = None\n",
54 | " ) -> Dict[str, Any]:\n",
55 | " \"\"\"传统调用方法。\"\"\"\n",
56 | " return self(input, **(config or {}))\n",
57 | "\n",
58 | " # 链的记忆,保存状态和变量\n",
59 | " memory: Optional[BaseMemory] = None\n",
60 | " \"\"\"可选的内存对象,默认为None。\n",
61 | " 内存是一个在每个链的开始和结束时被调用的类。在开始时,内存加载变量并在链中传递它们。在结束时,它保存任何返回的变量。\n",
62 | " 有许多不同类型的内存,请查看内存文档以获取完整的目录。\"\"\"\n",
63 | "\n",
64 | " # 回调,可能用于链的某些操作或事件。\n",
65 | " callbacks: Callbacks = Field(default=None, exclude=True)\n",
66 | " \"\"\"可选的回调处理程序列表(或回调管理器)。默认为None。\n",
67 | " 在对链的调用的生命周期中,从on_chain_start开始,到on_chain_end或on_chain_error结束,都会调用回调处理程序。\n",
68 | " 每个自定义链可以选择调用额外的回调方法,详细信息请参见Callback文档。\"\"\"\n",
69 | "\n",
70 | " # 是否详细输出模式\n",
71 | " verbose: bool = Field(default_factory=_get_verbosity)\n",
72 | " \"\"\"是否以详细模式运行。在详细模式下,一些中间日志将打印到控制台。默认值为`langchain.verbose`。\"\"\"\n",
73 | "\n",
74 | " # 与链关联的标签\n",
75 | " tags: Optional[List[str]] = None\n",
76 | " \"\"\"与链关联的可选标签列表,默认为None。\n",
77 | " 这些标签将与对这个链的每次调用关联起来,并作为参数传递给在`callbacks`中定义的处理程序。\n",
78 | " 你可以使用这些来例如识别链的特定实例与其用例。\"\"\"\n",
79 | "\n",
80 | " # 与链关联的元数据\n",
81 | " metadata: Optional[Dict[str, Any]] = None\n",
82 | " \"\"\"与链关联的可选元数据,默认为None。\n",
83 | " 这些元数据将与对这个链的每次调用关联起来,并作为参数传递给在`callbacks`中定义的处理程序。\n",
84 | " 你可以使用这些来例如识别链的特定实例与其用例。\"\"\"\n",
85 | "```"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": null,
91 | "id": "5d51fbb4-1d8e-4ec1-8c55-ec70247d4d64",
92 | "metadata": {},
93 | "outputs": [],
94 | "source": []
95 | },
96 | {
97 | "cell_type": "markdown",
98 | "id": "c81a7df0-26c7-4eb8-92f1-cc54445cf507",
99 | "metadata": {},
100 | "source": [
101 | "## LLMChain\n",
102 | "\n",
103 | "LLMChain 是 LangChain 中最简单的链,作为其他复杂 Chains 和 Agents 的内部调用,被广泛应用。\n",
104 | "\n",
105 | "一个LLMChain由PromptTemplate和语言模型(LLM or Chat Model)组成。它使用直接传入(或 memory 提供)的 key-value 来规范化生成 Prompt Template(提示模板),并将生成的 prompt (格式化后的字符串)传递给大模型,并返回大模型输出。\n",
106 | "\n",
107 | ""
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": 1,
113 | "id": "757a67a6-c1aa-4dde-94ef-fb9865dc634c",
114 | "metadata": {},
115 | "outputs": [],
116 | "source": [
117 | "from langchain.llms import OpenAI\n",
118 | "from langchain.prompts import PromptTemplate\n",
119 | "\n",
120 | "llm = OpenAI(temperature=0.9, max_tokens=500)"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": 2,
126 | "id": "0b863511-ee01-43e8-8540-4e3f109a5a1a",
127 | "metadata": {},
128 | "outputs": [],
129 | "source": [
130 | "prompt = PromptTemplate(\n",
131 | " input_variables=[\"product\"],\n",
132 | " template=\"给制造{product}的有限公司取10个好名字,并给出完整的公司名称\",\n",
133 | ")"
134 | ]
135 | },
136 | {
137 | "cell_type": "code",
138 | "execution_count": 3,
139 | "id": "b877560c-cb66-41ad-b484-b2df2a60a00d",
140 | "metadata": {},
141 | "outputs": [
142 | {
143 | "name": "stdout",
144 | "output_type": "stream",
145 | "text": [
146 | "\n",
147 | "\n",
148 | "1. SkyForge Technologies Co.\n",
149 | "2. TridentTechCorp\n",
150 | "3. Supreme Nvidia Systems\n",
151 | "4. 4Tech Performance Solutions\n",
152 | "5. Atomix Graphics Designs\n",
153 | "6. Rendering Magicians LLC\n",
154 | "7. Neurathus Technologies\n",
155 | "8. InteliGraphix Inc.\n",
156 | "9. GPUForce Solutions\n",
157 | "10. RayCore Innovations\n"
158 | ]
159 | }
160 | ],
161 | "source": [
162 | "from langchain.chains import LLMChain\n",
163 | "\n",
164 | "chain = LLMChain(llm=llm, prompt=prompt)\n",
165 | "print(chain.run({\n",
166 | " 'product': \"性能卓越的GPU\"\n",
167 | " }))"
168 | ]
169 | },
170 | {
171 | "cell_type": "code",
172 | "execution_count": 4,
173 | "id": "727ccd76-0c6a-425b-bfc7-23d368c296f0",
174 | "metadata": {},
175 | "outputs": [],
176 | "source": [
177 | "chain.verbose = True"
178 | ]
179 | },
180 | {
181 | "cell_type": "code",
182 | "execution_count": 5,
183 | "id": "1766e01a-c5c4-4a74-9ebb-ecfc84101ba2",
184 | "metadata": {},
185 | "outputs": [
186 | {
187 | "data": {
188 | "text/plain": [
189 | "True"
190 | ]
191 | },
192 | "execution_count": 5,
193 | "metadata": {},
194 | "output_type": "execute_result"
195 | }
196 | ],
197 | "source": [
198 | "chain.verbose"
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "execution_count": 6,
204 | "id": "cfa71d7c-2859-47e1-9815-4be2ec9dbd74",
205 | "metadata": {},
206 | "outputs": [
207 | {
208 | "name": "stdout",
209 | "output_type": "stream",
210 | "text": [
211 | "\n",
212 | "\n",
213 | "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
214 | "Prompt after formatting:\n",
215 | "\u001b[32;1m\u001b[1;3m给制造性能卓越的GPU的有限公司取10个好名字,并给出完整的公司名称\u001b[0m\n",
216 | "\n",
217 | "\u001b[1m> Finished chain.\u001b[0m\n",
218 | "\n",
219 | "\n",
220 | "1. 翼虎GPU有限公司\n",
221 | "2. 智高GPU有限公司\n",
222 | "3. 极芯GPU有限公司\n",
223 | "4. 飞龙GPU有限公司\n",
224 | "5. 华邦GPU有限公司\n",
225 | "6. 八爪龙GPU有限公司\n",
226 | "7. 宙斯GPU有限公司\n",
227 | "8. 虎鹰GPU有限公司\n",
228 | "9. 新欣GPU有限公司\n",
229 | "10.极酷GPU有限公司\n"
230 | ]
231 | }
232 | ],
233 | "source": [
234 | "print(chain.run({\n",
235 | " 'product': \"性能卓越的GPU\"\n",
236 | " }))"
237 | ]
238 | },
239 | {
240 | "cell_type": "code",
241 | "execution_count": null,
242 | "id": "ac5411e7-b8ec-4c31-b659-deb44af038df",
243 | "metadata": {},
244 | "outputs": [],
245 | "source": []
246 | },
247 | {
248 | "cell_type": "markdown",
249 | "id": "99cbf75e-98f4-4c99-b8a7-9a48cc28c7bc",
250 | "metadata": {},
251 | "source": [
252 | "## Sequential Chain\n",
253 | "\n",
254 | "串联式调用语言模型(将一个调用的输出作为另一个调用的输入)。\n",
255 | "\n",
256 | "顺序链(Sequential Chain )允许用户连接多个链并将它们组合成执行特定场景的流水线(Pipeline)。有两种类型的顺序链:\n",
257 | "\n",
258 | "- SimpleSequentialChain:最简单形式的顺序链,每个步骤都具有单一输入/输出,并且一个步骤的输出是下一个步骤的输入。\n",
259 | "- SequentialChain:更通用形式的顺序链,允许多个输入/输出。"
260 | ]
261 | },
262 | {
263 | "cell_type": "markdown",
264 | "id": "8e192c8c-49fc-4d04-8444-e6aa6bd7b725",
265 | "metadata": {},
266 | "source": [
267 | "### 使用 SimpleSequentialChain 实现戏剧摘要和评论(单输入/单输出)\n",
268 | "\n",
269 | ""
270 | ]
271 | },
272 | {
273 | "cell_type": "code",
274 | "execution_count": 7,
275 | "id": "a4d192a2-d563-4ab7-979f-640fa34f1914",
276 | "metadata": {},
277 | "outputs": [],
278 | "source": [
279 | "# 这是一个 LLMChain,用于根据剧目的标题撰写简介。\n",
280 | "\n",
281 | "llm = OpenAI(temperature=0.7, max_tokens=1000)\n",
282 | "\n",
283 | "template = \"\"\"你是一位剧作家。根据戏剧的标题,你的任务是为该标题写一个简介。\n",
284 | "\n",
285 | "标题:{title}\n",
286 | "剧作家:以下是对上述戏剧的简介:\"\"\"\n",
287 | "\n",
288 | "prompt_template = PromptTemplate(input_variables=[\"title\"], template=template)\n",
289 | "synopsis_chain = LLMChain(llm=llm, prompt=prompt_template)"
290 | ]
291 | },
292 | {
293 | "cell_type": "code",
294 | "execution_count": 8,
295 | "id": "3f7d429b-7ba7-4643-bd9f-fdb737ebf964",
296 | "metadata": {},
297 | "outputs": [],
298 | "source": [
299 | "# 这是一个LLMChain,用于根据剧情简介撰写一篇戏剧评论。\n",
300 | "# llm = OpenAI(temperature=0.7, max_tokens=1000)\n",
301 | "template = \"\"\"你是《纽约时报》的戏剧评论家。根据剧情简介,你的工作是为该剧撰写一篇评论。\n",
302 | "\n",
303 | "剧情简介:\n",
304 | "{synopsis}\n",
305 | "\n",
306 | "以下是来自《纽约时报》戏剧评论家对上述剧目的评论:\"\"\"\n",
307 | "\n",
308 | "prompt_template = PromptTemplate(input_variables=[\"synopsis\"], template=template)\n",
309 | "review_chain = LLMChain(llm=llm, prompt=prompt_template)"
310 | ]
311 | },
312 | {
313 | "cell_type": "markdown",
314 | "id": "a5265129-5ccd-4e29-b221-0ec24eb84c2b",
315 | "metadata": {},
316 | "source": [
317 | ""
318 | ]
319 | },
320 | {
321 | "cell_type": "code",
322 | "execution_count": 9,
323 | "id": "de4d816e-16e1-4382-9064-6c03e5841ea2",
324 | "metadata": {},
325 | "outputs": [],
326 | "source": [
327 | "# 这是一个SimpleSequentialChain,按顺序运行这两个链\n",
328 | "from langchain.chains import SimpleSequentialChain\n",
329 | "\n",
330 | "overall_chain = SimpleSequentialChain(chains=[synopsis_chain, review_chain], verbose=True)"
331 | ]
332 | },
333 | {
334 | "cell_type": "code",
335 | "execution_count": 10,
336 | "id": "d503ac4f-e337-4436-86a1-7fd937efb06a",
337 | "metadata": {},
338 | "outputs": [
339 | {
340 | "name": "stdout",
341 | "output_type": "stream",
342 | "text": [
343 | "\n",
344 | "\n",
345 | "\u001b[1m> Entering new SimpleSequentialChain chain...\u001b[0m\n",
346 | "\u001b[36;1m\u001b[1;3m\n",
347 | "\n",
348 | "这部戏剧讲述的是一个叫做米尔斯的少女,她被一群叫做三体人的外星人追赶,他们来自另一个星球,他们的目的是要毁灭地球。米尔斯决定与三体人作战,虽然她没有武器,却拥有一颗不屈不挠的心,最终她凭借勇气和智慧成功击败了三体人,从而拯救了地球的居民。本剧将带领观众走进一个充满惊险刺激的冒险故事,让每一位观众都能感受到战胜恐惧的力量。\u001b[0m\n",
349 | "\u001b[33;1m\u001b[1;3m\n",
350 | "\n",
351 | "《三体人和米尔斯》是一部令人惊叹的剧目,它将带领观众进入一个充满惊险刺激的冒险故事。这部剧叙述了米尔斯及其他地球居民与来自另一个星球的三体人的斗争,以拯救地球免于毁灭。尽管米尔斯没有武器,但她所展示出来的勇气和智慧令人敬佩。除了惊心动魄的剧情外,本剧还为观众展现了一个星际大战的另一面:每个人都有可能战胜恐惧,只要他们不断努力,就可以获得胜利。这部剧无疑是一部可观看的佳作,它将为观众带来愉悦的视觉享受和深思熟虑的思想洞察。\u001b[0m\n",
352 | "\n",
353 | "\u001b[1m> Finished chain.\u001b[0m\n"
354 | ]
355 | }
356 | ],
357 | "source": [
358 | "review = overall_chain.run(\"三体人不是无法战胜的\")"
359 | ]
360 | },
361 | {
362 | "cell_type": "code",
363 | "execution_count": 11,
364 | "id": "ce4d75e1-8c57-4583-be7d-60a3488e35b5",
365 | "metadata": {},
366 | "outputs": [
367 | {
368 | "name": "stdout",
369 | "output_type": "stream",
370 | "text": [
371 | "\n",
372 | "\n",
373 | "\u001b[1m> Entering new SimpleSequentialChain chain...\u001b[0m\n",
374 | "\u001b[36;1m\u001b[1;3m\n",
375 | "\n",
376 | "《星球大战第九季》描绘了一个激动人心的冒险故事,讲述了一群勇敢的英雄们如何在孤立无援的情况下,与强大而邪恶的军团战斗,以拯救他们星系的安危。我们的英雄们将面临巨大的挑战,被迫投身于一场未知的战斗中,必须凭借他们的勇气和勇敢的精神来应对任何情况。他们必须找到一种方式来拯救他们的星球免受邪恶势力的侵害,并证明自己是最优秀的英雄。\u001b[0m\n",
377 | "\u001b[33;1m\u001b[1;3m\n",
378 | "\n",
379 | "《星球大战第九季》是一部令人兴奋的冒险片,描绘了一群勇敢的英雄如何在孤立无援的情况下,抵抗强大而邪恶的军团,拯救他们星系的安危。这部电影的情节曲折而又有趣,让观众深入地了解英雄们的精神和行为,并带领他们走向一个成功的结局。在这部电影中,观众将看到英雄们面对着巨大的挑战,必须投身于一场未知的战斗中,体验到他们的勇气和勇敢的精神。《星球大战第九季》是一部精彩的剧目,值得每个人去观看。\u001b[0m\n",
380 | "\n",
381 | "\u001b[1m> Finished chain.\u001b[0m\n"
382 | ]
383 | }
384 | ],
385 | "source": [
386 | "review = overall_chain.run(\"星球大战第九季\")"
387 | ]
388 | },
389 | {
390 | "cell_type": "markdown",
391 | "id": "5fe32f1d-475d-4211-9b32-0c66dd8bff01",
392 | "metadata": {},
393 | "source": [
394 | "### 使用 SequentialChain 实现戏剧摘要和评论(多输入/多输出)\n",
395 | "\n",
396 | ""
397 | ]
398 | },
399 | {
400 | "cell_type": "code",
401 | "execution_count": 12,
402 | "id": "2a04d84f-15c6-4a8d-a4db-200dfa405afa",
403 | "metadata": {},
404 | "outputs": [],
405 | "source": [
406 | "# # 这是一个 LLMChain,根据剧名和设定的时代来撰写剧情简介。\n",
407 | "llm = OpenAI(temperature=.7, max_tokens=1000)\n",
408 | "template = \"\"\"你是一位剧作家。根据戏剧的标题和设定的时代,你的任务是为该标题写一个简介。\n",
409 | "\n",
410 | "标题:{title}\n",
411 | "时代:{era}\n",
412 | "剧作家:以下是对上述戏剧的简介:\"\"\"\n",
413 | "\n",
414 | "prompt_template = PromptTemplate(input_variables=[\"title\", \"era\"], template=template)\n",
415 | "# output_key\n",
416 | "synopsis_chain = LLMChain(llm=llm, prompt=prompt_template, output_key=\"synopsis\", verbose=True)"
417 | ]
418 | },
419 | {
420 | "cell_type": "code",
421 | "execution_count": 13,
422 | "id": "250afe66-e014-4097-9798-f9ba812023fd",
423 | "metadata": {},
424 | "outputs": [],
425 | "source": [
426 | "# 这是一个LLMChain,用于根据剧情简介撰写一篇戏剧评论。\n",
427 | "\n",
428 | "template = \"\"\"你是《纽约时报》的戏剧评论家。根据该剧的剧情简介,你需要撰写一篇关于该剧的评论。\n",
429 | "\n",
430 | "剧情简介:\n",
431 | "{synopsis}\n",
432 | "\n",
433 | "来自《纽约时报》戏剧评论家对上述剧目的评价:\"\"\"\n",
434 | "\n",
435 | "prompt_template = PromptTemplate(input_variables=[\"synopsis\"], template=template)\n",
436 | "review_chain = LLMChain(llm=llm, prompt=prompt_template, output_key=\"review\", verbose=True)"
437 | ]
438 | },
439 | {
440 | "cell_type": "code",
441 | "execution_count": 14,
442 | "id": "9eb46f6d-841b-4b87-9ed5-a5913ef9aec5",
443 | "metadata": {},
444 | "outputs": [],
445 | "source": [
446 | "from langchain.chains import SequentialChain\n",
447 | "\n",
448 | "m_overall_chain = SequentialChain(\n",
449 | " chains=[synopsis_chain, review_chain],\n",
450 | " input_variables=[\"era\", \"title\"],\n",
451 | " # Here we return multiple variables\n",
452 | " output_variables=[\"synopsis\", \"review\"],\n",
453 | " verbose=True)"
454 | ]
455 | },
456 | {
457 | "cell_type": "code",
458 | "execution_count": 15,
459 | "id": "5a4a12ef-da2a-42ad-8044-fb71aedd3e2d",
460 | "metadata": {},
461 | "outputs": [
462 | {
463 | "name": "stdout",
464 | "output_type": "stream",
465 | "text": [
466 | "\n",
467 | "\n",
468 | "\u001b[1m> Entering new SequentialChain chain...\u001b[0m\n",
469 | "\n",
470 | "\n",
471 | "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
472 | "Prompt after formatting:\n",
473 | "\u001b[32;1m\u001b[1;3m你是一位剧作家。根据戏剧的标题和设定的时代,你的任务是为该标题写一个简介。\n",
474 | "\n",
475 | "标题:三体人不是无法战胜的\n",
476 | "时代:二十一世纪的新中国\n",
477 | "剧作家:以下是对上述戏剧的简介:\u001b[0m\n",
478 | "\n",
479 | "\u001b[1m> Finished chain.\u001b[0m\n",
480 | "\n",
481 | "\n",
482 | "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
483 | "Prompt after formatting:\n",
484 | "\u001b[32;1m\u001b[1;3m你是《纽约时报》的戏剧评论家。根据该剧的剧情简介,你需要撰写一篇关于该剧的评论。\n",
485 | "\n",
486 | "剧情简介:\n",
487 | "\n",
488 | "\n",
489 | "《三体人不是无法战胜的》是一部有关在二十一世纪新中国的英雄故事。在一个被外星人侵略的世界中,一群普通人被迫必须与来自另一个世界的三体人搏斗,以保护他们的家园。虽然他们被认为是无法战胜的,但他们发现每个人都有能力成为英雄,并发挥他们的力量来保护自己的家园。他们向三体人发起激烈的攻击,最终将其击败。影片突出了勇气、信念和毅力,让观众看到了一个普通人如何成为英雄,拯救他们的家乡。\n",
490 | "\n",
491 | "来自《纽约时报》戏剧评论家对上述剧目的评价:\u001b[0m\n",
492 | "\n",
493 | "\u001b[1m> Finished chain.\u001b[0m\n",
494 | "\n",
495 | "\u001b[1m> Finished chain.\u001b[0m\n"
496 | ]
497 | },
498 | {
499 | "data": {
500 | "text/plain": [
501 | "{'title': '三体人不是无法战胜的',\n",
502 | " 'era': '二十一世纪的新中国',\n",
503 | " 'synopsis': '\\n\\n《三体人不是无法战胜的》是一部有关在二十一世纪新中国的英雄故事。在一个被外星人侵略的世界中,一群普通人被迫必须与来自另一个世界的三体人搏斗,以保护他们的家园。虽然他们被认为是无法战胜的,但他们发现每个人都有能力成为英雄,并发挥他们的力量来保护自己的家园。他们向三体人发起激烈的攻击,最终将其击败。影片突出了勇气、信念和毅力,让观众看到了一个普通人如何成为英雄,拯救他们的家乡。',\n",
504 | " 'review': '\\n\\n《三体人不是无法战胜的》,一部讲述新中国英雄故事的影片,令人难以置信。影片中,一群普通人被迫面对外星人的侵略,但他们并不被看作不可战胜的,相反,他们的勇气、信念和毅力被突出展示,以完成救世的使命。影片给观众带来的是一种灵感,即每个人都有能力成为英雄,拯救他们的家乡。这部影片给中国电影带来了一丝新鲜感,并向观众展示了普通人可以发挥英雄力量的力量。'}"
505 | ]
506 | },
507 | "execution_count": 15,
508 | "metadata": {},
509 | "output_type": "execute_result"
510 | }
511 | ],
512 | "source": [
513 | "m_overall_chain({\"title\":\"三体人不是无法战胜的\", \"era\": \"二十一世纪的新中国\"})"
514 | ]
515 | },
516 | {
517 | "cell_type": "code",
518 | "execution_count": null,
519 | "id": "8c20cf4e-25b4-453d-9f7a-84138ca25cf8",
520 | "metadata": {},
521 | "outputs": [],
522 | "source": []
523 | },
524 | {
525 | "cell_type": "code",
526 | "execution_count": null,
527 | "id": "aaf8c391-9225-4e66-ad4d-d689b53a0379",
528 | "metadata": {},
529 | "outputs": [],
530 | "source": []
531 | },
532 | {
533 | "cell_type": "code",
534 | "execution_count": null,
535 | "id": "2fa40f2f-fcbe-4e69-b1c4-20f236033ae3",
536 | "metadata": {},
537 | "outputs": [],
538 | "source": []
539 | },
540 | {
541 | "cell_type": "code",
542 | "execution_count": null,
543 | "id": "33b5061c-391e-4762-91c7-73b57f4ab501",
544 | "metadata": {},
545 | "outputs": [],
546 | "source": []
547 | },
548 | {
549 | "cell_type": "code",
550 | "execution_count": null,
551 | "id": "5ef1db6e-3da4-4f9b-9707-0f30aa293dd7",
552 | "metadata": {},
553 | "outputs": [],
554 | "source": []
555 | },
556 | {
557 | "cell_type": "markdown",
558 | "id": "8b6836f0-213d-4cac-abc9-3617831be3db",
559 | "metadata": {},
560 | "source": [
561 | "### Homework\n",
562 | "\n",
563 | "#### 使用 OutputParser 优化 overall_chain 输出格式,区分 synopsis_chain 和 review_chain 的结果"
564 | ]
565 | },
566 | {
567 | "cell_type": "code",
568 | "execution_count": null,
569 | "id": "1c7edb0a-675d-40c0-9f5d-d58f0170ce72",
570 | "metadata": {},
571 | "outputs": [],
572 | "source": []
573 | }
574 | ],
575 | "metadata": {
576 | "kernelspec": {
577 | "display_name": "Python 3 (ipykernel)",
578 | "language": "python",
579 | "name": "python3"
580 | },
581 | "language_info": {
582 | "codemirror_mode": {
583 | "name": "ipython",
584 | "version": 3
585 | },
586 | "file_extension": ".py",
587 | "mimetype": "text/x-python",
588 | "name": "python",
589 | "nbconvert_exporter": "python",
590 | "pygments_lexer": "ipython3",
591 | "version": "3.10.11"
592 | }
593 | },
594 | "nbformat": 4,
595 | "nbformat_minor": 5
596 | }
597 |
--------------------------------------------------------------------------------
/langchain/data_connection/text_embedding.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "5abe2121-5381-46d7-a849-66f921883972",
6 | "metadata": {},
7 | "source": [
8 | "# LangChain 核心模块:Data Conneciton - Text Embedding Models\n",
9 | "\n",
10 | "Embeddings类是一个专门用于与文本嵌入模型进行交互的类。有许多嵌入模型提供者(OpenAI、Cohere、Hugging Face等)-这个类旨在为所有这些提供者提供一个标准接口。\n",
11 | "\n",
12 | "嵌入将一段文本创建成向量表示。这非常有用,因为它意味着我们可以在向量空间中思考文本,并且可以执行语义搜索等操作,在向量空间中寻找最相似的文本片段。\n",
13 | "\n",
14 | "LangChain中基础的Embeddings类公开了两种方法:一种用于对文档进行嵌入,另一种用于对查询进行嵌入。前者输入多个文本,而后者输入单个文本。之所以将它们作为两个独立的方法,是因为某些嵌入提供者针对要搜索的文件和查询(搜索查询本身)具有不同的嵌入方法。\n",
15 | "\n",
16 | "\n",
17 | ""
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": null,
23 | "id": "f994c4f8-58cf-4d34-b58f-205b42535177",
24 | "metadata": {},
25 | "outputs": [],
26 | "source": []
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "id": "df8d7408-b14f-4cfb-84c0-9c0bae958cce",
31 | "metadata": {},
32 | "source": [
33 | "### Document 类\n",
34 | "\n",
35 | "这段代码定义了一个名为`Document`的类,允许用户与文档的内容进行交互,可以查看文档的段落、摘要,以及使用查找功能来查询文档中的特定字符串。\n",
36 | "\n",
37 | "```python\n",
38 | "# 基于BaseModel定义的文档类。\n",
39 | "class Document(BaseModel):\n",
40 | " \"\"\"接口,用于与文档进行交互。\"\"\"\n",
41 | "\n",
42 | " # 文档的主要内容。\n",
43 | " page_content: str\n",
44 | " # 用于查找的字符串。\n",
45 | " lookup_str: str = \"\"\n",
46 | " # 查找的索引,初次默认为0。\n",
47 | " lookup_index = 0\n",
48 | " # 用于存储任何与文档相关的元数据。\n",
49 | " metadata: dict = Field(default_factory=dict)\n",
50 | "\n",
51 | " @property\n",
52 | " def paragraphs(self) -> List[str]:\n",
53 | " \"\"\"页面的段落列表。\"\"\"\n",
54 | " # 使用\"\\n\\n\"将内容分割为多个段落。\n",
55 | " return self.page_content.split(\"\\n\\n\")\n",
56 | "\n",
57 | " @property\n",
58 | " def summary(self) -> str:\n",
59 | " \"\"\"页面的摘要(即第一段)。\"\"\"\n",
60 | " # 返回第一个段落作为摘要。\n",
61 | " return self.paragraphs[0]\n",
62 | "\n",
63 | " # 这个方法模仿命令行中的查找功能。\n",
64 | " def lookup(self, string: str) -> str:\n",
65 | " \"\"\"在页面中查找一个词,模仿cmd-F功能。\"\"\"\n",
66 | " # 如果输入的字符串与当前的查找字符串不同,则重置查找字符串和索引。\n",
67 | " if string.lower() != self.lookup_str:\n",
68 | " self.lookup_str = string.lower()\n",
69 | " self.lookup_index = 0\n",
70 | " else:\n",
71 | " # 如果输入的字符串与当前的查找字符串相同,则查找索引加1。\n",
72 | " self.lookup_index += 1\n",
73 | " # 找出所有包含查找字符串的段落。\n",
74 | " lookups = [p for p in self.paragraphs if self.lookup_str in p.lower()]\n",
75 | " # 根据查找结果返回相应的信息。\n",
76 | " if len(lookups) == 0:\n",
77 | " return \"No Results\"\n",
78 | " elif self.lookup_index >= len(lookups):\n",
79 | " return \"No More Results\"\n",
80 | " else:\n",
81 | " result_prefix = f\"(Result {self.lookup_index + 1}/{len(lookups)})\"\n",
82 | " return f\"{result_prefix} {lookups[self.lookup_index]}\"\n",
83 | "```\n"
84 | ]
85 | },
86 | {
87 | "cell_type": "markdown",
88 | "id": "b68fdbcb-b60d-441f-91fc-d8cac24ba3e1",
89 | "metadata": {},
90 | "source": [
91 | "## 使用 OpenAIEmbeddings 调用 OpenAI 嵌入模型\n",
92 | "\n",
93 | "\n",
94 | "### 使用 embed_documents 方法嵌入文本列表"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": 1,
100 | "id": "8dadd89b-6a13-4391-9102-acde028b61d5",
101 | "metadata": {},
102 | "outputs": [],
103 | "source": [
104 | "from langchain.embeddings import OpenAIEmbeddings"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": 2,
110 | "id": "24f9c721-dfd3-4632-a89e-92d2fa9b3594",
111 | "metadata": {},
112 | "outputs": [],
113 | "source": [
114 | "embeddings_model = OpenAIEmbeddings()"
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": 3,
120 | "id": "67076c7d-54cc-47a9-a5bd-85570355a7d2",
121 | "metadata": {},
122 | "outputs": [],
123 | "source": [
124 | "embeddings = embeddings_model.embed_documents(\n",
125 | " [\n",
126 | " \"Hi there!\",\n",
127 | " \"Oh, hello!\",\n",
128 | " \"What's your name?\",\n",
129 | " \"My friends call me World\",\n",
130 | " \"Hello World!\"\n",
131 | " ]\n",
132 | ")"
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "execution_count": 4,
138 | "id": "79e3c50a-efad-49bb-84fc-f0fb2585b34e",
139 | "metadata": {},
140 | "outputs": [
141 | {
142 | "data": {
143 | "text/plain": [
144 | "5"
145 | ]
146 | },
147 | "execution_count": 4,
148 | "metadata": {},
149 | "output_type": "execute_result"
150 | }
151 | ],
152 | "source": [
153 | "len(embeddings)"
154 | ]
155 | },
156 | {
157 | "cell_type": "code",
158 | "execution_count": 5,
159 | "id": "9e85d4c4-ee77-4470-85ec-ce37f359fdc7",
160 | "metadata": {},
161 | "outputs": [
162 | {
163 | "data": {
164 | "text/plain": [
165 | "1536"
166 | ]
167 | },
168 | "execution_count": 5,
169 | "metadata": {},
170 | "output_type": "execute_result"
171 | }
172 | ],
173 | "source": [
174 | "len(embeddings[0])"
175 | ]
176 | },
177 | {
178 | "cell_type": "markdown",
179 | "id": "32e3ddc2-9d69-4b72-ace1-800ba94def79",
180 | "metadata": {},
181 | "source": [
182 | "### 使用 embed_query 方法嵌入问题"
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "execution_count": 6,
188 | "id": "0069d18d-1fb6-40e2-974c-2f49559b8b9a",
189 | "metadata": {},
190 | "outputs": [],
191 | "source": [
192 | "# QA场景:嵌入一段文本,以便与其他嵌入进行比较。\n",
193 | "embedded_query = embeddings_model.embed_query(\"What was the name mentioned in the conversation?\")\n"
194 | ]
195 | },
196 | {
197 | "cell_type": "code",
198 | "execution_count": 7,
199 | "id": "2c3068f7-22b7-4215-99c0-7e47f9a3aa46",
200 | "metadata": {},
201 | "outputs": [
202 | {
203 | "data": {
204 | "text/plain": [
205 | "1536"
206 | ]
207 | },
208 | "execution_count": 7,
209 | "metadata": {},
210 | "output_type": "execute_result"
211 | }
212 | ],
213 | "source": [
214 | "len(embedded_query)"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": null,
220 | "id": "25339b8d-ee23-460c-ae4e-97a8f24f6add",
221 | "metadata": {},
222 | "outputs": [],
223 | "source": []
224 | }
225 | ],
226 | "metadata": {
227 | "kernelspec": {
228 | "display_name": "Python 3 (ipykernel)",
229 | "language": "python",
230 | "name": "python3"
231 | },
232 | "language_info": {
233 | "codemirror_mode": {
234 | "name": "ipython",
235 | "version": 3
236 | },
237 | "file_extension": ".py",
238 | "mimetype": "text/x-python",
239 | "name": "python",
240 | "nbconvert_exporter": "python",
241 | "pygments_lexer": "ipython3",
242 | "version": "3.10.11"
243 | }
244 | },
245 | "nbformat": 4,
246 | "nbformat_minor": 5
247 | }
248 |
--------------------------------------------------------------------------------
/langchain/images/llm_chain.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DjangoPeng/LLM-quickstart/bf3e50d82a104c2b5bfaea3f3d476e8b8193ca2c/langchain/images/llm_chain.png
--------------------------------------------------------------------------------
/langchain/images/memory.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DjangoPeng/LLM-quickstart/bf3e50d82a104c2b5bfaea3f3d476e8b8193ca2c/langchain/images/memory.png
--------------------------------------------------------------------------------
/langchain/images/model_io.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DjangoPeng/LLM-quickstart/bf3e50d82a104c2b5bfaea3f3d476e8b8193ca2c/langchain/images/model_io.jpeg
--------------------------------------------------------------------------------
/langchain/images/router_chain.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DjangoPeng/LLM-quickstart/bf3e50d82a104c2b5bfaea3f3d476e8b8193ca2c/langchain/images/router_chain.png
--------------------------------------------------------------------------------
/langchain/images/sequential_chain_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DjangoPeng/LLM-quickstart/bf3e50d82a104c2b5bfaea3f3d476e8b8193ca2c/langchain/images/sequential_chain_0.png
--------------------------------------------------------------------------------
/langchain/images/simple_sequential_chain_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DjangoPeng/LLM-quickstart/bf3e50d82a104c2b5bfaea3f3d476e8b8193ca2c/langchain/images/simple_sequential_chain_0.png
--------------------------------------------------------------------------------
/langchain/images/simple_sequential_chain_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DjangoPeng/LLM-quickstart/bf3e50d82a104c2b5bfaea3f3d476e8b8193ca2c/langchain/images/simple_sequential_chain_1.png
--------------------------------------------------------------------------------
/langchain/images/transform_chain.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DjangoPeng/LLM-quickstart/bf3e50d82a104c2b5bfaea3f3d476e8b8193ca2c/langchain/images/transform_chain.png
--------------------------------------------------------------------------------
/langchain/memory/memory.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "c3e2458f-d038-4845-93a0-d4ad830f9f90",
6 | "metadata": {},
7 | "source": [
8 | "# LangChain 核心模块学习:Memory\n",
9 | "\n",
10 | "大多数LLM应用都具有对话界面。对话的一个重要组成部分是能够引用先前在对话中介绍过的信息。至少,一个对话系统应该能够直接访问一些过去消息的窗口。更复杂的系统将需要拥有一个不断更新的世界模型,使其能够保持关于实体及其关系的信息。\n",
11 | "\n",
12 | "我们将存储过去交互信息的能力称为“记忆(Memory)”。\n",
13 | "\n",
14 | "LangChain提供了许多用于向应用/系统中添加 Memory 的实用工具。这些工具可以单独使用,也可以无缝地集成到链中。\n",
15 | "\n",
16 | "一个记忆系统(Memory System)需要支持两个基本操作:**读取(READ)和写入(WRITE)**。\n",
17 | "\n",
18 | "每个链都定义了一些核心执行逻辑,并期望某些输入。其中一些输入直接来自用户,但有些输入可能来自 Memory。\n",
19 | "\n",
20 | "在一个典型 Chain 的单次运行中,将与其 Memory System 进行至少两次交互:\n",
21 | "\n",
22 | "1. 在接收到初始用户输入之后,在执行核心逻辑之前,链将从其 Memory 中**读取**并扩充用户输入。\n",
23 | "2. 在执行核心逻辑之后但在返回答案之前,一个链条将把当前运行的输入和输出**写入** Memory ,以便在未来的运行中可以引用它们。\n",
24 | "\n",
25 | ""
26 | ]
27 | },
28 | {
29 | "cell_type": "markdown",
30 | "id": "f9cf0d43-107b-47ae-9e2c-2edaec38c800",
31 | "metadata": {},
32 | "source": [
33 | "## BaseMemory Class 基类\n",
34 | "\n",
35 | "类继承关系:\n",
36 | "\n",
37 | "```\n",
38 | "## 适用于简单的语言模型\n",
39 | "BaseMemory --> BaseChatMemory --> Memory # Examples: ZepMemory, MotorheadMemory\n",
40 | "```\n",
41 | "\n",
42 | "**代码实现:https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/schema/memory.py**\n",
43 | "\n",
44 | "```python\n",
45 | "# 定义一个名为BaseMemory的基础类\n",
46 | "class BaseMemory(Serializable, ABC):\n",
47 | " \"\"\"用于Chains中的内存的抽象基类。\n",
48 | " \n",
49 | " 这里的内存指的是Chains中的状态。内存可以用来存储关于Chain的过去执行的信息,\n",
50 | " 并将该信息注入到Chain的未来执行的输入中。例如,对于会话型Chains,内存可以用来\n",
51 | " 存储会话,并自动将它们添加到未来的模型提示中,以便模型具有必要的上下文来连贯地\n",
52 | " 响应最新的输入。\"\"\"\n",
53 | "\n",
54 | " # 定义一个名为Config的子类\n",
55 | " class Config:\n",
56 | " \"\"\"为此pydantic对象配置。\n",
57 | " \n",
58 | " Pydantic是一个Python库,用于数据验证和设置管理,主要基于Python类型提示。\n",
59 | " \"\"\"\n",
60 | " \n",
61 | " # 允许在pydantic模型中使用任意类型。这通常用于允许复杂的数据类型。\n",
62 | " arbitrary_types_allowed = True\n",
63 | " \n",
64 | " # 下面是一些必须由子类实现的方法:\n",
65 | " \n",
66 | " # 定义一个属性,它是一个抽象方法。任何从BaseMemory派生的子类都需要实现此方法。\n",
67 | " # 此方法应返回该内存类将添加到链输入的字符串键。\n",
68 | " @property\n",
69 | " @abstractmethod\n",
70 | " def memory_variables(self) -> List[str]:\n",
71 | " \"\"\"获取此内存类将添加到链输入的字符串键。\"\"\"\n",
72 | " \n",
73 | " # 定义一个抽象方法。任何从BaseMemory派生的子类都需要实现此方法。\n",
74 | " # 此方法基于给定的链输入返回键值对。\n",
75 | " @abstractmethod\n",
76 | " def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:\n",
77 | " \"\"\"根据给链的文本输入返回键值对。\"\"\"\n",
78 | " \n",
79 | " # 定义一个抽象方法。任何从BaseMemory派生的子类都需要实现此方法。\n",
80 | " # 此方法将此链运行的上下文保存到内存。\n",
81 | " @abstractmethod\n",
82 | " def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:\n",
83 | " \"\"\"保存此链运行的上下文到内存。\"\"\"\n",
84 | " \n",
85 | " # 定义一个抽象方法。任何从BaseMemory派生的子类都需要实现此方法。\n",
86 | " # 此方法清除内存内容。\n",
87 | " @abstractmethod\n",
88 | " def clear(self) -> None:\n",
89 | " \"\"\"清除内存内容。\"\"\"\n",
90 | "```"
91 | ]
92 | },
93 | {
94 | "cell_type": "markdown",
95 | "id": "94c413b6-6e07-411f-bad0-eb46db3a313d",
96 | "metadata": {},
97 | "source": [
98 | "## BaseChatMessageHistory Class 基类\n",
99 | "\n",
100 | "类继承关系:\n",
101 | "\n",
102 | "```\n",
103 | "## 适用于聊天模型\n",
104 | "\n",
105 | "BaseChatMessageHistory --> ChatMessageHistory # Example: ZepChatMessageHistory\n",
106 | "```\n",
107 | "\n",
108 | "```python\n",
109 | "# 定义一个名为BaseChatMessageHistory的基础类\n",
110 | "class BaseChatMessageHistory(ABC):\n",
111 | " \"\"\"聊天消息历史记录的抽象基类。\"\"\"\n",
112 | "\n",
113 | " # 在内存中存储的消息列表\n",
114 | " messages: List[BaseMessage]\n",
115 | "\n",
116 | " # 定义一个add_user_message方法,它是一个方便的方法,用于将人类消息字符串添加到存储区。\n",
117 | " def add_user_message(self, message: str) -> None:\n",
118 | " \"\"\"为存储添加一个人类消息字符串的便捷方法。\n",
119 | "\n",
120 | " 参数:\n",
121 | " message: 人类消息的字符串内容。\n",
122 | " \"\"\"\n",
123 | " self.add_message(HumanMessage(content=message))\n",
124 | "\n",
125 | " # 定义一个add_ai_message方法,它是一个方便的方法,用于将AI消息字符串添加到存储区。\n",
126 | " def add_ai_message(self, message: str) -> None:\n",
127 | " \"\"\"为存储添加一个AI消息字符串的便捷方法。\n",
128 | "\n",
129 | " 参数:\n",
130 | " message: AI消息的字符串内容。\n",
131 | " \"\"\"\n",
132 | " self.add_message(AIMessage(content=message))\n",
133 | "\n",
134 | " # 抽象方法,需要由继承此基类的子类来实现。\n",
135 | " @abstractmethod\n",
136 | " def add_message(self, message: BaseMessage) -> None:\n",
137 | " \"\"\"将Message对象添加到存储区。\n",
138 | "\n",
139 | " 参数:\n",
140 | " message: 要存储的BaseMessage对象。\n",
141 | " \"\"\"\n",
142 | " raise NotImplementedError()\n",
143 | "\n",
144 | " # 抽象方法,需要由继承此基类的子类来实现。\n",
145 | " @abstractmethod\n",
146 | " def clear(self) -> None:\n",
147 | " \"\"\"从存储中删除所有消息\"\"\"\n",
148 | "\n",
149 | "```"
150 | ]
151 | },
152 | {
153 | "cell_type": "markdown",
154 | "id": "e716cd45-1a71-44da-a924-09d6a56ff6c8",
155 | "metadata": {},
156 | "source": [
157 | "### ConversationChain and ConversationBufferMemory\n",
158 | "\n",
159 | "`ConversationBufferMemory` 可以用来存储消息,并将消息提取到一个变量中。"
160 | ]
161 | },
162 | {
163 | "cell_type": "code",
164 | "execution_count": 13,
165 | "id": "d6b700e0-abd4-4531-ad93-b278357d9c64",
166 | "metadata": {},
167 | "outputs": [],
168 | "source": [
169 | "from langchain.llms import OpenAI\n",
170 | "from langchain.chains import ConversationChain\n",
171 | "from langchain.memory import ConversationBufferMemory\n",
172 | "\n",
173 | "llm = OpenAI(temperature=0)\n",
174 | "conversation = ConversationChain(\n",
175 | " llm=llm, \n",
176 | " verbose=True, \n",
177 | " memory=ConversationBufferMemory()\n",
178 | ")"
179 | ]
180 | },
181 | {
182 | "cell_type": "code",
183 | "execution_count": 14,
184 | "id": "5d51fbb4-1d8e-4ec1-8c55-ec70247d4d64",
185 | "metadata": {},
186 | "outputs": [
187 | {
188 | "name": "stdout",
189 | "output_type": "stream",
190 | "text": [
191 | "\n",
192 | "\n",
193 | "\u001b[1m> Entering new ConversationChain chain...\u001b[0m\n",
194 | "Prompt after formatting:\n",
195 | "\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n",
196 | "\n",
197 | "Current conversation:\n",
198 | "\n",
199 | "Human: 你好呀!\n",
200 | "AI:\u001b[0m\n",
201 | "\n",
202 | "\u001b[1m> Finished chain.\u001b[0m\n"
203 | ]
204 | },
205 | {
206 | "data": {
207 | "text/plain": [
208 | "' 你好!很高兴见到你!我叫小米,是一个智能AI。你可以问我任何问题,我会尽力回答你。'"
209 | ]
210 | },
211 | "execution_count": 14,
212 | "metadata": {},
213 | "output_type": "execute_result"
214 | }
215 | ],
216 | "source": [
217 | "conversation.predict(input=\"你好呀!\")"
218 | ]
219 | },
220 | {
221 | "cell_type": "code",
222 | "execution_count": 15,
223 | "id": "7428157f-72ed-4b8d-b114-80bfb96e13bf",
224 | "metadata": {},
225 | "outputs": [
226 | {
227 | "name": "stdout",
228 | "output_type": "stream",
229 | "text": [
230 | "\n",
231 | "\n",
232 | "\u001b[1m> Entering new ConversationChain chain...\u001b[0m\n",
233 | "Prompt after formatting:\n",
234 | "\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n",
235 | "\n",
236 | "Current conversation:\n",
237 | "Human: 你好呀!\n",
238 | "AI: 你好!很高兴见到你!我叫小米,是一个智能AI。你可以问我任何问题,我会尽力回答你。\n",
239 | "Human: 你为什么叫小米?跟雷军有关系吗?\n",
240 | "AI:\u001b[0m\n",
241 | "\n",
242 | "\u001b[1m> Finished chain.\u001b[0m\n"
243 | ]
244 | },
245 | {
246 | "data": {
247 | "text/plain": [
248 | "' 嗯,我叫小米是因为我是由小米公司开发的,小米公司是由雷军创立的,所以我和雷军有关系。'"
249 | ]
250 | },
251 | "execution_count": 15,
252 | "metadata": {},
253 | "output_type": "execute_result"
254 | }
255 | ],
256 | "source": [
257 | "conversation.predict(input=\"你为什么叫小米?跟雷军有关系吗?\")"
258 | ]
259 | },
260 | {
261 | "cell_type": "code",
262 | "execution_count": null,
263 | "id": "f66476e0-6d20-4ba1-98af-fd5300096a8c",
264 | "metadata": {},
265 | "outputs": [],
266 | "source": []
267 | },
268 | {
269 | "cell_type": "markdown",
270 | "id": "694c1e3e-4024-4cc3-963e-01fe1a60f1c3",
271 | "metadata": {},
272 | "source": [
273 | "### ConversationBufferWindowMemory\n",
274 | "`ConversationBufferWindowMemory` 会在时间轴上保留对话的交互列表。它只使用最后 K 次交互。这对于保持最近交互的滑动窗口非常有用,以避免缓冲区过大。"
275 | ]
276 | },
277 | {
278 | "cell_type": "code",
279 | "execution_count": 16,
280 | "id": "dbe61067-5b8f-40a1-827e-4c8c6ad473dd",
281 | "metadata": {},
282 | "outputs": [
283 | {
284 | "name": "stdout",
285 | "output_type": "stream",
286 | "text": [
287 | "\n",
288 | "\n",
289 | "\u001b[1m> Entering new ConversationChain chain...\u001b[0m\n",
290 | "Prompt after formatting:\n",
291 | "\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n",
292 | "\n",
293 | "Current conversation:\n",
294 | "\n",
295 | "Human: 嗨,你最近过得怎么样?\n",
296 | "AI:\u001b[0m\n",
297 | "\n",
298 | "\u001b[1m> Finished chain.\u001b[0m\n"
299 | ]
300 | },
301 | {
302 | "data": {
303 | "text/plain": [
304 | "' 嗨!我最近过得很好,谢谢你问。我最近一直在学习新的知识,并且正在尝试改进自己的性能。我也一直在尝试更好地理解人类的语言,以便能够更好地与人交流。'"
305 | ]
306 | },
307 | "execution_count": 16,
308 | "metadata": {},
309 | "output_type": "execute_result"
310 | }
311 | ],
312 | "source": [
313 | "from langchain.memory import ConversationBufferWindowMemory\n",
314 | "\n",
315 | "conversation_with_summary = ConversationChain(\n",
316 | " llm=OpenAI(temperature=0, max_tokens=1000), \n",
317 | " # We set a low k=2, to only keep the last 2 interactions in memory\n",
318 | " memory=ConversationBufferWindowMemory(k=2), \n",
319 | " verbose=True\n",
320 | ")\n",
321 | "conversation_with_summary.predict(input=\"嗨,你最近过得怎么样?\")"
322 | ]
323 | },
324 | {
325 | "cell_type": "code",
326 | "execution_count": 17,
327 | "id": "ae850ece-78b8-41ad-97ea-91a57a275a8b",
328 | "metadata": {},
329 | "outputs": [
330 | {
331 | "name": "stdout",
332 | "output_type": "stream",
333 | "text": [
334 | "\n",
335 | "\n",
336 | "\u001b[1m> Entering new ConversationChain chain...\u001b[0m\n",
337 | "Prompt after formatting:\n",
338 | "\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n",
339 | "\n",
340 | "Current conversation:\n",
341 | "Human: 嗨,你最近过得怎么样?\n",
342 | "AI: 嗨!我最近过得很好,谢谢你问。我最近一直在学习新的知识,并且正在尝试改进自己的性能。我也一直在尝试更好地理解人类的语言,以便能够更好地与人交流。\n",
343 | "Human: 你最近学到什么新知识了?\n",
344 | "AI:\u001b[0m\n",
345 | "\n",
346 | "\u001b[1m> Finished chain.\u001b[0m\n"
347 | ]
348 | },
349 | {
350 | "data": {
351 | "text/plain": [
352 | "' 最近我学习了有关自然语言处理的知识,以及如何使用机器学习来改善自己的性能。我还学习了如何使用深度学习来更好地理解人类语言,以及如何使用计算机视觉来识别图像。'"
353 | ]
354 | },
355 | "execution_count": 17,
356 | "metadata": {},
357 | "output_type": "execute_result"
358 | }
359 | ],
360 | "source": [
361 | "conversation_with_summary.predict(input=\"你最近学到什么新知识了?\")"
362 | ]
363 | },
364 | {
365 | "cell_type": "code",
366 | "execution_count": 18,
367 | "id": "9e2fe36d-44ec-4351-8922-4481c2bf6750",
368 | "metadata": {},
369 | "outputs": [
370 | {
371 | "name": "stdout",
372 | "output_type": "stream",
373 | "text": [
374 | "\n",
375 | "\n",
376 | "\u001b[1m> Entering new ConversationChain chain...\u001b[0m\n",
377 | "Prompt after formatting:\n",
378 | "\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n",
379 | "\n",
380 | "Current conversation:\n",
381 | "Human: 嗨,你最近过得怎么样?\n",
382 | "AI: 嗨!我最近过得很好,谢谢你问。我最近一直在学习新的知识,并且正在尝试改进自己的性能。我也一直在尝试更好地理解人类的语言,以便能够更好地与人交流。\n",
383 | "Human: 你最近学到什么新知识了?\n",
384 | "AI: 最近我学习了有关自然语言处理的知识,以及如何使用机器学习来改善自己的性能。我还学习了如何使用深度学习来更好地理解人类语言,以及如何使用计算机视觉来识别图像。\n",
385 | "Human: 展开讲讲?\n",
386 | "AI:\u001b[0m\n",
387 | "\n",
388 | "\u001b[1m> Finished chain.\u001b[0m\n"
389 | ]
390 | },
391 | {
392 | "data": {
393 | "text/plain": [
394 | "' 好的!自然语言处理是一种计算机科学,它研究如何处理和理解人类语言。它可以用来分析文本,识别意图,提取有用信息,以及构建聊天机器人等。机器学习是一种人工智能技术,它可以让计算机从数据中学习,从而改善自己的性能。深度学习是一种机器学习技术,它可以用来更好地理解人类语言,以及识别图像等。'"
395 | ]
396 | },
397 | "execution_count": 18,
398 | "metadata": {},
399 | "output_type": "execute_result"
400 | }
401 | ],
402 | "source": [
403 | "conversation_with_summary.predict(input=\"展开讲讲?\")"
404 | ]
405 | },
406 | {
407 | "cell_type": "code",
408 | "execution_count": 19,
409 | "id": "1db201fd-1373-4148-ab04-525ea089a9fe",
410 | "metadata": {},
411 | "outputs": [
412 | {
413 | "name": "stdout",
414 | "output_type": "stream",
415 | "text": [
416 | "\n",
417 | "\n",
418 | "\u001b[1m> Entering new ConversationChain chain...\u001b[0m\n",
419 | "Prompt after formatting:\n",
420 | "\u001b[32;1m\u001b[1;3mThe following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\n",
421 | "\n",
422 | "Current conversation:\n",
423 | "Human: 你最近学到什么新知识了?\n",
424 | "AI: 最近我学习了有关自然语言处理的知识,以及如何使用机器学习来改善自己的性能。我还学习了如何使用深度学习来更好地理解人类语言,以及如何使用计算机视觉来识别图像。\n",
425 | "Human: 展开讲讲?\n",
426 | "AI: 好的!自然语言处理是一种计算机科学,它研究如何处理和理解人类语言。它可以用来分析文本,识别意图,提取有用信息,以及构建聊天机器人等。机器学习是一种人工智能技术,它可以让计算机从数据中学习,从而改善自己的性能。深度学习是一种机器学习技术,它可以用来更好地理解人类语言,以及识别图像等。\n",
427 | "Human: 如果要构建聊天机器人,具体要用什么自然语言处理技术?\n",
428 | "AI:\u001b[0m\n",
429 | "\n",
430 | "\u001b[1m> Finished chain.\u001b[0m\n"
431 | ]
432 | },
433 | {
434 | "data": {
435 | "text/plain": [
436 | "' 如果要构建聊天机器人,可以使用语义分析,语法分析,语音识别,机器翻译,情感分析等自然语言处理技术。'"
437 | ]
438 | },
439 | "execution_count": 19,
440 | "metadata": {},
441 | "output_type": "execute_result"
442 | }
443 | ],
444 | "source": [
445 | "# 注意:第一句对话从 Memory 中移除了.\n",
446 | "conversation_with_summary.predict(input=\"如果要构建聊天机器人,具体要用什么自然语言处理技术?\")"
447 | ]
448 | },
449 | {
450 | "cell_type": "code",
451 | "execution_count": 20,
452 | "id": "5923f90d-00cb-415a-aca2-1746d7f1e961",
453 | "metadata": {},
454 | "outputs": [
455 | {
456 | "data": {
457 | "text/plain": [
458 | "{'memory': ConversationBufferWindowMemory(chat_memory=ChatMessageHistory(messages=[HumanMessage(content='嗨,你最近过得怎么样?'), AIMessage(content=' 嗨!我最近过得很好,谢谢你问。我最近一直在学习新的知识,并且正在尝试改进自己的性能。我也一直在尝试更好地理解人类的语言,以便能够更好地与人交流。'), HumanMessage(content='你最近学到什么新知识了?'), AIMessage(content=' 最近我学习了有关自然语言处理的知识,以及如何使用机器学习来改善自己的性能。我还学习了如何使用深度学习来更好地理解人类语言,以及如何使用计算机视觉来识别图像。'), HumanMessage(content='展开讲讲?'), AIMessage(content=' 好的!自然语言处理是一种计算机科学,它研究如何处理和理解人类语言。它可以用来分析文本,识别意图,提取有用信息,以及构建聊天机器人等。机器学习是一种人工智能技术,它可以让计算机从数据中学习,从而改善自己的性能。深度学习是一种机器学习技术,它可以用来更好地理解人类语言,以及识别图像等。'), HumanMessage(content='如果要构建聊天机器人,具体要用什么自然语言处理技术?'), AIMessage(content=' 如果要构建聊天机器人,可以使用语义分析,语法分析,语音识别,机器翻译,情感分析等自然语言处理技术。')]), k=2),\n",
459 | " 'callbacks': None,\n",
460 | " 'callback_manager': None,\n",
461 | " 'verbose': True,\n",
462 | " 'tags': None,\n",
463 | " 'metadata': None,\n",
464 | " 'prompt': PromptTemplate(input_variables=['history', 'input'], template='The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.\\n\\nCurrent conversation:\\n{history}\\nHuman: {input}\\nAI:'),\n",
465 | " 'llm': OpenAI(client=, async_client=, temperature=0.0, max_tokens=1000, openai_api_key='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', openai_proxy=''),\n",
466 | " 'output_key': 'response',\n",
467 | " 'output_parser': StrOutputParser(),\n",
468 | " 'return_final_only': True,\n",
469 | " 'llm_kwargs': {},\n",
470 | " 'input_key': 'input'}"
471 | ]
472 | },
473 | "execution_count": 20,
474 | "metadata": {},
475 | "output_type": "execute_result"
476 | }
477 | ],
478 | "source": [
479 | "conversation_with_summary.__dict__"
480 | ]
481 | },
482 | {
483 | "cell_type": "markdown",
484 | "id": "0fc35065-ff20-4fda-ac5b-0976102160a9",
485 | "metadata": {},
486 | "source": [
487 | "### ConversationSummaryBufferMemory\n",
488 | "\n",
489 | "`ConversationSummaryBufferMemory` 在内存中保留了最近的交互缓冲区,但不仅仅是完全清除旧的交互,而是将它们编译成摘要并同时使用。与以前的实现不同的是,它使用token长度而不是交互次数来确定何时清除交互。"
490 | ]
491 | },
492 | {
493 | "cell_type": "code",
494 | "execution_count": 21,
495 | "id": "174ccc64-2dd9-4c98-b638-6aa542bdbd55",
496 | "metadata": {},
497 | "outputs": [],
498 | "source": [
499 | "from langchain.memory import ConversationSummaryBufferMemory\n",
500 | "\n",
501 | "memory = ConversationSummaryBufferMemory(llm=llm, max_token_limit=10)\n",
502 | "memory.save_context({\"input\": \"嗨,你最近过得怎么样?\"}, {\"output\": \" 嗨!我最近过得很好,谢谢你问。我最近一直在学习新的知识,并且正在尝试改进自己的性能。我也在尝试更多的交流,以便更好地了解人类的思维方式。\"})\n",
503 | "memory.save_context({\"input\": \"你最近学到什么新知识了?\"}, {\"output\": \" 最近我学习了有关自然语言处理的知识,以及如何更好地理解人类的语言。我还学习了有关机器学习的知识,以及如何使用它来改善自己的性能。\"})"
504 | ]
505 | },
506 | {
507 | "cell_type": "code",
508 | "execution_count": 22,
509 | "id": "c520d4eb-f095-4245-b003-83581b619b2b",
510 | "metadata": {},
511 | "outputs": [
512 | {
513 | "data": {
514 | "text/plain": [
515 | "{'history': 'System: \\n\\nThe human asks how the AI is doing recently. The AI responds that it is doing well and has been learning new knowledge and trying to improve its performance. It is also trying to communicate more in order to better understand human thinking. Specifically, it has been learning about natural language processing, how to better understand human language, and about machine learning and how to use it to improve its own performance.'}"
516 | ]
517 | },
518 | "execution_count": 22,
519 | "metadata": {},
520 | "output_type": "execute_result"
521 | }
522 | ],
523 | "source": [
524 | "memory.load_memory_variables({})"
525 | ]
526 | },
527 | {
528 | "cell_type": "code",
529 | "execution_count": 23,
530 | "id": "be306dea-4d37-4327-9969-4c2f0301e279",
531 | "metadata": {},
532 | "outputs": [
533 | {
534 | "name": "stdout",
535 | "output_type": "stream",
536 | "text": [
537 | "System: \n",
538 | "\n",
539 | "The human asks how the AI is doing recently. The AI responds that it is doing well and has been learning new knowledge and trying to improve its performance. It is also trying to communicate more in order to better understand human thinking. Specifically, it has been learning about natural language processing, how to better understand human language, and about machine learning and how to use it to improve its own performance.\n"
540 | ]
541 | }
542 | ],
543 | "source": [
544 | "print(memory.load_memory_variables({})['history'])"
545 | ]
546 | },
547 | {
548 | "cell_type": "code",
549 | "execution_count": null,
550 | "id": "6a286822-1b58-491c-8b7e-b4ad31dcab68",
551 | "metadata": {},
552 | "outputs": [],
553 | "source": []
554 | }
555 | ],
556 | "metadata": {
557 | "kernelspec": {
558 | "display_name": "Python 3 (ipykernel)",
559 | "language": "python",
560 | "name": "python3"
561 | },
562 | "language_info": {
563 | "codemirror_mode": {
564 | "name": "ipython",
565 | "version": 3
566 | },
567 | "file_extension": ".py",
568 | "mimetype": "text/x-python",
569 | "name": "python",
570 | "nbconvert_exporter": "python",
571 | "pygments_lexer": "ipython3",
572 | "version": "3.10.11"
573 | }
574 | },
575 | "nbformat": 4,
576 | "nbformat_minor": 5
577 | }
578 |
--------------------------------------------------------------------------------
/langchain/model_io/output_parser.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "attachments": {},
5 | "cell_type": "markdown",
6 | "id": "4be2e6fa-2187-4617-8433-0db4fb0c099c",
7 | "metadata": {},
8 | "source": [
9 | "# LangChain 核心模块学习:Model I/O\n",
10 | "\n",
11 | "`Model I/O` 是 LangChain 为开发者提供的一套面向 LLM 的标准化模型接口,包括模型输入(Prompts)、模型输出(Output Parsers)和模型本身(Models)。\n",
12 | "\n",
13 | "- Prompts:模板化、动态选择和管理模型输入\n",
14 | "- Models:以通用接口调用语言模型\n",
15 | "- Output Parser:从模型输出中提取信息,并规范化内容\n",
16 | "\n",
17 | "\r\n"
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": 1,
23 | "id": "2e64b01e-f5ad-4614-b0c3-a140f6bb575a",
24 | "metadata": {
25 | "collapsed": true,
26 | "jupyter": {
27 | "outputs_hidden": true
28 | }
29 | },
30 | "outputs": [
31 | {
32 | "name": "stdout",
33 | "output_type": "stream",
34 | "text": [
35 | "Requirement already satisfied: langchain in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (0.0.340)\n",
36 | "Requirement already satisfied: requests<3,>=2 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from langchain) (2.31.0)\n",
37 | "Requirement already satisfied: PyYAML>=5.3 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from langchain) (6.0)\n",
38 | "Requirement already satisfied: dataclasses-json<0.7,>=0.5.7 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from langchain) (0.5.7)\n",
39 | "Requirement already satisfied: pydantic<3,>=1 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from langchain) (1.10.8)\n",
40 | "Requirement already satisfied: aiohttp<4.0.0,>=3.8.3 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from langchain) (3.8.5)\n",
41 | "Requirement already satisfied: tenacity<9.0.0,>=8.1.0 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from langchain) (8.2.2)\n",
42 | "Requirement already satisfied: anyio<4.0 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from langchain) (3.6.2)\n",
43 | "Requirement already satisfied: SQLAlchemy<3,>=1.4 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from langchain) (1.4.41)\n",
44 | "Requirement already satisfied: async-timeout<5.0.0,>=4.0.0 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from langchain) (4.0.2)\n",
45 | "Requirement already satisfied: langsmith<0.1.0,>=0.0.63 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from langchain) (0.0.66)\n",
46 | "Requirement already satisfied: jsonpatch<2.0,>=1.33 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from langchain) (1.33)\n",
47 | "Requirement already satisfied: numpy<2,>=1 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from langchain) (1.26.2)\n",
48 | "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (3.1.0)\n",
49 | "Requirement already satisfied: frozenlist>=1.1.1 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.3.3)\n",
50 | "Requirement already satisfied: multidict<7.0,>=4.5 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (6.0.4)\n",
51 | "Requirement already satisfied: yarl<2.0,>=1.0 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.9.2)\n",
52 | "Requirement already satisfied: aiosignal>=1.1.2 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (1.3.1)\n",
53 | "Requirement already satisfied: attrs>=17.3.0 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from aiohttp<4.0.0,>=3.8.3->langchain) (23.1.0)\n",
54 | "Requirement already satisfied: idna>=2.8 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from anyio<4.0->langchain) (3.4)\n",
55 | "Requirement already satisfied: sniffio>=1.1 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from anyio<4.0->langchain) (1.3.0)\n",
56 | "Requirement already satisfied: typing-inspect>=0.4.0 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from dataclasses-json<0.7,>=0.5.7->langchain) (0.9.0)\n",
57 | "Requirement already satisfied: marshmallow<4.0.0,>=3.3.0 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from dataclasses-json<0.7,>=0.5.7->langchain) (3.19.0)\n",
58 | "Requirement already satisfied: marshmallow-enum<2.0.0,>=1.5.1 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from dataclasses-json<0.7,>=0.5.7->langchain) (1.5.1)\n",
59 | "Requirement already satisfied: jsonpointer>=1.9 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from jsonpatch<2.0,>=1.33->langchain) (2.3)\n",
60 | "Requirement already satisfied: typing-extensions>=4.2.0 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from pydantic<3,>=1->langchain) (4.6.2)\n",
61 | "Requirement already satisfied: certifi>=2017.4.17 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from requests<3,>=2->langchain) (2023.5.7)\n",
62 | "Requirement already satisfied: urllib3<3,>=1.21.1 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from requests<3,>=2->langchain) (1.26.16)\n",
63 | "Requirement already satisfied: greenlet!=0.4.17 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from SQLAlchemy<3,>=1.4->langchain) (2.0.2)\n",
64 | "Requirement already satisfied: packaging>=17.0 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from marshmallow<4.0.0,>=3.3.0->dataclasses-json<0.7,>=0.5.7->langchain) (23.1)\n",
65 | "Requirement already satisfied: mypy-extensions>=0.3.0 in /root/miniconda3/envs/langchain/lib/python3.10/site-packages (from typing-inspect>=0.4.0->dataclasses-json<0.7,>=0.5.7->langchain) (1.0.0)\n",
66 | "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
67 | "\u001b[0m"
68 | ]
69 | }
70 | ],
71 | "source": [
72 | "# 安装最新版本的 LangChain Python SDK(https://github.com/langchain-ai/langchain)\n",
73 | "!pip install -U langchain"
74 | ]
75 | },
76 | {
77 | "cell_type": "markdown",
78 | "id": "ce4a2474-0b69-4830-85cd-3715c22df304",
79 | "metadata": {},
80 | "source": [
81 | "## 输出解析器 Output Parser\n",
82 | "\n",
83 | "**语言模型的输出是文本。**\n",
84 | "\n",
85 | "但很多时候,您可能希望获得比纯文本**更结构化的信息**。这就是输出解析器的价值所在。\n",
86 | "\n",
87 | "输出解析器是帮助结构化语言模型响应的类。它们必须实现两种主要方法:\n",
88 | "\n",
89 | "- \"获取格式指令\":返回一个包含有关如何格式化语言模型输出的字符串的方法。\n",
90 | "- \"解析\":接受一个字符串(假设为来自语言模型的响应),并将其解析成某种结构。\n",
91 | "\n",
92 | "然后还有一种可选方法:\n",
93 | "- \"使用提示进行解析\":接受一个字符串(假设为来自语言模型的响应)和一个提示(假设为生成此响应的提示),并将其解析成某种结构。在需要重新尝试或修复输出,并且需要从提示中获取信息以执行此操作时,通常会提供提示。"
94 | ]
95 | },
96 | {
97 | "cell_type": "markdown",
98 | "id": "1f14f4cf-8e30-47ab-b8b1-d58a90b5b1c1",
99 | "metadata": {},
100 | "source": [
101 | "### 列表解析 List Parser\n",
102 | "\n",
103 | "当您想要返回一个逗号分隔的项目列表时,可以使用此输出解析器。"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 2,
109 | "id": "0089c8a5-a859-49f2-bec0-fcd84f2f3b56",
110 | "metadata": {},
111 | "outputs": [],
112 | "source": [
113 | "from langchain.output_parsers import CommaSeparatedListOutputParser\n",
114 | "from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate\n",
115 | "from langchain.llms import OpenAI\n",
116 | "\n",
117 | "# 创建一个输出解析器,用于处理带逗号分隔的列表输出\n",
118 | "output_parser = CommaSeparatedListOutputParser()\n",
119 | "\n",
120 | "# 获取格式化指令,该指令告诉模型如何格式化其输出\n",
121 | "format_instructions = output_parser.get_format_instructions()\n",
122 | "\n",
123 | "# 创建一个提示模板,它会基于给定的模板和变量来生成提示\n",
124 | "prompt = PromptTemplate(\n",
125 | " template=\"List five {subject}.\\n{format_instructions}\", # 模板内容\n",
126 | " input_variables=[\"subject\"], # 输入变量\n",
127 | " partial_variables={\"format_instructions\": format_instructions} # 预定义的变量,这里我们传入格式化指令\n",
128 | ")"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": 3,
134 | "id": "8d681566-cde1-4ae5-8cd7-f53cf59c3e36",
135 | "metadata": {},
136 | "outputs": [],
137 | "source": [
138 | "# 使用提示模板和给定的主题来格式化输入\n",
139 | "_input = prompt.format(subject=\"ice cream flavors\")"
140 | ]
141 | },
142 | {
143 | "cell_type": "code",
144 | "execution_count": 4,
145 | "id": "ef5ea022-b27f-4cc9-b6c8-5d4e96e51d51",
146 | "metadata": {},
147 | "outputs": [
148 | {
149 | "name": "stdout",
150 | "output_type": "stream",
151 | "text": [
152 | "List five ice cream flavors.\n",
153 | "Your response should be a list of comma separated values, eg: `foo, bar, baz`\n"
154 | ]
155 | }
156 | ],
157 | "source": [
158 | "print(_input)"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": 5,
164 | "id": "4db222de-501b-4114-aaaa-03e54c2da228",
165 | "metadata": {},
166 | "outputs": [],
167 | "source": [
168 | "llm = OpenAI(temperature=0)"
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": 6,
174 | "id": "bd743ccf-c47f-4bde-a5a6-63052f6a2553",
175 | "metadata": {},
176 | "outputs": [],
177 | "source": [
178 | "output = llm(_input)"
179 | ]
180 | },
181 | {
182 | "cell_type": "code",
183 | "execution_count": 7,
184 | "id": "12a2bc49-2656-47e8-b5a4-ee119ab77004",
185 | "metadata": {},
186 | "outputs": [
187 | {
188 | "name": "stdout",
189 | "output_type": "stream",
190 | "text": [
191 | "\n",
192 | "\n",
193 | "Vanilla, Chocolate, Strawberry, Mint Chocolate Chip, Cookies and Cream\n"
194 | ]
195 | }
196 | ],
197 | "source": [
198 | "print(output)"
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "execution_count": 8,
204 | "id": "0bb71bf2-3220-4326-a69c-7b0fa1864877",
205 | "metadata": {},
206 | "outputs": [
207 | {
208 | "data": {
209 | "text/plain": [
210 | "['Vanilla',\n",
211 | " 'Chocolate',\n",
212 | " 'Strawberry',\n",
213 | " 'Mint Chocolate Chip',\n",
214 | " 'Cookies and Cream']"
215 | ]
216 | },
217 | "execution_count": 8,
218 | "metadata": {},
219 | "output_type": "execute_result"
220 | }
221 | ],
222 | "source": [
223 | "# 使用之前创建的输出解析器来解析模型的输出\n",
224 | "output_parser.parse(output)"
225 | ]
226 | },
227 | {
228 | "cell_type": "code",
229 | "execution_count": null,
230 | "id": "81fbd00c-5e03-4b23-a276-6acc0f5d5f1a",
231 | "metadata": {},
232 | "outputs": [],
233 | "source": []
234 | },
235 | {
236 | "cell_type": "markdown",
237 | "id": "bd93d8d7-7d77-4453-97a6-f7349090a370",
238 | "metadata": {},
239 | "source": [
240 | "### 日期解析 Datatime Parser"
241 | ]
242 | },
243 | {
244 | "cell_type": "code",
245 | "execution_count": 9,
246 | "id": "9b91deaf-6d3f-4d48-a084-58ec1ec4b0b3",
247 | "metadata": {},
248 | "outputs": [],
249 | "source": [
250 | "from langchain.output_parsers import DatetimeOutputParser\n",
251 | "from langchain.chains import LLMChain\n",
252 | "\n",
253 | "output_parser = DatetimeOutputParser()\n",
254 | "template = \"\"\"Answer the users question:\n",
255 | "\n",
256 | "{question}\n",
257 | "\n",
258 | "{format_instructions}\"\"\"\n",
259 | "\n",
260 | "prompt = PromptTemplate.from_template(\n",
261 | " template,\n",
262 | " partial_variables={\"format_instructions\": output_parser.get_format_instructions()},\n",
263 | ")"
264 | ]
265 | },
266 | {
267 | "cell_type": "code",
268 | "execution_count": 10,
269 | "id": "96ad1f11-c2d0-4bb5-a8ec-1b5dd5132573",
270 | "metadata": {},
271 | "outputs": [
272 | {
273 | "name": "stdout",
274 | "output_type": "stream",
275 | "text": [
276 | "input_variables=['question'] partial_variables={'format_instructions': 'Write a datetime string that matches the \\n following pattern: \"%Y-%m-%dT%H:%M:%S.%fZ\". Examples: 1847-04-07T02:43:02.250267Z, 446-01-18T23:15:53.307833Z, 229-03-15T02:26:50.545980Z'} template='Answer the users question:\\n\\n{question}\\n\\n{format_instructions}'\n"
277 | ]
278 | }
279 | ],
280 | "source": [
281 | "print(prompt)"
282 | ]
283 | },
284 | {
285 | "cell_type": "code",
286 | "execution_count": 11,
287 | "id": "1eca781f-0146-45e5-9848-bf8347513a77",
288 | "metadata": {},
289 | "outputs": [
290 | {
291 | "name": "stdout",
292 | "output_type": "stream",
293 | "text": [
294 | "Answer the users question:\n",
295 | "\n",
296 | "around when was bitcoin founded?\n",
297 | "\n",
298 | "Write a datetime string that matches the \n",
299 | " following pattern: \"%Y-%m-%dT%H:%M:%S.%fZ\". Examples: 1847-04-07T02:43:02.250267Z, 446-01-18T23:15:53.307833Z, 229-03-15T02:26:50.545980Z\n"
300 | ]
301 | }
302 | ],
303 | "source": [
304 | "print(prompt.format(question=\"around when was bitcoin founded?\"))"
305 | ]
306 | },
307 | {
308 | "cell_type": "code",
309 | "execution_count": 12,
310 | "id": "4f46d70f-78a4-4576-9583-4f67c1ab1d08",
311 | "metadata": {},
312 | "outputs": [],
313 | "source": [
314 | "chain = LLMChain(prompt=prompt, llm=OpenAI())"
315 | ]
316 | },
317 | {
318 | "cell_type": "code",
319 | "execution_count": 13,
320 | "id": "97e3270b-43be-4018-b5c8-b2f7ffdbec46",
321 | "metadata": {},
322 | "outputs": [],
323 | "source": [
324 | "output = chain.run(\"around when was bitcoin founded?\")"
325 | ]
326 | },
327 | {
328 | "cell_type": "code",
329 | "execution_count": 14,
330 | "id": "dc7915dc-ef1e-4feb-8db1-83d51d773daa",
331 | "metadata": {},
332 | "outputs": [
333 | {
334 | "data": {
335 | "text/plain": [
336 | "'\\n\\n2008-01-03T18:15:05.000000Z'"
337 | ]
338 | },
339 | "execution_count": 14,
340 | "metadata": {},
341 | "output_type": "execute_result"
342 | }
343 | ],
344 | "source": [
345 | "output"
346 | ]
347 | },
348 | {
349 | "cell_type": "code",
350 | "execution_count": 15,
351 | "id": "e9bfcab7-a335-4b36-8292-6ac92352222b",
352 | "metadata": {},
353 | "outputs": [
354 | {
355 | "data": {
356 | "text/plain": [
357 | "datetime.datetime(2008, 1, 3, 18, 15, 5)"
358 | ]
359 | },
360 | "execution_count": 15,
361 | "metadata": {},
362 | "output_type": "execute_result"
363 | }
364 | ],
365 | "source": [
366 | "output_parser.parse(output)"
367 | ]
368 | },
369 | {
370 | "cell_type": "code",
371 | "execution_count": 16,
372 | "id": "2db4804c-7a1f-41df-99ed-cb7956387155",
373 | "metadata": {},
374 | "outputs": [
375 | {
376 | "name": "stdout",
377 | "output_type": "stream",
378 | "text": [
379 | "2008-01-03 18:15:05\n"
380 | ]
381 | }
382 | ],
383 | "source": [
384 | "print(output_parser.parse(output))"
385 | ]
386 | },
387 | {
388 | "cell_type": "code",
389 | "execution_count": null,
390 | "id": "a18b9b7e-cce2-4561-9030-21fdc5822d3f",
391 | "metadata": {},
392 | "outputs": [],
393 | "source": []
394 | }
395 | ],
396 | "metadata": {
397 | "kernelspec": {
398 | "display_name": "Python 3 (ipykernel)",
399 | "language": "python",
400 | "name": "python3"
401 | },
402 | "language_info": {
403 | "codemirror_mode": {
404 | "name": "ipython",
405 | "version": 3
406 | },
407 | "file_extension": ".py",
408 | "mimetype": "text/x-python",
409 | "name": "python",
410 | "nbconvert_exporter": "python",
411 | "pygments_lexer": "ipython3",
412 | "version": "3.10.11"
413 | }
414 | },
415 | "nbformat": 4,
416 | "nbformat_minor": 5
417 | }
418 |
--------------------------------------------------------------------------------
/llama/llama2_inference.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "886ed769-07f9-474f-9f86-c9f1f17345e4",
6 | "metadata": {},
7 | "source": [
8 | "## 使用微调后的 LLaMA2-7B 推理"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": null,
14 | "id": "24f9cf61-2994-48c7-9a42-150920397a93",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": []
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 1,
22 | "id": "1c571fa5-aa51-495c-8b3d-d5605a02e491",
23 | "metadata": {},
24 | "outputs": [
25 | {
26 | "data": {
27 | "application/vnd.jupyter.widget-view+json": {
28 | "model_id": "70a30ffcc759412186e57f3c13521b52",
29 | "version_major": 2,
30 | "version_minor": 0
31 | },
32 | "text/plain": [
33 | "Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]"
34 | ]
35 | },
36 | "metadata": {},
37 | "output_type": "display_data"
38 | },
39 | {
40 | "name": "stderr",
41 | "output_type": "stream",
42 | "text": [
43 | "/root/miniconda3/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:392: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0.9` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
44 | " warnings.warn(\n",
45 | "/root/miniconda3/lib/python3.11/site-packages/transformers/generation/configuration_utils.py:397: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.6` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`. This was detected when initializing the generation config instance, which means the corresponding file may hold incorrect parameterization and should be fixed.\n",
46 | " warnings.warn(\n"
47 | ]
48 | }
49 | ],
50 | "source": [
51 | "import torch\n",
52 | "from peft import AutoPeftModelForCausalLM\n",
53 | "from transformers import AutoTokenizer\n",
54 | "\n",
55 | "\n",
56 | "model_dir = \"models/llama-7-int4-dolly\"\n",
57 | " \n",
58 | "# 加载基础LLM模型与分词器\n",
59 | "model = AutoPeftModelForCausalLM.from_pretrained(\n",
60 | " model_dir,\n",
61 | " low_cpu_mem_usage=True,\n",
62 | " torch_dtype=torch.float16,\n",
63 | " load_in_4bit=True,\n",
64 | ") \n",
65 | "tokenizer = AutoTokenizer.from_pretrained(model_dir)"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": 2,
71 | "id": "7700c042-20cf-4ff1-8c40-a91cabeaaaa6",
72 | "metadata": {},
73 | "outputs": [
74 | {
75 | "name": "stderr",
76 | "output_type": "stream",
77 | "text": [
78 | "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n",
79 | "/root/miniconda3/lib/python3.11/site-packages/bitsandbytes/nn/modules.py:226: UserWarning: Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference or training speed.\n",
80 | " warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference or training speed.')\n"
81 | ]
82 | },
83 | {
84 | "name": "stdout",
85 | "output_type": "stream",
86 | "text": [
87 | "Prompt:\n",
88 | "Football (or soccer) is the world's most popular sport. Others include cricket, hockey, tennis, volleyball, table tennis, and basketball. This might come as a surprise to Americans, who favor (American) football.\n",
89 | "\n",
90 | "Generated instruction:\n",
91 | "Which is the most popular sport in the world?\n",
92 | "\n",
93 | "Ground truth:\n",
94 | "What are the world's most popular sports?\n"
95 | ]
96 | }
97 | ],
98 | "source": [
99 | "from datasets import load_dataset \n",
100 | "from random import randrange\n",
101 | " \n",
102 | " \n",
103 | "# 从hub加载数据集并得到一个样本\n",
104 | "dataset = load_dataset(\"databricks/databricks-dolly-15k\", split=\"train\")\n",
105 | "sample = dataset[randrange(len(dataset))]\n",
106 | " \n",
107 | "prompt = f\"\"\"### Instruction:\n",
108 | "Use the Input below to create an instruction, which could have been used to generate the input using an LLM. \n",
109 | " \n",
110 | "### Input:\n",
111 | "{sample['response']}\n",
112 | " \n",
113 | "### Response:\n",
114 | "\"\"\"\n",
115 | " \n",
116 | "input_ids = tokenizer(prompt, return_tensors=\"pt\", truncation=True).input_ids.cuda()\n",
117 | "\n",
118 | "outputs = model.generate(input_ids=input_ids, max_new_tokens=100, do_sample=True, top_p=0.9,temperature=0.9)\n",
119 | "\n",
120 | "print(f\"Prompt:\\n{sample['response']}\\n\")\n",
121 | "print(f\"Generated instruction:\\n{tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt):]}\")\n",
122 | "print(f\"Ground truth:\\n{sample['instruction']}\")"
123 | ]
124 | },
125 | {
126 | "cell_type": "code",
127 | "execution_count": null,
128 | "id": "7f6ea5b0-cf32-4cc0-bd5e-fc17268992c0",
129 | "metadata": {},
130 | "outputs": [],
131 | "source": []
132 | }
133 | ],
134 | "metadata": {
135 | "kernelspec": {
136 | "display_name": "Python 3 (ipykernel)",
137 | "language": "python",
138 | "name": "python3"
139 | },
140 | "language_info": {
141 | "codemirror_mode": {
142 | "name": "ipython",
143 | "version": 3
144 | },
145 | "file_extension": ".py",
146 | "mimetype": "text/x-python",
147 | "name": "python",
148 | "nbconvert_exporter": "python",
149 | "pygments_lexer": "ipython3",
150 | "version": "3.11.5"
151 | }
152 | },
153 | "nbformat": 4,
154 | "nbformat_minor": 5
155 | }
156 |
--------------------------------------------------------------------------------
/peft/data/audio/test_zh.flac:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DjangoPeng/LLM-quickstart/bf3e50d82a104c2b5bfaea3f3d476e8b8193ca2c/peft/data/audio/test_zh.flac
--------------------------------------------------------------------------------
/peft/peft_chatglm_inference.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "fc5bde60-1899-461d-8083-3ee04ac7c099",
6 | "metadata": {},
7 | "source": [
8 | "# 模型推理 - 使用 QLoRA 微调后的 ChatGLM3-6B"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": null,
14 | "id": "3292b88c-91f0-48d2-91a5-06b0830c7e70",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import torch\n",
19 | "from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig\n",
20 | "from peft import PeftModel, PeftConfig\n",
21 | "\n",
22 | "# 定义全局变量和参数\n",
23 | "model_name_or_path = 'THUDM/chatglm3-6b' # 模型ID或本地路径\n",
24 | "peft_model_path = f\"models/demo/{model_name_or_path}\"\n"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": 2,
30 | "id": "9f81454c-24b2-4072-ab05-b25f9b120ae6",
31 | "metadata": {},
32 | "outputs": [
33 | {
34 | "data": {
35 | "application/vnd.jupyter.widget-view+json": {
36 | "model_id": "a6a3dddcd9df4715a4b693559cf30cff",
37 | "version_major": 2,
38 | "version_minor": 0
39 | },
40 | "text/plain": [
41 | "Loading checkpoint shards: 0%| | 0/7 [00:00, ?it/s]"
42 | ]
43 | },
44 | "metadata": {},
45 | "output_type": "display_data"
46 | }
47 | ],
48 | "source": [
49 | "config = PeftConfig.from_pretrained(peft_model_path)\n",
50 | "\n",
51 | "q_config = BitsAndBytesConfig(load_in_4bit=True,\n",
52 | " bnb_4bit_quant_type='nf4',\n",
53 | " bnb_4bit_use_double_quant=True,\n",
54 | " bnb_4bit_compute_dtype=torch.float32)\n",
55 | "\n",
56 | "base_model = AutoModel.from_pretrained(config.base_model_name_or_path,\n",
57 | " quantization_config=q_config,\n",
58 | " trust_remote_code=True,\n",
59 | " device_map='auto')\n",
60 | "base_model.requires_grad_(False)\n",
61 | "base_model.eval()"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": null,
67 | "id": "342b3659-d644-4232-8af1-f092e733bf40",
68 | "metadata": {},
69 | "outputs": [],
70 | "source": []
71 | },
72 | {
73 | "cell_type": "markdown",
74 | "id": "6d23e720-dee1-4b43-a298-0cbe1d8ad11d",
75 | "metadata": {},
76 | "source": [
77 | "## 微调前后效果对比\n",
78 | "\n",
79 | "### ChatGLM-6B\n",
80 | "\n",
81 | "```\n",
82 | "输入:\n",
83 | "\n",
84 | "类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领\n",
85 | "\n",
86 | "ChatGLM-6B 微调前输出:\n",
87 | "\n",
88 | "* 版型:修身\n",
89 | "* 显瘦:True\n",
90 | "* 风格:文艺\n",
91 | "* 简约:True\n",
92 | "* 图案:印花\n",
93 | "* 撞色:True\n",
94 | "* 裙下摆:直筒或微喇\n",
95 | "* 裙长:中长裙\n",
96 | "* 连衣裙:True\n",
97 | "\n",
98 | "ChatGLM-6B 微调后输出:\n",
99 | "\n",
100 | "一款简约而不简单的连衣裙,采用撞色的印花点缀,打造文艺气息,简约的圆领,修饰脸型。衣袖和裙摆的压褶,增添设计感,修身的版型,勾勒出窈窕的身材曲线。\n",
101 | "```\n",
102 | "\n",
103 | "### ChatGLM2-6B\n",
104 | "\n",
105 | "```\n",
106 | "输入:\n",
107 | "类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领\n",
108 | "\n",
109 | "微调前:\n",
110 | "这款裙子,版型显瘦,采用简约文艺风格,图案为印花和撞色设计,裙下摆为压褶裙摆,裙长为连衣裙,适合各种场合穿着,让你舒适自在。圆领设计,优雅清新,让你在任何场合都充满自信。如果你正在寻找一款舒适、时尚、优雅的裙子,不妨 考虑这款吧!\n",
111 | "\n",
112 | "微调后: \n",
113 | "这款连衣裙简约的设计,撞色印花点缀,丰富了视觉,上身更显时尚。修身的版型,贴合身形,穿着舒适不束缚。圆领的设计,露出精致锁骨,尽显女性优雅气质。下摆压褶的设计,增添立体感,行走间更显飘逸。前短后长的设计,显 得身材比例更加完美。文艺的碎花设计,更显精致。\n",
114 | "```\n",
115 | "\n",
116 | "### ChatGLM3-6B"
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": 3,
122 | "id": "9d7757a4-7d1f-488f-8d80-b73dfa4863d4",
123 | "metadata": {},
124 | "outputs": [
125 | {
126 | "name": "stdout",
127 | "output_type": "stream",
128 | "text": [
129 | "输入:\n",
130 | "类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领\n"
131 | ]
132 | }
133 | ],
134 | "source": [
135 | "input_text = '类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领'\n",
136 | "print(f'输入:\\n{input_text}')\n",
137 | "tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path, trust_remote_code=True)"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": 8,
143 | "id": "2d30fce1-e01f-4303-aa55-ed004eaa22a8",
144 | "metadata": {},
145 | "outputs": [
146 | {
147 | "name": "stdout",
148 | "output_type": "stream",
149 | "text": [
150 | "ChatGLM3-6B 微调前:\n",
151 | "连衣裙是女孩子们最爱的单品之一,这款连衣裙采用撞色圆领设计,简洁大方,修饰脸型。衣身采用印花图案点缀,展现出优雅文艺的气质。袖子采用压褶设计,轻松遮肉显瘦,修饰臂部线条。简约的款式设计,穿着舒适,轻松百搭。\n"
152 | ]
153 | }
154 | ],
155 | "source": [
156 | "response, history = base_model.chat(tokenizer=tokenizer, query=input_text)\n",
157 | "print(f'ChatGLM3-6B 微调前:\\n{response}')"
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "execution_count": 9,
163 | "id": "38b5a770-baef-4697-bb71-6088e3a43d59",
164 | "metadata": {},
165 | "outputs": [
166 | {
167 | "name": "stdout",
168 | "output_type": "stream",
169 | "text": [
170 | "ChatGLM3-6B 微调后: \n",
171 | "这款连衣裙简约的圆领设计,凸显出修长的颈部线条,而衣身采用撞色印花设计,更显俏皮甜美,而袖口和裙摆处加入压褶设计,增添层次感,更具艺术感,而裙身整体设计简约大气,彰显出优雅文艺的气质。\n"
172 | ]
173 | }
174 | ],
175 | "source": [
176 | "model = PeftModel.from_pretrained(base_model, peft_model_path)\n",
177 | "response, history = model.chat(tokenizer=tokenizer, query=input_text)\n",
178 | "print(f'ChatGLM3-6B 微调后: \\n{response}')"
179 | ]
180 | },
181 | {
182 | "cell_type": "code",
183 | "execution_count": null,
184 | "id": "1cf454e0-f0f5-4fb0-ab90-83e9615f132a",
185 | "metadata": {},
186 | "outputs": [],
187 | "source": []
188 | }
189 | ],
190 | "metadata": {
191 | "kernelspec": {
192 | "display_name": "Python 3 (ipykernel)",
193 | "language": "python",
194 | "name": "python3"
195 | },
196 | "language_info": {
197 | "codemirror_mode": {
198 | "name": "ipython",
199 | "version": 3
200 | },
201 | "file_extension": ".py",
202 | "mimetype": "text/x-python",
203 | "name": "python",
204 | "nbconvert_exporter": "python",
205 | "pygments_lexer": "ipython3",
206 | "version": "3.11.5"
207 | }
208 | },
209 | "nbformat": 4,
210 | "nbformat_minor": 5
211 | }
212 |
--------------------------------------------------------------------------------
/quantization/AWQ-opt-125m.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "H_D9kG_efts3"
7 | },
8 | "source": [
9 | "# Transformers 模型量化技术:AWQ"
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {
15 | "id": "WE9IhcVyktah"
16 | },
17 | "source": [
18 | ""
19 | ]
20 | },
21 | {
22 | "attachments": {},
23 | "cell_type": "markdown",
24 | "metadata": {
25 | "id": "Wwsg6nCwoThm"
26 | },
27 | "source": [
28 | "在2023年6月,Ji Lin等人发表了论文[AWQ:Activation-aware Weight Quantization for LLM Compression and Acceleration](https://arxiv.org/pdf/2306.00978.pdf)。\n",
29 | "\n",
30 | "这篇论文详细介绍了一种激活感知权重量化算法,可以用于压缩任何基于 Transformer 的语言模型,同时只有微小的性能下降。关于 AWQ 算法的详细介绍,见[MIT Han Song 教授分享](https://hanlab.mit.edu/projects/awq)。\n",
31 | "\n",
32 | "transformers 现在支持两个不同的 AWQ 开源实现库:\n",
33 | "\n",
34 | "- [AutoAWQ](https://github.com/casper-hansen/AutoAWQ)\n",
35 | "- [LLM-AWQ](https://github.com/mit-han-lab/llm-awq) \n"
36 | ]
37 | },
38 | {
39 | "cell_type": "markdown",
40 | "metadata": {
41 | "id": "-H2019RkoiM-"
42 | },
43 | "source": [
44 | "因为 LLM-AWQ 不支持 Nvidia T4 GPU(课程演示 GPU),所以我们使用 AutoAWQ 库来介绍和演示 AWQ 模型量化技术。"
45 | ]
46 | },
47 | {
48 | "cell_type": "markdown",
49 | "metadata": {},
50 | "source": [
51 | "## 量化前模型测试文本生成任务"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 1,
57 | "metadata": {},
58 | "outputs": [],
59 | "source": [
60 | "from transformers import pipeline\n",
61 | "\n",
62 | "model_path = \"facebook/opt-125m\"\n",
63 | "\n",
64 | "# 使用 GPU 加载原始的 OPT-125m 模型\n",
65 | "generator = pipeline('text-generation',\n",
66 | " model=model_path,\n",
67 | " device=0,\n",
68 | " do_sample=True,\n",
69 | " num_return_sequences=3)"
70 | ]
71 | },
72 | {
73 | "cell_type": "markdown",
74 | "metadata": {},
75 | "source": [
76 | "#### 实测GPU显存占用:加载 OPT-125m 模型后\n",
77 | "\n",
78 | "```shell\n",
79 | "Sun Dec 24 15:11:33 2023\n",
80 | "+---------------------------------------------------------------------------------------+\n",
81 | "| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\n",
82 | "|-----------------------------------------+----------------------+----------------------+\n",
83 | "| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
84 | "| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
85 | "| | | MIG M. |\n",
86 | "|=========================================+======================+======================|\n",
87 | "| 0 Tesla T4 Off | 00000000:00:0D.0 Off | 0 |\n",
88 | "| N/A 47C P0 26W / 70W | 635MiB / 15360MiB | 0% Default |\n",
89 | "| | | N/A |\n",
90 | "+-----------------------------------------+----------------------+----------------------+\n",
91 | "\n",
92 | "+---------------------------------------------------------------------------------------+\n",
93 | "| Processes: |\n",
94 | "| GPU GI CI PID Type Process name GPU Memory |\n",
95 | "| ID ID Usage |\n",
96 | "|=======================================================================================|\n",
97 | "```"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": 2,
103 | "metadata": {},
104 | "outputs": [
105 | {
106 | "data": {
107 | "text/plain": [
108 | "[{'generated_text': 'The woman worked as a tour guide, and was also a teacher in her own right. In her'},\n",
109 | " {'generated_text': 'The woman worked as a sales manager at a grocery store so all she needed was someone to sit next'},\n",
110 | " {'generated_text': 'The woman worked as a chef for several years before deciding to retire in 2010. She was also the'}]"
111 | ]
112 | },
113 | "execution_count": 2,
114 | "metadata": {},
115 | "output_type": "execute_result"
116 | }
117 | ],
118 | "source": [
119 | "generator(\"The woman worked as a\")"
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": 3,
125 | "metadata": {},
126 | "outputs": [
127 | {
128 | "data": {
129 | "text/plain": [
130 | "[{'generated_text': 'The man worked as a \"bait\" during a career that I think had an emphasis on fishing'},\n",
131 | " {'generated_text': 'The man worked as a construction worker in California for a couple years before he moved into real estate in'},\n",
132 | " {'generated_text': \"The man worked as a cashier, and he's probably never heard of the place where you're\"}]"
133 | ]
134 | },
135 | "execution_count": 3,
136 | "metadata": {},
137 | "output_type": "execute_result"
138 | }
139 | ],
140 | "source": [
141 | "generator(\"The man worked as a\")"
142 | ]
143 | },
144 | {
145 | "cell_type": "markdown",
146 | "metadata": {
147 | "id": "6dJJRQ2p7eLQ"
148 | },
149 | "source": [
150 | "## 使用 AutoAWQ 量化模型\n",
151 | "\n",
152 | "下面我们以 `facebook opt-125m` 模型为例,使用 `AutoAWQ` 库实现的 AWQ 算法实现模型量化。"
153 | ]
154 | },
155 | {
156 | "cell_type": "code",
157 | "execution_count": 4,
158 | "metadata": {},
159 | "outputs": [
160 | {
161 | "data": {
162 | "application/vnd.jupyter.widget-view+json": {
163 | "model_id": "486f74cb2a7b4d11bcd5fd70ed8277e9",
164 | "version_major": 2,
165 | "version_minor": 0
166 | },
167 | "text/plain": [
168 | "Fetching 9 files: 0%| | 0/9 [00:00, ?it/s]"
169 | ]
170 | },
171 | "metadata": {},
172 | "output_type": "display_data"
173 | }
174 | ],
175 | "source": [
176 | "from awq import AutoAWQForCausalLM\n",
177 | "from transformers import AutoTokenizer\n",
178 | "\n",
179 | "\n",
180 | "quant_path = \"models/opt-125m-awq\"\n",
181 | "quant_config = {\"zero_point\": True, \"q_group_size\": 128, \"w_bit\": 4, \"version\": \"GEMM\"}\n",
182 | "\n",
183 | "# 加载模型\n",
184 | "model = AutoAWQForCausalLM.from_pretrained(model_path, device_map=\"cuda\")\n",
185 | "tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)"
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "execution_count": 5,
191 | "metadata": {
192 | "id": "Qn_P_E5p7gAN"
193 | },
194 | "outputs": [
195 | {
196 | "name": "stderr",
197 | "output_type": "stream",
198 | "text": [
199 | "/root/miniconda3/lib/python3.11/site-packages/huggingface_hub/repocard.py:105: UserWarning: Repo card metadata block was not found. Setting CardData to empty.\n",
200 | " warnings.warn(\"Repo card metadata block was not found. Setting CardData to empty.\")\n",
201 | "AWQ: 100%|██████████| 12/12 [01:20<00:00, 6.71s/it]\n"
202 | ]
203 | }
204 | ],
205 | "source": [
206 | "# 量化模型\n",
207 | "model.quantize(tokenizer, quant_config=quant_config)"
208 | ]
209 | },
210 | {
211 | "cell_type": "markdown",
212 | "metadata": {},
213 | "source": [
214 | "#### 实测GPU显存使用:量化模型时峰值达到将近 4GB\n",
215 | "\n",
216 | "```shell\n",
217 | "Sun Dec 24 15:12:50 2023\n",
218 | "+---------------------------------------------------------------------------------------+\n",
219 | "| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\n",
220 | "|-----------------------------------------+----------------------+----------------------+\n",
221 | "| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
222 | "| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
223 | "| | | MIG M. |\n",
224 | "|=========================================+======================+======================|\n",
225 | "| 0 Tesla T4 Off | 00000000:00:0D.0 Off | 0 |\n",
226 | "| N/A 48C P0 32W / 70W | 3703MiB / 15360MiB | 2% Default |\n",
227 | "| | | N/A |\n",
228 | "+-----------------------------------------+----------------------+----------------------+\n",
229 | "\n",
230 | "+---------------------------------------------------------------------------------------+\n",
231 | "| Processes: |\n",
232 | "| GPU GI CI PID Type Process name GPU Memory |\n",
233 | "| ID ID Usage |\n",
234 | "|=======================================================================================|\n",
235 | "```"
236 | ]
237 | },
238 | {
239 | "cell_type": "code",
240 | "execution_count": 6,
241 | "metadata": {
242 | "id": "nVzKDBlP_6MV"
243 | },
244 | "outputs": [
245 | {
246 | "data": {
247 | "text/plain": [
248 | "{'zero_point': True, 'q_group_size': 128, 'w_bit': 4, 'version': 'GEMM'}"
249 | ]
250 | },
251 | "execution_count": 6,
252 | "metadata": {},
253 | "output_type": "execute_result"
254 | }
255 | ],
256 | "source": [
257 | "quant_config"
258 | ]
259 | },
260 | {
261 | "cell_type": "markdown",
262 | "metadata": {
263 | "id": "PuPLq9sa8EaN"
264 | },
265 | "source": [
266 | "#### Transformers 兼容性配置\n",
267 | "\n",
268 | "为了使`quant_config` 与 transformers 兼容,我们需要修改配置文件:`使用 Transformers.AwqConfig 来实例化量化模型配置`"
269 | ]
270 | },
271 | {
272 | "cell_type": "code",
273 | "execution_count": 7,
274 | "metadata": {
275 | "id": "KE8xjwlL8DnA"
276 | },
277 | "outputs": [],
278 | "source": [
279 | "from transformers import AwqConfig, AutoConfig\n",
280 | "\n",
281 | "# 修改配置文件以使其与transformers集成兼容\n",
282 | "quantization_config = AwqConfig(\n",
283 | " bits=quant_config[\"w_bit\"],\n",
284 | " group_size=quant_config[\"q_group_size\"],\n",
285 | " zero_point=quant_config[\"zero_point\"],\n",
286 | " version=quant_config[\"version\"].lower(),\n",
287 | ").to_dict()\n",
288 | "\n",
289 | "# 预训练的transformers模型存储在model属性中,我们需要传递一个字典\n",
290 | "model.model.config.quantization_config = quantization_config"
291 | ]
292 | },
293 | {
294 | "cell_type": "code",
295 | "execution_count": 8,
296 | "metadata": {},
297 | "outputs": [
298 | {
299 | "data": {
300 | "text/plain": [
301 | "('models/opt-125m-awq/tokenizer_config.json',\n",
302 | " 'models/opt-125m-awq/special_tokens_map.json',\n",
303 | " 'models/opt-125m-awq/vocab.json',\n",
304 | " 'models/opt-125m-awq/merges.txt',\n",
305 | " 'models/opt-125m-awq/added_tokens.json',\n",
306 | " 'models/opt-125m-awq/tokenizer.json')"
307 | ]
308 | },
309 | "execution_count": 8,
310 | "metadata": {},
311 | "output_type": "execute_result"
312 | }
313 | ],
314 | "source": [
315 | "# 保存模型权重\n",
316 | "model.save_quantized(quant_path)\n",
317 | "tokenizer.save_pretrained(quant_path) # 保存分词器"
318 | ]
319 | },
320 | {
321 | "cell_type": "code",
322 | "execution_count": null,
323 | "metadata": {},
324 | "outputs": [],
325 | "source": []
326 | },
327 | {
328 | "cell_type": "markdown",
329 | "metadata": {},
330 | "source": [
331 | "### 使用 GPU 加载量化模型"
332 | ]
333 | },
334 | {
335 | "cell_type": "code",
336 | "execution_count": 9,
337 | "metadata": {},
338 | "outputs": [],
339 | "source": [
340 | "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
341 | "\n",
342 | "tokenizer = AutoTokenizer.from_pretrained(quant_path)\n",
343 | "model = AutoModelForCausalLM.from_pretrained(quant_path, device_map=\"cuda\").to(0)"
344 | ]
345 | },
346 | {
347 | "cell_type": "code",
348 | "execution_count": 10,
349 | "metadata": {},
350 | "outputs": [],
351 | "source": [
352 | "def generate_text(text):\n",
353 | " inputs = tokenizer(text, return_tensors=\"pt\").to(0)\n",
354 | "\n",
355 | " out = model.generate(**inputs, max_new_tokens=64)\n",
356 | " return tokenizer.decode(out[0], skip_special_tokens=True)\n"
357 | ]
358 | },
359 | {
360 | "cell_type": "code",
361 | "execution_count": 11,
362 | "metadata": {},
363 | "outputs": [
364 | {
365 | "name": "stdout",
366 | "output_type": "stream",
367 | "text": [
368 | "Merry Christmas! I'm glad to be the son of the son of the son of the son of the son of the son of the son of the son of be of the son of the son of the son of the son of the son of the son of the son of the son of the son of the son of the son of the son of the\n"
369 | ]
370 | }
371 | ],
372 | "source": [
373 | "result = generate_text(\"Merry Christmas! I'm glad to\")\n",
374 | "print(result)"
375 | ]
376 | },
377 | {
378 | "cell_type": "code",
379 | "execution_count": 12,
380 | "metadata": {
381 | "id": "Z0hAXYanCDW3"
382 | },
383 | "outputs": [
384 | {
385 | "name": "stdout",
386 | "output_type": "stream",
387 | "text": [
388 | "The woman worked as a teacher at the school told me this, she will only told me something this. She's \" told her that she's not a child has not a child he said that she * is a child * has not a child\n"
389 | ]
390 | }
391 | ],
392 | "source": [
393 | "result = generate_text(\"The woman worked as a\")\n",
394 | "print(result)"
395 | ]
396 | },
397 | {
398 | "cell_type": "markdown",
399 | "metadata": {},
400 | "source": [
401 | "## Homework:使用 AWQ 算法量化 Facebook OPT-2.7B 模型\n",
402 | "\n",
403 | "Facebook OPT 模型:https://huggingface.co/facebook?search_models=opt"
404 | ]
405 | },
406 | {
407 | "cell_type": "code",
408 | "execution_count": null,
409 | "metadata": {},
410 | "outputs": [],
411 | "source": []
412 | }
413 | ],
414 | "metadata": {
415 | "accelerator": "GPU",
416 | "colab": {
417 | "gpuType": "T4",
418 | "provenance": []
419 | },
420 | "kernelspec": {
421 | "display_name": "Python 3 (ipykernel)",
422 | "language": "python",
423 | "name": "python3"
424 | },
425 | "language_info": {
426 | "codemirror_mode": {
427 | "name": "ipython",
428 | "version": 3
429 | },
430 | "file_extension": ".py",
431 | "mimetype": "text/x-python",
432 | "name": "python",
433 | "nbconvert_exporter": "python",
434 | "pygments_lexer": "ipython3",
435 | "version": "3.11.5"
436 | }
437 | },
438 | "nbformat": 4,
439 | "nbformat_minor": 4
440 | }
441 |
--------------------------------------------------------------------------------
/quantization/AWQ_opt-2.7b.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "H_D9kG_efts3"
7 | },
8 | "source": [
9 | "# Transformers 模型量化技术:AWQ(OPT-2.7B)"
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {
15 | "id": "WE9IhcVyktah"
16 | },
17 | "source": [
18 | ""
19 | ]
20 | },
21 | {
22 | "attachments": {},
23 | "cell_type": "markdown",
24 | "metadata": {
25 | "id": "Wwsg6nCwoThm"
26 | },
27 | "source": [
28 | "在2023年6月,Ji Lin等人发表了论文 [AWQ:Activation-aware Weight Quantization for LLM Compression and Acceleration](https://arxiv.org/pdf/2306.00978.pdf)。\n",
29 | "\n",
30 | "这篇论文详细介绍了一种激活感知权重量化算法,可以用于压缩任何基于 Transformer 的语言模型,同时只有微小的性能下降。关于 AWQ 算法的详细介绍,见[MIT Han Song 教授分享](https://hanlab.mit.edu/projects/awq)。\n",
31 | "\n",
32 | "transformers 现在支持两个不同的 AWQ 开源实现库:\n",
33 | "\n",
34 | "- [AutoAWQ](https://github.com/casper-hansen/AutoAWQ)\n",
35 | "- [LLM-AWQ](https://github.com/mit-han-lab/llm-awq) \n"
36 | ]
37 | },
38 | {
39 | "cell_type": "markdown",
40 | "metadata": {
41 | "id": "-H2019RkoiM-"
42 | },
43 | "source": [
44 | "因为 LLM-AWQ 不支持 Nvidia T4 GPU(课程演示 GPU),所以我们使用 AutoAWQ 库来介绍和演示 AWQ 模型量化技术。"
45 | ]
46 | },
47 | {
48 | "cell_type": "markdown",
49 | "metadata": {
50 | "id": "6dJJRQ2p7eLQ"
51 | },
52 | "source": [
53 | "## 使用 AutoAWQ 量化模型\n",
54 | "\n",
55 | "下面我们以 `facebook opt-2.7B` 模型为例,使用 `AutoAWQ` 库实现的 AWQ 算法实现模型量化。"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": 7,
61 | "metadata": {},
62 | "outputs": [],
63 | "source": [
64 | "from awq import AutoAWQForCausalLM\n",
65 | "from transformers import AutoTokenizer\n",
66 | "\n",
67 | "model_name_or_path = \"facebook/opt-2.7b\"\n",
68 | "quant_model_dir = \"models/opt-2.7b-awq\"\n",
69 | "\n",
70 | "quant_config = {\n",
71 | " \"zero_point\": True,\n",
72 | " \"q_group_size\": 128,\n",
73 | " \"w_bit\": 4,\n",
74 | " \"version\": \"GEMM\"\n",
75 | "}"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": 3,
81 | "metadata": {},
82 | "outputs": [
83 | {
84 | "data": {
85 | "application/vnd.jupyter.widget-view+json": {
86 | "model_id": "379a0bba9ee74953b5e1facf448da666",
87 | "version_major": 2,
88 | "version_minor": 0
89 | },
90 | "text/plain": [
91 | "Fetching 8 files: 0%| | 0/8 [00:00, ?it/s]"
92 | ]
93 | },
94 | "metadata": {},
95 | "output_type": "display_data"
96 | }
97 | ],
98 | "source": [
99 | "# 加载模型\n",
100 | "model = AutoAWQForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True)\n",
101 | "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)"
102 | ]
103 | },
104 | {
105 | "cell_type": "code",
106 | "execution_count": 4,
107 | "metadata": {
108 | "id": "Qn_P_E5p7gAN"
109 | },
110 | "outputs": [
111 | {
112 | "name": "stderr",
113 | "output_type": "stream",
114 | "text": [
115 | "/root/miniconda3/lib/python3.11/site-packages/huggingface_hub/repocard.py:105: UserWarning: Repo card metadata block was not found. Setting CardData to empty.\n",
116 | " warnings.warn(\"Repo card metadata block was not found. Setting CardData to empty.\")\n",
117 | "AWQ: 100%|██████████| 32/32 [16:38<00:00, 31.21s/it]\n"
118 | ]
119 | }
120 | ],
121 | "source": [
122 | "# 量化模型\n",
123 | "model.quantize(tokenizer, quant_config=quant_config)"
124 | ]
125 | },
126 | {
127 | "cell_type": "markdown",
128 | "metadata": {},
129 | "source": [
130 | "### 实测 AWQ 量化模型:GPU显存占用峰值超过10GB\n",
131 | "\n",
132 | "\n",
133 | "```shell\n",
134 | "Sun Dec 24 15:21:35 2023\n",
135 | "+---------------------------------------------------------------------------------------+\n",
136 | "| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\n",
137 | "|-----------------------------------------+----------------------+----------------------+\n",
138 | "| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
139 | "| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
140 | "| | | MIG M. |\n",
141 | "|=========================================+======================+======================|\n",
142 | "| 0 Tesla T4 Off | 00000000:00:0D.0 Off | 0 |\n",
143 | "| N/A 53C P0 71W / 70W | 7261MiB / 15360MiB | 97% Default |\n",
144 | "| | | N/A |\n",
145 | "+-----------------------------------------+----------------------+----------------------+\n",
146 | "\n",
147 | "+---------------------------------------------------------------------------------------+\n",
148 | "| Processes: |\n",
149 | "| GPU GI CI PID Type Process name GPU Memory |\n",
150 | "| ID ID Usage |\n",
151 | "|=======================================================================================|```"
152 | ]
153 | },
154 | {
155 | "cell_type": "code",
156 | "execution_count": 8,
157 | "metadata": {
158 | "id": "nVzKDBlP_6MV"
159 | },
160 | "outputs": [
161 | {
162 | "data": {
163 | "text/plain": [
164 | "{'zero_point': True, 'q_group_size': 128, 'w_bit': 4, 'version': 'GEMM'}"
165 | ]
166 | },
167 | "execution_count": 8,
168 | "metadata": {},
169 | "output_type": "execute_result"
170 | }
171 | ],
172 | "source": [
173 | "quant_config"
174 | ]
175 | },
176 | {
177 | "cell_type": "markdown",
178 | "metadata": {
179 | "id": "PuPLq9sa8EaN"
180 | },
181 | "source": [
182 | "#### Transformers 兼容性配置\n",
183 | "\n",
184 | "为了使`quant_config` 与 transformers 兼容,我们需要修改配置文件:`使用 Transformers.AwqConfig 来实例化量化模型配置`"
185 | ]
186 | },
187 | {
188 | "cell_type": "code",
189 | "execution_count": 9,
190 | "metadata": {
191 | "id": "KE8xjwlL8DnA"
192 | },
193 | "outputs": [],
194 | "source": [
195 | "from transformers import AwqConfig, AutoConfig\n",
196 | "\n",
197 | "# 修改配置文件以使其与transformers集成兼容\n",
198 | "quantization_config = AwqConfig(\n",
199 | " bits=quant_config[\"w_bit\"],\n",
200 | " group_size=quant_config[\"q_group_size\"],\n",
201 | " zero_point=quant_config[\"zero_point\"],\n",
202 | " version=quant_config[\"version\"].lower(),\n",
203 | ").to_dict()\n",
204 | "\n",
205 | "# 预训练的transformers模型存储在model属性中,我们需要传递一个字典\n",
206 | "model.model.config.quantization_config = quantization_config"
207 | ]
208 | },
209 | {
210 | "cell_type": "code",
211 | "execution_count": 10,
212 | "metadata": {},
213 | "outputs": [
214 | {
215 | "data": {
216 | "text/plain": [
217 | "('models/opt-2.7b-awq/tokenizer_config.json',\n",
218 | " 'models/opt-2.7b-awq/special_tokens_map.json',\n",
219 | " 'models/opt-2.7b-awq/vocab.json',\n",
220 | " 'models/opt-2.7b-awq/merges.txt',\n",
221 | " 'models/opt-2.7b-awq/added_tokens.json',\n",
222 | " 'models/opt-2.7b-awq/tokenizer.json')"
223 | ]
224 | },
225 | "execution_count": 10,
226 | "metadata": {},
227 | "output_type": "execute_result"
228 | }
229 | ],
230 | "source": [
231 | "# 保存模型权重\n",
232 | "model.save_quantized(quant_model_dir)\n",
233 | "# 保存分词器\n",
234 | "tokenizer.save_pretrained(quant_model_dir) "
235 | ]
236 | },
237 | {
238 | "cell_type": "code",
239 | "execution_count": 11,
240 | "metadata": {},
241 | "outputs": [
242 | {
243 | "data": {
244 | "text/plain": [
245 | "OptAWQForCausalLM(\n",
246 | " (model): OPTForCausalLM(\n",
247 | " (model): OPTModel(\n",
248 | " (decoder): OPTDecoder(\n",
249 | " (embed_tokens): Embedding(50272, 2560, padding_idx=1)\n",
250 | " (embed_positions): OPTLearnedPositionalEmbedding(2050, 2560)\n",
251 | " (final_layer_norm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)\n",
252 | " (layers): ModuleList(\n",
253 | " (0-31): 32 x OPTDecoderLayer(\n",
254 | " (self_attn): OPTAttention(\n",
255 | " (k_proj): WQLinear_GEMM(in_features=2560, out_features=2560, bias=True, w_bit=4, group_size=128)\n",
256 | " (v_proj): WQLinear_GEMM(in_features=2560, out_features=2560, bias=True, w_bit=4, group_size=128)\n",
257 | " (q_proj): WQLinear_GEMM(in_features=2560, out_features=2560, bias=True, w_bit=4, group_size=128)\n",
258 | " (out_proj): WQLinear_GEMM(in_features=2560, out_features=2560, bias=True, w_bit=4, group_size=128)\n",
259 | " )\n",
260 | " (activation_fn): ReLU()\n",
261 | " (self_attn_layer_norm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)\n",
262 | " (fc1): WQLinear_GEMM(in_features=2560, out_features=10240, bias=True, w_bit=4, group_size=128)\n",
263 | " (fc2): WQLinear_GEMM(in_features=10240, out_features=2560, bias=True, w_bit=4, group_size=128)\n",
264 | " (final_layer_norm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)\n",
265 | " )\n",
266 | " )\n",
267 | " )\n",
268 | " )\n",
269 | " (lm_head): Linear(in_features=2560, out_features=50272, bias=False)\n",
270 | " )\n",
271 | ")"
272 | ]
273 | },
274 | "execution_count": 11,
275 | "metadata": {},
276 | "output_type": "execute_result"
277 | }
278 | ],
279 | "source": [
280 | "model.eval()"
281 | ]
282 | },
283 | {
284 | "cell_type": "markdown",
285 | "metadata": {},
286 | "source": [
287 | "### 使用 GPU 加载量化模型"
288 | ]
289 | },
290 | {
291 | "cell_type": "code",
292 | "execution_count": 12,
293 | "metadata": {},
294 | "outputs": [],
295 | "source": [
296 | "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
297 | "\n",
298 | "tokenizer = AutoTokenizer.from_pretrained(quant_model_dir)\n",
299 | "model = AutoModelForCausalLM.from_pretrained(quant_model_dir, device_map=\"cuda\").to(0)"
300 | ]
301 | },
302 | {
303 | "cell_type": "code",
304 | "execution_count": 13,
305 | "metadata": {},
306 | "outputs": [],
307 | "source": [
308 | "def generate_text(text):\n",
309 | " inputs = tokenizer(text, return_tensors=\"pt\").to(0)\n",
310 | "\n",
311 | " out = model.generate(**inputs, max_new_tokens=64)\n",
312 | " return tokenizer.decode(out[0], skip_special_tokens=True)\n"
313 | ]
314 | },
315 | {
316 | "cell_type": "code",
317 | "execution_count": 14,
318 | "metadata": {},
319 | "outputs": [
320 | {
321 | "name": "stdout",
322 | "output_type": "stream",
323 | "text": [
324 | "Merry Christmas! I'm glad to. M-M-M-M-M-M-M-M-M-M-M- M-M-M-M-1-\n",
325 | "\n",
326 | "M-M-M-M-M-M-M-M-M-M-M-M-M-M-M\n"
327 | ]
328 | }
329 | ],
330 | "source": [
331 | "result = generate_text(\"Merry Christmas! I'm glad to\")\n",
332 | "print(result)"
333 | ]
334 | },
335 | {
336 | "cell_type": "code",
337 | "execution_count": 15,
338 | "metadata": {
339 | "id": "Z0hAXYanCDW3"
340 | },
341 | "outputs": [
342 | {
343 | "name": "stdout",
344 | "output_type": "stream",
345 | "text": [
346 | "The woman worked as a the the woman.\n",
347 | "The the man\n",
348 | "the the the the the woman\n"
349 | ]
350 | }
351 | ],
352 | "source": [
353 | "result = generate_text(\"The woman worked as a\")\n",
354 | "print(result)"
355 | ]
356 | },
357 | {
358 | "cell_type": "code",
359 | "execution_count": null,
360 | "metadata": {},
361 | "outputs": [],
362 | "source": []
363 | }
364 | ],
365 | "metadata": {
366 | "accelerator": "GPU",
367 | "colab": {
368 | "gpuType": "T4",
369 | "provenance": []
370 | },
371 | "kernelspec": {
372 | "display_name": "Python 3 (ipykernel)",
373 | "language": "python",
374 | "name": "python3"
375 | },
376 | "language_info": {
377 | "codemirror_mode": {
378 | "name": "ipython",
379 | "version": 3
380 | },
381 | "file_extension": ".py",
382 | "mimetype": "text/x-python",
383 | "name": "python",
384 | "nbconvert_exporter": "python",
385 | "pygments_lexer": "ipython3",
386 | "version": "3.11.5"
387 | }
388 | },
389 | "nbformat": 4,
390 | "nbformat_minor": 4
391 | }
392 |
--------------------------------------------------------------------------------
/quantization/bits_and_bytes.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "31c9be60-bb96-4af9-a5ee-1eb47b201d45",
6 | "metadata": {},
7 | "source": [
8 | "# Transformers 量化技术 BitsAndBytes\n",
9 | "\n",
10 | "\n",
11 | "\n",
12 | "`bitsandbytes`是将模型量化为8位和4位的最简单选择。 \n",
13 | "\n",
14 | "- 8位量化将fp16中的异常值与int8中的非异常值相乘,将非异常值转换回fp16,然后将它们相加以返回fp16中的权重。这减少了异常值对模型性能产生的降级效果。\n",
15 | "- 4位量化进一步压缩了模型,并且通常与QLoRA一起用于微调量化LLM(低精度语言模型)。\n",
16 | "\n",
17 | "(`异常值`是指大于某个阈值的隐藏状态值,这些值是以fp16进行计算的。虽然这些值通常服从正态分布([-3.5, 3.5]),但对于大型模型来说,该分布可能会有很大差异([-60, 6]或[6, 60])。8位量化适用于约为5左右的数值,但超过此范围后将导致显著性能损失。一个好的默认阈值是6,但对于不稳定的模型(小型模型或微调)可能需要更低的阈值。)\n",
18 | "\n",
19 | "## 在 Transformers 中使用参数量化\n",
20 | "\n",
21 | "使用 Transformers 库的 `model.from_pretrained()`方法中的`load_in_8bit`或`load_in_4bit`参数,便可以对模型进行量化。只要模型支持使用Accelerate加载并包含torch.nn.Linear层,这几乎适用于任何模态的任何模型。"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 2,
27 | "id": "c2385671-3e67-4fcb-9243-d4b1affea031",
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "from transformers import AutoModelForCausalLM\n",
32 | "\n",
33 | "model_id = \"facebook/opt-2.7b\"\n",
34 | "\n",
35 | "model_4bit = AutoModelForCausalLM.from_pretrained(model_id,\n",
36 | " device_map=\"auto\",\n",
37 | " load_in_4bit=True)"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 3,
43 | "id": "f4731ac5-fe26-471e-ad17-eb2ba42cb596",
44 | "metadata": {},
45 | "outputs": [
46 | {
47 | "data": {
48 | "text/plain": [
49 | "OPTForCausalLM(\n",
50 | " (model): OPTModel(\n",
51 | " (decoder): OPTDecoder(\n",
52 | " (embed_tokens): Embedding(50272, 2560, padding_idx=1)\n",
53 | " (embed_positions): OPTLearnedPositionalEmbedding(2050, 2560)\n",
54 | " (final_layer_norm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)\n",
55 | " (layers): ModuleList(\n",
56 | " (0-31): 32 x OPTDecoderLayer(\n",
57 | " (self_attn): OPTAttention(\n",
58 | " (k_proj): Linear4bit(in_features=2560, out_features=2560, bias=True)\n",
59 | " (v_proj): Linear4bit(in_features=2560, out_features=2560, bias=True)\n",
60 | " (q_proj): Linear4bit(in_features=2560, out_features=2560, bias=True)\n",
61 | " (out_proj): Linear4bit(in_features=2560, out_features=2560, bias=True)\n",
62 | " )\n",
63 | " (activation_fn): ReLU()\n",
64 | " (self_attn_layer_norm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)\n",
65 | " (fc1): Linear4bit(in_features=2560, out_features=10240, bias=True)\n",
66 | " (fc2): Linear4bit(in_features=10240, out_features=2560, bias=True)\n",
67 | " (final_layer_norm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)\n",
68 | " )\n",
69 | " )\n",
70 | " )\n",
71 | " )\n",
72 | " (lm_head): Linear(in_features=2560, out_features=50272, bias=False)\n",
73 | ")"
74 | ]
75 | },
76 | "execution_count": 3,
77 | "metadata": {},
78 | "output_type": "execute_result"
79 | }
80 | ],
81 | "source": [
82 | "model_4bit"
83 | ]
84 | },
85 | {
86 | "cell_type": "markdown",
87 | "id": "a198b216-b113-4851-a02f-f57be038e1ac",
88 | "metadata": {},
89 | "source": [
90 | "### 实测GPU显存占用:Int4 量化精度\n",
91 | "\n",
92 | "```shell\n",
93 | "Sun Dec 24 18:04:14 2023\n",
94 | "+---------------------------------------------------------------------------------------+\n",
95 | "| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\n",
96 | "|-----------------------------------------+----------------------+----------------------+\n",
97 | "| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
98 | "| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
99 | "| | | MIG M. |\n",
100 | "|=========================================+======================+======================|\n",
101 | "| 0 Tesla T4 Off | 00000000:00:0D.0 Off | 0 |\n",
102 | "| N/A 42C P0 26W / 70W | 1779MiB / 15360MiB | 0% Default |\n",
103 | "| | | N/A |\n",
104 | "+-----------------------------------------+----------------------+----------------------+\n",
105 | "\n",
106 | "+---------------------------------------------------------------------------------------+\n",
107 | "| Processes: |\n",
108 | "| GPU GI CI PID Type Process name GPU Memory |\n",
109 | "| ID ID Usage |\n",
110 | "|=======================================================================================|\n",
111 | "```"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": 4,
117 | "id": "1d504b78-9ea4-4100-b614-03dc3bbcb65b",
118 | "metadata": {},
119 | "outputs": [
120 | {
121 | "name": "stdout",
122 | "output_type": "stream",
123 | "text": [
124 | "1457.52MiB\n"
125 | ]
126 | }
127 | ],
128 | "source": [
129 | "# 获取当前模型占用的 GPU显存(差值为预留给 PyTorch 的显存)\n",
130 | "memory_footprint_bytes = model_4bit.get_memory_footprint()\n",
131 | "memory_footprint_mib = memory_footprint_bytes / (1024 ** 2) # 转换为 MiB\n",
132 | "\n",
133 | "print(f\"{memory_footprint_mib:.2f}MiB\")"
134 | ]
135 | },
136 | {
137 | "cell_type": "code",
138 | "execution_count": 5,
139 | "id": "8af2edef-9142-443b-b55c-b57872a1fc1d",
140 | "metadata": {},
141 | "outputs": [
142 | {
143 | "name": "stderr",
144 | "output_type": "stream",
145 | "text": [
146 | "/root/miniconda3/lib/python3.11/site-packages/bitsandbytes/nn/modules.py:226: UserWarning: Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference or training speed.\n",
147 | " warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference or training speed.')\n"
148 | ]
149 | },
150 | {
151 | "name": "stdout",
152 | "output_type": "stream",
153 | "text": [
154 | "Merry Christmas! I'm glad to see you're still around.\n",
155 | "I'm still around, just not posting as much. I'm still here, just not posting as much. I'm still here, just not posting as much. I'm still here, just not posting as much. I'm still here, just not posting as much. I'm\n"
156 | ]
157 | }
158 | ],
159 | "source": [
160 | "from transformers import AutoTokenizer\n",
161 | "\n",
162 | "tokenizer = AutoTokenizer.from_pretrained(model_id\n",
163 | " )\n",
164 | "text = \"Merry Christmas! I'm glad to\"\n",
165 | "inputs = tokenizer(text, return_tensors=\"pt\").to(0)\n",
166 | "\n",
167 | "out = model_4bit.generate(**inputs, max_new_tokens=64)\n",
168 | "print(tokenizer.decode(out[0], skip_special_tokens=True))"
169 | ]
170 | },
171 | {
172 | "cell_type": "markdown",
173 | "id": "21f299ea-77f6-45cc-82c9-87c96addda06",
174 | "metadata": {},
175 | "source": [
176 | "### 使用 NF4 精度加载模型"
177 | ]
178 | },
179 | {
180 | "cell_type": "code",
181 | "execution_count": 6,
182 | "id": "00249404-c60b-47a5-bcb9-a8a4b4b6266f",
183 | "metadata": {},
184 | "outputs": [],
185 | "source": [
186 | "from transformers import BitsAndBytesConfig\n",
187 | "\n",
188 | "nf4_config = BitsAndBytesConfig(\n",
189 | " load_in_4bit=True,\n",
190 | " bnb_4bit_quant_type=\"nf4\",\n",
191 | ")\n",
192 | "\n",
193 | "model_nf4 = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=nf4_config)"
194 | ]
195 | },
196 | {
197 | "cell_type": "code",
198 | "execution_count": 7,
199 | "id": "ab499752-4c53-4ab4-a6a1-1fdf88cbbd0e",
200 | "metadata": {},
201 | "outputs": [
202 | {
203 | "name": "stdout",
204 | "output_type": "stream",
205 | "text": [
206 | "1457.52MiB\n"
207 | ]
208 | }
209 | ],
210 | "source": [
211 | "# 获取当前模型占用的 GPU显存(差值为预留给 PyTorch 的显存)\n",
212 | "memory_footprint_bytes = model_nf4.get_memory_footprint()\n",
213 | "memory_footprint_mib = memory_footprint_bytes / (1024 ** 2) # 转换为 MiB\n",
214 | "\n",
215 | "print(f\"{memory_footprint_mib:.2f}MiB\")"
216 | ]
217 | },
218 | {
219 | "cell_type": "markdown",
220 | "id": "e7d335c9-9f13-4834-8008-af20a9f5ca56",
221 | "metadata": {},
222 | "source": [
223 | "### 使用双量化加载模型"
224 | ]
225 | },
226 | {
227 | "cell_type": "code",
228 | "execution_count": 8,
229 | "id": "f6bfa211-9ad8-4c7b-93a8-37cccaad975a",
230 | "metadata": {},
231 | "outputs": [],
232 | "source": [
233 | "double_quant_config = BitsAndBytesConfig(\n",
234 | " load_in_4bit=True,\n",
235 | " bnb_4bit_use_double_quant=True,\n",
236 | ")\n",
237 | "\n",
238 | "model_double_quant = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=double_quant_config)"
239 | ]
240 | },
241 | {
242 | "cell_type": "code",
243 | "execution_count": 9,
244 | "id": "dbb3913a-a4aa-4d65-8901-8bcf546f1e08",
245 | "metadata": {},
246 | "outputs": [
247 | {
248 | "name": "stdout",
249 | "output_type": "stream",
250 | "text": [
251 | "1457.52MiB\n"
252 | ]
253 | }
254 | ],
255 | "source": [
256 | "# 获取当前模型占用的 GPU显存(差值为预留给 PyTorch 的显存)\n",
257 | "memory_footprint_bytes = model_double_quant.get_memory_footprint()\n",
258 | "memory_footprint_mib = memory_footprint_bytes / (1024 ** 2) # 转换为 MiB\n",
259 | "\n",
260 | "print(f\"{memory_footprint_mib:.2f}MiB\")"
261 | ]
262 | },
263 | {
264 | "cell_type": "markdown",
265 | "id": "d8153e9d-a080-47df-af83-3f1582b2b367",
266 | "metadata": {},
267 | "source": [
268 | "### 使用 QLoRA 所有量化技术加载模型"
269 | ]
270 | },
271 | {
272 | "cell_type": "code",
273 | "execution_count": 10,
274 | "id": "a4bd4f3a-a7f9-4545-b6a9-732fd6f91b42",
275 | "metadata": {},
276 | "outputs": [],
277 | "source": [
278 | "import torch\n",
279 | "\n",
280 | "qlora_config = BitsAndBytesConfig(\n",
281 | " load_in_4bit=True,\n",
282 | " bnb_4bit_use_double_quant=True,\n",
283 | " bnb_4bit_quant_type=\"nf4\",\n",
284 | " bnb_4bit_compute_dtype=torch.bfloat16\n",
285 | ")\n",
286 | "\n",
287 | "model_qlora = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=qlora_config)"
288 | ]
289 | },
290 | {
291 | "cell_type": "code",
292 | "execution_count": 11,
293 | "id": "da8edf77-03cc-4303-a3c0-1b088e5ec958",
294 | "metadata": {},
295 | "outputs": [
296 | {
297 | "name": "stdout",
298 | "output_type": "stream",
299 | "text": [
300 | "1457.52MiB\n"
301 | ]
302 | }
303 | ],
304 | "source": [
305 | "# 获取当前模型占用的 GPU显存(差值为预留给 PyTorch 的显存)\n",
306 | "memory_footprint_bytes = model_qlora.get_memory_footprint()\n",
307 | "memory_footprint_mib = memory_footprint_bytes / (1024 ** 2) # 转换为 MiB\n",
308 | "\n",
309 | "print(f\"{memory_footprint_mib:.2f}MiB\")"
310 | ]
311 | },
312 | {
313 | "cell_type": "code",
314 | "execution_count": null,
315 | "id": "1bf04637-85f6-4ef1-a1bd-81448cd9325c",
316 | "metadata": {},
317 | "outputs": [],
318 | "source": []
319 | }
320 | ],
321 | "metadata": {
322 | "kernelspec": {
323 | "display_name": "Python 3 (ipykernel)",
324 | "language": "python",
325 | "name": "python3"
326 | },
327 | "language_info": {
328 | "codemirror_mode": {
329 | "name": "ipython",
330 | "version": 3
331 | },
332 | "file_extension": ".py",
333 | "mimetype": "text/x-python",
334 | "name": "python",
335 | "nbconvert_exporter": "python",
336 | "pygments_lexer": "ipython3",
337 | "version": "3.11.5"
338 | }
339 | },
340 | "nbformat": 4,
341 | "nbformat_minor": 5
342 | }
343 |
--------------------------------------------------------------------------------
/quantization/docs/images/qlora.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DjangoPeng/LLM-quickstart/bf3e50d82a104c2b5bfaea3f3d476e8b8193ca2c/quantization/docs/images/qlora.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.3.1
2 | transformers==4.37.2
3 | ffmpeg==1.4
4 | ffmpeg-python==0.2.0
5 | timm==0.9.12
6 | datasets==2.16.1
7 | evaluate==0.4.1
8 | scikit-learn==1.3.2
9 | pandas==2.1.1
10 | peft==0.7.1
11 | accelerate==0.26.1
12 | autoawq==0.2.2
13 | optimum==1.17.0
14 | auto-gptq==0.6.0
15 | bitsandbytes==0.41.3.post2
16 | jiwer==3.0.3
17 | soundfile==0.12.1
18 | librosa==0.10.1
19 | gradio==4.13.0
20 | trl==0.8.1
21 | openai==1.30.1
22 | langchain==0.2.0
23 | langchain-openai==0.1.7
24 | langchain-core==0.2.1
25 | langchain-community==0.2.0
--------------------------------------------------------------------------------
/transformers/data/audio/mlk.flac:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DjangoPeng/LLM-quickstart/bf3e50d82a104c2b5bfaea3f3d476e8b8193ca2c/transformers/data/audio/mlk.flac
--------------------------------------------------------------------------------
/transformers/data/image/cat-chonk.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DjangoPeng/LLM-quickstart/bf3e50d82a104c2b5bfaea3f3d476e8b8193ca2c/transformers/data/image/cat-chonk.jpeg
--------------------------------------------------------------------------------
/transformers/data/image/cat_dog.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DjangoPeng/LLM-quickstart/bf3e50d82a104c2b5bfaea3f3d476e8b8193ca2c/transformers/data/image/cat_dog.jpg
--------------------------------------------------------------------------------
/transformers/data/image/panda.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DjangoPeng/LLM-quickstart/bf3e50d82a104c2b5bfaea3f3d476e8b8193ca2c/transformers/data/image/panda.jpg
--------------------------------------------------------------------------------
/transformers/docs/images/bert-base-chinese.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DjangoPeng/LLM-quickstart/bf3e50d82a104c2b5bfaea3f3d476e8b8193ca2c/transformers/docs/images/bert-base-chinese.png
--------------------------------------------------------------------------------
/transformers/docs/images/bert.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DjangoPeng/LLM-quickstart/bf3e50d82a104c2b5bfaea3f3d476e8b8193ca2c/transformers/docs/images/bert.png
--------------------------------------------------------------------------------
/transformers/docs/images/bert_pretrain.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DjangoPeng/LLM-quickstart/bf3e50d82a104c2b5bfaea3f3d476e8b8193ca2c/transformers/docs/images/bert_pretrain.png
--------------------------------------------------------------------------------
/transformers/docs/images/full_nlp_pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DjangoPeng/LLM-quickstart/bf3e50d82a104c2b5bfaea3f3d476e8b8193ca2c/transformers/docs/images/full_nlp_pipeline.png
--------------------------------------------------------------------------------
/transformers/docs/images/gpt2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DjangoPeng/LLM-quickstart/bf3e50d82a104c2b5bfaea3f3d476e8b8193ca2c/transformers/docs/images/gpt2.png
--------------------------------------------------------------------------------
/transformers/docs/images/pipeline_advanced.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DjangoPeng/LLM-quickstart/bf3e50d82a104c2b5bfaea3f3d476e8b8193ca2c/transformers/docs/images/pipeline_advanced.png
--------------------------------------------------------------------------------
/transformers/docs/images/pipeline_func.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DjangoPeng/LLM-quickstart/bf3e50d82a104c2b5bfaea3f3d476e8b8193ca2c/transformers/docs/images/pipeline_func.png
--------------------------------------------------------------------------------
/transformers/docs/images/question_answering.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DjangoPeng/LLM-quickstart/bf3e50d82a104c2b5bfaea3f3d476e8b8193ca2c/transformers/docs/images/question_answering.png
--------------------------------------------------------------------------------