├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── README.md ├── assets ├── 20B_tokenizer.json └── models │ └── .gitkeep ├── build.bat ├── convert ├── opslist.py ├── requirements.txt ├── to_onnx.py └── to_torchscript.py ├── main.cpp ├── src ├── http │ └── httplib.h ├── model │ ├── block.cpp │ ├── block.h │ ├── rwkv_interface.h │ ├── rwkv_onnx.cpp │ ├── rwkv_onnx.h │ ├── rwkv_torch.cpp │ └── rwkv_torch.h ├── pipeline.cpp ├── pipeline.h ├── rwkv_server.cpp ├── rwkv_server.h ├── rwkv_tokenizer.cpp └── rwkv_tokenizer.h └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pth 2 | *.pt 3 | CMakeFiles 4 | libtorch 5 | CMakeCache.txt 6 | cmake_install.cmake 7 | *.zip 8 | RWKVCPP 9 | .vscode 10 | Makefile 11 | lib 12 | Debug 13 | *.vcxproj* 14 | *.sln 15 | release 16 | Microsoft.* 17 | .vs 18 | *.dll 19 | tokenizer/release 20 | build 21 | *.onnx -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "tokenizer"] 2 | path = tokenizer 3 | url = https://github.com/ZeldaHuang/c_tokenizer 4 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.0) 2 | project(rwkv-server) 3 | find_package(Torch REQUIRED) 4 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") 5 | 6 | include_directories("Microsoft.ML.OnnxRuntime.DirectML.1.14.1/build/native/include") 7 | link_directories("Microsoft.ML.OnnxRuntime.DirectML.1.14.1/runtimes/win-x64/native") 8 | 9 | include_directories("tokenizer/bindings/c/include") 10 | link_directories("tokenizer/release") 11 | 12 | include_directories(${PROJECT_SOURCE_DIR}/src) 13 | include_directories(${PROJECT_SOURCE_DIR}/src/model) 14 | 15 | file(GLOB SRCS ${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp) 16 | file(GLOB MODLE_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/src/model/*.cpp) 17 | 18 | set(DEP_LIBS "${TORCH_LIBRARIES}" onnxruntime "../tokenizer/release/tokenizers.dll") 19 | 20 | add_executable(rwkv-server main.cpp ${SRCS} ${MODLE_SRCS} ) 21 | 22 | 23 | target_link_libraries(rwkv-server ${DEP_LIBS}) 24 | set_property(TARGET rwkv-server PROPERTY CXX_STANDARD 17) 25 | 26 | set_target_properties(rwkv-server PROPERTIES 27 | INSTALL_RPATH "$ORIGIN/lib") 28 | if (MSVC) 29 | file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll") 30 | 31 | add_custom_command(TARGET rwkv-server 32 | POST_BUILD 33 | COMMAND ${CMAKE_COMMAND} -E copy_if_different 34 | ${TORCH_DLLS} 35 | $) 36 | endif(MSVC) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # rwkv-cpp-server 3 | This project enable rwkv model running on windows with C++(**CPU/GPU**).You can run your own rwkv model service without any python dependence(just click a exe file). It provides following features: 4 | - support c tokenizer 5 | - support libtorch and onnxruntime inference 6 | - support server api by [chttplib](https://github.com/yhirose/cpp-httplib) 7 | - provide model convert script to convert rwkv checkpoint to torchscript/onnx file 8 | - provide client and server release file to use from scratch 9 | ## Build from source 10 | ### Prerequisite 11 | - Visual Studio 2022 12 | - cmake(version>=3.0) 13 | - [cargo](https://doc.rust-lang.org/cargo/getting-started/installation.html) 14 | ### Clone the repo 15 | ``` 16 | git clone --recursive https://github.com/ZeldaHuang/rwkv-cpp-server.git 17 | cd rwkv-cpp-server 18 | ``` 19 | ### Download libtorch 20 | Download libtorch with `curl -O https://download.pytorch.org/libtorch/cpu/libtorch-win-shared-with-deps-2.0.0%2Bcpu.zip` and unzip it to source folder. 21 | ### Download onnxruntime 22 | [Download onnxruntime](https://github.com/microsoft/onnxruntime/releases/download/v1.14.1/Microsoft.ML.OnnxRuntime.DirectML.1.14.1.zip) and unzip it to source folder. 23 | ### Compile 24 | Run `build.bat`.Release dir path is `build/release`,it contains the `rwkv-server.exe` and all dependence. 25 | 26 | ## Deploy rwkv server 27 | 28 | ### Convert models 29 | Download rwkv model from [huggingface](https://huggingface.co/BlinkDL), then convert `.pth` model to torchscript/onnx. 30 | ``` 31 | python convert/to_onnx.py 32 | python convert/to_torchscript.py 33 | ``` 34 | Place the torchscript/onnx model in `release/assets/models`. By default the first `.pt` or `.onnx` file in this dir will be loaded. 35 | ### Run server 36 | Execute `rwkv-server.exe` in `release` file with `rwkv-server.exe ${model_path} ${ip} ${port}`, you can test the service with `test.py` or open the client app to chat. 37 | -------------------------------------------------------------------------------- /assets/models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZeldaHuang/rwkv-cpp-server/b82805fc4129297975c22c4a38914b9de7cb418f/assets/models/.gitkeep -------------------------------------------------------------------------------- /build.bat: -------------------------------------------------------------------------------- 1 | cd tokenizer 2 | cmake . 3 | cmake --build . --config Release 4 | cd .. 5 | mkdir build 6 | cd build 7 | rmdir /s/q Debug 8 | rmdir /s/q Release 9 | cmake .. 10 | cmake -DCMAKE_PREFIX_PATH="libtorch" .. 11 | mkdir release 12 | mkdir release\assets 13 | cmake --build . --config Release 14 | copy ..\Microsoft.ML.OnnxRuntime.DirectML.1.14.1\runtimes\win-x64\native\onnxruntime.dll .\release\onnxruntime.dll 15 | copy ..\tokenizer\release\tokenizers.dll .\release\tokenizers.dll 16 | xcopy ..\assets\ .\release\assets\ /E -------------------------------------------------------------------------------- /convert/opslist.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class RWKVOnnxOps(): 5 | 6 | def __init__(self, layers, embed, *args, dtype=None, **kwargs): 7 | import onnx 8 | self.n_layers = layers 9 | self.n_embed = embed 10 | 11 | print("embed ", embed) 12 | 13 | dtype = onnx.TensorProto.FLOAT if dtype == np.float32 else onnx.TensorProto.FLOAT16 if dtype == np.float16 else onnx.TensorProto.BFLOAT16 if dtype == np.bfloat16 else onnx.TensorProto.FLOAT 14 | nptype = np.float32 if dtype == onnx.TensorProto.FLOAT else np.float16 if dtype == onnx.TensorProto.FLOAT16 else np.float16 if dtype == onnx.TensorProto.BFLOAT16 else np.float32 15 | 16 | self.nm = 0 17 | exportname = f"RWKV_{layers}_{embed}_{'32' if dtype == onnx.TensorProto.FLOAT else '16'}.onnx" 18 | externalname = f"RWKV_{layers}_{embed}_{'32' if dtype == onnx.TensorProto.FLOAT else '16'}.bin" 19 | 20 | # remove old files 21 | import os 22 | if os.path.exists(exportname): 23 | os.remove(exportname) 24 | if os.path.exists(externalname): 25 | os.remove(externalname) 26 | 27 | self.TensorList = [] 28 | self.NodeList = [] 29 | 30 | def initTensor(x): 31 | name = f"PreTrainedTensor_{self.nm}" 32 | self.nm += 1 33 | if isinstance(x, list): 34 | xx = np.array(x).astype(nptype) 35 | else: 36 | xx = x.squeeze().float().cpu().numpy() 37 | # convert to float32 38 | xx = xx.astype(nptype) 39 | rrx = onnx.helper.make_tensor( 40 | name, 41 | dtype, 42 | xx.shape, 43 | xx.tobytes(), 44 | raw=True 45 | 46 | ) 47 | 48 | onnx.external_data_helper.set_external_data( 49 | rrx, 50 | location=externalname, 51 | 52 | ) 53 | 54 | self.TensorList.append(rrx) 55 | return name 56 | 57 | self.initTensor = initTensor 58 | 59 | def sqrt(x): 60 | name = f"sqrt_{self.nm}_out" 61 | self.nm += 1 62 | node = onnx.helper.make_node( 63 | 'Sqrt', 64 | inputs=[x], 65 | outputs=[name] 66 | ) 67 | self.NodeList.append(node) 68 | 69 | return name 70 | 71 | self.sqrt = sqrt 72 | 73 | def mean(x): 74 | name = f"mean_{self.nm}_out" 75 | self.nm += 1 76 | node = onnx.helper.make_node( 77 | 'ReduceMean', 78 | inputs=[x], 79 | outputs=[name] 80 | ) 81 | self.NodeList.append(node) 82 | 83 | return name 84 | 85 | self.mean = mean 86 | 87 | def relu(x): 88 | name = f"relu_{self.nm}_out" 89 | self.nm += 1 90 | node = onnx.helper.make_node( 91 | 'Relu', 92 | inputs=[x], 93 | outputs=[name] 94 | ) 95 | self.NodeList.append(node) 96 | 97 | return name 98 | 99 | self.relu = relu 100 | 101 | def exp(x): 102 | name = f"exp_{self.nm}_out" 103 | self.nm += 1 104 | node = onnx.helper.make_node( 105 | 'Exp', 106 | inputs=[x], 107 | outputs=[name] 108 | ) 109 | self.NodeList.append(node) 110 | 111 | return name 112 | 113 | self.exp = exp 114 | 115 | def stack(x): 116 | return [initTensor(r) for r in x] 117 | 118 | self.stack = stack 119 | 120 | def matvec(x, y, is_output=False): 121 | if is_output: 122 | name= "output_token" 123 | else: 124 | name = f"matvec_{self.nm}_out" 125 | oname = f"matvec_g_{self.nm}_out" 126 | self.nm += 1 127 | node = onnx.helper.make_node( 128 | 'MatMul', 129 | inputs=[x, y], 130 | outputs=[name] 131 | ) 132 | self.NodeList.append(node) 133 | return name 134 | 135 | self.matvec = matvec 136 | 137 | def prod(x): 138 | name = f"prod_{self.nm}_out" 139 | self.nm += 1 140 | node = onnx.helper.make_node( 141 | 'ReduceProd', 142 | inputs=[x], 143 | outputs=[name], 144 | axes=[1], 145 | keepdims=0 146 | 147 | 148 | ) 149 | self.NodeList.append(node) 150 | 151 | return name 152 | 153 | self.prod = prod 154 | 155 | def mul(x, y): 156 | name = f"mul_{self.nm}_out" 157 | self.nm += 1 158 | node = onnx.helper.make_node( 159 | 'Mul', 160 | inputs=[x, y], 161 | outputs=[name] 162 | ) 163 | self.NodeList.append(node) 164 | 165 | return name 166 | 167 | self.multiply = mul 168 | 169 | def squeeze(x): 170 | name = f"squeeze_{self.nm}_out" 171 | self.nm += 1 172 | node = onnx.helper.make_node( 173 | 'Squeeze', 174 | inputs=[x], 175 | outputs=[name] 176 | ) 177 | self.NodeList.append(node) 178 | 179 | return name 180 | 181 | def add(x, y): 182 | 183 | name = f"add_{self.nm}_out" 184 | self.nm += 1 185 | node = onnx.helper.make_node( 186 | 'Add', 187 | inputs=[x, y], 188 | outputs=[name] 189 | ) 190 | self.NodeList.append(node) 191 | 192 | return name 193 | 194 | self.add = add 195 | 196 | def sub(x, y): 197 | name = f"sub_{self.nm}_out" 198 | self.nm += 1 199 | node = onnx.helper.make_node( 200 | 'Sub', 201 | inputs=[x, y], 202 | outputs=[name] 203 | ) 204 | self.NodeList.append(node) 205 | 206 | return name 207 | 208 | self.subtract = sub 209 | 210 | self.one = initTensor([1.0]*embed) 211 | 212 | def lerpx(x, y, z): 213 | return self.add(x, self.multiply(self.subtract(y, x), z)) 214 | 215 | self.lerp = lerpx 216 | 217 | def minimum(x, y): 218 | name = f"minimum_{self.nm}_out" 219 | self.nm += 1 220 | node = onnx.helper.make_node( 221 | 'Min', 222 | inputs=[x, y], 223 | outputs=[name] 224 | ) 225 | self.NodeList.append(node) 226 | 227 | return name 228 | self.minimum = minimum 229 | # module def 230 | self.module = object 231 | 232 | def log(x): 233 | name = f"log_{self.nm}_out" 234 | self.nm += 1 235 | node = onnx.helper.make_node( 236 | 'Log', 237 | inputs=[x], 238 | outputs=[name] 239 | ) 240 | self.NodeList.append(node) 241 | 242 | return name 243 | 244 | self.log = log 245 | 246 | # pytorch function defs 247 | self.initfunc = lambda x: x 248 | self.layerdef = lambda x: x 249 | self.mainfunc = lambda x: x 250 | 251 | def divide(x, y): 252 | name = f"divide_{self.nm}_out" 253 | self.nm += 1 254 | node = onnx.helper.make_node( 255 | 'Div', 256 | inputs=[x, y], 257 | outputs=[name] 258 | ) 259 | self.NodeList.append(node) 260 | 261 | return name 262 | 263 | self.divide = divide 264 | 265 | def layernorm(x, w, b): 266 | name = f"layernorm_{self.nm}_out" 267 | self.nm += 1 268 | node = onnx.helper.make_node( 269 | 'LayerNormalization', 270 | inputs=[x, w, b], 271 | outputs=[name] 272 | ) 273 | self.NodeList.append(node) 274 | 275 | return name 276 | 277 | self.layernorm = layernorm 278 | 279 | def getIndex(x, y): 280 | name = f"getIndex_{self.nm}_out" 281 | self.nm += 1 282 | node = onnx.helper.make_node( 283 | 'Gather', 284 | inputs=[x, y], 285 | outputs=[name] 286 | ) 287 | self.NodeList.append(node) 288 | 289 | return squeeze(name) 290 | 291 | self.stackEmbed = False 292 | 293 | def neg(x): 294 | name = f"neg_{self.nm}_out" 295 | self.nm += 1 296 | node = onnx.helper.make_node( 297 | 'Neg', 298 | inputs=[x], 299 | outputs=[name] 300 | ) 301 | self.NodeList.append(node) 302 | 303 | return name 304 | 305 | self.neg = neg 306 | 307 | def logistic(x): 308 | name = f"logistic_{self.nm}_out" 309 | self.nm += 1 310 | node = onnx.helper.make_node( 311 | 'Sigmoid', 312 | inputs=[x], 313 | outputs=[name] 314 | ) 315 | self.NodeList.append(node) 316 | 317 | return name 318 | self.logistical = logistic 319 | 320 | def maximum(x, y): 321 | name = f"maximum_{self.nm}_out" 322 | self.nm += 1 323 | node = onnx.helper.make_node( 324 | 'Max', 325 | inputs=[x, y], 326 | outputs=[name] 327 | ) 328 | self.NodeList.append(node) 329 | 330 | return name 331 | 332 | self.maximum = maximum 333 | 334 | self.getIndex = getIndex 335 | 336 | # convert to float32 337 | self.emptyState = np.array((([[0.00]*embed, [0.00]*embed, [0.00]*embed, [ 338 | 0.00]*embed]+[[-1e30]*embed]))*layers) 339 | self.emptyState = np.array(self.emptyState, dtype=nptype) 340 | 341 | # self.zero = initTensor([0.0]*embed) 342 | def reshape(x): 343 | name = f"output_state" 344 | self.nm += 1 345 | node = onnx.helper.make_node( 346 | 'Reshape', 347 | inputs=[x,"state_shape"], 348 | outputs=[name], 349 | ) 350 | self.NodeList.append(node) 351 | 352 | return name 353 | def concat(x): 354 | name = f"concat{self.nm}_out" 355 | self.nm += 1 356 | node = onnx.helper.make_node( 357 | 'Concat', 358 | inputs=[*x], 359 | outputs=[name], 360 | axis=0, 361 | ) 362 | 363 | self.NodeList.append(node) 364 | 365 | return reshape(name) 366 | 367 | def getStateIndex(x, y): 368 | name = f"getStateIndex_{self.nm}_out" 369 | self.nm += 1 370 | node = onnx.helper.make_node( 371 | 'Gather', 372 | inputs=[x,y], 373 | outputs=[name] 374 | ) 375 | self.NodeList.append(node) 376 | 377 | return name 378 | self.getStateIndex= getStateIndex 379 | self.concat=concat 380 | self.reshape=reshape 381 | def ppm(x): 382 | inputtensor = onnx.helper.make_tensor_value_info("input_token", 383 | onnx.TensorProto.INT32, 384 | [1]), "input_token" 385 | 386 | # emptyState = list(map(lambda x: (onnx.helper.make_tensor_value_info("instate"+str(x), 387 | # dtype, 388 | # [embed]), "instate"+str(x)), range(5*layers))) 389 | emptyState = onnx.helper.make_tensor_value_info("input_state", 390 | dtype, 391 | [5*layers,embed]), "input_state" 392 | print(inputtensor[1],emptyState[1]) 393 | 394 | outs = x.forward( 395 | inputtensor[1], emptyState[1]) 396 | for i in range(5*layers): 397 | self.TensorList.append(onnx.helper.make_tensor(str(i),onnx.TensorProto.INT32,[],[i])) 398 | self.TensorList.append(onnx.helper.make_tensor('state_shape',onnx.TensorProto.INT64,[2],[5*layers,embed])) 399 | print(self.TensorList.__len__()) 400 | print(self.NodeList.__len__()) 401 | print(outs) 402 | logits = onnx.helper.make_tensor_value_info(outs[0], 403 | dtype, 404 | [50277]) 405 | 406 | state = onnx.helper.make_tensor_value_info(outs[1], 407 | dtype, 408 | [5*layers,embed]) 409 | 410 | # Create the graph (GraphProto) 411 | graph_def = onnx.helper.make_graph( 412 | nodes=self.NodeList, # The list of nodes in the graph. 413 | name="RWKV", 414 | # Graph input 415 | 416 | inputs=[inputtensor[0], emptyState[0]], 417 | 418 | outputs=[logits, state], # Graph output 419 | 420 | initializer=self.TensorList, # initializer 421 | 422 | 423 | 424 | # did not work, needs to be external 425 | 426 | ) 427 | 428 | modelDef = onnx.helper.make_model( 429 | graph_def, producer_name="rwkvstic", 430 | 431 | 432 | ) 433 | 434 | modelDef.opset_import[0].version = 17 435 | 436 | onnx.save(modelDef, exportname) 437 | 438 | # run model 439 | print("Model saved to: ", exportname, " and is ready to be run") 440 | print("Data type: ", dtype) 441 | print("Embedding size: ", embed) 442 | print("Number of layers: ", layers) 443 | print("external data: ", externalname) 444 | exit() 445 | self.postProcessModule = ppm 446 | -------------------------------------------------------------------------------- /convert/requirements.txt: -------------------------------------------------------------------------------- 1 | rwkv 2 | torch 3 | numpy 4 | requests 5 | tkinter -------------------------------------------------------------------------------- /convert/to_onnx.py: -------------------------------------------------------------------------------- 1 | 2 | def RnnRWKV(ops, *args): 3 | class myRWKV(ops.module): 4 | 5 | @ ops.initfunc 6 | def __init__(self, w): 7 | super(myRWKV, self).__init__() 8 | print("Legacy RWKV") 9 | 10 | self.ops = ops 11 | self.postprocess0 = ops.initTensor((w["ln_out.weight"])) 12 | self.postprocess1 = ops.initTensor((w["ln_out.bias"])) 13 | self.postprocess2 = ops.initTensor((w["head.weight"])) 14 | self.emb = ops.initTensor(w["emb.weight"]) 15 | self.emb1 = ops.initTensor(w["blocks.0.ln0.weight"]) 16 | self.emb2 = ops.initTensor(w["blocks.0.ln0.bias"]) 17 | self.ln1w = (ops.stack( 18 | [w[f"blocks.{x}.ln1.weight"] for x in range(ops.n_layers)])) 19 | self.ln1b = (ops.stack( 20 | [w[f"blocks.{x}.ln1.bias"] for x in range(ops.n_layers)])) 21 | self.ln2w = (ops.stack( 22 | [w[f"blocks.{x}.ln2.weight"] for x in range(ops.n_layers)])) 23 | self.ln2b = (ops.stack( 24 | [w[f"blocks.{x}.ln2.bias"] for x in range(ops.n_layers)])) 25 | self.time_decay = (ops.stack([ 26 | w[f"blocks.{x}.att.time_decay"].double().exp().neg() for x in range(ops.n_layers)])) 27 | self.time_first = (ops.stack([ 28 | w[f"blocks.{x}.att.time_first"] for x in range(ops.n_layers)])) 29 | self.kktk = (ops.stack( 30 | [w[f"blocks.{x}.att.time_mix_k"] for x in range(ops.n_layers)])) 31 | self.vvtv = (ops.stack( 32 | [w[f"blocks.{x}.att.time_mix_v"] for x in range(ops.n_layers)])) 33 | self.rrtr = (ops.stack( 34 | [w[f"blocks.{x}.att.time_mix_r"] for x in range(ops.n_layers)])) 35 | self.key = (ops.stack( 36 | [w[f"blocks.{x}.att.key.weight"] for x in range(ops.n_layers)])) 37 | self.value = (ops.stack( 38 | [w[f"blocks.{x}.att.value.weight"] for x in range(ops.n_layers)])) 39 | self.receptance = (ops.stack([ 40 | w[f"blocks.{x}.att.receptance.weight"] for x in range(ops.n_layers)])) 41 | self.outputvv = (ops.stack([ 42 | w[f"blocks.{x}.att.output.weight"] for x in range(ops.n_layers)])) 43 | self.time_mix_k_ffn = (ops.stack([ 44 | w[f"blocks.{x}.ffn.time_mix_k"] for x in range(ops.n_layers)])) 45 | self.time_mix_r_ffn = (ops.stack([ 46 | w[f"blocks.{x}.ffn.time_mix_r"] for x in range(ops.n_layers)])) 47 | self.key_ffn = (ops.stack( 48 | [w[f"blocks.{x}.ffn.key.weight"] for x in range(ops.n_layers)])) 49 | self.receptance_ffn = (ops.stack([ 50 | w[f"blocks.{x}.ffn.receptance.weight"] for x in range(ops.n_layers)])) 51 | self.value_ffn = (ops.stack([ 52 | w[f"blocks.{x}.ffn.value.weight"] for x in range(ops.n_layers)])) 53 | 54 | @ops.layerdef 55 | def doLayer(self, x, statea, stateb, statec, stated, statee, xx): 56 | 57 | xy = ops.layernorm(x, self.ln1w[xx], self.ln1b[xx]) 58 | 59 | k = ops.matvec( 60 | self.key[xx], ops.lerp(statea, xy, self.kktk[xx])) 61 | 62 | v = ops.matvec(self.value[xx], ops.lerp( 63 | statea, xy, self.vvtv[xx])) 64 | rr = ops.matvec( 65 | self.receptance[xx], ops.lerp(statea, xy, self.rrtr[xx])) 66 | r = ops.logistical((rr)) 67 | 68 | ww = ops.add(k, self.time_first[xx]) 69 | p = ops.maximum(statee, ww) 70 | 71 | e1 = ops.exp(ops.subtract(statee, p)) 72 | e2 = ops.exp(ops.subtract(ww, p)) 73 | a = ops.add(ops.multiply(e1, stateb), ops.multiply(e2, v)) 74 | b = ops.add(ops.multiply(e1, statec), e2) 75 | ww = ops.add(statee, self.time_decay[xx]) 76 | 77 | p = ops.maximum(ww, k) 78 | 79 | e1 = ops.exp(ops.subtract(ww, p)) 80 | e2 = ops.exp(ops.subtract(k, p)) 81 | outb = ops.add(ops.multiply(e1, stateb), ops.multiply(e2, v)) 82 | outc = ops.add(ops.multiply(e1, statec), e2) 83 | eee = p 84 | wkv = ops.divide(a, b) 85 | 86 | mvv = ops.add(x, ops.matvec( 87 | self.outputvv[xx], ops.multiply(r, wkv))) 88 | 89 | ddd = ops.layernorm(mvv, self.ln2w[xx], self.ln2b[xx]) 90 | 91 | km = ops.relu(ops.matvec(self.key_ffn[xx], ops.lerp( 92 | stated, ddd, self.time_mix_k_ffn[xx]))) 93 | 94 | rt = ops.logistical((ops.matvec(self.receptance_ffn[xx], ops.lerp( 95 | stated, ddd, self.time_mix_r_ffn[xx])))) 96 | 97 | x = ops.add(mvv, ops.multiply( 98 | ops.matvec(self.value_ffn[xx], ops.multiply(km, km)), rt)) 99 | 100 | return x, xy, outb, outc, ddd, eee 101 | 102 | @ ops.mainfunc 103 | def forward(self, x, state = None): 104 | 105 | if (state is None): 106 | state = ops.emptyState 107 | 108 | x = ops.layernorm( 109 | ops.getIndex(self.emb, x), 110 | self.emb1, self.emb2) 111 | 112 | # statea = state[0::5] 113 | # stateb = state[1::5] 114 | # statec = state[2::5] 115 | # stated = state[3::5] 116 | # statee = state[4::5] 117 | 118 | ot = [] 119 | 120 | for i in range(ops.n_layers): 121 | statea = ops.getStateIndex(state,str(i*5)) 122 | stateb = ops.getStateIndex(state,str(i*5+1)) 123 | statec = ops.getStateIndex(state,str(i*5+2)) 124 | stated = ops.getStateIndex(state,str(i*5+3)) 125 | statee = ops.getStateIndex(state,str(i*5+4)) 126 | x, aaa, bbb, ccc, ddd, eee = self.doLayer( 127 | x, statea, stateb,statec, stated, statee, i) 128 | ot = ot + [aaa, bbb, ccc, ddd, eee] 129 | 130 | x = ops.matvec(self.postprocess2, ops.layernorm(x, self.postprocess0, 131 | self.postprocess1),True) 132 | ot = ops.concat(ot) 133 | return x, ot 134 | 135 | 136 | ops.postProcessModule(myRWKV(*args)) 137 | 138 | 139 | import opslist 140 | 141 | import torch 142 | 143 | def convert_model(path, dtype): 144 | w = torch.load(path, map_location="cpu") 145 | dims = len(w["blocks.0.att.key.weight"]) 146 | layers = len( 147 | list(filter(lambda x: "blocks" in x and "ln1.bias" in x, w.keys()))) 148 | 149 | 150 | ops = opslist.RWKVOnnxOps(layers,dims,dtype=dtype) 151 | 152 | RnnRWKV(ops,w) 153 | 154 | 155 | import tkinter as tk 156 | from tkinter import filedialog 157 | 158 | 159 | # Create the main window 160 | root = tk.Tk() 161 | root.title("File Converter") 162 | 163 | # Define the functions 164 | def choose_input_file(): 165 | input_file = filedialog.askopenfilename() 166 | input_path.set(input_file) 167 | 168 | import numpy as np 169 | def convert(): 170 | path = input_path.get() 171 | dtype = np.float16 if use_fp16.get() else np.float32 172 | convert_model(path, dtype) 173 | 174 | # Define the variables 175 | input_path = tk.StringVar() 176 | use_fp16 = tk.BooleanVar(value=True) 177 | 178 | # Create the widgets 179 | input_label = tk.Label(root, text="Input Path:") 180 | input_entry = tk.Entry(root, textvariable=input_path) 181 | input_button = tk.Button(root, text="Browse...", command=choose_input_file) 182 | 183 | 184 | check_button = tk.Checkbutton(root, text="Use fp16", variable=use_fp16) 185 | 186 | convert_button = tk.Button(root, text="Convert", command=convert) 187 | 188 | # Add the widgets to the window 189 | input_label.grid(row=0, column=0) 190 | input_entry.grid(row=0, column=1) 191 | input_button.grid(row=0, column=2) 192 | 193 | check_button.grid(row=2, column=0) 194 | 195 | convert_button.grid(row=3, column=1) 196 | 197 | # Start the main event loop 198 | root.mainloop() 199 | -------------------------------------------------------------------------------- /convert/to_torchscript.py: -------------------------------------------------------------------------------- 1 | from tkinter import filedialog 2 | import torch 3 | 4 | 5 | class Container(torch.nn.Module): 6 | def __init__(self, my_values): 7 | super().__init__() 8 | for key in my_values: 9 | setattr(self, key, my_values[key]) 10 | dims = my_values["blocks.0.ln0.bias"].shape[0] 11 | layers = len(list(filter(lambda k: k.startswith( 12 | "blocks.") and k.endswith(".ln1.bias"), my_values.keys()))) 13 | 14 | print("dims", dims) 15 | print("layers", layers) 16 | 17 | emptyState = torch.zeros(layers, 5, dims) 18 | for i in range(layers): 19 | emptyState[i][4] -= 1e30 20 | setattr(self, "emptyState", emptyState) 21 | 22 | 23 | # open file selector, only show .pth files 24 | path = filedialog.askopenfilename( 25 | initialdir="./", title="Select file", filetypes=(("pth files", "*.pth"), ("all files", "*.*"))) 26 | 27 | my_values = torch.load(path, map_location="cpu") 28 | 29 | # Save arbitrary values supported by TorchScript 30 | # https://pytorch.org/docs/master/jit.html#supported-type 31 | container = torch.jit.script(Container(my_values)) 32 | output_path = filedialog.asksaveasfilename( 33 | initialdir="./", title="Select file", filetypes=(("pt files", "*.pt"), ("all files", "*.*"))) 34 | container.save(output_path) -------------------------------------------------------------------------------- /main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "src/rwkv_server.h" 5 | 6 | void tokenizer_test() 7 | { 8 | std::string path = "./assets/20B_tokenizer.json"; 9 | RWKVTokenizer t = RWKVTokenizer("./assets/20B_tokenizer.json"); 10 | std::string s = "\n"; 11 | std::vector input_tokens = t.encodeTokens(s); 12 | std::string output_str = t.decodeTokens(input_tokens); 13 | std::cout << "done: " << output_str << std::endl; 14 | } 15 | void inference_speed_test(std::string model_path) 16 | { 17 | torch::Device device = torch::kCPU; 18 | std::cout << "CUDA DEVICE COUNT: " << torch::cuda::device_count() << std::endl; 19 | if (torch::cuda::is_available()) { 20 | std::cout << "CUDA is available! Inference on GPU." << std::endl; 21 | device = torch::kCUDA; 22 | } 23 | // Convert to torch dtype 24 | auto torch_dtype = torch::kFloat32; 25 | auto torch_runtimedtype = torch::kFloat32; 26 | torch::NoGradGuard no_grad; 27 | RWKVTorch rwkv(model_path, torch_dtype, torch_runtimedtype, device); 28 | // warm up 29 | torch::Tensor x; 30 | torch::Tensor state = rwkv.empty_state_; 31 | for (int i = 0; i < 10; i++) 32 | { 33 | std::tie(x, state) = rwkv.forward(torch::ones(1).to(torch::kInt32).to(device), rwkv.empty_state_.clone()); 34 | } 35 | std::cout << "finisn warmup" << std::endl; 36 | auto time = std::chrono::high_resolution_clock::now(); 37 | for (int i = 0; i < 10; ++i) 38 | { 39 | std::tie(x, state) = rwkv.forward(torch::ones(10).to(torch::kInt32).to(device) * 178, rwkv.empty_state_.clone()); 40 | } 41 | auto time2 = std::chrono::high_resolution_clock::now(); 42 | std::cout << "sequence inference Time: " << std::chrono::duration_cast(time2 - time).count() / 10 << "ms / 10 tokens" << std::endl; 43 | time = std::chrono::high_resolution_clock::now(); 44 | for (int i = 0; i < 10; ++i) 45 | { 46 | std::tie(x, state) = rwkv.forward(torch::ones(100).to(torch::kInt32).to(device) * 178, rwkv.empty_state_.clone()); 47 | } 48 | time2 = std::chrono::high_resolution_clock::now(); 49 | std::cout << "sequence inference Time: " << std::chrono::duration_cast(time2 - time).count() / 10 << "ms / 100 tokens" << std::endl; 50 | time = std::chrono::high_resolution_clock::now(); 51 | for (int i = 0; i < 10; ++i) 52 | { 53 | std::tie(x, state) = rwkv.forward(torch::ones(1000).to(torch::kInt32).to(device) * 178, rwkv.empty_state_.clone()); 54 | } 55 | time2 = std::chrono::high_resolution_clock::now(); 56 | std::cout << "sequence inference Time: " << std::chrono::duration_cast(time2 - time).count() / 10 << "ms / 1000 tokens" << std::endl; 57 | time = std::chrono::high_resolution_clock::now(); 58 | for (int i = 0; i < 100; ++i) 59 | { 60 | std::tie(x, state) = rwkv.forward(torch::ones(1).to(torch::kInt32).to(device) * 178, rwkv.empty_state_.clone()); 61 | } 62 | time2 = std::chrono::high_resolution_clock::now(); 63 | std::cout << "single token inference Time: " << std::chrono::duration_cast(time2 - time).count() / 100 << "ms / token" << std::endl; 64 | } 65 | 66 | void onnx_test() 67 | { 68 | std::string model_path = "./assets/models/rwkv_model.onnx"; 69 | 70 | RWKVONNX onnx_model = RWKVONNX(model_path); 71 | torch::Tensor out; 72 | torch::Tensor state; 73 | for (int i = 0; i < 10; ++i) 74 | { 75 | onnx_model.forward(torch::ones(1).to(torch::kInt32) * 178, onnx_model.empty_state_.clone()); 76 | } 77 | auto time = std::chrono::high_resolution_clock::now(); 78 | for (int i = 0; i < 10; ++i) 79 | { 80 | std::tie(out, state) = onnx_model.forward(torch::ones(100).to(torch::kInt32), onnx_model.empty_state_.clone()); 81 | } 82 | auto time2 = std::chrono::high_resolution_clock::now(); 83 | std::cout << "Time: " << std::chrono::duration_cast(time2 - time).count() / 10 << "ms / 100 tokens" << std::endl; 84 | } 85 | 86 | std::string find_model_path(char *argv[]) 87 | { 88 | std::string exe_path = argv[0]; 89 | std::size_t pos = exe_path.find_last_of("\\/"); 90 | exe_path = exe_path.substr(0, pos); 91 | std::filesystem::path models_assets_path(exe_path + "\\/assets\\/models"); 92 | for (auto &entry : std::filesystem::directory_iterator(models_assets_path)) 93 | { 94 | std::string model_path = entry.path().string(); 95 | std::size_t pos = model_path.find_last_of("."); 96 | if (pos >= model_path.size()) 97 | { 98 | continue; 99 | } 100 | if (model_path.substr(pos) == ".pt" || model_path.substr(pos) == ".onnx") 101 | { 102 | return model_path; 103 | } 104 | } 105 | return ""; 106 | } 107 | void start_server(std::string &model_path, std::string &ip, int port) 108 | { 109 | std::string tokenizer_path = "./assets/20B_tokenizer.json"; 110 | std::string model_type; 111 | std::size_t pos = model_path.find_last_of("."); 112 | if (pos < model_path.size() && model_path.substr(pos) == ".pt") 113 | { 114 | model_type = "libtorch"; 115 | } 116 | else if (pos < model_path.size() && model_path.substr(pos) == ".onnx") 117 | { 118 | model_type = "onnx"; 119 | } 120 | else 121 | { 122 | std::cout << "invalid model_path, support xxx.pt or xxx.onnx" << std::endl; 123 | return; 124 | } 125 | std::cout << "using model:" << model_path << std::endl; 126 | RWKVPipeline pipeline = RWKVPipeline(model_path, tokenizer_path, model_type); 127 | RWKVServer server = RWKVServer(pipeline); 128 | server.start(ip, port); 129 | } 130 | 131 | int main(int argc, char *argv[]) 132 | { 133 | std::string model_path; 134 | std::string port; 135 | std::string ip; 136 | 137 | try 138 | { 139 | if (argc <= 1) 140 | { 141 | throw ""; 142 | } 143 | model_path = argv[1]; 144 | } 145 | catch (...) 146 | { 147 | std::cout << "No model_path specified, finding in assets/models" << std::endl; 148 | model_path = find_model_path(argv); 149 | } 150 | 151 | try 152 | { 153 | if (argc <= 2) 154 | { 155 | throw ""; 156 | } 157 | ip = std::string(argv[2]); 158 | } 159 | catch (...) 160 | { 161 | std::cout << "No ip specified, default localhost" << std::endl; 162 | ip = "0.0.0.0"; 163 | } 164 | 165 | try 166 | { 167 | if (argc <= 3) 168 | { 169 | throw ""; 170 | } 171 | port = std::string(argv[3]); 172 | } 173 | catch (...) 174 | { 175 | std::cout << "No port specified, default 5000" << std::endl; 176 | port = "5000"; 177 | } 178 | torch::NoGradGuard no_grad; 179 | try 180 | { 181 | /* code */ 182 | start_server(model_path, ip, std::stoi(port)); 183 | } 184 | catch(const std::exception& e) 185 | { 186 | std::cerr << e.what() << '\n'; 187 | } 188 | // inference_speed_test(model_path); 189 | // tokenizer_test(); 190 | // onnx_test(); 191 | } -------------------------------------------------------------------------------- /src/model/block.cpp: -------------------------------------------------------------------------------- 1 | #include "block.h" 2 | 3 | Block::Block(int dims) 4 | { 5 | 6 | ln1 = torch::nn::LayerNorm(torch::nn::LayerNormOptions({dims})); 7 | ln2 = torch::nn::LayerNorm(torch::nn::LayerNormOptions({dims})); 8 | att_key = torch::nn::Linear(dims, dims); 9 | att_value = torch::nn::Linear(dims, dims); 10 | att_receptance = torch::nn::Linear(dims, dims); 11 | att_out = torch::nn::Linear(dims, dims); 12 | ffn_key = torch::nn::Linear(dims, dims * 4); 13 | ffn_value = torch::nn::Linear(dims * 4, dims); 14 | ffn_receptance = torch::nn::Linear(dims, dims); 15 | time_first = torch::zeros({dims}); 16 | time_decay = torch::zeros({dims}); 17 | att_time_mix_k = torch::zeros({dims}); 18 | att_time_mix_v = torch::zeros({dims}); 19 | att_time_mix_r = torch::zeros({dims}); 20 | ffn_time_mix_k = torch::zeros({dims}); 21 | ffn_time_mix_r = torch::zeros({dims}); 22 | } 23 | 24 | Block::Block(int i, torch::jit::script::Module w, c10::ScalarType dtype, c10::ScalarType runtime_dtype, torch::Device device) 25 | { 26 | int dims = w.attr("blocks." + std::to_string(i) + ".att.key.weight").toTensor().size(0); 27 | ln1 = torch::nn::LayerNorm(torch::nn::LayerNormOptions({dims})); 28 | ln2 = torch::nn::LayerNorm(torch::nn::LayerNormOptions({dims})); 29 | att_key = torch::nn::Linear(dims, dims); 30 | att_value = torch::nn::Linear(dims, dims); 31 | att_receptance = torch::nn::Linear(dims, dims); 32 | att_out = torch::nn::Linear(dims, dims); 33 | ffn_key = torch::nn::Linear(dims, dims * 4); 34 | ffn_value = torch::nn::Linear(dims * 4, dims); 35 | ffn_receptance = torch::nn::Linear(dims, dims); 36 | time_first = w.attr("blocks." + std::to_string(i) + ".att.time_first").toTensor().squeeze().to(runtime_dtype).to(device); 37 | time_decay = w.attr("blocks." + std::to_string(i) + ".att.time_decay").toTensor().squeeze().exp().neg().to(runtime_dtype).to(device); 38 | 39 | att_time_mix_k = w.attr("blocks." + std::to_string(i) + ".att.time_mix_k").toTensor().squeeze().to(runtime_dtype).to(device); 40 | att_time_mix_v = w.attr("blocks." + std::to_string(i) + ".att.time_mix_v").toTensor().squeeze().to(runtime_dtype).to(device); 41 | att_time_mix_r = w.attr("blocks." + std::to_string(i) + ".att.time_mix_r").toTensor().squeeze().to(runtime_dtype).to(device); 42 | ffn_time_mix_k = w.attr("blocks." + std::to_string(i) + ".ffn.time_mix_k").toTensor().squeeze().to(runtime_dtype).to(device); 43 | ffn_time_mix_r = w.attr("blocks." + std::to_string(i) + ".ffn.time_mix_r").toTensor().squeeze().to(runtime_dtype).to(device); 44 | 45 | ln1->weight = w.attr("blocks." + std::to_string(i) + ".ln1.weight").toTensor().squeeze().to(runtime_dtype).to(device); 46 | ln1->bias = w.attr("blocks." + std::to_string(i) + ".ln1.bias").toTensor().squeeze().to(runtime_dtype).to(device); 47 | ln2->weight = w.attr("blocks." + std::to_string(i) + ".ln2.weight").toTensor().squeeze().to(runtime_dtype).to(device); 48 | ln2->bias = w.attr("blocks." + std::to_string(i) + ".ln2.bias").toTensor().squeeze().to(runtime_dtype).to(device); 49 | att_key->weight = w.attr("blocks." + std::to_string(i) + ".att.key.weight").toTensor().squeeze().to(dtype).to(device).t(); 50 | att_value->weight = w.attr("blocks." + std::to_string(i) + ".att.value.weight").toTensor().squeeze().to(dtype).to(device).t(); 51 | att_receptance->weight = w.attr("blocks." + std::to_string(i) + ".att.receptance.weight").toTensor().squeeze().to(dtype).to(device).t(); 52 | att_out->weight = w.attr("blocks." + std::to_string(i) + ".att.output.weight").toTensor().squeeze().to(dtype).to(device).t(); 53 | ffn_key->weight = w.attr("blocks." + std::to_string(i) + ".ffn.key.weight").toTensor().squeeze().to(dtype).to(device).t(); 54 | ffn_value->weight = w.attr("blocks." + std::to_string(i) + ".ffn.value.weight").toTensor().squeeze().to(dtype).to(device).t(); 55 | ffn_receptance->weight = w.attr("blocks." + std::to_string(i) + ".ffn.receptance.weight").toTensor().squeeze().to(dtype).to(device).t(); 56 | 57 | 58 | dtype_ = dtype; 59 | runtime_dtype_ = runtime_dtype; 60 | 61 | block_idx_ = i; 62 | } 63 | 64 | torch::Tensor Block::FF_seq(torch::Tensor x, torch::Tensor state, torch::Tensor time_mix_k, torch::Tensor time_mix_r, torch::Tensor kw, torch::Tensor vw, torch::Tensor rw) 65 | { 66 | auto xx = torch::cat({state[0].to(dtype_).unsqueeze(0), x.slice(0, 0, -1)}); 67 | auto xk = x * time_mix_k + xx * (1 - time_mix_k); 68 | auto xr = x * time_mix_r + xx * (1 - time_mix_r); 69 | state[0] = x[-1].to(dtype_); 70 | 71 | auto r = torch::sigmoid(xr.matmul(rw)); 72 | auto k = torch::square(torch::relu(xk.matmul(kw))); 73 | auto kv = k.matmul(vw); 74 | return r * kv; 75 | } 76 | 77 | torch::Tensor Block::FF_one(torch::Tensor x, torch::Tensor state, torch::Tensor time_mix_k, torch::Tensor time_mix_r, torch::Tensor kw, torch::Tensor vw, torch::Tensor rw) 78 | { 79 | auto xx = state[0].to(dtype_); 80 | auto xk = x * time_mix_k + xx * (1 - time_mix_k); 81 | auto xr = x * time_mix_r + xx * (1 - time_mix_r); 82 | state[0] = x.to(dtype_); 83 | 84 | auto r = torch::sigmoid(xr.matmul(rw)); 85 | auto k = torch::square(torch::relu(xk.matmul(kw))); 86 | auto kv = k.matmul(vw); 87 | return r * kv; 88 | } 89 | 90 | torch::Tensor Block::SA_one(torch::Tensor x, torch::Tensor state, torch::Tensor time_mix_k, torch::Tensor time_mix_v, torch::Tensor time_mix_r, torch::Tensor time_first, torch::Tensor time_decay, torch::Tensor kw, torch::Tensor vw, torch::Tensor rw, torch::Tensor ow) 91 | { 92 | auto xx = state[1].to(dtype_); 93 | auto xk = x * time_mix_k + xx * (1 - time_mix_k); 94 | auto xv = x * time_mix_v + xx * (1 - time_mix_v); 95 | auto xr = x * time_mix_r + xx * (1 - time_mix_r); 96 | state[1] = x.to(dtype_); 97 | 98 | auto r = torch::sigmoid(xr.matmul(rw)); 99 | auto k = (xk.matmul(kw)).to(dtype_); 100 | auto v = (xv.matmul(vw)).to(dtype_); 101 | 102 | auto aa = state[2]; 103 | auto bb = state[3]; 104 | auto pp = state[4]; 105 | auto ww = time_first + k; 106 | auto p = torch::max(pp, ww); 107 | auto e1 = torch::exp(pp - p); 108 | auto e2 = torch::exp(ww - p); 109 | auto a = e1 * aa + e2 * v; 110 | auto b = e1 * bb + e2; 111 | ww = pp + time_decay; 112 | p = torch::max(ww, k); 113 | e1 = torch::exp(ww - p); 114 | e2 = torch::exp(k - p); 115 | state[2] = e1 * aa + e2 * v; 116 | state[3] = e1 * bb + e2; 117 | state[4] = p; 118 | auto wkv = (a / b).to(dtype_); 119 | return (r * wkv).matmul(ow); 120 | } 121 | 122 | torch::Tensor Block::SA_seq(torch::Tensor x, torch::Tensor state, torch::Tensor time_mix_k, torch::Tensor time_mix_v, torch::Tensor time_mix_r, torch::Tensor time_first, torch::Tensor time_decay, torch::Tensor kw, torch::Tensor vw, torch::Tensor rw, torch::Tensor ow) 123 | { 124 | auto xx = torch::cat({state[1].to(dtype_).unsqueeze(0), x.slice(0, 0, -1)}); 125 | auto xk = x * time_mix_k + xx * (1 - time_mix_k); 126 | auto xv = x * time_mix_v + xx * (1 - time_mix_v); 127 | auto xr = x * time_mix_r + xx * (1 - time_mix_r); 128 | state[1] = x[-1].to(dtype_); 129 | 130 | auto r = torch::sigmoid(xr.matmul(rw)); 131 | auto k = (xk.matmul(kw)).to(dtype_); 132 | auto v = (xv.matmul(vw)).to(dtype_); 133 | 134 | auto aa = state[2]; 135 | auto bb = state[3]; 136 | auto pp = state[4]; 137 | int T = x.size(0); 138 | for (int t = 0; t < T; t++) 139 | { 140 | auto ww = time_first + k[t]; 141 | auto p = torch::max(pp, ww); 142 | auto e1 = torch::exp(pp - p); 143 | auto e2 = torch::exp(ww - p); 144 | auto a = e1 * aa + e2 * v[t]; 145 | auto b = e1 * bb + e2; 146 | ww = pp + time_decay; 147 | p = torch::max(ww, k[t]); 148 | e1 = torch::exp(ww - p); 149 | e2 = torch::exp(k[t] - p); 150 | if (t != T - 1) 151 | { 152 | aa = e1 * aa + e2 * v[t]; 153 | bb = e1 * bb + e2; 154 | pp = p; 155 | } 156 | else 157 | { 158 | state[2] = e1 * aa + e2 * v[t]; 159 | state[3] = e1 * bb + e2; 160 | state[4] = p; 161 | } 162 | xx[t] = (a / b).to(dtype_); 163 | } 164 | return (r * xx).matmul(ow); 165 | } 166 | 167 | std::tuple Block::forward_seq(torch::Tensor x, torch::Tensor state) 168 | { 169 | x = x + SA_seq(ln1(x), state, att_time_mix_k, att_time_mix_v, att_time_mix_r, time_first, time_decay, att_key->weight, att_value->weight, att_receptance->weight, att_out->weight); 170 | x = x + FF_seq(ln2(x), state, ffn_time_mix_k, ffn_time_mix_r, ffn_key->weight, ffn_value->weight, ffn_receptance->weight); 171 | 172 | return std::make_tuple(x, state); 173 | } 174 | 175 | std::tuple Block::forward_one(torch::Tensor x, torch::Tensor state) 176 | { 177 | x = x + SA_one(ln1(x), state, att_time_mix_k, att_time_mix_v, att_time_mix_r, time_first, time_decay, att_key->weight, att_value->weight, att_receptance->weight, att_out->weight); 178 | x = x + FF_one(ln2(x), state, ffn_time_mix_k, ffn_time_mix_r, ffn_key->weight, ffn_value->weight, ffn_receptance->weight); 179 | 180 | return std::make_tuple(x, state); 181 | } 182 | -------------------------------------------------------------------------------- /src/model/block.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | // RWKV libtorch Block implementation, reference to pytorch implementation:https://github.com/BlinkDL/ChatRWKV/blob/main/RWKV_in_150_lines.py 9 | class Block : public torch::nn::Module 10 | { 11 | public: 12 | Block(int dims); 13 | 14 | Block(int i, torch::jit::script::Module w, c10::ScalarType dtype, c10::ScalarType runtime_dtype, torch::Device device); 15 | 16 | torch::Tensor FF_seq(torch::Tensor x, torch::Tensor state, torch::Tensor time_mix_k, torch::Tensor time_mix_r, torch::Tensor kw, torch::Tensor vw, torch::Tensor rw); 17 | 18 | torch::Tensor FF_one(torch::Tensor x, torch::Tensor state, torch::Tensor time_mix_k, torch::Tensor time_mix_r, torch::Tensor kw, torch::Tensor vw, torch::Tensor rw); 19 | 20 | torch::Tensor SA_seq(torch::Tensor x, torch::Tensor state, torch::Tensor time_mix_k, torch::Tensor time_mix_v, torch::Tensor time_mix_r, torch::Tensor time_first, torch::Tensor time_decay, torch::Tensor kw, torch::Tensor vw, torch::Tensor rw, torch::Tensor ow); 21 | 22 | torch::Tensor SA_one(torch::Tensor x, torch::Tensor state, torch::Tensor time_mix_k, torch::Tensor time_mix_v, torch::Tensor time_mix_r, torch::Tensor time_first, torch::Tensor time_decay, torch::Tensor kw, torch::Tensor vw, torch::Tensor rw, torch::Tensor ow); 23 | 24 | std::tuple forward_seq(torch::Tensor x, torch::Tensor state); 25 | 26 | std::tuple forward_one(torch::Tensor x, torch::Tensor state); 27 | 28 | private: 29 | torch::nn::LayerNorm ln1 = nullptr; 30 | torch::nn::LayerNorm ln2 = nullptr; 31 | torch::nn::Linear att_key = nullptr; 32 | torch::nn::Linear att_value = nullptr; 33 | torch::nn::Linear att_receptance = nullptr; 34 | torch::nn::Linear att_out = nullptr; 35 | torch::nn::Linear ffn_key = nullptr; 36 | torch::nn::Linear ffn_value = nullptr; 37 | torch::nn::Linear ffn_receptance = nullptr; 38 | torch::Tensor time_first; 39 | torch::Tensor time_decay; 40 | torch::Tensor att_time_mix_k; 41 | torch::Tensor att_time_mix_v; 42 | torch::Tensor att_time_mix_r; 43 | torch::Tensor ffn_time_mix_k; 44 | torch::Tensor ffn_time_mix_r; 45 | c10::ScalarType dtype_ = torch::kFloat32; 46 | c10::ScalarType runtime_dtype_ = torch::kFloat64; 47 | 48 | int block_idx_; 49 | }; -------------------------------------------------------------------------------- /src/model/rwkv_interface.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | // common interface of RWKV model, Base class of onnx/libtorch model 6 | class RWKVInterface{ 7 | public: 8 | virtual std::tuple forward(torch::Tensor x, torch::Tensor state) = 0; 9 | torch::Tensor empty_state_; 10 | }; -------------------------------------------------------------------------------- /src/model/rwkv_onnx.cpp: -------------------------------------------------------------------------------- 1 | #include "rwkv_onnx.h" 2 | 3 | RWKVONNX::RWKVONNX(std::string path) : env(ORT_LOGGING_LEVEL_ERROR, "rwkv_onnx"), session_{env, std::wstring(path.begin(), path.end()).c_str(), Ort::SessionOptions{nullptr}} 4 | { 5 | 6 | Ort::TypeInfo state_type_info = session_.GetInputTypeInfo(1); 7 | state_shape_ = state_type_info.GetTensorTypeAndShapeInfo().GetShape(); 8 | 9 | Ort::TypeInfo output_tokens_type_info = session_.GetOutputTypeInfo(0); 10 | output_tokens_shape_ = output_tokens_type_info.GetTensorTypeAndShapeInfo().GetShape(); 11 | 12 | empty_state_ = torch::zeros({state_shape_[0], state_shape_[1]}); 13 | for (int i = 0; i < state_shape_[0] / 5; ++i) 14 | { 15 | empty_state_[i * 5 + 4] -= 1e30; 16 | } 17 | std::cout << "init onnx model success" << std::endl; 18 | } 19 | std::tuple RWKVONNX::forward(torch::Tensor x, torch::Tensor state) 20 | { 21 | x = x.to(torch::kInt32); 22 | state = state.to(torch::kFloat32); 23 | if (x.size(0) == 1) 24 | { 25 | return this->forward_single_token(x, state); 26 | } 27 | else 28 | { 29 | torch::Tensor input_x = torch::empty({1}).to(torch::kInt32); 30 | torch::Tensor res_x; 31 | for (int i = 0; i < x.size(0); ++i) 32 | { 33 | input_x[0] = x[i]; 34 | std::tie(res_x, state) = this->forward_single_token(input_x, state); 35 | } 36 | return std::make_tuple(res_x, state); 37 | } 38 | } 39 | 40 | std::tuple RWKVONNX::forward_single_token(torch::Tensor x, torch::Tensor state) 41 | { 42 | Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu( 43 | OrtAllocatorType::OrtDeviceAllocator, OrtMemType::OrtMemTypeCPU); 44 | Ort::Value input_tokens_tensor = Ort::Value::CreateTensor(memory_info, x.data(), input_tokens_shape_[0], 45 | input_tokens_shape_.data(), input_tokens_shape_.size()); 46 | // onnx input_tensor 47 | assert(input_tokens_tensor.IsTensor()); 48 | std::vector input_tensor; 49 | input_tensor.push_back(std::move(input_tokens_tensor)); 50 | 51 | size_t state_size = state_shape_[0] * state_shape_[1]; 52 | Ort::Value input_state_tensor = Ort::Value::CreateTensor(memory_info, state.data(), state_size, 53 | state_shape_.data(), state_shape_.size()); 54 | assert(input_state_tensor.IsTensor()); 55 | input_tensor.push_back(std::move(input_state_tensor)); 56 | // onnx output_tensor 57 | 58 | auto ort_output = session_.Run(Ort::RunOptions{nullptr}, input_names_.data(), input_tensor.data(), input_tensor.size(), output_names_.data(), output_names_.size()); 59 | // std::cout << "inference done" << std::endl; 60 | 61 | // std::cout<<"onnx inference done "<(), {output_tokens_shape_[0]}); 63 | torch::Tensor output_state = torch::from_blob(ort_output[1].GetTensorMutableData(), {state_shape_[0], state_shape_[1]}); 64 | return std::make_tuple(output_x, output_state); 65 | } -------------------------------------------------------------------------------- /src/model/rwkv_onnx.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | #include "rwkv_interface.h" 10 | 11 | //rwkv onnx model 12 | class RWKVONNX : public RWKVInterface 13 | { 14 | public: 15 | RWKVONNX(std::string path); 16 | 17 | std::tuple forward(torch::Tensor x, torch::Tensor state); 18 | 19 | std::tuple forward_single_token(torch::Tensor x, torch::Tensor state); 20 | 21 | private: 22 | Ort::Env env; 23 | Ort::Session session_; 24 | std::vector input_tokens_shape_ = {1}; 25 | std::vector output_tokens_shape_; 26 | std::vector state_shape_; 27 | std::vector input_names_ = {"input_token", "input_state"}; 28 | std::vector output_names_ = {"output_token", "output_state"}; 29 | }; 30 | -------------------------------------------------------------------------------- /src/model/rwkv_torch.cpp: -------------------------------------------------------------------------------- 1 | #include "rwkv_torch.h" 2 | 3 | RWKVTorch::RWKVTorch(int dims, int layers, int headsize) 4 | { 5 | head = torch::nn::Linear(dims, headsize); 6 | emb = torch::nn::Embedding(headsize, dims); 7 | ln_out = torch::nn::LayerNorm(torch::nn::LayerNormOptions({dims})); 8 | ln_in = torch::nn::LayerNorm(torch::nn::LayerNormOptions({dims})); 9 | for (int i = 0; i < layers; i++) 10 | { 11 | blocks.push_back(Block(dims)); 12 | } 13 | this->eval(); 14 | } 15 | 16 | RWKVTorch::RWKVTorch(std::string path, c10::ScalarType dtype, c10::ScalarType runtime_dtype, torch::Device device) 17 | { 18 | torch::jit::script::Module w = torch::jit::load(path); 19 | head = torch::nn::Linear(w.attr("head.weight").toTensor().sizes()[1], w.attr("head.weight").toTensor().sizes()[0]); 20 | head->weight = w.attr("head.weight").toTensor().to(dtype).to(device); 21 | head->bias = head->bias.zero_().to(device); 22 | ln_in = torch::nn::LayerNorm(torch::nn::LayerNormOptions({w.attr("blocks.0.ln0.bias").toTensor().sizes()[0]})); 23 | ln_in->bias = w.attr("blocks.0.ln0.bias").toTensor().to(runtime_dtype).to(device); 24 | ln_in->weight = w.attr("blocks.0.ln0.weight").toTensor().to(runtime_dtype).to(device); 25 | ln_out = torch::nn::LayerNorm(torch::nn::LayerNormOptions({w.attr("ln_out.weight").toTensor().sizes()[0]})); 26 | ln_out->weight = w.attr("ln_out.weight").toTensor().to(runtime_dtype).to(device); 27 | ln_out->bias = w.attr("ln_out.bias").toTensor().to(runtime_dtype).to(device); 28 | emb = torch::nn::Embedding(w.attr("emb.weight").toTensor().sizes()[0], w.attr("emb.weight").toTensor().sizes()[1]); 29 | emb->weight = w.attr("emb.weight").toTensor().to(runtime_dtype).to(device); 30 | 31 | for (int i = 0; i < 100; i++) 32 | { 33 | if (w.hasattr("blocks." + std::to_string(i) + ".ln1.bias")) 34 | { 35 | blocks.push_back(Block(i, w, dtype, runtime_dtype,device)); 36 | } 37 | else 38 | { 39 | break; 40 | } 41 | } 42 | empty_state_ = w.attr("emptyState").toTensor().to(runtime_dtype).to(device); 43 | dtype_ = dtype; 44 | this->eval(); 45 | } 46 | 47 | std::tuple RWKVTorch::forward(torch::Tensor x, torch::Tensor state) 48 | { 49 | bool seq_mode = x.size(0) > 1; 50 | x = seq_mode ? emb(x) : emb(x[0]); 51 | x = ln_in(x); 52 | for (int i = 0; i < blocks.size(); i++) 53 | { 54 | torch::Tensor rstate; 55 | std::tie(x, rstate) = seq_mode ? blocks[i].forward_seq(x, state[i]) : blocks[i].forward_one(x, state[i]); 56 | state[i] = rstate; 57 | } 58 | x = seq_mode ? ln_out(x[-1]).to(dtype_) : ln_out(x).to(dtype_); 59 | torch::Tensor outx = head(x); 60 | return std::make_tuple(outx, state); 61 | } -------------------------------------------------------------------------------- /src/model/rwkv_torch.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "rwkv_interface.h" 4 | #include "block.h" 5 | 6 | // rwkv libtorch model 7 | class RWKVTorch : public torch::nn::Module,public RWKVInterface 8 | { 9 | public: 10 | RWKVTorch(int dims, int layers, int headsize); 11 | 12 | RWKVTorch(std::string path, c10::ScalarType dtype, c10::ScalarType runtime_dtype, torch::Device device); 13 | 14 | std::tuple forward(torch::Tensor x, torch::Tensor state) override; 15 | 16 | private: 17 | torch::nn::Linear head = nullptr; 18 | torch::nn::Embedding emb = nullptr; 19 | torch::nn::LayerNorm ln_out = nullptr; 20 | torch::nn::LayerNorm ln_in = nullptr; 21 | 22 | std::vector blocks = {}; 23 | c10::ScalarType dtype_; 24 | }; -------------------------------------------------------------------------------- /src/pipeline.cpp: -------------------------------------------------------------------------------- 1 | #include "pipeline.h" 2 | 3 | RWKVPipeline::RWKVPipeline(std::string &model_path, std::string &tokenizer_path, std::string model_type) : tokenizer_(tokenizer_path) 4 | { 5 | if(model_type=="libtorch"){ 6 | std::cout << "CUDA DEVICE COUNT: " << torch::cuda::device_count() << std::endl; 7 | if (torch::cuda::is_available()) { 8 | std::cout << "CUDA is available! Inference on GPU." << std::endl; 9 | device_ = torch::kCUDA; 10 | } 11 | model_ptr_ = std::make_shared(model_path, torch::kFloat32, torch::kFloat32, device_); 12 | } 13 | else{ 14 | model_ptr_ = std::make_shared(model_path); 15 | } 16 | prev_state_ = model_ptr_->empty_state_; 17 | } 18 | 19 | uint32_t RWKVPipeline::sample_logits(torch::Tensor logits, float temperature=1.0, float top_p=0.85, int top_k=0) 20 | { 21 | auto probs = torch::softmax(logits, -1); 22 | auto sorted_ids = torch::argsort(probs); 23 | auto sorted_probs = probs.index({sorted_ids}); 24 | sorted_probs = torch::flip(sorted_probs, {0}); 25 | auto cumulative_probs = torch::cumsum(sorted_probs, -1); 26 | // cumulative_probs = cumulative_probs.masked_fill(cumulative_probs < top_p, 0.0); 27 | auto cutoff = sorted_probs[torch::argmax((cumulative_probs > top_p).to(torch::kInt32))].to(torch::kFloat32); 28 | probs = probs.masked_fill(probs < cutoff, 0.0); 29 | if (top_k < probs.size(0) && top_k > 0) { 30 | probs.index({sorted_ids.slice(0, -top_k)}) = 0; 31 | } 32 | if (temperature != 1.0) { 33 | probs = torch::pow(probs, 1.0 / temperature); 34 | } 35 | auto out = torch::multinomial(probs, 1)[0]; 36 | return (uint32_t)out.item(); 37 | } 38 | 39 | std::string RWKVPipeline::generate(std::string &context, float temperature, float top_p, int token_count, float alpha_presence, float alpha_frequency, std::vector &stop_tokens, std::vector &ban_tokens) 40 | { 41 | std::vector input_tokens = tokenizer_.encodeTokens(context); 42 | torch::TensorOptions opts = torch::TensorOptions().dtype(torch::kInt32); 43 | torch::Tensor state = prev_state_.clone(); 44 | torch::Tensor last_out; 45 | std::map occurrence; 46 | torch::Tensor t = torch::from_blob(input_tokens.data(), {(int)input_tokens.size()}, opts).to(device_); 47 | std::tie(last_out, state) = model_ptr_->forward(t, state); 48 | std::vector output_tokens; 49 | for (int i = 0; i < token_count; ++i) 50 | { 51 | for (uint32_t ban_tok : ban_tokens) 52 | { 53 | last_out[ban_tok] -=1e30; 54 | } 55 | for (auto &par : occurrence) 56 | { 57 | uint32_t occurent_tok = par.first; 58 | uint32_t occurent_cnt = par.second; 59 | last_out[occurent_tok] -= (alpha_presence + occurent_cnt * alpha_frequency); 60 | } 61 | 62 | uint32_t tok = sample_logits(last_out, temperature, top_p, 100); 63 | if (tok == 0 || std::count(stop_tokens.begin(), stop_tokens.end(), tok)) 64 | { 65 | std::cout<<"break"<forward(t, state); 72 | } 73 | prev_state_ = state; 74 | std::string output_str = tokenizer_.decodeTokens(output_tokens); 75 | 76 | return output_str; 77 | } -------------------------------------------------------------------------------- /src/pipeline.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "model/rwkv_interface.h" 4 | #include "model/rwkv_onnx.h" 5 | #include "model/rwkv_torch.h" 6 | #include "rwkv_tokenizer.h" 7 | 8 | // rwkv pipeline, reference to :https://github.com/BlinkDL/ChatRWKV/blob/db57d70bd151fbbd3fd1fb7e67e18a052cfaab6e/rwkv_pip_package/src/rwkv/utils.py 9 | class RWKVPipeline 10 | { 11 | public: 12 | RWKVPipeline(std::string &model_path, std::string &tokenizer_path, std::string model_type); 13 | 14 | uint32_t sample_logits(torch::Tensor logits, float temperature, float top_p, int top_k); 15 | 16 | std::string generate(std::string &context, float temperature, float top_p, int token_count, float alpha_presence, float alpha_frequency, std::vector &stop_tokens, std::vector &ban_tokens); 17 | 18 | private: 19 | std::shared_ptr model_ptr_; 20 | RWKVTokenizer tokenizer_; 21 | torch::Tensor prev_state_; 22 | torch::Device device_ = torch::kCPU; 23 | }; 24 | -------------------------------------------------------------------------------- /src/rwkv_server.cpp: -------------------------------------------------------------------------------- 1 | #include "rwkv_server.h" 2 | 3 | RWKVServer::RWKVServer(RWKVPipeline &pipeline):pipeline_(pipeline) 4 | { 5 | svr_.Get("/api/chat", 6 | [this](const httplib::Request &req, httplib::Response &res){ 7 | std::string context = req.get_param_value("text"); 8 | float temperature = std::stof(req.get_param_value("temperature")); 9 | float top_p = std::stof(req.get_param_value("topP")); 10 | int token_count = std::stoi(req.get_param_value("tokenCount")); 11 | float fresence_penalty = std::stof(req.get_param_value("presencePenalty")); 12 | float count_penalty = std::stof(req.get_param_value("countPenalty")); 13 | // std::cout<<"request:"<pipeline_.generate(context, temperature, top_p, token_count, fresence_penalty, count_penalty, std::vector{}, std::vector{0}); 18 | } 19 | catch(const std::exception& e) 20 | { 21 | std::cerr << e.what() << '\n'; 22 | } 23 | res.set_content(out_str,"text/plain"); 24 | } 25 | ); 26 | svr_.Get("/api/write", 27 | [this](const httplib::Request &req, httplib::Response &res){ 28 | std::string context = req.get_param_value("text"); 29 | float temperature = std::stof(req.get_param_value("temperature")); 30 | float top_p = std::stof(req.get_param_value("topP")); 31 | int token_count = std::stoi(req.get_param_value("tokenCount")); 32 | float fresence_penalty = std::stof(req.get_param_value("presencePenalty")); 33 | float count_penalty = std::stof(req.get_param_value("countPenalty")); 34 | std::string out_str; 35 | try 36 | { 37 | out_str = this->pipeline_.generate(context, temperature, top_p, token_count, fresence_penalty, count_penalty, std::vector{}, std::vector{0}); 38 | } 39 | catch(const std::exception& e) 40 | { 41 | std::cerr << e.what() << '\n'; 42 | } 43 | 44 | res.set_content(out_str,"text/plain"); 45 | } 46 | ); 47 | } 48 | 49 | void RWKVServer::start(std::string ip, uint32_t port) 50 | { 51 | std::cout << "start listening" << std::endl; 52 | svr_.listen(ip, port); 53 | std::cout << "end"; 54 | } -------------------------------------------------------------------------------- /src/rwkv_server.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "http/httplib.h" 4 | 5 | #include "pipeline.h" 6 | 7 | class RWKVServer 8 | { 9 | public: 10 | RWKVServer(RWKVPipeline &pipeline); 11 | 12 | void start(std::string ip, uint32_t port); 13 | 14 | private: 15 | httplib::Server svr_; 16 | RWKVPipeline pipeline_; 17 | }; -------------------------------------------------------------------------------- /src/rwkv_tokenizer.cpp: -------------------------------------------------------------------------------- 1 | #include "rwkv_tokenizer.h" 2 | 3 | RWKVTokenizer::RWKVTokenizer(const std::string &path) 4 | { 5 | tokenizer_ = tokenizer_create_from_file(path.c_str()); 6 | } 7 | 8 | RWKVTokenizer::~RWKVTokenizer() 9 | { 10 | if (tokenizer_ != nullptr) 11 | { 12 | tokenizer_destroy(tokenizer_); 13 | } 14 | } 15 | 16 | std::string RWKVTokenizer::decodeTokens(std::vector &tokens) 17 | { 18 | auto allocator = [](size_t size, void *payload) -> void * 19 | { 20 | return malloc(size); 21 | }; 22 | CArrayRef input_ids = CArrayRef(tokens); 23 | c_str_t c_str; 24 | tokenizer_decode(tokenizer_, true, 0, allocator, &input_ids, &c_str); 25 | std::string decode_str = c_str.to_string(); 26 | free((void *)c_str.ptr); 27 | return decode_str; 28 | } 29 | 30 | std::vector RWKVTokenizer::encodeTokens(std::string &str) 31 | { 32 | Encoding_t *encoded = tokenizer_encode(tokenizer_, str.c_str(), true); 33 | CArrayRef array; 34 | std::vector tokens_vec; 35 | encoding_get_ids(encoded, &array); 36 | for (int i = 0; i < array.length; i++) 37 | { 38 | tokens_vec.push_back(array[i]); 39 | } 40 | return tokens_vec; 41 | } -------------------------------------------------------------------------------- /src/rwkv_tokenizer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "tokenizer.h" 9 | 10 | class RWKVTokenizer 11 | { 12 | public: 13 | 14 | RWKVTokenizer(const std::string &path); 15 | 16 | ~RWKVTokenizer(); 17 | 18 | std::string decodeTokens(std::vector &tokens); 19 | 20 | std::vector encodeTokens(std::string &str); 21 | 22 | Tokenizer_t *tokenizer_ = nullptr; 23 | }; 24 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import requests 2 | params={ 3 | "tokenCount":150, 4 | "temperature":1.2, 5 | "topP":0.5, 6 | "presencePenalty":0.4, 7 | "countPenalty":0.4, 8 | # "text":"\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese." 9 | "text":"hello", 10 | } 11 | r = requests.get("http://127.0.0.1:5000/api/chat", params=params) 12 | print(r.content) 13 | 14 | 15 | r = requests.get("http://127.0.0.1:5000/api/write", params=params) 16 | print(r.content) --------------------------------------------------------------------------------