├── .dockerignore
├── .github
└── workflows
│ └── Build.yml
├── .gitignore
├── .gitmodules
├── CMakeLists.txt
├── Dockerfile
├── LICENSE
├── README.md
├── docker-compose.yaml
├── docs
├── benchmark.md
├── faq.md
├── fastllm_pytools.md
└── llama_cookbook.md
├── example
├── Android
│ └── LLMAssistant
│ │ ├── .gitignore
│ │ ├── .idea
│ │ ├── .gitignore
│ │ ├── .name
│ │ ├── compiler.xml
│ │ ├── dbnavigator.xml
│ │ ├── deploymentTargetDropDown.xml
│ │ ├── gradle.xml
│ │ ├── misc.xml
│ │ └── vcs.xml
│ │ ├── app
│ │ ├── .gitignore
│ │ ├── build.gradle
│ │ ├── libs
│ │ │ ├── arm64-v8a
│ │ │ │ └── libassistant.so
│ │ │ └── armeabi-v7a
│ │ │ │ └── libassistant.so
│ │ ├── proguard-rules.pro
│ │ ├── release
│ │ │ ├── app-arm64-v8a-release-unsigned.apk
│ │ │ ├── app-armeabi-v7a-release-unsigned.apk
│ │ │ ├── app-universal-release-unsigned.apk
│ │ │ └── app-x86-release-unsigned.apk
│ │ └── src
│ │ │ ├── androidTest
│ │ │ └── java
│ │ │ │ └── com
│ │ │ │ └── doujiao
│ │ │ │ └── xiaozhihuiassistant
│ │ │ │ └── ExampleInstrumentedTest.java
│ │ │ ├── main
│ │ │ ├── AndroidManifest.xml
│ │ │ ├── cpp
│ │ │ │ ├── CMakeLists.txt
│ │ │ │ ├── LLMChat.cpp
│ │ │ │ ├── LLMChat.h
│ │ │ │ ├── main.cpp
│ │ │ │ └── native-lib.cpp
│ │ │ ├── java
│ │ │ │ └── com
│ │ │ │ │ └── doujiao
│ │ │ │ │ ├── core
│ │ │ │ │ └── AssistantCore.java
│ │ │ │ │ └── xiaozhihuiassistant
│ │ │ │ │ ├── ChatMessage.java
│ │ │ │ │ ├── MainActivity.java
│ │ │ │ │ ├── adapter
│ │ │ │ │ ├── BaseViewHolder.java
│ │ │ │ │ └── MyAdapter.java
│ │ │ │ │ ├── utils
│ │ │ │ │ ├── PrefUtil.java
│ │ │ │ │ ├── StatusBarUtils.java
│ │ │ │ │ └── UriUtils.java
│ │ │ │ │ └── widget
│ │ │ │ │ ├── ChatPromptViewManager.java
│ │ │ │ │ ├── Location.java
│ │ │ │ │ ├── PromptView.java
│ │ │ │ │ ├── PromptViewHelper.java
│ │ │ │ │ └── location
│ │ │ │ │ ├── BottomCenterLocation.java
│ │ │ │ │ ├── ICalculateLocation.java
│ │ │ │ │ ├── TopCenterLocation.java
│ │ │ │ │ ├── TopLeftLocation.java
│ │ │ │ │ └── TopRightLocation.java
│ │ │ └── res
│ │ │ │ ├── drawable-v24
│ │ │ │ └── ic_launcher_foreground.xml
│ │ │ │ ├── drawable
│ │ │ │ ├── btnbg.xml
│ │ │ │ ├── editbg.xml
│ │ │ │ └── ic_launcher_background.xml
│ │ │ │ ├── layout
│ │ │ │ ├── activity_item_left.xml
│ │ │ │ ├── activity_item_right.xml
│ │ │ │ └── activity_main.xml
│ │ │ │ ├── mipmap-anydpi-v26
│ │ │ │ ├── ic_launcher.xml
│ │ │ │ └── ic_launcher_round.xml
│ │ │ │ ├── mipmap-hdpi
│ │ │ │ ├── ic_launcher.webp
│ │ │ │ └── ic_launcher_round.webp
│ │ │ │ ├── mipmap-mdpi
│ │ │ │ ├── ic_launcher.webp
│ │ │ │ └── ic_launcher_round.webp
│ │ │ │ ├── mipmap-xhdpi
│ │ │ │ ├── ic_launcher.webp
│ │ │ │ └── ic_launcher_round.webp
│ │ │ │ ├── mipmap-xxhdpi
│ │ │ │ ├── glm.png
│ │ │ │ ├── ic_launcher.webp
│ │ │ │ ├── ic_launcher_round.webp
│ │ │ │ └── me.png
│ │ │ │ ├── mipmap-xxxhdpi
│ │ │ │ ├── ic_launcher.webp
│ │ │ │ └── ic_launcher_round.webp
│ │ │ │ ├── values-night
│ │ │ │ └── themes.xml
│ │ │ │ └── values
│ │ │ │ ├── colors.xml
│ │ │ │ ├── strings.xml
│ │ │ │ └── themes.xml
│ │ │ └── test
│ │ │ └── java
│ │ │ └── com
│ │ │ └── doujiao
│ │ │ └── xiaozhihuiassistant
│ │ │ └── ExampleUnitTest.java
│ │ ├── build.gradle
│ │ ├── gradle.properties
│ │ ├── gradle
│ │ └── wrapper
│ │ │ ├── gradle-wrapper.jar
│ │ │ └── gradle-wrapper.properties
│ │ ├── gradlew
│ │ ├── gradlew.bat
│ │ └── settings.gradle
├── Qui
│ ├── FastLLM.cpp
│ ├── bin
│ │ ├── Qt5Core.dll
│ │ ├── Qt5Gui.dll
│ │ ├── Qt5Widgets.dll
│ │ ├── Qui.exe
│ │ ├── fastllm_cpu.exe
│ │ ├── fastllm_cuda.exe
│ │ ├── path.txt
│ │ ├── platforms
│ │ │ └── qwindows.dll
│ │ ├── qui_cn.qm
│ │ └── styles
│ │ │ └── qwindowsvistastyle.dll
│ └── src
│ │ ├── Qui.cpp
│ │ ├── Qui.h
│ │ ├── Qui.pro
│ │ ├── Qui.ui
│ │ ├── main.cpp
│ │ └── qui_cn.ts
├── README.md
├── Win32Demo
│ ├── StringUtils.h
│ ├── Win32Demo.cpp
│ ├── Win32Demo.sln
│ ├── Win32Demo.vcxproj
│ ├── bin
│ │ └── web
│ │ │ ├── css
│ │ │ ├── github-markdown-light.min.css
│ │ │ ├── github.min.css
│ │ │ ├── katex.min.css
│ │ │ └── texmath.css
│ │ │ ├── index.html
│ │ │ └── js
│ │ │ ├── highlight.min.js
│ │ │ ├── katex.min.js
│ │ │ ├── markdown-it-link-attributes.min.js
│ │ │ ├── markdown-it.min.js
│ │ │ └── texmath.js
│ ├── fastllm-gpu.vcxproj
│ ├── fastllm-gpu.vcxproj.filters
│ ├── fastllm.vcxproj
│ ├── fastllm.vcxproj.filters
│ └── httplib.h
├── apiserver
│ ├── apiserver.cpp
│ ├── json11.cpp
│ └── json11.hpp
├── benchmark
│ ├── benchmark.cpp
│ └── prompts
│ │ ├── beijing.txt
│ │ └── hello.txt
└── webui
│ ├── httplib.h
│ ├── web
│ ├── css
│ │ ├── github-markdown-light.min.css
│ │ ├── github.min.css
│ │ ├── katex.min.css
│ │ └── texmath.css
│ ├── index.html
│ └── js
│ │ ├── highlight.min.js
│ │ ├── katex.min.js
│ │ ├── markdown-it-link-attributes.min.js
│ │ ├── markdown-it.min.js
│ │ └── texmath.js
│ └── webui.cpp
├── include
├── device.h
├── devices
│ ├── cpu
│ │ ├── cpudevice.h
│ │ └── cputhreadpool.h
│ └── cuda
│ │ ├── cudadevice.h
│ │ └── fastllm-cuda.cuh
├── executor.h
├── fastllm.h
├── model.h
├── models
│ ├── basellm.h
│ ├── chatglm.h
│ ├── factoryllm.h
│ ├── glm.h
│ ├── internlm2.h
│ ├── llama.h
│ ├── minicpm.h
│ ├── moss.h
│ └── qwen.h
└── utils
│ ├── armMath.h
│ └── utils.h
├── main.cpp
├── pyfastllm
├── README.md
├── examples
│ ├── cli_low_level.py
│ ├── cli_simple.py
│ ├── convert_model.py
│ ├── test_chatglm2.py
│ ├── test_chatglm2_cpp.py
│ ├── test_chatglm2_func.py
│ ├── test_ops.py
│ ├── web_api.py
│ └── web_api_client.py
├── fastllm
│ ├── __init__.py
│ ├── convert.py
│ ├── functions
│ │ ├── __init__.py
│ │ ├── custom_ops.py
│ │ ├── fastllm_ops.py
│ │ ├── numpy_ops.py
│ │ └── util.py
│ ├── hub
│ │ ├── __init__.py
│ │ └── chatglm2.py
│ ├── models.py
│ ├── nn
│ │ ├── __init__.py
│ │ ├── base_module.py
│ │ └── modules.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── converter.py
│ │ ├── quantizer.py
│ │ └── writer.py
├── install.sh
└── setup.py
├── src
├── device.cpp
├── devices
│ ├── cpu
│ │ ├── cpudevice.cpp
│ │ └── cpudevicebatch.cpp
│ └── cuda
│ │ ├── cudadevice.cpp
│ │ ├── cudadevicebatch.cpp
│ │ └── fastllm-cuda.cu
├── executor.cpp
├── fastllm.cpp
├── model.cpp
├── models
│ ├── basellm.cpp
│ ├── chatglm.cpp
│ ├── glm.cpp
│ ├── internlm2.cpp
│ ├── llama.cpp
│ ├── minicpm.cpp
│ ├── moss.cpp
│ └── qwen.cpp
└── pybinding.cpp
├── test
├── cmmlu
│ ├── README.md
│ ├── baichuan.py
│ ├── categories.py
│ ├── chatglm.py
│ ├── eval.py
│ └── qwen.py
└── ops
│ └── cppOps.cpp
└── tools
├── fastllm_pytools
├── __init__.py
├── hf_model.py
├── llm.py
└── torch2flm.py
├── scripts
├── alpaca2flm.py
├── baichuan2_2flm.py
├── baichuan2flm.py
├── chatglm_export.py
├── cli_demo.py
├── glm_export.py
├── llamalike2flm.py
├── minicpm2flm.py
├── moss_export.py
├── qwen2flm.py
├── setup.py
└── web_demo.py
└── src
├── pytools.cpp
└── quant.cpp
/.dockerignore:
--------------------------------------------------------------------------------
1 | ./models
2 | ./build/
--------------------------------------------------------------------------------
/.github/workflows/Build.yml:
--------------------------------------------------------------------------------
1 | name: Action Build
2 | on: [push]
3 |
4 | jobs:
5 | build:
6 | runs-on: ubuntu-latest
7 |
8 | steps:
9 | - name: Set up JDK
10 | uses: actions/setup-java@v1
11 | with:
12 | java-version: '11'
13 |
14 | - name: Checkout code
15 | uses: actions/checkout@v2
16 |
17 | - name: Build with arm64-v8a
18 | run: |
19 | wget https://dl.google.com/android/repository/android-ndk-r21d-linux-x86_64.zip
20 | unzip android-ndk-r21d-linux-x86_64.zip
21 | export NDK=$GITHUB_WORKSPACE/android-ndk-r21d
22 | mkdir build-android
23 | cd build-android
24 | #ls ${NDK}/build/cmake/android.toolchain.cmake
25 | cmake -DCMAKE_MAKE_PROGRAM=/usr/bin/make -DCMAKE_CXX_COMPILER=/usr/bin/g++ -DCMAKE_TOOLCHAIN_FILE=${NDK}/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=android-23 -DCMAKE_CXX_FLAGS=-march=armv8.2a+dotprod ..
26 | make -j -B
27 | cp main fastllm-main-android
28 |
29 | - name: Build with x86
30 | run: |
31 | mkdir build-x86
32 | cd build-x86
33 | cmake .. -DUSE_CUDA=OFF
34 | make -j
35 | cp main fastllm-main-x86_64
36 |
37 | - name: Export and Upload Artifact
38 | uses: actions/upload-artifact@v2
39 | with:
40 | name: Output
41 | path: |
42 | build-android/fastllm-main-android
43 | build-x86/fastllm-main-x86_64
44 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.log
2 | *.pyc
3 | token
4 | /cmake-build-debug/
5 | /build-tfacc/
6 | /build-android/
7 | /build-py/
8 | /build/
9 | /pyfastllm/build/
10 | /pyfastllm/dist/
11 | /.idea/
12 | /.vscode/
13 | /example/Win32Demo/bin/*.*
14 | /example/Win32Demo/Win32
15 | /example/Win32Demo/x64
16 | /example/Win32Demo/*.filters
17 | /example/Win32Demo/*.user
18 | /example/Win32Demo/.vs
19 | /example/Android/LLMAssistant/*.iml
20 | /example/Android/LLMAssistant/.gradle
21 | /example/Android/LLMAssistant/local.properties
22 | /example/Android/LLMAssistant/.idea/caches
23 | /example/Android/LLMAssistant/.idea/libraries
24 | /example/Android/LLMAssistant/.idea/modules.xml
25 | /example/Android/LLMAssistant/.idea/workspace.xml
26 | /example/Android/LLMAssistant/.idea/navEditor.xml
27 | /example/Android/LLMAssistant/.idea/assetWizardSettings.xml
28 | /example/Android/LLMAssistant/.DS_Store
29 | /example/Android/LLMAssistant/build
30 | /example/Android/LLMAssistant/captures
31 | /example/Android/LLMAssistant/.externalNativeBuild
32 | /example/Android/LLMAssistant/.cxx
33 | /example/Android/LLMAssistant/local.properties
34 | /test/cmmlu/results/
35 | /models/
36 | /localtest/
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "third_party/pybind11"]
2 | path = third_party/pybind11
3 | url = https://github.com/pybind/pybind11.git
4 | branch = v2.10.5
5 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # syntax=docker/dockerfile:1-labs
2 | FROM nvidia/cuda:12.1.0-devel-ubuntu22.04
3 |
4 | # Update Apt repositories
5 | RUN apt-get update
6 |
7 | # Install and configure Python
8 | RUN apt-get -y --no-install-recommends install wget build-essential python3.10 python3-pip
9 | RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1
10 | RUN pip install setuptools streamlit-chat
11 |
12 | ENV WORKDIR /fastllm
13 |
14 | # Install cmake
15 | RUN wget -c https://cmake.org/files/LatestRelease/cmake-3.28.3-linux-x86_64.sh && bash ./cmake-3.28.3-linux-x86_64.sh --skip-license --prefix=/usr/
16 |
17 | WORKDIR $WORKDIR
18 | ADD . $WORKDIR/
19 |
20 | RUN mkdir $WORKDIR/build && cd build && cmake .. -DUSE_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=native && make -j && cd tools && python setup.py install
21 |
22 | CMD /fastllm/build/webui -p /models/chatglm2-6b-int8.flm
23 |
--------------------------------------------------------------------------------
/docker-compose.yaml:
--------------------------------------------------------------------------------
1 | version: '3.8'
2 | services:
3 | fastllm:
4 | build:
5 | context: .
6 | args:
7 | DOCKER_BUILDKIT: 0
8 | # privileged: true
9 | platforms:
10 | - "linux/amd64"
11 | tags:
12 | - "fastllm:v0.9"
13 | restart: always
14 | ports:
15 | - 11234:8081
16 | volumes:
17 | - ./models/:/models/
18 | command: /fastllm/build/webui -p /models/chatglm2-6b-int8.flm -w ./example/webui/web
19 |
20 |
--------------------------------------------------------------------------------
/docs/benchmark.md:
--------------------------------------------------------------------------------
1 | ## 推理速度
2 |
3 | 可以使用benchmark程序进行测速,根据不同配置、不同输入,推理速度也会有一些差别
4 |
5 | 例如:
6 |
7 | ``` sh
8 | ./benchmark -p ~/chatglm-6b-int4.flm -f ../example/benchmark/prompts/beijing.txt -b 1
9 | ./benchmark -p ~/chatglm-6b-int8.flm -f ../example/benchmark/prompts/beijing.txt -b 1
10 | ./benchmark -p ~/chatglm-6b-fp16.flm -f ../example/benchmark/prompts/hello.txt -b 512 -l 18
11 | ```
12 |
13 | | 模型 | Data精度 | 平台 | Batch | 最大推理速度(token / s) |
14 | |-----------------:|---------|--------------------|-----------|---------------------:|
15 | | ChatGLM-6b-int4 | float32 | RTX 4090 | 1 | 176 |
16 | | ChatGLM-6b-int8 | float32 | RTX 4090 | 1 | 121 |
17 | | ChatGLM-6b-fp16 | float32 | RTX 4090 | 64 | 2919 |
18 | | ChatGLM-6b-fp16 | float32 | RTX 4090 | 256 | 7871 |
19 | | ChatGLM-6b-fp16 | float32 | RTX 4090 | 512 | 10209 |
20 | | ChatGLM-6b-int4 | float32 | Xiaomi 10 Pro - 4 Threads | 1 | 4 ~ 5 |
21 |
--------------------------------------------------------------------------------
/docs/faq.md:
--------------------------------------------------------------------------------
1 | # 常见问题
2 |
3 | ## CMAKE
4 |
5 | ### CMAKE_CUDA_ARCHITECTURES must be non-empty if set.
6 |
7 | **现象:**
8 |
9 | > CMake Error at cmake/Modules/CMakeDetermineCUDACompiler.cmake:277 (message):
10 | > CMAKE_CUDA_ARCHITECTURES must be non-empty if set.
11 | > Call Stack (most recent call first):
12 | > CMakeLists.txt:39 (enable_language)
13 |
14 | **解决办法:**
15 |
16 | 部分版本cmake存在该问题,需手动指定`CMAKE_CUDA_ARCHITECTURES`。执行:
17 |
18 | ```shell
19 | cmake .. -DUSE_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=native
20 | ```
21 |
22 | ### Unsupported gpu architecture 'compute_native'
23 |
24 | **现象:**
25 |
26 | > nvcc fatal : Unsupported gpu architecture 'compute_native'
27 |
28 | **解决办法:**
29 |
30 | 手动修改 CMakeLists.txt,根据GPU型号手动指定GPU的[Compute Capability](https://developer.nvidia.com/cuda-gpus)。如:
31 |
32 | ``` diff
33 | --- a/CMakeLists.txt
34 | +++ b/CMakeLists.txt
35 | @@ -52,7 +52,7 @@
36 | #message(${CMAKE_CUDA_IMPLICIT_LINK_DIRECTORIES})
37 | set(FASTLLM_CUDA_SOURCES src/devices/cuda/cudadevice.cpp src/devices/cuda/cudadevicebatch.cpp src/devices/cuda/fastllm-cuda.cu)
38 | set(FASTLLM_LINKED_LIBS ${FASTLLM_LINKED_LIBS} cublas)
39 | - set(CMAKE_CUDA_ARCHITECTURES "native")
40 | + set(CMAKE_CUDA_ARCHITECTURES 61 75 86 89)
41 | endif()
42 |
43 | if (PY_API)
44 | ```
45 |
46 | ### identifier "__hdiv" is undefined
47 |
48 | **现象:**
49 |
50 | > src/devices/cuda/fastllm-cuda.cu(247): error: identifier "hexp" is undefined
51 | > src/devices/cuda/fastllm-cuda.cu(247): error: identifier "__hdiv" is undefined
52 | > ...
53 |
54 | **原因:** [计算能力(Compute Capability)](https://developer.nvidia.com/cuda-gpus) <= 5.3 的GPU不支持半精度计算。
55 |
56 | **解决办法:** 如需要支持这些GPU,执行cmake时使用编译选项`CUDA_NO_TENSOR_CORE`:
57 |
58 | ```shell
59 | cmake .. -DUSE_CUDA=ON -DCUDA_NO_TENSOR_CORE=ON
60 | ```
61 |
62 | ## Windows
63 |
64 | ### fastllm.h error
65 |
66 | **现象:**
67 |
68 | > include\fastllm.h(50): error : identifier "top_k" is undefined
69 | > include\fastllm.h(172): error : expected a "}"
70 | > include\fastllm.h(234): error : identifier "DataDevice" is undefined
71 | > ....
72 |
73 | **解决办法:** 参考 [example\README.md](/example/README.md)。签出代码后,**修改 include/fastllm.h**,Visual Studio中点击”文件“ -> "高级保存选项",在编码中选择”Unicode (UTF-8 **带签名**) -代码页 65001“,或在其他文本编辑器中转为”UTF-8 BOM“编码。(由于linux下gcc不识别BOM头,MSVC依赖BOM判断文件编码,该修改只能手动处理。)
74 |
75 | ### main.exe 无法识别中文输入
76 |
77 | **原因:** Windows下cmd不支持UTF-8编码,
78 |
79 | **解决办法:** 编译[Win32Demo](/example/README.md#win32demo-windows平台) 或使用 [WebUI](/example/README.md#web-ui)
80 |
81 | ### Windows(MSVC)编译下,int4出现乱码
82 |
83 | **原因:** MSVC编译器优化选项 "`/Ob2`"、"`/Ob3`"与的现有代码冲突,
84 |
85 | **解决办法:** 编译时,在”属性“中找到"C/C++" -> "优化" -> "内联函数扩展" 中选择“只适用于 __inline (/Ob1)”。
86 |
87 | ### 导入提示 FileNotFoundError
88 |
89 | **现象:**
90 |
91 | > File "...Python\lib\ctypes\_\_init\_\_.py", line 374, in \_\_init\_\_
92 | > self._handle = _dlopen(self._name, mode)
93 | > FileNotFoundError: Could not find module 'tools\fastllm_pytools\fastllm_tools.dll' (or one of its dependencies). Try using the full path with constructor syntax.
94 |
95 | **解决办法:** 非CPU编译时,部分版本的python存在这一问题。
96 |
97 | GPU编译时,根据使用的CUDA版本,将cudart cublas的相关dll文件复制到fastllm_tools同一目录下,例如:
98 |
99 | * CUDA 9.2
100 | * %CUDA_PATH%\bin\cublas64_92.dll
101 | * %CUDA_PATH%\bin\cudart64_92.dll
102 | * CUDA 11.x
103 | * %CUDA_PATH%\bin\cudart64_110.dll
104 | * %CUDA_PATH%\bin\cublas64_11.dll
105 | * %CUDA_PATH%\bin\cublasLt64_11.dll
106 | * CUDA 12.x
107 | * %CUDA_PATH%\bin\cudart64_12.dll
108 | * %CUDA_PATH%\bin\cublas64_12.dll
109 | * %CUDA_PATH%\bin\cublasLt64_12.dll
110 |
111 | ## fastllm_pytools
112 |
113 | ### 释放内存报错: CUDA error when release memory
114 |
115 | **现象:**
116 | 退出时报错:
117 | > Error: CUDA error when release memory!
118 | > CUDA error = 4, cudaErrorCudartUnloading at fastllm/src/devices/cuda/fastllm-cuda.cu:1493
119 | > 'driver shutting down'
120 |
121 | **原因:** python解释器在终止时常常会优先终止自己的进程,而没有现先析构调用的第三方库,因此在退出python时CUDA Runtime已关闭,释放显存操作失败。由于大多数时候显存已释放,并不会引起问题。
122 |
123 | **解决办法:** python程序退出时,先显式调用 `llm.release_memory()`方法。
124 |
--------------------------------------------------------------------------------
/docs/fastllm_pytools.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/docs/fastllm_pytools.md
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/.gitignore:
--------------------------------------------------------------------------------
1 | *.iml
2 | .gradle
3 | /local.properties
4 | /.idea/caches
5 | /.idea/libraries
6 | /.idea/modules.xml
7 | /.idea/workspace.xml
8 | /.idea/navEditor.xml
9 | /.idea/assetWizardSettings.xml
10 | .DS_Store
11 | /build
12 | /captures
13 | .externalNativeBuild
14 | .cxx
15 | local.properties
16 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/.idea/.name:
--------------------------------------------------------------------------------
1 | XiaoZhihuiAssistant
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/.idea/compiler.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/.idea/deploymentTargetDropDown.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/.idea/gradle.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
19 |
20 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/.gitignore:
--------------------------------------------------------------------------------
1 | /build
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/build.gradle:
--------------------------------------------------------------------------------
1 | plugins {
2 | id 'com.android.application'
3 | }
4 |
5 | android {
6 | compileSdk 30
7 |
8 | defaultConfig {
9 | applicationId "com.doujiao.xiaozhihuiassistant"
10 | minSdk 21
11 | targetSdk 26
12 | versionCode 1
13 | versionName "1.0"
14 |
15 | testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
16 | externalNativeBuild {
17 | cmake {
18 | cppFlags '-std=c++11'
19 | }
20 | }
21 | ndk {
22 | abiFilters 'arm64-v8a','armeabi-v7a','x86'
23 | }
24 | }
25 |
26 | buildTypes {
27 | release {
28 | minifyEnabled false
29 | proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
30 | }
31 | }
32 | compileOptions {
33 | sourceCompatibility JavaVersion.VERSION_1_8
34 | targetCompatibility JavaVersion.VERSION_1_8
35 | }
36 | externalNativeBuild {
37 | cmake {
38 | path file('src/main/cpp/CMakeLists.txt')
39 | version '3.18.1'
40 | }
41 | }
42 | // sourceSets {
43 | // main {
44 | // // jnilib
45 | // jniLibs.srcDirs = ['libs']
46 | // }
47 | // }
48 | splits {
49 | abi {
50 | enable true
51 | reset()
52 | include 'arm64-v8a', 'armeabi-v7a','x86'
53 | universalApk true
54 | }
55 | }
56 | buildFeatures {
57 | viewBinding true
58 | }
59 | }
60 |
61 | dependencies {
62 |
63 | implementation 'com.android.support:appcompat-v7:28.0.0'
64 | implementation 'com.android.support:recyclerview-v7:28.0.0'
65 | implementation 'com.android.support.constraint:constraint-layout:2.0.4'
66 | testImplementation 'junit:junit:4.13.2'
67 | androidTestImplementation 'com.android.support.test:runner:1.0.2'
68 | androidTestImplementation 'com.android.support.test.espresso:espresso-core:3.0.2'
69 | }
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/libs/arm64-v8a/libassistant.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/example/Android/LLMAssistant/app/libs/arm64-v8a/libassistant.so
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/libs/armeabi-v7a/libassistant.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/example/Android/LLMAssistant/app/libs/armeabi-v7a/libassistant.so
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/proguard-rules.pro:
--------------------------------------------------------------------------------
1 | # Add project specific ProGuard rules here.
2 | # You can control the set of applied configuration files using the
3 | # proguardFiles setting in build.gradle.
4 | #
5 | # For more details, see
6 | # http://developer.android.com/guide/developing/tools/proguard.html
7 |
8 | # If your project uses WebView with JS, uncomment the following
9 | # and specify the fully qualified class name to the JavaScript interface
10 | # class:
11 | #-keepclassmembers class fqcn.of.javascript.interface.for.webview {
12 | # public *;
13 | #}
14 |
15 | # Uncomment this to preserve the line number information for
16 | # debugging stack traces.
17 | #-keepattributes SourceFile,LineNumberTable
18 |
19 | # If you keep the line number information, uncomment this to
20 | # hide the original source file name.
21 | #-renamesourcefileattribute SourceFile
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/release/app-arm64-v8a-release-unsigned.apk:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/example/Android/LLMAssistant/app/release/app-arm64-v8a-release-unsigned.apk
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/release/app-armeabi-v7a-release-unsigned.apk:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/example/Android/LLMAssistant/app/release/app-armeabi-v7a-release-unsigned.apk
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/release/app-universal-release-unsigned.apk:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/example/Android/LLMAssistant/app/release/app-universal-release-unsigned.apk
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/release/app-x86-release-unsigned.apk:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/example/Android/LLMAssistant/app/release/app-x86-release-unsigned.apk
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/androidTest/java/com/doujiao/xiaozhihuiassistant/ExampleInstrumentedTest.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.xiaozhihuiassistant;
2 |
3 | import android.content.Context;
4 | import android.support.test.InstrumentationRegistry;
5 | import android.support.test.runner.AndroidJUnit4;
6 |
7 | import org.junit.Test;
8 | import org.junit.runner.RunWith;
9 |
10 | import static org.junit.Assert.*;
11 |
12 | /**
13 | * Instrumented test, which will execute on an Android device.
14 | *
15 | * @see Testing documentation
16 | */
17 | @RunWith(AndroidJUnit4.class)
18 | public class ExampleInstrumentedTest {
19 | @Test
20 | public void useAppContext() {
21 | // Context of the app under test.
22 | Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
23 | assertEquals("com.doujiao.xiaozhihuiassistant", appContext.getPackageName());
24 | }
25 | }
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/AndroidManifest.xml:
--------------------------------------------------------------------------------
1 |
2 |
4 |
5 |
6 |
7 |
8 |
16 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/cpp/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # For more information about using CMake with Android Studio, read the
2 | # documentation: https://d.android.com/studio/projects/add-native-code.html
3 |
4 | # Sets the minimum version of CMake required to build the native library.
5 |
6 | cmake_minimum_required(VERSION 3.10.2)
7 |
8 | # Declares and names the project.
9 |
10 | project("assistant")
11 | set(CMAKE_BUILD_TYPE "Release")
12 |
13 | option(USE_CUDA "use cuda" OFF)
14 |
15 | option(PY_API "python api" OFF)
16 |
17 | #可以注释掉下面优化选项
18 | #add_definitions(${CMAKE_CXX_FLAGS} "${CMAKE_CXX_FLAGS} -march=armv8.2a+dotprod")
19 | #
20 | #file(GLOB_RECURSE NANODET_SOURCE ../../../../../../../src/*.cpp
21 | # ../../../../../../../src/devices/cpu/*.cpp
22 | # ../../../../../../../src/models/*.cpp)
23 | #
24 | #set(PROJECT_SOURCE
25 | # ${NANODET_SOURCE}
26 | # )
27 |
28 | set(PROJECT_SOURCE
29 | ../../../../../../../src/fastllm.cpp
30 | ../../../../../../../src/device.cpp
31 | ../../../../../../../src/model.cpp
32 | ../../../../../../../src/executor.cpp
33 | ../../../../../../../src/devices/cpu/cpudevice.cpp
34 | ../../../../../../../src/devices/cpu/cpudevicebatch.cpp
35 | ../../../../../../../src/models/basellm.cpp
36 | ../../../../../../../src/models/chatglm.cpp
37 | ../../../../../../../src/models/moss.cpp
38 | ../../../../../../../src/models/llama.cpp
39 | ../../../../../../../src/models/qwen.cpp
40 | ../../../../../../../src/models/glm.cpp
41 | ../../../../../../../src/models/minicpm.cpp
42 | )
43 |
44 | include_directories(
45 | ./
46 | ../../../../../../../include
47 | ../../../../../../../include/models
48 | ../../../../../../../include/utils
49 | ../../../../../../../include/devices/cpu)
50 |
51 | add_library( # Sets the name of the library.
52 | assistant
53 | # Sets the library as a shared library.
54 | SHARED
55 | # Provides a relative path to your source file(s).
56 | ${PROJECT_SOURCE} LLMChat.cpp native-lib.cpp)
57 |
58 | # Searches for a specified prebuilt library and stores the path as a
59 | # variable. Because CMake includes system libraries in the search path by
60 | # default, you only need to specify the name of the public NDK library
61 | # you want to add. CMake verifies that the library exists before
62 | # completing its build.
63 |
64 | find_library( # Sets the name of the path variable.
65 | log-lib
66 | # Specifies the name of the NDK library that
67 | # you want CMake to locate.
68 | log)
69 |
70 | # Specifies libraries CMake should link to your target library. You
71 | # can link multiple libraries, such as libraries you define in this
72 | # build script, prebuilt third-party libraries, or system libraries.
73 |
74 | target_link_libraries( # Specifies the target library.
75 | assistant
76 | # Links the target library to the log library
77 | # included in the NDK.
78 | ${log-lib})
79 |
80 | #add_executable(main main.cpp ../../../../../../../src/fastllm.cpp
81 | # ../../../../../../../src/chatglm.cpp
82 | # ../../../../../../../src/moss.cpp)
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/cpp/LLMChat.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include "LLMChat.h"
6 |
7 | #include "model.h"
8 | //void(^ __nonnull RuntimeChat)(int index,const char* _Nonnull content) = NULL;//实时回调
9 |
10 | static int modeltype = 0;
11 | static char* modelpath = NULL;
12 | static std::unique_ptr chatGlm = NULL;
13 | static int sRound = 0;
14 | static std::string history;
15 | static RuntimeResultMobile g_callback = NULL;
16 |
17 | std::string initGptConf(const char* modelPath,int threads) {
18 | fastllm::SetThreads(threads);
19 | LOG_Debug("@@init llmpath:%s\n",modelPath);
20 | chatGlm = fastllm::CreateLLMModelFromFile(modelPath);
21 | if(chatGlm != NULL)
22 | {
23 | std::string modelName = chatGlm->model_type;
24 | LOG_Debug("@@model name:%s\n",modelName.c_str());
25 | return modelName;
26 | }
27 | LOG_Debug("@@CreateLLMModelFromFile failed.");
28 | return "";
29 | }
30 |
31 | int chat(const char* prompt, RuntimeResultMobile chatCallback) {
32 | std::string ret = "";
33 | g_callback = chatCallback;
34 | LOG_Debug("@@init llm:type:%d,prompt:%s\n",modeltype,prompt);
35 | std::string input(prompt);
36 |
37 | if (input == "reset") {
38 | history = "";
39 | sRound = 0;
40 | g_callback(0,"Done!");
41 | g_callback(-1,"");
42 | return 0;
43 | }
44 |
45 | ret = chatGlm->Response(chatGlm->MakeInput(history, sRound, input), [](int index, const char* content) {
46 | g_callback(index,content);
47 | });
48 | history = chatGlm->MakeHistory(history, sRound, input, ret);
49 | sRound++;
50 |
51 | long len = ret.length();
52 | return len;
53 | }
54 |
55 | void uninitLLM()
56 | {
57 | }
58 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/cpp/LLMChat.h:
--------------------------------------------------------------------------------
1 | //
2 | // LLMChat.h
3 | // LLMChat
4 | //
5 | // Created by 胡其斌 on 2023/5/18.
6 | //
7 |
8 | #ifdef __cplusplus
9 | extern "C" {
10 | #endif
11 |
12 | #include
13 | #define LOG_Debug(...) __android_log_print(ANDROID_LOG_DEBUG, "Assistant", __VA_ARGS__)
14 |
15 | typedef void(* RuntimeResultMobile)(int index,const char* content);
16 |
17 | std::string initGptConf(const char* modelPath,int threads);
18 | int chat(const char* prompt, RuntimeResultMobile chatCallback);
19 | void uninitLLM();
20 |
21 | #ifdef __cplusplus
22 | }
23 | #endif
24 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/cpp/main.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 |
6 | #include "factoryllm.h"
7 |
8 | static factoryllm fllm;
9 | static int modeltype = 0;
10 | static char* modelpath = NULL;
11 | static fastllm::basellm* chatGlm = fllm.createllm(LLM_TYPE_CHATGLM);
12 | static fastllm::basellm* moss = fllm.createllm(LLM_TYPE_MOSS);
13 | static int sRound = 0;
14 | static std::string history;
15 |
16 | struct RunConfig {
17 | int model = LLM_TYPE_CHATGLM; // 模型类型, LLM_TYPE_CHATGLM:chatglm, LLM_TYPE_MOSS:moss
18 | std::string path = "/sdcard/chatglm-6b-int4.bin"; // 模型文件路径
19 | int threads = 6; // 使用的线程数
20 | };
21 |
22 | static struct option long_options[] = {
23 | {"help", no_argument, nullptr, 'h'},
24 | {"model", required_argument, nullptr, 'm'},
25 | {"path", required_argument, nullptr, 'p'},
26 | {"threads", required_argument, nullptr, 't'},
27 | {nullptr, 0, nullptr, 0},
28 | };
29 |
30 | void Usage() {
31 | std::cout << "Usage:" << std::endl;
32 | std::cout << "[-h|--help]: 显示帮助" << std::endl;
33 | std::cout << "<-m|--model> : 模型类型,默认为chatglm, 可以设置为0, moss:1" << std::endl;
34 | std::cout << "<-p|--path> : 模型文件的路径" << std::endl;
35 | std::cout << "<-t|--threads> : 使用的线程数量" << std::endl;
36 | }
37 |
38 | void ParseArgs(int argc, char **argv, RunConfig &config) {
39 | int opt;
40 | int option_index = 0;
41 | const char *opt_string = "h:m:p:t:";
42 |
43 | while ((opt = getopt_long_only(argc, argv, opt_string, long_options, &option_index)) != -1) {
44 | switch (opt) {
45 | case 'h':
46 | Usage();
47 | exit (0);
48 | case 'm':
49 | config.model = atoi(argv[optind - 1]);
50 | break;
51 | case 'p':
52 | config.path = argv[optind - 1];
53 | break;
54 | case 't':
55 | config.threads = atoi(argv[optind - 1]);
56 | break;
57 | default:
58 | Usage();
59 | exit (-1);
60 | }
61 | }
62 | }
63 |
64 | int initLLMConf(int model,const char* modelPath,int threads) {
65 | fastllm::SetThreads(threads);
66 | modeltype = model;
67 | // printf("@@init llm:type:%d,path:%s\n",model,modelPath);
68 | if (modeltype == 0) {
69 | chatGlm->LoadFromFile(modelPath);
70 | }
71 | if (modeltype == 1) {
72 | moss->LoadFromFile(modelPath);
73 | }
74 | return 0;
75 | }
76 |
77 | int chat(const char* prompt) {
78 | std::string ret = "";
79 | //printf("@@init llm:type:%d,prompt:%s\n",modeltype,prompt);
80 | std::string input(prompt);
81 | if (modeltype == 0) {
82 | if (input == "reset") {
83 | history = "";
84 | sRound = 0;
85 | return 0;
86 | }
87 | history += ("[Round " + std::to_string(sRound++) + "]\n问:" + input);
88 | auto prompt = sRound > 1 ? history : input;
89 | ret = chatGlm->Response(prompt,[](int index,const char* content){
90 |
91 | if(index == 0) {
92 | printf("ChatGLM:");
93 | }
94 | printf("%s", content);
95 | if (index == -1) {
96 | printf("\n");
97 | }
98 |
99 | });
100 | history += ("\n答:" + ret + "\n");
101 | }
102 |
103 | if (modeltype == 1) {
104 | auto prompt = "You are an AI assistant whose name is MOSS. <|Human|>: " + input + "";
105 | ret = moss->Response(prompt,[](int index,const char* content){
106 | if(index == 0) {
107 | printf("MOSS:");
108 | }
109 | printf("%s", content);
110 | if (index == -1) {
111 | printf("\n");
112 | }
113 | });
114 | }
115 | long len = ret.length();
116 | return len;
117 | }
118 |
119 | void uninitLLM()
120 | {
121 | if (chatGlm)
122 | {
123 | delete chatGlm;
124 | chatGlm = NULL;
125 | }
126 | if (moss)
127 | {
128 | delete moss;
129 | moss = NULL;
130 | }
131 | }
132 |
133 |
134 | int main(int argc, char **argv) {
135 | RunConfig config;
136 | ParseArgs(argc, argv, config);
137 |
138 | initLLMConf(config.model, config.path.c_str(), config.threads);
139 |
140 | if (config.model == LLM_TYPE_MOSS) {
141 |
142 | while (true) {
143 | printf("用户: ");
144 | std::string input;
145 | std::getline(std::cin, input);
146 | if (input == "stop") {
147 | break;
148 | }
149 | chat(input.c_str());
150 | }
151 | } else if (config.model == LLM_TYPE_CHATGLM) {
152 | while (true) {
153 | printf("用户: ");
154 | std::string input;
155 | std::getline(std::cin, input);
156 | if (input == "stop") {
157 | break;
158 | }
159 | chat(input.c_str());
160 | }
161 |
162 | } else {
163 | Usage();
164 | exit(-1);
165 | }
166 |
167 | return 0;
168 | }
169 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/cpp/native-lib.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include "LLMChat.h"
4 |
5 |
6 | JavaVM *g_javaVM = NULL;
7 | jobject g_obj;
8 |
9 | void initGvm(JNIEnv *env,jobject thiz) {
10 | if(g_javaVM == NULL) {
11 | env->GetJavaVM(&g_javaVM);
12 | g_obj = env->NewGlobalRef(thiz);
13 | }
14 | }
15 |
16 | void chatCb(int index,const char* content) {
17 | JNIEnv *env = NULL;
18 | int mNeedDetach = 0;
19 | //获取当前native线程是否有没有被附加到jvm环境中
20 | int getEnvStat = g_javaVM->GetEnv((void **)&env,JNI_VERSION_1_6);
21 | if (getEnvStat == JNI_EDETACHED) {
22 | //如果没有, 主动附加到jvm环境中,获取到env
23 | if (g_javaVM->AttachCurrentThread( &env, NULL) != 0) {
24 | LOG_Debug("Unable to AttachCurrentThread");
25 | return;
26 | }
27 | mNeedDetach = 1;
28 | }
29 | //通过全局变量g_obj 获取到要回调的类
30 | jclass javaClass = env->GetObjectClass(g_obj);//env->FindClass("com/doujiao/core/AssistantCore");//
31 | if (javaClass == 0) {
32 | LOG_Debug("Unable to find class");
33 | if(mNeedDetach) {
34 | g_javaVM->DetachCurrentThread();
35 | }
36 | return;
37 | }
38 | jmethodID jgetDBpathMethod = env->GetMethodID(javaClass, "reportChat", "(Ljava/lang/String;I)V");
39 | if (jgetDBpathMethod == NULL) {
40 | LOG_Debug("Unable to find method:jgetDBpathMethod");
41 | return;
42 | }
43 | jobject bb = env->NewDirectByteBuffer((void *) content, strlen(content));
44 | jclass cls_Charset = env->FindClass("java/nio/charset/Charset");
45 | jmethodID mid_Charset_forName = env->GetStaticMethodID(cls_Charset, "forName", "(Ljava/lang/String;)Ljava/nio/charset/Charset;");
46 | jobject charset = env->CallStaticObjectMethod(cls_Charset, mid_Charset_forName, env->NewStringUTF("UTF-8"));
47 |
48 | jmethodID mid_Charset_decode = env->GetMethodID(cls_Charset, "decode", "(Ljava/nio/ByteBuffer;)Ljava/nio/CharBuffer;");
49 | jobject cb = env->CallObjectMethod(charset, mid_Charset_decode, bb);
50 | env->DeleteLocalRef(bb);
51 |
52 | jclass cls_CharBuffer = env->FindClass("java/nio/CharBuffer");
53 | jmethodID mid_CharBuffer_toString = env->GetMethodID(cls_CharBuffer, "toString", "()Ljava/lang/String;");
54 | jstring str = static_cast(env->CallObjectMethod(cb, mid_CharBuffer_toString));
55 | env->CallVoidMethod(g_obj, jgetDBpathMethod,str,index);
56 | env->DeleteLocalRef(javaClass);
57 | //释放当前线程
58 | if(mNeedDetach) {
59 | g_javaVM->DetachCurrentThread();
60 | }
61 | env = NULL;
62 | }
63 |
64 | extern "C" JNIEXPORT jstring JNICALL
65 | Java_com_doujiao_core_AssistantCore_initLLMConfig(
66 | JNIEnv* env,
67 | jobject obj,
68 | jstring modelpath,
69 | jint threads) {
70 | initGvm(env,obj);
71 | const char *path = env->GetStringUTFChars(modelpath, NULL);
72 | std::string ret = initGptConf(path,threads);
73 | LOG_Debug("@@@initLLMConfig:%s",ret.c_str());
74 | env->ReleaseStringUTFChars( modelpath, path);
75 | return env->NewStringUTF(ret.c_str());
76 | }
77 |
78 | extern "C" JNIEXPORT jint JNICALL
79 | Java_com_doujiao_core_AssistantCore_chat(
80 | JNIEnv* env,
81 | jobject obj,
82 | jstring prompt) {
83 | initGvm(env,obj);
84 | const char *question = env->GetStringUTFChars(prompt, NULL);
85 | chat(question,[](int index,const char* content){
86 | chatCb(index,content);
87 | });
88 | // chatCb(1,"content");
89 | return 0;
90 | }
91 |
92 | extern "C" JNIEXPORT void JNICALL
93 | Java_com_doujiao_core_AssistantCore_uninitLLM(
94 | JNIEnv* env,
95 | jobject /* this */) {
96 | uninitLLM();
97 | }
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/java/com/doujiao/core/AssistantCore.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.core;
2 |
3 | import android.support.annotation.Keep;
4 | import android.util.Log;
5 |
6 | public class AssistantCore {
7 |
8 | private static AssistantCore instance = null;
9 | private static runtimeResult mRuntimeRes = null;
10 |
11 | static {
12 | System.loadLibrary("assistant");
13 | }
14 |
15 | /*静态对象*/
16 | public static AssistantCore getInstance(){
17 | if(instance == null){
18 | instance = new AssistantCore();
19 | }
20 |
21 | return instance;
22 | }
23 |
24 | public String initLLM(String path,runtimeResult callback) {
25 | mRuntimeRes = callback;
26 | return initLLMConfig(path,8);
27 | }
28 |
29 | @Keep
30 | public void reportChat(String content,int index) {
31 | Log.d("@@@","recv:"+content+",index:"+index);
32 | if (mRuntimeRes != null) {
33 | mRuntimeRes.callbackResult(index,content);
34 | }
35 | }
36 |
37 | public interface runtimeResult {
38 | void callbackResult(int index,String content);
39 | }
40 |
41 | private native String initLLMConfig(String path,int threads);
42 | public native int chat(String prompt);
43 | public native int uninitLLM();
44 | }
45 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/java/com/doujiao/xiaozhihuiassistant/ChatMessage.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.xiaozhihuiassistant;
2 |
3 | /**
4 | * Created by chenpengfei on 2016/10/27.
5 | */
6 | public class ChatMessage {
7 |
8 | private String content;
9 |
10 | private int type;
11 |
12 | public ChatMessage(String content, int type) {
13 | this.content = content;
14 | this.type = type;
15 | }
16 |
17 | public ChatMessage(String content) {
18 | this(content, 1);
19 | }
20 |
21 |
22 | public String getContent() {
23 | return content;
24 | }
25 |
26 | public void setContent(String content) {
27 | this.content = content;
28 | }
29 |
30 | public int getType() {
31 | return type;
32 | }
33 |
34 | public void setType(int type) {
35 | this.type = type;
36 | }
37 |
38 |
39 | }
40 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/java/com/doujiao/xiaozhihuiassistant/adapter/BaseViewHolder.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.xiaozhihuiassistant.adapter;
2 |
3 | import android.support.v7.widget.RecyclerView;
4 | import android.view.View;
5 |
6 | /**
7 | * Created by chenpengfei on 2016/10/27.
8 | */
9 | public class BaseViewHolder extends RecyclerView.ViewHolder {
10 |
11 | private View iv;
12 |
13 | public BaseViewHolder(View itemView) {
14 | super(itemView);
15 | iv = itemView;
16 | }
17 |
18 | public View findViewById(int id) {
19 | return iv.findViewById(id);
20 | }
21 | }
22 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/java/com/doujiao/xiaozhihuiassistant/adapter/MyAdapter.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.xiaozhihuiassistant.adapter;
2 |
3 | import android.app.Activity;
4 | import android.support.v7.widget.RecyclerView;
5 | import android.view.View;
6 | import android.view.ViewGroup;
7 | import android.widget.TextView;
8 | import android.widget.Toast;
9 |
10 |
11 | import com.doujiao.xiaozhihuiassistant.ChatMessage;
12 | import com.doujiao.xiaozhihuiassistant.R;
13 | import com.doujiao.xiaozhihuiassistant.utils.StatusBarUtils;
14 | import com.doujiao.xiaozhihuiassistant.widget.ChatPromptViewManager;
15 | import com.doujiao.xiaozhihuiassistant.widget.Location;
16 | import com.doujiao.xiaozhihuiassistant.widget.PromptViewHelper;
17 |
18 | import java.util.List;
19 |
20 | public class MyAdapter extends RecyclerView.Adapter {
21 |
22 | private List mChatMessageList = null;
23 | private Activity mActivity;
24 |
25 | public MyAdapter(Activity activity) {
26 | mActivity = activity;
27 | }
28 |
29 | public void setMessages(List chatMessageList) {
30 | mChatMessageList = chatMessageList;
31 | }
32 |
33 | @Override
34 | public BaseViewHolder onCreateViewHolder(ViewGroup parent, int viewType) {
35 | if(viewType == 1) {
36 | return new LeftViewHolder(View.inflate(mActivity, R.layout.activity_item_left, null));
37 | } else {
38 | return new RightViewHolder(View.inflate(mActivity, R.layout.activity_item_right, null));
39 | }
40 | }
41 |
42 | @Override
43 | public int getItemCount() {
44 | return mChatMessageList.size();
45 | }
46 |
47 | @Override
48 | public int getItemViewType(int position) {
49 | return mChatMessageList.get(position).getType();
50 | }
51 |
52 | @Override
53 | public void onBindViewHolder(BaseViewHolder holder, int position) {
54 | PromptViewHelper pvHelper = new PromptViewHelper(mActivity);
55 | ChatMessage chatMessage = mChatMessageList.get(position);
56 | if(holder instanceof LeftViewHolder) {
57 | LeftViewHolder leftViewHolder = (LeftViewHolder) holder;
58 | leftViewHolder.tv.setText(chatMessage.getContent());
59 | pvHelper.setPromptViewManager(new ChatPromptViewManager(mActivity));
60 | }
61 | if(holder instanceof RightViewHolder) {
62 | RightViewHolder rightViewHolder = (RightViewHolder) holder;
63 | rightViewHolder.tv.setText(chatMessage.getContent());
64 | pvHelper.setPromptViewManager(new ChatPromptViewManager(mActivity, Location.TOP_RIGHT));
65 | }
66 | pvHelper.addPrompt(holder.itemView.findViewById(R.id.textview_content));
67 | pvHelper.setOnItemClickListener(new PromptViewHelper.OnItemClickListener() {
68 | @Override
69 | public void onItemClick(int position) {
70 | String str = "";
71 | switch (position) {
72 | case 0:
73 | str = "已复制到剪贴板!";
74 | TextView tv = holder.itemView.findViewById(R.id.textview_content);
75 | StatusBarUtils.copyStr2ClibBoard(mActivity.getApplicationContext(), tv.getText().toString());
76 | break;
77 | }
78 | Toast.makeText(mActivity, str, Toast.LENGTH_SHORT).show();
79 | }
80 | });
81 | }
82 |
83 | class LeftViewHolder extends BaseViewHolder {
84 |
85 | TextView tv;
86 |
87 | public LeftViewHolder(View view) {
88 | super(view);
89 | tv = (TextView) findViewById(R.id.textview_content);
90 | }
91 | }
92 |
93 | class RightViewHolder extends BaseViewHolder {
94 |
95 | TextView tv;
96 |
97 | public RightViewHolder(View view) {
98 | super(view);
99 | tv = (TextView) findViewById(R.id.textview_content);
100 | }
101 | }
102 | }
103 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/java/com/doujiao/xiaozhihuiassistant/utils/PrefUtil.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.xiaozhihuiassistant.utils;
2 |
3 | import android.content.Context;
4 | import android.content.SharedPreferences;
5 |
6 | public class PrefUtil {
7 | private static final String SF_NAME = "com.doujiao.llm.config";
8 | private static final String MOLE_PATH = "llm_path";
9 | private static SharedPreferences mPref;
10 |
11 | public static void initPref(Context context) {
12 | if (mPref == null) {
13 | mPref = context.getSharedPreferences(SF_NAME, Context.MODE_PRIVATE);
14 | }
15 | }
16 |
17 | public static void setModelPath(String path) {
18 | if (mPref != null) {
19 | mPref.edit().putString(MOLE_PATH,path).apply();
20 | }
21 | }
22 |
23 | public static String getModelPath() {
24 | if (mPref != null) {
25 | return mPref.getString(MOLE_PATH,"");
26 | }
27 | return "";
28 | }
29 | }
30 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/java/com/doujiao/xiaozhihuiassistant/utils/StatusBarUtils.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.xiaozhihuiassistant.utils;
2 |
3 | import android.content.ClipData;
4 | import android.content.ClipboardManager;
5 | import android.content.Context;
6 | import android.graphics.Color;
7 | import android.os.Build;
8 | import android.support.v7.app.ActionBar;
9 | import android.support.v7.app.AppCompatActivity;
10 | import android.view.View;
11 | import android.view.Window;
12 | import android.view.WindowManager;
13 |
14 | public class StatusBarUtils {
15 | public static void setTranslucentStatus(AppCompatActivity activity) {
16 | if (Build.VERSION.SDK_INT >= 21) {
17 | View decorView = activity.getWindow().getDecorView();
18 | int option = View.SYSTEM_UI_FLAG_LAYOUT_FULLSCREEN
19 | | View.SYSTEM_UI_FLAG_LAYOUT_STABLE;
20 | decorView.setSystemUiVisibility(option);
21 | activity.getWindow().setStatusBarColor(Color.TRANSPARENT);
22 | }
23 | ActionBar actionBar = activity.getSupportActionBar();
24 | actionBar.hide();
25 | }
26 |
27 | public static void hideStatusBar(Window window, boolean darkText) {
28 | window.clearFlags(WindowManager.LayoutParams.FLAG_TRANSLUCENT_STATUS);
29 | window.addFlags(WindowManager.LayoutParams.FLAG_DRAWS_SYSTEM_BAR_BACKGROUNDS);
30 | window.setStatusBarColor(Color.TRANSPARENT);
31 |
32 | int flag = View.SYSTEM_UI_FLAG_LAYOUT_STABLE;
33 | if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M && darkText) {
34 | flag = View.SYSTEM_UI_FLAG_LIGHT_STATUS_BAR;
35 | }
36 |
37 | window.getDecorView().setSystemUiVisibility(flag |
38 | View.SYSTEM_UI_FLAG_LAYOUT_FULLSCREEN);
39 | }
40 |
41 | public static boolean copyStr2ClibBoard(Context context, String copyStr) {
42 | try {
43 | //获取剪贴板管理器
44 | ClipboardManager cm = (ClipboardManager) context.getSystemService(Context.CLIPBOARD_SERVICE);
45 | // 创建普通字符型ClipData
46 | ClipData mClipData = ClipData.newPlainText("Label", copyStr);
47 | // 将ClipData内容放到系统剪贴板里。
48 | cm.setPrimaryClip(mClipData);
49 | return true;
50 | } catch (Exception e) {
51 | return false;
52 | }
53 | }
54 | }
55 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/java/com/doujiao/xiaozhihuiassistant/widget/ChatPromptViewManager.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.xiaozhihuiassistant.widget;
2 |
3 | import android.app.Activity;
4 | import android.view.View;
5 |
6 | /**
7 | * Created by chenpengfei on 2016/11/2.
8 | */
9 | public class ChatPromptViewManager extends PromptViewHelper.PromptViewManager {
10 |
11 | public ChatPromptViewManager(Activity activity, String[] dataArray, Location location) {
12 | super(activity, dataArray, location);
13 | }
14 |
15 | public ChatPromptViewManager(Activity activity) {
16 | this(activity, new String[]{"复制"}, Location.TOP_LEFT);
17 | }
18 |
19 | public ChatPromptViewManager(Activity activity, Location location) {
20 | this(activity, new String[]{"复制"}, location);
21 | }
22 |
23 |
24 | @Override
25 | public View inflateView() {
26 | return new PromptView(activity);
27 | }
28 |
29 | @Override
30 | public void bindData(View view, String[] dataArray) {
31 | if(view instanceof PromptView) {
32 | PromptView promptView = (PromptView) view;
33 | promptView.setContentArray(dataArray);
34 | promptView.setOnItemClickListener(new PromptView.OnItemClickListener() {
35 | @Override
36 | public void onItemClick(int position) {
37 | if(onItemClickListener != null) onItemClickListener.onItemClick(position);
38 | }
39 | });
40 | }
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/java/com/doujiao/xiaozhihuiassistant/widget/Location.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.xiaozhihuiassistant.widget;
2 |
3 | import com.doujiao.xiaozhihuiassistant.widget.location.BottomCenterLocation;
4 | import com.doujiao.xiaozhihuiassistant.widget.location.ICalculateLocation;
5 | import com.doujiao.xiaozhihuiassistant.widget.location.TopCenterLocation;
6 | import com.doujiao.xiaozhihuiassistant.widget.location.TopLeftLocation;
7 | import com.doujiao.xiaozhihuiassistant.widget.location.TopRightLocation;
8 |
9 | /**
10 | * Created by chenpengfei on 2016/11/2.
11 | */
12 | public enum Location {
13 |
14 | TOP_LEFT(1),
15 | TOP_CENTER(2),
16 | TOP_RIGHT(3),
17 | BOTTOM_LEFT(4),
18 | BOTTOM_CENTER(5),
19 | BOTTOM_RIGHT(6);
20 |
21 | ICalculateLocation calculateLocation;
22 |
23 | private Location(int type) {
24 | switch (type) {
25 | case 1:
26 | calculateLocation = new TopLeftLocation();
27 | break;
28 | case 2:
29 | calculateLocation = new TopCenterLocation();
30 | break;
31 | case 3:
32 | calculateLocation = new TopRightLocation();
33 | break;
34 | case 4:
35 | calculateLocation = new TopLeftLocation();
36 | break;
37 | case 5:
38 | calculateLocation = new BottomCenterLocation();
39 | break;
40 | case 6:
41 | calculateLocation = new TopLeftLocation();
42 | break;
43 | }
44 | }
45 |
46 | }
47 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/java/com/doujiao/xiaozhihuiassistant/widget/PromptViewHelper.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.xiaozhihuiassistant.widget;
2 |
3 | import android.app.Activity;
4 | import android.graphics.Color;
5 | import android.graphics.drawable.ColorDrawable;
6 | import android.view.Gravity;
7 | import android.view.View;
8 | import android.view.ViewGroup;
9 | import android.view.ViewTreeObserver;
10 | import android.widget.PopupWindow;
11 |
12 | /**
13 | * Created by chenpengfei on 2016/11/2.
14 | */
15 | public class PromptViewHelper {
16 |
17 | private PromptViewManager promptViewManager;
18 | private Activity activity;
19 | private PopupWindow popupWindow;
20 | private boolean isShow;
21 | private OnItemClickListener onItemClickListener;
22 |
23 | public PromptViewHelper(Activity activity) {
24 | this.activity = activity;
25 | }
26 |
27 | public void setPromptViewManager(PromptViewManager promptViewManager) {
28 | this.promptViewManager = promptViewManager;
29 | this.promptViewManager.setOnItemClickListener(new OnItemClickListener() {
30 | @Override
31 | public void onItemClick(int position) {
32 | if(onItemClickListener != null && popupWindow != null) {
33 | onItemClickListener.onItemClick(position);
34 | popupWindow.dismiss();
35 | }
36 | }
37 | });
38 | }
39 |
40 | public void addPrompt(View srcView) {
41 | srcView.setOnLongClickListener(new View.OnLongClickListener() {
42 | @Override
43 | public boolean onLongClick(View v) {
44 | createPrompt(v);
45 | return true;
46 | }
47 | });
48 | }
49 |
50 | private void createPrompt(final View srcView) {
51 | final View promptView = promptViewManager.getPromptView();
52 | if(popupWindow == null)
53 | popupWindow = new PopupWindow(activity);
54 | popupWindow.setWindowLayoutMode(ViewGroup.LayoutParams.WRAP_CONTENT, ViewGroup.LayoutParams.WRAP_CONTENT);
55 | popupWindow.setTouchable(true);
56 | popupWindow.setOutsideTouchable(true);
57 | popupWindow.setBackgroundDrawable( new ColorDrawable(Color.TRANSPARENT));
58 | popupWindow.setContentView(promptView);
59 | final int[] location = new int[2];
60 | promptView.getViewTreeObserver().addOnGlobalLayoutListener(new ViewTreeObserver.OnGlobalLayoutListener() {
61 | @Override
62 | public void onGlobalLayout() {
63 | if(!isShow && popupWindow.isShowing()) {
64 | popupWindow.dismiss();
65 | show(srcView, promptView, location);
66 | isShow = true;
67 | }
68 | }
69 | });
70 | srcView.getLocationOnScreen(location);
71 | show(srcView, promptView, location);
72 | }
73 |
74 | public void show(View srcView, View promptView, int[] srcViewLocation) {
75 | int[] xy = promptViewManager.getLocation().calculateLocation.calculate(srcViewLocation, srcView, promptView);
76 | popupWindow.showAtLocation(srcView, Gravity.NO_GRAVITY, xy[0], xy[1]);
77 | }
78 |
79 |
80 | public static abstract class PromptViewManager {
81 |
82 | private View promptView;
83 | protected Activity activity;
84 | private String[] dataArray;
85 | private Location location;
86 | public OnItemClickListener onItemClickListener;
87 |
88 | public PromptViewManager(Activity activity, String[] dataArray, Location location) {
89 | this.activity = activity;
90 | this.dataArray = dataArray;
91 | this.location = location;
92 | init();
93 | }
94 |
95 | public void setOnItemClickListener(OnItemClickListener onItemClickListener) {
96 | this.onItemClickListener = onItemClickListener;
97 | }
98 |
99 | public void init() {
100 | promptView = inflateView();
101 | bindData(promptView, dataArray);
102 | }
103 |
104 | public abstract View inflateView();
105 |
106 | public abstract void bindData(View view, String[] dataArray);
107 |
108 | public View getPromptView() {
109 | return promptView;
110 | }
111 |
112 | public Location getLocation() {
113 | return location;
114 | }
115 | }
116 |
117 | public void setOnItemClickListener(OnItemClickListener onItemClickListener) {
118 | this.onItemClickListener = onItemClickListener;
119 | }
120 |
121 | public interface OnItemClickListener {
122 | void onItemClick(int position);
123 | }
124 | }
125 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/java/com/doujiao/xiaozhihuiassistant/widget/location/BottomCenterLocation.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.xiaozhihuiassistant.widget.location;
2 |
3 | import android.view.View;
4 |
5 | /**
6 | * Created by chenpengfei on 2016/11/2.
7 | */
8 | public class BottomCenterLocation implements ICalculateLocation {
9 |
10 | @Override
11 | public int[] calculate(int[] srcViewLocation, View srcView, View promptView) {
12 | int[] location = new int[2];
13 | int offset = (promptView.getWidth() - srcView.getWidth()) / 2;
14 | location[0] = srcViewLocation[0] - offset;
15 | location[1] = srcViewLocation[1] + promptView.getHeight();
16 | return location;
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/java/com/doujiao/xiaozhihuiassistant/widget/location/ICalculateLocation.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.xiaozhihuiassistant.widget.location;
2 |
3 | import android.view.View;
4 |
5 | /**
6 | * Created by chenpengfei on 2016/11/2.
7 | */
8 | public interface ICalculateLocation {
9 |
10 | int[] calculate(int[] srcViewLocation, View srcView, View promptView);
11 |
12 | }
13 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/java/com/doujiao/xiaozhihuiassistant/widget/location/TopCenterLocation.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.xiaozhihuiassistant.widget.location;
2 |
3 | import android.view.View;
4 |
5 | /**
6 | * Created by chenpengfei on 2016/11/2.
7 | */
8 | public class TopCenterLocation implements ICalculateLocation {
9 |
10 | @Override
11 | public int[] calculate(int[] srcViewLocation, View srcView, View promptView) {
12 | int[] location = new int[2];
13 | int offset = (promptView.getWidth() - srcView.getWidth()) / 2;
14 | location[0] = srcViewLocation[0] - offset;
15 | location[1] = srcViewLocation[1] - promptView.getHeight();
16 | return location;
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/java/com/doujiao/xiaozhihuiassistant/widget/location/TopLeftLocation.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.xiaozhihuiassistant.widget.location;
2 |
3 | import android.view.View;
4 |
5 | /**
6 | * Created by chenpengfei on 2016/11/2.
7 | */
8 | public class TopLeftLocation implements ICalculateLocation {
9 |
10 | @Override
11 | public int[] calculate(int[] srcViewLocation, View srcView, View promptView) {
12 | int[] location = new int[2];
13 | location[0] = srcViewLocation[0];
14 | location[1] = srcViewLocation[1] - promptView.getHeight();
15 | return location;
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/java/com/doujiao/xiaozhihuiassistant/widget/location/TopRightLocation.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.xiaozhihuiassistant.widget.location;
2 |
3 | import android.view.View;
4 |
5 | /**
6 | * Created by chenpengfei on 2016/11/2.
7 | */
8 | public class TopRightLocation implements ICalculateLocation {
9 |
10 | @Override
11 | public int[] calculate(int[] srcViewLocation, View srcView, View promptView) {
12 | int[] location = new int[2];
13 | int offset = promptView.getWidth() - srcView.getWidth();
14 | location[0] = srcViewLocation[0] - offset;
15 | location[1] = srcViewLocation[1] - promptView.getHeight();
16 | return location;
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/drawable-v24/ic_launcher_foreground.xml:
--------------------------------------------------------------------------------
1 |
7 |
8 |
9 |
15 |
18 |
21 |
22 |
23 |
24 |
30 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/drawable/btnbg.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | -
5 |
7 |
8 |
9 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/drawable/editbg.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | -
5 |
7 |
8 |
9 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/layout/activity_item_left.xml:
--------------------------------------------------------------------------------
1 |
2 |
6 |
7 |
11 |
12 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/layout/activity_item_right.xml:
--------------------------------------------------------------------------------
1 |
2 |
6 |
7 |
13 |
14 |
27 |
28 |
29 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/layout/activity_main.xml:
--------------------------------------------------------------------------------
1 |
2 |
9 |
10 |
17 |
21 |
26 |
34 |
35 |
36 |
43 |
44 |
52 |
60 |
69 |
70 |
71 |
72 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/mipmap-hdpi/ic_launcher.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/example/Android/LLMAssistant/app/src/main/res/mipmap-hdpi/ic_launcher.webp
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/mipmap-hdpi/ic_launcher_round.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/example/Android/LLMAssistant/app/src/main/res/mipmap-hdpi/ic_launcher_round.webp
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/mipmap-mdpi/ic_launcher.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/example/Android/LLMAssistant/app/src/main/res/mipmap-mdpi/ic_launcher.webp
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/mipmap-mdpi/ic_launcher_round.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/example/Android/LLMAssistant/app/src/main/res/mipmap-mdpi/ic_launcher_round.webp
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/mipmap-xhdpi/ic_launcher.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/example/Android/LLMAssistant/app/src/main/res/mipmap-xhdpi/ic_launcher.webp
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/mipmap-xhdpi/ic_launcher_round.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/example/Android/LLMAssistant/app/src/main/res/mipmap-xhdpi/ic_launcher_round.webp
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/mipmap-xxhdpi/glm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/example/Android/LLMAssistant/app/src/main/res/mipmap-xxhdpi/glm.png
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/mipmap-xxhdpi/ic_launcher.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/example/Android/LLMAssistant/app/src/main/res/mipmap-xxhdpi/ic_launcher.webp
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/example/Android/LLMAssistant/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.webp
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/mipmap-xxhdpi/me.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/example/Android/LLMAssistant/app/src/main/res/mipmap-xxhdpi/me.png
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/mipmap-xxxhdpi/ic_launcher.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/example/Android/LLMAssistant/app/src/main/res/mipmap-xxxhdpi/ic_launcher.webp
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/example/Android/LLMAssistant/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.webp
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/values-night/themes.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
10 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/values/colors.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | #FFBB86FC
4 | #FF6200EE
5 | #FF3700B3
6 | #FF03DAC5
7 | #FF018786
8 | #FF000000
9 | #FFFFFFFF
10 | #FFC0C0C0
11 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/values/strings.xml:
--------------------------------------------------------------------------------
1 |
2 | FastLLM
3 | 选择
4 | 模型
5 | 请从手机存储中选择llm模型
6 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/values/themes.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
10 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/test/java/com/doujiao/xiaozhihuiassistant/ExampleUnitTest.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.xiaozhihuiassistant;
2 |
3 | import org.junit.Test;
4 |
5 | import static org.junit.Assert.*;
6 |
7 | /**
8 | * Example local unit test, which will execute on the development machine (host).
9 | *
10 | * @see Testing documentation
11 | */
12 | public class ExampleUnitTest {
13 | @Test
14 | public void addition_isCorrect() {
15 | assertEquals(4, 2 + 2);
16 | }
17 | }
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/build.gradle:
--------------------------------------------------------------------------------
1 | // Top-level build file where you can add configuration options common to all sub-projects/modules.
2 | plugins {
3 | id 'com.android.application' version '7.1.2' apply false
4 | id 'com.android.library' version '7.1.2' apply false
5 | }
6 |
7 | task clean(type: Delete) {
8 | delete rootProject.buildDir
9 | }
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/gradle.properties:
--------------------------------------------------------------------------------
1 | # Project-wide Gradle settings.
2 | # IDE (e.g. Android Studio) users:
3 | # Gradle settings configured through the IDE *will override*
4 | # any settings specified in this file.
5 | # For more details on how to configure your build environment visit
6 | # http://www.gradle.org/docs/current/userguide/build_environment.html
7 | # Specifies the JVM arguments used for the daemon process.
8 | # The setting is particularly useful for tweaking memory settings.
9 | org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8
10 | # When configured, Gradle will run in incubating parallel mode.
11 | # This option should only be used with decoupled projects. More details, visit
12 | # http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
13 | # org.gradle.parallel=true
14 | # Enables namespacing of each library's R class so that its R class includes only the
15 | # resources declared in the library itself and none from the library's dependencies,
16 | # thereby reducing the size of the R class for that library
17 | android.nonTransitiveRClass=true
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/gradle/wrapper/gradle-wrapper.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/example/Android/LLMAssistant/gradle/wrapper/gradle-wrapper.jar
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/gradle/wrapper/gradle-wrapper.properties:
--------------------------------------------------------------------------------
1 | #Wed May 24 16:20:51 CST 2023
2 | distributionBase=GRADLE_USER_HOME
3 | distributionUrl=https\://services.gradle.org/distributions/gradle-7.2-bin.zip
4 | distributionPath=wrapper/dists
5 | zipStorePath=wrapper/dists
6 | zipStoreBase=GRADLE_USER_HOME
7 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/gradlew.bat:
--------------------------------------------------------------------------------
1 | @rem
2 | @rem Copyright 2015 the original author or authors.
3 | @rem
4 | @rem Licensed under the Apache License, Version 2.0 (the "License");
5 | @rem you may not use this file except in compliance with the License.
6 | @rem You may obtain a copy of the License at
7 | @rem
8 | @rem https://www.apache.org/licenses/LICENSE-2.0
9 | @rem
10 | @rem Unless required by applicable law or agreed to in writing, software
11 | @rem distributed under the License is distributed on an "AS IS" BASIS,
12 | @rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | @rem See the License for the specific language governing permissions and
14 | @rem limitations under the License.
15 | @rem
16 |
17 | @if "%DEBUG%" == "" @echo off
18 | @rem ##########################################################################
19 | @rem
20 | @rem Gradle startup script for Windows
21 | @rem
22 | @rem ##########################################################################
23 |
24 | @rem Set local scope for the variables with windows NT shell
25 | if "%OS%"=="Windows_NT" setlocal
26 |
27 | set DIRNAME=%~dp0
28 | if "%DIRNAME%" == "" set DIRNAME=.
29 | set APP_BASE_NAME=%~n0
30 | set APP_HOME=%DIRNAME%
31 |
32 | @rem Resolve any "." and ".." in APP_HOME to make it shorter.
33 | for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi
34 |
35 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
36 | set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"
37 |
38 | @rem Find java.exe
39 | if defined JAVA_HOME goto findJavaFromJavaHome
40 |
41 | set JAVA_EXE=java.exe
42 | %JAVA_EXE% -version >NUL 2>&1
43 | if "%ERRORLEVEL%" == "0" goto execute
44 |
45 | echo.
46 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
47 | echo.
48 | echo Please set the JAVA_HOME variable in your environment to match the
49 | echo location of your Java installation.
50 |
51 | goto fail
52 |
53 | :findJavaFromJavaHome
54 | set JAVA_HOME=%JAVA_HOME:"=%
55 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe
56 |
57 | if exist "%JAVA_EXE%" goto execute
58 |
59 | echo.
60 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
61 | echo.
62 | echo Please set the JAVA_HOME variable in your environment to match the
63 | echo location of your Java installation.
64 |
65 | goto fail
66 |
67 | :execute
68 | @rem Setup the command line
69 |
70 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
71 |
72 |
73 | @rem Execute Gradle
74 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %*
75 |
76 | :end
77 | @rem End local scope for the variables with windows NT shell
78 | if "%ERRORLEVEL%"=="0" goto mainEnd
79 |
80 | :fail
81 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
82 | rem the _cmd.exe /c_ return code!
83 | if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
84 | exit /b 1
85 |
86 | :mainEnd
87 | if "%OS%"=="Windows_NT" endlocal
88 |
89 | :omega
90 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/settings.gradle:
--------------------------------------------------------------------------------
1 | pluginManagement {
2 | repositories {
3 | gradlePluginPortal()
4 | google()
5 | mavenCentral()
6 | }
7 | }
8 | dependencyResolutionManagement {
9 | repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
10 | repositories {
11 | google()
12 | mavenCentral()
13 | }
14 | }
15 | rootProject.name = "XiaoZhihuiAssistant"
16 | include ':app'
17 |
--------------------------------------------------------------------------------
/example/Qui/bin/Qt5Core.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/example/Qui/bin/Qt5Core.dll
--------------------------------------------------------------------------------
/example/Qui/bin/Qt5Gui.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/example/Qui/bin/Qt5Gui.dll
--------------------------------------------------------------------------------
/example/Qui/bin/Qt5Widgets.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/example/Qui/bin/Qt5Widgets.dll
--------------------------------------------------------------------------------
/example/Qui/bin/Qui.exe:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/example/Qui/bin/Qui.exe
--------------------------------------------------------------------------------
/example/Qui/bin/fastllm_cpu.exe:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/example/Qui/bin/fastllm_cpu.exe
--------------------------------------------------------------------------------
/example/Qui/bin/fastllm_cuda.exe:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/example/Qui/bin/fastllm_cuda.exe
--------------------------------------------------------------------------------
/example/Qui/bin/path.txt:
--------------------------------------------------------------------------------
1 | C:/DEV/Prj/fastllm/fastllm-bin/Release
2 | C:/DEV/Prj/fastllm/fastllm/example/Qui/bin/fastllm_cuda.exe
3 |
--------------------------------------------------------------------------------
/example/Qui/bin/platforms/qwindows.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/example/Qui/bin/platforms/qwindows.dll
--------------------------------------------------------------------------------
/example/Qui/bin/qui_cn.qm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/example/Qui/bin/qui_cn.qm
--------------------------------------------------------------------------------
/example/Qui/bin/styles/qwindowsvistastyle.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/example/Qui/bin/styles/qwindowsvistastyle.dll
--------------------------------------------------------------------------------
/example/Qui/src/Qui.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include "ui_Qui.h"
9 | #include
10 | #include
11 |
12 | class Qui : public QWidget
13 | {
14 | Q_OBJECT
15 |
16 | public:
17 | Qui(QWidget *parent = nullptr);
18 | ~Qui() {};
19 |
20 | private slots:
21 | void clearChat();
22 | void wirteToFile();
23 | void sendAI();
24 | void onReset();
25 |
26 | void onPathSelectModel();
27 | void onPathSelectFlm();
28 |
29 | void onDeviceCheck();
30 |
31 | void runModelWithSetting();
32 | void stopModelRuning();
33 |
34 | bool eventFilter(QObject *obj, QEvent *event) override;
35 |
36 | void readData();
37 | void finishedProcess();
38 | void errorProcess();
39 |
40 | private:
41 | QProcess *process;
42 | void closeEvent(QCloseEvent *event) override;
43 |
44 | private:
45 | Ui::QuiClass ui;
46 |
47 | QString modelPath = "";
48 | QString flmPath = "";
49 | void updateModelList();
50 | bool inited = false;
51 | };
52 |
--------------------------------------------------------------------------------
/example/Qui/src/Qui.pro:
--------------------------------------------------------------------------------
1 | QT += core gui svg widgets network
2 | CONFIG += c++17
3 |
4 | APP_NAME = Qui
5 | DESTDIR = ../bin
6 |
7 |
8 | SOURCES += \
9 | Qui.cpp \
10 | main.cpp
11 |
12 | HEADERS += \
13 | Qui.h \
14 | resource.h
15 |
16 | FORMS += \
17 | Qui.ui
18 |
19 | TRANSLATIONS = qui_cn.ts
20 |
--------------------------------------------------------------------------------
/example/Qui/src/main.cpp:
--------------------------------------------------------------------------------
1 | #include "Qui.h"
2 | #include
3 |
4 | int
5 | main(int argc, char *argv[])
6 | {
7 | QApplication a(argc, argv);
8 |
9 | QTranslator tran;
10 |
11 | if (tran.load(QString("qui_cn.qm")))
12 | {
13 | a.installTranslator(&tran);
14 | }
15 |
16 | Qui w;
17 | w.show();
18 | return a.exec();
19 | }
20 |
--------------------------------------------------------------------------------
/example/README.md:
--------------------------------------------------------------------------------
1 | # example 示例项目
2 |
3 | ## Benchmark
4 |
5 | 测速示例程序,方便大家测试不同软硬件下的推理性能。作者测试的速度可以参考[这里](doc/benchmark.md)。
6 |
7 | 由于实际使用时很难满足batch的条件,也并非贪婪解码,该速度与真实使用时的速度有一定差异。
8 |
9 | ### 使用方法:
10 |
11 | CPU:
12 |
13 | `./benchmark -p chatglm-6b-int4.flm -f prompts.txt -t [线程数] --batch [Batch大小]`
14 |
15 | GPU:
16 |
17 | `./benchmark -p chatglm-6b-int4.flm -f prompts.txt --batch [Batch大小]`
18 |
19 |
20 |
21 | ## Web UI
22 |
23 | 由 Jacques CHEN 提供,鸣谢!
24 |
25 | ## Win32Demo (Windows平台)
26 |
27 | Win32Demo,是windows平台上运行FastLLM程序的一个Visual Studio工程。
28 | 由于Windows控制台默认编码为ANSI(中文是GBK编码,code page 936),而FastLLM默认输入输出编码为UTF-8,故与`main`存在一些差异,特提供专门的版本。为防止部分token是半个字符(如BPE编码),目前连续的中文字符是一并输出的。
29 |
30 | 生成的exe位置为:`Win32Demo\bin\Win32Demo.exe`
31 |
32 | 请尽量编译Release版本,速度快!
33 |
34 | 除此之外提供了fastllm的.vcproj文件,带GPU支持,本项目最低可在Visual Studio 2015 Update 3 下编译通过。
35 | (但是**编译pyfastllm至少需要 MSVC 2017**)
36 |
37 | ### 编译
38 |
39 | fastllm工程目前分为CPU版本和GPU版本,为简单上手,在没有cmake时,本项目可以使用Visual Studio工程文件并配置预处理器定义开关功能项。默认使用CPU版本。
40 |
41 | 签出代码后,**修改 include/fastllm.h**,Visual Studio中点击”文件“ -> "高级保存选项",在编码中选择”Unicode (UTF-8 **带签名**) -代码页 65001“,或在其他文本编辑器中转为”UTF-8 BOM“编码。(由于linux下gcc不识别BOM头,该修改只能手动处理。)
42 |
43 | * **CPU版本**:
44 | * 如果本机没有安装CUDA,在Win32Demo项目“属性”中找到"链接器" -> "输入" -> "附加依赖项",点击'从父级或项目设置继承'。
45 |
46 | * **GPU版本**:
47 | - 需要正确安装CUDA及其中的Visual Studio Integration;
48 | - 正确配置CUDA_PATH环境变量,指向要编译的CUDA版本;
49 | - 在解决方案资源管理器中移除fastllm.vcproj,引入fastllm-gpu.vcproj,
50 | - 对fastllm-gpu项目,在”生成依赖项“ -> "生成自定义" 中手动添加已安装的CUDA的自定义项文件;
51 | - 对fastllm-gpu项目,在”属性“中找到"CUDA C/C++" -> "Device" -> "Code Generation" 中配置编译后支持的[GPU计算能力](https://developer.nvidia.com/cuda-gpus#compute);
52 | - 在Win32Demo项目上选择”添加“ -> "引用“,勾选fastllm-gpu项目;
53 | - 在Win32Demo项目上配置预处理器定义”USE_CUDA“。
54 |
55 | ### 使用方法:
56 |
57 | 1. 打开命令提示符cmd;
58 |
59 | 2. `cd example\Win32Demo\bin` ;
60 |
61 | 3. 运行时参数与`main`基本一致。但多一个参数 `-w` ,表示启动webui,不加为控制台运行。如:
62 |
63 | `Win32Demo.exe -p c:\chatglm-6b-v1.1-int4.flm -w`
64 |
65 | ## Android (android平台)
66 | Android,使用Android studio工具建立的一個Android平台上运行LLM程序的例子。
67 |
68 | ### 使用方法:
69 |
70 | 1.在Android Studio直接打开工程运行。
71 |
72 | 2.直接下载release目录里里面的apk体验。
73 |
74 | 3.可以通过CMake工具链编译main文件(具体步骤见主页的readme),通过adb shell运行,
75 |
76 | 1. `adb push main /data/local/tmp` 将main文件放到手机的tmp文件夹,
77 | 2. `adb shell` ,
78 | 3. `cd /data/local/tmp`
79 | 4. `./main` 运行。
80 |
81 | 注意:demo apk 会将模型文件复制到应用 data 目录以方便 native 读取,因此设备需准备至少两倍模型大小的空余空间。
--------------------------------------------------------------------------------
/example/Win32Demo/StringUtils.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 |
5 | std::string utf2Gb(const char* szU8)
6 | {
7 | int wcsLen = ::MultiByteToWideChar(CP_UTF8, NULL, szU8, strlen(szU8), NULL, 0);
8 | wchar_t* wszString = new wchar_t[wcsLen + 1];
9 | ::MultiByteToWideChar(CP_UTF8, NULL, szU8, strlen(szU8), wszString, wcsLen);
10 | wszString[wcsLen] = '\0';
11 |
12 | wcsLen = WideCharToMultiByte(CP_ACP, 0, wszString, -1, NULL, 0, NULL, NULL);
13 | char* gb2312 = new char[wcsLen + 1];
14 | memset(gb2312, 0, wcsLen + 1);
15 | WideCharToMultiByte(CP_ACP, 0, wszString, -1, gb2312, wcsLen, NULL, NULL);
16 |
17 | if (wszString)
18 | delete[] wszString;
19 | std::string gbstr(gb2312);
20 | if (gb2312)
21 | delete[] gb2312;
22 | return gbstr;
23 | }
24 |
25 | std::string Gb2utf(std::string ws)
26 | {
27 | //int dwNum = WideCharToMultiByte(CP_UTF8, 0, ws.c_str(), -1, 0, 0, 0, 0);
28 | int dwNum = MultiByteToWideChar(CP_ACP, 0, ws.c_str(), -1, NULL, 0);
29 |
30 | wchar_t* wstr = new wchar_t[dwNum + 1];
31 | memset(wstr, 0, dwNum + 1);
32 | MultiByteToWideChar(CP_ACP, 0, ws.c_str(), -1, wstr, dwNum);
33 |
34 | dwNum = WideCharToMultiByte(CP_UTF8, 0, wstr, -1, NULL, 0, NULL, NULL);
35 | char* utf8 = new char[dwNum + 1];
36 | memset(utf8, 0, dwNum + 1);
37 | WideCharToMultiByte(CP_UTF8, 0, wstr, -1, utf8, dwNum, NULL, NULL);
38 | if (wstr)
39 | delete[] wstr;
40 | std::string str(utf8);
41 | if (utf8)
42 | delete[] utf8;
43 | return str;
44 | }
--------------------------------------------------------------------------------
/example/Win32Demo/Win32Demo.sln:
--------------------------------------------------------------------------------
1 |
2 | Microsoft Visual Studio Solution File, Format Version 12.00
3 | # Visual Studio 14
4 | VisualStudioVersion = 14.0.25420.1
5 | MinimumVisualStudioVersion = 10.0.40219.1
6 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "Win32Demo", "Win32Demo.vcxproj", "{B560ABA3-CCC2-42C6-8A99-43D7F811FEED}"
7 | EndProject
8 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "fastllm", "fastllm.vcxproj", "{BDA13DDF-572F-4FAD-B7A9-80EA5CAC3F2B}"
9 | EndProject
10 | Global
11 | GlobalSection(SolutionConfigurationPlatforms) = preSolution
12 | Debug|x64 = Debug|x64
13 | Debug|x86 = Debug|x86
14 | Release|x64 = Release|x64
15 | Release|x86 = Release|x86
16 | EndGlobalSection
17 | GlobalSection(ProjectConfigurationPlatforms) = postSolution
18 | {B560ABA3-CCC2-42C6-8A99-43D7F811FEED}.Debug|x64.ActiveCfg = Debug|x64
19 | {B560ABA3-CCC2-42C6-8A99-43D7F811FEED}.Debug|x64.Build.0 = Debug|x64
20 | {B560ABA3-CCC2-42C6-8A99-43D7F811FEED}.Debug|x86.ActiveCfg = Debug|Win32
21 | {B560ABA3-CCC2-42C6-8A99-43D7F811FEED}.Debug|x86.Build.0 = Debug|Win32
22 | {B560ABA3-CCC2-42C6-8A99-43D7F811FEED}.Release|x64.ActiveCfg = Release|x64
23 | {B560ABA3-CCC2-42C6-8A99-43D7F811FEED}.Release|x64.Build.0 = Release|x64
24 | {B560ABA3-CCC2-42C6-8A99-43D7F811FEED}.Release|x86.ActiveCfg = Release|Win32
25 | {B560ABA3-CCC2-42C6-8A99-43D7F811FEED}.Release|x86.Build.0 = Release|Win32
26 | {BDA13DDF-572F-4FAD-B7A9-80EA5CAC3F2B}.Debug|x64.ActiveCfg = Debug|x64
27 | {BDA13DDF-572F-4FAD-B7A9-80EA5CAC3F2B}.Debug|x64.Build.0 = Debug|x64
28 | {BDA13DDF-572F-4FAD-B7A9-80EA5CAC3F2B}.Debug|x86.ActiveCfg = Debug|Win32
29 | {BDA13DDF-572F-4FAD-B7A9-80EA5CAC3F2B}.Debug|x86.Build.0 = Debug|Win32
30 | {BDA13DDF-572F-4FAD-B7A9-80EA5CAC3F2B}.Release|x64.ActiveCfg = Release|x64
31 | {BDA13DDF-572F-4FAD-B7A9-80EA5CAC3F2B}.Release|x64.Build.0 = Release|x64
32 | {BDA13DDF-572F-4FAD-B7A9-80EA5CAC3F2B}.Release|x86.ActiveCfg = Release|Win32
33 | {BDA13DDF-572F-4FAD-B7A9-80EA5CAC3F2B}.Release|x86.Build.0 = Release|Win32
34 | EndGlobalSection
35 | GlobalSection(SolutionProperties) = preSolution
36 | HideSolutionNode = FALSE
37 | EndGlobalSection
38 | GlobalSection(ExtensibilityGlobals) = postSolution
39 | SolutionGuid = {48072E1E-35C1-46A9-9F40-FF4CDF872E08}
40 | EndGlobalSection
41 | EndGlobal
42 |
--------------------------------------------------------------------------------
/example/Win32Demo/bin/web/css/github.min.css:
--------------------------------------------------------------------------------
1 | pre code.hljs{display:block;overflow-x:auto;padding:1em}code.hljs{padding:3px 5px}/*!
2 | Theme: GitHub
3 | Description: Light theme as seen on github.com
4 | Author: github.com
5 | Maintainer: @Hirse
6 | Updated: 2021-05-15
7 |
8 | Outdated base version: https://github.com/primer/github-syntax-light
9 | Current colors taken from GitHub's CSS
10 | */.hljs{color:#24292e;background:#fff}.hljs-doctag,.hljs-keyword,.hljs-meta .hljs-keyword,.hljs-template-tag,.hljs-template-variable,.hljs-type,.hljs-variable.language_{color:#d73a49}.hljs-title,.hljs-title.class_,.hljs-title.class_.inherited__,.hljs-title.function_{color:#6f42c1}.hljs-attr,.hljs-attribute,.hljs-literal,.hljs-meta,.hljs-number,.hljs-operator,.hljs-selector-attr,.hljs-selector-class,.hljs-selector-id,.hljs-variable{color:#005cc5}.hljs-meta .hljs-string,.hljs-regexp,.hljs-string{color:#032f62}.hljs-built_in,.hljs-symbol{color:#e36209}.hljs-code,.hljs-comment,.hljs-formula{color:#6a737d}.hljs-name,.hljs-quote,.hljs-selector-pseudo,.hljs-selector-tag{color:#22863a}.hljs-subst{color:#24292e}.hljs-section{color:#005cc5;font-weight:700}.hljs-bullet{color:#735c0f}.hljs-emphasis{color:#24292e;font-style:italic}.hljs-strong{color:#24292e;font-weight:700}.hljs-addition{color:#22863a;background-color:#f0fff4}.hljs-deletion{color:#b31d28;background-color:#ffeef0}
--------------------------------------------------------------------------------
/example/Win32Demo/bin/web/css/texmath.css:
--------------------------------------------------------------------------------
1 | /* style for html inside of browsers */
2 | .katex { font-size: 1em !important; } /* align KaTeX font-size to surrounding text */
3 |
4 | eq { display: inline-block; }
5 | eqn { display: block}
6 | section.eqno {
7 | display: flex;
8 | flex-direction: row;
9 | align-content: space-between;
10 | align-items: center;
11 | }
12 | section.eqno > eqn {
13 | width: 100%;
14 | margin-left: 3em;
15 | }
16 | section.eqno > span {
17 | width:3em;
18 | text-align:right;
19 | }
20 |
--------------------------------------------------------------------------------
/example/Win32Demo/bin/web/js/markdown-it-link-attributes.min.js:
--------------------------------------------------------------------------------
1 | (function(f){if(typeof exports==="object"&&typeof module!=="undefined"){module.exports=f()}else if(typeof define==="function"&&define.amd){define([],f)}else{var g;if(typeof window!=="undefined"){g=window}else if(typeof global!=="undefined"){g=global}else if(typeof self!=="undefined"){g=self}else{g=this}g.markdownitLinkAttributes=f()}})(function(){var define,module,exports;return function(){function r(e,n,t){function o(i,f){if(!n[i]){if(!e[i]){var c="function"==typeof require&&require;if(!f&&c)return c(i,!0);if(u)return u(i,!0);var a=new Error("Cannot find module '"+i+"'");throw a.code="MODULE_NOT_FOUND",a}var p=n[i]={exports:{}};e[i][0].call(p.exports,function(r){var n=e[i][1][r];return o(n||r)},p,p.exports,r,e,n,t)}return n[i].exports}for(var u="function"==typeof require&&require,i=0;i eqn {
13 | width: 100%;
14 | margin-left: 3em;
15 | }
16 | section.eqno > span {
17 | width:3em;
18 | text-align:right;
19 | }
20 |
--------------------------------------------------------------------------------
/example/webui/web/js/markdown-it-link-attributes.min.js:
--------------------------------------------------------------------------------
1 | (function(f){if(typeof exports==="object"&&typeof module!=="undefined"){module.exports=f()}else if(typeof define==="function"&&define.amd){define([],f)}else{var g;if(typeof window!=="undefined"){g=window}else if(typeof global!=="undefined"){g=global}else if(typeof self!=="undefined"){g=self}else{g=this}g.markdownitLinkAttributes=f()}})(function(){var define,module,exports;return function(){function r(e,n,t){function o(i,f){if(!n[i]){if(!e[i]){var c="function"==typeof require&&require;if(!f&&c)return c(i,!0);if(u)return u(i,!0);var a=new Error("Cannot find module '"+i+"'");throw a.code="MODULE_NOT_FOUND",a}var p=n[i]={exports:{}};e[i][0].call(p.exports,function(r){var n=e[i][1][r];return o(n||r)},p,p.exports,r,e,n,t)}return n[i].exports}for(var u="function"==typeof require&&require,i=0;i
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 | #include
14 |
15 | struct WebConfig {
16 | std::string path = "chatglm-6b-int4.bin"; // 模型文件路径
17 | std::string webPath = "web"; // 网页文件路径
18 | int threads = 4; // 使用的线程数
19 | bool lowMemMode = false; // 是否使用低内存模式
20 | int port = 8081; // 端口号
21 | };
22 |
23 | void Usage() {
24 | std::cout << "Usage:" << std::endl;
25 | std::cout << "[-h|--help]: 显示帮助" << std::endl;
26 | std::cout << "<-p|--path> : 模型文件的路径" << std::endl;
27 | std::cout << "<-w|--web> : 网页文件的路径" << std::endl;
28 | std::cout << "<-t|--threads> : 使用的线程数量" << std::endl;
29 | std::cout << "<-l|--low>: 使用低内存模式" << std::endl;
30 | std::cout << "<--port> : 网页端口号" << std::endl;
31 | }
32 |
33 | void ParseArgs(int argc, char **argv, WebConfig &config) {
34 | std::vector sargv;
35 | for (int i = 0; i < argc; i++) {
36 | sargv.push_back(std::string(argv[i]));
37 | }
38 | for (int i = 1; i < argc; i++) {
39 | if (sargv[i] == "-h" || sargv[i] == "--help") {
40 | Usage();
41 | exit(0);
42 | } else if (sargv[i] == "-p" || sargv[i] == "--path") {
43 | config.path = sargv[++i];
44 | } else if (sargv[i] == "-t" || sargv[i] == "--threads") {
45 | config.threads = atoi(sargv[++i].c_str());
46 | } else if (sargv[i] == "-l" || sargv[i] == "--low") {
47 | config.lowMemMode = true;
48 | } else if (sargv[i] == "-w" || sargv[i] == "--web") {
49 | config.webPath = sargv[++i];
50 | } else if (sargv[i] == "--port") {
51 | config.port = atoi(sargv[++i].c_str());
52 | } else {
53 | Usage();
54 | exit(-1);
55 | }
56 | }
57 | }
58 |
59 | struct ChatSession {
60 | std::string history = "";
61 | std::string input = "";
62 | std::string output = "";
63 | int round = 0;
64 | int status = 0; // 0: 空闲 1: 结果生成好了 2: 已经写回了
65 | };
66 |
67 | std::map sessions;
68 | std::mutex locker;
69 |
70 | int main(int argc, char** argv) {
71 | WebConfig config;
72 | ParseArgs(argc, argv, config);
73 |
74 | fastllm::SetThreads(config.threads);
75 | fastllm::SetLowMemMode(config.lowMemMode);
76 | auto model = fastllm::CreateLLMModelFromFile(config.path);
77 |
78 | httplib::Server svr;
79 | auto chat = [&](ChatSession *session, const std::string input) {
80 | if (input == "reset" || input == "stop") {
81 | session->history = "";
82 | session->round = 0;
83 | session->output = "\n";
84 | session->status = 2;
85 | } else {
86 | auto prompt = model->MakeInput(session->history, session->round, input);
87 | auto inputs = model->weight.tokenizer.Encode(prompt);
88 |
89 | std::vector tokens;
90 | for (int i = 0; i < inputs.Count(0); i++) {
91 | tokens.push_back(((float *) inputs.cpuData)[i]);
92 | }
93 |
94 | int handleId = model->LaunchResponseTokens(tokens);
95 | std::vector results;
96 | while (true) {
97 | int result = model->FetchResponseTokens(handleId);
98 | if (result == -1) {
99 | break;
100 | } else {
101 | results.clear();
102 | results.push_back(result);
103 | session->output += model->weight.tokenizer.Decode(fastllm::Data (fastllm::DataType::FLOAT32, {(int)results.size()}, results));
104 | }
105 | if (session->status == 2) {
106 | break;
107 | }
108 | }
109 | session->history = model->MakeHistory(session->history, session->round++, input, session->output);
110 | session->output += "\n";
111 | session->status = 2;
112 | }
113 | };
114 |
115 | svr.Post("/chat", [&](const httplib::Request &req, httplib::Response &res) {
116 | const std::string uuid = req.get_header_value("uuid");
117 | locker.lock();
118 | if (sessions.find(uuid) == sessions.end()) {
119 | sessions[uuid] = new ChatSession();
120 | }
121 | auto *session = sessions[uuid];
122 | locker.unlock();
123 |
124 | if (session->status != 0) {
125 | res.set_content(session->output, "text/plain");
126 | if (session->status == 2) {
127 | session->status = 0;
128 | }
129 | } else {
130 | session->output = "";
131 | session->status = 1;
132 | std::thread chat_thread(chat, session, req.body);
133 | chat_thread.detach();
134 | }
135 | });
136 |
137 | svr.set_mount_point("/", config.webPath);
138 | std::cout << ">>> please open http://127.0.0.1:" + std::to_string(config.port) + "\n";
139 | svr.listen("0.0.0.0", config.port);
140 |
141 | return 0;
142 | }
143 |
--------------------------------------------------------------------------------
/include/device.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 6/13/23.
3 | //
4 |
5 | #ifndef FASTLLM_DEVICE_H
6 | #define FASTLLM_DEVICE_H
7 |
8 | #include "fastllm.h"
9 |
10 | namespace fastllm {
11 | typedef std::map DataDict;
12 | typedef std::map FloatDict;
13 | typedef std::map IntDict;
14 |
15 | class BaseOperator {
16 | public:
17 | // 是否可以运行某一个算子
18 | virtual bool CanRun(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
19 |
20 | // 对某一个算子进行形状推理
21 | virtual void Reshape(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
22 |
23 | // 对某一个算子进行推理
24 | virtual void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams) = 0;
25 | };
26 |
27 | class BaseBatchOperator : BaseOperator {
28 | public:
29 | // 对某一个算子进行形状推理
30 | virtual void Reshape(const std::string &opType, const DataDict &datas, const FloatDict &floatParams,
31 | const IntDict &intParams);
32 | };
33 |
34 | class BaseDevice {
35 | public:
36 | virtual bool Malloc (void **ret, size_t size) = 0; // 分配尺寸为size的空间
37 | virtual bool Malloc (void **ret, Data &data); // 分配形状为dims的空间
38 | virtual bool Free(void *ret) = 0; // 释放ret
39 |
40 | virtual bool CopyDataToCPU(void *dst, void *src, size_t size) = 0; // device上的src拷贝到cpu上的dst
41 | virtual bool CopyDataToCPU(Data &data); // data数据从该device移动到CPU
42 |
43 | virtual bool CopyDataFromCPU(void *dst, void *src, size_t size) = 0; // cpu上的src拷贝到device上的dst
44 | virtual bool CopyDataFromCPU(Data &data); // data数据从CPU移动到该device
45 |
46 | // 是否可以运行某一个算子
47 | virtual bool CanRun(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
48 |
49 | // 对某一个算子进行形状推理
50 | virtual void Reshape(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
51 |
52 | // 对某一个算子进行推理
53 | virtual void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
54 |
55 | std::string deviceType;
56 | std::string deviceName;
57 | std::vector deviceIds;
58 |
59 | std::map ops;
60 | };
61 | }
62 |
63 | #endif //FASTLLM_DEVICE_H
64 |
--------------------------------------------------------------------------------
/include/devices/cpu/cputhreadpool.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 7/5/23.
3 | //
4 |
5 | #ifndef FASTLLCPUTHREADPOOL_H
6 | #define FASTLLCPUTHREADPOOL_H
7 |
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 |
16 | namespace fastllm {
17 | template
18 | class TaskQueue {
19 | private:
20 | std::queue q;
21 | std::mutex locker;
22 | public:
23 | TaskQueue() {}
24 |
25 | ~TaskQueue() {}
26 |
27 | bool Empty() {
28 | std::unique_lock lock(locker);
29 | return q.empty();
30 | }
31 |
32 | int Size() {
33 | std::unique_lock lock(locker);
34 | return q.size();
35 | }
36 |
37 | void Push(T &t) {
38 | std::unique_lock lock(locker);
39 | q.emplace(t);
40 | }
41 |
42 | bool Pop(T &t) {
43 | std::unique_lock lock(locker);
44 | if (q.empty()) {
45 | return false;
46 | }
47 | t = std::move(q.front());
48 | q.pop();
49 | return true;
50 | }
51 | };
52 |
53 | class ThreadPool {
54 | private:
55 | class ThreadWorker
56 | {
57 | private:
58 | int id;
59 | ThreadPool *pool;
60 | public:
61 | ThreadWorker(ThreadPool *pool, const int id) : pool(pool), id(id) {}
62 |
63 | void operator()() {
64 | std::function func;
65 | bool dequeued;
66 |
67 | while (!pool->shutdown) {
68 | {
69 | std::unique_lock lock(pool->locker);
70 | if (pool->queue.Empty()) {
71 | pool->cv.wait(lock);
72 | }
73 |
74 | dequeued = pool->queue.Pop(func);
75 | }
76 | if (dequeued) {
77 | func();
78 | }
79 | }
80 | }
81 | };
82 |
83 | bool shutdown = false;
84 | TaskQueue> queue;
85 | std::vector threads;
86 | std::mutex locker;
87 | std::condition_variable cv;
88 | public:
89 | ThreadPool(const int t = 4) : threads(std::vector(t)) {
90 | for (int i = 0; i < threads.size(); ++i) {
91 | threads[i] = std::thread(ThreadWorker(this, i));
92 | }
93 | }
94 | void Shutdown() {
95 | shutdown = true;
96 | cv.notify_all();
97 | for (int i = 0; i < threads.size(); ++i) {
98 | if (threads[i].joinable()) {
99 | threads[i].join();
100 | }
101 | }
102 | }
103 |
104 | template
105 | auto Submit(F &&f, Args &&...args) -> std::future {
106 | std::function func = std::bind(std::forward(f), std::forward(args)...);
107 | auto task_ptr = std::make_shared>(func);
108 | std::function warpper_func = [task_ptr]() {
109 | (*task_ptr)();
110 | };
111 | queue.Push(warpper_func);
112 | cv.notify_one();
113 | return task_ptr->get_future();
114 | }
115 | };
116 | }
117 |
118 | #endif //FASTLLCPUTHREADPOOL_H
119 |
--------------------------------------------------------------------------------
/include/executor.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 6/13/23.
3 | //
4 |
5 | #ifndef FASTLLM_EXECUTOR_H
6 | #define FASTLLM_EXECUTOR_H
7 |
8 | #include "device.h"
9 |
10 | namespace fastllm {
11 | class Executor {
12 | private:
13 | std::vector devices;
14 | std::map profiler;
15 |
16 | public:
17 | Executor (); // 创建默认的Executor
18 |
19 | ~Executor(); // 析构
20 |
21 | void ClearDevices(); // 清空 devices
22 |
23 | void AddDevice(BaseDevice *device); // 增加一个device
24 |
25 | void SetFirstDevice(const std::string &device); // 设定优先的device
26 |
27 | std::vector GetDeviceIds(const std::string &device); // 获取指定device的deviceIds
28 |
29 | // 运行一个op
30 | void Run(const std::string &opType, const fastllm::DataDict &datas, const fastllm::FloatDict &floatParams,
31 | const fastllm::IntDict &intParams);
32 |
33 | void ClearProfiler();
34 |
35 | void PrintProfiler();
36 | };
37 | }
38 |
39 | #endif //FASTLLM_EXECUTOR_H
40 |
--------------------------------------------------------------------------------
/include/model.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 6/20/23.
3 | //
4 |
5 | #ifndef FASTLLM_MODEL_H
6 | #define FASTLLM_MODEL_H
7 |
8 | #include "basellm.h"
9 |
10 | namespace fastllm {
11 | std::unique_ptr CreateLLMModelFromFile(const std::string &fileName);
12 |
13 | std::unique_ptr CreateEmptyLLMModel(const std::string &modelType);
14 | }
15 |
16 | #endif //FASTLLM_MODEL_H
17 |
--------------------------------------------------------------------------------
/include/models/chatglm.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 5/11/23.
3 | //
4 |
5 | #ifndef FASTLLM_CHATGLM_H
6 | #define FASTLLM_CHATGLM_H
7 |
8 | #include "basellm.h"
9 | #include "cmath"
10 |
11 | #include
12 |
13 | namespace fastllm {
14 | class ChatGLMModel: public basellm {
15 | public:
16 | ChatGLMModel (); // 构造函数
17 |
18 | virtual void InitParams(); // 初始化参数信息
19 |
20 | // 推理
21 | virtual int Forward(
22 | const Data &inputIds,
23 | const Data &attentionMask,
24 | const Data &positionIds,
25 | std::vector > &pastKeyValues,
26 | const GenerationConfig &generationConfig = GenerationConfig(),
27 | const LastTokensManager &lastTokens = LastTokensManager(),
28 | std::vector *logits = nullptr);
29 |
30 | std::vector ForwardBatch(
31 | int batch,
32 | const Data &inputIds,
33 | const Data &attentionMask,
34 | const Data &positionIds,
35 | std::vector > &pastKeyValues,
36 | const GenerationConfig &generationConfig = GenerationConfig(),
37 | const LastTokensManager &lastTokens = LastTokensManager(),
38 | std::vector *> *retLogits = nullptr);
39 |
40 | std::vector ForwardBatch(
41 | int batch,
42 | const Data &inputIds,
43 | const std::vector &attentionMask,
44 | const std::vector &positionIds,
45 | const std::vector &seqLens,
46 | std::vector > &pastKeyValues,
47 | const std::vector &generationConfigs,
48 | const LastTokensManager &lastTokens = LastTokensManager(),
49 | std::vector *> *logits = nullptr);
50 |
51 | // 根据输入的tokens生成LLM推理的输入
52 | virtual void FillLLMInputs(std::vector > &inputTokens,
53 | const std::map ¶ms,
54 | Data &inputIds, Data &attentionMask, Data &positionIds);
55 |
56 | // 根据输入的tokens生成LLM推理的输入
57 | virtual void FillLLMInputsBatch(std::vector > &inputTokens,
58 | const std::vector > ¶ms,
59 | Data &inputIds, Data &attentionMask, Data &positionIds);
60 |
61 | virtual void WarmUp(); // 预热
62 |
63 | virtual std::string MakeInput(const std::string &history, int round, const std::string &input); // 根据历史信息和当前输入生成prompt
64 |
65 | virtual std::string MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output); // 根据当前回复更新history
66 |
67 | int GetVersion();
68 |
69 | void UpdateRotaryPosEmb(float rope_factor);
70 |
71 | int gmask_token_id;
72 | private:
73 | virtual void CausalMask(Data &data, int start) {}; // 因果mask?
74 |
75 | float rope_factor = 1.0f;
76 | };
77 | }
78 |
79 | #endif //FASTLLM_CHATGLM_H
80 |
--------------------------------------------------------------------------------
/include/models/factoryllm.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "chatglm.h"
3 | #include "moss.h"
4 | #include "basellm.h"
5 | #include "llama.h"
6 | #include "qwen.h"
7 | #include "fastllm.h"
8 |
9 | enum LLM_TYPE {
10 | LLM_TYPE_CHATGLM = 0,
11 | LLM_TYPE_MOSS = 1,
12 | LLM_TYPE_VICUNA = 2,
13 | LLM_TYPE_BAICHUAN = 3,
14 | LLM_TYPE_QWEN = 4
15 | };
16 |
17 | class factoryllm {
18 | public:
19 | factoryllm() {};
20 |
21 | ~factoryllm() {};
22 |
23 | fastllm::basellm *createllm(LLM_TYPE type) {
24 | fastllm::basellm *pLLM = NULL;
25 | switch (type) {
26 | case LLM_TYPE_CHATGLM:
27 | pLLM = new fastllm::ChatGLMModel();
28 | break;
29 | case LLM_TYPE_MOSS:
30 | pLLM = new fastllm::MOSSModel();
31 | break;
32 | case LLM_TYPE_VICUNA:
33 | pLLM = new fastllm::LlamaModel();
34 | break;
35 | default:
36 | pLLM = new fastllm::QWenModel();
37 | break;
38 | }
39 | return pLLM;
40 | };
41 | };
--------------------------------------------------------------------------------
/include/models/glm.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 5/11/23.
3 | //
4 |
5 | #ifndef FASTLLM_GLM_H
6 | #define FASTLLM_GLM_H
7 |
8 | #include "basellm.h"
9 | #include "cmath"
10 |
11 | #include
12 |
13 | namespace fastllm {
14 | class GLMModel: public basellm {
15 | public:
16 | GLMModel (); // 构造函数
17 |
18 | // 推理
19 | virtual int Forward(
20 | const Data &inputIds,
21 | const Data &attentionMask,
22 | const Data &positionIds,
23 | std::vector > &pastKeyValues,
24 | const GenerationConfig &generationConfig = GenerationConfig(),
25 | const LastTokensManager &lastTokens = LastTokensManager(),
26 | std::vector *logits = nullptr);
27 |
28 | std::vector ForwardBatch(
29 | int batch,
30 | const Data &inputIds,
31 | const Data &attentionMask,
32 | const Data &positionIds,
33 | std::vector > &pastKeyValues,
34 | const GenerationConfig &generationConfig = GenerationConfig(),
35 | const LastTokensManager &lastTokens = LastTokensManager(),
36 | std::vector *> *retLogits = nullptr);
37 |
38 | // 根据输入的tokens生成LLM推理的输入
39 | virtual void FillLLMInputs(std::vector > &inputTokens,
40 | const std::map ¶ms,
41 | Data &inputIds, Data &attentionMask, Data &positionIds);
42 |
43 | virtual void InitParams();
44 | virtual void WarmUp(); // 预热
45 |
46 | virtual std::string MakeInput(const std::string &history, int round, const std::string &input); // 根据历史信息和当前输入生成prompt
47 |
48 | virtual std::string MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output); // 根据当前回复更新history
49 |
50 | private:
51 |
52 | float scale_attn_1;
53 |
54 | static constexpr int eot_token_id = 50000;//<|endoftext|>
55 | static constexpr int cls_token_id = 50002;//[CLS]
56 | static constexpr int mask_token_id = 50003;//[MASK]
57 | static constexpr int smask_token_id = 50008;//[sMASK]
58 | static constexpr int gmask_token_id = 50009;//[gMASK]
59 | };
60 | }
61 |
62 | #endif //FASTLLM_GLM_H
63 |
--------------------------------------------------------------------------------
/include/models/internlm2.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by tylunasli on 3/14/24.
3 | //
4 |
5 | #ifndef FASTLLM_INTERNLM2_H
6 | #define FASTLLM_INTERNLM2_H
7 |
8 | #include "basellm.h"
9 | #include "llama.h"
10 | #include "cmath"
11 |
12 | #include
13 |
14 | namespace fastllm {
15 | class Internlm2Model : public LlamaModel {
16 | public:
17 | Internlm2Model(); // 构造函数
18 |
19 | virtual void InitParams(); // 初始化参数信息
20 |
21 | // 推理
22 | virtual int Forward(
23 | const Data &inputIds,
24 | const Data &attentionMask,
25 | const Data &positionIds,
26 | std::vector > &pastKeyValues,
27 | const GenerationConfig &generationConfig = GenerationConfig(),
28 | const LastTokensManager &lastTokens = LastTokensManager(),
29 | std::vector *logits = nullptr);
30 |
31 | std::vector ForwardBatch(
32 | int batch,
33 | const Data &inputIds,
34 | const Data &attentionMask,
35 | const Data &positionIds,
36 | std::vector > &pastKeyValues,
37 | const GenerationConfig &generationConfig = GenerationConfig(),
38 | const LastTokensManager &lastTokens = LastTokensManager(),
39 | std::vector *> *logits = nullptr);
40 |
41 | std::vector ForwardBatch(
42 | int batch,
43 | const Data &inputIds,
44 | const std::vector &attentionMask,
45 | const std::vector &positionIds,
46 | const std::vector &seqLens,
47 | std::vector > &pastKeyValues,
48 | const std::vector &generationConfigs,
49 | const LastTokensManager &lastTokens = LastTokensManager(),
50 | std::vector *> *logits = nullptr);
51 | };
52 | }
53 |
54 | #endif //FASTLLM_INTERNLM2_H
55 |
--------------------------------------------------------------------------------
/include/models/llama.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 6/1/23.
3 | //
4 |
5 | #ifndef FASTLLM_LLAMA_H
6 | #define FASTLLM_LLAMA_H
7 |
8 | #include "basellm.h"
9 | #include "cmath"
10 |
11 | #include
12 |
13 | namespace fastllm {
14 |
15 | enum RoPEType { // 位置编码外推类型
16 | BASE = 0,
17 | LINEAR_SCALE = 1,
18 | STATIC_NTK = 2,
19 | DYMAMIC_NTK = 3
20 | };
21 |
22 | class LlamaModel: public basellm {
23 | public:
24 | LlamaModel (); // 构造函数
25 |
26 | virtual void InitParams(); // 初始化参数信息
27 |
28 | // 推理
29 | virtual int Forward(
30 | const Data &inputIds,
31 | const Data &attentionMask,
32 | const Data &positionIds,
33 | std::vector > &pastKeyValues,
34 | const GenerationConfig &generationConfig = GenerationConfig(),
35 | const LastTokensManager &lastTokens = LastTokensManager(),
36 | std::vector *logits = nullptr);
37 |
38 | std::vector ForwardBatch(
39 | int batch,
40 | const Data &inputIds,
41 | const Data &attentionMask,
42 | const Data &positionIds,
43 | std::vector > &pastKeyValues,
44 | const GenerationConfig &generationConfig = GenerationConfig(),
45 | const LastTokensManager &lastTokens = LastTokensManager(),
46 | std::vector *> *logits = nullptr);
47 |
48 | std::vector ForwardBatch(
49 | int batch,
50 | const Data &inputIds,
51 | const std::vector &attentionMask,
52 | const std::vector &positionIds,
53 | const std::vector &seqLens,
54 | std::vector > &pastKeyValues,
55 | const std::vector &generationConfigs,
56 | const LastTokensManager &lastTokens = LastTokensManager(),
57 | std::vector *> *logits = nullptr);
58 |
59 | virtual std::string Response(const std::string& input,
60 | RuntimeResult retCb,
61 | const GenerationConfig &generationConfig = GenerationConfig()); // 根据给出的内容回复
62 |
63 | virtual void ResponseBatch(const std::vector &inputs,
64 | std::vector &outputs,
65 | RuntimeResultBatch retCb,
66 | const GenerationConfig &generationConfig = GenerationConfig());
67 |
68 | virtual int LaunchResponseTokens(const std::vector &inputTokens,
69 | const GenerationConfig &generationConfig = GenerationConfig()); // 启动一个response任务,返回分配的handleId
70 |
71 | virtual int FetchResponseTokens(int handelId); // 获取指定handle的输出, -1代表输出结束了
72 |
73 | virtual void WarmUp(); // 预热
74 |
75 | virtual std::string MakeInput(const std::string &history, int round, const std::string &input); // 根据历史信息和当前输入生成prompt
76 |
77 | virtual std::string MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output); // 根据当前回复更新history
78 |
79 | std::pair, std::vector> UpdateRotaryPosEmb(float base, float factor, int seqLen = 0); // 更新位置编码
80 |
81 | protected:
82 | RoPEType rope_type = RoPEType::BASE;
83 |
84 | float rope_base = 10000.f;
85 |
86 | float rope_factor = 1.f;
87 |
88 | int num_key_value_heads = num_attention_heads;
89 |
90 | float rms_norm_eps = 1e-6;
91 | };
92 | }
93 |
94 | #endif //FASTLLM_LLAMA_H
95 |
--------------------------------------------------------------------------------
/include/models/minicpm.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 6/1/23.
3 | //
4 |
5 | #ifndef FASTLLM_MINICPM_H
6 | #define FASTLLM_MINICPM_H
7 |
8 | #include "basellm.h"
9 | #include "llama.h"
10 | #include "cmath"
11 |
12 | #include
13 |
14 | namespace fastllm {
15 | class MiniCpmModel: public LlamaModel {
16 | public:
17 | MiniCpmModel(); // 构造函数
18 |
19 | virtual void InitParams(); // 初始化参数信息
20 |
21 | // 推理
22 | virtual int Forward(
23 | const Data &inputIds,
24 | const Data &attentionMask,
25 | const Data &positionIds,
26 | std::vector > &pastKeyValues,
27 | const GenerationConfig &generationConfig = GenerationConfig(),
28 | const LastTokensManager &lastTokens = LastTokensManager(),
29 | std::vector *logits = nullptr);
30 |
31 | std::vector ForwardBatch(
32 | int batch,
33 | const Data &inputIds,
34 | const Data &attentionMask,
35 | const Data &positionIds,
36 | std::vector > &pastKeyValues,
37 | const GenerationConfig &generationConfig = GenerationConfig(),
38 | const LastTokensManager &lastTokens = LastTokensManager(),
39 | std::vector *> *logits = nullptr);
40 |
41 | std::vector ForwardBatch(
42 | int batch,
43 | const Data &inputIds,
44 | const std::vector &attentionMask,
45 | const std::vector &positionIds,
46 | const std::vector &seqLens,
47 | std::vector > &pastKeyValues,
48 | const std::vector &generationConfigs,
49 | const LastTokensManager &lastTokens = LastTokensManager(),
50 | std::vector *> *logits = nullptr);
51 |
52 | private:
53 | float embed_scale = 1.f;
54 |
55 | float attention_scale = 1.f / std::sqrt(block_cnt);
56 |
57 | float rms_scale = 1.f / 4096.f;
58 | };
59 | }
60 |
61 | #endif //FASTLLM_MINICPM_H
62 |
--------------------------------------------------------------------------------
/include/models/moss.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 5/12/23.
3 | //
4 |
5 | #ifndef TEST_MOSS_H
6 | #define TEST_MOSS_H
7 |
8 | #include "basellm.h"
9 | #include "cmath"
10 |
11 | namespace fastllm {
12 | class MOSSModel: public basellm {
13 | public:
14 | MOSSModel(); // 构造函数
15 |
16 | // 推理
17 | virtual int Forward(
18 | const Data &inputIds,
19 | const Data &attentionMask,
20 | const Data &positionIds,
21 | std::vector > &pastKeyValues,
22 | const GenerationConfig &generationConfig = GenerationConfig(),
23 | const LastTokensManager &lastTokens = LastTokensManager(),
24 | std::vector *logits = nullptr);
25 |
26 | virtual std::string Response(const std::string &input, RuntimeResult retCb,
27 | const GenerationConfig &generationConfig = GenerationConfig()); // 根据给出的内容回复
28 |
29 | virtual std::string MakeInput(const std::string &history, int round, const std::string &input); // 根据历史信息和当前输入生成prompt
30 |
31 | virtual std::string MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output); // 根据当前回复更新history
32 |
33 | virtual void FillLLMInputs(std::vector > &inputTokens,
34 | const std::map ¶ms,
35 | Data &inputIds, Data &attentionMask, Data &positionIds);
36 |
37 | virtual void WarmUp();
38 | private:
39 | virtual void RotatePosition2D(Data &data, const Data &positionIds); // 二维位置编码
40 |
41 | virtual void CausalMask(Data &data, int start); // 因果mask?
42 | };
43 | }
44 |
45 | #endif //TEST_MOSS_H
46 |
--------------------------------------------------------------------------------
/include/models/qwen.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by siemon on 8/9/23.
3 | //
4 |
5 | #ifndef TEST_QWEN_H
6 | #define TEST_QWEN_H
7 |
8 | #include "basellm.h"
9 |
10 | namespace fastllm {
11 | class QWenModel : public basellm {
12 | public:
13 | QWenModel();
14 |
15 | // 推理
16 | virtual int Forward(
17 | const Data &inputIds,
18 | const Data &attentionMask,
19 | const Data &positionIds,
20 | std::vector > &pastKeyValues,
21 | const GenerationConfig &generationConfig = GenerationConfig(),
22 | const LastTokensManager &lastTokens = LastTokensManager(),
23 | std::vector *logits = nullptr);
24 |
25 | std::vector ForwardBatch(
26 | int batch,
27 | const Data &inputIds,
28 | const Data &attentionMask,
29 | const Data &positionIds,
30 | std::vector > &pastKeyValues,
31 | const GenerationConfig &generationConfig = GenerationConfig(),
32 | const LastTokensManager &lastTokens = LastTokensManager(),
33 | std::vector *> *logits = nullptr);
34 |
35 | std::vector ForwardBatch(
36 | int batch,
37 | const Data &inputIds,
38 | const std::vector &attentionMask,
39 | const std::vector &positionIds,
40 | const std::vector &seqLens,
41 | std::vector > &pastKeyValues,
42 | const std::vector &generationConfigs,
43 | const LastTokensManager &lastTokens = LastTokensManager(),
44 | std::vector *> *retLogits = nullptr);
45 |
46 | virtual std::string MakeInput(const std::string &history, int round, const std::string &input);
47 |
48 | virtual std::string MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output);
49 |
50 | virtual void FillLLMInputs(std::vector > &inputTokens,
51 | const std::map ¶ms,
52 | Data &inputIds, Data &attentionMask, Data &positionIds);
53 |
54 | virtual void FillLLMInputsBatch(std::vector > &inputTokens,
55 | const std::vector > ¶ms,
56 | Data &inputIds, Data &attentionMask, Data &positionIds);
57 |
58 | virtual void WarmUp();
59 |
60 | void UpdateRotaryPosEmb(float ntk_alpha);
61 |
62 | int seq_length;
63 | float ntk_alpha;
64 |
65 | bool use_log_attn;
66 | Data logn_list;
67 |
68 | private:
69 | std::string im_start = "<|im_start|>";
70 | std::string im_end = "<|im_end|>";
71 | };
72 | }
73 |
74 | #endif //TEST_QWEN_H
--------------------------------------------------------------------------------
/main.cpp:
--------------------------------------------------------------------------------
1 | #include "model.h"
2 |
3 | struct RunConfig {
4 | std::string path = "chatglm-6b-int4.bin"; // 模型文件路径
5 | int threads = 4; // 使用的线程数
6 | bool lowMemMode = false; // 是否使用低内存模式
7 | };
8 |
9 | void Usage() {
10 | std::cout << "Usage:" << std::endl;
11 | std::cout << "[-h|--help]: 显示帮助" << std::endl;
12 | std::cout << "<-p|--path> : 模型文件的路径" << std::endl;
13 | std::cout << "<-t|--threads> : 使用的线程数量" << std::endl;
14 | std::cout << "<-l|--low>: 使用低内存模式" << std::endl;
15 | std::cout << "<--top_p> : 采样参数top_p" << std::endl;
16 | std::cout << "<--top_k> : 采样参数top_k" << std::endl;
17 | std::cout << "<--temperature> : 采样参数温度,越高结果越不固定" << std::endl;
18 | std::cout << "<--repeat_penalty> : 采样参数重复惩罚" << std::endl;
19 | }
20 |
21 | void ParseArgs(int argc, char **argv, RunConfig &config, fastllm::GenerationConfig &generationConfig) {
22 | std::vector sargv;
23 | for (int i = 0; i < argc; i++) {
24 | sargv.push_back(std::string(argv[i]));
25 | }
26 | for (int i = 1; i < argc; i++) {
27 | if (sargv[i] == "-h" || sargv[i] == "--help") {
28 | Usage();
29 | exit(0);
30 | } else if (sargv[i] == "-p" || sargv[i] == "--path") {
31 | config.path = sargv[++i];
32 | } else if (sargv[i] == "-t" || sargv[i] == "--threads") {
33 | config.threads = atoi(sargv[++i].c_str());
34 | } else if (sargv[i] == "-l" || sargv[i] == "--low") {
35 | config.lowMemMode = true;
36 | } else if (sargv[i] == "-m" || sargv[i] == "--model") {
37 | i++;
38 | } else if (sargv[i] == "--top_p") {
39 | generationConfig.top_p = atof(sargv[++i].c_str());
40 | } else if (sargv[i] == "--top_k") {
41 | generationConfig.top_k = atof(sargv[++i].c_str());
42 | } else if (sargv[i] == "--temperature") {
43 | generationConfig.temperature = atof(sargv[++i].c_str());
44 | } else if (sargv[i] == "--repeat_penalty") {
45 | generationConfig.repeat_penalty = atof(sargv[++i].c_str());
46 | } else {
47 | Usage();
48 | exit(-1);
49 | }
50 | }
51 | }
52 |
53 | int main(int argc, char **argv) {
54 | int round = 0;
55 | std::string history = "";
56 |
57 | RunConfig config;
58 | fastllm::GenerationConfig generationConfig;
59 | ParseArgs(argc, argv, config, generationConfig);
60 |
61 | fastllm::PrintInstructionInfo();
62 | fastllm::SetThreads(config.threads);
63 | fastllm::SetLowMemMode(config.lowMemMode);
64 | auto model = fastllm::CreateLLMModelFromFile(config.path);
65 |
66 | static std::string modelType = model->model_type;
67 | printf("欢迎使用 %s 模型. 输入内容对话,reset清空历史记录,stop退出程序.\n", model->model_type.c_str());
68 | while (true) {
69 | printf("用户: ");
70 | std::string input;
71 | std::getline(std::cin, input);
72 | if (input == "reset") {
73 | history = "";
74 | round = 0;
75 | continue;
76 | }
77 | if (input == "stop") {
78 | break;
79 | }
80 | std::string ret = model->Response(model->MakeInput(history, round, input), [](int index, const char* content) {
81 | if (index == 0) {
82 | printf("%s:%s", modelType.c_str(), content);
83 | fflush(stdout);
84 | }
85 | if (index > 0) {
86 | printf("%s", content);
87 | fflush(stdout);
88 | }
89 | if (index == -1) {
90 | printf("\n");
91 | }
92 | }, generationConfig);
93 | history = model->MakeHistory(history, round, input, ret);
94 | round++;
95 | }
96 |
97 | return 0;
98 | }
--------------------------------------------------------------------------------
/pyfastllm/examples/cli_low_level.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import sys
3 | import platform
4 | import logging
5 | import argparse
6 | import fastllm
7 |
8 | logging.info(f"python gcc version:{platform.python_compiler()}")
9 |
10 | def args_parser():
11 | parser = argparse.ArgumentParser(description='fastllm')
12 | parser.add_argument('-m', '--model', type=int, required=False, default=0, help='模型类型,默认为0, 可以设置为0(chatglm),1(moss),2(vicuna),3(baichuan)')
13 | parser.add_argument('-p', '--path', type=str, required=True, default='', help='模型文件的路径')
14 | parser.add_argument('-t', '--threads', type=int, default=4, help='使用的线程数量')
15 | parser.add_argument('-l', '--low', action='store_true', help='使用低内存模式')
16 | args = parser.parse_args()
17 | return args
18 |
19 | # 请谨慎使用该函数,目前仍存在bug,仅作为low level api调用示例,请勿在生产环境使用
20 | def response(model, prompt_input:str, stream_output:bool=False):
21 | gmask_token_id = 130001
22 | bos_token_id = 130004
23 | eos_token_id = model.eos_token_id
24 |
25 | input_ids = model.weight.tokenizer.encode(prompt_input)
26 | if model.model_type == "chatglm":
27 | gmask_token_id = model.gmask_token_id
28 | bos_token_id = model.bos_token_id
29 | gmask_bos = fastllm.Tensor(fastllm.float32, [1, 2], [gmask_token_id, bos_token_id])
30 | input_ids = fastllm.cat([gmask_bos, input_ids], 0)
31 |
32 | seq_len = input_ids.count(0)
33 | vmask = [0] * (seq_len * seq_len)
34 | vpids = [0] * (seq_len * 2)
35 | for i in range(seq_len-1):
36 | vmask[i*seq_len + seq_len -1] = 1
37 | vpids[i] = i
38 | vpids[seq_len - 1] = seq_len - 2
39 | vpids[seq_len * 2 - 1] = 1
40 | attention_mask = fastllm.Tensor(fastllm.float32, [seq_len, seq_len], vmask)
41 | position_ids = fastllm.Tensor(fastllm.float32, [2, seq_len], vpids)
42 |
43 | pastKeyValues = []
44 | for _ in range(model.block_cnt):
45 | pastKeyValues.append([fastllm.Tensor(fastllm.float32), fastllm.Tensor(fastllm.float32)])
46 |
47 | ret_str = ""
48 | ret_len = 1
49 | mask_ids = -1
50 | output_tokens = []
51 | penalty_factor = fastllm.Tensor()
52 |
53 | while len(output_tokens) < 2048: # config.max_seq_len
54 | ret, pastKeyValues = model.forward(input_ids, attention_mask, position_ids, penalty_factor, pastKeyValues)
55 | if ret == eos_token_id:
56 | break
57 |
58 | output_tokens.append(ret)
59 | cur_str = model.weight.tokenizer.decode(fastllm.Tensor(fastllm.float32, [len(output_tokens)], output_tokens))
60 | ret_str += cur_str
61 |
62 | print(cur_str, end="")
63 | sys.stdout.flush()
64 | if stream_output:
65 | yield cur_str
66 |
67 | ret_len += 1
68 | output_tokens = []
69 |
70 | if mask_ids == -1:
71 | mask_ids = seq_len - 2
72 |
73 | input_ids = fastllm.Tensor(fastllm.float32, [1, 1], [ret])
74 | attention_mask = fastllm.Tensor()
75 | position_ids = fastllm.Tensor(fastllm.float32, [2, 1], [mask_ids, ret_len])
76 |
77 | print()
78 | return ret_str
79 |
80 |
81 | def run_with_low_level(args):
82 | model_path = args.path
83 | llm_type = fastllm.get_llm_type(model_path)
84 | print(f"llm model: {llm_type}")
85 | model = fastllm.create_llm(model_path)
86 |
87 | prompt = ""
88 | while prompt != "stop":
89 | prompt = input("User: ")
90 | outputs = response(model, prompt_input=model.make_input("", 0, prompt))
91 | for output in outputs:
92 | print(output)
93 | sys.stdout.flush()
94 |
95 | if __name__ == "__main__":
96 | args = args_parser()
97 | run_with_low_level(args)
98 |
--------------------------------------------------------------------------------
/pyfastllm/examples/cli_simple.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import sys, os
3 | import platform
4 | import logging
5 | import argparse
6 | import fastllm
7 |
8 | logging.info(f"python gcc version:{platform.python_compiler()}")
9 |
10 | def args_parser():
11 | parser = argparse.ArgumentParser(description='fastllm')
12 | parser.add_argument('-m', '--model', type=int, required=False, default=0, help='模型类型,默认为0, 可以设置为0(chatglm),1(moss),2(vicuna),3(baichuan)')
13 | parser.add_argument('-p', '--path', type=str, required=True, default='', help='模型文件的路径')
14 | parser.add_argument('-t', '--threads', type=int, default=4, help='使用的线程数量')
15 | parser.add_argument('-l', '--low', action='store_true', help='使用低内存模式')
16 | args = parser.parse_args()
17 | return args
18 |
19 |
20 | def response(model, prompt_input:str, stream_output:bool=False):
21 |
22 | input_ids = model.weight.tokenizer.encode(prompt_input)
23 | input_ids = input_ids.to_list()
24 | input_ids = [int(v) for v in input_ids]
25 | if model.model_type == "chatglm":
26 | input_ids = [model.gmask_token_id, model.bos_token_id] + input_ids
27 | # print(input_ids)
28 |
29 | handle = model.launch_response(input_ids, fastllm.GenerationConfig())
30 | continue_token = True
31 |
32 | ret_byte = b""
33 | ret_str = ""
34 |
35 | while continue_token:
36 | resp_token = model.fetch_response(handle)
37 | continue_token = (resp_token != -1)
38 |
39 | content = model.weight.tokenizer.decode_byte([resp_token])
40 | ret_byte += content
41 | ret_str = ret_byte.decode(errors='ignore')
42 |
43 | if stream_output:
44 | yield ret_str
45 |
46 | return ret_str
47 |
48 | def run_with_response(args):
49 | model_path = args.path
50 | OLD_API = False
51 | if OLD_API:
52 | model = fastllm.ChatGLMModel()
53 | model.load_weights(model_path)
54 | model.warmup()
55 | else:
56 | fastllm.set_threads(args.threads)
57 | fastllm.set_low_memory(args.low)
58 | if not os.path.exists(model_path):
59 | print(f"模型文件{args.path}不存在!")
60 | exit(-1)
61 | model = fastllm.create_llm(model_path)
62 | print(f"llm model: {model.model_type}")
63 | print(f"欢迎使用 {model.model_type} 模型. 输入内容对话,reset清空历史记录,stop退出程序");
64 |
65 | input_text = ""
66 | history = ""
67 | dialog_round = 0
68 | while input_text != "stop":
69 | input_text = input("User: ")
70 | if 'stop' == input_text:
71 | break
72 | if 'reset' == input_text:
73 | history = ''
74 | continue
75 | prompt = model.make_input(history, dialog_round, input_text)
76 |
77 | outputs = response(model, prompt_input=prompt, stream_output=True)
78 |
79 | print(f"{model.model_type}:", end=' ')
80 | past_len = 0
81 | for output in outputs:
82 | print(output[past_len:], end='', flush=True)
83 | past_len = len(output)
84 | print()
85 | model.make_history(history, dialog_round, input_text, output)
86 | dialog_round += 1
87 |
88 |
89 | def run_with_callback(args):
90 | model_path = args.path
91 | OLD_API = False
92 | LLM_TYPE = ""
93 | if OLD_API:
94 | model = fastllm.ChatGLMModel()
95 | model.load_weights(model_path)
96 | model.warmup()
97 | else:
98 | fastllm.set_threads(args.threads)
99 | fastllm.set_low_memory(args.low)
100 | if not os.path.exists(model_path):
101 | print(f"模型文件{args.path}不存在!")
102 | exit(-1)
103 | LLM_TYPE = fastllm.get_llm_type(model_path)
104 | model = fastllm.create_llm(model_path)
105 |
106 | def print_back(idx:int, content: bytearray):
107 | content = content.decode(encoding="utf-8", errors="replace")
108 | if idx >= 0:
109 | print(f"\r{LLM_TYPE}:{content}", end='', flush=True)
110 | elif idx == -1:
111 | print()
112 | sys.stdout.flush()
113 |
114 | print(f"欢迎使用 {LLM_TYPE} 模型. 输入内容对话,reset清空历史记录,stop退出程序");
115 | prompt = ""
116 | while prompt != "stop":
117 | prompt = input("User: ")
118 | config = fastllm.GenerationConfig()
119 | model.response(model.make_input("", 0, prompt), print_back, config)
120 | print()
121 | sys.stdout.flush()
122 |
123 |
124 | if __name__ == "__main__":
125 | args = args_parser()
126 | # run_with_callback(args)
127 | run_with_response(args)
128 |
--------------------------------------------------------------------------------
/pyfastllm/examples/convert_model.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from transformers import AutoTokenizer, AutoModel
3 |
4 | import fastllm
5 |
6 | def export():
7 | model_path = '/public/Models/chatglm-6b' # 仅支持fp32模型加载
8 | export_path = "chatglm-6b-fp32.flm"
9 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
10 | model = AutoModel.from_pretrained(model_path, trust_remote_code=True).float()
11 | model = model.eval()
12 |
13 | fastllm.utils.convert(model=model, tokenizer=tokenizer, output_path=export_path, verbose=True)
14 |
15 | def response(model, prompt_input:str, stream_output:bool=False):
16 | gmask_token_id = 130001
17 | bos_token_id = 130004
18 |
19 | input_ids = model.weight.tokenizer.encode(prompt_input)
20 | input_ids = input_ids.to_list()
21 | input_ids.extend([gmask_token_id, bos_token_id])
22 | input_ids = [int(v) for v in input_ids]
23 |
24 | handle = model.launch_response(input_ids)
25 | continue_token = True
26 |
27 | ret_byte = b""
28 | ret_str = ""
29 |
30 | while continue_token:
31 | resp_token = model.fetch_response(handle)
32 | continue_token = (resp_token != -1)
33 |
34 | content = model.weight.tokenizer.decode_byte([resp_token])
35 | ret_byte += content
36 | ret_str = ret_byte.decode(errors='ignore')
37 |
38 | if stream_output:
39 | yield ret_str
40 |
41 | return ret_str
42 |
43 | def infer():
44 | model_path = "chatglm-6b-fp32.flm"
45 | model = fastllm.create_llm(model_path)
46 |
47 | prompt = "你好"
48 | outputs = response(model, prompt_input=prompt, stream_output=True)
49 | for output in outputs:
50 | print('\r LLM:' + output, end='', flush=True)
51 |
52 | print()
53 |
54 |
55 | if __name__ == "__main__":
56 | # export()
57 | infer()
--------------------------------------------------------------------------------
/pyfastllm/examples/test_chatglm2.py:
--------------------------------------------------------------------------------
1 | import fastllm
2 | from fastllm.hub.chatglm2 import ChatGLM2, ChatGLMConfig
3 | import fastllm.functions as ops
4 |
5 | from transformers import AutoTokenizer
6 |
7 | def load_weights():
8 | file = "/home/pan/Public/Models/models-flm/chatglm2-6b.flm"
9 | state_dict = ops.load(file)
10 | return state_dict
11 |
12 | def run():
13 | # fastllm.set_device_map({"cuda:0": 28})
14 | state_dict = load_weights()
15 | cfg = ChatGLMConfig()
16 | model = ChatGLM2(cfg)
17 | model.set_weights(state_dict)
18 | print("model loaded!!!")
19 |
20 | model_path = "/home/pan/Public/Models/models-hf/chatglm2-6b"
21 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
22 | # model.warmup()
23 | res = ""
24 | for output in model.stream_chat(query="飞机为什么会飞", tokenizer=tokenizer):
25 | res = output
26 |
27 | print("最终问答", res)
28 |
29 | if __name__ == "__main__":
30 | run()
31 |
32 |
--------------------------------------------------------------------------------
/pyfastllm/examples/test_ops.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import pytest
3 | import numpy as np
4 | import fastllm
5 |
6 | import pyfastllm
7 | import gc
8 |
9 | import np_ops
10 | import ops as flm_ops
11 |
12 | # from fastllm import ops as flm_ops
13 | # from fastllm import np_ops
14 |
15 | np.random.seed(42)
16 |
17 | def diff(dataA, dataB):
18 | # print(dataA)
19 | # print(dataB)
20 | mae = np.max(np.abs(dataA - dataB))
21 | print('max abs err is ', mae)
22 | return mae
23 |
24 | def to_tensor(data):
25 | return pyfastllm.from_numpy(data)
26 |
27 | def to_numpy(data):
28 | # return data.numpy()
29 | return np.array(data, copy=False, order='C')
30 |
31 | def test_rms_norm(inputs=None, weights=None, eps=1e-6):
32 | if not inputs:
33 | inputs = np.random.random(size=[1, 256])
34 | weights = np.random.random(size=[1, 256])
35 |
36 | np_out = np_ops.rms_norm(inputs, weights, eps)
37 | flm_out = flm_ops.rms_norm(to_tensor(inputs), to_tensor(weights), eps)
38 | mae = diff(np_out, to_numpy(flm_out))
39 | assert mae <= 1e-6
40 | return flm_out
41 |
42 | def test_swiglu(inputs=None):
43 | if not inputs:
44 | inputs = np.random.random(size=[1, 256])
45 |
46 | np_out = np_ops.swiglu(inputs)
47 | out = flm_ops.activation(inputs=to_tensor(inputs), activate_type="swiglu")
48 | mae = diff(np_out, to_numpy(out))
49 | assert mae <= 1e-6
50 | return out
51 |
52 | def test_attention(q=None, k=None, v=None, mask=None, group=1, scale=1.0):
53 | if q is None:
54 | q = np.random.random(size=[12, 1, 4096])
55 | k = np.random.random(size=[12, 1, 4096])
56 | v = np.random.random(size=[12, 1, 4096])
57 | scale = 1 / np.sqrt(q.shape[-1])
58 |
59 | np_out = np_ops.attention(q, k, v, scale=scale)
60 |
61 | mask = fastllm.Tensor()
62 | flm_out = flm_ops.attention(to_tensor(q), to_tensor(k), to_tensor(v), mask, group=group, scale=scale, attentionType=0)
63 |
64 | mae = diff(np_out, to_numpy(flm_out))
65 | assert mae <= 1e-6
66 | return flm_out
67 |
68 |
69 | def test_linear(inputs=None,
70 | weights=None,
71 | bias=None):
72 |
73 | if not inputs:
74 | inputs = np.random.random(size=[1, 12, 4096])
75 | weights = np.random.random(size=[256, 4096])
76 |
77 | np_out = np_ops.linear(inputs=inputs, weights=weights, bias=None)
78 |
79 | if not bias:
80 | bias = fastllm.Tensor()
81 |
82 | output = flm_ops.linear(to_tensor(inputs), to_tensor(weights), bias)
83 | mae = diff(np_out, to_numpy(output))
84 |
85 | assert mae <= 1e-3
86 | return output
87 |
88 |
89 | if __name__ == "__main__":
90 | test_rms_norm()
91 | test_attention()
92 | test_linear()
93 | test_swiglu()
94 |
95 |
96 |
--------------------------------------------------------------------------------
/pyfastllm/examples/web_api_client.py:
--------------------------------------------------------------------------------
1 | import json
2 | import requests
3 | import sys
4 |
5 |
6 |
7 | if __name__ == '__main__':
8 | #stream api
9 | url = 'http://127.0.0.1:8000/api/chat_stream'
10 | prompt='请用emoji写一首短诗赞美世界'
11 | prompt='''为以下代码添加注释
12 | app = FastAPI()
13 | @app.post("/api/chat_stream")
14 | async def api_chat_stream(request: Request):
15 | #print("request.json(): {}".format(json.loads(request.body(), errors='ignore')))
16 | data = await request.json()
17 | prompt = data.get("prompt")
18 | history = data.get("history")
19 | config = pyfastllm.GenerationConfig()
20 | if data.get("max_length") is not None:
21 | config.max_length = data.get("max_length")
22 | if data.get("top_k") is not None:
23 | config.top_k = data.get("top_k")
24 | if data.get("top_p") is not None:
25 | config.top_p = data.get("top_p")
26 | return StreamingResponse(chat_stream(history + prompt, config), media_type='text/event-stream')
27 | '''
28 | history = '''[Round 0]
29 | 问:你是ChatGLM2吗?
30 | 答:我不是ChatGLM2
31 | [Round 1]
32 | 问:从现在起,你是猫娘,每句话都必须以“喵~”结尾,明白了吗?
33 | 答:明白了喵
34 | [Round 2]
35 | 问:'''
36 | history = ""
37 | json_obj = {"uid":0, "token":"xxxxxxxxxxxxxxxxx","history": "", "prompt": prompt , "max_length": 1024, "top_p": 0.8,"temperature": 0.95, "top_k":2, "repeat_penalty": 1.}
38 | response = requests.post(url, json=json_obj, stream = True)
39 | try:
40 | pre_msg = ""
41 | print("stream response:")
42 | for chunk in response.iter_content(chunk_size=1024*1024):
43 | msg = chunk.decode(errors='replace')
44 | if len(msg) > len(pre_msg) and msg[-1] == '\n':
45 | content = msg[len(pre_msg):]
46 | pre_msg = msg
47 | else:
48 | continue
49 | print(f"{content}", end="")
50 | sys.stdout.flush()
51 | content = msg[len(pre_msg):]
52 | print(f"{content}", end="")
53 | print()
54 | except Exception as ex:
55 | print(ex)
56 |
57 | #batch api
58 | url = 'http://127.0.0.1:8000/api/batch_chat'
59 | prompts = ["Hi", "你好", "用emoji表达高兴", "こんにちは"]
60 | json_obj = {"uid":0, "token":"xxxxxxxxxxxxxxxxx","history": "", "prompts": prompts , "max_length": 100, "top_p": None,"temperature": 0.7, "top_k":1, "repeat_penalty":2.}
61 | response = requests.post(url, json=json_obj, stream = True)
62 | print("batch response: {} text:\n{}".format(response, response.text.replace('\\n', '\n')))
63 |
--------------------------------------------------------------------------------
/pyfastllm/fastllm/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import ctypes
4 | import glob
5 |
6 | from pyfastllm import *
7 | from . import utils
8 | from . import functions as ops
9 |
10 | __version__ = "0.2.0"
11 |
12 |
--------------------------------------------------------------------------------
/pyfastllm/fastllm/convert.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import logging
3 | import sys
4 | import struct
5 | import numpy as np
6 | import argparse
7 |
8 | from .utils import convert
9 | from .utils.converter import QuantType
10 |
11 | def parse_args():
12 | # -p 模型路径或hf路径
13 | # -o --out_path 导出路径
14 | # -q 量化位数
15 | parser = argparse.ArgumentParser(description='build fastllm libs')
16 | parser.add_argument('-o', dest='export_path', default=None,
17 | help='output export path')
18 | parser.add_argument('-p', dest='model_path', type=str, default='',
19 | help='the model path or huggingface path, such as: -p THUDM/chatglm-6b')
20 | parser.add_argument('--lora', dest='lora_path', default='',
21 | help='lora model path')
22 | parser.add_argument('-m', dest='model', default='chatglm6B',
23 | help='model name with(alpaca, baichuan7B, chatglm6B, moss)')
24 | parser.add_argument('-q', dest='q_bit', type=int,
25 | help='model quantization bit')
26 | args = parser.parse_args()
27 | return args
28 |
29 |
30 | def main(args=None):
31 | if not args: args = parse_args()
32 |
33 | quant_type_to_qbit = {
34 | QuantType.FP32: 32,
35 | QuantType.FP16: 16,
36 | QuantType.INT8: 8,
37 | QuantType.INT4: 4,
38 | }
39 | qbit_to_quant_type = {v: k for k, v in quant_type_to_qbit.items()}
40 | q_type = qbit_to_quant_type[args.q_bit]
41 | convert(args.model_path, args.export_path, q_type=q_type)
42 |
43 | if __name__ == "__main__":
44 | args = parse_args()
45 | main(args)
--------------------------------------------------------------------------------
/pyfastllm/fastllm/functions/__init__.py:
--------------------------------------------------------------------------------
1 | from .fastllm_ops import *
2 | from . import util
--------------------------------------------------------------------------------
/pyfastllm/fastllm/functions/custom_ops.py:
--------------------------------------------------------------------------------
1 | import triton as tl
--------------------------------------------------------------------------------
/pyfastllm/fastllm/functions/fastllm_ops.py:
--------------------------------------------------------------------------------
1 | import pyfastllm
2 |
3 | def embedding(inputs: pyfastllm.Tensor, embedding_weights:pyfastllm.Tensor):
4 | output = pyfastllm.Tensor()
5 | pyfastllm.embedding(inputs, embedding_weights, output)
6 | return output
7 |
8 | def rms_norm(inputs:pyfastllm.Tensor, weights: pyfastllm.Tensor, eps: float=1e-5):
9 | output = pyfastllm.Tensor()
10 | pyfastllm.rms_norm(inputs, weights, eps, output)
11 | return output
12 |
13 | def layer_norm(inputs: pyfastllm.Tensor,
14 | gamma: pyfastllm.Tensor,
15 | beta: pyfastllm.Tensor,
16 | axis:int=-1 ):
17 | output = pyfastllm.Tensor()
18 | pyfastllm.layer_norm(inputs, gamma, beta,axis, output)
19 | return output
20 |
21 | def linear(inputs: pyfastllm.Tensor,
22 | weights: pyfastllm.Tensor,
23 | bias: pyfastllm.Tensor=None):
24 | output = pyfastllm.Tensor()
25 | # print(weights)
26 | if not bias:
27 | bias = pyfastllm.Tensor()
28 |
29 | pyfastllm.linear(inputs, weights, bias, output)
30 | return output
31 |
32 | def matmul(inputs0: pyfastllm.Tensor,
33 | inputs1: pyfastllm.Tensor,
34 | alpha: pyfastllm.Tensor):
35 | output = pyfastllm.Tensor()
36 | pyfastllm.matmul(inputs0, inputs1, alpha, output)
37 | return output
38 |
39 | def attention(q: pyfastllm.Tensor,
40 | k: pyfastllm.Tensor,
41 | v: pyfastllm.Tensor,
42 | mask: pyfastllm.Tensor,
43 | group: int,
44 | scale: float,
45 | attentionType:int = 0):
46 | output = pyfastllm.Tensor()
47 | pyfastllm.attention(q, k, v, mask, group, scale, attentionType, output)
48 | return output
49 |
50 | def activation(inputs: pyfastllm.Tensor, axis=-1, activate_type="silu"):
51 | assert activate_type in ("softmax", "silu", "gelu", "swiglu")
52 | func = getattr(pyfastllm, activate_type)
53 |
54 | output = pyfastllm.Tensor()
55 | if activate_type == "softmax":
56 | func(inputs, axis, output)
57 | else:
58 | func(inputs, output)
59 | return output
60 |
61 | def cat_(inputs, cur_data, axis=1):
62 | pyfastllm.cat_direct(inputs, cur_data, axis)
63 |
64 | def mul(inputs: pyfastllm.Tensor, v: int):
65 | output = pyfastllm.Tensor()
66 | pyfastllm.mul(inputs, v, output)
67 | return output
68 |
69 | def add(input0: pyfastllm.Tensor, input1: pyfastllm.Tensor, v:int=1.0):
70 | output = pyfastllm.Tensor()
71 | output = pyfastllm.add(input0, input1, v)
72 | return output
73 |
74 | def permute(inputs: pyfastllm.Tensor, dims=None):
75 | output = pyfastllm.Tensor()
76 | pyfastllm.permute(inputs, dims, output)
77 | # pyfastllm.permute_(inputs, dims)
78 | return output
79 |
80 | def split(inputs: pyfastllm.Tensor, axis:int, start:int, end:int):
81 | output = pyfastllm.Tensor()
82 | pyfastllm.split(inputs, axis, start, end, output)
83 | return output
84 |
85 | def topk(logits:pyfastllm.Tensor, axis:int = 1):
86 | output = pyfastllm.Tensor()
87 | pyfastllm.topk(logits, axis, output)
88 | return output
89 |
90 | def load(filepath):
91 | state_dict = pyfastllm.WeightMap()
92 | state_dict.load(filepath)
93 | return state_dict
94 |
95 | def AttentionMask():
96 | pass
97 |
98 | def AlibiMask():
99 | pass
100 |
101 | def RotatePosition2D(data, pos_id, sin_data, cos_data, rotary_dim):
102 | return pyfastllm.rotateposition2D(data, pos_id, sin_data, cos_data, rotary_dim)
103 |
104 | def NearlyRotatePosition2D(data, pos_id, sin_data, cos_data, rotary_dim):
105 | return pyfastllm.nearlyrotateposition2D(data, pos_id, sin_data, cos_data, rotary_dim)
106 |
107 | def LlamaRotatePosition2D():
108 | pass
109 |
110 | def RepeatPenalty():
111 | pass
112 |
--------------------------------------------------------------------------------
/pyfastllm/fastllm/functions/numpy_ops.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from numba import cuda
3 | from numba import jit
4 |
5 | @jit(nopython=True)
6 | def rms_norm(inputs, weights, eps):
7 | channel = inputs.shape[-1]
8 | sqrt_mean = np.sqrt(np.sum(inputs**2)/channel + eps)
9 | return inputs / sqrt_mean *weights
10 |
11 | @jit(nopython=True)
12 | def layer_norm(inputs, gamma, beta, axis=-1):
13 | assert axis < len(inputs.shapes), "axis should less than inputs dims"
14 | channel = inputs.shape[axis]
15 | mean = np.mean(inputs, axis=axis)
16 | var = np.var(inputs, axis=axis)
17 |
18 | output = (inputs - mean) / var * gamma + beta
19 | return output
20 |
21 | # @jit
22 | def softmax(inputs, axis=None):
23 | maxv = inputs.max(axis, keepdims=True)
24 | exp_v = np.exp(inputs - maxv)
25 | exp_sum = np.sum(exp_v, axis=axis)
26 | return exp_v / exp_sum
27 |
28 | @jit(nopython=True)
29 | def silu(inputs, ):
30 | return inputs / (1 + np.exp(-inputs))
31 |
32 | @jit
33 | def swiglu(inputs, ):
34 | dim = inputs.shape[1] // 2
35 | for batch in range(inputs.shape[0]):
36 | return inputs[batch, :dim] / (1 + np.exp(-inputs[batch, :dim])) * inputs[batch, dim:]
37 |
38 | # @jit
39 | def linear(inputs, weights, bias):
40 | if len(inputs.shape) == 2:
41 | inputs = inputs[None, :]
42 | weights = weights[None, :]
43 |
44 | output = np.zeros(shape=[inputs.shape[0], inputs.shape[1], weights.shape[0]])
45 | for batch in range(inputs.shape[0]):
46 | output[batch] = np.matmul(inputs[batch], weights.T)
47 |
48 | if bias:
49 | output[batch] += bias[batch]
50 |
51 | return output
52 |
53 | # @jit
54 | def attention(q, k, v, mask=None, group=None, scale=None):
55 | print("shape:", q.shape)
56 | if len(q.shape) == 2:
57 | q = q[None, :]
58 | k = k[None, :]
59 | v = v[None, :]
60 | # mask = mask[None, :]
61 |
62 | attn = np.zeros_like(q)
63 | for batch in range(q.shape[0]):
64 | qk = softmax(q[batch] @ k[batch].T * scale, axis=-1)
65 | attn[batch, :, :] = qk @ v[batch]
66 | return attn
--------------------------------------------------------------------------------
/pyfastllm/fastllm/functions/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pyfastllm
3 |
4 | def diff(dataA, dataB):
5 | mae = np.max(np.abs(dataA - dataB))
6 | print('max abs err is ', mae)
7 | return mae
8 |
9 | def to_tensor(data):
10 | if not isinstance(data, np.ndarray):
11 | return None
12 | return pyfastllm.from_numpy(data)
13 |
14 | def to_numpy(data):
15 | if not isinstance(data, pyfastllm.Tensor):
16 | return None
17 |
18 | return np.array(data, copy=False)
--------------------------------------------------------------------------------
/pyfastllm/fastllm/hub/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TylunasLi/fastllm/f3cfc63dc9efb995670461b6e0d64d2c55cfa0a3/pyfastllm/fastllm/hub/__init__.py
--------------------------------------------------------------------------------
/pyfastllm/fastllm/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_module import Module
2 | from .modules import *
3 |
--------------------------------------------------------------------------------
/pyfastllm/fastllm/nn/base_module.py:
--------------------------------------------------------------------------------
1 | import pyfastllm
2 | from typing import Any
3 | from abc import abstractmethod
4 |
5 | class Module():
6 | def __init__(self) -> None:
7 | pass
8 |
9 | def __call__(self, *args: Any, **kwds: Any) -> Any:
10 | return self.forward(*args, **kwds)
11 |
12 | @abstractmethod
13 | def forward(self, ):
14 | pass
15 |
16 | def _init_weight(self, ):
17 | pass
18 |
19 | import numpy as np
20 | from typing import Union, Sequence
21 | from pyfastllm import Tensor
22 | from ..functions import util
23 |
24 | class Parameter(object):
25 | _DEFAULT_DTYPE = pyfastllm.float32
26 |
27 | def __init__(self,
28 | value: Union[np.ndarray, None] = None,
29 | shape: Sequence[int] = None,
30 | dtype: Union[pyfastllm.DataType, None] = None):
31 | dtype = self._DEFAULT_DTYPE if dtype is None else dtype
32 | if value is None:
33 | assert isinstance(shape, (list, tuple))
34 | self._value = pyfastllm.Tensor()
35 | """
36 | value = np.zeros(shape=shape, dtype=np.float32)
37 |
38 | if len(shape) == 2:
39 | v_range = np.sqrt(6) / np.sqrt(shape[0] + shape[1])
40 | else:
41 | v_range = 0.1
42 |
43 | # value ~ U[-1, 1]
44 | value = np.random.random(size=shape) * 2 - 1
45 | value = np.array(value, dtype=np.float32)
46 | # value ~ U[-v_range, v_range]
47 | value *= v_range
48 | """
49 | else:
50 | self._value = util.to_tensor(value)
51 |
52 | @property
53 | def value(self) -> Tensor:
54 | if isinstance(self._value, np.ndarray):
55 | self._value = util.to_tensor(self._value)
56 |
57 | return self._value
58 |
59 | @value.setter
60 | def value(self, v: np.ndarray):
61 | assert isinstance(v, np.ndarray) or isinstance(v, pyfastllm.Tensor)
62 | # assert v.shape == self._value.shape, \
63 | # ('The value updated is not the same shape as the original. ', \
64 | # f'Updated: {v.shape}, original: {self._value.shape}')
65 | self._value = v
66 |
--------------------------------------------------------------------------------
/pyfastllm/fastllm/nn/modules.py:
--------------------------------------------------------------------------------
1 | from .base_module import Module
2 | from .base_module import Parameter
3 | from ..functions import fastllm_ops as F
4 | from ..functions import util
5 | import numpy as np
6 |
7 | class Linear(Module):
8 | def __init__(self, in_dim, out_dim, bias=False) -> None:
9 | self.has_bias = bias
10 | self.weights = Parameter(shape=(out_dim, in_dim))
11 | self.bias = None
12 |
13 | if bias:
14 | self.bias = Parameter(shape=(out_dim, ))
15 |
16 | super().__init__()
17 |
18 | def forward(self, x):
19 | if self.has_bias:
20 | return F.linear(x, self.weights.value, self.bias.value)
21 |
22 | return F.linear(x, self.weights.value)
23 |
24 | class SiLU(Module):
25 | def __init__(self) -> None:
26 | super().__init__()
27 |
28 | def forward(self, x, axis=-1):
29 | return F.activation(x, axis=axis, activate_type='silu')
30 |
31 | class SwiGLU(Module):
32 | def __init__(self) -> None:
33 | super().__init__()
34 |
35 | def forward(self, inputs):
36 | return F.activation(inputs=inputs, activate_type="swiglu")
37 |
38 | class Embedding(Module):
39 | def __init__(self, vocab_size, embed_dim) -> None:
40 | super().__init__()
41 | self.vocab_size = vocab_size
42 | self.embed_dim = embed_dim
43 | self.weights = Parameter(shape=[vocab_size, embed_dim])
44 |
45 | def forward(self, inputs):
46 | return F.embedding(inputs, self.weights.value)
47 |
48 | class RMSNorm(Module):
49 | def __init__(self, dim=4096, eps=1e-5) -> None:
50 | super().__init__()
51 | self.weights = Parameter(shape=[dim, ])
52 | self.eps = eps
53 |
54 | def forward(self, inputs):
55 | return F.rms_norm(inputs, self.weights.value, eps=self.eps)
56 |
57 | class Attention(Module):
58 | def __init__(self) -> None:
59 | super().__init__()
60 |
61 | def forward(self, q, k, v, mask, group, scale):
62 | return F.attention(q, k, v, mask, group=group, scale=scale, attentionType=0)
63 |
64 | class RoPE(Module):
65 | def __init__(self, rotary_dim=128) -> None:
66 | super().__init__()
67 | self.rotary_dim = rotary_dim
68 | self.sin_data, self.cos_data = self._get_sin_cos_data()
69 | self.sin_data = util.to_tensor(self.sin_data)
70 | self.cos_data = util.to_tensor(self.cos_data)
71 |
72 | def _get_sin_cos_data(self, base=1e4, seq_len=32768, dim=128):
73 | inv_freq = 1.0 / (base ** (np.arange(0, dim, 2) / dim))
74 | t = np.arange(0, seq_len)
75 | freqs = np.einsum('i,j->ij', t, inv_freq)
76 | emb = np.concatenate((freqs, freqs), axis=-1)
77 | return np.sin(emb), np.cos(emb)
78 |
79 | def forward(self, data, pos_id):
80 | return F.RotatePosition2D(data, pos_id, self.sin_data, self.cos_data, self.rotary_dim)
81 |
82 | class NearlyRoPE(RoPE):
83 | def __init__(self, rotary_dim=64) -> None:
84 | super().__init__(rotary_dim)
85 |
86 | def forward(self, data, pos_id):
87 | outputs = F.NearlyRotatePosition2D(data, pos_id, self.sin_data, self.cos_data, self.rotary_dim)
88 | return outputs
--------------------------------------------------------------------------------
/pyfastllm/fastllm/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
2 |
3 | from .quantizer import QuantType
4 | from .converter import ChatglmConverter, BaichuanConverter, QwenConverter, MossConverter
5 |
6 | def convert(hf_model_name_or_path:str, save_path:str, q_type=QuantType.INT4):
7 | config = AutoConfig.from_pretrained(hf_model_name_or_path, trust_remote_code=True)
8 | tokenizer = AutoTokenizer.from_pretrained(hf_model_name_or_path, trust_remote_code=True)
9 |
10 | if "Baichuan" in config.architectures:
11 | model = AutoModelForCausalLM.from_pretrained(hf_model_name_or_path, trust_remote_code=True).cpu().eval()
12 | converter = BaichuanConverter(model=model, tokenizer=tokenizer, q_type=q_type)
13 | elif "ChatGLM" in config.architectures:
14 | model = AutoModel.from_pretrained(hf_model_name_or_path, trust_remote_code=True).cpu().eval()
15 | converter = ChatglmConverter(model=model, tokenizer=tokenizer, q_type=q_type)
16 | elif "Qwen" in config.architectures:
17 | model = AutoModelForCausalLM.from_pretrained(hf_model_name_or_path, trust_remote_code=True, fp16=True).cpu().eval()
18 | converter = QwenConverter(model=model, tokenizer=tokenizer, q_type=q_type)
19 | elif "Moss" in config.architectures:
20 | model = AutoModelForCausalLM.from_pretrained(hf_model_name_or_path, trust_remote_code=True).cpu().eval()
21 | converter = MossConverter(model=model, tokenizer=tokenizer, q_type=q_type)
22 | else:
23 | raise NotImplementedError(f"Unsupport model: {config.architectures}")
24 |
25 | converter.dump(save_path)
26 |
--------------------------------------------------------------------------------
/pyfastllm/fastllm/utils/quantizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from enum import Enum
3 | from .writer import Writer
4 |
5 | class QuantType(Enum):
6 | FP32 = 0
7 | FP16 = 7
8 | INT8 = 3
9 | INT4 = 8
10 |
11 | class Quantizer():
12 | quant_bit = {QuantType.FP16: 16, QuantType.INT8: 8, QuantType.INT4: 4}
13 |
14 | def __init__(self, quant_type:QuantType, symmetry=True) -> None:
15 | self.quant_type = quant_type
16 | self.q_bit = self.quant_bit[quant_type]
17 |
18 | self.up_bound = (2**(self.q_bit-1)) -1
19 | self.low_bound = -(2 ** (self.q_bit-1))
20 |
21 | self.symmetry = symmetry
22 |
23 | # 范围小,单数据精度高,适用于分布集中场景
24 | def asymquantize(self, data:np.ndarray):
25 | c_min = np.expand_dims(data.min(axis=-1), -1)
26 | c_max = np.expand_dims(data.max(axis=-1), -1)
27 | c_scale = (c_max - c_min) / (self.up_bound - self.low_bound)
28 | c_zero = np.round(0.0 - c_min / c_scale).clip(0, self.up_bound - self.low_bound)
29 | c_min = -c_scale * c_zero
30 |
31 | q_data = (data - c_min)/ c_scale
32 |
33 | if self.quant_type == QuantType.FP32:
34 | q_data = data.astype(np.float32)
35 | elif self.quant_type == QuantType.FP16:
36 | q_data = data.astype(np.float16)
37 | elif self.quant_type == QuantType.INT8:
38 | q_data = (q_data + 0.5).astype(np.int8).clip(0, 255).astype(np.uint8)
39 | elif self.quant_type == QuantType.INT4:
40 | q_data = (q_data + 0.5).astype(np.int8).clip(0, 15).astype(np.uint8)
41 | q_data = q_data[:, 0::2] * 16 + q_data[:, 1::2]
42 | else:
43 | raise NotImplementedError(f"unsupport quant type")
44 |
45 | self.c_min = c_min
46 | self.c_max = c_max
47 | self.c_scale = c_scale
48 | self.c_zero = c_zero
49 | self.quant_data = q_data
50 |
51 | return q_data
52 |
53 | # 范围大、单数据精度低,适用分布较分散场景
54 | def symquantize(self, data:np.ndarray):
55 | c_min = np.expand_dims(-np.abs(data).max(axis = -1), -1)
56 | c_max = np.expand_dims(np.abs(data).max(axis = -1), -1)
57 | c_scale = c_max / self.up_bound
58 | c_min = c_scale * self.low_bound
59 |
60 | q_data = (data - c_min) / c_scale
61 |
62 | if self.quant_type == QuantType.FP32:
63 | q_data = data.astype(np.float32)
64 | elif self.quant_type == QuantType.FP16:
65 | q_data = data.astype(np.float16)
66 | elif self.quant_type == QuantType.INT8:
67 | q_data = (q_data + 0.5).astype(np.int8).clip(1, 255).astype(np.uint8)
68 | elif self.quant_type == QuantType.INT4:
69 | q_data = (q_data + 0.5).astype(np.int8).clip(0, 15).astype(np.uint8)
70 | q_data = q_data[:, 0::2] * 16 + q_data[:, 1::2]
71 | else:
72 | raise NotImplementedError(f"unsupport quant type")
73 |
74 | self.c_min = c_min
75 | self.c_max = c_max
76 | self.c_scale = c_scale
77 | self.quant_data = q_data
78 |
79 | return q_data
80 |
81 | def quantize(self, data:np.ndarray):
82 | if self.symmetry:
83 | return self.symquantize(data)
84 | else:
85 | return self.asymquantize(data)
86 |
87 | def dequantize(self, ):
88 | if not self.c_scale:
89 | raise ValueError
90 |
91 | data = self.quant_data * self.c_scale + self.c_min
92 | data = data.astype(np.float32)
93 |
94 | return data
95 |
96 | def dump(self, wt:Writer):
97 | wt.write(self.quant_type.value)
98 | if self.quant_type in (QuantType.INT4, QuantType.INT8):
99 | wt.write(0)
100 | for i in range(self.c_min.shape[0]):
101 | wt.write(float(self.c_min[i][0]))
102 | wt.write(float(self.c_max[i][0]))
103 |
104 | wt.fd.write(self.quant_data.data)
105 |
106 |
--------------------------------------------------------------------------------
/pyfastllm/fastllm/utils/writer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import struct
3 | from enum import Enum
4 |
5 | class QuantType(Enum):
6 | FP32 = 0
7 | FP16 = 7
8 | INT8 = 3
9 | INT4 = 8
10 |
11 | def write_int8(fo, v):
12 | c_max = np.expand_dims(np.abs(v).max(axis = -1), -1).clip(0.1, 1e100)
13 | c_scale = c_max / 127.0
14 | v = (v / c_scale + 128.5).clip(1, 255).astype(np.uint8)
15 | fo.write(struct.pack('i', 3))
16 | fo.write(struct.pack('i', 0))
17 | for i in range(c_max.shape[0]):
18 | fo.write(struct.pack('f', -c_max[i][0]))
19 | fo.write(struct.pack('f', c_max[i][0]))
20 | fo.write(v.data)
21 |
22 | def write_int4(fo, v):
23 | c_min = np.expand_dims(-np.abs(v).max(axis = -1), -1)
24 | c_max = np.expand_dims(np.abs(v).max(axis = -1), -1)
25 | c_scale = c_max / 7.0
26 | c_min = c_scale * -8.0
27 | v = (v - c_min) / c_scale
28 | v = (v + 0.5).astype(np.int8).clip(0, 15).astype(np.uint8)
29 | v = v[:, 0::2] * 16 + v[:, 1::2]
30 | fo.write(struct.pack('i', 8))
31 | fo.write(struct.pack('i', 0))
32 | for i in range(c_min.shape[0]):
33 | fo.write(struct.pack('f', c_min[i][0]))
34 | fo.write(struct.pack('f', c_max[i][0]))
35 | fo.write(v.data)
36 |
37 | class Writer():
38 | def __init__(self, outpath) -> None:
39 | self.fd = open(outpath, 'wb')
40 |
41 | def __del__(self, ):
42 | if not self.fd.closed:
43 | self.fd.close()
44 |
45 | def write(self, value):
46 | if isinstance(value, int):
47 | self.fd.write(struct.pack('i', value))
48 | elif isinstance(value, float):
49 | self.fd.write(struct.pack('f', value))
50 | elif isinstance(value, str):
51 | self.write_str(value)
52 | elif isinstance(value, bytes):
53 | self.write_bytes(value)
54 | elif isinstance(value, list):
55 | self.write_list(value)
56 | elif isinstance(value, dict):
57 | self.write_dict(value)
58 | elif isinstance(value, np.ndarray):
59 | self.write_tensor(value)
60 | else:
61 | raise NotImplementedError(f"Unsupport data type: {type(value)}")
62 |
63 | def write_str(self, s):
64 | self.write(len(s))
65 | self.fd.write(s.encode())
66 |
67 | def write_bytes(self, s):
68 | self.write(len(s))
69 | for c in s: self.write(int(c))
70 |
71 | def write_list(self, data):
72 | self.write(len(data))
73 | for d in data: self.write(d)
74 |
75 | def write_dict(self, data):
76 | self.write(len(data))
77 | for key in data:
78 | self.write_str(key)
79 | self.write(data[key])
80 |
81 | def write_tensor(self, data, data_type:QuantType=QuantType.FP32):
82 | self.write(list(data.shape))
83 | if data_type == QuantType.INT4:
84 | write_int4(self.fd, data)
85 | elif data_type == QuantType.INT8:
86 | write_int8(self.fd, data)
87 | else:
88 | self.write(int(data_type.value))
89 | self.fd.write(data.data)
90 |
91 |
--------------------------------------------------------------------------------
/pyfastllm/install.sh:
--------------------------------------------------------------------------------
1 | rm -rf build/ && rm -rf dist/
2 | python3 setup.py sdist bdist_wheel
3 | pip install dist/*.whl --force-reinstall
4 | # python3 examples/test_ops.py # coredump when run with cuda backend
--------------------------------------------------------------------------------
/src/device.cpp:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 6/13/23.
3 | //
4 |
5 | #include "utils.h"
6 | #include "device.h"
7 |
8 | namespace fastllm {
9 | bool BaseDevice::Malloc(void **ret, Data &data) {
10 | return Malloc(ret, data.expansionBytes);
11 | }
12 |
13 | bool BaseDevice::CopyDataFromCPU(Data &data) {
14 | AssertInFastLLM(data.cpuData != nullptr, "Copy data to " + this->deviceName + " from cpu failed: cpu's data is null.\n");
15 | AssertInFastLLM(data.deviceData == nullptr, "Copy data to " + this->deviceName + " from cpu failed: device's data is not null.\n");
16 | Malloc(&data.deviceData, data.expansionBytes);
17 | bool ret = CopyDataFromCPU(data.cudaData, data.cpuData, data.expansionBytes);
18 | delete[] data.cpuData;
19 | data.cpuData = nullptr;
20 | return ret;
21 | }
22 |
23 | bool BaseDevice::CopyDataToCPU(Data &data) {
24 | AssertInFastLLM(data.cpuData == nullptr, "Copy data from " + this->deviceName + " to cpu failed: cpu's data is not null.\n");
25 | AssertInFastLLM(data.deviceData != nullptr, "Copy data from " + this->deviceName + " to cpu failed: device's data is null.\n");
26 | data.cpuData = new uint8_t [data.expansionBytes];
27 | bool ret = CopyDataToCPU(data.cpuData, data.deviceData, data.expansionBytes);
28 | this->Free(data.deviceData);
29 | data.deviceData = nullptr;
30 | return ret;
31 | }
32 |
33 | bool BaseDevice::CanRun(const std::string &opType, const fastllm::DataDict &datas,
34 | const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
35 | if (this->ops.find(opType) == this->ops.end()) {
36 | return false;
37 | }
38 | return this->ops[opType]->CanRun(opType, datas, floatParams, intParams);
39 | }
40 |
41 | void BaseDevice::Reshape(const std::string &opType, const fastllm::DataDict &datas,
42 | const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
43 | this->ops[opType]->Reshape(opType, datas, floatParams, intParams);
44 | }
45 |
46 | void BaseDevice::Run(const std::string &opType, const fastllm::DataDict &datas,
47 | const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
48 | this->ops[opType]->Run(opType, datas, floatParams, intParams);
49 | }
50 |
51 | bool BaseOperator::CanRun(const std::string &opType, const DataDict &datas, const FloatDict &floatParams,
52 | const IntDict &intParams) {
53 | return true;
54 | }
55 |
56 | void BaseOperator::Reshape(const std::string &opType, const DataDict &datas, const FloatDict &floatParams,
57 | const IntDict &intParams) {
58 | if (datas.find("output") == datas.end()) {
59 | return;
60 | }
61 | // 默认的Reshape,把output和input变成一样的形状
62 | Data *inputs = (datas.find("input")->second);
63 | Data *outputs = (datas.find("output")->second);
64 | if (inputs == outputs) {
65 | return;
66 | }
67 | outputs[0].dataType = inputs[0].dataType;
68 | outputs[0].Resize(inputs[0].dims);
69 | }
70 |
71 | void BaseBatchOperator::Reshape(const std::string &opType, const fastllm::DataDict &datas,
72 | const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
73 | if (datas.find("output") == datas.end()) {
74 | return;
75 | }
76 | // 默认的Reshape,把output和input变成一样的形状
77 | Data **inputs = (Data**)(datas.find("input")->second);
78 | Data **outputs = (Data**)(datas.find("output")->second);
79 | if (inputs == outputs) {
80 | return;
81 | }
82 |
83 | int batch = 1;
84 | if (intParams.find("input___batch") != intParams.end()) {
85 | batch = intParams.find("input___batch")->second;
86 | }
87 |
88 | for (int i = 0; i < batch; i++) {
89 | outputs[i]->dataType = inputs[i]->dataType;
90 | outputs[i]->Resize(inputs[i]->dims);
91 | }
92 | }
93 | }
--------------------------------------------------------------------------------
/src/executor.cpp:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 6/13/23.
3 | //
4 |
5 | #include "utils.h"
6 |
7 | #include "executor.h"
8 |
9 | #include "devices/cpu/cpudevice.h"
10 |
11 | #ifdef USE_CUDA
12 | #include "devices/cuda/cudadevice.h"
13 | #include "devices/cuda/fastllm-cuda.cuh"
14 | #endif
15 |
16 | namespace fastllm {
17 | Executor::Executor() {
18 | this->devices.clear();
19 | #ifdef USE_CUDA
20 | this->devices.push_back((BaseDevice*) new CudaDevice());
21 | #endif
22 | this->devices.push_back((BaseDevice*) new CpuDevice());
23 | }
24 |
25 | Executor::~Executor() {
26 | for (int i = 0; i < devices.size(); i++) {
27 | delete devices[i];
28 | }
29 | }
30 |
31 | void Executor::ClearDevices() {
32 | this->devices.clear();
33 | }
34 |
35 | void Executor::AddDevice(fastllm::BaseDevice *device) {
36 | this->devices.push_back(device);
37 | }
38 |
39 | void Executor::SetFirstDevice(const std::string &device) {
40 | auto temp = this->devices;
41 | this->devices.clear();
42 | for (int i = 0; i < temp.size(); i++) {
43 | if (StartWith(device, temp[i]->deviceType)) {
44 | this->devices.push_back(temp[i]);
45 | this->devices.back()->deviceIds = ParseDeviceIds(device, temp[i]->deviceType);
46 | }
47 | }
48 | for (int i = 0; i < temp.size(); i++) {
49 | if (!StartWith(device, temp[i]->deviceType)) {
50 | this->devices.push_back(temp[i]);
51 | }
52 | }
53 | }
54 |
55 | std::vector Executor::GetDeviceIds(const std::string &device) {
56 | for (int i = 0; i < devices.size(); i++) {
57 | if (StartWith(devices[i]->deviceType, device)) {
58 | return devices[i]->deviceIds;
59 | }
60 | }
61 | return {0};
62 | }
63 |
64 | void Executor::Run(const std::string &opType, const fastllm::DataDict &datas, const fastllm::FloatDict &floatParams,
65 | const fastllm::IntDict &intParams) {
66 | auto st = std::chrono::system_clock::now();
67 | bool lockInCPU = false;
68 | for (auto &it: datas) {
69 | if (intParams.find(it.first + "___batch") != intParams.end()) {
70 | int batch = intParams.find(it.first + "___batch")->second;
71 | for (int i = 0; i < batch; i++) {
72 | lockInCPU |= (((Data**)it.second)[i] && ((Data**)it.second)[i]->lockInCPU);
73 | }
74 | } else {
75 | lockInCPU |= (it.second && it.second->lockInCPU);
76 | }
77 | }
78 | for (auto device: devices) {
79 | if (lockInCPU && device->deviceType != "cpu") {
80 | continue;
81 | }
82 | if (device->CanRun(opType, datas, floatParams, intParams)) {
83 | #ifdef USE_CUDA
84 | if (device->deviceType == "cuda" && device->deviceIds.size() > 0) {
85 | FastllmCudaSetDevice(device->deviceIds[0]);
86 | }
87 | #endif
88 | for (auto &it: datas) {
89 | if (intParams.find(it.first + "___batch") != intParams.end()) {
90 | int batch = intParams.find(it.first + "___batch")->second;
91 | for (int i = 0; i < batch; i++) {
92 | if (((Data**)it.second)[i]) {
93 | ((Data**)it.second)[i]->ToDevice((void *) device);
94 | }
95 | }
96 | } else {
97 | if (it.second) {
98 | it.second->ToDevice((void *) device);
99 | }
100 | }
101 | }
102 | device->Reshape(opType, datas, floatParams, intParams);
103 | device->Run(opType, datas, floatParams, intParams);
104 | break;
105 | }
106 | }
107 | float spend = GetSpan(st, std::chrono::system_clock::now());
108 | profiler[opType] += spend;
109 | }
110 |
111 | void Executor::ClearProfiler() {
112 | profiler.clear();
113 | }
114 |
115 | void Executor::PrintProfiler() {
116 | float sum = 0.0;
117 | for (auto &it : profiler) {
118 | printf("%s spend %f\n", it.first.c_str(), it.second);
119 | sum += it.second;
120 | }
121 | printf("total spend %f\n", sum);
122 | }
123 | }
--------------------------------------------------------------------------------
/test/cmmlu/README.md:
--------------------------------------------------------------------------------
1 | CMMLU是一个综合性的中文评估基准,专门用于评估语言模型在中文语境下的知识和推理能力。
2 |
3 | 项目官网网址为: https://github.com/haonan-li/CMMLU
4 |
5 | 本目录下的chatglm.py程序会调用fastllm框架进行测试
6 |
7 | 测试步骤如下:
8 |
9 | - 1. 克隆CMMLU仓库
10 |
11 | ``` sh
12 | git clone https://github.com/haonan-li/CMMLU
13 | ```
14 |
15 | - 2. 测试
16 |
17 | ```
18 | # chatglm测试脚本
19 | # 这里model_name_or_path可以使用ChatGLM2-6b官方的原始模型、int4模型,dtype支持float16, int8, int4
20 | python3 chatglm.py --model_name_or_path 此处填写模型路径 --save_dir 此处填写结果保存路径 --dtype float16
21 |
22 | # baichuan13b测试脚本
23 | # 这里model_name_or_path可以使用Baichuan13B-Base或Baichuan13B-Chat官方的原始模型,dtype支持float16, int8, int4
24 | python3 baichuan.py --model_name_or_path 此处填写模型路径 --save_dir 此处填写结果保存路径 --dtype float16
25 | ```
26 |
27 | 测试数据较多,过程比较漫长,测试中途可以通过以下命令查看已完成的测试成绩
28 |
29 | ```
30 | python3 eval.py 此处填写结果保存路径
31 | ```
32 |
33 | - 3. 参考结果
34 |
35 | | 模型 | Data精度 | Shot | CMMLU分数 |
36 | |-----------------------: |-------- |----------|-----------|
37 | | ChatGLM2-6b-fp16 | float32 |0 | 50.16 |
38 | | ChatGLM2-6b-int8 | float32 |0 | 50.14 |
39 | | ChatGLM2-6b-int4 | float32 |0 | 49.63 |
40 | | QWen-7b-Base-fp16 | float32 |0 | 57.43 |
41 | | QWen-7b-Chat-fp16 | float32 |0 | 54.82 |
42 | | Baichuan-13b-Base-int8 | float32 |5 | 55.12 |
43 | | Baichuan-13b-Base-int4 | float32 |5 | 52.22 |
44 |
--------------------------------------------------------------------------------
/test/cmmlu/baichuan.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import argparse
5 | from CMMLU.src.mp_utils import choices, format_example, gen_prompt, softmax, run_eval
6 |
7 | from peft import PeftModel
8 | from transformers import LlamaForCausalLM, LlamaTokenizer
9 | from transformers import AutoModelForCausalLM, AutoTokenizer
10 |
11 | def eval(model, tokenizer, subject, dev_df, test_df, num_few_shot, max_length, cot):
12 | choice_ids = [tokenizer.convert_tokens_to_ids(choice) for choice in choices]
13 | cors = []
14 | all_conf = []
15 | all_preds = []
16 | answers = choices[: test_df.shape[1] - 2]
17 |
18 | for i in range(test_df.shape[0]):
19 | prompt_end = format_example(test_df, i, subject, include_answer=False)
20 | prompt = gen_prompt(dev_df=dev_df,
21 | subject=subject,
22 | prompt_end=prompt_end,
23 | num_few_shot=num_few_shot,
24 | tokenizer=tokenizer,
25 | max_length=max_length)
26 | label = test_df.iloc[i, test_df.shape[1] - 1]
27 | logits = model.response_logits(prompt, tokenizer = tokenizer);
28 | sel = 0;
29 | for j in range(4):
30 | if (logits[choice_ids[j]] > logits[choice_ids[sel]]):
31 | sel = j;
32 | pred = choices[sel];
33 | conf = [logits[choice_ids[j]] for j in range(4)]
34 | all_preds += pred
35 | all_conf.append(conf)
36 | cors.append(pred == label)
37 | print(i, np.mean(cors))
38 |
39 | acc = np.mean(cors)
40 | print("Average accuracy {:.3f} - {}".format(acc, subject))
41 | return acc, all_preds, all_conf
42 |
43 |
44 | if __name__ == "__main__":
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument("--model_name_or_path", type=str, default="")
47 | parser.add_argument("--lora_weights", type=str, default="")
48 | parser.add_argument("--data_dir", type=str, default="./CMMLU/data")
49 | parser.add_argument("--save_dir", type=str, default="../results/not_specified")
50 | parser.add_argument("--num_few_shot", type=int, default=0)
51 | parser.add_argument("--max_length", type=int, default=2048)
52 | parser.add_argument("--load_in_8bit", action='store_true')
53 | parser.add_argument("--dtype", type=str, default="float16")
54 | parser.add_argument("--with_conf", action='store_true')
55 | parser.add_argument("--cot", action='store_true')
56 | args = parser.parse_args()
57 |
58 | # TODO: better handle
59 | tokenizer_class = LlamaTokenizer if 'llama' in args.model_name_or_path else AutoTokenizer
60 | model_class = LlamaForCausalLM if 'llama' in args.model_name_or_path else AutoModelForCausalLM
61 | tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, trust_remote_code=True)
62 | model = model_class.from_pretrained(args.model_name_or_path,
63 | trust_remote_code=True,
64 | load_in_8bit=args.load_in_8bit,
65 | torch_dtype=torch.float16,
66 | device_map="cpu"
67 | )
68 | if args.lora_weights != "":
69 | model = PeftModel.from_pretrained(
70 | model,
71 | args.lora_weights,
72 | torch_dtype=torch.float16,
73 | )
74 |
75 | from fastllm_pytools import llm;
76 | model = llm.from_hf(model, tokenizer, dtype = args.dtype);
77 | model.direct_query = True;
78 |
79 | run_eval(model, tokenizer, eval, args)
80 |
--------------------------------------------------------------------------------
/test/cmmlu/chatglm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import argparse
5 | from CMMLU.src.mp_utils import choices, format_example, gen_prompt, softmax, run_eval
6 | from transformers import AutoModel, AutoTokenizer
7 | import threading
8 |
9 | def chat(model, tokenizer, prompt, output_list, idx):
10 | pred, history = model.chat(tokenizer, prompt, history=[], max_length = 5);
11 | if pred[0] not in choices:
12 | pred, history = model.chat(tokenizer, prompt, history=[], max_length = 1000);
13 | output_list[idx] = pred;
14 |
15 | def eval_chat_multithread(model, tokenizer, subject, dev_df, test_df, num_few_shot, max_length, cot):
16 | cors = []
17 | all_preds = []
18 | answers = choices[: test_df.shape[1] - 2]
19 |
20 | batch_num = 64;
21 | output_list = ["" for i in range(test_df.shape[0])];
22 | ths = [None for i in range(test_df.shape[0])];
23 |
24 | for j in range(0, test_df.shape[0], batch_num):
25 | cur_len = min(test_df.shape[0] - j, batch_num);
26 | for i in range(j, j + cur_len):
27 | prompt_end = format_example(test_df, i, subject, include_answer=False, cot=cot)
28 | prompt = gen_prompt(dev_df=dev_df,
29 | subject=subject,
30 | prompt_end=prompt_end,
31 | num_few_shot=num_few_shot,
32 | tokenizer=tokenizer,
33 | max_length=max_length,
34 | cot=cot)
35 | ths[i] = threading.Thread(target = chat, args=(model, tokenizer, prompt, output_list, i));
36 | ths[i].start();
37 | for i in range(j, j + cur_len):
38 | ths[i].join();
39 | pred = output_list[i];
40 | label = test_df.iloc[i, test_df.shape[1] - 1]
41 | if pred and pred[0] in choices:
42 | cors.append(pred[0] == label);
43 | all_preds.append(pred.replace("\n", ""))
44 | print(i, test_df.shape[0], np.mean(cors))
45 | acc = np.mean(cors)
46 | print("Average accuracy {:.3f} - {}".format(acc, subject))
47 | print("{} results, {} inappropriate formated answers.".format(len(cors), len(all_preds)-len(cors)))
48 | return acc, all_preds, None
49 |
50 | if __name__ == "__main__":
51 | parser = argparse.ArgumentParser()
52 | parser.add_argument("--model_name_or_path", type=str, default="")
53 | parser.add_argument("--lora_weights", type=str, default="")
54 | parser.add_argument("--data_dir", type=str, default="./CMMLU/data")
55 | parser.add_argument("--save_dir", type=str, default="./results/ChatGLM2-6B")
56 | parser.add_argument("--num_few_shot", type=int, default=0)
57 | parser.add_argument("--max_length", type=int, default=2048)
58 | parser.add_argument("--dtype", type=str, default="float16")
59 | parser.add_argument("--cot", action='store_true')
60 | args = parser.parse_args()
61 |
62 | # Initialize models
63 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True,)
64 | model = AutoModel.from_pretrained(args.model_name_or_path, trust_remote_code=True).cpu()
65 |
66 | from fastllm_pytools import llm;
67 | model = llm.from_hf(model, tokenizer, dtype = args.dtype);
68 |
69 | # model.save("/root/test.flm");
70 | # Always use Chat-style evaluation
71 | run_eval(model, tokenizer, eval_chat_multithread, args)
72 |
--------------------------------------------------------------------------------
/test/cmmlu/eval.py:
--------------------------------------------------------------------------------
1 | import CMMLU.src.mp_utils as mp
2 | import sys
3 | print(mp.get_results(sys.argv[1]))
4 |
--------------------------------------------------------------------------------
/test/cmmlu/qwen.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import argparse
5 | import threading
6 | from CMMLU.src.mp_utils import choices, format_example, gen_prompt, softmax, run_eval
7 |
8 | from peft import PeftModel
9 | from transformers import AutoModelForCausalLM, AutoTokenizer
10 | from transformers.generation import GenerationConfig
11 |
12 |
13 | def chat(model, tokenizer, prompt, output_list, idx):
14 | pred, history = model.chat(tokenizer, prompt, history=[], max_length = 5)
15 | if pred[0] not in choices:
16 | pred, history = model.chat(tokenizer, prompt, history=[], max_length = 1000)
17 | output_list[idx] = pred
18 |
19 | def eval_chat_multithread(model, tokenizer, subject, dev_df, test_df, num_few_shot, max_length, cot):
20 | cors = []
21 | all_preds = []
22 | answers = choices[: test_df.shape[1] - 2]
23 |
24 | batch_num = 1
25 | output_list = ["" for i in range(test_df.shape[0])]
26 | ths = [None for i in range(test_df.shape[0])]
27 |
28 | for j in range(0, test_df.shape[0], batch_num):
29 | cur_len = min(test_df.shape[0] - j, batch_num)
30 | for i in range(j, j + cur_len):
31 | prompt_end = format_example(test_df, i, subject, include_answer=False, cot=cot)
32 | prompt = gen_prompt(dev_df=dev_df,
33 | subject=subject,
34 | prompt_end=prompt_end,
35 | num_few_shot=num_few_shot,
36 | tokenizer=tokenizer,
37 | max_length=max_length,
38 | cot=cot)
39 | ths[i] = threading.Thread(target = chat, args=(model, tokenizer, prompt, output_list, i))
40 | ths[i].start()
41 | for i in range(j, j + cur_len):
42 | ths[i].join()
43 | pred = output_list[i]
44 | label = test_df.iloc[i, test_df.shape[1] - 1]
45 | if pred and pred[0] in choices:
46 | cors.append(pred[0] == label)
47 | all_preds.append(pred.replace("\n", ""))
48 | print(i, test_df.shape[0], np.mean(cors))
49 | acc = np.mean(cors)
50 | print("Average accuracy {:.3f} - {}".format(acc, subject))
51 | print("{} results, {} inappropriate formated answers.".format(len(cors), len(all_preds)-len(cors)))
52 | return acc, all_preds, None
53 |
54 |
55 | if __name__ == "__main__":
56 | parser = argparse.ArgumentParser()
57 | parser.add_argument("--model_name_or_path", type=str, default="")
58 | parser.add_argument("--lora_weights", type=str, default="")
59 | parser.add_argument("--data_dir", type=str, default="./CMMLU/data")
60 | parser.add_argument("--save_dir", type=str, default="../results/not_specified")
61 | parser.add_argument("--num_few_shot", type=int, default=0)
62 | parser.add_argument("--max_length", type=int, default=2048)
63 | parser.add_argument("--load_in_8bit", action='store_true')
64 | parser.add_argument("--dtype", type=str, default="float16")
65 | parser.add_argument("--with_conf", action='store_true')
66 | parser.add_argument("--cot", action='store_true')
67 | args = parser.parse_args()
68 |
69 | # TODO: better handle
70 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
71 | model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, device_map="cpu", trust_remote_code=True, fp16=True).eval()
72 | model.generation_config = GenerationConfig.from_pretrained(args.model_name_or_path, trust_remote_code=True)
73 | if args.lora_weights != "":
74 | model = PeftModel.from_pretrained(
75 | model,
76 | args.lora_weights,
77 | torch_dtype=torch.float16,
78 | )
79 |
80 | from fastllm_pytools import llm;
81 | model = llm.from_hf(model, tokenizer, dtype = args.dtype)
82 | model.direct_query = True
83 |
84 | run_eval(model, tokenizer, eval_chat_multithread, args)
--------------------------------------------------------------------------------
/test/ops/cppOps.cpp:
--------------------------------------------------------------------------------
1 | #include "fastllm.h"
2 |
3 | void callBaseOp(int optype=0){
4 | fastllm::Data inputs = fastllm::Data(fastllm::DataType::FLOAT32, {1, 2}, {1, 5});
5 | fastllm::Data outputs = fastllm::Data(fastllm::DataType::FLOAT32, {1, 2}, {3, 4});
6 |
7 | switch (optype)
8 | {
9 | case 0:
10 | fastllm::AddTo(inputs, outputs, 1);
11 | break;
12 | case 1:
13 | fastllm::Cat(inputs, inputs, 0, outputs);
14 | break;
15 | case 2:
16 | fastllm::Mul(inputs, 2, outputs);
17 | break;
18 | case 3:
19 | fastllm::Permute(inputs, {1, 0}, outputs);
20 | break;
21 | case 4:
22 | fastllm::Split(inputs, 0, 0, 1, outputs);
23 | break;
24 | case 5:
25 | fastllm::Permute(inputs, {1, 0}, outputs);
26 | fastllm::MatMul(inputs, outputs, outputs);
27 | break;
28 | default:
29 | break;
30 | }
31 | outputs.ToDevice(fastllm::DataDevice::CPU);
32 | outputs.Print();
33 | }
34 |
35 | void callNormOp(int normType=0){
36 | fastllm::Data inputs = fastllm::Data(fastllm::DataType::FLOAT32, {1, 2}, {1, 5});
37 | fastllm::Data weights = fastllm::Data(fastllm::DataType::FLOAT32, {1, 2}, {1, 2});
38 | fastllm::Data gamma = fastllm::Data(fastllm::DataType::FLOAT32, {1, 2}, {1, 1});
39 | fastllm::Data beta = fastllm::Data(fastllm::DataType::FLOAT32, {1, 2}, {0, 0});
40 | fastllm::Data outputs;
41 |
42 | switch (normType)
43 | {
44 | case 0:
45 | fastllm::LayerNorm(inputs, gamma, beta, -1, outputs);
46 | break;
47 | case 1:
48 | fastllm::RMSNorm(inputs, weights, 1e-5, outputs);
49 | break;
50 | default:
51 | break;
52 | }
53 | outputs.ToDevice(fastllm::DataDevice::CPU);
54 | outputs.Print();
55 | }
56 |
57 |
58 | void callLinearOp(){
59 | fastllm::Data inputs = fastllm::Data(fastllm::DataType::FLOAT32, {1, 2}, {1, 2});
60 | fastllm::Data weights = fastllm::Data(fastllm::DataType::FLOAT32, {3, 2}, {3, 4, 5, 5, 6, 7});
61 | fastllm::Data bias = fastllm::Data(fastllm::DataType::FLOAT32, {1, 3}, {0, 1, 1});
62 | fastllm::Data outputs;
63 | fastllm::Linear(inputs, weights, bias, outputs);
64 | outputs.ToDevice(fastllm::DataDevice::CPU);
65 | outputs.Print();
66 | }
67 |
68 | void callActivationOp(int activateType=0){
69 | fastllm::Data inputs = fastllm::Data(fastllm::DataType::FLOAT32, {1, 2}, {1, 5});
70 | fastllm::Data outputs;
71 | switch (activateType)
72 | {
73 | case 0:
74 | fastllm::Silu(inputs, outputs);
75 | break;
76 | case 1:
77 | fastllm::Softmax(inputs, outputs, -1);
78 | break;
79 | case 2:
80 | fastllm::GeluNew(inputs, outputs);
81 | break;
82 | case 3:
83 | fastllm::Swiglu(inputs, outputs);
84 | break;
85 | default:
86 | break;
87 | }
88 | outputs.ToDevice(fastllm::DataDevice::CPU);
89 | outputs.Print();
90 | }
91 |
92 | void callAttentionOp(int group=1, int attentionType=0){
93 | const fastllm::Data q = fastllm::Data(fastllm::DataType::FLOAT32, {1, 2, 3}, {1, 2, 3, 4, 5, 6});
94 | const fastllm::Data k = fastllm::Data(fastllm::DataType::FLOAT32, {1, 2, 3}, {5, 6, 7, 8, 9, 10});
95 | const fastllm::Data v = fastllm::Data(fastllm::DataType::FLOAT32, {1, 2, 3}, {1, 1, 1, 2, 1, 3});
96 | const fastllm::Data mask = fastllm::Data();
97 | int dims = q.dims.back();
98 | float scale = 1/sqrt(dims);
99 | fastllm::Data output;
100 |
101 | fastllm::Attention(q, k, v, mask, output, group, scale, attentionType);
102 | }
103 |
104 | void testBase(){
105 | printf("testing BaseOp...\n");
106 | for (int i=0;i<6;i++){
107 | callBaseOp(i);
108 | }
109 | printf("test BaseOp finished!\n");
110 | }
111 |
112 | void testActivation(){
113 | printf("testing ActivationOp...\n");
114 | for (int i=0;i<4;i++){
115 | callActivationOp(i);
116 | }
117 | printf("test ActivationOp finished!\n");
118 | }
119 |
120 | void testAttention(){
121 | printf("testing AttentionOp...\n");
122 | callAttentionOp();
123 | printf("test AttentionOp finished!\n");
124 | }
125 |
126 | void testLinaer(){
127 | printf("testing LinearOp...\n");
128 | callLinearOp();
129 | printf("test LinearOp finished!\n");
130 | }
131 |
132 | void testNorm(){
133 | printf("testing NormOp...\n");
134 | for (int i=0;i<2;i++){
135 | callNormOp(i);
136 | }
137 | printf("test NormOp finished!\n");
138 | }
139 |
140 | void testAll(){
141 | testBase();
142 | testActivation();
143 | testAttention();
144 | testNorm();
145 | testLinaer();
146 | }
147 |
148 |
149 | int main(){
150 | testAll();
151 | }
--------------------------------------------------------------------------------
/tools/fastllm_pytools/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ["llm"]
--------------------------------------------------------------------------------
/tools/scripts/alpaca2flm.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 | from transformers import AutoTokenizer, LlamaForCausalLM
4 | from fastllm_pytools import torch2flm
5 |
6 | if __name__ == "__main__":
7 | model_name = sys.argv[3] if len(sys.argv) >= 4 else 'minlik/chinese-alpaca-33b-merged'
8 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
9 | # `torch_dtype=torch.float16` is set by default, if it will not cause an OOM Error, you can load model in float32.
10 | model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
11 | conf = model.config.__dict__
12 | conf["model_type"] = "llama"
13 | dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16"
14 | exportPath = sys.argv[1] if len(sys.argv) >= 2 else "alpaca-33b-" + dtype + ".flm"
15 | # add custom code here
16 | torch2flm.tofile(exportPath, model, tokenizer, dtype = dtype)
17 |
--------------------------------------------------------------------------------
/tools/scripts/baichuan2_2flm.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 | from transformers import AutoModelForCausalLM, AutoTokenizer
4 | from transformers.generation.utils import GenerationConfig
5 | from fastllm_pytools import torch2flm
6 |
7 | if __name__ == "__main__":
8 | modelpath = "baichuan-inc/Baichuan2-7B-Chat"
9 | tokenizer = AutoTokenizer.from_pretrained(modelpath, use_fast=False, trust_remote_code=True)
10 | model = AutoModelForCausalLM.from_pretrained(modelpath, device_map="auto", torch_dtype=torch.float32, trust_remote_code=True)
11 |
12 | # normalize lm_head
13 | state_dict = model.state_dict()
14 | state_dict['lm_head.weight'] = torch.nn.functional.normalize(state_dict['lm_head.weight'])
15 | model.load_state_dict(state_dict)
16 |
17 | try:
18 | model.generation_config = GenerationConfig.from_pretrained(modelpath)
19 | except:
20 | pass
21 |
22 | dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16"
23 | exportPath = sys.argv[1] if len(sys.argv) >= 2 else "baichuan2-7b-" + dtype + ".flm"
24 | torch2flm.tofile(exportPath, model.to('cpu'), tokenizer, dtype=dtype)
--------------------------------------------------------------------------------
/tools/scripts/baichuan2flm.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 | from transformers import AutoModelForCausalLM, AutoTokenizer
4 | from transformers.generation.utils import GenerationConfig
5 | from fastllm_pytools import torch2flm
6 |
7 | if __name__ == "__main__":
8 | modelpath = "baichuan-inc/baichuan-13B-Chat"
9 | tokenizer = AutoTokenizer.from_pretrained(modelpath, trust_remote_code=True)
10 | model = AutoModelForCausalLM.from_pretrained(modelpath, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)
11 | model.to("cpu")
12 | try:
13 | model.generation_config = GenerationConfig.from_pretrained(modelpath)
14 | except:
15 | pass
16 | dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16"
17 | exportPath = sys.argv[1] if len(sys.argv) >= 2 else "baichuan-13b-" + dtype + ".flm"
18 | torch2flm.tofile(exportPath, model, tokenizer, dtype = dtype)
19 |
--------------------------------------------------------------------------------
/tools/scripts/chatglm_export.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from transformers import AutoTokenizer, AutoModel
3 | from fastllm_pytools import torch2flm
4 |
5 | if __name__ == "__main__":
6 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
7 | model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
8 | model = model.eval()
9 |
10 | dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16"
11 | exportPath = sys.argv[1] if len(sys.argv) >= 2 else "chatglm-6b-" + dtype + ".flm"
12 | torch2flm.tofile(exportPath, model, tokenizer, dtype = dtype)
13 |
--------------------------------------------------------------------------------
/tools/scripts/cli_demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from fastllm_pytools import llm
3 |
4 | def args_parser():
5 | parser = argparse.ArgumentParser(description = 'fastllm_chat_demo')
6 | parser.add_argument('-p', '--path', type = str, required = True, default = '', help = '模型文件的路径')
7 | parser.add_argument('-t', '--threads', type=int, default=4, help='使用的线程数量')
8 | parser.add_argument('-l', '--low', action='store_true', help='使用低内存模式')
9 | args = parser.parse_args()
10 | return args
11 |
12 | if __name__ == "__main__":
13 | args = args_parser()
14 | llm.set_cpu_threads(args.threads)
15 | llm.set_cpu_low_mem(args.low)
16 | model = llm.model(args.path)
17 |
18 | history = []
19 | print("输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
20 | while True:
21 | query = input("\n用户:")
22 | if query.strip() == "stop":
23 | break
24 | if query.strip() == "clear":
25 | history = []
26 | print("输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
27 | continue
28 | print("AI:", end = "");
29 | curResponse = "";
30 | for response in model.stream_response(query, history = history):
31 | curResponse += response;
32 | print(response, flush = True, end = "")
33 | history.append((query, curResponse))
34 | model.release_memory()
--------------------------------------------------------------------------------
/tools/scripts/glm_export.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import struct
3 | import numpy as np
4 | import torch
5 | import binascii
6 | from transformers import AutoTokenizer, AutoModel
7 | from fastllm_pytools import torch2flm
8 |
9 | def glmtofile(exportPath,
10 | model,
11 | tokenizer = None,
12 | dtype = "float16"):
13 | if (dtype not in torch2flm.fastllm_data_type_dict):
14 | print("dtype should in ", list(torch2flm.fastllm_data_type_dict.keys()))
15 | exit(0)
16 |
17 | dict = model.state_dict()
18 | fo = open(exportPath, "wb")
19 |
20 | # 0. version id
21 | fo.write(struct.pack('i', 2))
22 |
23 | # 0.1 model info
24 | modelInfo = model.config.__dict__
25 | if model.generation_config is not None:
26 | modelInfo.update(model.generation_config.__dict__)
27 | if ("model_type" not in modelInfo):
28 | print("unknown model_type.")
29 | exit(0)
30 |
31 | modelInfo["tokenizer_use_score"] = "1" # 分词带分数
32 | modelInfo["tokenizer_serialized"]=binascii.hexlify(tokenizer.sp_model.serialized_model_proto()).decode("latin-1") # sentencepiece分词器序列化存储
33 |
34 | if hasattr(model, "peft_config"):
35 | adapter_size = len(model.peft_config)
36 | modelInfo["peft_size"] = adapter_size
37 |
38 | fo.write(struct.pack('i', len(modelInfo)))
39 | for it in modelInfo.keys():
40 | torch2flm.writeKeyValue(fo, str(it), str(modelInfo[it]))
41 |
42 | if hasattr(model, "peft_config"):
43 | for adapter_name in model.peft_config.keys():
44 | adapter_dict = model.peft_config[adapter_name].__dict__
45 | torch2flm.writeString(fo, adapter_name)
46 | fo.write(struct.pack('i', len(adapter_dict)))
47 | for it in adapter_dict.keys():
48 | torch2flm.writeKeyValue(fo, str(it), str(adapter_dict[it]))
49 |
50 | # 1. vocab
51 | if (tokenizer):
52 | if (hasattr(tokenizer, "tokenizer")):
53 | tokenizer = tokenizer.tokenizer
54 | if (hasattr(tokenizer, "sp_model")):
55 | piece_size = tokenizer.sp_model.piece_size()
56 | fo.write(struct.pack('i', piece_size))
57 | for i in range(piece_size):
58 | s = tokenizer.sp_model.id_to_piece(i).encode()
59 | fo.write(struct.pack('i', len(s)))
60 | for c in s:
61 | fo.write(struct.pack('i', c))
62 | fo.write(struct.pack('i', i))
63 | fo.write(struct.pack('f', float(tokenizer.sp_model.get_score(i))))
64 | else:
65 | vocab = tokenizer.get_vocab()
66 | fo.write(struct.pack('i', len(vocab)))
67 | for v in vocab.keys():
68 | s = v.encode()
69 | fo.write(struct.pack('i', len(s)))
70 | for c in s:
71 | fo.write(struct.pack('i', c))
72 | fo.write(struct.pack('i', vocab[v]))
73 | fo.write(struct.pack('f', 1.0))
74 | else:
75 | fo.write(struct.pack('i', 0))
76 |
77 | weight_type_dict = {}
78 | module_dict = {}
79 | for key, m in model.named_modules():
80 | if (isinstance(m, torch.nn.Linear)):
81 | weight_type_dict[key + ".weight"] = "linear"
82 | module_dict[key + ".weight"] = m
83 | if (isinstance(m, torch.nn.Embedding)):
84 | weight_type_dict[key] = "embedding"
85 |
86 | # 2. weight
87 | fo.write(struct.pack('i', len(dict)))
88 | tot = 0
89 | for key in dict:
90 | ori_data_type = 0
91 | ori_np_data_type = np.float32
92 | cur_weight_type = 0
93 | if (key in weight_type_dict and weight_type_dict[key] in torch2flm.fastllm_weight_type_dict):
94 | cur_weight_type = torch2flm.fastllm_weight_type_dict[weight_type_dict[key]]
95 | to_data_type = 0
96 | if (cur_weight_type == 1):
97 | to_data_type = torch2flm.fastllm_data_type_dict[dtype]
98 | if (to_data_type == 7):
99 | ori_data_type = 7
100 | ori_np_data_type = np.float16
101 |
102 | cur = dict[key].numpy().astype(ori_np_data_type)
103 |
104 | if hasattr(model, "peft_config"):
105 | weight_name = key.replace('base_model.model.', '')
106 | fo.write(struct.pack('i', len(weight_name)))
107 | fo.write(weight_name.encode())
108 | else:
109 | fo.write(struct.pack('i', len(key)))
110 | fo.write(key.encode())
111 | fo.write(struct.pack('i', len(cur.shape)))
112 | for i in cur.shape:
113 | fo.write(struct.pack('i', i))
114 | if (to_data_type == 3):
115 | write_int8(fo, cur)
116 | elif (to_data_type == 8):
117 | write_int4(fo, cur)
118 | else:
119 | fo.write(struct.pack('i', to_data_type))
120 | fo.write(cur.data)
121 | tot += 1
122 | print("output (", tot, "/", len(dict), end = " )\r")
123 | print("\nfinish.")
124 | fo.close()
125 |
126 | if __name__ == "__main__":
127 | tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-large-chinese", trust_remote_code=True)
128 | model = AutoModel.from_pretrained("THUDM/glm-large-chinese", trust_remote_code=True)
129 | model = model.eval()
130 |
131 | dtype = sys.argv[2] if len(sys.argv) >= 3 else "float32"
132 | exportPath = sys.argv[1] if len(sys.argv) >= 2 else "glm-" + dtype + ".flm"
133 | glmtofile(exportPath, model, tokenizer, dtype = dtype)
134 |
--------------------------------------------------------------------------------
/tools/scripts/llamalike2flm.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 | from transformers import AutoTokenizer, AutoModelForCausalLM
4 | from fastllm_pytools import torch2flm
5 |
6 | if __name__ == "__main__":
7 | modelNameOrPath = sys.argv[3] if len(sys.argv) >= 4 else 'qwen/Qwen1.5-7B-Chat'
8 | tokenizer = AutoTokenizer.from_pretrained(modelNameOrPath, trust_remote_code=True);
9 | # `torch_dtype=torch.float16` is set by default, if it will not cause an OOM Error, you can load model in float32.
10 | model = AutoModelForCausalLM.from_pretrained(modelNameOrPath, trust_remote_code=True, torch_dtype=torch.float16)
11 | model = model.eval()
12 | dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16"
13 | exportPath = sys.argv[1] if len(sys.argv) >= 2 else model.config.model_type + "-7b-" + dtype + ".flm"
14 | if model.config.model_type == "internlm":
15 | torch2flm.tofile(exportPath, model, tokenizer, pre_prompt = "",
16 | user_role = "<|User|>:", bot_role = "\n<|Bot|>:",
17 | history_sep = "\n", dtype = dtype)
18 | elif model.config.model_type == "internlm2":
19 | torch2flm.tofile(exportPath, model, tokenizer, pre_prompt="<|im_start|>system\nYou are an AI assistant whose name is InternLM (书生·浦语).\n<|im_end|>",
20 | user_role="<|im_start|>user\n", bot_role="<|im_end|><|im_start|>assistant\n", history_sep="<|im_end|>\n", dtype = dtype)
21 | elif model.config.model_type == "qwen2":
22 | torch2flm.tofile(exportPath, model, tokenizer, pre_prompt="<|im_start|>system\nYou are a helpful assistant.<|im_end|>", user_role="<|im_start|>user\n",
23 | bot_role="<|im_end|><|im_start|>assistant\n", history_sep="<|im_end|>\n", dtype = dtype)
24 | # add custom code here
25 | else:
26 | torch2flm.tofile(exportPath, model, tokenizer, pre_prompt = "", user_role = "",
27 | bot_role = "", history_sep = "", dtype = dtype)
28 |
--------------------------------------------------------------------------------
/tools/scripts/minicpm2flm.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 | from transformers import AutoTokenizer, AutoModelForCausalLM
4 | from fastllm_pytools import torch2flm
5 |
6 | if __name__ == "__main__":
7 | modelNameOrPath = sys.argv[3] if len(sys.argv) >= 4 else "openbmb/MiniCPM-2B-dpo-fp16"
8 | tokenizer = AutoTokenizer.from_pretrained(modelNameOrPath, use_fast=False, trust_remote_code=True)
9 | # `torch_dtype=torch.float16` is set by default, if it will not cause an OOM Error, you can load model in float32.
10 | model = AutoModelForCausalLM.from_pretrained(modelNameOrPath, trust_remote_code=True, torch_dtype=torch.float16)
11 | model = model.eval()
12 |
13 | model.config.__dict__['model_type'] = 'minicpm'
14 |
15 | dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16"
16 | exportPath = sys.argv[1] if len(sys.argv) >= 2 else "minicpm-2b-" + dtype + ".flm"
17 | torch2flm.tofile(exportPath, model, tokenizer, pre_prompt = "",
18 | user_role = "<用户>", bot_role = "",
19 | history_sep = "", dtype = dtype)
20 |
--------------------------------------------------------------------------------
/tools/scripts/moss_export.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from transformers import AutoTokenizer, AutoModelForCausalLM
3 | from fastllm_pytools import torch2flm
4 |
5 | tokenizer = AutoTokenizer.from_pretrained("fnlp/moss-moon-003-sft", trust_remote_code=True);
6 | model = AutoModelForCausalLM.from_pretrained("fnlp/moss-moon-003-sft", trust_remote_code=True).float();
7 | model = model.eval();
8 |
9 | if __name__ == "__main__":
10 | dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16"
11 | exportPath = sys.argv[1] if len(sys.argv) >= 2 else "moss-" + dtype + ".flm"
12 | torch2flm.tofile(exportPath, model, tokenizer, dtype = dtype)
--------------------------------------------------------------------------------
/tools/scripts/qwen2flm.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from transformers import AutoModelForCausalLM, AutoTokenizer
3 | from transformers.generation import GenerationConfig
4 | from fastllm_pytools import torch2flm
5 |
6 | if __name__ == "__main__":
7 | tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True)
8 | model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="cpu", trust_remote_code=True, fp32=True).eval()
9 | model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True) # 可指定不同的生成长度、top_p等相关超参
10 |
11 | dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16"
12 | exportPath = sys.argv[1] if len(sys.argv) >= 2 else "qwen-7b-" + dtype + ".flm"
13 | torch2flm.tofile(exportPath, model, tokenizer, dtype = dtype)
--------------------------------------------------------------------------------
/tools/scripts/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup (
4 | name = "fastllm_pytools",
5 | version = "0.0.1",
6 | author = "huangyuyang",
7 | author_email = "ztxz16@foxmail.com",
8 | description = "Fastllm pytools",
9 | url = "https://github.com/ztxz16/fastllm",
10 | packages = ['fastllm_pytools'],
11 |
12 | package_data = {
13 | '': ['*.dll', '*.so', '*.dylib']
14 | }
15 | )
16 |
--------------------------------------------------------------------------------
/tools/scripts/web_demo.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | from streamlit_chat import message
3 | from fastllm_pytools import llm
4 | import sys
5 |
6 | st.set_page_config(
7 | page_title="fastllm web demo",
8 | page_icon=":robot:"
9 | )
10 |
11 | @st.cache_resource
12 | def get_model():
13 | model = llm.model(sys.argv[1])
14 | return model
15 |
16 | if "messages" not in st.session_state:
17 | st.session_state.messages = []
18 |
19 | for i, (prompt, response) in enumerate(st.session_state.messages):
20 | with st.chat_message("user"):
21 | st.markdown(prompt)
22 | with st.chat_message("assistant"):
23 | st.markdown(response)
24 |
25 | if prompt := st.chat_input("请开始对话"):
26 | model = get_model()
27 | with st.chat_message("user"):
28 | st.markdown(prompt)
29 |
30 | with st.chat_message("assistant"):
31 | message_placeholder = st.empty()
32 | full_response = ""
33 | for chunk in model.stream_response(prompt, st.session_state.messages, one_by_one = True):
34 | full_response += chunk
35 | message_placeholder.markdown(full_response + "▌")
36 | message_placeholder.markdown(full_response)
37 | st.session_state.messages.append((prompt, full_response))
38 |
--------------------------------------------------------------------------------
/tools/src/quant.cpp:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 5/13/23.
3 | //
4 |
5 | #include
6 | #include "model.h"
7 |
8 | struct QuantConfig {
9 | std::string path; // 模型文件路径
10 | std::string output; // 输出文件路径
11 | int bits; // 量化位数
12 | };
13 |
14 | void Usage() {
15 | std::cout << "Usage:" << std::endl;
16 | std::cout << "[-h|--help]: 显示帮助" << std::endl;
17 | std::cout << "<-p|--path> : 模型文件的路径" << std::endl;
18 | std::cout << "<-b|--bits> : 量化位数, 4 = int4, 8 = int8, 16 = fp16" << std::endl;
19 | std::cout << "<-o|--output> : 输出文件路径" << std::endl;
20 | }
21 |
22 | void ParseArgs(int argc, char **argv, QuantConfig &config) {
23 | std::vector sargv;
24 | for (int i = 0; i < argc; i++) {
25 | sargv.push_back(std::string(argv[i]));
26 | }
27 | for (int i = 1; i < argc; i++) {
28 | if (sargv[i] == "-h" || sargv[i] == "--help") {
29 | Usage();
30 | exit(0);
31 | } else if (sargv[i] == "-p" || sargv[i] == "--path") {
32 | config.path = sargv[++i];
33 | } else if (sargv[i] == "-b" || sargv[i] == "--bits") {
34 | config.bits = atoi(sargv[++i].c_str());
35 | } else if (sargv[i] == "-o" || sargv[i] == "--output") {
36 | config.output = sargv[++i];
37 | } else if (sargv[i] == "-m" || sargv[i] == "--model") {
38 | i++;
39 | } else {
40 | Usage();
41 | exit(-1);
42 | }
43 | }
44 | }
45 |
46 | int main(int argc, char **argv) {
47 | QuantConfig config;
48 | ParseArgs(argc, argv, config);
49 | auto model = fastllm::CreateLLMModelFromFile(config.path);
50 | model->SaveLowBitModel(config.output, config.bits);
51 | return 0;
52 | }
--------------------------------------------------------------------------------