The response has been limited to 50k tokens of the smallest files in the repo. You can remove this limitation by removing the max tokens filter.
├── .gitattributes
├── .github
    ├── dependabot.yml
    └── workflows
    │   ├── docker.yml
    │   ├── pre-release.yml
    │   └── release.yml
├── .gitignore
├── .vscode
    ├── launch.json
    ├── settings.json
    └── tasks.json
├── CURRENT_CHANGE.md
├── Dockerfile
├── LICENSE
├── Makefile
├── README.md
├── README_JA.md
├── README_ZH.md
├── assets
    ├── default_sound_font.sf2
    ├── sound-font
    │   ├── sound_fetch.py
    │   └── soundfont.json
    └── soundfont_builder.rb
├── backend-golang
    ├── app.go
    ├── cmd_interactive.go
    ├── cmd_interactive_unix.go
    ├── cmd_interactive_windows.go
    ├── download.go
    ├── file.go
    ├── hw_info.go
    ├── midi.go
    ├── rwkv.go
    ├── utils.go
    ├── utils_unix.go
    ├── utils_windows.go
    ├── wsl_unix.go
    └── wsl_windows.go
├── backend-python
    ├── convert_model.py
    ├── convert_pytorch_to_ggml.py
    ├── convert_safetensors.py
    ├── dep_check.py
    ├── get-pip.py
    ├── global_var.py
    ├── main.py
    ├── requirements.txt
    ├── requirements_without_cyac.txt
    ├── routes
    │   ├── completion.py
    │   ├── config.py
    │   ├── file_process.py
    │   ├── midi.py
    │   ├── misc.py
    │   ├── schema.py
    │   └── state_cache.py
    ├── rwkv_pip
    │   ├── 20B_tokenizer.json
    │   ├── cpp
    │   │   ├── librwkv.dylib
    │   │   ├── librwkv.so
    │   │   ├── model.py
    │   │   ├── rwkv.dll
    │   │   ├── rwkv_cpp_model.py
    │   │   └── rwkv_cpp_shared_library.py
    │   ├── cuda
    │   │   ├── gemm_fp16_cublas.cpp
    │   │   ├── operators.cu
    │   │   ├── rwkv5.cu
    │   │   ├── rwkv5_op.cpp
    │   │   ├── rwkv6.cu
    │   │   ├── rwkv6_op.cpp
    │   │   ├── rwkv7.cu
    │   │   ├── rwkv7_op.cpp
    │   │   └── wrapper.cpp
    │   ├── kernels
    │   │   ├── torch-1.13.1+cu117
    │   │   │   ├── rwkv5.pyd
    │   │   │   ├── rwkv6.pyd
    │   │   │   ├── wkv7s.pyd
    │   │   │   └── wkv_cuda.pyd
    │   │   └── torch-2.7.1+cu128
    │   │   │   ├── rwkv5.pyd
    │   │   │   ├── rwkv6.pyd
    │   │   │   ├── wkv7s.pyd
    │   │   │   └── wkv_cuda.pyd
    │   ├── model.py
    │   ├── rwkv_tokenizer.py
    │   ├── rwkv_vocab_v20230424.txt
    │   ├── tokenizer-midi.json
    │   ├── tokenizer-midipiano.json
    │   ├── utils.py
    │   └── webgpu
    │   │   ├── model.py
    │   │   └── web_rwkv_py.cp310-win_amd64.pyd
    ├── tests
    │   ├── function_call.py
    │   ├── function_call_stream.py
    │   └── postprocess_response.py
    ├── utils
    │   ├── llama.py
    │   ├── log.py
    │   ├── midi.py
    │   ├── midi_filter_config.json
    │   ├── midi_vocab_config.json
    │   ├── ngrok.py
    │   ├── rwkv.py
    │   ├── torch.py
    │   └── vocab_config_piano.json
    └── webui_server.py
├── backend-rust
    └── assets
    │   └── rwkv_vocab_v20230424.json
├── build
    ├── README.md
    ├── appicon.png
    ├── darwin
    │   ├── Info.dev.plist
    │   ├── Info.plist
    │   ├── Readme_Install.txt
    │   ├── entitlements.plist
    │   └── gon-sign.json
    ├── linux
    │   └── Readme_Install.txt
    └── windows
    │   ├── Readme_Install.txt
    │   ├── WELCOMEFINISHPAGE.bmp
    │   ├── icon.ico
    │   ├── info.json
    │   ├── installer
    │       ├── project.nsi
    │       └── wails_tools.nsh
    │   └── wails.exe.manifest
├── components
    └── gitkeep
├── deploy-examples
    ├── ChatGPT-Next-Web
    │   ├── setup.bat
    │   └── setup.sh
    └── RWKV-Runner-WebUI
    │   ├── setup.bat
    │   └── setup.sh
├── docker-compose.yml
├── exportModelsJson.js
├── finetune
    ├── data
    │   └── sample.jsonl
    ├── get_layer_and_embd.py
    ├── install-wsl-dep-and-train.sh
    ├── json2binidx_tool
    │   └── tools
    │   │   ├── indexed_dataset.py
    │   │   ├── preprocess_data.py
    │   │   ├── rwkv_tokenizer.py
    │   │   └── tokenizer.py
    ├── lora
    │   ├── merge_lora.py
    │   ├── v4
    │   │   ├── cuda
    │   │   │   ├── wkv_cuda.cu
    │   │   │   ├── wkv_cuda_bf16.cu
    │   │   │   ├── wkv_op.cpp
    │   │   │   └── wkv_op_bf16.cpp
    │   │   ├── src
    │   │   │   ├── __init__.py
    │   │   │   ├── binidx.py
    │   │   │   ├── dataset.py
    │   │   │   ├── model.py
    │   │   │   ├── trainer.py
    │   │   │   └── utils.py
    │   │   └── train.py
    │   ├── v5
    │   │   ├── cuda
    │   │   │   ├── wkv5_cuda.cu
    │   │   │   └── wkv5_op.cpp
    │   │   ├── src
    │   │   │   ├── __init__.py
    │   │   │   ├── binidx.py
    │   │   │   ├── dataset.py
    │   │   │   ├── model.py
    │   │   │   ├── trainer.py
    │   │   │   └── utils.py
    │   │   └── train.py
    │   └── v6
    │   │   ├── cuda
    │   │       ├── wkv5_cuda.cu
    │   │       ├── wkv5_op.cpp
    │   │       ├── wkv6_cuda.cu
    │   │       ├── wkv6_op.cpp
    │   │       ├── wkv6infctx_cuda.cu
    │   │       ├── wkv6infctx_op.cpp
    │   │       ├── wkv6state_cuda.cu
    │   │       └── wkv6state_op.cpp
    │   │   ├── demo
    │   │       ├── demo-lora-merge.sh
    │   │       ├── demo-lora.sh
    │   │       ├── demo-pissa-merge.sh
    │   │       ├── demo-pissa.sh
    │   │       ├── demo-qpissa-pt.sh
    │   │       ├── demo-state-merge.sh
    │   │       ├── demo-state-tuning.sh
    │   │       ├── demo-training-prepare.sh
    │   │       ├── demo-training-run.sh
    │   │       └── infctx.sh
    │   │   ├── fla
    │   │       ├── __init__.py
    │   │       ├── layers
    │   │       │   ├── __init__.py
    │   │       │   ├── abc.py
    │   │       │   ├── based.py
    │   │       │   ├── delta_net.py
    │   │       │   ├── gated_abc.py
    │   │       │   ├── gla.py
    │   │       │   ├── hgrn.py
    │   │       │   ├── hgrn2.py
    │   │       │   ├── linear_attn.py
    │   │       │   ├── multiscale_retention.py
    │   │       │   ├── rebased.py
    │   │       │   ├── rwkv6.py
    │   │       │   └── simple_gla.py
    │   │       ├── models
    │   │       │   ├── __init__.py
    │   │       │   ├── abc
    │   │       │   │   ├── __init__.py
    │   │       │   │   ├── configuration_abc.py
    │   │       │   │   └── modeling_abc.py
    │   │       │   ├── delta_net
    │   │       │   │   ├── __init__.py
    │   │       │   │   ├── configuration_delta_net.py
    │   │       │   │   └── modeling_delta_net.py
    │   │       │   ├── gla
    │   │       │   │   ├── __init__.py
    │   │       │   │   ├── configuration_gla.py
    │   │       │   │   └── modeling_gla.py
    │   │       │   ├── hgrn
    │   │       │   │   ├── __init__.py
    │   │       │   │   ├── configuration_hgrn.py
    │   │       │   │   └── modeling_hgrn.py
    │   │       │   ├── hgrn2
    │   │       │   │   ├── __init__.py
    │   │       │   │   ├── configuration_hgrn2.py
    │   │       │   │   └── modeling_hgrn2.py
    │   │       │   ├── linear_attn
    │   │       │   │   ├── __init__.py
    │   │       │   │   ├── configuration_linear_attn.py
    │   │       │   │   └── modeling_linear_attn.py
    │   │       │   ├── mamba
    │   │       │   │   ├── __init__.py
    │   │       │   │   ├── configuration_mamba.py
    │   │       │   │   └── modeling_mamba.py
    │   │       │   ├── retnet
    │   │       │   │   ├── __init__.py
    │   │       │   │   ├── configuration_retnet.py
    │   │       │   │   └── modeling_retnet.py
    │   │       │   ├── rwkv6
    │   │       │   │   ├── __init__.py
    │   │       │   │   ├── configuration_rwkv6.py
    │   │       │   │   └── modeling_rwkv6.py
    │   │       │   ├── transformer
    │   │       │   │   ├── __init__.py
    │   │       │   │   ├── configuration_transformer.py
    │   │       │   │   └── modeling_transformer.py
    │   │       │   └── utils.py
    │   │       ├── modules
    │   │       │   ├── __init__.py
    │   │       │   ├── activations.py
    │   │       │   ├── convolution.py
    │   │       │   ├── feature_map.py
    │   │       │   ├── fused_cross_entropy.py
    │   │       │   ├── fused_norm_gate.py
    │   │       │   ├── l2norm.py
    │   │       │   ├── layernorm.py
    │   │       │   └── rotary.py
    │   │       ├── ops
    │   │       │   ├── __init__.py
    │   │       │   ├── abc
    │   │       │   │   ├── __init__.py
    │   │       │   │   ├── chunk.py
    │   │       │   │   ├── chunk_gate.py
    │   │       │   │   ├── naive.py
    │   │       │   │   └── recurrent_fuse.py
    │   │       │   ├── based
    │   │       │   │   ├── __init__.py
    │   │       │   │   ├── chunk_fuse.py
    │   │       │   │   ├── naive.py
    │   │       │   │   └── parallel.py
    │   │       │   ├── delta_rule
    │   │       │   │   ├── README.md
    │   │       │   │   ├── __init__.py
    │   │       │   │   ├── chunk.py
    │   │       │   │   ├── chunk_fuse.py
    │   │       │   │   ├── naive.py
    │   │       │   │   ├── recurrent_fuse.py
    │   │       │   │   ├── utils.py
    │   │       │   │   └── wy_fast.py
    │   │       │   ├── gla
    │   │       │   │   ├── __init__.py
    │   │       │   │   ├── chunk.py
    │   │       │   │   ├── chunk_fuse.py
    │   │       │   │   ├── chunk_util.py
    │   │       │   │   ├── naive.py
    │   │       │   │   └── recurrent_fuse.py
    │   │       │   ├── hgrn
    │   │       │   │   ├── __init__.py
    │   │       │   │   ├── chunk.py
    │   │       │   │   ├── naive.py
    │   │       │   │   └── recurrent_fuse.py
    │   │       │   ├── linear_attn
    │   │       │   │   ├── __init__.py
    │   │       │   │   ├── chunk.py
    │   │       │   │   ├── chunk_fuse.py
    │   │       │   │   ├── naive.py
    │   │       │   │   └── recurrent_fuse.py
    │   │       │   ├── rebased
    │   │       │   │   ├── __init__.py
    │   │       │   │   ├── naive.py
    │   │       │   │   └── parallel.py
    │   │       │   ├── retention
    │   │       │   │   ├── __init__.py
    │   │       │   │   ├── chunk.py
    │   │       │   │   ├── chunk_fuse.py
    │   │       │   │   ├── naive.py
    │   │       │   │   ├── parallel.py
    │   │       │   │   └── recurrent_fuse.py
    │   │       │   ├── rotary.py
    │   │       │   ├── rwkv4
    │   │       │   │   ├── __init__.py
    │   │       │   │   └── recurrent_fuse.py
    │   │       │   ├── rwkv6
    │   │       │   │   ├── __init__.py
    │   │       │   │   ├── chunk.py
    │   │       │   │   ├── chunk_naive.py
    │   │       │   │   ├── recurrent_fuse.py
    │   │       │   │   └── recurrent_naive.py
    │   │       │   ├── simple_gla
    │   │       │   │   ├── README.md
    │   │       │   │   ├── __init__.py
    │   │       │   │   ├── chunk.py
    │   │       │   │   └── naive.py
    │   │       │   └── utils.py
    │   │       └── utils.py
    │   │   ├── merge
    │   │       ├── merge.py
    │   │       ├── merge_lora.py
    │   │       ├── merge_pissa.py
    │   │       └── merge_state.py
    │   │   ├── requirements.txt
    │   │   ├── src
    │   │       ├── __init__.py
    │   │       ├── binidx.py
    │   │       ├── dataset.py
    │   │       ├── infctx_module.py
    │   │       ├── model.py
    │   │       ├── trainer.py
    │   │       └── utils.py
    │   │   └── train.py
    └── requirements.txt
├── frontend
    ├── i18nally.json
    ├── index.html
    ├── package-lock.json
    ├── package.json
    ├── postcss.config.js
    ├── prettier.config.js
    ├── src
    │   ├── App.tsx
    │   ├── _locales
    │   │   ├── i18n-react.ts
    │   │   ├── i18n.ts
    │   │   ├── ja
    │   │   │   └── main.json
    │   │   ├── resources.ts
    │   │   └── zh-hans
    │   │   │   └── main.json
    │   ├── apis
    │   │   └── index.ts
    │   ├── assets
    │   │   └── images
    │   │   │   ├── banner.jpg
    │   │   │   ├── logo.png
    │   │   │   ├── strategy.jpg
    │   │   │   └── strategy_zh.jpg
    │   ├── components
    │   │   ├── BottomLogger.tsx
    │   │   ├── ConfigSelector.tsx
    │   │   ├── CopyButton.tsx
    │   │   ├── CustomToastContainer.tsx
    │   │   ├── DebugModeIndicator.tsx
    │   │   ├── DialogButton.tsx
    │   │   ├── Labeled.tsx
    │   │   ├── LazyImportComponent.tsx
    │   │   ├── MarkdownRender.tsx
    │   │   ├── MobileFloatingNavigator.tsx
    │   │   ├── NumberInput.tsx
    │   │   ├── Page.tsx
    │   │   ├── ReadButton.tsx
    │   │   ├── ResetConfigsButton.tsx
    │   │   ├── RunButton.tsx
    │   │   ├── Section.tsx
    │   │   ├── ToolTipButton.tsx
    │   │   ├── ValuedSlider.tsx
    │   │   └── WorkHeader.tsx
    │   ├── main.tsx
    │   ├── pages
    │   │   ├── About.tsx
    │   │   ├── AudiotrackManager
    │   │   │   ├── AudiotrackButton.tsx
    │   │   │   └── AudiotrackEditor.tsx
    │   │   ├── AutoConfig.tsx
    │   │   ├── Chat.tsx
    │   │   ├── Completion.tsx
    │   │   ├── Composition.tsx
    │   │   ├── Configs.tsx
    │   │   ├── Downloads.tsx
    │   │   ├── Home.tsx
    │   │   ├── Models.tsx
    │   │   ├── PresetsManager
    │   │   │   ├── MessagesEditor.tsx
    │   │   │   └── PresetsButton.tsx
    │   │   ├── Settings.tsx
    │   │   ├── Train.tsx
    │   │   ├── defaultConfigs.ts
    │   │   └── index.tsx
    │   ├── startup.ts
    │   ├── stores
    │   │   ├── cmdTaskChainStore.ts
    │   │   └── commonStore.ts
    │   ├── style.scss
    │   ├── types
    │   │   ├── about.ts
    │   │   ├── chat.ts
    │   │   ├── completion.ts
    │   │   ├── composition.ts
    │   │   ├── configs.ts
    │   │   ├── downloads.ts
    │   │   ├── home.ts
    │   │   ├── html-midi-player.d.ts
    │   │   ├── models.ts
    │   │   ├── presets.ts
    │   │   ├── settings.ts
    │   │   └── train.ts
    │   ├── utils
    │   │   ├── convert-model.ts
    │   │   ├── copy-cuda-kernels.ts
    │   │   ├── filter-function-properties.ts
    │   │   ├── generate-strategy.ts
    │   │   ├── get-available-torch-cu-version.ts
    │   │   ├── index.tsx
    │   │   ├── rwkv-task.ts
    │   │   └── web-file-operations.ts
    │   ├── vite-env.d.ts
    │   └── webWails.js
    ├── tailwind.config.js
    ├── tsconfig.json
    ├── tsconfig.node.json
    ├── vite.config.ts
    └── wailsjs
    │   ├── go
    │       ├── backend_golang
    │       │   ├── App.d.ts
    │       │   └── App.js
    │       └── models.ts
    │   └── runtime
    │       ├── package.json
    │       ├── runtime.d.ts
    │       └── runtime.js
├── go.mod
├── go.sum
├── main.go
├── manifest.json
├── midi
    └── sample.txt
├── parse_api_log.py
├── py310
    └── Lib
    │   └── site-packages
    │       ├── cyac-1.9.dist-info
    │           └── .gitkeep
    │       ├── cyac
    │           └── .gitkeep
    │       ├── diskcache-5.6.3.dist-info
    │           ├── INSTALLER
    │           ├── LICENSE
    │           ├── METADATA
    │           ├── RECORD
    │           ├── REQUESTED
    │           ├── WHEEL
    │           └── top_level.txt
    │       ├── diskcache
    │           ├── __init__.py
    │           ├── cli.py
    │           ├── core.py
    │           ├── djangocache.py
    │           ├── fanout.py
    │           ├── persistent.py
    │           └── recipes.py
    │       ├── llama_cpp
    │           ├── __init__.py
    │           ├── _ctypes_extensions.py
    │           ├── _ggml.py
    │           ├── _internals.py
    │           ├── _logger.py
    │           ├── _utils.py
    │           ├── lib
    │           │   ├── ggml-base.dll
    │           │   ├── ggml-base.lib
    │           │   ├── ggml-cpu.dll
    │           │   ├── ggml-cpu.lib
    │           │   ├── ggml-vulkan.dll
    │           │   ├── ggml-vulkan.lib
    │           │   ├── ggml.dll
    │           │   ├── ggml.lib
    │           │   ├── llama.dll
    │           │   ├── llama.lib
    │           │   ├── llava.dll
    │           │   └── llava.lib
    │           ├── llama.py
    │           ├── llama_cache.py
    │           ├── llama_chat_format.py
    │           ├── llama_cpp.py
    │           ├── llama_grammar.py
    │           ├── llama_speculative.py
    │           ├── llama_tokenizer.py
    │           ├── llama_types.py
    │           ├── llava_cpp.py
    │           ├── py.typed
    │           └── server
    │           │   ├── __init__.py
    │           │   ├── __main__.py
    │           │   ├── app.py
    │           │   ├── cli.py
    │           │   ├── errors.py
    │           │   ├── model.py
    │           │   ├── settings.py
    │           │   └── types.py
    │       └── llama_cpp_python-0.3.9.dist-info
    │           ├── INSTALLER
    │           ├── METADATA
    │           ├── RECORD
    │           ├── REQUESTED
    │           ├── WHEEL
    │           ├── direct_url.json
    │           └── licenses
    │               └── LICENSE.md
├── scripts
    └── merge_manifest.py
└── wails.json


/.gitattributes:
--------------------------------------------------------------------------------
 1 | * text=auto eol=lf
 2 | 
 3 | backend-python/rwkv_pip/** linguist-vendored
 4 | backend-python/wkv_cuda_utils/** linguist-vendored
 5 | backend-python/get-pip.py linguist-vendored
 6 | backend-python/convert_model.py linguist-vendored
 7 | backend-python/convert_safetensors.py linguist-vendored
 8 | backend-python/convert_pytorch_to_ggml.py linguist-vendored
 9 | backend-python/utils/midi.py linguist-vendored
10 | build/** linguist-vendored
11 | finetune/lora/** linguist-vendored
12 | finetune/json2binidx_tool/** linguist-vendored
13 | py310/** linguist-vendored
14 | frontend/wailsjs/** linguist-generated


--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
 1 | version: 2
 2 | updates:
 3 |   - package-ecosystem: "github-actions"
 4 |     directory: "/"
 5 |     schedule:
 6 |       interval: "weekly"
 7 |     commit-message:
 8 |       prefix: "chore"
 9 |       include: "scope"
10 | 


--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
 1 | build/bin
 2 | node_modules
 3 | frontend/dist
 4 | __pycache__
 5 | .idea
 6 | .vs
 7 | *.pth
 8 | *.st
 9 | *.safetensors
10 | *.bin
11 | *.mid
12 | /config.json
13 | /cache.json
14 | /durable.json
15 | /presets.json
16 | /frontend/stats.html
17 | /frontend/package.json.md5
18 | /py310
19 | *.zip
20 | /cmd-helper.bat
21 | /install-py-dep.bat
22 | /backend-python/wkv_cuda
23 | /backend-python/rwkv5
24 | /backend-python/rwkv6
25 | /backend-python/wkv7s
26 | /backend-python/rwkv_pip/wkv_cuda.pyd
27 | /backend-python/rwkv_pip/rwkv5.pyd
28 | /backend-python/rwkv_pip/rwkv6.pyd
29 | /backend-python/rwkv_pip/wkv7s.pyd
30 | *.exe
31 | *.old
32 | .DS_Store
33 | *.log.*
34 | *.log
35 | train_log.txt
36 | finetune/json2binidx_tool/data
37 | /wsl.state
38 | /components
39 | error.txt
40 | 


--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
 1 | {
 2 |     // Use IntelliSense to learn about possible attributes.
 3 |     // Hover to view descriptions of existing attributes.
 4 |     // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
 5 |     //
 6 |     // Use Ctrl+Shift+P to Select Interpreter
 7 |     "version": "0.2.0",
 8 |     "configurations": [
 9 |         {
10 |             "name": "Python",
11 |             "type": "python",
12 |             "request": "launch",
13 |             "program": "${workspaceFolder}/backend-python/main.py",
14 |             "console": "integratedTerminal",
15 |             "justMyCode": false
16 |         },
17 |         {
18 |             "name": "Golang",
19 |             "type": "go",
20 |             "request": "launch",
21 |             "mode": "exec",
22 |             "program": "${workspaceFolder}/build/bin/testwails.exe",
23 |             "console": "integratedTerminal",
24 |             "preLaunchTask": "build dev"
25 |         },
26 |         {
27 |             "name": "Frontend",
28 |             "type": "node-terminal",
29 |             "request": "launch",
30 |             "command": "wails dev -browser"
31 |         }
32 |     ]
33 | }


--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
 1 | {
 2 |   "prettier.configPath": "./frontend/prettier.config.js",
 3 |   "prettier.prettierPath": "./frontend/node_modules/prettier",
 4 |   "prettier.requireConfig": true,
 5 |   "editor.defaultFormatter": "esbenp.prettier-vscode",
 6 |   "[go]": {
 7 |     "editor.defaultFormatter": "golang.go"
 8 |   },
 9 |   "[python]": {
10 |     "editor.defaultFormatter": "ms-python.black-formatter"
11 |   },
12 |   "python.formatting.provider": "none",
13 |   "editor.formatOnSave": true,
14 | }
15 | 


--------------------------------------------------------------------------------
/.vscode/tasks.json:
--------------------------------------------------------------------------------
 1 | {
 2 |   "version": "2.0.0",
 3 |   "tasks": [
 4 |     {
 5 |       "label": "build dev",
 6 |       "type": "shell",
 7 |       "options": {
 8 |         "cwd": "${workspaceFolder}",
 9 |         "env": {
10 |           "CGO_ENABLED": "1"
11 |         }
12 |       },
13 |       "osx": {
14 |         "options": {
15 |           "env": {
16 |             "CGO_CFLAGS": "-mmacosx-version-min=10.13",
17 |             "CGO_LDFLAGS": "-framework UniformTypeIdentifiers -mmacosx-version-min=10.13"
18 |           }
19 |         }
20 |       },
21 |       "windows": {
22 |         "options": {
23 |           "env": {
24 |             "CGO_ENABLED": "0"
25 |           }
26 |         }
27 |       },
28 |       "command": "go",
29 |       "args": [
30 |         "build",
31 |         "-tags",
32 |         "dev",
33 |         "-gcflags",
34 |         "all=-N -l",
35 |         "-o",
36 |         "build/bin/testwails.exe"
37 |       ]
38 |     }
39 |   ]
40 | }


--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
 1 | FROM node:21-slim AS frontend
 2 | 
 3 | RUN echo "registry=https://registry.npmmirror.com/" > ~/.npmrc
 4 | 
 5 | WORKDIR /app
 6 | 
 7 | COPY manifest.json manifest.json
 8 | COPY frontend frontend
 9 | 
10 | WORKDIR /app/frontend
11 | 
12 | RUN npm ci
13 | RUN npm run build
14 | 
15 | FROM nvidia/cuda:11.6.1-devel-ubuntu20.04 AS runtime
16 | 
17 | ENV DEBIAN_FRONTEND=noninteractive
18 | 
19 | RUN apt update && \
20 |     apt install -yq git curl wget build-essential ninja-build aria2 jq software-properties-common
21 | 
22 | RUN add-apt-repository -y ppa:deadsnakes/ppa && \
23 |     add-apt-repository -y ppa:ubuntu-toolchain-r/test && \
24 |     apt install -y g++-11 python3.10 python3.10-distutils python3.10-dev && \
25 |     curl -sS http://mirrors.aliyun.com/pypi/get-pip.py | python3.10
26 | 
27 | RUN python3.10 -m pip install cmake
28 | 
29 | FROM runtime AS librwkv
30 | 
31 | WORKDIR /app
32 | 
33 | RUN git clone https://github.com/RWKV/rwkv.cpp.git && \
34 |     cd rwkv.cpp && \
35 |     git submodule update --init --recursive && \
36 |     mkdir -p build && \
37 |     cd build && \
38 |     cmake -G Ninja .. && \
39 |     cmake --build .
40 | 
41 | FROM runtime AS final
42 | 
43 | WORKDIR /app
44 | 
45 | COPY ./backend-python/requirements.txt ./backend-python/requirements.txt
46 | 
47 | RUN python3.10 -m pip install --quiet -r ./backend-python/requirements.txt
48 | 
49 | COPY . .
50 | COPY --from=frontend /app/frontend/dist /app/frontend/dist
51 | COPY --from=librwkv /app/rwkv.cpp/build/librwkv.so /app/backend-python/rwkv_pip/cpp/librwkv.so
52 | 
53 | EXPOSE 27777
54 | 
55 | CMD ["python3.10", "./backend-python/main.py", "--port", "27777", "--host", "0.0.0.0", "--webui"]
56 | 


--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
 1 | MIT License
 2 | 
 3 | Copyright (c) 2023 josStorer
 4 | 
 5 | Permission is hereby granted, free of charge, to any person obtaining a copy
 6 | of this software and associated documentation files (the "Software"), to deal
 7 | in the Software without restriction, including without limitation the rights
 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 | 
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 | 
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.


--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
 1 | ifeq ($(OS), Windows_NT)
 2 | build: build-windows
 3 | else ifeq ($(shell uname -s), Darwin)
 4 | build: build-macos
 5 | else
 6 | build: build-linux
 7 | endif
 8 | 
 9 | windows_build = wails build -ldflags '-s -w -extldflags "-static"' -platform windows/amd64 -devtools -upx -upxflags "-9 --lzma"
10 | 
11 | build-windows:
12 | 	@echo ---- build for windows
13 | 	$(windows_build) -nsis
14 | 
15 | debug:
16 | 	$(windows_build) -windowsconsole
17 | 
18 | build-macos:
19 | 	@echo ---- build for macos
20 | 	wails build -ldflags '-s -w' -platform darwin/universal -devtools
21 | 
22 | build-linux:
23 | 	@echo ---- build for linux
24 | 	wails build -ldflags '-s -w' -platform linux/amd64 -devtools -upx -upxflags "-9 --lzma"
25 | 
26 | build-linux_webkit2_41:
27 | 	@echo ---- build for linux with webkit2_41
28 | 	wails build -tags webkit2_41 -ldflags '-s -w' -platform linux/amd64 -devtools -upx -upxflags "-9 --lzma"
29 | 
30 | build-web:
31 | 	@echo ---- build for web
32 | 	cd frontend && npm run build
33 | 
34 | dev:
35 | 	wails dev
36 | 
37 | # go install github.com/josStorer/wails/v2/cmd/wails@v2.9.2x
38 | devq:
39 | 	wails dev -s -m -skipembedcreate -skipbindings
40 | 
41 | devq2:
42 | 	wails dev -s -m -skipembedcreate
43 | 
44 | dev-web:
45 | 	cd frontend && npm run dev
46 | 
47 | preview:
48 | 	cd frontend && npm run preview
49 | 
50 | 


--------------------------------------------------------------------------------
/assets/default_sound_font.sf2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/assets/default_sound_font.sf2


--------------------------------------------------------------------------------
/backend-golang/cmd_interactive.go:
--------------------------------------------------------------------------------
 1 | package backend_golang
 2 | 
 3 | import "os"
 4 | 
 5 | var cmds = make(map[string][]string)
 6 | var cmdProcesses = make(map[string]*os.Process)
 7 | 
 8 | func (a *App) GetCmds() map[string][]string {
 9 | 	return cmds
10 | }
11 | 
12 | func (a *App) KillCmd(eventId string) error {
13 | 	cmd, ok := cmdProcesses[eventId]
14 | 	if !ok {
15 | 		return nil
16 | 	}
17 | 	delete(cmds, eventId)
18 | 	delete(cmdProcesses, eventId)
19 | 	return cmd.Kill()
20 | }
21 | 
22 | func (a *App) IsCmdRunning(eventId string) bool {
23 | 	_, ok := cmds[eventId]
24 | 	return ok
25 | }
26 | 


--------------------------------------------------------------------------------
/backend-golang/cmd_interactive_unix.go:
--------------------------------------------------------------------------------
 1 | //go:build darwin || linux
 2 | 
 3 | package backend_golang
 4 | 
 5 | import (
 6 | 	"bufio"
 7 | 	"io"
 8 | 	"os/exec"
 9 | 
10 | 	wruntime "github.com/wailsapp/wails/v2/pkg/runtime"
11 | )
12 | 
13 | func (a *App) CmdInteractive(args []string, eventId string) error {
14 | 	cmd := exec.Command(args[0], args[1:]...)
15 | 
16 | 	stdout, err := cmd.StdoutPipe()
17 | 	if err != nil {
18 | 		return err
19 | 	}
20 | 
21 | 	cmd.Stderr = cmd.Stdout
22 | 
23 | 	err = cmd.Start()
24 | 
25 | 	if err != nil {
26 | 		return err
27 | 	}
28 | 
29 | 	cmds[eventId] = args
30 | 	cmdProcesses[eventId] = cmd.Process
31 | 	reader := bufio.NewReader(stdout)
32 | 
33 | 	for {
34 | 		line, _, err := reader.ReadLine()
35 | 		if err != nil {
36 | 			delete(cmds, eventId)
37 | 			delete(cmdProcesses, eventId)
38 | 			if err == io.EOF {
39 | 				wruntime.EventsEmit(a.ctx, eventId+"-finish")
40 | 				return nil
41 | 			}
42 | 			return err
43 | 		}
44 | 		strLine := string(line)
45 | 		wruntime.EventsEmit(a.ctx, eventId+"-output", strLine)
46 | 	}
47 | }
48 | 


--------------------------------------------------------------------------------
/backend-golang/cmd_interactive_windows.go:
--------------------------------------------------------------------------------
 1 | //go:build windows
 2 | 
 3 | package backend_golang
 4 | 
 5 | import (
 6 | 	"bufio"
 7 | 	"bytes"
 8 | 	"io"
 9 | 	"os/exec"
10 | 	"syscall"
11 | 
12 | 	wruntime "github.com/wailsapp/wails/v2/pkg/runtime"
13 | 	"golang.org/x/text/encoding/simplifiedchinese"
14 | 	"golang.org/x/text/transform"
15 | )
16 | 
17 | func (a *App) CmdInteractive(args []string, eventId string) error {
18 | 	cmd := exec.Command(args[0], args[1:]...)
19 | 
20 | 	stdout, err := cmd.StdoutPipe()
21 | 	if err != nil {
22 | 		return err
23 | 	}
24 | 
25 | 	cmd.Stderr = cmd.Stdout
26 | 	cmd.SysProcAttr = &syscall.SysProcAttr{}
27 | 	cmd.SysProcAttr.HideWindow = true
28 | 
29 | 	err = cmd.Start()
30 | 
31 | 	if err != nil {
32 | 		return err
33 | 	}
34 | 
35 | 	cmds[eventId] = args
36 | 	cmdProcesses[eventId] = cmd.Process
37 | 	reader := bufio.NewReader(stdout)
38 | 
39 | 	for {
40 | 		line, _, err := reader.ReadLine()
41 | 		if err != nil {
42 | 			delete(cmds, eventId)
43 | 			delete(cmdProcesses, eventId)
44 | 			if err == io.EOF {
45 | 				wruntime.EventsEmit(a.ctx, eventId+"-finish")
46 | 				return nil
47 | 			}
48 | 			return err
49 | 		}
50 | 		reader := transform.NewReader(bytes.NewReader(line), simplifiedchinese.GBK.NewDecoder())
51 | 		line2, err := io.ReadAll(reader)
52 | 		if err == nil {
53 | 			line = line2
54 | 		}
55 | 		strLine := string(line)
56 | 		wruntime.EventsEmit(a.ctx, eventId+"-output", strLine)
57 | 	}
58 | }
59 | 


--------------------------------------------------------------------------------
/backend-golang/hw_info.go:
--------------------------------------------------------------------------------
 1 | package backend_golang
 2 | 
 3 | import (
 4 | 	"errors"
 5 | 	"strconv"
 6 | 	"strings"
 7 | )
 8 | 
 9 | func (a *App) GetNvidiaGpuCount() (int, error) {
10 | 	// temperature.gpu,utilization.gpu,utilization.memory,memory.total,memory.free,memory.used
11 | 	// gpu_name,gpu_bus_id,driver_version
12 | 	// nvidia-smi --help-query-gpu
13 | 	output, err := a.CommandOutput("nvidia-smi", "--query-gpu=count", "--format=csv,noheader,nounits")
14 | 	if err != nil {
15 | 		return 0, err
16 | 	}
17 | 	return strconv.Atoi(output)
18 | }
19 | 
20 | func (a *App) GetCudaComputeCapability(index int) (string, error) {
21 | 	output, err := a.CommandOutput("nvidia-smi", "-i="+strconv.Itoa(index), "--query-gpu=compute_cap", "--format=csv,noheader,nounits")
22 | 	if err != nil {
23 | 		return "", err
24 | 	}
25 | 
26 | 	if output == "" {
27 | 		return "", errors.New("compute capability is empty")
28 | 	}
29 | 
30 | 	return output, nil
31 | }
32 | 
33 | func (a *App) GetMaxCudaComputeCapability() (string, error) {
34 | 	gpuCount, err := a.GetNvidiaGpuCount()
35 | 	if err != nil {
36 | 		return "", err
37 | 	}
38 | 	maxComputeCap := "0.0"
39 | 	for i := 0; i < gpuCount; i++ {
40 | 		computeCap, err := a.GetCudaComputeCapability(i)
41 | 		if err != nil {
42 | 			return "", err
43 | 		}
44 | 		computeCapFloat, err := strconv.ParseFloat(computeCap, 64)
45 | 		if err != nil {
46 | 			return "", err
47 | 		}
48 | 		maxComputeCapFloat, err := strconv.ParseFloat(maxComputeCap, 64)
49 | 		if err != nil {
50 | 			return "", err
51 | 		}
52 | 		if computeCapFloat > maxComputeCapFloat {
53 | 			maxComputeCap = computeCap
54 | 		}
55 | 	}
56 | 	if maxComputeCap == "0.0" {
57 | 		return "", errors.New("no cuda compute capability")
58 | 	}
59 | 	return maxComputeCap, nil
60 | }
61 | 
62 | func (a *App) GetSupportedCudaVersion() (string, error) {
63 | 	output, err := a.CommandOutput("nvidia-smi", "--query")
64 | 	if err != nil {
65 | 		return "", err
66 | 	}
67 | 
68 | 	lines := strings.Split(output, "\n")
69 | 
70 | 	for _, line := range lines {
71 | 		if strings.Contains(line, "CUDA Version") {
72 | 			return strings.TrimSpace(strings.Split(line, ":")[1]), nil
73 | 		}
74 | 	}
75 | 
76 | 	return "", errors.New("cuda version is empty")
77 | }
78 | 
79 | func (a *App) GetTorchVersion(python string) (string, error) {
80 | 	var err error
81 | 	if python == "" {
82 | 		python, err = a.GetPython()
83 | 		if err != nil {
84 | 			return "", err
85 | 		}
86 | 	}
87 | 
88 | 	output, err := a.CommandOutput(python, "-c", "import torch; print(torch.__version__)")
89 | 	if err != nil {
90 | 		return "", err
91 | 	}
92 | 
93 | 	if output == "" {
94 | 		return "", errors.New("torch version is empty")
95 | 	}
96 | 
97 | 	return output, nil
98 | }
99 | 


--------------------------------------------------------------------------------
/backend-golang/utils_unix.go:
--------------------------------------------------------------------------------
 1 | //go:build darwin || linux
 2 | 
 3 | package backend_golang
 4 | 
 5 | import (
 6 | 	"os/exec"
 7 | 	"strings"
 8 | )
 9 | 
10 | func CmdSetHideWindow(cmd *exec.Cmd, hideWindow bool) {
11 | }
12 | 
13 | func (a *App) CommandOutput(name string, args ...string) (string, error) {
14 | 	cmd := exec.Command(name, args...)
15 | 	output, err := cmd.CombinedOutput()
16 | 	if err != nil {
17 | 		return "", err
18 | 	}
19 | 	return strings.TrimSpace(string(output)), nil
20 | }
21 | 


--------------------------------------------------------------------------------
/backend-golang/utils_windows.go:
--------------------------------------------------------------------------------
 1 | //go:build windows
 2 | 
 3 | package backend_golang
 4 | 
 5 | import (
 6 | 	"os/exec"
 7 | 	"strings"
 8 | 	"syscall"
 9 | )
10 | 
11 | func CmdSetHideWindow(cmd *exec.Cmd, hideWindow bool) {
12 | 	if cmd.SysProcAttr == nil {
13 | 		cmd.SysProcAttr = &syscall.SysProcAttr{}
14 | 	}
15 | 	cmd.SysProcAttr.HideWindow = hideWindow
16 | }
17 | 
18 | func (a *App) CommandOutput(name string, args ...string) (string, error) {
19 | 	cmd := exec.Command(name, args...)
20 | 	CmdSetHideWindow(cmd, true)
21 | 	output, err := cmd.CombinedOutput()
22 | 	if err != nil {
23 | 		return "", err
24 | 	}
25 | 	return strings.TrimSpace(string(output)), nil
26 | }
27 | 


--------------------------------------------------------------------------------
/backend-golang/wsl_unix.go:
--------------------------------------------------------------------------------
 1 | //go:build darwin || linux
 2 | 
 3 | package backend_golang
 4 | 
 5 | import (
 6 | 	"errors"
 7 | )
 8 | 
 9 | func (a *App) WslStart() error {
10 | 	return errors.New("wsl not supported")
11 | }
12 | 
13 | func (a *App) WslCommand(command string) error {
14 | 	return errors.New("wsl not supported")
15 | }
16 | 
17 | func (a *App) WslStop() error {
18 | 	return errors.New("wsl not supported")
19 | }
20 | 
21 | func (a *App) WslIsEnabled() error {
22 | 	return errors.New("wsl not supported")
23 | }
24 | 
25 | func (a *App) WslEnable(forceMode bool) error {
26 | 	return errors.New("wsl not supported")
27 | }
28 | 
29 | func (a *App) WslInstallUbuntu() error {
30 | 	return errors.New("wsl not supported")
31 | }
32 | 


--------------------------------------------------------------------------------
/backend-python/dep_check.py:
--------------------------------------------------------------------------------
 1 | import setuptools
 2 | 
 3 | if setuptools.__version__ >= "70.0.0":
 4 |     raise ImportError("setuptools>=70.0.0 is not supported")
 5 | 
 6 | import multipart
 7 | import fitz
 8 | import safetensors
 9 | import midi2audio
10 | import mido
11 | import lm_dataformat
12 | import ftfy
13 | import tqdm
14 | import tiktoken
15 | 
16 | import torch
17 | import rwkv
18 | import langchain
19 | import numpy
20 | import tokenizers
21 | import fastapi
22 | import uvicorn
23 | import sse_starlette
24 | import pydantic
25 | import psutil
26 | 


--------------------------------------------------------------------------------
/backend-python/global_var.py:
--------------------------------------------------------------------------------
 1 | from enum import Enum, auto
 2 | 
 3 | Args = "args"
 4 | Model = "model"
 5 | Model_Status = "model_status"
 6 | Model_Config = "model_config"
 7 | Deploy_Mode = "deploy_mode"
 8 | Midi_Vocab_Config_Type = "midi_vocab_config_type"
 9 | 
10 | 
11 | class ModelStatus(Enum):
12 |     Offline = 0
13 |     Loading = 2
14 |     Working = 3
15 | 
16 | 
17 | class MidiVocabConfig(Enum):
18 |     Default = auto()
19 |     Piano = auto()
20 | 
21 | 
22 | def init():
23 |     global GLOBALS
24 |     GLOBALS = {}
25 |     set(Model_Status, ModelStatus.Offline)
26 |     set(Deploy_Mode, False)
27 |     set(Midi_Vocab_Config_Type, MidiVocabConfig.Default)
28 | 
29 | 
30 | def set(key, value):
31 |     GLOBALS[key] = value
32 | 
33 | 
34 | def get(key):
35 |     if key in GLOBALS:
36 |         return GLOBALS[key]
37 |     else:
38 |         return None
39 | 


--------------------------------------------------------------------------------
/backend-python/requirements.txt:
--------------------------------------------------------------------------------
 1 | torch
 2 | torchvision
 3 | torchaudio
 4 | setuptools==69.5.1
 5 | rwkv==0.8.29
 6 | langchain==0.0.322
 7 | fastapi==0.109.1
 8 | uvicorn==0.23.2
 9 | sse-starlette==1.6.5
10 | pydantic==2.4.2
11 | psutil==5.9.6
12 | gputil==1.4.0
13 | tiktoken==0.5.1
14 | ftfy==6.1.1
15 | lm-dataformat==0.0.20
16 | numpy==1.24.4
17 | tokenizers==0.14.1
18 | tqdm==4.66.1
19 | midi2audio==0.1.1
20 | mido==1.3.0
21 | safetensors==0.4.0
22 | PyMuPDF==1.23.5
23 | python-multipart==0.0.7
24 | Cython==3.0.4
25 | 


--------------------------------------------------------------------------------
/backend-python/requirements_without_cyac.txt:
--------------------------------------------------------------------------------
 1 | torch
 2 | torchvision
 3 | torchaudio
 4 | setuptools==69.5.1
 5 | rwkv==0.8.29
 6 | langchain==0.0.322
 7 | fastapi==0.109.1
 8 | uvicorn==0.23.2
 9 | sse-starlette==1.6.5
10 | pydantic==2.4.2
11 | psutil==5.9.6
12 | gputil==1.4.0
13 | tiktoken==0.5.1
14 | ftfy==6.1.1
15 | lm-dataformat==0.0.20
16 | numpy==1.24.4
17 | tokenizers==0.14.1
18 | tqdm==4.66.1
19 | midi2audio==0.1.1
20 | mido==1.3.0
21 | safetensors==0.4.0
22 | PyMuPDF==1.23.5
23 | python-multipart==0.0.7
24 | Cython==3.0.4
25 | 


--------------------------------------------------------------------------------
/backend-python/routes/file_process.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | from fastapi import (
 3 |     APIRouter,
 4 |     HTTPException,
 5 |     status,
 6 |     Depends,
 7 |     File,
 8 |     UploadFile,
 9 | )
10 | from pydantic import BaseModel
11 | from typing import Iterator
12 | 
13 | router = APIRouter()
14 | 
15 | 
16 | class FileToTextParams(BaseModel):
17 |     file_name: str
18 |     file_encoding: str = "utf-8"
19 | 
20 | 
21 | @router.post("/file-to-text", tags=["File Process"])
22 | async def file_to_text(
23 |     params: FileToTextParams = Depends(), file_data: UploadFile = File(...)
24 | ):
25 |     from langchain.schema import Document
26 |     from langchain.document_loaders.blob_loaders import Blob
27 | 
28 |     # from langchain
29 |     def parse_text(blob: Blob) -> Iterator[Document]:
30 |         yield Document(page_content=blob.as_string(), metadata={"source": blob.source})
31 | 
32 |     # from langchain
33 |     def parse_pdf(blob: Blob) -> Iterator[Document]:
34 |         import fitz
35 | 
36 |         with blob.as_bytes_io() as stream:
37 |             doc = fitz.Document(stream=stream)
38 | 
39 |             yield from [
40 |                 Document(
41 |                     page_content=page.get_text(),
42 |                     metadata=dict(
43 |                         {
44 |                             "source": blob.source,
45 |                             "file_path": blob.source,
46 |                             "page": page.number,
47 |                             "total_pages": len(doc),
48 |                         },
49 |                         **{
50 |                             k: doc.metadata[k]
51 |                             for k in doc.metadata
52 |                             if type(doc.metadata[k]) in [str, int]
53 |                         },
54 |                     ),
55 |                 )
56 |                 for page in doc
57 |             ]
58 | 
59 |     file_parsers = {".txt": parse_text, ".pdf": parse_pdf}
60 | 
61 |     file_name = file_data.filename or params.file_name
62 |     file_ext = os.path.splitext(file_name)[-1]
63 | 
64 |     if file_ext not in file_parsers:
65 |         raise HTTPException(status.HTTP_400_BAD_REQUEST, "file type not supported")
66 | 
67 |     try:
68 |         pages: Iterator[Document] = file_parsers[file_ext](
69 |             Blob.from_data(
70 |                 await file_data.read(),
71 |                 encoding=params.file_encoding,
72 |                 path=file_name,
73 |             )
74 |         )
75 |         pages = list(pages)
76 |     except Exception as e:
77 |         raise HTTPException(status.HTTP_400_BAD_REQUEST, f"{e}")
78 | 
79 |     return {"pages": pages}
80 | 


--------------------------------------------------------------------------------
/backend-python/rwkv_pip/cpp/librwkv.dylib:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/backend-python/rwkv_pip/cpp/librwkv.dylib


--------------------------------------------------------------------------------
/backend-python/rwkv_pip/cpp/librwkv.so:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/backend-python/rwkv_pip/cpp/librwkv.so


--------------------------------------------------------------------------------
/backend-python/rwkv_pip/cpp/model.py:
--------------------------------------------------------------------------------
 1 | from typing import Any, List, Union
 2 | from . import rwkv_cpp_model
 3 | from . import rwkv_cpp_shared_library
 4 | 
 5 | 
 6 | class RWKV:
 7 |     def __init__(self, model_path: str, strategy=None):
 8 |         self.library = rwkv_cpp_shared_library.load_rwkv_shared_library()
 9 |         self.model = rwkv_cpp_model.RWKVModel(self.library, model_path)
10 |         self.w = {}  # fake weight
11 |         self.w["emb.weight"] = [0] * self.model.n_vocab
12 |         self.version = (
13 |             self.model.arch_version_major + self.model.arch_version_minor / 10
14 |         )
15 | 
16 |     def forward(self, tokens: List[int], state: Union[Any, None] = None):
17 |         return self.model.eval_sequence_in_chunks(tokens, state, use_numpy=True)
18 | 


--------------------------------------------------------------------------------
/backend-python/rwkv_pip/cpp/rwkv.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/backend-python/rwkv_pip/cpp/rwkv.dll


--------------------------------------------------------------------------------
/backend-python/rwkv_pip/cuda/rwkv5_op.cpp:
--------------------------------------------------------------------------------
 1 | #include <torch/extension.h>
 2 | #include "ATen/ATen.h"
 3 | #include <c10/cuda/CUDAGuard.h>
 4 | typedef at::BFloat16 bf16;
 5 | typedef at::Half fp16;
 6 | typedef float fp32;
 7 | 
 8 | void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
 9 | void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
10 | void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
11 | 
12 | void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
13 |     const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
14 |     cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
15 | }
16 | void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
17 |     const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
18 |     cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
19 | }
20 | void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
21 |     const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
22 |     cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
23 | }
24 | 
25 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
26 |     m.def("forward_bf16", &forward_bf16, "rwkv5 forward_bf16");
27 |     m.def("forward_fp16", &forward_fp16, "rwkv5 forward_fp16");
28 |     m.def("forward_fp32", &forward_fp32, "rwkv5 forward_fp32");
29 | }
30 | TORCH_LIBRARY(rwkv5, m) {
31 |     m.def("forward_bf16", forward_bf16);
32 |     m.def("forward_fp16", forward_fp16);
33 |     m.def("forward_fp32", forward_fp32);
34 | }
35 | 


--------------------------------------------------------------------------------
/backend-python/rwkv_pip/cuda/rwkv6_op.cpp:
--------------------------------------------------------------------------------
 1 | #include <torch/extension.h>
 2 | #include "ATen/ATen.h"
 3 | #include <c10/cuda/CUDAGuard.h>
 4 | typedef at::BFloat16 bf16;
 5 | typedef at::Half fp16;
 6 | typedef float fp32;
 7 | 
 8 | void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
 9 | void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *k, fp16 *v, float *w, fp16 *u, fp16 *y);
10 | void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *k, fp32 *v, float *w, fp32 *u, fp32 *y);
11 | 
12 | void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
13 |     const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
14 |     cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
15 | }
16 | void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
17 |     const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
18 |     cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), w.data_ptr<float>(), u.data_ptr<fp16>(), y.data_ptr<fp16>());
19 | }
20 | void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
21 |     const at::cuda::OptionalCUDAGuard device_guard(device_of(state));
22 |     cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), w.data_ptr<float>(), u.data_ptr<fp32>(), y.data_ptr<fp32>());
23 | }
24 | 
25 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
26 |     m.def("forward_bf16", &forward_bf16, "rwkv6 forward_bf16");
27 |     m.def("forward_fp16", &forward_fp16, "rwkv6 forward_fp16");
28 |     m.def("forward_fp32", &forward_fp32, "rwkv6 forward_fp32");
29 | }
30 | TORCH_LIBRARY(rwkv6, m) {
31 |     m.def("forward_bf16", forward_bf16);
32 |     m.def("forward_fp16", forward_fp16);
33 |     m.def("forward_fp32", forward_fp32);
34 | }
35 | 


--------------------------------------------------------------------------------
/backend-python/rwkv_pip/cuda/rwkv7.cu:
--------------------------------------------------------------------------------
 1 | #include <stdio.h>
 2 | #include <assert.h>
 3 | #include "ATen/ATen.h"
 4 | 
 5 | typedef at::Half fp16;
 6 | typedef at::BFloat16 bf16;
 7 | typedef float fp32;
 8 | 
 9 | template <typename F>
10 | __global__ void kernel_forward(const int B, const int T, const int C, const int H,
11 |                                float *__restrict__ _state, const F *__restrict__ const _r, const F *__restrict__ const _w, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _a, const F *__restrict__ const _b,
12 |                                F *__restrict__ const _y)
13 | {
14 |     const int e = blockIdx.x / H;
15 |     const int h = blockIdx.x % H;
16 |     const int i = threadIdx.x;
17 |     _state += h*_N_*_N_ + i*_N_; // wrong if B > 1 !!!
18 | 
19 |     float state[_N_];
20 |     #pragma unroll
21 |     for (int j = 0; j < _N_; j++)
22 |         state[j] = _state[j];
23 | 
24 |     __shared__ float r[_N_], k[_N_], w[_N_], a[_N_], b[_N_];
25 | 
26 |     for (int _t = 0; _t < T; _t++)
27 |     {
28 |         const int t = e*T*C + h*_N_ + i + _t * C;
29 |         __syncthreads();
30 |         r[i] = float(_r[t]);
31 |         w[i] = __expf(-__expf(float(_w[t])));
32 |         k[i] = float(_k[t]);
33 |         a[i] = float(_a[t]);
34 |         b[i] = float(_b[t]);
35 |         __syncthreads();
36 | 
37 |         float sa = 0;
38 |         #pragma unroll
39 |         for (int j = 0; j < _N_; j++)
40 |         {
41 |             sa += a[j] * state[j];
42 |         }
43 | 
44 |         float vv = float(_v[t]);
45 |         float y = 0;
46 |         #pragma unroll
47 |         for (int j = 0; j < _N_; j++)
48 |         {
49 |             float& s = state[j];
50 |             s = s * w[j] + k[j] * vv + sa * b[j];
51 |             y += s * r[j];
52 |         }
53 |         _y[t] = F(y);
54 |     }
55 |     #pragma unroll
56 |     for (int j = 0; j < _N_; j++)
57 |         _state[j] = state[j];    
58 | }
59 | 
60 | void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16* w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y)
61 | {
62 |     assert(H*_N_ == C);
63 |     assert(B == 1); // only for B=1
64 |     kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y);
65 | }
66 | void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16* w, fp16 *k, fp16 *v, fp16 *a, fp16 *b, fp16 *y)
67 | {
68 |     assert(H*_N_ == C);
69 |     assert(B == 1); // only for B=1
70 |     kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y);
71 | }
72 | void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32* w, fp32 *k, fp32 *v, fp32 *a, fp32 *b, fp32 *y)
73 | {
74 |     assert(H*_N_ == C);
75 |     assert(B == 1); // only for B=1
76 |     kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, state, r, w, k, v, a, b, y);
77 | }
78 | 


--------------------------------------------------------------------------------
/backend-python/rwkv_pip/cuda/rwkv7_op.cpp:
--------------------------------------------------------------------------------
 1 | #include <torch/extension.h>
 2 | #include "ATen/ATen.h"
 3 | 
 4 | typedef at::Half fp16;
 5 | typedef at::BFloat16 bf16;
 6 | typedef float fp32;
 7 | 
 8 | void cuda_forward_bf16(int B, int T, int C, int H, float *state, bf16 *r, bf16 *w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y);
 9 | void cuda_forward_fp16(int B, int T, int C, int H, float *state, fp16 *r, fp16 *w, fp16 *k, fp16 *v, fp16 *a, fp16 *b, fp16 *y);
10 | void cuda_forward_fp32(int B, int T, int C, int H, float *state, fp32 *r, fp32 *w, fp32 *k, fp32 *v, fp32 *a, fp32 *b, fp32 *y);
11 | 
12 | void forward_bf16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
13 |     cuda_forward_bf16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<bf16>(), w.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), a.data_ptr<bf16>(), b.data_ptr<bf16>(), y.data_ptr<bf16>());
14 | }
15 | void forward_fp16(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
16 |     cuda_forward_fp16(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp16>(), w.data_ptr<fp16>(), k.data_ptr<fp16>(), v.data_ptr<fp16>(), a.data_ptr<fp16>(), b.data_ptr<fp16>(), y.data_ptr<fp16>());
17 | }
18 | void forward_fp32(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &state, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
19 |     cuda_forward_fp32(B, T, C, H, state.data_ptr<float>(), r.data_ptr<fp32>(), w.data_ptr<fp32>(), k.data_ptr<fp32>(), v.data_ptr<fp32>(), a.data_ptr<fp32>(), b.data_ptr<fp32>(), y.data_ptr<fp32>());
20 | }
21 | 
22 | TORCH_LIBRARY(wkv7s, m) {
23 |     m.def("forward_bf16", forward_bf16);
24 |     m.def("forward_fp16", forward_fp16);
25 |     m.def("forward_fp32", forward_fp32);
26 | }
27 | 


--------------------------------------------------------------------------------
/backend-python/rwkv_pip/kernels/torch-1.13.1+cu117/rwkv5.pyd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/backend-python/rwkv_pip/kernels/torch-1.13.1+cu117/rwkv5.pyd


--------------------------------------------------------------------------------
/backend-python/rwkv_pip/kernels/torch-1.13.1+cu117/rwkv6.pyd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/backend-python/rwkv_pip/kernels/torch-1.13.1+cu117/rwkv6.pyd


--------------------------------------------------------------------------------
/backend-python/rwkv_pip/kernels/torch-1.13.1+cu117/wkv7s.pyd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/backend-python/rwkv_pip/kernels/torch-1.13.1+cu117/wkv7s.pyd


--------------------------------------------------------------------------------
/backend-python/rwkv_pip/kernels/torch-1.13.1+cu117/wkv_cuda.pyd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/backend-python/rwkv_pip/kernels/torch-1.13.1+cu117/wkv_cuda.pyd


--------------------------------------------------------------------------------
/backend-python/rwkv_pip/kernels/torch-2.7.1+cu128/rwkv5.pyd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/backend-python/rwkv_pip/kernels/torch-2.7.1+cu128/rwkv5.pyd


--------------------------------------------------------------------------------
/backend-python/rwkv_pip/kernels/torch-2.7.1+cu128/rwkv6.pyd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/backend-python/rwkv_pip/kernels/torch-2.7.1+cu128/rwkv6.pyd


--------------------------------------------------------------------------------
/backend-python/rwkv_pip/kernels/torch-2.7.1+cu128/wkv7s.pyd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/backend-python/rwkv_pip/kernels/torch-2.7.1+cu128/wkv7s.pyd


--------------------------------------------------------------------------------
/backend-python/rwkv_pip/kernels/torch-2.7.1+cu128/wkv_cuda.pyd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/backend-python/rwkv_pip/kernels/torch-2.7.1+cu128/wkv_cuda.pyd


--------------------------------------------------------------------------------
/backend-python/rwkv_pip/webgpu/model.py:
--------------------------------------------------------------------------------
 1 | from typing import Any, List, Union
 2 | 
 3 | try:
 4 |     import web_rwkv_py as wrp
 5 | except ModuleNotFoundError:
 6 |     try:
 7 |         from . import web_rwkv_py as wrp
 8 |     except ImportError:
 9 |         raise ModuleNotFoundError(
10 |             "web_rwkv_py not found, install it from https://github.com/cryscan/web-rwkv-py"
11 |         )
12 | 
13 | 
14 | class RWKV:
15 |     def __init__(self, model_path: str, strategy: str = None):
16 |         layer = (
17 |             int(s.lstrip("layer"))
18 |             for s in strategy.split()
19 |             for s in s.split(",")
20 |             if s.startswith("layer")
21 |         )
22 | 
23 |         chunk_size = (
24 |             int(s.lstrip("chunk"))
25 |             for s in strategy.split()
26 |             for s in s.split(",")
27 |             if s.startswith("chunk")
28 |         )
29 |         self.token_chunk_size = next(chunk_size, 32)
30 | 
31 |         args = {
32 |             "path": model_path,
33 |             "quant": next(layer, 31) if "i8" in strategy else 0,
34 |             "quant_nf4": next(layer, 26) if "i4" in strategy else 0,
35 |         }
36 |         self.model = wrp.Model(**args)
37 |         self.info = self.model.info()
38 |         self.w = {}  # fake weight
39 |         self.w["emb.weight"] = [0] * self.info.num_vocab
40 |         self.version = str(self.info.version).lower()
41 |         self.version = float(self.version.lower().replace("v", ""))
42 | 
43 |     def forward(self, tokens: List[int], state: Union[Any, None] = None):
44 |         if state is None:
45 |             self.model.clear_state()
46 |         elif type(state).__name__ == "State_Cpu":
47 |             self.model.load_state(state)
48 |         logits = self.model.run(tokens, self.token_chunk_size)
49 |         ret_state = "State_Gpu"
50 |         return logits, ret_state
51 | 


--------------------------------------------------------------------------------
/backend-python/rwkv_pip/webgpu/web_rwkv_py.cp310-win_amd64.pyd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/backend-python/rwkv_pip/webgpu/web_rwkv_py.cp310-win_amd64.pyd


--------------------------------------------------------------------------------
/backend-python/tests/postprocess_response.py:
--------------------------------------------------------------------------------
 1 | import re
 2 | 
 3 | 
 4 | def postprocess_response(s):
 5 |     REGEX_BLOCKS = r"([\w]+)[\s]*```[\w]*(.*?)```"
 6 |     REGEX_ARGS = r'"([^"]+)"\s*=\s*"([^"]+)"'
 7 | 
 8 |     name = re.search(REGEX_BLOCKS, s, re.DOTALL).group(1)
 9 |     function = re.search(REGEX_BLOCKS, s, re.DOTALL).group(2).strip()
10 |     arguments = dict(re.findall(REGEX_ARGS, function))
11 | 
12 |     print(f"Name:\n{name}")
13 |     print(f"Function:\n{function}")
14 |     print(f"arguments:\n{arguments}")
15 |     print()
16 | 
17 |     return
18 | 
19 | 
20 | def postprocess_response_reserved(s):
21 |     REGEX_BLOCKS = r"```[\w]*(.*?)```"
22 |     REGEX_FUNCTIONS = r"(\w+)*\("
23 |     REGEX_ARGS = r'"([^"]+)"\s*=\s*"([^"]+)"'
24 | 
25 |     blocks = re.findall(REGEX_BLOCKS, s, re.DOTALL)
26 |     print(f"Blocks:\n{blocks}")
27 |     for block in blocks:
28 |         functions = block.strip().split("\n")
29 |         print(f"Functions:\n{functions}")
30 |         print()
31 |         for function in functions:
32 |             name = re.search(REGEX_FUNCTIONS, function).group(1)
33 |             arguments = f"{dict(re.findall(REGEX_ARGS, function))}"
34 | 
35 |             print(function)
36 |             print(name)
37 |             print(arguments)
38 |             print()
39 | 
40 |     return
41 | 
42 | 
43 | if __name__ == "__main__":
44 |     str = """
45 |     some texts
46 |     some texts
47 |     some texts
48 |     some texts
49 | 
50 |     ```python\n
51 |     get_current_wether("location"= "Tokyo", "unit" ="None")\n
52 |     ```
53 | 
54 |     some texts
55 |     some texts
56 |     some texts
57 |     some texts
58 |     """
59 |     postprocess_response(str)
60 | 
61 |     str = """ get_exchange_rate
62 | ```python
63 | tool_call("base_currency"= "func_as_param('Hello World!')", "target_currency"= "CNY")
64 | ```"""
65 |     postprocess_response(str)
66 | 
67 |     str = """\
68 | get_current_weather
69 | ```python\n
70 | tool_call("location"= "Tokyo", "unit"= "None")\n
71 | ```"""
72 |     postprocess_response(str)
73 | 


--------------------------------------------------------------------------------
/backend-python/utils/log.py:
--------------------------------------------------------------------------------
 1 | import json
 2 | import logging
 3 | from typing import Any, Union
 4 | from fastapi import Request
 5 | from pydantic import BaseModel
 6 | from enum import Enum
 7 | 
 8 | 
 9 | logger = logging.getLogger()
10 | logger.setLevel(logging.INFO)
11 | formatter = logging.Formatter("%(asctime)s - %(levelname)s\n%(message)s")
12 | fh = logging.handlers.RotatingFileHandler(
13 |     "api.log", mode="a", maxBytes=3 * 1024 * 1024, backupCount=3, encoding="utf-8"
14 | )
15 | fh.setFormatter(formatter)
16 | logger.addHandler(fh)
17 | 
18 | 
19 | class ClsEncoder(json.JSONEncoder):
20 |     def default(self, obj):
21 |         if isinstance(obj, BaseModel):
22 |             return obj.dict()
23 |         if isinstance(obj, Enum):
24 |             return obj.value
25 |         return super().default(obj)
26 | 
27 | 
28 | def quick_log(request: Union[Request, None], body: Any, response: str):
29 |     try:
30 |         logger.info(
31 |             f"Client: {request.client if request else ''}\nUrl: {request.url if request else ''}\n"
32 |             + (
33 |                 f"Body: {json.dumps(body.__dict__, ensure_ascii=False, cls=ClsEncoder)}\n"
34 |                 if body
35 |                 else ""
36 |             )
37 |             + (f"Data:\n{response}\n" if response else "")
38 |         )
39 |     except Exception as e:
40 |         logger.info(f"Error quick_log request:\n{e}")
41 | 
42 | 
43 | async def log_middleware(request: Request):
44 |     try:
45 |         logger.info(
46 |             f"Client: {request.client}\nUrl: {request.url}\nBody: {await request.body()}\n"
47 |         )
48 |     except Exception as e:
49 |         logger.info(f"Error log_middleware request:\n{e}")
50 | 


--------------------------------------------------------------------------------
/backend-python/utils/midi_filter_config.json:
--------------------------------------------------------------------------------
1 | {
2 |     "deduplicate_md5": true,
3 |     "piece_split_delay": 10000,
4 |     "min_piece_length": 0
5 | }


--------------------------------------------------------------------------------
/backend-python/utils/ngrok.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | import global_var
 3 | 
 4 | 
 5 | def ngrok_connect():
 6 |     from pyngrok import ngrok, conf
 7 | 
 8 |     conf.set_default(
 9 |         conf.PyngrokConfig(ngrok_path="./ngrok.exe" if os.name == "nt" else "./ngrok")
10 |     )
11 |     ngrok.set_auth_token(os.environ["ngrok_token"])
12 |     http_tunnel = ngrok.connect(global_var.get(global_var.Args).port)
13 |     print(f"ngrok url: {http_tunnel.public_url}")
14 | 


--------------------------------------------------------------------------------
/backend-python/utils/torch.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | import sysconfig
 3 | 
 4 | 
 5 | def set_torch():
 6 |     torch_path = os.path.join(sysconfig.get_paths()["purelib"], f"torch{os.sep}lib")
 7 |     paths = os.environ.get("PATH", "")
 8 |     if os.path.exists(torch_path):
 9 |         print(f"torch found: {torch_path}")
10 |         if torch_path in paths:
11 |             print("torch already set")
12 |         else:
13 |             os.environ["PATH"] = paths + os.pathsep + torch_path + os.pathsep
14 |             print("torch set")
15 |             # print("run:")
16 |             # print(f"set Path={paths + os.pathsep + torch_path + os.pathsep}")
17 |     else:
18 |         print("torch not found")
19 | 
20 | 
21 | def torch_gc():
22 |     try:
23 |         import torch
24 | 
25 |         if torch.cuda.is_available():
26 |             with torch.cuda.device(0):
27 |                 torch.cuda.empty_cache()
28 |                 torch.cuda.ipc_collect()
29 |     except:
30 |         pass  # prevent 'torch' has no attribute 'cuda' error, so user can use CPU or WebGPU
31 | 


--------------------------------------------------------------------------------
/backend-python/webui_server.py:
--------------------------------------------------------------------------------
 1 | from fastapi import FastAPI
 2 | from fastapi.middleware.gzip import GZipMiddleware
 3 | from fastapi.staticfiles import StaticFiles
 4 | import uvicorn
 5 | 
 6 | webui_server = FastAPI()
 7 | 
 8 | webui_server.add_middleware(GZipMiddleware, minimum_size=1000)
 9 | webui_server.mount(
10 |     "/", StaticFiles(directory="frontend/dist", html=True), name="static"
11 | )
12 | 
13 | if __name__ == "__main__":
14 |     uvicorn.run("webui_server:webui_server")
15 | 


--------------------------------------------------------------------------------
/build/README.md:
--------------------------------------------------------------------------------
 1 | # Build Directory
 2 | 
 3 | The build directory is used to house all the build files and assets for your application. 
 4 | 
 5 | The structure is:
 6 | 
 7 | * bin - Output directory
 8 | * darwin - macOS specific files
 9 | * windows - Windows specific files
10 | 
11 | ## Mac
12 | 
13 | The `darwin` directory holds files specific to Mac builds.
14 | These may be customised and used as part of the build. To return these files to the default state, simply delete them
15 | and
16 | build with `wails build`.
17 | 
18 | The directory contains the following files:
19 | 
20 | - `Info.plist` - the main plist file used for Mac builds. It is used when building using `wails build`.
21 | - `Info.dev.plist` - same as the main plist file but used when building using `wails dev`.
22 | 
23 | ## Windows
24 | 
25 | The `windows` directory contains the manifest and rc files used when building with `wails build`.
26 | These may be customised for your application. To return these files to the default state, simply delete them and
27 | build with `wails build`.
28 | 
29 | - `icon.ico` - The icon used for the application. This is used when building using `wails build`. If you wish to
30 |   use a different icon, simply replace this file with your own. If it is missing, a new `icon.ico` file
31 |   will be created using the `appicon.png` file in the build directory.
32 | - `installer/*` - The files used to create the Windows installer. These are used when building using `wails build`.
33 | - `info.json` - Application details used for Windows builds. The data here will be used by the Windows installer,
34 |   as well as the application itself (right click the exe -> properties -> details)
35 | - `wails.exe.manifest` - The main application manifest file.


--------------------------------------------------------------------------------
/build/appicon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/build/appicon.png


--------------------------------------------------------------------------------
/build/darwin/Info.dev.plist:
--------------------------------------------------------------------------------
 1 | <!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
 2 | <plist version="1.0">
 3 |     <dict>
 4 |         <key>CFBundlePackageType</key>
 5 |         <string>APPL</string>
 6 |         <key>CFBundleName</key>
 7 |         <string>{{.Info.ProductName}}</string>
 8 |         <key>CFBundleExecutable</key>
 9 |         <string>{{.Name}}</string>
10 |         <key>CFBundleIdentifier</key>
11 |         <string>dev.josStorer.RWKV-Runner</string>
12 |         <key>CFBundleVersion</key>
13 |         <string>{{.Info.ProductVersion}}</string>
14 |         <key>CFBundleGetInfoString</key>
15 |         <string>{{.Info.Comments}}</string>
16 |         <key>CFBundleShortVersionString</key>
17 |         <string>{{.Info.ProductVersion}}</string>
18 |         <key>CFBundleIconFile</key>
19 |         <string>iconfile</string>
20 |         <key>LSMinimumSystemVersion</key>
21 |         <string>10.13.0</string>
22 |         <key>NSHighResolutionCapable</key>
23 |         <string>true</string>
24 |         <key>NSHumanReadableCopyright</key>
25 |         <string>{{.Info.Copyright}}</string>
26 |         <key>NSAppTransportSecurity</key>
27 |         <dict>
28 |             <key>NSAllowsLocalNetworking</key>
29 |             <true/>
30 |         </dict>
31 |     </dict>
32 | </plist>


--------------------------------------------------------------------------------
/build/darwin/Info.plist:
--------------------------------------------------------------------------------
 1 | <!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
 2 | <plist version="1.0">
 3 |     <dict>
 4 |         <key>CFBundlePackageType</key>
 5 |         <string>APPL</string>
 6 |         <key>CFBundleName</key>
 7 |         <string>{{.Info.ProductName}}</string>
 8 |         <key>CFBundleExecutable</key>
 9 |         <string>{{.Name}}</string>
10 |         <key>CFBundleIdentifier</key>
11 |         <string>dev.josStorer.RWKV-Runner</string>
12 |         <key>CFBundleVersion</key>
13 |         <string>{{.Info.ProductVersion}}</string>
14 |         <key>CFBundleGetInfoString</key>
15 |         <string>{{.Info.Comments}}</string>
16 |         <key>CFBundleShortVersionString</key>
17 |         <string>{{.Info.ProductVersion}}</string>
18 |         <key>CFBundleIconFile</key>
19 |         <string>iconfile</string>
20 |         <key>LSMinimumSystemVersion</key>
21 |         <string>10.13.0</string>
22 |         <key>NSHighResolutionCapable</key>
23 |         <string>true</string>
24 |         <key>NSHumanReadableCopyright</key>
25 |         <string>{{.Info.Copyright}}</string>
26 |     </dict>
27 | </plist>


--------------------------------------------------------------------------------
/build/darwin/Readme_Install.txt:
--------------------------------------------------------------------------------
 1 | Client Download URL:
 2 | 客户端下载地址:
 3 | クライアントのダウンロードURL:
 4 | https://github.com/josStorer/RWKV-Runner/releases/latest/download/RWKV-Runner_macos_universal.zip
 5 | 
 6 | For Mac and Linux users, please manually install Python 3.10 (usually the latest systems come with it built-in). You can specify the Python interpreter to use in Settings. (which python3)
 7 | 对于Mac和Linux用户,请手动安装 Python3.10 (通常最新的系统已经内置了). 你可以在设置中指定使用的Python解释器. (which python3)
 8 | MacおよびLinuxのユーザーの方は、Python3.10を手動でインストールしてください(通常、最新のシステムには既に組み込まれています)。 設定メニューで使用するPythonインタプリタを指定することができます。 (which python3)
 9 | 
10 | Please execute this program in an empty directory. All related dependencies will be placed in this directory.
11 | 请将本程序放在一个空目录内执行, 所有相关依赖均会放置于此目录.
12 | このプログラムを空のディレクトリで実行してください。関連するすべての依存関係は、このディレクトリに配置されます。
13 | 
14 | Please execute the following command in the terminal to remove the permission restrictions of this app, and then this program can work properly:
15 | 请在终端执行以下命令解除本app的权限限制, 然后本程序才可以正常工作:
16 | このアプリの権限制限を解除するために、ターミナルで以下のコマンドを実行してください。その後、このプログラムは正常に動作するようになります:
17 | 
18 | sudo xattr -r -d com.apple.quarantine ./RWKV-Runner.app
19 | 


--------------------------------------------------------------------------------
/build/darwin/entitlements.plist:
--------------------------------------------------------------------------------
 1 | <?xml version="1.0" encoding="UTF-8"?>
 2 | <!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
 3 | <plist version="1.0">
 4 | <dict>
 5 |   <key>com.apple.security.app-sandbox</key>
 6 |   <false/>
 7 |   <key>com.apple.security.network.client</key>
 8 |   <true/>
 9 |   <key>com.apple.security.network.server</key>
10 |   <true/>
11 |   <key>com.apple.security.files.user-selected.read-write</key>
12 |   <true/>
13 |   <key>com.apple.security.files.downloads.read-write</key>
14 |   <true/>
15 | </dict>
16 | </plist>


--------------------------------------------------------------------------------
/build/darwin/gon-sign.json:
--------------------------------------------------------------------------------
 1 | {
 2 |     "source": [
 3 |         "./build/bin/RWKV-Runner_darwin.app"
 4 |     ],
 5 |     "bundle_id": "dev.josStorer.RWKV-Runner",
 6 |     "apple_id": {
 7 |         "username": "joshua1466587594@outlook.com",
 8 |         "password": ""
 9 |     },
10 |     "sign": {
11 |         "application_identity": "D00A983569B4EAA2A008B963254F385F42A493FD",
12 |         "entitlements_file": "./build/darwin/entitlements.plist"
13 |     },
14 |     "zip": {
15 |         "output_path": "./build/bin/RWKV-Runner_darwin.archive.zip"
16 |     }
17 | }


--------------------------------------------------------------------------------
/build/linux/Readme_Install.txt:
--------------------------------------------------------------------------------
 1 | Client Download URL:
 2 | 客户端下载地址:
 3 | クライアントのダウンロードURL:
 4 | https://github.com/josStorer/RWKV-Runner/releases/latest/download/RWKV-Runner_linux_x64
 5 | 
 6 | For Mac and Linux users, please manually install Python 3.10 (usually the latest systems come with it built-in). You can specify the Python interpreter to use in Settings.
 7 | 对于Mac和Linux用户,请手动安装 Python3.10 (通常最新的系统已经内置了). 你可以在设置中指定使用的Python解释器.
 8 | MacおよびLinuxのユーザーの方は、Python3.10を手動でインストールしてください(通常、最新のシステムには既に組み込まれています)。 設定メニューで使用するPythonインタプリタを指定することができます。
 9 | 
10 | Please execute this program in an empty directory. All related dependencies will be placed in this directory.
11 | 请将本程序放在一个空目录内执行, 所有相关依赖均会放置于此目录.
12 | このプログラムを空のディレクトリで実行してください。関連するすべての依存関係は、このディレクトリに配置されます。
13 | 
14 | On Linux system, this program cannot invoke the terminal for automatic dependency installation. You must manually execute the following commands for installation so that it can be used normally:
15 | 在Linux系统下, 本程序无法调用终端自动安装依赖, 你必须手动执行以下命令进行安装, 之后方可正常使用:
16 | Linuxシステムでは、このプログラムはターミナルを自動的に呼び出して依存関係をインストールすることができません。以下のコマンドを手動で実行する必要があります。それが完了した後に、正常に使用することができます:
17 | 
18 | sudo apt install python3-dev
19 | chmod +x ./RWKV-Runner
20 | ./RWKV-Runner
21 | cd backend-python
22 | pip3 install -r requirements.txt # or pip3 install -r requirements_without_cyac.txt
23 | 
24 | # See More: https://github.com/josStorer/RWKV-Runner/tree/master/deploy-examples
25 | 


--------------------------------------------------------------------------------
/build/windows/Readme_Install.txt:
--------------------------------------------------------------------------------
 1 | Client Download URL:
 2 | 客户端下载地址:
 3 | クライアントのダウンロードURL:
 4 | https://github.com/josStorer/RWKV-Runner/releases/latest/download/RWKV-Runner-amd64-installer.exe
 5 | https://github.com/josStorer/RWKV-Runner/releases/latest/download/RWKV-Runner_windows_x64.exe
 6 | 
 7 | Please execute this program in an empty directory. All related dependencies will be placed in this directory.
 8 | 请将本程序放在一个空目录内执行, 所有相关依赖均会放置于此目录.
 9 | このプログラムを空のディレクトリで実行してください。関連するすべての依存関係は、このディレクトリに配置されます。
10 | 


--------------------------------------------------------------------------------
/build/windows/WELCOMEFINISHPAGE.bmp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/build/windows/WELCOMEFINISHPAGE.bmp


--------------------------------------------------------------------------------
/build/windows/icon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/build/windows/icon.ico


--------------------------------------------------------------------------------
/build/windows/info.json:
--------------------------------------------------------------------------------
 1 | {
 2 | 	"fixed": {
 3 | 		"file_version": "{{.Info.ProductVersion}}"
 4 | 	},
 5 | 	"info": {
 6 | 		"0000": {
 7 | 			"ProductVersion": "{{.Info.ProductVersion}}",
 8 | 			"CompanyName": "{{.Info.CompanyName}}",
 9 | 			"FileDescription": "{{.Info.ProductName}}",
10 | 			"LegalCopyright": "{{.Info.Copyright}}",
11 | 			"ProductName": "{{.Info.ProductName}}",
12 | 			"Comments": "{{.Info.Comments}}"
13 | 		}
14 | 	}
15 | }


--------------------------------------------------------------------------------
/build/windows/wails.exe.manifest:
--------------------------------------------------------------------------------
 1 | <?xml version="1.0" encoding="UTF-8" standalone="yes"?>
 2 | <assembly manifestVersion="1.0" xmlns="urn:schemas-microsoft-com:asm.v1" xmlns:asmv3="urn:schemas-microsoft-com:asm.v3">
 3 |     <assemblyIdentity type="win32" name="com.wails.{{.Name}}" version="{{.Info.ProductVersion}}.0" processorArchitecture="*"/>
 4 |     <dependency>
 5 |         <dependentAssembly>
 6 |             <assemblyIdentity type="win32" name="Microsoft.Windows.Common-Controls" version="6.0.0.0" processorArchitecture="*" publicKeyToken="6595b64144ccf1df" language="*"/>
 7 |         </dependentAssembly>
 8 |     </dependency>
 9 |     <asmv3:application>
10 |         <asmv3:windowsSettings>
11 |             <dpiAware xmlns="http://schemas.microsoft.com/SMI/2005/WindowsSettings">true/pm</dpiAware> <!-- fallback for Windows 7 and 8 -->
12 |             <dpiAwareness xmlns="http://schemas.microsoft.com/SMI/2016/WindowsSettings">permonitorv2,permonitor</dpiAwareness> <!-- falls back to per-monitor if per-monitor v2 is not supported -->
13 |         </asmv3:windowsSettings>
14 |     </asmv3:application>
15 | </assembly>


--------------------------------------------------------------------------------
/components/gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/components/gitkeep


--------------------------------------------------------------------------------
/deploy-examples/ChatGPT-Next-Web/setup.bat:
--------------------------------------------------------------------------------
 1 | : install git python3.10 yarn by yourself
 2 | : change model and strategy according to your hardware
 3 | 
 4 | mkdir RWKV-Next-Web
 5 | cd RWKV-Next-Web
 6 | 
 7 | git clone https://github.com/josStorer/RWKV-Runner --depth=1
 8 | python -m pip install torch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 --index-url https://download.pytorch.org/whl/cu117
 9 | python -m pip install -r RWKV-Runner/backend-python/requirements.txt
10 | start python ./RWKV-Runner/backend-python/main.py
11 | 
12 | powershell -Command "(Test-Path ./RWKV-Runner/models) -or (mkdir RWKV-Runner/models)"
13 | powershell -Command "Import-Module BitsTransfer"
14 | powershell -Command "(Test-Path ./RWKV-Runner/models/RWKV-4-World-1.5B-v1-fixed-20230612-ctx4096.pth) -or (Start-BitsTransfer https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-1.5B-v1-fixed-20230612-ctx4096.pth ./RWKV-Runner/models/RWKV-4-World-1.5B-v1-fixed-20230612-ctx4096.pth)"
15 | powershell -Command "Invoke-WebRequest http://127.0.0.1:8000/switch-model -Method POST -ContentType 'application/json' -Body '{\"model\":\"./RWKV-Runner/models/RWKV-4-World-1.5B-v1-fixed-20230612-ctx4096.pth\",\"strategy\":\"cuda fp32 *20+\"}'"
16 | 
17 | git clone https://github.com/Yidadaa/ChatGPT-Next-Web --depth=1
18 | cd ChatGPT-Next-Web
19 | call yarn install
20 | call yarn build
21 | set PROXY_URL=""
22 | set BASE_URL=http://127.0.0.1:8000
23 | start "C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe" "http://127.0.0.1:3000"
24 | yarn start
25 | 


--------------------------------------------------------------------------------
/deploy-examples/ChatGPT-Next-Web/setup.sh:
--------------------------------------------------------------------------------
 1 | # install git python3.10 yarn by yourself
 2 | # change model and strategy according to your hardware
 3 | 
 4 | sudo apt install python3-dev
 5 | 
 6 | mkdir RWKV-Next-Web
 7 | cd RWKV-Next-Web
 8 | 
 9 | git clone https://github.com/josStorer/RWKV-Runner --depth=1
10 | python3 -m pip install torch torchvision torchaudio
11 | python3 -m pip install -r RWKV-Runner/backend-python/requirements.txt
12 | python3 ./RWKV-Runner/backend-python/main.py > log.txt & # this is only an example, you should use screen or other tools to run it in background
13 | 
14 | if [ ! -d RWKV-Runner/models ]; then
15 |     mkdir RWKV-Runner/models
16 | fi
17 | wget -N https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-0.1B-v1-20230520-ctx4096.pth -P RWKV-Runner/models/
18 | 
19 | git clone https://github.com/Yidadaa/ChatGPT-Next-Web --depth=1
20 | cd ChatGPT-Next-Web
21 | yarn install
22 | yarn build
23 | export PROXY_URL=""
24 | export BASE_URL=http://127.0.0.1:8000
25 | yarn start & # this is only an example, you should use screen or other tools to run it in background
26 | 
27 | curl http://127.0.0.1:8000/switch-model -X POST -H "Content-Type: application/json" -d '{"model":"./RWKV-Runner/models/RWKV-4-World-0.1B-v1-20230520-ctx4096.pth","strategy":"cpu fp32"}'
28 | 


--------------------------------------------------------------------------------
/deploy-examples/RWKV-Runner-WebUI/setup.bat:
--------------------------------------------------------------------------------
 1 | : install git python3.10 npm by yourself
 2 | : change model and strategy according to your hardware
 3 | 
 4 | git clone https://github.com/josStorer/RWKV-Runner --depth=1
 5 | python -m pip install torch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 --index-url https://download.pytorch.org/whl/cu117
 6 | python -m pip install -r RWKV-Runner/backend-python/requirements.txt
 7 | cd RWKV-Runner/frontend
 8 | call npm ci
 9 | call npm run build
10 | cd ..
11 | 
12 | : optional: set ngrok_token=YOUR_NGROK_TOKEN
13 | start python ./backend-python/main.py --webui
14 | start "C:\Program Files (x86)\Microsoft\Edge\Application\msedge.exe" "http://127.0.0.1:8000"
15 | 
16 | powershell -Command "(Test-Path ./models) -or (mkdir models)"
17 | powershell -Command "Import-Module BitsTransfer"
18 | powershell -Command "(Test-Path ./models/RWKV-4-World-1.5B-v1-fixed-20230612-ctx4096.pth) -or (Start-BitsTransfer https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-1.5B-v1-fixed-20230612-ctx4096.pth ./models/RWKV-4-World-1.5B-v1-fixed-20230612-ctx4096.pth)"
19 | powershell -Command "Invoke-WebRequest http://127.0.0.1:8000/switch-model -Method POST -ContentType 'application/json' -Body '{\"model\":\"./models/RWKV-4-World-1.5B-v1-fixed-20230612-ctx4096.pth\",\"strategy\":\"cuda fp32 *20+\",\"deploy\":\"true\"}'"
20 | 


--------------------------------------------------------------------------------
/deploy-examples/RWKV-Runner-WebUI/setup.sh:
--------------------------------------------------------------------------------
 1 | # install git python3.10 npm by yourself
 2 | # change model and strategy according to your hardware
 3 | 
 4 | sudo apt install python3-dev
 5 | 
 6 | git clone https://github.com/josStorer/RWKV-Runner --depth=1
 7 | python3 -m pip install torch torchvision torchaudio
 8 | python3 -m pip install -r RWKV-Runner/backend-python/requirements.txt
 9 | cd RWKV-Runner/frontend
10 | npm ci
11 | npm run build
12 | cd ..
13 | 
14 | # optional: export ngrok_token=YOUR_NGROK_TOKEN
15 | python3 ./backend-python/main.py --webui > log.txt & # this is only an example, you should use screen or other tools to run it in background
16 | 
17 | if [ ! -d models ]; then
18 |     mkdir models
19 | fi
20 | wget -N https://huggingface.co/BlinkDL/rwkv-4-world/resolve/main/RWKV-4-World-0.1B-v1-20230520-ctx4096.pth -P models/
21 | 
22 | curl http://127.0.0.1:8000/switch-model -X POST -H "Content-Type: application/json" -d '{"model":"./models/RWKV-4-World-0.1B-v1-20230520-ctx4096.pth","strategy":"cpu fp32","deploy":"true"}'
23 | 


--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
 1 | services:
 2 |   rmkv_runner:
 3 |     image: rwkv-runner:latest
 4 |     build: .
 5 |     # Append "--rwkv.cpp" parameter to use rwkv.cpp
 6 |     # command: python3.10 ./backend-python/main.py  --port 27777 --host 0.0.0.0 --webui --rwkv.cpp
 7 |     volumes:
 8 |       - /mnt:/mnt
 9 |     ports:
10 |       - "27777:27777"
11 |     # Comment the following lines if use rwkv.cpp
12 |     deploy:
13 |       resources:
14 |         reservations:
15 |           devices:
16 |             - driver: nvidia
17 |               count: 1
18 |               capabilities: [gpu]
19 | 


--------------------------------------------------------------------------------
/exportModelsJson.js:
--------------------------------------------------------------------------------
 1 | // Execute this script on the Hugging Face files list page to export JSON data. Don't forget to click "Load more files".
 2 | // Run console.log(JSON.stringify(modelsJson, null, 2)) to output the JSON to the console.
 3 | 
 4 | let modelsJson = []
 5 | 
 6 | function extractValue(text, prefix) {
 7 |   let ret
 8 |   text.split('\n').forEach(line => {
 9 |     if (!ret && line.startsWith(prefix))
10 |       ret = line.replace(prefix, '').trim()
11 |   })
12 |   return ret || ''
13 | }
14 | 
15 | document.querySelectorAll('.grid.h-10.grid-cols-12.place-content-center.gap-x-3.border-t.px-3.dark\\:border-gray-800').forEach(async e => {
16 |   let data = {}
17 |   data.name = e.children[0].children[0].textContent.trim()
18 | 
19 |   if (!data.name.endsWith('.bin') && !data.name.endsWith('.pth') && !data.name.endsWith('.gguf'))
20 |     return
21 | 
22 |   data.desc = { en: '', zh: '', ja: '' }
23 |   const rawText = await (await fetch(e.children[1].href.replace('/resolve/', '/raw/'))).text()
24 | 
25 |   data.size = parseInt(extractValue(rawText, 'size'))
26 |   data.SHA256 = extractValue(rawText, 'oid sha256:')
27 |   data.lastUpdated = e.children[3].children[0].getAttribute('datetime')
28 |   data.url = e.children[1].href.replace('/resolve/', '/blob/').replace('?download=true', '')
29 |   data.downloadUrl = e.children[1].href.replace('?download=true', '')
30 |   data.tags = []
31 | 
32 |   modelsJson.push(data)
33 | })
34 | 
35 | setTimeout(() => {
36 |   console.log(JSON.stringify(modelsJson, null, 2))
37 | }, 500)
38 | 


--------------------------------------------------------------------------------
/finetune/data/sample.jsonl:
--------------------------------------------------------------------------------
1 | {"text": "The following is an epic science fiction masterpiece that is immortalized, with delicate descriptions and grand depictions of interstellar civilization wars.\nChapter 1.\nAs I sit down to write here amidst the shadows of vine-leaves under the blue sky of southern Italy, it comes to me with a certain quality of astonishment that my participation in these amazing adventures of Mr. Cavor was, after all, the outcome of the purest accident. It might have been any one. I fell into these things at a time when I thought myself removed from the slightest possibility of disturbing experiences. I had gone to Lympne because I had imagined it the most uneventful place in the world. “Here, at any rate,” said I, “I shall find peace and a chance to work!”"}
2 | {"text": "Translate the following into Chinese.\n\nEnglish: What rooms do you have available?\nChinese: 你们有哪些房间可以提供"}
3 | {"text": "User: Hello.\n\nAssistant: I'm here, meow~.\n\nUser: Can you tell some jokes?\n\nAssistant: Of course, master. What kind of jokes would you like to hear?"}
4 | {"text": "Instruction: Write a story using the following information\n\nInput: A man named Alex chops a tree down\n\nResponse: Once upon a time, there was a man named Alex who lived in the heart of the forest. He had always been fascinated by trees and spent most of his days exploring the forest and learning about its many wonders. One day, while wandering through the woods, he stumbled upon an old oak tree that stood tall and proud in the middle of a clearing."}
5 | {"text": "def get_args(args: Union[Sequence[str], None] = None):\n    parser = argparse.ArgumentParser()\n    group = parser.add_argument_group(title=\"server arguments\")\n    group.add_argument(\n        \"--port\",\n        type=int,\n        default=8000,\n        help=\"port to run the server on (default: 8000)\",\n    )\n    group.add_argument(\n        \"--host\",\n        type=str,\n        default=\"127.0.0.1\",\n        help=\"host to run the server on (default: 127.0.0.1)\",\n    )"}


--------------------------------------------------------------------------------
/finetune/get_layer_and_embd.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | import sys
 3 | import time
 4 | import os
 5 | import threading
 6 | import gc
 7 | 
 8 | 
 9 | def file_cleaner(file):
10 |     last_pos = 0
11 | 
12 |     def cleaner():
13 |         nonlocal last_pos
14 |         while True:
15 |             time.sleep(0.1)
16 |             pos = file.tell()
17 |             if pos > last_pos:
18 |                 os.posix_fadvise(
19 |                     file.fileno(), last_pos, pos - last_pos, os.POSIX_FADV_DONTNEED
20 |                 )
21 |             last_pos = pos
22 | 
23 |     return cleaner
24 | 
25 | 
26 | expected_max_version = float(sys.argv[2]) if len(sys.argv) > 2 else 100
27 | model_file = open(sys.argv[1], "rb")
28 | cleaner = file_cleaner(model_file)
29 | cleaner_thread = threading.Thread(target=cleaner, daemon=True)
30 | cleaner_thread.start()
31 | 
32 | w = torch.load(model_file, map_location="cpu")
33 | gc.collect()
34 | 
35 | vocab_size = w["emb.weight"].shape[0]
36 | n_embd = w["emb.weight"].shape[1]
37 | n_layer = 0
38 | keys = list(w.keys())
39 | version = 4
40 | for x in keys:
41 |     layer_id = int(x.split(".")[1]) if ("blocks." in x) else 0
42 |     n_layer = max(n_layer, layer_id + 1)
43 | 
44 |     if "ln_x" in x:
45 |         version = max(5, version)
46 |     if "gate.weight" in x:
47 |         version = max(5.1, version)
48 |     if int(version) == 5 and "att.time_decay" in x:
49 |         if len(w[x].shape) > 1:
50 |             if w[x].shape[1] > 1:
51 |                 version = max(5.2, version)
52 |     if "time_maa" in x:
53 |         version = max(6, version)
54 | 
55 | params = f"--vocab_size {vocab_size} --n_layer {n_layer} --n_embd {n_embd}"
56 | 
57 | if version <= expected_max_version:
58 |     if version == 6:
59 |         params += ' --my_testing "x060"'
60 |     print(
61 |         f"v{int(version)}/train.py {params}",
62 |         end="",
63 |     )
64 | else:
65 |     raise Exception(f"RWKV{version} is not supported")
66 | 


--------------------------------------------------------------------------------
/finetune/install-wsl-dep-and-train.sh:
--------------------------------------------------------------------------------
 1 | echo $@
 2 | 
 3 | if [[ ${cnMirror} == 1 ]]; then
 4 |   export PIP_INDEX_URL="https://mirrors.aliyun.com/pypi/simple"
 5 |   if grep -q "mirrors.aliyun.com" /etc/apt/sources.list; then
 6 |     echo "apt cnMirror already set"
 7 |   else
 8 |     sudo sed -i 's/http:\/\/archive.ubuntu.com\/ubuntu\//http:\/\/mirrors.aliyun.com\/ubuntu\//g' /etc/apt/sources.list
 9 |     sudo apt update
10 |   fi
11 | fi
12 | 
13 | if dpkg -s "gcc" >/dev/null 2>&1; then
14 |   echo "gcc installed"
15 | else
16 |   sudo apt -y install gcc
17 | fi
18 | 
19 | if dpkg -s "python3-pip" >/dev/null 2>&1; then
20 |   echo "pip installed"
21 | else
22 |   sudo apt -y install python3-pip
23 | fi
24 | 
25 | if dpkg -s "python3-dev" >/dev/null 2>&1; then
26 |   echo "python3-dev installed"
27 | else
28 |   sudo apt -y install python3-dev
29 | fi
30 | 
31 | if dpkg -s "ninja-build" >/dev/null 2>&1; then
32 |   echo "ninja installed"
33 | else
34 |   sudo apt -y install ninja-build
35 | fi
36 | 
37 | if dpkg -s "cuda" >/dev/null 2>&1 && dpkg -s "cuda" | grep Version | awk '{print $2}' | grep -q "12"; then
38 |   echo "cuda 12 installed"
39 | else
40 |   wget -N https://developer.download.nvidia.com/compute/cuda/repos/wsl-ubuntu/x86_64/cuda-wsl-ubuntu.pin
41 |   sudo mv cuda-wsl-ubuntu.pin /etc/apt/preferences.d/cuda-repository-pin-600
42 |   wget -N https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda-repo-wsl-ubuntu-12-2-local_12.2.0-1_amd64.deb
43 |   sudo dpkg -i cuda-repo-wsl-ubuntu-12-2-local_12.2.0-1_amd64.deb
44 |   sudo cp /var/cuda-repo-wsl-ubuntu-12-2-local/cuda-*-keyring.gpg /usr/share/keyrings/
45 |   sudo apt-get update
46 |   sudo apt-get -y install cuda
47 | fi
48 | 
49 | if python3 -c "import pkg_resources; pkg_resources.require(open('./finetune/requirements.txt',mode='r'))" &>/dev/null; then
50 |   echo "requirements satisfied"
51 | else
52 |   python3 -m pip install -r ./finetune/requirements.txt
53 | fi
54 | 
55 | echo "loading $loadModel"
56 | modelInfo=$(python3 ./finetune/get_layer_and_embd.py $loadModel 6.0)
57 | echo $modelInfo
58 | if [[ $modelInfo =~ "--n_layer" ]]; then
59 |   sudo rm -rf /root/.cache/torch_extensions
60 |   python3 ./finetune/lora/$modelInfo $@ --proj_dir lora-models --data_type binidx --lora \
61 |     --lora_parts=att,ffn,time,ln --strategy deepspeed_stage_2 --accelerator gpu --ds_bucket_mb 2
62 | else
63 |   echo "modelInfo is invalid"
64 |   exit 1
65 | fi
66 | 


--------------------------------------------------------------------------------
/finetune/lora/merge_lora.py:
--------------------------------------------------------------------------------
 1 | from collections import OrderedDict
 2 | import os
 3 | import sys
 4 | from typing import Dict
 5 | import typing
 6 | import torch
 7 | 
 8 | try:
 9 |     if "-h" in sys.argv or "--help" in sys.argv:
10 |         print(
11 |             f"Usage: python3 {sys.argv[0]} [--use-gpu] <lora_alpha> <base_model.pth> <lora_checkpoint.pth> <output.pth>"
12 |         )
13 | 
14 |     if sys.argv[1] == "--use-gpu":
15 |         device = "cuda"
16 |         lora_alpha, base_model, lora, output = (
17 |             float(sys.argv[2]),
18 |             sys.argv[3],
19 |             sys.argv[4],
20 |             sys.argv[5],
21 |         )
22 |     else:
23 |         device = "cpu"
24 |         lora_alpha, base_model, lora, output = (
25 |             float(sys.argv[1]),
26 |             sys.argv[2],
27 |             sys.argv[3],
28 |             sys.argv[4],
29 |         )
30 | 
31 |     with torch.no_grad():
32 |         w: Dict[str, torch.Tensor] = torch.load(base_model, map_location="cpu")
33 |         # merge LoRA-only slim checkpoint into the main weights
34 |         w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location="cpu")
35 |         for k in w_lora.keys():
36 |             w[k] = w_lora[k]
37 |         output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict()
38 |         # merge LoRA weights
39 |         keys = list(w.keys())
40 |         for k in keys:
41 |             if k.endswith(".weight"):
42 |                 prefix = k[: -len(".weight")]
43 |                 lora_A = prefix + ".lora_A"
44 |                 lora_B = prefix + ".lora_B"
45 |                 if lora_A in keys:
46 |                     assert lora_B in keys
47 |                     print(f"merging {lora_A} and {lora_B} into {k}")
48 |                     assert w[lora_B].shape[1] == w[lora_A].shape[0]
49 |                     lora_r = w[lora_B].shape[1]
50 |                     w[k] = w[k].to(device=device)
51 |                     w[lora_A] = w[lora_A].to(device=device)
52 |                     w[lora_B] = w[lora_B].to(device=device)
53 |                     w[k] += w[lora_B] @ w[lora_A] * (lora_alpha / lora_r)
54 |                     output_w[k] = w[k].to(device="cpu", copy=True)
55 |                     del w[k]
56 |                     del w[lora_A]
57 |                     del w[lora_B]
58 |                     continue
59 | 
60 |             if "lora" not in k:
61 |                 print(f"retaining {k}")
62 |                 output_w[k] = w[k].clone()
63 |                 del w[k]
64 | 
65 |         torch.save(output_w, output)
66 | except Exception as e:
67 |     print(e)
68 |     with open("error.txt", "w") as f:
69 |         f.write(str(e))
70 | 


--------------------------------------------------------------------------------
/finetune/lora/v4/cuda/wkv_op.cpp:
--------------------------------------------------------------------------------
 1 | #include <torch/extension.h>
 2 | 
 3 | void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
 4 | void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv);
 5 | 
 6 | void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
 7 |     cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
 8 | }
 9 | void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
10 |     cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
11 | }
12 | 
13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
14 |     m.def("forward", &forward, "wkv forward");
15 |     m.def("backward", &backward, "wkv backward");
16 | }
17 | 
18 | TORCH_LIBRARY(wkv, m) {
19 |     m.def("forward", forward);
20 |     m.def("backward", backward);
21 | }
22 | 


--------------------------------------------------------------------------------
/finetune/lora/v4/cuda/wkv_op_bf16.cpp:
--------------------------------------------------------------------------------
 1 | #include <torch/extension.h>
 2 | #include "ATen/ATen.h"
 3 | typedef at::BFloat16 bf16;
 4 | 
 5 | void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y);
 6 | void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv);
 7 | 
 8 | void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
 9 |     cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>());
10 | }
11 | void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y,
12 |     torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
13 |     cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(),
14 |         gy.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>());
15 | }
16 | 
17 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
18 |     m.def("forward", &forward, "wkv forward");
19 |     m.def("backward", &backward, "wkv backward");
20 | }
21 | 
22 | TORCH_LIBRARY(wkv, m) {
23 |     m.def("forward", forward);
24 |     m.def("backward", backward);
25 | }
26 | 


--------------------------------------------------------------------------------
/finetune/lora/v4/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/finetune/lora/v4/src/__init__.py


--------------------------------------------------------------------------------
/finetune/lora/v5/cuda/wkv5_op.cpp:
--------------------------------------------------------------------------------
 1 | #include <torch/extension.h>
 2 | #include "ATen/ATen.h"
 3 | typedef at::BFloat16 bf16;
 4 | 
 5 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
 6 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu);
 7 | 
 8 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
 9 |     cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
10 | }
11 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) {
12 |     cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), ww.data_ptr<float>(), u.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>());
13 | }
14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
15 |     m.def("forward", &forward, "wkv5 forward");
16 |     m.def("backward", &backward, "wkv5 backward");
17 | }
18 | 
19 | TORCH_LIBRARY(wkv5, m) {
20 |     m.def("forward", forward);
21 |     m.def("backward", backward);
22 | }
23 | 


--------------------------------------------------------------------------------
/finetune/lora/v5/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/finetune/lora/v5/src/__init__.py


--------------------------------------------------------------------------------
/finetune/lora/v6/cuda/wkv5_op.cpp:
--------------------------------------------------------------------------------
 1 | #include <torch/extension.h>
 2 | #include "ATen/ATen.h"
 3 | typedef at::BFloat16 bf16;
 4 | 
 5 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
 6 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu);
 7 | 
 8 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
 9 |     cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
10 | }
11 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) {
12 |     cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), ww.data_ptr<float>(), u.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>());
13 | }
14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
15 |     m.def("forward", &forward, "wkv5 forward");
16 |     m.def("backward", &backward, "wkv5 backward");
17 | }
18 | 
19 | TORCH_LIBRARY(wkv5, m) {
20 |     m.def("forward", forward);
21 |     m.def("backward", backward);
22 | }
23 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/cuda/wkv6_op.cpp:
--------------------------------------------------------------------------------
 1 | #include <torch/extension.h>
 2 | #include "ATen/ATen.h"
 3 | typedef at::BFloat16 bf16;
 4 | 
 5 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
 6 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu);
 7 | 
 8 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
 9 |     cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
10 | }
11 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) {
12 |     cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>());
13 | }
14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
15 |     m.def("forward", &forward, "wkv6 forward");
16 |     m.def("backward", &backward, "wkv6 backward");
17 | }
18 | 
19 | TORCH_LIBRARY(wkv6, m) {
20 |     m.def("forward", forward);
21 |     m.def("backward", backward);
22 | }
23 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/cuda/wkv6infctx_op.cpp:
--------------------------------------------------------------------------------
 1 | #include <torch/extension.h>
 2 | #include "ATen/ATen.h"
 3 | typedef at::BFloat16 bf16;
 4 | 
 5 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *y);
 6 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu, bf16 *gs);
 7 | 
 8 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &y) {
 9 |     cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<bf16>(), u.data_ptr<bf16>(), s.data_ptr<bf16>(), y.data_ptr<bf16>());
10 | }
11 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gs) {
12 |     cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<bf16>(), u.data_ptr<bf16>(), s.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gs.data_ptr<bf16>());
13 | }
14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
15 |     m.def("forward", &forward, "wkv6state forward");
16 |     m.def("backward", &backward, "wkv6state backward");
17 | }
18 | 
19 | TORCH_LIBRARY(wkv6state, m) {
20 |     m.def("forward", forward);
21 |     m.def("backward", backward);
22 | }
23 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/cuda/wkv6state_op.cpp:
--------------------------------------------------------------------------------
 1 | #include <torch/extension.h>
 2 | #include "ATen/ATen.h"
 3 | typedef at::BFloat16 bf16;
 4 | 
 5 | void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *y);
 6 | void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, bf16 *w, bf16 *u, bf16 *s, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu, bf16 *gs);
 7 | 
 8 | void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &y) {
 9 |     cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<bf16>(), u.data_ptr<bf16>(), s.data_ptr<bf16>(), y.data_ptr<bf16>());
10 | }
11 | void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &s, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gs) {
12 |     cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<bf16>(), u.data_ptr<bf16>(), s.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gs.data_ptr<bf16>());
13 | }
14 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
15 |     m.def("forward", &forward, "wkv6state forward");
16 |     m.def("backward", &backward, "wkv6state backward");
17 | }
18 | 
19 | TORCH_LIBRARY(wkv6state, m) {
20 |     m.def("forward", forward);
21 |     m.def("backward", backward);
22 | }
23 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/demo/demo-lora-merge.sh:
--------------------------------------------------------------------------------
 1 | 
 2 | base_model='/home/rwkv/JL/model/rwkv-x060-7b-world-v2.1-36%trained-20240413-ctx4k.pth'
 3 | lora_init='/home/rwkv/JL/out_model/nf4/init_lora.pth'
 4 | lora_checkpoint='/home/rwkv/JL/out_model/nf4/rwkv-0.pth'
 5 | output='/home/rwkv/JL/model/nf4-world.pth'
 6 | QUANT='nf4' #follow train
 7 | TYPE='lora'
 8 | Lora_alpha=128
 9 | 
10 | python merge/merge.py --base_model $base_model \
11 | --lora_init $lora_init \
12 | --lora_checkpoint $lora_checkpoint \
13 | --output $output \
14 | --quant $QUANT \
15 | --type $TYPE \
16 | --lora_alpha $Lora_alpha


--------------------------------------------------------------------------------
/finetune/lora/v6/demo/demo-lora.sh:
--------------------------------------------------------------------------------
 1 | load_model='/home/rwkv/JL/model/rwkv-x060-7b-world-v2.1-36%trained-20240413-ctx4k.pth'
 2 | proj_dir='/home/rwkv/JL/out_model/nf4'
 3 | data_file='/home/rwkv/JL/data/roleplay'
 4 | 
 5 | QUANT='nf4'  #4bit nf4 fp4 none
 6 | 
 7 | lora_r=64
 8 | lora_alpha=128
 9 | 
10 | n_layer=32
11 | n_embd=4096
12 | 
13 | micro_bsz=8
14 | epoch_save=1
15 | epoch_steps=1000
16 | ctx_len=1024
17 | 
18 | python train.py --load_model $load_model \
19 | --proj_dir $proj_dir --data_file $data_file \
20 | --data_type binidx --vocab_size 65536 \
21 | --ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 20 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \
22 | --n_layer $n_layer --n_embd $n_embd \
23 | --pre_ffn 0 --head_qk 0 --lr_init 5e-5 --lr_final 5e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
24 | --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 1 \
25 | --my_testing "x060" \
26 | --lora_load rwkv-0 --lora --lora_r $lora_r --lora_alpha $lora_alpha --lora_dropout 0.01 --lora_parts=att,ffn,time,ln \
27 | --quant $QUANT


--------------------------------------------------------------------------------
/finetune/lora/v6/demo/demo-pissa-merge.sh:
--------------------------------------------------------------------------------
 1 | 
 2 | 
 3 | base_model='/home/rwkv/JL/model/RWKV-x060-World-1B6-v2-20240208-ctx4096.pth'
 4 | lora_init='/home/rwkv/JL/out_model/nf4/init_lora.pth'
 5 | lora_checkpoint='/home/rwkv/JL/out_model/nf4/rwkv-0.pth'
 6 | output='/home/rwkv/JL/model/end-world.pth'
 7 | QUANT='nf4' #follow train
 8 | TYPE='pissa'
 9 | 
10 | python merge/merge.py --base_model $base_model \
11 | --lora_init $lora_init \
12 | --lora_checkpoint $lora_checkpoint \
13 | --output $output \
14 | --quant $QUANT \
15 | --type $TYPE 


--------------------------------------------------------------------------------
/finetune/lora/v6/demo/demo-pissa.sh:
--------------------------------------------------------------------------------
 1 | 
 2 | load_model='/home/rwkv/JL/model/RWKV-x060-World-1B6-v2.1-20240328-ctx4096.pth'
 3 | proj_dir='/home/rwkv/JL/out_model/nf4'
 4 | data_file='/home/rwkv/JL/data/end_text_document'
 5 | 
 6 | QUANT='nf4'  #4bit nf4 fp4 none
 7 | svd_niter=4  
 8 | lora_r=64
 9 | 
10 | n_layer=24
11 | n_embd=2048
12 | 
13 | micro_bsz=8
14 | epoch_save=1
15 | epoch_steps=1000
16 | ctx_len=1024
17 | 
18 | python train.py --load_model $load_model \
19 | --proj_dir $proj_dir --data_file $data_file \
20 | --data_type binidx --vocab_size 65536 \
21 | --ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 1 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \
22 | --n_layer $n_layer --n_embd $n_embd \
23 | --pre_ffn 0 --head_qk 0 --lr_init 5e-5 --lr_final 5e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
24 | --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 1 \
25 | --my_testing "x060" \
26 | --lora_load rwkv-0 --lora --lora_r $lora_r --lora_alpha 128 --lora_dropout 0.01 --lora_parts=att,ffn,time,ln \
27 | --PISSA --svd_niter $svd_niter \
28 | --dataload pad
29 | 
30 | ###remove load_model
31 | # python train.py --proj_dir $proj_dir --data_file $data_file \
32 | # --data_type binidx --vocab_size 65536 \
33 | # --ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 20 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \
34 | # --n_layer $n_layer --n_embd $n_embd \
35 | # --pre_ffn 0 --head_qk 0 --lr_init 5e-5 --lr_final 5e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
36 | # --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 1 \
37 | # --my_testing "x060" \
38 | # --lora_load rwkv-0 --lora --lora_r $lora_r --lora_alpha 128 --lora_dropout 0.01 --lora_parts=att,ffn,time,ln \
39 | # --PISSA --svd_niter $svd_niter \
40 | # --quant $QUANT


--------------------------------------------------------------------------------
/finetune/lora/v6/demo/demo-qpissa-pt.sh:
--------------------------------------------------------------------------------
 1 | load_model='/home/rwkv/JL/model/rwkv-x060-7b-world-v2.1-36%trained-20240413-ctx4k.pth'
 2 | proj_dir='/home/rwkv/JL/out_model/nf4'
 3 | data_file='/home/rwkv/JL/data/roleplay'
 4 | 
 5 | QUANT='nf4'  #4bit nf4 fp4 none
 6 | svd_niter=4  
 7 | lora_r=64
 8 | 
 9 | n_layer=32
10 | n_embd=4096
11 | 
12 | micro_bsz=4
13 | epoch_save=1
14 | epoch_steps=1000
15 | ctx_len=1024
16 | 
17 | 
18 | python train.py --proj_dir $proj_dir --data_file $data_file \
19 | --data_type binidx --vocab_size 65536 \
20 | --ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 20 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \
21 | --n_layer $n_layer --n_embd $n_embd \
22 | --pre_ffn 0 --head_qk 0 --lr_init 5e-5 --lr_final 5e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
23 | --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 1 \
24 | --my_testing "x060" \
25 | --lora_load rwkv-0 --lora --lora_r $lora_r --lora_alpha 128 --lora_dropout 0.01 --lora_parts=att,ffn,time,ln \
26 | --PISSA --svd_niter $svd_niter \
27 | --quant $QUANT


--------------------------------------------------------------------------------
/finetune/lora/v6/demo/demo-state-merge.sh:
--------------------------------------------------------------------------------
1 | base_model='/home/rwkv/JL/model/RWKV-x060-World-3B-v2.1-20240417-ctx4096.pth'
2 | state_checkpoint='/home/rwkv/JL/out_model/state/rwkv-9.pth'
3 | output='/home/rwkv/JL/model/state-0.pth'
4 | 
5 | 
6 | python merge/merge_state.py --base_model $base_model \
7 | --state_checkpoint $state_checkpoint \
8 | --output $output


--------------------------------------------------------------------------------
/finetune/lora/v6/demo/demo-state-tuning.sh:
--------------------------------------------------------------------------------
 1 | load_model='/home/rwkv/JL/model/RWKV-x060-World-1B6-v2.1-20240328-ctx4096.pth'
 2 | proj_dir='/home/rwkv/JL/out_model/state'
 3 | data_file='/home/rwkv/JL/data/end_text_document'
 4 | 
 5 | 
 6 | n_layer=24
 7 | n_embd=2048
 8 | 
 9 | micro_bsz=1
10 | epoch_save=1
11 | epoch_steps=1000
12 | ctx_len=1024
13 | 
14 | python train.py --load_model $load_model \
15 | --proj_dir $proj_dir --data_file $data_file \
16 | --data_type binidx --vocab_size 65536 \
17 | --ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 1 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \
18 | --n_layer $n_layer --n_embd $n_embd \
19 | --pre_ffn 0 --head_qk 0 --lr_init 1 --lr_final 1e-1 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
20 | --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 0 \
21 | --my_testing "x060" \
22 | --train_type "state"  --dataload pad --wandb fla --fla


--------------------------------------------------------------------------------
/finetune/lora/v6/demo/demo-training-prepare.sh:
--------------------------------------------------------------------------------
 1 | #!/bin/bash
 2 | 
 3 | # Create data directory
 4 | 
 5 | mkdir -p data
 6 | 
 7 | # Download minipile (1498226207 tokens, around 3GB)
 8 | 
 9 | wget --continue -O data/minipile.idx https://huggingface.co/datasets/BlinkDL/minipile-tokenized/resolve/main/rwkv_vocab_v20230424/minipile.idx
10 | wget --continue -O data/minipile.bin https://huggingface.co/datasets/BlinkDL/minipile-tokenized/resolve/main/rwkv_vocab_v20230424/minipile.bin
11 | 
12 | # Generate initial model (L12-D768 = 169M)
13 | 
14 | BASE_NAME="model/0.1-1"
15 | N_LAYER="12"
16 | N_EMBD="768"
17 | 
18 | # magic_prime = the largest 3n+2 prime smaller than datalen/ctxlen-1 (= 1498226207/512-1 = 2926222.06 in this case)
19 | # use https://www.dcode.fr/prime-numbers-search
20 | 
21 | python train.py --wandb "" --proj_dir $BASE_NAME \
22 |  --data_file "data/minipile" --data_type "binidx" --vocab_size 65536 \
23 |  --ctx_len 512 --my_pile_stage 1 --epoch_count 1 --epoch_begin 0 \
24 |  --epoch_save 1 --weight_decay 0 --head_size_a 64 \
25 |  --num_nodes 1 --micro_bsz 1 --n_layer $N_LAYER --n_embd $N_EMBD --pre_ffn 0 --head_qk 0 --my_exit_tokens 1498226207 --magic_prime 2926181 \
26 |  --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 10 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 --my_pile_edecay 0 \
27 |  --accelerator cpu --devices 1 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0 --enable_progress_bar False --ds_bucket_mb 200
28 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/demo/demo-training-run.sh:
--------------------------------------------------------------------------------
 1 | #!/bin/bash
 2 | 
 3 | BASE_NAME="model/0.1-1"
 4 | N_LAYER="12"
 5 | N_EMBD="768"
 6 | M_BSZ="16" # takes 16G VRAM (reduce this to save VRAM)
 7 | LR_INIT="6e-4"
 8 | LR_FINAL="6e-5"
 9 | GRAD_CP=0 # set to 1 to save VRAM (will be slower)
10 | EPOCH_SAVE=10
11 | 
12 | # magic_prime = the largest 3n+2 prime smaller than datalen/ctxlen-1 (= 1498226207/512-1 = 2926222.06 in this case)
13 | # use https://www.dcode.fr/prime-numbers-search
14 | 
15 | python train.py --load_model "0" --wandb "RWKV-5-Test" --proj_dir $BASE_NAME \
16 |  --ctx_len 512 --my_pile_stage 3 --epoch_count 999999 --epoch_begin 0 \
17 |  --data_file "data/minipile" --my_exit_tokens 1498226207 --magic_prime 2926181 \
18 |  --num_nodes 1 --micro_bsz $M_BSZ --n_layer $N_LAYER --n_embd $N_EMBD --pre_ffn 0 --head_qk 0 \
19 |  --lr_init $LR_INIT --lr_final $LR_FINAL --warmup_steps 10 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 --my_pile_edecay 0 --data_type "binidx" --vocab_size 65536 \
20 |  --weight_decay 0.001 --epoch_save $EPOCH_SAVE --head_size_a 64 \
21 |  --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2 --grad_cp $GRAD_CP --enable_progress_bar True --ds_bucket_mb 200
22 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/demo/infctx.sh:
--------------------------------------------------------------------------------
 1 | load_model='/home/rwkv/JL/model/RWKV-x060-World-1B6-v2.1-20240328-ctx4096.pth'
 2 | proj_dir='/home/rwkv/JL/out_model/infctx'
 3 | data_file='/home/rwkv/JL/data/roleplay'
 4 | 
 5 | 
 6 | n_layer=24
 7 | n_embd=2048
 8 | 
 9 | micro_bsz=8
10 | epoch_save=5
11 | epoch_steps=1000
12 | ctx_len=16384
13 | chunk_ctx=2048
14 | 
15 | 
16 | python train.py --load_model $load_model \
17 | --proj_dir $proj_dir --data_file $data_file \
18 | --data_type binidx --vocab_size 65536 \
19 | --ctx_len $ctx_len --epoch_steps $epoch_steps --epoch_count 1 --epoch_begin 0 --epoch_save $epoch_save --micro_bsz $micro_bsz \
20 | --n_layer $n_layer --n_embd $n_embd \
21 | --pre_ffn 0 --head_qk 0 --lr_init 1e-4 --lr_final 1e-4 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \
22 | --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_1 --grad_cp 1 \
23 | --lora_load rwkv-0 --lora --lora_r 64 --lora_alpha 128 --lora_dropout 0.01 --lora_parts=att,ffn,time,ln \
24 | --my_testing "x060"  --dataload pad \
25 | --train_type infctx --chunk_ctx $chunk_ctx --fla --wandb infctx


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from fla.layers import (ABCAttention, BasedLinearAttention, DeltaNet,
 4 |                         GatedLinearAttention, HGRN2Attention, LinearAttention,
 5 |                         MultiScaleRetention, ReBasedLinearAttention)
 6 | from fla.models import (ABCForCausalLM, ABCModel, DeltaNetForCausalLM,
 7 |                         DeltaNetModel, GLAForCausalLM, GLAModel,
 8 |                         HGRN2ForCausalLM, HGRN2Model, HGRNForCausalLM,
 9 |                         HGRNModel, LinearAttentionForCausalLM,
10 |                         LinearAttentionModel, RetNetForCausalLM, RetNetModel,
11 |                         RWKV6ForCausalLM, RWKV6Model, TransformerForCausalLM,
12 |                         TransformerModel)
13 | from fla.ops import (chunk_gla, chunk_retention, fused_chunk_based,
14 |                      fused_chunk_gla, fused_chunk_retention)
15 | 
16 | __all__ = [
17 |     'ABCAttention',
18 |     'BasedLinearAttention',
19 |     'DeltaNet',
20 |     'HGRN2Attention',
21 |     'GatedLinearAttention',
22 |     'LinearAttention',
23 |     'MultiScaleRetention',
24 |     'ReBasedLinearAttention',
25 |     'ABCForCausalLM',
26 |     'ABCModel',
27 |     'DeltaNetForCausalLM',
28 |     'DeltaNetModel',
29 |     'HGRNForCausalLM',
30 |     'HGRNModel',
31 |     'HGRN2ForCausalLM',
32 |     'HGRN2Model',
33 |     'GLAForCausalLM',
34 |     'GLAModel',
35 |     'LinearAttentionForCausalLM',
36 |     'LinearAttentionModel',
37 |     'RetNetForCausalLM',
38 |     'RetNetModel',
39 |     'RWKV6ForCausalLM',
40 |     'RWKV6Model',
41 |     'TransformerForCausalLM',
42 |     'TransformerModel',
43 |     'chunk_gla',
44 |     'chunk_retention',
45 |     'fused_chunk_based',
46 |     'fused_chunk_gla',
47 |     'fused_chunk_retention'
48 | ]
49 | 
50 | __version__ = '0.1'
51 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/layers/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from .abc import ABCAttention
 4 | from .based import BasedLinearAttention
 5 | from .delta_net import DeltaNet
 6 | from .gla import GatedLinearAttention
 7 | from .hgrn import HGRNAttention
 8 | from .hgrn2 import HGRN2Attention
 9 | from .linear_attn import LinearAttention
10 | from .multiscale_retention import MultiScaleRetention
11 | from .rebased import ReBasedLinearAttention
12 | from .rwkv6 import RWKV6Attention
13 | 
14 | __all__ = [
15 |     'ABCAttention',
16 |     'BasedLinearAttention',
17 |     'DeltaNet',
18 |     'GatedLinearAttention',
19 |     'HGRNAttention',
20 |     'HGRN2Attention',
21 |     'LinearAttention',
22 |     'MultiScaleRetention',
23 |     'ReBasedLinearAttention',
24 |     'RWKV6Attention'
25 | ]
26 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/models/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from fla.models.abc import ABCConfig, ABCForCausalLM, ABCModel
 4 | from fla.models.delta_net import (DeltaNetConfig, DeltaNetForCausalLM,
 5 |                                   DeltaNetModel)
 6 | from fla.models.gla import GLAConfig, GLAForCausalLM, GLAModel
 7 | from fla.models.hgrn import HGRNConfig, HGRNForCausalLM, HGRNModel
 8 | from fla.models.hgrn2 import HGRN2Config, HGRN2ForCausalLM, HGRN2Model
 9 | from fla.models.linear_attn import (LinearAttentionConfig,
10 |                                     LinearAttentionForCausalLM,
11 |                                     LinearAttentionModel)
12 | from fla.models.mamba import MambaConfig, MambaForCausalLM, MambaModel
13 | from fla.models.retnet import RetNetConfig, RetNetForCausalLM, RetNetModel
14 | from fla.models.rwkv6 import RWKV6Config, RWKV6ForCausalLM, RWKV6Model
15 | from fla.models.transformer import (TransformerConfig, TransformerForCausalLM,
16 |                                     TransformerModel)
17 | 
18 | __all__ = [
19 |     'ABCConfig', 'ABCForCausalLM', 'ABCModel',
20 |     'DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel',
21 |     'GLAConfig', 'GLAForCausalLM', 'GLAModel',
22 |     'HGRNConfig', 'HGRNForCausalLM', 'HGRNModel',
23 |     'HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model',
24 |     'LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel',
25 |     'MambaConfig', 'MambaForCausalLM', 'MambaModel',
26 |     'RetNetConfig', 'RetNetForCausalLM', 'RetNetModel',
27 |     'RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model',
28 |     'TransformerConfig', 'TransformerForCausalLM', 'TransformerModel'
29 | ]
30 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/models/abc/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
 4 | 
 5 | from fla.models.abc.configuration_abc import ABCConfig
 6 | from fla.models.abc.modeling_abc import ABCForCausalLM, ABCModel
 7 | 
 8 | AutoConfig.register(ABCConfig.model_type, ABCConfig)
 9 | AutoModel.register(ABCConfig, ABCModel)
10 | AutoModelForCausalLM.register(ABCConfig, ABCForCausalLM)
11 | 
12 | 
13 | __all__ = ['ABCConfig', 'ABCForCausalLM', 'ABCModel']
14 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/models/abc/configuration_abc.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from typing import Optional
 4 | 
 5 | from transformers.configuration_utils import PretrainedConfig
 6 | 
 7 | 
 8 | class ABCConfig(PretrainedConfig):
 9 | 
10 |     model_type = 'abc'
11 |     keys_to_ignore_at_inference = ['past_key_values']
12 | 
13 |     def __init__(
14 |         self,
15 |         vocab_size: int = 32000,
16 |         hidden_size: int = 2048,
17 |         gate_low_rank_dim: int = 16,
18 |         clamp_min: float = -32,
19 |         clamp_max: float = 32,
20 |         hidden_ratio: Optional[int] = 4,
21 |         intermediate_size: Optional[int] = None,
22 |         num_hidden_layers: int = 24,
23 |         num_heads: int = 4,
24 |         num_slots: Optional[int] = 64,
25 |         use_short_conv: bool = True,
26 |         conv_size: int = 4,
27 |         share_conv_kernel: bool = True,
28 |         exapnd_k: float = 0.5,
29 |         exapnd_v: float = 1,
30 |         hidden_act: str = "swish",
31 |         max_position_embeddings: int = 2048,
32 |         elementwise_affine: Optional[bool] = True,
33 |         norm_eps: float = 1e-6,
34 |         use_cache: bool = True,
35 |         pad_token_id: int = None,
36 |         bos_token_id: int = 1,
37 |         eos_token_id: int = 2,
38 |         initializer_range: float = 0.02,
39 |         tie_word_embeddings: bool = False,
40 |         fuse_norm: bool = True,
41 |         fuse_cross_entropy: bool = True,
42 |         **kwargs
43 |     ):
44 |         self.vocab_size = vocab_size
45 |         self.max_position_embeddings = max_position_embeddings
46 |         self.hidden_size = hidden_size
47 |         self.gate_low_rank_dim = gate_low_rank_dim
48 |         self.clamp_min = clamp_min
49 |         self.clamp_max = clamp_max
50 |         self.hidden_ratio = hidden_ratio
51 |         self.intermediate_size = intermediate_size
52 |         self.num_hidden_layers = num_hidden_layers
53 |         self.num_heads = num_heads
54 |         self.num_slots = num_slots
55 |         self.use_short_conv = use_short_conv
56 |         self.conv_size = conv_size
57 |         self.share_conv_kernel = share_conv_kernel
58 |         self.expand_k = exapnd_k
59 |         self.expand_v = exapnd_v
60 |         self.hidden_act = hidden_act
61 |         self.elementwise_affine = elementwise_affine
62 |         self.norm_eps = norm_eps
63 |         self.use_cache = use_cache
64 |         self.initializer_range = initializer_range
65 |         self.fuse_cross_entropy = fuse_cross_entropy
66 |         self.fuse_norm = fuse_norm
67 | 
68 |         super().__init__(
69 |             pad_token_id=pad_token_id,
70 |             bos_token_id=bos_token_id,
71 |             eos_token_id=eos_token_id,
72 |             tie_word_embeddings=tie_word_embeddings,
73 |             **kwargs,
74 |         )
75 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/models/delta_net/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
 4 | 
 5 | from fla.models.delta_net.configuration_delta_net import \
 6 |     DeltaNetConfig
 7 | from fla.models.delta_net.modeling_delta_net import (
 8 |     DeltaNetForCausalLM, DeltaNetModel)
 9 | 
10 | AutoConfig.register(DeltaNetConfig.model_type, DeltaNetConfig)
11 | AutoModel.register(DeltaNetConfig, DeltaNetModel)
12 | AutoModelForCausalLM.register(DeltaNetConfig, DeltaNetForCausalLM)
13 | 
14 | __all__ = ['DeltaNetConfig', 'DeltaNetForCausalLM', 'DeltaNetModel']
15 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/models/gla/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
 4 | 
 5 | from fla.models.gla.configuration_gla import GLAConfig
 6 | from fla.models.gla.modeling_gla import GLAForCausalLM, GLAModel
 7 | 
 8 | AutoConfig.register(GLAConfig.model_type, GLAConfig)
 9 | AutoModel.register(GLAConfig, GLAModel)
10 | AutoModelForCausalLM.register(GLAConfig, GLAForCausalLM)
11 | 
12 | 
13 | __all__ = ['GLAConfig', 'GLAForCausalLM', 'GLAModel']
14 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/models/hgrn/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
 4 | 
 5 | from fla.models.hgrn.configuration_hgrn import HGRNConfig
 6 | from fla.models.hgrn.modeling_hgrn import HGRNForCausalLM, HGRNModel
 7 | 
 8 | AutoConfig.register(HGRNConfig.model_type, HGRNConfig)
 9 | AutoModel.register(HGRNConfig, HGRNModel)
10 | AutoModelForCausalLM.register(HGRNConfig, HGRNForCausalLM)
11 | 
12 | 
13 | __all__ = ['HGRNConfig', 'HGRNForCausalLM', 'HGRNModel']
14 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/models/hgrn/configuration_hgrn.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from typing import Optional
 4 | 
 5 | from transformers.configuration_utils import PretrainedConfig
 6 | 
 7 | 
 8 | class HGRNConfig(PretrainedConfig):
 9 | 
10 |     model_type = 'hgrn'
11 |     keys_to_ignore_at_inference = ['past_key_values']
12 | 
13 |     def __init__(
14 |         self,
15 |         attn_mode: str = "chunk",
16 |         vocab_size: int = 32000,
17 |         hidden_size: int = 2048,
18 |         num_hidden_layers: int = 24,
19 |         num_heads: Optional[int] = 1,
20 |         expand_ratio: Optional[int] = 1,
21 |         use_short_conv: bool = False,
22 |         conv_size: int = 4,
23 |         share_conv_kernel: bool = True,
24 |         use_lower_bound: bool = True,
25 |         hidden_ratio: Optional[int] = 4,
26 |         intermediate_size: Optional[int] = None,
27 |         hidden_act: str = "swish",
28 |         max_position_embeddings: int = 2048,
29 |         elementwise_affine: Optional[bool] = True,
30 |         norm_eps: float = 1e-6,
31 |         use_cache: bool = True,
32 |         pad_token_id: int = None,
33 |         bos_token_id: int = 1,
34 |         eos_token_id: int = 2,
35 |         tie_word_embeddings: bool = False,
36 |         initializer_range: float = 0.02,
37 |         fuse_cross_entropy: bool = True,
38 |         **kwargs
39 |     ):
40 |         self.attn_mode = attn_mode
41 |         self.vocab_size = vocab_size
42 |         self.max_position_embeddings = max_position_embeddings
43 |         self.hidden_size = hidden_size
44 |         self.num_hidden_layers = num_hidden_layers
45 |         self.num_heads = num_heads
46 |         self.expand_ratio = expand_ratio
47 |         self.use_short_conv = use_short_conv
48 |         self.conv_size = conv_size
49 |         self.share_conv_kernel = share_conv_kernel
50 |         self.use_lower_bound = use_lower_bound
51 |         self.hidden_ratio = hidden_ratio
52 |         self.intermediate_size = intermediate_size
53 |         self.hidden_act = hidden_act
54 |         self.elementwise_affine = elementwise_affine
55 |         self.norm_eps = norm_eps
56 |         self.use_cache = use_cache
57 |         self.initializer_range = initializer_range
58 |         self.fuse_cross_entropy = fuse_cross_entropy
59 | 
60 |         super().__init__(
61 |             pad_token_id=pad_token_id,
62 |             bos_token_id=bos_token_id,
63 |             eos_token_id=eos_token_id,
64 |             tie_word_embeddings=tie_word_embeddings,
65 |             **kwargs,
66 |         )
67 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/models/hgrn2/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
 4 | 
 5 | from fla.models.hgrn2.configuration_hgrn2 import HGRN2Config
 6 | from fla.models.hgrn2.modeling_hgrn2 import HGRN2ForCausalLM, HGRN2Model
 7 | 
 8 | AutoConfig.register(HGRN2Config.model_type, HGRN2Config)
 9 | AutoModel.register(HGRN2Config, HGRN2Model)
10 | AutoModelForCausalLM.register(HGRN2Config, HGRN2ForCausalLM)
11 | 
12 | 
13 | __all__ = ['HGRN2Config', 'HGRN2ForCausalLM', 'HGRN2Model']
14 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/models/hgrn2/configuration_hgrn2.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from typing import Optional
 4 | 
 5 | from transformers.configuration_utils import PretrainedConfig
 6 | 
 7 | 
 8 | class HGRN2Config(PretrainedConfig):
 9 | 
10 |     model_type = 'hgrn2'
11 |     keys_to_ignore_at_inference = ['past_key_values']
12 | 
13 |     def __init__(
14 |         self,
15 |         vocab_size: int = 32000,
16 |         hidden_size: int = 2048,
17 |         num_hidden_layers: int = 24,
18 |         attn_mode: str = "chunk",
19 |         num_heads: Optional[int] = None,
20 |         expand_ratio: Optional[int] = 128,
21 |         use_short_conv: bool = False,
22 |         conv_size: int = 4,
23 |         share_conv_kernel: bool = True,
24 |         use_lower_bound: bool = True,
25 |         hidden_ratio: Optional[int] = 4,
26 |         intermediate_size: Optional[int] = None,
27 |         hidden_act: str = "swish",
28 |         max_position_embeddings: int = 2048,
29 |         elementwise_affine: Optional[bool] = True,
30 |         norm_eps: float = 1e-6,
31 |         use_cache: bool = True,
32 |         pad_token_id: int = None,
33 |         bos_token_id: int = 1,
34 |         eos_token_id: int = 2,
35 |         tie_word_embeddings: bool = False,
36 |         initializer_range: float = 0.02,
37 |         fuse_cross_entropy: bool = True,
38 |         **kwargs
39 |     ):
40 |         self.vocab_size = vocab_size
41 |         self.max_position_embeddings = max_position_embeddings
42 |         self.hidden_size = hidden_size
43 |         self.num_hidden_layers = num_hidden_layers
44 |         self.attn_mode = attn_mode
45 |         self.num_heads = num_heads
46 |         self.expand_ratio = expand_ratio
47 |         self.use_short_conv = use_short_conv
48 |         self.conv_size = conv_size
49 |         self.share_conv_kernel = share_conv_kernel
50 |         self.use_lower_bound = use_lower_bound
51 |         self.hidden_ratio = hidden_ratio
52 |         self.intermediate_size = intermediate_size
53 |         self.hidden_act = hidden_act
54 |         self.elementwise_affine = elementwise_affine
55 |         self.norm_eps = norm_eps
56 |         self.use_cache = use_cache
57 |         self.initializer_range = initializer_range
58 |         self.fuse_cross_entropy = fuse_cross_entropy
59 | 
60 |         super().__init__(
61 |             pad_token_id=pad_token_id,
62 |             bos_token_id=bos_token_id,
63 |             eos_token_id=eos_token_id,
64 |             tie_word_embeddings=tie_word_embeddings,
65 |             **kwargs,
66 |         )
67 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/models/linear_attn/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
 4 | 
 5 | from fla.models.linear_attn.configuration_linear_attn import \
 6 |     LinearAttentionConfig
 7 | from fla.models.linear_attn.modeling_linear_attn import (
 8 |     LinearAttentionForCausalLM, LinearAttentionModel)
 9 | 
10 | AutoConfig.register(LinearAttentionConfig.model_type, LinearAttentionConfig)
11 | AutoModel.register(LinearAttentionConfig, LinearAttentionModel)
12 | AutoModelForCausalLM.register(LinearAttentionConfig, LinearAttentionForCausalLM)
13 | 
14 | __all__ = ['LinearAttentionConfig', 'LinearAttentionForCausalLM', 'LinearAttentionModel']
15 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/models/linear_attn/configuration_linear_attn.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from typing import Optional
 4 | 
 5 | from transformers.configuration_utils import PretrainedConfig
 6 | 
 7 | 
 8 | class LinearAttentionConfig(PretrainedConfig):
 9 | 
10 |     model_type = 'linear_attn'
11 |     keys_to_ignore_at_inference = ['past_key_values']
12 | 
13 |     def __init__(
14 |         self,
15 |         vocab_size: int = 32000,
16 |         hidden_size: int = 2048,
17 |         expand_k: int = 1,
18 |         expand_v: int = 1,
19 |         hidden_ratio: Optional[int] = 4,
20 |         intermediate_size: Optional[int] = None,
21 |         num_hidden_layers: int = 24,
22 |         num_heads: int = 4,
23 |         attn_mode: str = "fused_chunk",
24 |         feature_map: str = "elementwise_product",
25 |         tie_feature_map_qk: bool = False,
26 |         norm_q: bool = False,
27 |         norm_k: bool = False,
28 |         norm_feature_map: bool = False,
29 |         hidden_act: str = "swish",
30 |         max_position_embeddings: int = 2048,
31 |         elementwise_affine: Optional[bool] = True,
32 |         norm_eps: float = 1e-6,
33 |         use_cache: bool = True,
34 |         pad_token_id: int = None,
35 |         bos_token_id: int = 1,
36 |         eos_token_id: int = 2,
37 |         tie_word_embeddings: bool = False,
38 |         initializer_range: float = 0.02,
39 |         fuse_cross_entropy: bool = True,
40 |         **kwargs
41 |     ):
42 |         self.vocab_size = vocab_size
43 |         self.max_position_embeddings = max_position_embeddings
44 |         self.hidden_size = hidden_size
45 |         self.expand_k = expand_k
46 |         self.expand_v = expand_v
47 |         self.hidden_ratio = hidden_ratio
48 |         self.intermediate_size = intermediate_size
49 |         self.num_hidden_layers = num_hidden_layers
50 |         self.num_heads = num_heads
51 |         self.attn_mode = attn_mode
52 |         self.feature_map = feature_map
53 |         self.tie_feature_map_qk = tie_feature_map_qk
54 |         self.norm_q = norm_q
55 |         self.norm_k = norm_k
56 |         self.norm_feature_map = norm_feature_map
57 |         self.hidden_act = hidden_act
58 |         self.elementwise_affine = elementwise_affine
59 |         self.norm_eps = norm_eps
60 |         self.use_cache = use_cache
61 |         self.initializer_range = initializer_range
62 |         self.fuse_cross_entropy = fuse_cross_entropy
63 | 
64 |         super().__init__(
65 |             pad_token_id=pad_token_id,
66 |             bos_token_id=bos_token_id,
67 |             eos_token_id=eos_token_id,
68 |             tie_word_embeddings=tie_word_embeddings,
69 |             **kwargs,
70 |         )
71 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/models/mamba/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
 4 | 
 5 | from fla.models.mamba.configuration_mamba import MambaConfig
 6 | from fla.models.mamba.modeling_mamba import (MambaBlock, MambaForCausalLM,
 7 |                                              MambaModel)
 8 | 
 9 | AutoConfig.register(MambaConfig.model_type, MambaConfig, True)
10 | AutoModel.register(MambaConfig, MambaModel, True)
11 | AutoModelForCausalLM.register(MambaConfig, MambaForCausalLM, True)
12 | 
13 | 
14 | __all__ = ['MambaConfig', 'MambaForCausalLM', 'MambaModel', 'MambaBlock']
15 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/models/retnet/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
 4 | 
 5 | from fla.models.retnet.configuration_retnet import RetNetConfig
 6 | from fla.models.retnet.modeling_retnet import RetNetForCausalLM, RetNetModel
 7 | 
 8 | AutoConfig.register(RetNetConfig.model_type, RetNetConfig)
 9 | AutoModel.register(RetNetConfig, RetNetModel)
10 | AutoModelForCausalLM.register(RetNetConfig, RetNetForCausalLM)
11 | 
12 | 
13 | __all__ = ['RetNetConfig', 'RetNetForCausalLM', 'RetNetModel']
14 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/models/rwkv6/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
 4 | 
 5 | from fla.models.rwkv6.configuration_rwkv6 import RWKV6Config
 6 | from fla.models.rwkv6.modeling_rwkv6 import RWKV6ForCausalLM, RWKV6Model
 7 | 
 8 | AutoConfig.register(RWKV6Config.model_type, RWKV6Config)
 9 | AutoModel.register(RWKV6Config, RWKV6Model)
10 | AutoModelForCausalLM.register(RWKV6Config, RWKV6ForCausalLM)
11 | 
12 | 
13 | __all__ = ['RWKV6Config', 'RWKV6ForCausalLM', 'RWKV6Model']
14 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/models/rwkv6/configuration_rwkv6.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from typing import Optional
 4 | 
 5 | from transformers.configuration_utils import PretrainedConfig
 6 | 
 7 | 
 8 | class RWKV6Config(PretrainedConfig):
 9 | 
10 |     model_type = 'rwkv6'
11 |     keys_to_ignore_at_inference = ['past_key_values']
12 | 
13 |     def __init__(
14 |         self,
15 |         attn_mode: str = "chunk",
16 |         vocab_size: int = 32000,
17 |         hidden_size: int = 2048,
18 |         expand_k: int = 0.5,
19 |         expand_v: int = 1,
20 |         hidden_ratio: Optional[int] = 3.5,
21 |         intermediate_size: Optional[int] = None,
22 |         use_glu: Optional[bool] = False,
23 |         num_hidden_layers: int = 24,
24 |         num_heads: int = 4,
25 |         proj_low_rank_dim: int = 32,
26 |         gate_low_rank_dim: int = 64,
27 |         hidden_act: str = "sqrelu",
28 |         max_position_embeddings: int = 2048,
29 |         eps: float = 1e-6,
30 |         use_cache: bool = True,
31 |         pad_token_id: int = None,
32 |         bos_token_id: int = 1,
33 |         eos_token_id: int = 2,
34 |         tie_word_embeddings: bool = False,
35 |         initializer_range: float = 0.02,
36 |         fuse_norm: bool = True,
37 |         fuse_cross_entropy: bool = True,
38 |         **kwargs
39 |     ):
40 |         self.vocab_size = vocab_size
41 |         self.max_position_embeddings = max_position_embeddings
42 |         self.hidden_size = hidden_size
43 |         self.expand_k = expand_k
44 |         self.expand_v = expand_v
45 |         self.hidden_ratio = hidden_ratio
46 |         self.intermediate_size = intermediate_size
47 |         self.use_glu = use_glu
48 |         self.num_hidden_layers = num_hidden_layers
49 |         self.num_heads = num_heads
50 |         self.proj_low_rank_dim = proj_low_rank_dim
51 |         self.gate_low_rank_dim = gate_low_rank_dim
52 |         self.attn_mode = attn_mode
53 |         self.hidden_act = hidden_act
54 |         self.eps = eps
55 |         self.use_cache = use_cache
56 |         self.initializer_range = initializer_range
57 |         self.fuse_norm = fuse_norm
58 |         self.fuse_cross_entropy = fuse_cross_entropy
59 | 
60 |         super().__init__(
61 |             pad_token_id=pad_token_id,
62 |             bos_token_id=bos_token_id,
63 |             eos_token_id=eos_token_id,
64 |             tie_word_embeddings=tie_word_embeddings,
65 |             **kwargs,
66 |         )
67 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/models/transformer/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
 4 | 
 5 | from fla.models.transformer.configuration_transformer import TransformerConfig
 6 | from fla.models.transformer.modeling_transformer import (
 7 |     TransformerForCausalLM, TransformerModel)
 8 | 
 9 | AutoConfig.register(TransformerConfig.model_type, TransformerConfig)
10 | AutoModel.register(TransformerConfig, TransformerModel)
11 | AutoModelForCausalLM.register(TransformerConfig, TransformerForCausalLM)
12 | 
13 | 
14 | __all__ = ['TransformerConfig', 'TransformerForCausalLM', 'TransformerModel']
15 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/models/transformer/configuration_transformer.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from typing import Optional
 4 | 
 5 | from transformers.configuration_utils import PretrainedConfig
 6 | 
 7 | 
 8 | class TransformerConfig(PretrainedConfig):
 9 | 
10 |     model_type = 'transformer'
11 |     keys_to_ignore_at_inference = ['past_key_values']
12 | 
13 |     def __init__(
14 |         self,
15 |         vocab_size: int = 32000,
16 |         hidden_size: int = 2048,
17 |         hidden_ratio: Optional[int] = 4,
18 |         intermediate_size: Optional[int] = None,
19 |         num_hidden_layers: int = 24,
20 |         num_heads: int = 32,
21 |         num_kv_heads: int = None,
22 |         hidden_act: str = "swish",
23 |         max_position_embeddings: int = 2048,
24 |         initializer_range: float = 0.02,
25 |         elementwise_affine: Optional[bool] = True,
26 |         norm_eps: float = 1e-6,
27 |         use_cache: bool = True,
28 |         pad_token_id: int = None,
29 |         bos_token_id: int = 1,
30 |         eos_token_id: int = 2,
31 |         tie_word_embeddings: bool = False,
32 |         attention_bias: bool = False,
33 |         fuse_norm: bool = True,
34 |         fuse_cross_entropy: bool = True,
35 |         **kwargs,
36 |     ):
37 |         self.vocab_size = vocab_size
38 |         self.max_position_embeddings = max_position_embeddings
39 |         self.hidden_size = hidden_size
40 |         self.hidden_ratio = hidden_ratio
41 |         self.intermediate_size = intermediate_size
42 |         self.num_hidden_layers = num_hidden_layers
43 |         self.num_heads = num_heads
44 |         self.num_kv_heads = num_kv_heads
45 | 
46 |         self.hidden_act = hidden_act
47 |         self.initializer_range = initializer_range
48 |         self.elementwise_affine = elementwise_affine
49 |         self.norm_eps = norm_eps
50 |         self.use_cache = use_cache
51 |         self.attention_bias = attention_bias
52 |         self.fuse_cross_entropy = fuse_cross_entropy
53 |         self.fuse_norm = fuse_norm
54 | 
55 |         super().__init__(
56 |             pad_token_id=pad_token_id,
57 |             bos_token_id=bos_token_id,
58 |             eos_token_id=eos_token_id,
59 |             tie_word_embeddings=tie_word_embeddings,
60 |             **kwargs,
61 |         )
62 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/modules/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from fla.modules.convolution import (ImplicitLongConvolution, LongConvolution,
 4 |                                      ShortConvolution)
 5 | from fla.modules.fused_cross_entropy import FusedCrossEntropyLoss
 6 | from fla.modules.fused_norm_gate import (FusedLayerNormSwishGate,
 7 |                                          FusedLayerNormSwishGateLinear,
 8 |                                          FusedRMSNormSwishGate,
 9 |                                          FusedRMSNormSwishGateLinear)
10 | from fla.modules.layernorm import (LayerNorm, LayerNormLinear, RMSNorm,
11 |                                    RMSNormLinear)
12 | from fla.modules.rotary import RotaryEmbedding
13 | 
14 | __all__ = [
15 |     'ImplicitLongConvolution', 'LongConvolution', 'ShortConvolution',
16 |     'FusedCrossEntropyLoss',
17 |     'LayerNorm', 'LayerNormLinear', 'RMSNorm', 'RMSNormLinear',
18 |     'FusedLayerNormSwishGate', 'FusedLayerNormSwishGateLinear', 'FusedRMSNormSwishGate', 'FusedRMSNormSwishGateLinear',
19 |     'RotaryEmbedding'
20 | ]
21 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/ops/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from .based import fused_chunk_based, parallel_based
 4 | from .gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla
 5 | from .retention import (chunk_retention, fused_chunk_retention,
 6 |                         fused_recurrent_retention, parallel_retention)
 7 | 
 8 | __all__ = [
 9 |     'fused_chunk_based',
10 |     'parallel_based',
11 |     'chunk_gla',
12 |     'fused_chunk_gla',
13 |     'fused_recurrent_gla',
14 |     'chunk_retention',
15 |     'fused_chunk_retention',
16 |     'fused_recurrent_retention',
17 |     'parallel_retention'
18 | ]
19 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/ops/abc/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from .chunk import chunk_abc
 4 | from .chunk_gate import chunk_gated_abc
 5 | from .recurrent_fuse import fused_recurrent_gated_abc
 6 | 
 7 | __all__ = [
 8 |     'chunk_abc',
 9 |     'chunk_gated_abc',
10 |     'fused_recurrent_gated_abc'
11 | ]
12 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/ops/based/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from .chunk_fuse import fused_chunk_based
 4 | from .parallel import parallel_based
 5 | 
 6 | __all__ = [
 7 |     'fused_chunk_based',
 8 |     'parallel_based'
 9 | ]
10 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/ops/delta_rule/README.md:
--------------------------------------------------------------------------------
1 | - Delta Rule
2 | 
3 | The implementation of delta rule described in https://arxiv.org/abs/2102.11174 
4 | 
5 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/ops/delta_rule/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from .chunk_fuse import fused_chunk_delta_rule
 4 | from .recurrent_fuse import fused_recurrent_linear_attn_delta_rule
 5 | from .chunk import chunk_delta_rule
 6 | 
 7 | __all__ = [
 8 |     'fused_chunk_delta_rule',
 9 |     'fused_recurrent_linear_attn_delta_rule',
10 |     'chunk_delta_rule'
11 | ]
12 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/ops/gla/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from .chunk import chunk_gla
 4 | from .chunk_fuse import fused_chunk_gla
 5 | from .recurrent_fuse import fused_recurrent_gla
 6 | 
 7 | __all__ = [
 8 |     'chunk_gla',
 9 |     'fused_chunk_gla',
10 |     'fused_recurrent_gla'
11 | ]
12 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/ops/hgrn/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from .chunk import chunk_hgrn
 4 | from .recurrent_fuse import fused_recurrent_hgrn
 5 | 
 6 | __all__ = [
 7 |     'chunk_hgrn',
 8 |     'fused_recurrent_hgrn'
 9 | ]
10 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/ops/hgrn/naive.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from typing import Optional
 4 | 
 5 | import torch
 6 | 
 7 | 
 8 | def naive_recurrent_hgrn(
 9 |     x: torch.Tensor,
10 |     g: torch.Tensor,
11 |     initial_state: Optional[torch.Tensor] = None,
12 |     output_final_state: Optional[bool] = False
13 | ) -> torch.Tensor:
14 |     dtype = x.dtype
15 |     x, g = map(lambda i: i.float(), (x, g))
16 |     B, H, T, D = x.shape
17 | 
18 |     h = torch.zeros(B, H, D, dtype=torch.float, device=x.device)
19 |     o = torch.zeros_like(x)
20 | 
21 |     final_state = None
22 |     if initial_state is not None:
23 |         h += initial_state.detach()
24 | 
25 |     for i in range(T):
26 |         h = g[:, :, i].exp() * h + x[:, :, i]
27 |         o[:, :, i] = h
28 | 
29 |     if output_final_state:
30 |         final_state = h
31 |     return o.to(dtype), final_state
32 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/ops/linear_attn/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from .chunk import chunk_linear_attn
 4 | from .chunk_fuse import fused_chunk_linear_attn
 5 | from .recurrent_fuse import fused_recurrent_linear_attn
 6 | 
 7 | __all__ = [
 8 |     'chunk_linear_attn',
 9 |     'fused_chunk_linear_attn',
10 |     'fused_recurrent_linear_attn'
11 | ]
12 | 
13 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/ops/linear_attn/naive.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | import torch
 4 | from einops import rearrange
 5 | 
 6 | 
 7 | def torch_chunk_linear_attn(q, k, v, chunk_size=64):
 8 |     q = rearrange(q, 'b h (n c) d -> b h n c d', c = chunk_size) * (q.shape[-1] **-0.5)
 9 |     k = rearrange(k, 'b h (n c) d -> b h n c d', c = chunk_size)
10 |     v = rearrange(v, 'b h (n c) d -> b h n c d', c = chunk_size)
11 |     kv = k.transpose(-1, -2) @ v
12 |     kv = kv.cumsum(2)
13 |     kv = torch.cat([
14 |         torch.zeros_like(kv[:, :, :1]),
15 |         kv[:, :, :-1]
16 |     ], dim=2)
17 |     inter = q @ kv
18 |     intra = ((q @ k.transpose(-1, -2)).masked_fill_(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0)) @ v
19 |     o = inter + intra
20 |     return rearrange(o, 'b h n c d -> b h (n c) d')
21 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/ops/rebased/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | 
3 | from .parallel import parallel_rebased
4 | 
5 | __all__ = [
6 |     'parallel_rebased'
7 | ]
8 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/ops/retention/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from .chunk import chunk_retention
 4 | from .chunk_fuse import fused_chunk_retention
 5 | from .parallel import parallel_retention
 6 | from .recurrent_fuse import fused_recurrent_retention
 7 | 
 8 | __all__ = [
 9 |     'chunk_retention',
10 |     'fused_chunk_retention',
11 |     'parallel_retention',
12 |     'fused_recurrent_retention'
13 | ]
14 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/ops/retention/naive.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | import torch
 4 | 
 5 | 
 6 | def naive_retention(q, k, v):
 7 |     orig_type = q.dtype
 8 |     q, k, v = q.float(), k.float(), v.float()
 9 |     _, n_heads, seq_len, d_head = q.shape
10 |     s = (1 - q.new_tensor(2., dtype=torch.float).pow(-5. - q.new_tensor(range(n_heads), dtype=torch.float))).log2()
11 |     n = q.new_tensor(range(seq_len), dtype=torch.float)
12 |     n = torch.exp2((n.unsqueeze(-1) - n) * s.view(-1, 1, 1)) * n.unsqueeze(-1).ge(n)
13 |     s = torch.einsum('bhqd,bhkd,hqk->bhqk', q * d_head ** -0.5, k, n.to(q.dtype))
14 |     o = torch.einsum('bhqk,bhkd->bhqd', s, v)
15 |     return o.to(orig_type)
16 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/ops/rwkv4/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | 
3 | from .recurrent_fuse import fused_recurrent_rwkv4
4 | 
5 | __all__ = [
6 |     'fused_recurrent_rwkv4'
7 | ]
8 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/ops/rwkv6/__init__.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | from .chunk import chunk_rwkv6
 4 | from .recurrent_fuse import fused_recurrent_rwkv6
 5 | 
 6 | __all__ = [
 7 |     'chunk_rwkv6',
 8 |     'fused_recurrent_rwkv6'
 9 | ]
10 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/ops/simple_gla/README.md:
--------------------------------------------------------------------------------
1 | - Simple GLA
2 | 
3 | Gating mechanism in https://arxiv.org/abs/2103.02143. Compared to GLA, the gating is head-wise instead of elementwise. As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability. It is faster than GLA but has less expressive power. I will use it as a baseline for the GLA.
4 | 
5 | $S_{t+1} = g_{t+1} \odot S_{t} + K_{t+1} V_{t+1}^{\top}$ where $g$ is a scalar.


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/ops/simple_gla/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | 
3 | from .chunk import chunk_simple_gla
4 | 
5 | __all__ = [
6 |     'chunk_simple_gla'
7 | ]
8 | 
9 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/ops/simple_gla/naive.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | import torch
 4 | from einops import rearrange
 5 | 
 6 | 
 7 | def torch_simple_gla(q, k, v, g, chunk_size=64):
 8 |     q = rearrange(q, 'b h (n c) d -> b h n c d', c = chunk_size) * (q.shape[-1] ** -0.5)
 9 |     k = rearrange(k, 'b h (n c) d -> b h n c d', c = chunk_size)
10 |     v = rearrange(v, 'b h (n c) d -> b h n c d', c = chunk_size)    
11 |     g = rearrange(g, 'b h (n c) -> b h n c', c = chunk_size)
12 |     g = g.cumsum(-1)
13 |     kv = k.transpose(-1, -2) @ (v * (-g + g[:, :, :, -1, None]).exp()[..., None])
14 |     S = torch.zeros_like(kv)
15 | 
16 |     for i in range(1, g.shape[-2]):
17 |         S[:, :, i] = S[:, :, i-1].clone() * g[:, :, i-1, -1, None, None].exp() + kv[:, :, i-1]
18 |     
19 |     inter = (q * g[..., None].exp()) @ S
20 |     attn = q @ k.transpose(-1, -2)
21 |     attn = attn * (g[..., None] - g[..., None, :]).exp()
22 |     attn = attn.masked_fill(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0)
23 |     intra = attn @ v
24 |     o = inter + intra
25 |     return rearrange(o, 'b h n c d -> b h (n c) d')
26 | 
27 | 
28 | def torch_simple_gla_recurrent(q, k, v, g, chunk_size=64):
29 |     # q = rearrange(q, 'b h (n c) d -> b h n c d', c = chunk_size) * (q.shape[-1] ** -0.5)
30 |     # k = rearrange(k, 'b h (n c) d -> b h n c d', c = chunk_size)
31 |     # v = rearrange(v, 'b h (n c) d -> b h n c d', c = chunk_size)    
32 |     # g = rearrange(g, 'b h (n c) -> b h n c', c = chunk_size)
33 |     # g = g.cumsum(-1)
34 |     # kv = k.transpose(-1, -2) @ v
35 | 
36 |     B, H, T, DK = q.shape
37 |     q = q * (DK ** -0.5)
38 |     _, _, _, DV = v.shape
39 |     S = torch.zeros(B, H, DK, DV).to(q)
40 |     o = torch.zeros(B, H, T, DV).to(q)
41 |     for i in range(T):
42 |         gate = g[:, :, i].exp()
43 |         key = k[:, :, i]
44 |         value = v[:, :, i]
45 |         kv = key.unsqueeze(-1) * value.unsqueeze(-2)
46 |         S = S.clone() * gate.unsqueeze(-1).unsqueeze(-1) + kv
47 |         q_i = q[:, :, i, :] 
48 |         o_i = (q_i.unsqueeze(-1) * S).sum(-2)
49 |         o[:, :, i] = o_i
50 | 
51 |     return o 
52 | 
53 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/fla/utils.py:
--------------------------------------------------------------------------------
 1 | # -*- coding: utf-8 -*-
 2 | 
 3 | import functools
 4 | 
 5 | import torch
 6 | 
 7 | 
 8 | def contiguous(fn):
 9 |     @functools.wraps(fn)
10 |     def wrapper(ctx, *args, **kwargs):
11 |         return fn(ctx,
12 |                   *(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args),
13 |                   **{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()})
14 |     return wrapper
15 | 
16 | 
17 | def require_version(version, hint):
18 |     def decorator(fn):
19 |         @functools.wraps(fn)
20 |         def wrapper(ctx, *args, **kwargs):
21 |             from transformers.utils.versions import require_version
22 |             require_version(version, hint)
23 |             return fn(ctx,
24 |                       *(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args),
25 |                       **{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()})
26 |         return wrapper
27 |     return decorator
28 | 
29 | 
30 | def checkpoint(func):
31 |     def wrapper(*args, **kwargs):
32 |         return torch.utils.checkpoint.checkpoint(func, *args, **kwargs)
33 |     return wrapper
34 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/merge/merge_lora.py:
--------------------------------------------------------------------------------
 1 | from collections import OrderedDict
 2 | import os
 3 | import sys
 4 | from typing import Dict
 5 | import typing
 6 | import torch
 7 | 
 8 | if '-h' in sys.argv or '--help' in sys.argv:
 9 |     print(f'Usage: python3 {sys.argv[0]} [--use-gpu] <lora_alpha> <base_model.pth> <lora_checkpoint.pth> <output.pth>')
10 | 
11 | if sys.argv[1] == '--use-gpu':
12 |     device = 'cuda'
13 |     lora_alpha, base_model, lora, output = float(sys.argv[2]), sys.argv[3], sys.argv[4], sys.argv[5]
14 | else:
15 |     device = 'cpu'
16 |     lora_alpha, base_model, lora, output = float(sys.argv[1]), sys.argv[2], sys.argv[3], sys.argv[4]
17 | 
18 | 
19 | with torch.no_grad():
20 |     w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu')
21 |     # merge LoRA-only slim checkpoint into the main weights
22 |     w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu')
23 |     for k in w_lora.keys():
24 |         w[k] = w_lora[k]
25 |     output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict()
26 |     # merge LoRA weights
27 |     keys = list(w.keys())
28 |     for k in keys:
29 |         if k.endswith('.weight'):
30 |             prefix = k[:-len('.weight')]
31 |             lora_A = prefix + '.lora_A'
32 |             lora_B = prefix + '.lora_B'
33 |             if lora_A in keys:
34 |                 assert lora_B in keys
35 |                 print(f'merging {lora_A} and {lora_B} into {k}')
36 |                 assert w[lora_B].shape[1] == w[lora_A].shape[0]
37 |                 lora_r = w[lora_B].shape[1]
38 |                 w[k] = w[k].to(device=device)
39 |                 w[lora_A] = w[lora_A].to(device=device)
40 |                 w[lora_B] = w[lora_B].to(device=device)
41 |                 w[k] += w[lora_B] @ w[lora_A] * (lora_alpha / lora_r)
42 |                 output_w[k] = w[k].to(device='cpu', copy=True)
43 |                 del w[k]
44 |                 del w[lora_A]
45 |                 del w[lora_B]
46 |                 continue
47 | 
48 |         if 'lora' not in k:
49 |             print(f'retaining {k}')
50 |             output_w[k] = w[k].clone()
51 |             del w[k]
52 |     torch.save(output_w, output)
53 | 


--------------------------------------------------------------------------------
/finetune/lora/v6/merge/merge_pissa.py:
--------------------------------------------------------------------------------
 1 | from collections import OrderedDict
 2 | import os
 3 | import sys
 4 | from typing import Dict
 5 | import typing
 6 | import torch
 7 | 
 8 | if '-h' in sys.argv or '--help' in sys.argv:
 9 |     print(f'Usage: python3 {sys.argv[0]} [--use-gpu] <base_model.pth> <lora_init.pth> <lora_checkpoint.pth> <output.pth>')
10 | 
11 | if sys.argv[1] == '--use-gpu':
12 |     device = 'cuda'
13 |     base_model, init_lora, lora, output = sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5]
14 | else:
15 |     device = 'cpu'
16 |     base_model, init_lora, lora, output = sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4]
17 | 
18 | 
19 | with torch.no_grad():
20 |     w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu')
21 |     # merge LoRA-only slim checkpoint into the main weights
22 |     w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu')
23 |     w_init_lora: Dict[str, torch.Tensor] = torch.load(init_lora, map_location='cpu')
24 |     for k in w_lora.keys():
25 |         w[k] = w_lora[k]
26 |     output_w: typing.OrderedDict[str, torch.Tensor] = OrderedDict()
27 |     # merge LoRA weights
28 |     keys = list(w.keys())
29 |     for k in keys:
30 |         if k.endswith('.weight'):
31 |             prefix = k[:-len('.weight')]
32 |             lora_A = prefix + '.lora_A'
33 |             lora_B = prefix + '.lora_B'
34 |             init_lora_A = prefix + '.init_lora_A'
35 |             init_lora_B = prefix + '.init_lora_B'
36 |             if lora_A in keys:
37 |                 assert lora_B in keys
38 |                 print(f'merging {lora_A} and {lora_B} into {k}')
39 |                 assert w[lora_B].shape[1] == w[lora_A].shape[0]
40 |                 lora_r = w[lora_B].shape[1]
41 |                 w[k] = w[k].to(device=device)
42 |                 w[lora_A] = w[lora_A].to(device=device)
43 |                 w[lora_B] = w[lora_B].to(device=device)
44 |                 w_init_lora[init_lora_A] = w_init_lora[init_lora_A].to(device=device)
45 |                 w_init_lora[init_lora_B] = w_init_lora[init_lora_B].to(device=device)
46 |                 w[k] = (w[k]- w_init_lora[init_lora_B] @ w_init_lora[init_lora_A]).to(dtype=torch.bfloat16)
47 |                 w[k] +=  w[lora_B] @ w[lora_A]
48 |                 output_w[k] = w[k].to(device='cpu', copy=True)
49 |                 del w[k]
50 |                 del w[lora_A]
51 |                 del w[lora_B]
52 |                 continue
53 | 
54 |         if 'lora' not in k:
55 |             print(f'retaining {k}')
56 |             output_w[k] = w[k].clone()
57 |             del w[k]
58 |     torch.save(output_w, output)


--------------------------------------------------------------------------------
/finetune/lora/v6/merge/merge_state.py:
--------------------------------------------------------------------------------
 1 | from collections import OrderedDict
 2 | import os
 3 | import sys
 4 | from typing import Dict
 5 | import typing
 6 | import torch
 7 | import bitsandbytes as bnb
 8 | from argparse import ArgumentParser
 9 | 
10 | parser = ArgumentParser()
11 | parser.add_argument("--base_model", default="", type=str)
12 | parser.add_argument("--state_checkpoint", default="", type=str)
13 | parser.add_argument("--output", default="", type=str)
14 | # parser.add_argument("--quant", default="none", type=str)
15 | parser.add_argument("--device", default="cuda", type=str)
16 | # parser.add_argument("--lora_alpha", default=16, type=int)
17 | args = parser.parse_args()
18 | device= args.device
19 | base_model = args.base_model
20 | state= args.state_checkpoint
21 | output= args.output
22 | 
23 | 
24 | with torch.no_grad():
25 |     w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu')
26 |     # merge LoRA-only slim checkpoint into the main weights
27 |     w_state: Dict[str, torch.Tensor] = torch.load(state, map_location='cpu')
28 | 
29 |     for k in w_state.keys():
30 |         print(k)
31 |         w[k] = w_state[k]
32 |     # merge LoRA weights
33 |     for k in w.keys():
34 |         print(k)
35 |     
36 |     torch.save(w, output)


--------------------------------------------------------------------------------
/finetune/lora/v6/requirements.txt:
--------------------------------------------------------------------------------
1 | pytorch-lightning==1.9.5
2 | bitsandbytes
3 | deepspeed
4 | einops
5 | triton==2.2.0


--------------------------------------------------------------------------------
/finetune/lora/v6/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/finetune/lora/v6/src/__init__.py


--------------------------------------------------------------------------------
/finetune/lora/v6/src/infctx_module.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | ######state
 3 | class TimeMixState:
 4 |     def __init__(self, shift_state: torch.Tensor, wkv_state: torch.Tensor):
 5 |         self.shift_state = shift_state
 6 |         self.wkv_state = wkv_state
 7 | 
 8 | 
 9 | class ChannelMixState:
10 |     def __init__(self, shift_state: torch.Tensor):
11 |         self.shift_state = shift_state
12 | 
13 | 
14 | class BlockState:
15 |     def __init__(self, time_mix_state: TimeMixState,
16 |                  channel_mix_state: ChannelMixState):
17 |         self.time_mix_state = time_mix_state
18 |         self.channel_mix_state = channel_mix_state
19 | 
20 | class BlockStateList:
21 | 
22 |     def __init__(self, shift_states, wkv_states):
23 |         self.wkv_states = wkv_states
24 |         self.shift_states = shift_states
25 | 
26 |     @staticmethod
27 |     def create(N, B, C, H, device, dtype):
28 |         result = BlockStateList.empty(N, B, C, H, device, dtype)
29 |         result.wkv_states[:] = 0
30 |         result.wkv_states[:] = 0
31 |         result.shift_states[:] = 0
32 |         return result
33 | 
34 |     @staticmethod
35 |     def empty(N, B, C, H, device, dtype):
36 |         wkv_states = torch.empty((N, B, H, C//H, C//H),
37 |                                  device=device,
38 |                                  dtype=torch.bfloat16)
39 |         shift_states = torch.empty((N, 2, B, C), device=device, dtype=dtype)
40 |         return BlockStateList(shift_states, wkv_states)
41 | 
42 |     def __getitem__(self, layer: int):
43 |         return BlockState(
44 |             TimeMixState(self.shift_states[layer, 0], self.wkv_states[layer]),
45 |             ChannelMixState(self.shift_states[layer, 1]))
46 | 
47 |     def __setitem__(self, layer: int, state: BlockState):
48 |         self.shift_states[layer, 0] = state.time_mix_state.shift_state
49 |         self.wkv_states[layer] = state.time_mix_state.wkv_state
50 |         self.shift_states[layer, 1] = state.channel_mix_state.shift_state
51 | 
52 | 
53 | 


--------------------------------------------------------------------------------
/finetune/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.2.0
2 | pytorch_lightning==1.9.5
3 | deepspeed==0.12.6
4 | bitsandbytes==0.43.1
5 | einops==0.8.0
6 | triton==2.2.0
7 | transformers==4.41.1
8 | numpy==1.26.4


--------------------------------------------------------------------------------
/frontend/index.html:
--------------------------------------------------------------------------------
 1 | <!DOCTYPE html>
 2 | <html lang="en">
 3 | <head>
 4 |   <meta charset="UTF-8" />
 5 |   <meta content="width=device-width, initial-scale=1.0" name="viewport" />
 6 |   <title>RWKV-Runner</title>
 7 |   <link href="./src/assets/images/logo.png" rel="icon" type="image/x-icon">
 8 | </head>
 9 | <body>
10 | <div id="root"></div>
11 | <script src="./src/main.tsx" type="module"></script>
12 | </body>
13 | </html>
14 | 
15 | 


--------------------------------------------------------------------------------
/frontend/package.json:
--------------------------------------------------------------------------------
 1 | {
 2 |   "name": "frontend",
 3 |   "private": true,
 4 |   "version": "0.0.0",
 5 |   "scripts": {
 6 |     "dev": "vite --host",
 7 |     "build": "tsc && vite build",
 8 |     "preview": "vite preview --host"
 9 |   },
10 |   "dependencies": {
11 |     "@fluentui/react-components": "^9.47.2",
12 |     "@fluentui/react-icons": "^2.0.201",
13 |     "@magenta/music": "^1.23.1",
14 |     "@microsoft/fetch-event-source": "^2.0.1",
15 |     "@primer/octicons-react": "^19.1.0",
16 |     "@tabler/icons-react": "^3.31.0",
17 |     "abcjs": "^6.2.3",
18 |     "chart.js": "^4.3.0",
19 |     "classnames": "^2.3.2",
20 |     "compare-versions": "^6.1.1",
21 |     "delay": "^6.0.0",
22 |     "file-saver": "^2.0.5",
23 |     "html-midi-player": "^1.5.0",
24 |     "html-to-image": "^1.11.13",
25 |     "i18next": "^22.4.15",
26 |     "katex": "^0.16.9",
27 |     "lodash-es": "^4.17.21",
28 |     "mobx": "^6.9.0",
29 |     "mobx-react-lite": "^3.4.3",
30 |     "pdfjs-dist": "^4.0.189",
31 |     "react": "^18.2.0",
32 |     "react-beautiful-dnd": "^13.1.1",
33 |     "react-chartjs-2": "^5.2.0",
34 |     "react-dom": "^18.2.0",
35 |     "react-draggable": "^4.4.6",
36 |     "react-i18next": "^12.2.2",
37 |     "react-markdown": "^8.0.7",
38 |     "react-router": "^6.11.1",
39 |     "react-router-dom": "^6.11.1",
40 |     "react-toastify": "^9.1.3",
41 |     "rehype-highlight": "^6.0.0",
42 |     "rehype-katex": "^6.0.3",
43 |     "rehype-raw": "^6.1.1",
44 |     "remark-breaks": "^3.0.3",
45 |     "remark-gfm": "^3.0.1",
46 |     "remark-math": "^5.1.1",
47 |     "usehooks-ts": "^2.9.1",
48 |     "uuid": "^9.0.0"
49 |   },
50 |   "devDependencies": {
51 |     "@ianvs/prettier-plugin-sort-imports": "^4.3.1",
52 |     "@tailwindcss/typography": "^0.5.10",
53 |     "@types/file-saver": "^2.0.7",
54 |     "@types/lodash-es": "^4.17.12",
55 |     "@types/react": "^18.2.6",
56 |     "@types/react-beautiful-dnd": "^13.1.4",
57 |     "@types/react-dom": "^18.2.4",
58 |     "@types/uuid": "^9.0.1",
59 |     "@vitejs/plugin-react": "^4.0.0",
60 |     "autoprefixer": "^10.4.14",
61 |     "postcss": "^8.4.23",
62 |     "prettier": "^3.3.3",
63 |     "prettier-plugin-tailwindcss": "^0.6.6",
64 |     "rollup-plugin-visualizer": "^5.9.0",
65 |     "sass": "^1.62.1",
66 |     "tailwindcss": "^3.3.2",
67 |     "typescript": "^5.0.4",
68 |     "vite": "^4.5.9",
69 |     "vite-plugin-top-level-await": "^1.3.1"
70 |   }
71 | }
72 | 


--------------------------------------------------------------------------------
/frontend/postcss.config.js:
--------------------------------------------------------------------------------
1 | module.exports = {
2 |   plugins: {
3 |     tailwindcss: {},
4 |     autoprefixer: {},
5 |   },
6 | }
7 | 


--------------------------------------------------------------------------------
/frontend/prettier.config.js:
--------------------------------------------------------------------------------
 1 | /** @type {import('prettier').Config} */
 2 | module.exports = {
 3 |   endOfLine: 'lf',
 4 |   semi: false,
 5 |   singleQuote: true,
 6 |   tabWidth: 2,
 7 |   trailingComma: 'es5',
 8 |   importOrder: ['^(react/(.*)$)|^(react$)', '<THIRD_PARTY_MODULES>', '^[./]'],
 9 |   importOrderSeparation: false,
10 |   importOrderSortSpecifiers: true,
11 |   importOrderBuiltinModulesToTop: true,
12 |   importOrderParserPlugins: ['typescript', 'jsx', 'decorators-legacy'],
13 |   importOrderMergeDuplicateImports: true,
14 |   importOrderCombineTypeAndValueImports: true,
15 |   plugins: [
16 |     '@ianvs/prettier-plugin-sort-imports',
17 |     'prettier-plugin-tailwindcss',
18 |   ],
19 | }
20 | 


--------------------------------------------------------------------------------
/frontend/src/_locales/i18n-react.ts:
--------------------------------------------------------------------------------
 1 | import i18n, { changeLanguage } from 'i18next'
 2 | import { initReactI18next } from 'react-i18next'
 3 | import { getUserLanguage } from '../utils'
 4 | import { resources } from './resources'
 5 | 
 6 | i18n
 7 |   .use(initReactI18next)
 8 |   .init({
 9 |     resources,
10 |     interpolation: {
11 |       escapeValue: false, // not needed for react as it escapes by default
12 |     },
13 |   })
14 |   .then(() => {
15 |     changeLanguage(getUserLanguage())
16 |   })
17 | 


--------------------------------------------------------------------------------
/frontend/src/_locales/i18n.ts:
--------------------------------------------------------------------------------
 1 | import i18n, { changeLanguage } from 'i18next'
 2 | import { getUserLanguage } from '../utils'
 3 | import { resources } from './resources'
 4 | 
 5 | i18n
 6 |   .init({
 7 |     resources,
 8 |   })
 9 |   .then(() => {
10 |     changeLanguage(getUserLanguage())
11 |   })
12 | 


--------------------------------------------------------------------------------
/frontend/src/_locales/resources.ts:
--------------------------------------------------------------------------------
 1 | import ja from './ja/main.json'
 2 | import zhHans from './zh-hans/main.json'
 3 | 
 4 | export const resources = {
 5 |   zh: {
 6 |     translation: zhHans,
 7 |   },
 8 |   // de: {
 9 |   //   translation: de,
10 |   // },
11 |   // es: {
12 |   //   translation: es,
13 |   // },
14 |   // fr: {
15 |   //   translation: fr,
16 |   // },
17 |   // in: {
18 |   //   translation: inTrans,
19 |   // },
20 |   // it: {
21 |   //   translation: it,
22 |   // },
23 |   ja: {
24 |     translation: ja,
25 |   },
26 |   // ko: {
27 |   //   translation: ko,
28 |   // },
29 |   // pt: {
30 |   //   translation: pt,
31 |   // },
32 |   // ru: {
33 |   //   translation: ru,
34 |   // },
35 |   // zhHant: {
36 |   //   translation: zhHant,
37 |   // },
38 | }
39 | 


--------------------------------------------------------------------------------
/frontend/src/assets/images/banner.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/frontend/src/assets/images/banner.jpg


--------------------------------------------------------------------------------
/frontend/src/assets/images/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/frontend/src/assets/images/logo.png


--------------------------------------------------------------------------------
/frontend/src/assets/images/strategy.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/frontend/src/assets/images/strategy.jpg


--------------------------------------------------------------------------------
/frontend/src/assets/images/strategy_zh.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/frontend/src/assets/images/strategy_zh.jpg


--------------------------------------------------------------------------------
/frontend/src/components/ConfigSelector.tsx:
--------------------------------------------------------------------------------
 1 | import { FC } from 'react'
 2 | import { Dropdown, Option, PresenceBadge } from '@fluentui/react-components'
 3 | import { observer } from 'mobx-react-lite'
 4 | import commonStore from '../stores/commonStore'
 5 | 
 6 | export const ConfigSelector: FC<{ size?: 'small' | 'medium' | 'large' }> =
 7 |   observer(({ size }) => {
 8 |     return (
 9 |       <Dropdown
10 |         size={size}
11 |         style={{ minWidth: 0 }}
12 |         listbox={{ style: { minWidth: 'fit-content' } }}
13 |         value={commonStore.getCurrentModelConfig().name}
14 |         selectedOptions={[commonStore.currentModelConfigIndex.toString()]}
15 |         onOptionSelect={(_, data) => {
16 |           if (data.optionValue)
17 |             commonStore.setCurrentConfigIndex(Number(data.optionValue))
18 |         }}
19 |       >
20 |         {commonStore.modelConfigs.map((config, index) => (
21 |           <Option key={index} value={index.toString()} text={config.name}>
22 |             <div className="flex grow justify-between">
23 |               {config.name}
24 |               {commonStore.modelSourceList.find(
25 |                 (item) => item.name === config.modelParameters.modelName
26 |               )?.isComplete && <PresenceBadge status="available" />}
27 |             </div>
28 |           </Option>
29 |         ))}
30 |       </Dropdown>
31 |     )
32 |   })
33 | 


--------------------------------------------------------------------------------
/frontend/src/components/CopyButton.tsx:
--------------------------------------------------------------------------------
 1 | import { FC, useState } from 'react'
 2 | import { CheckIcon, CopyIcon } from '@primer/octicons-react'
 3 | import { useTranslation } from 'react-i18next'
 4 | import { ClipboardSetText } from '../../wailsjs/runtime'
 5 | import { ToolTipButton } from './ToolTipButton'
 6 | 
 7 | export const CopyButton: FC<{ content: string; showDelay?: number }> = ({
 8 |   content,
 9 |   showDelay = 0,
10 | }) => {
11 |   const { t } = useTranslation()
12 |   const [copied, setCopied] = useState(false)
13 | 
14 |   const onClick = () => {
15 |     ClipboardSetText(content)
16 |       .then(() => setCopied(true))
17 |       .then(() =>
18 |         setTimeout(() => {
19 |           setCopied(false)
20 |         }, 600)
21 |       )
22 |   }
23 | 
24 |   return (
25 |     <ToolTipButton
26 |       desc={t('Copy')}
27 |       size="small"
28 |       appearance="subtle"
29 |       showDelay={showDelay}
30 |       icon={copied ? <CheckIcon /> : <CopyIcon />}
31 |       onClick={onClick}
32 |     />
33 |   )
34 | }
35 | 


--------------------------------------------------------------------------------
/frontend/src/components/CustomToastContainer.tsx:
--------------------------------------------------------------------------------
 1 | import { ToastContainer } from 'react-toastify'
 2 | import commonStore from '../stores/commonStore'
 3 | 
 4 | export const CustomToastContainer = () => (
 5 |   <ToastContainer
 6 |     style={{ width: '350px' }}
 7 |     position="top-center"
 8 |     autoClose={4000}
 9 |     pauseOnHover={true}
10 |     hideProgressBar={true}
11 |     newestOnTop={true}
12 |     closeOnClick={false}
13 |     rtl={false}
14 |     pauseOnFocusLoss={false}
15 |     draggable={false}
16 |     theme={commonStore.settings.darkMode ? 'dark' : 'light'}
17 |   />
18 | )
19 | 


--------------------------------------------------------------------------------
/frontend/src/components/DebugModeIndicator.tsx:
--------------------------------------------------------------------------------
 1 | import { FC } from 'react'
 2 | import classNames from 'classnames'
 3 | 
 4 | export const DebugModeIndicator: FC<{ showInDebugMode?: boolean }> = ({
 5 |   showInDebugMode = true,
 6 | }) => {
 7 |   if (import.meta.env.PROD) return <></>
 8 |   if (!showInDebugMode) return <></>
 9 |   return (
10 |     <div
11 |       className={classNames(
12 |         'absolute',
13 |         'right-1',
14 |         'top-1',
15 |         'p-1',
16 |         'rounded',
17 |         'bg-red-600',
18 |         'bg-opacity-50',
19 |         'text-white',
20 |         'font-semibold',
21 |         'text-opacity-50',
22 |         'text-xs'
23 |       )}
24 |     >
25 |       Debug Mode
26 |     </div>
27 |   )
28 | }
29 | 


--------------------------------------------------------------------------------
/frontend/src/components/Labeled.tsx:
--------------------------------------------------------------------------------
 1 | import { FC, ReactElement } from 'react'
 2 | import { Label, Tooltip } from '@fluentui/react-components'
 3 | import classnames from 'classnames'
 4 | 
 5 | export const Labeled: FC<{
 6 |   label: string
 7 |   desc?: string | null
 8 |   descComponent?: ReactElement
 9 |   content: ReactElement
10 |   flex?: boolean
11 |   spaceBetween?: boolean
12 |   breakline?: boolean
13 |   onMouseEnter?: () => void
14 |   onMouseLeave?: () => void
15 | }> = ({
16 |   label,
17 |   desc,
18 |   descComponent,
19 |   content,
20 |   flex,
21 |   spaceBetween,
22 |   breakline,
23 |   onMouseEnter,
24 |   onMouseLeave,
25 | }) => {
26 |   return (
27 |     <div
28 |       className={classnames(
29 |         !breakline ? 'items-center' : '',
30 |         flex ? 'flex' : 'grid grid-cols-2',
31 |         breakline ? 'flex-col' : '',
32 |         spaceBetween && 'justify-between'
33 |       )}
34 |     >
35 |       {desc || descComponent ? (
36 |         <Tooltip
37 |           content={descComponent ? descComponent : desc!}
38 |           showDelay={0}
39 |           hideDelay={0}
40 |           relationship="description"
41 |         >
42 |           <Label onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
43 |             {label}
44 |           </Label>
45 |         </Tooltip>
46 |       ) : (
47 |         <Label onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
48 |           {label}
49 |         </Label>
50 |       )}
51 |       {content}
52 |     </div>
53 |   )
54 | }
55 | 


--------------------------------------------------------------------------------
/frontend/src/components/LazyImportComponent.tsx:
--------------------------------------------------------------------------------
 1 | import { FC, LazyExoticComponent, ReactNode, Suspense } from 'react'
 2 | import { Spinner } from '@fluentui/react-components'
 3 | import { useTranslation } from 'react-i18next'
 4 | 
 5 | interface LazyImportComponentProps {
 6 |   lazyChildren: LazyExoticComponent<FC<any>>
 7 |   lazyProps?: any
 8 |   children?: ReactNode
 9 |   disableFallback?: boolean
10 | }
11 | 
12 | export const LazyImportComponent: FC<LazyImportComponentProps> = (props) => {
13 |   const { t } = useTranslation()
14 | 
15 |   return (
16 |     <Suspense
17 |       fallback={
18 |         !props.disableFallback && (
19 |           <div className="flex h-full w-full items-center justify-center">
20 |             <Spinner size="huge" label={t('Loading...')} />
21 |           </div>
22 |         )
23 |       }
24 |     >
25 |       <props.lazyChildren {...props.lazyProps}>
26 |         {props.children}
27 |       </props.lazyChildren>
28 |     </Suspense>
29 |   )
30 | }
31 | 


--------------------------------------------------------------------------------
/frontend/src/components/NumberInput.tsx:
--------------------------------------------------------------------------------
 1 | import React, { CSSProperties, FC } from 'react'
 2 | import { Input } from '@fluentui/react-components'
 3 | import { SliderOnChangeData } from '@fluentui/react-slider'
 4 | 
 5 | export const NumberInput: FC<{
 6 |   value: number
 7 |   min: number
 8 |   max: number
 9 |   step?: number
10 |   onChange?: (
11 |     ev: React.ChangeEvent<HTMLInputElement>,
12 |     data: SliderOnChangeData
13 |   ) => void
14 |   style?: CSSProperties
15 |   toFixed?: number
16 |   disabled?: boolean
17 | }> = ({ value, min, max, step, onChange, style, toFixed = 2, disabled }) => {
18 |   return (
19 |     <Input
20 |       type="number"
21 |       style={style}
22 |       value={value.toString()}
23 |       min={min}
24 |       max={max}
25 |       step={step}
26 |       disabled={disabled}
27 |       onChange={(e, data) => {
28 |         onChange?.(e, { value: Number(data.value) })
29 |       }}
30 |       onBlur={(e) => {
31 |         if (onChange) {
32 |           if (step) {
33 |             const offset = (min > 0 ? min : 0) - (max < 0 ? max : 0)
34 |             value = Number(
35 |               (Math.round((value - offset) / step) * step + offset).toFixed(
36 |                 toFixed
37 |               )
38 |             ) // avoid precision issues
39 |           }
40 |           onChange(e, { value: Math.max(Math.min(value, max), min) })
41 |         }
42 |       }}
43 |     />
44 |   )
45 | }
46 | 


--------------------------------------------------------------------------------
/frontend/src/components/Page.tsx:
--------------------------------------------------------------------------------
 1 | import React, { FC, ReactElement } from 'react'
 2 | import { Divider, Text } from '@fluentui/react-components'
 3 | 
 4 | export const Page: FC<{ title: string; content: ReactElement }> = ({
 5 |   title,
 6 |   content,
 7 | }) => {
 8 |   return (
 9 |     <div className="flex h-full flex-col gap-2 p-2">
10 |       <Text size={600}>{title}</Text>
11 |       <Divider style={{ flexGrow: 0 }} />
12 |       {content}
13 |     </div>
14 |   )
15 | }
16 | 


--------------------------------------------------------------------------------
/frontend/src/components/ReadButton.tsx:
--------------------------------------------------------------------------------
 1 | import { FC, useState } from 'react'
 2 | import { MuteIcon, UnmuteIcon } from '@primer/octicons-react'
 3 | import { observer } from 'mobx-react-lite'
 4 | import { useTranslation } from 'react-i18next'
 5 | import commonStore from '../stores/commonStore'
 6 | import { ToolTipButton } from './ToolTipButton'
 7 | 
 8 | const synth = window.speechSynthesis
 9 | 
10 | export const ReadButton: FC<{
11 |   content: string
12 |   inSpeaking?: boolean
13 |   showDelay?: number
14 |   setSpeakingOuter?: (speaking: boolean) => void
15 | }> = observer(
16 |   ({ content, inSpeaking = false, showDelay = 0, setSpeakingOuter }) => {
17 |     const { t } = useTranslation()
18 |     const [speaking, setSpeaking] = useState(inSpeaking)
19 |     let lang: string = commonStore.settings.language
20 |     if (lang === 'dev') lang = 'en'
21 | 
22 |     const setSpeakingInner = (speaking: boolean) => {
23 |       setSpeakingOuter?.(speaking)
24 |       setSpeaking(speaking)
25 |     }
26 | 
27 |     const startSpeak = () => {
28 |       synth.cancel()
29 | 
30 |       const utterance = new SpeechSynthesisUtterance(content)
31 |       const voices = synth.getVoices()
32 | 
33 |       let voice
34 |       if (lang === 'en')
35 |         voice = voices.find((v) =>
36 |           v.name.toLowerCase().includes('microsoft aria')
37 |         )
38 |       else if (lang === 'zh')
39 |         voice = voices.find((v) => v.name.toLowerCase().includes('xiaoyi'))
40 |       else if (lang === 'ja')
41 |         voice = voices.find((v) => v.name.toLowerCase().includes('nanami'))
42 |       if (!voice) voice = voices.find((v) => v.lang.substring(0, 2) === lang)
43 |       if (!voice) voice = voices.find((v) => v.lang === navigator.language)
44 | 
45 |       Object.assign(utterance, {
46 |         rate: 1,
47 |         volume: 1,
48 |         onend: () => setSpeakingInner(false),
49 |         onerror: () => setSpeakingInner(false),
50 |         voice: voice,
51 |       })
52 | 
53 |       synth.speak(utterance)
54 |       setSpeakingInner(true)
55 |     }
56 | 
57 |     const stopSpeak = () => {
58 |       synth.cancel()
59 |       setSpeakingInner(false)
60 |     }
61 | 
62 |     return (
63 |       <ToolTipButton
64 |         desc={t('Read Aloud')}
65 |         size="small"
66 |         appearance="subtle"
67 |         showDelay={showDelay}
68 |         icon={speaking ? <MuteIcon /> : <UnmuteIcon />}
69 |         onClick={speaking ? stopSpeak : startSpeak}
70 |       />
71 |     )
72 |   }
73 | )
74 | 


--------------------------------------------------------------------------------
/frontend/src/components/ResetConfigsButton.tsx:
--------------------------------------------------------------------------------
 1 | import React, { FC } from 'react'
 2 | import { ArrowReset20Regular } from '@fluentui/react-icons'
 3 | import { useTranslation } from 'react-i18next'
 4 | import {
 5 |   defaultModelConfigs,
 6 |   defaultModelConfigsMac,
 7 | } from '../pages/defaultConfigs'
 8 | import commonStore from '../stores/commonStore'
 9 | import { DialogButton } from './DialogButton'
10 | 
11 | export const ResetConfigsButton: FC<{ afterConfirm?: () => void }> = ({
12 |   afterConfirm,
13 | }) => {
14 |   const { t } = useTranslation()
15 |   return (
16 |     <DialogButton
17 |       icon={<ArrowReset20Regular />}
18 |       tooltip={t('Reset All Configs')}
19 |       title={t('Reset All Configs')}
20 |       content={t(
21 |         'Are you sure you want to reset all configs? This will obtain the latest preset configs, but will override your custom configs and cannot be undone.'
22 |       )}
23 |       onConfirm={() => {
24 |         commonStore.setModelConfigs(
25 |           commonStore.platform !== 'darwin'
26 |             ? defaultModelConfigs
27 |             : defaultModelConfigsMac,
28 |           false
29 |         )
30 |         commonStore.setCurrentConfigIndex(0, true)
31 |         afterConfirm?.()
32 |       }}
33 |     />
34 |   )
35 | }
36 | 


--------------------------------------------------------------------------------
/frontend/src/components/Section.tsx:
--------------------------------------------------------------------------------
 1 | import { FC, ReactElement } from 'react'
 2 | import { Card, Text } from '@fluentui/react-components'
 3 | 
 4 | export const Section: FC<{
 5 |   title: string
 6 |   desc?: string | null
 7 |   content: ReactElement
 8 |   outline?: boolean
 9 | }> = ({ title, desc, content, outline = true }) => {
10 |   return (
11 |     <Card size="small" appearance={outline ? 'outline' : 'subtle'}>
12 |       <div className="flex flex-col gap-5">
13 |         <div className="flex flex-col gap-1">
14 |           <Text weight="medium">{title}</Text>
15 |           {desc && <Text size={100}>{desc}</Text>}
16 |         </div>
17 |       </div>
18 |       <div className="overflow-y-auto overflow-x-hidden p-1">{content}</div>
19 |     </Card>
20 |   )
21 | }
22 | 


--------------------------------------------------------------------------------
/frontend/src/components/ToolTipButton.tsx:
--------------------------------------------------------------------------------
 1 | import React, {
 2 |   CSSProperties,
 3 |   FC,
 4 |   MouseEventHandler,
 5 |   ReactElement,
 6 | } from 'react'
 7 | import { Button, Tooltip } from '@fluentui/react-components'
 8 | 
 9 | export const ToolTipButton: FC<{
10 |   text?: string | null
11 |   desc: string
12 |   icon?: ReactElement
13 |   className?: string
14 |   style?: CSSProperties
15 |   size?: 'small' | 'medium' | 'large'
16 |   shape?: 'rounded' | 'circular' | 'square'
17 |   appearance?: 'secondary' | 'primary' | 'outline' | 'subtle' | 'transparent'
18 |   disabled?: boolean
19 |   onClick?: MouseEventHandler
20 |   showDelay?: number
21 | }> = ({
22 |   text,
23 |   desc,
24 |   icon,
25 |   className,
26 |   style,
27 |   size,
28 |   shape,
29 |   appearance,
30 |   disabled,
31 |   onClick,
32 |   showDelay = 0,
33 | }) => {
34 |   return desc ? (
35 |     <Tooltip
36 |       content={desc}
37 |       showDelay={showDelay}
38 |       hideDelay={0}
39 |       relationship="label"
40 |     >
41 |       <Button
42 |         style={style}
43 |         className={className}
44 |         disabled={disabled}
45 |         icon={icon}
46 |         onClick={onClick}
47 |         size={size}
48 |         shape={shape}
49 |         appearance={appearance}
50 |       >
51 |         {text}
52 |       </Button>
53 |     </Tooltip>
54 |   ) : (
55 |     <Button
56 |       style={style}
57 |       className={className}
58 |       disabled={disabled}
59 |       icon={icon}
60 |       onClick={onClick}
61 |       size={size}
62 |       shape={shape}
63 |       appearance={appearance}
64 |     >
65 |       {text}
66 |     </Button>
67 |   )
68 | }
69 | 


--------------------------------------------------------------------------------
/frontend/src/components/ValuedSlider.tsx:
--------------------------------------------------------------------------------
 1 | import React, { FC, useEffect, useRef } from 'react'
 2 | import { Slider, Text } from '@fluentui/react-components'
 3 | import { SliderOnChangeData } from '@fluentui/react-slider'
 4 | import { NumberInput } from './NumberInput'
 5 | 
 6 | export const ValuedSlider: FC<{
 7 |   value: number
 8 |   min: number
 9 |   max: number
10 |   step?: number
11 |   input?: boolean
12 |   onChange?: (
13 |     ev: React.ChangeEvent<HTMLInputElement>,
14 |     data: SliderOnChangeData
15 |   ) => void
16 |   toFixed?: number
17 |   disabled?: boolean
18 | }> = ({ value, min, max, step, input, onChange, toFixed, disabled }) => {
19 |   const sliderRef = useRef<HTMLInputElement>(null)
20 |   useEffect(() => {
21 |     if (step && sliderRef.current && sliderRef.current.parentElement) {
22 |       if ((max - min) / step > 10)
23 |         sliderRef.current.parentElement.style.removeProperty(
24 |           '--fui-Slider--steps-percent'
25 |         )
26 |     }
27 |   }, [])
28 | 
29 |   return (
30 |     <div className="flex items-center">
31 |       <Slider
32 |         ref={sliderRef}
33 |         className="grow"
34 |         style={{ minWidth: '50%' }}
35 |         value={value}
36 |         min={min}
37 |         max={max}
38 |         step={step}
39 |         onChange={onChange}
40 |         disabled={disabled}
41 |       />
42 |       {input ? (
43 |         <NumberInput
44 |           style={{ minWidth: 0 }}
45 |           value={value}
46 |           min={min}
47 |           max={max}
48 |           step={step}
49 |           onChange={onChange}
50 |           toFixed={toFixed}
51 |           disabled={disabled}
52 |         />
53 |       ) : (
54 |         <Text>{value}</Text>
55 |       )}
56 |     </div>
57 |   )
58 | }
59 | 


--------------------------------------------------------------------------------
/frontend/src/components/WorkHeader.tsx:
--------------------------------------------------------------------------------
 1 | import React, { FC } from 'react'
 2 | import { PresenceBadgeStatus } from '@fluentui/react-badge'
 3 | import { Divider, PresenceBadge, Text } from '@fluentui/react-components'
 4 | import { observer } from 'mobx-react-lite'
 5 | import { useTranslation } from 'react-i18next'
 6 | import { useMediaQuery } from 'usehooks-ts'
 7 | import commonStore, { ModelStatus } from '../stores/commonStore'
 8 | import { ConfigSelector } from './ConfigSelector'
 9 | import { RunButton } from './RunButton'
10 | 
11 | const statusText = {
12 |   [ModelStatus.Offline]: 'Offline',
13 |   [ModelStatus.Starting]: 'Starting',
14 |   [ModelStatus.Loading]: 'Loading',
15 |   [ModelStatus.Working]: 'Working',
16 | }
17 | 
18 | const badgeStatus: { [modelStatus: number]: PresenceBadgeStatus } = {
19 |   [ModelStatus.Offline]: 'unknown',
20 |   [ModelStatus.Starting]: 'away',
21 |   [ModelStatus.Loading]: 'away',
22 |   [ModelStatus.Working]: 'available',
23 | }
24 | 
25 | export const WorkHeader: FC = observer(() => {
26 |   const { t } = useTranslation()
27 |   const mq = useMediaQuery('(min-width: 640px)')
28 |   const port = commonStore.getCurrentModelConfig().apiParameters.apiPort
29 | 
30 |   return commonStore.platform === 'web' ? (
31 |     <div />
32 |   ) : (
33 |     <div className="flex flex-col gap-1">
34 |       <div className="flex items-center justify-between">
35 |         <div className="flex items-center gap-2">
36 |           <PresenceBadge status={badgeStatus[commonStore.status.status]} />
37 |           <Text size={100}>
38 |             {t('Model Status') +
39 |               ': ' +
40 |               t(statusText[commonStore.status.status])}
41 |           </Text>
42 |         </div>
43 |         {commonStore.lastModelName && mq && (
44 |           <Text size={100}>{commonStore.lastModelName}</Text>
45 |         )}
46 |         <div className="flex items-center gap-2">
47 |           <ConfigSelector size="small" />
48 |           <RunButton iconMode />
49 |         </div>
50 |       </div>
51 |       <Text size={100}>
52 |         {t(
53 |           "This tool's API is compatible with OpenAI API. It can be used with any ChatGPT tool you like. Go to the settings of some ChatGPT tool, replace the 'https://api.openai.com' part in the API address with '"
54 |         ) +
55 |           `http://127.0.0.1:${port}` +
56 |           "'."}
57 |       </Text>
58 |       <Divider style={{ flexGrow: 0 }} />
59 |     </div>
60 |   )
61 | })
62 | 


--------------------------------------------------------------------------------
/frontend/src/main.tsx:
--------------------------------------------------------------------------------
 1 | import './webWails'
 2 | import React from 'react'
 3 | import { createRoot } from 'react-dom/client'
 4 | import './style.scss'
 5 | import 'react-toastify/dist/ReactToastify.css'
 6 | import { HashRouter } from 'react-router-dom'
 7 | import App from './App'
 8 | import { startup } from './startup'
 9 | import './_locales/i18n-react'
10 | import { WindowShow } from '../wailsjs/runtime'
11 | 
12 | startup().then(() => {
13 |   const container = document.getElementById('root')
14 | 
15 |   const root = createRoot(container!)
16 | 
17 |   root.render(
18 |     <HashRouter>
19 |       <App />
20 |     </HashRouter>
21 |   )
22 | 
23 |   // force display the window
24 |   WindowShow()
25 | })
26 | 


--------------------------------------------------------------------------------
/frontend/src/pages/About.tsx:
--------------------------------------------------------------------------------
 1 | import React, { FC } from 'react'
 2 | import { observer } from 'mobx-react-lite'
 3 | import { useTranslation } from 'react-i18next'
 4 | import MarkdownRender from '../components/MarkdownRender'
 5 | import { Page } from '../components/Page'
 6 | import commonStore from '../stores/commonStore'
 7 | 
 8 | const About: FC = observer(() => {
 9 |   const { t } = useTranslation()
10 |   const lang: string = commonStore.settings.language
11 | 
12 |   return (
13 |     <Page
14 |       title={t('About')}
15 |       content={
16 |         <div className="overflow-y-auto overflow-x-hidden p-1">
17 |           <MarkdownRender>
18 |             {lang in commonStore.about
19 |               ? commonStore.about[lang]
20 |               : commonStore.about['en']}
21 |           </MarkdownRender>
22 |         </div>
23 |       }
24 |     />
25 |   )
26 | })
27 | 
28 | export default About
29 | 


--------------------------------------------------------------------------------
/frontend/src/pages/AudiotrackManager/AudiotrackButton.tsx:
--------------------------------------------------------------------------------
 1 | import React, { FC, lazy } from 'react'
 2 | import {
 3 |   Button,
 4 |   Dialog,
 5 |   DialogBody,
 6 |   DialogContent,
 7 |   DialogSurface,
 8 |   DialogTrigger,
 9 | } from '@fluentui/react-components'
10 | import { useTranslation } from 'react-i18next'
11 | import { CustomToastContainer } from '../../components/CustomToastContainer'
12 | import { LazyImportComponent } from '../../components/LazyImportComponent'
13 | import commonStore from '../../stores/commonStore'
14 | import { flushMidiRecordingContent } from '../../utils'
15 | 
16 | const AudiotrackEditor = lazy(() => import('./AudiotrackEditor'))
17 | 
18 | export const AudiotrackButton: FC<{
19 |   size?: 'small' | 'medium' | 'large'
20 |   shape?: 'rounded' | 'circular' | 'square'
21 |   appearance?: 'secondary' | 'primary' | 'outline' | 'subtle' | 'transparent'
22 |   setPrompt: (prompt: string) => void
23 | }> = ({ size, shape, appearance, setPrompt }) => {
24 |   const { t } = useTranslation()
25 | 
26 |   return (
27 |     <Dialog
28 |       onOpenChange={(e, data) => {
29 |         if (!data.open) {
30 |           flushMidiRecordingContent()
31 |           commonStore.setRecordingTrackId('')
32 |           commonStore.setPlayingTrackId('')
33 |         }
34 |       }}
35 |     >
36 |       <DialogTrigger disableButtonEnhancement>
37 |         <Button size={size} shape={shape} appearance={appearance}>
38 |           {t('Open MIDI Input Audio Tracks')}
39 |         </Button>
40 |       </DialogTrigger>
41 |       <DialogSurface
42 |         style={{
43 |           paddingTop: 0,
44 |           maxWidth: '90vw',
45 |           width: 'fit-content',
46 |           transform: 'unset',
47 |         }}
48 |       >
49 |         <DialogBody>
50 |           <DialogContent className="overflow-hidden">
51 |             <CustomToastContainer />
52 |             <LazyImportComponent
53 |               lazyChildren={AudiotrackEditor}
54 |               lazyProps={{ setPrompt }}
55 |             />
56 |           </DialogContent>
57 |         </DialogBody>
58 |       </DialogSurface>
59 |     </Dialog>
60 |   )
61 | }
62 | 


--------------------------------------------------------------------------------
/frontend/src/pages/index.tsx:
--------------------------------------------------------------------------------
 1 | import { FC, lazy, LazyExoticComponent, ReactElement } from 'react'
 2 | import {
 3 |   ArrowDownload20Regular,
 4 |   Chat20Regular,
 5 |   ClipboardEdit20Regular,
 6 |   DataUsageSettings20Regular,
 7 |   DocumentSettings20Regular,
 8 |   Home20Regular,
 9 |   Info20Regular,
10 |   MusicNote220Regular,
11 |   Settings20Regular,
12 |   Storage20Regular,
13 | } from '@fluentui/react-icons'
14 | 
15 | type NavigationItem = {
16 |   label: string
17 |   path: string
18 |   icon: ReactElement
19 |   element: LazyExoticComponent<FC>
20 |   top: boolean
21 | }
22 | 
23 | export const pages: NavigationItem[] = [
24 |   {
25 |     label: 'Home',
26 |     path: '/',
27 |     icon: <Home20Regular />,
28 |     element: lazy(() => import('./Home')),
29 |     top: true,
30 |   },
31 |   {
32 |     label: 'Chat',
33 |     path: '/chat',
34 |     icon: <Chat20Regular />,
35 |     element: lazy(() => import('./Chat')),
36 |     top: true,
37 |   },
38 |   {
39 |     label: 'Completion',
40 |     path: '/completion',
41 |     icon: <ClipboardEdit20Regular />,
42 |     element: lazy(() => import('./Completion')),
43 |     top: true,
44 |   },
45 |   {
46 |     label: 'Composition',
47 |     path: '/composition',
48 |     icon: <MusicNote220Regular />,
49 |     element: lazy(() => import('./Composition')),
50 |     top: true,
51 |   },
52 |   {
53 |     label: 'Configs',
54 |     path: '/configs',
55 |     icon: <DocumentSettings20Regular />,
56 |     element: lazy(() => import('./Configs')),
57 |     top: true,
58 |   },
59 |   {
60 |     label: 'Models',
61 |     path: '/models',
62 |     icon: <DataUsageSettings20Regular />,
63 |     element: lazy(() => import('./Models')),
64 |     top: true,
65 |   },
66 |   {
67 |     label: 'Downloads',
68 |     path: '/downloads',
69 |     icon: <ArrowDownload20Regular />,
70 |     element: lazy(() => import('./Downloads')),
71 |     top: true,
72 |   },
73 |   {
74 |     label: 'Train',
75 |     path: '/train',
76 |     icon: <Storage20Regular />,
77 |     element: lazy(() => import('./Train')),
78 |     top: true,
79 |   },
80 |   {
81 |     label: 'Settings',
82 |     path: '/settings',
83 |     icon: <Settings20Regular />,
84 |     element: lazy(() => import('./Settings')),
85 |     top: false,
86 |   },
87 |   {
88 |     label: 'About',
89 |     path: '/about',
90 |     icon: <Info20Regular />,
91 |     element: lazy(() => import('./About')),
92 |     top: false,
93 |   },
94 | ]
95 | 


--------------------------------------------------------------------------------
/frontend/src/style.scss:
--------------------------------------------------------------------------------
 1 | [data-theme='dark'] {
 2 |   @import 'highlight.js/scss/github-dark.scss';
 3 |   --color-neutral-muted: rgba(110, 118, 129, 0.4);
 4 | }
 5 | 
 6 | [data-theme='light'] {
 7 |   @import 'highlight.js/scss/github.scss';
 8 |   --color-neutral-muted: rgba(150, 160, 170, 0.3);
 9 | }
10 | 
11 | @tailwind base;
12 | @tailwind components;
13 | @tailwind utilities;
14 | 
15 | body {
16 |   margin: 0;
17 |   width: 100%;
18 |   height: 100%;
19 | }
20 | 
21 | * {
22 |   scrollbar-width: thin;
23 | }
24 | 
25 | /* Works on Chrome, Edge, and Safari */
26 | *::-webkit-scrollbar {
27 |   width: 9px;
28 |   height: 9px;
29 | }
30 | 
31 | *::-webkit-scrollbar-thumb {
32 |   background-color: rgba(155, 155, 155, 0.5);
33 |   border-radius: 20px;
34 |   border: transparent;
35 | }
36 | 
37 | *::-webkit-scrollbar-track {
38 |   background: transparent;
39 | }
40 | 
41 | *::-webkit-scrollbar-corner {
42 |   background: transparent;
43 | }
44 | 
45 | .markdown-body {
46 |   overflow-y: auto;
47 |   overflow-x: hidden;
48 | 
49 |   pre {
50 |     padding: 0;
51 |     background: transparent;
52 | 
53 |     code {
54 |       font-size: 14px;
55 |     }
56 |   }
57 | 
58 |   code {
59 |     white-space: pre-wrap;
60 |     word-break: break-word;
61 |     border-radius: 8px;
62 |     background-color: var(--color-neutral-muted);
63 |   }
64 | 
65 |   details summary {
66 |     cursor: pointer;
67 |   }
68 | }
69 | 
70 | midi-player {
71 |   &::part(control-panel) {
72 |     background: none;
73 |   }
74 | }
75 | 
76 | midi-visualizer {
77 |   $instrument-colors: #007bff, #20c997, #dc3545, #6610f2, #ffc107, #e83e8c,
78 |     #17a2b8, #fd7e14, #28a745;
79 | 
80 |   svg {
81 |     @for $i from 0 to 200 {
82 |       $color: nth($instrument-colors, ($i % length($instrument-colors)) + 1);
83 |       rect.note[data-instrument='#{$i}'] {
84 |         fill: $color;
85 |       }
86 |     }
87 |   }
88 | }
89 | 


--------------------------------------------------------------------------------
/frontend/src/types/about.ts:
--------------------------------------------------------------------------------
1 | export type AboutContent = { [lang: string]: string }
2 | 


--------------------------------------------------------------------------------
/frontend/src/types/chat.ts:
--------------------------------------------------------------------------------
 1 | import { ApiParameters } from './configs'
 2 | 
 3 | export const userName = 'M E'
 4 | export const botName = 'A I'
 5 | export const systemName = 'System'
 6 | export const welcomeUuid = 'welcome'
 7 | 
 8 | export enum MessageType {
 9 |   Normal,
10 |   Error,
11 | }
12 | 
13 | export type Side = 'left' | 'right' | 'center'
14 | export type Color = 'neutral' | 'brand' | 'colorful'
15 | export type MessageItem = {
16 |   sender: string
17 |   toolName?: string
18 |   type: MessageType
19 |   color: Color
20 |   avatarImg?: string
21 |   time: string
22 |   content: string
23 |   side: Side
24 |   done: boolean
25 |   thinking?: boolean
26 |   thinkingEnded?: boolean
27 | }
28 | export type Conversation = {
29 |   [uuid: string]: MessageItem
30 | }
31 | export type Role = 'assistant' | 'user' | 'system' | 'tool'
32 | export type ConversationMessage = {
33 |   role: Role
34 |   content: string
35 |   prefix?: boolean
36 |   tool_call_id?: string
37 |   tool_calls?: Array<{
38 |     id: string
39 |     type: 'function'
40 |     function: {
41 |       name: string
42 |       arguments: string
43 |     }
44 |   }>
45 | }
46 | export type Attachment = {
47 |   name: string
48 |   size: number
49 |   content: string
50 | }
51 | export type ChatParams = Omit<ApiParameters, 'apiPort'> & {
52 |   historyN: number
53 |   markdown: boolean
54 |   functionCall: boolean
55 |   toolDefinition: string
56 |   toolReturn: string
57 | }
58 | 


--------------------------------------------------------------------------------
/frontend/src/types/completion.ts:
--------------------------------------------------------------------------------
 1 | import { ApiParameters } from './configs'
 2 | 
 3 | export type CompletionParams = Omit<ApiParameters, 'apiPort'> & {
 4 |   stop: string
 5 |   injectStart: string
 6 |   injectEnd: string
 7 | }
 8 | export type CompletionPreset = {
 9 |   name: string
10 |   prompt: string
11 |   params: CompletionParams
12 | }
13 | 


--------------------------------------------------------------------------------
/frontend/src/types/composition.ts:
--------------------------------------------------------------------------------
 1 | import { NoteSequence } from '@magenta/music/esm/protobuf'
 2 | 
 3 | export const tracksMinimalTotalTime = 5000
 4 | 
 5 | export type CompositionParams = {
 6 |   prompt: string
 7 |   maxResponseToken: number
 8 |   temperature: number
 9 |   topP: number
10 |   autoPlay: boolean
11 |   useLocalSoundFont: boolean
12 |   externalPlay: boolean
13 |   midi: ArrayBuffer | null
14 |   ns: NoteSequence | null
15 |   generationStartTime: number
16 |   playOnlyGeneratedContent: boolean
17 | }
18 | export type Track = {
19 |   id: string
20 |   mainInstrument: string
21 |   content: string
22 |   rawContent: MidiMessage[]
23 |   offsetTime: number
24 |   contentTime: number
25 | }
26 | export type MidiPort = {
27 |   name: string
28 | }
29 | 
30 | export type MessageType = 'NoteOff' | 'NoteOn' | 'ElapsedTime' | 'ControlChange'
31 | 
32 | export type MidiMessage = {
33 |   messageType: MessageType
34 |   channel: number
35 |   note: number
36 |   velocity: number
37 |   control: number
38 |   value: number
39 |   instrument: InstrumentType
40 | }
41 | 
42 | export enum InstrumentType {
43 |   Piano,
44 |   Percussion,
45 |   Drum,
46 |   Tuba,
47 |   Marimba,
48 |   Bass,
49 |   Guitar,
50 |   Violin,
51 |   Trumpet,
52 |   Sax,
53 |   Flute,
54 |   Lead,
55 |   Pad,
56 | }
57 | 
58 | export const InstrumentTypeNameMap = [
59 |   'Piano',
60 |   'Percussion',
61 |   'Drum',
62 |   'Tuba',
63 |   'Marimba',
64 |   'Bass',
65 |   'Guitar',
66 |   'Violin',
67 |   'Trumpet',
68 |   'Sax',
69 |   'Flute',
70 |   'Lead',
71 |   'Pad',
72 | ]
73 | 
74 | export const InstrumentTypeTokenMap = [
75 |   'pi',
76 |   'p',
77 |   'd',
78 |   't',
79 |   'm',
80 |   'b',
81 |   'g',
82 |   'v',
83 |   'tr',
84 |   's',
85 |   'f',
86 |   'l',
87 |   'pa',
88 | ]
89 | 


--------------------------------------------------------------------------------
/frontend/src/types/configs.ts:
--------------------------------------------------------------------------------
 1 | export type ApiParameters = {
 2 |   apiPort: number
 3 |   maxResponseToken: number
 4 |   temperature: number
 5 |   topP: number
 6 |   presencePenalty: number
 7 |   frequencyPenalty: number
 8 |   penaltyDecay?: number
 9 |   globalPenalty?: boolean
10 |   stateModel?: string
11 | }
12 | export type Device =
13 |   | 'CPU'
14 |   | 'CPU (rwkv.cpp)'
15 |   | 'CUDA'
16 |   | 'CUDA-Beta'
17 |   | 'WebGPU'
18 |   | 'WebGPU (Python)'
19 |   | 'MPS'
20 |   | 'Custom'
21 | export type Precision = 'fp16' | 'int8' | 'fp32' | 'nf4' | 'Q5_1'
22 | export type GGUFMode = 'CPU' | 'Vulkan GPU'
23 | export type ModelParameters = {
24 |   // different models can not have the same name
25 |   modelName: string
26 |   device: Device
27 |   precision: Precision
28 |   storedLayers: number
29 |   maxStoredLayers: number
30 |   quantizedLayers?: number
31 |   tokenChunkSize?: number
32 |   useCustomCuda?: boolean
33 |   customStrategy?: string
34 |   useCustomTokenizer?: boolean
35 |   customTokenizer?: string
36 |   ggufMode?: GGUFMode
37 |   llamaContext?: number
38 | }
39 | export type ModelConfig = {
40 |   // different configs can have the same name
41 |   name: string
42 |   apiParameters: ApiParameters
43 |   modelParameters: ModelParameters
44 |   enableWebUI?: boolean
45 | }
46 | 


--------------------------------------------------------------------------------
/frontend/src/types/downloads.ts:
--------------------------------------------------------------------------------
 1 | export type DownloadStatus = {
 2 |   name: string
 3 |   path: string
 4 |   url: string
 5 |   transferred: number
 6 |   size: number
 7 |   speed: number
 8 |   progress: number
 9 |   downloading: boolean
10 |   done: boolean
11 | }
12 | 


--------------------------------------------------------------------------------
/frontend/src/types/home.ts:
--------------------------------------------------------------------------------
 1 | import { ReactElement } from 'react'
 2 | 
 3 | export type IntroductionContent = {
 4 |   [lang: string]: string
 5 | }
 6 | export type NavCard = {
 7 |   label: string
 8 |   desc: string
 9 |   path: string
10 |   icon: ReactElement
11 | }
12 | 


--------------------------------------------------------------------------------
/frontend/src/types/html-midi-player.d.ts:
--------------------------------------------------------------------------------
 1 | declare module JSX {
 2 |   import { PlayerElement } from 'html-midi-player'
 3 |   import { VisualizerElement } from 'html-midi-player'
 4 | 
 5 |   interface IntrinsicElements {
 6 |     'midi-player': PlayerElement
 7 |     'midi-visualizer': VisualizerElement
 8 |   }
 9 | }
10 | 


--------------------------------------------------------------------------------
/frontend/src/types/models.ts:
--------------------------------------------------------------------------------
 1 | export type ModelSourceItem = {
 2 |   name: string
 3 |   desc?: { [lang: string]: string | undefined }
 4 |   size: number
 5 |   SHA256?: string
 6 |   lastUpdated: string
 7 |   url?: string
 8 |   downloadUrl?: string
 9 |   tags?: string[]
10 |   customTokenizer?: string
11 |   hide?: boolean
12 |   functionCall?: boolean
13 | 
14 |   lastUpdatedMs?: number
15 |   isComplete?: boolean
16 |   isLocal?: boolean
17 |   localSize?: number
18 | }
19 | 


--------------------------------------------------------------------------------
/frontend/src/types/presets.ts:
--------------------------------------------------------------------------------
 1 | import { ReactElement } from 'react'
 2 | import { ConversationMessage } from './chat'
 3 | 
 4 | export type PresetType = 'chat' | 'completion' | 'chatInCompletion'
 5 | export type Preset = {
 6 |   name: string
 7 |   tag: string
 8 |   // if name and sourceUrl are same, it will be overridden when importing
 9 |   sourceUrl: string
10 |   desc: string
11 |   avatarImg: string
12 |   userAvatarImg?: string
13 |   type: PresetType
14 |   // chat
15 |   welcomeMessage: string
16 |   messages: ConversationMessage[]
17 |   displayPresetMessages: boolean
18 |   // completion
19 |   prompt: string
20 |   stop: string
21 |   injectStart: string
22 |   injectEnd: string
23 |   presystem?: boolean
24 |   userName?: string
25 |   assistantName?: string
26 | }
27 | export type PresetsNavigationItem = {
28 |   icon: ReactElement
29 |   element: ReactElement
30 | }
31 | 


--------------------------------------------------------------------------------
/frontend/src/types/settings.ts:
--------------------------------------------------------------------------------
 1 | export const Languages = {
 2 |   dev: 'English', // i18n default
 3 |   zh: '简体中文',
 4 |   ja: '日本語',
 5 | }
 6 | export type Language = keyof typeof Languages
 7 | export type SettingsType = {
 8 |   language: Language
 9 |   darkMode: boolean
10 |   autoUpdatesCheck: boolean
11 |   giteeUpdatesSource: boolean
12 |   cnMirror: boolean
13 |   useHfMirror: boolean
14 |   host: string
15 |   dpiScaling: number
16 |   customModelsPath: string
17 |   customPythonPath: string
18 |   apiUrl: string
19 |   apiKey: string
20 |   apiChatModelName: string
21 |   apiCompletionModelName: string
22 |   coreApiUrl: string
23 |   rememberAllDurableData: boolean
24 | }
25 | 


--------------------------------------------------------------------------------
/frontend/src/types/train.ts:
--------------------------------------------------------------------------------
 1 | import { ReactElement } from 'react'
 2 | 
 3 | export type DataProcessParameters = {
 4 |   dataPath: string
 5 |   vocabPath: string
 6 | }
 7 | export type LoraFinetunePrecision = 'bf16' | 'fp16' | 'tf32'
 8 | export type LoraFinetuneParameters = {
 9 |   baseModel: string
10 |   ctxLen: number
11 |   epochSteps: number
12 |   epochCount: number
13 |   epochBegin: number
14 |   epochSave: number
15 |   microBsz: number
16 |   accumGradBatches: number
17 |   preFfn: boolean
18 |   headQk: boolean
19 |   lrInit: string
20 |   lrFinal: string
21 |   warmupSteps: number
22 |   beta1: number
23 |   beta2: number
24 |   adamEps: string
25 |   devices: number
26 |   precision: LoraFinetunePrecision
27 |   gradCp: boolean
28 |   loraR: number
29 |   loraAlpha: number
30 |   loraDropout: number
31 |   loraLoad: string
32 | }
33 | export type TrainNavigationItem = {
34 |   element: ReactElement
35 | }
36 | 


--------------------------------------------------------------------------------
/frontend/src/utils/copy-cuda-kernels.ts:
--------------------------------------------------------------------------------
 1 | import { CopyFolderFiles } from '../../wailsjs/go/backend_golang/App'
 2 | 
 3 | export async function copyCudaKernels(torchVersion: string) {
 4 |   const copyRoot = './backend-python/rwkv_pip'
 5 |   if (torchVersion === '2.7.1+cu128') {
 6 |     await CopyFolderFiles(
 7 |       copyRoot + '/kernels/torch-2.7.1+cu128',
 8 |       copyRoot,
 9 |       true
10 |     )
11 |   } else {
12 |     await CopyFolderFiles(
13 |       copyRoot + '/kernels/torch-1.13.1+cu117',
14 |       copyRoot,
15 |       true
16 |     )
17 |   }
18 | }
19 | 


--------------------------------------------------------------------------------
/frontend/src/utils/filter-function-properties.ts:
--------------------------------------------------------------------------------
1 | export type FilterFunctionProperties<T> = {
2 |   // eslint-disable-next-line
3 |   [K in keyof T as T[K] extends Function ? never : K]: T[K]
4 | }
5 | 


--------------------------------------------------------------------------------
/frontend/src/utils/get-available-torch-cu-version.ts:
--------------------------------------------------------------------------------
 1 | import { compare } from 'compare-versions'
 2 | 
 3 | export const torchVersions = ['1.13.1', '2.7.1']
 4 | 
 5 | export function getAvailableTorchCuVersion(
 6 |   torchVersion: string,
 7 |   driverCudaVersion: string
 8 | ) {
 9 |   let retTorchVersion = ''
10 |   let retCuSourceVersion = ''
11 |   const targetTorchVersion = torchVersion.split('+')[0]
12 |   if (compare(targetTorchVersion, '2.7.1', '>=')) {
13 |     retTorchVersion = '2.7.1'
14 |     if (compare(driverCudaVersion, '12.8', '>=')) {
15 |       retCuSourceVersion = '12.8'
16 |     } else if (compare(driverCudaVersion, '12.6', '>=')) {
17 |       retCuSourceVersion = '12.6'
18 |     } else {
19 |       retCuSourceVersion = '11.8'
20 |     }
21 |   } else {
22 |     retTorchVersion = '1.13.1'
23 |     if (compare(driverCudaVersion, '11.7', '>=')) {
24 |       retCuSourceVersion = '11.7'
25 |     } else {
26 |       retCuSourceVersion = '11.6'
27 |     }
28 |   }
29 |   return { torchVersion: retTorchVersion, cuSourceVersion: retCuSourceVersion }
30 | }
31 | 


--------------------------------------------------------------------------------
/frontend/src/vite-env.d.ts:
--------------------------------------------------------------------------------
1 | /// <reference types="vite/client" />
2 | 


--------------------------------------------------------------------------------
/frontend/tailwind.config.js:
--------------------------------------------------------------------------------
  1 | const markdownElements = [
  2 |   'div',
  3 |   'p',
  4 |   'span',
  5 | 
  6 |   'video',
  7 |   'img',
  8 | 
  9 |   'abbr',
 10 |   'acronym',
 11 |   'b',
 12 |   'blockquote',
 13 |   'code',
 14 |   'em',
 15 |   'i',
 16 |   'li',
 17 |   'ol',
 18 |   'ul',
 19 |   'strong',
 20 |   'table',
 21 |   'tr',
 22 |   'td',
 23 |   'th',
 24 | 
 25 |   'details',
 26 |   'summary',
 27 |   'kbd',
 28 |   'samp',
 29 |   'sub',
 30 |   'sup',
 31 |   'ins',
 32 |   'del',
 33 |   'var',
 34 |   'q',
 35 |   'dl',
 36 |   'dt',
 37 |   'dd',
 38 |   'ruby',
 39 |   'rt',
 40 |   'rp',
 41 | 
 42 |   'br',
 43 |   'hr',
 44 | 
 45 |   'h1',
 46 |   'h2',
 47 |   'h3',
 48 |   'h4',
 49 |   'h5',
 50 |   'h6',
 51 | 
 52 |   'thead',
 53 |   'tbody',
 54 |   'tfoot',
 55 |   'u',
 56 |   's',
 57 |   'a',
 58 |   'pre',
 59 |   'cite',
 60 | ]
 61 | 
 62 | const markdownPseudoElements = ['::marker', '::before', '::after']
 63 | 
 64 | const tableElements = ['table', 'tr', 'td', 'th', 'thead', 'tbody', 'tfoot']
 65 | 
 66 | const proseStyles = {
 67 |   color: 'inherit',
 68 | }
 69 | 
 70 | const tableProseStyles = {
 71 |   ...proseStyles,
 72 |   borderWidth: 'thin',
 73 |   borderColor: '#d2d2d5',
 74 | }
 75 | 
 76 | const elementsStyles = markdownElements.reduce((acc, element) => {
 77 |   let styles = proseStyles
 78 |   if (tableElements.includes(element)) styles = tableProseStyles
 79 | 
 80 |   acc[element] = styles
 81 |   markdownPseudoElements.forEach((pseudo) => {
 82 |     acc[element + pseudo] = styles
 83 |   })
 84 |   return acc
 85 | }, {})
 86 | 
 87 | /** @type {import('tailwindcss').Config} */
 88 | export default {
 89 |   content: ['./index.html', './src/**/*.{js,ts,jsx,tsx}'],
 90 |   theme: {
 91 |     extend: {
 92 |       typography: {
 93 |         DEFAULT: {
 94 |           css: {
 95 |             color: 'inherit',
 96 |             fontSize: 'inherit',
 97 |             ...elementsStyles,
 98 |           },
 99 |         },
100 |       },
101 |     },
102 |   },
103 |   plugins: [require('@tailwindcss/typography')],
104 | }
105 | 


--------------------------------------------------------------------------------
/frontend/tsconfig.json:
--------------------------------------------------------------------------------
 1 | {
 2 |   "compilerOptions": {
 3 |     "target": "ESNext",
 4 |     "useDefineForClassFields": true,
 5 |     "lib": ["DOM", "DOM.Iterable", "ESNext"],
 6 |     "allowJs": false,
 7 |     "skipLibCheck": true,
 8 |     "esModuleInterop": false,
 9 |     "allowSyntheticDefaultImports": true,
10 |     "strict": true,
11 |     "forceConsistentCasingInFileNames": true,
12 |     "module": "ESNext",
13 |     "moduleResolution": "Node",
14 |     "resolveJsonModule": true,
15 |     "isolatedModules": true,
16 |     "noEmit": true,
17 |     "jsx": "react-jsx"
18 |   },
19 |   "include": ["src"],
20 |   "references": [
21 |     {
22 |       "path": "./tsconfig.node.json"
23 |     }
24 |   ]
25 | }
26 | 


--------------------------------------------------------------------------------
/frontend/tsconfig.node.json:
--------------------------------------------------------------------------------
 1 | {
 2 |   "compilerOptions": {
 3 |     "composite": true,
 4 |     "module": "ESNext",
 5 |     "moduleResolution": "Node",
 6 |     "allowSyntheticDefaultImports": true
 7 |   },
 8 |   "include": ["vite.config.ts"]
 9 | }
10 | 


--------------------------------------------------------------------------------
/frontend/vite.config.ts:
--------------------------------------------------------------------------------
 1 | import react from '@vitejs/plugin-react'
 2 | import { visualizer } from 'rollup-plugin-visualizer'
 3 | import { defineConfig } from 'vite'
 4 | import topLevelAwait from 'vite-plugin-top-level-await'
 5 | // @ts-ignore
 6 | import { dependencies } from './package.json'
 7 | 
 8 | // dependencies that exist anywhere
 9 | const vendor = [
10 |   'react',
11 |   'react-dom',
12 |   'react-router',
13 |   'react-router-dom',
14 |   '@fluentui/react-icons',
15 |   'mobx',
16 |   'mobx-react-lite',
17 |   'i18next',
18 |   'react-i18next',
19 |   'usehooks-ts',
20 |   'react-toastify',
21 |   'classnames',
22 |   'lodash-es',
23 | ]
24 | 
25 | const embedded = [
26 |   // split @fluentui/react-components by components
27 |   '@fluentui/react-components',
28 |   '@tabler/icons-react',
29 | 
30 |   // dependencies that exist in single component
31 |   'react-beautiful-dnd',
32 |   'react-draggable',
33 |   '@magenta/music',
34 |   'html-midi-player',
35 |   'react-markdown',
36 |   'rehype-highlight',
37 |   'rehype-raw',
38 |   'remark-breaks',
39 |   'remark-gfm',
40 |   'remark-math',
41 |   'rehype-katex',
42 |   'katex',
43 | ]
44 | 
45 | function renderChunks(deps: Record<string, string>) {
46 |   let chunks = {}
47 |   Object.keys(deps).forEach((key) => {
48 |     if ([...vendor, ...embedded].includes(key)) return
49 |     chunks[key] = [key]
50 |   })
51 |   return chunks
52 | }
53 | 
54 | // https://vitejs.dev/config/
55 | export default defineConfig({
56 |   plugins: [
57 |     react(),
58 |     visualizer({
59 |       template: 'treemap',
60 |       gzipSize: true,
61 |       brotliSize: true,
62 |     }),
63 |     topLevelAwait({
64 |       promiseExportName: '__tla',
65 |       promiseImportName: (i) => `__tla_${i}`,
66 |     }),
67 |   ],
68 |   resolve: {
69 |     alias: {
70 |       // /esm/icons/index.mjs only exports the icons statically, so no separate chunks are created
71 |       '@tabler/icons-react': '@tabler/icons-react/dist/esm/icons/index.mjs',
72 |     },
73 |   },
74 |   build: {
75 |     chunkSizeWarningLimit: 3000,
76 |     rollupOptions: {
77 |       output: {
78 |         manualChunks: {
79 |           vendor,
80 |           ...renderChunks(dependencies),
81 |         },
82 |         entryFileNames: `assets/[name].js`,
83 |         chunkFileNames: `assets/[name].js`,
84 |         assetFileNames: `assets/[name].[ext]`,
85 |       },
86 |     },
87 |   },
88 | })
89 | 


--------------------------------------------------------------------------------
/frontend/wailsjs/go/models.ts:
--------------------------------------------------------------------------------
 1 | export namespace backend_golang {
 2 | 	
 3 | 	export class FileInfo {
 4 | 	    name: string;
 5 | 	    size: number;
 6 | 	    isDir: boolean;
 7 | 	    modTime: string;
 8 | 	
 9 | 	    static createFrom(source: any = {}) {
10 | 	        return new FileInfo(source);
11 | 	    }
12 | 	
13 | 	    constructor(source: any = {}) {
14 | 	        if ('string' === typeof source) source = JSON.parse(source);
15 | 	        this.name = source["name"];
16 | 	        this.size = source["size"];
17 | 	        this.isDir = source["isDir"];
18 | 	        this.modTime = source["modTime"];
19 | 	    }
20 | 	}
21 | 	export class MIDIMessage {
22 | 	    messageType: string;
23 | 	    channel: number;
24 | 	    note: number;
25 | 	    velocity: number;
26 | 	    control: number;
27 | 	    value: number;
28 | 	
29 | 	    static createFrom(source: any = {}) {
30 | 	        return new MIDIMessage(source);
31 | 	    }
32 | 	
33 | 	    constructor(source: any = {}) {
34 | 	        if ('string' === typeof source) source = JSON.parse(source);
35 | 	        this.messageType = source["messageType"];
36 | 	        this.channel = source["channel"];
37 | 	        this.note = source["note"];
38 | 	        this.velocity = source["velocity"];
39 | 	        this.control = source["control"];
40 | 	        this.value = source["value"];
41 | 	    }
42 | 	}
43 | 
44 | }
45 | 
46 | 


--------------------------------------------------------------------------------
/frontend/wailsjs/runtime/package.json:
--------------------------------------------------------------------------------
 1 | {
 2 |   "name": "@wailsapp/runtime",
 3 |   "version": "2.0.0",
 4 |   "description": "Wails Javascript runtime library",
 5 |   "main": "runtime.js",
 6 |   "types": "runtime.d.ts",
 7 |   "scripts": {
 8 |   },
 9 |   "repository": {
10 |     "type": "git",
11 |     "url": "git+https://github.com/wailsapp/wails.git"
12 |   },
13 |   "keywords": [
14 |     "Wails",
15 |     "Javascript",
16 |     "Go"
17 |   ],
18 |   "author": "Lea Anthony <lea.anthony@gmail.com>",
19 |   "license": "MIT",
20 |   "bugs": {
21 |     "url": "https://github.com/wailsapp/wails/issues"
22 |   },
23 |   "homepage": "https://github.com/wailsapp/wails#readme"
24 | }
25 | 


--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
 1 | module rwkv-runner
 2 | 
 3 | go 1.20
 4 | 
 5 | require (
 6 | 	github.com/cavaliergopher/grab/v3 v3.0.1
 7 | 	github.com/fsnotify/fsnotify v1.8.0
 8 | 	github.com/mattrtaylor/go-rtmidi v0.0.0-20220428034745-af795b1c1a79
 9 | 	github.com/minio/selfupdate v0.6.0
10 | 	github.com/nyaosorg/go-windows-su v0.2.1
11 | 	github.com/ubuntu/gowsl v0.0.0-20230615094051-94945650cc1e
12 | 	github.com/wailsapp/wails/v2 v2.9.2
13 | 	golang.org/x/text v0.22.0
14 | )
15 | 
16 | require (
17 | 	aead.dev/minisign v0.2.0 // indirect
18 | 	github.com/bep/debounce v1.2.1 // indirect
19 | 	github.com/go-ole/go-ole v1.3.0 // indirect
20 | 	github.com/godbus/dbus/v5 v5.1.0 // indirect
21 | 	github.com/google/uuid v1.6.0 // indirect
22 | 	github.com/jchv/go-winloader v0.0.0-20210711035445-715c2860da7e // indirect
23 | 	github.com/labstack/echo/v4 v4.13.3 // indirect
24 | 	github.com/labstack/gommon v0.4.2 // indirect
25 | 	github.com/leaanthony/go-ansi-parser v1.6.1 // indirect
26 | 	github.com/leaanthony/gosod v1.0.4 // indirect
27 | 	github.com/leaanthony/slicer v1.6.0 // indirect
28 | 	github.com/leaanthony/u v1.1.1 // indirect
29 | 	github.com/mattn/go-colorable v0.1.13 // indirect
30 | 	github.com/mattn/go-isatty v0.0.20 // indirect
31 | 	github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect
32 | 	github.com/pkg/errors v0.9.1 // indirect
33 | 	github.com/rivo/uniseg v0.4.7 // indirect
34 | 	github.com/samber/lo v1.49.1 // indirect
35 | 	github.com/sirupsen/logrus v1.9.0 // indirect
36 | 	github.com/tkrajina/go-reflector v0.5.8 // indirect
37 | 	github.com/ubuntu/decorate v0.0.0-20230125165522-2d5b0a9bb117 // indirect
38 | 	github.com/valyala/bytebufferpool v1.0.0 // indirect
39 | 	github.com/valyala/fasttemplate v1.2.2 // indirect
40 | 	github.com/wailsapp/go-webview2 v1.0.21 // indirect
41 | 	github.com/wailsapp/mimetype v1.4.1 // indirect
42 | 	golang.org/x/crypto v0.33.0 // indirect
43 | 	golang.org/x/net v0.35.0 // indirect
44 | 	golang.org/x/sys v0.30.0 // indirect
45 | )
46 | 


--------------------------------------------------------------------------------
/py310/Lib/site-packages/cyac-1.9.dist-info/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/cyac-1.9.dist-info/.gitkeep


--------------------------------------------------------------------------------
/py310/Lib/site-packages/cyac/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/cyac/.gitkeep


--------------------------------------------------------------------------------
/py310/Lib/site-packages/diskcache-5.6.3.dist-info/INSTALLER:
--------------------------------------------------------------------------------
1 | pip
2 | 


--------------------------------------------------------------------------------
/py310/Lib/site-packages/diskcache-5.6.3.dist-info/LICENSE:
--------------------------------------------------------------------------------
 1 | Copyright 2016-2022 Grant Jenks
 2 | 
 3 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use
 4 | this file except in compliance with the License.  You may obtain a copy of the
 5 | License at
 6 | 
 7 |     http://www.apache.org/licenses/LICENSE-2.0
 8 | 
 9 | Unless required by applicable law or agreed to in writing, software distributed
10 | under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
11 | CONDITIONS OF ANY KIND, either express or implied.  See the License for the
12 | specific language governing permissions and limitations under the License.
13 | 


--------------------------------------------------------------------------------
/py310/Lib/site-packages/diskcache-5.6.3.dist-info/RECORD:
--------------------------------------------------------------------------------
 1 | diskcache-5.6.3.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
 2 | diskcache-5.6.3.dist-info/LICENSE,sha256=WDVGuqP9k2B9hFEmVwZ3pAH1COIotQRP77w5Sa8XlnI,559
 3 | diskcache-5.6.3.dist-info/METADATA,sha256=b8cM_0OjxmZlMebCaOkAWfuTrliu1wgWPAzuyrJNuu8,20458
 4 | diskcache-5.6.3.dist-info/RECORD,,
 5 | diskcache-5.6.3.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
 6 | diskcache-5.6.3.dist-info/WHEEL,sha256=yQN5g4mg4AybRjkgi-9yy4iQEFibGQmlz78Pik5Or-A,92
 7 | diskcache-5.6.3.dist-info/top_level.txt,sha256=A5fqg_AHgOQc_0o1NZ-Uo5Bsb7CV3fR99J-p1-F4yuA,10
 8 | diskcache/__init__.py,sha256=JcEtz224G-ysHmgyxWId28A0wQfCIif0mdv4tqC-Zxk,1262
 9 | diskcache/__pycache__/__init__.cpython-310.pyc,,
10 | diskcache/__pycache__/cli.cpython-310.pyc,,
11 | diskcache/__pycache__/core.cpython-310.pyc,,
12 | diskcache/__pycache__/djangocache.cpython-310.pyc,,
13 | diskcache/__pycache__/fanout.cpython-310.pyc,,
14 | diskcache/__pycache__/persistent.cpython-310.pyc,,
15 | diskcache/__pycache__/recipes.cpython-310.pyc,,
16 | diskcache/cli.py,sha256=RVv6Fyn7h_0E5BoAVJzbbjOWPDHXO0ZjNgh9g5CBRP8,44
17 | diskcache/core.py,sha256=oinmHnYCXvR2QXVRaBoqPf8iM79M5pAeixajlaJ9TNs,81867
18 | diskcache/djangocache.py,sha256=SX8jl2d-vfOMG7hXZuuh2VCnZEUP-mHGjYLpM2RdMaQ,16110
19 | diskcache/fanout.py,sha256=E4Gk-puAGsaiqyf45YlLdVTf2RM1g7Qy-TPSug1Elpo,22725
20 | diskcache/persistent.py,sha256=Cir3CPetAfACHlCEzY1P5B_YrZPPeWqsoKrkuBBxcCo,34681
21 | diskcache/recipes.py,sha256=dr6CtZ67nRe3XMUBj2J6D_uftquMi_rNpZBg8i6lfRM,14922
22 | 


--------------------------------------------------------------------------------
/py310/Lib/site-packages/diskcache-5.6.3.dist-info/REQUESTED:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/diskcache-5.6.3.dist-info/REQUESTED


--------------------------------------------------------------------------------
/py310/Lib/site-packages/diskcache-5.6.3.dist-info/WHEEL:
--------------------------------------------------------------------------------
1 | Wheel-Version: 1.0
2 | Generator: bdist_wheel (0.41.2)
3 | Root-Is-Purelib: true
4 | Tag: py3-none-any
5 | 
6 | 


--------------------------------------------------------------------------------
/py310/Lib/site-packages/diskcache-5.6.3.dist-info/top_level.txt:
--------------------------------------------------------------------------------
1 | diskcache
2 | 


--------------------------------------------------------------------------------
/py310/Lib/site-packages/diskcache/__init__.py:
--------------------------------------------------------------------------------
 1 | """
 2 | DiskCache API Reference
 3 | =======================
 4 | 
 5 | The :doc:`tutorial` provides a helpful walkthrough of most methods.
 6 | """
 7 | 
 8 | from .core import (
 9 |     DEFAULT_SETTINGS,
10 |     ENOVAL,
11 |     EVICTION_POLICY,
12 |     UNKNOWN,
13 |     Cache,
14 |     Disk,
15 |     EmptyDirWarning,
16 |     JSONDisk,
17 |     Timeout,
18 |     UnknownFileWarning,
19 | )
20 | from .fanout import FanoutCache
21 | from .persistent import Deque, Index
22 | from .recipes import (
23 |     Averager,
24 |     BoundedSemaphore,
25 |     Lock,
26 |     RLock,
27 |     barrier,
28 |     memoize_stampede,
29 |     throttle,
30 | )
31 | 
32 | __all__ = [
33 |     'Averager',
34 |     'BoundedSemaphore',
35 |     'Cache',
36 |     'DEFAULT_SETTINGS',
37 |     'Deque',
38 |     'Disk',
39 |     'ENOVAL',
40 |     'EVICTION_POLICY',
41 |     'EmptyDirWarning',
42 |     'FanoutCache',
43 |     'Index',
44 |     'JSONDisk',
45 |     'Lock',
46 |     'RLock',
47 |     'Timeout',
48 |     'UNKNOWN',
49 |     'UnknownFileWarning',
50 |     'barrier',
51 |     'memoize_stampede',
52 |     'throttle',
53 | ]
54 | 
55 | try:
56 |     from .djangocache import DjangoCache  # noqa
57 | 
58 |     __all__.append('DjangoCache')
59 | except Exception:  # pylint: disable=broad-except  # pragma: no cover
60 |     # Django not installed or not setup so ignore.
61 |     pass
62 | 
63 | __title__ = 'diskcache'
64 | __version__ = '5.6.3'
65 | __build__ = 0x050603
66 | __author__ = 'Grant Jenks'
67 | __license__ = 'Apache 2.0'
68 | __copyright__ = 'Copyright 2016-2023 Grant Jenks'
69 | 


--------------------------------------------------------------------------------
/py310/Lib/site-packages/diskcache/cli.py:
--------------------------------------------------------------------------------
1 | """Command line interface to disk cache."""
2 | 


--------------------------------------------------------------------------------
/py310/Lib/site-packages/llama_cpp/__init__.py:
--------------------------------------------------------------------------------
1 | from .llama_cpp import *
2 | from .llama import *
3 | 
4 | __version__ = "0.3.9"
5 | 


--------------------------------------------------------------------------------
/py310/Lib/site-packages/llama_cpp/_ggml.py:
--------------------------------------------------------------------------------
 1 | """Internal module use at your own risk
 2 | 
 3 | This module provides a minimal interface for working with ggml tensors from llama-cpp-python
 4 | """
 5 | import os
 6 | import pathlib
 7 | 
 8 | import llama_cpp._ctypes_extensions as ctypes_ext
 9 | 
10 | libggml_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib"
11 | libggml = ctypes_ext.load_shared_library("ggml", libggml_base_path)
12 | 
13 | 


--------------------------------------------------------------------------------
/py310/Lib/site-packages/llama_cpp/_logger.py:
--------------------------------------------------------------------------------
 1 | import sys
 2 | import ctypes
 3 | import logging
 4 | 
 5 | import llama_cpp
 6 | 
 7 | # enum ggml_log_level {
 8 | #     GGML_LOG_LEVEL_NONE  = 0,
 9 | #     GGML_LOG_LEVEL_INFO  = 1,
10 | #     GGML_LOG_LEVEL_WARN  = 2,
11 | #     GGML_LOG_LEVEL_ERROR = 3,
12 | #     GGML_LOG_LEVEL_DEBUG = 4,
13 | #     GGML_LOG_LEVEL_CONT  = 5, // continue previous log
14 | # };
15 | GGML_LOG_LEVEL_TO_LOGGING_LEVEL = {
16 |     0: logging.CRITICAL,
17 |     1: logging.INFO,
18 |     2: logging.WARNING,
19 |     3: logging.ERROR,
20 |     4: logging.DEBUG,
21 |     5: logging.DEBUG,
22 | }
23 | 
24 | logger = logging.getLogger("llama-cpp-python")
25 | 
26 | _last_log_level = GGML_LOG_LEVEL_TO_LOGGING_LEVEL[0]
27 | 
28 | # typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data);
29 | @llama_cpp.llama_log_callback
30 | def llama_log_callback(
31 |     level: int,
32 |     text: bytes,
33 |     user_data: ctypes.c_void_p,
34 | ):
35 |     # TODO: Correctly implement continue previous log
36 |     global _last_log_level
37 |     log_level = GGML_LOG_LEVEL_TO_LOGGING_LEVEL[level] if level != 5 else _last_log_level
38 |     if logger.level <= GGML_LOG_LEVEL_TO_LOGGING_LEVEL[level]:
39 |         print(text.decode("utf-8"), end="", flush=True, file=sys.stderr)
40 |     _last_log_level = log_level
41 | 
42 | 
43 | llama_cpp.llama_log_set(llama_log_callback, ctypes.c_void_p(0))
44 | 
45 | 
46 | def set_verbose(verbose: bool):
47 |     logger.setLevel(logging.DEBUG if verbose else logging.ERROR)
48 | 


--------------------------------------------------------------------------------
/py310/Lib/site-packages/llama_cpp/_utils.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | import sys
 3 | 
 4 | from typing import Any, Dict
 5 | 
 6 | # Avoid "LookupError: unknown encoding: ascii" when open() called in a destructor
 7 | outnull_file = open(os.devnull, "w")
 8 | errnull_file = open(os.devnull, "w")
 9 | 
10 | STDOUT_FILENO = 1
11 | STDERR_FILENO = 2
12 | 
13 | 
14 | class suppress_stdout_stderr(object):
15 |     # NOTE: these must be "saved" here to avoid exceptions when using
16 |     #       this context manager inside of a __del__ method
17 |     sys = sys
18 |     os = os
19 | 
20 |     def __init__(self, disable: bool = True):
21 |         self.disable = disable
22 | 
23 |     # Oddly enough this works better than the contextlib version
24 |     def __enter__(self):
25 |         if self.disable:
26 |             return self
27 | 
28 |         self.old_stdout_fileno_undup = STDOUT_FILENO
29 |         self.old_stderr_fileno_undup = STDERR_FILENO
30 | 
31 |         self.old_stdout_fileno = self.os.dup(self.old_stdout_fileno_undup)
32 |         self.old_stderr_fileno = self.os.dup(self.old_stderr_fileno_undup)
33 | 
34 |         self.old_stdout = self.sys.stdout
35 |         self.old_stderr = self.sys.stderr
36 | 
37 |         self.os.dup2(outnull_file.fileno(), self.old_stdout_fileno_undup)
38 |         self.os.dup2(errnull_file.fileno(), self.old_stderr_fileno_undup)
39 | 
40 |         self.sys.stdout = outnull_file
41 |         self.sys.stderr = errnull_file
42 |         return self
43 | 
44 |     def __exit__(self, *_):
45 |         if self.disable:
46 |             return
47 | 
48 |         # Check if sys.stdout and sys.stderr have fileno method
49 |         self.sys.stdout = self.old_stdout
50 |         self.sys.stderr = self.old_stderr
51 | 
52 |         self.os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
53 |         self.os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
54 | 
55 |         self.os.close(self.old_stdout_fileno)
56 |         self.os.close(self.old_stderr_fileno)
57 | 
58 | 
59 | class MetaSingleton(type):
60 |     """
61 |     Metaclass for implementing the Singleton pattern.
62 |     """
63 | 
64 |     _instances: Dict[type, Any] = {}
65 | 
66 |     def __call__(cls, *args: Any, **kwargs: Any) -> Any:
67 |         if cls not in cls._instances:
68 |             cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs)
69 |         return cls._instances[cls]
70 | 
71 | 
72 | class Singleton(object, metaclass=MetaSingleton):
73 |     """
74 |     Base class for implementing the Singleton pattern.
75 |     """
76 | 
77 |     def __init__(self):
78 |         super(Singleton, self).__init__()
79 | 


--------------------------------------------------------------------------------
/py310/Lib/site-packages/llama_cpp/lib/ggml-base.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp/lib/ggml-base.dll


--------------------------------------------------------------------------------
/py310/Lib/site-packages/llama_cpp/lib/ggml-base.lib:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp/lib/ggml-base.lib


--------------------------------------------------------------------------------
/py310/Lib/site-packages/llama_cpp/lib/ggml-cpu.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp/lib/ggml-cpu.dll


--------------------------------------------------------------------------------
/py310/Lib/site-packages/llama_cpp/lib/ggml-cpu.lib:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp/lib/ggml-cpu.lib


--------------------------------------------------------------------------------
/py310/Lib/site-packages/llama_cpp/lib/ggml-vulkan.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp/lib/ggml-vulkan.dll


--------------------------------------------------------------------------------
/py310/Lib/site-packages/llama_cpp/lib/ggml-vulkan.lib:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp/lib/ggml-vulkan.lib


--------------------------------------------------------------------------------
/py310/Lib/site-packages/llama_cpp/lib/ggml.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp/lib/ggml.dll


--------------------------------------------------------------------------------
/py310/Lib/site-packages/llama_cpp/lib/ggml.lib:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp/lib/ggml.lib


--------------------------------------------------------------------------------
/py310/Lib/site-packages/llama_cpp/lib/llama.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp/lib/llama.dll


--------------------------------------------------------------------------------
/py310/Lib/site-packages/llama_cpp/lib/llama.lib:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp/lib/llama.lib


--------------------------------------------------------------------------------
/py310/Lib/site-packages/llama_cpp/lib/llava.dll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp/lib/llava.dll


--------------------------------------------------------------------------------
/py310/Lib/site-packages/llama_cpp/lib/llava.lib:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp/lib/llava.lib


--------------------------------------------------------------------------------
/py310/Lib/site-packages/llama_cpp/llama_speculative.py:
--------------------------------------------------------------------------------
 1 | import abc
 2 | 
 3 | from typing import Any
 4 | 
 5 | import numpy as np
 6 | import numpy.typing as npt
 7 | 
 8 | 
 9 | class LlamaDraftModel(abc.ABC):
10 |     @abc.abstractmethod
11 |     def __call__(
12 |         self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any
13 |     ) -> npt.NDArray[np.intc]:
14 |         raise NotImplementedError()
15 | 
16 | 
17 | class LlamaPromptLookupDecoding(LlamaDraftModel):
18 |     """Based on https://github.com/apoorvumang/prompt-lookup-decoding"""
19 | 
20 |     def __init__(self, max_ngram_size: int = 2, num_pred_tokens: int = 10):
21 |         self.max_ngram_size = max_ngram_size
22 |         self.num_pred_tokens = num_pred_tokens
23 | 
24 |     @staticmethod
25 |     def find_candidate_pred_tokens(
26 |         input_ids: npt.NDArray[np.intc],
27 |         max_ngram_size: int,
28 |         num_pred_tokens: int,
29 |     ):
30 |         input_length = input_ids.shape[0]
31 | 
32 |         for ngram_size in range(min(max_ngram_size, input_length - 1), 0, -1):
33 |             # Create sliding windows of size ngram_size
34 |             windows = np.lib.stride_tricks.sliding_window_view(input_ids, (ngram_size,))
35 | 
36 |             # Convert ngram to an array for comparison
37 |             ngram_array = input_ids[-ngram_size:]
38 | 
39 |             # Find where the windows match the ngram
40 |             matches = np.all(windows == ngram_array, axis=1)
41 | 
42 |             # Get the indices of matches
43 |             match_indices = np.nonzero(matches)[0]
44 | 
45 |             # Iterate through match indices to find a valid continuation
46 |             for idx in match_indices:
47 |                 start_idx = idx + ngram_size
48 |                 end_idx = start_idx + num_pred_tokens
49 |                 end_idx = min(end_idx, input_length)
50 | 
51 |                 if start_idx < end_idx:
52 |                     return input_ids[start_idx:end_idx]
53 | 
54 |         # If no match is found, return an empty array
55 |         return np.array([], dtype=np.intc)
56 | 
57 |     def __call__(
58 |         self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any
59 |     ) -> npt.NDArray[np.intc]:
60 |         return self.find_candidate_pred_tokens(
61 |             input_ids=input_ids,
62 |             max_ngram_size=self.max_ngram_size,
63 |             num_pred_tokens=self.num_pred_tokens,
64 |         )
65 | 


--------------------------------------------------------------------------------
/py310/Lib/site-packages/llama_cpp/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp/py.typed


--------------------------------------------------------------------------------
/py310/Lib/site-packages/llama_cpp/server/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp/server/__init__.py


--------------------------------------------------------------------------------
/py310/Lib/site-packages/llama_cpp_python-0.3.9.dist-info/INSTALLER:
--------------------------------------------------------------------------------
1 | pip
2 | 


--------------------------------------------------------------------------------
/py310/Lib/site-packages/llama_cpp_python-0.3.9.dist-info/REQUESTED:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/josStorer/RWKV-Runner/b7991a8adc5179a7ba29113e3826cb59438d47c4/py310/Lib/site-packages/llama_cpp_python-0.3.9.dist-info/REQUESTED


--------------------------------------------------------------------------------
/py310/Lib/site-packages/llama_cpp_python-0.3.9.dist-info/WHEEL:
--------------------------------------------------------------------------------
1 | Wheel-Version: 1.0
2 | Generator: scikit-build-core 0.11.4
3 | Root-Is-Purelib: false
4 | Tag: cp310-cp310-win_amd64
5 | 
6 | 


--------------------------------------------------------------------------------
/py310/Lib/site-packages/llama_cpp_python-0.3.9.dist-info/direct_url.json:
--------------------------------------------------------------------------------
1 | {"archive_info": {"hash": "sha256=ab2c671a12fd73b27844145c7283290815c1fde80123391211f28c9328109d63", "hashes": {"sha256": "ab2c671a12fd73b27844145c7283290815c1fde80123391211f28c9328109d63"}}, "url": "file:///C:/Users/%E7%94%A8%E6%88%B7/Downloads/llama_cpp_python-0.3.9-cp310-cp310-win_amd64.whl"}


--------------------------------------------------------------------------------
/py310/Lib/site-packages/llama_cpp_python-0.3.9.dist-info/licenses/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 | 
3 | Copyright (c) 2023 Andrei Betlen
4 | 
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6 | 
7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8 | 
9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


--------------------------------------------------------------------------------
/scripts/merge_manifest.py:
--------------------------------------------------------------------------------
 1 | import glob
 2 | import os, sys
 3 | 
 4 | MAIN_IMAGE_NAME=sys.argv[1]
 5 | TARGET_TAG="latest" if len(sys.argv) < 3 else sys.argv[2]
 6 | 
 7 | args=["docker manifest create {}:{}".format(MAIN_IMAGE_NAME, TARGET_TAG)]
 8 | for i in glob.glob("/tmp/images/*/*.txt"):
 9 |     with open(i, "r") as file:
10 |         args += " --amend {}@{}".format(MAIN_IMAGE_NAME, file.readline().strip())
11 | cmd_create="".join(args)
12 | cmd_push="docker manifest push {}:{}".format(MAIN_IMAGE_NAME, TARGET_TAG)
13 | os.system(cmd_create)
14 | os.system(cmd_push)


--------------------------------------------------------------------------------
/wails.json:
--------------------------------------------------------------------------------
 1 | {
 2 |   "$schema": "https://wails.io/schemas/config.v2.json",
 3 |   "name": "RWKV-Runner",
 4 |   "outputfilename": "RWKV-Runner",
 5 |   "frontend:install": "npm install",
 6 |   "frontend:build": "npm run build",
 7 |   "frontend:dev:watcher": "npm run dev",
 8 |   "frontend:dev:serverUrl": "auto",
 9 |   "author": {
10 |     "name": "josc146",
11 |     "email": "josStorer@outlook.com"
12 |   },
13 |   "Info": {
14 |     "companyName": "RWKV-Runner",
15 |     "productName": "RWKV-Runner",
16 |     "productVersion": "1.0.0",
17 |     "copyright": "Copyright © 2023 RWKV-Runner"
18 |   }
19 | }
20 | 


--------------------------------------------------------------------------------