├── .dockerignore
├── .github
└── workflows
│ └── Build.yml
├── .gitignore
├── .gitmodules
├── CMakeLists.txt
├── Dockerfile
├── LICENSE
├── README.md
├── README_EN.md
├── docker-compose.yaml
├── docs
├── benchmark.md
├── custom.md
├── custom_op.md
├── deepseek.md
├── demo_arguments.md
├── english_custom.md
├── english_demo_arguments.md
├── faq.md
├── ftllm.md
├── llama_cookbook.md
├── mixforward.md
├── models.md
├── qwen3.md
├── rocm.md
├── tfacc.md
├── version.md
└── wechat_group0.jpg
├── example
├── Android
│ └── LLMAssistant
│ │ ├── .gitignore
│ │ ├── .idea
│ │ ├── .gitignore
│ │ ├── .name
│ │ ├── compiler.xml
│ │ ├── dbnavigator.xml
│ │ ├── deploymentTargetDropDown.xml
│ │ ├── gradle.xml
│ │ ├── misc.xml
│ │ └── vcs.xml
│ │ ├── app
│ │ ├── .gitignore
│ │ ├── build.gradle
│ │ ├── libs
│ │ │ ├── arm64-v8a
│ │ │ │ └── libassistant.so
│ │ │ └── armeabi-v7a
│ │ │ │ └── libassistant.so
│ │ ├── proguard-rules.pro
│ │ ├── release
│ │ │ ├── app-arm64-v8a-release-unsigned.apk
│ │ │ ├── app-armeabi-v7a-release-unsigned.apk
│ │ │ ├── app-universal-release-unsigned.apk
│ │ │ └── app-x86-release-unsigned.apk
│ │ └── src
│ │ │ ├── androidTest
│ │ │ └── java
│ │ │ │ └── com
│ │ │ │ └── doujiao
│ │ │ │ └── xiaozhihuiassistant
│ │ │ │ └── ExampleInstrumentedTest.java
│ │ │ ├── main
│ │ │ ├── AndroidManifest.xml
│ │ │ ├── cpp
│ │ │ │ ├── CMakeLists.txt
│ │ │ │ ├── LLMChat.cpp
│ │ │ │ ├── LLMChat.h
│ │ │ │ ├── main.cpp
│ │ │ │ └── native-lib.cpp
│ │ │ ├── java
│ │ │ │ └── com
│ │ │ │ │ └── doujiao
│ │ │ │ │ ├── core
│ │ │ │ │ └── AssistantCore.java
│ │ │ │ │ └── xiaozhihuiassistant
│ │ │ │ │ ├── ChatMessage.java
│ │ │ │ │ ├── MainActivity.java
│ │ │ │ │ ├── adapter
│ │ │ │ │ ├── BaseViewHolder.java
│ │ │ │ │ └── MyAdapter.java
│ │ │ │ │ ├── utils
│ │ │ │ │ ├── PrefUtil.java
│ │ │ │ │ ├── StatusBarUtils.java
│ │ │ │ │ └── UriUtils.java
│ │ │ │ │ └── widget
│ │ │ │ │ ├── ChatPromptViewManager.java
│ │ │ │ │ ├── Location.java
│ │ │ │ │ ├── PromptView.java
│ │ │ │ │ ├── PromptViewHelper.java
│ │ │ │ │ └── location
│ │ │ │ │ ├── BottomCenterLocation.java
│ │ │ │ │ ├── ICalculateLocation.java
│ │ │ │ │ ├── TopCenterLocation.java
│ │ │ │ │ ├── TopLeftLocation.java
│ │ │ │ │ └── TopRightLocation.java
│ │ │ └── res
│ │ │ │ ├── drawable-v24
│ │ │ │ └── ic_launcher_foreground.xml
│ │ │ │ ├── drawable
│ │ │ │ ├── btnbg.xml
│ │ │ │ ├── editbg.xml
│ │ │ │ └── ic_launcher_background.xml
│ │ │ │ ├── layout
│ │ │ │ ├── activity_item_left.xml
│ │ │ │ ├── activity_item_right.xml
│ │ │ │ └── activity_main.xml
│ │ │ │ ├── mipmap-anydpi-v26
│ │ │ │ ├── ic_launcher.xml
│ │ │ │ └── ic_launcher_round.xml
│ │ │ │ ├── mipmap-hdpi
│ │ │ │ ├── ic_launcher.webp
│ │ │ │ └── ic_launcher_round.webp
│ │ │ │ ├── mipmap-mdpi
│ │ │ │ ├── ic_launcher.webp
│ │ │ │ └── ic_launcher_round.webp
│ │ │ │ ├── mipmap-xhdpi
│ │ │ │ ├── ic_launcher.webp
│ │ │ │ └── ic_launcher_round.webp
│ │ │ │ ├── mipmap-xxhdpi
│ │ │ │ ├── glm.png
│ │ │ │ ├── ic_launcher.webp
│ │ │ │ ├── ic_launcher_round.webp
│ │ │ │ └── me.png
│ │ │ │ ├── mipmap-xxxhdpi
│ │ │ │ ├── ic_launcher.webp
│ │ │ │ └── ic_launcher_round.webp
│ │ │ │ ├── values-night
│ │ │ │ └── themes.xml
│ │ │ │ └── values
│ │ │ │ ├── colors.xml
│ │ │ │ ├── strings.xml
│ │ │ │ └── themes.xml
│ │ │ └── test
│ │ │ └── java
│ │ │ └── com
│ │ │ └── doujiao
│ │ │ └── xiaozhihuiassistant
│ │ │ └── ExampleUnitTest.java
│ │ ├── build.gradle
│ │ ├── gradle.properties
│ │ ├── gradle
│ │ └── wrapper
│ │ │ ├── gradle-wrapper.jar
│ │ │ └── gradle-wrapper.properties
│ │ ├── gradlew
│ │ ├── gradlew.bat
│ │ └── settings.gradle
├── FastllmStudio
│ └── cli
│ │ ├── cli.cpp
│ │ ├── ui.cpp
│ │ └── ui.h
├── Qui
│ ├── FastLLM.cpp
│ ├── bin
│ │ ├── Qt5Core.dll
│ │ ├── Qt5Gui.dll
│ │ ├── Qt5Widgets.dll
│ │ ├── Qui.exe
│ │ ├── fastllm_cpu.exe
│ │ ├── fastllm_cuda.exe
│ │ ├── path.txt
│ │ ├── platforms
│ │ │ └── qwindows.dll
│ │ ├── qui_cn.qm
│ │ └── styles
│ │ │ └── qwindowsvistastyle.dll
│ └── src
│ │ ├── Qui.cpp
│ │ ├── Qui.h
│ │ ├── Qui.pro
│ │ ├── Qui.ui
│ │ ├── main.cpp
│ │ └── qui_cn.ts
├── README.md
├── Win32Demo
│ ├── StringUtils.h
│ ├── Win32Demo.cpp
│ ├── Win32Demo.sln
│ ├── Win32Demo.vcxproj
│ ├── bin
│ │ └── web
│ │ │ ├── css
│ │ │ ├── github-markdown-light.min.css
│ │ │ ├── github.min.css
│ │ │ ├── katex.min.css
│ │ │ └── texmath.css
│ │ │ ├── index.html
│ │ │ └── js
│ │ │ ├── highlight.min.js
│ │ │ ├── katex.min.js
│ │ │ ├── markdown-it-link-attributes.min.js
│ │ │ ├── markdown-it.min.js
│ │ │ └── texmath.js
│ ├── fastllm-gpu.vcxproj
│ ├── fastllm-gpu.vcxproj.filters
│ ├── fastllm.vcxproj
│ ├── fastllm.vcxproj.filters
│ └── httplib.h
├── apiserver
│ └── apiserver.cpp
├── benchmark
│ ├── benchmark.cpp
│ └── prompts
│ │ ├── beijing.txt
│ │ └── hello.txt
├── openai_server
│ ├── README.md
│ ├── fastllm_completion.py
│ ├── openai_api_server.py
│ ├── protocal
│ │ ├── __init__.py
│ │ └── openai_protocol.py
│ └── requirements.txt
├── python
│ ├── custom_model.py
│ └── qwen2.py
└── webui
│ ├── httplib.h
│ ├── web
│ ├── css
│ │ ├── github-markdown-light.min.css
│ │ ├── github.min.css
│ │ ├── katex.min.css
│ │ └── texmath.css
│ ├── index.html
│ └── js
│ │ ├── highlight.min.js
│ │ ├── katex.min.js
│ │ ├── markdown-it-link-attributes.min.js
│ │ ├── markdown-it.min.js
│ │ └── texmath.js
│ └── webui.cpp
├── include
├── device.h
├── devices
│ ├── cpu
│ │ ├── alivethreadpool.h
│ │ ├── computeutils.h
│ │ ├── cpudevice.h
│ │ └── cputhreadpool.h
│ ├── cuda
│ │ ├── cudadevice.h
│ │ ├── fastllm-cuda.cuh
│ │ └── fastllm-hip.h
│ ├── multicuda
│ │ ├── fastllm-multicuda.cuh
│ │ └── multicudadevice.h
│ ├── numa
│ │ ├── computeserver.h
│ │ ├── fastllm-numa.h
│ │ ├── kvcache.h
│ │ └── numadevice.h
│ ├── tfacc
│ │ ├── fastllm-tfacc.h
│ │ └── tfaccdevice.h
│ └── tops
│ │ └── topsdevice.h
├── executor.h
├── fastllm.h
├── graph.h
├── model.h
├── models
│ ├── basellm.h
│ ├── bert.h
│ ├── chatglm.h
│ ├── cogvlm.h
│ ├── deepseekv2.h
│ ├── factoryllm.h
│ ├── glm.h
│ ├── graphllm.h
│ ├── internlm2.h
│ ├── llama.h
│ ├── minicpm.h
│ ├── minicpm3.h
│ ├── moe.h
│ ├── moss.h
│ ├── phi3.h
│ ├── qwen.h
│ ├── qwen3.h
│ ├── qwen3_moe.h
│ └── xlmroberta.h
├── template.h
└── utils
│ ├── armMath.h
│ ├── avxMath.h
│ └── utils.h
├── install.sh
├── main.cpp
├── make_whl.sh
├── make_whl_rocm.sh
├── pyfastllm
├── README.md
├── examples
│ ├── cli_low_level.py
│ ├── cli_simple.py
│ ├── convert_model.py
│ ├── test_chatglm2.py
│ ├── test_chatglm2_cpp.py
│ ├── test_chatglm2_func.py
│ ├── test_ops.py
│ ├── web_api.py
│ └── web_api_client.py
├── fastllm
│ ├── __init__.py
│ ├── convert.py
│ ├── functions
│ │ ├── __init__.py
│ │ ├── custom_ops.py
│ │ ├── fastllm_ops.py
│ │ ├── numpy_ops.py
│ │ └── util.py
│ ├── hub
│ │ ├── __init__.py
│ │ └── chatglm2.py
│ ├── models.py
│ ├── nn
│ │ ├── __init__.py
│ │ ├── base_module.py
│ │ └── modules.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── converter.py
│ │ ├── quantizer.py
│ │ └── writer.py
├── install.sh
└── setup.py
├── requirements-server.txt
├── simple_install.sh
├── src
├── device.cpp
├── devices
│ ├── cpu
│ │ ├── avx512bf16.cpp
│ │ ├── avx512vnni.cpp
│ │ ├── cpudevice.cpp
│ │ ├── cpudevicebatch.cpp
│ │ └── linear.cpp
│ ├── cuda
│ │ ├── cudadevice.cpp
│ │ ├── cudadevicebatch.cpp
│ │ └── fastllm-cuda.cu
│ ├── multicuda
│ │ ├── fastllm-multicuda.cu
│ │ └── multicudadevice.cpp
│ ├── numa
│ │ ├── computeserver.cpp
│ │ ├── fastllm-numa.cpp
│ │ ├── kvcache.cpp
│ │ └── numadevice.cpp
│ ├── tfacc
│ │ ├── fastllm-tfacc.cpp
│ │ └── tfaccdevice.cpp
│ └── tops
│ │ └── topsdevice.cpp
├── executor.cpp
├── fastllm.cpp
├── graph.cpp
├── model.cpp
├── models
│ ├── basellm.cpp
│ ├── bert.cpp
│ ├── chatglm.cpp
│ ├── cogvlm.cpp
│ ├── deepseekv2.cpp
│ ├── glm.cpp
│ ├── graph
│ │ ├── fastllmjson.cpp
│ │ ├── gemma2.cpp
│ │ ├── minicpm3.cpp
│ │ ├── phi3.cpp
│ │ ├── qwen2.cpp
│ │ └── telechat.cpp
│ ├── graphllm.cpp
│ ├── internlm2.cpp
│ ├── llama.cpp
│ ├── minicpm.cpp
│ ├── minicpm3.cpp
│ ├── moe.cpp
│ ├── moss.cpp
│ ├── phi3.cpp
│ ├── qwen.cpp
│ ├── qwen3.cpp
│ ├── qwen3_moe.cpp
│ └── xlmroberta.cpp
├── pybinding.cpp
└── template.cpp
├── test
├── basic
│ ├── config.py
│ ├── forward_check.py
│ └── tokenizer_check.py
├── cmmlu
│ ├── README.md
│ ├── baichuan.py
│ ├── categories.py
│ ├── chatglm.py
│ ├── eval.py
│ ├── minicpm3.py
│ ├── qwen.py
│ └── qwq.py
└── ops
│ ├── cppOps.cpp
│ └── tokenizerTest.cpp
├── third_party
├── hipify_torch
│ ├── LICENSE.txt
│ ├── README.md
│ ├── cmake
│ │ └── Hipify.cmake
│ ├── hipify_cli.py
│ ├── hipify_torch
│ │ ├── __init__.py
│ │ ├── constants.py
│ │ ├── cuda_to_hip_mappings.py
│ │ ├── hipify_python.py
│ │ └── version.py
│ ├── setup.py
│ ├── test
│ │ └── test_installation.py
│ └── tools
│ │ └── replace_cuda_with_hip_files.py
├── json11
│ ├── json11.cpp
│ └── json11.hpp
└── tfacc
│ ├── driver
│ └── tfacc2
│ │ ├── Makefile
│ │ ├── build_driver.sh
│ │ ├── modules.order
│ │ ├── tfacc2.c
│ │ └── tfacc2.h
│ ├── launch.py
│ ├── pull.sh
│ └── server
├── tools
├── fastllm_pytools
│ ├── __init__.py
│ ├── chat.py
│ ├── cli.py
│ ├── download.py
│ ├── export.py
│ ├── hf_model.py
│ ├── llm.py
│ ├── openai_server
│ │ ├── fastllm_completion.py
│ │ ├── fastllm_embed.py
│ │ ├── fastllm_model.py
│ │ ├── fastllm_reranker.py
│ │ └── protocal
│ │ │ └── openai_protocol.py
│ ├── server.py
│ ├── torch2flm.py
│ ├── ui.py
│ ├── util.py
│ ├── web_demo.py
│ └── webui.py
├── scripts
│ ├── alpaca2flm.py
│ ├── baichuan2_2flm.py
│ ├── baichuan2flm.py
│ ├── bert2flm.py
│ ├── chatglm_export.py
│ ├── cli_demo.py
│ ├── glm_export.py
│ ├── llama3_to_flm.py
│ ├── llamalike2flm.py
│ ├── minicpm2flm.py
│ ├── moss_export.py
│ ├── qwen2flm.py
│ ├── setup.py
│ ├── setup_rocm.py
│ └── web_demo.py
└── src
│ ├── pytools.cpp
│ ├── pytools_t2s.cpp
│ └── quant.cpp
├── whl_docker
└── Dockerfile
└── whl_docker_rocm
├── 24.04
└── Dockerfile
└── Dockerfile
/.dockerignore:
--------------------------------------------------------------------------------
1 | ./models
2 | ./build/
--------------------------------------------------------------------------------
/.github/workflows/Build.yml:
--------------------------------------------------------------------------------
1 | name: Action Build
2 | on: [push]
3 |
4 | jobs:
5 | build:
6 | runs-on: ubuntu-latest
7 |
8 | steps:
9 | - name: Set up JDK
10 | uses: actions/setup-java@v1
11 | with:
12 | java-version: '11'
13 |
14 | - name: Checkout code
15 | uses: actions/checkout@v2
16 |
17 | #- name: Build with arm64-v8a
18 | # run: |
19 | # wget -q https://dl.google.com/android/repository/android-ndk-r22b-linux-x86_64.zip
20 | # unzip android-ndk-r22b-linux-x86_64.zip
21 | # export NDK=$GITHUB_WORKSPACE/android-ndk-r22b
22 | # mkdir build-android
23 | # cd build-android
24 | #ls ${NDK}/build/cmake/android.toolchain.cmake
25 | # cmake -DCMAKE_MAKE_PROGRAM=/usr/bin/make -DCMAKE_CXX_COMPILER=/usr/bin/g++ -DCMAKE_TOOLCHAIN_FILE=${NDK}/build/cmake/android.toolchain.cmake -DANDROID_ABI=arm64-v8a -DANDROID_PLATFORM=android-23 -DCMAKE_CXX_FLAGS=-march=armv8.2a+dotprod ..
26 | # make -j -B
27 | # cp main fastllm-main-android
28 |
29 | - name: Build with x86
30 | run: |
31 | mkdir build-x86
32 | cd build-x86
33 | cmake .. -DUSE_CUDA=OFF
34 | make -j $(nproc)
35 | cp main fastllm-main-x86_64
36 |
37 | - name: Export and Upload Artifact
38 | uses: actions/upload-artifact@v4
39 | with:
40 | name: Output
41 | path: |
42 | build-android/fastllm-main-android
43 | build-x86/fastllm-main-x86_64
44 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.log
2 | *.pyc
3 | token
4 | /cmake-build-debug/
5 | /build*
6 | /pyfastllm/build/
7 | /pyfastllm/dist/
8 | /.idea/
9 | /.vscode/
10 | /example/Win32Demo/bin/*.*
11 | /example/Win32Demo/Win32
12 | /example/Win32Demo/x64
13 | /example/Win32Demo/*.filters
14 | /example/Win32Demo/*.user
15 | /example/Win32Demo/.vs
16 | /example/Android/LLMAssistant/*.iml
17 | /example/Android/LLMAssistant/.gradle
18 | /example/Android/LLMAssistant/local.properties
19 | /example/Android/LLMAssistant/.idea/caches
20 | /example/Android/LLMAssistant/.idea/libraries
21 | /example/Android/LLMAssistant/.idea/modules.xml
22 | /example/Android/LLMAssistant/.idea/workspace.xml
23 | /example/Android/LLMAssistant/.idea/navEditor.xml
24 | /example/Android/LLMAssistant/.idea/assetWizardSettings.xml
25 | /example/Android/LLMAssistant/.DS_Store
26 | /example/Android/LLMAssistant/build
27 | /example/Android/LLMAssistant/captures
28 | /example/Android/LLMAssistant/.externalNativeBuild
29 | /example/Android/LLMAssistant/.cxx
30 | /example/Android/LLMAssistant/local.properties
31 | /test/cmmlu/results/
32 | /models/
33 | /localtest/
34 | /third_party/tfacc/driver/tfacc2/result
35 | /.chainlit
36 | /.files
37 | /src/devices/hip
38 | /src/devices/multihip
39 | /test/mmlu
40 | *.o
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "third_party/pybind11"]
2 | path = third_party/pybind11
3 | url = https://github.com/pybind/pybind11.git
4 | branch = v2.10.5
5 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # syntax=docker/dockerfile:1-labs
2 | FROM nvidia/cuda:12.1.0-devel-ubuntu22.04
3 |
4 | # Update Apt repositories
5 | RUN apt-get update
6 |
7 | # Install and configure Python
8 | RUN apt-get -y --no-install-recommends install wget build-essential python3.10 python3-pip
9 | RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.10 1
10 | RUN pip install setuptools streamlit-chat
11 |
12 | ENV WORKDIR /fastllm
13 |
14 | # Install cmake
15 | RUN wget -c https://cmake.org/files/LatestRelease/cmake-3.28.3-linux-x86_64.sh && bash ./cmake-3.28.3-linux-x86_64.sh --skip-license --prefix=/usr/
16 |
17 | WORKDIR $WORKDIR
18 | ADD . $WORKDIR/
19 |
20 | RUN mkdir $WORKDIR/build && cd build && cmake .. -DUSE_CUDA=ON -DCMAKE_CUDA_ARCHITECTURES=native && make -j && cd tools && python setup.py install
21 |
22 | CMD /fastllm/build/webui -p /models/chatglm2-6b-int8.flm
23 |
--------------------------------------------------------------------------------
/docker-compose.yaml:
--------------------------------------------------------------------------------
1 | version: '3.8'
2 | services:
3 | fastllm:
4 | build:
5 | context: .
6 | args:
7 | DOCKER_BUILDKIT: 0
8 | # privileged: true
9 | platforms:
10 | - "linux/amd64"
11 | tags:
12 | - "fastllm:v0.9"
13 | restart: always
14 | ports:
15 | - 11234:8081
16 | volumes:
17 | - ./models/:/models/
18 | command: /fastllm/build/webui -p /models/chatglm2-6b-int8.flm -w ./example/webui/web
19 |
20 |
--------------------------------------------------------------------------------
/docs/benchmark.md:
--------------------------------------------------------------------------------
1 | ## 推理速度
2 |
3 | 可以使用benchmark程序进行测速,根据不同配置、不同输入,推理速度也会有一些差别
4 |
5 | 例如:
6 |
7 | ``` sh
8 | ./benchmark -p ~/chatglm-6b-int4.flm -f ../example/benchmark/prompts/beijing.txt -b 1
9 | ./benchmark -p ~/chatglm-6b-int8.flm -f ../example/benchmark/prompts/beijing.txt -b 1
10 | ./benchmark -p ~/chatglm-6b-fp16.flm -f ../example/benchmark/prompts/hello.txt -b 512 -l 18
11 | ```
12 |
13 | | 模型 | Data精度 | 平台 | Batch | 最大推理速度(token / s) |
14 | |-----------------:|---------|--------------------|-----------|---------------------:|
15 | | ChatGLM-6b-int4 | float32 | RTX 4090 | 1 | 176 |
16 | | ChatGLM-6b-int8 | float32 | RTX 4090 | 1 | 121 |
17 | | ChatGLM-6b-fp16 | float32 | RTX 4090 | 64 | 2919 |
18 | | ChatGLM-6b-fp16 | float32 | RTX 4090 | 256 | 7871 |
19 | | ChatGLM-6b-fp16 | float32 | RTX 4090 | 512 | 10209 |
20 | | ChatGLM-6b-int4 | float32 | Xiaomi 10 Pro - 4 Threads | 1 | 4 ~ 5 |
21 |
--------------------------------------------------------------------------------
/docs/custom.md:
--------------------------------------------------------------------------------
1 | ### 自定义模型
2 |
3 | 对于Fastllm框架中没有支持的模型,可以通过自定义模型结构来支持
4 |
5 | Pyhton 自定义模型只需要一个python文件来描述模型结构,可参考 [QWEN](../example/python/qwen2.py) 中的实现
6 |
7 | ### Python自定义模型的使用
8 |
9 | 使用ftllm.chat, ftllm.webui, ftllm.server时,可以加入参数--custom来指定自定义模型文件
10 |
11 | 假设我们的模型位于 `~/Qwen2-7B-Instruct/` 目录,自定义模型位于 `~/qwen2.py`
12 |
13 | 那么可以使用命令
14 |
15 | ``` sh
16 | python3 -m ftllm.chat -t 16 -p ~/Qwen2-7B-Instruct/ --custom ~/qwen2.py
17 | ```
18 |
19 | 来通过自定义模型文件加在Qwen2模型,server和webui用法类似
20 |
21 | ### Python自定义模型的写法
22 |
23 | 自定义模型时,需要实现一个模型的描述类,继承自ftllm.llm.ComputeGraph
24 |
25 | 对应 [QWEN](../example/python/qwen2.py) 中的代码
26 |
27 | ``` python
28 | from ftllm.llm import ComputeGraph
29 | class Qwen2Model(ComputeGraph):
30 | ```
31 |
32 | 文件最后需要定义 `__model__` 变量来指定自定义模型结构对应的class, 对应代码
33 |
34 | ``` python
35 | __model__ = Qwen2Model
36 | ```
37 |
38 | 模型描述类中需要实现build方法,来获取模型参数、描述计算流程
39 |
40 | 这里以示例代码为例介绍
41 |
42 | ``` python
43 | class Qwen2Model(ComputeGraph):
44 | def build(self):
45 | # 1. 获取weight, data, config
46 | weight, data, config = self.weight, self.data, self.config
47 |
48 | # 2. 设置一些config
49 | config["max_positions"] = 128000
50 |
51 | # 3. 描述计算流程
52 | head_dim = config["hidden_size"] // config["num_attention_heads"]
53 | self.Embedding(data["inputIds"], weight["model.embed_tokens.weight"], data["hiddenStates"]);
54 | # 以下是计算流程,具体参见示例代码
55 | ```
56 |
57 | #### `self.config`
58 |
59 | 模型配置,默认会从模型文件夹下的 `config.json` 文件中读取
60 |
61 | build方法中可以修改config中的参数,例如改动 `max_positions` 可以修改上下文长度
62 |
63 | 有一些模型的 `config.json` 中使用的变量名不一致,需要在build过程中手动为config赋值。
64 |
65 | 例如在TeleChat7B模型的配置中没有 `max_positions` 变量,而是用 `seq_length` 变量代表长度,那么在build方法中需要用如下代码赋值:
66 |
67 | ``` python
68 | self.config["max_positions"] = self.config["seq_length"]
69 | ```
70 |
71 | config中,有以下变量必须要赋值(如果config.json中变量名一致,可以不处理):
72 |
73 | ``` python
74 | self.config["max_positions"] #代表最长上下文长度
75 | ```
76 |
77 | #### `self.weight`
78 |
79 | 代表权重数据
80 |
81 | `self.weight[weightName]` 代表模型文件中名为weightName的参数(对应HF模型文件夹中.safetensors文件中的参数名)
82 |
83 | #### ```self.data```
84 |
85 | 代表计算流程的中间变量和输入变量
86 |
87 | `self.data[dataName]` 代表名为dataName的中间变量,`dataName` 可以使用除以下输入变量名之外的任意字符串
88 |
89 | 输入变量:
90 |
91 | ``` python
92 | data["inputIds"] # 输入token
93 | data["positionIds"] # 位置信息
94 | data["attentionMask"] # mask信息
95 | data["sin"] # 用于旋转编码的sin
96 | data["cos"] # 用于旋转编码的cos
97 | data["atype"] # 推理中的数据类型
98 | data["pastKey."][i] # 第i个block的key cache
99 | data["pastValue."][i] # 第i个block的value cache
100 | ```
101 |
102 | #### 计算流程及算子
103 |
104 | 使用基类ComputeGraph添加算子的函数来描述计算流程
105 |
106 | 目前支持的算子见文档 [自定义模型算子](./custom_op.md)
107 |
108 | ### cpp版本的自定义模型
109 |
110 | (cpp版本的自定义模型接口还在修改中...)
111 |
--------------------------------------------------------------------------------
/docs/demo_arguments.md:
--------------------------------------------------------------------------------
1 | # Fastllm Python Demo 参数说明
2 |
3 | ## 通用参数
4 |
5 | 模型相关配置,OpenAI API Server, WebUI, 对话Demo 均可使用
6 |
7 | - **模型路径 (`-p, --path`)**: 指定模型的路径,可以是fastllm模型文件或Hugging Face模型文件夹。例如:
8 | ```bash
9 | --path ~/Qwen2-7B-Instruct/ # 从~/Qwen2-7B-Instruct/中读取模型,这里的模型需要是从HuggingFace或ModelScope或其他网站下载的Hugging face格式的标准模型,暂不支持AWQ,GPTQ等格式
10 | --path ~/model.flm # 从~/model.flm中读取模型,这里的模型是Fastllm格式的模型文件
11 | ```
12 | - **推理类型 (`--atype`)**: 设置中间计算类型,可以指定为`float16`或`float32`
13 | - **权重类型 (`--dtype`)**: 指定模型的权重类型,适用于读取Hugging Face模型时。可以指定为`float16`, `int8`, `int4`, `int4g`(int4分组量化),例如:
14 | ```bash
15 | --dtype float16 # 使用float16权重(不量化)
16 | --dtype int8 # 在线量化成int8权重
17 | --dtype int4g128 # 在线量化成int4分组权重(128个权重一组)
18 | --dtype int4g256 # 在线量化成int4分组权重(256个权重一组)
19 | --dtype int4 # 在线量化成int4权重
20 | ```
21 | - **使用的设备 (`--device`)**: 指定服务器使用的设备。可以指定为`cpu`或`cuda`或额外编译的其余device类型
22 | - **CUDA Embedding (`--cuda_embedding`)**: 若带上此配置且device设置为`cuda`,那么会在cuda设备上进行embedding操作,这样速度会略微提升,显存占用也会提升,建议在显存非常充足的情况下使用
23 | - **KV缓存最大使用量 (`--kv_cache_limit`)**: 设置KV缓存的最大使用量。若不使用此参数或设置为`auto`,框架会自动处理。手动设定示例如下:
24 | ```bash
25 | --kv_cache_limit 5G # 设置为5G
26 | --kv_cache_limit 100M # 设置为100M
27 | --kv_cache_limit 168K # 设置为168K
28 | ```
29 | - **最大Batch数量 (`--max_batch`)**: 设置每次同时处理的请求数量。若不使用此参数,框架会自动处理
30 | - **线程数量 (`-t, --threads`)**: 设置CPU线程数量,device设置为`cpu`时对速度有较大影响,设置为`cuda`时影响较小,主要影响读取模型的速度
31 | - **自定义模型描述文件 (`--custom`)**: 指定描述自定义模型的Python文件。具体见 [自定义模型](custom.md)
32 |
33 | ## OpenAI API Server配置参数
34 | - **模型名称 (`--model_name`)**: 指定部署的模型名称,API调用时会进行名称核验
35 | - **API服务器主机地址 (`--host`)**: 设置API服务器的主机地址
36 | - **API服务器端口号 (`--port`)**: 设置API服务器的端口号
37 |
38 |
39 | ## Web UI 配置参数
40 | - **API服务器端口号 (`--port`)**: 设置WebUI的端口号
41 | - **页面标题 (`--title`)**: 设置WebUI的页面标题
--------------------------------------------------------------------------------
/docs/ftllm.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/docs/ftllm.md
--------------------------------------------------------------------------------
/docs/mixforward.md:
--------------------------------------------------------------------------------
1 | # 混合推理使用说明
2 |
3 | 这个文档以`DeepSeek-V3-0324-INT4`模型为例,介绍如何使用混合推理来榨干硬件
4 |
5 | ## 基本用法
6 |
7 | 假设我们在一台有两张48G的显卡上部署`DeepSeek-V3-0324-INT4`模型,一般用法是这样的
8 |
9 |
10 | ```
11 | ftllm server fastllm/DeepSeek-V3-0324-INT4
12 | ```
13 |
14 | 这时候会默认将模型的moe部分运行在cpu上,非moe部分运行在cuda上,等价于如下命令:
15 |
16 | ```
17 | ftllm server fastllm/DeepSeek-V3-0324-INT4 --device cuda --moe_device cpu
18 | ```
19 |
20 | (注意:之后的优化目前仅对`cuda`和`cpu`的混合推理有效,`numa`无法使用这些功能)
21 |
22 | ## 将部分moe层运行在单张显卡上
23 |
24 | 用上述命令运行时,显存会有大量剩余,我们可以通过设置`moe_device`,将一部分moe层
25 | 指定在cuda上运行
26 |
27 | ```
28 | ftllm server fastllm/DeepSeek-V3-0324-INT4 --device cuda --moe_device "{'cuda':1,'cpu':19}"
29 | ```
30 |
31 | 上述命令中将`moe_device`设置为`"{'cuda':1,'cpu':19}"`,代表`1/20`的moe层运行在cuda上,`19/20`的moe层运行在cpu上
32 |
33 | 这样能轻微提升decode速度,但是可能会降低上下文长度
34 |
35 | ## 将部分moe层运行在多张显卡上
36 |
37 | 使用下面的命令可以使用多张显卡来加速部分moe层
38 |
39 | ```
40 | ftllm server fastllm/DeepSeek-V3-0324-INT4 --device cuda --moe_device "{'multicuda:0,1':15,'cpu':85}"
41 | ```
42 |
43 | 上述命令中将`moe_device`设置为`"{'multicuda:0,1':15,'cpu':85}"`,代表`15/100`的moe层使用0,1两张gpu张量并行推理,`85/100`的moe层运行在cpu上
44 |
45 | 这样能进一步提升decode速度
46 |
47 | (建议看到这里就结束,但如果想了解更多的花活也可以继续往下看)
48 |
49 | ## 将部分moe层使用混合张量并行推理
50 |
51 | 使用下面的命令可以使用混合张量并来加速部分moe层
52 |
53 | ```
54 | ftllm server fastllm/DeepSeek-V3-0324-INT4 --device cuda --moe_device "{'multicuda:0:3,1:3,cpu:2':15,'cpu':85}"
55 | ```
56 |
57 | 上述命令中将`moe_device`设置为`"{'multicuda:0:3,1:3,cpu:2':15,'cpu':85}"`,代表:
58 | - `15/100`的moe层使用混合张量并行,这时候两张显卡和cpu会同时工作,`3/8`的计算量在显卡0上,`3/8`的计算量在显卡1上,`2/8`的计算量在cpu上
59 | - `85/100`的moe层运行在cpu上
60 |
61 | 这样理论上能更进一步提升decode速度,但目前实现效率不高,速度还不如上一步,后续会继续优化
62 |
63 |
--------------------------------------------------------------------------------
/docs/qwen3.md:
--------------------------------------------------------------------------------
1 | ## Qwen3模型介绍
2 |
3 | Qwen3是阿里巴巴出品的系列模型
4 |
5 | ### 安装Fastllm
6 |
7 | - PIP安装
8 |
9 | Linux系统可尝试直接pip安装,命令如下:
10 | ```
11 | pip install ftllm -U
12 | ```
13 | 若安装失败则参考[源码安装](../README.md#安装)
14 |
15 | ### 运行示例
16 |
17 | #### 命令行聊天:
18 |
19 | ```
20 | ftllm run fastllm/Qwen3-235B-A22B-INT4MIX
21 | ftllm run Qwen/Qwen3-30B-A3B
22 | ```
23 |
24 | #### webui:
25 |
26 | ```
27 | ftllm webui fastllm/Qwen3-235B-A22B-INT4MIX
28 | ftllm webui Qwen/Qwen3-30B-A3B
29 | ```
30 |
31 | #### api server (openai风格):
32 |
33 | ```
34 | ftllm server fastllm/Qwen3-235B-A22B-INT4MIX
35 | ftllm server Qwen/Qwen3-30B-A3B
36 | ```
37 |
38 | #### 参数建议
39 |
40 | 如有需要,可以将以下参数可以加在运行命令中
41 |
42 | - 硬思考模式: 千问3的独有模式,该模式默认打开,可以通过enable_thinking参数来关闭,关闭后模型将不生成思考。例如
43 |
44 | ```bash
45 | ftllm server Qwen/Qwen3-30B-A3B --enable_thinking false
46 | ```
47 |
48 | - 推理设备: 非MOE模型默认使用显卡推理,若显存容量不足希望使用纯CPU推理,可以设置`--device cpu`, 或`--device numa`使用多路numa加速
49 | - 量化: Qwen3系列模型目前建议使用参数`--dtype int4g256`指定4bit量化,`--dtype int8`指定8bit量化
50 |
51 |
52 | - MOE模型(Qwen3-30B-A3B, Qwen3-235B-A22B)默认使用cpu+gpu混合推理,若希望使用cuda推理需要指定device参数,例如
53 | ``` bash
54 | ftllm server Qwen/Qwen3-30B-A3B --device cuda --dtype int4g256
55 | ftllm server Qwen/Qwen3-30B-A3B --device cuda --dtype int8
56 | ```
57 |
58 | - 更多参数信息可参考 [常用参数](../README.md#常用参数)
59 |
60 | #### NUMA加速
61 |
62 | 若想使用单NUMA节点,建议用numactl绑定numa节点
63 |
64 | 可以设定环境变量来激活多NUMA节点加速(PIP版本可直接激活,源码安装时需要在编译时加入-DUSE_NUMA=ON选项)
65 |
66 | ```
67 | export FASTLLM_USE_NUMA=ON
68 | # export FASTLLM_NUMA_THREADS=27 # 选用,这个变量用于设定每个numa节点开启的线程数
69 | ```
70 |
71 | #### 本地模型
72 |
73 | 可以启动本地下载好的模型(支持原始模型,AWQ模型,FASTLLM模型,暂不支持GGUF模型),假设本地模型路径为 `/mnt/Qwen/Qwen3-30B-A3B`
74 | 则可以用如下命令启动(webui, server类似)
75 |
76 | ```
77 | ftllm run /mnt/Qwen/Qwen3-30B-A3B
78 | ```
79 |
80 | ### 模型下载
81 |
82 | 可以使用如下命令将模型下载到本地
83 |
84 | ```
85 | ftllm download Qwen/Qwen3-30B-A3B
86 | ```
87 |
--------------------------------------------------------------------------------
/docs/rocm.md:
--------------------------------------------------------------------------------
1 | # ROCm 编译
2 |
3 | ## 0. 支持平台
4 |
5 | ROCm 编译目前仅支持Linux平台。
6 |
7 | 目前支持的GPU型号如下:
8 |
9 | - AMD Radeon Instinct MI系列,如MI50, MI100,MI210等
10 | - AMD Radeon RDNA RX 7000 游戏卡和工作站卡系列,W7800,W7900等
11 | - 海光系列GPU,如K100等(未验证,理论可行)
12 |
13 | ## 1. 安装 ROCm,获取 ROCm Arch
14 |
15 | 请参考 [ROCm 官方文档](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/) 安装 ROCm。
16 |
17 | 可以在 [架构列表](https://rocm.docs.amd.com/en/latest/reference/gpu-arch-specs.html)的LLVM target列中找到GPU的 ROCm Arch。
18 |
19 | 常见GPU对应的架构:
20 | | 架构代号 | 架构系列 | 代表产品示例 | 推荐 ROCm 版本 |
21 | |----------|-----------|---------------------------------------------|----------------|
22 | | gfx900 | GCN5.0 | Radeon Instinct MI25 | ❌不支持 |
23 | | gfx906 | GCN5.1 | Radeon VII, Instinct MI50 | [6.3.3](https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.3.3/install/quick-start.html) |
24 | | gfx908 | CDNA | Radeon Instinct MI100 | [6.4.0](https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.4.0/install/quick-start.html) |
25 | | gfx90a | CDNA2 | Radeon Instinct MI210/MI250/MI250X | [6.4.0](https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.4.0/install/quick-start.html) |
26 | | gfx942 | CDNA3 | Instinct MI300A/MI300X/MI325X | [6.4.0](https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.4.0/install/quick-start.html) |
27 | | gfx1030 | RDNA2 | Radeon PRO W6800/V620 | [6.4.0](https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.4.0/install/quick-start.html) |
28 | | gfx1100 | RDNA3 | Radeon PRO W7800/W7900, RX 7900 XT/XTX/GRE | [6.4.0](https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.4.0/install/quick-start.html) |
29 | | gfx1101 | RDNA3 | Radeon PRO V710 | [6.4.0](https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.4.0/install/quick-start.html) |
30 |
31 |
32 |
33 | 把需要编译的GPU架构用`;`分隔,填入`-DROCM_ARCH`参数中。默认为`gfx908;gfx90a;gfx1100`。
34 |
35 | 注意,部分GPU(比如RX6000系列、MI50不支持矩阵乘法加速`rocwmma`,只要列表中有一个GPU不支持`rocwmma`,则编译时不会使用`rocwmma`。
36 |
37 | ## 2. 编译
38 |
39 | ``` sh
40 | bash install.sh -DUSE_ROCM=ON -DROCM_ARCH="gfx908;gfx90a;gfx1100"
41 | ```
42 |
43 | ## TODO
44 |
45 | - [ ] 海光系列GPU的验证
46 | - [ ] 支持`rocwmma`,能使用矩阵乘法加速
47 |
48 | ## 鸣谢
49 |
50 | [leavelet](https://github.com/leavelet) 提供ROCM支持
51 |
--------------------------------------------------------------------------------
/docs/tfacc.md:
--------------------------------------------------------------------------------
1 | ## TFACC介绍
2 |
3 | TFACC是ThinkForce公司7000系列处理器的AI算力平台,可用于TF 7000系列处理器的大模型推理加速。
4 |
5 | ## 快速开始
6 |
7 | ### 加载驱动
8 |
9 | ``` sh
10 | cd fastllm/third_party/tfacc/driver/tfacc2
11 | ./build_driver.sh
12 | modprobe tfacc2
13 | ```
14 |
15 | ### 打开TFACC计算服务
16 |
17 | ``` sh
18 | cd fastllm/third_party/tfacc
19 | python3 ./launch.py 4 & #这里的参数是numa节点数量,需要根据7000服务器具体的型号设定
20 | ```
21 |
22 | ### 编译
23 |
24 | 建议使用cmake编译,需要提前安装c++编译器,make, cmake
25 |
26 | gcc版本建议9.4以上,cmake版本建议3.23以上
27 |
28 | 使用如下命令编译
29 |
30 | ``` sh
31 | bash install.sh -DUSE_TFACC=ON
32 | ```
33 |
34 | ### 运行demo程序
35 |
36 | 我们假设已经获取了名为`model.flm`的模型(参照 [模型获取](#模型获取),初次使用可以先下载转换好的模型)
37 |
38 | 编译完成之后在build目录下可以使用下列demo:
39 |
40 | ``` sh
41 | # 这时在fastllm/build目录下
42 |
43 | # 命令行聊天程序, 支持打字机效果 (只支持Linux)
44 | ./main -p model.flm
45 |
46 | # 简易webui, 使用流式输出 + 动态batch,可多路并发访问
47 | ./webui -p model.flm --port 1234
48 |
49 | # python版本的命令行聊天程序,使用了模型创建以及流式对话效果
50 | python tools/cli_demo.py -p model.flm
51 |
52 | # python版本的简易webui,需要先安装streamlit-chat
53 | streamlit run tools/web_demo.py model.flm
54 |
55 | ```
56 |
57 | 更多功能及接口请参照[详细文档](../README.md)
--------------------------------------------------------------------------------
/docs/version.md:
--------------------------------------------------------------------------------
1 | ## V0.1.2.0
2 |
3 | - 规范版本号 a.b.c.d
4 | - a为保留位,目前为0
5 | - b为大版本号
6 | - c为小版本号
7 | - d为bug修复版本的编号
8 |
9 | ## V0.0.1.2
10 |
11 | - 优化了numa加速
12 | - 略微提升了prefill和decode速度
13 | - 支持了moe的混合张量并行,参考[混合推理指南](mixforward.md)
14 | - 修复了multicuda的一些bug,支持了所有精度的混合张量并行
15 | - 修复了C++下Jinja模板的一些bug,支持Qwen3, DS等一系列模型的内置分词器
16 |
17 | ## V0.0.1.1
18 |
19 | - 支持了 `FP8_E4M3` 精度(新老硬件均可)
20 | - MOE模型支持用`--moe_dtype`来设置混合精度
21 | - 可以在`ROCM`环境下使用`pip`安装了
22 | - 修复了C++下Jinja模板的一些bug
23 | - api server的默认输出token数由8K提升到32K
24 |
25 | ## V0.0.1.0
26 |
27 | - 支持了千问3模型 [部署指南](qwen3.md)
28 | - 优化了DeepSeek模型的显存使用
29 | - 增加参数`--cache_fast`来指定是否使用显存缓存
30 |
31 | ## V0.0.0.9
32 |
33 | - 优化了使用DeepSeek模型时的多轮对话缓存
34 | - 略微提升了DeepSeek模型的多并发速度
35 | - 减少了DeepSeek模型Prefill时的显存消耗,可以支持更长的上下文
36 | - 支持了DeepSeek模型的INT8量化 (使用原始模型时`--dtype int8`,或者导出时`--dtype int8`)
37 | - 隐藏了 "None of PyTorch, TensorFlow >= 2.0 ..." 的警告信息
38 | - 增加了`--cache_dir`参数来指定缓存目录
39 | - server增加了`--hide_input`参数来隐藏日志中的请求信息
40 | - webui增加了`--max_token`参数来指定最大输出,--think参数来强制思考
41 |
42 | ## V0.0.0.8
43 |
44 | - api server增加api_key参数,来设定api_key
45 | - api server支持了一些复合输入
46 | - 提升了moe模型prefill的速度
47 | - 增加了--version参数查看版本号
48 |
49 | ## V0.0.0.7
50 |
51 | - 增加config选项,可通过config.json文件来启动模型
52 | - 提升moe模型的速度
53 |
54 | ## V0.0.0.6
55 |
56 | - 降低GLIBC版本,PIP安装包兼容更多系统
57 | - PIP安装包支持更多架构(目前最低支持到SM_52)
58 |
59 | ## V0.0.0.5
60 |
61 | - 修改文档,增加了一些pip安装后无法使用的情况说明
62 | - 聊天模式下自动读取模型的生成配置文件
63 | - 修复一些情况下kv_cache_limit计算错误的问题
64 |
65 | ## V0.0.0.4
66 |
67 | - 增加ftllm run, chat, webui, server接口
--------------------------------------------------------------------------------
/docs/wechat_group0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/docs/wechat_group0.jpg
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/.gitignore:
--------------------------------------------------------------------------------
1 | *.iml
2 | .gradle
3 | /local.properties
4 | /.idea/caches
5 | /.idea/libraries
6 | /.idea/modules.xml
7 | /.idea/workspace.xml
8 | /.idea/navEditor.xml
9 | /.idea/assetWizardSettings.xml
10 | .DS_Store
11 | /build
12 | /captures
13 | .externalNativeBuild
14 | .cxx
15 | local.properties
16 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/.idea/.name:
--------------------------------------------------------------------------------
1 | XiaoZhihuiAssistant
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/.idea/compiler.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/.idea/deploymentTargetDropDown.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/.idea/gradle.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
19 |
20 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/.gitignore:
--------------------------------------------------------------------------------
1 | /build
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/build.gradle:
--------------------------------------------------------------------------------
1 | plugins {
2 | id 'com.android.application'
3 | }
4 |
5 | android {
6 | compileSdk 30
7 |
8 | defaultConfig {
9 | applicationId "com.doujiao.xiaozhihuiassistant"
10 | minSdk 21
11 | targetSdk 26
12 | versionCode 1
13 | versionName "1.0"
14 |
15 | testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
16 | externalNativeBuild {
17 | cmake {
18 | cppFlags '-std=c++11'
19 | }
20 | }
21 | ndk {
22 | abiFilters 'arm64-v8a','armeabi-v7a','x86'
23 | }
24 | }
25 |
26 | buildTypes {
27 | release {
28 | minifyEnabled false
29 | proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
30 | }
31 | }
32 | compileOptions {
33 | sourceCompatibility JavaVersion.VERSION_1_8
34 | targetCompatibility JavaVersion.VERSION_1_8
35 | }
36 | externalNativeBuild {
37 | cmake {
38 | path file('src/main/cpp/CMakeLists.txt')
39 | version '3.18.1'
40 | }
41 | }
42 | // sourceSets {
43 | // main {
44 | // // jnilib
45 | // jniLibs.srcDirs = ['libs']
46 | // }
47 | // }
48 | splits {
49 | abi {
50 | enable true
51 | reset()
52 | include 'arm64-v8a', 'armeabi-v7a','x86'
53 | universalApk true
54 | }
55 | }
56 | buildFeatures {
57 | viewBinding true
58 | }
59 | }
60 |
61 | dependencies {
62 |
63 | implementation 'com.android.support:appcompat-v7:28.0.0'
64 | implementation 'com.android.support:recyclerview-v7:28.0.0'
65 | implementation 'com.android.support.constraint:constraint-layout:2.0.4'
66 | testImplementation 'junit:junit:4.13.2'
67 | androidTestImplementation 'com.android.support.test:runner:1.0.2'
68 | androidTestImplementation 'com.android.support.test.espresso:espresso-core:3.0.2'
69 | }
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/libs/arm64-v8a/libassistant.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/example/Android/LLMAssistant/app/libs/arm64-v8a/libassistant.so
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/libs/armeabi-v7a/libassistant.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/example/Android/LLMAssistant/app/libs/armeabi-v7a/libassistant.so
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/proguard-rules.pro:
--------------------------------------------------------------------------------
1 | # Add project specific ProGuard rules here.
2 | # You can control the set of applied configuration files using the
3 | # proguardFiles setting in build.gradle.
4 | #
5 | # For more details, see
6 | # http://developer.android.com/guide/developing/tools/proguard.html
7 |
8 | # If your project uses WebView with JS, uncomment the following
9 | # and specify the fully qualified class name to the JavaScript interface
10 | # class:
11 | #-keepclassmembers class fqcn.of.javascript.interface.for.webview {
12 | # public *;
13 | #}
14 |
15 | # Uncomment this to preserve the line number information for
16 | # debugging stack traces.
17 | #-keepattributes SourceFile,LineNumberTable
18 |
19 | # If you keep the line number information, uncomment this to
20 | # hide the original source file name.
21 | #-renamesourcefileattribute SourceFile
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/release/app-arm64-v8a-release-unsigned.apk:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/example/Android/LLMAssistant/app/release/app-arm64-v8a-release-unsigned.apk
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/release/app-armeabi-v7a-release-unsigned.apk:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/example/Android/LLMAssistant/app/release/app-armeabi-v7a-release-unsigned.apk
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/release/app-universal-release-unsigned.apk:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/example/Android/LLMAssistant/app/release/app-universal-release-unsigned.apk
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/release/app-x86-release-unsigned.apk:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/example/Android/LLMAssistant/app/release/app-x86-release-unsigned.apk
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/androidTest/java/com/doujiao/xiaozhihuiassistant/ExampleInstrumentedTest.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.xiaozhihuiassistant;
2 |
3 | import android.content.Context;
4 | import android.support.test.InstrumentationRegistry;
5 | import android.support.test.runner.AndroidJUnit4;
6 |
7 | import org.junit.Test;
8 | import org.junit.runner.RunWith;
9 |
10 | import static org.junit.Assert.*;
11 |
12 | /**
13 | * Instrumented test, which will execute on an Android device.
14 | *
15 | * @see Testing documentation
16 | */
17 | @RunWith(AndroidJUnit4.class)
18 | public class ExampleInstrumentedTest {
19 | @Test
20 | public void useAppContext() {
21 | // Context of the app under test.
22 | Context appContext = InstrumentationRegistry.getInstrumentation().getTargetContext();
23 | assertEquals("com.doujiao.xiaozhihuiassistant", appContext.getPackageName());
24 | }
25 | }
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/AndroidManifest.xml:
--------------------------------------------------------------------------------
1 |
2 |
4 |
5 |
6 |
7 |
8 |
16 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/cpp/LLMChat.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include "LLMChat.h"
6 |
7 | #include "model.h"
8 | //void(^ __nonnull RuntimeChat)(int index,const char* _Nonnull content) = NULL;//实时回调
9 |
10 | static int modeltype = 0;
11 | static char* modelpath = NULL;
12 | static std::unique_ptr chatGlm = NULL;
13 | static int sRound = 0;
14 | static std::string history;
15 | static RuntimeResultMobile g_callback = NULL;
16 |
17 | std::string initGptConf(const char* modelPath,int threads) {
18 | fastllm::SetThreads(threads);
19 | LOG_Debug("@@init llmpath:%s\n",modelPath);
20 | chatGlm = fastllm::CreateLLMModelFromFile(modelPath);
21 | if(chatGlm != NULL)
22 | {
23 | std::string modelName = chatGlm->model_type;
24 | LOG_Debug("@@model name:%s\n",modelName.c_str());
25 | return modelName;
26 | }
27 | LOG_Debug("@@CreateLLMModelFromFile failed.");
28 | return "";
29 | }
30 |
31 | int chat(const char* prompt, RuntimeResultMobile chatCallback) {
32 | std::string ret = "";
33 | g_callback = chatCallback;
34 | LOG_Debug("@@init llm:type:%d,prompt:%s\n",modeltype,prompt);
35 | std::string input(prompt);
36 |
37 | if (input == "reset") {
38 | history = "";
39 | sRound = 0;
40 | g_callback(0,"Done!");
41 | g_callback(-1,"");
42 | return 0;
43 | }
44 |
45 | ret = chatGlm->Response(chatGlm->MakeInput(history, sRound, input), [](int index, const char* content) {
46 | g_callback(index,content);
47 | });
48 | history = chatGlm->MakeHistory(history, sRound, input, ret);
49 | sRound++;
50 |
51 | long len = ret.length();
52 | return len;
53 | }
54 |
55 | void uninitLLM()
56 | {
57 | }
58 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/cpp/LLMChat.h:
--------------------------------------------------------------------------------
1 | //
2 | // LLMChat.h
3 | // LLMChat
4 | //
5 | // Created by 胡其斌 on 2023/5/18.
6 | //
7 |
8 | #ifdef __cplusplus
9 | extern "C" {
10 | #endif
11 |
12 | #include
13 | #define LOG_Debug(...) __android_log_print(ANDROID_LOG_DEBUG, "Assistant", __VA_ARGS__)
14 |
15 | typedef void(* RuntimeResultMobile)(int index,const char* content);
16 |
17 | std::string initGptConf(const char* modelPath,int threads);
18 | int chat(const char* prompt, RuntimeResultMobile chatCallback);
19 | void uninitLLM();
20 |
21 | #ifdef __cplusplus
22 | }
23 | #endif
24 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/java/com/doujiao/core/AssistantCore.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.core;
2 |
3 | import android.support.annotation.Keep;
4 | import android.util.Log;
5 |
6 | public class AssistantCore {
7 |
8 | private static AssistantCore instance = null;
9 | private static runtimeResult mRuntimeRes = null;
10 |
11 | static {
12 | System.loadLibrary("assistant");
13 | }
14 |
15 | /*静态对象*/
16 | public static AssistantCore getInstance(){
17 | if(instance == null){
18 | instance = new AssistantCore();
19 | }
20 |
21 | return instance;
22 | }
23 |
24 | public String initLLM(String path,runtimeResult callback) {
25 | mRuntimeRes = callback;
26 | return initLLMConfig(path,8);
27 | }
28 |
29 | @Keep
30 | public void reportChat(String content,int index) {
31 | Log.d("@@@","recv:"+content+",index:"+index);
32 | if (mRuntimeRes != null) {
33 | mRuntimeRes.callbackResult(index,content);
34 | }
35 | }
36 |
37 | public interface runtimeResult {
38 | void callbackResult(int index,String content);
39 | }
40 |
41 | private native String initLLMConfig(String path,int threads);
42 | public native int chat(String prompt);
43 | public native int uninitLLM();
44 | }
45 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/java/com/doujiao/xiaozhihuiassistant/ChatMessage.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.xiaozhihuiassistant;
2 |
3 | /**
4 | * Created by chenpengfei on 2016/10/27.
5 | */
6 | public class ChatMessage {
7 |
8 | private String content;
9 |
10 | private int type;
11 |
12 | public ChatMessage(String content, int type) {
13 | this.content = content;
14 | this.type = type;
15 | }
16 |
17 | public ChatMessage(String content) {
18 | this(content, 1);
19 | }
20 |
21 |
22 | public String getContent() {
23 | return content;
24 | }
25 |
26 | public void setContent(String content) {
27 | this.content = content;
28 | }
29 |
30 | public int getType() {
31 | return type;
32 | }
33 |
34 | public void setType(int type) {
35 | this.type = type;
36 | }
37 |
38 |
39 | }
40 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/java/com/doujiao/xiaozhihuiassistant/adapter/BaseViewHolder.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.xiaozhihuiassistant.adapter;
2 |
3 | import android.support.v7.widget.RecyclerView;
4 | import android.view.View;
5 |
6 | /**
7 | * Created by chenpengfei on 2016/10/27.
8 | */
9 | public class BaseViewHolder extends RecyclerView.ViewHolder {
10 |
11 | private View iv;
12 |
13 | public BaseViewHolder(View itemView) {
14 | super(itemView);
15 | iv = itemView;
16 | }
17 |
18 | public View findViewById(int id) {
19 | return iv.findViewById(id);
20 | }
21 | }
22 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/java/com/doujiao/xiaozhihuiassistant/utils/PrefUtil.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.xiaozhihuiassistant.utils;
2 |
3 | import android.content.Context;
4 | import android.content.SharedPreferences;
5 |
6 | public class PrefUtil {
7 | private static final String SF_NAME = "com.doujiao.llm.config";
8 | private static final String MOLE_PATH = "llm_path";
9 | private static SharedPreferences mPref;
10 |
11 | public static void initPref(Context context) {
12 | if (mPref == null) {
13 | mPref = context.getSharedPreferences(SF_NAME, Context.MODE_PRIVATE);
14 | }
15 | }
16 |
17 | public static void setModelPath(String path) {
18 | if (mPref != null) {
19 | mPref.edit().putString(MOLE_PATH,path).apply();
20 | }
21 | }
22 |
23 | public static String getModelPath() {
24 | if (mPref != null) {
25 | return mPref.getString(MOLE_PATH,"");
26 | }
27 | return "";
28 | }
29 | }
30 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/java/com/doujiao/xiaozhihuiassistant/utils/StatusBarUtils.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.xiaozhihuiassistant.utils;
2 |
3 | import android.content.ClipData;
4 | import android.content.ClipboardManager;
5 | import android.content.Context;
6 | import android.graphics.Color;
7 | import android.os.Build;
8 | import android.support.v7.app.ActionBar;
9 | import android.support.v7.app.AppCompatActivity;
10 | import android.view.View;
11 | import android.view.Window;
12 | import android.view.WindowManager;
13 |
14 | public class StatusBarUtils {
15 | public static void setTranslucentStatus(AppCompatActivity activity) {
16 | if (Build.VERSION.SDK_INT >= 21) {
17 | View decorView = activity.getWindow().getDecorView();
18 | int option = View.SYSTEM_UI_FLAG_LAYOUT_FULLSCREEN
19 | | View.SYSTEM_UI_FLAG_LAYOUT_STABLE;
20 | decorView.setSystemUiVisibility(option);
21 | activity.getWindow().setStatusBarColor(Color.TRANSPARENT);
22 | }
23 | ActionBar actionBar = activity.getSupportActionBar();
24 | actionBar.hide();
25 | }
26 |
27 | public static void hideStatusBar(Window window, boolean darkText) {
28 | window.clearFlags(WindowManager.LayoutParams.FLAG_TRANSLUCENT_STATUS);
29 | window.addFlags(WindowManager.LayoutParams.FLAG_DRAWS_SYSTEM_BAR_BACKGROUNDS);
30 | window.setStatusBarColor(Color.TRANSPARENT);
31 |
32 | int flag = View.SYSTEM_UI_FLAG_LAYOUT_STABLE;
33 | if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M && darkText) {
34 | flag = View.SYSTEM_UI_FLAG_LIGHT_STATUS_BAR;
35 | }
36 |
37 | window.getDecorView().setSystemUiVisibility(flag |
38 | View.SYSTEM_UI_FLAG_LAYOUT_FULLSCREEN);
39 | }
40 |
41 | public static boolean copyStr2ClibBoard(Context context, String copyStr) {
42 | try {
43 | //获取剪贴板管理器
44 | ClipboardManager cm = (ClipboardManager) context.getSystemService(Context.CLIPBOARD_SERVICE);
45 | // 创建普通字符型ClipData
46 | ClipData mClipData = ClipData.newPlainText("Label", copyStr);
47 | // 将ClipData内容放到系统剪贴板里。
48 | cm.setPrimaryClip(mClipData);
49 | return true;
50 | } catch (Exception e) {
51 | return false;
52 | }
53 | }
54 | }
55 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/java/com/doujiao/xiaozhihuiassistant/widget/ChatPromptViewManager.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.xiaozhihuiassistant.widget;
2 |
3 | import android.app.Activity;
4 | import android.view.View;
5 |
6 | /**
7 | * Created by chenpengfei on 2016/11/2.
8 | */
9 | public class ChatPromptViewManager extends PromptViewHelper.PromptViewManager {
10 |
11 | public ChatPromptViewManager(Activity activity, String[] dataArray, Location location) {
12 | super(activity, dataArray, location);
13 | }
14 |
15 | public ChatPromptViewManager(Activity activity) {
16 | this(activity, new String[]{"复制"}, Location.TOP_LEFT);
17 | }
18 |
19 | public ChatPromptViewManager(Activity activity, Location location) {
20 | this(activity, new String[]{"复制"}, location);
21 | }
22 |
23 |
24 | @Override
25 | public View inflateView() {
26 | return new PromptView(activity);
27 | }
28 |
29 | @Override
30 | public void bindData(View view, String[] dataArray) {
31 | if(view instanceof PromptView) {
32 | PromptView promptView = (PromptView) view;
33 | promptView.setContentArray(dataArray);
34 | promptView.setOnItemClickListener(new PromptView.OnItemClickListener() {
35 | @Override
36 | public void onItemClick(int position) {
37 | if(onItemClickListener != null) onItemClickListener.onItemClick(position);
38 | }
39 | });
40 | }
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/java/com/doujiao/xiaozhihuiassistant/widget/Location.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.xiaozhihuiassistant.widget;
2 |
3 | import com.doujiao.xiaozhihuiassistant.widget.location.BottomCenterLocation;
4 | import com.doujiao.xiaozhihuiassistant.widget.location.ICalculateLocation;
5 | import com.doujiao.xiaozhihuiassistant.widget.location.TopCenterLocation;
6 | import com.doujiao.xiaozhihuiassistant.widget.location.TopLeftLocation;
7 | import com.doujiao.xiaozhihuiassistant.widget.location.TopRightLocation;
8 |
9 | /**
10 | * Created by chenpengfei on 2016/11/2.
11 | */
12 | public enum Location {
13 |
14 | TOP_LEFT(1),
15 | TOP_CENTER(2),
16 | TOP_RIGHT(3),
17 | BOTTOM_LEFT(4),
18 | BOTTOM_CENTER(5),
19 | BOTTOM_RIGHT(6);
20 |
21 | ICalculateLocation calculateLocation;
22 |
23 | private Location(int type) {
24 | switch (type) {
25 | case 1:
26 | calculateLocation = new TopLeftLocation();
27 | break;
28 | case 2:
29 | calculateLocation = new TopCenterLocation();
30 | break;
31 | case 3:
32 | calculateLocation = new TopRightLocation();
33 | break;
34 | case 4:
35 | calculateLocation = new TopLeftLocation();
36 | break;
37 | case 5:
38 | calculateLocation = new BottomCenterLocation();
39 | break;
40 | case 6:
41 | calculateLocation = new TopLeftLocation();
42 | break;
43 | }
44 | }
45 |
46 | }
47 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/java/com/doujiao/xiaozhihuiassistant/widget/location/BottomCenterLocation.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.xiaozhihuiassistant.widget.location;
2 |
3 | import android.view.View;
4 |
5 | /**
6 | * Created by chenpengfei on 2016/11/2.
7 | */
8 | public class BottomCenterLocation implements ICalculateLocation {
9 |
10 | @Override
11 | public int[] calculate(int[] srcViewLocation, View srcView, View promptView) {
12 | int[] location = new int[2];
13 | int offset = (promptView.getWidth() - srcView.getWidth()) / 2;
14 | location[0] = srcViewLocation[0] - offset;
15 | location[1] = srcViewLocation[1] + promptView.getHeight();
16 | return location;
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/java/com/doujiao/xiaozhihuiassistant/widget/location/ICalculateLocation.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.xiaozhihuiassistant.widget.location;
2 |
3 | import android.view.View;
4 |
5 | /**
6 | * Created by chenpengfei on 2016/11/2.
7 | */
8 | public interface ICalculateLocation {
9 |
10 | int[] calculate(int[] srcViewLocation, View srcView, View promptView);
11 |
12 | }
13 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/java/com/doujiao/xiaozhihuiassistant/widget/location/TopCenterLocation.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.xiaozhihuiassistant.widget.location;
2 |
3 | import android.view.View;
4 |
5 | /**
6 | * Created by chenpengfei on 2016/11/2.
7 | */
8 | public class TopCenterLocation implements ICalculateLocation {
9 |
10 | @Override
11 | public int[] calculate(int[] srcViewLocation, View srcView, View promptView) {
12 | int[] location = new int[2];
13 | int offset = (promptView.getWidth() - srcView.getWidth()) / 2;
14 | location[0] = srcViewLocation[0] - offset;
15 | location[1] = srcViewLocation[1] - promptView.getHeight();
16 | return location;
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/java/com/doujiao/xiaozhihuiassistant/widget/location/TopLeftLocation.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.xiaozhihuiassistant.widget.location;
2 |
3 | import android.view.View;
4 |
5 | /**
6 | * Created by chenpengfei on 2016/11/2.
7 | */
8 | public class TopLeftLocation implements ICalculateLocation {
9 |
10 | @Override
11 | public int[] calculate(int[] srcViewLocation, View srcView, View promptView) {
12 | int[] location = new int[2];
13 | location[0] = srcViewLocation[0];
14 | location[1] = srcViewLocation[1] - promptView.getHeight();
15 | return location;
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/java/com/doujiao/xiaozhihuiassistant/widget/location/TopRightLocation.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.xiaozhihuiassistant.widget.location;
2 |
3 | import android.view.View;
4 |
5 | /**
6 | * Created by chenpengfei on 2016/11/2.
7 | */
8 | public class TopRightLocation implements ICalculateLocation {
9 |
10 | @Override
11 | public int[] calculate(int[] srcViewLocation, View srcView, View promptView) {
12 | int[] location = new int[2];
13 | int offset = promptView.getWidth() - srcView.getWidth();
14 | location[0] = srcViewLocation[0] - offset;
15 | location[1] = srcViewLocation[1] - promptView.getHeight();
16 | return location;
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/drawable-v24/ic_launcher_foreground.xml:
--------------------------------------------------------------------------------
1 |
7 |
8 |
9 |
15 |
18 |
21 |
22 |
23 |
24 |
30 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/drawable/btnbg.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | -
5 |
7 |
8 |
9 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/drawable/editbg.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | -
5 |
7 |
8 |
9 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/layout/activity_item_left.xml:
--------------------------------------------------------------------------------
1 |
2 |
6 |
7 |
11 |
12 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/layout/activity_item_right.xml:
--------------------------------------------------------------------------------
1 |
2 |
6 |
7 |
13 |
14 |
27 |
28 |
29 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/layout/activity_main.xml:
--------------------------------------------------------------------------------
1 |
2 |
9 |
10 |
17 |
21 |
26 |
34 |
35 |
36 |
43 |
44 |
52 |
60 |
69 |
70 |
71 |
72 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/mipmap-anydpi-v26/ic_launcher.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/mipmap-anydpi-v26/ic_launcher_round.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/mipmap-hdpi/ic_launcher.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/example/Android/LLMAssistant/app/src/main/res/mipmap-hdpi/ic_launcher.webp
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/mipmap-hdpi/ic_launcher_round.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/example/Android/LLMAssistant/app/src/main/res/mipmap-hdpi/ic_launcher_round.webp
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/mipmap-mdpi/ic_launcher.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/example/Android/LLMAssistant/app/src/main/res/mipmap-mdpi/ic_launcher.webp
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/mipmap-mdpi/ic_launcher_round.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/example/Android/LLMAssistant/app/src/main/res/mipmap-mdpi/ic_launcher_round.webp
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/mipmap-xhdpi/ic_launcher.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/example/Android/LLMAssistant/app/src/main/res/mipmap-xhdpi/ic_launcher.webp
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/mipmap-xhdpi/ic_launcher_round.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/example/Android/LLMAssistant/app/src/main/res/mipmap-xhdpi/ic_launcher_round.webp
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/mipmap-xxhdpi/glm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/example/Android/LLMAssistant/app/src/main/res/mipmap-xxhdpi/glm.png
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/mipmap-xxhdpi/ic_launcher.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/example/Android/LLMAssistant/app/src/main/res/mipmap-xxhdpi/ic_launcher.webp
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/example/Android/LLMAssistant/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.webp
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/mipmap-xxhdpi/me.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/example/Android/LLMAssistant/app/src/main/res/mipmap-xxhdpi/me.png
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/mipmap-xxxhdpi/ic_launcher.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/example/Android/LLMAssistant/app/src/main/res/mipmap-xxxhdpi/ic_launcher.webp
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/example/Android/LLMAssistant/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.webp
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/values-night/themes.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
10 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/values/colors.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | #FFBB86FC
4 | #FF6200EE
5 | #FF3700B3
6 | #FF03DAC5
7 | #FF018786
8 | #FF000000
9 | #FFFFFFFF
10 | #FFC0C0C0
11 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/values/strings.xml:
--------------------------------------------------------------------------------
1 |
2 | FastLLM
3 | 选择
4 | 模型
5 | 请从手机存储中选择llm模型
6 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/main/res/values/themes.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
10 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/app/src/test/java/com/doujiao/xiaozhihuiassistant/ExampleUnitTest.java:
--------------------------------------------------------------------------------
1 | package com.doujiao.xiaozhihuiassistant;
2 |
3 | import org.junit.Test;
4 |
5 | import static org.junit.Assert.*;
6 |
7 | /**
8 | * Example local unit test, which will execute on the development machine (host).
9 | *
10 | * @see Testing documentation
11 | */
12 | public class ExampleUnitTest {
13 | @Test
14 | public void addition_isCorrect() {
15 | assertEquals(4, 2 + 2);
16 | }
17 | }
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/build.gradle:
--------------------------------------------------------------------------------
1 | // Top-level build file where you can add configuration options common to all sub-projects/modules.
2 | plugins {
3 | id 'com.android.application' version '7.1.2' apply false
4 | id 'com.android.library' version '7.1.2' apply false
5 | }
6 |
7 | task clean(type: Delete) {
8 | delete rootProject.buildDir
9 | }
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/gradle.properties:
--------------------------------------------------------------------------------
1 | # Project-wide Gradle settings.
2 | # IDE (e.g. Android Studio) users:
3 | # Gradle settings configured through the IDE *will override*
4 | # any settings specified in this file.
5 | # For more details on how to configure your build environment visit
6 | # http://www.gradle.org/docs/current/userguide/build_environment.html
7 | # Specifies the JVM arguments used for the daemon process.
8 | # The setting is particularly useful for tweaking memory settings.
9 | org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8
10 | # When configured, Gradle will run in incubating parallel mode.
11 | # This option should only be used with decoupled projects. More details, visit
12 | # http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
13 | # org.gradle.parallel=true
14 | # Enables namespacing of each library's R class so that its R class includes only the
15 | # resources declared in the library itself and none from the library's dependencies,
16 | # thereby reducing the size of the R class for that library
17 | android.nonTransitiveRClass=true
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/gradle/wrapper/gradle-wrapper.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/example/Android/LLMAssistant/gradle/wrapper/gradle-wrapper.jar
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/gradle/wrapper/gradle-wrapper.properties:
--------------------------------------------------------------------------------
1 | #Wed May 24 16:20:51 CST 2023
2 | distributionBase=GRADLE_USER_HOME
3 | distributionUrl=https\://services.gradle.org/distributions/gradle-7.2-bin.zip
4 | distributionPath=wrapper/dists
5 | zipStorePath=wrapper/dists
6 | zipStoreBase=GRADLE_USER_HOME
7 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/gradlew.bat:
--------------------------------------------------------------------------------
1 | @rem
2 | @rem Copyright 2015 the original author or authors.
3 | @rem
4 | @rem Licensed under the Apache License, Version 2.0 (the "License");
5 | @rem you may not use this file except in compliance with the License.
6 | @rem You may obtain a copy of the License at
7 | @rem
8 | @rem https://www.apache.org/licenses/LICENSE-2.0
9 | @rem
10 | @rem Unless required by applicable law or agreed to in writing, software
11 | @rem distributed under the License is distributed on an "AS IS" BASIS,
12 | @rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | @rem See the License for the specific language governing permissions and
14 | @rem limitations under the License.
15 | @rem
16 |
17 | @if "%DEBUG%" == "" @echo off
18 | @rem ##########################################################################
19 | @rem
20 | @rem Gradle startup script for Windows
21 | @rem
22 | @rem ##########################################################################
23 |
24 | @rem Set local scope for the variables with windows NT shell
25 | if "%OS%"=="Windows_NT" setlocal
26 |
27 | set DIRNAME=%~dp0
28 | if "%DIRNAME%" == "" set DIRNAME=.
29 | set APP_BASE_NAME=%~n0
30 | set APP_HOME=%DIRNAME%
31 |
32 | @rem Resolve any "." and ".." in APP_HOME to make it shorter.
33 | for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi
34 |
35 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
36 | set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"
37 |
38 | @rem Find java.exe
39 | if defined JAVA_HOME goto findJavaFromJavaHome
40 |
41 | set JAVA_EXE=java.exe
42 | %JAVA_EXE% -version >NUL 2>&1
43 | if "%ERRORLEVEL%" == "0" goto execute
44 |
45 | echo.
46 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
47 | echo.
48 | echo Please set the JAVA_HOME variable in your environment to match the
49 | echo location of your Java installation.
50 |
51 | goto fail
52 |
53 | :findJavaFromJavaHome
54 | set JAVA_HOME=%JAVA_HOME:"=%
55 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe
56 |
57 | if exist "%JAVA_EXE%" goto execute
58 |
59 | echo.
60 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
61 | echo.
62 | echo Please set the JAVA_HOME variable in your environment to match the
63 | echo location of your Java installation.
64 |
65 | goto fail
66 |
67 | :execute
68 | @rem Setup the command line
69 |
70 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
71 |
72 |
73 | @rem Execute Gradle
74 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %*
75 |
76 | :end
77 | @rem End local scope for the variables with windows NT shell
78 | if "%ERRORLEVEL%"=="0" goto mainEnd
79 |
80 | :fail
81 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
82 | rem the _cmd.exe /c_ return code!
83 | if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
84 | exit /b 1
85 |
86 | :mainEnd
87 | if "%OS%"=="Windows_NT" endlocal
88 |
89 | :omega
90 |
--------------------------------------------------------------------------------
/example/Android/LLMAssistant/settings.gradle:
--------------------------------------------------------------------------------
1 | pluginManagement {
2 | repositories {
3 | gradlePluginPortal()
4 | google()
5 | mavenCentral()
6 | }
7 | }
8 | dependencyResolutionManagement {
9 | repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
10 | repositories {
11 | google()
12 | mavenCentral()
13 | }
14 | }
15 | rootProject.name = "XiaoZhihuiAssistant"
16 | include ':app'
17 |
--------------------------------------------------------------------------------
/example/FastllmStudio/cli/ui.cpp:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 5/17/24.
3 | //
4 |
5 | #include "ui.h"
6 |
7 | namespace fastllmui {
8 | inline char getCh(){
9 | static char ch;
10 | int ret = system("stty -icanon -echo");
11 | ret = scanf("%c", &ch);
12 | ret = system("stty icanon echo");
13 | return ch;
14 | }
15 |
16 | void PrintNormalLine(const std::string &line) {
17 | printf("%s", line.c_str());
18 | }
19 |
20 | void PrintHighlightLine(const std::string &line) {
21 | printf("\e[1;31;40m %s \e[0m", line.c_str());
22 | }
23 |
24 | void HideCursor() {
25 | printf("\033[?25l");
26 | }
27 |
28 | void ShowCursor() {
29 | printf("\033[?25h");
30 | }
31 |
32 | void ClearScreen() {
33 | printf("\033c");
34 | }
35 |
36 | void CursorUp() {
37 | printf("\033[F");
38 | }
39 |
40 | void CursorDown() {
41 | printf("\033[B");
42 | }
43 |
44 | void CursorClearLine() {
45 | printf("\033[1G");
46 | printf("\033[K");
47 | }
48 |
49 | int Menu::Show() {
50 | for (int i = 0; i < items.size(); i++) {
51 | if (i == curIndex) {
52 | PrintHighlightLine(items[i]);
53 | printf("\n");
54 | } else {
55 | PrintNormalLine(items[i]);
56 | printf("\n");
57 | }
58 | }
59 |
60 | for (int i = curIndex; i < items.size(); i++) {
61 | printf("\033[F");
62 | }
63 |
64 | std::string upString = {27, 91, 65};
65 | std::string downString = {27, 91, 66};
66 | std::string now = "";
67 | while (true) {
68 | char ch = getCh();
69 | if (ch == '\r' || ch == '\n') {
70 | return curIndex;
71 | } else {
72 | now += ch;
73 | if (now.size() >= 3 && now.substr(now.size() - 3) == downString) {
74 | if (curIndex + 1 < items.size()) {
75 | CursorClearLine();
76 | PrintNormalLine(items[curIndex++]);
77 | CursorDown();
78 | CursorClearLine();
79 | PrintHighlightLine(items[curIndex]);
80 | }
81 | } else if (now.size() >= 3 && now.substr(now.size() - 3) == upString) {
82 | if (curIndex - 1 >= 0) {
83 | CursorClearLine();
84 | PrintNormalLine(items[curIndex--]);
85 | CursorUp();
86 | CursorClearLine();
87 | PrintHighlightLine(items[curIndex]);
88 | }
89 | }
90 | }
91 | }
92 | }
93 | } // namespace fastllmui
94 |
--------------------------------------------------------------------------------
/example/FastllmStudio/cli/ui.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 5/17/24.
3 | //
4 |
5 | #ifndef FASTLLMUI_H
6 | #define FASTLLMUI_H
7 |
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | namespace fastllmui {
15 | void HideCursor();
16 | void ShowCursor();
17 |
18 | void ClearScreen();
19 | void CursorUp();
20 | void CursorDown();
21 | void CursorHome();
22 | void CursorClearLine();
23 |
24 | struct Menu {
25 | std::vector items;
26 | int curIndex = 0;
27 |
28 | Menu (std::vector items) :
29 | items(items) {}
30 |
31 | int Show();
32 | };
33 | }
34 |
35 | #endif
--------------------------------------------------------------------------------
/example/Qui/bin/Qt5Core.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/example/Qui/bin/Qt5Core.dll
--------------------------------------------------------------------------------
/example/Qui/bin/Qt5Gui.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/example/Qui/bin/Qt5Gui.dll
--------------------------------------------------------------------------------
/example/Qui/bin/Qt5Widgets.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/example/Qui/bin/Qt5Widgets.dll
--------------------------------------------------------------------------------
/example/Qui/bin/Qui.exe:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/example/Qui/bin/Qui.exe
--------------------------------------------------------------------------------
/example/Qui/bin/fastllm_cpu.exe:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/example/Qui/bin/fastllm_cpu.exe
--------------------------------------------------------------------------------
/example/Qui/bin/fastllm_cuda.exe:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/example/Qui/bin/fastllm_cuda.exe
--------------------------------------------------------------------------------
/example/Qui/bin/path.txt:
--------------------------------------------------------------------------------
1 | C:/DEV/Prj/fastllm/fastllm-bin/Release
2 | C:/DEV/Prj/fastllm/fastllm/example/Qui/bin/fastllm_cuda.exe
3 |
--------------------------------------------------------------------------------
/example/Qui/bin/platforms/qwindows.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/example/Qui/bin/platforms/qwindows.dll
--------------------------------------------------------------------------------
/example/Qui/bin/qui_cn.qm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/example/Qui/bin/qui_cn.qm
--------------------------------------------------------------------------------
/example/Qui/bin/styles/qwindowsvistastyle.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/example/Qui/bin/styles/qwindowsvistastyle.dll
--------------------------------------------------------------------------------
/example/Qui/src/Qui.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include "ui_Qui.h"
9 | #include
10 | #include
11 |
12 | class Qui : public QWidget
13 | {
14 | Q_OBJECT
15 |
16 | public:
17 | Qui(QWidget *parent = nullptr);
18 | ~Qui() {};
19 |
20 | private slots:
21 | void clearChat();
22 | void wirteToFile();
23 | void sendAI();
24 | void onReset();
25 |
26 | void onPathSelectModel();
27 | void onPathSelectFlm();
28 |
29 | void onDeviceCheck();
30 |
31 | void runModelWithSetting();
32 | void stopModelRuning();
33 |
34 | bool eventFilter(QObject *obj, QEvent *event) override;
35 |
36 | void readData();
37 | void finishedProcess();
38 | void errorProcess();
39 |
40 | private:
41 | QProcess *process;
42 | void closeEvent(QCloseEvent *event) override;
43 |
44 | private:
45 | Ui::QuiClass ui;
46 |
47 | QString modelPath = "";
48 | QString flmPath = "";
49 | void updateModelList();
50 | bool inited = false;
51 | };
52 |
--------------------------------------------------------------------------------
/example/Qui/src/Qui.pro:
--------------------------------------------------------------------------------
1 | QT += core gui svg widgets network
2 | CONFIG += c++17
3 |
4 | APP_NAME = Qui
5 | DESTDIR = ../bin
6 |
7 |
8 | SOURCES += \
9 | Qui.cpp \
10 | main.cpp
11 |
12 | HEADERS += \
13 | Qui.h \
14 | resource.h
15 |
16 | FORMS += \
17 | Qui.ui
18 |
19 | TRANSLATIONS = qui_cn.ts
20 |
--------------------------------------------------------------------------------
/example/Qui/src/main.cpp:
--------------------------------------------------------------------------------
1 | #include "Qui.h"
2 | #include
3 |
4 | int
5 | main(int argc, char *argv[])
6 | {
7 | QApplication a(argc, argv);
8 |
9 | QTranslator tran;
10 |
11 | if (tran.load(QString("qui_cn.qm")))
12 | {
13 | a.installTranslator(&tran);
14 | }
15 |
16 | Qui w;
17 | w.show();
18 | return a.exec();
19 | }
20 |
--------------------------------------------------------------------------------
/example/README.md:
--------------------------------------------------------------------------------
1 | # example 示例项目
2 |
3 | ## Benchmark
4 |
5 | 测速示例程序,方便大家测试不同软硬件下的推理性能。作者测试的速度可以参考[这里](doc/benchmark.md)。
6 |
7 | 由于实际使用时很难满足batch的条件,也并非贪婪解码,该速度与真实使用时的速度有一定差异。
8 |
9 | ### 使用方法:
10 |
11 | CPU:
12 |
13 | `./benchmark -p chatglm-6b-int4.flm -f prompts.txt -t [线程数] --batch [Batch大小]`
14 |
15 | GPU:
16 |
17 | `./benchmark -p chatglm-6b-int4.flm -f prompts.txt --batch [Batch大小]`
18 |
19 |
20 |
21 | ## Web UI
22 |
23 | 由 Jacques CHEN 提供,鸣谢!
24 |
25 | ## Win32Demo (Windows平台)
26 |
27 | Win32Demo,是windows平台上运行FastLLM程序的一个Visual Studio工程。
28 | 由于Windows控制台默认编码为ANSI(中文是GBK编码,code page 936),而FastLLM默认输入输出编码为UTF-8,故与`main`存在一些差异,特提供专门的版本。为防止部分token是半个字符(如BPE编码),目前连续的中文字符是一并输出的。
29 |
30 | 生成的exe位置为:`Win32Demo\bin\Win32Demo.exe`
31 |
32 | 请尽量编译Release版本,速度快!
33 |
34 | 除此之外提供了fastllm的.vcproj文件,带GPU支持,本项目最低可在Visual Studio 2015 Update 3 下编译通过。
35 | (但是**编译pyfastllm至少需要 MSVC 2017**)
36 |
37 | ### 编译
38 |
39 | fastllm工程目前分为CPU版本和GPU版本,为简单上手,在没有cmake时,本项目可以使用Visual Studio工程文件并配置预处理器定义开关功能项。默认使用CPU版本。
40 |
41 | 签出代码后,**修改 include/fastllm.h**,Visual Studio中点击”文件“ -> "高级保存选项",在编码中选择”Unicode (UTF-8 **带签名**) -代码页 65001“,或在其他文本编辑器中转为”UTF-8 BOM“编码。(由于linux下gcc不识别BOM头,该修改只能手动处理。)
42 |
43 | * **CPU版本**:
44 | * 如果本机没有安装CUDA,在Win32Demo项目“属性”中找到"链接器" -> "输入" -> "附加依赖项",点击'从父级或项目设置继承'。
45 |
46 | * **GPU版本**:
47 | - 需要正确安装CUDA及其中的Visual Studio Integration;
48 | - 正确配置CUDA_PATH环境变量,指向要编译的CUDA版本;
49 | - 在解决方案资源管理器中移除fastllm.vcproj,引入fastllm-gpu.vcproj,
50 | - 对fastllm-gpu项目,在”生成依赖项“ -> "生成自定义" 中手动添加已安装的CUDA的自定义项文件;
51 | - 对fastllm-gpu项目,在”属性“中找到"CUDA C/C++" -> "Device" -> "Code Generation" 中配置编译后支持的[GPU计算能力](https://developer.nvidia.com/cuda-gpus#compute);
52 | - 在Win32Demo项目上选择”添加“ -> "引用“,勾选fastllm-gpu项目;
53 | - 在Win32Demo项目上配置预处理器定义”USE_CUDA“。
54 |
55 | ### 使用方法:
56 |
57 | 1. 打开命令提示符cmd;
58 |
59 | 2. `cd example\Win32Demo\bin` ;
60 |
61 | 3. 运行时参数与`main`基本一致。但多一个参数 `-w` ,表示启动webui,不加为控制台运行。如:
62 |
63 | `Win32Demo.exe -p c:\chatglm-6b-v1.1-int4.flm -w`
64 |
65 | ## Android (android平台)
66 | Android,使用Android studio工具建立的一個Android平台上运行LLM程序的例子。
67 |
68 | ### 使用方法:
69 |
70 | 1.在Android Studio直接打开工程运行。
71 |
72 | 2.直接下载release目录里里面的apk体验。
73 |
74 | 3.可以通过CMake工具链编译main文件(具体步骤见主页的readme),通过adb shell运行,
75 |
76 | 1. `adb push main /data/local/tmp` 将main文件放到手机的tmp文件夹,
77 | 2. `adb shell` ,
78 | 3. `cd /data/local/tmp`
79 | 4. `./main` 运行。
80 |
81 | 注意:demo apk 会将模型文件复制到应用 data 目录以方便 native 读取,因此设备需准备至少两倍模型大小的空余空间。
--------------------------------------------------------------------------------
/example/Win32Demo/StringUtils.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 | #include
4 |
5 | std::string utf2Gb(const char* szU8)
6 | {
7 | int wcsLen = ::MultiByteToWideChar(CP_UTF8, NULL, szU8, strlen(szU8), NULL, 0);
8 | wchar_t* wszString = new wchar_t[wcsLen + 1];
9 | ::MultiByteToWideChar(CP_UTF8, NULL, szU8, strlen(szU8), wszString, wcsLen);
10 | wszString[wcsLen] = '\0';
11 |
12 | wcsLen = WideCharToMultiByte(CP_ACP, 0, wszString, -1, NULL, 0, NULL, NULL);
13 | char* gb2312 = new char[wcsLen + 1];
14 | memset(gb2312, 0, wcsLen + 1);
15 | WideCharToMultiByte(CP_ACP, 0, wszString, -1, gb2312, wcsLen, NULL, NULL);
16 |
17 | if (wszString)
18 | delete[] wszString;
19 | std::string gbstr(gb2312);
20 | if (gb2312)
21 | delete[] gb2312;
22 | return gbstr;
23 | }
24 |
25 | std::string Gb2utf(std::string ws)
26 | {
27 | //int dwNum = WideCharToMultiByte(CP_UTF8, 0, ws.c_str(), -1, 0, 0, 0, 0);
28 | int dwNum = MultiByteToWideChar(CP_ACP, 0, ws.c_str(), -1, NULL, 0);
29 |
30 | wchar_t* wstr = new wchar_t[dwNum + 1];
31 | memset(wstr, 0, dwNum + 1);
32 | MultiByteToWideChar(CP_ACP, 0, ws.c_str(), -1, wstr, dwNum);
33 |
34 | dwNum = WideCharToMultiByte(CP_UTF8, 0, wstr, -1, NULL, 0, NULL, NULL);
35 | char* utf8 = new char[dwNum + 1];
36 | memset(utf8, 0, dwNum + 1);
37 | WideCharToMultiByte(CP_UTF8, 0, wstr, -1, utf8, dwNum, NULL, NULL);
38 | if (wstr)
39 | delete[] wstr;
40 | std::string str(utf8);
41 | if (utf8)
42 | delete[] utf8;
43 | return str;
44 | }
--------------------------------------------------------------------------------
/example/Win32Demo/Win32Demo.sln:
--------------------------------------------------------------------------------
1 |
2 | Microsoft Visual Studio Solution File, Format Version 12.00
3 | # Visual Studio 14
4 | VisualStudioVersion = 14.0.25420.1
5 | MinimumVisualStudioVersion = 10.0.40219.1
6 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "Win32Demo", "Win32Demo.vcxproj", "{B560ABA3-CCC2-42C6-8A99-43D7F811FEED}"
7 | EndProject
8 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "fastllm", "fastllm.vcxproj", "{BDA13DDF-572F-4FAD-B7A9-80EA5CAC3F2B}"
9 | EndProject
10 | Global
11 | GlobalSection(SolutionConfigurationPlatforms) = preSolution
12 | Debug|x64 = Debug|x64
13 | Debug|x86 = Debug|x86
14 | Release|x64 = Release|x64
15 | Release|x86 = Release|x86
16 | EndGlobalSection
17 | GlobalSection(ProjectConfigurationPlatforms) = postSolution
18 | {B560ABA3-CCC2-42C6-8A99-43D7F811FEED}.Debug|x64.ActiveCfg = Debug|x64
19 | {B560ABA3-CCC2-42C6-8A99-43D7F811FEED}.Debug|x64.Build.0 = Debug|x64
20 | {B560ABA3-CCC2-42C6-8A99-43D7F811FEED}.Debug|x86.ActiveCfg = Debug|Win32
21 | {B560ABA3-CCC2-42C6-8A99-43D7F811FEED}.Debug|x86.Build.0 = Debug|Win32
22 | {B560ABA3-CCC2-42C6-8A99-43D7F811FEED}.Release|x64.ActiveCfg = Release|x64
23 | {B560ABA3-CCC2-42C6-8A99-43D7F811FEED}.Release|x64.Build.0 = Release|x64
24 | {B560ABA3-CCC2-42C6-8A99-43D7F811FEED}.Release|x86.ActiveCfg = Release|Win32
25 | {B560ABA3-CCC2-42C6-8A99-43D7F811FEED}.Release|x86.Build.0 = Release|Win32
26 | {BDA13DDF-572F-4FAD-B7A9-80EA5CAC3F2B}.Debug|x64.ActiveCfg = Debug|x64
27 | {BDA13DDF-572F-4FAD-B7A9-80EA5CAC3F2B}.Debug|x64.Build.0 = Debug|x64
28 | {BDA13DDF-572F-4FAD-B7A9-80EA5CAC3F2B}.Debug|x86.ActiveCfg = Debug|Win32
29 | {BDA13DDF-572F-4FAD-B7A9-80EA5CAC3F2B}.Debug|x86.Build.0 = Debug|Win32
30 | {BDA13DDF-572F-4FAD-B7A9-80EA5CAC3F2B}.Release|x64.ActiveCfg = Release|x64
31 | {BDA13DDF-572F-4FAD-B7A9-80EA5CAC3F2B}.Release|x64.Build.0 = Release|x64
32 | {BDA13DDF-572F-4FAD-B7A9-80EA5CAC3F2B}.Release|x86.ActiveCfg = Release|Win32
33 | {BDA13DDF-572F-4FAD-B7A9-80EA5CAC3F2B}.Release|x86.Build.0 = Release|Win32
34 | EndGlobalSection
35 | GlobalSection(SolutionProperties) = preSolution
36 | HideSolutionNode = FALSE
37 | EndGlobalSection
38 | GlobalSection(ExtensibilityGlobals) = postSolution
39 | SolutionGuid = {48072E1E-35C1-46A9-9F40-FF4CDF872E08}
40 | EndGlobalSection
41 | EndGlobal
42 |
--------------------------------------------------------------------------------
/example/Win32Demo/bin/web/css/github.min.css:
--------------------------------------------------------------------------------
1 | pre code.hljs{display:block;overflow-x:auto;padding:1em}code.hljs{padding:3px 5px}/*!
2 | Theme: GitHub
3 | Description: Light theme as seen on github.com
4 | Author: github.com
5 | Maintainer: @Hirse
6 | Updated: 2021-05-15
7 |
8 | Outdated base version: https://github.com/primer/github-syntax-light
9 | Current colors taken from GitHub's CSS
10 | */.hljs{color:#24292e;background:#fff}.hljs-doctag,.hljs-keyword,.hljs-meta .hljs-keyword,.hljs-template-tag,.hljs-template-variable,.hljs-type,.hljs-variable.language_{color:#d73a49}.hljs-title,.hljs-title.class_,.hljs-title.class_.inherited__,.hljs-title.function_{color:#6f42c1}.hljs-attr,.hljs-attribute,.hljs-literal,.hljs-meta,.hljs-number,.hljs-operator,.hljs-selector-attr,.hljs-selector-class,.hljs-selector-id,.hljs-variable{color:#005cc5}.hljs-meta .hljs-string,.hljs-regexp,.hljs-string{color:#032f62}.hljs-built_in,.hljs-symbol{color:#e36209}.hljs-code,.hljs-comment,.hljs-formula{color:#6a737d}.hljs-name,.hljs-quote,.hljs-selector-pseudo,.hljs-selector-tag{color:#22863a}.hljs-subst{color:#24292e}.hljs-section{color:#005cc5;font-weight:700}.hljs-bullet{color:#735c0f}.hljs-emphasis{color:#24292e;font-style:italic}.hljs-strong{color:#24292e;font-weight:700}.hljs-addition{color:#22863a;background-color:#f0fff4}.hljs-deletion{color:#b31d28;background-color:#ffeef0}
--------------------------------------------------------------------------------
/example/Win32Demo/bin/web/css/texmath.css:
--------------------------------------------------------------------------------
1 | /* style for html inside of browsers */
2 | .katex { font-size: 1em !important; } /* align KaTeX font-size to surrounding text */
3 |
4 | eq { display: inline-block; }
5 | eqn { display: block}
6 | section.eqno {
7 | display: flex;
8 | flex-direction: row;
9 | align-content: space-between;
10 | align-items: center;
11 | }
12 | section.eqno > eqn {
13 | width: 100%;
14 | margin-left: 3em;
15 | }
16 | section.eqno > span {
17 | width:3em;
18 | text-align:right;
19 | }
20 |
--------------------------------------------------------------------------------
/example/Win32Demo/bin/web/js/markdown-it-link-attributes.min.js:
--------------------------------------------------------------------------------
1 | (function(f){if(typeof exports==="object"&&typeof module!=="undefined"){module.exports=f()}else if(typeof define==="function"&&define.amd){define([],f)}else{var g;if(typeof window!=="undefined"){g=window}else if(typeof global!=="undefined"){g=global}else if(typeof self!=="undefined"){g=self}else{g=this}g.markdownitLinkAttributes=f()}})(function(){var define,module,exports;return function(){function r(e,n,t){function o(i,f){if(!n[i]){if(!e[i]){var c="function"==typeof require&&require;if(!f&&c)return c(i,!0);if(u)return u(i,!0);var a=new Error("Cannot find module '"+i+"'");throw a.code="MODULE_NOT_FOUND",a}var p=n[i]={exports:{}};e[i][0].call(p.exports,function(r){var n=e[i][1][r];return o(n||r)},p,p.exports,r,e,n,t)}return n[i].exports}for(var u="function"==typeof require&&require,i=0;i 这个实现参考了vllm v0.4.1 中OpenAI-compatible API server的实现, 在这个基础上进行了简化
5 |
6 |
7 | ## 目前支持的接口
8 | * Open AI 接口官方说明: https://platform.openai.com/docs/api-reference/
9 |
10 | | 类型 | 接口名称 | method | 目前明确不支持的选项| 官方说明和样例 |
11 | | :--- | :--------------------- | :------------------ | :--------------------------------------------------------- |:--------------------------------------------------------- |
12 | | Chat | Create chat completion | v1/chat/completions | (n, presence_penalty, tools, functions, logprobs, seed, logit_bias) | https://platform.openai.com/docs/api-reference/chat/create |
13 |
14 |
15 | > 考虑到Completions接口已经被标记为Legacy接口,因此不实现该接口
16 |
17 | ## 依赖
18 | 以下依赖在python 3.12.2上没有问题
19 | 1. 需要先安装ftllm工具包
20 | 2. 需要安装以下依赖
21 | ```bash
22 | cd example/openai_server
23 | pip install -r requirements.txt
24 | ```
25 |
26 | ## 使用方法 && 样例
27 | * server启动命令
28 | ```bash
29 | cd example/openai_server
30 | python openai_api_server.py --model_name "model_name" -p "path_to_your_flm_model"
31 | # eg : python openai_api_server.py --model_name "chat-glm2-6b-int4" -p "./chatglm2-6b-int4.flm"
32 | ```
33 |
34 | * client测试命令
35 | ```bash
36 | # client 测试
37 | # 测试命令
38 | curl http://localhost:8080/v1/chat/completions \
39 | -H "Content-Type: application/json" \
40 | -H "Authorization: Bearer something" \
41 | -d '{
42 | "model": "chat-glm2-6b-int4",
43 | "messages": [
44 | {
45 | "role": "system",
46 | "content": "You are a helpful assistant."
47 | },
48 | {
49 | "role": "user",
50 | "content": "Hello!"
51 | }'] }
52 | # 响应结果
53 | {"id":"fastllm-chat-glm2-6b-int4-e4fd6bea564548f6ae95f6327218616d","object":"chat.completion","created":1715150460,"model":"chat-glm2-6b-int4","choices":[{"index":0,"message":{"role":"assistant","content":" Hello! How can I assist you today?"},"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"total_tokens":0,"completion_tokens":0}}
54 | ```
55 |
56 |
57 |
--------------------------------------------------------------------------------
/example/openai_server/protocal/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/example/openai_server/protocal/__init__.py
--------------------------------------------------------------------------------
/example/openai_server/requirements.txt:
--------------------------------------------------------------------------------
1 | fastapi
2 | pydantic
3 | openai
4 | shortuuid
--------------------------------------------------------------------------------
/example/python/custom_model.py:
--------------------------------------------------------------------------------
1 | from ftllm import llm
2 | from qwen2 import Qwen2Model
3 | import os
4 |
5 | root_path = "/mnt/hfmodels/"
6 | model_path = os.path.join(root_path, "Qwen/Qwen2-7B-Instruct")
7 |
8 | model = llm.model(model_path, graph = Qwen2Model)
9 | prompt = "北京有什么景点?"
10 | messages = [
11 | {"role": "system", "content": "你是一个爱说英文的人工智能,不管我跟你说什么语言,你都会用英文回复我"},
12 | {"role": "user", "content": prompt}
13 | ]
14 | for response in model.stream_response(messages, one_by_one = True):
15 | print(response, flush = True, end = "")
--------------------------------------------------------------------------------
/example/python/qwen2.py:
--------------------------------------------------------------------------------
1 | from ftllm.llm import ComputeGraph
2 | import math
3 |
4 | class Qwen2Model(ComputeGraph):
5 | def build(self):
6 | weight, data, config = self.weight, self.data, self.config
7 | config["max_positions"] = 128000
8 |
9 | head_dim = config["hidden_size"] // config["num_attention_heads"]
10 | self.Embedding(data["inputIds"], weight["model.embed_tokens.weight"], data["hiddenStates"]);
11 | self.DataTypeAs(data["hiddenStates"], data["atype"])
12 | for i in range(config["num_hidden_layers"]):
13 | pastKey = data["pastKey."][i]
14 | pastValue = data["pastValue."][i]
15 | layer = weight["model.layers."][i]
16 | self.RMSNorm(data["hiddenStates"], layer[".input_layernorm.weight"], config["rms_norm_eps"], data["attenInput"])
17 | self.Linear(data["attenInput"], layer[".self_attn.q_proj.weight"], layer[".self_attn.q_proj.bias"], data["q"])
18 | self.Linear(data["attenInput"], layer[".self_attn.k_proj.weight"], layer[".self_attn.k_proj.bias"], data["k"])
19 | self.Linear(data["attenInput"], layer[".self_attn.v_proj.weight"], layer[".self_attn.v_proj.bias"], data["v"])
20 | self.ExpandHead(data["q"], head_dim)
21 | self.ExpandHead(data["k"], head_dim)
22 | self.ExpandHead(data["v"], head_dim)
23 | self.LlamaRotatePosition2D(data["q"], data["positionIds"], data["sin"], data["cos"], head_dim // 2)
24 | self.LlamaRotatePosition2D(data["k"], data["positionIds"], data["sin"], data["cos"], head_dim // 2)
25 | self.FusedAttention(data["q"], pastKey, pastValue, data["k"], data["v"], data["attenInput"],
26 | data["attentionMask"], data["attenOutput"], data["seqLens"], 1.0 / math.sqrt(head_dim))
27 | self.Linear(data["attenOutput"], layer[".self_attn.o_proj.weight"], layer[".self_attn.o_proj.bias"], data["attenLastOutput"]);
28 | self.AddTo(data["hiddenStates"], data["attenLastOutput"]);
29 | self.RMSNorm(data["hiddenStates"], layer[".post_attention_layernorm.weight"], config["rms_norm_eps"], data["attenInput"])
30 | self.Linear(data["attenInput"], layer[".mlp.gate_proj.weight"], layer[".mlp.gate_proj.bias"], data["w1"])
31 | self.Linear(data["attenInput"], layer[".mlp.up_proj.weight"], layer[".mlp.up_proj.bias"], data["w3"])
32 | self.Silu(data["w1"], data["w1"])
33 | self.MulTo(data["w1"], data["w3"])
34 | self.Linear(data["w1"], layer[".mlp.down_proj.weight"], layer[".mlp.down_proj.bias"], data["w2"])
35 | self.AddTo(data["hiddenStates"], data["w2"])
36 | self.SplitLastTokenStates(data["hiddenStates"], data["seqLens"], data["lastTokensStates"])
37 | self.RMSNorm(data["lastTokensStates"], weight["model.norm.weight"], config["rms_norm_eps"], data["lastTokensStates"])
38 | self.Linear(data["lastTokensStates"], weight["lm_head.weight"], weight["lm_head.bias"], data["logits"])
39 |
40 | __model__ = Qwen2Model
--------------------------------------------------------------------------------
/example/webui/web/css/github.min.css:
--------------------------------------------------------------------------------
1 | pre code.hljs{display:block;overflow-x:auto;padding:1em}code.hljs{padding:3px 5px}/*!
2 | Theme: GitHub
3 | Description: Light theme as seen on github.com
4 | Author: github.com
5 | Maintainer: @Hirse
6 | Updated: 2021-05-15
7 |
8 | Outdated base version: https://github.com/primer/github-syntax-light
9 | Current colors taken from GitHub's CSS
10 | */.hljs{color:#24292e;background:#fff}.hljs-doctag,.hljs-keyword,.hljs-meta .hljs-keyword,.hljs-template-tag,.hljs-template-variable,.hljs-type,.hljs-variable.language_{color:#d73a49}.hljs-title,.hljs-title.class_,.hljs-title.class_.inherited__,.hljs-title.function_{color:#6f42c1}.hljs-attr,.hljs-attribute,.hljs-literal,.hljs-meta,.hljs-number,.hljs-operator,.hljs-selector-attr,.hljs-selector-class,.hljs-selector-id,.hljs-variable{color:#005cc5}.hljs-meta .hljs-string,.hljs-regexp,.hljs-string{color:#032f62}.hljs-built_in,.hljs-symbol{color:#e36209}.hljs-code,.hljs-comment,.hljs-formula{color:#6a737d}.hljs-name,.hljs-quote,.hljs-selector-pseudo,.hljs-selector-tag{color:#22863a}.hljs-subst{color:#24292e}.hljs-section{color:#005cc5;font-weight:700}.hljs-bullet{color:#735c0f}.hljs-emphasis{color:#24292e;font-style:italic}.hljs-strong{color:#24292e;font-weight:700}.hljs-addition{color:#22863a;background-color:#f0fff4}.hljs-deletion{color:#b31d28;background-color:#ffeef0}
--------------------------------------------------------------------------------
/example/webui/web/css/texmath.css:
--------------------------------------------------------------------------------
1 | /* style for html inside of browsers */
2 | .katex { font-size: 1em !important; } /* align KaTeX font-size to surrounding text */
3 |
4 | eq { display: inline-block; }
5 | eqn { display: block}
6 | section.eqno {
7 | display: flex;
8 | flex-direction: row;
9 | align-content: space-between;
10 | align-items: center;
11 | }
12 | section.eqno > eqn {
13 | width: 100%;
14 | margin-left: 3em;
15 | }
16 | section.eqno > span {
17 | width:3em;
18 | text-align:right;
19 | }
20 |
--------------------------------------------------------------------------------
/example/webui/web/js/markdown-it-link-attributes.min.js:
--------------------------------------------------------------------------------
1 | (function(f){if(typeof exports==="object"&&typeof module!=="undefined"){module.exports=f()}else if(typeof define==="function"&&define.amd){define([],f)}else{var g;if(typeof window!=="undefined"){g=window}else if(typeof global!=="undefined"){g=global}else if(typeof self!=="undefined"){g=self}else{g=this}g.markdownitLinkAttributes=f()}})(function(){var define,module,exports;return function(){function r(e,n,t){function o(i,f){if(!n[i]){if(!e[i]){var c="function"==typeof require&&require;if(!f&&c)return c(i,!0);if(u)return u(i,!0);var a=new Error("Cannot find module '"+i+"'");throw a.code="MODULE_NOT_FOUND",a}var p=n[i]={exports:{}};e[i][0].call(p.exports,function(r){var n=e[i][1][r];return o(n||r)},p,p.exports,r,e,n,t)}return n[i].exports}for(var u="function"==typeof require&&require,i=0;i DataDict;
12 | typedef std::map FloatDict;
13 | typedef std::map IntDict;
14 |
15 | class BaseOperator {
16 | public:
17 | // 是否可以运行某一个算子
18 | virtual bool CanRun(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
19 |
20 | // 对某一个算子进行形状推理
21 | virtual void Reshape(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
22 |
23 | // 对某一个算子进行推理
24 | virtual void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams) = 0;
25 | };
26 |
27 | class BaseBatchOperator : BaseOperator {
28 | public:
29 | // 对某一个算子进行形状推理
30 | virtual void Reshape(const std::string &opType, const DataDict &datas, const FloatDict &floatParams,
31 | const IntDict &intParams);
32 | };
33 |
34 | class BaseDevice {
35 | public:
36 | virtual bool Malloc (void **ret, size_t size) = 0; // 分配尺寸为size的空间
37 | virtual bool Malloc (void **ret, Data &data); // 分配形状为dims的空间
38 | virtual bool Free(void *ret) = 0; // 释放ret
39 |
40 | virtual bool CopyDataToCPU(void *dst, void *src, size_t size) = 0; // device上的src拷贝到cpu上的dst
41 | virtual bool CopyDataToCPU(Data &data); // data数据从该device移动到CPU
42 |
43 | virtual bool CopyDataFromCPU(void *dst, void *src, size_t size) = 0; // cpu上的src拷贝到device上的dst
44 | virtual bool CopyDataFromCPU(Data &data); // data数据从CPU移动到该device
45 |
46 | // 是否可以运行某一个算子
47 | virtual bool CanRun(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
48 |
49 | // 对某一个算子进行形状推理
50 | virtual void Reshape(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
51 |
52 | // 对某一个算子进行推理
53 | virtual void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
54 |
55 | std::string deviceType;
56 | std::string deviceName;
57 | std::vector deviceIds;
58 | std::map deviceIdsRatio;
59 |
60 | std::map ops;
61 | };
62 | }
63 |
64 | #endif //FASTLLM_DEVICE_H
65 |
--------------------------------------------------------------------------------
/include/devices/multicuda/fastllm-multicuda.cuh:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 8/2/24.
3 | //
4 |
5 | #include "fastllm.h"
6 |
7 | std::vector FastllmCudaGetFreeSizes();
8 |
9 | #ifdef __cplusplus
10 | extern "C" {
11 | #endif
12 |
13 | // deviceId -> [[l0, r0), [l1, r1), ...]
14 | using DivisionScheme = std::map > >;
15 |
16 | std::vector FastllmMultiCudaGetSplitPoints(std::vector &multiCudaCurrentDevices, std::map &multiCudaCurrentRatios, int total, int unit);
17 | void FastllmGetMulticudaDeviceAndRatio(std::vector &devices, std::map &ratios, bool noSpecial);
18 | bool SplitMultiCudaWeight(fastllm::Data &weight, fastllm::Data &bias,
19 | std::vector &multiCudaCurrentDevices, DivisionScheme divisionScheme, int splitAxis);
20 | void CopyToMultiDevices(fastllm::Data &data, std::vector devices, bool copyData);
21 |
22 | void FastllmMultiCudaSetDevice(std::vector ids);
23 | void FastllmMultiCudaSetDeviceRatio(std::map &deviceRatio);
24 |
25 | bool FastllmMultiCudaHalfMatMul(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k);
26 | bool FastllmMultiCudaMatMul(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k);
27 |
28 | #ifdef __cplusplus
29 | }
30 | #endif
31 |
--------------------------------------------------------------------------------
/include/devices/multicuda/multicudadevice.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 8/2/24.
3 | //
4 |
5 | #ifndef FASTLLM_MULTICUDADEVICE_H
6 | #define FASTLLM_MULTICUDADEVICE_H
7 |
8 | #include "device.h"
9 |
10 | namespace fastllm {
11 | class MultiCudaDevice : BaseDevice {
12 |
13 | private:
14 | CudaDevice *cudaDevice;
15 |
16 | public:
17 | MultiCudaDevice (CudaDevice *cudaDevice);
18 |
19 | bool Malloc (void **ret, size_t size); // 分配尺寸为size的空间
20 | bool Free(void *ret); // 释放ret
21 |
22 | bool CopyDataToCPU(void *dst, void *src, size_t size);
23 | bool CopyDataFromCPU(void *dst, void *src, size_t size);
24 |
25 | // 是否可以运行某一个算子
26 | bool CanRun(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
27 |
28 | // 对某一个算子进行形状推理
29 | void Reshape(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
30 |
31 | // 对某一个算子进行推理
32 | void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
33 | };
34 |
35 | class MultiCudaLinearOp : CudaLinearOp {
36 | bool CanRun(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
37 | void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
38 | };
39 |
40 | class MultiCudaMLPOp : CudaLinearOp {
41 | void Reshape(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
42 | void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
43 | };
44 |
45 | class MultiCudaMergeMOE : CpuMergeMOE {
46 | void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
47 | };
48 |
49 | class MultiCudaMergeAttention : CudaMergeAttention {
50 | void Reshape(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
51 | void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
52 | };
53 | }
54 |
55 | #endif //FASTLLM_MULTICUDADEVICE_H
56 |
--------------------------------------------------------------------------------
/include/devices/numa/fastllm-numa.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 4/11/24.
3 | //
4 |
5 | #ifndef FASTLLM_NUMA_COMPUTE_H
6 | #define FASTLLM_NUMA_COMPUTE_H
7 |
8 | #include "fastllm.h"
9 |
10 | namespace fastllm {
11 | struct NumaClient {
12 | int fd;
13 | volatile uint8_t *buf;
14 | volatile uint8_t *result;
15 | volatile int32_t *flag;
16 |
17 | int serverVersion;
18 | int serverNumaCnt;
19 |
20 | std::set registerDataNames; // 向服务器上注册过的DataName
21 |
22 | NumaClient ();
23 |
24 | ~NumaClient ();
25 |
26 | void Launch(int opType);
27 |
28 | void Wait();
29 |
30 | void SendLongMessage(uint8_t *buffer, uint64_t len);
31 |
32 | void RegisterFastllmData(fastllm::Data *data, const std::string &weightType);
33 |
34 | void UnregisterFastllmData(const std::string &dataName);
35 |
36 | void RunNumaLinearU(int n, int m, int k, int group, int groupCnt,
37 | fastllm::Data *weight, fastllm::Data *bias,
38 | std::vector *inputConfigs,
39 | uint8_t *uinput, float *output,
40 | LinearExType exType,
41 | DataType outputType);
42 |
43 | void RunNumaLinearF(int n, int m, int k, fastllm::Data *weight, fastllm::Data *bias,
44 | float *input, float *output, LinearExType exType, DataType dataType);
45 |
46 | void RunNumaMOEU(int n, int m, int k, int group, int groupCnt,
47 | std::vector weights, std::vector factors,
48 | std::vector *inputConfigs,
49 | uint8_t *uinput, float *output,
50 | DataType outputType);
51 |
52 | void RunNumaMOEUMultiRow(int n, int m, int k, int group, int groupCnt,
53 | std::vector > &weights, std::vector > &factors,
54 | std::vector *inputConfigs,
55 | uint8_t *uinput, float *output,
56 | DataType outputType);
57 |
58 | void RunNumaMOEF(int n, int m, int k,
59 | std::vector weights, std::vector factors,
60 | float *input, float *output,
61 | DataType outputType);
62 |
63 | void RunNumaMOEFMultiRow(int n, int m, int k,
64 | std::vector > &weights, std::vector > &factors,
65 | float *input, float *output,
66 | DataType outputType);
67 |
68 | void AppendKVCache(long long uid, Data *content);
69 |
70 | void Attention(Data *q, Data *k, Data *v, int group, float scale, int maskType, Data *output);
71 | };
72 |
73 | void RegisterFastllmData(fastllm::Data *data, const std::string &weightType);
74 | }
75 |
76 | #endif
--------------------------------------------------------------------------------
/include/devices/numa/kvcache.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 24-4-19.
3 | //
4 |
5 | #ifndef TFACCCOMPUTESERVER_KVCACHE_H
6 | #define TFACCCOMPUTESERVER_KVCACHE_H
7 |
8 | #include "fastllm.h"
9 |
10 | namespace fastllm {
11 | struct KVCache {
12 | std::chrono::system_clock::time_point lastFlushTime;
13 |
14 | DataType dataType;
15 | int unitSize;
16 |
17 | int len;
18 | int head, dim; // 尺寸为[head, len, dim]
19 | int currentCap; // 预分配[head, currentCap, dim]的空间,当middle超出时扩容
20 | int unitLen = 64; // 扩容单位
21 | uint8_t *data = nullptr;
22 |
23 | KVCache (DataType dataType, int head, int dim);
24 | ~KVCache();
25 |
26 | void Append(int len, uint8_t *data);
27 | };
28 |
29 | struct KVCacheManager {
30 | std::unordered_map caches;
31 |
32 | KVCache *Get(long long uid);
33 | KVCache *Get(long long uid, DataType dataType, int head, int dim);
34 | void Delete(long long uid);
35 | };
36 | }
37 |
38 | #endif //TFACCCOMPUTESERVER_KVCACHE_H
39 |
--------------------------------------------------------------------------------
/include/devices/numa/numadevice.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 4/11/24.
3 | //
4 |
5 | #ifndef FASTLLM_NUMADEVICE_H
6 | #define FASTLLM_NUMADEVICE_H
7 |
8 | #include "device.h"
9 | #include "devices/cpu/cpudevice.h"
10 |
11 | namespace fastllm {
12 | class NumaDevice : BaseDevice {
13 | public:
14 | NumaDevice();
15 |
16 | // numa use cpu DDR
17 | bool Malloc (void **ret, size_t size);
18 | bool Free(void *ret);
19 |
20 | bool CopyDataToCPU(void *dst, void *src, size_t size);
21 | bool CopyDataFromCPU(void *dst, void *src, size_t size);
22 | };
23 |
24 | class NumaLinearOp : CpuLinearOp {
25 | bool CanRun(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
26 | void Reshape(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
27 | void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
28 | long long int Ops(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
29 | };
30 |
31 | class NumaMergeMOE : CpuMergeMOE {
32 | bool CanRun(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
33 | void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
34 | };
35 |
36 | class NumaCatDirectOp : CpuCatDirectOp {
37 | void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
38 | };
39 |
40 | class NumaAttention : CpuAttention {
41 | void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
42 | };
43 |
44 | class NumaAttentionBatchOp : CpuAttentionBatchOp {
45 | void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
46 | };
47 |
48 | class NumaCatDirectBatchOp : CpuCatDirectBatchOp {
49 | void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
50 | };
51 | }
52 |
53 | #endif
--------------------------------------------------------------------------------
/include/devices/tfacc/fastllm-tfacc.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 4/11/24.
3 | //
4 |
5 | #ifndef FASTLLM_TFACC_COMPUTE_H
6 | #define FASTLLM_TFACC_COMPUTE_H
7 |
8 | #include "fastllm.h"
9 |
10 | namespace fastllm {
11 | struct TfaccClient {
12 | int fd;
13 | volatile uint8_t *buf;
14 | volatile uint8_t *result;
15 | volatile int32_t *flag;
16 |
17 | int serverVersion;
18 | int serverNumaCnt;
19 |
20 | std::set registerDataNames; // 向服务器上注册过的DataName
21 |
22 | TfaccClient ();
23 |
24 | ~TfaccClient ();
25 |
26 | void Launch(int opType);
27 |
28 | void Wait();
29 |
30 | void SendLongMessage(uint8_t *buffer, uint64_t len);
31 |
32 | void RegisterFastllmData(fastllm::Data *data, const std::string &weightType);
33 |
34 | void UnregisterFastllmData(const std::string &dataName);
35 |
36 | void RunTfaccLinearU(int n, int m, int k, int group, int groupCnt,
37 | fastllm::Data *weight, fastllm::Data *bias,
38 | std::vector *inputConfigs,
39 | uint8_t *uinput, float *output,
40 | LinearExType exType,
41 | DataType outputType);
42 |
43 | void RunTfaccLinearF(int n, int m, int k, fastllm::Data *weight, fastllm::Data *bias,
44 | float *input, float *output, LinearExType exType, DataType dataType);
45 |
46 | void RunTfaccMOEU(int n, int m, int k, int group, int groupCnt,
47 | std::vector weights, std::vector factors,
48 | std::vector *inputConfigs,
49 | uint8_t *uinput, float *output,
50 | DataType outputType);
51 |
52 | void AppendKVCache(long long uid, Data *content);
53 |
54 | void Attention(Data *q, Data *k, Data *v, int group, float scale, int maskType, Data *output);
55 | };
56 |
57 | void RegisterFastllmData(fastllm::Data *data, const std::string &weightType);
58 | }
59 |
60 | #endif
--------------------------------------------------------------------------------
/include/devices/tfacc/tfaccdevice.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 4/11/24.
3 | //
4 |
5 | #ifndef FASTLLM_TFACCDEVICE_H
6 | #define FASTLLM_TFACCDEVICE_H
7 |
8 | #include "device.h"
9 | #include "devices/cpu/cpudevice.h"
10 |
11 | namespace fastllm {
12 | class TfaccDevice : BaseDevice {
13 | public:
14 | TfaccDevice();
15 |
16 | // tfacc use cpu DDR
17 | bool Malloc (void **ret, size_t size);
18 | bool Free(void *ret);
19 |
20 | bool CopyDataToCPU(void *dst, void *src, size_t size);
21 | bool CopyDataFromCPU(void *dst, void *src, size_t size);
22 | };
23 |
24 | class TfaccLinearOp : CpuLinearOp {
25 | bool CanRun(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
26 | void Reshape(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
27 | void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
28 | long long int Ops(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
29 | };
30 |
31 | class TfaccMergeMOE : CpuMergeMOE {
32 | bool CanRun(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
33 | void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
34 | };
35 |
36 | class TfaccCatDirectOp : CpuCatDirectOp {
37 | void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
38 | };
39 |
40 | class TfaccAttention : CpuAttention {
41 | void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
42 | };
43 |
44 | class TfaccAttentionBatchOp : CpuAttentionBatchOp {
45 | void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
46 | };
47 |
48 | class TfaccCatDirectBatchOp : CpuCatDirectBatchOp {
49 | void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
50 | };
51 | }
52 |
53 | #endif
--------------------------------------------------------------------------------
/include/devices/tops/topsdevice.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 2/24/225.
3 | //
4 |
5 | #ifndef FASTLLM_TOPSDEVICE_H
6 | #define FASTLLM_TOPSDEVICE_H
7 |
8 | #include "device.h"
9 | #include "devices/cpu/cpudevice.h"
10 |
11 | namespace fastllm {
12 | class TopsDevice : BaseDevice {
13 | public:
14 | TopsDevice();
15 |
16 | bool Malloc (void **ret, size_t size);
17 | bool Free(void *ret);
18 |
19 | bool CopyDataToCPU(void *dst, void *src, size_t size);
20 | bool CopyDataFromCPU(void *dst, void *src, size_t size);
21 | };
22 |
23 | class TopsLinearOp : CpuLinearOp {
24 | bool CanRun(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
25 | void Reshape(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
26 | void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
27 | long long int Ops(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
28 | };
29 | }
30 |
31 | #endif
--------------------------------------------------------------------------------
/include/executor.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 6/13/23.
3 | //
4 |
5 | #ifndef FASTLLM_EXECUTOR_H
6 | #define FASTLLM_EXECUTOR_H
7 |
8 | #include "device.h"
9 |
10 | namespace fastllm {
11 | class Executor {
12 | private:
13 | std::vector devices;
14 | std::map profiler;
15 |
16 | public:
17 | Executor (); // 创建默认的Executor
18 |
19 | ~Executor(); // 析构
20 |
21 | void ClearDevices(); // 清空 devices
22 |
23 | void AddDevice(BaseDevice *device); // 增加一个device
24 |
25 | void SetFirstDevice(const std::string &device); // 设定优先的device
26 |
27 | std::string GetFirstDeviceType(); // 获取优先device的type
28 |
29 | std::vector GetDeviceIds(const std::string &device); // 获取指定device的deviceIds
30 |
31 | bool CanRunOnFirstDevice(const std::string &opType, const fastllm::DataDict &datas, const fastllm::FloatDict &floatParams,
32 | const fastllm::IntDict &intParams);
33 |
34 | // 运行一个op
35 | void Run(const std::string &opType, const fastllm::DataDict &datas, const fastllm::FloatDict &floatParams,
36 | const fastllm::IntDict &intParams);
37 |
38 | void ClearProfiler();
39 |
40 | void PrintProfiler();
41 | };
42 | }
43 |
44 | #endif //FASTLLM_EXECUTOR_H
45 |
--------------------------------------------------------------------------------
/include/model.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 6/20/23.
3 | //
4 |
5 | #ifndef FASTLLM_MODEL_H
6 | #define FASTLLM_MODEL_H
7 |
8 | #include "basellm.h"
9 | #include "bert.h"
10 | #include "xlmroberta.h"
11 |
12 | namespace fastllm {
13 | std::unique_ptr CreateEmbeddingModelFromFile(const std::string &fileName);
14 |
15 | std::unique_ptr CreateLLMModelFromFile(const std::string &fileName);
16 |
17 | std::unique_ptr CreateEmptyLLMModel(const std::string &modelType);
18 |
19 | std::unique_ptr CreateLLMModelFromHF(const std::string &modelPath,
20 | DataType linearDataType,
21 | int groupCnt = -1,
22 | bool skipTokenizer = false,
23 | const std::string &modelConfig = "",
24 | const std::string &loraPath = "",
25 | bool weightOnly = false,
26 | bool useMoeDataType = false,
27 | DataType moeDataType = DataType::FLOAT32,
28 | int moeGroupCnt = -1);
29 |
30 | void ExportLLMModelFromHF(const std::string &modelPath,
31 | DataType linearDataType,
32 | int groupCnt,
33 | const std::string &exportPath,
34 | const std::string &modelConfig = "",
35 | const std::string &loraPath = "",
36 | bool useMoeDataType = false,
37 | DataType moeDataType = DataType::FLOAT32,
38 | int moeGroupCnt = -1);
39 |
40 | std::unique_ptr CreateLLMTokenizerFromHF(const std::string &modelPath);
41 | }
42 |
43 | #endif //FASTLLM_MODEL_H
44 |
--------------------------------------------------------------------------------
/include/models/bert.h:
--------------------------------------------------------------------------------
1 |
2 | #ifndef FASTLLM_BERT_H
3 | #define FASTLLM_BERT_H
4 |
5 | #include "basellm.h"
6 | #include "fastllm.h"
7 |
8 | namespace fastllm {
9 | // 类BERT类大模型基础类
10 | // 支持Compute-Score,计算两个token序列的相似程度(用于reranker)
11 | // 支持Embedding,生成token序列的向量
12 | class BertModel: public basellm {
13 | public:
14 | BertModel() {};
15 |
16 | ~BertModel() {
17 | this->weight.ReleaseWeight();
18 | };
19 |
20 | void InitParams(); // 初始化参数信息
21 |
22 | void Normalize(float *data, int dataLen);
23 |
24 | // 推理
25 | virtual std::vector > ForwardAll(
26 | const Data &inputIds,
27 | const Data &attentionMask,
28 | const Data &tokenTypeIds,
29 | const Data &positionIds,
30 | bool normalize);
31 |
32 | // 推理
33 | virtual int Forward(
34 | const Data &inputIds,
35 | const Data &attentionMask,
36 | const Data &positionIds,
37 | std::vector > &pastKeyValues,
38 | const GenerationConfig &generationConfig = GenerationConfig(),
39 | const LastTokensManager &lastTokens = LastTokensManager(),
40 | std::vector *logits = nullptr);
41 |
42 | virtual void FillBertInputsBatch(const std::vector > &tokens,
43 | Data &inputIds, Data &attentionMask, Data &tokenTypeIds, Data &positionIds);
44 |
45 | // 计算相似分数
46 | // tokens: 输入tokens, tokens[i]代表第i个输入的token序列
47 | // ret: ret[i]代表第i个输入的相似度
48 | std::vector ComputeScore(std::vector > tokens);
49 |
50 | std::vector EmbeddingSentence(const std::vector &tokens, bool normalize);
51 |
52 | std::vector > EmbeddingSentenceBatch(const std::vector > &tokens, bool normalize);
53 |
54 | std::vector EmbeddingSentence(const std::string &context, bool normalize);
55 |
56 | std::vector > EmbeddingSentenceBatch(const std::vector &contexts, bool normalize);
57 |
58 | void LoadFromFile(const std::string &fileName); // 从文件读取
59 |
60 | void WarmUp(); // 预热
61 |
62 | virtual std::string MakeInput(const std::string &history, int round, const std::string &input);
63 |
64 | virtual std::string MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output);
65 |
66 | std::string model_type;
67 |
68 | float layer_norm_eps = 1e-12;
69 |
70 | int embed_dim = 512;
71 | int num_attention_heads = 64;
72 | int head_dim = embed_dim / num_attention_heads;
73 | int max_positions = 32768;
74 | int block_cnt = 12;
75 |
76 | std::map deviceMap;
77 | };
78 | }
79 |
80 | #endif //FASTLLM_BERT_H
--------------------------------------------------------------------------------
/include/models/factoryllm.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include "chatglm.h"
3 | #include "moss.h"
4 | #include "basellm.h"
5 | #include "llama.h"
6 | #include "qwen.h"
7 | #include "fastllm.h"
8 |
9 | enum LLM_TYPE {
10 | LLM_TYPE_CHATGLM = 0,
11 | LLM_TYPE_MOSS = 1,
12 | LLM_TYPE_VICUNA = 2,
13 | LLM_TYPE_BAICHUAN = 3,
14 | LLM_TYPE_QWEN = 4
15 | };
16 |
17 | class factoryllm {
18 | public:
19 | factoryllm() {};
20 |
21 | ~factoryllm() {};
22 |
23 | fastllm::basellm *createllm(LLM_TYPE type) {
24 | fastllm::basellm *pLLM = NULL;
25 | switch (type) {
26 | case LLM_TYPE_CHATGLM:
27 | pLLM = new fastllm::ChatGLMModel();
28 | break;
29 | case LLM_TYPE_MOSS:
30 | pLLM = new fastllm::MOSSModel();
31 | break;
32 | case LLM_TYPE_VICUNA:
33 | pLLM = new fastllm::LlamaModel();
34 | break;
35 | default:
36 | pLLM = new fastllm::QWenModel();
37 | break;
38 | }
39 | return pLLM;
40 | };
41 | };
--------------------------------------------------------------------------------
/include/models/glm.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 5/11/23.
3 | //
4 |
5 | #ifndef FASTLLM_GLM_H
6 | #define FASTLLM_GLM_H
7 |
8 | #include "basellm.h"
9 | #include "cmath"
10 |
11 | #include
12 |
13 | namespace fastllm {
14 | class GLMModel: public basellm {
15 | public:
16 | GLMModel (); // 构造函数
17 |
18 | // 推理
19 | virtual int Forward(
20 | const Data &inputIds,
21 | const Data &attentionMask,
22 | const Data &positionIds,
23 | std::vector > &pastKeyValues,
24 | const GenerationConfig &generationConfig = GenerationConfig(),
25 | const LastTokensManager &lastTokens = LastTokensManager(),
26 | std::vector *logits = nullptr);
27 |
28 | std::vector ForwardBatch(
29 | int batch,
30 | const Data &inputIds,
31 | const Data &attentionMask,
32 | const Data &positionIds,
33 | std::vector > &pastKeyValues,
34 | const GenerationConfig &generationConfig = GenerationConfig(),
35 | const LastTokensManager &lastTokens = LastTokensManager(),
36 | std::vector *> *retLogits = nullptr);
37 |
38 | // 根据输入的tokens生成LLM推理的输入
39 | virtual void FillLLMInputs(std::vector > &inputTokens,
40 | const std::map ¶ms,
41 | Data &inputIds, Data &attentionMask, Data &positionIds);
42 |
43 | virtual void InitParams();
44 | virtual void WarmUp(); // 预热
45 |
46 | virtual std::string MakeInput(const std::string &history, int round, const std::string &input); // 根据历史信息和当前输入生成prompt
47 |
48 | virtual std::string MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output); // 根据当前回复更新history
49 |
50 | private:
51 |
52 | float scale_attn_1;
53 |
54 | static constexpr int eot_token_id = 50000;//<|endoftext|>
55 | static constexpr int cls_token_id = 50002;//[CLS]
56 | static constexpr int mask_token_id = 50003;//[MASK]
57 | static constexpr int smask_token_id = 50008;//[sMASK]
58 | static constexpr int gmask_token_id = 50009;//[gMASK]
59 | };
60 | }
61 |
62 | #endif //FASTLLM_GLM_H
63 |
--------------------------------------------------------------------------------
/include/models/internlm2.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by tylunasli on 3/14/24.
3 | //
4 |
5 | #ifndef FASTLLM_INTERNLM2_H
6 | #define FASTLLM_INTERNLM2_H
7 |
8 | #include "basellm.h"
9 | #include "llama.h"
10 | #include "cmath"
11 |
12 | #include
13 |
14 | namespace fastllm {
15 | class Internlm2Model : public LlamaModel {
16 | public:
17 | Internlm2Model(); // 构造函数
18 |
19 | virtual void InitParams(); // 初始化参数信息
20 |
21 | // 推理
22 | virtual int Forward(
23 | const Data &inputIds,
24 | const Data &attentionMask,
25 | const Data &positionIds,
26 | std::vector > &pastKeyValues,
27 | const GenerationConfig &generationConfig = GenerationConfig(),
28 | const LastTokensManager &lastTokens = LastTokensManager(),
29 | std::vector *logits = nullptr);
30 |
31 | std::vector ForwardBatch(
32 | int batch,
33 | const Data &inputIds,
34 | const Data &attentionMask,
35 | const Data &positionIds,
36 | std::vector > &pastKeyValues,
37 | const GenerationConfig &generationConfig = GenerationConfig(),
38 | const LastTokensManager &lastTokens = LastTokensManager(),
39 | std::vector *> *logits = nullptr);
40 |
41 | std::vector ForwardBatch(
42 | int batch,
43 | const Data &inputIds,
44 | const std::vector &attentionMask,
45 | const std::vector &positionIds,
46 | const std::vector &seqLens,
47 | std::vector > &pastKeyValues,
48 | const std::vector &generationConfigs,
49 | const LastTokensManager &lastTokens = LastTokensManager(),
50 | std::vector *> *logits = nullptr);
51 | };
52 | }
53 |
54 | #endif //FASTLLM_INTERNLM2_H
55 |
--------------------------------------------------------------------------------
/include/models/minicpm.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 6/1/23.
3 | //
4 |
5 | #ifndef FASTLLM_MINICPM_H
6 | #define FASTLLM_MINICPM_H
7 |
8 | #include "basellm.h"
9 | #include "llama.h"
10 | #include "cmath"
11 |
12 | #include
13 |
14 | namespace fastllm {
15 | class MiniCpmModel: public LlamaModel {
16 | public:
17 | MiniCpmModel(); // 构造函数
18 |
19 | virtual void InitParams(); // 初始化参数信息
20 |
21 | // 推理
22 | virtual int Forward(
23 | const Data &inputIds,
24 | const Data &attentionMask,
25 | const Data &positionIds,
26 | std::vector > &pastKeyValues,
27 | const GenerationConfig &generationConfig = GenerationConfig(),
28 | const LastTokensManager &lastTokens = LastTokensManager(),
29 | std::vector *logits = nullptr);
30 |
31 | std::vector ForwardBatch(
32 | int batch,
33 | const Data &inputIds,
34 | const Data &attentionMask,
35 | const Data &positionIds,
36 | std::vector > &pastKeyValues,
37 | const GenerationConfig &generationConfig = GenerationConfig(),
38 | const LastTokensManager &lastTokens = LastTokensManager(),
39 | std::vector *> *logits = nullptr);
40 |
41 | std::vector ForwardBatch(
42 | int batch,
43 | const Data &inputIds,
44 | const std::vector &attentionMask,
45 | const std::vector &positionIds,
46 | const std::vector &seqLens,
47 | std::vector > &pastKeyValues,
48 | const std::vector &generationConfigs,
49 | const LastTokensManager &lastTokens = LastTokensManager(),
50 | std::vector *> *logits = nullptr);
51 |
52 | private:
53 | float embed_scale = 1.f;
54 |
55 | float attention_scale = 1.f / std::sqrt(block_cnt);
56 |
57 | float rms_scale = 1.f / 4096.f;
58 | };
59 | }
60 |
61 | #endif //FASTLLM_MINICPM_H
62 |
--------------------------------------------------------------------------------
/include/models/minicpm3.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 6/1/23.
3 | //
4 |
5 | #ifndef FASTLLM_MINICPM3_H
6 | #define FASTLLM_MINICPM3_H
7 |
8 | #include "basellm.h"
9 | #include "llama.h"
10 | #include "cmath"
11 |
12 | #include
13 |
14 | namespace fastllm {
15 | class MiniCpm3Model: public LlamaModel {
16 | public:
17 | MiniCpm3Model(); // 构造函数
18 |
19 | virtual void InitParams(); // 初始化参数信息
20 |
21 | // 推理
22 | virtual int Forward(
23 | const Data &inputIds,
24 | const Data &attentionMask,
25 | const Data &positionIds,
26 | std::vector > &pastKeyValues,
27 | const GenerationConfig &generationConfig = GenerationConfig(),
28 | const LastTokensManager &lastTokens = LastTokensManager(),
29 | std::vector *logits = nullptr);
30 |
31 | std::vector ForwardBatch(
32 | int batch,
33 | const Data &inputIds,
34 | const Data &attentionMask,
35 | const Data &positionIds,
36 | std::vector > &pastKeyValues,
37 | const GenerationConfig &generationConfig = GenerationConfig(),
38 | const LastTokensManager &lastTokens = LastTokensManager(),
39 | std::vector *> *logits = nullptr);
40 |
41 | std::vector ForwardBatch(
42 | int batch,
43 | const Data &inputIds,
44 | const std::vector &attentionMask,
45 | const std::vector &positionIds,
46 | const std::vector &seqLens,
47 | std::vector > &pastKeyValues,
48 | const std::vector &generationConfigs,
49 | const LastTokensManager &lastTokens = LastTokensManager(),
50 | std::vector *> *logits = nullptr);
51 |
52 | private:
53 | float embed_scale = 1.f;
54 |
55 | float attention_scale = 1.f / std::sqrt(block_cnt);
56 |
57 | float rms_scale = 1.f / 4096.f;
58 |
59 | int hidden_size = 2560;
60 | int qk_nope_head_dim = 64;
61 | int qk_rope_head_dim = 32;
62 | int kv_lora_rank = 256;
63 | };
64 | }
65 |
66 | #endif //FASTLLM_MINICPM_H
67 |
--------------------------------------------------------------------------------
/include/models/moss.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 5/12/23.
3 | //
4 |
5 | #ifndef TEST_MOSS_H
6 | #define TEST_MOSS_H
7 |
8 | #include "basellm.h"
9 | #include "cmath"
10 |
11 | namespace fastllm {
12 | class MOSSModel: public basellm {
13 | public:
14 | MOSSModel(); // 构造函数
15 |
16 | // 推理
17 | virtual int Forward(
18 | const Data &inputIds,
19 | const Data &attentionMask,
20 | const Data &positionIds,
21 | std::vector > &pastKeyValues,
22 | const GenerationConfig &generationConfig = GenerationConfig(),
23 | const LastTokensManager &lastTokens = LastTokensManager(),
24 | std::vector *logits = nullptr);
25 |
26 | virtual std::string Response(const std::string &input, RuntimeResult retCb,
27 | const GenerationConfig &generationConfig = GenerationConfig()); // 根据给出的内容回复
28 |
29 | virtual std::string MakeInput(const std::string &history, int round, const std::string &input); // 根据历史信息和当前输入生成prompt
30 |
31 | virtual std::string MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output); // 根据当前回复更新history
32 |
33 | virtual void FillLLMInputs(std::vector > &inputTokens,
34 | const std::map ¶ms,
35 | Data &inputIds, Data &attentionMask, Data &positionIds);
36 |
37 | virtual void WarmUp();
38 | private:
39 | virtual void RotatePosition2D(Data &data, const Data &positionIds); // 二维位置编码
40 |
41 | virtual void CausalMask(Data &data, int start); // 因果mask?
42 | };
43 | }
44 |
45 | #endif //TEST_MOSS_H
46 |
--------------------------------------------------------------------------------
/include/models/phi3.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by tylunasli on 9/13/24.
3 | //
4 |
5 | #ifndef FASTLLM_PHI3_H
6 | #define FASTLLM_PHI3_H
7 |
8 | #include "basellm.h"
9 | #include "llama.h"
10 | #include "cmath"
11 |
12 | #include
13 |
14 | namespace fastllm {
15 | class Phi3Model : public LlamaModel {
16 | public:
17 | Phi3Model(); // 构造函数
18 |
19 | virtual void InitParams(); // 初始化参数信息
20 |
21 | // 推理
22 | virtual int Forward(
23 | const Data &inputIds,
24 | const Data &attentionMask,
25 | const Data &positionIds,
26 | std::vector > &pastKeyValues,
27 | const GenerationConfig &generationConfig = GenerationConfig(),
28 | const LastTokensManager &lastTokens = LastTokensManager(),
29 | std::vector *logits = nullptr);
30 |
31 | std::vector ForwardBatch(
32 | int batch,
33 | const Data &inputIds,
34 | const Data &attentionMask,
35 | const Data &positionIds,
36 | std::vector > &pastKeyValues,
37 | const GenerationConfig &generationConfig = GenerationConfig(),
38 | const LastTokensManager &lastTokens = LastTokensManager(),
39 | std::vector *> *logits = nullptr);
40 |
41 | std::vector ForwardBatch(
42 | int batch,
43 | const Data &inputIds,
44 | const std::vector &attentionMask,
45 | const std::vector &positionIds,
46 | const std::vector &seqLens,
47 | std::vector > &pastKeyValues,
48 | const std::vector &generationConfigs,
49 | const LastTokensManager &lastTokens = LastTokensManager(),
50 | std::vector *> *logits = nullptr);
51 | };
52 | }
53 |
54 | #endif //FASTLLM_PHI3_H
55 |
--------------------------------------------------------------------------------
/include/models/qwen.h:
--------------------------------------------------------------------------------
1 | //
2 | // Created by siemon on 8/9/23.
3 | //
4 |
5 | #ifndef TEST_QWEN_H
6 | #define TEST_QWEN_H
7 |
8 | #include "basellm.h"
9 |
10 | namespace fastllm {
11 | class QWenModel : public basellm {
12 | public:
13 | QWenModel();
14 |
15 | // 推理
16 | virtual int Forward(
17 | const Data &inputIds,
18 | const Data &attentionMask,
19 | const Data &positionIds,
20 | std::vector > &pastKeyValues,
21 | const GenerationConfig &generationConfig = GenerationConfig(),
22 | const LastTokensManager &lastTokens = LastTokensManager(),
23 | std::vector *logits = nullptr);
24 |
25 | std::vector ForwardBatch(
26 | int batch,
27 | const Data &inputIds,
28 | const Data &attentionMask,
29 | const Data &positionIds,
30 | std::vector > &pastKeyValues,
31 | const GenerationConfig &generationConfig = GenerationConfig(),
32 | const LastTokensManager &lastTokens = LastTokensManager(),
33 | std::vector *> *logits = nullptr);
34 |
35 | std::vector ForwardBatch(
36 | int batch,
37 | const Data &inputIds,
38 | const std::vector &attentionMask,
39 | const std::vector &positionIds,
40 | const std::vector &seqLens,
41 | std::vector > &pastKeyValues,
42 | const std::vector &generationConfigs,
43 | const LastTokensManager &lastTokens = LastTokensManager(),
44 | std::vector *> *retLogits = nullptr);
45 |
46 | virtual std::string MakeInput(const std::string &history, int round, const std::string &input);
47 |
48 | virtual std::string MakeHistory(const std::string &history, int round, const std::string &input, const std::string &output);
49 |
50 | virtual void FillLLMInputs(std::vector > &inputTokens,
51 | const std::map ¶ms,
52 | Data &inputIds, Data &attentionMask, Data &positionIds);
53 |
54 | virtual void FillLLMInputsBatch(std::vector > &inputTokens,
55 | const std::vector > ¶ms,
56 | Data &inputIds, Data &attentionMask, Data &positionIds);
57 |
58 | virtual void WarmUp();
59 |
60 | void UpdateRotaryPosEmb(float ntk_alpha);
61 |
62 | int seq_length;
63 | float ntk_alpha;
64 |
65 | bool use_log_attn;
66 | Data logn_list;
67 |
68 | private:
69 | std::string im_start = "<|im_start|>";
70 | std::string im_end = "<|im_end|>";
71 | };
72 | }
73 |
74 | #endif //TEST_QWEN_H
--------------------------------------------------------------------------------
/include/models/xlmroberta.h:
--------------------------------------------------------------------------------
1 |
2 | #ifndef FASTLLM_XLMROBERTA_H
3 | #define FASTLLM_XLMROBERTA_H
4 |
5 | #include "basellm.h"
6 | #include "bert.h"
7 | #include "fastllm.h"
8 |
9 | namespace fastllm {
10 | class XlmRobertaModel : BertModel {
11 | public:
12 | XlmRobertaModel();
13 |
14 | ~XlmRobertaModel() {
15 | this->weight.ReleaseWeight();
16 | };
17 |
18 | void InitParams(); // 初始化参数信息
19 |
20 | void FillBertInputsBatch(const std::vector > &tokens,
21 | Data &inputIds, Data &attentionMask, Data &tokenTypeIds, Data &positionIds);
22 |
23 | // 推理
24 | std::vector > ForwardAll(
25 | const Data &inputIds,
26 | const Data &attentionMask,
27 | const Data &tokenTypeIds,
28 | const Data &positionIds,
29 | bool normalize);
30 |
31 | std::string model_type;
32 |
33 | float layer_norm_eps = 1e-12;
34 |
35 | int embed_dim = 512;
36 | int num_attention_heads = 64;
37 | int head_dim = embed_dim / num_attention_heads;
38 | int max_positions = 32768;
39 | int block_cnt = 12;
40 |
41 | std::map deviceMap;
42 | };
43 | }
44 |
45 | #endif //FASTLLM_XLMROBERTA_H
--------------------------------------------------------------------------------
/install.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | folder="build-fastllm"
3 |
4 | # 创建工作文件夹
5 | if [ ! -d "$folder" ]; then
6 | mkdir "$folder"
7 | fi
8 |
9 | cd $folder
10 | cmake .. "$@"
11 | make -j$(nproc)
12 |
13 | #编译失败停止执行
14 | if [ $? != 0 ]; then
15 | exit -1
16 | fi
17 |
18 | cd tools
19 | pip install .[all]
20 | #python3 setup.py sdist build
21 | #python3 setup.py bdist_wheel
22 | #python3 setup.py install --all
--------------------------------------------------------------------------------
/make_whl.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | folder="build-fastllm"
3 |
4 | # 创建工作文件夹
5 | rm -rf "$folder"
6 | mkdir "$folder"
7 | cd $folder
8 |
9 | # cpu
10 | rm -rf CMakeCache.txt CMakeFiles
11 | cmake .. -DMAKE_WHL_X86=ON -DUSE_CUDA=OFF -DUSE_NUMA=ON
12 | make fastllm_tools -j$(nproc)
13 | if [ $? != 0 ]; then
14 | exit -1
15 | fi
16 | cp tools/ftllm/libfastllm_tools.so tools/ftllm/libfastllm_tools-cpu.so
17 |
18 | # cuda-10
19 | #rm -rf CMakeCache.txt CMakeFiles
20 | #cmake .. -DMAKE_WHL_X86=ON -DUSE_CUDA=ON -DCUDA_ARCH=70 -D CMAKE_CUDA_COMPILER=/usr/local/cuda-10.1/bin/nvcc
21 | #make fastllm_tools -j$(nproc)
22 | #if [ $? != 0 ]; then
23 | # exit -1
24 | #fi
25 | #cp tools/ftllm/libfastllm_tools.so tools/ftllm/libfastllm_tools-cu10.so
26 |
27 | # cuda-11
28 | rm -rf CMakeCache.txt CMakeFiles
29 | cmake .. -DMAKE_WHL_X86=ON -DUSE_CUDA=ON -DUSE_NUMA=ON -DCUDA_ARCH="52;53;70" -D CMAKE_CXX_COMPILER=g++-10 -D CMAKE_CUDA_HOST_COMPILER=/usr/bin/g++-10 -D CMAKE_CUDA_COMPILER=/usr/local/cuda-11.3/bin/nvcc
30 | make fastllm_tools -j$(nproc)
31 | if [ $? != 0 ]; then
32 | exit -1
33 | fi
34 | cp tools/ftllm/libfastllm_tools.so tools/ftllm/libfastllm_tools-cu11.so
35 |
36 | # cuda-12
37 | rm -rf CMakeCache.txt CMakeFiles
38 | cmake .. -DMAKE_WHL_X86=ON -DUSE_CUDA=ON -DUSE_NUMA=ON -DCUDA_ARCH="52;53;70;89" -D CMAKE_CXX_COMPILER=g++-11 -D CMAKE_CUDA_HOST_COMPILER=/usr/bin/g++-11 -D CMAKE_CUDA_COMPILER=/usr/local/cuda-12.1/bin/nvcc
39 | make fastllm_tools -j$(nproc)
40 | if [ $? != 0 ]; then
41 | exit -1
42 | fi
43 |
44 | cd tools
45 | ldd ftllm/libfastllm_tools.so | grep '=>' | awk '{print $3}' | grep 'libnuma' | xargs -I {} cp -n {} ftllm/.
46 | python3 setup.py sdist build
47 | python3 setup.py bdist_wheel --plat-name manylinux2014_$(uname -m)
48 | #python3 setup.py install --all
--------------------------------------------------------------------------------
/make_whl_rocm.sh:
--------------------------------------------------------------------------------
1 | source ~/ftllm/bin/activate
2 | pip install setuptools wheel -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
3 |
4 | #!/bin/bash
5 | folder="build-fastllm-rocm"
6 |
7 | # 创建工作文件夹
8 | rm -rf "$folder"
9 | mkdir "$folder"
10 | cd $folder
11 |
12 | # cpu
13 | rm -rf CMakeCache.txt CMakeFiles
14 | cmake .. -DMAKE_WHL_X86=ON -DUSE_ROCM=OFF -DUSE_NUMA=ON
15 | make fastllm_tools -j$(nproc)
16 | if [ $? != 0 ]; then
17 | exit -1
18 | fi
19 | cp tools/ftllm/libfastllm_tools.so tools/ftllm/libfastllm_tools-cpu.so
20 |
21 | # cuda-11
22 | rm -rf CMakeCache.txt CMakeFiles
23 | cmake .. -DMAKE_WHL_X86=ON -DUSE_ROCM=ON -DUSE_NUMA=ON -DROCM_ARCH="gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101"
24 | make fastllm_tools -j$(nproc)
25 | if [ $? != 0 ]; then
26 | exit -1
27 | fi
28 |
29 | cd tools
30 | ldd ftllm/libfastllm_tools.so | grep '=>' | awk '{print $3}' | grep 'libnuma' | xargs -I {} cp -n {} ftllm/.
31 | python3 setup_rocm.py sdist build
32 | python3 setup_rocm.py bdist_wheel --plat-name manylinux2014_$(uname -m)
--------------------------------------------------------------------------------
/pyfastllm/examples/convert_model.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from transformers import AutoTokenizer, AutoModel
3 |
4 | import fastllm
5 |
6 | def export():
7 | model_path = '/public/Models/chatglm-6b' # 仅支持fp32模型加载
8 | export_path = "chatglm-6b-fp32.flm"
9 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
10 | model = AutoModel.from_pretrained(model_path, trust_remote_code=True).float()
11 | model = model.eval()
12 |
13 | fastllm.utils.convert(model=model, tokenizer=tokenizer, output_path=export_path, verbose=True)
14 |
15 | def response(model, prompt_input:str, stream_output:bool=False):
16 | gmask_token_id = 130001
17 | bos_token_id = 130004
18 |
19 | input_ids = model.weight.tokenizer.encode(prompt_input)
20 | input_ids = input_ids.to_list()
21 | input_ids.extend([gmask_token_id, bos_token_id])
22 | input_ids = [int(v) for v in input_ids]
23 |
24 | handle = model.launch_response(input_ids)
25 | continue_token = True
26 |
27 | ret_byte = b""
28 | ret_str = ""
29 |
30 | while continue_token:
31 | resp_token = model.fetch_response(handle)
32 | continue_token = (resp_token != -1)
33 |
34 | content = model.weight.tokenizer.decode_byte([resp_token])
35 | ret_byte += content
36 | ret_str = ret_byte.decode(errors='ignore')
37 |
38 | if stream_output:
39 | yield ret_str
40 |
41 | return ret_str
42 |
43 | def infer():
44 | model_path = "chatglm-6b-fp32.flm"
45 | model = fastllm.create_llm(model_path)
46 |
47 | prompt = "你好"
48 | outputs = response(model, prompt_input=prompt, stream_output=True)
49 | for output in outputs:
50 | print('\r LLM:' + output, end='', flush=True)
51 |
52 | print()
53 |
54 |
55 | if __name__ == "__main__":
56 | # export()
57 | infer()
--------------------------------------------------------------------------------
/pyfastllm/examples/test_chatglm2.py:
--------------------------------------------------------------------------------
1 | import fastllm
2 | from fastllm.hub.chatglm2 import ChatGLM2, ChatGLMConfig
3 | import fastllm.functions as ops
4 |
5 | from transformers import AutoTokenizer
6 |
7 | def load_weights():
8 | file = "/home/pan/Public/Models/models-flm/chatglm2-6b.flm"
9 | state_dict = ops.load(file)
10 | return state_dict
11 |
12 | def run():
13 | # fastllm.set_device_map({"cuda:0": 28})
14 | state_dict = load_weights()
15 | cfg = ChatGLMConfig()
16 | model = ChatGLM2(cfg)
17 | model.set_weights(state_dict)
18 | print("model loaded!!!")
19 |
20 | model_path = "/home/pan/Public/Models/models-hf/chatglm2-6b"
21 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
22 | # model.warmup()
23 | res = ""
24 | for output in model.stream_chat(query="飞机为什么会飞", tokenizer=tokenizer):
25 | res = output
26 |
27 | print("最终问答", res)
28 |
29 | if __name__ == "__main__":
30 | run()
31 |
32 |
--------------------------------------------------------------------------------
/pyfastllm/examples/test_ops.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import pytest
3 | import numpy as np
4 | import fastllm
5 |
6 | import pyfastllm
7 | import gc
8 |
9 | import np_ops
10 | import ops as flm_ops
11 |
12 | # from fastllm import ops as flm_ops
13 | # from fastllm import np_ops
14 |
15 | np.random.seed(42)
16 |
17 | def diff(dataA, dataB):
18 | # print(dataA)
19 | # print(dataB)
20 | mae = np.max(np.abs(dataA - dataB))
21 | print('max abs err is ', mae)
22 | return mae
23 |
24 | def to_tensor(data):
25 | return pyfastllm.from_numpy(data)
26 |
27 | def to_numpy(data):
28 | # return data.numpy()
29 | return np.array(data, copy=False, order='C')
30 |
31 | def test_rms_norm(inputs=None, weights=None, eps=1e-6):
32 | if not inputs:
33 | inputs = np.random.random(size=[1, 256])
34 | weights = np.random.random(size=[1, 256])
35 |
36 | np_out = np_ops.rms_norm(inputs, weights, eps)
37 | flm_out = flm_ops.rms_norm(to_tensor(inputs), to_tensor(weights), eps)
38 | mae = diff(np_out, to_numpy(flm_out))
39 | assert mae <= 1e-6
40 | return flm_out
41 |
42 | def test_swiglu(inputs=None):
43 | if not inputs:
44 | inputs = np.random.random(size=[1, 256])
45 |
46 | np_out = np_ops.swiglu(inputs)
47 | out = flm_ops.activation(inputs=to_tensor(inputs), activate_type="swiglu")
48 | mae = diff(np_out, to_numpy(out))
49 | assert mae <= 1e-6
50 | return out
51 |
52 | def test_attention(q=None, k=None, v=None, mask=None, group=1, scale=1.0):
53 | if q is None:
54 | q = np.random.random(size=[12, 1, 4096])
55 | k = np.random.random(size=[12, 1, 4096])
56 | v = np.random.random(size=[12, 1, 4096])
57 | scale = 1 / np.sqrt(q.shape[-1])
58 |
59 | np_out = np_ops.attention(q, k, v, scale=scale)
60 |
61 | mask = fastllm.Tensor()
62 | flm_out = flm_ops.attention(to_tensor(q), to_tensor(k), to_tensor(v), mask, group=group, scale=scale, attentionType=0)
63 |
64 | mae = diff(np_out, to_numpy(flm_out))
65 | assert mae <= 1e-6
66 | return flm_out
67 |
68 |
69 | def test_linear(inputs=None,
70 | weights=None,
71 | bias=None):
72 |
73 | if not inputs:
74 | inputs = np.random.random(size=[1, 12, 4096])
75 | weights = np.random.random(size=[256, 4096])
76 |
77 | np_out = np_ops.linear(inputs=inputs, weights=weights, bias=None)
78 |
79 | if not bias:
80 | bias = fastllm.Tensor()
81 |
82 | output = flm_ops.linear(to_tensor(inputs), to_tensor(weights), bias)
83 | mae = diff(np_out, to_numpy(output))
84 |
85 | assert mae <= 1e-3
86 | return output
87 |
88 |
89 | if __name__ == "__main__":
90 | test_rms_norm()
91 | test_attention()
92 | test_linear()
93 | test_swiglu()
94 |
95 |
96 |
--------------------------------------------------------------------------------
/pyfastllm/examples/web_api_client.py:
--------------------------------------------------------------------------------
1 | import json
2 | import requests
3 | import sys
4 |
5 |
6 |
7 | if __name__ == '__main__':
8 | #stream api
9 | url = 'http://127.0.0.1:8000/api/chat_stream'
10 | prompt='请用emoji写一首短诗赞美世界'
11 | prompt='''为以下代码添加注释
12 | app = FastAPI()
13 | @app.post("/api/chat_stream")
14 | async def api_chat_stream(request: Request):
15 | #print("request.json(): {}".format(json.loads(request.body(), errors='ignore')))
16 | data = await request.json()
17 | prompt = data.get("prompt")
18 | history = data.get("history")
19 | config = pyfastllm.GenerationConfig()
20 | if data.get("max_length") is not None:
21 | config.max_length = data.get("max_length")
22 | if data.get("top_k") is not None:
23 | config.top_k = data.get("top_k")
24 | if data.get("top_p") is not None:
25 | config.top_p = data.get("top_p")
26 | return StreamingResponse(chat_stream(history + prompt, config), media_type='text/event-stream')
27 | '''
28 | history = '''[Round 0]
29 | 问:你是ChatGLM2吗?
30 | 答:我不是ChatGLM2
31 | [Round 1]
32 | 问:从现在起,你是猫娘,每句话都必须以“喵~”结尾,明白了吗?
33 | 答:明白了喵
34 | [Round 2]
35 | 问:'''
36 | history = ""
37 | json_obj = {"uid":0, "token":"xxxxxxxxxxxxxxxxx","history": "", "prompt": prompt , "max_length": 1024, "top_p": 0.8,"temperature": 0.95, "top_k":2, "repeat_penalty": 1.}
38 | response = requests.post(url, json=json_obj, stream = True)
39 | try:
40 | pre_msg = ""
41 | print("stream response:")
42 | for chunk in response.iter_content(chunk_size=1024*1024):
43 | msg = chunk.decode(errors='replace')
44 | if len(msg) > len(pre_msg) and msg[-1] == '\n':
45 | content = msg[len(pre_msg):]
46 | pre_msg = msg
47 | else:
48 | continue
49 | print(f"{content}", end="")
50 | sys.stdout.flush()
51 | content = msg[len(pre_msg):]
52 | print(f"{content}", end="")
53 | print()
54 | except Exception as ex:
55 | print(ex)
56 |
57 | #batch api
58 | url = 'http://127.0.0.1:8000/api/batch_chat'
59 | prompts = ["Hi", "你好", "用emoji表达高兴", "こんにちは"]
60 | json_obj = {"uid":0, "token":"xxxxxxxxxxxxxxxxx","history": "", "prompts": prompts , "max_length": 100, "top_p": None,"temperature": 0.7, "top_k":1, "repeat_penalty":2.}
61 | response = requests.post(url, json=json_obj, stream = True)
62 | print("batch response: {} text:\n{}".format(response, response.text.replace('\\n', '\n')))
63 |
--------------------------------------------------------------------------------
/pyfastllm/fastllm/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import ctypes
4 | import glob
5 |
6 | from pyfastllm import *
7 | from . import utils
8 | from . import functions as ops
9 |
10 | __version__ = "0.2.0"
11 |
12 |
--------------------------------------------------------------------------------
/pyfastllm/fastllm/convert.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import logging
3 | import sys
4 | import struct
5 | import numpy as np
6 | import argparse
7 |
8 | from .utils import convert
9 | from .utils.converter import QuantType
10 |
11 | def parse_args():
12 | # -p 模型路径或hf路径
13 | # -o --out_path 导出路径
14 | # -q 量化位数
15 | parser = argparse.ArgumentParser(description='build fastllm libs')
16 | parser.add_argument('-o', dest='export_path', default=None,
17 | help='output export path')
18 | parser.add_argument('-p', dest='model_path', type=str, default='',
19 | help='the model path or huggingface path, such as: -p THUDM/chatglm-6b')
20 | parser.add_argument('--lora', dest='lora_path', default='',
21 | help='lora model path')
22 | parser.add_argument('-m', dest='model', default='chatglm6B',
23 | help='model name with(alpaca, baichuan7B, chatglm6B, moss)')
24 | parser.add_argument('-q', dest='q_bit', type=int,
25 | help='model quantization bit')
26 | args = parser.parse_args()
27 | return args
28 |
29 |
30 | def main(args=None):
31 | if not args: args = parse_args()
32 |
33 | quant_type_to_qbit = {
34 | QuantType.FP32: 32,
35 | QuantType.FP16: 16,
36 | QuantType.INT8: 8,
37 | QuantType.INT4: 4,
38 | }
39 | qbit_to_quant_type = {v: k for k, v in quant_type_to_qbit.items()}
40 | q_type = qbit_to_quant_type[args.q_bit]
41 | convert(args.model_path, args.export_path, q_type=q_type)
42 |
43 | if __name__ == "__main__":
44 | args = parse_args()
45 | main(args)
--------------------------------------------------------------------------------
/pyfastllm/fastllm/functions/__init__.py:
--------------------------------------------------------------------------------
1 | from .fastllm_ops import *
2 | from . import util
--------------------------------------------------------------------------------
/pyfastllm/fastllm/functions/custom_ops.py:
--------------------------------------------------------------------------------
1 | import triton as tl
--------------------------------------------------------------------------------
/pyfastllm/fastllm/functions/numpy_ops.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from numba import cuda
3 | from numba import jit
4 |
5 | @jit(nopython=True)
6 | def rms_norm(inputs, weights, eps):
7 | channel = inputs.shape[-1]
8 | sqrt_mean = np.sqrt(np.sum(inputs**2)/channel + eps)
9 | return inputs / sqrt_mean *weights
10 |
11 | @jit(nopython=True)
12 | def layer_norm(inputs, gamma, beta, axis=-1):
13 | assert axis < len(inputs.shapes), "axis should less than inputs dims"
14 | channel = inputs.shape[axis]
15 | mean = np.mean(inputs, axis=axis)
16 | var = np.var(inputs, axis=axis)
17 |
18 | output = (inputs - mean) / var * gamma + beta
19 | return output
20 |
21 | # @jit
22 | def softmax(inputs, axis=None):
23 | maxv = inputs.max(axis, keepdims=True)
24 | exp_v = np.exp(inputs - maxv)
25 | exp_sum = np.sum(exp_v, axis=axis)
26 | return exp_v / exp_sum
27 |
28 | @jit(nopython=True)
29 | def silu(inputs, ):
30 | return inputs / (1 + np.exp(-inputs))
31 |
32 | @jit
33 | def swiglu(inputs, ):
34 | dim = inputs.shape[1] // 2
35 | for batch in range(inputs.shape[0]):
36 | return inputs[batch, :dim] / (1 + np.exp(-inputs[batch, :dim])) * inputs[batch, dim:]
37 |
38 | # @jit
39 | def linear(inputs, weights, bias):
40 | if len(inputs.shape) == 2:
41 | inputs = inputs[None, :]
42 | weights = weights[None, :]
43 |
44 | output = np.zeros(shape=[inputs.shape[0], inputs.shape[1], weights.shape[0]])
45 | for batch in range(inputs.shape[0]):
46 | output[batch] = np.matmul(inputs[batch], weights.T)
47 |
48 | if bias:
49 | output[batch] += bias[batch]
50 |
51 | return output
52 |
53 | # @jit
54 | def attention(q, k, v, mask=None, group=None, scale=None):
55 | print("shape:", q.shape)
56 | if len(q.shape) == 2:
57 | q = q[None, :]
58 | k = k[None, :]
59 | v = v[None, :]
60 | # mask = mask[None, :]
61 |
62 | attn = np.zeros_like(q)
63 | for batch in range(q.shape[0]):
64 | qk = softmax(q[batch] @ k[batch].T * scale, axis=-1)
65 | attn[batch, :, :] = qk @ v[batch]
66 | return attn
--------------------------------------------------------------------------------
/pyfastllm/fastllm/functions/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pyfastllm
3 |
4 | def diff(dataA, dataB):
5 | mae = np.max(np.abs(dataA - dataB))
6 | print('max abs err is ', mae)
7 | return mae
8 |
9 | def to_tensor(data):
10 | if not isinstance(data, np.ndarray):
11 | return None
12 | return pyfastllm.from_numpy(data)
13 |
14 | def to_numpy(data):
15 | if not isinstance(data, pyfastllm.Tensor):
16 | return None
17 |
18 | return np.array(data, copy=False)
--------------------------------------------------------------------------------
/pyfastllm/fastllm/hub/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/pyfastllm/fastllm/hub/__init__.py
--------------------------------------------------------------------------------
/pyfastllm/fastllm/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_module import Module
2 | from .modules import *
3 |
--------------------------------------------------------------------------------
/pyfastllm/fastllm/nn/base_module.py:
--------------------------------------------------------------------------------
1 | import pyfastllm
2 | from typing import Any
3 | from abc import abstractmethod
4 |
5 | class Module():
6 | def __init__(self) -> None:
7 | pass
8 |
9 | def __call__(self, *args: Any, **kwds: Any) -> Any:
10 | return self.forward(*args, **kwds)
11 |
12 | @abstractmethod
13 | def forward(self, ):
14 | pass
15 |
16 | def _init_weight(self, ):
17 | pass
18 |
19 | import numpy as np
20 | from typing import Union, Sequence
21 | from pyfastllm import Tensor
22 | from ..functions import util
23 |
24 | class Parameter(object):
25 | _DEFAULT_DTYPE = pyfastllm.float32
26 |
27 | def __init__(self,
28 | value: Union[np.ndarray, None] = None,
29 | shape: Sequence[int] = None,
30 | dtype: Union[pyfastllm.DataType, None] = None):
31 | dtype = self._DEFAULT_DTYPE if dtype is None else dtype
32 | if value is None:
33 | assert isinstance(shape, (list, tuple))
34 | self._value = pyfastllm.Tensor()
35 | """
36 | value = np.zeros(shape=shape, dtype=np.float32)
37 |
38 | if len(shape) == 2:
39 | v_range = np.sqrt(6) / np.sqrt(shape[0] + shape[1])
40 | else:
41 | v_range = 0.1
42 |
43 | # value ~ U[-1, 1]
44 | value = np.random.random(size=shape) * 2 - 1
45 | value = np.array(value, dtype=np.float32)
46 | # value ~ U[-v_range, v_range]
47 | value *= v_range
48 | """
49 | else:
50 | self._value = util.to_tensor(value)
51 |
52 | @property
53 | def value(self) -> Tensor:
54 | if isinstance(self._value, np.ndarray):
55 | self._value = util.to_tensor(self._value)
56 |
57 | return self._value
58 |
59 | @value.setter
60 | def value(self, v: np.ndarray):
61 | assert isinstance(v, np.ndarray) or isinstance(v, pyfastllm.Tensor)
62 | # assert v.shape == self._value.shape, \
63 | # ('The value updated is not the same shape as the original. ', \
64 | # f'Updated: {v.shape}, original: {self._value.shape}')
65 | self._value = v
66 |
--------------------------------------------------------------------------------
/pyfastllm/fastllm/nn/modules.py:
--------------------------------------------------------------------------------
1 | from .base_module import Module
2 | from .base_module import Parameter
3 | from ..functions import fastllm_ops as F
4 | from ..functions import util
5 | import numpy as np
6 |
7 | class Linear(Module):
8 | def __init__(self, in_dim, out_dim, bias=False) -> None:
9 | self.has_bias = bias
10 | self.weights = Parameter(shape=(out_dim, in_dim))
11 | self.bias = None
12 |
13 | if bias:
14 | self.bias = Parameter(shape=(out_dim, ))
15 |
16 | super().__init__()
17 |
18 | def forward(self, x):
19 | if self.has_bias:
20 | return F.linear(x, self.weights.value, self.bias.value)
21 |
22 | return F.linear(x, self.weights.value)
23 |
24 | class SiLU(Module):
25 | def __init__(self) -> None:
26 | super().__init__()
27 |
28 | def forward(self, x, axis=-1):
29 | return F.activation(x, axis=axis, activate_type='silu')
30 |
31 | class SwiGLU(Module):
32 | def __init__(self) -> None:
33 | super().__init__()
34 |
35 | def forward(self, inputs):
36 | return F.activation(inputs=inputs, activate_type="swiglu")
37 |
38 | class Embedding(Module):
39 | def __init__(self, vocab_size, embed_dim) -> None:
40 | super().__init__()
41 | self.vocab_size = vocab_size
42 | self.embed_dim = embed_dim
43 | self.weights = Parameter(shape=[vocab_size, embed_dim])
44 |
45 | def forward(self, inputs):
46 | return F.embedding(inputs, self.weights.value)
47 |
48 | class RMSNorm(Module):
49 | def __init__(self, dim=4096, eps=1e-5) -> None:
50 | super().__init__()
51 | self.weights = Parameter(shape=[dim, ])
52 | self.eps = eps
53 |
54 | def forward(self, inputs):
55 | return F.rms_norm(inputs, self.weights.value, eps=self.eps)
56 |
57 | class Attention(Module):
58 | def __init__(self) -> None:
59 | super().__init__()
60 |
61 | def forward(self, q, k, v, mask, group, scale):
62 | return F.attention(q, k, v, mask, group=group, scale=scale, attentionType=0)
63 |
64 | class RoPE(Module):
65 | def __init__(self, rotary_dim=128) -> None:
66 | super().__init__()
67 | self.rotary_dim = rotary_dim
68 | self.sin_data, self.cos_data = self._get_sin_cos_data()
69 | self.sin_data = util.to_tensor(self.sin_data)
70 | self.cos_data = util.to_tensor(self.cos_data)
71 |
72 | def _get_sin_cos_data(self, base=1e4, seq_len=32768, dim=128):
73 | inv_freq = 1.0 / (base ** (np.arange(0, dim, 2) / dim))
74 | t = np.arange(0, seq_len)
75 | freqs = np.einsum('i,j->ij', t, inv_freq)
76 | emb = np.concatenate((freqs, freqs), axis=-1)
77 | return np.sin(emb), np.cos(emb)
78 |
79 | def forward(self, data, pos_id):
80 | return F.RotatePosition2D(data, pos_id, self.sin_data, self.cos_data, self.rotary_dim)
81 |
82 | class NearlyRoPE(RoPE):
83 | def __init__(self, rotary_dim=64) -> None:
84 | super().__init__(rotary_dim)
85 |
86 | def forward(self, data, pos_id):
87 | outputs = F.NearlyRotatePosition2D(data, pos_id, self.sin_data, self.cos_data, self.rotary_dim)
88 | return outputs
--------------------------------------------------------------------------------
/pyfastllm/fastllm/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer
2 |
3 | from .quantizer import QuantType
4 | from .converter import ChatglmConverter, BaichuanConverter, QwenConverter, MossConverter
5 |
6 | def convert(hf_model_name_or_path:str, save_path:str, q_type=QuantType.INT4):
7 | config = AutoConfig.from_pretrained(hf_model_name_or_path, trust_remote_code=True)
8 | tokenizer = AutoTokenizer.from_pretrained(hf_model_name_or_path, trust_remote_code=True)
9 |
10 | if "Baichuan" in config.architectures:
11 | model = AutoModelForCausalLM.from_pretrained(hf_model_name_or_path, trust_remote_code=True).cpu().eval()
12 | converter = BaichuanConverter(model=model, tokenizer=tokenizer, q_type=q_type)
13 | elif "ChatGLM" in config.architectures:
14 | model = AutoModel.from_pretrained(hf_model_name_or_path, trust_remote_code=True).cpu().eval()
15 | converter = ChatglmConverter(model=model, tokenizer=tokenizer, q_type=q_type)
16 | elif "Qwen" in config.architectures:
17 | model = AutoModelForCausalLM.from_pretrained(hf_model_name_or_path, trust_remote_code=True, fp16=True).cpu().eval()
18 | converter = QwenConverter(model=model, tokenizer=tokenizer, q_type=q_type)
19 | elif "Moss" in config.architectures:
20 | model = AutoModelForCausalLM.from_pretrained(hf_model_name_or_path, trust_remote_code=True).cpu().eval()
21 | converter = MossConverter(model=model, tokenizer=tokenizer, q_type=q_type)
22 | else:
23 | raise NotImplementedError(f"Unsupport model: {config.architectures}")
24 |
25 | converter.dump(save_path)
26 |
--------------------------------------------------------------------------------
/pyfastllm/fastllm/utils/writer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import struct
3 | from enum import Enum
4 |
5 | class QuantType(Enum):
6 | FP32 = 0
7 | FP16 = 7
8 | INT8 = 3
9 | INT4 = 8
10 |
11 | def write_int8(fo, v):
12 | c_max = np.expand_dims(np.abs(v).max(axis = -1), -1).clip(0.1, 1e100)
13 | c_scale = c_max / 127.0
14 | v = (v / c_scale + 128.5).clip(1, 255).astype(np.uint8)
15 | fo.write(struct.pack('i', 3))
16 | fo.write(struct.pack('i', 0))
17 | for i in range(c_max.shape[0]):
18 | fo.write(struct.pack('f', -c_max[i][0]))
19 | fo.write(struct.pack('f', c_max[i][0]))
20 | fo.write(v.data)
21 |
22 | def write_int4(fo, v):
23 | c_min = np.expand_dims(-np.abs(v).max(axis = -1), -1)
24 | c_max = np.expand_dims(np.abs(v).max(axis = -1), -1)
25 | c_scale = c_max / 7.0
26 | c_min = c_scale * -8.0
27 | v = (v - c_min) / c_scale
28 | v = (v + 0.5).astype(np.int8).clip(0, 15).astype(np.uint8)
29 | v = v[:, 0::2] * 16 + v[:, 1::2]
30 | fo.write(struct.pack('i', 8))
31 | fo.write(struct.pack('i', 0))
32 | for i in range(c_min.shape[0]):
33 | fo.write(struct.pack('f', c_min[i][0]))
34 | fo.write(struct.pack('f', c_max[i][0]))
35 | fo.write(v.data)
36 |
37 | class Writer():
38 | def __init__(self, outpath) -> None:
39 | self.fd = open(outpath, 'wb')
40 |
41 | def __del__(self, ):
42 | if not self.fd.closed:
43 | self.fd.close()
44 |
45 | def write(self, value):
46 | if isinstance(value, int):
47 | self.fd.write(struct.pack('i', value))
48 | elif isinstance(value, float):
49 | self.fd.write(struct.pack('f', value))
50 | elif isinstance(value, str):
51 | self.write_str(value)
52 | elif isinstance(value, bytes):
53 | self.write_bytes(value)
54 | elif isinstance(value, list):
55 | self.write_list(value)
56 | elif isinstance(value, dict):
57 | self.write_dict(value)
58 | elif isinstance(value, np.ndarray):
59 | self.write_tensor(value)
60 | else:
61 | raise NotImplementedError(f"Unsupport data type: {type(value)}")
62 |
63 | def write_str(self, s):
64 | self.write(len(s))
65 | self.fd.write(s.encode())
66 |
67 | def write_bytes(self, s):
68 | self.write(len(s))
69 | for c in s: self.write(int(c))
70 |
71 | def write_list(self, data):
72 | self.write(len(data))
73 | for d in data: self.write(d)
74 |
75 | def write_dict(self, data):
76 | self.write(len(data))
77 | for key in data:
78 | self.write_str(key)
79 | self.write(data[key])
80 |
81 | def write_tensor(self, data, data_type:QuantType=QuantType.FP32):
82 | self.write(list(data.shape))
83 | if data_type == QuantType.INT4:
84 | write_int4(self.fd, data)
85 | elif data_type == QuantType.INT8:
86 | write_int8(self.fd, data)
87 | else:
88 | self.write(int(data_type.value))
89 | self.fd.write(data.data)
90 |
91 |
--------------------------------------------------------------------------------
/pyfastllm/install.sh:
--------------------------------------------------------------------------------
1 | rm -rf build/ && rm -rf dist/
2 | python3 setup.py sdist bdist_wheel
3 | pip install dist/*.whl --force-reinstall
4 | # python3 examples/test_ops.py # coredump when run with cuda backend
--------------------------------------------------------------------------------
/requirements-server.txt:
--------------------------------------------------------------------------------
1 | fastapi
2 | pydantic
3 | openai
4 | shortuuid
5 | uvicorn
6 |
--------------------------------------------------------------------------------
/simple_install.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | folder="build-fastllm"
3 |
4 | # 创建工作文件夹
5 | if [ ! -d "$folder" ]; then
6 | mkdir "$folder"
7 | fi
8 |
9 | cd $folder
10 | cmake .. "$@"
11 | make -j$(nproc)
12 |
13 | #编译失败停止执行
14 | if [ $? != 0 ]; then
15 | exit -1
16 | fi
17 |
18 | cd tools
19 | pip install .
20 | #python3 setup.py sdist build
21 | #python3 setup.py bdist_wheel
22 | #python3 setup.py install --all
--------------------------------------------------------------------------------
/src/devices/cpu/avx512bf16.cpp:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 5/8/25.
3 | //
4 |
5 | #include
6 |
7 | #ifdef __AVX2__
8 | #include "immintrin.h"
9 | #endif
10 |
11 | namespace fastllm {
12 | bool LinearBFloat16FP8E4M3_AVX512BF16_Kernel(uint16_t *inputData, uint8_t *weightData, float *biasData, float *outputData,
13 | int n, int m, int k, int st, int end, int blockK, int blockM, float *scales,
14 | int ks, int ms, float magicScale) {
15 | if (!(m % blockM == 0 && blockM % 32 == 0)) {
16 | return false;
17 | }
18 | #ifdef __AVX512BF16__
19 | for (int i = 0; i < n; i++) {
20 | int j = st;
21 | __m256i v_a_mask_byte = _mm256_set1_epi8(0x80);
22 | __m256i v_b_mask_byte = _mm256_set1_epi8(0x7F);
23 | for (; j < end; j++) {
24 | float now = biasData ? biasData[j] : 0.0f;
25 | __m512 last_sum = _mm512_setzero_ps(); // Accumulator for 16 parallel sums
26 |
27 | for (int midx = 0; midx < ms; midx++) {
28 | float curScale = scales[j / blockK * ms + midx];
29 | __m512 vScale = _mm512_set1_ps(curScale);
30 |
31 | int l = midx * blockM;
32 | __m512 v_sum = _mm512_setzero_ps(); // Accumulator for 16 parallel sums
33 | for (; l + 31 < m && l + 31 < (midx + 1) * blockM; l += 32) {
34 | // 1. Load 32 BF16 inputs
35 | // Treat uint16_t* as __m512bh* - use loadu for unaligned access
36 | __m512bh v_input_bf16 = (__m512bh)_mm512_loadu_si512((__m512i const*)(inputData + i * m + l));
37 | // 2. Load 32 FP8 weights
38 | __m256i va_bytes = _mm256_loadu_si256((__m256i*)&weightData[j * m + l]);
39 |
40 | __m256i va_masked_bytes = _mm256_and_si256(va_bytes, v_a_mask_byte);
41 | __m512i va_promoted_words = _mm512_cvtepu8_epi16(va_masked_bytes);
42 | __m512i v_a_term_shifted = _mm512_slli_epi16(va_promoted_words, 8);
43 |
44 | __m256i vb_masked_bytes = _mm256_and_si256(va_bytes, v_b_mask_byte);
45 | __m512i vb_promoted_words = _mm512_cvtepu8_epi16(vb_masked_bytes);
46 | __m512i v_b_term_shifted = _mm512_slli_epi16(vb_promoted_words, 4);
47 |
48 | __m512i v_result = _mm512_or_si512(v_a_term_shifted, v_b_term_shifted);
49 | __m512bh v_weights_bf16 = (__m512bh)v_result;
50 |
51 | // 3. Compute dot product: v_sum += v_input_bf16 * v_weights_bf16
52 | v_sum = _mm512_dpbf16_ps(v_sum, v_input_bf16, v_weights_bf16);
53 | }
54 |
55 | last_sum = _mm512_fmadd_ps(v_sum, vScale, last_sum);
56 | }
57 |
58 | now += _mm512_reduce_add_ps(last_sum) * magicScale;
59 | outputData[i * k + j] = now;
60 | }
61 | }
62 | return true;
63 | #endif
64 | return false;
65 | }
66 | }
--------------------------------------------------------------------------------
/src/devices/numa/kvcache.cpp:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 24-4-19.
3 | //
4 |
5 | #include "kvcache.h"
6 |
7 | namespace fastllm {
8 | KVCache::KVCache(fastllm::DataType dataType, int head, int dim) {
9 | this->dataType = dataType;
10 | if (dataType == DataType::FLOAT32) {
11 | this->unitSize = 4;
12 | } else if (dataType == DataType::FLOAT16 || dataType == DataType::BFLOAT16) {
13 | this->unitSize = 2;
14 | } else if (dataType == DataType::INT8) {
15 | this->unitSize = 1;
16 | }
17 |
18 | this->head = head;
19 | this->dim = dim;
20 | this->currentCap = 0;
21 | this->len = 0;
22 | }
23 |
24 | KVCache::~KVCache() {
25 | delete this->data;
26 | }
27 |
28 | void KVCache::Append(int len, uint8_t *data) {
29 | this->lastFlushTime = std::chrono::system_clock::now();
30 | if (this->len + len > this->currentCap) {
31 | int newCap = ((this->len + len - 1) / unitLen + 1) * unitLen;
32 | if (this->currentCap != 0) {
33 | uint8_t *old = this->data;
34 | this->data = new uint8_t [head * newCap * dim * unitSize];
35 | for (int h = 0; h < head; h++) {
36 | memcpy(this->data + h * newCap * dim * unitSize,
37 | old + h * this->currentCap * dim * unitSize,
38 | this->currentCap * dim * unitSize);
39 | }
40 | delete old;
41 | } else {
42 | this->data = new uint8_t [head * newCap * dim * unitSize];
43 | }
44 | this->currentCap = newCap;
45 | }
46 | for (int h = 0; h < head; h++) {
47 | memcpy(this->data + (h * this->currentCap + this->len) * dim * unitSize,
48 | data + h * len * dim * unitSize,
49 | len * dim * unitSize);
50 | }
51 | this->len += len;
52 | }
53 |
54 | KVCache *KVCacheManager::Get(long long uid) {
55 | if (this->caches.find(uid) == this->caches.end()) {
56 | return nullptr;
57 | }
58 | return this->caches[uid];
59 | }
60 |
61 | KVCache *KVCacheManager::Get(long long uid, fastllm::DataType dataType, int head, int dim) {
62 | if (this->caches.find(uid) == this->caches.end()) {
63 | this->caches[uid] = new KVCache(dataType, head, dim);
64 | }
65 | return this->caches[uid];
66 | }
67 |
68 | void KVCacheManager::Delete(long long uid) {
69 | if (this->caches.find(uid) != this->caches.end()) {
70 | delete this->caches[uid];
71 | this->caches.erase(uid);
72 | }
73 | }
74 | }
--------------------------------------------------------------------------------
/test/basic/config.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | default_messages_list = [
4 | [
5 | {"role": "system", "content": "You are a helpful assistant."},
6 | {"role": "user", "content": '''北京有什么景点?'''}
7 | ]
8 | ]
9 |
10 | def compute_cosine_similarity(a, b):
11 | l = min(len(a), len(b))
12 | a = a[:l]
13 | b = b[:l]
14 | dot_product = sum(v1 * v2 for v1, v2 in zip(a, b))
15 | norm_vec1 = math.sqrt(sum(v ** 2 for v in a))
16 | norm_vec2 = math.sqrt(sum(v ** 2 for v in b))
17 | cosine_similarity = dot_product / (norm_vec1 * norm_vec2)
18 | return cosine_similarity
19 |
20 | model_list = [
21 | "Qwen/Qwen2-0.5B-Instruct/",
22 | "Qwen/Qwen2-1.5B-Instruct/",
23 | "Qwen/Qwen2-7B-Instruct/"
24 | ]
--------------------------------------------------------------------------------
/test/basic/tokenizer_check.py:
--------------------------------------------------------------------------------
1 | from config import default_messages_list
2 |
3 | import argparse
4 | import logging
5 | import os
6 | from transformers import AutoTokenizer
7 | from ftllm import llm
8 |
9 | def args_parser():
10 | parser = argparse.ArgumentParser(description = 'fastllm_test')
11 | parser.add_argument('--model', type = str, required = True, default = '', help = '模型文件目录')
12 | args = parser.parse_args()
13 | return args
14 |
15 | if __name__ == "__main__":
16 | args = args_parser()
17 | messages_list = default_messages_list
18 |
19 | logger = logging.getLogger()
20 | logging.basicConfig(level = logging.INFO, format = '%(asctime)s - %(levelname)s - %(message)s')
21 |
22 | model_path = args.model
23 |
24 | logger.info("开始测试模型 " + model_path)
25 | logger.info("正在用Transformer读取Tokenizer")
26 | tokenizer = AutoTokenizer.from_pretrained(model_path)
27 | logger.info("读取成功")
28 | logger.info("正在用Fastllm读取Tokenizer")
29 | fastllm_tokenizer = llm.tokenizer(model_path)
30 | logger.info("读取成功")
31 |
32 | check_succ = True
33 | for messages in messages_list:
34 | hf_text = tokenizer.apply_chat_template (messages, tokenize = False, add_generation_prompt = True)
35 | fastllm_text = tokenizer.apply_chat_template (messages, tokenize = False, add_generation_prompt = True)
36 | if (hf_text != fastllm_text):
37 | check_succ = False
38 | logger.error("apply_chat_template结果比对错误" +
39 | "\n\n输入:\n" + str(messages) +
40 | "\n\nhf结果:\n" + hf_text +
41 | "\nfastllm结果:\n" + fastllm_text);
42 | break
43 | hf_tokens = tokenizer.encode(hf_text)
44 | fastllm_tokens = fastllm_tokenizer.encode(fastllm_text)
45 | if (hf_tokens != fastllm_tokens):
46 | check_succ = False
47 | logger.error("encode结果比对错误" +
48 | "\n\n输入:\n" + hf_text +
49 | "\n\nhf结果:\n" + str(hf_tokens) +
50 | "\nfastllm结果:\n" + str(fastllm_tokens));
51 | break
52 |
53 | if check_succ:
54 | logger.info("分词结果比对正确")
55 |
--------------------------------------------------------------------------------
/test/cmmlu/README.md:
--------------------------------------------------------------------------------
1 | CMMLU是一个综合性的中文评估基准,专门用于评估语言模型在中文语境下的知识和推理能力。
2 |
3 | 项目官网网址为: https://github.com/haonan-li/CMMLU
4 |
5 | 本目录下的chatglm.py程序会调用fastllm框架进行测试
6 |
7 | 测试步骤如下:
8 |
9 | - 1. 克隆CMMLU仓库
10 |
11 | ``` sh
12 | git clone https://github.com/haonan-li/CMMLU
13 | ```
14 |
15 | - 2. 测试
16 |
17 | ```
18 | # chatglm测试脚本
19 | # 这里model_name_or_path可以使用ChatGLM2-6b官方的原始模型、int4模型,dtype支持float16, int8, int4
20 | python3 chatglm.py --model_name_or_path 此处填写模型路径 --save_dir 此处填写结果保存路径 --dtype float16
21 |
22 | # baichuan13b测试脚本
23 | # 这里model_name_or_path可以使用Baichuan13B-Base或Baichuan13B-Chat官方的原始模型,dtype支持float16, int8, int4
24 | python3 baichuan.py --model_name_or_path 此处填写模型路径 --save_dir 此处填写结果保存路径 --dtype float16
25 | ```
26 |
27 | 测试数据较多,过程比较漫长,测试中途可以通过以下命令查看已完成的测试成绩
28 |
29 | ```
30 | python3 eval.py 此处填写结果保存路径
31 | ```
32 |
33 | - 3. 参考结果
34 |
35 | | 模型 | Data精度 | Shot | CMMLU分数 |
36 | |-----------------------: |-------- |----------|-----------|
37 | | ChatGLM2-6b-fp16 | float32 |0 | 50.16 |
38 | | ChatGLM2-6b-int8 | float32 |0 | 50.14 |
39 | | ChatGLM2-6b-int4 | float32 |0 | 49.63 |
40 | | QWen-7b-Base-fp16 | float32 |0 | 57.43 |
41 | | QWen-7b-Chat-fp16 | float32 |0 | 54.82 |
42 | | Baichuan-13b-Base-int8 | float32 |5 | 55.12 |
43 | | Baichuan-13b-Base-int4 | float32 |5 | 52.22 |
44 |
--------------------------------------------------------------------------------
/test/cmmlu/eval.py:
--------------------------------------------------------------------------------
1 | import CMMLU.src.mp_utils as mp
2 | import sys
3 | print(mp.get_results(sys.argv[1]))
4 |
--------------------------------------------------------------------------------
/third_party/hipify_torch/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021-2024, Advanced Micro Devices, Inc. All rights reserved.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/third_party/hipify_torch/hipify_torch/__init__.py:
--------------------------------------------------------------------------------
1 | from .version import __version__
2 |
--------------------------------------------------------------------------------
/third_party/hipify_torch/hipify_torch/constants.py:
--------------------------------------------------------------------------------
1 | """ Constants for annotations in the mapping.
2 | The constants defined here are used to annotate the mapping tuples in cuda_to_hip_mappings.py.
3 | They are based on
4 | https://github.com/ROCm-Developer-Tools/HIP/blob/master/hipify-clang/src/Statistics.h
5 | and fall in three categories: 1) type of mapping, 2) API of mapping, 3) unsupported
6 | mapping.
7 | """
8 |
9 | CONV_VERSION = 0,
10 | CONV_INIT = 1
11 | CONV_DEVICE = 2
12 | CONV_MEM = 3
13 | CONV_KERN = 4
14 | CONV_COORD_FUNC = 5
15 | CONV_MATH_FUNC = 6
16 | CONV_DEVICE_FUNC = 7
17 | CONV_SPECIAL_FUNC = 8
18 | CONV_STREAM = 9
19 | CONV_EVENT = 10
20 | CONV_OCCUPANCY = 11
21 | CONV_CONTEXT = 12
22 | CONV_PEER = 13
23 | CONV_MODULE = 14
24 | CONV_CACHE = 15
25 | CONV_EXEC = 16
26 | CONV_ERROR = 17
27 | CONV_DEF = 18
28 | CONV_TEX = 19
29 | CONV_GL = 20
30 | CONV_GRAPHICS = 21
31 | CONV_SURFACE = 22
32 | CONV_JIT = 23
33 | CONV_D3D9 = 24
34 | CONV_D3D10 = 25
35 | CONV_D3D11 = 26
36 | CONV_VDPAU = 27
37 | CONV_EGL = 28
38 | CONV_THREAD = 29
39 | CONV_OTHER = 30
40 | CONV_INCLUDE = 31
41 | CONV_INCLUDE_CUDA_MAIN_H = 32
42 | CONV_TYPE = 33
43 | CONV_LITERAL = 34
44 | CONV_NUMERIC_LITERAL = 35
45 | CONV_LAST = 36
46 |
47 | API_DRIVER = 37
48 | API_RUNTIME = 38
49 | API_BLAS = 39
50 | API_SPECIAL = 40
51 | API_RAND = 41
52 | API_LAST = 42
53 | API_FFT = 43
54 | API_RTC = 44
55 | API_ROCTX = 45
56 | API_ROCMSMI = 46
57 | API_PYT_EXT = 47
58 |
59 | API_WMMA = 48
60 |
61 | HIP_UNSUPPORTED = 49
62 | API_PYTORCH = 1337
63 | API_CAFFE2 = 1338
64 | API_C10 = 1339
65 |
--------------------------------------------------------------------------------
/third_party/hipify_torch/hipify_torch/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '1.0.0'
2 |
--------------------------------------------------------------------------------
/third_party/hipify_torch/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | import setuptools.command.install
3 |
4 | ## extending the install command functionality.
5 | class install(setuptools.command.install.install):
6 | def run(self):
7 | print ("INFO: Installing hipify_torch")
8 | setuptools.command.install.install.run(self)
9 | print ("OK: Successfully installed hipify_torch")
10 |
11 | cmd_class = {
12 | "install" : install,
13 | }
14 |
15 | setup(
16 | name='hipify_torch',
17 | version='1.0',
18 | cmdclass=cmd_class,
19 | packages=['hipify_torch',],
20 | long_description=open('README.md').read(),
21 | )
22 |
--------------------------------------------------------------------------------
/third_party/hipify_torch/test/test_installation.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | class TestHipifyInstallation(unittest.TestCase):
4 | def test_hipify_torch_installation(self):
5 | try:
6 | from hipify_torch import hipify_python
7 | except ImportError:
8 | print ("ERROR: please install hipify_torch using setup.py install")
9 | raise ImportError('Install hipify_torch module')
10 |
11 | if __name__ == '__main__':
12 | unittest.main()
13 |
--------------------------------------------------------------------------------
/third_party/hipify_torch/tools/replace_cuda_with_hip_files.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the BSD-style license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | import os
10 | import sys
11 | import argparse
12 | import json
13 |
14 | def main():
15 | parser = argparse.ArgumentParser(description="")
16 | parser.add_argument(
17 | '--io-file',
18 | type=str,
19 | help="Input file containing list of files which will be overwritten by hipified file names",
20 | required=True)
21 |
22 | parser.add_argument(
23 | '--dump-dict-file',
24 | type=str,
25 | help="Input file where the dictionary output of hipify is stored",
26 | required=True)
27 |
28 | args = parser.parse_args()
29 |
30 | file_obj = open(args.dump_dict_file, mode='r')
31 | json_string = file_obj.read()
32 | file_obj.close()
33 | hipified_result = json.loads(json_string)
34 |
35 | out_list = []
36 | with open(args.io_file) as inp_file:
37 | for line in inp_file:
38 | line = line.strip()
39 | line = os.path.abspath(line)
40 | if line in hipified_result:
41 | out_list.append(hipified_result[line]['hipified_path'])
42 | else:
43 | out_list.append(line)
44 |
45 | w_file_obj = open(args.io_file, mode='w')
46 | for f in out_list:
47 | w_file_obj.write(f+"\n")
48 | w_file_obj.close()
49 |
50 | if __name__ == "__main__":
51 | main()
52 |
--------------------------------------------------------------------------------
/third_party/tfacc/driver/tfacc2/Makefile:
--------------------------------------------------------------------------------
1 | obj-m += tfacc2.o
2 | PWD=$(shell pwd)
3 |
4 | #LINUX_VERSION=$(LINUX_VERSION)
5 | KERNEL_VERSION=$(shell uname -r)
6 |
7 | ifeq ($(LINUX_VERSION),Ubuntu)
8 | KDIR=/usr/src/linux-headers-$(KERNEL_VERSION)
9 | else
10 | KDIR=/usr/src/kernels/$(KERNEL_VERSION)/
11 | endif
12 |
13 |
14 | install:
15 | bash ./build_driver.sh
16 |
17 | tfacc2:
18 | make -C $(KDIR) M=$(PWD) modules
19 | make -C $(KDIR) M=$(PWD) modules_install
20 | depmod -A
21 | clean:
22 | make -C $(KDIR) M=$(PWD) clean
23 |
--------------------------------------------------------------------------------
/third_party/tfacc/driver/tfacc2/build_driver.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | base_dir=$(
4 | cd "$(dirname "$0")" || exit
5 | pwd
6 | )
7 | cd "${base_dir}" || exit
8 |
9 | cp ../../tfsmi /usr/local/bin/tfsmi
10 | cp ../../tfsmbios /usr/local/bin/tfsmbios
11 |
12 | output_path=result/$1
13 |
14 | rm -rf "${output_path}"
15 | mkdir -p "${output_path}"
16 | cp -r $(ls | grep -v result | xargs) "${output_path}"
17 |
18 | cd "${output_path}" || exit
19 |
20 | export LINUX_VERSION=$(cat /etc/issue | awk -F ' ' '{print $1}' | awk 'NR==1')
21 | make tfacc2
22 |
--------------------------------------------------------------------------------
/third_party/tfacc/driver/tfacc2/modules.order:
--------------------------------------------------------------------------------
1 | /home/huangyuyang/Downloads/new/tfdl2/driver/tfacc2/tfacc2.ko
2 |
--------------------------------------------------------------------------------
/third_party/tfacc/launch.py:
--------------------------------------------------------------------------------
1 | import os
2 | import threading
3 | import sys
4 |
5 | def run(cmd):
6 | os.system(cmd)
7 |
8 | #modelName = "/root/flmModels/chatglm3-6b-int4.flm"
9 | #modelName = "/root/flmModels/qwen1.5-36B-chat-int4.flm"
10 |
11 | total = int(sys.argv[1])
12 |
13 | for i in range(total):
14 | st = i * 40
15 | end = st + 39
16 | cmd = "numactl -C " + str(st) + "-" + str(end) + " -m " + str(i) + " ./server " + str(i) + " " + str(total)
17 | print(cmd)
18 | (threading.Thread(target = run, args = ([cmd]) )).start()
19 |
20 |
21 |
22 |
23 |
--------------------------------------------------------------------------------
/third_party/tfacc/pull.sh:
--------------------------------------------------------------------------------
1 | cp ../../../TFACCComputeServer/build/server .
2 | cp ../../../TFACCComputeServer/launch.py .
3 | cp -r ../../../TFACCComputeServer/driver .
--------------------------------------------------------------------------------
/third_party/tfacc/server:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ztxz16/fastllm/170090720d2585f065b3dd95bdafa79dcf6913e8/third_party/tfacc/server
--------------------------------------------------------------------------------
/tools/fastllm_pytools/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ["llm"]
2 |
3 | from importlib.metadata import version
4 | try:
5 | __version__ = version("ftllm") # 从安装的元数据读取
6 | except:
7 | __version__ = version("ftllm-rocm") # 从安装的元数据读取
--------------------------------------------------------------------------------
/tools/fastllm_pytools/chat.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from .util import make_normal_parser
3 | import readline
4 |
5 | def args_parser():
6 | parser = make_normal_parser('fastllm_chat')
7 | args = parser.parse_args()
8 | return args
9 |
10 | def fastllm_chat(args):
11 | from .util import make_normal_llm_model
12 | model = make_normal_llm_model(args)
13 |
14 | generation_config = {
15 | 'repetition_penalty': 1.0,
16 | 'top_p': 0.8,
17 | 'top_k': 1,
18 | 'temperature': 1.0
19 | }
20 | import os
21 | import json
22 | if (os.path.exists(os.path.join(args.path, "generation_config.json"))):
23 | with open(os.path.join(args.path, "generation_config.json"), "r", encoding="utf-8") as file:
24 | config = json.load(file)
25 | if ('do_sample' in config and config['do_sample']):
26 | for it in ["repetition_penalty", "top_p", "top_k", "temperature"]:
27 | if (it in config):
28 | generation_config[it] = config[it];
29 |
30 | hint = "输入内容开始对话\n'clear'清空记录\n'stop'终止程序."
31 | history = []
32 |
33 | print(hint)
34 | while True:
35 | query = input("\nUser:")
36 | if query.strip() == "stop":
37 | break
38 | if query.strip() == "clear":
39 | history = []
40 | print(hint)
41 | continue
42 | print("AI:", end = "");
43 | curResponse = "";
44 | for response in model.stream_response(query, history = history,
45 | repeat_penalty = generation_config["repetition_penalty"],
46 | top_p = generation_config["top_p"],
47 | top_k = generation_config["top_k"],
48 | temperature = generation_config["temperature"]):
49 | curResponse += response;
50 | print(response, flush = True, end = "")
51 | history.append((query, curResponse))
52 | model.release_memory()
53 |
54 | if __name__ == "__main__":
55 | args = args_parser()
56 | fastllm_chat(args)
--------------------------------------------------------------------------------
/tools/fastllm_pytools/export.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from ftllm import llm
3 | from .util import make_normal_parser
4 | from .util import make_normal_llm_model
5 | import readline
6 |
7 | def args_parser():
8 | parser = make_normal_parser('fastllm_export')
9 | parser.add_argument('-o', '--output', type = str, required = True, help = '导出路径')
10 | args = parser.parse_args()
11 | return args
12 |
13 | if __name__ == "__main__":
14 | args = args_parser()
15 | llm.export_llm_model_fromhf(path = args.path, dtype = args.dtype, moe_dtype = args.moe_dtype, lora = args.lora, output = args.output)
--------------------------------------------------------------------------------
/tools/fastllm_pytools/openai_server/fastllm_embed.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import logging
3 | import json
4 | import traceback
5 | from fastapi import Request
6 |
7 | from .protocal.openai_protocol import *
8 |
9 | class FastLLmEmbed:
10 | def __init__(self,
11 | model_name,
12 | model):
13 | self.model_name = model_name
14 | self.model = model
15 |
16 | def embedding_sentence(self, request: EmbedRequest, raw_request: Request):
17 | return self.model.embedding_sentence(request.inputs, request.normalize)
18 |
--------------------------------------------------------------------------------
/tools/fastllm_pytools/openai_server/fastllm_model.py:
--------------------------------------------------------------------------------
1 |
2 | class FastLLmModel:
3 | def __init__(self,
4 | model_name,
5 | ):
6 | self.model_name = model_name
7 | data = [
8 | {
9 | "id": model_name,
10 | "object": "model",
11 | "owned_by": "fastllm",
12 | "permission": []
13 | }
14 | ]
15 | self.response = {
16 | "data": data,
17 | "object": "list"
18 | }
--------------------------------------------------------------------------------
/tools/fastllm_pytools/openai_server/fastllm_reranker.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import logging
3 | import json
4 | import traceback
5 | from fastapi import Request
6 |
7 | from .protocal.openai_protocol import *
8 |
9 | class FastLLmReranker:
10 | def __init__(self,
11 | model_name,
12 | model):
13 | self.model_name = model_name
14 | self.model = model
15 |
16 | def rerank(self, request: RerankRequest, raw_request: Request):
17 | query = request.query
18 | pairs = []
19 | for text in request.texts:
20 | pairs.append([query, text])
21 | scores = self.model.reranker_compute_score(pairs = pairs)
22 | ret = []
23 | for i in range(len(request.texts)):
24 | now = {'index': i, 'score': scores[i]}
25 | if (request.return_text):
26 | now['text'] = request.texts[i]
27 | ret.append(now)
28 | ret = sorted(ret, key = lambda x : -x['score'])
29 | return ret
30 |
--------------------------------------------------------------------------------
/tools/fastllm_pytools/webui.py:
--------------------------------------------------------------------------------
1 | try:
2 | import streamlit as st
3 | except:
4 | print("Plase install streamlit-chat. (pip install streamlit-chat)")
5 | exit(0)
6 |
7 | import os
8 | import sys
9 |
10 | if __name__ == "__main__":
11 | current_path = os.path.dirname(os.path.abspath(__file__))
12 | web_demo_path = os.path.join(current_path, 'web_demo.py')
13 | port = ""
14 | for i in range(len(sys.argv)):
15 | if sys.argv[i] == "--port":
16 | port = "--server.port " + sys.argv[i + 1]
17 | if sys.argv[i] == "--help" or sys.argv[i] == "-h":
18 | os.system("python3 " + web_demo_path + " --help")
19 | exit(0)
20 | os.system("streamlit run " + port + " " + web_demo_path + ' -- ' + ' '.join(sys.argv[1:]))
--------------------------------------------------------------------------------
/tools/scripts/alpaca2flm.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 | from transformers import AutoTokenizer, LlamaForCausalLM
4 | from ftllm import torch2flm
5 |
6 | if __name__ == "__main__":
7 | model_name = sys.argv[3] if len(sys.argv) >= 4 else 'minlik/chinese-alpaca-33b-merged'
8 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
9 | # `torch_dtype=torch.float16` is set by default, if it will not cause an OOM Error, you can load model in float32.
10 | model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
11 | conf = model.config.__dict__
12 | conf["model_type"] = "llama"
13 | dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16"
14 | exportPath = sys.argv[1] if len(sys.argv) >= 2 else "alpaca-33b-" + dtype + ".flm"
15 | # add custom code here
16 | torch2flm.tofile(exportPath, model, tokenizer, dtype = dtype)
17 |
--------------------------------------------------------------------------------
/tools/scripts/baichuan2_2flm.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 | from transformers import AutoModelForCausalLM, AutoTokenizer
4 | from transformers.generation.utils import GenerationConfig
5 | from ftllm import torch2flm
6 |
7 | if __name__ == "__main__":
8 | modelpath = "baichuan-inc/Baichuan2-7B-Chat"
9 | tokenizer = AutoTokenizer.from_pretrained(modelpath, use_fast=False, trust_remote_code=True)
10 | model = AutoModelForCausalLM.from_pretrained(modelpath, device_map="auto", torch_dtype=torch.float32, trust_remote_code=True)
11 |
12 | # normalize lm_head
13 | state_dict = model.state_dict()
14 | state_dict['lm_head.weight'] = torch.nn.functional.normalize(state_dict['lm_head.weight'])
15 | model.load_state_dict(state_dict)
16 |
17 | try:
18 | model.generation_config = GenerationConfig.from_pretrained(modelpath)
19 | except:
20 | pass
21 |
22 | dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16"
23 | exportPath = sys.argv[1] if len(sys.argv) >= 2 else "baichuan2-7b-" + dtype + ".flm"
24 | torch2flm.tofile(exportPath, model.to('cpu'), tokenizer, dtype=dtype)
--------------------------------------------------------------------------------
/tools/scripts/baichuan2flm.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 | from transformers import AutoModelForCausalLM, AutoTokenizer
4 | from transformers.generation.utils import GenerationConfig
5 | from ftllm import torch2flm
6 |
7 | if __name__ == "__main__":
8 | modelpath = "baichuan-inc/baichuan-13B-Chat"
9 | tokenizer = AutoTokenizer.from_pretrained(modelpath, trust_remote_code=True)
10 | model = AutoModelForCausalLM.from_pretrained(modelpath, device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)
11 | model.to("cpu")
12 | try:
13 | model.generation_config = GenerationConfig.from_pretrained(modelpath)
14 | except:
15 | pass
16 | dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16"
17 | exportPath = sys.argv[1] if len(sys.argv) >= 2 else "baichuan-13b-" + dtype + ".flm"
18 | torch2flm.tofile(exportPath, model, tokenizer, dtype = dtype)
19 |
--------------------------------------------------------------------------------
/tools/scripts/bert2flm.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from transformers import AutoTokenizer, AutoModel
3 | from ftllm import torch2flm
4 |
5 | if __name__ == "__main__":
6 | modelpath = sys.argv[3] if len(sys.argv) >= 4 else 'BAAI/bge-small-zh-v1.5'
7 | tokenizer = AutoTokenizer.from_pretrained(modelpath)
8 | model = AutoModel.from_pretrained(modelpath).cpu().float()
9 | model = model.eval()
10 |
11 | dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16"
12 | exportPath = sys.argv[1] if len(sys.argv) >= 2 else "bert-" + dtype + ".flm"
13 | torch2flm.tofile(exportPath, model, tokenizer, dtype = dtype)
--------------------------------------------------------------------------------
/tools/scripts/chatglm_export.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from transformers import AutoTokenizer, AutoModel
3 | from ftllm import torch2flm
4 |
5 | if __name__ == "__main__":
6 | modelNameOrPath = sys.argv[3] if len(sys.argv) >= 4 else 'THUDM/chatglm2-6b'
7 | tokenizer = AutoTokenizer.from_pretrained(modelNameOrPath, trust_remote_code=True)
8 | model = AutoModel.from_pretrained(modelNameOrPath, trust_remote_code=True)
9 | model = model.eval()
10 |
11 | dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16"
12 | exportPath = sys.argv[1] if len(sys.argv) >= 2 else "chatglm-6b-" + dtype + ".flm"
13 | torch2flm.tofile(exportPath, model, tokenizer, dtype = dtype)
14 |
--------------------------------------------------------------------------------
/tools/scripts/cli_demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from ftllm import llm
3 | import readline
4 |
5 | def args_parser():
6 | parser = argparse.ArgumentParser(description = 'fastllm_chat_demo')
7 | parser.add_argument('-p', '--path', type = str, required = True, default = '', help = '模型文件的路径')
8 | parser.add_argument('-t', '--threads', type=int, default=4, help='使用的线程数量')
9 | parser.add_argument('-l', '--low', action='store_true', help='使用低内存模式')
10 | args = parser.parse_args()
11 | return args
12 |
13 | if __name__ == "__main__":
14 | args = args_parser()
15 | llm.set_cpu_threads(args.threads)
16 | llm.set_cpu_low_mem(args.low)
17 | model = llm.model(args.path)
18 | model.set_save_history(True)
19 |
20 | history = []
21 | print("输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
22 | while True:
23 | query = input("\n用户:")
24 | if query.strip() == "stop":
25 | break
26 | if query.strip() == "clear":
27 | history = []
28 | print("输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
29 | continue
30 | print("AI:", end = "");
31 | curResponse = "";
32 | for response in model.stream_response(query, history = history):
33 | curResponse += response;
34 | print(response, flush = True, end = "")
35 | history.append((query, curResponse))
36 | model.release_memory()
--------------------------------------------------------------------------------
/tools/scripts/llama3_to_flm.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 | from transformers import AutoTokenizer, AutoModelForCausalLM
4 | from ftllm import torch2flm
5 |
6 | if __name__ == "__main__":
7 | modelNameOrPath = sys.argv[3] if len(sys.argv) >= 4 else 'meta-llama/Meta-Llama-3-8B'
8 | tokenizer = AutoTokenizer.from_pretrained(modelNameOrPath, trust_remote_code=True);
9 | # `torch_dtype=torch.float16` is set by default, if it will not cause an OOM Error, you can load model in float32.
10 | model = AutoModelForCausalLM.from_pretrained(modelNameOrPath, trust_remote_code=True, torch_dtype=torch.float16)
11 | model = model.eval()
12 | dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16"
13 | exportPath = sys.argv[1] if len(sys.argv) >= 2 else model.config.model_type + "-7b-" + dtype + ".flm"
14 | torch2flm.tofile(exportPath, model, tokenizer,
15 | pre_prompt="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are a helpful assistant.<|eot_id|>",
16 | user_role="<|start_header_id|>user<|end_header_id|>\n",
17 | bot_role="<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
18 | history_sep="<|eot_id|>\n",
19 | eos_id = tokenizer.convert_tokens_to_ids("<|eot_id|>"),
20 | dtype = dtype)
21 |
--------------------------------------------------------------------------------
/tools/scripts/llamalike2flm.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 | from transformers import AutoTokenizer, AutoModelForCausalLM
4 | from ftllm import torch2flm
5 |
6 | if __name__ == "__main__":
7 | modelNameOrPath = sys.argv[3] if len(sys.argv) >= 4 else 'qwen/Qwen1.5-7B-Chat'
8 | tokenizer = AutoTokenizer.from_pretrained(modelNameOrPath, trust_remote_code=True);
9 | # `torch_dtype=torch.float16` is set by default, if it will not cause an OOM Error, you can load model in float32.
10 | model = AutoModelForCausalLM.from_pretrained(modelNameOrPath, trust_remote_code=True, torch_dtype=torch.float16)
11 | model = model.eval()
12 | dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16"
13 | exportPath = sys.argv[1] if len(sys.argv) >= 2 else model.config.model_type + "-7b-" + dtype + ".flm"
14 | if model.config.model_type == "internlm":
15 | torch2flm.tofile(exportPath, model, tokenizer, pre_prompt = "",
16 | user_role = "<|User|>:", bot_role = "\n<|Bot|>:",
17 | history_sep = "\n", dtype = dtype)
18 | elif model.config.model_type == "internlm2":
19 | torch2flm.tofile(exportPath, model, tokenizer, pre_prompt="<|im_start|>system\nYou are an AI assistant whose name is InternLM (书生·浦语).\n<|im_end|>",
20 | user_role="<|im_start|>user\n", bot_role="<|im_end|><|im_start|>assistant\n", history_sep="<|im_end|>\n", dtype = dtype)
21 | elif model.config.model_type == "qwen2":
22 | torch2flm.tofile(exportPath, model, tokenizer, pre_prompt="<|im_start|>system\nYou are a helpful assistant.<|im_end|>", user_role="<|im_start|>user\n",
23 | bot_role="<|im_end|><|im_start|>assistant\n", history_sep="<|im_end|>\n", dtype = dtype)
24 | elif model.config.model_type == "qwen2_moe":
25 | torch2flm.tofile(exportPath, model, tokenizer, pre_prompt="<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n", user_role="<|im_start|>user\n",
26 | bot_role="<|im_end|>\n<|im_start|>assistant\n", history_sep="<|im_end|>\n", eos_id = tokenizer.eos_token_id, dtype = dtype)
27 | # add custom code here
28 | else:
29 | torch2flm.tofile(exportPath, model, tokenizer, pre_prompt = "", user_role = "",
30 | bot_role = "", history_sep = "", dtype = dtype)
31 |
--------------------------------------------------------------------------------
/tools/scripts/minicpm2flm.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 | from transformers import AutoTokenizer, AutoModelForCausalLM
4 | from ftllm import torch2flm
5 |
6 | if __name__ == "__main__":
7 | modelNameOrPath = sys.argv[3] if len(sys.argv) >= 4 else "openbmb/MiniCPM-2B-dpo-fp16"
8 | tokenizer = AutoTokenizer.from_pretrained(modelNameOrPath, use_fast=False, trust_remote_code=True)
9 | # `torch_dtype=torch.float16` is set by default, if it will not cause an OOM Error, you can load model in float32.
10 | model = AutoModelForCausalLM.from_pretrained(modelNameOrPath, trust_remote_code=True, torch_dtype=torch.float16)
11 | model = model.eval()
12 |
13 | dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16"
14 | exportPath = sys.argv[1] if len(sys.argv) >= 2 else "minicpm-2b-" + dtype + ".flm"
15 |
16 | if model.config.architectures == ["MiniCPMForCausalLM"]:
17 | model.config.model_type = "minicpm"
18 | torch2flm.tofile(exportPath, model, tokenizer, pre_prompt = "", user_role = "<用户>",
19 | bot_role = "", history_sep = "", dtype = dtype)
20 | else:
21 | torch2flm.tofile(exportPath, model, tokenizer, pre_prompt="", user_role="<|im_start|>user\n",
22 | bot_role="<|im_end|>\n<|im_start|>assistant\n", history_sep="<|im_end|>\n", eos_id = tokenizer.eos_token_id, dtype = dtype)
23 |
--------------------------------------------------------------------------------
/tools/scripts/moss_export.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from transformers import AutoTokenizer, AutoModelForCausalLM
3 | from ftllm import torch2flm
4 |
5 | tokenizer = AutoTokenizer.from_pretrained("fnlp/moss-moon-003-sft", trust_remote_code=True);
6 | model = AutoModelForCausalLM.from_pretrained("fnlp/moss-moon-003-sft", trust_remote_code=True).float();
7 | model = model.eval();
8 |
9 | if __name__ == "__main__":
10 | dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16"
11 | exportPath = sys.argv[1] if len(sys.argv) >= 2 else "moss-" + dtype + ".flm"
12 | torch2flm.tofile(exportPath, model, tokenizer, dtype = dtype)
--------------------------------------------------------------------------------
/tools/scripts/qwen2flm.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from transformers import AutoModelForCausalLM, AutoTokenizer
3 | from transformers.generation import GenerationConfig
4 | from ftllm import torch2flm
5 |
6 | if __name__ == "__main__":
7 | tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True)
8 | model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="cpu", trust_remote_code=True, fp32=True).eval()
9 | model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True) # 可指定不同的生成长度、top_p等相关超参
10 |
11 | dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16"
12 | exportPath = sys.argv[1] if len(sys.argv) >= 2 else "qwen-7b-" + dtype + ".flm"
13 | torch2flm.tofile(exportPath, model, tokenizer, dtype = dtype)
--------------------------------------------------------------------------------
/tools/scripts/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | server_require = ['fastapi', 'pydantic', 'openai', 'shortuuid', 'uvicorn']
4 | webui_require = ['streamlit-chat']
5 | download_require = ['aria2']
6 | all_require = server_require + webui_require + download_require
7 |
8 | setup (
9 | name = "ftllm",
10 | version = "0.1.2.0",
11 | author = "huangyuyang",
12 | author_email = "ztxz16@foxmail.com",
13 | description = "Fastllm",
14 | url = "https://github.com/ztxz16/fastllm",
15 | entry_points = {
16 | 'console_scripts' : [
17 | 'ftllm=ftllm.cli:main'
18 | ]
19 | },
20 | packages = ['ftllm', 'ftllm/openai_server', 'ftllm/openai_server/protocal'],
21 | package_data = {
22 | '': ['*.dll', '*.so', '*.dylib', '*.so.*']
23 | },
24 | install_requires=[
25 | 'pyreadline3',
26 | 'transformers',
27 | 'jinja2>=3.1.0'
28 | ] + all_require,
29 | extras_require={
30 | 'all': all_require,
31 | 'server': server_require,
32 | 'webui': webui_require
33 | },
34 | )
35 |
--------------------------------------------------------------------------------
/tools/scripts/setup_rocm.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | server_require = ['fastapi', 'pydantic', 'openai', 'shortuuid', 'uvicorn']
4 | webui_require = ['streamlit-chat']
5 | download_require = ['aria2']
6 | all_require = server_require + webui_require + download_require
7 |
8 | setup (
9 | name = "ftllm_rocm",
10 | version = "0.1.2.0",
11 | author = "huangyuyang",
12 | author_email = "ztxz16@foxmail.com",
13 | description = "Fastllm",
14 | url = "https://github.com/ztxz16/fastllm",
15 | entry_points = {
16 | 'console_scripts' : [
17 | 'ftllm=ftllm.cli:main'
18 | ]
19 | },
20 | packages = ['ftllm', 'ftllm/openai_server', 'ftllm/openai_server/protocal'],
21 | package_data = {
22 | '': ['*.dll', '*.so', '*.dylib', '*.so.*']
23 | },
24 | install_requires=[
25 | 'pyreadline3',
26 | 'transformers',
27 | 'jinja2>=3.1.0'
28 | ] + all_require,
29 | extras_require={
30 | 'all': all_require,
31 | 'server': server_require,
32 | 'webui': webui_require
33 | },
34 | )
35 |
--------------------------------------------------------------------------------
/tools/src/quant.cpp:
--------------------------------------------------------------------------------
1 | //
2 | // Created by huangyuyang on 5/13/23.
3 | //
4 |
5 | #include
6 | #include "model.h"
7 |
8 | struct QuantConfig {
9 | std::string path; // 模型文件路径
10 | std::string output; // 输出文件路径
11 | int bits; // 量化位数
12 | };
13 |
14 | void Usage() {
15 | std::cout << "Usage:" << std::endl;
16 | std::cout << "[-h|--help]: 显示帮助" << std::endl;
17 | std::cout << "<-p|--path> : 模型文件的路径" << std::endl;
18 | std::cout << "<-b|--bits> : 量化位数, 4 = int4, 8 = int8, 16 = fp16" << std::endl;
19 | std::cout << "<-o|--output> : 输出文件路径" << std::endl;
20 | }
21 |
22 | void ParseArgs(int argc, char **argv, QuantConfig &config) {
23 | std::vector sargv;
24 | for (int i = 0; i < argc; i++) {
25 | sargv.push_back(std::string(argv[i]));
26 | }
27 | for (int i = 1; i < argc; i++) {
28 | if (sargv[i] == "-h" || sargv[i] == "--help") {
29 | Usage();
30 | exit(0);
31 | } else if (sargv[i] == "-p" || sargv[i] == "--path") {
32 | config.path = sargv[++i];
33 | } else if (sargv[i] == "-b" || sargv[i] == "--bits") {
34 | config.bits = atoi(sargv[++i].c_str());
35 | } else if (sargv[i] == "-o" || sargv[i] == "--output") {
36 | config.output = sargv[++i];
37 | } else if (sargv[i] == "-m" || sargv[i] == "--model") {
38 | i++;
39 | } else {
40 | Usage();
41 | exit(-1);
42 | }
43 | }
44 | }
45 |
46 | int main(int argc, char **argv) {
47 | QuantConfig config;
48 | ParseArgs(argc, argv, config);
49 | auto model = fastllm::CreateLLMModelFromFile(config.path);
50 | model->SaveLowBitModel(config.output, config.bits);
51 | return 0;
52 | }
--------------------------------------------------------------------------------
/whl_docker/Dockerfile:
--------------------------------------------------------------------------------
1 | # 使用 Ubuntu 20.04 作为基础镜像
2 | FROM ubuntu:20.04
3 |
4 | # 设置非交互式安装环境
5 | ENV DEBIAN_FRONTEND=noninteractive
6 |
7 | # 安装基础工具和GCC版本
8 | RUN apt-get update && apt-get install -y --no-install-recommends \
9 | build-essential \
10 | software-properties-common \
11 | wget \
12 | ca-certificates \
13 | && add-apt-repository -y ppa:ubuntu-toolchain-r/test \
14 | && apt-get update \
15 | && apt-get install -y \
16 | g++-10 \
17 | g++-11 \
18 | && rm -rf /var/lib/apt/lists/*
19 |
20 | # 设置G++多版本配置
21 | RUN update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-10 50 \
22 | && update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-11 60 \
23 | && update-alternatives --set g++ /usr/bin/g++-11
24 |
25 | # 安装CUDA 11.3
26 | RUN wget https://developer.download.nvidia.com/compute/cuda/11.3.0/local_installers/cuda_11.3.0_465.19.01_linux.run \
27 | && sh cuda_11.3.0_465.19.01_linux.run --silent --toolkit --override \
28 | --no-drm --no-man-page --no-opengl-libs --installpath=/usr/local/cuda-11.3 \
29 | && rm cuda_11.3.0_465.19.01_linux.run
30 |
31 | # 安装CUDA 12.1
32 | RUN wget https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run \
33 | && sh cuda_12.1.0_530.30.02_linux.run --silent --toolkit --override \
34 | --no-drm --no-man-page --no-opengl-libs --installpath=/usr/local/cuda-12.1 \
35 | && rm cuda_12.1.0_530.30.02_linux.run
36 |
37 | # 配置CUDA环境变量(默认使用11.3)
38 | ENV PATH=/usr/local/cuda-11.3/bin:${PATH}
39 | ENV LD_LIBRARY_PATH=/usr/local/cuda-11.3/lib64:${LD_LIBRARY_PATH}
40 |
41 | # 创建版本切换符号链接
42 | RUN ln -sf /usr/local/cuda-11.3 /usr/local/cuda
43 |
44 | # 验证安装
45 | RUN g++ --version
46 | RUN nvcc --version
47 |
48 | RUN apt-get update
49 | RUN apt-get install cmake -y
50 | RUN apt-get install libnuma-dev -y
51 | RUN apt-get install python3-pip -y
52 | RUN pip install setuptools wheel
53 |
54 | RUN apt remove --purge cmake -y
55 | RUN pip install cmake==3.25.0 -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
56 |
--------------------------------------------------------------------------------
/whl_docker_rocm/24.04/Dockerfile:
--------------------------------------------------------------------------------
1 | # 使用 Ubuntu 20.04 作为基础镜像
2 | FROM ubuntu:24.04
3 |
4 | # 设置非交互式安装环境
5 | ENV DEBIAN_FRONTEND=noninteractive
6 |
7 | # 安装基础工具和GCC版本
8 | RUN apt-get update && apt-get install -y --no-install-recommends \
9 | build-essential \
10 | software-properties-common \
11 | wget \
12 | ca-certificates \
13 | && add-apt-repository -y ppa:ubuntu-toolchain-r/test \
14 | && apt-get update \
15 | && apt-get install -y \
16 | g++-10 \
17 | g++-11 \
18 | && rm -rf /var/lib/apt/lists/*
19 |
20 | # 设置G++多版本配置
21 | RUN update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-10 50 \
22 | && update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-11 60 \
23 | && update-alternatives --set g++ /usr/bin/g++-11
24 |
25 | # 验证安装
26 | RUN g++ --version
27 |
28 | RUN apt-get update
29 | RUN apt-get install cmake -y
30 | RUN apt-get install libnuma-dev -y
31 | RUN apt-get install python3-pip -y
32 |
33 | RUN wget https://repo.radeon.com/amdgpu-install/6.3.3/ubuntu/noble/amdgpu-install_6.3.60303-1_all.deb
34 | RUN apt install ./amdgpu-install_6.3.60303-1_all.deb -y
35 | RUN amdgpu-install --usecase=hiplibsdk,rocm,dkms -y
36 | RUN apt-get install python-is-python3 -y
37 | RUN apt-get install python3.12-venv -y
38 | RUN python3 -m venv ~/ftllm
39 | #RUN source ~/ftllm/bin/activate
40 | #RUN pip install setuptools wheel -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
41 |
42 | #RUN apt remove --purge cmake -y
43 | #RUN pip install cmake==3.25.0 -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
44 |
--------------------------------------------------------------------------------
/whl_docker_rocm/Dockerfile:
--------------------------------------------------------------------------------
1 | # 使用 Ubuntu 22.04 作为基础镜像
2 | FROM ubuntu:22.04
3 |
4 | # 设置非交互式安装环境
5 | ENV DEBIAN_FRONTEND=noninteractive
6 |
7 | # 安装基础工具和GCC版本
8 | RUN apt-get update && apt-get install gcc g++ make -y
9 | # 验证安装
10 | RUN g++ --version
11 |
12 | RUN apt-get update
13 | RUN apt-get install cmake -y
14 | RUN apt-get install libnuma-dev -y
15 | RUN apt-get install python3-pip -y
16 | RUN apt-get install wget -y
17 |
18 | RUN wget https://repo.radeon.com/amdgpu-install/6.3.3/ubuntu/jammy/amdgpu-install_6.3.60303-1_all.deb
19 | RUN apt install ./amdgpu-install_6.3.60303-1_all.deb -y
20 | RUN amdgpu-install --usecase=hiplibsdk,rocm,dkms -y
21 | RUN apt-get install python-is-python3 -y
22 | RUN pip install setuptools wheel -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
23 | RUN apt install libstdc++-12-dev
24 |
--------------------------------------------------------------------------------