├── .devcontainer
├── Dockerfile-client
├── Dockerfile-cuda
├── Dockerfile-prod-cpp
├── Dockerfile-prod-vllm
├── Dockerfile-vllm
├── client
│ └── devcontainer.json
├── common.dockerfile
├── cuda-settings.dockerfile
├── cuda
│ └── devcontainer.json
├── vllm-requirements.txt
└── vllm
│ └── devcontainer.json
├── .gitattributes
├── .github
└── workflows
│ ├── aicirt-release.yml
│ ├── aicirt.yml
│ ├── codeql.yml
│ ├── links.yml
│ └── rllm-cuda.yml
├── .gitignore
├── .gitmodules
├── .vscode
├── launch.json
├── settings.json
└── tasks.json
├── CODE_OF_CONDUCT.md
├── Cargo.lock
├── Cargo.toml
├── LICENSE
├── NOTICE.md
├── README.md
├── SECURITY.md
├── SUPPORT.md
├── TRANSPARENCY.md
├── aici.sh
├── aicirt
├── Cargo.toml
├── README.md
├── src
│ ├── api.rs
│ ├── bench.rs
│ ├── futexshm.rs
│ ├── hostimpl.rs
│ ├── lib.rs
│ ├── macos.rs
│ ├── main.rs
│ ├── moduleinstance.rs
│ ├── msgchannel.rs
│ ├── semaphore.rs
│ ├── shm.rs
│ └── worker.rs
└── vllm.md
├── controllers
├── aici_abi
│ ├── .cargo
│ │ └── config.toml
│ ├── Cargo.toml
│ ├── README.md
│ ├── grammars
│ │ ├── c.y
│ │ ├── json0.guidance
│ │ └── sample.c
│ ├── implementation.md
│ └── src
│ │ ├── cfg.rs
│ │ ├── dlex.rs
│ │ ├── host.rs
│ │ ├── lex.rs
│ │ ├── lib.rs
│ │ ├── rx.rs
│ │ ├── substring.rs
│ │ └── yesno.rs
├── aici_native
│ ├── Cargo.toml
│ ├── README.md
│ └── src
│ │ ├── bintokens.rs
│ │ ├── lib.rs
│ │ ├── log.rs
│ │ └── variables.rs
├── declctrl
│ ├── .cargo
│ │ └── config.toml
│ ├── Cargo.toml
│ ├── README.md
│ ├── arg.json
│ ├── arg2.json
│ ├── genarg.py
│ ├── native.sh
│ ├── size.js
│ ├── src
│ │ └── declctrl.rs
│ └── wasm.sh
├── jsctrl
│ ├── .cargo
│ │ └── config.toml
│ ├── Cargo.toml
│ ├── README.md
│ ├── build.rs
│ ├── gen-dts.mjs
│ ├── samples
│ │ ├── aici-types.d.ts
│ │ ├── count-yellow.js
│ │ ├── hello.js
│ │ ├── hellots.ts
│ │ ├── mapping.js
│ │ ├── schema.js
│ │ ├── test.ts
│ │ └── tsconfig.json
│ ├── src
│ │ └── jsctrl.rs
│ └── ts
│ │ ├── aici.ts
│ │ ├── native.d.ts
│ │ └── tsconfig.json
├── llguidance_ctrl
│ ├── .cargo
│ │ └── config.toml
│ ├── Cargo.toml
│ ├── go.sh
│ ├── run_g.py
│ ├── src
│ │ └── gctrl.rs
│ └── text_req.py
├── pyctrl
│ ├── .cargo
│ │ └── config.toml
│ ├── Cargo.toml
│ ├── Lib
│ │ ├── LICENSE
│ │ ├── _collections_abc.py
│ │ ├── _py_abc.py
│ │ ├── _weakrefset.py
│ │ ├── abc.py
│ │ ├── collections
│ │ │ ├── __init__.py
│ │ │ ├── _defaultdict.py
│ │ │ └── abc.py
│ │ ├── contextlib.py
│ │ ├── copy.sh
│ │ ├── copyreg.py
│ │ ├── enum.py
│ │ ├── functools.py
│ │ ├── genericpath.py
│ │ ├── keyword.py
│ │ ├── operator.py
│ │ ├── os.py
│ │ ├── posixpath.py
│ │ ├── re.py
│ │ ├── reprlib.py
│ │ ├── sre_compile.py
│ │ ├── sre_constants.py
│ │ ├── sre_parse.py
│ │ ├── stat.py
│ │ ├── types.py
│ │ └── typing.py
│ ├── README.md
│ ├── build.rs
│ ├── driver.py
│ ├── samples
│ │ ├── cfg.py
│ │ ├── forkbomb.py
│ │ ├── idents.py
│ │ ├── phi.py
│ │ ├── substr.py
│ │ ├── test.py
│ │ ├── tla.py
│ │ ├── warsaw.py
│ │ └── yesno.py
│ ├── src
│ │ └── pyctrl.rs
│ └── wasm.sh
└── uppercase
│ ├── .cargo
│ └── config.toml
│ ├── Cargo.toml
│ ├── README.md
│ └── src
│ └── main.rs
├── docs
├── FAQ.md
├── REST.md
├── aicirt-proto.md
└── proxy.md
├── py
├── promptlib
│ ├── README.md
│ ├── __init__.py
│ ├── notebooks
│ │ ├── basics_tutorial.ipynb
│ │ └── information_flow_examples.ipynb
│ └── promptlib
│ │ ├── __init__.py
│ │ ├── aici.py
│ │ ├── gen.py
│ │ ├── prompt.py
│ │ └── vars.py
├── pyaici
│ ├── __init__.py
│ ├── _vllm_protocol.py
│ ├── _vllm_runner.py
│ ├── _vllm_sampling_ctrl.py
│ ├── ast_.py
│ ├── cli.py
│ ├── comms.py
│ ├── jssrc.py
│ ├── rest.py
│ ├── server.py
│ ├── server_native.py
│ ├── util.py
│ ├── vllm.py
│ └── vllm_server.py
├── setup.py
└── tests
│ ├── conftest.py
│ ├── declctrl_test.py
│ └── test-prompt.txt
├── pytest.ini
├── rllm
├── llama-cpp-low
│ ├── Cargo.toml
│ ├── build.rs
│ └── src
│ │ ├── lib.rs
│ │ └── main.rs
├── rllm-base
│ ├── Cargo.toml
│ └── src
│ │ ├── config.rs
│ │ ├── engine.rs
│ │ ├── exec.rs
│ │ ├── expected.rs
│ │ ├── iface.rs
│ │ ├── lib.rs
│ │ ├── logits.rs
│ │ ├── scheduler.rs
│ │ ├── seq.rs
│ │ ├── server
│ │ ├── api.rs
│ │ ├── completion.rs
│ │ ├── mod.rs
│ │ └── openai
│ │ │ ├── LICENSE
│ │ │ ├── mod.rs
│ │ │ ├── requests.rs
│ │ │ └── responses.rs
│ │ └── util.rs
├── rllm-cuda
│ ├── Cargo.toml
│ ├── README.md
│ ├── expected
│ │ ├── codellama
│ │ │ ├── args.txt
│ │ │ ├── cats.safetensors
│ │ │ ├── lighthouse.safetensors
│ │ │ └── primes.safetensors
│ │ ├── codellama34
│ │ │ ├── args.txt
│ │ │ ├── cats.safetensors
│ │ │ ├── lighthouse.safetensors
│ │ │ └── primes.safetensors
│ │ ├── go.sh
│ │ ├── llama
│ │ │ ├── args.txt
│ │ │ ├── cats.safetensors
│ │ │ ├── lighthouse.safetensors
│ │ │ └── primes.safetensors
│ │ ├── orca
│ │ │ ├── args.txt
│ │ │ ├── cats.safetensors
│ │ │ ├── lighthouse.safetensors
│ │ │ └── primes.safetensors
│ │ ├── phi-1_5
│ │ │ ├── args.txt
│ │ │ ├── cats.safetensors
│ │ │ ├── lighthouse.safetensors
│ │ │ └── primes.safetensors
│ │ └── phi-2
│ │ │ ├── args.txt
│ │ │ ├── cats.safetensors
│ │ │ ├── lighthouse.safetensors
│ │ │ └── primes.safetensors
│ ├── scripts
│ │ ├── cmp2.sh
│ │ ├── convert.py
│ │ ├── tensorcmp.py
│ │ └── testgen.py
│ ├── server.sh
│ ├── src
│ │ ├── llm
│ │ │ ├── config.rs
│ │ │ ├── kernels.rs
│ │ │ ├── llama.rs
│ │ │ ├── loader.rs
│ │ │ ├── mod.rs
│ │ │ ├── paged
│ │ │ │ ├── batch_info.rs
│ │ │ │ ├── blocks.rs
│ │ │ │ ├── cache_engine.rs
│ │ │ │ ├── cuda_stub.rs
│ │ │ │ └── mod.rs
│ │ │ ├── phi.rs
│ │ │ ├── refkernels.rs
│ │ │ ├── tmodel.rs
│ │ │ └── util.rs
│ │ └── rllm-cuda.rs
│ └── test.sh
├── rllm-llamacpp
│ ├── Cargo.toml
│ ├── README.md
│ ├── server.sh
│ └── src
│ │ ├── llamacpp
│ │ ├── blocks.rs
│ │ ├── loader.rs
│ │ ├── mod.rs
│ │ ├── seqid.rs
│ │ └── tmodel.rs
│ │ └── rllm-llamacpp.rs
└── tch-cuda
│ ├── Cargo.toml
│ ├── README.md
│ ├── build.rs
│ ├── convhd.js
│ ├── kernels
│ ├── cuda.cpp
│ ├── flash_attn
│ │ ├── AUTHORS
│ │ ├── LICENSE
│ │ ├── block_info.h
│ │ ├── flash.h
│ │ ├── flash_api.cpp
│ │ ├── flash_fwd_kernel.h
│ │ ├── flash_fwd_launch_template.h
│ │ ├── flash_fwd_split_hdim128_bf16_sm80.cu
│ │ ├── flash_fwd_split_hdim128_fp16_sm80.cu
│ │ ├── flash_fwd_split_hdim160_bf16_sm80.cu
│ │ ├── flash_fwd_split_hdim160_fp16_sm80.cu
│ │ ├── flash_fwd_split_hdim192_bf16_sm80.cu
│ │ ├── flash_fwd_split_hdim192_fp16_sm80.cu
│ │ ├── flash_fwd_split_hdim224_bf16_sm80.cu
│ │ ├── flash_fwd_split_hdim224_fp16_sm80.cu
│ │ ├── flash_fwd_split_hdim256_bf16_sm80.cu
│ │ ├── flash_fwd_split_hdim256_fp16_sm80.cu
│ │ ├── flash_fwd_split_hdim32_bf16_sm80.cu
│ │ ├── flash_fwd_split_hdim32_fp16_sm80.cu
│ │ ├── flash_fwd_split_hdim64_bf16_sm80.cu
│ │ ├── flash_fwd_split_hdim64_fp16_sm80.cu
│ │ ├── flash_fwd_split_hdim96_bf16_sm80.cu
│ │ ├── flash_fwd_split_hdim96_fp16_sm80.cu
│ │ ├── kernel_traits.h
│ │ ├── philox.cuh
│ │ ├── softmax.h
│ │ ├── static_switch.h
│ │ └── utils.h
│ ├── vllm
│ │ ├── LICENSE
│ │ ├── activation_kernels.cu
│ │ ├── attention
│ │ │ ├── attention_dtypes.h
│ │ │ ├── attention_generic.cuh
│ │ │ ├── attention_kernels.cu
│ │ │ ├── attention_utils.cuh
│ │ │ ├── dtype_bfloat16.cuh
│ │ │ ├── dtype_float16.cuh
│ │ │ └── dtype_float32.cuh
│ │ ├── cache.h
│ │ ├── cache_kernels.cu
│ │ ├── cuda_compat.h
│ │ ├── cuda_utils.h
│ │ ├── cuda_utils_kernels.cu
│ │ ├── dispatch_utils.h
│ │ ├── layernorm_kernels.cu
│ │ ├── ops.h
│ │ ├── pos_encoding_kernels.cu
│ │ ├── pybind.cpp
│ │ ├── quantization
│ │ │ ├── awq
│ │ │ │ ├── dequantize.cuh
│ │ │ │ └── gemm_kernels.cu
│ │ │ └── squeezellm
│ │ │ │ └── quant_cuda_kernel.cu
│ │ └── reduction_utils.cuh
│ └── vllm_bindings.cpp
│ ├── src
│ ├── event.rs
│ ├── lib.rs
│ └── stream.rs
│ └── tests
│ └── flash_attn_tests.rs
├── rustfmt.toml
└── scripts
├── aici.sh
├── bench-comms.sh
├── bench-earley.sh
├── bench-server.sh
├── bump.sh
├── checkdeps.js
├── checklinks.js
├── checklinks.sh
├── docker-build.sh
├── docker-cpp-run.sh
├── docker-run.sh
├── docker-vllm-build.sh
├── hf.sh
├── host.sh
├── kill-rt.sh
├── kill-server.sh
├── py
├── bench_server.py
├── hgwells.txt
├── run_hf.py
├── run_vllm.py
└── vllm_server.py
├── random
├── disasm.js
├── parse-tokenizer-automaton.js
└── tokenizer-stats.js
├── release.sh
├── sample-uppercase.sh
├── sample-yesno.sh
├── tag-ctrls.sh
├── test-all.sh
├── test-guidance.sh
├── test-infer1.sh
├── test-jsctrl.sh
├── test-llg1.sh
├── test-parallel.sh
├── test-pyctrl.sh
├── tokenizer-stats.js
├── upload-all.sh
├── vllm-init.sh
└── vllm-server.sh
/.devcontainer/Dockerfile-client:
--------------------------------------------------------------------------------
1 | # syntax = edrevo/dockerfile-plus
2 | # ^^^ this line enables the INCLUDE+ directive
3 |
4 | FROM mcr.microsoft.com/devcontainers/cpp:ubuntu-22.04
5 |
6 | INCLUDE+ common.dockerfile
7 |
--------------------------------------------------------------------------------
/.devcontainer/Dockerfile-cuda:
--------------------------------------------------------------------------------
1 | # syntax = edrevo/dockerfile-plus
2 | # ^^^ this line enables the INCLUDE+ directive
3 |
4 | FROM nvcr.io/nvidia/pytorch:23.09-py3
5 |
6 | INCLUDE+ cuda-settings.dockerfile
7 | INCLUDE+ common.dockerfile
8 |
9 | RUN pip install torch==2.1.0 nvidia-cuda-runtime
10 | # the .so file seems to be missing
11 | RUN ln -s /usr/local/lib/python3.10/dist-packages/nvidia/cuda_runtime/lib/libcudart.so{.12,}
12 |
13 | # perf tool
14 | RUN apt-get install -y linux-tools-`uname -r`
15 |
16 | RUN source /usr/local/nvm/nvm.sh && npm install -g yarn
17 |
18 | # we mostly need guidance deps
19 | RUN pip install guidance
20 |
--------------------------------------------------------------------------------
/.devcontainer/Dockerfile-prod-cpp:
--------------------------------------------------------------------------------
1 | # docker container with aicirt and CPU-only rllm (llama.cpp)
2 | # TAG: aici/rllm-llamacpp
3 |
4 | FROM bitnami/minideb:bookworm
5 |
6 | RUN apt-get update && apt-get install -y libssl3 && apt-get clean
7 |
8 | # install aicirt and rllm
9 | COPY target/dist/aicirt/aicirt /usr/bin/aicirt
10 | COPY target/dist/rllm-llamacpp/rllm-llamacpp /usr/bin/rllm-llamacpp
11 |
12 | RUN mkdir /workspace
13 |
14 | # copy the controllers
15 | WORKDIR /workspace
16 |
17 | # RUN mkdir wasm
18 | # COPY target/dist/aici_*.wasm wasm/
19 | # # "upload" and tag the controllers
20 | # RUN aicirt --module wasm/aici_guidance_ctrl.wasm --tag guidance
21 | # RUN aicirt --module wasm/aici_pyctrl.wasm --tag pyctrl --gh-module gh:microsoft/aici/pyctrl
22 | # RUN aicirt --module wasm/aici_jsctrl.wasm --tag jsctrl --gh-module gh:microsoft/aici/jsctrl
23 |
24 | ENV RUST_LOG info,tokenizers=error
25 |
26 | ENTRYPOINT ["rllm-llamacpp", "--aicirt=/usr/bin/aicirt"]
27 |
--------------------------------------------------------------------------------
/.devcontainer/Dockerfile-prod-vllm:
--------------------------------------------------------------------------------
1 | FROM rust:1.75.0-bookworm AS aicirt
2 |
3 | WORKDIR /workspace
4 |
5 | RUN rustup target add wasm32-wasi
6 | RUN curl -fsSL https://deb.nodesource.com/setup_18.x | bash -
7 | RUN apt-get install -y nodejs
8 |
9 | COPY controllers controllers
10 | COPY aicirt aicirt
11 | COPY scripts scripts
12 | COPY py/pyaici py/pyaici
13 | COPY Cargo.toml Cargo.lock /workspace/
14 |
15 | # make sure we rebuild these
16 | RUN rm -rf controllers/jsctrl/samples/dist controllers/jsctrl/ts/dist
17 |
18 | RUN grep -v rllm Cargo.toml > Cargo.toml.tmp && mv Cargo.toml.tmp Cargo.toml
19 |
20 | RUN --mount=type=cache,target=/usr/local/cargo/git \
21 | --mount=type=cache,target=/usr/local/cargo/registry \
22 | --mount=type=cache,target=/workspace/target \
23 | cargo fetch
24 |
25 | ARG tag=latest
26 | ENV BUILD_TAG=$tag
27 | RUN --mount=type=cache,target=/usr/local/cargo/git \
28 | --mount=type=cache,target=/usr/local/cargo/registry \
29 | --mount=type=cache,target=/workspace/target \
30 | SKIP_LLAMA_CPP=1 \
31 | ./scripts/release.sh && cp -r /workspace/target/dist /workspace/
32 |
33 |
34 | FROM vllm/vllm-openai as vllm-base
35 |
36 | # install pyaici pre-requisites
37 | RUN pip install posix_ipc ujson
38 |
39 | # install pyaici
40 | RUN mkdir /tmp/pyaici
41 | COPY py/setup.py /tmp/pyaici/
42 | COPY py/pyaici /tmp/pyaici/pyaici
43 | RUN cd /tmp/pyaici && pip install . && rm -rf /tmp/pyaici
44 |
45 | # patch the vllm python files
46 | RUN --mount=source=py/vllm/vllm,target=/tmp/vllm \
47 | (cd /tmp/vllm && find . -name '*.py' -print0 | tar -cf - --null -T -) | \
48 | tar -C /usr/local/lib/python3.10/dist-packages/vllm -xf -
49 |
50 | # copy the controllers and aicirt
51 | WORKDIR /vllm-workspace
52 | RUN mkdir wasm
53 |
54 | RUN --mount=from=aicirt,source=/workspace/dist,target=/tmp/dist \
55 | cp /tmp/dist/aicirt/aicirt /usr/bin/aicirt && \
56 | cp /tmp/dist/aici_*.wasm wasm/
57 |
58 | RUN ls -l wasm/
59 |
60 | ENV RUST_LOG info,tokenizers=error
61 |
62 | ENTRYPOINT ["python3", "-m", "pyaici.vllm_server", "--enforce-eager", "--use-v2-block-manager", "--enable-chunked-prefill", "--aici-rt=/usr/bin/aicirt", "-A--restricted", "-A--wasm-timer-resolution-us=10"]
63 |
64 | FROM vllm-base AS vllm-guidance
65 |
66 | RUN aicirt --module wasm/aici_llguidance_ctrl.wasm --tag llguidance
67 |
68 | FROM vllm-base as vllm-general
69 |
70 | RUN aicirt --module wasm/aici_llguidance_ctrl.wasm --tag llguidance
71 | RUN aicirt --module wasm/aici_pyctrl.wasm --tag pyctrl --gh-module gh:microsoft/aici/pyctrl
72 | RUN aicirt --module wasm/aici_jsctrl.wasm --tag jsctrl --gh-module gh:microsoft/aici/jsctrl
73 |
--------------------------------------------------------------------------------
/.devcontainer/Dockerfile-vllm:
--------------------------------------------------------------------------------
1 | # syntax = edrevo/dockerfile-plus
2 | # ^^^ this line enables the INCLUDE+ directive
3 |
4 | FROM nvcr.io/nvidia/pytorch:23.10-py3
5 |
6 | INCLUDE+ cuda-settings.dockerfile
7 | INCLUDE+ common.dockerfile
8 |
9 | ENV NVTE_FRAMEWORK=pytorch
10 |
11 | COPY vllm-requirements.txt /tmp/requirements.txt
12 | RUN pip install -r /tmp/requirements.txt
13 |
14 | # Uninstall the transformer engine that comes with the base image.
15 | # Otherwise it will cause error when importing vLLM (LLAVA models).
16 | RUN pip uninstall -y transformer_engine
17 |
18 | # crashes docker?
19 | # RUN pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
20 |
21 | # takes forever!
22 | # RUN pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
23 | # RUN pip install typing_extensions==4.5.0
24 | # RUN pip install -U flash-attn
25 |
26 | # RUN pip install torch==2.1.0 nvidia-cuda-runtime
27 | # the .so file seems to be missing
28 | RUN ln -s /usr/local/lib/python3.10/dist-packages/nvidia/cuda_runtime/lib/libcudart.so{.12,}
29 |
30 | # perf tool
31 | RUN apt-get install -y linux-tools-`uname -r`
32 |
33 | RUN source /usr/local/nvm/nvm.sh && npm install -g yarn
34 |
--------------------------------------------------------------------------------
/.devcontainer/client/devcontainer.json:
--------------------------------------------------------------------------------
1 | // For format details, see https://aka.ms/devcontainer.json
2 | {
3 | "name": "AICI Client-side",
4 | "build": {
5 | "dockerfile": "../Dockerfile-client",
6 | "context": ".."
7 | },
8 | "customizations": {
9 | "vscode": {
10 | "extensions": [
11 | "ms-python.python",
12 | "ms-python.black-formatter",
13 | "1YiB.rust-bundle",
14 | "dtsvet.vscode-wasm",
15 | "ms-vscode.cpptools",
16 | "esbenp.prettier-vscode",
17 | "streetsidesoftware.code-spell-checker",
18 | "GitHub.copilot"
19 | ]
20 | }
21 | },
22 | "remoteUser": "root",
23 | "containerUser": "root",
24 | "mounts": [
25 | "source=profile,target=/root,type=volume",
26 | "target=/root/.vscode-server,type=volume"
27 | ]
28 | }
--------------------------------------------------------------------------------
/.devcontainer/common.dockerfile:
--------------------------------------------------------------------------------
1 | # makes it easier to diagnose ccache issues
2 | ENV CCACHE_DEBUG="1"
3 |
4 | # need git 2.41 for GCM/Github EMU account switching
5 | # https://askubuntu.com/questions/568591/how-do-i-install-the-latest-version-of-git-with-apt
6 | RUN apt-get update && apt-get install -y software-properties-common
7 | RUN apt-add-repository ppa:git-core/ppa
8 | RUN apt-get update && apt-get install -y git
9 |
10 | RUN DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
11 | build-essential ca-certificates ccache \
12 | cmake curl libjpeg-dev libpng-dev \
13 | strace linux-tools-common linux-tools-generic \
14 | llvm-dev libclang-dev clang ccache apache2-utils git-lfs \
15 | screen bsdmainutils pip python3-dev python-is-python3 \
16 | nodejs npm pkg-config
17 |
18 | RUN pip install pytest pytest-forked ujson posix_ipc numpy requests
19 |
20 | # RUN curl -L https://github.com/WebAssembly/binaryen/releases/download/version_116/binaryen-version_116-x86_64-linux.tar.gz \
21 | # | tar zxf - --strip-components=1 -C /usr/local
22 |
23 | RUN cd /tmp && \
24 | curl -L https://github.com/WebAssembly/wabt/releases/download/1.0.33/wabt-1.0.33.tar.xz | tar Jxf - && \
25 | cd wabt-1.0.33 && make gcc-release && cp -v bin/wasm-* /usr/bin && cd .. && rm -rf wabt-1.0.33
26 |
27 | ENV RUSTUP_HOME=/usr/local/rustup \
28 | CARGO_HOME=/usr/local/cargo \
29 | PATH=/usr/local/cargo/bin:$PATH \
30 | RUST_VERSION=1.75.0
31 |
32 | RUN curl https://sh.rustup.rs -sSf | sh -s -- \
33 | -y --no-modify-path --profile minimal --default-toolchain $RUST_VERSION
34 | RUN rustup target add wasm32-wasi
35 | RUN rustup component add rustfmt
36 |
37 | # run as root please; note that settings in devcontainer.json are also needed...
38 | USER root
39 |
--------------------------------------------------------------------------------
/.devcontainer/cuda-settings.dockerfile:
--------------------------------------------------------------------------------
1 | # A100:
2 | ENV TORCH_CUDA_ARCH_LIST="8.0"
3 | ENV CUDA_COMPUTE_CAP="80"
4 |
5 | # candle is slow without caching; pytorch image has this on by default
6 | ENV CUDA_CACHE_DISABLE=""
7 |
8 | ENV LIBTORCH_USE_PYTORCH="1"
9 | ENV LIBTORCH_BYPASS_VERSION_CHECK="1"
10 |
--------------------------------------------------------------------------------
/.devcontainer/cuda/devcontainer.json:
--------------------------------------------------------------------------------
1 | // For format details, see https://aka.ms/devcontainer.json
2 | {
3 | "name": "AICI with CUDA",
4 | "build": {
5 | "dockerfile": "../Dockerfile-cuda",
6 | "context": ".."
7 | },
8 | "runArgs": [
9 | "--privileged",
10 | "--gpus",
11 | "all",
12 | "--shm-size=8g"
13 | ],
14 | "mounts": [
15 | "source=profile,target=/root,type=volume",
16 | "target=/root/.vscode-server,type=volume"
17 | ],
18 | "customizations": {
19 | "vscode": {
20 | "extensions": [
21 | "ms-python.python",
22 | "ms-python.black-formatter",
23 | "1YiB.rust-bundle",
24 | "dtsvet.vscode-wasm",
25 | "ms-vscode.cpptools",
26 | "esbenp.prettier-vscode",
27 | "streetsidesoftware.code-spell-checker",
28 | "GitHub.copilot"
29 | ]
30 | }
31 | },
32 | "forwardPorts": [
33 | 4242
34 | ]
35 | }
--------------------------------------------------------------------------------
/.devcontainer/vllm-requirements.txt:
--------------------------------------------------------------------------------
1 | # vllm: requirements.txt
2 | ninja # For faster builds.
3 | psutil
4 | ray >= 2.9
5 | sentencepiece # Required for LLaMA tokenizer.
6 | numpy
7 | torch == 2.1.2
8 | transformers >= 4.37.0 # Required for Qwen2
9 | xformers == 0.0.23.post1 # Required for CUDA 12.1.
10 | fastapi
11 | uvicorn[standard]
12 | pydantic >= 2.0 # Required for OpenAI server.
13 | aioprometheus[starlette]
14 | pynvml == 11.5.0
15 | triton >= 2.1.0
16 | cupy-cuda12x == 12.3.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead.
17 |
18 | # vllm: requirements-dev.txt
19 | # formatting
20 | yapf==0.32.0
21 | toml==0.10.2
22 | ruff==0.1.5
23 |
24 | # type checking
25 | mypy==0.991
26 | types-PyYAML
27 | types-requests
28 | types-setuptools
29 |
30 | # testing
31 | pytest
32 | pytest-forked
33 | pytest-asyncio
34 | httpx
35 | einops # required for MPT
36 | flash_attn # required for HuggingFace's llama implementation
37 | openai
38 | requests
39 | # ray - XXX
40 |
41 | # vllm: requirements-build.txt
42 | # Should be mirrored in pyproject.toml
43 | ninja
44 | packaging
45 | setuptools>=49.4.0
46 | # torch==2.1.2 - XXX
47 | wheel
48 |
49 | # non-vllm:
50 | ujson
51 | posix_ipc
52 | accelerate
53 | fschat
54 |
--------------------------------------------------------------------------------
/.devcontainer/vllm/devcontainer.json:
--------------------------------------------------------------------------------
1 | // For format details, see https://aka.ms/devcontainer.json
2 | {
3 | "name": "AICI with CUDA and vLLM (experimental)",
4 | "build": {
5 | "dockerfile": "../Dockerfile-vllm",
6 | "context": ".."
7 | },
8 | "runArgs": [
9 | "--privileged",
10 | "--gpus",
11 | "all",
12 | "--shm-size=8g"
13 | ],
14 | "mounts": [
15 | "source=profile,target=/root,type=volume",
16 | "target=/root/.vscode-server,type=volume"
17 | ],
18 | "customizations": {
19 | "vscode": {
20 | "extensions": [
21 | "ms-python.python",
22 | "ms-python.black-formatter",
23 | "eeyore.yapf",
24 | "1YiB.rust-bundle",
25 | "dtsvet.vscode-wasm",
26 | "ms-vscode.cpptools",
27 | "esbenp.prettier-vscode",
28 | "streetsidesoftware.code-spell-checker",
29 | "GitHub.copilot"
30 | ]
31 | }
32 | },
33 | "forwardPorts": [
34 | 4242
35 | ]
36 | }
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | rllm/tch-cuda/kernels/vllm/** linguist-vendored
2 | rllm/tch-cuda/kernels/flash_attn/** linguist-vendored
3 | controllers/pyctrl/Lib/** linguist-vendored
4 |
--------------------------------------------------------------------------------
/.github/workflows/aicirt-release.yml:
--------------------------------------------------------------------------------
1 | name: AICIrt release
2 |
3 | on:
4 | push:
5 | tags:
6 | - "v*.*.*"
7 |
8 | env:
9 | CARGO_TERM_COLOR: always
10 |
11 | jobs:
12 | build:
13 | runs-on: ubuntu-latest
14 |
15 | permissions:
16 | contents: write
17 |
18 | steps:
19 | - uses: actions/checkout@v3
20 | with:
21 | submodules: true
22 | - run: rustup target add wasm32-wasi
23 | - uses: hendrikmuhs/ccache-action@v1.2
24 | - uses: Swatinem/rust-cache@v2
25 | with:
26 | cache-on-failure: true
27 | - name: Release script
28 | run: ./scripts/release.sh --xz
29 | - name: Release
30 | uses: softprops/action-gh-release@v1
31 | if: startsWith(github.ref, 'refs/tags/')
32 | with:
33 | body_path: target/dist/README.md
34 | files: |
35 | target/dist/*.tar.gz
36 | target/dist/*.tar.xz
37 | target/dist/*.wasm
38 |
--------------------------------------------------------------------------------
/.github/workflows/aicirt.yml:
--------------------------------------------------------------------------------
1 | name: AICIrt
2 |
3 | on:
4 | push:
5 | branches: [ "main" ]
6 | pull_request:
7 | branches: [ "main" ]
8 |
9 | env:
10 | CARGO_TERM_COLOR: always
11 |
12 | jobs:
13 | build:
14 |
15 | runs-on: ubuntu-latest
16 |
17 | steps:
18 | - uses: actions/checkout@v3
19 | with:
20 | submodules: true
21 | - run: rustup target add wasm32-wasi
22 | - uses: hendrikmuhs/ccache-action@v1.2
23 | - uses: Swatinem/rust-cache@v2
24 | with:
25 | cache-on-failure: true
26 | - name: Build aici_abi
27 | run: cargo build --verbose --release
28 | working-directory: controllers/aici_abi
29 | - name: Build uppercase
30 | run: cargo build --verbose --release
31 | working-directory: controllers/uppercase
32 | - name: Build pyctrl
33 | run: cargo build --verbose --release
34 | working-directory: controllers/pyctrl
35 | - name: Build jsctrl
36 | run: cargo build --verbose --release
37 | working-directory: controllers/jsctrl
38 | - name: Build declctrl
39 | run: cargo build --verbose --release
40 | working-directory: controllers/declctrl
41 | - name: Build aicirt
42 | run: cargo build --verbose --release
43 | working-directory: aicirt
44 | - name: Build rllm-llamacpp
45 | run: cargo build --verbose --release --no-default-features
46 | working-directory: rllm/rllm-llamacpp
47 | - name: Release script
48 | run: ./scripts/release.sh --xz
49 | - name: Artifact upload
50 | uses: actions/upload-artifact@v4
51 | with:
52 | name: aicirt-xz
53 | path: target/dist/*.tar.xz
54 |
--------------------------------------------------------------------------------
/.github/workflows/links.yml:
--------------------------------------------------------------------------------
1 | name: Markdown link check
2 |
3 | on:
4 | push:
5 | branches: [ "main" ]
6 | pull_request:
7 | branches: [ "main" ]
8 |
9 | jobs:
10 | build:
11 |
12 | runs-on: ubuntu-latest
13 |
14 | steps:
15 | - uses: actions/checkout@v3
16 | - run: ./scripts/checklinks.sh
17 |
--------------------------------------------------------------------------------
/.github/workflows/rllm-cuda.yml:
--------------------------------------------------------------------------------
1 | name: rLLM with CUDA
2 |
3 | on:
4 | push:
5 | branches: [ "disabled-main" ]
6 | pull_request:
7 | branches: [ "disabled-main" ]
8 |
9 | env:
10 | CARGO_TERM_COLOR: always
11 | TORCH_CUDA_ARCH_LIST: 8.0
12 | CUDA_COMPUTE_CAP: 80
13 | LIBTORCH_USE_PYTORCH: 1
14 | LIBTORCH_BYPASS_VERSION_CHECK: 1
15 |
16 | jobs:
17 | build:
18 |
19 | runs-on: ubuntu-latest
20 |
21 | steps:
22 | - uses: actions/checkout@v3
23 | with:
24 | submodules: true
25 | - run: sudo apt-get install ccache
26 |
27 | - run: sudo df -h
28 | - run: sudo rm -rf /usr/share/dotnet /opt/ghc /usr/local/share/boost
29 | - run: sudo df -h
30 |
31 | - run: pip install torch==2.1.0
32 | - run: sudo df -h
33 |
34 | - uses: Jimver/cuda-toolkit@v0.2.13
35 | id: cuda-toolkit
36 | with:
37 | cuda: '12.3.2'
38 |
39 | - run: echo "Installed cuda version is ${{ steps.cuda-toolkit.outputs.cuda }}"
40 | - run: echo "Cuda install location ${{ steps.cuda-toolkit.outputs.CUDA_PATH }}"
41 |
42 | - run: nvcc -V
43 |
44 | - name: Build rLLM
45 | run: cargo build --verbose --release
46 | working-directory: rllm/rllm-cuda
47 |
48 | - run: strip target/release/rllm-cuda
49 | - name: Artifact upload
50 | uses: actions/upload-artifact@v4
51 | with:
52 | name: rllm-cuda
53 | path: target/release/rllm-cuda
54 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.py[cod]
3 | .env
4 | .venv
5 | built
6 | tmp
7 | cache
8 | target
9 | tokenizer.json
10 | .hypothesis
11 | perf.data
12 | tokenizer.bin
13 | log.txt
14 | rllm/rllm-cuda/*.safetensor
15 | rllm/rllm-cuda/report*.sqlite
16 | *.ncu-rep
17 | *.nsys-rep
18 | .DS_Store
19 | build/
20 | dist/
21 | *.egg-info
22 | .vscode/c_cpp_properties.json
23 | logs
24 | controllers/RustPython/
25 | *.cpython*.so
26 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "vllm"]
2 | path = py/vllm
3 | url = https://github.com/mmoskal/vllm
4 | branch = hooks
5 | [submodule "tch-cuda/cutlass"]
6 | path = rllm/tch-cuda/cutlass
7 | url = https://github.com/NVIDIA/cutlass.git
8 | [submodule "llama-cpp-low/llama.cpp"]
9 | path = rllm/llama-cpp-low/llama.cpp
10 | url = https://github.com/ggerganov/llama.cpp
11 | [submodule "py/guidance"]
12 | path = py/guidance
13 | url = https://github.com/hudson-ai/guidance
14 | branch = lazy_grammars
15 | [submodule "controllers/toktrie"]
16 | path = controllers/toktrie
17 | url = https://github.com/microsoft/toktrie
18 | [submodule "controllers/derivre"]
19 | path = controllers/derivre
20 | url = https://github.com/microsoft/derivre
21 | [submodule "controllers/llguidance"]
22 | path = controllers/llguidance
23 | url = https://github.com/microsoft/llguidance
24 |
--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | // Use IntelliSense to learn about possible attributes.
3 | // Hover to view descriptions of existing attributes.
4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5 | "version": "0.2.0",
6 | "configurations": [
7 | {
8 | "name": "vLLM server",
9 | "type": "debugpy",
10 | "request": "launch",
11 | "module": "pyaici.vllm_server",
12 | "env": {
13 | "RUST_LOG": "info,tokenizers=error,aicirt=info",
14 | "RUST_BACKTRACE": "1",
15 | "PYTHONPATH": "${workspaceFolder}/py:${workspaceFolder}/py/vllm"
16 | },
17 | "args": [
18 | "--enforce-eager",
19 | "--use-v2-block-manager",
20 | "--enable-chunked-prefill",
21 | "--served-model-name=model",
22 | "--aici-rt",
23 | "./target/release/aicirt",
24 | "-A--wasm-timer-resolution-us=10",
25 | "--model",
26 | "microsoft/Phi-3-mini-128k-instruct",
27 | "--trust-remote-code",
28 | "--port",
29 | "4242",
30 | "--host",
31 | "127.0.0.1",
32 | "--trust-remote-code"
33 | ]
34 | },
35 | {
36 | "type": "lldb",
37 | "request": "launch",
38 | "name": "rllm-llamacpp phi",
39 | "cwd": "rllm/rllm-llamacpp",
40 | "preLaunchTask": "rllm-llamacpp: build",
41 | "program": "${workspaceFolder}/target/debug/rllm-llamacpp",
42 | "env": {
43 | "RUST_LOG": "info,tokenizers=error,rllm=trace,aicirt=info,llama_cpp_low=trace"
44 | },
45 | "args": [
46 | "--verbose",
47 | "--aicirt=${workspaceFolder}/target/release/aicirt",
48 | "--model=https://huggingface.co/TheBloke/phi-2-GGUF/blob/main/phi-2.Q8_0.gguf",
49 | "--gpu-layers=100"
50 | ]
51 | }
52 | ]
53 | }
--------------------------------------------------------------------------------
/.vscode/tasks.json:
--------------------------------------------------------------------------------
1 | {
2 | "version": "2.0.0",
3 | "tasks": [
4 | {
5 | "type": "cargo",
6 | "command": "build",
7 | "options": {
8 | "cwd": "${workspaceFolder}/rllm/rllm-llamacpp"
9 | },
10 | "problemMatcher": [
11 | "$rustc"
12 | ],
13 | "group": "build",
14 | "label": "rllm-llamacpp: build"
15 | }
16 | ]
17 | }
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/Cargo.toml:
--------------------------------------------------------------------------------
1 | [workspace]
2 | exclude = ["tch-rs"]
3 | members = [
4 | "aicirt",
5 | "controllers/toktrie/core",
6 | "controllers/toktrie/hf_tokenizers",
7 | "controllers/aici_abi",
8 | "controllers/aici_native",
9 | "controllers/declctrl",
10 | "controllers/pyctrl",
11 | "controllers/jsctrl",
12 | "controllers/llguidance/rust",
13 | "controllers/llguidance/parser",
14 | "controllers/llguidance_ctrl",
15 | "controllers/uppercase",
16 | "controllers/derivre",
17 | "rllm/rllm-base",
18 | "rllm/rllm-cuda",
19 | "rllm/rllm-llamacpp",
20 | "rllm/tch-cuda",
21 | "rllm/llama-cpp-low",
22 | ]
23 | resolver = "2"
24 |
25 | [profile.release]
26 | debug = 1
27 |
28 | [patch.'https://github.com/microsoft/toktrie']
29 | toktrie = { path = "controllers/toktrie/core" }
30 |
31 | [patch.'https://github.com/microsoft/derivre']
32 | derivre = { path = "controllers/derivre" }
33 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
40 |
41 |
42 |
--------------------------------------------------------------------------------
/SUPPORT.md:
--------------------------------------------------------------------------------
1 | # Support
2 |
3 | ## How to file issues and get help
4 |
5 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing
6 | issues before filing new issues to avoid duplicates. For new issues, file your bug or
7 | feature request as a new Issue.
8 |
9 | For other help and questions about using this project, please use GitHub Discussions.
10 |
11 | ## Microsoft Support Policy
12 |
13 | Support for this the AI Controller Interface (AICI) is limited to the resources listed above.
14 |
--------------------------------------------------------------------------------
/aici.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | if [ "X$AICI_API_BASE" = "X" ] ; then
4 | export AICI_API_BASE="http://127.0.0.1:4242/v1/"
5 | fi
6 |
7 | PYTHONPATH=`dirname $0`/py \
8 | python3 -m pyaici.cli "$@"
9 |
--------------------------------------------------------------------------------
/aicirt/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "aicirt"
3 | version = "0.1.0"
4 | edition = "2021"
5 |
6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
7 |
8 | [dependencies]
9 | aici_abi = { path = "../controllers/aici_abi" }
10 | aici_native = { path = "../controllers/aici_native" }
11 | anyhow = "1.0.75"
12 | base64 = "0.21.4"
13 | clap = { version = "4.4.4", features = ["derive"] }
14 | hex = "0.4.3"
15 | libc = "0.2.148"
16 | log = "0.4.20"
17 | rayon = "1.7.0"
18 | serde = { version = "1.0.192", features = ["derive"] }
19 | serde_json = { version = "1.0.108", features = ["preserve_order"] }
20 | sha2 = "0.10.7"
21 | wasmtime = { version = "16.0.0", default-features = false, features = ["cranelift", "parallel-compilation", "pooling-allocator"] }
22 | tokenizers = { version = "0.15.0", features = ["http"] }
23 | thread-priority = "0.15.1"
24 | cap = "0.1.2"
25 | bincode = "1.3.3"
26 | uuid = { version = "1.6.1", features = ["v4"] }
27 | regex = "1.10.3"
28 | ureq = "2.9.5"
29 |
30 | [target.'cfg(target_os = "linux")'.dependencies]
31 | linux-futex = "0.2.0"
32 |
33 | [target.'cfg(target_os = "macos")'.dependencies]
34 | ulock-sys = "0.1.0"
35 |
36 |
--------------------------------------------------------------------------------
/aicirt/README.md:
--------------------------------------------------------------------------------
1 | # AICI Runtime (aicirt)
2 |
3 | Multi-threaded runner for AICI Controllers, built on top of [Wasmtime](https://wasmtime.dev/).
4 |
5 | ```mermaid
6 | graph TD
7 | User1 <-- HTTP --> LLM
8 | User2 <-- HTTP --> LLM
9 | UserN <-- HTTP --> LLM["LLM Server
(batching)"]
10 | LLM <-- CUDA/pytorch --> GPU
11 | LLM <-- POSIX SHM --> aicirt[AICI-runtime]
12 | aicirt <-- Sockets+SHM --> Worker1[Worker1
Running Wasm]
13 | aicirt <-- Sockets+SHM --> Worker2[Worker2
Running Wasm]
14 | aicirt <-- Sockets+SHM --> WorkerM[WorkerM
Running Wasm]
15 | ```
16 |
17 | ```mermaid
18 | sequenceDiagram
19 | actor User
20 | participant GPU
21 | participant LLM
22 | participant aicirt as AICI-runtime
23 | LLM -->> GPU: Model
24 | User -->> LLM: Request (Prompt + Wasm)
25 | LLM -->>+ aicirt: Prompt + Wasm
26 | aicirt -->>- LLM: logit bias 1
27 | LLM -->>+ GPU: Prompt
28 | LLM -->> GPU: logit bias 1
29 | GPU -->> LLM: token 1
30 | LLM -->>+ aicirt: token 1
31 | LLM -->> User: token 1
32 | aicirt -->>- LLM: logit bias 2
33 | LLM -->> GPU: logit bias 2
34 | GPU -->>- LLM: token 2
35 | LLM -->> User: token 2
36 | ```
37 |
38 | Below is process structure.
39 |
40 | - dotted arrow from A to B indicates that A sends requests to B (and gets responses)
41 | - solid arrow from A to B indicates that A spawns (forks) B
42 | - `spawner` is a special process, forked from `aicirt` at the beginning;
43 | for every user requests it spawns a process for top-level controller and a `common state` process
44 | for handling shared state between
45 | all controller instances for that request (they can talk to the `common state` process)
46 | - the top-level constraint can spawn more constraints, which can spawn yet more;
47 | `aicirt` has a direct connection to all these constraints though
48 |
49 | ```mermaid
50 | graph TD
51 | LLM ---> aicirt[AICI-runtime]
52 | LLM -..-> aicirt
53 | aicirt -..-> spawner
54 | aicirt -..-> A0((A0))
55 | aicirt -..-> A1((A1))
56 | aicirt -..-> A2((A2))
57 | aicirt -..-> A3((A3))
58 | aicirt -..-> A4((A4))
59 | aicirt ---> spawner
60 | spawner --> A0
61 | spawner --> CommsA[Common state for A]
62 | subgraph User Request A
63 | A0 --> A1
64 | A0 --> A2
65 | A2 --> A3
66 | A2 --> A4
67 | A0 -..-> CommsA
68 | A1 -..-> CommsA
69 | A2 -..-> CommsA
70 | A3 -..-> CommsA
71 | A4 -..-> CommsA
72 | end
73 | aicirt -..-> B0((B0))
74 | aicirt -..-> B1((B1))
75 | spawner --> B0
76 | spawner --> CommsB[Common state for B]
77 | subgraph User Request B
78 | B0 -..-> CommsB
79 | B1 -..-> CommsB
80 | B0 --> B1
81 | end
82 | ```
--------------------------------------------------------------------------------
/aicirt/src/lib.rs:
--------------------------------------------------------------------------------
1 | pub mod api;
2 | mod bench;
3 | pub mod futexshm;
4 | pub mod msgchannel;
5 | pub mod semaphore;
6 | pub mod shm;
7 |
8 | pub use aici_native::*;
9 |
10 | #[cfg(target_os = "macos")]
11 | mod macos;
12 |
13 | pub use bench::*;
14 | use thread_priority::{
15 | set_thread_priority_and_policy, thread_native_id, RealtimeThreadSchedulePolicy, ThreadPriority,
16 | ThreadSchedulePolicy,
17 | };
18 |
19 | pub fn get_unix_time() -> u64 {
20 | std::time::SystemTime::now()
21 | .duration_since(std::time::UNIX_EPOCH)
22 | .unwrap()
23 | .as_secs()
24 | }
25 |
26 | /// An error thrown from the WASM runtime or otherwise originating from user error
27 | /// - should not generate additional stacktraces from where it's caught.
28 | #[derive(Debug)]
29 | pub struct UserError {
30 | pub msg: String,
31 | }
32 |
33 | impl UserError {
34 | pub fn new(msg: String) -> Self {
35 | Self { msg }
36 | }
37 |
38 | pub fn anyhow(msg: String) -> anyhow::Error {
39 | anyhow::anyhow!(Self::new(msg))
40 | }
41 |
42 | pub fn is_self(e: &anyhow::Error) -> bool {
43 | e.downcast_ref::().is_some()
44 | }
45 |
46 | pub fn maybe_stacktrace(e: &anyhow::Error) -> String {
47 | if let Some(e) = e.downcast_ref::() {
48 | format!("{}", e)
49 | } else {
50 | format!("{:?}", e)
51 | }
52 | }
53 | }
54 |
55 | impl std::fmt::Display for UserError {
56 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 | write!(f, "{}", self.msg)
58 | }
59 | }
60 |
61 | impl std::error::Error for UserError {}
62 |
63 | #[macro_export]
64 | macro_rules! user_error {
65 | ($($tt:tt)*) => {
66 | $crate::UserError::anyhow(format!($($tt)*))
67 | };
68 | }
69 |
70 | #[macro_export]
71 | macro_rules! bail_user {
72 | ($($tt:tt)*) => {
73 | return Err($crate::UserError::anyhow(format!($($tt)*)))
74 | };
75 | }
76 |
77 | #[macro_export]
78 | macro_rules! ensure_user {
79 | ($cond:expr, $($tt:tt)*) => {
80 | if !$cond {
81 | return Err($crate::UserError::anyhow(format!($($tt)*)))
82 | }
83 | };
84 | }
85 |
86 | pub fn is_hex_string(s: &str) -> bool {
87 | s.chars().all(|c| c.is_digit(16))
88 | }
89 |
90 | pub fn valid_module_or_tag(s: &str) -> bool {
91 | valid_module_id(s) || valid_tagname(s)
92 | }
93 |
94 | pub fn valid_module_id(s: &str) -> bool {
95 | s.len() == 64 && is_hex_string(s)
96 | }
97 |
98 | pub fn valid_tagname(s: &str) -> bool {
99 | match s.chars().next() {
100 | Some(c) if c.is_alphabetic() => {
101 | !valid_module_id(s)
102 | && s.chars().all(|c| {
103 | c == '_' || c == '-' || c == '.' || c.is_digit(10) || c.is_alphabetic()
104 | })
105 | }
106 | _ => false,
107 | }
108 | }
109 |
110 | fn set_priority(pri: ThreadPriority) {
111 | // this fails on WSL
112 | let _ = set_thread_priority_and_policy(
113 | thread_native_id(),
114 | pri,
115 | ThreadSchedulePolicy::Realtime(RealtimeThreadSchedulePolicy::Fifo),
116 | );
117 | }
118 |
119 | pub fn set_max_priority() {
120 | set_priority(ThreadPriority::Max);
121 | }
122 |
123 | pub fn set_min_priority() {
124 | set_priority(ThreadPriority::Min);
125 | }
126 |
--------------------------------------------------------------------------------
/aicirt/src/macos.rs:
--------------------------------------------------------------------------------
1 | use anyhow::{anyhow, Result};
2 | use std::{sync::atomic::AtomicU32, time::Duration};
3 | use ulock_sys::{
4 | __ulock_wait, __ulock_wake, darwin19::UL_COMPARE_AND_WAIT_SHARED, ULF_NO_ERRNO, ULF_WAKE_ALL,
5 | };
6 |
7 | pub trait AsFutex {
8 | fn as_futex(&self) -> &Futex;
9 | }
10 |
11 | impl AsFutex for AtomicU32 {
12 | #[must_use]
13 | #[inline]
14 | fn as_futex(&self) -> &Futex {
15 | unsafe { std::mem::transmute(self) }
16 | }
17 | }
18 |
19 | #[repr(transparent)]
20 | pub struct Futex {
21 | pub value: AtomicU32,
22 | }
23 |
24 | impl Futex {
25 | fn wait_core(&self, expected_value: u32, micros: u32) -> Result<()> {
26 | let r = unsafe {
27 | __ulock_wait(
28 | UL_COMPARE_AND_WAIT_SHARED | ULF_NO_ERRNO,
29 | self.value.as_ptr() as *mut libc::c_void,
30 | expected_value as u64,
31 | micros,
32 | )
33 | };
34 |
35 | if r >= 0 {
36 | Ok(())
37 | } else {
38 | // TODO: can copy errors from https://github.com/ziglang/zig/blob/9e684e8d1af39904055abe64a9afda69a3d44a59/lib/std/Thread/Futex.zig#L192
39 | Err(anyhow!("__ulock_wait failed: {}", r))
40 | }
41 | }
42 |
43 | /// Wait until this futex is awoken by a `wake` call.
44 | /// The thread will only be sent to sleep if the futex's value matches the
45 | /// expected value.
46 | pub fn wait(&self, expected_value: u32) -> Result<()> {
47 | self.wait_core(expected_value, 0)
48 | }
49 |
50 | /// Wait until this futex is awoken by a `wake` call, or until the timeout expires.
51 | /// The thread will only be sent to sleep if the futex's value matches the
52 | /// expected value.
53 | pub fn wait_for(&self, expected_value: u32, timeout: Duration) -> Result<()> {
54 | if timeout >= Duration::from_micros(u32::MAX as u64) {
55 | self.wait_core(expected_value, 0)
56 | } else {
57 | self.wait_core(expected_value, timeout.as_micros() as u32)
58 | }
59 | }
60 |
61 | /// Wake up `n` waiters.
62 | pub fn wake(&self, _n: i32) -> i32 {
63 | loop {
64 | let r = unsafe {
65 | __ulock_wake(
66 | UL_COMPARE_AND_WAIT_SHARED | ULF_NO_ERRNO | ULF_WAKE_ALL,
67 | self.value.as_ptr() as *mut libc::c_void,
68 | 0,
69 | )
70 | };
71 | if r == -libc::ENOENT {
72 | return 0;
73 | }
74 | if r >= 0 {
75 | return 1;
76 | }
77 | }
78 | }
79 | }
80 |
--------------------------------------------------------------------------------
/aicirt/src/msgchannel.rs:
--------------------------------------------------------------------------------
1 | use crate::{semaphore::Semaphore, shm::{Shm, Unlink}};
2 | use anyhow::Result;
3 | use std::time::Duration;
4 |
5 | pub struct MessageChannel {
6 | shm: Shm,
7 | write_sem: Semaphore,
8 | read_sem: Semaphore,
9 | }
10 |
11 | unsafe impl Send for MessageChannel {}
12 |
13 | impl MessageChannel {
14 | pub fn shm_name(name: &str) -> String {
15 | format!("{0}-shm", name)
16 | }
17 |
18 | pub fn new_cmd(name: &str, size: usize) -> Result {
19 | Self::new_ex(name, size, true)
20 | }
21 |
22 | pub fn new(name: &str, size: usize) -> Result {
23 | Self::new_ex(name, size, false)
24 | }
25 |
26 | fn new_ex(name: &str, size: usize, unlink: bool) -> Result {
27 | log::debug!("msg ch: {} size={}k", name, size / 1024);
28 |
29 | let shm = Shm::new(
30 | &Self::shm_name(name),
31 | size,
32 | if unlink { Unlink::Pre } else { Unlink::None },
33 | )?;
34 | let write_sem = Semaphore::new(&format!("{0}-wr", name), 1, unlink)?;
35 | let read_sem = Semaphore::new(&format!("{0}-rd", name), 0, unlink)?;
36 |
37 | Ok(Self {
38 | shm,
39 | write_sem,
40 | read_sem,
41 | })
42 | }
43 |
44 | pub fn send(&self, msg: &[u8]) -> Result<()> {
45 | self.shm.fits_msg(msg)?;
46 | self.write_sem.wait()?;
47 | self.shm.write_msg(msg).unwrap();
48 | self.read_sem.post()?;
49 | Ok(())
50 | }
51 |
52 | #[allow(dead_code)]
53 | pub fn busy_reset(&self) {
54 | unsafe { std::ptr::write_volatile(self.shm.ptr_at(0) as *mut u32, 0) };
55 | }
56 |
57 | #[allow(dead_code)]
58 | pub fn busy_send(&self, msg: &[u8]) -> Result<()> {
59 | self.shm.fits_msg(msg)?;
60 | loop {
61 | let len = self.shm.read_len()?;
62 | if len != 0 {
63 | std::hint::spin_loop();
64 | continue;
65 | }
66 | return Ok(self.shm.write_msg(msg).unwrap());
67 | }
68 | }
69 |
70 | #[allow(dead_code)]
71 | pub fn busy_recv(&self) -> Result> {
72 | loop {
73 | let len = self.shm.read_len()?;
74 | if len == 0 {
75 | std::hint::spin_loop();
76 | continue;
77 | }
78 | let res = self.shm.read_msg();
79 | return res;
80 | }
81 | }
82 |
83 | pub fn recv(&self, busy_wait_duration: &Duration) -> Result> {
84 | self.read_sem.busy_wait(busy_wait_duration)?;
85 | let res = self.shm.read_msg();
86 | self.write_sem.post()?;
87 | res
88 | }
89 | }
90 |
--------------------------------------------------------------------------------
/aicirt/src/semaphore.rs:
--------------------------------------------------------------------------------
1 | use anyhow::Result;
2 | use std::{
3 | ffi::CString,
4 | io,
5 | time::{Duration, Instant},
6 | };
7 |
8 | pub struct Semaphore {
9 | sem: *mut libc::sem_t,
10 | }
11 |
12 | impl Semaphore {
13 | fn last_error() -> Result {
14 | Err(io::Error::last_os_error().into())
15 | }
16 |
17 | pub fn new(name: &str, initial_value: u32, unlink: bool) -> Result {
18 | log::trace!("sem_open: {}", name);
19 | let c_name = CString::new(name).unwrap();
20 | if unlink {
21 | unsafe {
22 | libc::sem_unlink(c_name.as_ptr());
23 | };
24 | }
25 | let sem = unsafe { libc::sem_open(c_name.as_ptr(), libc::O_CREAT, 0o666, initial_value) };
26 |
27 | if sem.is_null() {
28 | return Self::last_error();
29 | }
30 |
31 | Ok(Self { sem })
32 | }
33 |
34 | pub fn wait(&self) -> Result<()> {
35 | let ret = unsafe { libc::sem_wait(self.sem) };
36 | if ret < 0 {
37 | return Self::last_error();
38 | }
39 | Ok(())
40 | }
41 |
42 | pub fn busy_wait(&self, wait_duration: &Duration) -> Result<()> {
43 | let deadline = Instant::now() + *wait_duration;
44 | loop {
45 | let ret = unsafe { libc::sem_trywait(self.sem) };
46 | if ret < 0 {
47 | #[cfg(target_os = "linux")]
48 | let last_error = unsafe { *libc::__errno_location() };
49 | #[cfg(not(target_os = "linux"))]
50 | let last_error = unsafe { *libc::__error() };
51 | if last_error == libc::EAGAIN {
52 | if Instant::now() > deadline {
53 | return self.wait();
54 | } else {
55 | // std::hint::spin_loop();
56 | continue;
57 | }
58 | } else {
59 | return Self::last_error();
60 | }
61 | } else {
62 | return Ok(());
63 | }
64 | }
65 | }
66 |
67 | pub fn post(&self) -> Result<()> {
68 | let ret = unsafe { libc::sem_post(self.sem) };
69 | if ret < 0 {
70 | return Self::last_error();
71 | }
72 | Ok(())
73 | }
74 | }
75 |
76 | impl Drop for Semaphore {
77 | fn drop(&mut self) {
78 | unsafe {
79 | libc::sem_close(self.sem);
80 | }
81 | }
82 | }
83 |
--------------------------------------------------------------------------------
/aicirt/vllm.md:
--------------------------------------------------------------------------------
1 | # Notes on integration with vLLM
2 |
3 | Following are callbacks in the vLLM flow:
4 |
5 | ```python
6 | LLMEngine.step()
7 | Scheduler.schedule()
8 | scheduler_outputs = Scheduler._schedule()
9 | SamplingParams.initiate_step(scheduler, llm_engine.counter, scheduler_outputs)
10 | return SequenceGroupMetadata(scheduler_outputs)
11 | samples = LLMEngine._run_workers("execute_model")
12 | Worker.execute_model
13 | Worker._prepare_inputs
14 | # ...
15 | SamplingParams.recv_attention_mask()
16 | # ...
17 | self.model()
18 | # ...
19 | Sampler.forward()
20 | logits = ...
21 | SamplingParams.apply_dynamic_logit_bias(logits)
22 | return _sample(logits) : SequenceOutputs
23 | return LLMEngine._process_model_outputs(samples)
24 | LLMEngine._process_sequence_group_samples()
25 | SamplingParams.append_ff_tokens(seq_group)
26 | # free and fork sequences as needed
27 | SamplingParams.finish_sampling()
28 | json_output = ...
29 | return json_output
30 | ```
31 |
32 | Thoughts:
33 | - expose Scheduler._schedule() and call it from LLMEngine; move initiate_step to LLMEngine
34 | - return logits from Sampler.forward() and call _sample() from LLMEngine; move apply_dynamic_logit_bias to LLMEngine
35 | - pass attn_mask to execute model from LLMEngine
36 |
37 | - vllm forks sequences in _process_sequence_group_samples(); this means fork processing in AICI is done
38 | in pre_process(), not process(), so it blocks; in full AICI env you would only fork from AICI module not
39 | n= parameter
--------------------------------------------------------------------------------
/controllers/aici_abi/.cargo/config.toml:
--------------------------------------------------------------------------------
1 | [build]
2 | target = "wasm32-wasi"
3 |
4 | [profile.dev]
5 | strip = "debuginfo"
6 |
7 | [profile.release]
8 | strip = "debuginfo"
9 |
--------------------------------------------------------------------------------
/controllers/aici_abi/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "aici_abi"
3 | version = "0.1.0"
4 | edition = "2021"
5 | rust-version = "1.75.0"
6 |
7 | [lib]
8 | name = "aici_abi"
9 |
10 | [dependencies]
11 | toktrie = { path = "../toktrie/core" }
12 | serde = { version = "1.0.192", features = ["derive"] }
13 | serde_json = "1.0.108"
14 | anyhow = "1.0.75"
15 | regex-automata = { version = "0.4.6", default-features = false, features = ["std", "dfa", "syntax", "perf", "meta"], optional = true }
16 | cfgrammar = { version = "0.13.3", optional = true }
17 | lrtable = { version = "0.13.3", optional = true }
18 | vob = { version = "3.0.3", optional = true }
19 | rustc-hash = { version = "2.0.0", optional = true }
20 | bytemuck = "1.16.0"
21 | bytemuck_derive = "1.6.0"
22 |
23 | [features]
24 | default = ["cfg", "rx"]
25 | cfg = ["dep:cfgrammar", "dep:lrtable", "dep:vob", "dep:rustc-hash"]
26 | rx = ["dep:regex-automata"]
27 |
28 | [[bin]]
29 | name = "yesno"
30 | path = "src/yesno.rs"
31 |
--------------------------------------------------------------------------------
/controllers/aici_abi/grammars/json0.guidance:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/controllers/aici_abi/grammars/json0.guidance
--------------------------------------------------------------------------------
/controllers/aici_abi/implementation.md:
--------------------------------------------------------------------------------
1 | # Implementation notes
2 |
3 | ## LR(1) parsing
4 |
5 | The LR(1) parsing consists of DFA-based lexer and the actual LR(1) parser.
6 | DFA has a single number as the state, while the state of the LR(1) is a stack of numbers.
7 | The LR(1) action is determined based on the next token from the lexer and the top of the stack.
8 |
9 | The `Recognizer` interface also has a concept of stack, however every entry on that
10 | stack contains a DFA state and an LR(1) stack.
11 |
12 | Most of the time (~98.5% for the C grammar), pushing a byte involves only updating the DFA state,
13 | while the LR(1) stack is copied unchanged (the memory is shared).
14 |
15 |
16 | ### Early error detection
17 |
18 | Consider the following invalid C program:
19 |
20 | ```c
21 | int 123456;
22 | ```
23 |
24 | The lexer would produce `int` keyword, whitespace, `123456` constant and `;` keyword.
25 | The parser would reject `123456`, however only after all six characters of it have been read.
26 | This is too late for the LLM.
27 |
28 | To detect such errors early, we compute a set of reachable tokens for each DFA state.
29 | For example, consider a DFA that recognizes `int`, `if`, `ID` (`/[a-z][a-z0-9]*/`) and `INTLIT` (`/[0-9]+/`).
30 | The initial DFA state has a full set of tokens, while a state after `'i'`
31 | has only `int`, `if`, and `ID`,
32 | and a state after `'1'` includes only `INTLIT`.
33 | In the picture below, each state is labelled by its reachable set,
34 | and the token for which it is a match (if any) is postfixed with `*`. We only use lower-case letters and digits for simplicity.
35 |
36 | ```mermaid
37 | graph LR
38 | 0["{int,if,ID,INTLIT}"] -- "[i]" --> i(("{int,if,ID*}"))
39 | 0 -- "[a-z] - [i]" --> id(("{ID*}"))
40 | 0 -- "[0-9]" --> const(("{INTLIT*}"))
41 | const -- "[0-9]" --> const
42 | const -- "[a-z]" --> bot["{}"]
43 | i -- "[a-z0-9] - [nf]" --> id
44 | id -- "[a-z0-9]" --> id
45 | i -- "[n]" --> in(("{int,ID*}"))
46 | in -- "[t]" --> int(("{int*,ID}"))
47 | in -- "[a-z0-9] - [t]" --> id
48 | int -- "[a-z0-9]" --> id
49 | i -- "[f]" --> if(("{if*,ID}"))
50 | if -- "[a-z0-9]" --> id
51 | ```
52 |
53 | For each LR(1) automaton state we compute a set of viable tokens, i.e., ones that do
54 | not immediately lead to an error.
55 |
56 | While parsing input, if the intersection of viable and reachable tokens is empty, we report an error.
57 |
58 | In the example above, the viable tokens after `int` do not include `INTLIT`,
59 | and thus the parser fails immediately at `1`.
60 |
61 |
--------------------------------------------------------------------------------
/controllers/aici_abi/src/yesno.rs:
--------------------------------------------------------------------------------
1 | use aici_abi::{host_trie, tokenize, toktrie::TokTrie, AiciCtrl, MidProcessArg, MidProcessResult, TokenId};
2 |
3 | pub struct Runner {
4 | toktrie: TokTrie,
5 | tokens: Vec,
6 | yes: TokenId,
7 | no: TokenId,
8 | }
9 |
10 | impl Runner {
11 | pub fn new() -> Self {
12 | let yes = tokenize("Yes")[0];
13 | let no = tokenize("No")[0];
14 | // ignore user-passed arg
15 | Runner {
16 | toktrie: host_trie(),
17 | tokens: Vec::new(),
18 | yes,
19 | no,
20 | }
21 | }
22 | }
23 |
24 | impl AiciCtrl for Runner {
25 | fn mid_process(&mut self, arg: MidProcessArg) -> MidProcessResult {
26 | arg.save_tokens(&mut self.tokens);
27 | if self.tokens.len() >= 1 {
28 | // we only want the first token
29 | MidProcessResult::stop()
30 | } else {
31 | let mut set = self.toktrie.alloc_token_set();
32 | set.allow_token(self.yes);
33 | set.allow_token(self.no);
34 | MidProcessResult::sample(set)
35 | }
36 | }
37 | }
38 |
39 | fn main() {
40 | // test code here?
41 | }
42 |
43 | aici_abi::aici_expose_all!(Runner, Runner::new());
44 |
--------------------------------------------------------------------------------
/controllers/aici_native/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "aici_native"
3 | version = "0.1.0"
4 | edition = "2021"
5 |
6 | [lib]
7 | name = "aici_native"
8 |
9 | [dependencies]
10 | aici_abi = { path = "../aici_abi" }
11 | toktrie_hf_tokenizers = { path = "../toktrie/hf_tokenizers" }
12 | serde = { version = "1.0.192", features = ["derive"] }
13 | serde_json = "1.0.108"
14 | anyhow = "1.0.75"
15 | rustc-hash = "2.0.0"
16 | tokenizers = { version = "0.15.0", features = ["http"] }
17 | log = "0.4.21"
18 | flexi_logger = "0.28.0"
19 |
--------------------------------------------------------------------------------
/controllers/aici_native/README.md:
--------------------------------------------------------------------------------
1 | # AICI native
2 |
3 | Utilities for building native (non-Wasm) AICI Controllers.
4 |
--------------------------------------------------------------------------------
/controllers/aici_native/src/lib.rs:
--------------------------------------------------------------------------------
1 | pub mod bintokens;
2 | mod log;
3 | pub mod variables;
4 |
5 | pub use log::*;
6 |
7 | pub use rustc_hash::FxHashMap as HashMap;
8 | pub use rustc_hash::FxHashSet as HashSet;
9 |
--------------------------------------------------------------------------------
/controllers/aici_native/src/log.rs:
--------------------------------------------------------------------------------
1 | use std::fmt::Write;
2 |
3 | use anyhow::Result;
4 | use flexi_logger::style;
5 | use flexi_logger::{DeferredNow, Logger, WriteMode};
6 | use log::Record;
7 |
8 | pub enum LogMode {
9 | Normal,
10 | Test,
11 | Daemon,
12 | }
13 |
14 | struct LimitedWrite {
15 | limit: usize,
16 | dst: Vec,
17 | }
18 |
19 | impl Write for LimitedWrite {
20 | fn write_str(&mut self, s: &str) -> std::fmt::Result {
21 | if self.dst.len() > self.limit {
22 | return Err(std::fmt::Error);
23 | }
24 | if self.dst.len() + s.len() < self.limit {
25 | self.dst.extend_from_slice(s.as_bytes());
26 | Ok(())
27 | } else {
28 | let remaining = self.limit - self.dst.len();
29 | self.dst.extend_from_slice(&s.as_bytes()[..remaining]);
30 | self.dst.extend_from_slice(b" (...)");
31 | Err(std::fmt::Error)
32 | }
33 | }
34 | }
35 |
36 | fn args_to_str(limit: usize, args: &std::fmt::Arguments) -> String {
37 | // let capacity = args.estimated_capacity();
38 | let mut output = LimitedWrite {
39 | limit,
40 | dst: Vec::with_capacity(128),
41 | };
42 | if output.write_fmt(*args).is_err() {
43 | assert!(output.dst.len() > limit);
44 | }
45 | match String::from_utf8(output.dst) {
46 | Ok(s) => s,
47 | Err(err) => String::from_utf8_lossy(err.as_bytes()).to_string(),
48 | }
49 | }
50 |
51 | fn truncated_format(
52 | w: &mut dyn std::io::Write,
53 | _now: &mut DeferredNow,
54 | record: &Record,
55 | ) -> Result<(), std::io::Error> {
56 | let level = record.level();
57 | write!(
58 | w,
59 | "{} [{}] {}",
60 | style(level).paint(level.to_string()),
61 | record.module_path().unwrap_or(""),
62 | style(level).paint(args_to_str(1000, record.args()))
63 | )
64 | }
65 |
66 | fn daemon_format(
67 | w: &mut dyn std::io::Write,
68 | now: &mut DeferredNow,
69 | record: &Record,
70 | ) -> Result<(), std::io::Error> {
71 | write!(
72 | w,
73 | "{} {} [{}] {}",
74 | now.format("%Y-%m-%d %H:%M:%S%.3f"),
75 | record.level(),
76 | record.module_path().unwrap_or(""),
77 | args_to_str(5000, record.args())
78 | )
79 | }
80 |
81 | pub fn init_log(mode: LogMode) -> Result<()> {
82 | let logger = match mode {
83 | LogMode::Normal => Logger::try_with_env_or_str("info")?
84 | .format(truncated_format)
85 | .log_to_stdout(),
86 | LogMode::Test => {
87 | Logger::try_with_env_or_str("debug")?.write_mode(WriteMode::SupportCapture)
88 | }
89 | LogMode::Daemon => Logger::try_with_env_or_str("info")?
90 | .format(daemon_format)
91 | .log_to_stdout(),
92 | };
93 |
94 | logger.start()?;
95 | Ok(())
96 | }
97 |
98 | pub fn setup_log() {
99 | init_log(LogMode::Normal).expect("Failed to initialize log")
100 | }
101 |
--------------------------------------------------------------------------------
/controllers/aici_native/src/variables.rs:
--------------------------------------------------------------------------------
1 | use aici_abi::{StorageCmd, StorageOp, StorageResp};
2 | use rustc_hash::FxHashMap;
3 |
4 | #[derive(Default)]
5 | pub struct Variables {
6 | pub variables: FxHashMap)>,
7 | }
8 |
9 | impl Variables {
10 | pub fn process_cmd(&mut self, cmd: StorageCmd) -> StorageResp {
11 | match cmd {
12 | StorageCmd::ReadVar { name } => match self.variables.get(&name).map(|x| x.clone()) {
13 | None => StorageResp::VariableMissing {},
14 | Some((version, value)) => StorageResp::ReadVar { value, version },
15 | },
16 | StorageCmd::WriteVar {
17 | name,
18 | value,
19 | when_version_is,
20 | op,
21 | } => {
22 | let curr = self.variables.get(&name).map(|x| x.clone());
23 | match curr {
24 | Some((prev_version, prev_val)) => match when_version_is {
25 | Some(v) if v != prev_version => StorageResp::ReadVar {
26 | version: prev_version,
27 | value: prev_val,
28 | },
29 | _ => {
30 | let value = match op {
31 | StorageOp::Append => {
32 | let mut v = prev_val.clone();
33 | v.extend(value);
34 | v
35 | }
36 | StorageOp::Set => value,
37 | };
38 | let version = prev_version + 1;
39 | self.variables.insert(name, (version, value));
40 | StorageResp::WriteVar { version }
41 | }
42 | },
43 |
44 | None => match when_version_is {
45 | None => {
46 | self.variables.insert(name, (1, value));
47 | StorageResp::WriteVar { version: 1 }
48 | }
49 | Some(_) => StorageResp::VariableMissing {},
50 | },
51 | }
52 | }
53 | }
54 | }
55 | }
56 |
57 |
--------------------------------------------------------------------------------
/controllers/declctrl/.cargo/config.toml:
--------------------------------------------------------------------------------
1 | [build]
2 | target = "wasm32-wasi"
3 |
4 | [profile.dev]
5 | strip = "debuginfo"
6 |
7 | [profile.release]
8 | strip = "debuginfo"
9 |
--------------------------------------------------------------------------------
/controllers/declctrl/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "aici_declctrl"
3 | version = "0.1.0"
4 | edition = "2021"
5 |
6 | [dependencies]
7 | aici_abi = { path = "../aici_abi" }
8 | regex-automata = { version = "0.3.8", default-features = false, features = ["std", "dfa", "syntax", "perf", "meta"] }
9 | serde = { version = "1.0.192", features = ["derive"] }
10 | serde_json = "1.0.108"
11 | anyhow = "1.0.75"
12 |
13 | [[bin]]
14 | name = "aici_declctrl"
15 | path = "src/declctrl.rs"
16 |
--------------------------------------------------------------------------------
/controllers/declctrl/README.md:
--------------------------------------------------------------------------------
1 | # DeclCtrl
2 |
3 | The [DeclCtrl](src/declctrl.rs) exposes similar constraints
4 | to [PyCtrl](../pyctrl), but the glueing is done via a JSON AST (Abstract Syntax Tree) and thus is
5 | more restrictive.
6 |
7 | There is no reason to use it as is, but it can be used as a base for other controller.
8 |
9 |
--------------------------------------------------------------------------------
/controllers/declctrl/arg2.json:
--------------------------------------------------------------------------------
1 | {
2 | "steps": [
3 | {
4 | "Fixed": {
5 | "text": {
6 | "String": {
7 | "str": "The word 'hello'"
8 | }
9 | }
10 | }
11 | },
12 | {
13 | "Fixed": {
14 | "text": {
15 | "String": {
16 | "str": " in French is translated as"
17 | }
18 | },
19 | "label": "lang"
20 | }
21 | },
22 | {
23 | "Gen": {
24 | "rx": " '[^']*'",
25 | "inner": [],
26 | "max_tokens": 15,
27 | "stmts": [
28 | {
29 | "Set": {
30 | "var": "french",
31 | "expr": {
32 | "Current": {}
33 | }
34 | }
35 | }
36 | ]
37 | }
38 | },
39 | {
40 | "Fixed": {
41 | "text": {
42 | "String": {
43 | "str": " or"
44 | }
45 | },
46 | "following": "lang"
47 | }
48 | },
49 | {
50 | "Gen": {
51 | "rx": " '[^']*'",
52 | "inner": [],
53 | "max_tokens": 15,
54 | "stmts": [
55 | {
56 | "Set": {
57 | "var": "blah",
58 | "expr": {
59 | "Current": {}
60 | }
61 | }
62 | }
63 | ]
64 | }
65 | },
66 | {
67 | "Fixed": {
68 | "text": {
69 | "Concat": {
70 | "parts": [
71 | {
72 | "String": {
73 | "str": "\nResults: "
74 | }
75 | },
76 | {
77 | "Var": {
78 | "var": "french"
79 | }
80 | },
81 | {
82 | "String": {
83 | "str": " "
84 | }
85 | },
86 | {
87 | "Var": {
88 | "var": "blah"
89 | }
90 | }
91 | ]
92 | }
93 | }
94 | }
95 | }
96 | ]
97 | }
--------------------------------------------------------------------------------
/controllers/declctrl/genarg.py:
--------------------------------------------------------------------------------
1 | # run as:
2 | # PYTHONPATH=.. python genarg.py
3 | import ujson
4 | import pyaici._ast as ast
5 |
6 | def main():
7 | aa = {
8 | "steps": [ast.fixed("Here's some JSON about J.R.Hacker from Seattle:\n")]
9 | + ast.json_to_steps(
10 | {
11 | "name": "",
12 | "valid": True,
13 | "description": "",
14 | "type": "foo|bar|baz|something|else",
15 | "address": {"street": "", "city": "", "state": "[A-Z][A-Z]"},
16 | "age": 1,
17 | "fraction": 1.5,
18 | }
19 | )
20 | }
21 |
22 | aa = {
23 | "steps": [
24 | ast.fixed("The word 'hello'"),
25 | ast.label("lang", ast.fixed(" in French is translated as")),
26 | ast.gen(rx=r" '[^']*'", max_tokens=15, set_var="french"),
27 | ast.fixed(" or", following="lang"),
28 | ast.gen(rx=r" '[^']*'", max_tokens=15, set_var="blah"),
29 | ast.fixed("\nResults: {{french}} {{blah}}", expand_vars=True),
30 | ]
31 | }
32 |
33 | ast.clear_none(aa)
34 | print(ujson.dumps(aa))
35 |
36 | main()
--------------------------------------------------------------------------------
/controllers/declctrl/native.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | T=llama
4 | T=gpt4
5 |
6 | set -x
7 | set -e
8 | if test -f tokenizer.bin ; then
9 | echo "Skipping tokenizer"
10 | else
11 | (cd ../aicirt && cargo run --release -- --tokenizer $T --save-tokenizer ../declctrl/tokenizer.bin)
12 | fi
13 | cargo build --release
14 | if [ `uname` = Linux ] ; then
15 | perf stat ./target/release/declctrl
16 | else
17 | ./target/release/declctrl
18 | fi
19 |
--------------------------------------------------------------------------------
/controllers/declctrl/size.js:
--------------------------------------------------------------------------------
1 | const fs = require("fs");
2 |
3 | // run "cmd" and capture the output
4 | function run(cmd) {
5 | return new Promise((resolve, reject) => {
6 | require("child_process").exec(cmd, (error, stdout, stderr) => {
7 | if (error) {
8 | reject(error);
9 | } else {
10 | resolve(stdout.trim());
11 | }
12 | });
13 | });
14 | }
15 |
16 | function fmt(perc) {
17 | return perc.toFixed(2).padStart(6) + "%";
18 | }
19 |
20 | function gethd(o) {
21 | return (
22 | fmt(o.retained_size_percent) +
23 | " " +
24 | fmt(o.shallow_size_percent) +
25 | " " +
26 | o.name
27 | );
28 | }
29 |
30 | function cchildren(o) {
31 | const r = {};
32 | (o.children ?? []).forEach((c) => {
33 | r[gethd(c)] = cchildren(c);
34 | });
35 | return r;
36 | }
37 |
38 | async function main() {
39 | const o = JSON.parse(
40 | await run("twiggy dominators target/strip.wasm -f json")
41 | );
42 | o.root = cchildren({ name: "ROOT", children: o.items });
43 | delete o.items;
44 | fs.writeFileSync(
45 | "target/dominators.json",
46 | JSON.stringify(o, null, 2)
47 | );
48 | }
49 | main();
50 |
--------------------------------------------------------------------------------
/controllers/declctrl/wasm.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | set -x
4 | set -e
5 | cargo build --release
6 | BIN=$(cd ../target; pwd)
7 | cp $BIN/wasm32-wasi/release/aici_declctrl.wasm $BIN/opt.wasm
8 | ls -l $BIN/opt.wasm
9 | if [ "X$1" = "Xbuild" ] ; then
10 | exit
11 | fi
12 | if [ "X$1" = "Xsize" ] ; then
13 | node size.js
14 | fx $BIN/dominators.json
15 | exit
16 | fi
17 |
18 | (cd ../aicirt; cargo build --release)
19 |
20 | mkdir -p tmp
21 | if [ "X$1" = "Xcache" ] ; then
22 | $BIN/release/aicirt --module $BIN/opt.wasm | tee tmp/runlog.txt
23 | exit
24 | fi
25 |
26 | PERF=
27 | if [ `uname` = Linux ] ; then
28 | PERF="perf stat"
29 | fi
30 | RUST_LOG=info $PERF $BIN/release/aicirt --tokenizer gpt4 --module $BIN/opt.wasm
31 | RUST_LOG=info $PERF $BIN/release/aicirt \
32 | --tokenizer gpt4 --module $BIN/opt.wasm --run | tee tmp/runlog.txt
33 | ls -l $BIN/opt.wasm
34 |
--------------------------------------------------------------------------------
/controllers/jsctrl/.cargo/config.toml:
--------------------------------------------------------------------------------
1 | [build]
2 | target = "wasm32-wasi"
3 |
4 | [profile.dev]
5 | strip = "debuginfo"
6 |
7 | [profile.release]
8 | strip = "debuginfo"
9 |
--------------------------------------------------------------------------------
/controllers/jsctrl/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "aici_jsctrl"
3 | version = "0.1.0"
4 | edition = "2021"
5 | build = "build.rs"
6 |
7 | [dependencies]
8 | aici_abi = { path = "../aici_abi" }
9 | serde = { version = "1.0.192", features = ["derive"] }
10 | serde_json = "1.0.108"
11 | anyhow = "1.0.75"
12 | lazy_static = "1.4.0"
13 | rquickjs = { git = "https://github.com/DelSkayn/rquickjs", rev = "343b21b742d3bb052710dc53144b79dc61bb592d", features = ["array-buffer", "macro"] }
14 |
15 | [[bin]]
16 | name = "aici_jsctrl"
17 | path = "src/jsctrl.rs"
18 |
19 | [build-dependencies]
20 | glob = "0.3.1"
21 |
--------------------------------------------------------------------------------
/controllers/jsctrl/README.md:
--------------------------------------------------------------------------------
1 | # JsCtrl
2 |
3 | This crate implements AI Controller Interface by embedding
4 | [QuickJS](https://bellard.org/quickjs/) (JavaScript (ES2023) interpreter)
5 | via [rquickjs](https://github.com/DelSkayn/rquickjs)
6 | in a Wasm module together with native
7 | primitives for specific kinds of output constraints:
8 | fixed token output, regexps, LR(1) grammars, substring constrains etc.
9 | JavaScript code is typically only used lightly, for gluing the primitives together,
10 | and thus is not performance critical.
11 |
12 | There are [some sample scripts](./samples/) available.
13 | The scripts use the [aici module](./samples/aici-types.d.ts) to communicate with the AICI runtime
14 | and use the native constraints.
15 |
16 | This is quite similar to [PyCtrl](../pyctrl/README.md) but with JavaScript instead of Python.
17 | It is also smaller, at 1.3MiB without regex and CFG, 1.8MiB with regex, and 3.3MiB with regex and CFG.
18 | For comparison, pyctrl is 14MiB.
19 | Also, the [PyCtrl samples](../pyctrl/samples/) translate 1:1 to JsCtrl.
20 |
21 | ## Usage
22 |
23 | To run a JsCtrl sample use:
24 |
25 | ```bash
26 | ../../aici.sh run samples/hello.js
27 | ```
28 |
29 | If you write your sample in TypeScript, compile it first with `tsc -p samples`.
30 |
31 | If you want to build the interpreter yourself, use:
32 |
33 | ```bash
34 | ../../aici.sh run --build . samples/hello.js
35 | ```
36 |
37 | You will see the console output of the program.
38 |
39 |
--------------------------------------------------------------------------------
/controllers/jsctrl/build.rs:
--------------------------------------------------------------------------------
1 | use std::process::Command;
2 |
3 | fn run(cmd: &mut Command, msg: &str) {
4 | let status = cmd.status().expect(&format!("failed to execute: {msg}"));
5 | if !status.success() {
6 | panic!("process exited with status: {}; {}", status, msg);
7 | }
8 | }
9 |
10 | fn rerun_if_glob(pattern: &str) {
11 | for entry in glob::glob(pattern).expect(pattern).flatten() {
12 | let display = entry.display();
13 | println!("cargo:rerun-if-changed={display}");
14 | }
15 | }
16 |
17 | fn main() {
18 | rerun_if_glob("ts/*.ts");
19 | rerun_if_glob("ts/*.json");
20 | rerun_if_glob("gen-dts.mjs");
21 |
22 | if Command::new("tsc").arg("--version").status().is_err() {
23 | // println!("cargo:warning=typescript not found, installing...");
24 | run(
25 | Command::new("npm")
26 | .arg("install")
27 | .arg("-g")
28 | .arg("typescript"),
29 | "npm install failed",
30 | );
31 | }
32 |
33 | run(Command::new("tsc").arg("-p").arg("ts"), "build failed");
34 | run(Command::new("node").arg("gen-dts.mjs"), "gen-dts failed");
35 | }
36 |
--------------------------------------------------------------------------------
/controllers/jsctrl/gen-dts.mjs:
--------------------------------------------------------------------------------
1 | import * as fs from 'fs';
2 |
3 | function gen() {
4 | const ts = "./ts/"
5 | const native = fs.readFileSync(ts + '/native.d.ts', 'utf8')
6 | let aici = fs.readFileSync(ts + '/dist/aici.d.ts', 'utf8')
7 | aici = aici.replace(/ pre + aici + '"""')
16 |
17 | const tsconfig = fs.readFileSync("./samples/tsconfig.json", "utf-8")
18 | jssrc = jssrc.replace(/(tsconfig_json = r""")[^]*?"""/g, (_, pre) => pre + tsconfig + '"""')
19 |
20 | const hello = fs.readFileSync("./samples/hello.js", "utf-8")
21 | jssrc = jssrc.replace(/(hello_js = r""")[^]*?"""/g, (_, pre) => pre + hello + '"""')
22 |
23 | fs.writeFileSync("../../py/pyaici/jssrc.py", jssrc)
24 | }
25 |
26 | gen()
--------------------------------------------------------------------------------
/controllers/jsctrl/samples/count-yellow.js:
--------------------------------------------------------------------------------
1 | import { Label, setLogLevel } from "aici"
2 |
3 | const colors = [
4 | // "red",
5 | // "green",
6 | "blue",
7 | // "violet",
8 | // "white",
9 | // "black",
10 | ]
11 |
12 | function gencolors(k) {
13 | let s = []
14 |
15 | for (let i = 0; i < k; ++i) {
16 | s.push(colors[Math.floor(Math.random() * colors.length)])
17 | }
18 |
19 | return s.join(", ")
20 | }
21 |
22 | async function countYellow() {
23 | setLogLevel(10)
24 | const q = `Does the color yellow appear in the following list of colors?`
25 | await $`<|user|>\n${q}\n`
26 | const l = new Label()
27 | for (let i = 0; i < 100; ++i) {
28 | const hasYellow = Math.random() < 0.5
29 | const A = (10 + Math.random() * 100) | 0
30 | const B = (10 + Math.random() * 130) | 0
31 | const text = gencolors(A) + (hasYellow ? ", yellow, " : ", blue, ") + gencolors(B)
32 | await l.fixedAfter(`${text}<|end|>\n<|assistant|>`)
33 | const r = await gen({ maxTokens: 10, regex: /(Yes|No)/ })
34 | console.log(q)
35 | console.log(text)
36 | console.log(hasYellow ? "Yes" : "No", r)
37 | assert(r === "Yes" || r === "No")
38 | assert(r === (hasYellow ? "Yes" : "No"))
39 | await $`\n`
40 | }
41 | }
42 |
43 | start(countYellow)
44 |
--------------------------------------------------------------------------------
/controllers/jsctrl/samples/hello.js:
--------------------------------------------------------------------------------
1 | async function main() {
2 | await $`Ultimate answer is to the life, universe and everything is `
3 | await gen({ regex: /\d\d/ })
4 | }
5 |
6 | start(main)
7 |
--------------------------------------------------------------------------------
/controllers/jsctrl/samples/hellots.ts:
--------------------------------------------------------------------------------
1 | async function main() {
2 | await $`Hello`
3 | await gen({ regex: / [A-Z]+/ })
4 | }
5 |
6 | start(main)
7 |
--------------------------------------------------------------------------------
/controllers/jsctrl/samples/mapping.js:
--------------------------------------------------------------------------------
1 | import { Label, getTokens, setLogLevel } from "aici"
2 |
3 | function randomInt(min, max) {
4 | return Math.floor(Math.random() * (max - min + 1) + min)
5 | }
6 |
7 | async function countYellow() {
8 | setLogLevel(0)
9 | const q = `Tell me the value of x42?`
10 | await $`<|user|>\n${q}\n`
11 | const l = new Label()
12 | let numok = 0
13 | for (let i = 0; i < 20; ++i) {
14 | let text = ""
15 | let x42 = randomInt(10, 99)
16 | for (let i = 10; i < 300; ++i) {
17 | text += `The value of x${i} is ${i == 42 ? x42 : randomInt(10, 99)}.\n`
18 | }
19 | await l.fixedAfter(`${text}\nTell me x42.<|end|>\n<|assistant|>The value of x42 is `)
20 | const r = await gen({ maxTokens: 10, regex: /\d\d/ })
21 | // console.log(q)
22 | // console.log(text)
23 | console.log(getTokens().length, x42, r, r === x42.toString())
24 | if (r === x42.toString()) {
25 | numok++
26 | }
27 | // assert(r === x42.toString())
28 | await $`\n`
29 | }
30 | console.log("numok", numok)
31 | }
32 |
33 | start(countYellow)
34 |
--------------------------------------------------------------------------------
/controllers/jsctrl/samples/schema.js:
--------------------------------------------------------------------------------
1 | // Doesn't seem to work too well...
2 |
3 | import { Label } from "aici"
4 |
5 | async function jsonString() {
6 | await gen({
7 | maxTokens: 50,
8 | regex: /"(\\(["\\\/bfnrt]|u[a-fA-F0-9]{4})|[^"\\\x00-\x1F\x7F]+)+"/
9 | })
10 | }
11 |
12 | async function jsonInt() {
13 | await gen({ regex: /\d+/ })
14 | }
15 |
16 | /**
17 | * @param {string} name
18 | */
19 | async function jsonField(name) {
20 | await $` "${name}": `
21 | }
22 |
23 | async function cityList() {
24 | const start = new Label()
25 | await $`[`
26 | const maxNodes = 3;
27 | for (let i = 0; i < maxNodes; i++) {
28 | await $`{\n`
29 | await jsonField("name")
30 | await jsonString()
31 | await $`,\n`
32 | await jsonField("population")
33 | await jsonInt()
34 | await $`,\n`
35 | await jsonField("url")
36 | await jsonString()
37 | await $`\n`
38 | const nextChar = await gen({ options: ['},\n', ']'] })
39 | if (nextChar === ']') {
40 | break
41 | }
42 | }
43 | console.log(start.textSince())
44 | }
45 |
46 | async function main() {
47 | await $`Here is JSON objects for five European cities:\n`
48 | await cityList()
49 | }
50 |
51 | start(main)
52 |
--------------------------------------------------------------------------------
/controllers/jsctrl/samples/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | /* Visit https://aka.ms/tsconfig to read more about this file */
4 | "target": "ES2020",
5 | "lib": [
6 | "ES2020"
7 | ],
8 | "moduleDetection": "force",
9 | "module": "ES2020",
10 | "allowJs": true,
11 | "checkJs": true,
12 | "strict": true,
13 | "noImplicitThis": true,
14 | "noImplicitReturns": true,
15 | "outDir": "./dist",
16 | "skipDefaultLibCheck": true,
17 | }
18 | }
--------------------------------------------------------------------------------
/controllers/llguidance_ctrl/.cargo/config.toml:
--------------------------------------------------------------------------------
1 | [build]
2 | target = "wasm32-wasi"
3 |
4 | [profile.dev]
5 | strip = "debuginfo"
6 |
7 | [profile.release]
8 | strip = "debuginfo"
9 |
--------------------------------------------------------------------------------
/controllers/llguidance_ctrl/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "aici_llguidance_ctrl"
3 | version = "0.1.0"
4 | edition = "2021"
5 |
6 | [dependencies]
7 | aici_abi = { path = "../aici_abi" }
8 | llguidance_parser = { path = "../llguidance/parser" }
9 | serde = { version = "1.0.192", features = ["derive"] }
10 | serde_json = "1.0.108"
11 | anyhow = "1.0.75"
12 |
13 | [[bin]]
14 | name = "aici_llguidance_ctrl"
15 | path = "src/gctrl.rs"
16 |
17 | [features]
18 | default = []
19 | logging = ["llguidance_parser/logging"]
--------------------------------------------------------------------------------
/controllers/llguidance_ctrl/go.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | F="$1"
4 | if [ -z "$F" ] ; then
5 | F=run_g.py
6 | fi
7 |
8 | set -x
9 | cd `dirname $0`
10 | HERE=`pwd`
11 | PYTHONPATH=$HERE/../../py:$HERE/../../py/guidance \
12 | python3 $HERE/$F
13 |
14 |
--------------------------------------------------------------------------------
/controllers/llguidance_ctrl/src/gctrl.rs:
--------------------------------------------------------------------------------
1 | use std::sync::Arc;
2 |
3 | use aici_abi::{
4 | arg_bytes, get_config,
5 | toktrie::{InferenceCapabilities, StepArg},
6 | AiciCtrl, InitPromptArg, InitPromptResult, MidProcessArg, MidProcessResult,
7 | };
8 | use serde::{Deserialize, Serialize};
9 |
10 | use llguidance_parser::{api::TopLevelGrammar, output::Reporter, Logger, TokenParser};
11 |
12 | const INFO: bool = true;
13 |
14 | macro_rules! infoln {
15 | ($($arg:tt)*) => {
16 | if INFO {
17 | println!($($arg)*);
18 | }
19 | };
20 | }
21 |
22 | pub struct Runner {
23 | tok_parser: TokenParser,
24 | reporter: Reporter,
25 | }
26 |
27 | #[derive(Serialize, Deserialize)]
28 | struct RunnerArg {
29 | grammar: TopLevelGrammar,
30 | }
31 |
32 | impl Runner {
33 | pub fn new() -> Self {
34 | infoln!("building runner...");
35 | let arg: RunnerArg = serde_json::from_slice(&arg_bytes()).expect("invalid JSON arg");
36 | let log_level = 2;
37 | let inf = InferenceCapabilities {
38 | backtrack: get_config("backtrack") != 0,
39 | ff_tokens: get_config("ff_tokens") != 0,
40 | conditional_ff_tokens: get_config("ff_tokens") != 0,
41 | fork: false,
42 | };
43 | let tok_parser = TokenParser::from_llguidance_json(
44 | Arc::new(aici_abi::WasmTokenizerEnv::default()),
45 | arg.grammar,
46 | Logger::new(0, log_level),
47 | inf,
48 | )
49 | .expect("invalid guidance protobuf");
50 |
51 | let reporter = Reporter::new(&tok_parser);
52 | Runner {
53 | tok_parser,
54 | reporter,
55 | }
56 | }
57 | }
58 |
59 | fn json_out(obj: &T) {
60 | println!("JSON-OUT: {}", serde_json::to_string(obj).unwrap());
61 | }
62 |
63 | impl AiciCtrl for Runner {
64 | fn init_prompt(&mut self, arg: InitPromptArg) -> InitPromptResult {
65 | InitPromptResult {
66 | prompt: self.tok_parser.process_prompt(arg.prompt),
67 | }
68 | }
69 | fn mid_process(&mut self, arg: MidProcessArg) -> MidProcessResult {
70 | let r = self.tok_parser.mid_process(StepArg {
71 | backtrack: arg.backtrack,
72 | tokens: arg.tokens,
73 | sampled: arg.sampled,
74 | });
75 | for v in self.reporter.get_progress(&mut self.tok_parser, &r) {
76 | json_out(&v);
77 | }
78 | MidProcessResult::from_branch(r)
79 | }
80 | }
81 |
82 | fn main() {}
83 |
84 | aici_abi::aici_expose_all!(Runner, Runner::new());
85 |
--------------------------------------------------------------------------------
/controllers/llguidance_ctrl/text_req.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import os
3 |
4 | base = os.getenv("AICI_API_BASE", "http://localhost:4242/v1")
5 | url = base + '/completions'
6 |
7 | headers = {
8 | 'Content-Type': 'application/json',
9 | }
10 |
11 | data = {
12 | 'model': 'model',
13 | 'prompt': 'Once upon a time,',
14 | 'max_tokens': 5,
15 | 'temperature': 0,
16 | 'stream': True
17 | }
18 |
19 | # read tmp/prompt.txt
20 | with open('tmp/prompt.txt', 'r') as file:
21 | data['prompt'] = file.read()
22 |
23 | response = requests.post(url, headers=headers, json=data, stream=True)
24 |
25 | if response.status_code == 200:
26 | for line in response.iter_lines():
27 | if line:
28 | decoded_line = line.decode('utf-8')
29 | print(decoded_line)
30 | else:
31 | print(f"Request failed with status code {response.status_code}")
32 | print(response.text)
33 |
--------------------------------------------------------------------------------
/controllers/pyctrl/.cargo/config.toml:
--------------------------------------------------------------------------------
1 | [build]
2 | target = "wasm32-wasi"
3 |
4 | [profile.dev]
5 | strip = "debuginfo"
6 |
7 | [profile.release]
8 | strip = "debuginfo"
9 |
--------------------------------------------------------------------------------
/controllers/pyctrl/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "aici_pyctrl"
3 | version = "0.1.0"
4 | edition = "2021"
5 |
6 | [dependencies]
7 | aici_abi = { path = "../aici_abi" }
8 | serde = { version = "1.0.192", features = ["derive"] }
9 | serde_json = "1.0.108"
10 | anyhow = "1.0.75"
11 | rustpython-vm = { git = "https://github.com/RustPython/RustPython", rev = "317f44945420e", default-features = false, features = ["compiler"] }
12 | rustpython-derive = { git = "https://github.com/RustPython/RustPython", rev = "317f44945420e" }
13 | lazy_static = "1.4.0"
14 | num-traits = "0.2.17"
15 | crossbeam-utils = "0.8.16"
16 | once_cell = "1.18.0"
17 |
18 | [[bin]]
19 | name = "aici_pyctrl"
20 | path = "src/pyctrl.rs"
21 |
22 | [build-dependencies]
23 | glob = "0.3.1"
24 |
--------------------------------------------------------------------------------
/controllers/pyctrl/Lib/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 RustPython Team
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/controllers/pyctrl/Lib/collections/_defaultdict.py:
--------------------------------------------------------------------------------
1 | from reprlib import recursive_repr as _recursive_repr
2 |
3 | class defaultdict(dict):
4 | def __init__(self, *args, **kwargs):
5 | if len(args) >= 1:
6 | default_factory = args[0]
7 | if default_factory is not None and not callable(default_factory):
8 | raise TypeError("first argument must be callable or None")
9 | args = args[1:]
10 | else:
11 | default_factory = None
12 | super().__init__(*args, **kwargs)
13 | self.default_factory = default_factory
14 |
15 | def __missing__(self, key):
16 | if self.default_factory is not None:
17 | val = self.default_factory()
18 | else:
19 | raise KeyError(key)
20 | self[key] = val
21 | return val
22 |
23 | @_recursive_repr()
24 | def __repr_factory(factory):
25 | return repr(factory)
26 |
27 | def __repr__(self):
28 | return f"{type(self).__name__}({defaultdict.__repr_factory(self.default_factory)}, {dict.__repr__(self)})"
29 |
30 | def copy(self):
31 | return type(self)(self.default_factory, self)
32 |
33 | __copy__ = copy
34 |
35 | def __reduce__(self):
36 | if self.default_factory is not None:
37 | args = self.default_factory,
38 | else:
39 | args = ()
40 | return type(self), args, None, None, iter(self.items())
41 |
42 | def __or__(self, other):
43 | if not isinstance(other, dict):
44 | return NotImplemented
45 |
46 | new = defaultdict(self.default_factory, self)
47 | new.update(other)
48 | return new
49 |
50 | def __ror__(self, other):
51 | if not isinstance(other, dict):
52 | return NotImplemented
53 |
54 | new = defaultdict(self.default_factory, other)
55 | new.update(self)
56 | return new
57 |
58 | defaultdict.__module__ = 'collections'
59 |
--------------------------------------------------------------------------------
/controllers/pyctrl/Lib/collections/abc.py:
--------------------------------------------------------------------------------
1 | from _collections_abc import *
2 | from _collections_abc import __all__
3 | from _collections_abc import _CallableGenericAlias
4 |
--------------------------------------------------------------------------------
/controllers/pyctrl/Lib/copy.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | for f in `find -name \*.py | xargs cat | grep -E '^(import |from \S+ import)' | awk '{print $2}' | sort | uniq` ; do
4 | test -f $f.py && continue
5 | test -f $f/__init__.py && continue
6 | if test -f ../../RustPython/pylib/Lib/$f.py ; then
7 | echo cp ../../RustPython/pylib/Lib/$f.py $f.py
8 | continue
9 | fi
10 | echo "? $f"
11 | done
12 |
--------------------------------------------------------------------------------
/controllers/pyctrl/Lib/keyword.py:
--------------------------------------------------------------------------------
1 | """Keywords (from "Grammar/python.gram")
2 |
3 | This file is automatically generated; please don't muck it up!
4 |
5 | To update the symbols in this file, 'cd' to the top directory of
6 | the python source tree and run:
7 |
8 | PYTHONPATH=Tools/peg_generator python3 -m pegen.keywordgen \
9 | Grammar/python.gram \
10 | Grammar/Tokens \
11 | Lib/keyword.py
12 |
13 | Alternatively, you can run 'make regen-keyword'.
14 | """
15 |
16 | __all__ = ["iskeyword", "issoftkeyword", "kwlist", "softkwlist"]
17 |
18 | kwlist = [
19 | 'False',
20 | 'None',
21 | 'True',
22 | 'and',
23 | 'as',
24 | 'assert',
25 | 'async',
26 | 'await',
27 | 'break',
28 | 'class',
29 | 'continue',
30 | 'def',
31 | 'del',
32 | 'elif',
33 | 'else',
34 | 'except',
35 | 'finally',
36 | 'for',
37 | 'from',
38 | 'global',
39 | 'if',
40 | 'import',
41 | 'in',
42 | 'is',
43 | 'lambda',
44 | 'nonlocal',
45 | 'not',
46 | 'or',
47 | 'pass',
48 | 'raise',
49 | 'return',
50 | 'try',
51 | 'while',
52 | 'with',
53 | 'yield'
54 | ]
55 |
56 | softkwlist = [
57 | '_',
58 | 'case',
59 | 'match'
60 | ]
61 |
62 | iskeyword = frozenset(kwlist).__contains__
63 | issoftkeyword = frozenset(softkwlist).__contains__
64 |
--------------------------------------------------------------------------------
/controllers/pyctrl/build.rs:
--------------------------------------------------------------------------------
1 | fn main() {
2 | for entry in glob::glob("Lib/**/*.py").expect("Lib/ exists?").flatten() {
3 | let display = entry.display();
4 | println!("cargo:rerun-if-changed={display}");
5 | }
6 | for entry in glob::glob("../../py/pyaici/server*.py")
7 | .expect("exists?")
8 | .flatten()
9 | {
10 | let display = entry.display();
11 | println!("cargo:rerun-if-changed={display}");
12 | }
13 | }
14 |
--------------------------------------------------------------------------------
/controllers/pyctrl/samples/forkbomb.py:
--------------------------------------------------------------------------------
1 | import pyaici.server as aici
2 |
3 |
4 | async def fork_bomb():
5 | await aici.FixedTokens("The value of")
6 | id = await aici.fork(20)
7 | await aici.FixedTokens(f" {id} is")
8 | await aici.gen_text(max_tokens=5, store_var=f"x{id}")
9 |
10 |
11 | async def deadlock():
12 | await aici.wait_vars("foo")
13 |
14 |
15 | async def burn_tokens():
16 | id = await aici.fork(10)
17 | await aici.FixedTokens(f"The value of {id} in the universe is and everything is ")
18 | await aici.gen_text(max_tokens=200, store_var=f"foo{id}")
19 |
20 | aici.test(burn_tokens())
21 |
--------------------------------------------------------------------------------
/controllers/pyctrl/samples/idents.py:
--------------------------------------------------------------------------------
1 | import pyaici.server as aici
2 | import re
3 |
4 | # asserts for microsoft/Orca-2-13b
5 |
6 | aici.log_level = 1
7 |
8 | async def test_id():
9 | await aici.FixedTokens("Here's a fib function\n```python\n")
10 |
11 | max_tokens = 60
12 | dyn_lex = aici.DynamicLexer("")
13 | for id in ["def", "fibo", "n", "return", "if"]:
14 | dyn_lex.add(id)
15 | next_token = aici.ConstrainedToken(lambda: dyn_lex.constraint())
16 | res = []
17 | text = ""
18 | for _ in range(max_tokens):
19 | tokens = await next_token
20 | if tokens:
21 | res += tokens
22 | print("GEN-STEP:", aici.tokens_repr(tokens))
23 | text = aici.detokenize(res).decode(errors="replace")
24 | if next_token.finished:
25 | break
26 | print("RESULT:", text)
27 |
28 |
29 | aici.test(test_id())
30 |
--------------------------------------------------------------------------------
/controllers/pyctrl/samples/phi.py:
--------------------------------------------------------------------------------
1 | import pyaici.server as aici
2 |
3 |
4 | async def test_42():
5 | await aici.FixedTokens(
6 | "The ultimate answer to life, the universe and everything is"
7 | )
8 | s = await aici.gen_text(regex=r" \d+\.", max_tokens=5, store_var="x")
9 | aici.check_vars({"x": " 42."})
10 |
11 |
12 | aici.test(test_42())
13 |
--------------------------------------------------------------------------------
/controllers/pyctrl/samples/substr.py:
--------------------------------------------------------------------------------
1 | import pyaici.server as aici
2 |
3 | earth = """
4 | Earth is rounded into an ellipsoid with a circumference of about 40,000 km. It is the densest planet in the Solar System. Of the four rocky planets, it is the largest and most massive. Earth is about eight light-minutes away from the Sun and orbits it, taking a year (about 365.25 days) to complete one revolution. Earth rotates around its own axis in slightly less than a day (in about 23 hours and 56 minutes). Earth's axis of rotation is tilted with respect to the perPendicular to its orbital plane around the Sun, producing seasons. Earth is orbited by one permanent natural satellite, the Moon, which orbits Earth at 384,400 km (1.28 light seconds) and is roughly a quarter as wide as Earth. Through tidal locking, the Moon always faces Earth with the same side, which causes tides, stabilizes Earth's axis, and gradually slows its rotation.
5 | """
6 |
7 | prompt = f"""[INST] Here's some text:
8 | {earth}
9 |
10 | Based on the text answer the question: Is Earth axis aligned with its orbit?
11 | Provide a quote from the text, prefixed by 'Source: "', to support your answer.
12 | [/INST]
13 | """
14 |
15 | async def test_substr():
16 | await aici.FixedTokens(prompt)
17 | await aici.gen_tokens(max_tokens=60, stop_at="Source: \"", store_var="answer")
18 | await aici.gen_tokens(substring=earth, substring_end="\"", max_tokens=60, store_var="source")
19 | # make sure we can continue generating afterwards
20 | await aici.FixedTokens("\nThe tilt is")
21 | await aici.gen_tokens(max_tokens=6, store_var="tilt")
22 |
23 | aici.test(test_substr())
--------------------------------------------------------------------------------
/controllers/pyctrl/samples/tla.py:
--------------------------------------------------------------------------------
1 | import pyaici.server as aici
2 |
3 | apalache = r"""
4 |
5 | %start tnl
6 | %%
7 |
8 | List
9 | : T
10 | | List ", " T
11 | ;
12 |
13 | tnl: T "/\n/";
14 |
15 | T
16 | // integers
17 | : "Int"
18 | // immutable constant strings
19 | | "Str"
20 | // boolean
21 | | "Bool"
22 | // functions
23 | | T " -> " T
24 | // sets
25 | | "Set" "(" T ")"
26 | // sequences
27 | | "Seq" "(" T ")"
28 | // tuples
29 | | "<<" List ">>"
30 | // parentheses, e.g., to change associativity of functions
31 | | "(" T ")"
32 | // operators
33 | | "(" List ") => " T
34 | ;
35 |
36 | %%
37 | """
38 |
39 |
40 | async def gen_and_test_grammar():
41 | aici.log_level = 3
42 | await aici.FixedTokens(
43 | """
44 | Here's a TLA+ spec:
45 |
46 | ---- MODULE Counter ----
47 | VARIABLE
48 | b
49 | q
50 |
51 | Init == b = TRUE
52 | Next == q = q + 1
53 | ====
54 |
55 | Now with added types:
56 |
57 | ---- MODULE Counter ----
58 | VARIABLE
59 | b: """
60 | )
61 | await aici.gen_tokens(yacc=apalache, max_tokens=10, store_var="b")
62 | await aici.FixedTokens(" q: ")
63 | await aici.gen_tokens(yacc=apalache, max_tokens=10, store_var="q")
64 |
65 |
66 | aici.test(gen_and_test_grammar())
--------------------------------------------------------------------------------
/controllers/pyctrl/samples/warsaw.py:
--------------------------------------------------------------------------------
1 | import pyaici.server as aici
2 |
3 | system_message = "You are a helpful assistant."
4 |
5 | async def fixed(user_message: str):
6 | prompt = f"<|im_start|>system\n{system_message}<|im_end|>\n"
7 | prompt += f"<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant\n"
8 | await aici.FixedTokens(prompt)
9 |
10 | async def main():
11 | await fixed("What is the capital of Poland?")
12 | await aici.gen_tokens(max_tokens=10, store_var="capital")
13 |
14 | aici.start(main())
15 |
--------------------------------------------------------------------------------
/controllers/pyctrl/samples/yesno.py:
--------------------------------------------------------------------------------
1 | import pyaici.server as aici
2 |
3 | async def main():
4 | tokens = await aici.GetPrompt()
5 | assert len(tokens) > 2, "prompt too short"
6 | await aici.FixedTokens("\n")
7 | await aici.gen_tokens(options=["Yes", "No"])
8 |
9 | aici.start(main())
10 |
--------------------------------------------------------------------------------
/controllers/pyctrl/wasm.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | set -x
4 | set -e
5 | cargo build --release
6 | BIN=$(cd ../target; pwd)
7 | cp $BIN/wasm32-wasi/release/aici_pyctrl.wasm $BIN/opt.wasm
8 | ls -l $BIN/opt.wasm
9 | if [ "X$1" = "Xbuild" ] ; then
10 | exit
11 | fi
12 | if [ "X$1" = "Xsize" ] ; then
13 | node size.js
14 | fx $BIN/dominators.json
15 | exit
16 | fi
17 |
18 | (cd ../aicirt; cargo build --release)
19 |
20 | mkdir -p tmp
21 | if [ "X$1" = "Xcache" ] ; then
22 | $BIN/release/aicirt --module $BIN/opt.wasm | tee tmp/runlog.txt
23 | exit
24 | fi
25 |
26 | PERF=
27 | if [ `uname` = Linux ] ; then
28 | _PERF="perf stat"
29 | fi
30 | RUST_LOG=info $PERF $BIN/release/aicirt --tokenizer gpt4 --module $BIN/opt.wasm
31 | RUST_LOG=info $PERF $BIN/release/aicirt \
32 | --tokenizer gpt4 --module $BIN/opt.wasm --run --run-arg ./samples/test.py | tee tmp/runlog.txt
33 | ls -l $BIN/opt.wasm
34 |
--------------------------------------------------------------------------------
/controllers/uppercase/.cargo/config.toml:
--------------------------------------------------------------------------------
1 | [build]
2 | target = "wasm32-wasi"
3 |
4 | [profile.dev]
5 | strip = "debuginfo"
6 |
7 | [profile.release]
8 | strip = "debuginfo"
9 |
--------------------------------------------------------------------------------
/controllers/uppercase/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "aici_uppercase"
3 | version = "0.1.0"
4 | edition = "2021"
5 |
6 | [dependencies]
7 | aici_abi = { path = "../aici_abi" }
8 | anyhow = "1.0.75"
9 |
--------------------------------------------------------------------------------
/controllers/uppercase/README.md:
--------------------------------------------------------------------------------
1 | # AICI Uppercase
2 |
3 | This folder provides a simple, hello-world-like AICI Controller
4 | that you can clone to build your own controller.
5 |
6 | When you clone it, make sure to keep [.cargo/config.toml](.cargo/config.toml)
7 | as it sets up the linker flags for the Wasm target.
8 |
9 | The [main.rs](src/main.rs) shows usage of the `FunctionalRecognizer` interface.
10 | It forces every 4th letter of the model output to be uppercase.
11 |
12 | ```
13 | $ ../../aici.sh run --build .
14 | will build aici_uppercase from /workspaces/aici/uppercase/Cargo.toml
15 | Compiling aici_abi v0.1.0 (/workspaces/aici/aici_abi)
16 | Compiling aici_uppercase v0.1.0 (/workspaces/aici/uppercase)
17 | Finished release [optimized + debuginfo] target(s) in 1.81s
18 | built: /workspaces/aici/target/wasm32-wasi/release/aici_uppercase.wasm, 0.189 MiB
19 | upload module... 193kB -> 675kB id:a4000d9b
20 | [DONE]
21 | [Response] Here's a tweet:
22 | I'm SO EXCITED! I'm GoinG toBe aMom!I'm GoinG toHaVeA BaBy!
23 | ```
24 |
25 | This is only meant as a sample - it could be done with [PyCtrl](../pyctrl) or
26 | [JsCtrl](../jsctrl) and a simple regex.
27 |
28 | ## Yes/no controller
29 |
30 | You can also take a look at the [yes/no controller](../aici_abi/src/yesno.rs), which
31 | only allows the model to say "Yes" or "No" in answer to the question in the prompt.
32 |
33 | ```
34 | $ ./scripts/sample-yesno.sh "Can orcas sing?"
35 | + echo 'Can orcas sing?'
36 | + ./aici.sh run --build aici_abi::yesno -
37 | will build yesno from /workspaces/aici/aici_abi/Cargo.toml
38 | Compiling aici_abi v0.1.0 (/workspaces/aici/aici_abi)
39 | Finished release [optimized + debuginfo] target(s) in 0.58s
40 | upload module... 192kB -> 671kB id:c65e78e9
41 | [DONE]
42 | [Response] Can orcas sing?
43 |
44 | Yes
45 | ```
46 |
47 | A similar effect can be achieved with PyCtrl and [10x less lines of code](../pyctrl/samples/yesno.py),
48 | but it illustrates the raw token APIs.
49 |
50 |
51 | ```
52 | $ ./aici.sh run pyctrl/samples/yesno.py
53 | Running with tagged AICI Controller: pyctrl-latest
54 | [0]: FIXED 'Are dolphins fish?\n'
55 | [0]: GEN 'No'
56 | [DONE]
57 | [Response] Are dolphins fish?
58 | No
59 | ```
60 |
61 | ## DeclCtrl
62 |
63 | For a more full-fledged example, take a look at the [DeclCtrl](../declctrl/src/declctrl.rs).
64 |
--------------------------------------------------------------------------------
/controllers/uppercase/src/main.rs:
--------------------------------------------------------------------------------
1 | use aici_abi::{
2 | host_trie,
3 | recognizer::{FunctionalRecognizer, StackRecognizer},
4 | tokenize,
5 | toktrie::{SpecialToken, TokTrie},
6 | AiciCtrl, InitPromptArg, InitPromptResult, MidProcessArg, MidProcessResult,
7 | };
8 |
9 | // This constraints enforces an upper case letter every 4th byte
10 | // The state is the position in the output stream
11 | struct QuadUpper {}
12 | impl FunctionalRecognizer for QuadUpper {
13 | fn initial(&self) -> usize {
14 | 0
15 | }
16 |
17 | fn try_append(&self, state: usize, byte: u8) -> Option {
18 | if state % 4 == 0 && !byte.is_ascii_uppercase() {
19 | None
20 | } else {
21 | Some(state + 1)
22 | }
23 | }
24 |
25 | fn special_allowed(&self, _state: usize, tok: SpecialToken) -> bool {
26 | match tok {
27 | SpecialToken::EndOfSentence => false,
28 | _ => false,
29 | }
30 | }
31 | }
32 |
33 | pub struct Runner {
34 | toktrie: TokTrie,
35 | tokens: Vec,
36 | recognizer: StackRecognizer,
37 | }
38 |
39 | impl Runner {
40 | pub fn new() -> Self {
41 | Runner {
42 | toktrie: host_trie(),
43 | tokens: Vec::new(),
44 | recognizer: StackRecognizer::from(QuadUpper {}),
45 | }
46 | }
47 | }
48 |
49 | impl AiciCtrl for Runner {
50 | fn init_prompt(&mut self, arg: InitPromptArg) -> InitPromptResult {
51 | if arg.prompt.len() <= 1 {
52 | // in case no prompt was provided, invent some
53 | InitPromptResult {
54 | prompt: tokenize("Here's a tweet:\n"),
55 | }
56 | } else {
57 | InitPromptResult::from_arg(arg)
58 | }
59 | }
60 |
61 | fn mid_process(&mut self, arg: MidProcessArg) -> MidProcessResult {
62 | // store our tokens
63 | arg.save_tokens(&mut self.tokens);
64 | // and update the state of our recognizer
65 | self.toktrie
66 | .append_tokens(&mut self.recognizer, &arg.tokens)
67 | .unwrap();
68 |
69 | // stop after 50 tokens
70 | if self.tokens.len() > 50 || arg.has_eos() {
71 | return MidProcessResult::stop();
72 | }
73 |
74 | // otherwise, compute bias according to our recognizer
75 | let mut set = self.toktrie.alloc_token_set();
76 | self.toktrie.compute_bias(&mut self.recognizer, &mut set);
77 | MidProcessResult::sample(set)
78 | }
79 | }
80 |
81 | fn main() {
82 | // test code here?
83 | }
84 |
85 | aici_abi::aici_expose_all!(Runner, Runner::new());
86 |
--------------------------------------------------------------------------------
/docs/FAQ.md:
--------------------------------------------------------------------------------
1 | # Frequently Asked Questions
2 |
3 | ## How does system prompt or chat mode work with AICI?
4 |
5 | AICI interacts with models at the level of sequences of tokens.
6 | Models themselves do not have a distinct input for "system prompt" or "chat message",
7 | instead they are wrapped in model-specific tokens.
8 | You need to find the model's "Instruction format", typically on model's page on HuggingFace.
9 |
10 | For example, the [Orca-2-13b model](https://huggingface.co/microsoft/Orca-2-13b) has the following instruction format:
11 | ```
12 | <|im_start|>system
13 | {system_message}<|im_end|>
14 | <|im_start|>user
15 | {user_message}<|im_end|>
16 | <|im_start|>assistant
17 | ```
18 |
19 | The [Mistral-Instruct](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) and [Mixtral-Instruct](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1), as well as [CodeLlama-Instruct](https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf) models use:
20 | ```
21 | [INST]{instruction}[/INST]
22 | ```
23 |
24 | Intrestingly, `<|im_start|>` and `<|im_end|>` are special tokens, while `[INST]` and `[/INST]` are regular strings.
25 |
26 | The start token (typically denoted ``) is always implicit in AICI!
27 |
28 | For example, for Orca model you can use the following:
29 |
30 | ```python
31 | import pyaici.server as aici
32 |
33 | system_message = "You are a helpful assistant."
34 |
35 | async def ask(user_message: str):
36 | prompt = f"<|im_start|>system\n{system_message}<|im_end|>\n"
37 | prompt += f"<|im_start|>user\n{user_message}<|im_end|>\n"
38 | prompt += "<|im_start|>assistant\n"
39 | await aici.FixedTokens(prompt)
40 |
41 | async def main():
42 | await ask("What is the capital of Poland?")
43 | await aici.gen_tokens(max_tokens=10, store_var="capital")
44 |
45 | aici.start(main())
46 | ```
--------------------------------------------------------------------------------
/docs/proxy.md:
--------------------------------------------------------------------------------
1 | # Client-side access to AICI
2 |
3 | The [Artificial Intelligence Controller Interface (AICI)](https://github.com/microsoft/aici)
4 | can be used to constrain output of an LLM in real time.
5 | While the GPU is working on the next token of the output, the AICI runtime can use the CPU to
6 | compute a user-provided constraint on the next token.
7 | This adds minimal latency to the LLM generation.
8 |
9 | ## Setup
10 |
11 | Install the `pyaici` package, export credentials, and see if the connection is working:
12 |
13 | ```bash
14 | pip uninstall pyaici
15 | pip install -e "git+https://github.com/microsoft/aici#egg=pyaici&subdirectory=py"
16 | export AICI_API_BASE="https://inference.example.com/v1/#key=wht_..."
17 | aici infer --max-tokens=10 "Answer to the Ultimate Question of Life, the Universe, and Everything is"
18 | ```
19 |
20 | To test out the `pyctrl`, create `answer.py` file with:
21 |
22 | ```python
23 | import pyaici.server as aici
24 |
25 | async def main():
26 | await aici.FixedTokens("The ultimate answer to the universe is ")
27 | await aici.gen_text(regex=r'\d\d', max_tokens=2)
28 |
29 | aici.start(main())
30 | ```
31 |
32 | You can run it with `aici run answer.py`. Try `aici run --help` for available options.
33 |
34 | You can use `aici --log-level=5 run answer.py` to see arguments to the REST requests,
35 | if you want to do them yourself.
36 |
--------------------------------------------------------------------------------
/py/promptlib/README.md:
--------------------------------------------------------------------------------
1 | Promptlib is a sample client-side Python library that interacts with the DeclCtrl AIController, using the pyaici package for communication.
2 | See the notebooks folder for examples of using promptlib.
3 |
--------------------------------------------------------------------------------
/py/promptlib/__init__.py:
--------------------------------------------------------------------------------
1 | from .promptlib import PromptNode, append, begin, end
2 | from .promptlib import PromptProgram
3 | from .promptlib import gen, choose, wait
4 | from .promptlib import AICI
5 |
--------------------------------------------------------------------------------
/py/promptlib/promptlib/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .prompt import PromptNode, append, begin, end, PromptProgram
3 | from .gen import gen, choose, wait
4 |
5 | from .aici import AICI
6 |
7 | setattr(PromptNode, "append", append)
8 | setattr(PromptNode, "begin", begin)
9 | setattr(PromptNode, "end", end)
10 | setattr(PromptNode, "gen", gen)
11 | setattr(PromptNode, "choose", choose)
12 | setattr(PromptNode, "wait", wait)
13 |
14 |
--------------------------------------------------------------------------------
/py/promptlib/promptlib/aici.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import requests
3 | import json
4 | import sys
5 | import os
6 | import re
7 | from typing import Optional
8 |
9 | import pyaici.rest as aici_rest
10 |
11 | class AICI:
12 | def __init__(self, base_url=None, wasm_runner_id=None, wasm_runner_path=None, wasm_runner_buildsh=None ):
13 | self.base_url = base_url
14 |
15 | if wasm_runner_id is None:
16 | if wasm_runner_path is None:
17 | if wasm_runner_buildsh is None:
18 | raise RuntimeError("Must specify wasm_runner_id or wasm_runner_path or wasm_runner_buildsh")
19 | wasm_runner_path = _compile_wasm(wasm_runner_buildsh)
20 | wasm_runner_id = _upload_wasm(self.base_url, wasm_runner_path)
21 | self.wasm_runner_id = wasm_runner_id
22 |
23 | def run(self, prompt_plan):
24 | return _submit_program(self.base_url, self.wasm_runner_id, prompt_plan, log=True)
25 |
26 |
27 | def _compile_wasm(wasm_runner_buildsh, scriptargs=["build"]):
28 | # separate wasm_runner_buildsh into the script filename and the directory
29 | # containing the script
30 | script_dir = os.path.dirname(wasm_runner_buildsh)
31 | script_name = os.path.basename(wasm_runner_buildsh)
32 |
33 | r = subprocess.run(["sh", script_name].extend(scriptargs), cwd=script_dir)
34 | if r.returncode != 0:
35 | raise RuntimeError(f"error compiling aici promptlib module")
36 |
37 | file_path = script_dir + "/target/strip.wasm"
38 | return file_path
39 |
40 |
41 | def _upload_wasm(base_url, wasm_runner_path):
42 | print("upload module... ", end="")
43 | with open(wasm_runner_path, "rb") as f:
44 | resp = requests.post(base_url + "aici_modules", data=f)
45 | if resp.status_code == 200:
46 | d = resp.json()
47 | dd = d["data"]
48 | mod_id = dd["module_id"]
49 | print(
50 | f"{dd['wasm_size']//1024}kB -> {dd['compiled_size']//1024}kB id:{mod_id[0:8]}"
51 | )
52 | return mod_id
53 | else:
54 | raise RuntimeError(
55 | f"bad response to model upload: {resp.status_code} {resp.reason}: {resp.text}"
56 | )
57 |
58 |
59 | def _submit_program(base_url, aici_module, aici_arg, temperature=0, max_tokens=None, log=False):
60 | return aici_rest.run_controller(controller=aici_module, controller_arg=aici_arg, temperature=temperature, max_tokens=max_tokens, base_url=base_url)
61 |
--------------------------------------------------------------------------------
/py/promptlib/promptlib/vars.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/py/promptlib/promptlib/vars.py
--------------------------------------------------------------------------------
/py/pyaici/__init__.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def runner_from_cli(args, dtype: str = 'f32'):
5 | from pyaici.comms import AiciRunner
6 |
7 | tokenizer = args.aici_tokenizer
8 |
9 | # when no explicit --aici-tokenizer, we look for:
10 | # --tokenizer + --tokenizer-revision
11 | # --model + --revision
12 | if not tokenizer:
13 | model_tokenizer = getattr(args, "tokenizer", None)
14 | if model_tokenizer:
15 | rev = getattr(args, "tokenizer_revision", None)
16 | if rev:
17 | model_tokenizer += f"@{rev}"
18 | tokenizer = model_tokenizer
19 | else:
20 | model = getattr(args, "model", None)
21 | if model:
22 | rev = getattr(args, "revision", None)
23 | if rev:
24 | model += f"@{rev}"
25 | tokenizer = model
26 |
27 | if not tokenizer:
28 | raise ValueError("No AICIrt tokenizer specified")
29 | if not args.aici_rt:
30 | raise ValueError("No AICIrt path specified")
31 |
32 | aici = AiciRunner(
33 | rtpath=args.aici_rt,
34 | tokenizer=tokenizer,
35 | trace_file=args.aici_trace,
36 | rtargs=args.aici_rtarg,
37 | pref=args.aici_shm_prefix,
38 | dtype=dtype,
39 | )
40 | return aici
41 |
42 |
43 | def add_cli_args(parser: argparse.ArgumentParser, single=False):
44 | parser.add_argument(
45 | "--aici-rt",
46 | type=str,
47 | required=True,
48 | help="path to aicirt",
49 | )
50 | parser.add_argument(
51 | "--aici-shm-prefix",
52 | type=str,
53 | default="/aici0-",
54 | help="prefix for shared memory communication channels",
55 | )
56 | parser.add_argument(
57 | "--aici-tokenizer",
58 | type=str,
59 | default="",
60 | help=
61 | "tokenizer to use; llama, phi, ...; can also use HF tokenizer name",
62 | )
63 | parser.add_argument(
64 | "--aici-trace",
65 | type=str,
66 | help="save trace of aicirt interaction to a JSONL file",
67 | )
68 | parser.add_argument(
69 | "--aici-rtarg",
70 | "-A",
71 | type=str,
72 | default=[],
73 | action="append",
74 | help="pass argument to aicirt process",
75 | )
76 |
77 | if single:
78 | parser.add_argument(
79 | "--controller",
80 | type=str,
81 | required=True,
82 | help="id of the module to run",
83 | )
84 | parser.add_argument(
85 | "--controller-arg",
86 | type=str,
87 | default="",
88 | help="arg passed to module (filename)",
89 | )
90 |
91 | return parser
92 |
--------------------------------------------------------------------------------
/py/pyaici/_vllm_protocol.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 | from typing import List, Optional, Union
3 | from vllm.sampling_params import SamplingParams
4 | from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
5 |
6 | EPSILON_TEMP = 1.5e-5
7 |
8 | class RunRequest(BaseModel):
9 | model: Optional[str] = None
10 | prompt: Optional[str] = None
11 | messages: Optional[List[ChatCompletionMessageParam]] = None
12 | controller: str
13 | controller_arg: Union[str, dict]
14 | temperature: float = 0.0
15 | top_p: float = 1.0
16 | top_k: int = -1
17 | max_tokens: Optional[int] = None
18 |
19 | def to_sampling_params(self):
20 | r = SamplingParams(
21 | temperature=max(self.temperature, EPSILON_TEMP),
22 | top_p=self.top_p,
23 | top_k=self.top_k,
24 | max_tokens=self.max_tokens,
25 | ignore_eos=False,
26 | logprobs=10,
27 | )
28 | r.has_aici = True # type: ignore
29 | return r
30 |
31 |
32 | class SetTagsRequest(BaseModel):
33 | module_id: str
34 | tags: List[str]
--------------------------------------------------------------------------------
/py/pyaici/util.py:
--------------------------------------------------------------------------------
1 | llama_sys_prompt = """
2 | You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. You are concise.
3 | """
4 |
5 |
6 | def llama_prompt(prompt: str) -> str:
7 | return f"[INST] <>\n{llama_sys_prompt}\n<>\n\n [/INST]\n[INST] {prompt} [/INST]\n"
8 |
9 |
10 | def codellama_prompt(prompt: str) -> str:
11 | return f"[INST] {prompt} [/INST]\n"
12 |
13 |
14 | system_message = "You are a helpful assistant."
15 | orca_prefix = f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n"
16 | orca_suffix = "<|im_end|>\n<|im_start|>assistant\n"
17 |
18 |
19 | def orca_prompt(prompt: str) -> str:
20 | return orca_prefix + prompt + orca_suffix
21 |
--------------------------------------------------------------------------------
/py/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(
4 | name='pyaici',
5 | version='0.1',
6 | packages=['pyaici'],
7 | entry_points={
8 | 'console_scripts': [
9 | 'aici = pyaici.cli:main'
10 | ]
11 | },
12 | install_requires=[
13 | 'requests'
14 | ],
15 | )
16 |
--------------------------------------------------------------------------------
/py/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import subprocess
4 |
5 | prj_dir = os.path.dirname(os.path.abspath(__file__)) + "/../.."
6 | sys.path.append(prj_dir + "/py")
7 |
8 | ast_module_path = prj_dir + "/tmp/ast_module.txt"
9 |
10 |
11 | def upload_wasm():
12 | import pyaici.rest
13 | prog = prj_dir + "/controllers/declctrl"
14 | r = subprocess.run(
15 | ["cargo", "build", "--release"],
16 | cwd=prog,
17 | stdout=subprocess.DEVNULL,
18 | stderr=subprocess.DEVNULL,
19 | )
20 | if r.returncode != 0:
21 | sys.exit(1)
22 | file_path = prj_dir + "/target/wasm32-wasi/release/aici_declctrl.wasm"
23 | pyaici.rest.log_level = 0
24 | ast_module = pyaici.rest.upload_module(file_path)
25 |
26 | os.makedirs(prj_dir + "/tmp", exist_ok=True)
27 | with open(ast_module_path, "w") as f:
28 | f.write(ast_module)
29 |
30 |
31 | def pytest_configure(config):
32 | import pyaici.rest
33 |
34 | if not hasattr(config, "workerinput"):
35 | upload_wasm()
36 | with open(ast_module_path, "r") as f:
37 | pyaici.rest.ast_module = f.read()
38 |
--------------------------------------------------------------------------------
/py/tests/test-prompt.txt:
--------------------------------------------------------------------------------
1 | ```python
2 | def print_prime(n):
3 | """
4 | Print all primes between 1 and n
5 | """
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 | ;[pytest]
2 | ;testpaths = py/tests
3 | ;addopts = -n 1
4 |
--------------------------------------------------------------------------------
/rllm/llama-cpp-low/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "llama_cpp_low"
3 | version = "0.0.1"
4 | edition = "2021"
5 |
6 | [dependencies]
7 | anyhow = "1.0.79"
8 | link-cplusplus = "1.0.9"
9 | log = "0.4.20"
10 | num_cpus = "1.16.0"
11 |
12 | [build-dependencies]
13 | bindgen = "0.69.2"
14 | cmake = "0.1.50"
15 |
16 | [features]
17 | default = []
18 | cuda = []
19 |
--------------------------------------------------------------------------------
/rllm/llama-cpp-low/src/main.rs:
--------------------------------------------------------------------------------
1 | fn main() {}
2 |
3 | #[cfg(feature = "disabled")]
4 | fn main() {
5 | use llama_cpp_low::*;
6 |
7 | let mparams = ModelParams::default();
8 | let mut cparams = ContextParams::default();
9 | cparams.n_ctx = 2048;
10 | let model = Model::from_file("tmp/llama-2-7b-chat.Q5_K_M.gguf", mparams, cparams);
11 | let mut batch = Batch::new(512);
12 | let s = model.new_sequence();
13 | for (idx, tok) in model
14 | .tokenize("Hello, my name is".as_bytes(), true, true)
15 | .iter()
16 | .enumerate()
17 | {
18 | batch.add_token(*tok, idx, &s, false)
19 | }
20 |
21 | let mut logit_idx = batch.len() - 1;
22 | let mut pos = batch.len();
23 |
24 | batch.enable_logits(logit_idx);
25 |
26 | for _ in 0..10 {
27 | model.decode(&mut batch).unwrap();
28 | let logits = model.get_logits(logit_idx);
29 | let top_idx = logits
30 | .iter()
31 | .enumerate()
32 | .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
33 | .map(|(index, _)| index)
34 | .unwrap();
35 | println!(
36 | "top_idx: {:?} {:?}",
37 | top_idx,
38 | String::from_utf8_lossy(&model.token_to_bytes(top_idx as u32))
39 | );
40 |
41 | logit_idx = 0;
42 | batch.clear();
43 | batch.add_token(top_idx as u32, pos, &s, true);
44 | pos += 1;
45 | }
46 | }
47 |
--------------------------------------------------------------------------------
/rllm/rllm-base/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "rllm"
3 | version = "0.1.0"
4 | edition = "2021"
5 |
6 | [dependencies]
7 | anyhow = "1.0.75"
8 | clap = "4.4.8"
9 | hf-hub = "0.3.2"
10 | tokenizers = { version = "0.15.0", features = ["hf-hub"] }
11 | serde_json = "1.0.108"
12 | serde = { version = "1.0.193", features = ["derive"] }
13 | rand = "0.8.5"
14 | half = "2.3.1"
15 | log = "0.4.20"
16 | actix-web = "4.4.0"
17 | tokio = { version = "1.34.0", features = ["sync"] }
18 | futures = "0.3.29"
19 | uuid = { version = "1.6.1", features = ["v4"] }
20 |
21 | aicirt = { path = "../../aicirt" }
22 | aici_abi = { path = "../../controllers/aici_abi" }
23 | libc = "0.2.150"
24 | base64 = "0.21.5"
25 | memmap2 = "0.9.0"
26 | safetensors = "0.4.1"
27 | lazy_static = "1.4.0"
28 | percent-encoding = "2.3.1"
29 |
--------------------------------------------------------------------------------
/rllm/rllm-base/src/exec.rs:
--------------------------------------------------------------------------------
1 | use std::{fmt::Display, sync::Arc};
2 |
3 | use aicirt::TimerRef;
4 | use anyhow::Result;
5 |
6 | use crate::{
7 | config::{ModelMeta, RllmConfig},
8 | scheduler::SchedulerOutputs,
9 | seq::{Sequence, SequenceGroup},
10 | HashMap, LoaderArgs, LogitsProcessor, RllmEngine,
11 | };
12 |
13 | #[derive(Debug, Clone, Copy)]
14 | pub enum BlockLocation {
15 | GPU,
16 | CPU,
17 | }
18 |
19 | pub trait AiciBias {
20 | fn apply(&self, logits: &mut T, seq_id: usize);
21 | }
22 |
23 | #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24 | pub struct SeqId(pub usize);
25 |
26 | impl SeqId {
27 | pub fn to_num(&self) -> usize {
28 | self.0
29 | }
30 | }
31 |
32 | impl Display for SeqId {
33 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
34 | write!(f, "{}", self.0)
35 | }
36 | }
37 |
38 | pub trait SequenceManager {
39 | fn new_sequence(&self) -> SeqId;
40 | fn copy(&self, src: SeqId, dst: SeqId, length: usize);
41 | fn trim(&self, seq: SeqId, length: usize);
42 | fn delete(&self, seq: SeqId);
43 | }
44 |
45 | pub trait ModelExec: Sized {
46 | type Tensor;
47 | type BlockSpaceManager: TBlockSpaceManager;
48 | type AiciBias: AiciBias;
49 | type ModelConfig;
50 | type ModelLoaderArgs: Send + 'static;
51 | type SequenceManager: SequenceManager;
52 |
53 | fn tensor_to_vec1(tensor: &Self::Tensor) -> Vec;
54 |
55 | fn load_model_config(
56 | args: &LoaderArgs,
57 | model_args: &mut Self::ModelLoaderArgs,
58 | ) -> Result<(ModelMeta, Self::ModelConfig)>;
59 | fn verify_args(args: &RllmConfig) -> Result<()>;
60 | fn load_rllm_engine(
61 | args: LoaderArgs,
62 | model_args: Self::ModelLoaderArgs,
63 | ) -> Result>;
64 |
65 | fn sequence_manager(&self) -> Arc;
66 |
67 | fn run(
68 | &mut self,
69 | _vocab_size: usize,
70 | tim: &TimerRef,
71 | step_no: usize,
72 | sched_out: &mut SchedulerOutputs,
73 | ) -> Result<()>;
74 | fn get_logits(&self, seq_id: usize) -> Self::Tensor;
75 | fn finalize_run(&mut self) -> Result<()>;
76 |
77 | fn empty_bias(&self, vocab_size: usize) -> Self::AiciBias;
78 | fn new_bias(&self, slice: &'static [f32], num_seqs: usize, vocab_size: usize)
79 | -> Self::AiciBias;
80 |
81 | fn sample(&self, processor: &mut LogitsProcessor, logits: &Self::Tensor) -> Result;
82 | }
83 |
84 | pub trait TBlockSpaceManager {
85 | fn can_allocate(&self, _seq_group: &SequenceGroup) -> bool;
86 | fn allocate(&mut self, seq_group: &mut SequenceGroup);
87 |
88 | fn can_append_slot(&self, _seq_group: &SequenceGroup) -> bool;
89 | fn append_slots(&mut self, _seq: &mut Sequence, _outputs: &mut SchedulerOutputs);
90 | fn get_num_free_gpu_blocks(&self) -> usize;
91 | fn get_num_free_cpu_blocks(&self) -> usize;
92 |
93 | fn can_swap_in(&self, _seq_group: &SequenceGroup) -> bool {
94 | false
95 | }
96 |
97 | fn swap_in(&mut self, _seq_group: &mut SequenceGroup) -> HashMap {
98 | panic!("no swap_in")
99 | }
100 |
101 | fn swap_out(&mut self, _seq_group: &mut SequenceGroup) -> HashMap {
102 | panic!("no swap_out")
103 | }
104 |
105 | fn can_swap_out(&self, _seq_group: &SequenceGroup) -> bool {
106 | false
107 | }
108 | }
109 |
--------------------------------------------------------------------------------
/rllm/rllm-base/src/lib.rs:
--------------------------------------------------------------------------------
1 | pub mod seq;
2 |
3 | // vllm modules
4 | pub mod config;
5 | mod engine;
6 | mod exec;
7 | mod expected;
8 | pub mod iface;
9 | mod logits;
10 | mod scheduler;
11 | pub mod server;
12 | pub mod util;
13 |
14 | use config::AiciConfig;
15 | pub use engine::*;
16 | pub use exec::*;
17 | pub use logits::LogitsProcessor;
18 | pub use scheduler::*;
19 | use std::sync::atomic::AtomicBool;
20 |
21 | pub use aicirt::HashMap;
22 | pub use aicirt::HashSet;
23 |
24 | pub struct LoaderArgs {
25 | pub tokenizer: String, // one of aici_tokenizer; eg "llama"
26 | pub model_id: String,
27 | pub revision: Option,
28 | pub file: Option,
29 | pub local_weights: Option,
30 | pub alt: usize,
31 | pub aici: AiciConfig,
32 | }
33 |
34 | impl Default for LoaderArgs {
35 | fn default() -> Self {
36 | Self {
37 | tokenizer: "llama".to_string(),
38 | model_id: "NousResearch/Llama-2-7b-hf".to_string(),
39 | revision: None,
40 | local_weights: None,
41 | file: None,
42 | aici: AiciConfig::default(),
43 | alt: 0,
44 | }
45 | }
46 | }
47 |
48 | static mut TRACE: AtomicBool = AtomicBool::new(false);
49 |
50 | pub fn set_trace(trace_enabled: bool) {
51 | unsafe {
52 | TRACE.store(trace_enabled, std::sync::atomic::Ordering::Relaxed);
53 | }
54 | }
55 |
56 | pub fn get_trace() -> bool {
57 | unsafe { TRACE.load(std::sync::atomic::Ordering::Relaxed) }
58 | }
59 |
60 | #[macro_export]
61 | macro_rules! rtrace {
62 | ($($arg:tt)*) => {{
63 | if $crate::get_trace() {
64 | println!($($arg)*);
65 | }
66 | }};
67 | }
68 |
--------------------------------------------------------------------------------
/rllm/rllm-base/src/logits.rs:
--------------------------------------------------------------------------------
1 | // based on https://github.com/huggingface/candle/blob/main/candle-transformers/src/generation/mod.rs
2 |
3 | use crate::config::{SamplingParams, SAMPLING_EPS};
4 | use rand::SeedableRng;
5 |
6 | pub struct LogitsProcessor {
7 | pub rng: rand::rngs::StdRng,
8 | pub temperature: Option,
9 | pub top_p: f32,
10 | }
11 |
12 | impl LogitsProcessor {
13 | pub fn new(sampling_params: &SamplingParams) -> Self {
14 | let temperature = if sampling_params.temperature < SAMPLING_EPS {
15 | None
16 | } else {
17 | Some(sampling_params.temperature)
18 | };
19 |
20 | Self {
21 | rng: rand::rngs::StdRng::from_entropy(),
22 | // seed_from_u64(42),
23 | temperature,
24 | top_p: sampling_params.top_p,
25 | }
26 | }
27 |
28 | pub fn set_temperature(&mut self, temperature: f32) {
29 | if temperature < SAMPLING_EPS {
30 | self.temperature = None;
31 | } else {
32 | self.temperature = Some(temperature);
33 | }
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/rllm/rllm-base/src/server/api.rs:
--------------------------------------------------------------------------------
1 | use aici_abi::StorageCmd;
2 | use serde::{Deserialize, Serialize};
3 |
4 | #[derive(Debug, Clone, Serialize, Deserialize)]
5 | pub struct RunRequest {
6 | pub controller: String,
7 | pub controller_arg: serde_json::Value,
8 | #[serde(default)]
9 | pub prompt: String,
10 | pub temperature: Option, // defl 0.0
11 | pub top_p: Option, // defl 1.0
12 | pub top_k: Option, // defl -1
13 | pub max_tokens: Option, // defl context size
14 | }
15 |
16 | #[derive(Debug, Clone, Serialize, Deserialize)]
17 | pub struct RunUsageResponse {
18 | pub sampled_tokens: usize,
19 | pub ff_tokens: usize,
20 | pub cost: usize,
21 | }
22 |
23 | #[derive(Debug, Clone, Serialize, Deserialize)]
24 | pub struct InitialRunResponse {
25 | pub id: String,
26 | pub object: &'static str, // "initial-run"
27 | pub created: u64,
28 | pub model: String,
29 | }
30 |
31 | #[derive(Debug, Clone, Serialize, Deserialize)]
32 | pub struct RunResponse {
33 | pub object: &'static str, // "run"
34 | pub forks: Vec,
35 | pub usage: RunUsageResponse,
36 | }
37 |
38 | #[derive(Debug, Clone, Serialize, Deserialize)]
39 | pub struct RunForkResponse {
40 | pub index: usize,
41 | #[serde(skip_serializing_if = "Option::is_none")]
42 | pub finish_reason: Option,
43 | pub text: String,
44 | pub error: String,
45 | pub logs: String,
46 | pub storage: Vec,
47 | pub micros: u64,
48 | }
49 |
--------------------------------------------------------------------------------
/rllm/rllm-base/src/server/openai/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Eric Buehler
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/rllm/rllm-base/src/server/openai/mod.rs:
--------------------------------------------------------------------------------
1 | pub mod requests;
2 | pub mod responses;
3 |
--------------------------------------------------------------------------------
/rllm/rllm-base/src/server/openai/requests.rs:
--------------------------------------------------------------------------------
1 | use serde::{Deserialize, Serialize};
2 | use crate::HashMap;
3 |
4 | #[derive(Debug, Clone, Serialize, Deserialize)]
5 | #[serde(untagged)]
6 | pub enum Messages {
7 | Map(Vec>),
8 | Literal(String),
9 | }
10 |
11 | #[derive(Debug, Clone, Serialize, Deserialize)]
12 | pub enum StopTokens {
13 | Multi(Vec),
14 | Single(String),
15 | }
16 |
17 | #[derive(Debug, Clone, Serialize, Deserialize)]
18 | pub struct ChatCompletionRequest {
19 | pub model: String,
20 | pub messages: Messages,
21 | #[serde(default)]
22 | pub temperature: Option, //0.7
23 | #[serde(default)]
24 | pub top_p: Option, //1.0
25 | #[serde(default)]
26 | pub n: Option, //1
27 | #[serde(default)]
28 | pub max_tokens: Option, //None
29 | #[serde(default)]
30 | pub stop: Option,
31 | #[serde(default)]
32 | pub stream: Option, //false
33 | #[serde(default)]
34 | pub presence_penalty: Option, //0.0
35 | #[serde(default)]
36 | pub frequency_penalty: Option, //0.0
37 | #[serde(default)]
38 | pub logit_bias: Option>, //None
39 | #[serde(default)]
40 | pub user: Option, //None
41 | #[serde(default)]
42 | //Additional candle-vllm params
43 | pub top_k: Option, //-1
44 | #[serde(default)]
45 | pub best_of: Option, //None
46 | #[serde(default)]
47 | pub use_beam_search: Option, //false
48 | #[serde(default)]
49 | pub ignore_eos: Option, //false
50 | #[serde(default)]
51 | pub skip_special_tokens: Option, //false
52 | #[serde(default)]
53 | pub stop_token_ids: Option>, //[]
54 | }
55 |
56 | #[derive(Debug, Clone, Serialize, Deserialize)]
57 | pub struct CompletionRequest {
58 | pub model: String,
59 | pub prompt: String,
60 |
61 | #[serde(default)]
62 | pub temperature: Option, //0.7
63 | #[serde(default)]
64 | pub top_p: Option, //1.0
65 | #[serde(default)]
66 | pub n: Option, //1
67 | #[serde(default)]
68 | pub max_tokens: Option, //None
69 | #[serde(default)]
70 | pub stop: Option>,
71 | #[serde(default)]
72 | pub stream: Option, //false
73 | #[serde(default)]
74 | pub presence_penalty: Option, //0.0
75 | #[serde(default)]
76 | pub frequency_penalty: Option, //0.0
77 | #[serde(default)]
78 | pub logit_bias: Option>, //None
79 | #[serde(default)]
80 | pub user: Option, //None
81 | #[serde(default)]
82 | pub top_k: Option, //-1
83 | #[serde(default)]
84 | pub best_of: Option, //None
85 | #[serde(default)]
86 | pub use_beam_search: Option, //false
87 | #[serde(default)]
88 | pub ignore_eos: Option, //false
89 | #[serde(default)]
90 | pub skip_special_tokens: Option, //false
91 | #[serde(default)]
92 | pub stop_token_ids: Option>, //[]
93 | }
94 |
--------------------------------------------------------------------------------
/rllm/rllm-base/src/server/openai/responses.rs:
--------------------------------------------------------------------------------
1 | use aici_abi::StorageCmd;
2 | use serde::{Deserialize, Serialize};
3 | use std::fmt::Debug;
4 |
5 | #[derive(Debug, Clone, Serialize, Deserialize)]
6 | pub struct ChatCompletionUsageResponse {
7 | pub completion_tokens: usize,
8 | pub prompt_tokens: usize,
9 | pub total_tokens: usize,
10 | pub fuel_tokens: usize,
11 | }
12 |
13 | // tool_calls, function_call not supported!
14 | #[derive(Debug, Clone, Serialize, Deserialize)]
15 | pub struct ChatChoiceData {
16 | pub content: Option,
17 | pub role: String,
18 | }
19 |
20 | #[derive(Debug, Clone, Serialize, Deserialize)]
21 | pub struct ChatChoice {
22 | pub message: ChatChoiceData,
23 | pub finish_reason: Option,
24 | pub index: usize,
25 | }
26 |
27 | #[derive(Debug, Clone, Serialize, Deserialize)]
28 | pub struct ChatCompletionResponse {
29 | pub id: String,
30 | pub choices: Vec,
31 | pub created: u64,
32 | pub model: String,
33 | pub object: &'static str,
34 | pub usage: ChatCompletionUsageResponse,
35 | }
36 |
37 | #[derive(Debug, Clone, Serialize, Deserialize)]
38 | pub struct CompletionChoice {
39 | pub text: String,
40 | pub finish_reason: Option,
41 | pub index: usize,
42 | }
43 |
44 | #[derive(Debug, Clone, Serialize, Deserialize)]
45 | pub struct CompletionResponse {
46 | pub id: String,
47 | pub choices: Vec,
48 | pub created: u64,
49 | pub model: String,
50 | pub object: &'static str, // "text_completion"
51 | pub usage: ChatCompletionUsageResponse,
52 | }
53 |
54 | // tool_calls, function_call not supported!
55 | #[derive(Debug, Clone, Serialize, Deserialize)]
56 | pub struct StreamingChoiceData {
57 | pub content: Option,
58 | pub role: String,
59 | }
60 |
61 | #[derive(Debug, Clone, Serialize, Deserialize)]
62 | pub struct StreamingChatChoice {
63 | pub delta: StreamingChoiceData,
64 | pub finish_reason: Option,
65 | pub index: usize,
66 | }
67 |
68 | #[derive(Debug, Clone, Serialize, Deserialize)]
69 | pub struct StreamingChatCompletionResponse {
70 | pub id: String,
71 | pub choices: Vec,
72 | pub created: u64,
73 | pub model: String,
74 | pub object: &'static str,
75 | }
76 |
77 | #[derive(Debug, Clone, Serialize, Deserialize)]
78 | pub struct List {
79 | pub object: &'static str, // "list"
80 | pub data: Vec,
81 | }
82 |
83 | impl List {
84 | pub fn new(data: Vec) -> Self {
85 | Self {
86 | object: "list",
87 | data,
88 | }
89 | }
90 | }
91 |
92 | #[derive(Debug, Clone, Serialize, Deserialize)]
93 | pub struct Model {
94 | pub object: &'static str, // "model"
95 | pub id: String,
96 | pub created: u64,
97 | pub owned_by: String,
98 | }
99 |
100 | #[derive(Debug, Clone, Serialize, Deserialize)]
101 | pub struct StreamingCompletionChoice {
102 | pub index: usize,
103 | pub finish_reason: Option,
104 | pub text: String,
105 |
106 | pub error: String,
107 | pub logs: String,
108 | pub storage: Vec,
109 | // pub logprobs: Option,
110 | }
111 |
112 | #[derive(Debug, Clone, Serialize, Deserialize)]
113 | pub struct StreamingCompletionResponse {
114 | pub object: &'static str, // "text_completion"
115 | pub id: String,
116 | pub model: String,
117 | pub created: u64,
118 | pub choices: Vec,
119 | pub usage: ChatCompletionUsageResponse,
120 | }
121 |
--------------------------------------------------------------------------------
/rllm/rllm-cuda/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "rllm-cuda"
3 | version = "0.1.0"
4 | edition = "2021"
5 |
6 | [dependencies]
7 | anyhow = "1.0.75"
8 | clap = "4.4.8"
9 | serde_json = "1.0.108"
10 | serde = { version = "1.0.193", features = ["derive"] }
11 | rand = "0.8.5"
12 | log = "0.4.20"
13 | actix-web = "4.4.0"
14 | tch = { version = "0.14.0" }
15 |
16 | cudarc = { version = "0.10.0", features = ["f16"], optional = true }
17 | tch-cuda = { path = "../tch-cuda", optional = true }
18 |
19 | rllm = { path = "../rllm-base" }
20 | aicirt = { path = "../../aicirt" }
21 | indicatif = "0.17.7"
22 | memmap2 = "0.9.0"
23 | safetensors = "0.4.1"
24 |
25 | [[bin]]
26 | name = "rllm-cuda"
27 | path = "src/rllm-cuda.rs"
28 |
29 | [features]
30 | default = ["cuda"]
31 | cuda = ["dep:tch-cuda", "dep:cudarc"]
32 |
--------------------------------------------------------------------------------
/rllm/rllm-cuda/README.md:
--------------------------------------------------------------------------------
1 | # rLLM
2 |
3 | This is a partial port of [vLLM](https://github.com/vllm-project/vllm)
4 | to Rust and [tch-rs](https://github.com/LaurentMazare/tch-rs)
5 | (bindings for [libtorch](https://github.com/pytorch/pytorch/blob/main/docs/libtorch.rst)
6 | which is basis of [PyTorch](https://github.com/pytorch/pytorch)).
7 | It is mostly meant as a proving ground for AICI (AI Controller Interface) integration.
8 |
9 |
10 | ## Building
11 |
12 | If you're not using the supplied docker container make sure to check
13 | that the following environment variables are set:
14 |
15 | ```bash
16 | export CUDA_COMPUTE_CAP="80"
17 | export LIBTORCH_USE_PYTORCH="1"
18 | ```
19 |
20 | You can run the server with `./server.sh` script; have a look inside to figure out
21 | how to run with different options.
22 |
23 | ## Tests
24 |
25 | The `expected/` directory contains sample prompts along with expected model output -
26 | top 128 logits for the first few tokens of output.
27 | Running `./expected/tests.sh` will run rLLM on these testcases and make sure it gets the
28 | same logits with some tolerance.
29 |
30 | You can inspect test cases like so:
31 |
32 | ```
33 | $ python scripts/testgen.py show expected/phi-1_5/lighthouse.safetensors
34 | Prompt: 'Write a detailed analogy between mathematics and a lighthouse.\n\nAnswer:'
35 | Output: ' In mathematics, logic is like a beacon of the lighthouse. It guides us'
36 | logits: torch.Size([15, 128]) min: tensor(12.7188) avg: tensor(17.6671) max: tensor(36.0938)
37 | prob_mass: torch.Size([15]) min: tensor(0.9795) avg: tensor(0.9944) max: tensor(0.9999)
38 | $
39 | ```
40 |
41 | `prob_mass` refers to the sum of probiblites of the top 128 logits after softmax
42 | (for every token of output). It should be very close to 1.
43 |
44 | ## Models
45 |
46 | The following models have been tested:
47 |
48 | * [CodeLlama-13b-Instruct-hf](https://huggingface.co/codellama/CodeLlama-13b-Instruct-hf)
49 | * [CodeLlama-34b-Instruct-hf](https://huggingface.co/codellama/CodeLlama-34b-Instruct-hf)
50 | - this barely fits in the 80GB A100, not much space for KV-cache
51 | * [Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf)
52 | * [Orca-2-13b](https://huggingface.co/microsoft/Orca-2-13b)
53 | * [phi-1_5](https://huggingface.co/microsoft/phi-1_5)
54 | * [phi-2](https://huggingface.co/microsoft/phi-2)
55 |
56 | In general all Llama models should work.
57 |
58 | ## Acknowledgements
59 |
60 | See [top-level README.md](../../README.md#acknowledgements).
61 |
--------------------------------------------------------------------------------
/rllm/rllm-cuda/expected/codellama/args.txt:
--------------------------------------------------------------------------------
1 | --model codellama/CodeLlama-13b-Instruct-hf --tokenizer llama16
2 | -s test_maxtol=0.20 -s test_avgtol=0.10
3 |
--------------------------------------------------------------------------------
/rllm/rllm-cuda/expected/codellama/cats.safetensors:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/codellama/cats.safetensors
--------------------------------------------------------------------------------
/rllm/rllm-cuda/expected/codellama/lighthouse.safetensors:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/codellama/lighthouse.safetensors
--------------------------------------------------------------------------------
/rllm/rllm-cuda/expected/codellama/primes.safetensors:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/codellama/primes.safetensors
--------------------------------------------------------------------------------
/rllm/rllm-cuda/expected/codellama34/args.txt:
--------------------------------------------------------------------------------
1 | --model codellama/CodeLlama-34b-Instruct-hf --tokenizer llama
2 | -s test_maxtol=8 -s test_avgtol=0.5
3 |
--------------------------------------------------------------------------------
/rllm/rllm-cuda/expected/codellama34/cats.safetensors:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/codellama34/cats.safetensors
--------------------------------------------------------------------------------
/rllm/rllm-cuda/expected/codellama34/lighthouse.safetensors:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/codellama34/lighthouse.safetensors
--------------------------------------------------------------------------------
/rllm/rllm-cuda/expected/codellama34/primes.safetensors:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/codellama34/primes.safetensors
--------------------------------------------------------------------------------
/rllm/rllm-cuda/expected/go.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | set -e
4 |
5 | BIN=$(cd ../target; pwd)
6 |
7 | # the PORT is in fact unused
8 | COMMON_ARGS="--verbose --aicirt $BIN/release/aicirt"
9 |
10 | (cd ../aicirt; cargo build --release)
11 |
12 | RLLM_LOG=debug
13 |
14 | FILES=
15 | for f in "$@" ; do
16 | if [ -f "$f" ] ; then
17 | FILES="$FILES $f"
18 | elif [ -f "$f/args.txt" ] ; then
19 | FILES="$FILES $f/args.txt"
20 | else
21 | echo "File $f not found"
22 | exit 1
23 | fi
24 | done
25 |
26 | for A in $FILES ; do
27 | echo
28 | echo
29 | echo
30 | echo "*** $A ***"
31 | echo
32 | ARGS="$COMMON_ARGS `cat $A`"
33 | for S in $(dirname $A)/*.safetensors ; do
34 | ARGS="$ARGS --test $S"
35 | done
36 | RUST_BACKTRACE=1 \
37 | RUST_LOG=info,rllm=$RLLM_LOG,aicirt=info \
38 | cargo run $REL -- $ARGS
39 | done
40 |
41 | echo "All OK!"
42 |
--------------------------------------------------------------------------------
/rllm/rllm-cuda/expected/llama/args.txt:
--------------------------------------------------------------------------------
1 | --model NousResearch/Llama-2-7b-hf --tokenizer llama
2 | -s test_maxtol=0.30 -s test_avgtol=0.10
3 |
--------------------------------------------------------------------------------
/rllm/rllm-cuda/expected/llama/cats.safetensors:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/llama/cats.safetensors
--------------------------------------------------------------------------------
/rllm/rllm-cuda/expected/llama/lighthouse.safetensors:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/llama/lighthouse.safetensors
--------------------------------------------------------------------------------
/rllm/rllm-cuda/expected/llama/primes.safetensors:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/llama/primes.safetensors
--------------------------------------------------------------------------------
/rllm/rllm-cuda/expected/orca/args.txt:
--------------------------------------------------------------------------------
1 | --model microsoft/Orca-2-13b@refs/pr/22 --tokenizer orca
2 | -s test_maxtol=0.50 -s test_avgtol=0.10
3 |
--------------------------------------------------------------------------------
/rllm/rllm-cuda/expected/orca/cats.safetensors:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/orca/cats.safetensors
--------------------------------------------------------------------------------
/rllm/rllm-cuda/expected/orca/lighthouse.safetensors:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/orca/lighthouse.safetensors
--------------------------------------------------------------------------------
/rllm/rllm-cuda/expected/orca/primes.safetensors:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/orca/primes.safetensors
--------------------------------------------------------------------------------
/rllm/rllm-cuda/expected/phi-1_5/args.txt:
--------------------------------------------------------------------------------
1 | --model microsoft/phi-1_5@refs/pr/66 --tokenizer phi
2 | -s test_maxtol=0.05 -s test_avgtol=0.01
--------------------------------------------------------------------------------
/rllm/rllm-cuda/expected/phi-1_5/cats.safetensors:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/phi-1_5/cats.safetensors
--------------------------------------------------------------------------------
/rllm/rllm-cuda/expected/phi-1_5/lighthouse.safetensors:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/phi-1_5/lighthouse.safetensors
--------------------------------------------------------------------------------
/rllm/rllm-cuda/expected/phi-1_5/primes.safetensors:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/phi-1_5/primes.safetensors
--------------------------------------------------------------------------------
/rllm/rllm-cuda/expected/phi-2/args.txt:
--------------------------------------------------------------------------------
1 | --model microsoft/phi-2 --tokenizer phi -s test_maxtol=2.0 -s test_avgtol=0.15
--------------------------------------------------------------------------------
/rllm/rllm-cuda/expected/phi-2/cats.safetensors:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/phi-2/cats.safetensors
--------------------------------------------------------------------------------
/rllm/rllm-cuda/expected/phi-2/lighthouse.safetensors:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/phi-2/lighthouse.safetensors
--------------------------------------------------------------------------------
/rllm/rllm-cuda/expected/phi-2/primes.safetensors:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/aici/ecc50362fe2c620c3c8487740267f6fae49397d3/rllm/rllm-cuda/expected/phi-2/primes.safetensors
--------------------------------------------------------------------------------
/rllm/rllm-cuda/scripts/cmp2.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | set -x
4 | set -e
5 | cargo run -- --sample-len 1
6 | mv step-1.safetensor single.safetensor
7 | cargo run -- --sample-len 1 --alt=7
8 | mv step-1.safetensor multi.safetensor
9 | python3 tensorcmp.py single.safetensor multi.safetensor
10 |
--------------------------------------------------------------------------------
/rllm/rllm-cuda/scripts/tensorcmp.py:
--------------------------------------------------------------------------------
1 | import safetensors
2 | import sys
3 | import torch
4 |
5 | args = sys.argv[1:]
6 |
7 |
8 | def check_all_close(tensor1: torch.Tensor, tensor2: torch.Tensor, rtol=1e-05, atol=1e-10):
9 | # assert torch.allclose(tensor1, tensor2, rtol=rtol, atol=atol)
10 | if torch.equal(tensor1, tensor2):
11 | return
12 | d = tensor1.sub(tensor2).abs()
13 | print(d)
14 | print(d.max())
15 | assert False
16 |
17 |
18 | def cmp(a: torch.Tensor, b: torch.Tensor):
19 | # if b.shape[0] == 27:
20 | # b = b[4:16, :, :]
21 | # elif b.shape[1] == 27:
22 | # b = b[:, 4:16, :]
23 |
24 | if a.shape == b.shape:
25 | print(a.shape)
26 | check_all_close(a, b)
27 | else:
28 | print("Size wrong", a.shape, b.shape)
29 | assert False
30 |
31 | def load_all(fn: str) -> dict[str, torch.Tensor]:
32 | r = {}
33 | with safetensors.safe_open(fn, framework="pt", device="cuda") as a:
34 | keys = a.keys()
35 | keys.sort()
36 | for key in keys:
37 | kk = key.split("_")[1]
38 | if kk not in r:
39 | r[kk] = a.get_tensor(key)
40 | return r
41 |
42 | def main():
43 | a = load_all(args[0])
44 | b = load_all(args[1])
45 |
46 | x1 = a["x1"]
47 | w1 = a["w1"].t().unsqueeze(0)
48 |
49 | m1 = a["m1"]
50 | print("X", x1.shape, w1.shape)
51 | m1p = x1.matmul(w1)
52 | #cmp(m1, m1p)
53 |
54 | bx1 = b["x1"]
55 | bw1 = b["w1"].t().unsqueeze(0)
56 | cmp(w1, bw1)
57 | cmp(bx1[:,4:16,:], x1)
58 | bm1 = bx1.matmul(bw1)
59 | cmp(bm1[:,4:16,:], m1p)
60 |
61 |
62 |
63 |
64 | # for key in a.keys():
65 | # print(key)
66 | # cmp(a[key], b[key])
67 |
68 | main()
--------------------------------------------------------------------------------
/rllm/rllm-cuda/src/llm/kernels.rs:
--------------------------------------------------------------------------------
1 | #[cfg(not(feature = "cuda"))]
2 | pub use super::refkernels::*;
3 | use tch::{Device, Tensor};
4 | #[cfg(feature = "cuda")]
5 | pub use tch_cuda::flash_attn_varlen as varlen_attn;
6 | #[cfg(feature = "cuda")]
7 | pub use tch_cuda::*;
8 |
9 | /// Convert a vector of lengths into a tensor of offsets, as expected by flash attn.
10 | pub fn to_offsets(seqlens: impl Iterator- , device: Device) -> (usize, Tensor) {
11 | let mut offsets = Vec::new();
12 | let mut offset = 0;
13 | let mut max = 0;
14 | for len in seqlens {
15 | max = std::cmp::max(len, max);
16 | offsets.push(offset as i32);
17 | offset += len;
18 | }
19 | offsets.push(offset as i32);
20 | (max, Tensor::from_slice(offsets.as_slice()).to(device))
21 | }
22 |
--------------------------------------------------------------------------------
/rllm/rllm-cuda/src/llm/paged/cuda_stub.rs:
--------------------------------------------------------------------------------
1 | use tch::Device;
2 |
3 | pub struct CudaEvent {}
4 |
5 | impl CudaEvent {
6 | pub fn new() -> Self {
7 | CudaEvent {}
8 | }
9 |
10 | pub fn wait(&self, _stream: &CudaStream) {}
11 | }
12 |
13 | pub struct CudaStream {}
14 |
15 | impl CudaStream {
16 | pub fn new(_device: Device) -> Self {
17 | CudaStream {}
18 | }
19 |
20 | pub fn current(_device: Device) -> Self {
21 | CudaStream {}
22 | }
23 | }
24 |
--------------------------------------------------------------------------------
/rllm/rllm-cuda/src/llm/paged/mod.rs:
--------------------------------------------------------------------------------
1 | #[cfg(not(feature = "cuda"))]
2 | mod cuda_stub;
3 |
4 | mod batch_info;
5 | mod blocks;
6 | mod cache_engine;
7 |
8 | pub use batch_info::*;
9 | pub use blocks::*;
10 | pub use cache_engine::*;
11 |
--------------------------------------------------------------------------------
/rllm/rllm-cuda/src/rllm-cuda.rs:
--------------------------------------------------------------------------------
1 | mod llm;
2 |
3 | use clap::Parser;
4 | use llm::{
5 | tmodel::{TModel, TchLoaderArgs},
6 | DType,
7 | };
8 | use rllm::util::parse_with_settings;
9 | use tch::Device;
10 |
11 | /// Serve LLMs with AICI over HTTP with tch (torch) backend.
12 | #[derive(Parser, Debug)]
13 | #[command(version, about, long_about = None)]
14 | pub struct DriverArgs {
15 | #[clap(flatten)]
16 | pub args: rllm::server::RllmCliArgs,
17 |
18 | /// Specify which type to use in the model (bf16, f16, f32)
19 | #[arg(long, default_value = "", help_heading = "Model")]
20 | pub dtype: String,
21 |
22 | /// Enable nvprof profiling for given engine step (if available)
23 | #[arg(long, default_value_t = 0, help_heading = "Development")]
24 | pub profile_step: usize,
25 | }
26 |
27 | #[actix_web::main]
28 | async fn main() -> () {
29 | let args = parse_with_settings::();
30 | let _ = args;
31 |
32 | let (device, dtype) = if tch::Cuda::is_available() {
33 | (Device::Cuda(0), None)
34 | } else {
35 | // At least on AMD 5500m MPS is 3x slower than CPU
36 | // #[cfg(target_os = "macos")]
37 | // let r = (Device::Mps, DType::Half);
38 | // #[cfg(not(target_os = "macos"))]
39 | let r = (Device::Cpu, Some(DType::Float));
40 | r
41 | };
42 |
43 | let dtype = match args.dtype.as_str() {
44 | "bf16" => Some(DType::BFloat16),
45 | "f16" => Some(DType::Half),
46 | "f32" => Some(DType::Float),
47 | "" => dtype,
48 | _ => panic!("invalid dtype; try one of bf16, f16, f32"),
49 | };
50 |
51 | let model_args = TchLoaderArgs {
52 | device,
53 | dtype,
54 | profile_step_no: args.profile_step,
55 | };
56 | rllm::server::server_main::(args.args, model_args).await;
57 | }
58 |
--------------------------------------------------------------------------------
/rllm/rllm-cuda/test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | set -e
4 |
5 | ./expected/go.sh \
6 | expected/phi-1_5 \
7 | expected/orca
8 |
9 | if [ "$1" = "all" ] ; then
10 | ./expected/go.sh \
11 | expected/codellama34 \
12 | expected/codellama \
13 | expected/phi-2 \
14 | expected/llama
15 | fi
16 |
17 |
--------------------------------------------------------------------------------
/rllm/rllm-llamacpp/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "rllm-llamacpp"
3 | version = "0.1.0"
4 | edition = "2021"
5 | rust-version = "1.75.0"
6 |
7 | [dependencies]
8 | actix-web = "4.4.0"
9 | anyhow = "1.0.79"
10 | clap = { version = "4.4.18", features = ["derive"] }
11 | llama_cpp_low = { path = "../llama-cpp-low" }
12 | log = "0.4.20"
13 | rllm = { path = "../rllm-base" }
14 | aicirt = { path = "../../aicirt" }
15 | rand = "0.8.5"
16 |
17 | [[bin]]
18 | name = "rllm-llamacpp"
19 | path = "src/rllm-llamacpp.rs"
20 |
21 | [features]
22 | default = []
23 | cuda = ["llama_cpp_low/cuda"]
24 |
--------------------------------------------------------------------------------
/rllm/rllm-llamacpp/README.md:
--------------------------------------------------------------------------------
1 | # rLLM for llama.cpp
2 |
3 | This is similar to the [CUDA-based rLLM](../rllm-cuda/)
4 | but built on top of [llama.cpp](https://github.com/ggerganov/llama.cpp).
5 |
6 | ## Building
7 |
8 | If you're not using the supplied docker container follow the
9 | [build setup instructions](../../README.md#development-environment-setup).
10 |
11 | To compile and run first aicirt and then the rllm server, run:
12 |
13 | ```bash
14 | ./server.sh phi2
15 | ```
16 |
17 | Run `./server.sh --help` for more options.
18 |
19 | You can also try passing `--cuda` before `phi2`, which will enable cuBLASS in llama.cpp.
20 | Note that this is different from [rllm-cuda](../rllm-cuda/),
21 | which may give you better performance when doing batched inference.
22 |
--------------------------------------------------------------------------------
/rllm/rllm-llamacpp/server.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | CPP=1 exec `dirname $0`/../rllm-cuda/server.sh "$@"
4 |
--------------------------------------------------------------------------------
/rllm/rllm-llamacpp/src/llamacpp/blocks.rs:
--------------------------------------------------------------------------------
1 | use rllm::{
2 | seq::{Sequence, SequenceGroup},
3 | SchedulerOutputs, TBlockSpaceManager,
4 | };
5 |
6 | use super::tmodel::TModel;
7 |
8 | /// Manages the mapping between logical and physical token blocks.
9 | pub struct CppBlockSpaceManager {}
10 |
11 | impl TBlockSpaceManager for CppBlockSpaceManager {
12 | fn can_allocate(&self, _seq_group: &SequenceGroup) -> bool {
13 | true
14 | }
15 |
16 | fn allocate(&mut self, seq_group: &mut SequenceGroup) {
17 | let seq = seq_group.only_seq();
18 | assert!(seq.num_kv_computed == 0);
19 | }
20 |
21 | fn can_append_slot(&self, _seq_group: &SequenceGroup) -> bool {
22 | true
23 | }
24 |
25 | fn append_slots(&mut self, _seq: &mut Sequence, _outputs: &mut SchedulerOutputs) {}
26 |
27 | fn get_num_free_gpu_blocks(&self) -> usize {
28 | 0
29 | }
30 |
31 | fn get_num_free_cpu_blocks(&self) -> usize {
32 | 0
33 | }
34 | }
35 |
--------------------------------------------------------------------------------
/rllm/rllm-llamacpp/src/llamacpp/loader.rs:
--------------------------------------------------------------------------------
1 | use std::sync::Arc;
2 |
3 | use anyhow::{bail, Result};
4 | use rllm::{config::ModelMeta, LoaderArgs, Repo, RllmEngine};
5 |
6 | use llama_cpp_low as cpp;
7 |
8 | use super::{
9 | blocks::CppBlockSpaceManager,
10 | tmodel::{CppLoaderArgs, TModel},
11 | };
12 |
13 | pub(super) fn load_rllm_engine(
14 | args: LoaderArgs,
15 | mut model_args: CppLoaderArgs,
16 | ) -> Result> {
17 | let model = do_load(&args, &mut model_args)?;
18 | let rllm_config = RllmEngine::::build_config(&args, &mut model_args)?;
19 |
20 | let mut cparams = cpp::ContextParams::default();
21 | cparams.n_batch = rllm_config.scheduler.max_num_batched_tokens as u32;
22 | cparams.n_ctx = 10000; // TODO
23 | model.setup_context(cparams);
24 |
25 | let rllm_config = Arc::new(rllm_config);
26 | let tmodel = TModel::new(rllm_config.clone(), model);
27 | let block_mgr = CppBlockSpaceManager {};
28 | RllmEngine::build(args, tmodel, block_mgr, rllm_config)
29 | }
30 |
31 | fn do_load(args: &LoaderArgs, model_args: &mut CppLoaderArgs) -> Result {
32 | if model_args.cached_model.is_none() {
33 | let repo = Repo::from(args)?;
34 | log::info!("loading the model from {}", repo);
35 |
36 | let gguf = match args.file.as_ref() {
37 | Some(gguf) => gguf,
38 | None => {
39 | bail!("--gguf file.gguf or --model user/model::file.gguf is required for loading the model")
40 | }
41 | };
42 |
43 | let file = repo.get(gguf)?;
44 |
45 | let mut mparams = cpp::ModelParams::default();
46 | // TODO: make this configurable
47 | mparams.set_split_mode(cpp::SplitMode::Layer);
48 | match model_args.n_gpu_layers {
49 | Some(n) => mparams.n_gpu_layers = n as i32,
50 | None => {
51 | mparams.n_gpu_layers = 999;
52 | // by default, don't GPU offload on Intel macs - it's much slower than CPU
53 | #[cfg(all(target_os = "macos", target_arch = "x86_64"))]
54 | {
55 | mparams.n_gpu_layers = 0;
56 | }
57 | }
58 | }
59 | log::info!("{} layer(s) offloaded to GPU", mparams.n_gpu_layers);
60 |
61 | let m = cpp::Model::from_file(file.to_str().unwrap(), mparams)?;
62 | model_args.cached_model = Some(m);
63 | }
64 |
65 | let model = model_args.cached_model.as_ref().unwrap().clone();
66 | Ok(model)
67 | }
68 |
69 | pub(super) fn load_model_config(
70 | args: &LoaderArgs,
71 | model_args: &mut CppLoaderArgs,
72 | ) -> Result {
73 | let model = do_load(args, model_args)?;
74 |
75 | let info = model.model_info();
76 | let vocab_size = info.n_vocab.try_into().unwrap();
77 | let max_sequence_length = info.n_ctx_train.try_into().unwrap();
78 |
79 | let mut meta = ModelMeta {
80 | id: args.model_id.clone(),
81 | max_sequence_length,
82 | vocab_size,
83 | tok_vocab_size: vocab_size,
84 | };
85 |
86 | // hidden_size: info.n_embd.try_into().unwrap(),
87 | // rope_theta: info.rope,
88 | // rotary_dim: max_sequence_length,
89 |
90 | let tok = aicirt::bintokens::find_tokenizer(&args.tokenizer)?;
91 | meta.tok_vocab_size = tok.tokrx_info().vocab_size as usize;
92 |
93 | Ok(meta)
94 | }
95 |
--------------------------------------------------------------------------------
/rllm/rllm-llamacpp/src/llamacpp/mod.rs:
--------------------------------------------------------------------------------
1 | pub mod blocks;
2 | pub mod loader;
3 | pub mod tmodel;
4 | pub mod seqid;
5 |
6 | #[derive(Clone)]
7 | pub struct Tensor {
8 | ptr: *mut f32,
9 | size: usize,
10 | }
11 |
12 | impl Tensor {
13 | pub fn from_slice(slice: &'static [f32]) -> Self {
14 | Tensor {
15 | ptr: slice.as_ptr() as *mut f32,
16 | size: slice.len(),
17 | }
18 | }
19 |
20 | pub fn as_slice(&self) -> &[f32] {
21 | unsafe { std::slice::from_raw_parts(self.ptr, self.size) }
22 | }
23 |
24 | pub fn as_mut_slice(&mut self) -> &mut [f32] {
25 | unsafe { std::slice::from_raw_parts_mut(self.ptr, self.size) }
26 | }
27 |
28 | pub fn to_vec1(&self) -> Vec {
29 | self.as_slice().to_vec()
30 | }
31 | }
32 |
--------------------------------------------------------------------------------
/rllm/rllm-llamacpp/src/llamacpp/seqid.rs:
--------------------------------------------------------------------------------
1 | use std::sync::Mutex;
2 |
3 | use rllm::{HashMap, SeqId, SequenceManager};
4 | use llama_cpp_low as cpp;
5 |
6 | pub struct CppSequenceManager {
7 | model: cpp::Model,
8 | seqs: Mutex>,
9 | }
10 |
11 | impl CppSequenceManager {
12 | pub fn new(model: cpp::Model) -> Self {
13 | Self {
14 | model,
15 | seqs: Mutex::new(HashMap::default()),
16 | }
17 | }
18 |
19 | pub fn with_cpp(&self, seq: SeqId, cb: impl FnOnce(&cpp::Sequence)) {
20 | let seqs = self.seqs.lock().unwrap();
21 | let seq = seqs.get(&seq).unwrap();
22 | cb(seq);
23 | }
24 | }
25 |
26 | impl SequenceManager for CppSequenceManager {
27 | fn new_sequence(&self) -> SeqId {
28 | let r = self.model.new_sequence();
29 | let id = SeqId(r.id() as usize);
30 | self.seqs.lock().unwrap().insert(id, r);
31 | id
32 | }
33 |
34 | fn copy(&self, src: SeqId, dst: SeqId, length: usize) {
35 | let seqs = self.seqs.lock().unwrap();
36 | let src = seqs.get(&src).unwrap();
37 | let dst = seqs.get(&dst).unwrap();
38 | dst.cp_from(src, 0, length as i32);
39 | }
40 |
41 | fn trim(&self, seq: SeqId, length: usize) {
42 | let seqs = self.seqs.lock().unwrap();
43 | let seq = seqs.get(&seq).unwrap();
44 | seq.rm(length as i32, -1);
45 | }
46 |
47 | fn delete(&self, seq: SeqId) {
48 | self.seqs.lock().unwrap().remove(&seq);
49 | }
50 | }
51 |
--------------------------------------------------------------------------------
/rllm/rllm-llamacpp/src/rllm-llamacpp.rs:
--------------------------------------------------------------------------------
1 | mod llamacpp;
2 | use clap::Parser;
3 | use llamacpp::tmodel::{CppLoaderArgs, TModel};
4 | use rllm::util::parse_with_settings;
5 |
6 | /// Serve LLMs with AICI over HTTP with llama.cpp backend.
7 | #[derive(Parser, Debug)]
8 | #[command(version, about, long_about = None)]
9 | pub struct CppArgs {
10 | #[clap(flatten)]
11 | pub args: rllm::server::RllmCliArgs,
12 |
13 | /// Name of .gguf file inside of the model folder/repo.
14 | #[arg(long, help_heading = "Model")]
15 | pub gguf: Option,
16 |
17 | /// How many model layers to offload to GPU (if available)
18 | #[arg(long, short = 'g', help_heading = "Model")]
19 | pub gpu_layers: Option,
20 | }
21 |
22 | #[actix_web::main]
23 | async fn main() -> () {
24 | let mut args = parse_with_settings::();
25 | args.args.file = args.gguf;
26 | let model_args = CppLoaderArgs::new(args.gpu_layers);
27 | rllm::server::server_main::(args.args, model_args).await;
28 | }
29 |
--------------------------------------------------------------------------------
/rllm/tch-cuda/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "tch-cuda"
3 | version = "0.1.0"
4 | edition = "2021"
5 |
6 | description = "Flash attention layer for the tch-rs"
7 | license = "MIT"
8 | readme = "README.md"
9 |
10 | [dependencies]
11 | half = { version = "2.3.1", features = ["num-traits"] }
12 | libc = "0.2.151"
13 | tch = "0.14.0"
14 | torch-sys = "0.14.0"
15 | rustc-hash = "2.0.0"
16 |
17 | [build-dependencies]
18 | anyhow = { version = "1", features = ["backtrace"] }
19 | num_cpus = "1.15.0"
20 | rayon = "1.7.0"
21 | glob = "0.3.1"
22 |
23 | [dev-dependencies]
24 | anyhow = { version = "1", features = ["backtrace"] }
25 |
--------------------------------------------------------------------------------
/rllm/tch-cuda/README.md:
--------------------------------------------------------------------------------
1 | # tch-cuda
2 |
3 | based on
4 | https://github.com/Dao-AILab/flash-attention/tree/9356a1c0389660d7e231ff3163c1ac17d9e3824a/csrc/flash_attn/src
5 |
6 |
7 |
--------------------------------------------------------------------------------
/rllm/tch-cuda/convhd.js:
--------------------------------------------------------------------------------
1 | const fs = require('fs');
2 |
3 | const args = process.argv.slice(2);
4 |
5 | let all_ext_c = "";
6 | let all_rust = "";
7 |
8 | const tpmap = {
9 | "torch::Tensor&": ["tensor", "*mut C_tensor"],
10 | "const torch::Tensor&": ["tensor", "*const C_tensor"],
11 | "const c10::optional&": ["tensor", "*const C_tensor"],
12 | };
13 |
14 | const r_tpmap = {
15 | "float": "f32",
16 | "int": "i32",
17 | "bool": "bool",
18 | }
19 |
20 | let left = fs.readFileSync(args[0], 'utf-8').replace(/^(\S+) (\w+)\(([^)]*)\);/mg, (_, rettp, fname, args) => {
21 | let skip = false;
22 |
23 | if (rettp != "void")
24 | skip = true;
25 |
26 | if (args.indexOf("std::map<") >= 0) {
27 | return ""
28 | }
29 |
30 | let ext_c = `char* ${fname}_C(\n`;
31 | let rust = `fn ${fname}_C(\n`;
32 | let ext_c_inner = `${fname}(`;
33 |
34 | args.split(/,\s+/).forEach((x, i, arr) => {
35 | const m = /^(.*)\s+(\w+)$/.exec(x.trim());
36 | let tp = m[1];
37 | let aname = m[2];
38 |
39 | if (tp == "torch::Tensor")
40 | skip = true;
41 |
42 |
43 | const tp0 = tpmap[tp];
44 |
45 | let rtp = tp0?.[1] ?? r_tpmap[tp] ?? `C__${tp}`;
46 | let ctp = tp0?.[0] ?? tp;
47 |
48 | ext_c += ` ${ctp} ${aname}`;
49 |
50 | if (tp0) {
51 | ext_c_inner += `*${aname}`
52 | } else {
53 | ext_c_inner += `${aname}`
54 | }
55 |
56 | rust += ` ${aname}: ${rtp}`;
57 |
58 | if (i < arr.length - 1) {
59 | ext_c += ",\n";
60 | ext_c_inner += ",\n";
61 | rust += ",\n";
62 | }
63 | })
64 |
65 | ext_c += ") {\nPROTECT(" + ext_c_inner + "));\n}\n\n";
66 | rust += ") -> *mut libc::c_char;\n\n";
67 |
68 | if (!skip) {
69 | all_ext_c += ext_c;
70 | all_rust += rust;
71 | }
72 |
73 | return "";
74 | });
75 |
76 | left = left.replace(/^\s*#.*/mg, "").trim();
77 | if (left) {
78 | console.log("left", left);
79 | } else {
80 | console.log(all_ext_c);
81 | console.log(all_rust);
82 | }
83 |
--------------------------------------------------------------------------------
/rllm/tch-cuda/kernels/flash_attn/AUTHORS:
--------------------------------------------------------------------------------
1 | Tri Dao, trid@cs.stanford.edu
--------------------------------------------------------------------------------
/rllm/tch-cuda/kernels/flash_attn/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/rllm/tch-cuda/kernels/flash_attn/block_info.h:
--------------------------------------------------------------------------------
1 | /******************************************************************************
2 | * Copyright (c) 2023, Tri Dao.
3 | ******************************************************************************/
4 |
5 | #pragma once
6 |
7 | namespace flash {
8 |
9 | ////////////////////////////////////////////////////////////////////////////////////////////////////
10 |
11 | template
12 | struct BlockInfo {
13 |
14 | template
15 | __device__ BlockInfo(const Params ¶ms, const int bidb)
16 | : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
17 | , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
18 | , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
19 | // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
20 | // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
21 | , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
22 | , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
23 | {
24 | }
25 |
26 | template
27 | inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
28 | return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
29 | }
30 |
31 | template
32 | inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
33 | return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
34 | }
35 |
36 | const int sum_s_q;
37 | const int sum_s_k;
38 | const int actual_seqlen_q;
39 | // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
40 | const int seqlen_k_cache;
41 | const int actual_seqlen_k;
42 | };
43 |
44 | ////////////////////////////////////////////////////////////////////////////////////////////////////
45 |
46 | } // namespace flash
47 |
--------------------------------------------------------------------------------
/rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim128_bf16_sm80.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Tri Dao.
2 | // Splitting the different head dimensions to different files to speed up compilation.
3 | // This file is auto-generated. See "generate_kernels.py"
4 |
5 | #include "flash_fwd_launch_template.h"
6 |
7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
8 |
--------------------------------------------------------------------------------
/rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim128_fp16_sm80.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Tri Dao.
2 | // Splitting the different head dimensions to different files to speed up compilation.
3 | // This file is auto-generated. See "generate_kernels.py"
4 |
5 | #include "flash_fwd_launch_template.h"
6 |
7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
8 |
--------------------------------------------------------------------------------
/rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim160_bf16_sm80.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Tri Dao.
2 | // Splitting the different head dimensions to different files to speed up compilation.
3 | // This file is auto-generated. See "generate_kernels.py"
4 |
5 | #include "flash_fwd_launch_template.h"
6 |
7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
8 |
--------------------------------------------------------------------------------
/rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim160_fp16_sm80.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Tri Dao.
2 | // Splitting the different head dimensions to different files to speed up compilation.
3 | // This file is auto-generated. See "generate_kernels.py"
4 |
5 | #include "flash_fwd_launch_template.h"
6 |
7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
8 |
--------------------------------------------------------------------------------
/rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim192_bf16_sm80.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Tri Dao.
2 | // Splitting the different head dimensions to different files to speed up compilation.
3 | // This file is auto-generated. See "generate_kernels.py"
4 |
5 | #include "flash_fwd_launch_template.h"
6 |
7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
8 |
--------------------------------------------------------------------------------
/rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim192_fp16_sm80.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Tri Dao.
2 | // Splitting the different head dimensions to different files to speed up compilation.
3 | // This file is auto-generated. See "generate_kernels.py"
4 |
5 | #include "flash_fwd_launch_template.h"
6 |
7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
8 |
--------------------------------------------------------------------------------
/rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim224_bf16_sm80.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Tri Dao.
2 | // Splitting the different head dimensions to different files to speed up compilation.
3 | // This file is auto-generated. See "generate_kernels.py"
4 |
5 | #include "flash_fwd_launch_template.h"
6 |
7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
8 |
--------------------------------------------------------------------------------
/rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim224_fp16_sm80.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Tri Dao.
2 | // Splitting the different head dimensions to different files to speed up compilation.
3 | // This file is auto-generated. See "generate_kernels.py"
4 |
5 | #include "flash_fwd_launch_template.h"
6 |
7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
8 |
--------------------------------------------------------------------------------
/rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim256_bf16_sm80.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Tri Dao.
2 | // Splitting the different head dimensions to different files to speed up compilation.
3 | // This file is auto-generated. See "generate_kernels.py"
4 |
5 | #include "flash_fwd_launch_template.h"
6 |
7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
8 |
--------------------------------------------------------------------------------
/rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim256_fp16_sm80.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Tri Dao.
2 | // Splitting the different head dimensions to different files to speed up compilation.
3 | // This file is auto-generated. See "generate_kernels.py"
4 |
5 | #include "flash_fwd_launch_template.h"
6 |
7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
8 |
--------------------------------------------------------------------------------
/rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim32_bf16_sm80.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Tri Dao.
2 | // Splitting the different head dimensions to different files to speed up compilation.
3 | // This file is auto-generated. See "generate_kernels.py"
4 |
5 | #include "flash_fwd_launch_template.h"
6 |
7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
8 |
--------------------------------------------------------------------------------
/rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim32_fp16_sm80.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Tri Dao.
2 | // Splitting the different head dimensions to different files to speed up compilation.
3 | // This file is auto-generated. See "generate_kernels.py"
4 |
5 | #include "flash_fwd_launch_template.h"
6 |
7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
8 |
--------------------------------------------------------------------------------
/rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim64_bf16_sm80.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Tri Dao.
2 | // Splitting the different head dimensions to different files to speed up compilation.
3 | // This file is auto-generated. See "generate_kernels.py"
4 |
5 | #include "flash_fwd_launch_template.h"
6 |
7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
8 |
--------------------------------------------------------------------------------
/rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim64_fp16_sm80.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Tri Dao.
2 | // Splitting the different head dimensions to different files to speed up compilation.
3 | // This file is auto-generated. See "generate_kernels.py"
4 |
5 | #include "flash_fwd_launch_template.h"
6 |
7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
8 |
--------------------------------------------------------------------------------
/rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim96_bf16_sm80.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Tri Dao.
2 | // Splitting the different head dimensions to different files to speed up compilation.
3 | // This file is auto-generated. See "generate_kernels.py"
4 |
5 | #include "flash_fwd_launch_template.h"
6 |
7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
8 |
--------------------------------------------------------------------------------
/rllm/tch-cuda/kernels/flash_attn/flash_fwd_split_hdim96_fp16_sm80.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2023, Tri Dao.
2 | // Splitting the different head dimensions to different files to speed up compilation.
3 | // This file is auto-generated. See "generate_kernels.py"
4 |
5 | #include "flash_fwd_launch_template.h"
6 |
7 | template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
8 |
--------------------------------------------------------------------------------
/rllm/tch-cuda/kernels/flash_attn/static_switch.h:
--------------------------------------------------------------------------------
1 | // Inspired by
2 | // https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
3 | // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
4 |
5 | #pragma once
6 |
7 | /// @param COND - a boolean expression to switch by
8 | /// @param CONST_NAME - a name given for the constexpr bool variable.
9 | /// @param ... - code to execute for true and false
10 | ///
11 | /// Usage:
12 | /// ```
13 | /// BOOL_SWITCH(flag, BoolConst, [&] {
14 | /// some_function(...);
15 | /// });
16 | /// ```
17 | #define BOOL_SWITCH(COND, CONST_NAME, ...) \
18 | [&] { \
19 | if (COND) { \
20 | constexpr static bool CONST_NAME = true; \
21 | return __VA_ARGS__(); \
22 | } else { \
23 | constexpr static bool CONST_NAME = false; \
24 | return __VA_ARGS__(); \
25 | } \
26 | }()
27 |
28 | #define FP16_SWITCH(COND, ...) \
29 | [&] { \
30 | if (COND) { \
31 | using elem_type = cutlass::half_t; \
32 | return __VA_ARGS__(); \
33 | } else { \
34 | using elem_type = cutlass::bfloat16_t; \
35 | return __VA_ARGS__(); \
36 | } \
37 | }()
38 |
39 | #define FWD_HEADDIM_SWITCH(HEADDIM, ...) \
40 | [&] { \
41 | if (HEADDIM <= 32) { \
42 | constexpr static int kHeadDim = 32; \
43 | return __VA_ARGS__(); \
44 | } else if (HEADDIM <= 64) { \
45 | constexpr static int kHeadDim = 64; \
46 | return __VA_ARGS__(); \
47 | } else if (HEADDIM <= 96) { \
48 | constexpr static int kHeadDim = 96; \
49 | return __VA_ARGS__(); \
50 | } else if (HEADDIM <= 128) { \
51 | constexpr static int kHeadDim = 128; \
52 | return __VA_ARGS__(); \
53 | } else if (HEADDIM <= 160) { \
54 | constexpr static int kHeadDim = 160; \
55 | return __VA_ARGS__(); \
56 | } else if (HEADDIM <= 192) { \
57 | constexpr static int kHeadDim = 192; \
58 | return __VA_ARGS__(); \
59 | } else if (HEADDIM <= 224) { \
60 | constexpr static int kHeadDim = 224; \
61 | return __VA_ARGS__(); \
62 | } else if (HEADDIM <= 256) { \
63 | constexpr static int kHeadDim = 256; \
64 | return __VA_ARGS__(); \
65 | } \
66 | }()
67 |
--------------------------------------------------------------------------------
/rllm/tch-cuda/kernels/vllm/attention/attention_dtypes.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "attention_generic.cuh"
4 | #include "dtype_float16.cuh"
5 | #include "dtype_float32.cuh"
6 | #include "dtype_bfloat16.cuh"
7 |
--------------------------------------------------------------------------------
/rllm/tch-cuda/kernels/vllm/attention/attention_generic.cuh:
--------------------------------------------------------------------------------
1 | /*
2 | * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
3 | * Copyright (c) 2023, The vLLM team.
4 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
5 | *
6 | * Licensed under the Apache License, Version 2.0 (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #pragma once
19 |
20 | #include
21 |
22 | namespace vllm {
23 |
24 | // A vector type to store Q, K, V elements.
25 | template
26 | struct Vec {};
27 |
28 | // A vector type to store FP32 accumulators.
29 | template
30 | struct FloatVec {};
31 |
32 | // Template vector operations.
33 | template
34 | inline __device__ Acc mul(A a, B b);
35 |
36 | template
37 | inline __device__ float sum(T v);
38 |
39 | template
40 | inline __device__ float dot(T a, T b) {
41 | return sum(mul(a, b));
42 | }
43 |
44 | template
45 | inline __device__ float dot(T a, T b) {
46 | return sum(mul(a, b));
47 | }
48 |
49 | template
50 | inline __device__ void zero(T& dst) {
51 | constexpr int WORDS = sizeof(T) / 4;
52 | union {
53 | T raw;
54 | uint32_t words[WORDS];
55 | } tmp;
56 |
57 | #pragma unroll
58 | for (int ii = 0; ii < WORDS; ++ii) {
59 | tmp.words[ii] = 0u;
60 | }
61 | dst = tmp.raw;
62 | }
63 |
64 | } // namespace vllm
65 |
--------------------------------------------------------------------------------
/rllm/tch-cuda/kernels/vllm/attention/attention_utils.cuh:
--------------------------------------------------------------------------------
1 | /*
2 | * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
3 | * Copyright (c) 2023, The vLLM team.
4 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
5 | *
6 | * Licensed under the Apache License, Version 2.0 (the "License");
7 | * you may not use this file except in compliance with the License.
8 | * You may obtain a copy of the License at
9 | *
10 | * http://www.apache.org/licenses/LICENSE-2.0
11 | *
12 | * Unless required by applicable law or agreed to in writing, software
13 | * distributed under the License is distributed on an "AS IS" BASIS,
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | * See the License for the specific language governing permissions and
16 | * limitations under the License.
17 | */
18 | #pragma once
19 |
20 | #include "../cuda_compat.h"
21 | #include "attention_dtypes.h"
22 |
23 | #include
24 | #include
25 |
26 | namespace vllm {
27 |
28 | // Q*K^T operation.
29 | template
30 | inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
31 | using A_vec = typename FloatVec::Type;
32 | // Compute the parallel products for Q*K^T (treat vector lanes separately).
33 | A_vec qk_vec = mul(q[0], k[0]);
34 | #pragma unroll
35 | for (int ii = 1; ii < N; ++ii) {
36 | qk_vec = fma(q[ii], k[ii], qk_vec);
37 | }
38 |
39 | // Finalize the reduction across lanes.
40 | float qk = sum(qk_vec);
41 | #pragma unroll
42 | for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
43 | qk += VLLM_SHFL_XOR_SYNC(qk, mask);
44 | }
45 | return qk;
46 | }
47 |
48 | template
49 | struct Qk_dot {
50 | template
51 | static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
52 | return qk_dot_(q, k);
53 | }
54 | };
55 |
56 | } // namespace vllm
57 |
--------------------------------------------------------------------------------
/rllm/tch-cuda/kernels/vllm/cache.h:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include