├── .gitattributes ├── .github ├── dependabot.yml └── workflows │ ├── docker.yml │ ├── pre-release.yml │ └── release.yml ├── .gitignore ├── .vscode ├── launch.json ├── settings.json └── tasks.json ├── CURRENT_CHANGE.md ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── README_JA.md ├── README_ZH.md ├── assets ├── default_sound_font.sf2 ├── sound-font │ ├── sound_fetch.py │ └── soundfont.json └── soundfont_builder.rb ├── backend-golang ├── app.go ├── cmd_interactive.go ├── cmd_interactive_unix.go ├── cmd_interactive_windows.go ├── download.go ├── file.go ├── hw_info.go ├── midi.go ├── rwkv.go ├── utils.go ├── utils_unix.go ├── utils_windows.go ├── wsl_unix.go └── wsl_windows.go ├── backend-python ├── convert_model.py ├── convert_pytorch_to_ggml.py ├── convert_safetensors.py ├── dep_check.py ├── get-pip.py ├── global_var.py ├── main.py ├── requirements.txt ├── requirements_without_cyac.txt ├── routes │ ├── completion.py │ ├── config.py │ ├── file_process.py │ ├── midi.py │ ├── misc.py │ ├── schema.py │ └── state_cache.py ├── rwkv_pip │ ├── 20B_tokenizer.json │ ├── cpp │ │ ├── librwkv.dylib │ │ ├── librwkv.so │ │ ├── model.py │ │ ├── rwkv.dll │ │ ├── rwkv_cpp_model.py │ │ └── rwkv_cpp_shared_library.py │ ├── cuda │ │ ├── gemm_fp16_cublas.cpp │ │ ├── operators.cu │ │ ├── rwkv5.cu │ │ ├── rwkv5_op.cpp │ │ ├── rwkv6.cu │ │ ├── rwkv6_op.cpp │ │ ├── rwkv7.cu │ │ ├── rwkv7_op.cpp │ │ └── wrapper.cpp │ ├── kernels │ │ ├── torch-1.13.1+cu117 │ │ │ ├── rwkv5.pyd │ │ │ ├── rwkv6.pyd │ │ │ ├── wkv7s.pyd │ │ │ └── wkv_cuda.pyd │ │ └── torch-2.7.1+cu128 │ │ │ ├── rwkv5.pyd │ │ │ ├── rwkv6.pyd │ │ │ ├── wkv7s.pyd │ │ │ └── wkv_cuda.pyd │ ├── model.py │ ├── rwkv_tokenizer.py │ ├── rwkv_vocab_v20230424.txt │ ├── tokenizer-midi.json │ ├── tokenizer-midipiano.json │ ├── utils.py │ └── webgpu │ │ ├── model.py │ │ └── web_rwkv_py.cp310-win_amd64.pyd ├── tests │ ├── function_call.py │ ├── function_call_stream.py │ └── postprocess_response.py ├── utils │ ├── llama.py │ ├── log.py │ ├── midi.py │ ├── midi_filter_config.json │ ├── midi_vocab_config.json │ ├── ngrok.py │ ├── rwkv.py │ ├── torch.py │ └── vocab_config_piano.json └── webui_server.py ├── backend-rust └── assets │ └── rwkv_vocab_v20230424.json ├── build ├── README.md ├── appicon.png ├── darwin │ ├── Info.dev.plist │ ├── Info.plist │ ├── Readme_Install.txt │ ├── entitlements.plist │ └── gon-sign.json ├── linux │ └── Readme_Install.txt └── windows │ ├── Readme_Install.txt │ ├── WELCOMEFINISHPAGE.bmp │ ├── icon.ico │ ├── info.json │ ├── installer │ ├── project.nsi │ └── wails_tools.nsh │ └── wails.exe.manifest ├── components └── gitkeep ├── deploy-examples ├── ChatGPT-Next-Web │ ├── setup.bat │ └── setup.sh └── RWKV-Runner-WebUI │ ├── setup.bat │ └── setup.sh ├── docker-compose.yml ├── exportModelsJson.js ├── finetune ├── data │ └── sample.jsonl ├── get_layer_and_embd.py ├── install-wsl-dep-and-train.sh ├── json2binidx_tool │ └── tools │ │ ├── indexed_dataset.py │ │ ├── preprocess_data.py │ │ ├── rwkv_tokenizer.py │ │ └── tokenizer.py ├── lora │ ├── merge_lora.py │ ├── v4 │ │ ├── cuda │ │ │ ├── wkv_cuda.cu │ │ │ ├── wkv_cuda_bf16.cu │ │ │ ├── wkv_op.cpp │ │ │ └── wkv_op_bf16.cpp │ │ ├── src │ │ │ ├── __init__.py │ │ │ ├── binidx.py │ │ │ ├── dataset.py │ │ │ ├── model.py │ │ │ ├── trainer.py │ │ │ └── utils.py │ │ └── train.py │ ├── v5 │ │ ├── cuda │ │ │ ├── wkv5_cuda.cu │ │ │ └── wkv5_op.cpp │ │ ├── src │ │ │ ├── __init__.py │ │ │ ├── binidx.py │ │ │ ├── dataset.py │ │ │ ├── model.py │ │ │ ├── trainer.py │ │ │ └── utils.py │ │ └── train.py │ └── v6 │ │ ├── cuda │ │ ├── wkv5_cuda.cu │ │ ├── wkv5_op.cpp │ │ ├── wkv6_cuda.cu │ │ ├── wkv6_op.cpp │ │ ├── wkv6infctx_cuda.cu │ │ ├── wkv6infctx_op.cpp │ │ ├── wkv6state_cuda.cu │ │ └── wkv6state_op.cpp │ │ ├── demo │ │ ├── demo-lora-merge.sh │ │ ├── demo-lora.sh │ │ ├── demo-pissa-merge.sh │ │ ├── demo-pissa.sh │ │ ├── demo-qpissa-pt.sh │ │ ├── demo-state-merge.sh │ │ ├── demo-state-tuning.sh │ │ ├── demo-training-prepare.sh │ │ ├── demo-training-run.sh │ │ └── infctx.sh │ │ ├── fla │ │ ├── __init__.py │ │ ├── layers │ │ │ ├── __init__.py │ │ │ ├── abc.py │ │ │ ├── based.py │ │ │ ├── delta_net.py │ │ │ ├── gated_abc.py │ │ │ ├── gla.py │ │ │ ├── hgrn.py │ │ │ ├── hgrn2.py │ │ │ ├── linear_attn.py │ │ │ ├── multiscale_retention.py │ │ │ ├── rebased.py │ │ │ ├── rwkv6.py │ │ │ └── simple_gla.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── abc │ │ │ │ ├── __init__.py │ │ │ │ ├── configuration_abc.py │ │ │ │ └── modeling_abc.py │ │ │ ├── delta_net │ │ │ │ ├── __init__.py │ │ │ │ ├── configuration_delta_net.py │ │ │ │ └── modeling_delta_net.py │ │ │ ├── gla │ │ │ │ ├── __init__.py │ │ │ │ ├── configuration_gla.py │ │ │ │ └── modeling_gla.py │ │ │ ├── hgrn │ │ │ │ ├── __init__.py │ │ │ │ ├── configuration_hgrn.py │ │ │ │ └── modeling_hgrn.py │ │ │ ├── hgrn2 │ │ │ │ ├── __init__.py │ │ │ │ ├── configuration_hgrn2.py │ │ │ │ └── modeling_hgrn2.py │ │ │ ├── linear_attn │ │ │ │ ├── __init__.py │ │ │ │ ├── configuration_linear_attn.py │ │ │ │ └── modeling_linear_attn.py │ │ │ ├── mamba │ │ │ │ ├── __init__.py │ │ │ │ ├── configuration_mamba.py │ │ │ │ └── modeling_mamba.py │ │ │ ├── retnet │ │ │ │ ├── __init__.py │ │ │ │ ├── configuration_retnet.py │ │ │ │ └── modeling_retnet.py │ │ │ ├── rwkv6 │ │ │ │ ├── __init__.py │ │ │ │ ├── configuration_rwkv6.py │ │ │ │ └── modeling_rwkv6.py │ │ │ ├── transformer │ │ │ │ ├── __init__.py │ │ │ │ ├── configuration_transformer.py │ │ │ │ └── modeling_transformer.py │ │ │ └── utils.py │ │ ├── modules │ │ │ ├── __init__.py │ │ │ ├── activations.py │ │ │ ├── convolution.py │ │ │ ├── feature_map.py │ │ │ ├── fused_cross_entropy.py │ │ │ ├── fused_norm_gate.py │ │ │ ├── l2norm.py │ │ │ ├── layernorm.py │ │ │ └── rotary.py │ │ ├── ops │ │ │ ├── __init__.py │ │ │ ├── abc │ │ │ │ ├── __init__.py │ │ │ │ ├── chunk.py │ │ │ │ ├── chunk_gate.py │ │ │ │ ├── naive.py │ │ │ │ └── recurrent_fuse.py │ │ │ ├── based │ │ │ │ ├── __init__.py │ │ │ │ ├── chunk_fuse.py │ │ │ │ ├── naive.py │ │ │ │ └── parallel.py │ │ │ ├── delta_rule │ │ │ │ ├── README.md │ │ │ │ ├── __init__.py │ │ │ │ ├── chunk.py │ │ │ │ ├── chunk_fuse.py │ │ │ │ ├── naive.py │ │ │ │ ├── recurrent_fuse.py │ │ │ │ ├── utils.py │ │ │ │ └── wy_fast.py │ │ │ ├── gla │ │ │ │ ├── __init__.py │ │ │ │ ├── chunk.py │ │ │ │ ├── chunk_fuse.py │ │ │ │ ├── chunk_util.py │ │ │ │ ├── naive.py │ │ │ │ └── recurrent_fuse.py │ │ │ ├── hgrn │ │ │ │ ├── __init__.py │ │ │ │ ├── chunk.py │ │ │ │ ├── naive.py │ │ │ │ └── recurrent_fuse.py │ │ │ ├── linear_attn │ │ │ │ ├── __init__.py │ │ │ │ ├── chunk.py │ │ │ │ ├── chunk_fuse.py │ │ │ │ ├── naive.py │ │ │ │ └── recurrent_fuse.py │ │ │ ├── rebased │ │ │ │ ├── __init__.py │ │ │ │ ├── naive.py │ │ │ │ └── parallel.py │ │ │ ├── retention │ │ │ │ ├── __init__.py │ │ │ │ ├── chunk.py │ │ │ │ ├── chunk_fuse.py │ │ │ │ ├── naive.py │ │ │ │ ├── parallel.py │ │ │ │ └── recurrent_fuse.py │ │ │ ├── rotary.py │ │ │ ├── rwkv4 │ │ │ │ ├── __init__.py │ │ │ │ └── recurrent_fuse.py │ │ │ ├── rwkv6 │ │ │ │ ├── __init__.py │ │ │ │ ├── chunk.py │ │ │ │ ├── chunk_naive.py │ │ │ │ ├── recurrent_fuse.py │ │ │ │ └── recurrent_naive.py │ │ │ ├── simple_gla │ │ │ │ ├── README.md │ │ │ │ ├── __init__.py │ │ │ │ ├── chunk.py │ │ │ │ └── naive.py │ │ │ └── utils.py │ │ └── utils.py │ │ ├── merge │ │ ├── merge.py │ │ ├── merge_lora.py │ │ ├── merge_pissa.py │ │ └── merge_state.py │ │ ├── requirements.txt │ │ ├── src │ │ ├── __init__.py │ │ ├── binidx.py │ │ ├── dataset.py │ │ ├── infctx_module.py │ │ ├── model.py │ │ ├── trainer.py │ │ └── utils.py │ │ └── train.py └── requirements.txt ├── frontend ├── i18nally.json ├── index.html ├── package-lock.json ├── package.json ├── postcss.config.js ├── prettier.config.js ├── src │ ├── App.tsx │ ├── _locales │ │ ├── i18n-react.ts │ │ ├── i18n.ts │ │ ├── ja │ │ │ └── main.json │ │ ├── resources.ts │ │ └── zh-hans │ │ │ └── main.json │ ├── apis │ │ └── index.ts │ ├── assets │ │ └── images │ │ │ ├── banner.jpg │ │ │ ├── logo.png │ │ │ ├── strategy.jpg │ │ │ └── strategy_zh.jpg │ ├── components │ │ ├── BottomLogger.tsx │ │ ├── ConfigSelector.tsx │ │ ├── CopyButton.tsx │ │ ├── CustomToastContainer.tsx │ │ ├── DebugModeIndicator.tsx │ │ ├── DialogButton.tsx │ │ ├── Labeled.tsx │ │ ├── LazyImportComponent.tsx │ │ ├── MarkdownRender.tsx │ │ ├── MobileFloatingNavigator.tsx │ │ ├── NumberInput.tsx │ │ ├── Page.tsx │ │ ├── ReadButton.tsx │ │ ├── ResetConfigsButton.tsx │ │ ├── RunButton.tsx │ │ ├── Section.tsx │ │ ├── ToolTipButton.tsx │ │ ├── ValuedSlider.tsx │ │ └── WorkHeader.tsx │ ├── main.tsx │ ├── pages │ │ ├── About.tsx │ │ ├── AudiotrackManager │ │ │ ├── AudiotrackButton.tsx │ │ │ └── AudiotrackEditor.tsx │ │ ├── AutoConfig.tsx │ │ ├── Chat.tsx │ │ ├── Completion.tsx │ │ ├── Composition.tsx │ │ ├── Configs.tsx │ │ ├── Downloads.tsx │ │ ├── Home.tsx │ │ ├── Models.tsx │ │ ├── PresetsManager │ │ │ ├── MessagesEditor.tsx │ │ │ └── PresetsButton.tsx │ │ ├── Settings.tsx │ │ ├── Train.tsx │ │ ├── defaultConfigs.ts │ │ └── index.tsx │ ├── startup.ts │ ├── stores │ │ ├── cmdTaskChainStore.ts │ │ └── commonStore.ts │ ├── style.scss │ ├── types │ │ ├── about.ts │ │ ├── chat.ts │ │ ├── completion.ts │ │ ├── composition.ts │ │ ├── configs.ts │ │ ├── downloads.ts │ │ ├── home.ts │ │ ├── html-midi-player.d.ts │ │ ├── models.ts │ │ ├── presets.ts │ │ ├── settings.ts │ │ └── train.ts │ ├── utils │ │ ├── convert-model.ts │ │ ├── copy-cuda-kernels.ts │ │ ├── filter-function-properties.ts │ │ ├── generate-strategy.ts │ │ ├── get-available-torch-cu-version.ts │ │ ├── index.tsx │ │ ├── rwkv-task.ts │ │ └── web-file-operations.ts │ ├── vite-env.d.ts │ └── webWails.js ├── tailwind.config.js ├── tsconfig.json ├── tsconfig.node.json ├── vite.config.ts └── wailsjs │ ├── go │ ├── backend_golang │ │ ├── App.d.ts │ │ └── App.js │ └── models.ts │ └── runtime │ ├── package.json │ ├── runtime.d.ts │ └── runtime.js ├── go.mod ├── go.sum ├── main.go ├── manifest.json ├── midi └── sample.txt ├── parse_api_log.py ├── py310 └── Lib │ └── site-packages │ ├── cyac-1.9.dist-info │ └── .gitkeep │ ├── cyac │ └── .gitkeep │ ├── diskcache-5.6.3.dist-info │ ├── INSTALLER │ ├── LICENSE │ ├── METADATA │ ├── RECORD │ ├── REQUESTED │ ├── WHEEL │ └── top_level.txt │ ├── diskcache │ ├── __init__.py │ ├── cli.py │ ├── core.py │ ├── djangocache.py │ ├── fanout.py │ ├── persistent.py │ └── recipes.py │ ├── llama_cpp │ ├── __init__.py │ ├── _ctypes_extensions.py │ ├── _ggml.py │ ├── _internals.py │ ├── _logger.py │ ├── _utils.py │ ├── lib │ │ ├── ggml-base.dll │ │ ├── ggml-base.lib │ │ ├── ggml-cpu.dll │ │ ├── ggml-cpu.lib │ │ ├── ggml-vulkan.dll │ │ ├── ggml-vulkan.lib │ │ ├── ggml.dll │ │ ├── ggml.lib │ │ ├── llama.dll │ │ ├── llama.lib │ │ ├── llava.dll │ │ └── llava.lib │ ├── llama.py │ ├── llama_cache.py │ ├── llama_chat_format.py │ ├── llama_cpp.py │ ├── llama_grammar.py │ ├── llama_speculative.py │ ├── llama_tokenizer.py │ ├── llama_types.py │ ├── llava_cpp.py │ ├── py.typed │ └── server │ │ ├── __init__.py │ │ ├── __main__.py │ │ ├── app.py │ │ ├── cli.py │ │ ├── errors.py │ │ ├── model.py │ │ ├── settings.py │ │ └── types.py │ └── llama_cpp_python-0.3.9.dist-info │ ├── INSTALLER │ ├── METADATA │ ├── RECORD │ ├── REQUESTED │ ├── WHEEL │ ├── direct_url.json │ └── licenses │ └── LICENSE.md ├── scripts └── merge_manifest.py └── wails.json /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto eol=lf 2 | 3 | backend-python/rwkv_pip/** linguist-vendored 4 | backend-python/wkv_cuda_utils/** linguist-vendored 5 | backend-python/get-pip.py linguist-vendored 6 | backend-python/convert_model.py linguist-vendored 7 | backend-python/convert_safetensors.py linguist-vendored 8 | backend-python/convert_pytorch_to_ggml.py linguist-vendored 9 | backend-python/utils/midi.py linguist-vendored 10 | build/** linguist-vendored 11 | finetune/lora/** linguist-vendored 12 | finetune/json2binidx_tool/** linguist-vendored 13 | py310/** linguist-vendored 14 | frontend/wailsjs/** linguist-generated -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | commit-message: 8 | prefix: "chore" 9 | include: "scope" 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build/bin 2 | node_modules 3 | frontend/dist 4 | __pycache__ 5 | .idea 6 | .vs 7 | *.pth 8 | *.st 9 | *.safetensors 10 | *.bin 11 | *.mid 12 | /config.json 13 | /cache.json 14 | /durable.json 15 | /presets.json 16 | /frontend/stats.html 17 | /frontend/package.json.md5 18 | /py310 19 | *.zip 20 | /cmd-helper.bat 21 | /install-py-dep.bat 22 | /backend-python/wkv_cuda 23 | /backend-python/rwkv5 24 | /backend-python/rwkv6 25 | /backend-python/wkv7s 26 | /backend-python/rwkv_pip/wkv_cuda.pyd 27 | /backend-python/rwkv_pip/rwkv5.pyd 28 | /backend-python/rwkv_pip/rwkv6.pyd 29 | /backend-python/rwkv_pip/wkv7s.pyd 30 | *.exe 31 | *.old 32 | .DS_Store 33 | *.log.* 34 | *.log 35 | train_log.txt 36 | finetune/json2binidx_tool/data 37 | /wsl.state 38 | /components 39 | error.txt 40 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | // 6 | // Use Ctrl+Shift+P to Select Interpreter 7 | "version": "0.2.0", 8 | "configurations": [ 9 | { 10 | "name": "Python", 11 | "type": "python", 12 | "request": "launch", 13 | "program": "${workspaceFolder}/backend-python/main.py", 14 | "console": "integratedTerminal", 15 | "justMyCode": false 16 | }, 17 | { 18 | "name": "Golang", 19 | "type": "go", 20 | "request": "launch", 21 | "mode": "exec", 22 | "program": "${workspaceFolder}/build/bin/testwails.exe", 23 | "console": "integratedTerminal", 24 | "preLaunchTask": "build dev" 25 | }, 26 | { 27 | "name": "Frontend", 28 | "type": "node-terminal", 29 | "request": "launch", 30 | "command": "wails dev -browser" 31 | } 32 | ] 33 | } -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "prettier.configPath": "./frontend/prettier.config.js", 3 | "prettier.prettierPath": "./frontend/node_modules/prettier", 4 | "prettier.requireConfig": true, 5 | "editor.defaultFormatter": "esbenp.prettier-vscode", 6 | "[go]": { 7 | "editor.defaultFormatter": "golang.go" 8 | }, 9 | "[python]": { 10 | "editor.defaultFormatter": "ms-python.black-formatter" 11 | }, 12 | "python.formatting.provider": "none", 13 | "editor.formatOnSave": true, 14 | } 15 | -------------------------------------------------------------------------------- /.vscode/tasks.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "2.0.0", 3 | "tasks": [ 4 | { 5 | "label": "build dev", 6 | "type": "shell", 7 | "options": { 8 | "cwd": "${workspaceFolder}", 9 | "env": { 10 | "CGO_ENABLED": "1" 11 | } 12 | }, 13 | "osx": { 14 | "options": { 15 | "env": { 16 | "CGO_CFLAGS": "-mmacosx-version-min=10.13", 17 | "CGO_LDFLAGS": "-framework UniformTypeIdentifiers -mmacosx-version-min=10.13" 18 | } 19 | } 20 | }, 21 | "windows": { 22 | "options": { 23 | "env": { 24 | "CGO_ENABLED": "0" 25 | } 26 | } 27 | }, 28 | "command": "go", 29 | "args": [ 30 | "build", 31 | "-tags", 32 | "dev", 33 | "-gcflags", 34 | "all=-N -l", 35 | "-o", 36 | "build/bin/testwails.exe" 37 | ] 38 | } 39 | ] 40 | } -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM node:21-slim AS frontend 2 | 3 | RUN echo "registry=https://registry.npmmirror.com/" > ~/.npmrc 4 | 5 | WORKDIR /app 6 | 7 | COPY manifest.json manifest.json 8 | COPY frontend frontend 9 | 10 | WORKDIR /app/frontend 11 | 12 | RUN npm ci 13 | RUN npm run build 14 | 15 | FROM nvidia/cuda:11.6.1-devel-ubuntu20.04 AS runtime 16 | 17 | ENV DEBIAN_FRONTEND=noninteractive 18 | 19 | RUN apt update && \ 20 | apt install -yq git curl wget build-essential ninja-build aria2 jq software-properties-common 21 | 22 | RUN add-apt-repository -y ppa:deadsnakes/ppa && \ 23 | add-apt-repository -y ppa:ubuntu-toolchain-r/test && \ 24 | apt install -y g++-11 python3.10 python3.10-distutils python3.10-dev && \ 25 | curl -sS http://mirrors.aliyun.com/pypi/get-pip.py | python3.10 26 | 27 | RUN python3.10 -m pip install cmake 28 | 29 | FROM runtime AS librwkv 30 | 31 | WORKDIR /app 32 | 33 | RUN git clone https://github.com/RWKV/rwkv.cpp.git && \ 34 | cd rwkv.cpp && \ 35 | git submodule update --init --recursive && \ 36 | mkdir -p build && \ 37 | cd build && \ 38 | cmake -G Ninja .. && \ 39 | cmake --build . 40 | 41 | FROM runtime AS final 42 | 43 | WORKDIR /app 44 | 45 | COPY ./backend-python/requirements.txt ./backend-python/requirements.txt 46 | 47 | RUN python3.10 -m pip install --quiet -r ./backend-python/requirements.txt 48 | 49 | COPY . . 50 | COPY --from=frontend /app/frontend/dist /app/frontend/dist 51 | COPY --from=librwkv /app/rwkv.cpp/build/librwkv.so /app/backend-python/rwkv_pip/cpp/librwkv.so 52 | 53 | EXPOSE 27777 54 | 55 | CMD ["python3.10", "./backend-python/main.py", "--port", "27777", "--host", "0.0.0.0", "--webui"] 56 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 josStorer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | ifeq ($(OS), Windows_NT) 2 | build: build-windows 3 | else ifeq ($(shell uname -s), Darwin) 4 | build: build-macos 5 | else 6 | build: build-linux 7 | endif 8 | 9 | windows_build = wails build -ldflags '-s -w -extldflags "-static"' -platform windows/amd64 -devtools -upx -upxflags "-9 --lzma" 10 | 11 | build-windows: 12 | @echo ---- build for windows 13 | $(windows_build) -nsis 14 | 15 | debug: 16 | $(windows_build) -windowsconsole 17 | 18 | build-macos: 19 | @echo ---- build for macos 20 | wails build -ldflags '-s -w' -platform darwin/universal -devtools 21 | 22 | build-linux: 23 | @echo ---- build for linux 24 | wails build -ldflags '-s -w' -platform linux/amd64 -devtools -upx -upxflags "-9 --lzma" 25 | 26 | build-linux_webkit2_41: 27 | @echo ---- build for linux with webkit2_41 28 | wails build -tags webkit2_41 -ldflags '-s -w' -platform linux/amd64 -devtools -upx -upxflags "-9 --lzma" 29 | 30 | build-web: 31 | @echo ---- build for web 32 | cd frontend && npm run build 33 | 34 | dev: 35 | wails dev 36 | 37 | # go install github.com/josStorer/wails/v2/cmd/wails@v2.9.2x 38 | devq: 39 | wails dev -s -m -skipembedcreate -skipbindings 40 | 41 | devq2: 42 | wails dev -s -m -skipembedcreate 43 | 44 | dev-web: 45 | cd frontend && npm run dev 46 | 47 | preview: 48 | cd frontend && npm run preview 49 | 50 | -------------------------------------------------------------------------------- /assets/default_sound_font.sf2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/assets/default_sound_font.sf2 -------------------------------------------------------------------------------- /backend-golang/cmd_interactive.go: -------------------------------------------------------------------------------- 1 | package backend_golang 2 | 3 | import "os" 4 | 5 | var cmds = make(map[string][]string) 6 | var cmdProcesses = make(map[string]*os.Process) 7 | 8 | func (a *App) GetCmds() map[string][]string { 9 | return cmds 10 | } 11 | 12 | func (a *App) KillCmd(eventId string) error { 13 | cmd, ok := cmdProcesses[eventId] 14 | if !ok { 15 | return nil 16 | } 17 | delete(cmds, eventId) 18 | delete(cmdProcesses, eventId) 19 | return cmd.Kill() 20 | } 21 | 22 | func (a *App) IsCmdRunning(eventId string) bool { 23 | _, ok := cmds[eventId] 24 | return ok 25 | } 26 | -------------------------------------------------------------------------------- /backend-golang/cmd_interactive_unix.go: -------------------------------------------------------------------------------- 1 | //go:build darwin || linux 2 | 3 | package backend_golang 4 | 5 | import ( 6 | "bufio" 7 | "io" 8 | "os/exec" 9 | 10 | wruntime "github.com/wailsapp/wails/v2/pkg/runtime" 11 | ) 12 | 13 | func (a *App) CmdInteractive(args []string, eventId string) error { 14 | cmd := exec.Command(args[0], args[1:]...) 15 | 16 | stdout, err := cmd.StdoutPipe() 17 | if err != nil { 18 | return err 19 | } 20 | 21 | cmd.Stderr = cmd.Stdout 22 | 23 | err = cmd.Start() 24 | 25 | if err != nil { 26 | return err 27 | } 28 | 29 | cmds[eventId] = args 30 | cmdProcesses[eventId] = cmd.Process 31 | reader := bufio.NewReader(stdout) 32 | 33 | for { 34 | line, _, err := reader.ReadLine() 35 | if err != nil { 36 | delete(cmds, eventId) 37 | delete(cmdProcesses, eventId) 38 | if err == io.EOF { 39 | wruntime.EventsEmit(a.ctx, eventId+"-finish") 40 | return nil 41 | } 42 | return err 43 | } 44 | strLine := string(line) 45 | wruntime.EventsEmit(a.ctx, eventId+"-output", strLine) 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /backend-golang/cmd_interactive_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package backend_golang 4 | 5 | import ( 6 | "bufio" 7 | "bytes" 8 | "io" 9 | "os/exec" 10 | "syscall" 11 | 12 | wruntime "github.com/wailsapp/wails/v2/pkg/runtime" 13 | "golang.org/x/text/encoding/simplifiedchinese" 14 | "golang.org/x/text/transform" 15 | ) 16 | 17 | func (a *App) CmdInteractive(args []string, eventId string) error { 18 | cmd := exec.Command(args[0], args[1:]...) 19 | 20 | stdout, err := cmd.StdoutPipe() 21 | if err != nil { 22 | return err 23 | } 24 | 25 | cmd.Stderr = cmd.Stdout 26 | cmd.SysProcAttr = &syscall.SysProcAttr{} 27 | cmd.SysProcAttr.HideWindow = true 28 | 29 | err = cmd.Start() 30 | 31 | if err != nil { 32 | return err 33 | } 34 | 35 | cmds[eventId] = args 36 | cmdProcesses[eventId] = cmd.Process 37 | reader := bufio.NewReader(stdout) 38 | 39 | for { 40 | line, _, err := reader.ReadLine() 41 | if err != nil { 42 | delete(cmds, eventId) 43 | delete(cmdProcesses, eventId) 44 | if err == io.EOF { 45 | wruntime.EventsEmit(a.ctx, eventId+"-finish") 46 | return nil 47 | } 48 | return err 49 | } 50 | reader := transform.NewReader(bytes.NewReader(line), simplifiedchinese.GBK.NewDecoder()) 51 | line2, err := io.ReadAll(reader) 52 | if err == nil { 53 | line = line2 54 | } 55 | strLine := string(line) 56 | wruntime.EventsEmit(a.ctx, eventId+"-output", strLine) 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /backend-golang/hw_info.go: -------------------------------------------------------------------------------- 1 | package backend_golang 2 | 3 | import ( 4 | "errors" 5 | "strconv" 6 | "strings" 7 | ) 8 | 9 | func (a *App) GetNvidiaGpuCount() (int, error) { 10 | // temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used 11 | // gpu_name,gpu_bus_id,driver_version 12 | // nvidia-smi --help-query-gpu 13 | output, err := a.CommandOutput("nvidia-smi", "--query-gpu=count", "--format=csv,noheader,nounits") 14 | if err != nil { 15 | return 0, err 16 | } 17 | return strconv.Atoi(output) 18 | } 19 | 20 | func (a *App) GetCudaComputeCapability(index int) (string, error) { 21 | output, err := a.CommandOutput("nvidia-smi", "-i="+strconv.Itoa(index), "--query-gpu=compute_cap", "--format=csv,noheader,nounits") 22 | if err != nil { 23 | return "", err 24 | } 25 | 26 | if output == "" { 27 | return "", errors.New("compute capability is empty") 28 | } 29 | 30 | return output, nil 31 | } 32 | 33 | func (a *App) GetMaxCudaComputeCapability() (string, error) { 34 | gpuCount, err := a.GetNvidiaGpuCount() 35 | if err != nil { 36 | return "", err 37 | } 38 | maxComputeCap := "0.0" 39 | for i := 0; i < gpuCount; i++ { 40 | computeCap, err := a.GetCudaComputeCapability(i) 41 | if err != nil { 42 | return "", err 43 | } 44 | computeCapFloat, err := strconv.ParseFloat(computeCap, 64) 45 | if err != nil { 46 | return "", err 47 | } 48 | maxComputeCapFloat, err := strconv.ParseFloat(maxComputeCap, 64) 49 | if err != nil { 50 | return "", err 51 | } 52 | if computeCapFloat > maxComputeCapFloat { 53 | maxComputeCap = computeCap 54 | } 55 | } 56 | if maxComputeCap == "0.0" { 57 | return "", errors.New("no cuda compute capability") 58 | } 59 | return maxComputeCap, nil 60 | } 61 | 62 | func (a *App) GetSupportedCudaVersion() (string, error) { 63 | output, err := a.CommandOutput("nvidia-smi", "--query") 64 | if err != nil { 65 | return "", err 66 | } 67 | 68 | lines := strings.Split(output, "\n") 69 | 70 | for _, line := range lines { 71 | if strings.Contains(line, "CUDA Version") { 72 | return strings.TrimSpace(strings.Split(line, ":")[1]), nil 73 | } 74 | } 75 | 76 | return "", errors.New("cuda version is empty") 77 | } 78 | 79 | func (a *App) GetTorchVersion(python string) (string, error) { 80 | var err error 81 | if python == "" { 82 | python, err = a.GetPython() 83 | if err != nil { 84 | return "", err 85 | } 86 | } 87 | 88 | output, err := a.CommandOutput(python, "-c", "import torch; print(torch.__version__)") 89 | if err != nil { 90 | return "", err 91 | } 92 | 93 | if output == "" { 94 | return "", errors.New("torch version is empty") 95 | } 96 | 97 | return output, nil 98 | } 99 | -------------------------------------------------------------------------------- /backend-golang/utils_unix.go: -------------------------------------------------------------------------------- 1 | //go:build darwin || linux 2 | 3 | package backend_golang 4 | 5 | import ( 6 | "os/exec" 7 | "strings" 8 | ) 9 | 10 | func CmdSetHideWindow(cmd *exec.Cmd, hideWindow bool) { 11 | } 12 | 13 | func (a *App) CommandOutput(name string, args ...string) (string, error) { 14 | cmd := exec.Command(name, args...) 15 | output, err := cmd.CombinedOutput() 16 | if err != nil { 17 | return "", err 18 | } 19 | return strings.TrimSpace(string(output)), nil 20 | } 21 | -------------------------------------------------------------------------------- /backend-golang/utils_windows.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package backend_golang 4 | 5 | import ( 6 | "os/exec" 7 | "strings" 8 | "syscall" 9 | ) 10 | 11 | func CmdSetHideWindow(cmd *exec.Cmd, hideWindow bool) { 12 | if cmd.SysProcAttr == nil { 13 | cmd.SysProcAttr = &syscall.SysProcAttr{} 14 | } 15 | cmd.SysProcAttr.HideWindow = hideWindow 16 | } 17 | 18 | func (a *App) CommandOutput(name string, args ...string) (string, error) { 19 | cmd := exec.Command(name, args...) 20 | CmdSetHideWindow(cmd, true) 21 | output, err := cmd.CombinedOutput() 22 | if err != nil { 23 | return "", err 24 | } 25 | return strings.TrimSpace(string(output)), nil 26 | } 27 | -------------------------------------------------------------------------------- /backend-golang/wsl_unix.go: -------------------------------------------------------------------------------- 1 | //go:build darwin || linux 2 | 3 | package backend_golang 4 | 5 | import ( 6 | "errors" 7 | ) 8 | 9 | func (a *App) WslStart() error { 10 | return errors.New("wsl not supported") 11 | } 12 | 13 | func (a *App) WslCommand(command string) error { 14 | return errors.New("wsl not supported") 15 | } 16 | 17 | func (a *App) WslStop() error { 18 | return errors.New("wsl not supported") 19 | } 20 | 21 | func (a *App) WslIsEnabled() error { 22 | return errors.New("wsl not supported") 23 | } 24 | 25 | func (a *App) WslEnable(forceMode bool) error { 26 | return errors.New("wsl not supported") 27 | } 28 | 29 | func (a *App) WslInstallUbuntu() error { 30 | return errors.New("wsl not supported") 31 | } 32 | -------------------------------------------------------------------------------- /backend-python/dep_check.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | if setuptools.__version__ >= "70.0.0": 4 | raise ImportError("setuptools>=70.0.0 is not supported") 5 | 6 | import multipart 7 | import fitz 8 | import safetensors 9 | import midi2audio 10 | import mido 11 | import lm_dataformat 12 | import ftfy 13 | import tqdm 14 | import tiktoken 15 | 16 | import torch 17 | import rwkv 18 | import langchain 19 | import numpy 20 | import tokenizers 21 | import fastapi 22 | import uvicorn 23 | import sse_starlette 24 | import pydantic 25 | import psutil 26 | -------------------------------------------------------------------------------- /backend-python/global_var.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, auto 2 | 3 | Args = "args" 4 | Model = "model" 5 | Model_Status = "model_status" 6 | Model_Config = "model_config" 7 | Deploy_Mode = "deploy_mode" 8 | Midi_Vocab_Config_Type = "midi_vocab_config_type" 9 | 10 | 11 | class ModelStatus(Enum): 12 | Offline = 0 13 | Loading = 2 14 | Working = 3 15 | 16 | 17 | class MidiVocabConfig(Enum): 18 | Default = auto() 19 | Piano = auto() 20 | 21 | 22 | def init(): 23 | global GLOBALS 24 | GLOBALS = {} 25 | set(Model_Status, ModelStatus.Offline) 26 | set(Deploy_Mode, False) 27 | set(Midi_Vocab_Config_Type, MidiVocabConfig.Default) 28 | 29 | 30 | def set(key, value): 31 | GLOBALS[key] = value 32 | 33 | 34 | def get(key): 35 | if key in GLOBALS: 36 | return GLOBALS[key] 37 | else: 38 | return None 39 | -------------------------------------------------------------------------------- /backend-python/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | torchaudio 4 | setuptools==69.5.1 5 | rwkv==0.8.29 6 | langchain==0.0.322 7 | fastapi==0.109.1 8 | uvicorn==0.23.2 9 | sse-starlette==1.6.5 10 | pydantic==2.4.2 11 | psutil==5.9.6 12 | gputil==1.4.0 13 | tiktoken==0.5.1 14 | ftfy==6.1.1 15 | lm-dataformat==0.0.20 16 | numpy==1.24.4 17 | tokenizers==0.14.1 18 | tqdm==4.66.1 19 | midi2audio==0.1.1 20 | mido==1.3.0 21 | safetensors==0.4.0 22 | PyMuPDF==1.23.5 23 | python-multipart==0.0.7 24 | Cython==3.0.4 25 | -------------------------------------------------------------------------------- /backend-python/requirements_without_cyac.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | torchaudio 4 | setuptools==69.5.1 5 | rwkv==0.8.29 6 | langchain==0.0.322 7 | fastapi==0.109.1 8 | uvicorn==0.23.2 9 | sse-starlette==1.6.5 10 | pydantic==2.4.2 11 | psutil==5.9.6 12 | gputil==1.4.0 13 | tiktoken==0.5.1 14 | ftfy==6.1.1 15 | lm-dataformat==0.0.20 16 | numpy==1.24.4 17 | tokenizers==0.14.1 18 | tqdm==4.66.1 19 | midi2audio==0.1.1 20 | mido==1.3.0 21 | safetensors==0.4.0 22 | PyMuPDF==1.23.5 23 | python-multipart==0.0.7 24 | Cython==3.0.4 25 | -------------------------------------------------------------------------------- /backend-python/routes/file_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | from fastapi import ( 3 | APIRouter, 4 | HTTPException, 5 | status, 6 | Depends, 7 | File, 8 | UploadFile, 9 | ) 10 | from pydantic import BaseModel 11 | from typing import Iterator 12 | 13 | router = APIRouter() 14 | 15 | 16 | class FileToTextParams(BaseModel): 17 | file_name: str 18 | file_encoding: str = "utf-8" 19 | 20 | 21 | @router.post("/file-to-text", tags=["File Process"]) 22 | async def file_to_text( 23 | params: FileToTextParams = Depends(), file_data: UploadFile = File(...) 24 | ): 25 | from langchain.schema import Document 26 | from langchain.document_loaders.blob_loaders import Blob 27 | 28 | # from langchain 29 | def parse_text(blob: Blob) -> Iterator[Document]: 30 | yield Document(page_content=blob.as_string(), metadata={"source": blob.source}) 31 | 32 | # from langchain 33 | def parse_pdf(blob: Blob) -> Iterator[Document]: 34 | import fitz 35 | 36 | with blob.as_bytes_io() as stream: 37 | doc = fitz.Document(stream=stream) 38 | 39 | yield from [ 40 | Document( 41 | page_content=page.get_text(), 42 | metadata=dict( 43 | { 44 | "source": blob.source, 45 | "file_path": blob.source, 46 | "page": page.number, 47 | "total_pages": len(doc), 48 | }, 49 | **{ 50 | k: doc.metadata[k] 51 | for k in doc.metadata 52 | if type(doc.metadata[k]) in [str, int] 53 | }, 54 | ), 55 | ) 56 | for page in doc 57 | ] 58 | 59 | file_parsers = {".txt": parse_text, ".pdf": parse_pdf} 60 | 61 | file_name = file_data.filename or params.file_name 62 | file_ext = os.path.splitext(file_name)[-1] 63 | 64 | if file_ext not in file_parsers: 65 | raise HTTPException(status.HTTP_400_BAD_REQUEST, "file type not supported") 66 | 67 | try: 68 | pages: Iterator[Document] = file_parsers[file_ext]( 69 | Blob.from_data( 70 | await file_data.read(), 71 | encoding=params.file_encoding, 72 | path=file_name, 73 | ) 74 | ) 75 | pages = list(pages) 76 | except Exception as e: 77 | raise HTTPException(status.HTTP_400_BAD_REQUEST, f"{e}") 78 | 79 | return {"pages": pages} 80 | -------------------------------------------------------------------------------- /backend-python/rwkv_pip/cpp/librwkv.dylib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/backend-python/rwkv_pip/cpp/librwkv.dylib -------------------------------------------------------------------------------- /backend-python/rwkv_pip/cpp/librwkv.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/backend-python/rwkv_pip/cpp/librwkv.so -------------------------------------------------------------------------------- /backend-python/rwkv_pip/cpp/model.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Union 2 | from . import rwkv_cpp_model 3 | from . import rwkv_cpp_shared_library 4 | 5 | 6 | class RWKV: 7 | def __init__(self, model_path: str, strategy=None): 8 | self.library = rwkv_cpp_shared_library.load_rwkv_shared_library() 9 | self.model = rwkv_cpp_model.RWKVModel(self.library, model_path) 10 | self.w = {} # fake weight 11 | self.w["emb.weight"] = [0] * self.model.n_vocab 12 | self.version = ( 13 | self.model.arch_version_major + self.model.arch_version_minor / 10 14 | ) 15 | 16 | def forward(self, tokens: List[int], state: Union[Any, None] = None): 17 | return self.model.eval_sequence_in_chunks(tokens, state, use_numpy=True) 18 | -------------------------------------------------------------------------------- /backend-python/rwkv_pip/cpp/rwkv.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/backend-python/rwkv_pip/cpp/rwkv.dll -------------------------------------------------------------------------------- /backend-python/rwkv_pip/cuda/rwkv5_op.cpp: -------------------------------------------------------------------------------- 1 | #include <torch/extension.h> 2 | #include "ATen/ATen.h" 3 | #include <c10/cuda/CUDAGuard.h> 4 | typedef at::BFloat16 bf16; 5 | typedef at::Half fp16; 6 | typedef float fp32; 7 | 8 | void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); 9 | void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y); 10 | void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y); 11 | 12 | void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 13 | const at::cuda::OptionalCUDAGuard device_guard(device_of(state)); 14 | cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>()); 15 | } 16 | void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 17 | const at::cuda::OptionalCUDAGuard device_guard(device_of(state)); 18 | cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>()); 19 | } 20 | void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 21 | const at::cuda::OptionalCUDAGuard device_guard(device_of(state)); 22 | cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>()); 23 | } 24 | 25 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 26 | m.def("forward_bf16", &forward_bf16, "rwkv5 forward_bf16"); 27 | m.def("forward_fp16", &forward_fp16, "rwkv5 forward_fp16"); 28 | m.def("forward_fp32", &forward_fp32, "rwkv5 forward_fp32"); 29 | } 30 | TORCH_LIBRARY(rwkv5, m) { 31 | m.def("forward_bf16", forward_bf16); 32 | m.def("forward_fp16", forward_fp16); 33 | m.def("forward_fp32", forward_fp32); 34 | } 35 | -------------------------------------------------------------------------------- /backend-python/rwkv_pip/cuda/rwkv6_op.cpp: -------------------------------------------------------------------------------- 1 | #include <torch/extension.h> 2 | #include "ATen/ATen.h" 3 | #include <c10/cuda/CUDAGuard.h> 4 | typedef at::BFloat16 bf16; 5 | typedef at::Half fp16; 6 | typedef float fp32; 7 | 8 | void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); 9 | void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y); 10 | void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y); 11 | 12 | void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 13 | const at::cuda::OptionalCUDAGuard device_guard(device_of(state)); 14 | cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>()); 15 | } 16 | void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 17 | const at::cuda::OptionalCUDAGuard device_guard(device_of(state)); 18 | cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>()); 19 | } 20 | void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 21 | const at::cuda::OptionalCUDAGuard device_guard(device_of(state)); 22 | cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>()); 23 | } 24 | 25 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 26 | m.def("forward_bf16", &forward_bf16, "rwkv6 forward_bf16"); 27 | m.def("forward_fp16", &forward_fp16, "rwkv6 forward_fp16"); 28 | m.def("forward_fp32", &forward_fp32, "rwkv6 forward_fp32"); 29 | } 30 | TORCH_LIBRARY(rwkv6, m) { 31 | m.def("forward_bf16", forward_bf16); 32 | m.def("forward_fp16", forward_fp16); 33 | m.def("forward_fp32", forward_fp32); 34 | } 35 | -------------------------------------------------------------------------------- /backend-python/rwkv_pip/cuda/rwkv7.cu: -------------------------------------------------------------------------------- 1 | #include <stdio.h> 2 | #include <assert.h> 3 | #include "ATen/ATen.h" 4 | 5 | typedef at::Half fp16; 6 | typedef at::BFloat16 bf16; 7 | typedef float fp32; 8 | 9 | template <typename F> 10 | __global__ void kernel_forward(const int B, const int T, const int C, const int H, 11 | float *__restrict__ _state, const F *__restrict__ const _r, const F *__restrict__ const _w, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _a, const F *__restrict__ const _b, 12 | F *__restrict__ const _y) 13 | { 14 | const int e = blockIdx.x / H; 15 | const int h = blockIdx.x % H; 16 | const int i = threadIdx.x; 17 | _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!! 18 | 19 | float state[_N_]; 20 | #pragma unroll 21 | for (int j = 0; j < _N_; j++) 22 | state[j] = _state[j]; 23 | 24 | __shared__ float r[_N_], k[_N_], w[_N_], a[_N_], b[_N_]; 25 | 26 | for (int _t = 0; _t < T; _t++) 27 | { 28 | const int t = e*T*C + h*_N_ + i + _t * C; 29 | __syncthreads(); 30 | r[i] = float(_r[t]); 31 | w[i] = __expf(-__expf(float(_w[t]))); 32 | k[i] = float(_k[t]); 33 | a[i] = float(_a[t]); 34 | b[i] = float(_b[t]); 35 | __syncthreads(); 36 | 37 | float sa = 0; 38 | #pragma unroll 39 | for (int j = 0; j < _N_; j++) 40 | { 41 | sa += a[j] * state[j]; 42 | } 43 | 44 | float vv = float(_v[t]); 45 | float y = 0; 46 | #pragma unroll 47 | for (int j = 0; j < _N_; j++) 48 | { 49 | float& s = state[j]; 50 | s = s * w[j] + k[j] * vv + sa * b[j]; 51 | y += s * r[j]; 52 | } 53 | _y[t] = F(y); 54 | } 55 | #pragma unroll 56 | for (int j = 0; j < _N_; j++) 57 | _state[j] = state[j]; 58 | } 59 | 60 | void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16* w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y) 61 | { 62 | assert(H*_N_ == C); 63 | assert(B == 1); // only for B=1 64 | kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y); 65 | } 66 | void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16* w, fp16 *k, fp16 *v, fp16 *a, fp16 *b, fp16 *y) 67 | { 68 | assert(H*_N_ == C); 69 | assert(B == 1); // only for B=1 70 | kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y); 71 | } 72 | void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32* w, fp32 *k, fp32 *v, fp32 *a, fp32 *b, fp32 *y) 73 | { 74 | assert(H*_N_ == C); 75 | assert(B == 1); // only for B=1 76 | kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y); 77 | } 78 | -------------------------------------------------------------------------------- /backend-python/rwkv_pip/cuda/rwkv7_op.cpp: -------------------------------------------------------------------------------- 1 | #include <torch/extension.h> 2 | #include "ATen/ATen.h" 3 | 4 | typedef at::Half fp16; 5 | typedef at::BFloat16 bf16; 6 | typedef float fp32; 7 | 8 | void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y); 9 | void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *w, fp16 *k, fp16 *v, fp16 *a, fp16 *b, fp16 *y); 10 | void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *w, fp32 *k, fp32 *v, fp32 *a, fp32 *b, fp32 *y); 11 | 12 | void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) { 13 | cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), w.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), a.data_ptr<bf16>(), b.data_ptr<bf16>(), y.data_ptr<bf16>()); 14 | } 15 | void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) { 16 | cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), w.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), a.data_ptr<fp16>(), b.data_ptr<fp16>(), y.data_ptr<fp16>()); 17 | } 18 | void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) { 19 | cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), w.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), a.data_ptr<fp32>(), b.data_ptr<fp32>(), y.data_ptr<fp32>()); 20 | } 21 | 22 | TORCH_LIBRARY(wkv7s, m) { 23 | m.def("forward_bf16", forward_bf16); 24 | m.def("forward_fp16", forward_fp16); 25 | m.def("forward_fp32", forward_fp32); 26 | } 27 | -------------------------------------------------------------------------------- /backend-python/rwkv_pip/kernels/torch-1.13.1+cu117/rwkv5.pyd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/backend-python/rwkv_pip/kernels/torch-1.13.1+cu117/rwkv5.pyd -------------------------------------------------------------------------------- /backend-python/rwkv_pip/kernels/torch-1.13.1+cu117/rwkv6.pyd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/backend-python/rwkv_pip/kernels/torch-1.13.1+cu117/rwkv6.pyd -------------------------------------------------------------------------------- /backend-python/rwkv_pip/kernels/torch-1.13.1+cu117/wkv7s.pyd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/backend-python/rwkv_pip/kernels/torch-1.13.1+cu117/wkv7s.pyd -------------------------------------------------------------------------------- /backend-python/rwkv_pip/kernels/torch-1.13.1+cu117/wkv_cuda.pyd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/backend-python/rwkv_pip/kernels/torch-1.13.1+cu117/wkv_cuda.pyd -------------------------------------------------------------------------------- /backend-python/rwkv_pip/kernels/torch-2.7.1+cu128/rwkv5.pyd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/backend-python/rwkv_pip/kernels/torch-2.7.1+cu128/rwkv5.pyd -------------------------------------------------------------------------------- /backend-python/rwkv_pip/kernels/torch-2.7.1+cu128/rwkv6.pyd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/backend-python/rwkv_pip/kernels/torch-2.7.1+cu128/rwkv6.pyd -------------------------------------------------------------------------------- /backend-python/rwkv_pip/kernels/torch-2.7.1+cu128/wkv7s.pyd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/backend-python/rwkv_pip/kernels/torch-2.7.1+cu128/wkv7s.pyd -------------------------------------------------------------------------------- /backend-python/rwkv_pip/kernels/torch-2.7.1+cu128/wkv_cuda.pyd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/backend-python/rwkv_pip/kernels/torch-2.7.1+cu128/wkv_cuda.pyd -------------------------------------------------------------------------------- /backend-python/rwkv_pip/webgpu/model.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Union 2 | 3 | try: 4 | import web_rwkv_py as wrp 5 | except ModuleNotFoundError: 6 | try: 7 | from . import web_rwkv_py as wrp 8 | except ImportError: 9 | raise ModuleNotFoundError( 10 | "web_rwkv_py not found, install it from https://github.com/cryscan/web-rwkv-py" 11 | ) 12 | 13 | 14 | class RWKV: 15 | def __init__(self, model_path: str, strategy: str = None): 16 | layer = ( 17 | int(s.lstrip("layer")) 18 | for s in strategy.split() 19 | for s in s.split(",") 20 | if s.startswith("layer") 21 | ) 22 | 23 | chunk_size = ( 24 | int(s.lstrip("chunk")) 25 | for s in strategy.split() 26 | for s in s.split(",") 27 | if s.startswith("chunk") 28 | ) 29 | self.token_chunk_size = next(chunk_size, 32) 30 | 31 | args = { 32 | "path": model_path, 33 | "quant": next(layer, 31) if "i8" in strategy else 0, 34 | "quant_nf4": next(layer, 26) if "i4" in strategy else 0, 35 | } 36 | self.model = wrp.Model(**args) 37 | self.info = self.model.info() 38 | self.w = {} # fake weight 39 | self.w["emb.weight"] = [0] * self.info.num_vocab 40 | self.version = str(self.info.version).lower() 41 | self.version = float(self.version.lower().replace("v", "")) 42 | 43 | def forward(self, tokens: List[int], state: Union[Any, None] = None): 44 | if state is None: 45 | self.model.clear_state() 46 | elif type(state).__name__ == "State_Cpu": 47 | self.model.load_state(state) 48 | logits = self.model.run(tokens, self.token_chunk_size) 49 | ret_state = "State_Gpu" 50 | return logits, ret_state 51 | -------------------------------------------------------------------------------- /backend-python/rwkv_pip/webgpu/web_rwkv_py.cp310-win_amd64.pyd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/backend-python/rwkv_pip/webgpu/web_rwkv_py.cp310-win_amd64.pyd -------------------------------------------------------------------------------- /backend-python/tests/postprocess_response.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def postprocess_response(s): 5 | REGEX_BLOCKS = r"([\w]+)[\s]*```[\w]*(.*?)```" 6 | REGEX_ARGS = r'"([^"]+)"\s*=\s*"([^"]+)"' 7 | 8 | name = re.search(REGEX_BLOCKS, s, re.DOTALL).group(1) 9 | function = re.search(REGEX_BLOCKS, s, re.DOTALL).group(2).strip() 10 | arguments = dict(re.findall(REGEX_ARGS, function)) 11 | 12 | print(f"Name:\n{name}") 13 | print(f"Function:\n{function}") 14 | print(f"arguments:\n{arguments}") 15 | print() 16 | 17 | return 18 | 19 | 20 | def postprocess_response_reserved(s): 21 | REGEX_BLOCKS = r"```[\w]*(.*?)```" 22 | REGEX_FUNCTIONS = r"(\w+)*\(" 23 | REGEX_ARGS = r'"([^"]+)"\s*=\s*"([^"]+)"' 24 | 25 | blocks = re.findall(REGEX_BLOCKS, s, re.DOTALL) 26 | print(f"Blocks:\n{blocks}") 27 | for block in blocks: 28 | functions = block.strip().split("\n") 29 | print(f"Functions:\n{functions}") 30 | print() 31 | for function in functions: 32 | name = re.search(REGEX_FUNCTIONS, function).group(1) 33 | arguments = f"{dict(re.findall(REGEX_ARGS, function))}" 34 | 35 | print(function) 36 | print(name) 37 | print(arguments) 38 | print() 39 | 40 | return 41 | 42 | 43 | if __name__ == "__main__": 44 | str = """ 45 | some texts 46 | some texts 47 | some texts 48 | some texts 49 | 50 | ```python\n 51 | get_current_wether("location"= "Tokyo", "unit" ="None")\n 52 | ``` 53 | 54 | some texts 55 | some texts 56 | some texts 57 | some texts 58 | """ 59 | postprocess_response(str) 60 | 61 | str = """ get_exchange_rate 62 | ```python 63 | tool_call("base_currency"= "func_as_param('Hello World!')", "target_currency"= "CNY") 64 | ```""" 65 | postprocess_response(str) 66 | 67 | str = """\ 68 | get_current_weather 69 | ```python\n 70 | tool_call("location"= "Tokyo", "unit"= "None")\n 71 | ```""" 72 | postprocess_response(str) 73 | -------------------------------------------------------------------------------- /backend-python/utils/log.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from typing import Any, Union 4 | from fastapi import Request 5 | from pydantic import BaseModel 6 | from enum import Enum 7 | 8 | 9 | logger = logging.getLogger() 10 | logger.setLevel(logging.INFO) 11 | formatter = logging.Formatter("%(asctime)s - %(levelname)s\n%(message)s") 12 | fh = logging.handlers.RotatingFileHandler( 13 | "api.log", mode="a", maxBytes=3 * 1024 * 1024, backupCount=3, encoding="utf-8" 14 | ) 15 | fh.setFormatter(formatter) 16 | logger.addHandler(fh) 17 | 18 | 19 | class ClsEncoder(json.JSONEncoder): 20 | def default(self, obj): 21 | if isinstance(obj, BaseModel): 22 | return obj.dict() 23 | if isinstance(obj, Enum): 24 | return obj.value 25 | return super().default(obj) 26 | 27 | 28 | def quick_log(request: Union[Request, None], body: Any, response: str): 29 | try: 30 | logger.info( 31 | f"Client: {request.client if request else ''}\nUrl: {request.url if request else ''}\n" 32 | + ( 33 | f"Body: {json.dumps(body.__dict__, ensure_ascii=False, cls=ClsEncoder)}\n" 34 | if body 35 | else "" 36 | ) 37 | + (f"Data:\n{response}\n" if response else "") 38 | ) 39 | except Exception as e: 40 | logger.info(f"Error quick_log request:\n{e}") 41 | 42 | 43 | async def log_middleware(request: Request): 44 | try: 45 | logger.info( 46 | f"Client: {request.client}\nUrl: {request.url}\nBody: {await request.body()}\n" 47 | ) 48 | except Exception as e: 49 | logger.info(f"Error log_middleware request:\n{e}") 50 | -------------------------------------------------------------------------------- /backend-python/utils/midi_filter_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "deduplicate_md5": true, 3 | "piece_split_delay": 10000, 4 | "min_piece_length": 0 5 | } -------------------------------------------------------------------------------- /backend-python/utils/ngrok.py: -------------------------------------------------------------------------------- 1 | import os 2 | import global_var 3 | 4 | 5 | def ngrok_connect(): 6 | from pyngrok import ngrok, conf 7 | 8 | conf.set_default( 9 | conf.PyngrokConfig(ngrok_path="./ngrok.exe" if os.name == "nt" else "./ngrok") 10 | ) 11 | ngrok.set_auth_token(os.environ["ngrok_token"]) 12 | http_tunnel = ngrok.connect(global_var.get(global_var.Args).port) 13 | print(f"ngrok url: {http_tunnel.public_url}") 14 | -------------------------------------------------------------------------------- /backend-python/utils/torch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sysconfig 3 | 4 | 5 | def set_torch(): 6 | torch_path = os.path.join(sysconfig.get_paths()["purelib"], f"torch{os.sep}lib") 7 | paths = os.environ.get("PATH", "") 8 | if os.path.exists(torch_path): 9 | print(f"torch found: {torch_path}") 10 | if torch_path in paths: 11 | print("torch already set") 12 | else: 13 | os.environ["PATH"] = paths + os.pathsep + torch_path + os.pathsep 14 | print("torch set") 15 | # print("run:") 16 | # print(f"set Path={paths + os.pathsep + torch_path + os.pathsep}") 17 | else: 18 | print("torch not found") 19 | 20 | 21 | def torch_gc(): 22 | try: 23 | import torch 24 | 25 | if torch.cuda.is_available(): 26 | with torch.cuda.device(0): 27 | torch.cuda.empty_cache() 28 | torch.cuda.ipc_collect() 29 | except: 30 | pass # prevent 'torch' has no attribute 'cuda' error, so user can use CPU or WebGPU 31 | -------------------------------------------------------------------------------- /backend-python/webui_server.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI 2 | from fastapi.middleware.gzip import GZipMiddleware 3 | from fastapi.staticfiles import StaticFiles 4 | import uvicorn 5 | 6 | webui_server = FastAPI() 7 | 8 | webui_server.add_middleware(GZipMiddleware, minimum_size=1000) 9 | webui_server.mount( 10 | "/", StaticFiles(directory="frontend/dist", html=True), name="static" 11 | ) 12 | 13 | if __name__ == "__main__": 14 | uvicorn.run("webui_server:webui_server") 15 | -------------------------------------------------------------------------------- /build/README.md: -------------------------------------------------------------------------------- 1 | # Build Directory 2 | 3 | The build directory is used to house all the build files and assets for your application. 4 | 5 | The structure is: 6 | 7 | * bin - Output directory 8 | * darwin - macOS specific files 9 | * windows - Windows specific files 10 | 11 | ## Mac 12 | 13 | The `darwin` directory holds files specific to Mac builds. 14 | These may be customised and used as part of the build. To return these files to the default state, simply delete them 15 | and 16 | build with `wails build`. 17 | 18 | The directory contains the following files: 19 | 20 | - `Info.plist` - the main plist file used for Mac builds. It is used when building using `wails build`. 21 | - `Info.dev.plist` - same as the main plist file but used when building using `wails dev`. 22 | 23 | ## Windows 24 | 25 | The `windows` directory contains the manifest and rc files used when building with `wails build`. 26 | These may be customised for your application. To return these files to the default state, simply delete them and 27 | build with `wails build`. 28 | 29 | - `icon.ico` - The icon used for the application. This is used when building using `wails build`. If you wish to 30 | use a different icon, simply replace this file with your own. If it is missing, a new `icon.ico` file 31 | will be created using the `appicon.png` file in the build directory. 32 | - `installer/*` - The files used to create the Windows installer. These are used when building using `wails build`. 33 | - `info.json` - Application details used for Windows builds. The data here will be used by the Windows installer, 34 | as well as the application itself (right click the exe -> properties -> details) 35 | - `wails.exe.manifest` - The main application manifest file. -------------------------------------------------------------------------------- /build/appicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/build/appicon.png -------------------------------------------------------------------------------- /build/darwin/Info.dev.plist: -------------------------------------------------------------------------------- 1 | <!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd"> 2 | <plist version="1.0"> 3 | <dict> 4 | <key>CFBundlePackageType</key> 5 | <string>APPL</string> 6 | <key>CFBundleName</key> 7 | <string>{{.Info.ProductName}}</string> 8 | <key>CFBundleExecutable</key> 9 | <string>{{.Name}}</string> 10 | <key>CFBundleIdentifier</key> 11 | <string>dev.josStorer.RWKV-Runner</string> 12 | <key>CFBundleVersion</key> 13 | <string>{{.Info.ProductVersion}}</string> 14 | <key>CFBundleGetInfoString</key> 15 | <string>{{.Info.Comments}}</string> 16 | <key>CFBundleShortVersionString</key> 17 | <string>{{.Info.ProductVersion}}</string> 18 | <key>CFBundleIconFile</key> 19 | <string>iconfile</string> 20 | <key>LSMinimumSystemVersion</key> 21 | <string>10.13.0</string> 22 | <key>NSHighResolutionCapable</key> 23 | <string>true</string> 24 | <key>NSHumanReadableCopyright</key> 25 | <string>{{.Info.Copyright}}</string> 26 | <key>NSAppTransportSecurity</key> 27 | <dict> 28 | <key>NSAllowsLocalNetworking</key> 29 | <true/> 30 | </dict> 31 | </dict> 32 | </plist> -------------------------------------------------------------------------------- /build/darwin/Info.plist: -------------------------------------------------------------------------------- 1 | <!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd"> 2 | <plist version="1.0"> 3 | <dict> 4 | <key>CFBundlePackageType</key> 5 | <string>APPL</string> 6 | <key>CFBundleName</key> 7 | <string>{{.Info.ProductName}}</string> 8 | <key>CFBundleExecutable</key> 9 | <string>{{.Name}}</string> 10 | <key>CFBundleIdentifier</key> 11 | <string>dev.josStorer.RWKV-Runner</string> 12 | <key>CFBundleVersion</key> 13 | <string>{{.Info.ProductVersion}}</string> 14 | <key>CFBundleGetInfoString</key> 15 | <string>{{.Info.Comments}}</string> 16 | <key>CFBundleShortVersionString</key> 17 | <string>{{.Info.ProductVersion}}</string> 18 | <key>CFBundleIconFile</key> 19 | <string>iconfile</string> 20 | <key>LSMinimumSystemVersion</key> 21 | <string>10.13.0</string> 22 | <key>NSHighResolutionCapable</key> 23 | <string>true</string> 24 | <key>NSHumanReadableCopyright</key> 25 | <string>{{.Info.Copyright}}</string> 26 | </dict> 27 | </plist> -------------------------------------------------------------------------------- /build/darwin/Readme_Install.txt: -------------------------------------------------------------------------------- 1 | Client Download URL: 2 | 客户端下载地址: 3 | クライアントのダウンロードURL: 4 | https://github.com/josStorer/RWKV-Runner/releases/latest/download/RWKV-Runner_macos_universal.zip 5 | 6 | For Mac and Linux users, please manually install Python 3.10 (usually the latest systems come with it built-in). You can specify the Python interpreter to use in Settings. (which python3) 7 | 对于Mac和Linux用户,请手动安装 Python3.10 (通常最新的系统已经内置了). 你可以在设置中指定使用的Python解释器. (which python3) 8 | MacおよびLinuxのユーザーの方は、Python3.10を手動でインストールしてください(通常、最新のシステムには既に組み込まれています)。 設定メニューで使用するPythonインタプリタを指定することができます。 (which python3) 9 | 10 | Please execute this program in an empty directory. All related dependencies will be placed in this directory. 11 | 请将本程序放在一个空目录内执行, 所有相关依赖均会放置于此目录. 12 | このプログラムを空のディレクトリで実行してください。関連するすべての依存関係は、このディレクトリに配置されます。 13 | 14 | Please execute the following command in the terminal to remove the permission restrictions of this app, and then this program can work properly: 15 | 请在终端执行以下命令解除本app的权限限制, 然后本程序才可以正常工作: 16 | このアプリの権限制限を解除するために、ターミナルで以下のコマンドを実行してください。その後、このプログラムは正常に動作するようになります: 17 | 18 | sudo xattr -r -d com.apple.quarantine ./RWKV-Runner.app 19 | -------------------------------------------------------------------------------- /build/darwin/entitlements.plist: -------------------------------------------------------------------------------- 1 | <?xml version="1.0" encoding="UTF-8"?> 2 | <!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd"> 3 | <plist version="1.0"> 4 | <dict> 5 | <key>com.apple.security.app-sandbox</key> 6 | <false/> 7 | <key>com.apple.security.network.client</key> 8 | <true/> 9 | <key>com.apple.security.network.server</key> 10 | <true/> 11 | <key>com.apple.security.files.user-selected.read-write</key> 12 | <true/> 13 | <key>com.apple.security.files.downloads.read-write</key> 14 | <true/> 15 | </dict> 16 | </plist> -------------------------------------------------------------------------------- /build/darwin/gon-sign.json: -------------------------------------------------------------------------------- 1 | { 2 | "source": [ 3 | "./build/bin/RWKV-Runner_darwin.app" 4 | ], 5 | "bundle_id": "dev.josStorer.RWKV-Runner", 6 | "apple_id": { 7 | "username": "joshua1466587594@outlook.com", 8 | "password": "" 9 | }, 10 | "sign": { 11 | "application_identity": "D00A983569B4EAA2A008B963254F385F42A493FD", 12 | "entitlements_file": "./build/darwin/entitlements.plist" 13 | }, 14 | "zip": { 15 | "output_path": "./build/bin/RWKV-Runner_darwin.archive.zip" 16 | } 17 | } -------------------------------------------------------------------------------- /build/linux/Readme_Install.txt: -------------------------------------------------------------------------------- 1 | Client Download URL: 2 | 客户端下载地址: 3 | クライアントのダウンロードURL: 4 | https://github.com/josStorer/RWKV-Runner/releases/latest/download/RWKV-Runner_linux_x64 5 | 6 | For Mac and Linux users, please manually install Python 3.10 (usually the latest systems come with it built-in). You can specify the Python interpreter to use in Settings. 7 | 对于Mac和Linux用户,请手动安装 Python3.10 (通常最新的系统已经内置了). 你可以在设置中指定使用的Python解释器. 8 | MacおよびLinuxのユーザーの方は、Python3.10を手動でインストールしてください(通常、最新のシステムには既に組み込まれています)。 設定メニューで使用するPythonインタプリタを指定することができます。 9 | 10 | Please execute this program in an empty directory. All related dependencies will be placed in this directory. 11 | 请将本程序放在一个空目录内执行, 所有相关依赖均会放置于此目录. 12 | このプログラムを空のディレクトリで実行してください。関連するすべての依存関係は、このディレクトリに配置されます。 13 | 14 | On Linux system, this program cannot invoke the terminal for automatic dependency installation. You must manually execute the following commands for installation so that it can be used normally: 15 | 在Linux系统下, 本程序无法调用终端自动安装依赖, 你必须手动执行以下命令进行安装, 之后方可正常使用: 16 | Linuxシステムでは、このプログラムはターミナルを自動的に呼び出して依存関係をインストールすることができません。以下のコマンドを手動で実行する必要があります。それが完了した後に、正常に使用することができます: 17 | 18 | sudo apt install python3-dev 19 | chmod +x ./RWKV-Runner 20 | ./RWKV-Runner 21 | cd backend-python 22 | pip3 install -r requirements.txt # or pip3 install -r requirements_without_cyac.txt 23 | 24 | # See More: https://github.com/josStorer/RWKV-Runner/tree/master/deploy-examples 25 | -------------------------------------------------------------------------------- /build/windows/Readme_Install.txt: -------------------------------------------------------------------------------- 1 | Client Download URL: 2 | 客户端下载地址: 3 | クライアントのダウンロードURL: 4 | https://github.com/josStorer/RWKV-Runner/releases/latest/download/RWKV-Runner-amd64-installer.exe 5 | https://github.com/josStorer/RWKV-Runner/releases/latest/download/RWKV-Runner_windows_x64.exe 6 | 7 | Please execute this program in an empty directory. All related dependencies will be placed in this directory. 8 | 请将本程序放在一个空目录内执行, 所有相关依赖均会放置于此目录. 9 | このプログラムを空のディレクトリで実行してください。関連するすべての依存関係は、このディレクトリに配置されます。 10 | -------------------------------------------------------------------------------- /build/windows/WELCOMEFINISHPAGE.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/build/windows/WELCOMEFINISHPAGE.bmp -------------------------------------------------------------------------------- /build/windows/icon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/build/windows/icon.ico -------------------------------------------------------------------------------- /build/windows/info.json: -------------------------------------------------------------------------------- 1 | { 2 | "fixed": { 3 | "file_version": "{{.Info.ProductVersion}}" 4 | }, 5 | "info": { 6 | "0000": { 7 | "ProductVersion": "{{.Info.ProductVersion}}", 8 | "CompanyName": "{{.Info.CompanyName}}", 9 | "FileDescription": "{{.Info.ProductName}}", 10 | "LegalCopyright": "{{.Info.Copyright}}", 11 | "ProductName": "{{.Info.ProductName}}", 12 | "Comments": "{{.Info.Comments}}" 13 | } 14 | } 15 | } -------------------------------------------------------------------------------- /build/windows/wails.exe.manifest: -------------------------------------------------------------------------------- 1 | <?xml version="1.0" encoding="UTF-8" standalone="yes"?> 2 | <assembly manifestVersion="1.0" xmlns="urn:schemas-microsoft-com:asm.v1" xmlns:asmv3="urn:schemas-microsoft-com:asm.v3"> 3 | <assemblyIdentity type="win32" name="com.wails.{{.Name}}" version="{{.Info.ProductVersion}}.0" processorArchitecture="*"/> 4 | <dependency> 5 | <dependentAssembly> 6 | <assemblyIdentity type="win32" name="Microsoft.Windows.Common-Controls" version="6.0.0.0" processorArchitecture="*" publicKeyToken="6595b64144ccf1df" language="*"/> 7 | </dependentAssembly> 8 | </dependency> 9 | <asmv3:application> 10 | <asmv3:windowsSettings> 11 | <dpiAware xmlns="http://schemas.microsoft.com/SMI/2005/WindowsSettings">true/pm</dpiAware> <!-- fallback for Windows 7 and 8 --> 12 | <dpiAwareness xmlns="http://schemas.microsoft.com/SMI/2016/WindowsSettings">permonitorv2,permonitor</dpiAwareness> <!-- falls back to per-monitor if per-monitor v2 is not supported --> 13 | </asmv3:windowsSettings> 14 | </asmv3:application> 15 | </assembly> -------------------------------------------------------------------------------- /components/gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/components/gitkeep -------------------------------------------------------------------------------- /deploy-examples/ChatGPT-Next-Web/setup.bat: -------------------------------------------------------------------------------- 1 | : install git python3.10 yarn by yourself 2 | : change model and strategy according to your hardware 3 | 4 | mkdir RWKV-Next-Web 5 | cd RWKV-Next-Web 6 | 7 | git clone https://github.com/josStorer/RWKV-Runner --depth=1 8 | python -m pip install torch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 --index-url https://download.pytorch.org/whl/cu117 9 | python -m pip install -r RWKV-Runner/backend-python/requirements.txt 10 | start python ./RWKV-Runner/backend-python/main.py 11 | 12 | powershell -Command "(Test-Path ./RWKV-Runner/models) -or (mkdir RWKV-Runner/models)" 13 | powershell -Command "Import-Module BitsTransfer" 14 | powershell -Command "(Test-Path ./RWKV-Runner/models/RWKV-4-World-1.5B-v1-fixed-20230612-ctx4096.pth) -or (Start-BitsTransfer https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-1.5B-v1-fixed-20230612-ctx4096.pth ./RWKV-Runner/models/RWKV-4-World-1.5B-v1-fixed-20230612-ctx4096.pth)" 15 | powershell -Command "Invoke-WebRequest http://127.0.0.1:8000/switch-model -Method POST -ContentType 'application/json' -Body '{\"model\":\"./RWKV-Runner/models/RWKV-4-World-1.5B-v1-fixed-20230612-ctx4096.pth\",\"strategy\":\"cuda fp32 *20+\"}'" 16 | 17 | git clone https://github.com/Yidadaa/ChatGPT-Next-Web --depth=1 18 | cd ChatGPT-Next-Web 19 | call yarn install 20 | call yarn build 21 | set PROXY_URL="" 22 | set BASE_URL=http://127.0.0.1:8000 23 | start "C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe" "http://127.0.0.1:3000" 24 | yarn start 25 | -------------------------------------------------------------------------------- /deploy-examples/ChatGPT-Next-Web/setup.sh: -------------------------------------------------------------------------------- 1 | # install git python3.10 yarn by yourself 2 | # change model and strategy according to your hardware 3 | 4 | sudo apt install python3-dev 5 | 6 | mkdir RWKV-Next-Web 7 | cd RWKV-Next-Web 8 | 9 | git clone https://github.com/josStorer/RWKV-Runner --depth=1 10 | python3 -m pip install torch torchvision torchaudio 11 | python3 -m pip install -r RWKV-Runner/backend-python/requirements.txt 12 | python3 ./RWKV-Runner/backend-python/main.py > log.txt & # this is only an example, you should use screen or other tools to run it in background 13 | 14 | if [ ! -d RWKV-Runner/models ]; then 15 | mkdir RWKV-Runner/models 16 | fi 17 | wget -N https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-0.1B-v1-20230520-ctx4096.pth -P RWKV-Runner/models/ 18 | 19 | git clone https://github.com/Yidadaa/ChatGPT-Next-Web --depth=1 20 | cd ChatGPT-Next-Web 21 | yarn install 22 | yarn build 23 | export PROXY_URL="" 24 | export BASE_URL=http://127.0.0.1:8000 25 | yarn start & # this is only an example, you should use screen or other tools to run it in background 26 | 27 | curl http://127.0.0.1:8000/switch-model -X POST -H "Content-Type: application/json" -d '{"model":"./RWKV-Runner/models/RWKV-4-World-0.1B-v1-20230520-ctx4096.pth","strategy":"cpu fp32"}' 28 | -------------------------------------------------------------------------------- /deploy-examples/RWKV-Runner-WebUI/setup.bat: -------------------------------------------------------------------------------- 1 | : install git python3.10 npm by yourself 2 | : change model and strategy according to your hardware 3 | 4 | git clone https://github.com/josStorer/RWKV-Runner --depth=1 5 | python -m pip install torch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 --index-url https://download.pytorch.org/whl/cu117 6 | python -m pip install -r RWKV-Runner/backend-python/requirements.txt 7 | cd RWKV-Runner/frontend 8 | call npm ci 9 | call npm run build 10 | cd .. 11 | 12 | : optional: set ngrok_token=YOUR_NGROK_TOKEN 13 | start python ./backend-python/main.py --webui 14 | start "C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe" "http://127.0.0.1:8000" 15 | 16 | powershell -Command "(Test-Path ./models) -or (mkdir models)" 17 | powershell -Command "Import-Module BitsTransfer" 18 | powershell -Command "(Test-Path ./models/RWKV-4-World-1.5B-v1-fixed-20230612-ctx4096.pth) -or (Start-BitsTransfer https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-1.5B-v1-fixed-20230612-ctx4096.pth ./models/RWKV-4-World-1.5B-v1-fixed-20230612-ctx4096.pth)" 19 | powershell -Command "Invoke-WebRequest http://127.0.0.1:8000/switch-model -Method POST -ContentType 'application/json' -Body '{\"model\":\"./models/RWKV-4-World-1.5B-v1-fixed-20230612-ctx4096.pth\",\"strategy\":\"cuda fp32 *20+\",\"deploy\":\"true\"}'" 20 | -------------------------------------------------------------------------------- /deploy-examples/RWKV-Runner-WebUI/setup.sh: -------------------------------------------------------------------------------- 1 | # install git python3.10 npm by yourself 2 | # change model and strategy according to your hardware 3 | 4 | sudo apt install python3-dev 5 | 6 | git clone https://github.com/josStorer/RWKV-Runner --depth=1 7 | python3 -m pip install torch torchvision torchaudio 8 | python3 -m pip install -r RWKV-Runner/backend-python/requirements.txt 9 | cd RWKV-Runner/frontend 10 | npm ci 11 | npm run build 12 | cd .. 13 | 14 | # optional: export ngrok_token=YOUR_NGROK_TOKEN 15 | python3 ./backend-python/main.py --webui > log.txt & # this is only an example, you should use screen or other tools to run it in background 16 | 17 | if [ ! -d models ]; then 18 | mkdir models 19 | fi 20 | wget -N https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-0.1B-v1-20230520-ctx4096.pth -P models/ 21 | 22 | curl http://127.0.0.1:8000/switch-model -X POST -H "Content-Type: application/json" -d '{"model":"./models/RWKV-4-World-0.1B-v1-20230520-ctx4096.pth","strategy":"cpu fp32","deploy":"true"}' 23 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | rmkv_runner: 3 | image: rwkv-runner:latest 4 | build: . 5 | # Append "--rwkv.cpp" parameter to use rwkv.cpp 6 | # command: python3.10 ./backend-python/main.py --port 27777 --host 0.0.0.0 --webui --rwkv.cpp 7 | volumes: 8 | - /mnt:/mnt 9 | ports: 10 | - "27777:27777" 11 | # Comment the following lines if use rwkv.cpp 12 | deploy: 13 | resources: 14 | reservations: 15 | devices: 16 | - driver: nvidia 17 | count: 1 18 | capabilities: [gpu] 19 | -------------------------------------------------------------------------------- /exportModelsJson.js: -------------------------------------------------------------------------------- 1 | // Execute this script on the Hugging Face files list page to export JSON data. Don't forget to click "Load more files". 2 | // Run console.log(JSON.stringify(modelsJson, null, 2)) to output the JSON to the console. 3 | 4 | let modelsJson = [] 5 | 6 | function extractValue(text, prefix) { 7 | let ret 8 | text.split('\n').forEach(line => { 9 | if (!ret && line.startsWith(prefix)) 10 | ret = line.replace(prefix, '').trim() 11 | }) 12 | return ret || '' 13 | } 14 | 15 | document.querySelectorAll('.grid.h-10.grid-cols-12.place-content-center.gap-x-3.border-t.px-3.dark\\:border-gray-800').forEach(async e => { 16 | let data = {} 17 | data.name = e.children[0].children[0].textContent.trim() 18 | 19 | if (!data.name.endsWith('.bin') && !data.name.endsWith('.pth') && !data.name.endsWith('.gguf')) 20 | return 21 | 22 | data.desc = { en: '', zh: '', ja: '' } 23 | const rawText = await (await fetch(e.children[1].href.replace('/resolve/', '/raw/'))).text() 24 | 25 | data.size = parseInt(extractValue(rawText, 'size')) 26 | data.SHA256 = extractValue(rawText, 'oid sha256:') 27 | data.lastUpdated = e.children[3].children[0].getAttribute('datetime') 28 | data.url = e.children[1].href.replace('/resolve/', '/blob/').replace('?download=true', '') 29 | data.downloadUrl = e.children[1].href.replace('?download=true', '') 30 | data.tags = [] 31 | 32 | modelsJson.push(data) 33 | }) 34 | 35 | setTimeout(() => { 36 | console.log(JSON.stringify(modelsJson, null, 2)) 37 | }, 500) 38 | -------------------------------------------------------------------------------- /finetune/data/sample.jsonl: -------------------------------------------------------------------------------- 1 | {"text": "The following is an epic science fiction masterpiece that is immortalized, with delicate descriptions and grand depictions of interstellar civilization wars.\nChapter 1.\nAs I sit down to write here amidst the shadows of vine-leaves under the blue sky of southern Italy, it comes to me with a certain quality of astonishment that my participation in these amazing adventures of Mr. Cavor was, after all, the outcome of the purest accident. It might have been any one. I fell into these things at a time when I thought myself removed from the slightest possibility of disturbing experiences. I had gone to Lympne because I had imagined it the most uneventful place in the world. “Here, at any rate,” said I, “I shall find peace and a chance to work!”"} 2 | {"text": "Translate the following into Chinese.\n\nEnglish: What rooms do you have available?\nChinese: 你们有哪些房间可以提供"} 3 | {"text": "User: Hello.\n\nAssistant: I'm here, meow~.\n\nUser: Can you tell some jokes?\n\nAssistant: Of course, master. What kind of jokes would you like to hear?"} 4 | {"text": "Instruction: Write a story using the following information\n\nInput: A man named Alex chops a tree down\n\nResponse: Once upon a time, there was a man named Alex who lived in the heart of the forest. He had always been fascinated by trees and spent most of his days exploring the forest and learning about its many wonders. One day, while wandering through the woods, he stumbled upon an old oak tree that stood tall and proud in the middle of a clearing."} 5 | {"text": "def get_args(args: Union[Sequence[str], None] = None):\n parser = argparse.ArgumentParser()\n group = parser.add_argument_group(title=\"server arguments\")\n group.add_argument(\n \"--port\",\n type=int,\n default=8000,\n help=\"port to run the server on (default: 8000)\",\n )\n group.add_argument(\n \"--host\",\n type=str,\n default=\"127.0.0.1\",\n help=\"host to run the server on (default: 127.0.0.1)\",\n )"} -------------------------------------------------------------------------------- /finetune/get_layer_and_embd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | import time 4 | import os 5 | import threading 6 | import gc 7 | 8 | 9 | def file_cleaner(file): 10 | last_pos = 0 11 | 12 | def cleaner(): 13 | nonlocal last_pos 14 | while True: 15 | time.sleep(0.1) 16 | pos = file.tell() 17 | if pos > last_pos: 18 | os.posix_fadvise( 19 | file.fileno(), last_pos, pos - last_pos, os.POSIX_FADV_DONTNEED 20 | ) 21 | last_pos = pos 22 | 23 | return cleaner 24 | 25 | 26 | expected_max_version = float(sys.argv[2]) if len(sys.argv) > 2 else 100 27 | model_file = open(sys.argv[1], "rb") 28 | cleaner = file_cleaner(model_file) 29 | cleaner_thread = threading.Thread(target=cleaner, daemon=True) 30 | cleaner_thread.start() 31 | 32 | w = torch.load(model_file, map_location="cpu") 33 | gc.collect() 34 | 35 | vocab_size = w["emb.weight"].shape[0] 36 | n_embd = w["emb.weight"].shape[1] 37 | n_layer = 0 38 | keys = list(w.keys()) 39 | version = 4 40 | for x in keys: 41 | layer_id = int(x.split(".")[1]) if ("blocks." in x) else 0 42 | n_layer = max(n_layer, layer_id + 1) 43 | 44 | if "ln_x" in x: 45 | version = max(5, version) 46 | if "gate.weight" in x: 47 | version = max(5.1, version) 48 | if int(version) == 5 and "att.time_decay" in x: 49 | if len(w[x].shape) > 1: 50 | if w[x].shape[1] > 1: 51 | version = max(5.2, version) 52 | if "time_maa" in x: 53 | version = max(6, version) 54 | 55 | params = f"--vocab_size {vocab_size} --n_layer {n_layer} --n_embd {n_embd}" 56 | 57 | if version <= expected_max_version: 58 | if version == 6: 59 | params += ' --my_testing "x060"' 60 | print( 61 | f"v{int(version)}/train.py {params}", 62 | end="", 63 | ) 64 | else: 65 | raise Exception(f"RWKV{version} is not supported") 66 | -------------------------------------------------------------------------------- /finetune/install-wsl-dep-and-train.sh: -------------------------------------------------------------------------------- 1 | echo $@ 2 | 3 | if [[ ${cnMirror} == 1 ]]; then 4 | export PIP_INDEX_URL="https://mirrors.aliyun.com/pypi/simple" 5 | if grep -q "mirrors.aliyun.com" /etc/apt/sources.list; then 6 | echo "apt cnMirror already set" 7 | else 8 | sudo sed -i 's/http:\/\/archive.ubuntu.com\/ubuntu\//http:\/\/mirrors.aliyun.com\/ubuntu\//g' /etc/apt/sources.list 9 | sudo apt update 10 | fi 11 | fi 12 | 13 | if dpkg -s "gcc" >/dev/null 2>&1; then 14 | echo "gcc installed" 15 | else 16 | sudo apt -y install gcc 17 | fi 18 | 19 | if dpkg -s "python3-pip" >/dev/null 2>&1; then 20 | echo "pip installed" 21 | else 22 | sudo apt -y install python3-pip 23 | fi 24 | 25 | if dpkg -s "python3-dev" >/dev/null 2>&1; then 26 | echo "python3-dev installed" 27 | else 28 | sudo apt -y install python3-dev 29 | fi 30 | 31 | if dpkg -s "ninja-build" >/dev/null 2>&1; then 32 | echo "ninja installed" 33 | else 34 | sudo apt -y install ninja-build 35 | fi 36 | 37 | if dpkg -s "cuda" >/dev/null 2>&1 && dpkg -s "cuda" | grep Version | awk '{print $2}' | grep -q "12"; then 38 | echo "cuda 12 installed" 39 | else 40 | wget -N https://developer.download.nvidia.com/compute/cuda/repos/wsl-ubuntu/x86_64/cuda-wsl-ubuntu.pin 41 | sudo mv cuda-wsl-ubuntu.pin /etc/apt/preferences.d/cuda-repository-pin-600 42 | wget -N https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda-repo-wsl-ubuntu-12-2-local_12.2.0-1_amd64.deb 43 | sudo dpkg -i cuda-repo-wsl-ubuntu-12-2-local_12.2.0-1_amd64.deb 44 | sudo cp /var/cuda-repo-wsl-ubuntu-12-2-local/cuda-*-keyring.gpg /usr/share/keyrings/ 45 | sudo apt-get update 46 | sudo apt-get -y install cuda 47 | fi 48 | 49 | if python3 -c "import pkg_resources; pkg_resources.require(open('./finetune/requirements.txt',mode='r'))" &>/dev/null; then 50 | echo "requirements satisfied" 51 | else 52 | python3 -m pip install -r ./finetune/requirements.txt 53 | fi 54 | 55 | echo "loading $loadModel" 56 | modelInfo=$(python3 ./finetune/get_layer_and_embd.py $loadModel 6.0) 57 | echo $modelInfo 58 | if [[ $modelInfo =~ "--n_layer" ]]; then 59 | sudo rm -rf /root/.cache/torch_extensions 60 | python3 ./finetune/lora/$modelInfo $@ --proj_dir lora-models --data_type binidx --lora \ 61 | --lora_parts=att,ffn,time,ln --strategy deepspeed_stage_2 --accelerator gpu --ds_bucket_mb 2 62 | else 63 | echo "modelInfo is invalid" 64 | exit 1 65 | fi 66 | -------------------------------------------------------------------------------- /finetune/lora/merge_lora.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import os 3 | import sys 4 | from typing import Dict 5 | import typing 6 | import torch 7 | 8 | try: 9 | if "-h" in sys.argv or "--help" in sys.argv: 10 | print( 11 | f"Usage: python3 {sys.argv[0]} [--use-gpu] <lora_alpha> <base_model.pth> <lora_checkpoint.pth> <output.pth>" 12 | ) 13 | 14 | if sys.argv[1] == "--use-gpu": 15 | device = "cuda" 16 | lora_alpha, base_model, lora, output = ( 17 | float(sys.argv[2]), 18 | sys.argv[3], 19 | sys.argv[4], 20 | sys.argv[5], 21 | ) 22 | else: 23 | device = "cpu" 24 | lora_alpha, base_model, lora, output = ( 25 | float(sys.argv[1]), 26 | sys.argv[2], 27 | sys.argv[3], 28 | sys.argv[4], 29 | ) 30 | 31 | with torch.no_grad(): 32 | w: Dict[str, torch.Tensor] = torch.load(base_model, map_location="cpu") 33 | # merge LoRA-only slim checkpoint into the main weights 34 | w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location="cpu") 35 | for k in w_lora.keys(): 36 | w[k] = w_lora[k] 37 | output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict() 38 | # merge LoRA weights 39 | keys = list(w.keys()) 40 | for k in keys: 41 | if k.endswith(".weight"): 42 | prefix = k[: -len(".weight")] 43 | lora_A = prefix + ".lora_A" 44 | lora_B = prefix + ".lora_B" 45 | if lora_A in keys: 46 | assert lora_B in keys 47 | print(f"merging {lora_A} and {lora_B} into {k}") 48 | assert w[lora_B].shape[1] == w[lora_A].shape[0] 49 | lora_r = w[lora_B].shape[1] 50 | w[k] = w[k].to(device=device) 51 | w[lora_A] = w[lora_A].to(device=device) 52 | w[lora_B] = w[lora_B].to(device=device) 53 | w[k] += w[lora_B] @ w[lora_A] * (lora_alpha / lora_r) 54 | output_w[k] = w[k].to(device="cpu", copy=True) 55 | del w[k] 56 | del w[lora_A] 57 | del w[lora_B] 58 | continue 59 | 60 | if "lora" not in k: 61 | print(f"retaining {k}") 62 | output_w[k] = w[k].clone() 63 | del w[k] 64 | 65 | torch.save(output_w, output) 66 | except Exception as e: 67 | print(e) 68 | with open("error.txt", "w") as f: 69 | f.write(str(e)) 70 | -------------------------------------------------------------------------------- /finetune/lora/v4/cuda/wkv_op.cpp: -------------------------------------------------------------------------------- 1 | #include <torch/extension.h> 2 | 3 | void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); 4 | void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv); 5 | 6 | void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { 7 | cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>()); 8 | } 9 | void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { 10 | cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>()); 11 | } 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("forward", &forward, "wkv forward"); 15 | m.def("backward", &backward, "wkv backward"); 16 | } 17 | 18 | TORCH_LIBRARY(wkv, m) { 19 | m.def("forward", forward); 20 | m.def("backward", backward); 21 | } 22 | -------------------------------------------------------------------------------- /finetune/lora/v4/cuda/wkv_op_bf16.cpp: -------------------------------------------------------------------------------- 1 | #include <torch/extension.h> 2 | #include "ATen/ATen.h" 3 | typedef at::BFloat16 bf16; 4 | 5 | void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y); 6 | void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv); 7 | 8 | void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { 9 | cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>()); 10 | } 11 | void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, 12 | torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { 13 | cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(), 14 | gy.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>()); 15 | } 16 | 17 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 18 | m.def("forward", &forward, "wkv forward"); 19 | m.def("backward", &backward, "wkv backward"); 20 | } 21 | 22 | TORCH_LIBRARY(wkv, m) { 23 | m.def("forward", forward); 24 | m.def("backward", backward); 25 | } 26 | -------------------------------------------------------------------------------- /finetune/lora/v4/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/finetune/lora/v4/src/__init__.py -------------------------------------------------------------------------------- /finetune/lora/v5/cuda/wkv5_op.cpp: -------------------------------------------------------------------------------- 1 | #include <torch/extension.h> 2 | #include "ATen/ATen.h" 3 | typedef at::BFloat16 bf16; 4 | 5 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); 6 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu); 7 | 8 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 9 | cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>()); 10 | } 11 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) { 12 | cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), ww.data_ptr<float>(), u.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>()); 13 | } 14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 15 | m.def("forward", &forward, "wkv5 forward"); 16 | m.def("backward", &backward, "wkv5 backward"); 17 | } 18 | 19 | TORCH_LIBRARY(wkv5, m) { 20 | m.def("forward", forward); 21 | m.def("backward", backward); 22 | } 23 | -------------------------------------------------------------------------------- /finetune/lora/v5/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/finetune/lora/v5/src/__init__.py -------------------------------------------------------------------------------- /finetune/lora/v6/cuda/wkv5_op.cpp: -------------------------------------------------------------------------------- 1 | #include <torch/extension.h> 2 | #include "ATen/ATen.h" 3 | typedef at::BFloat16 bf16; 4 | 5 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); 6 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu); 7 | 8 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 9 | cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>()); 10 | } 11 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) { 12 | cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), ww.data_ptr<float>(), u.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>()); 13 | } 14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 15 | m.def("forward", &forward, "wkv5 forward"); 16 | m.def("backward", &backward, "wkv5 backward"); 17 | } 18 | 19 | TORCH_LIBRARY(wkv5, m) { 20 | m.def("forward", forward); 21 | m.def("backward", backward); 22 | } 23 | -------------------------------------------------------------------------------- /finetune/lora/v6/cuda/wkv6_op.cpp: -------------------------------------------------------------------------------- 1 | #include <torch/extension.h> 2 | #include "ATen/ATen.h" 3 | typedef at::BFloat16 bf16; 4 | 5 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y); 6 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu); 7 | 8 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) { 9 | cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>()); 10 | } 11 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) { 12 | cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>()); 13 | } 14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 15 | m.def("forward", &forward, "wkv6 forward"); 16 | m.def("backward", &backward, "wkv6 backward"); 17 | } 18 | 19 | TORCH_LIBRARY(wkv6, m) { 20 | m.def("forward", forward); 21 | m.def("backward", backward); 22 | } 23 | -------------------------------------------------------------------------------- /finetune/lora/v6/cuda/wkv6infctx_op.cpp: -------------------------------------------------------------------------------- 1 | #include <torch/extension.h> 2 | #include "ATen/ATen.h" 3 | typedef at::BFloat16 bf16; 4 | 5 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *y); 6 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu, bf16 *gs); 7 | 8 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &y) { 9 | cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<bf16>(), u.data_ptr<bf16>(), s.data_ptr<bf16>(), y.data_ptr<bf16>()); 10 | } 11 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gs) { 12 | cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<bf16>(), u.data_ptr<bf16>(), s.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gs.data_ptr<bf16>()); 13 | } 14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 15 | m.def("forward", &forward, "wkv6state forward"); 16 | m.def("backward", &backward, "wkv6state backward"); 17 | } 18 | 19 | TORCH_LIBRARY(wkv6state, m) { 20 | m.def("forward", forward); 21 | m.def("backward", backward); 22 | } 23 | -------------------------------------------------------------------------------- /finetune/lora/v6/cuda/wkv6state_op.cpp: -------------------------------------------------------------------------------- 1 | #include <torch/extension.h> 2 | #include "ATen/ATen.h" 3 | typedef at::BFloat16 bf16; 4 | 5 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *y); 6 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu, bf16 *gs); 7 | 8 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &y) { 9 | cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<bf16>(), u.data_ptr<bf16>(), s.data_ptr<bf16>(), y.data_ptr<bf16>()); 10 | } 11 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gs) { 12 | cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<bf16>(), u.data_ptr<bf16>(), s.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gs.data_ptr<bf16>()); 13 | } 14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 15 | m.def("forward", &forward, "wkv6state forward"); 16 | m.def("backward", &backward, "wkv6state backward"); 17 | } 18 | 19 | TORCH_LIBRARY(wkv6state, m) { 20 | m.def("forward", forward); 21 | m.def("backward", backward); 22 | } 23 | -------------------------------------------------------------------------------- /finetune/lora/v6/demo/demo-lora-merge.sh: -------------------------------------------------------------------------------- 1 | 2 | base_model='/home/rwkv/JL/model/rwkv-x060-7b-world-v2.1-36%trained-20240413-ctx4k.pth' 3 | lora_init='/home/rwkv/JL/out_model/nf4/init_lora.pth' 4 | lora_checkpoint='/home/rwkv/JL/out_model/nf4/rwkv-0.pth' 5 | output='/home/rwkv/JL/model/nf4-world.pth' 6 | QUANT='nf4' #follow train 7 | TYPE='lora' 8 | Lora_alpha=128 9 | 10 | python merge/merge.py --base_model $base_model \ 11 | --lora_init $lora_init \ 12 | --lora_checkpoint $lora_checkpoint \ 13 | --output $output \ 14 | --quant $QUANT \ 15 | --type $TYPE \ 16 | --lora_alpha $Lora_alpha -------------------------------------------------------------------------------- /finetune/lora/v6/demo/demo-lora.sh: -------------------------------------------------------------------------------- 1 | load_model='/home/rwkv/JL/model/rwkv-x060-7b-world-v2.1-36%trained-20240413-ctx4k.pth' 2 | proj_dir='/home/rwkv/JL/out_model/nf4' 3 | data_file='/home/rwkv/JL/data/roleplay' 4 | 5 | QUANT='nf4' #4bit nf4 fp4 none 6 | 7 | lora_r=64 8 | lora_alpha=128 9 | 10 | n_layer=32 11 | n_embd=4096 12 | 13 | micro_bsz=8 14 | epoch_save=1 15 | epoch_steps=1000 16 | ctx_len=1024 17 | 18 | python train.py --load_model $load_model \ 19 | --proj_dir $proj_dir --data_file $data_file \ 20 | --data_type binidx --vocab_size 65536 \ 21 | --ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 20 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \ 22 | --n_layer $n_layer --n_embd $n_embd \ 23 | --pre_ffn 0 --head_qk 0 --lr_init 5e-5 --lr_final 5e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \ 24 | --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 1 \ 25 | --my_testing "x060" \ 26 | --lora_load rwkv-0 --lora --lora_r $lora_r --lora_alpha $lora_alpha --lora_dropout 0.01 --lora_parts=att,ffn,time,ln \ 27 | --quant $QUANT -------------------------------------------------------------------------------- /finetune/lora/v6/demo/demo-pissa-merge.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | base_model='/home/rwkv/JL/model/RWKV-x060-World-1B6-v2-20240208-ctx4096.pth' 4 | lora_init='/home/rwkv/JL/out_model/nf4/init_lora.pth' 5 | lora_checkpoint='/home/rwkv/JL/out_model/nf4/rwkv-0.pth' 6 | output='/home/rwkv/JL/model/end-world.pth' 7 | QUANT='nf4' #follow train 8 | TYPE='pissa' 9 | 10 | python merge/merge.py --base_model $base_model \ 11 | --lora_init $lora_init \ 12 | --lora_checkpoint $lora_checkpoint \ 13 | --output $output \ 14 | --quant $QUANT \ 15 | --type $TYPE -------------------------------------------------------------------------------- /finetune/lora/v6/demo/demo-pissa.sh: -------------------------------------------------------------------------------- 1 | 2 | load_model='/home/rwkv/JL/model/RWKV-x060-World-1B6-v2.1-20240328-ctx4096.pth' 3 | proj_dir='/home/rwkv/JL/out_model/nf4' 4 | data_file='/home/rwkv/JL/data/end_text_document' 5 | 6 | QUANT='nf4' #4bit nf4 fp4 none 7 | svd_niter=4 8 | lora_r=64 9 | 10 | n_layer=24 11 | n_embd=2048 12 | 13 | micro_bsz=8 14 | epoch_save=1 15 | epoch_steps=1000 16 | ctx_len=1024 17 | 18 | python train.py --load_model $load_model \ 19 | --proj_dir $proj_dir --data_file $data_file \ 20 | --data_type binidx --vocab_size 65536 \ 21 | --ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 1 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \ 22 | --n_layer $n_layer --n_embd $n_embd \ 23 | --pre_ffn 0 --head_qk 0 --lr_init 5e-5 --lr_final 5e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \ 24 | --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 1 \ 25 | --my_testing "x060" \ 26 | --lora_load rwkv-0 --lora --lora_r $lora_r --lora_alpha 128 --lora_dropout 0.01 --lora_parts=att,ffn,time,ln \ 27 | --PISSA --svd_niter $svd_niter \ 28 | --dataload pad 29 | 30 | ###remove load_model 31 | # python train.py --proj_dir $proj_dir --data_file $data_file \ 32 | # --data_type binidx --vocab_size 65536 \ 33 | # --ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 20 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \ 34 | # --n_layer $n_layer --n_embd $n_embd \ 35 | # --pre_ffn 0 --head_qk 0 --lr_init 5e-5 --lr_final 5e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \ 36 | # --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 1 \ 37 | # --my_testing "x060" \ 38 | # --lora_load rwkv-0 --lora --lora_r $lora_r --lora_alpha 128 --lora_dropout 0.01 --lora_parts=att,ffn,time,ln \ 39 | # --PISSA --svd_niter $svd_niter \ 40 | # --quant $QUANT -------------------------------------------------------------------------------- /finetune/lora/v6/demo/demo-qpissa-pt.sh: -------------------------------------------------------------------------------- 1 | load_model='/home/rwkv/JL/model/rwkv-x060-7b-world-v2.1-36%trained-20240413-ctx4k.pth' 2 | proj_dir='/home/rwkv/JL/out_model/nf4' 3 | data_file='/home/rwkv/JL/data/roleplay' 4 | 5 | QUANT='nf4' #4bit nf4 fp4 none 6 | svd_niter=4 7 | lora_r=64 8 | 9 | n_layer=32 10 | n_embd=4096 11 | 12 | micro_bsz=4 13 | epoch_save=1 14 | epoch_steps=1000 15 | ctx_len=1024 16 | 17 | 18 | python train.py --proj_dir $proj_dir --data_file $data_file \ 19 | --data_type binidx --vocab_size 65536 \ 20 | --ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 20 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \ 21 | --n_layer $n_layer --n_embd $n_embd \ 22 | --pre_ffn 0 --head_qk 0 --lr_init 5e-5 --lr_final 5e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \ 23 | --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 1 \ 24 | --my_testing "x060" \ 25 | --lora_load rwkv-0 --lora --lora_r $lora_r --lora_alpha 128 --lora_dropout 0.01 --lora_parts=att,ffn,time,ln \ 26 | --PISSA --svd_niter $svd_niter \ 27 | --quant $QUANT -------------------------------------------------------------------------------- /finetune/lora/v6/demo/demo-state-merge.sh: -------------------------------------------------------------------------------- 1 | base_model='/home/rwkv/JL/model/RWKV-x060-World-3B-v2.1-20240417-ctx4096.pth' 2 | state_checkpoint='/home/rwkv/JL/out_model/state/rwkv-9.pth' 3 | output='/home/rwkv/JL/model/state-0.pth' 4 | 5 | 6 | python merge/merge_state.py --base_model $base_model \ 7 | --state_checkpoint $state_checkpoint \ 8 | --output $output -------------------------------------------------------------------------------- /finetune/lora/v6/demo/demo-state-tuning.sh: -------------------------------------------------------------------------------- 1 | load_model='/home/rwkv/JL/model/RWKV-x060-World-1B6-v2.1-20240328-ctx4096.pth' 2 | proj_dir='/home/rwkv/JL/out_model/state' 3 | data_file='/home/rwkv/JL/data/end_text_document' 4 | 5 | 6 | n_layer=24 7 | n_embd=2048 8 | 9 | micro_bsz=1 10 | epoch_save=1 11 | epoch_steps=1000 12 | ctx_len=1024 13 | 14 | python train.py --load_model $load_model \ 15 | --proj_dir $proj_dir --data_file $data_file \ 16 | --data_type binidx --vocab_size 65536 \ 17 | --ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 1 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \ 18 | --n_layer $n_layer --n_embd $n_embd \ 19 | --pre_ffn 0 --head_qk 0 --lr_init 1 --lr_final 1e-1 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \ 20 | --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 0 \ 21 | --my_testing "x060" \ 22 | --train_type "state" --dataload pad --wandb fla --fla -------------------------------------------------------------------------------- /finetune/lora/v6/demo/demo-training-prepare.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Create data directory 4 | 5 | mkdir -p data 6 | 7 | # Download minipile (1498226207 tokens, around 3GB) 8 | 9 | wget --continue -O data/minipile.idx https://huggingface.co/datasets/BlinkDL/minipile-tokenized/resolve/main/rwkv_vocab_v20230424/minipile.idx 10 | wget --continue -O data/minipile.bin https://huggingface.co/datasets/BlinkDL/minipile-tokenized/resolve/main/rwkv_vocab_v20230424/minipile.bin 11 | 12 | # Generate initial model (L12-D768 = 169M) 13 | 14 | BASE_NAME="model/0.1-1" 15 | N_LAYER="12" 16 | N_EMBD="768" 17 | 18 | # magic_prime = the largest 3n+2 prime smaller than datalen/ctxlen-1 (= 1498226207/512-1 = 2926222.06 in this case) 19 | # use https://www.dcode.fr/prime-numbers-search 20 | 21 | python train.py --wandb "" --proj_dir $BASE_NAME \ 22 | --data_file "data/minipile" --data_type "binidx" --vocab_size 65536 \ 23 | --ctx_len 512 --my_pile_stage 1 --epoch_count 1 --epoch_begin 0 \ 24 | --epoch_save 1 --weight_decay 0 --head_size_a 64 \ 25 | --num_nodes 1 --micro_bsz 1 --n_layer $N_LAYER --n_embd $N_EMBD --pre_ffn 0 --head_qk 0 --my_exit_tokens 1498226207 --magic_prime 2926181 \ 26 | --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 10 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 --my_pile_edecay 0 \ 27 | --accelerator cpu --devices 1 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0 --enable_progress_bar False --ds_bucket_mb 200 28 | -------------------------------------------------------------------------------- /finetune/lora/v6/demo/demo-training-run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | BASE_NAME="model/0.1-1" 4 | N_LAYER="12" 5 | N_EMBD="768" 6 | M_BSZ="16" # takes 16G VRAM (reduce this to save VRAM) 7 | LR_INIT="6e-4" 8 | LR_FINAL="6e-5" 9 | GRAD_CP=0 # set to 1 to save VRAM (will be slower) 10 | EPOCH_SAVE=10 11 | 12 | # magic_prime = the largest 3n+2 prime smaller than datalen/ctxlen-1 (= 1498226207/512-1 = 2926222.06 in this case) 13 | # use https://www.dcode.fr/prime-numbers-search 14 | 15 | python train.py --load_model "0" --wandb "RWKV-5-Test" --proj_dir $BASE_NAME \ 16 | --ctx_len 512 --my_pile_stage 3 --epoch_count 999999 --epoch_begin 0 \ 17 | --data_file "data/minipile" --my_exit_tokens 1498226207 --magic_prime 2926181 \ 18 | --num_nodes 1 --micro_bsz $M_BSZ --n_layer $N_LAYER --n_embd $N_EMBD --pre_ffn 0 --head_qk 0 \ 19 | --lr_init $LR_INIT --lr_final $LR_FINAL --warmup_steps 10 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 --my_pile_edecay 0 --data_type "binidx" --vocab_size 65536 \ 20 | --weight_decay 0.001 --epoch_save $EPOCH_SAVE --head_size_a 64 \ 21 | --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2 --grad_cp $GRAD_CP --enable_progress_bar True --ds_bucket_mb 200 22 | -------------------------------------------------------------------------------- /finetune/lora/v6/demo/infctx.sh: -------------------------------------------------------------------------------- 1 | load_model='/home/rwkv/JL/model/RWKV-x060-World-1B6-v2.1-20240328-ctx4096.pth' 2 | proj_dir='/home/rwkv/JL/out_model/infctx' 3 | data_file='/home/rwkv/JL/data/roleplay' 4 | 5 | 6 | n_layer=24 7 | n_embd=2048 8 | 9 | micro_bsz=8 10 | epoch_save=5 11 | epoch_steps=1000 12 | ctx_len=16384 13 | chunk_ctx=2048 14 | 15 | 16 | python train.py --load_model $load_model \ 17 | --proj_dir $proj_dir --data_file $data_file \ 18 | --data_type binidx --vocab_size 65536 \ 19 | --ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 1 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \ 20 | --n_layer $n_layer --n_embd $n_embd \ 21 | --pre_ffn 0 --head_qk 0 --lr_init 1e-4 --lr_final 1e-4 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \ 22 | --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 1 \ 23 | --lora_load rwkv-0 --lora --lora_r 64 --lora_alpha 128 --lora_dropout 0.01 --lora_parts=att,ffn,time,ln \ 24 | --my_testing "x060" --dataload pad \ 25 | --train_type infctx --chunk_ctx $chunk_ctx --fla --wandb infctx -------------------------------------------------------------------------------- /finetune/lora/v6/fla/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from fla.layers import (ABCAttention, BasedLinearAttention, DeltaNet, 4 | GatedLinearAttention, HGRN2Attention, LinearAttention, 5 | MultiScaleRetention, ReBasedLinearAttention) 6 | from fla.models import (ABCForCausalLM, ABCModel, DeltaNetForCausalLM, 7 | DeltaNetModel, GLAForCausalLM, GLAModel, 8 | HGRN2ForCausalLM, HGRN2Model, HGRNForCausalLM, 9 | HGRNModel, LinearAttentionForCausalLM, 10 | LinearAttentionModel, RetNetForCausalLM, RetNetModel, 11 | RWKV6ForCausalLM, RWKV6Model, TransformerForCausalLM, 12 | TransformerModel) 13 | from fla.ops import (chunk_gla, chunk_retention, fused_chunk_based, 14 | fused_chunk_gla, fused_chunk_retention) 15 | 16 | __all__ = [ 17 | 'ABCAttention', 18 | 'BasedLinearAttention', 19 | 'DeltaNet', 20 | 'HGRN2Attention', 21 | 'GatedLinearAttention', 22 | 'LinearAttention', 23 | 'MultiScaleRetention', 24 | 'ReBasedLinearAttention', 25 | 'ABCForCausalLM', 26 | 'ABCModel', 27 | 'DeltaNetForCausalLM', 28 | 'DeltaNetModel', 29 | 'HGRNForCausalLM', 30 | 'HGRNModel', 31 | 'HGRN2ForCausalLM', 32 | 'HGRN2Model', 33 | 'GLAForCausalLM', 34 | 'GLAModel', 35 | 'LinearAttentionForCausalLM', 36 | 'LinearAttentionModel', 37 | 'RetNetForCausalLM', 38 | 'RetNetModel', 39 | 'RWKV6ForCausalLM', 40 | 'RWKV6Model', 41 | 'TransformerForCausalLM', 42 | 'TransformerModel', 43 | 'chunk_gla', 44 | 'chunk_retention', 45 | 'fused_chunk_based', 46 | 'fused_chunk_gla', 47 | 'fused_chunk_retention' 48 | ] 49 | 50 | __version__ = '0.1' 51 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .abc import ABCAttention 4 | from .based import BasedLinearAttention 5 | from .delta_net import DeltaNet 6 | from .gla import GatedLinearAttention 7 | from .hgrn import HGRNAttention 8 | from .hgrn2 import HGRN2Attention 9 | from .linear_attn import LinearAttention 10 | from .multiscale_retention import MultiScaleRetention 11 | from .rebased import ReBasedLinearAttention 12 | from .rwkv6 import RWKV6Attention 13 | 14 | __all__ = [ 15 | 'ABCAttention', 16 | 'BasedLinearAttention', 17 | 'DeltaNet', 18 | 'GatedLinearAttention', 19 | 'HGRNAttention', 20 | 'HGRN2Attention', 21 | 'LinearAttention', 22 | 'MultiScaleRetention', 23 | 'ReBasedLinearAttention', 24 | 'RWKV6Attention' 25 | ] 26 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from fla.models.abc import ABCConfig, ABCForCausalLM, ABCModel 4 | from fla.models.delta_net import (DeltaNetConfig, DeltaNetForCausalLM, 5 | DeltaNetModel) 6 | from fla.models.gla import GLAConfig, GLAForCausalLM, GLAModel 7 | from fla.models.hgrn import HGRNConfig, HGRNForCausalLM, HGRNModel 8 | from fla.models.hgrn2 import HGRN2Config, HGRN2ForCausalLM, HGRN2Model 9 | from fla.models.linear_attn import (LinearAttentionConfig, 10 | LinearAttentionForCausalLM, 11 | LinearAttentionModel) 12 | from fla.models.mamba import MambaConfig, MambaForCausalLM, MambaModel 13 | from fla.models.retnet import RetNetConfig, RetNetForCausalLM, RetNetModel 14 | from fla.models.rwkv6 import RWKV6Config, RWKV6ForCausalLM, RWKV6Model 15 | from fla.models.transformer import (TransformerConfig, TransformerForCausalLM, 16 | TransformerModel) 17 | 18 | __all__ = [ 19 | 'ABCConfig', 'ABCForCausalLM', 'ABCModel', 20 | 'DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel', 21 | 'GLAConfig', 'GLAForCausalLM', 'GLAModel', 22 | 'HGRNConfig', 'HGRNForCausalLM', 'HGRNModel', 23 | 'HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model', 24 | 'LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel', 25 | 'MambaConfig', 'MambaForCausalLM', 'MambaModel', 26 | 'RetNetConfig', 'RetNetForCausalLM', 'RetNetModel', 27 | 'RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model', 28 | 'TransformerConfig', 'TransformerForCausalLM', 'TransformerModel' 29 | ] 30 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/models/abc/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.abc.configuration_abc import ABCConfig 6 | from fla.models.abc.modeling_abc import ABCForCausalLM, ABCModel 7 | 8 | AutoConfig.register(ABCConfig.model_type, ABCConfig) 9 | AutoModel.register(ABCConfig, ABCModel) 10 | AutoModelForCausalLM.register(ABCConfig, ABCForCausalLM) 11 | 12 | 13 | __all__ = ['ABCConfig', 'ABCForCausalLM', 'ABCModel'] 14 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/models/abc/configuration_abc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class ABCConfig(PretrainedConfig): 9 | 10 | model_type = 'abc' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | vocab_size: int = 32000, 16 | hidden_size: int = 2048, 17 | gate_low_rank_dim: int = 16, 18 | clamp_min: float = -32, 19 | clamp_max: float = 32, 20 | hidden_ratio: Optional[int] = 4, 21 | intermediate_size: Optional[int] = None, 22 | num_hidden_layers: int = 24, 23 | num_heads: int = 4, 24 | num_slots: Optional[int] = 64, 25 | use_short_conv: bool = True, 26 | conv_size: int = 4, 27 | share_conv_kernel: bool = True, 28 | exapnd_k: float = 0.5, 29 | exapnd_v: float = 1, 30 | hidden_act: str = "swish", 31 | max_position_embeddings: int = 2048, 32 | elementwise_affine: Optional[bool] = True, 33 | norm_eps: float = 1e-6, 34 | use_cache: bool = True, 35 | pad_token_id: int = None, 36 | bos_token_id: int = 1, 37 | eos_token_id: int = 2, 38 | initializer_range: float = 0.02, 39 | tie_word_embeddings: bool = False, 40 | fuse_norm: bool = True, 41 | fuse_cross_entropy: bool = True, 42 | **kwargs 43 | ): 44 | self.vocab_size = vocab_size 45 | self.max_position_embeddings = max_position_embeddings 46 | self.hidden_size = hidden_size 47 | self.gate_low_rank_dim = gate_low_rank_dim 48 | self.clamp_min = clamp_min 49 | self.clamp_max = clamp_max 50 | self.hidden_ratio = hidden_ratio 51 | self.intermediate_size = intermediate_size 52 | self.num_hidden_layers = num_hidden_layers 53 | self.num_heads = num_heads 54 | self.num_slots = num_slots 55 | self.use_short_conv = use_short_conv 56 | self.conv_size = conv_size 57 | self.share_conv_kernel = share_conv_kernel 58 | self.expand_k = exapnd_k 59 | self.expand_v = exapnd_v 60 | self.hidden_act = hidden_act 61 | self.elementwise_affine = elementwise_affine 62 | self.norm_eps = norm_eps 63 | self.use_cache = use_cache 64 | self.initializer_range = initializer_range 65 | self.fuse_cross_entropy = fuse_cross_entropy 66 | self.fuse_norm = fuse_norm 67 | 68 | super().__init__( 69 | pad_token_id=pad_token_id, 70 | bos_token_id=bos_token_id, 71 | eos_token_id=eos_token_id, 72 | tie_word_embeddings=tie_word_embeddings, 73 | **kwargs, 74 | ) 75 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/models/delta_net/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.delta_net.configuration_delta_net import \ 6 | DeltaNetConfig 7 | from fla.models.delta_net.modeling_delta_net import ( 8 | DeltaNetForCausalLM, DeltaNetModel) 9 | 10 | AutoConfig.register(DeltaNetConfig.model_type, DeltaNetConfig) 11 | AutoModel.register(DeltaNetConfig, DeltaNetModel) 12 | AutoModelForCausalLM.register(DeltaNetConfig, DeltaNetForCausalLM) 13 | 14 | __all__ = ['DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel'] 15 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/models/gla/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.gla.configuration_gla import GLAConfig 6 | from fla.models.gla.modeling_gla import GLAForCausalLM, GLAModel 7 | 8 | AutoConfig.register(GLAConfig.model_type, GLAConfig) 9 | AutoModel.register(GLAConfig, GLAModel) 10 | AutoModelForCausalLM.register(GLAConfig, GLAForCausalLM) 11 | 12 | 13 | __all__ = ['GLAConfig', 'GLAForCausalLM', 'GLAModel'] 14 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/models/hgrn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.hgrn.configuration_hgrn import HGRNConfig 6 | from fla.models.hgrn.modeling_hgrn import HGRNForCausalLM, HGRNModel 7 | 8 | AutoConfig.register(HGRNConfig.model_type, HGRNConfig) 9 | AutoModel.register(HGRNConfig, HGRNModel) 10 | AutoModelForCausalLM.register(HGRNConfig, HGRNForCausalLM) 11 | 12 | 13 | __all__ = ['HGRNConfig', 'HGRNForCausalLM', 'HGRNModel'] 14 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/models/hgrn/configuration_hgrn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class HGRNConfig(PretrainedConfig): 9 | 10 | model_type = 'hgrn' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | attn_mode: str = "chunk", 16 | vocab_size: int = 32000, 17 | hidden_size: int = 2048, 18 | num_hidden_layers: int = 24, 19 | num_heads: Optional[int] = 1, 20 | expand_ratio: Optional[int] = 1, 21 | use_short_conv: bool = False, 22 | conv_size: int = 4, 23 | share_conv_kernel: bool = True, 24 | use_lower_bound: bool = True, 25 | hidden_ratio: Optional[int] = 4, 26 | intermediate_size: Optional[int] = None, 27 | hidden_act: str = "swish", 28 | max_position_embeddings: int = 2048, 29 | elementwise_affine: Optional[bool] = True, 30 | norm_eps: float = 1e-6, 31 | use_cache: bool = True, 32 | pad_token_id: int = None, 33 | bos_token_id: int = 1, 34 | eos_token_id: int = 2, 35 | tie_word_embeddings: bool = False, 36 | initializer_range: float = 0.02, 37 | fuse_cross_entropy: bool = True, 38 | **kwargs 39 | ): 40 | self.attn_mode = attn_mode 41 | self.vocab_size = vocab_size 42 | self.max_position_embeddings = max_position_embeddings 43 | self.hidden_size = hidden_size 44 | self.num_hidden_layers = num_hidden_layers 45 | self.num_heads = num_heads 46 | self.expand_ratio = expand_ratio 47 | self.use_short_conv = use_short_conv 48 | self.conv_size = conv_size 49 | self.share_conv_kernel = share_conv_kernel 50 | self.use_lower_bound = use_lower_bound 51 | self.hidden_ratio = hidden_ratio 52 | self.intermediate_size = intermediate_size 53 | self.hidden_act = hidden_act 54 | self.elementwise_affine = elementwise_affine 55 | self.norm_eps = norm_eps 56 | self.use_cache = use_cache 57 | self.initializer_range = initializer_range 58 | self.fuse_cross_entropy = fuse_cross_entropy 59 | 60 | super().__init__( 61 | pad_token_id=pad_token_id, 62 | bos_token_id=bos_token_id, 63 | eos_token_id=eos_token_id, 64 | tie_word_embeddings=tie_word_embeddings, 65 | **kwargs, 66 | ) 67 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/models/hgrn2/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.hgrn2.configuration_hgrn2 import HGRN2Config 6 | from fla.models.hgrn2.modeling_hgrn2 import HGRN2ForCausalLM, HGRN2Model 7 | 8 | AutoConfig.register(HGRN2Config.model_type, HGRN2Config) 9 | AutoModel.register(HGRN2Config, HGRN2Model) 10 | AutoModelForCausalLM.register(HGRN2Config, HGRN2ForCausalLM) 11 | 12 | 13 | __all__ = ['HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model'] 14 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/models/hgrn2/configuration_hgrn2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class HGRN2Config(PretrainedConfig): 9 | 10 | model_type = 'hgrn2' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | vocab_size: int = 32000, 16 | hidden_size: int = 2048, 17 | num_hidden_layers: int = 24, 18 | attn_mode: str = "chunk", 19 | num_heads: Optional[int] = None, 20 | expand_ratio: Optional[int] = 128, 21 | use_short_conv: bool = False, 22 | conv_size: int = 4, 23 | share_conv_kernel: bool = True, 24 | use_lower_bound: bool = True, 25 | hidden_ratio: Optional[int] = 4, 26 | intermediate_size: Optional[int] = None, 27 | hidden_act: str = "swish", 28 | max_position_embeddings: int = 2048, 29 | elementwise_affine: Optional[bool] = True, 30 | norm_eps: float = 1e-6, 31 | use_cache: bool = True, 32 | pad_token_id: int = None, 33 | bos_token_id: int = 1, 34 | eos_token_id: int = 2, 35 | tie_word_embeddings: bool = False, 36 | initializer_range: float = 0.02, 37 | fuse_cross_entropy: bool = True, 38 | **kwargs 39 | ): 40 | self.vocab_size = vocab_size 41 | self.max_position_embeddings = max_position_embeddings 42 | self.hidden_size = hidden_size 43 | self.num_hidden_layers = num_hidden_layers 44 | self.attn_mode = attn_mode 45 | self.num_heads = num_heads 46 | self.expand_ratio = expand_ratio 47 | self.use_short_conv = use_short_conv 48 | self.conv_size = conv_size 49 | self.share_conv_kernel = share_conv_kernel 50 | self.use_lower_bound = use_lower_bound 51 | self.hidden_ratio = hidden_ratio 52 | self.intermediate_size = intermediate_size 53 | self.hidden_act = hidden_act 54 | self.elementwise_affine = elementwise_affine 55 | self.norm_eps = norm_eps 56 | self.use_cache = use_cache 57 | self.initializer_range = initializer_range 58 | self.fuse_cross_entropy = fuse_cross_entropy 59 | 60 | super().__init__( 61 | pad_token_id=pad_token_id, 62 | bos_token_id=bos_token_id, 63 | eos_token_id=eos_token_id, 64 | tie_word_embeddings=tie_word_embeddings, 65 | **kwargs, 66 | ) 67 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/models/linear_attn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.linear_attn.configuration_linear_attn import \ 6 | LinearAttentionConfig 7 | from fla.models.linear_attn.modeling_linear_attn import ( 8 | LinearAttentionForCausalLM, LinearAttentionModel) 9 | 10 | AutoConfig.register(LinearAttentionConfig.model_type, LinearAttentionConfig) 11 | AutoModel.register(LinearAttentionConfig, LinearAttentionModel) 12 | AutoModelForCausalLM.register(LinearAttentionConfig, LinearAttentionForCausalLM) 13 | 14 | __all__ = ['LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel'] 15 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/models/linear_attn/configuration_linear_attn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class LinearAttentionConfig(PretrainedConfig): 9 | 10 | model_type = 'linear_attn' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | vocab_size: int = 32000, 16 | hidden_size: int = 2048, 17 | expand_k: int = 1, 18 | expand_v: int = 1, 19 | hidden_ratio: Optional[int] = 4, 20 | intermediate_size: Optional[int] = None, 21 | num_hidden_layers: int = 24, 22 | num_heads: int = 4, 23 | attn_mode: str = "fused_chunk", 24 | feature_map: str = "elementwise_product", 25 | tie_feature_map_qk: bool = False, 26 | norm_q: bool = False, 27 | norm_k: bool = False, 28 | norm_feature_map: bool = False, 29 | hidden_act: str = "swish", 30 | max_position_embeddings: int = 2048, 31 | elementwise_affine: Optional[bool] = True, 32 | norm_eps: float = 1e-6, 33 | use_cache: bool = True, 34 | pad_token_id: int = None, 35 | bos_token_id: int = 1, 36 | eos_token_id: int = 2, 37 | tie_word_embeddings: bool = False, 38 | initializer_range: float = 0.02, 39 | fuse_cross_entropy: bool = True, 40 | **kwargs 41 | ): 42 | self.vocab_size = vocab_size 43 | self.max_position_embeddings = max_position_embeddings 44 | self.hidden_size = hidden_size 45 | self.expand_k = expand_k 46 | self.expand_v = expand_v 47 | self.hidden_ratio = hidden_ratio 48 | self.intermediate_size = intermediate_size 49 | self.num_hidden_layers = num_hidden_layers 50 | self.num_heads = num_heads 51 | self.attn_mode = attn_mode 52 | self.feature_map = feature_map 53 | self.tie_feature_map_qk = tie_feature_map_qk 54 | self.norm_q = norm_q 55 | self.norm_k = norm_k 56 | self.norm_feature_map = norm_feature_map 57 | self.hidden_act = hidden_act 58 | self.elementwise_affine = elementwise_affine 59 | self.norm_eps = norm_eps 60 | self.use_cache = use_cache 61 | self.initializer_range = initializer_range 62 | self.fuse_cross_entropy = fuse_cross_entropy 63 | 64 | super().__init__( 65 | pad_token_id=pad_token_id, 66 | bos_token_id=bos_token_id, 67 | eos_token_id=eos_token_id, 68 | tie_word_embeddings=tie_word_embeddings, 69 | **kwargs, 70 | ) 71 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/models/mamba/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.mamba.configuration_mamba import MambaConfig 6 | from fla.models.mamba.modeling_mamba import (MambaBlock, MambaForCausalLM, 7 | MambaModel) 8 | 9 | AutoConfig.register(MambaConfig.model_type, MambaConfig, True) 10 | AutoModel.register(MambaConfig, MambaModel, True) 11 | AutoModelForCausalLM.register(MambaConfig, MambaForCausalLM, True) 12 | 13 | 14 | __all__ = ['MambaConfig', 'MambaForCausalLM', 'MambaModel', 'MambaBlock'] 15 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/models/retnet/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.retnet.configuration_retnet import RetNetConfig 6 | from fla.models.retnet.modeling_retnet import RetNetForCausalLM, RetNetModel 7 | 8 | AutoConfig.register(RetNetConfig.model_type, RetNetConfig) 9 | AutoModel.register(RetNetConfig, RetNetModel) 10 | AutoModelForCausalLM.register(RetNetConfig, RetNetForCausalLM) 11 | 12 | 13 | __all__ = ['RetNetConfig', 'RetNetForCausalLM', 'RetNetModel'] 14 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/models/rwkv6/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.rwkv6.configuration_rwkv6 import RWKV6Config 6 | from fla.models.rwkv6.modeling_rwkv6 import RWKV6ForCausalLM, RWKV6Model 7 | 8 | AutoConfig.register(RWKV6Config.model_type, RWKV6Config) 9 | AutoModel.register(RWKV6Config, RWKV6Model) 10 | AutoModelForCausalLM.register(RWKV6Config, RWKV6ForCausalLM) 11 | 12 | 13 | __all__ = ['RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model'] 14 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/models/rwkv6/configuration_rwkv6.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class RWKV6Config(PretrainedConfig): 9 | 10 | model_type = 'rwkv6' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | attn_mode: str = "chunk", 16 | vocab_size: int = 32000, 17 | hidden_size: int = 2048, 18 | expand_k: int = 0.5, 19 | expand_v: int = 1, 20 | hidden_ratio: Optional[int] = 3.5, 21 | intermediate_size: Optional[int] = None, 22 | use_glu: Optional[bool] = False, 23 | num_hidden_layers: int = 24, 24 | num_heads: int = 4, 25 | proj_low_rank_dim: int = 32, 26 | gate_low_rank_dim: int = 64, 27 | hidden_act: str = "sqrelu", 28 | max_position_embeddings: int = 2048, 29 | eps: float = 1e-6, 30 | use_cache: bool = True, 31 | pad_token_id: int = None, 32 | bos_token_id: int = 1, 33 | eos_token_id: int = 2, 34 | tie_word_embeddings: bool = False, 35 | initializer_range: float = 0.02, 36 | fuse_norm: bool = True, 37 | fuse_cross_entropy: bool = True, 38 | **kwargs 39 | ): 40 | self.vocab_size = vocab_size 41 | self.max_position_embeddings = max_position_embeddings 42 | self.hidden_size = hidden_size 43 | self.expand_k = expand_k 44 | self.expand_v = expand_v 45 | self.hidden_ratio = hidden_ratio 46 | self.intermediate_size = intermediate_size 47 | self.use_glu = use_glu 48 | self.num_hidden_layers = num_hidden_layers 49 | self.num_heads = num_heads 50 | self.proj_low_rank_dim = proj_low_rank_dim 51 | self.gate_low_rank_dim = gate_low_rank_dim 52 | self.attn_mode = attn_mode 53 | self.hidden_act = hidden_act 54 | self.eps = eps 55 | self.use_cache = use_cache 56 | self.initializer_range = initializer_range 57 | self.fuse_norm = fuse_norm 58 | self.fuse_cross_entropy = fuse_cross_entropy 59 | 60 | super().__init__( 61 | pad_token_id=pad_token_id, 62 | bos_token_id=bos_token_id, 63 | eos_token_id=eos_token_id, 64 | tie_word_embeddings=tie_word_embeddings, 65 | **kwargs, 66 | ) 67 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/models/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 4 | 5 | from fla.models.transformer.configuration_transformer import TransformerConfig 6 | from fla.models.transformer.modeling_transformer import ( 7 | TransformerForCausalLM, TransformerModel) 8 | 9 | AutoConfig.register(TransformerConfig.model_type, TransformerConfig) 10 | AutoModel.register(TransformerConfig, TransformerModel) 11 | AutoModelForCausalLM.register(TransformerConfig, TransformerForCausalLM) 12 | 13 | 14 | __all__ = ['TransformerConfig', 'TransformerForCausalLM', 'TransformerModel'] 15 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/models/transformer/configuration_transformer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | 7 | 8 | class TransformerConfig(PretrainedConfig): 9 | 10 | model_type = 'transformer' 11 | keys_to_ignore_at_inference = ['past_key_values'] 12 | 13 | def __init__( 14 | self, 15 | vocab_size: int = 32000, 16 | hidden_size: int = 2048, 17 | hidden_ratio: Optional[int] = 4, 18 | intermediate_size: Optional[int] = None, 19 | num_hidden_layers: int = 24, 20 | num_heads: int = 32, 21 | num_kv_heads: int = None, 22 | hidden_act: str = "swish", 23 | max_position_embeddings: int = 2048, 24 | initializer_range: float = 0.02, 25 | elementwise_affine: Optional[bool] = True, 26 | norm_eps: float = 1e-6, 27 | use_cache: bool = True, 28 | pad_token_id: int = None, 29 | bos_token_id: int = 1, 30 | eos_token_id: int = 2, 31 | tie_word_embeddings: bool = False, 32 | attention_bias: bool = False, 33 | fuse_norm: bool = True, 34 | fuse_cross_entropy: bool = True, 35 | **kwargs, 36 | ): 37 | self.vocab_size = vocab_size 38 | self.max_position_embeddings = max_position_embeddings 39 | self.hidden_size = hidden_size 40 | self.hidden_ratio = hidden_ratio 41 | self.intermediate_size = intermediate_size 42 | self.num_hidden_layers = num_hidden_layers 43 | self.num_heads = num_heads 44 | self.num_kv_heads = num_kv_heads 45 | 46 | self.hidden_act = hidden_act 47 | self.initializer_range = initializer_range 48 | self.elementwise_affine = elementwise_affine 49 | self.norm_eps = norm_eps 50 | self.use_cache = use_cache 51 | self.attention_bias = attention_bias 52 | self.fuse_cross_entropy = fuse_cross_entropy 53 | self.fuse_norm = fuse_norm 54 | 55 | super().__init__( 56 | pad_token_id=pad_token_id, 57 | bos_token_id=bos_token_id, 58 | eos_token_id=eos_token_id, 59 | tie_word_embeddings=tie_word_embeddings, 60 | **kwargs, 61 | ) 62 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from fla.modules.convolution import (ImplicitLongConvolution, LongConvolution, 4 | ShortConvolution) 5 | from fla.modules.fused_cross_entropy import FusedCrossEntropyLoss 6 | from fla.modules.fused_norm_gate import (FusedLayerNormSwishGate, 7 | FusedLayerNormSwishGateLinear, 8 | FusedRMSNormSwishGate, 9 | FusedRMSNormSwishGateLinear) 10 | from fla.modules.layernorm import (LayerNorm, LayerNormLinear, RMSNorm, 11 | RMSNormLinear) 12 | from fla.modules.rotary import RotaryEmbedding 13 | 14 | __all__ = [ 15 | 'ImplicitLongConvolution', 'LongConvolution', 'ShortConvolution', 16 | 'FusedCrossEntropyLoss', 17 | 'LayerNorm', 'LayerNormLinear', 'RMSNorm', 'RMSNormLinear', 18 | 'FusedLayerNormSwishGate', 'FusedLayerNormSwishGateLinear', 'FusedRMSNormSwishGate', 'FusedRMSNormSwishGateLinear', 19 | 'RotaryEmbedding' 20 | ] 21 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .based import fused_chunk_based, parallel_based 4 | from .gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla 5 | from .retention import (chunk_retention, fused_chunk_retention, 6 | fused_recurrent_retention, parallel_retention) 7 | 8 | __all__ = [ 9 | 'fused_chunk_based', 10 | 'parallel_based', 11 | 'chunk_gla', 12 | 'fused_chunk_gla', 13 | 'fused_recurrent_gla', 14 | 'chunk_retention', 15 | 'fused_chunk_retention', 16 | 'fused_recurrent_retention', 17 | 'parallel_retention' 18 | ] 19 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/ops/abc/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_abc 4 | from .chunk_gate import chunk_gated_abc 5 | from .recurrent_fuse import fused_recurrent_gated_abc 6 | 7 | __all__ = [ 8 | 'chunk_abc', 9 | 'chunk_gated_abc', 10 | 'fused_recurrent_gated_abc' 11 | ] 12 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/ops/based/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk_fuse import fused_chunk_based 4 | from .parallel import parallel_based 5 | 6 | __all__ = [ 7 | 'fused_chunk_based', 8 | 'parallel_based' 9 | ] 10 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/ops/delta_rule/README.md: -------------------------------------------------------------------------------- 1 | - Delta Rule 2 | 3 | The implementation of delta rule described in https://arxiv.org/abs/2102.11174 4 | 5 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/ops/delta_rule/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk_fuse import fused_chunk_delta_rule 4 | from .recurrent_fuse import fused_recurrent_linear_attn_delta_rule 5 | from .chunk import chunk_delta_rule 6 | 7 | __all__ = [ 8 | 'fused_chunk_delta_rule', 9 | 'fused_recurrent_linear_attn_delta_rule', 10 | 'chunk_delta_rule' 11 | ] 12 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/ops/gla/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_gla 4 | from .chunk_fuse import fused_chunk_gla 5 | from .recurrent_fuse import fused_recurrent_gla 6 | 7 | __all__ = [ 8 | 'chunk_gla', 9 | 'fused_chunk_gla', 10 | 'fused_recurrent_gla' 11 | ] 12 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/ops/hgrn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_hgrn 4 | from .recurrent_fuse import fused_recurrent_hgrn 5 | 6 | __all__ = [ 7 | 'chunk_hgrn', 8 | 'fused_recurrent_hgrn' 9 | ] 10 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/ops/hgrn/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Optional 4 | 5 | import torch 6 | 7 | 8 | def naive_recurrent_hgrn( 9 | x: torch.Tensor, 10 | g: torch.Tensor, 11 | initial_state: Optional[torch.Tensor] = None, 12 | output_final_state: Optional[bool] = False 13 | ) -> torch.Tensor: 14 | dtype = x.dtype 15 | x, g = map(lambda i: i.float(), (x, g)) 16 | B, H, T, D = x.shape 17 | 18 | h = torch.zeros(B, H, D, dtype=torch.float, device=x.device) 19 | o = torch.zeros_like(x) 20 | 21 | final_state = None 22 | if initial_state is not None: 23 | h += initial_state.detach() 24 | 25 | for i in range(T): 26 | h = g[:, :, i].exp() * h + x[:, :, i] 27 | o[:, :, i] = h 28 | 29 | if output_final_state: 30 | final_state = h 31 | return o.to(dtype), final_state 32 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/ops/linear_attn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_linear_attn 4 | from .chunk_fuse import fused_chunk_linear_attn 5 | from .recurrent_fuse import fused_recurrent_linear_attn 6 | 7 | __all__ = [ 8 | 'chunk_linear_attn', 9 | 'fused_chunk_linear_attn', 10 | 'fused_recurrent_linear_attn' 11 | ] 12 | 13 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/ops/linear_attn/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from einops import rearrange 5 | 6 | 7 | def torch_chunk_linear_attn(q, k, v, chunk_size=64): 8 | q = rearrange(q, 'b h (n c) d -> b h n c d', c = chunk_size) * (q.shape[-1] **-0.5) 9 | k = rearrange(k, 'b h (n c) d -> b h n c d', c = chunk_size) 10 | v = rearrange(v, 'b h (n c) d -> b h n c d', c = chunk_size) 11 | kv = k.transpose(-1, -2) @ v 12 | kv = kv.cumsum(2) 13 | kv = torch.cat([ 14 | torch.zeros_like(kv[:, :, :1]), 15 | kv[:, :, :-1] 16 | ], dim=2) 17 | inter = q @ kv 18 | intra = ((q @ k.transpose(-1, -2)).masked_fill_(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0)) @ v 19 | o = inter + intra 20 | return rearrange(o, 'b h n c d -> b h (n c) d') 21 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/ops/rebased/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .parallel import parallel_rebased 4 | 5 | __all__ = [ 6 | 'parallel_rebased' 7 | ] 8 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/ops/retention/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_retention 4 | from .chunk_fuse import fused_chunk_retention 5 | from .parallel import parallel_retention 6 | from .recurrent_fuse import fused_recurrent_retention 7 | 8 | __all__ = [ 9 | 'chunk_retention', 10 | 'fused_chunk_retention', 11 | 'parallel_retention', 12 | 'fused_recurrent_retention' 13 | ] 14 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/ops/retention/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | 5 | 6 | def naive_retention(q, k, v): 7 | orig_type = q.dtype 8 | q, k, v = q.float(), k.float(), v.float() 9 | _, n_heads, seq_len, d_head = q.shape 10 | s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(n_heads), dtype=torch.float))).log2() 11 | n = q.new_tensor(range(seq_len), dtype=torch.float) 12 | n = torch.exp2((n.unsqueeze(-1) - n) * s.view(-1, 1, 1)) * n.unsqueeze(-1).ge(n) 13 | s = torch.einsum('bhqd,bhkd,hqk->bhqk', q * d_head ** -0.5, k, n.to(q.dtype)) 14 | o = torch.einsum('bhqk,bhkd->bhqd', s, v) 15 | return o.to(orig_type) 16 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/ops/rwkv4/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .recurrent_fuse import fused_recurrent_rwkv4 4 | 5 | __all__ = [ 6 | 'fused_recurrent_rwkv4' 7 | ] 8 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/ops/rwkv6/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_rwkv6 4 | from .recurrent_fuse import fused_recurrent_rwkv6 5 | 6 | __all__ = [ 7 | 'chunk_rwkv6', 8 | 'fused_recurrent_rwkv6' 9 | ] 10 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/ops/simple_gla/README.md: -------------------------------------------------------------------------------- 1 | - Simple GLA 2 | 3 | Gating mechanism in https://arxiv.org/abs/2103.02143. Compared to GLA, the gating is head-wise instead of elementwise. As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability. It is faster than GLA but has less expressive power. I will use it as a baseline for the GLA. 4 | 5 | $S_{t+1} = g_{t+1} \odot S_{t} + K_{t+1} V_{t+1}^{\top}$ where $g$ is a scalar. -------------------------------------------------------------------------------- /finetune/lora/v6/fla/ops/simple_gla/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .chunk import chunk_simple_gla 4 | 5 | __all__ = [ 6 | 'chunk_simple_gla' 7 | ] 8 | 9 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/ops/simple_gla/naive.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from einops import rearrange 5 | 6 | 7 | def torch_simple_gla(q, k, v, g, chunk_size=64): 8 | q = rearrange(q, 'b h (n c) d -> b h n c d', c = chunk_size) * (q.shape[-1] ** -0.5) 9 | k = rearrange(k, 'b h (n c) d -> b h n c d', c = chunk_size) 10 | v = rearrange(v, 'b h (n c) d -> b h n c d', c = chunk_size) 11 | g = rearrange(g, 'b h (n c) -> b h n c', c = chunk_size) 12 | g = g.cumsum(-1) 13 | kv = k.transpose(-1, -2) @ (v * (-g + g[:, :, :, -1, None]).exp()[..., None]) 14 | S = torch.zeros_like(kv) 15 | 16 | for i in range(1, g.shape[-2]): 17 | S[:, :, i] = S[:, :, i-1].clone() * g[:, :, i-1, -1, None, None].exp() + kv[:, :, i-1] 18 | 19 | inter = (q * g[..., None].exp()) @ S 20 | attn = q @ k.transpose(-1, -2) 21 | attn = attn * (g[..., None] - g[..., None, :]).exp() 22 | attn = attn.masked_fill(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0) 23 | intra = attn @ v 24 | o = inter + intra 25 | return rearrange(o, 'b h n c d -> b h (n c) d') 26 | 27 | 28 | def torch_simple_gla_recurrent(q, k, v, g, chunk_size=64): 29 | # q = rearrange(q, 'b h (n c) d -> b h n c d', c = chunk_size) * (q.shape[-1] ** -0.5) 30 | # k = rearrange(k, 'b h (n c) d -> b h n c d', c = chunk_size) 31 | # v = rearrange(v, 'b h (n c) d -> b h n c d', c = chunk_size) 32 | # g = rearrange(g, 'b h (n c) -> b h n c', c = chunk_size) 33 | # g = g.cumsum(-1) 34 | # kv = k.transpose(-1, -2) @ v 35 | 36 | B, H, T, DK = q.shape 37 | q = q * (DK ** -0.5) 38 | _, _, _, DV = v.shape 39 | S = torch.zeros(B, H, DK, DV).to(q) 40 | o = torch.zeros(B, H, T, DV).to(q) 41 | for i in range(T): 42 | gate = g[:, :, i].exp() 43 | key = k[:, :, i] 44 | value = v[:, :, i] 45 | kv = key.unsqueeze(-1) * value.unsqueeze(-2) 46 | S = S.clone() * gate.unsqueeze(-1).unsqueeze(-1) + kv 47 | q_i = q[:, :, i, :] 48 | o_i = (q_i.unsqueeze(-1) * S).sum(-2) 49 | o[:, :, i] = o_i 50 | 51 | return o 52 | 53 | -------------------------------------------------------------------------------- /finetune/lora/v6/fla/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import functools 4 | 5 | import torch 6 | 7 | 8 | def contiguous(fn): 9 | @functools.wraps(fn) 10 | def wrapper(ctx, *args, **kwargs): 11 | return fn(ctx, 12 | *(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args), 13 | **{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}) 14 | return wrapper 15 | 16 | 17 | def require_version(version, hint): 18 | def decorator(fn): 19 | @functools.wraps(fn) 20 | def wrapper(ctx, *args, **kwargs): 21 | from transformers.utils.versions import require_version 22 | require_version(version, hint) 23 | return fn(ctx, 24 | *(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args), 25 | **{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}) 26 | return wrapper 27 | return decorator 28 | 29 | 30 | def checkpoint(func): 31 | def wrapper(*args, **kwargs): 32 | return torch.utils.checkpoint.checkpoint(func, *args, **kwargs) 33 | return wrapper 34 | -------------------------------------------------------------------------------- /finetune/lora/v6/merge/merge_lora.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import os 3 | import sys 4 | from typing import Dict 5 | import typing 6 | import torch 7 | 8 | if '-h' in sys.argv or '--help' in sys.argv: 9 | print(f'Usage: python3 {sys.argv[0]} [--use-gpu] <lora_alpha> <base_model.pth> <lora_checkpoint.pth> <output.pth>') 10 | 11 | if sys.argv[1] == '--use-gpu': 12 | device = 'cuda' 13 | lora_alpha, base_model, lora, output = float(sys.argv[2]), sys.argv[3], sys.argv[4], sys.argv[5] 14 | else: 15 | device = 'cpu' 16 | lora_alpha, base_model, lora, output = float(sys.argv[1]), sys.argv[2], sys.argv[3], sys.argv[4] 17 | 18 | 19 | with torch.no_grad(): 20 | w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu') 21 | # merge LoRA-only slim checkpoint into the main weights 22 | w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu') 23 | for k in w_lora.keys(): 24 | w[k] = w_lora[k] 25 | output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict() 26 | # merge LoRA weights 27 | keys = list(w.keys()) 28 | for k in keys: 29 | if k.endswith('.weight'): 30 | prefix = k[:-len('.weight')] 31 | lora_A = prefix + '.lora_A' 32 | lora_B = prefix + '.lora_B' 33 | if lora_A in keys: 34 | assert lora_B in keys 35 | print(f'merging {lora_A} and {lora_B} into {k}') 36 | assert w[lora_B].shape[1] == w[lora_A].shape[0] 37 | lora_r = w[lora_B].shape[1] 38 | w[k] = w[k].to(device=device) 39 | w[lora_A] = w[lora_A].to(device=device) 40 | w[lora_B] = w[lora_B].to(device=device) 41 | w[k] += w[lora_B] @ w[lora_A] * (lora_alpha / lora_r) 42 | output_w[k] = w[k].to(device='cpu', copy=True) 43 | del w[k] 44 | del w[lora_A] 45 | del w[lora_B] 46 | continue 47 | 48 | if 'lora' not in k: 49 | print(f'retaining {k}') 50 | output_w[k] = w[k].clone() 51 | del w[k] 52 | torch.save(output_w, output) 53 | -------------------------------------------------------------------------------- /finetune/lora/v6/merge/merge_pissa.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import os 3 | import sys 4 | from typing import Dict 5 | import typing 6 | import torch 7 | 8 | if '-h' in sys.argv or '--help' in sys.argv: 9 | print(f'Usage: python3 {sys.argv[0]} [--use-gpu] <base_model.pth> <lora_init.pth> <lora_checkpoint.pth> <output.pth>') 10 | 11 | if sys.argv[1] == '--use-gpu': 12 | device = 'cuda' 13 | base_model, init_lora, lora, output = sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5] 14 | else: 15 | device = 'cpu' 16 | base_model, init_lora, lora, output = sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4] 17 | 18 | 19 | with torch.no_grad(): 20 | w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu') 21 | # merge LoRA-only slim checkpoint into the main weights 22 | w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu') 23 | w_init_lora: Dict[str, torch.Tensor] = torch.load(init_lora, map_location='cpu') 24 | for k in w_lora.keys(): 25 | w[k] = w_lora[k] 26 | output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict() 27 | # merge LoRA weights 28 | keys = list(w.keys()) 29 | for k in keys: 30 | if k.endswith('.weight'): 31 | prefix = k[:-len('.weight')] 32 | lora_A = prefix + '.lora_A' 33 | lora_B = prefix + '.lora_B' 34 | init_lora_A = prefix + '.init_lora_A' 35 | init_lora_B = prefix + '.init_lora_B' 36 | if lora_A in keys: 37 | assert lora_B in keys 38 | print(f'merging {lora_A} and {lora_B} into {k}') 39 | assert w[lora_B].shape[1] == w[lora_A].shape[0] 40 | lora_r = w[lora_B].shape[1] 41 | w[k] = w[k].to(device=device) 42 | w[lora_A] = w[lora_A].to(device=device) 43 | w[lora_B] = w[lora_B].to(device=device) 44 | w_init_lora[init_lora_A] = w_init_lora[init_lora_A].to(device=device) 45 | w_init_lora[init_lora_B] = w_init_lora[init_lora_B].to(device=device) 46 | w[k] = (w[k]- w_init_lora[init_lora_B] @ w_init_lora[init_lora_A]).to(dtype=torch.bfloat16) 47 | w[k] += w[lora_B] @ w[lora_A] 48 | output_w[k] = w[k].to(device='cpu', copy=True) 49 | del w[k] 50 | del w[lora_A] 51 | del w[lora_B] 52 | continue 53 | 54 | if 'lora' not in k: 55 | print(f'retaining {k}') 56 | output_w[k] = w[k].clone() 57 | del w[k] 58 | torch.save(output_w, output) -------------------------------------------------------------------------------- /finetune/lora/v6/merge/merge_state.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import os 3 | import sys 4 | from typing import Dict 5 | import typing 6 | import torch 7 | import bitsandbytes as bnb 8 | from argparse import ArgumentParser 9 | 10 | parser = ArgumentParser() 11 | parser.add_argument("--base_model", default="", type=str) 12 | parser.add_argument("--state_checkpoint", default="", type=str) 13 | parser.add_argument("--output", default="", type=str) 14 | # parser.add_argument("--quant", default="none", type=str) 15 | parser.add_argument("--device", default="cuda", type=str) 16 | # parser.add_argument("--lora_alpha", default=16, type=int) 17 | args = parser.parse_args() 18 | device= args.device 19 | base_model = args.base_model 20 | state= args.state_checkpoint 21 | output= args.output 22 | 23 | 24 | with torch.no_grad(): 25 | w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu') 26 | # merge LoRA-only slim checkpoint into the main weights 27 | w_state: Dict[str, torch.Tensor] = torch.load(state, map_location='cpu') 28 | 29 | for k in w_state.keys(): 30 | print(k) 31 | w[k] = w_state[k] 32 | # merge LoRA weights 33 | for k in w.keys(): 34 | print(k) 35 | 36 | torch.save(w, output) -------------------------------------------------------------------------------- /finetune/lora/v6/requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning==1.9.5 2 | bitsandbytes 3 | deepspeed 4 | einops 5 | triton==2.2.0 -------------------------------------------------------------------------------- /finetune/lora/v6/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/finetune/lora/v6/src/__init__.py -------------------------------------------------------------------------------- /finetune/lora/v6/src/infctx_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | ######state 3 | class TimeMixState: 4 | def __init__(self, shift_state: torch.Tensor, wkv_state: torch.Tensor): 5 | self.shift_state = shift_state 6 | self.wkv_state = wkv_state 7 | 8 | 9 | class ChannelMixState: 10 | def __init__(self, shift_state: torch.Tensor): 11 | self.shift_state = shift_state 12 | 13 | 14 | class BlockState: 15 | def __init__(self, time_mix_state: TimeMixState, 16 | channel_mix_state: ChannelMixState): 17 | self.time_mix_state = time_mix_state 18 | self.channel_mix_state = channel_mix_state 19 | 20 | class BlockStateList: 21 | 22 | def __init__(self, shift_states, wkv_states): 23 | self.wkv_states = wkv_states 24 | self.shift_states = shift_states 25 | 26 | @staticmethod 27 | def create(N, B, C, H, device, dtype): 28 | result = BlockStateList.empty(N, B, C, H, device, dtype) 29 | result.wkv_states[:] = 0 30 | result.wkv_states[:] = 0 31 | result.shift_states[:] = 0 32 | return result 33 | 34 | @staticmethod 35 | def empty(N, B, C, H, device, dtype): 36 | wkv_states = torch.empty((N, B, H, C//H, C//H), 37 | device=device, 38 | dtype=torch.bfloat16) 39 | shift_states = torch.empty((N, 2, B, C), device=device, dtype=dtype) 40 | return BlockStateList(shift_states, wkv_states) 41 | 42 | def __getitem__(self, layer: int): 43 | return BlockState( 44 | TimeMixState(self.shift_states[layer, 0], self.wkv_states[layer]), 45 | ChannelMixState(self.shift_states[layer, 1])) 46 | 47 | def __setitem__(self, layer: int, state: BlockState): 48 | self.shift_states[layer, 0] = state.time_mix_state.shift_state 49 | self.wkv_states[layer] = state.time_mix_state.wkv_state 50 | self.shift_states[layer, 1] = state.channel_mix_state.shift_state 51 | 52 | 53 | -------------------------------------------------------------------------------- /finetune/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.2.0 2 | pytorch_lightning==1.9.5 3 | deepspeed==0.12.6 4 | bitsandbytes==0.43.1 5 | einops==0.8.0 6 | triton==2.2.0 7 | transformers==4.41.1 8 | numpy==1.26.4 -------------------------------------------------------------------------------- /frontend/index.html: -------------------------------------------------------------------------------- 1 | <!DOCTYPE html> 2 | <html lang="en"> 3 | <head> 4 | <meta charset="UTF-8" /> 5 | <meta content="width=device-width, initial-scale=1.0" name="viewport" /> 6 | <title>RWKV-Runner</title> 7 | <link href="./src/assets/images/logo.png" rel="icon" type="image/x-icon"> 8 | </head> 9 | <body> 10 | <div id="root"></div> 11 | <script src="./src/main.tsx" type="module"></script> 12 | </body> 13 | </html> 14 | 15 | -------------------------------------------------------------------------------- /frontend/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "frontend", 3 | "private": true, 4 | "version": "0.0.0", 5 | "scripts": { 6 | "dev": "vite --host", 7 | "build": "tsc && vite build", 8 | "preview": "vite preview --host" 9 | }, 10 | "dependencies": { 11 | "@fluentui/react-components": "^9.47.2", 12 | "@fluentui/react-icons": "^2.0.201", 13 | "@magenta/music": "^1.23.1", 14 | "@microsoft/fetch-event-source": "^2.0.1", 15 | "@primer/octicons-react": "^19.1.0", 16 | "@tabler/icons-react": "^3.31.0", 17 | "abcjs": "^6.2.3", 18 | "chart.js": "^4.3.0", 19 | "classnames": "^2.3.2", 20 | "compare-versions": "^6.1.1", 21 | "delay": "^6.0.0", 22 | "file-saver": "^2.0.5", 23 | "html-midi-player": "^1.5.0", 24 | "html-to-image": "^1.11.13", 25 | "i18next": "^22.4.15", 26 | "katex": "^0.16.9", 27 | "lodash-es": "^4.17.21", 28 | "mobx": "^6.9.0", 29 | "mobx-react-lite": "^3.4.3", 30 | "pdfjs-dist": "^4.0.189", 31 | "react": "^18.2.0", 32 | "react-beautiful-dnd": "^13.1.1", 33 | "react-chartjs-2": "^5.2.0", 34 | "react-dom": "^18.2.0", 35 | "react-draggable": "^4.4.6", 36 | "react-i18next": "^12.2.2", 37 | "react-markdown": "^8.0.7", 38 | "react-router": "^6.11.1", 39 | "react-router-dom": "^6.11.1", 40 | "react-toastify": "^9.1.3", 41 | "rehype-highlight": "^6.0.0", 42 | "rehype-katex": "^6.0.3", 43 | "rehype-raw": "^6.1.1", 44 | "remark-breaks": "^3.0.3", 45 | "remark-gfm": "^3.0.1", 46 | "remark-math": "^5.1.1", 47 | "usehooks-ts": "^2.9.1", 48 | "uuid": "^9.0.0" 49 | }, 50 | "devDependencies": { 51 | "@ianvs/prettier-plugin-sort-imports": "^4.3.1", 52 | "@tailwindcss/typography": "^0.5.10", 53 | "@types/file-saver": "^2.0.7", 54 | "@types/lodash-es": "^4.17.12", 55 | "@types/react": "^18.2.6", 56 | "@types/react-beautiful-dnd": "^13.1.4", 57 | "@types/react-dom": "^18.2.4", 58 | "@types/uuid": "^9.0.1", 59 | "@vitejs/plugin-react": "^4.0.0", 60 | "autoprefixer": "^10.4.14", 61 | "postcss": "^8.4.23", 62 | "prettier": "^3.3.3", 63 | "prettier-plugin-tailwindcss": "^0.6.6", 64 | "rollup-plugin-visualizer": "^5.9.0", 65 | "sass": "^1.62.1", 66 | "tailwindcss": "^3.3.2", 67 | "typescript": "^5.0.4", 68 | "vite": "^4.5.9", 69 | "vite-plugin-top-level-await": "^1.3.1" 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /frontend/postcss.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | plugins: { 3 | tailwindcss: {}, 4 | autoprefixer: {}, 5 | }, 6 | } 7 | -------------------------------------------------------------------------------- /frontend/prettier.config.js: -------------------------------------------------------------------------------- 1 | /** @type {import('prettier').Config} */ 2 | module.exports = { 3 | endOfLine: 'lf', 4 | semi: false, 5 | singleQuote: true, 6 | tabWidth: 2, 7 | trailingComma: 'es5', 8 | importOrder: ['^(react/(.*)$)|^(react$)', '<THIRD_PARTY_MODULES>', '^[./]'], 9 | importOrderSeparation: false, 10 | importOrderSortSpecifiers: true, 11 | importOrderBuiltinModulesToTop: true, 12 | importOrderParserPlugins: ['typescript', 'jsx', 'decorators-legacy'], 13 | importOrderMergeDuplicateImports: true, 14 | importOrderCombineTypeAndValueImports: true, 15 | plugins: [ 16 | '@ianvs/prettier-plugin-sort-imports', 17 | 'prettier-plugin-tailwindcss', 18 | ], 19 | } 20 | -------------------------------------------------------------------------------- /frontend/src/_locales/i18n-react.ts: -------------------------------------------------------------------------------- 1 | import i18n, { changeLanguage } from 'i18next' 2 | import { initReactI18next } from 'react-i18next' 3 | import { getUserLanguage } from '../utils' 4 | import { resources } from './resources' 5 | 6 | i18n 7 | .use(initReactI18next) 8 | .init({ 9 | resources, 10 | interpolation: { 11 | escapeValue: false, // not needed for react as it escapes by default 12 | }, 13 | }) 14 | .then(() => { 15 | changeLanguage(getUserLanguage()) 16 | }) 17 | -------------------------------------------------------------------------------- /frontend/src/_locales/i18n.ts: -------------------------------------------------------------------------------- 1 | import i18n, { changeLanguage } from 'i18next' 2 | import { getUserLanguage } from '../utils' 3 | import { resources } from './resources' 4 | 5 | i18n 6 | .init({ 7 | resources, 8 | }) 9 | .then(() => { 10 | changeLanguage(getUserLanguage()) 11 | }) 12 | -------------------------------------------------------------------------------- /frontend/src/_locales/resources.ts: -------------------------------------------------------------------------------- 1 | import ja from './ja/main.json' 2 | import zhHans from './zh-hans/main.json' 3 | 4 | export const resources = { 5 | zh: { 6 | translation: zhHans, 7 | }, 8 | // de: { 9 | // translation: de, 10 | // }, 11 | // es: { 12 | // translation: es, 13 | // }, 14 | // fr: { 15 | // translation: fr, 16 | // }, 17 | // in: { 18 | // translation: inTrans, 19 | // }, 20 | // it: { 21 | // translation: it, 22 | // }, 23 | ja: { 24 | translation: ja, 25 | }, 26 | // ko: { 27 | // translation: ko, 28 | // }, 29 | // pt: { 30 | // translation: pt, 31 | // }, 32 | // ru: { 33 | // translation: ru, 34 | // }, 35 | // zhHant: { 36 | // translation: zhHant, 37 | // }, 38 | } 39 | -------------------------------------------------------------------------------- /frontend/src/assets/images/banner.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/frontend/src/assets/images/banner.jpg -------------------------------------------------------------------------------- /frontend/src/assets/images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/frontend/src/assets/images/logo.png -------------------------------------------------------------------------------- /frontend/src/assets/images/strategy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/frontend/src/assets/images/strategy.jpg -------------------------------------------------------------------------------- /frontend/src/assets/images/strategy_zh.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/frontend/src/assets/images/strategy_zh.jpg -------------------------------------------------------------------------------- /frontend/src/components/ConfigSelector.tsx: -------------------------------------------------------------------------------- 1 | import { FC } from 'react' 2 | import { Dropdown, Option, PresenceBadge } from '@fluentui/react-components' 3 | import { observer } from 'mobx-react-lite' 4 | import commonStore from '../stores/commonStore' 5 | 6 | export const ConfigSelector: FC<{ size?: 'small' | 'medium' | 'large' }> = 7 | observer(({ size }) => { 8 | return ( 9 | <Dropdown 10 | size={size} 11 | style={{ minWidth: 0 }} 12 | listbox={{ style: { minWidth: 'fit-content' } }} 13 | value={commonStore.getCurrentModelConfig().name} 14 | selectedOptions={[commonStore.currentModelConfigIndex.toString()]} 15 | onOptionSelect={(_, data) => { 16 | if (data.optionValue) 17 | commonStore.setCurrentConfigIndex(Number(data.optionValue)) 18 | }} 19 | > 20 | {commonStore.modelConfigs.map((config, index) => ( 21 | <Option key={index} value={index.toString()} text={config.name}> 22 | <div className="flex grow justify-between"> 23 | {config.name} 24 | {commonStore.modelSourceList.find( 25 | (item) => item.name === config.modelParameters.modelName 26 | )?.isComplete && <PresenceBadge status="available" />} 27 | </div> 28 | </Option> 29 | ))} 30 | </Dropdown> 31 | ) 32 | }) 33 | -------------------------------------------------------------------------------- /frontend/src/components/CopyButton.tsx: -------------------------------------------------------------------------------- 1 | import { FC, useState } from 'react' 2 | import { CheckIcon, CopyIcon } from '@primer/octicons-react' 3 | import { useTranslation } from 'react-i18next' 4 | import { ClipboardSetText } from '../../wailsjs/runtime' 5 | import { ToolTipButton } from './ToolTipButton' 6 | 7 | export const CopyButton: FC<{ content: string; showDelay?: number }> = ({ 8 | content, 9 | showDelay = 0, 10 | }) => { 11 | const { t } = useTranslation() 12 | const [copied, setCopied] = useState(false) 13 | 14 | const onClick = () => { 15 | ClipboardSetText(content) 16 | .then(() => setCopied(true)) 17 | .then(() => 18 | setTimeout(() => { 19 | setCopied(false) 20 | }, 600) 21 | ) 22 | } 23 | 24 | return ( 25 | <ToolTipButton 26 | desc={t('Copy')} 27 | size="small" 28 | appearance="subtle" 29 | showDelay={showDelay} 30 | icon={copied ? <CheckIcon /> : <CopyIcon />} 31 | onClick={onClick} 32 | /> 33 | ) 34 | } 35 | -------------------------------------------------------------------------------- /frontend/src/components/CustomToastContainer.tsx: -------------------------------------------------------------------------------- 1 | import { ToastContainer } from 'react-toastify' 2 | import commonStore from '../stores/commonStore' 3 | 4 | export const CustomToastContainer = () => ( 5 | <ToastContainer 6 | style={{ width: '350px' }} 7 | position="top-center" 8 | autoClose={4000} 9 | pauseOnHover={true} 10 | hideProgressBar={true} 11 | newestOnTop={true} 12 | closeOnClick={false} 13 | rtl={false} 14 | pauseOnFocusLoss={false} 15 | draggable={false} 16 | theme={commonStore.settings.darkMode ? 'dark' : 'light'} 17 | /> 18 | ) 19 | -------------------------------------------------------------------------------- /frontend/src/components/DebugModeIndicator.tsx: -------------------------------------------------------------------------------- 1 | import { FC } from 'react' 2 | import classNames from 'classnames' 3 | 4 | export const DebugModeIndicator: FC<{ showInDebugMode?: boolean }> = ({ 5 | showInDebugMode = true, 6 | }) => { 7 | if (import.meta.env.PROD) return <></> 8 | if (!showInDebugMode) return <></> 9 | return ( 10 | <div 11 | className={classNames( 12 | 'absolute', 13 | 'right-1', 14 | 'top-1', 15 | 'p-1', 16 | 'rounded', 17 | 'bg-red-600', 18 | 'bg-opacity-50', 19 | 'text-white', 20 | 'font-semibold', 21 | 'text-opacity-50', 22 | 'text-xs' 23 | )} 24 | > 25 | Debug Mode 26 | </div> 27 | ) 28 | } 29 | -------------------------------------------------------------------------------- /frontend/src/components/Labeled.tsx: -------------------------------------------------------------------------------- 1 | import { FC, ReactElement } from 'react' 2 | import { Label, Tooltip } from '@fluentui/react-components' 3 | import classnames from 'classnames' 4 | 5 | export const Labeled: FC<{ 6 | label: string 7 | desc?: string | null 8 | descComponent?: ReactElement 9 | content: ReactElement 10 | flex?: boolean 11 | spaceBetween?: boolean 12 | breakline?: boolean 13 | onMouseEnter?: () => void 14 | onMouseLeave?: () => void 15 | }> = ({ 16 | label, 17 | desc, 18 | descComponent, 19 | content, 20 | flex, 21 | spaceBetween, 22 | breakline, 23 | onMouseEnter, 24 | onMouseLeave, 25 | }) => { 26 | return ( 27 | <div 28 | className={classnames( 29 | !breakline ? 'items-center' : '', 30 | flex ? 'flex' : 'grid grid-cols-2', 31 | breakline ? 'flex-col' : '', 32 | spaceBetween && 'justify-between' 33 | )} 34 | > 35 | {desc || descComponent ? ( 36 | <Tooltip 37 | content={descComponent ? descComponent : desc!} 38 | showDelay={0} 39 | hideDelay={0} 40 | relationship="description" 41 | > 42 | <Label onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}> 43 | {label} 44 | </Label> 45 | </Tooltip> 46 | ) : ( 47 | <Label onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}> 48 | {label} 49 | </Label> 50 | )} 51 | {content} 52 | </div> 53 | ) 54 | } 55 | -------------------------------------------------------------------------------- /frontend/src/components/LazyImportComponent.tsx: -------------------------------------------------------------------------------- 1 | import { FC, LazyExoticComponent, ReactNode, Suspense } from 'react' 2 | import { Spinner } from '@fluentui/react-components' 3 | import { useTranslation } from 'react-i18next' 4 | 5 | interface LazyImportComponentProps { 6 | lazyChildren: LazyExoticComponent<FC<any>> 7 | lazyProps?: any 8 | children?: ReactNode 9 | disableFallback?: boolean 10 | } 11 | 12 | export const LazyImportComponent: FC<LazyImportComponentProps> = (props) => { 13 | const { t } = useTranslation() 14 | 15 | return ( 16 | <Suspense 17 | fallback={ 18 | !props.disableFallback && ( 19 | <div className="flex h-full w-full items-center justify-center"> 20 | <Spinner size="huge" label={t('Loading...')} /> 21 | </div> 22 | ) 23 | } 24 | > 25 | <props.lazyChildren {...props.lazyProps}> 26 | {props.children} 27 | </props.lazyChildren> 28 | </Suspense> 29 | ) 30 | } 31 | -------------------------------------------------------------------------------- /frontend/src/components/NumberInput.tsx: -------------------------------------------------------------------------------- 1 | import React, { CSSProperties, FC } from 'react' 2 | import { Input } from '@fluentui/react-components' 3 | import { SliderOnChangeData } from '@fluentui/react-slider' 4 | 5 | export const NumberInput: FC<{ 6 | value: number 7 | min: number 8 | max: number 9 | step?: number 10 | onChange?: ( 11 | ev: React.ChangeEvent<HTMLInputElement>, 12 | data: SliderOnChangeData 13 | ) => void 14 | style?: CSSProperties 15 | toFixed?: number 16 | disabled?: boolean 17 | }> = ({ value, min, max, step, onChange, style, toFixed = 2, disabled }) => { 18 | return ( 19 | <Input 20 | type="number" 21 | style={style} 22 | value={value.toString()} 23 | min={min} 24 | max={max} 25 | step={step} 26 | disabled={disabled} 27 | onChange={(e, data) => { 28 | onChange?.(e, { value: Number(data.value) }) 29 | }} 30 | onBlur={(e) => { 31 | if (onChange) { 32 | if (step) { 33 | const offset = (min > 0 ? min : 0) - (max < 0 ? max : 0) 34 | value = Number( 35 | (Math.round((value - offset) / step) * step + offset).toFixed( 36 | toFixed 37 | ) 38 | ) // avoid precision issues 39 | } 40 | onChange(e, { value: Math.max(Math.min(value, max), min) }) 41 | } 42 | }} 43 | /> 44 | ) 45 | } 46 | -------------------------------------------------------------------------------- /frontend/src/components/Page.tsx: -------------------------------------------------------------------------------- 1 | import React, { FC, ReactElement } from 'react' 2 | import { Divider, Text } from '@fluentui/react-components' 3 | 4 | export const Page: FC<{ title: string; content: ReactElement }> = ({ 5 | title, 6 | content, 7 | }) => { 8 | return ( 9 | <div className="flex h-full flex-col gap-2 p-2"> 10 | <Text size={600}>{title}</Text> 11 | <Divider style={{ flexGrow: 0 }} /> 12 | {content} 13 | </div> 14 | ) 15 | } 16 | -------------------------------------------------------------------------------- /frontend/src/components/ReadButton.tsx: -------------------------------------------------------------------------------- 1 | import { FC, useState } from 'react' 2 | import { MuteIcon, UnmuteIcon } from '@primer/octicons-react' 3 | import { observer } from 'mobx-react-lite' 4 | import { useTranslation } from 'react-i18next' 5 | import commonStore from '../stores/commonStore' 6 | import { ToolTipButton } from './ToolTipButton' 7 | 8 | const synth = window.speechSynthesis 9 | 10 | export const ReadButton: FC<{ 11 | content: string 12 | inSpeaking?: boolean 13 | showDelay?: number 14 | setSpeakingOuter?: (speaking: boolean) => void 15 | }> = observer( 16 | ({ content, inSpeaking = false, showDelay = 0, setSpeakingOuter }) => { 17 | const { t } = useTranslation() 18 | const [speaking, setSpeaking] = useState(inSpeaking) 19 | let lang: string = commonStore.settings.language 20 | if (lang === 'dev') lang = 'en' 21 | 22 | const setSpeakingInner = (speaking: boolean) => { 23 | setSpeakingOuter?.(speaking) 24 | setSpeaking(speaking) 25 | } 26 | 27 | const startSpeak = () => { 28 | synth.cancel() 29 | 30 | const utterance = new SpeechSynthesisUtterance(content) 31 | const voices = synth.getVoices() 32 | 33 | let voice 34 | if (lang === 'en') 35 | voice = voices.find((v) => 36 | v.name.toLowerCase().includes('microsoft aria') 37 | ) 38 | else if (lang === 'zh') 39 | voice = voices.find((v) => v.name.toLowerCase().includes('xiaoyi')) 40 | else if (lang === 'ja') 41 | voice = voices.find((v) => v.name.toLowerCase().includes('nanami')) 42 | if (!voice) voice = voices.find((v) => v.lang.substring(0, 2) === lang) 43 | if (!voice) voice = voices.find((v) => v.lang === navigator.language) 44 | 45 | Object.assign(utterance, { 46 | rate: 1, 47 | volume: 1, 48 | onend: () => setSpeakingInner(false), 49 | onerror: () => setSpeakingInner(false), 50 | voice: voice, 51 | }) 52 | 53 | synth.speak(utterance) 54 | setSpeakingInner(true) 55 | } 56 | 57 | const stopSpeak = () => { 58 | synth.cancel() 59 | setSpeakingInner(false) 60 | } 61 | 62 | return ( 63 | <ToolTipButton 64 | desc={t('Read Aloud')} 65 | size="small" 66 | appearance="subtle" 67 | showDelay={showDelay} 68 | icon={speaking ? <MuteIcon /> : <UnmuteIcon />} 69 | onClick={speaking ? stopSpeak : startSpeak} 70 | /> 71 | ) 72 | } 73 | ) 74 | -------------------------------------------------------------------------------- /frontend/src/components/ResetConfigsButton.tsx: -------------------------------------------------------------------------------- 1 | import React, { FC } from 'react' 2 | import { ArrowReset20Regular } from '@fluentui/react-icons' 3 | import { useTranslation } from 'react-i18next' 4 | import { 5 | defaultModelConfigs, 6 | defaultModelConfigsMac, 7 | } from '../pages/defaultConfigs' 8 | import commonStore from '../stores/commonStore' 9 | import { DialogButton } from './DialogButton' 10 | 11 | export const ResetConfigsButton: FC<{ afterConfirm?: () => void }> = ({ 12 | afterConfirm, 13 | }) => { 14 | const { t } = useTranslation() 15 | return ( 16 | <DialogButton 17 | icon={<ArrowReset20Regular />} 18 | tooltip={t('Reset All Configs')} 19 | title={t('Reset All Configs')} 20 | content={t( 21 | 'Are you sure you want to reset all configs? This will obtain the latest preset configs, but will override your custom configs and cannot be undone.' 22 | )} 23 | onConfirm={() => { 24 | commonStore.setModelConfigs( 25 | commonStore.platform !== 'darwin' 26 | ? defaultModelConfigs 27 | : defaultModelConfigsMac, 28 | false 29 | ) 30 | commonStore.setCurrentConfigIndex(0, true) 31 | afterConfirm?.() 32 | }} 33 | /> 34 | ) 35 | } 36 | -------------------------------------------------------------------------------- /frontend/src/components/Section.tsx: -------------------------------------------------------------------------------- 1 | import { FC, ReactElement } from 'react' 2 | import { Card, Text } from '@fluentui/react-components' 3 | 4 | export const Section: FC<{ 5 | title: string 6 | desc?: string | null 7 | content: ReactElement 8 | outline?: boolean 9 | }> = ({ title, desc, content, outline = true }) => { 10 | return ( 11 | <Card size="small" appearance={outline ? 'outline' : 'subtle'}> 12 | <div className="flex flex-col gap-5"> 13 | <div className="flex flex-col gap-1"> 14 | <Text weight="medium">{title}</Text> 15 | {desc && <Text size={100}>{desc}</Text>} 16 | </div> 17 | </div> 18 | <div className="overflow-y-auto overflow-x-hidden p-1">{content}</div> 19 | </Card> 20 | ) 21 | } 22 | -------------------------------------------------------------------------------- /frontend/src/components/ToolTipButton.tsx: -------------------------------------------------------------------------------- 1 | import React, { 2 | CSSProperties, 3 | FC, 4 | MouseEventHandler, 5 | ReactElement, 6 | } from 'react' 7 | import { Button, Tooltip } from '@fluentui/react-components' 8 | 9 | export const ToolTipButton: FC<{ 10 | text?: string | null 11 | desc: string 12 | icon?: ReactElement 13 | className?: string 14 | style?: CSSProperties 15 | size?: 'small' | 'medium' | 'large' 16 | shape?: 'rounded' | 'circular' | 'square' 17 | appearance?: 'secondary' | 'primary' | 'outline' | 'subtle' | 'transparent' 18 | disabled?: boolean 19 | onClick?: MouseEventHandler 20 | showDelay?: number 21 | }> = ({ 22 | text, 23 | desc, 24 | icon, 25 | className, 26 | style, 27 | size, 28 | shape, 29 | appearance, 30 | disabled, 31 | onClick, 32 | showDelay = 0, 33 | }) => { 34 | return desc ? ( 35 | <Tooltip 36 | content={desc} 37 | showDelay={showDelay} 38 | hideDelay={0} 39 | relationship="label" 40 | > 41 | <Button 42 | style={style} 43 | className={className} 44 | disabled={disabled} 45 | icon={icon} 46 | onClick={onClick} 47 | size={size} 48 | shape={shape} 49 | appearance={appearance} 50 | > 51 | {text} 52 | </Button> 53 | </Tooltip> 54 | ) : ( 55 | <Button 56 | style={style} 57 | className={className} 58 | disabled={disabled} 59 | icon={icon} 60 | onClick={onClick} 61 | size={size} 62 | shape={shape} 63 | appearance={appearance} 64 | > 65 | {text} 66 | </Button> 67 | ) 68 | } 69 | -------------------------------------------------------------------------------- /frontend/src/components/ValuedSlider.tsx: -------------------------------------------------------------------------------- 1 | import React, { FC, useEffect, useRef } from 'react' 2 | import { Slider, Text } from '@fluentui/react-components' 3 | import { SliderOnChangeData } from '@fluentui/react-slider' 4 | import { NumberInput } from './NumberInput' 5 | 6 | export const ValuedSlider: FC<{ 7 | value: number 8 | min: number 9 | max: number 10 | step?: number 11 | input?: boolean 12 | onChange?: ( 13 | ev: React.ChangeEvent<HTMLInputElement>, 14 | data: SliderOnChangeData 15 | ) => void 16 | toFixed?: number 17 | disabled?: boolean 18 | }> = ({ value, min, max, step, input, onChange, toFixed, disabled }) => { 19 | const sliderRef = useRef<HTMLInputElement>(null) 20 | useEffect(() => { 21 | if (step && sliderRef.current && sliderRef.current.parentElement) { 22 | if ((max - min) / step > 10) 23 | sliderRef.current.parentElement.style.removeProperty( 24 | '--fui-Slider--steps-percent' 25 | ) 26 | } 27 | }, []) 28 | 29 | return ( 30 | <div className="flex items-center"> 31 | <Slider 32 | ref={sliderRef} 33 | className="grow" 34 | style={{ minWidth: '50%' }} 35 | value={value} 36 | min={min} 37 | max={max} 38 | step={step} 39 | onChange={onChange} 40 | disabled={disabled} 41 | /> 42 | {input ? ( 43 | <NumberInput 44 | style={{ minWidth: 0 }} 45 | value={value} 46 | min={min} 47 | max={max} 48 | step={step} 49 | onChange={onChange} 50 | toFixed={toFixed} 51 | disabled={disabled} 52 | /> 53 | ) : ( 54 | <Text>{value}</Text> 55 | )} 56 | </div> 57 | ) 58 | } 59 | -------------------------------------------------------------------------------- /frontend/src/components/WorkHeader.tsx: -------------------------------------------------------------------------------- 1 | import React, { FC } from 'react' 2 | import { PresenceBadgeStatus } from '@fluentui/react-badge' 3 | import { Divider, PresenceBadge, Text } from '@fluentui/react-components' 4 | import { observer } from 'mobx-react-lite' 5 | import { useTranslation } from 'react-i18next' 6 | import { useMediaQuery } from 'usehooks-ts' 7 | import commonStore, { ModelStatus } from '../stores/commonStore' 8 | import { ConfigSelector } from './ConfigSelector' 9 | import { RunButton } from './RunButton' 10 | 11 | const statusText = { 12 | [ModelStatus.Offline]: 'Offline', 13 | [ModelStatus.Starting]: 'Starting', 14 | [ModelStatus.Loading]: 'Loading', 15 | [ModelStatus.Working]: 'Working', 16 | } 17 | 18 | const badgeStatus: { [modelStatus: number]: PresenceBadgeStatus } = { 19 | [ModelStatus.Offline]: 'unknown', 20 | [ModelStatus.Starting]: 'away', 21 | [ModelStatus.Loading]: 'away', 22 | [ModelStatus.Working]: 'available', 23 | } 24 | 25 | export const WorkHeader: FC = observer(() => { 26 | const { t } = useTranslation() 27 | const mq = useMediaQuery('(min-width: 640px)') 28 | const port = commonStore.getCurrentModelConfig().apiParameters.apiPort 29 | 30 | return commonStore.platform === 'web' ? ( 31 | <div /> 32 | ) : ( 33 | <div className="flex flex-col gap-1"> 34 | <div className="flex items-center justify-between"> 35 | <div className="flex items-center gap-2"> 36 | <PresenceBadge status={badgeStatus[commonStore.status.status]} /> 37 | <Text size={100}> 38 | {t('Model Status') + 39 | ': ' + 40 | t(statusText[commonStore.status.status])} 41 | </Text> 42 | </div> 43 | {commonStore.lastModelName && mq && ( 44 | <Text size={100}>{commonStore.lastModelName}</Text> 45 | )} 46 | <div className="flex items-center gap-2"> 47 | <ConfigSelector size="small" /> 48 | <RunButton iconMode /> 49 | </div> 50 | </div> 51 | <Text size={100}> 52 | {t( 53 | "This tool's API is compatible with OpenAI API. It can be used with any ChatGPT tool you like. Go to the settings of some ChatGPT tool, replace the 'https://api.openai.com' part in the API address with '" 54 | ) + 55 | `http://127.0.0.1:${port}` + 56 | "'."} 57 | </Text> 58 | <Divider style={{ flexGrow: 0 }} /> 59 | </div> 60 | ) 61 | }) 62 | -------------------------------------------------------------------------------- /frontend/src/main.tsx: -------------------------------------------------------------------------------- 1 | import './webWails' 2 | import React from 'react' 3 | import { createRoot } from 'react-dom/client' 4 | import './style.scss' 5 | import 'react-toastify/dist/ReactToastify.css' 6 | import { HashRouter } from 'react-router-dom' 7 | import App from './App' 8 | import { startup } from './startup' 9 | import './_locales/i18n-react' 10 | import { WindowShow } from '../wailsjs/runtime' 11 | 12 | startup().then(() => { 13 | const container = document.getElementById('root') 14 | 15 | const root = createRoot(container!) 16 | 17 | root.render( 18 | <HashRouter> 19 | <App /> 20 | </HashRouter> 21 | ) 22 | 23 | // force display the window 24 | WindowShow() 25 | }) 26 | -------------------------------------------------------------------------------- /frontend/src/pages/About.tsx: -------------------------------------------------------------------------------- 1 | import React, { FC } from 'react' 2 | import { observer } from 'mobx-react-lite' 3 | import { useTranslation } from 'react-i18next' 4 | import MarkdownRender from '../components/MarkdownRender' 5 | import { Page } from '../components/Page' 6 | import commonStore from '../stores/commonStore' 7 | 8 | const About: FC = observer(() => { 9 | const { t } = useTranslation() 10 | const lang: string = commonStore.settings.language 11 | 12 | return ( 13 | <Page 14 | title={t('About')} 15 | content={ 16 | <div className="overflow-y-auto overflow-x-hidden p-1"> 17 | <MarkdownRender> 18 | {lang in commonStore.about 19 | ? commonStore.about[lang] 20 | : commonStore.about['en']} 21 | </MarkdownRender> 22 | </div> 23 | } 24 | /> 25 | ) 26 | }) 27 | 28 | export default About 29 | -------------------------------------------------------------------------------- /frontend/src/pages/AudiotrackManager/AudiotrackButton.tsx: -------------------------------------------------------------------------------- 1 | import React, { FC, lazy } from 'react' 2 | import { 3 | Button, 4 | Dialog, 5 | DialogBody, 6 | DialogContent, 7 | DialogSurface, 8 | DialogTrigger, 9 | } from '@fluentui/react-components' 10 | import { useTranslation } from 'react-i18next' 11 | import { CustomToastContainer } from '../../components/CustomToastContainer' 12 | import { LazyImportComponent } from '../../components/LazyImportComponent' 13 | import commonStore from '../../stores/commonStore' 14 | import { flushMidiRecordingContent } from '../../utils' 15 | 16 | const AudiotrackEditor = lazy(() => import('./AudiotrackEditor')) 17 | 18 | export const AudiotrackButton: FC<{ 19 | size?: 'small' | 'medium' | 'large' 20 | shape?: 'rounded' | 'circular' | 'square' 21 | appearance?: 'secondary' | 'primary' | 'outline' | 'subtle' | 'transparent' 22 | setPrompt: (prompt: string) => void 23 | }> = ({ size, shape, appearance, setPrompt }) => { 24 | const { t } = useTranslation() 25 | 26 | return ( 27 | <Dialog 28 | onOpenChange={(e, data) => { 29 | if (!data.open) { 30 | flushMidiRecordingContent() 31 | commonStore.setRecordingTrackId('') 32 | commonStore.setPlayingTrackId('') 33 | } 34 | }} 35 | > 36 | <DialogTrigger disableButtonEnhancement> 37 | <Button size={size} shape={shape} appearance={appearance}> 38 | {t('Open MIDI Input Audio Tracks')} 39 | </Button> 40 | </DialogTrigger> 41 | <DialogSurface 42 | style={{ 43 | paddingTop: 0, 44 | maxWidth: '90vw', 45 | width: 'fit-content', 46 | transform: 'unset', 47 | }} 48 | > 49 | <DialogBody> 50 | <DialogContent className="overflow-hidden"> 51 | <CustomToastContainer /> 52 | <LazyImportComponent 53 | lazyChildren={AudiotrackEditor} 54 | lazyProps={{ setPrompt }} 55 | /> 56 | </DialogContent> 57 | </DialogBody> 58 | </DialogSurface> 59 | </Dialog> 60 | ) 61 | } 62 | -------------------------------------------------------------------------------- /frontend/src/pages/index.tsx: -------------------------------------------------------------------------------- 1 | import { FC, lazy, LazyExoticComponent, ReactElement } from 'react' 2 | import { 3 | ArrowDownload20Regular, 4 | Chat20Regular, 5 | ClipboardEdit20Regular, 6 | DataUsageSettings20Regular, 7 | DocumentSettings20Regular, 8 | Home20Regular, 9 | Info20Regular, 10 | MusicNote220Regular, 11 | Settings20Regular, 12 | Storage20Regular, 13 | } from '@fluentui/react-icons' 14 | 15 | type NavigationItem = { 16 | label: string 17 | path: string 18 | icon: ReactElement 19 | element: LazyExoticComponent<FC> 20 | top: boolean 21 | } 22 | 23 | export const pages: NavigationItem[] = [ 24 | { 25 | label: 'Home', 26 | path: '/', 27 | icon: <Home20Regular />, 28 | element: lazy(() => import('./Home')), 29 | top: true, 30 | }, 31 | { 32 | label: 'Chat', 33 | path: '/chat', 34 | icon: <Chat20Regular />, 35 | element: lazy(() => import('./Chat')), 36 | top: true, 37 | }, 38 | { 39 | label: 'Completion', 40 | path: '/completion', 41 | icon: <ClipboardEdit20Regular />, 42 | element: lazy(() => import('./Completion')), 43 | top: true, 44 | }, 45 | { 46 | label: 'Composition', 47 | path: '/composition', 48 | icon: <MusicNote220Regular />, 49 | element: lazy(() => import('./Composition')), 50 | top: true, 51 | }, 52 | { 53 | label: 'Configs', 54 | path: '/configs', 55 | icon: <DocumentSettings20Regular />, 56 | element: lazy(() => import('./Configs')), 57 | top: true, 58 | }, 59 | { 60 | label: 'Models', 61 | path: '/models', 62 | icon: <DataUsageSettings20Regular />, 63 | element: lazy(() => import('./Models')), 64 | top: true, 65 | }, 66 | { 67 | label: 'Downloads', 68 | path: '/downloads', 69 | icon: <ArrowDownload20Regular />, 70 | element: lazy(() => import('./Downloads')), 71 | top: true, 72 | }, 73 | { 74 | label: 'Train', 75 | path: '/train', 76 | icon: <Storage20Regular />, 77 | element: lazy(() => import('./Train')), 78 | top: true, 79 | }, 80 | { 81 | label: 'Settings', 82 | path: '/settings', 83 | icon: <Settings20Regular />, 84 | element: lazy(() => import('./Settings')), 85 | top: false, 86 | }, 87 | { 88 | label: 'About', 89 | path: '/about', 90 | icon: <Info20Regular />, 91 | element: lazy(() => import('./About')), 92 | top: false, 93 | }, 94 | ] 95 | -------------------------------------------------------------------------------- /frontend/src/style.scss: -------------------------------------------------------------------------------- 1 | [data-theme='dark'] { 2 | @import 'highlight.js/scss/github-dark.scss'; 3 | --color-neutral-muted: rgba(110, 118, 129, 0.4); 4 | } 5 | 6 | [data-theme='light'] { 7 | @import 'highlight.js/scss/github.scss'; 8 | --color-neutral-muted: rgba(150, 160, 170, 0.3); 9 | } 10 | 11 | @tailwind base; 12 | @tailwind components; 13 | @tailwind utilities; 14 | 15 | body { 16 | margin: 0; 17 | width: 100%; 18 | height: 100%; 19 | } 20 | 21 | * { 22 | scrollbar-width: thin; 23 | } 24 | 25 | /* Works on Chrome, Edge, and Safari */ 26 | *::-webkit-scrollbar { 27 | width: 9px; 28 | height: 9px; 29 | } 30 | 31 | *::-webkit-scrollbar-thumb { 32 | background-color: rgba(155, 155, 155, 0.5); 33 | border-radius: 20px; 34 | border: transparent; 35 | } 36 | 37 | *::-webkit-scrollbar-track { 38 | background: transparent; 39 | } 40 | 41 | *::-webkit-scrollbar-corner { 42 | background: transparent; 43 | } 44 | 45 | .markdown-body { 46 | overflow-y: auto; 47 | overflow-x: hidden; 48 | 49 | pre { 50 | padding: 0; 51 | background: transparent; 52 | 53 | code { 54 | font-size: 14px; 55 | } 56 | } 57 | 58 | code { 59 | white-space: pre-wrap; 60 | word-break: break-word; 61 | border-radius: 8px; 62 | background-color: var(--color-neutral-muted); 63 | } 64 | 65 | details summary { 66 | cursor: pointer; 67 | } 68 | } 69 | 70 | midi-player { 71 | &::part(control-panel) { 72 | background: none; 73 | } 74 | } 75 | 76 | midi-visualizer { 77 | $instrument-colors: #007bff, #20c997, #dc3545, #6610f2, #ffc107, #e83e8c, 78 | #17a2b8, #fd7e14, #28a745; 79 | 80 | svg { 81 | @for $i from 0 to 200 { 82 | $color: nth($instrument-colors, ($i % length($instrument-colors)) + 1); 83 | rect.note[data-instrument='#{$i}'] { 84 | fill: $color; 85 | } 86 | } 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /frontend/src/types/about.ts: -------------------------------------------------------------------------------- 1 | export type AboutContent = { [lang: string]: string } 2 | -------------------------------------------------------------------------------- /frontend/src/types/chat.ts: -------------------------------------------------------------------------------- 1 | import { ApiParameters } from './configs' 2 | 3 | export const userName = 'M E' 4 | export const botName = 'A I' 5 | export const systemName = 'System' 6 | export const welcomeUuid = 'welcome' 7 | 8 | export enum MessageType { 9 | Normal, 10 | Error, 11 | } 12 | 13 | export type Side = 'left' | 'right' | 'center' 14 | export type Color = 'neutral' | 'brand' | 'colorful' 15 | export type MessageItem = { 16 | sender: string 17 | toolName?: string 18 | type: MessageType 19 | color: Color 20 | avatarImg?: string 21 | time: string 22 | content: string 23 | side: Side 24 | done: boolean 25 | thinking?: boolean 26 | thinkingEnded?: boolean 27 | } 28 | export type Conversation = { 29 | [uuid: string]: MessageItem 30 | } 31 | export type Role = 'assistant' | 'user' | 'system' | 'tool' 32 | export type ConversationMessage = { 33 | role: Role 34 | content: string 35 | prefix?: boolean 36 | tool_call_id?: string 37 | tool_calls?: Array<{ 38 | id: string 39 | type: 'function' 40 | function: { 41 | name: string 42 | arguments: string 43 | } 44 | }> 45 | } 46 | export type Attachment = { 47 | name: string 48 | size: number 49 | content: string 50 | } 51 | export type ChatParams = Omit<ApiParameters, 'apiPort'> & { 52 | historyN: number 53 | markdown: boolean 54 | functionCall: boolean 55 | toolDefinition: string 56 | toolReturn: string 57 | } 58 | -------------------------------------------------------------------------------- /frontend/src/types/completion.ts: -------------------------------------------------------------------------------- 1 | import { ApiParameters } from './configs' 2 | 3 | export type CompletionParams = Omit<ApiParameters, 'apiPort'> & { 4 | stop: string 5 | injectStart: string 6 | injectEnd: string 7 | } 8 | export type CompletionPreset = { 9 | name: string 10 | prompt: string 11 | params: CompletionParams 12 | } 13 | -------------------------------------------------------------------------------- /frontend/src/types/composition.ts: -------------------------------------------------------------------------------- 1 | import { NoteSequence } from '@magenta/music/esm/protobuf' 2 | 3 | export const tracksMinimalTotalTime = 5000 4 | 5 | export type CompositionParams = { 6 | prompt: string 7 | maxResponseToken: number 8 | temperature: number 9 | topP: number 10 | autoPlay: boolean 11 | useLocalSoundFont: boolean 12 | externalPlay: boolean 13 | midi: ArrayBuffer | null 14 | ns: NoteSequence | null 15 | generationStartTime: number 16 | playOnlyGeneratedContent: boolean 17 | } 18 | export type Track = { 19 | id: string 20 | mainInstrument: string 21 | content: string 22 | rawContent: MidiMessage[] 23 | offsetTime: number 24 | contentTime: number 25 | } 26 | export type MidiPort = { 27 | name: string 28 | } 29 | 30 | export type MessageType = 'NoteOff' | 'NoteOn' | 'ElapsedTime' | 'ControlChange' 31 | 32 | export type MidiMessage = { 33 | messageType: MessageType 34 | channel: number 35 | note: number 36 | velocity: number 37 | control: number 38 | value: number 39 | instrument: InstrumentType 40 | } 41 | 42 | export enum InstrumentType { 43 | Piano, 44 | Percussion, 45 | Drum, 46 | Tuba, 47 | Marimba, 48 | Bass, 49 | Guitar, 50 | Violin, 51 | Trumpet, 52 | Sax, 53 | Flute, 54 | Lead, 55 | Pad, 56 | } 57 | 58 | export const InstrumentTypeNameMap = [ 59 | 'Piano', 60 | 'Percussion', 61 | 'Drum', 62 | 'Tuba', 63 | 'Marimba', 64 | 'Bass', 65 | 'Guitar', 66 | 'Violin', 67 | 'Trumpet', 68 | 'Sax', 69 | 'Flute', 70 | 'Lead', 71 | 'Pad', 72 | ] 73 | 74 | export const InstrumentTypeTokenMap = [ 75 | 'pi', 76 | 'p', 77 | 'd', 78 | 't', 79 | 'm', 80 | 'b', 81 | 'g', 82 | 'v', 83 | 'tr', 84 | 's', 85 | 'f', 86 | 'l', 87 | 'pa', 88 | ] 89 | -------------------------------------------------------------------------------- /frontend/src/types/configs.ts: -------------------------------------------------------------------------------- 1 | export type ApiParameters = { 2 | apiPort: number 3 | maxResponseToken: number 4 | temperature: number 5 | topP: number 6 | presencePenalty: number 7 | frequencyPenalty: number 8 | penaltyDecay?: number 9 | globalPenalty?: boolean 10 | stateModel?: string 11 | } 12 | export type Device = 13 | | 'CPU' 14 | | 'CPU (rwkv.cpp)' 15 | | 'CUDA' 16 | | 'CUDA-Beta' 17 | | 'WebGPU' 18 | | 'WebGPU (Python)' 19 | | 'MPS' 20 | | 'Custom' 21 | export type Precision = 'fp16' | 'int8' | 'fp32' | 'nf4' | 'Q5_1' 22 | export type GGUFMode = 'CPU' | 'Vulkan GPU' 23 | export type ModelParameters = { 24 | // different models can not have the same name 25 | modelName: string 26 | device: Device 27 | precision: Precision 28 | storedLayers: number 29 | maxStoredLayers: number 30 | quantizedLayers?: number 31 | tokenChunkSize?: number 32 | useCustomCuda?: boolean 33 | customStrategy?: string 34 | useCustomTokenizer?: boolean 35 | customTokenizer?: string 36 | ggufMode?: GGUFMode 37 | llamaContext?: number 38 | } 39 | export type ModelConfig = { 40 | // different configs can have the same name 41 | name: string 42 | apiParameters: ApiParameters 43 | modelParameters: ModelParameters 44 | enableWebUI?: boolean 45 | } 46 | -------------------------------------------------------------------------------- /frontend/src/types/downloads.ts: -------------------------------------------------------------------------------- 1 | export type DownloadStatus = { 2 | name: string 3 | path: string 4 | url: string 5 | transferred: number 6 | size: number 7 | speed: number 8 | progress: number 9 | downloading: boolean 10 | done: boolean 11 | } 12 | -------------------------------------------------------------------------------- /frontend/src/types/home.ts: -------------------------------------------------------------------------------- 1 | import { ReactElement } from 'react' 2 | 3 | export type IntroductionContent = { 4 | [lang: string]: string 5 | } 6 | export type NavCard = { 7 | label: string 8 | desc: string 9 | path: string 10 | icon: ReactElement 11 | } 12 | -------------------------------------------------------------------------------- /frontend/src/types/html-midi-player.d.ts: -------------------------------------------------------------------------------- 1 | declare module JSX { 2 | import { PlayerElement } from 'html-midi-player' 3 | import { VisualizerElement } from 'html-midi-player' 4 | 5 | interface IntrinsicElements { 6 | 'midi-player': PlayerElement 7 | 'midi-visualizer': VisualizerElement 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /frontend/src/types/models.ts: -------------------------------------------------------------------------------- 1 | export type ModelSourceItem = { 2 | name: string 3 | desc?: { [lang: string]: string | undefined } 4 | size: number 5 | SHA256?: string 6 | lastUpdated: string 7 | url?: string 8 | downloadUrl?: string 9 | tags?: string[] 10 | customTokenizer?: string 11 | hide?: boolean 12 | functionCall?: boolean 13 | 14 | lastUpdatedMs?: number 15 | isComplete?: boolean 16 | isLocal?: boolean 17 | localSize?: number 18 | } 19 | -------------------------------------------------------------------------------- /frontend/src/types/presets.ts: -------------------------------------------------------------------------------- 1 | import { ReactElement } from 'react' 2 | import { ConversationMessage } from './chat' 3 | 4 | export type PresetType = 'chat' | 'completion' | 'chatInCompletion' 5 | export type Preset = { 6 | name: string 7 | tag: string 8 | // if name and sourceUrl are same, it will be overridden when importing 9 | sourceUrl: string 10 | desc: string 11 | avatarImg: string 12 | userAvatarImg?: string 13 | type: PresetType 14 | // chat 15 | welcomeMessage: string 16 | messages: ConversationMessage[] 17 | displayPresetMessages: boolean 18 | // completion 19 | prompt: string 20 | stop: string 21 | injectStart: string 22 | injectEnd: string 23 | presystem?: boolean 24 | userName?: string 25 | assistantName?: string 26 | } 27 | export type PresetsNavigationItem = { 28 | icon: ReactElement 29 | element: ReactElement 30 | } 31 | -------------------------------------------------------------------------------- /frontend/src/types/settings.ts: -------------------------------------------------------------------------------- 1 | export const Languages = { 2 | dev: 'English', // i18n default 3 | zh: '简体中文', 4 | ja: '日本語', 5 | } 6 | export type Language = keyof typeof Languages 7 | export type SettingsType = { 8 | language: Language 9 | darkMode: boolean 10 | autoUpdatesCheck: boolean 11 | giteeUpdatesSource: boolean 12 | cnMirror: boolean 13 | useHfMirror: boolean 14 | host: string 15 | dpiScaling: number 16 | customModelsPath: string 17 | customPythonPath: string 18 | apiUrl: string 19 | apiKey: string 20 | apiChatModelName: string 21 | apiCompletionModelName: string 22 | coreApiUrl: string 23 | rememberAllDurableData: boolean 24 | } 25 | -------------------------------------------------------------------------------- /frontend/src/types/train.ts: -------------------------------------------------------------------------------- 1 | import { ReactElement } from 'react' 2 | 3 | export type DataProcessParameters = { 4 | dataPath: string 5 | vocabPath: string 6 | } 7 | export type LoraFinetunePrecision = 'bf16' | 'fp16' | 'tf32' 8 | export type LoraFinetuneParameters = { 9 | baseModel: string 10 | ctxLen: number 11 | epochSteps: number 12 | epochCount: number 13 | epochBegin: number 14 | epochSave: number 15 | microBsz: number 16 | accumGradBatches: number 17 | preFfn: boolean 18 | headQk: boolean 19 | lrInit: string 20 | lrFinal: string 21 | warmupSteps: number 22 | beta1: number 23 | beta2: number 24 | adamEps: string 25 | devices: number 26 | precision: LoraFinetunePrecision 27 | gradCp: boolean 28 | loraR: number 29 | loraAlpha: number 30 | loraDropout: number 31 | loraLoad: string 32 | } 33 | export type TrainNavigationItem = { 34 | element: ReactElement 35 | } 36 | -------------------------------------------------------------------------------- /frontend/src/utils/copy-cuda-kernels.ts: -------------------------------------------------------------------------------- 1 | import { CopyFolderFiles } from '../../wailsjs/go/backend_golang/App' 2 | 3 | export async function copyCudaKernels(torchVersion: string) { 4 | const copyRoot = './backend-python/rwkv_pip' 5 | if (torchVersion === '2.7.1+cu128') { 6 | await CopyFolderFiles( 7 | copyRoot + '/kernels/torch-2.7.1+cu128', 8 | copyRoot, 9 | true 10 | ) 11 | } else { 12 | await CopyFolderFiles( 13 | copyRoot + '/kernels/torch-1.13.1+cu117', 14 | copyRoot, 15 | true 16 | ) 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /frontend/src/utils/filter-function-properties.ts: -------------------------------------------------------------------------------- 1 | export type FilterFunctionProperties<T> = { 2 | // eslint-disable-next-line 3 | [K in keyof T as T[K] extends Function ? never : K]: T[K] 4 | } 5 | -------------------------------------------------------------------------------- /frontend/src/utils/get-available-torch-cu-version.ts: -------------------------------------------------------------------------------- 1 | import { compare } from 'compare-versions' 2 | 3 | export const torchVersions = ['1.13.1', '2.7.1'] 4 | 5 | export function getAvailableTorchCuVersion( 6 | torchVersion: string, 7 | driverCudaVersion: string 8 | ) { 9 | let retTorchVersion = '' 10 | let retCuSourceVersion = '' 11 | const targetTorchVersion = torchVersion.split('+')[0] 12 | if (compare(targetTorchVersion, '2.7.1', '>=')) { 13 | retTorchVersion = '2.7.1' 14 | if (compare(driverCudaVersion, '12.8', '>=')) { 15 | retCuSourceVersion = '12.8' 16 | } else if (compare(driverCudaVersion, '12.6', '>=')) { 17 | retCuSourceVersion = '12.6' 18 | } else { 19 | retCuSourceVersion = '11.8' 20 | } 21 | } else { 22 | retTorchVersion = '1.13.1' 23 | if (compare(driverCudaVersion, '11.7', '>=')) { 24 | retCuSourceVersion = '11.7' 25 | } else { 26 | retCuSourceVersion = '11.6' 27 | } 28 | } 29 | return { torchVersion: retTorchVersion, cuSourceVersion: retCuSourceVersion } 30 | } 31 | -------------------------------------------------------------------------------- /frontend/src/vite-env.d.ts: -------------------------------------------------------------------------------- 1 | /// <reference types="vite/client" /> 2 | -------------------------------------------------------------------------------- /frontend/tailwind.config.js: -------------------------------------------------------------------------------- 1 | const markdownElements = [ 2 | 'div', 3 | 'p', 4 | 'span', 5 | 6 | 'video', 7 | 'img', 8 | 9 | 'abbr', 10 | 'acronym', 11 | 'b', 12 | 'blockquote', 13 | 'code', 14 | 'em', 15 | 'i', 16 | 'li', 17 | 'ol', 18 | 'ul', 19 | 'strong', 20 | 'table', 21 | 'tr', 22 | 'td', 23 | 'th', 24 | 25 | 'details', 26 | 'summary', 27 | 'kbd', 28 | 'samp', 29 | 'sub', 30 | 'sup', 31 | 'ins', 32 | 'del', 33 | 'var', 34 | 'q', 35 | 'dl', 36 | 'dt', 37 | 'dd', 38 | 'ruby', 39 | 'rt', 40 | 'rp', 41 | 42 | 'br', 43 | 'hr', 44 | 45 | 'h1', 46 | 'h2', 47 | 'h3', 48 | 'h4', 49 | 'h5', 50 | 'h6', 51 | 52 | 'thead', 53 | 'tbody', 54 | 'tfoot', 55 | 'u', 56 | 's', 57 | 'a', 58 | 'pre', 59 | 'cite', 60 | ] 61 | 62 | const markdownPseudoElements = ['::marker', '::before', '::after'] 63 | 64 | const tableElements = ['table', 'tr', 'td', 'th', 'thead', 'tbody', 'tfoot'] 65 | 66 | const proseStyles = { 67 | color: 'inherit', 68 | } 69 | 70 | const tableProseStyles = { 71 | ...proseStyles, 72 | borderWidth: 'thin', 73 | borderColor: '#d2d2d5', 74 | } 75 | 76 | const elementsStyles = markdownElements.reduce((acc, element) => { 77 | let styles = proseStyles 78 | if (tableElements.includes(element)) styles = tableProseStyles 79 | 80 | acc[element] = styles 81 | markdownPseudoElements.forEach((pseudo) => { 82 | acc[element + pseudo] = styles 83 | }) 84 | return acc 85 | }, {}) 86 | 87 | /** @type {import('tailwindcss').Config} */ 88 | export default { 89 | content: ['./index.html', './src/**/*.{js,ts,jsx,tsx}'], 90 | theme: { 91 | extend: { 92 | typography: { 93 | DEFAULT: { 94 | css: { 95 | color: 'inherit', 96 | fontSize: 'inherit', 97 | ...elementsStyles, 98 | }, 99 | }, 100 | }, 101 | }, 102 | }, 103 | plugins: [require('@tailwindcss/typography')], 104 | } 105 | -------------------------------------------------------------------------------- /frontend/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "ESNext", 4 | "useDefineForClassFields": true, 5 | "lib": ["DOM", "DOM.Iterable", "ESNext"], 6 | "allowJs": false, 7 | "skipLibCheck": true, 8 | "esModuleInterop": false, 9 | "allowSyntheticDefaultImports": true, 10 | "strict": true, 11 | "forceConsistentCasingInFileNames": true, 12 | "module": "ESNext", 13 | "moduleResolution": "Node", 14 | "resolveJsonModule": true, 15 | "isolatedModules": true, 16 | "noEmit": true, 17 | "jsx": "react-jsx" 18 | }, 19 | "include": ["src"], 20 | "references": [ 21 | { 22 | "path": "./tsconfig.node.json" 23 | } 24 | ] 25 | } 26 | -------------------------------------------------------------------------------- /frontend/tsconfig.node.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "composite": true, 4 | "module": "ESNext", 5 | "moduleResolution": "Node", 6 | "allowSyntheticDefaultImports": true 7 | }, 8 | "include": ["vite.config.ts"] 9 | } 10 | -------------------------------------------------------------------------------- /frontend/vite.config.ts: -------------------------------------------------------------------------------- 1 | import react from '@vitejs/plugin-react' 2 | import { visualizer } from 'rollup-plugin-visualizer' 3 | import { defineConfig } from 'vite' 4 | import topLevelAwait from 'vite-plugin-top-level-await' 5 | // @ts-ignore 6 | import { dependencies } from './package.json' 7 | 8 | // dependencies that exist anywhere 9 | const vendor = [ 10 | 'react', 11 | 'react-dom', 12 | 'react-router', 13 | 'react-router-dom', 14 | '@fluentui/react-icons', 15 | 'mobx', 16 | 'mobx-react-lite', 17 | 'i18next', 18 | 'react-i18next', 19 | 'usehooks-ts', 20 | 'react-toastify', 21 | 'classnames', 22 | 'lodash-es', 23 | ] 24 | 25 | const embedded = [ 26 | // split @fluentui/react-components by components 27 | '@fluentui/react-components', 28 | '@tabler/icons-react', 29 | 30 | // dependencies that exist in single component 31 | 'react-beautiful-dnd', 32 | 'react-draggable', 33 | '@magenta/music', 34 | 'html-midi-player', 35 | 'react-markdown', 36 | 'rehype-highlight', 37 | 'rehype-raw', 38 | 'remark-breaks', 39 | 'remark-gfm', 40 | 'remark-math', 41 | 'rehype-katex', 42 | 'katex', 43 | ] 44 | 45 | function renderChunks(deps: Record<string, string>) { 46 | let chunks = {} 47 | Object.keys(deps).forEach((key) => { 48 | if ([...vendor, ...embedded].includes(key)) return 49 | chunks[key] = [key] 50 | }) 51 | return chunks 52 | } 53 | 54 | // https://vitejs.dev/config/ 55 | export default defineConfig({ 56 | plugins: [ 57 | react(), 58 | visualizer({ 59 | template: 'treemap', 60 | gzipSize: true, 61 | brotliSize: true, 62 | }), 63 | topLevelAwait({ 64 | promiseExportName: '__tla', 65 | promiseImportName: (i) => `__tla_${i}`, 66 | }), 67 | ], 68 | resolve: { 69 | alias: { 70 | // /esm/icons/index.mjs only exports the icons statically, so no separate chunks are created 71 | '@tabler/icons-react': '@tabler/icons-react/dist/esm/icons/index.mjs', 72 | }, 73 | }, 74 | build: { 75 | chunkSizeWarningLimit: 3000, 76 | rollupOptions: { 77 | output: { 78 | manualChunks: { 79 | vendor, 80 | ...renderChunks(dependencies), 81 | }, 82 | entryFileNames: `assets/[name].js`, 83 | chunkFileNames: `assets/[name].js`, 84 | assetFileNames: `assets/[name].[ext]`, 85 | }, 86 | }, 87 | }, 88 | }) 89 | -------------------------------------------------------------------------------- /frontend/wailsjs/go/models.ts: -------------------------------------------------------------------------------- 1 | export namespace backend_golang { 2 | 3 | export class FileInfo { 4 | name: string; 5 | size: number; 6 | isDir: boolean; 7 | modTime: string; 8 | 9 | static createFrom(source: any = {}) { 10 | return new FileInfo(source); 11 | } 12 | 13 | constructor(source: any = {}) { 14 | if ('string' === typeof source) source = JSON.parse(source); 15 | this.name = source["name"]; 16 | this.size = source["size"]; 17 | this.isDir = source["isDir"]; 18 | this.modTime = source["modTime"]; 19 | } 20 | } 21 | export class MIDIMessage { 22 | messageType: string; 23 | channel: number; 24 | note: number; 25 | velocity: number; 26 | control: number; 27 | value: number; 28 | 29 | static createFrom(source: any = {}) { 30 | return new MIDIMessage(source); 31 | } 32 | 33 | constructor(source: any = {}) { 34 | if ('string' === typeof source) source = JSON.parse(source); 35 | this.messageType = source["messageType"]; 36 | this.channel = source["channel"]; 37 | this.note = source["note"]; 38 | this.velocity = source["velocity"]; 39 | this.control = source["control"]; 40 | this.value = source["value"]; 41 | } 42 | } 43 | 44 | } 45 | 46 | -------------------------------------------------------------------------------- /frontend/wailsjs/runtime/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "@wailsapp/runtime", 3 | "version": "2.0.0", 4 | "description": "Wails Javascript runtime library", 5 | "main": "runtime.js", 6 | "types": "runtime.d.ts", 7 | "scripts": { 8 | }, 9 | "repository": { 10 | "type": "git", 11 | "url": "git+https://github.com/wailsapp/wails.git" 12 | }, 13 | "keywords": [ 14 | "Wails", 15 | "Javascript", 16 | "Go" 17 | ], 18 | "author": "Lea Anthony <lea.anthony@gmail.com>", 19 | "license": "MIT", 20 | "bugs": { 21 | "url": "https://github.com/wailsapp/wails/issues" 22 | }, 23 | "homepage": "https://github.com/wailsapp/wails#readme" 24 | } 25 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module rwkv-runner 2 | 3 | go 1.20 4 | 5 | require ( 6 | github.com/cavaliergopher/grab/v3 v3.0.1 7 | github.com/fsnotify/fsnotify v1.8.0 8 | github.com/mattrtaylor/go-rtmidi v0.0.0-20220428034745-af795b1c1a79 9 | github.com/minio/selfupdate v0.6.0 10 | github.com/nyaosorg/go-windows-su v0.2.1 11 | github.com/ubuntu/gowsl v0.0.0-20230615094051-94945650cc1e 12 | github.com/wailsapp/wails/v2 v2.9.2 13 | golang.org/x/text v0.22.0 14 | ) 15 | 16 | require ( 17 | aead.dev/minisign v0.2.0 // indirect 18 | github.com/bep/debounce v1.2.1 // indirect 19 | github.com/go-ole/go-ole v1.3.0 // indirect 20 | github.com/godbus/dbus/v5 v5.1.0 // indirect 21 | github.com/google/uuid v1.6.0 // indirect 22 | github.com/jchv/go-winloader v0.0.0-20210711035445-715c2860da7e // indirect 23 | github.com/labstack/echo/v4 v4.13.3 // indirect 24 | github.com/labstack/gommon v0.4.2 // indirect 25 | github.com/leaanthony/go-ansi-parser v1.6.1 // indirect 26 | github.com/leaanthony/gosod v1.0.4 // indirect 27 | github.com/leaanthony/slicer v1.6.0 // indirect 28 | github.com/leaanthony/u v1.1.1 // indirect 29 | github.com/mattn/go-colorable v0.1.13 // indirect 30 | github.com/mattn/go-isatty v0.0.20 // indirect 31 | github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect 32 | github.com/pkg/errors v0.9.1 // indirect 33 | github.com/rivo/uniseg v0.4.7 // indirect 34 | github.com/samber/lo v1.49.1 // indirect 35 | github.com/sirupsen/logrus v1.9.0 // indirect 36 | github.com/tkrajina/go-reflector v0.5.8 // indirect 37 | github.com/ubuntu/decorate v0.0.0-20230125165522-2d5b0a9bb117 // indirect 38 | github.com/valyala/bytebufferpool v1.0.0 // indirect 39 | github.com/valyala/fasttemplate v1.2.2 // indirect 40 | github.com/wailsapp/go-webview2 v1.0.21 // indirect 41 | github.com/wailsapp/mimetype v1.4.1 // indirect 42 | golang.org/x/crypto v0.33.0 // indirect 43 | golang.org/x/net v0.35.0 // indirect 44 | golang.org/x/sys v0.30.0 // indirect 45 | ) 46 | -------------------------------------------------------------------------------- /py310/Lib/site-packages/cyac-1.9.dist-info/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/cyac-1.9.dist-info/.gitkeep -------------------------------------------------------------------------------- /py310/Lib/site-packages/cyac/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/cyac/.gitkeep -------------------------------------------------------------------------------- /py310/Lib/site-packages/diskcache-5.6.3.dist-info/INSTALLER: -------------------------------------------------------------------------------- 1 | pip 2 | -------------------------------------------------------------------------------- /py310/Lib/site-packages/diskcache-5.6.3.dist-info/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2016-2022 Grant Jenks 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use 4 | this file except in compliance with the License. You may obtain a copy of the 5 | License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software distributed 10 | under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 11 | CONDITIONS OF ANY KIND, either express or implied. See the License for the 12 | specific language governing permissions and limitations under the License. 13 | -------------------------------------------------------------------------------- /py310/Lib/site-packages/diskcache-5.6.3.dist-info/RECORD: -------------------------------------------------------------------------------- 1 | diskcache-5.6.3.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 2 | diskcache-5.6.3.dist-info/LICENSE,sha256=WDVGuqP9k2B9hFEmVwZ3pAH1COIotQRP77w5Sa8XlnI,559 3 | diskcache-5.6.3.dist-info/METADATA,sha256=b8cM_0OjxmZlMebCaOkAWfuTrliu1wgWPAzuyrJNuu8,20458 4 | diskcache-5.6.3.dist-info/RECORD,, 5 | diskcache-5.6.3.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 6 | diskcache-5.6.3.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92 7 | diskcache-5.6.3.dist-info/top_level.txt,sha256=A5fqg_AHgOQc_0o1NZ-Uo5Bsb7CV3fR99J-p1-F4yuA,10 8 | diskcache/__init__.py,sha256=JcEtz224G-ysHmgyxWId28A0wQfCIif0mdv4tqC-Zxk,1262 9 | diskcache/__pycache__/__init__.cpython-310.pyc,, 10 | diskcache/__pycache__/cli.cpython-310.pyc,, 11 | diskcache/__pycache__/core.cpython-310.pyc,, 12 | diskcache/__pycache__/djangocache.cpython-310.pyc,, 13 | diskcache/__pycache__/fanout.cpython-310.pyc,, 14 | diskcache/__pycache__/persistent.cpython-310.pyc,, 15 | diskcache/__pycache__/recipes.cpython-310.pyc,, 16 | diskcache/cli.py,sha256=RVv6Fyn7h_0E5BoAVJzbbjOWPDHXO0ZjNgh9g5CBRP8,44 17 | diskcache/core.py,sha256=oinmHnYCXvR2QXVRaBoqPf8iM79M5pAeixajlaJ9TNs,81867 18 | diskcache/djangocache.py,sha256=SX8jl2d-vfOMG7hXZuuh2VCnZEUP-mHGjYLpM2RdMaQ,16110 19 | diskcache/fanout.py,sha256=E4Gk-puAGsaiqyf45YlLdVTf2RM1g7Qy-TPSug1Elpo,22725 20 | diskcache/persistent.py,sha256=Cir3CPetAfACHlCEzY1P5B_YrZPPeWqsoKrkuBBxcCo,34681 21 | diskcache/recipes.py,sha256=dr6CtZ67nRe3XMUBj2J6D_uftquMi_rNpZBg8i6lfRM,14922 22 | -------------------------------------------------------------------------------- /py310/Lib/site-packages/diskcache-5.6.3.dist-info/REQUESTED: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/diskcache-5.6.3.dist-info/REQUESTED -------------------------------------------------------------------------------- /py310/Lib/site-packages/diskcache-5.6.3.dist-info/WHEEL: -------------------------------------------------------------------------------- 1 | Wheel-Version: 1.0 2 | Generator: bdist_wheel (0.41.2) 3 | Root-Is-Purelib: true 4 | Tag: py3-none-any 5 | 6 | -------------------------------------------------------------------------------- /py310/Lib/site-packages/diskcache-5.6.3.dist-info/top_level.txt: -------------------------------------------------------------------------------- 1 | diskcache 2 | -------------------------------------------------------------------------------- /py310/Lib/site-packages/diskcache/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | DiskCache API Reference 3 | ======================= 4 | 5 | The :doc:`tutorial` provides a helpful walkthrough of most methods. 6 | """ 7 | 8 | from .core import ( 9 | DEFAULT_SETTINGS, 10 | ENOVAL, 11 | EVICTION_POLICY, 12 | UNKNOWN, 13 | Cache, 14 | Disk, 15 | EmptyDirWarning, 16 | JSONDisk, 17 | Timeout, 18 | UnknownFileWarning, 19 | ) 20 | from .fanout import FanoutCache 21 | from .persistent import Deque, Index 22 | from .recipes import ( 23 | Averager, 24 | BoundedSemaphore, 25 | Lock, 26 | RLock, 27 | barrier, 28 | memoize_stampede, 29 | throttle, 30 | ) 31 | 32 | __all__ = [ 33 | 'Averager', 34 | 'BoundedSemaphore', 35 | 'Cache', 36 | 'DEFAULT_SETTINGS', 37 | 'Deque', 38 | 'Disk', 39 | 'ENOVAL', 40 | 'EVICTION_POLICY', 41 | 'EmptyDirWarning', 42 | 'FanoutCache', 43 | 'Index', 44 | 'JSONDisk', 45 | 'Lock', 46 | 'RLock', 47 | 'Timeout', 48 | 'UNKNOWN', 49 | 'UnknownFileWarning', 50 | 'barrier', 51 | 'memoize_stampede', 52 | 'throttle', 53 | ] 54 | 55 | try: 56 | from .djangocache import DjangoCache # noqa 57 | 58 | __all__.append('DjangoCache') 59 | except Exception: # pylint: disable=broad-except # pragma: no cover 60 | # Django not installed or not setup so ignore. 61 | pass 62 | 63 | __title__ = 'diskcache' 64 | __version__ = '5.6.3' 65 | __build__ = 0x050603 66 | __author__ = 'Grant Jenks' 67 | __license__ = 'Apache 2.0' 68 | __copyright__ = 'Copyright 2016-2023 Grant Jenks' 69 | -------------------------------------------------------------------------------- /py310/Lib/site-packages/diskcache/cli.py: -------------------------------------------------------------------------------- 1 | """Command line interface to disk cache.""" 2 | -------------------------------------------------------------------------------- /py310/Lib/site-packages/llama_cpp/__init__.py: -------------------------------------------------------------------------------- 1 | from .llama_cpp import * 2 | from .llama import * 3 | 4 | __version__ = "0.3.9" 5 | -------------------------------------------------------------------------------- /py310/Lib/site-packages/llama_cpp/_ggml.py: -------------------------------------------------------------------------------- 1 | """Internal module use at your own risk 2 | 3 | This module provides a minimal interface for working with ggml tensors from llama-cpp-python 4 | """ 5 | import os 6 | import pathlib 7 | 8 | import llama_cpp._ctypes_extensions as ctypes_ext 9 | 10 | libggml_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib" 11 | libggml = ctypes_ext.load_shared_library("ggml", libggml_base_path) 12 | 13 | -------------------------------------------------------------------------------- /py310/Lib/site-packages/llama_cpp/_logger.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import ctypes 3 | import logging 4 | 5 | import llama_cpp 6 | 7 | # enum ggml_log_level { 8 | # GGML_LOG_LEVEL_NONE = 0, 9 | # GGML_LOG_LEVEL_INFO = 1, 10 | # GGML_LOG_LEVEL_WARN = 2, 11 | # GGML_LOG_LEVEL_ERROR = 3, 12 | # GGML_LOG_LEVEL_DEBUG = 4, 13 | # GGML_LOG_LEVEL_CONT = 5, // continue previous log 14 | # }; 15 | GGML_LOG_LEVEL_TO_LOGGING_LEVEL = { 16 | 0: logging.CRITICAL, 17 | 1: logging.INFO, 18 | 2: logging.WARNING, 19 | 3: logging.ERROR, 20 | 4: logging.DEBUG, 21 | 5: logging.DEBUG, 22 | } 23 | 24 | logger = logging.getLogger("llama-cpp-python") 25 | 26 | _last_log_level = GGML_LOG_LEVEL_TO_LOGGING_LEVEL[0] 27 | 28 | # typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data); 29 | @llama_cpp.llama_log_callback 30 | def llama_log_callback( 31 | level: int, 32 | text: bytes, 33 | user_data: ctypes.c_void_p, 34 | ): 35 | # TODO: Correctly implement continue previous log 36 | global _last_log_level 37 | log_level = GGML_LOG_LEVEL_TO_LOGGING_LEVEL[level] if level != 5 else _last_log_level 38 | if logger.level <= GGML_LOG_LEVEL_TO_LOGGING_LEVEL[level]: 39 | print(text.decode("utf-8"), end="", flush=True, file=sys.stderr) 40 | _last_log_level = log_level 41 | 42 | 43 | llama_cpp.llama_log_set(llama_log_callback, ctypes.c_void_p(0)) 44 | 45 | 46 | def set_verbose(verbose: bool): 47 | logger.setLevel(logging.DEBUG if verbose else logging.ERROR) 48 | -------------------------------------------------------------------------------- /py310/Lib/site-packages/llama_cpp/_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from typing import Any, Dict 5 | 6 | # Avoid "LookupError: unknown encoding: ascii" when open() called in a destructor 7 | outnull_file = open(os.devnull, "w") 8 | errnull_file = open(os.devnull, "w") 9 | 10 | STDOUT_FILENO = 1 11 | STDERR_FILENO = 2 12 | 13 | 14 | class suppress_stdout_stderr(object): 15 | # NOTE: these must be "saved" here to avoid exceptions when using 16 | # this context manager inside of a __del__ method 17 | sys = sys 18 | os = os 19 | 20 | def __init__(self, disable: bool = True): 21 | self.disable = disable 22 | 23 | # Oddly enough this works better than the contextlib version 24 | def __enter__(self): 25 | if self.disable: 26 | return self 27 | 28 | self.old_stdout_fileno_undup = STDOUT_FILENO 29 | self.old_stderr_fileno_undup = STDERR_FILENO 30 | 31 | self.old_stdout_fileno = self.os.dup(self.old_stdout_fileno_undup) 32 | self.old_stderr_fileno = self.os.dup(self.old_stderr_fileno_undup) 33 | 34 | self.old_stdout = self.sys.stdout 35 | self.old_stderr = self.sys.stderr 36 | 37 | self.os.dup2(outnull_file.fileno(), self.old_stdout_fileno_undup) 38 | self.os.dup2(errnull_file.fileno(), self.old_stderr_fileno_undup) 39 | 40 | self.sys.stdout = outnull_file 41 | self.sys.stderr = errnull_file 42 | return self 43 | 44 | def __exit__(self, *_): 45 | if self.disable: 46 | return 47 | 48 | # Check if sys.stdout and sys.stderr have fileno method 49 | self.sys.stdout = self.old_stdout 50 | self.sys.stderr = self.old_stderr 51 | 52 | self.os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) 53 | self.os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) 54 | 55 | self.os.close(self.old_stdout_fileno) 56 | self.os.close(self.old_stderr_fileno) 57 | 58 | 59 | class MetaSingleton(type): 60 | """ 61 | Metaclass for implementing the Singleton pattern. 62 | """ 63 | 64 | _instances: Dict[type, Any] = {} 65 | 66 | def __call__(cls, *args: Any, **kwargs: Any) -> Any: 67 | if cls not in cls._instances: 68 | cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs) 69 | return cls._instances[cls] 70 | 71 | 72 | class Singleton(object, metaclass=MetaSingleton): 73 | """ 74 | Base class for implementing the Singleton pattern. 75 | """ 76 | 77 | def __init__(self): 78 | super(Singleton, self).__init__() 79 | -------------------------------------------------------------------------------- /py310/Lib/site-packages/llama_cpp/lib/ggml-base.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp/lib/ggml-base.dll -------------------------------------------------------------------------------- /py310/Lib/site-packages/llama_cpp/lib/ggml-base.lib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp/lib/ggml-base.lib -------------------------------------------------------------------------------- /py310/Lib/site-packages/llama_cpp/lib/ggml-cpu.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp/lib/ggml-cpu.dll -------------------------------------------------------------------------------- /py310/Lib/site-packages/llama_cpp/lib/ggml-cpu.lib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp/lib/ggml-cpu.lib -------------------------------------------------------------------------------- /py310/Lib/site-packages/llama_cpp/lib/ggml-vulkan.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp/lib/ggml-vulkan.dll -------------------------------------------------------------------------------- /py310/Lib/site-packages/llama_cpp/lib/ggml-vulkan.lib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp/lib/ggml-vulkan.lib -------------------------------------------------------------------------------- /py310/Lib/site-packages/llama_cpp/lib/ggml.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp/lib/ggml.dll -------------------------------------------------------------------------------- /py310/Lib/site-packages/llama_cpp/lib/ggml.lib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp/lib/ggml.lib -------------------------------------------------------------------------------- /py310/Lib/site-packages/llama_cpp/lib/llama.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp/lib/llama.dll -------------------------------------------------------------------------------- /py310/Lib/site-packages/llama_cpp/lib/llama.lib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp/lib/llama.lib -------------------------------------------------------------------------------- /py310/Lib/site-packages/llama_cpp/lib/llava.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp/lib/llava.dll -------------------------------------------------------------------------------- /py310/Lib/site-packages/llama_cpp/lib/llava.lib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp/lib/llava.lib -------------------------------------------------------------------------------- /py310/Lib/site-packages/llama_cpp/llama_speculative.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | from typing import Any 4 | 5 | import numpy as np 6 | import numpy.typing as npt 7 | 8 | 9 | class LlamaDraftModel(abc.ABC): 10 | @abc.abstractmethod 11 | def __call__( 12 | self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any 13 | ) -> npt.NDArray[np.intc]: 14 | raise NotImplementedError() 15 | 16 | 17 | class LlamaPromptLookupDecoding(LlamaDraftModel): 18 | """Based on https://github.com/apoorvumang/prompt-lookup-decoding""" 19 | 20 | def __init__(self, max_ngram_size: int = 2, num_pred_tokens: int = 10): 21 | self.max_ngram_size = max_ngram_size 22 | self.num_pred_tokens = num_pred_tokens 23 | 24 | @staticmethod 25 | def find_candidate_pred_tokens( 26 | input_ids: npt.NDArray[np.intc], 27 | max_ngram_size: int, 28 | num_pred_tokens: int, 29 | ): 30 | input_length = input_ids.shape[0] 31 | 32 | for ngram_size in range(min(max_ngram_size, input_length - 1), 0, -1): 33 | # Create sliding windows of size ngram_size 34 | windows = np.lib.stride_tricks.sliding_window_view(input_ids, (ngram_size,)) 35 | 36 | # Convert ngram to an array for comparison 37 | ngram_array = input_ids[-ngram_size:] 38 | 39 | # Find where the windows match the ngram 40 | matches = np.all(windows == ngram_array, axis=1) 41 | 42 | # Get the indices of matches 43 | match_indices = np.nonzero(matches)[0] 44 | 45 | # Iterate through match indices to find a valid continuation 46 | for idx in match_indices: 47 | start_idx = idx + ngram_size 48 | end_idx = start_idx + num_pred_tokens 49 | end_idx = min(end_idx, input_length) 50 | 51 | if start_idx < end_idx: 52 | return input_ids[start_idx:end_idx] 53 | 54 | # If no match is found, return an empty array 55 | return np.array([], dtype=np.intc) 56 | 57 | def __call__( 58 | self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any 59 | ) -> npt.NDArray[np.intc]: 60 | return self.find_candidate_pred_tokens( 61 | input_ids=input_ids, 62 | max_ngram_size=self.max_ngram_size, 63 | num_pred_tokens=self.num_pred_tokens, 64 | ) 65 | -------------------------------------------------------------------------------- /py310/Lib/site-packages/llama_cpp/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp/py.typed -------------------------------------------------------------------------------- /py310/Lib/site-packages/llama_cpp/server/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp/server/__init__.py -------------------------------------------------------------------------------- /py310/Lib/site-packages/llama_cpp_python-0.3.9.dist-info/INSTALLER: -------------------------------------------------------------------------------- 1 | pip 2 | -------------------------------------------------------------------------------- /py310/Lib/site-packages/llama_cpp_python-0.3.9.dist-info/REQUESTED: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp_python-0.3.9.dist-info/REQUESTED -------------------------------------------------------------------------------- /py310/Lib/site-packages/llama_cpp_python-0.3.9.dist-info/WHEEL: -------------------------------------------------------------------------------- 1 | Wheel-Version: 1.0 2 | Generator: scikit-build-core 0.11.4 3 | Root-Is-Purelib: false 4 | Tag: cp310-cp310-win_amd64 5 | 6 | -------------------------------------------------------------------------------- /py310/Lib/site-packages/llama_cpp_python-0.3.9.dist-info/direct_url.json: -------------------------------------------------------------------------------- 1 | {"archive_info": {"hash": "sha256=ab2c671a12fd73b27844145c7283290815c1fde80123391211f28c9328109d63", "hashes": {"sha256": "ab2c671a12fd73b27844145c7283290815c1fde80123391211f28c9328109d63"}}, "url": "file:///C:/Users/%E7%94%A8%E6%88%B7/Downloads/llama_cpp_python-0.3.9-cp310-cp310-win_amd64.whl"} -------------------------------------------------------------------------------- /py310/Lib/site-packages/llama_cpp_python-0.3.9.dist-info/licenses/LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Andrei Betlen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /scripts/merge_manifest.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os, sys 3 | 4 | MAIN_IMAGE_NAME=sys.argv[1] 5 | TARGET_TAG="latest" if len(sys.argv) < 3 else sys.argv[2] 6 | 7 | args=["docker manifest create {}:{}".format(MAIN_IMAGE_NAME, TARGET_TAG)] 8 | for i in glob.glob("/tmp/images/*/*.txt"): 9 | with open(i, "r") as file: 10 | args += " --amend {}@{}".format(MAIN_IMAGE_NAME, file.readline().strip()) 11 | cmd_create="".join(args) 12 | cmd_push="docker manifest push {}:{}".format(MAIN_IMAGE_NAME, TARGET_TAG) 13 | os.system(cmd_create) 14 | os.system(cmd_push) -------------------------------------------------------------------------------- /wails.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://wails.io/schemas/config.v2.json", 3 | "name": "RWKV-Runner", 4 | "outputfilename": "RWKV-Runner", 5 | "frontend:install": "npm install", 6 | "frontend:build": "npm run build", 7 | "frontend:dev:watcher": "npm run dev", 8 | "frontend:dev:serverUrl": "auto", 9 | "author": { 10 | "name": "josc146", 11 | "email": "josStorer@outlook.com" 12 | }, 13 | "Info": { 14 | "companyName": "RWKV-Runner", 15 | "productName": "RWKV-Runner", 16 | "productVersion": "1.0.0", 17 | "copyright": "Copyright © 2023 RWKV-Runner" 18 | } 19 | } 20 | --------------------------------------------------------------------------------